diff --git a/src/config/schema.ts b/src/config/schema.ts index 317cb02..1bc7760 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -18,6 +18,7 @@ const modelConfigSchema = z.object({ api_key: z.string().optional(), auth_token: z.string().optional(), for: z.array(z.string()).optional(), + num_gpu: z.number().optional(), }); const modelsSchema = z.object({ diff --git a/src/models/local/ollama.ts b/src/models/local/ollama.ts index 94361c1..74e5d0d 100644 --- a/src/models/local/ollama.ts +++ b/src/models/local/ollama.ts @@ -1,20 +1,23 @@ import { Ollama } from 'ollama'; -import type { ChatRequest, ChatResponse, ModelClient } from '../types.js'; +import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js'; export interface OllamaClientConfig { host?: string; model: string; + numGpu?: number; } export class OllamaClient implements ModelClient { private client: Ollama; private model: string; + private numGpu: number; constructor(config: OllamaClientConfig) { this.client = new Ollama({ host: config.host ?? 'http://localhost:11434', }); this.model = config.model; + this.numGpu = config.numGpu ?? -1; } async chat(request: ChatRequest): Promise { @@ -31,6 +34,9 @@ export class OllamaClient implements ModelClient { const response = await this.client.chat({ model: this.model, messages, + options: { + num_gpu: this.numGpu, + }, }); return { @@ -42,4 +48,58 @@ export class OllamaClient implements ModelClient { }, }; } + + async *chatStream(request: ChatRequest): AsyncIterable { + const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = []; + + if (request.system) { + messages.push({ role: 'system', content: request.system }); + } + + for (const msg of request.messages) { + messages.push({ role: msg.role, content: msg.content }); + } + + try { + const stream = await this.client.chat({ + model: this.model, + messages, + stream: true, + options: { + num_gpu: this.numGpu, + }, + }); + + let inputTokens = 0; + let outputTokens = 0; + + for await (const chunk of stream) { + if (chunk.message?.content) { + yield { type: 'content', content: chunk.message.content }; + } + + if (chunk.prompt_eval_count) { + inputTokens = chunk.prompt_eval_count; + } + if (chunk.eval_count) { + outputTokens = chunk.eval_count; + } + + if (chunk.done) { + yield { + type: 'done', + usage: { + inputTokens, + outputTokens, + }, + }; + } + } + } catch (error) { + yield { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)), + }; + } + } } diff --git a/src/models/types.ts b/src/models/types.ts index 95badf3..b3cc5a1 100644 --- a/src/models/types.ts +++ b/src/models/types.ts @@ -1,6 +1,7 @@ export interface Message { role: 'user' | 'assistant'; content: string; + timestamp?: number; } export interface ChatRequest {