diff --git a/src/models/local/llamacpp.test.ts b/src/models/local/llamacpp.test.ts index fd3518b..1884a28 100644 --- a/src/models/local/llamacpp.test.ts +++ b/src/models/local/llamacpp.test.ts @@ -1,5 +1,6 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import { LlamaCppClient } from './llamacpp.js'; +import type { ChatStreamEvent } from '../types.js'; describe('LlamaCppClient', () => { const mockFetch = vi.fn(); @@ -33,4 +34,51 @@ describe('LlamaCppClient', () => { expect(response.usage.inputTokens).toBe(10); expect(response.usage.outputTokens).toBe(5); }); + + it('streams responses via SSE', async () => { + const chunks = [ + 'data: {"choices":[{"delta":{"content":"Hello"}}]}\n\n', + 'data: {"choices":[{"delta":{"content":" world"}}]}\n\n', + 'data: {"choices":[{}],"usage":{"prompt_tokens":5,"completion_tokens":2}}\n\n', + 'data: [DONE]\n\n', + ]; + + const encoder = new TextEncoder(); + let chunkIndex = 0; + + const mockStream = new ReadableStream({ + pull(controller) { + if (chunkIndex < chunks.length) { + controller.enqueue(encoder.encode(chunks[chunkIndex])); + chunkIndex++; + } else { + controller.close(); + } + }, + }); + + mockFetch.mockResolvedValue({ + ok: true, + body: mockStream, + }); + + const client = new LlamaCppClient({ + endpoint: 'http://localhost:8080', + }); + + const events: ChatStreamEvent[] = []; + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Hi' }], + })) { + events.push(event); + } + + expect(events).toHaveLength(3); + expect(events[0]).toEqual({ type: 'content', content: 'Hello' }); + expect(events[1]).toEqual({ type: 'content', content: ' world' }); + expect(events[2]).toEqual({ + type: 'done', + usage: { inputTokens: 5, outputTokens: 2 }, + }); + }); }); diff --git a/src/models/local/llamacpp.ts b/src/models/local/llamacpp.ts index c9129eb..8aeac17 100644 --- a/src/models/local/llamacpp.ts +++ b/src/models/local/llamacpp.ts @@ -1,4 +1,4 @@ -import type { ChatRequest, ChatResponse, ModelClient } from '../types.js'; +import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js'; export interface LlamaCppClientConfig { endpoint: string; @@ -15,6 +15,11 @@ interface LlamaCppResponse { usage: { prompt_tokens: number; completion_tokens: number }; } +interface LlamaCppStreamChunk { + choices: Array<{ delta?: { content?: string } }>; + usage?: { prompt_tokens: number; completion_tokens: number }; +} + export class LlamaCppClient implements ModelClient { private endpoint: string; private authToken?: string; @@ -68,4 +73,91 @@ export class LlamaCppClient implements ModelClient { }, }; } + + async *chatStream(request: ChatRequest): AsyncIterable { + const messages: LlamaCppMessage[] = []; + + if (request.system) { + messages.push({ role: 'system', content: request.system }); + } + + for (const msg of request.messages) { + messages.push({ role: msg.role, content: msg.content }); + } + + const headers: Record = { + 'Content-Type': 'application/json', + }; + + if (this.authToken) { + headers['Authorization'] = `Bearer ${this.authToken}`; + } + + try { + const response = await fetch(`${this.endpoint}/v1/chat/completions`, { + method: 'POST', + headers, + body: JSON.stringify({ + messages, + max_tokens: request.maxTokens ?? 2048, + stream: true, + }), + }); + + if (!response.ok) { + const text = await response.text(); + throw new Error(`llama-server error (${response.status}): ${text}`); + } + + if (!response.body) { + throw new Error('No response body for streaming'); + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + let usage = { inputTokens: 0, outputTokens: 0 }; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop() ?? ''; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed || !trimmed.startsWith('data: ')) continue; + + const data = trimmed.slice(6); + if (data === '[DONE]') continue; + + try { + const chunk = JSON.parse(data) as LlamaCppStreamChunk; + + if (chunk.choices[0]?.delta?.content) { + yield { type: 'content', content: chunk.choices[0].delta.content }; + } + + if (chunk.usage) { + usage = { + inputTokens: chunk.usage.prompt_tokens, + outputTokens: chunk.usage.completion_tokens, + }; + } + } catch { + // Skip malformed JSON + } + } + } + + yield { type: 'done', usage }; + } catch (error) { + yield { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)), + }; + } + } }