feat: add streaming support and num_gpu option to Ollama client

This commit is contained in:
William Valentin
2026-02-05 15:51:28 -08:00
parent a2e1f73493
commit dbf1acd822
3 changed files with 63 additions and 1 deletions
+1
View File
@@ -18,6 +18,7 @@ const modelConfigSchema = z.object({
api_key: z.string().optional(), api_key: z.string().optional(),
auth_token: z.string().optional(), auth_token: z.string().optional(),
for: z.array(z.string()).optional(), for: z.array(z.string()).optional(),
num_gpu: z.number().optional(),
}); });
const modelsSchema = z.object({ const modelsSchema = z.object({
+61 -1
View File
@@ -1,20 +1,23 @@
import { Ollama } from 'ollama'; import { Ollama } from 'ollama';
import type { ChatRequest, ChatResponse, ModelClient } from '../types.js'; import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js';
export interface OllamaClientConfig { export interface OllamaClientConfig {
host?: string; host?: string;
model: string; model: string;
numGpu?: number;
} }
export class OllamaClient implements ModelClient { export class OllamaClient implements ModelClient {
private client: Ollama; private client: Ollama;
private model: string; private model: string;
private numGpu: number;
constructor(config: OllamaClientConfig) { constructor(config: OllamaClientConfig) {
this.client = new Ollama({ this.client = new Ollama({
host: config.host ?? 'http://localhost:11434', host: config.host ?? 'http://localhost:11434',
}); });
this.model = config.model; this.model = config.model;
this.numGpu = config.numGpu ?? -1;
} }
async chat(request: ChatRequest): Promise<ChatResponse> { async chat(request: ChatRequest): Promise<ChatResponse> {
@@ -31,6 +34,9 @@ export class OllamaClient implements ModelClient {
const response = await this.client.chat({ const response = await this.client.chat({
model: this.model, model: this.model,
messages, messages,
options: {
num_gpu: this.numGpu,
},
}); });
return { return {
@@ -42,4 +48,58 @@ export class OllamaClient implements ModelClient {
}, },
}; };
} }
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = [];
if (request.system) {
messages.push({ role: 'system', content: request.system });
}
for (const msg of request.messages) {
messages.push({ role: msg.role, content: msg.content });
}
try {
const stream = await this.client.chat({
model: this.model,
messages,
stream: true,
options: {
num_gpu: this.numGpu,
},
});
let inputTokens = 0;
let outputTokens = 0;
for await (const chunk of stream) {
if (chunk.message?.content) {
yield { type: 'content', content: chunk.message.content };
}
if (chunk.prompt_eval_count) {
inputTokens = chunk.prompt_eval_count;
}
if (chunk.eval_count) {
outputTokens = chunk.eval_count;
}
if (chunk.done) {
yield {
type: 'done',
usage: {
inputTokens,
outputTokens,
},
};
}
}
} catch (error) {
yield {
type: 'error',
error: error instanceof Error ? error : new Error(String(error)),
};
}
}
} }
+1
View File
@@ -1,6 +1,7 @@
export interface Message { export interface Message {
role: 'user' | 'assistant'; role: 'user' | 'assistant';
content: string; content: string;
timestamp?: number;
} }
export interface ChatRequest { export interface ChatRequest {