feat: add tool calling support to Ollama and llama.cpp clients

- Ollama: pass tools to API, parse tool_calls responses, handle thinking field from reasoning models (deepseek-r1, glm-4.7-flash)
- llama.cpp: pass tools via OpenAI-compatible endpoint, parse tool_calls, accumulate streaming tool call deltas
- Both clients now set stopReason to 'tool_use' when tool calls are present
- Tests: 12 new tests (8 Ollama + 5 llama.cpp, total 983→995)
This commit is contained in:
William Valentin
2026-02-07 17:20:27 -08:00
parent fcbab1e1ee
commit fb20acfbcd
4 changed files with 655 additions and 30 deletions
+244
View File
@@ -6,6 +6,7 @@ describe('LlamaCppClient', () => {
const mockFetch = vi.fn();
beforeEach(() => {
mockFetch.mockReset();
vi.stubGlobal('fetch', mockFetch);
});
@@ -96,4 +97,247 @@ describe('LlamaCppClient', () => {
messages: [{ role: 'user', content: 'Hello' }],
})).rejects.toThrow('llama-server not running at http://localhost:8080');
});
it('passes tools in request body', async () => {
mockFetch.mockResolvedValue({
ok: true,
json: () => Promise.resolve({
choices: [{ message: { content: 'I can help with that.' } }],
usage: { prompt_tokens: 12, completion_tokens: 6 },
}),
});
const client = new LlamaCppClient({
endpoint: 'http://localhost:8080',
model: 'test-model',
});
await client.chat({
messages: [{ role: 'user', content: 'Run ls' }],
tools: [{
name: 'shell.exec',
description: 'Run shell',
input_schema: {
type: 'object',
properties: { command: { type: 'string' } },
required: ['command'],
},
}],
});
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
expect(requestBody.tools).toEqual([{
type: 'function',
function: {
name: 'shell.exec',
description: 'Run shell',
parameters: {
type: 'object',
properties: { command: { type: 'string' } },
required: ['command'],
},
},
}]);
});
it('parses tool_calls from response', async () => {
mockFetch.mockResolvedValue({
ok: true,
json: () => Promise.resolve({
choices: [{
message: {
content: null,
tool_calls: [{
id: 'call_123',
type: 'function',
function: { name: 'shell.exec', arguments: '{"command":"ls"}' },
}],
},
finish_reason: 'tool_calls',
}],
usage: { prompt_tokens: 15, completion_tokens: 8 },
}),
});
const client = new LlamaCppClient({
endpoint: 'http://localhost:8080',
model: 'test-model',
});
const response = await client.chat({
messages: [{ role: 'user', content: 'List files' }],
tools: [{
name: 'shell.exec',
description: 'Run shell',
input_schema: {
type: 'object',
properties: { command: { type: 'string' } },
required: ['command'],
},
}],
});
expect(response.stopReason).toBe('tool_use');
expect(response.toolCalls).toHaveLength(1);
expect(response.toolCalls![0]).toEqual({
id: 'call_123',
name: 'shell.exec',
args: { command: 'ls' },
});
expect(response.usage.inputTokens).toBe(15);
expect(response.usage.outputTokens).toBe(8);
});
it('does not send tools when none provided', async () => {
mockFetch.mockResolvedValue({
ok: true,
json: () => Promise.resolve({
choices: [{ message: { content: 'Hello!' } }],
usage: { prompt_tokens: 5, completion_tokens: 2 },
}),
});
const client = new LlamaCppClient({
endpoint: 'http://localhost:8080',
model: 'test-model',
});
await client.chat({
messages: [{ role: 'user', content: 'Hello' }],
});
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
expect(requestBody.tools).toBeUndefined();
});
it('streaming: accumulates and yields tool_calls from deltas', async () => {
const chunks = [
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"shell.exec"}}]}}]}\n\n',
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"comma"}}]}}]}\n\n',
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"nd\\":\\"ls\\"}"}}]}}]}\n\n',
'data: {"choices":[{}],"usage":{"prompt_tokens":10,"completion_tokens":5}}\n\n',
'data: [DONE]\n\n',
];
const encoder = new TextEncoder();
let chunkIndex = 0;
const mockStream = new ReadableStream({
pull(controller) {
if (chunkIndex < chunks.length) {
controller.enqueue(encoder.encode(chunks[chunkIndex]));
chunkIndex++;
} else {
controller.close();
}
},
});
mockFetch.mockResolvedValue({
ok: true,
body: mockStream,
});
const client = new LlamaCppClient({
endpoint: 'http://localhost:8080',
model: 'test-model',
});
const events: ChatStreamEvent[] = [];
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'Run ls' }],
tools: [{
name: 'shell.exec',
description: 'Run shell',
input_schema: {
type: 'object',
properties: { command: { type: 'string' } },
required: ['command'],
},
}],
})) {
events.push(event);
}
// Should have a tool_use event and a done event
const toolUseEvents = events.filter(e => e.type === 'tool_use');
const doneEvents = events.filter(e => e.type === 'done');
expect(toolUseEvents).toHaveLength(1);
expect(toolUseEvents[0].toolCall).toEqual({
id: 'call_1',
name: 'shell.exec',
args: { command: 'ls' },
});
expect(doneEvents).toHaveLength(1);
expect(doneEvents[0].usage).toEqual({
inputTokens: 10,
outputTokens: 5,
});
});
it('streaming: passes tools in request body', async () => {
const chunks = [
'data: {"choices":[{"delta":{"content":"Hi"}}]}\n\n',
'data: {"choices":[{}],"usage":{"prompt_tokens":3,"completion_tokens":1}}\n\n',
'data: [DONE]\n\n',
];
const encoder = new TextEncoder();
let chunkIndex = 0;
const mockStream = new ReadableStream({
pull(controller) {
if (chunkIndex < chunks.length) {
controller.enqueue(encoder.encode(chunks[chunkIndex]));
chunkIndex++;
} else {
controller.close();
}
},
});
mockFetch.mockResolvedValue({
ok: true,
body: mockStream,
});
const client = new LlamaCppClient({
endpoint: 'http://localhost:8080',
model: 'test-model',
});
// Consume the stream to trigger the fetch call
const events: ChatStreamEvent[] = [];
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'Hi' }],
tools: [{
name: 'shell.exec',
description: 'Run shell',
input_schema: {
type: 'object',
properties: { command: { type: 'string' } },
required: ['command'],
},
}],
})) {
events.push(event);
}
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
expect(requestBody.tools).toEqual([{
type: 'function',
function: {
name: 'shell.exec',
description: 'Run shell',
parameters: {
type: 'object',
properties: { command: { type: 'string' } },
required: ['command'],
},
},
}]);
expect(requestBody.stream).toBe(true);
});
});
+113 -15
View File
@@ -1,4 +1,4 @@
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js';
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall } from '../types.js';
import { getMessageText } from '../media.js';
export interface LlamaCppClientConfig {
@@ -12,13 +12,42 @@ interface LlamaCppMessage {
content: string;
}
interface LlamaCppToolCall {
id: string;
type: 'function';
function: {
name: string;
arguments: string; // JSON string
};
}
interface LlamaCppResponse {
choices: Array<{ message: { content: string } }>;
choices: Array<{
message: {
content: string | null;
tool_calls?: LlamaCppToolCall[];
};
finish_reason?: string;
}>;
usage: { prompt_tokens: number; completion_tokens: number };
}
interface LlamaCppStreamChunk {
choices: Array<{ delta?: { content?: string } }>;
choices: Array<{
delta?: {
content?: string;
tool_calls?: Array<{
index: number;
id?: string;
type?: string;
function?: {
name?: string;
arguments?: string;
};
}>;
};
finish_reason?: string | null;
}>;
usage?: { prompt_tokens: number; completion_tokens: number };
}
@@ -54,14 +83,28 @@ export class LlamaCppClient implements ModelClient {
let response: Response;
try {
const body: Record<string, unknown> = {
model: this.model,
messages,
max_tokens: request.maxTokens ?? 2048,
};
// Pass tool definitions to the API if provided
if (request.tools && request.tools.length > 0) {
body.tools = request.tools.map(t => ({
type: 'function' as const,
function: {
name: t.name,
description: t.description,
parameters: t.input_schema,
},
}));
}
response = await fetch(`${this.endpoint}/v1/chat/completions`, {
method: 'POST',
headers,
body: JSON.stringify({
model: this.model,
messages,
max_tokens: request.maxTokens ?? 2048,
}),
body: JSON.stringify(body),
});
} catch (error) {
if (error instanceof TypeError && error.message.includes('fetch failed')) {
@@ -77,13 +120,24 @@ export class LlamaCppClient implements ModelClient {
const data = (await response.json()) as LlamaCppResponse;
// Parse tool calls from the response, if present
const toolCalls: ModelToolCall[] = data.choices[0]?.message?.tool_calls?.map((tc) => ({
id: tc.id ?? `llamacpp_tc_${Math.random().toString(36).slice(2, 8)}`,
name: tc.function.name,
args: JSON.parse(tc.function.arguments),
})) ?? [];
// Set stopReason to 'tool_use' when tool_calls are present
const stopReason = toolCalls.length > 0 ? 'tool_use' : (data.choices[0]?.finish_reason ?? 'stop');
return {
content: data.choices[0]?.message?.content ?? '',
stopReason: 'stop',
stopReason,
usage: {
inputTokens: data.usage?.prompt_tokens ?? 0,
outputTokens: data.usage?.completion_tokens ?? 0,
},
...(toolCalls.length > 0 ? { toolCalls } : {}),
};
}
@@ -107,15 +161,29 @@ export class LlamaCppClient implements ModelClient {
}
try {
const body: Record<string, unknown> = {
model: this.model,
messages,
max_tokens: request.maxTokens ?? 2048,
stream: true,
};
// Pass tool definitions to the API if provided
if (request.tools && request.tools.length > 0) {
body.tools = request.tools.map(t => ({
type: 'function' as const,
function: {
name: t.name,
description: t.description,
parameters: t.input_schema,
},
}));
}
const response = await fetch(`${this.endpoint}/v1/chat/completions`, {
method: 'POST',
headers,
body: JSON.stringify({
model: this.model,
messages,
max_tokens: request.maxTokens ?? 2048,
stream: true,
}),
body: JSON.stringify(body),
});
if (!response.ok) {
@@ -131,6 +199,8 @@ export class LlamaCppClient implements ModelClient {
const decoder = new TextDecoder();
let buffer = '';
let usage = { inputTokens: 0, outputTokens: 0 };
// Accumulate tool call deltas across streamed chunks
const toolCallAccumulators: Map<number, { id: string; name: string; arguments: string }> = new Map();
while (true) {
const { done, value } = await reader.read();
@@ -154,6 +224,22 @@ export class LlamaCppClient implements ModelClient {
yield { type: 'content', content: chunk.choices[0].delta.content };
}
// Accumulate tool call deltas from the stream
if (chunk.choices[0]?.delta?.tool_calls) {
for (const tc of chunk.choices[0].delta.tool_calls) {
if (!toolCallAccumulators.has(tc.index)) {
toolCallAccumulators.set(tc.index, {
id: tc.id ?? `llamacpp_tc_${tc.index}`,
name: tc.function?.name ?? '',
arguments: '',
});
}
const acc = toolCallAccumulators.get(tc.index)!;
if (tc.function?.name) acc.name = tc.function.name;
if (tc.function?.arguments) acc.arguments += tc.function.arguments;
}
}
if (chunk.usage) {
usage = {
inputTokens: chunk.usage.prompt_tokens,
@@ -166,6 +252,18 @@ export class LlamaCppClient implements ModelClient {
}
}
// Yield completed tool calls before the done event
for (const [, acc] of toolCallAccumulators) {
yield {
type: 'tool_use',
toolCall: {
id: acc.id,
name: acc.name,
args: JSON.parse(acc.arguments),
},
};
}
yield { type: 'done', usage };
} catch (error) {
yield {
+216 -9
View File
@@ -1,23 +1,29 @@
import { describe, it, expect, vi } from 'vitest';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { OllamaClient } from './ollama.js';
const mockChat = vi.fn();
vi.mock('ollama', () => ({
Ollama: vi.fn().mockImplementation(() => ({
chat: vi.fn().mockResolvedValue({
message: { content: 'Hello from Ollama!' },
done_reason: 'stop',
prompt_eval_count: 10,
eval_count: 5,
}),
chat: mockChat,
})),
}));
describe('OllamaClient', () => {
beforeEach(() => {
mockChat.mockReset();
});
it('sends messages and returns response', async () => {
const client = new OllamaClient({
model: 'llama3.2',
mockChat.mockResolvedValue({
message: { content: 'Hello from Ollama!' },
done_reason: 'stop',
prompt_eval_count: 10,
eval_count: 5,
});
const client = new OllamaClient({ model: 'llama3.2' });
const response = await client.chat({
messages: [{ role: 'user', content: 'Hello' }],
});
@@ -27,4 +33,205 @@ describe('OllamaClient', () => {
expect(response.usage.inputTokens).toBe(10);
expect(response.usage.outputTokens).toBe(5);
});
it('passes tools to Ollama API in correct format', async () => {
mockChat.mockResolvedValue({
message: { content: 'I can help with that.' },
done_reason: 'stop',
prompt_eval_count: 15,
eval_count: 8,
});
const client = new OllamaClient({ model: 'llama3.2' });
await client.chat({
messages: [{ role: 'user', content: 'List files' }],
tools: [
{
name: 'shell.exec',
description: 'Execute a shell command',
input_schema: {
type: 'object',
properties: {
command: { type: 'string', description: 'The command to run' },
},
required: ['command'],
},
},
],
});
expect(mockChat).toHaveBeenCalledWith(
expect.objectContaining({
tools: [
{
type: 'function',
function: {
name: 'shell.exec',
description: 'Execute a shell command',
parameters: {
type: 'object',
required: ['command'],
properties: {
command: { type: 'string', description: 'The command to run' },
},
},
},
},
],
}),
);
});
it('parses tool_calls from response', async () => {
mockChat.mockResolvedValue({
message: {
content: '',
tool_calls: [
{ function: { name: 'shell.exec', arguments: { command: 'ls' } } },
],
},
done_reason: 'stop',
prompt_eval_count: 12,
eval_count: 6,
});
const client = new OllamaClient({ model: 'llama3.2' });
const response = await client.chat({
messages: [{ role: 'user', content: 'List files' }],
});
expect(response.stopReason).toBe('tool_use');
expect(response.toolCalls).toHaveLength(1);
expect(response.toolCalls![0]).toEqual({
id: 'ollama_tc_0',
name: 'shell.exec',
args: { command: 'ls' },
});
});
it('handles thinking field from reasoning models', async () => {
mockChat.mockResolvedValue({
message: { content: '', thinking: 'Let me think...' },
done_reason: 'stop',
prompt_eval_count: 20,
eval_count: 15,
});
const client = new OllamaClient({ model: 'deepseek-r1' });
const response = await client.chat({
messages: [{ role: 'user', content: 'Solve this problem' }],
});
// When content is empty, thinking is used as fallback
expect(response.content).toBe('Let me think...');
expect(response.thinkingContent).toBe('Let me think...');
});
it('thinking field does not override existing content', async () => {
mockChat.mockResolvedValue({
message: { content: 'Final answer', thinking: 'Reasoning...' },
done_reason: 'stop',
prompt_eval_count: 20,
eval_count: 15,
});
const client = new OllamaClient({ model: 'deepseek-r1' });
const response = await client.chat({
messages: [{ role: 'user', content: 'Solve this problem' }],
});
expect(response.content).toBe('Final answer');
expect(response.thinkingContent).toBe('Reasoning...');
});
it('does not send tools when none provided', async () => {
mockChat.mockResolvedValue({
message: { content: 'No tools needed.' },
done_reason: 'stop',
prompt_eval_count: 5,
eval_count: 3,
});
const client = new OllamaClient({ model: 'llama3.2' });
await client.chat({
messages: [{ role: 'user', content: 'Hello' }],
});
const callArgs = mockChat.mock.calls[0][0];
expect(callArgs.tools).toBeUndefined();
});
it('streaming: yields content events', async () => {
mockChat.mockResolvedValue(
(async function* () {
yield { message: { content: 'Hello' }, done: false };
yield { message: { content: ' world' }, done: false };
yield { message: { content: '' }, done: true, prompt_eval_count: 10, eval_count: 5 };
})(),
);
const client = new OllamaClient({ model: 'llama3.2' });
const events: Array<{ type: string; content?: string; usage?: { inputTokens: number; outputTokens: number } }> = [];
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'Hello' }],
})) {
events.push(event);
}
expect(events).toHaveLength(3);
expect(events[0]).toEqual({ type: 'content', content: 'Hello' });
expect(events[1]).toEqual({ type: 'content', content: ' world' });
expect(events[2]).toEqual({
type: 'done',
usage: { inputTokens: 10, outputTokens: 5 },
});
});
it('streaming: yields tool_use events from final chunk', async () => {
mockChat.mockResolvedValue(
(async function* () {
yield {
message: {
content: '',
tool_calls: [
{ function: { name: 'system.info', arguments: {} } },
],
},
done: true,
prompt_eval_count: 5,
eval_count: 3,
};
})(),
);
const client = new OllamaClient({ model: 'llama3.2' });
const events: Array<{ type: string; toolCall?: { id: string; name: string; args: unknown }; usage?: { inputTokens: number; outputTokens: number } }> = [];
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'Get system info' }],
})) {
events.push(event);
}
// Should have tool_use event followed by done
expect(events).toHaveLength(2);
expect(events[0]).toEqual({
type: 'tool_use',
toolCall: {
id: 'ollama_tc_0',
name: 'system.info',
args: {},
},
});
expect(events[1]).toEqual({
type: 'done',
usage: { inputTokens: 5, outputTokens: 3 },
});
});
});
+82 -6
View File
@@ -1,5 +1,5 @@
import { Ollama } from 'ollama';
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js';
import { Ollama, type Tool } from 'ollama';
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ToolDefinition, ModelToolCall } from '../types.js';
import { getMessageText } from '../media.js';
export interface OllamaClientConfig {
@@ -21,6 +21,24 @@ export class OllamaClient implements ModelClient {
this.numGpu = config.numGpu ?? -1;
}
/**
* Convert Flynn ToolDefinition[] to Ollama Tool[] format.
*/
private convertTools(tools: ToolDefinition[]): Tool[] {
return tools.map(t => ({
type: 'function',
function: {
name: t.name,
description: t.description,
parameters: {
type: t.input_schema.type,
required: t.input_schema.required,
properties: t.input_schema.properties as Record<string, any>,
},
},
}));
}
async chat(request: ChatRequest): Promise<ChatResponse> {
const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = [];
@@ -32,21 +50,51 @@ export class OllamaClient implements ModelClient {
messages.push({ role: msg.role, content: getMessageText(msg) });
}
const response = await this.client.chat({
// Build the chat params, optionally including tools
const chatParams: Parameters<typeof this.client.chat>[0] = {
model: this.model,
messages,
options: {
num_gpu: this.numGpu,
},
});
};
if (request.tools && request.tools.length > 0) {
chatParams.tools = this.convertTools(request.tools);
}
const response = await this.client.chat(chatParams);
// Extract content, checking for thinking field from reasoning models
let content = response.message.content;
let thinkingContent: string | undefined;
const thinking = (response.message as any).thinking;
if (thinking && typeof thinking === 'string') {
if (!content) {
// If no regular content, use thinking as content
content = thinking;
}
thinkingContent = thinking;
}
// Parse tool_calls from the response
const toolCalls: ModelToolCall[] = response.message.tool_calls?.map((tc, i) => ({
id: `ollama_tc_${i}`,
name: tc.function.name,
args: tc.function.arguments,
})) ?? [];
const hasToolCalls = toolCalls.length > 0;
return {
content: response.message.content,
stopReason: response.done_reason ?? 'stop',
content,
stopReason: hasToolCalls ? 'tool_use' : (response.done_reason ?? 'stop'),
usage: {
inputTokens: response.prompt_eval_count ?? 0,
outputTokens: response.eval_count ?? 0,
},
...(hasToolCalls ? { toolCalls } : {}),
...(thinkingContent ? { thinkingContent } : {}),
};
}
@@ -62,6 +110,11 @@ export class OllamaClient implements ModelClient {
}
try {
// Build tools array if provided
const tools = request.tools && request.tools.length > 0
? this.convertTools(request.tools)
: undefined;
const stream = await this.client.chat({
model: this.model,
messages,
@@ -69,6 +122,7 @@ export class OllamaClient implements ModelClient {
options: {
num_gpu: this.numGpu,
},
...(tools ? { tools } : {}),
});
let inputTokens = 0;
@@ -79,6 +133,12 @@ export class OllamaClient implements ModelClient {
yield { type: 'content', content: chunk.message.content };
}
// Handle thinking field from reasoning models (e.g., deepseek-r1)
const thinking = (chunk.message as any)?.thinking;
if (thinking && typeof thinking === 'string') {
yield { type: 'content', content: thinking };
}
if (chunk.prompt_eval_count) {
inputTokens = chunk.prompt_eval_count;
}
@@ -87,6 +147,22 @@ export class OllamaClient implements ModelClient {
}
if (chunk.done) {
// Handle tool_calls in the final chunk
const toolCalls = (chunk.message as any)?.tool_calls;
if (toolCalls && Array.isArray(toolCalls)) {
for (let i = 0; i < toolCalls.length; i++) {
const tc = toolCalls[i];
yield {
type: 'tool_use',
toolCall: {
id: `ollama_tc_${i}`,
name: tc.function.name,
args: tc.function.arguments,
},
};
}
}
yield {
type: 'done',
usage: {