From fb20acfbcdc878ba5269f524f97c0652befa9538 Mon Sep 17 00:00:00 2001 From: William Valentin Date: Sat, 7 Feb 2026 17:20:27 -0800 Subject: [PATCH] feat: add tool calling support to Ollama and llama.cpp clients MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Ollama: pass tools to API, parse tool_calls responses, handle thinking field from reasoning models (deepseek-r1, glm-4.7-flash) - llama.cpp: pass tools via OpenAI-compatible endpoint, parse tool_calls, accumulate streaming tool call deltas - Both clients now set stopReason to 'tool_use' when tool calls are present - Tests: 12 new tests (8 Ollama + 5 llama.cpp, total 983→995) --- src/models/local/llamacpp.test.ts | 244 ++++++++++++++++++++++++++++++ src/models/local/llamacpp.ts | 128 ++++++++++++++-- src/models/local/ollama.test.ts | 225 +++++++++++++++++++++++++-- src/models/local/ollama.ts | 88 ++++++++++- 4 files changed, 655 insertions(+), 30 deletions(-) diff --git a/src/models/local/llamacpp.test.ts b/src/models/local/llamacpp.test.ts index 8e0a91d..1735f14 100644 --- a/src/models/local/llamacpp.test.ts +++ b/src/models/local/llamacpp.test.ts @@ -6,6 +6,7 @@ describe('LlamaCppClient', () => { const mockFetch = vi.fn(); beforeEach(() => { + mockFetch.mockReset(); vi.stubGlobal('fetch', mockFetch); }); @@ -96,4 +97,247 @@ describe('LlamaCppClient', () => { messages: [{ role: 'user', content: 'Hello' }], })).rejects.toThrow('llama-server not running at http://localhost:8080'); }); + + it('passes tools in request body', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ + choices: [{ message: { content: 'I can help with that.' } }], + usage: { prompt_tokens: 12, completion_tokens: 6 }, + }), + }); + + const client = new LlamaCppClient({ + endpoint: 'http://localhost:8080', + model: 'test-model', + }); + + await client.chat({ + messages: [{ role: 'user', content: 'Run ls' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell', + input_schema: { + type: 'object', + properties: { command: { type: 'string' } }, + required: ['command'], + }, + }], + }); + + const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(requestBody.tools).toEqual([{ + type: 'function', + function: { + name: 'shell.exec', + description: 'Run shell', + parameters: { + type: 'object', + properties: { command: { type: 'string' } }, + required: ['command'], + }, + }, + }]); + }); + + it('parses tool_calls from response', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ + choices: [{ + message: { + content: null, + tool_calls: [{ + id: 'call_123', + type: 'function', + function: { name: 'shell.exec', arguments: '{"command":"ls"}' }, + }], + }, + finish_reason: 'tool_calls', + }], + usage: { prompt_tokens: 15, completion_tokens: 8 }, + }), + }); + + const client = new LlamaCppClient({ + endpoint: 'http://localhost:8080', + model: 'test-model', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'List files' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell', + input_schema: { + type: 'object', + properties: { command: { type: 'string' } }, + required: ['command'], + }, + }], + }); + + expect(response.stopReason).toBe('tool_use'); + expect(response.toolCalls).toHaveLength(1); + expect(response.toolCalls![0]).toEqual({ + id: 'call_123', + name: 'shell.exec', + args: { command: 'ls' }, + }); + expect(response.usage.inputTokens).toBe(15); + expect(response.usage.outputTokens).toBe(8); + }); + + it('does not send tools when none provided', async () => { + mockFetch.mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ + choices: [{ message: { content: 'Hello!' } }], + usage: { prompt_tokens: 5, completion_tokens: 2 }, + }), + }); + + const client = new LlamaCppClient({ + endpoint: 'http://localhost:8080', + model: 'test-model', + }); + + await client.chat({ + messages: [{ role: 'user', content: 'Hello' }], + }); + + const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(requestBody.tools).toBeUndefined(); + }); + + it('streaming: accumulates and yields tool_calls from deltas', async () => { + const chunks = [ + 'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"shell.exec"}}]}}]}\n\n', + 'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"comma"}}]}}]}\n\n', + 'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"nd\\":\\"ls\\"}"}}]}}]}\n\n', + 'data: {"choices":[{}],"usage":{"prompt_tokens":10,"completion_tokens":5}}\n\n', + 'data: [DONE]\n\n', + ]; + + const encoder = new TextEncoder(); + let chunkIndex = 0; + + const mockStream = new ReadableStream({ + pull(controller) { + if (chunkIndex < chunks.length) { + controller.enqueue(encoder.encode(chunks[chunkIndex])); + chunkIndex++; + } else { + controller.close(); + } + }, + }); + + mockFetch.mockResolvedValue({ + ok: true, + body: mockStream, + }); + + const client = new LlamaCppClient({ + endpoint: 'http://localhost:8080', + model: 'test-model', + }); + + const events: ChatStreamEvent[] = []; + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Run ls' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell', + input_schema: { + type: 'object', + properties: { command: { type: 'string' } }, + required: ['command'], + }, + }], + })) { + events.push(event); + } + + // Should have a tool_use event and a done event + const toolUseEvents = events.filter(e => e.type === 'tool_use'); + const doneEvents = events.filter(e => e.type === 'done'); + + expect(toolUseEvents).toHaveLength(1); + expect(toolUseEvents[0].toolCall).toEqual({ + id: 'call_1', + name: 'shell.exec', + args: { command: 'ls' }, + }); + + expect(doneEvents).toHaveLength(1); + expect(doneEvents[0].usage).toEqual({ + inputTokens: 10, + outputTokens: 5, + }); + }); + + it('streaming: passes tools in request body', async () => { + const chunks = [ + 'data: {"choices":[{"delta":{"content":"Hi"}}]}\n\n', + 'data: {"choices":[{}],"usage":{"prompt_tokens":3,"completion_tokens":1}}\n\n', + 'data: [DONE]\n\n', + ]; + + const encoder = new TextEncoder(); + let chunkIndex = 0; + + const mockStream = new ReadableStream({ + pull(controller) { + if (chunkIndex < chunks.length) { + controller.enqueue(encoder.encode(chunks[chunkIndex])); + chunkIndex++; + } else { + controller.close(); + } + }, + }); + + mockFetch.mockResolvedValue({ + ok: true, + body: mockStream, + }); + + const client = new LlamaCppClient({ + endpoint: 'http://localhost:8080', + model: 'test-model', + }); + + // Consume the stream to trigger the fetch call + const events: ChatStreamEvent[] = []; + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Hi' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell', + input_schema: { + type: 'object', + properties: { command: { type: 'string' } }, + required: ['command'], + }, + }], + })) { + events.push(event); + } + + const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body); + expect(requestBody.tools).toEqual([{ + type: 'function', + function: { + name: 'shell.exec', + description: 'Run shell', + parameters: { + type: 'object', + properties: { command: { type: 'string' } }, + required: ['command'], + }, + }, + }]); + expect(requestBody.stream).toBe(true); + }); }); diff --git a/src/models/local/llamacpp.ts b/src/models/local/llamacpp.ts index d3c50bf..6c0ea58 100644 --- a/src/models/local/llamacpp.ts +++ b/src/models/local/llamacpp.ts @@ -1,4 +1,4 @@ -import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js'; +import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall } from '../types.js'; import { getMessageText } from '../media.js'; export interface LlamaCppClientConfig { @@ -12,13 +12,42 @@ interface LlamaCppMessage { content: string; } +interface LlamaCppToolCall { + id: string; + type: 'function'; + function: { + name: string; + arguments: string; // JSON string + }; +} + interface LlamaCppResponse { - choices: Array<{ message: { content: string } }>; + choices: Array<{ + message: { + content: string | null; + tool_calls?: LlamaCppToolCall[]; + }; + finish_reason?: string; + }>; usage: { prompt_tokens: number; completion_tokens: number }; } interface LlamaCppStreamChunk { - choices: Array<{ delta?: { content?: string } }>; + choices: Array<{ + delta?: { + content?: string; + tool_calls?: Array<{ + index: number; + id?: string; + type?: string; + function?: { + name?: string; + arguments?: string; + }; + }>; + }; + finish_reason?: string | null; + }>; usage?: { prompt_tokens: number; completion_tokens: number }; } @@ -54,14 +83,28 @@ export class LlamaCppClient implements ModelClient { let response: Response; try { + const body: Record = { + model: this.model, + messages, + max_tokens: request.maxTokens ?? 2048, + }; + + // Pass tool definitions to the API if provided + if (request.tools && request.tools.length > 0) { + body.tools = request.tools.map(t => ({ + type: 'function' as const, + function: { + name: t.name, + description: t.description, + parameters: t.input_schema, + }, + })); + } + response = await fetch(`${this.endpoint}/v1/chat/completions`, { method: 'POST', headers, - body: JSON.stringify({ - model: this.model, - messages, - max_tokens: request.maxTokens ?? 2048, - }), + body: JSON.stringify(body), }); } catch (error) { if (error instanceof TypeError && error.message.includes('fetch failed')) { @@ -77,13 +120,24 @@ export class LlamaCppClient implements ModelClient { const data = (await response.json()) as LlamaCppResponse; + // Parse tool calls from the response, if present + const toolCalls: ModelToolCall[] = data.choices[0]?.message?.tool_calls?.map((tc) => ({ + id: tc.id ?? `llamacpp_tc_${Math.random().toString(36).slice(2, 8)}`, + name: tc.function.name, + args: JSON.parse(tc.function.arguments), + })) ?? []; + + // Set stopReason to 'tool_use' when tool_calls are present + const stopReason = toolCalls.length > 0 ? 'tool_use' : (data.choices[0]?.finish_reason ?? 'stop'); + return { content: data.choices[0]?.message?.content ?? '', - stopReason: 'stop', + stopReason, usage: { inputTokens: data.usage?.prompt_tokens ?? 0, outputTokens: data.usage?.completion_tokens ?? 0, }, + ...(toolCalls.length > 0 ? { toolCalls } : {}), }; } @@ -107,15 +161,29 @@ export class LlamaCppClient implements ModelClient { } try { + const body: Record = { + model: this.model, + messages, + max_tokens: request.maxTokens ?? 2048, + stream: true, + }; + + // Pass tool definitions to the API if provided + if (request.tools && request.tools.length > 0) { + body.tools = request.tools.map(t => ({ + type: 'function' as const, + function: { + name: t.name, + description: t.description, + parameters: t.input_schema, + }, + })); + } + const response = await fetch(`${this.endpoint}/v1/chat/completions`, { method: 'POST', headers, - body: JSON.stringify({ - model: this.model, - messages, - max_tokens: request.maxTokens ?? 2048, - stream: true, - }), + body: JSON.stringify(body), }); if (!response.ok) { @@ -131,6 +199,8 @@ export class LlamaCppClient implements ModelClient { const decoder = new TextDecoder(); let buffer = ''; let usage = { inputTokens: 0, outputTokens: 0 }; + // Accumulate tool call deltas across streamed chunks + const toolCallAccumulators: Map = new Map(); while (true) { const { done, value } = await reader.read(); @@ -154,6 +224,22 @@ export class LlamaCppClient implements ModelClient { yield { type: 'content', content: chunk.choices[0].delta.content }; } + // Accumulate tool call deltas from the stream + if (chunk.choices[0]?.delta?.tool_calls) { + for (const tc of chunk.choices[0].delta.tool_calls) { + if (!toolCallAccumulators.has(tc.index)) { + toolCallAccumulators.set(tc.index, { + id: tc.id ?? `llamacpp_tc_${tc.index}`, + name: tc.function?.name ?? '', + arguments: '', + }); + } + const acc = toolCallAccumulators.get(tc.index)!; + if (tc.function?.name) acc.name = tc.function.name; + if (tc.function?.arguments) acc.arguments += tc.function.arguments; + } + } + if (chunk.usage) { usage = { inputTokens: chunk.usage.prompt_tokens, @@ -166,6 +252,18 @@ export class LlamaCppClient implements ModelClient { } } + // Yield completed tool calls before the done event + for (const [, acc] of toolCallAccumulators) { + yield { + type: 'tool_use', + toolCall: { + id: acc.id, + name: acc.name, + args: JSON.parse(acc.arguments), + }, + }; + } + yield { type: 'done', usage }; } catch (error) { yield { diff --git a/src/models/local/ollama.test.ts b/src/models/local/ollama.test.ts index 8a8447e..7c45e32 100644 --- a/src/models/local/ollama.test.ts +++ b/src/models/local/ollama.test.ts @@ -1,23 +1,29 @@ -import { describe, it, expect, vi } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; import { OllamaClient } from './ollama.js'; +const mockChat = vi.fn(); + vi.mock('ollama', () => ({ Ollama: vi.fn().mockImplementation(() => ({ - chat: vi.fn().mockResolvedValue({ - message: { content: 'Hello from Ollama!' }, - done_reason: 'stop', - prompt_eval_count: 10, - eval_count: 5, - }), + chat: mockChat, })), })); describe('OllamaClient', () => { + beforeEach(() => { + mockChat.mockReset(); + }); + it('sends messages and returns response', async () => { - const client = new OllamaClient({ - model: 'llama3.2', + mockChat.mockResolvedValue({ + message: { content: 'Hello from Ollama!' }, + done_reason: 'stop', + prompt_eval_count: 10, + eval_count: 5, }); + const client = new OllamaClient({ model: 'llama3.2' }); + const response = await client.chat({ messages: [{ role: 'user', content: 'Hello' }], }); @@ -27,4 +33,205 @@ describe('OllamaClient', () => { expect(response.usage.inputTokens).toBe(10); expect(response.usage.outputTokens).toBe(5); }); + + it('passes tools to Ollama API in correct format', async () => { + mockChat.mockResolvedValue({ + message: { content: 'I can help with that.' }, + done_reason: 'stop', + prompt_eval_count: 15, + eval_count: 8, + }); + + const client = new OllamaClient({ model: 'llama3.2' }); + + await client.chat({ + messages: [{ role: 'user', content: 'List files' }], + tools: [ + { + name: 'shell.exec', + description: 'Execute a shell command', + input_schema: { + type: 'object', + properties: { + command: { type: 'string', description: 'The command to run' }, + }, + required: ['command'], + }, + }, + ], + }); + + expect(mockChat).toHaveBeenCalledWith( + expect.objectContaining({ + tools: [ + { + type: 'function', + function: { + name: 'shell.exec', + description: 'Execute a shell command', + parameters: { + type: 'object', + required: ['command'], + properties: { + command: { type: 'string', description: 'The command to run' }, + }, + }, + }, + }, + ], + }), + ); + }); + + it('parses tool_calls from response', async () => { + mockChat.mockResolvedValue({ + message: { + content: '', + tool_calls: [ + { function: { name: 'shell.exec', arguments: { command: 'ls' } } }, + ], + }, + done_reason: 'stop', + prompt_eval_count: 12, + eval_count: 6, + }); + + const client = new OllamaClient({ model: 'llama3.2' }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'List files' }], + }); + + expect(response.stopReason).toBe('tool_use'); + expect(response.toolCalls).toHaveLength(1); + expect(response.toolCalls![0]).toEqual({ + id: 'ollama_tc_0', + name: 'shell.exec', + args: { command: 'ls' }, + }); + }); + + it('handles thinking field from reasoning models', async () => { + mockChat.mockResolvedValue({ + message: { content: '', thinking: 'Let me think...' }, + done_reason: 'stop', + prompt_eval_count: 20, + eval_count: 15, + }); + + const client = new OllamaClient({ model: 'deepseek-r1' }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'Solve this problem' }], + }); + + // When content is empty, thinking is used as fallback + expect(response.content).toBe('Let me think...'); + expect(response.thinkingContent).toBe('Let me think...'); + }); + + it('thinking field does not override existing content', async () => { + mockChat.mockResolvedValue({ + message: { content: 'Final answer', thinking: 'Reasoning...' }, + done_reason: 'stop', + prompt_eval_count: 20, + eval_count: 15, + }); + + const client = new OllamaClient({ model: 'deepseek-r1' }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'Solve this problem' }], + }); + + expect(response.content).toBe('Final answer'); + expect(response.thinkingContent).toBe('Reasoning...'); + }); + + it('does not send tools when none provided', async () => { + mockChat.mockResolvedValue({ + message: { content: 'No tools needed.' }, + done_reason: 'stop', + prompt_eval_count: 5, + eval_count: 3, + }); + + const client = new OllamaClient({ model: 'llama3.2' }); + + await client.chat({ + messages: [{ role: 'user', content: 'Hello' }], + }); + + const callArgs = mockChat.mock.calls[0][0]; + expect(callArgs.tools).toBeUndefined(); + }); + + it('streaming: yields content events', async () => { + mockChat.mockResolvedValue( + (async function* () { + yield { message: { content: 'Hello' }, done: false }; + yield { message: { content: ' world' }, done: false }; + yield { message: { content: '' }, done: true, prompt_eval_count: 10, eval_count: 5 }; + })(), + ); + + const client = new OllamaClient({ model: 'llama3.2' }); + + const events: Array<{ type: string; content?: string; usage?: { inputTokens: number; outputTokens: number } }> = []; + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Hello' }], + })) { + events.push(event); + } + + expect(events).toHaveLength(3); + expect(events[0]).toEqual({ type: 'content', content: 'Hello' }); + expect(events[1]).toEqual({ type: 'content', content: ' world' }); + expect(events[2]).toEqual({ + type: 'done', + usage: { inputTokens: 10, outputTokens: 5 }, + }); + }); + + it('streaming: yields tool_use events from final chunk', async () => { + mockChat.mockResolvedValue( + (async function* () { + yield { + message: { + content: '', + tool_calls: [ + { function: { name: 'system.info', arguments: {} } }, + ], + }, + done: true, + prompt_eval_count: 5, + eval_count: 3, + }; + })(), + ); + + const client = new OllamaClient({ model: 'llama3.2' }); + + const events: Array<{ type: string; toolCall?: { id: string; name: string; args: unknown }; usage?: { inputTokens: number; outputTokens: number } }> = []; + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Get system info' }], + })) { + events.push(event); + } + + // Should have tool_use event followed by done + expect(events).toHaveLength(2); + expect(events[0]).toEqual({ + type: 'tool_use', + toolCall: { + id: 'ollama_tc_0', + name: 'system.info', + args: {}, + }, + }); + expect(events[1]).toEqual({ + type: 'done', + usage: { inputTokens: 5, outputTokens: 3 }, + }); + }); }); diff --git a/src/models/local/ollama.ts b/src/models/local/ollama.ts index bce5cdd..884bd96 100644 --- a/src/models/local/ollama.ts +++ b/src/models/local/ollama.ts @@ -1,5 +1,5 @@ -import { Ollama } from 'ollama'; -import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js'; +import { Ollama, type Tool } from 'ollama'; +import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ToolDefinition, ModelToolCall } from '../types.js'; import { getMessageText } from '../media.js'; export interface OllamaClientConfig { @@ -21,6 +21,24 @@ export class OllamaClient implements ModelClient { this.numGpu = config.numGpu ?? -1; } + /** + * Convert Flynn ToolDefinition[] to Ollama Tool[] format. + */ + private convertTools(tools: ToolDefinition[]): Tool[] { + return tools.map(t => ({ + type: 'function', + function: { + name: t.name, + description: t.description, + parameters: { + type: t.input_schema.type, + required: t.input_schema.required, + properties: t.input_schema.properties as Record, + }, + }, + })); + } + async chat(request: ChatRequest): Promise { const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = []; @@ -32,21 +50,51 @@ export class OllamaClient implements ModelClient { messages.push({ role: msg.role, content: getMessageText(msg) }); } - const response = await this.client.chat({ + // Build the chat params, optionally including tools + const chatParams: Parameters[0] = { model: this.model, messages, options: { num_gpu: this.numGpu, }, - }); + }; + + if (request.tools && request.tools.length > 0) { + chatParams.tools = this.convertTools(request.tools); + } + + const response = await this.client.chat(chatParams); + + // Extract content, checking for thinking field from reasoning models + let content = response.message.content; + let thinkingContent: string | undefined; + const thinking = (response.message as any).thinking; + if (thinking && typeof thinking === 'string') { + if (!content) { + // If no regular content, use thinking as content + content = thinking; + } + thinkingContent = thinking; + } + + // Parse tool_calls from the response + const toolCalls: ModelToolCall[] = response.message.tool_calls?.map((tc, i) => ({ + id: `ollama_tc_${i}`, + name: tc.function.name, + args: tc.function.arguments, + })) ?? []; + + const hasToolCalls = toolCalls.length > 0; return { - content: response.message.content, - stopReason: response.done_reason ?? 'stop', + content, + stopReason: hasToolCalls ? 'tool_use' : (response.done_reason ?? 'stop'), usage: { inputTokens: response.prompt_eval_count ?? 0, outputTokens: response.eval_count ?? 0, }, + ...(hasToolCalls ? { toolCalls } : {}), + ...(thinkingContent ? { thinkingContent } : {}), }; } @@ -62,6 +110,11 @@ export class OllamaClient implements ModelClient { } try { + // Build tools array if provided + const tools = request.tools && request.tools.length > 0 + ? this.convertTools(request.tools) + : undefined; + const stream = await this.client.chat({ model: this.model, messages, @@ -69,6 +122,7 @@ export class OllamaClient implements ModelClient { options: { num_gpu: this.numGpu, }, + ...(tools ? { tools } : {}), }); let inputTokens = 0; @@ -79,6 +133,12 @@ export class OllamaClient implements ModelClient { yield { type: 'content', content: chunk.message.content }; } + // Handle thinking field from reasoning models (e.g., deepseek-r1) + const thinking = (chunk.message as any)?.thinking; + if (thinking && typeof thinking === 'string') { + yield { type: 'content', content: thinking }; + } + if (chunk.prompt_eval_count) { inputTokens = chunk.prompt_eval_count; } @@ -87,6 +147,22 @@ export class OllamaClient implements ModelClient { } if (chunk.done) { + // Handle tool_calls in the final chunk + const toolCalls = (chunk.message as any)?.tool_calls; + if (toolCalls && Array.isArray(toolCalls)) { + for (let i = 0; i < toolCalls.length; i++) { + const tc = toolCalls[i]; + yield { + type: 'tool_use', + toolCall: { + id: `ollama_tc_${i}`, + name: tc.function.name, + args: tc.function.arguments, + }, + }; + } + } + yield { type: 'done', usage: {