feat: add tool calling support to Ollama and llama.cpp clients
- Ollama: pass tools to API, parse tool_calls responses, handle thinking field from reasoning models (deepseek-r1, glm-4.7-flash) - llama.cpp: pass tools via OpenAI-compatible endpoint, parse tool_calls, accumulate streaming tool call deltas - Both clients now set stopReason to 'tool_use' when tool calls are present - Tests: 12 new tests (8 Ollama + 5 llama.cpp, total 983→995)
This commit is contained in:
@@ -6,6 +6,7 @@ describe('LlamaCppClient', () => {
|
||||
const mockFetch = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
mockFetch.mockReset();
|
||||
vi.stubGlobal('fetch', mockFetch);
|
||||
});
|
||||
|
||||
@@ -96,4 +97,247 @@ describe('LlamaCppClient', () => {
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
})).rejects.toThrow('llama-server not running at http://localhost:8080');
|
||||
});
|
||||
|
||||
it('passes tools in request body', async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({
|
||||
choices: [{ message: { content: 'I can help with that.' } }],
|
||||
usage: { prompt_tokens: 12, completion_tokens: 6 },
|
||||
}),
|
||||
});
|
||||
|
||||
const client = new LlamaCppClient({
|
||||
endpoint: 'http://localhost:8080',
|
||||
model: 'test-model',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'Run ls' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell',
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: { command: { type: 'string' } },
|
||||
required: ['command'],
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
|
||||
expect(requestBody.tools).toEqual([{
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: { command: { type: 'string' } },
|
||||
required: ['command'],
|
||||
},
|
||||
},
|
||||
}]);
|
||||
});
|
||||
|
||||
it('parses tool_calls from response', async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({
|
||||
choices: [{
|
||||
message: {
|
||||
content: null,
|
||||
tool_calls: [{
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: { name: 'shell.exec', arguments: '{"command":"ls"}' },
|
||||
}],
|
||||
},
|
||||
finish_reason: 'tool_calls',
|
||||
}],
|
||||
usage: { prompt_tokens: 15, completion_tokens: 8 },
|
||||
}),
|
||||
});
|
||||
|
||||
const client = new LlamaCppClient({
|
||||
endpoint: 'http://localhost:8080',
|
||||
model: 'test-model',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'List files' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell',
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: { command: { type: 'string' } },
|
||||
required: ['command'],
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
expect(response.stopReason).toBe('tool_use');
|
||||
expect(response.toolCalls).toHaveLength(1);
|
||||
expect(response.toolCalls![0]).toEqual({
|
||||
id: 'call_123',
|
||||
name: 'shell.exec',
|
||||
args: { command: 'ls' },
|
||||
});
|
||||
expect(response.usage.inputTokens).toBe(15);
|
||||
expect(response.usage.outputTokens).toBe(8);
|
||||
});
|
||||
|
||||
it('does not send tools when none provided', async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({
|
||||
choices: [{ message: { content: 'Hello!' } }],
|
||||
usage: { prompt_tokens: 5, completion_tokens: 2 },
|
||||
}),
|
||||
});
|
||||
|
||||
const client = new LlamaCppClient({
|
||||
endpoint: 'http://localhost:8080',
|
||||
model: 'test-model',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
});
|
||||
|
||||
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
|
||||
expect(requestBody.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('streaming: accumulates and yields tool_calls from deltas', async () => {
|
||||
const chunks = [
|
||||
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"shell.exec"}}]}}]}\n\n',
|
||||
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"comma"}}]}}]}\n\n',
|
||||
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"nd\\":\\"ls\\"}"}}]}}]}\n\n',
|
||||
'data: {"choices":[{}],"usage":{"prompt_tokens":10,"completion_tokens":5}}\n\n',
|
||||
'data: [DONE]\n\n',
|
||||
];
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
let chunkIndex = 0;
|
||||
|
||||
const mockStream = new ReadableStream({
|
||||
pull(controller) {
|
||||
if (chunkIndex < chunks.length) {
|
||||
controller.enqueue(encoder.encode(chunks[chunkIndex]));
|
||||
chunkIndex++;
|
||||
} else {
|
||||
controller.close();
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
body: mockStream,
|
||||
});
|
||||
|
||||
const client = new LlamaCppClient({
|
||||
endpoint: 'http://localhost:8080',
|
||||
model: 'test-model',
|
||||
});
|
||||
|
||||
const events: ChatStreamEvent[] = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Run ls' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell',
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: { command: { type: 'string' } },
|
||||
required: ['command'],
|
||||
},
|
||||
}],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Should have a tool_use event and a done event
|
||||
const toolUseEvents = events.filter(e => e.type === 'tool_use');
|
||||
const doneEvents = events.filter(e => e.type === 'done');
|
||||
|
||||
expect(toolUseEvents).toHaveLength(1);
|
||||
expect(toolUseEvents[0].toolCall).toEqual({
|
||||
id: 'call_1',
|
||||
name: 'shell.exec',
|
||||
args: { command: 'ls' },
|
||||
});
|
||||
|
||||
expect(doneEvents).toHaveLength(1);
|
||||
expect(doneEvents[0].usage).toEqual({
|
||||
inputTokens: 10,
|
||||
outputTokens: 5,
|
||||
});
|
||||
});
|
||||
|
||||
it('streaming: passes tools in request body', async () => {
|
||||
const chunks = [
|
||||
'data: {"choices":[{"delta":{"content":"Hi"}}]}\n\n',
|
||||
'data: {"choices":[{}],"usage":{"prompt_tokens":3,"completion_tokens":1}}\n\n',
|
||||
'data: [DONE]\n\n',
|
||||
];
|
||||
|
||||
const encoder = new TextEncoder();
|
||||
let chunkIndex = 0;
|
||||
|
||||
const mockStream = new ReadableStream({
|
||||
pull(controller) {
|
||||
if (chunkIndex < chunks.length) {
|
||||
controller.enqueue(encoder.encode(chunks[chunkIndex]));
|
||||
chunkIndex++;
|
||||
} else {
|
||||
controller.close();
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
body: mockStream,
|
||||
});
|
||||
|
||||
const client = new LlamaCppClient({
|
||||
endpoint: 'http://localhost:8080',
|
||||
model: 'test-model',
|
||||
});
|
||||
|
||||
// Consume the stream to trigger the fetch call
|
||||
const events: ChatStreamEvent[] = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Hi' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell',
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: { command: { type: 'string' } },
|
||||
required: ['command'],
|
||||
},
|
||||
}],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
|
||||
expect(requestBody.tools).toEqual([{
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: { command: { type: 'string' } },
|
||||
required: ['command'],
|
||||
},
|
||||
},
|
||||
}]);
|
||||
expect(requestBody.stream).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
+113
-15
@@ -1,4 +1,4 @@
|
||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js';
|
||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall } from '../types.js';
|
||||
import { getMessageText } from '../media.js';
|
||||
|
||||
export interface LlamaCppClientConfig {
|
||||
@@ -12,13 +12,42 @@ interface LlamaCppMessage {
|
||||
content: string;
|
||||
}
|
||||
|
||||
interface LlamaCppToolCall {
|
||||
id: string;
|
||||
type: 'function';
|
||||
function: {
|
||||
name: string;
|
||||
arguments: string; // JSON string
|
||||
};
|
||||
}
|
||||
|
||||
interface LlamaCppResponse {
|
||||
choices: Array<{ message: { content: string } }>;
|
||||
choices: Array<{
|
||||
message: {
|
||||
content: string | null;
|
||||
tool_calls?: LlamaCppToolCall[];
|
||||
};
|
||||
finish_reason?: string;
|
||||
}>;
|
||||
usage: { prompt_tokens: number; completion_tokens: number };
|
||||
}
|
||||
|
||||
interface LlamaCppStreamChunk {
|
||||
choices: Array<{ delta?: { content?: string } }>;
|
||||
choices: Array<{
|
||||
delta?: {
|
||||
content?: string;
|
||||
tool_calls?: Array<{
|
||||
index: number;
|
||||
id?: string;
|
||||
type?: string;
|
||||
function?: {
|
||||
name?: string;
|
||||
arguments?: string;
|
||||
};
|
||||
}>;
|
||||
};
|
||||
finish_reason?: string | null;
|
||||
}>;
|
||||
usage?: { prompt_tokens: number; completion_tokens: number };
|
||||
}
|
||||
|
||||
@@ -54,14 +83,28 @@ export class LlamaCppClient implements ModelClient {
|
||||
|
||||
let response: Response;
|
||||
try {
|
||||
const body: Record<string, unknown> = {
|
||||
model: this.model,
|
||||
messages,
|
||||
max_tokens: request.maxTokens ?? 2048,
|
||||
};
|
||||
|
||||
// Pass tool definitions to the API if provided
|
||||
if (request.tools && request.tools.length > 0) {
|
||||
body.tools = request.tools.map(t => ({
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
response = await fetch(`${this.endpoint}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
model: this.model,
|
||||
messages,
|
||||
max_tokens: request.maxTokens ?? 2048,
|
||||
}),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
} catch (error) {
|
||||
if (error instanceof TypeError && error.message.includes('fetch failed')) {
|
||||
@@ -77,13 +120,24 @@ export class LlamaCppClient implements ModelClient {
|
||||
|
||||
const data = (await response.json()) as LlamaCppResponse;
|
||||
|
||||
// Parse tool calls from the response, if present
|
||||
const toolCalls: ModelToolCall[] = data.choices[0]?.message?.tool_calls?.map((tc) => ({
|
||||
id: tc.id ?? `llamacpp_tc_${Math.random().toString(36).slice(2, 8)}`,
|
||||
name: tc.function.name,
|
||||
args: JSON.parse(tc.function.arguments),
|
||||
})) ?? [];
|
||||
|
||||
// Set stopReason to 'tool_use' when tool_calls are present
|
||||
const stopReason = toolCalls.length > 0 ? 'tool_use' : (data.choices[0]?.finish_reason ?? 'stop');
|
||||
|
||||
return {
|
||||
content: data.choices[0]?.message?.content ?? '',
|
||||
stopReason: 'stop',
|
||||
stopReason,
|
||||
usage: {
|
||||
inputTokens: data.usage?.prompt_tokens ?? 0,
|
||||
outputTokens: data.usage?.completion_tokens ?? 0,
|
||||
},
|
||||
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -107,15 +161,29 @@ export class LlamaCppClient implements ModelClient {
|
||||
}
|
||||
|
||||
try {
|
||||
const body: Record<string, unknown> = {
|
||||
model: this.model,
|
||||
messages,
|
||||
max_tokens: request.maxTokens ?? 2048,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// Pass tool definitions to the API if provided
|
||||
if (request.tools && request.tools.length > 0) {
|
||||
body.tools = request.tools.map(t => ({
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
const response = await fetch(`${this.endpoint}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
model: this.model,
|
||||
messages,
|
||||
max_tokens: request.maxTokens ?? 2048,
|
||||
stream: true,
|
||||
}),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -131,6 +199,8 @@ export class LlamaCppClient implements ModelClient {
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
let usage = { inputTokens: 0, outputTokens: 0 };
|
||||
// Accumulate tool call deltas across streamed chunks
|
||||
const toolCallAccumulators: Map<number, { id: string; name: string; arguments: string }> = new Map();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@@ -154,6 +224,22 @@ export class LlamaCppClient implements ModelClient {
|
||||
yield { type: 'content', content: chunk.choices[0].delta.content };
|
||||
}
|
||||
|
||||
// Accumulate tool call deltas from the stream
|
||||
if (chunk.choices[0]?.delta?.tool_calls) {
|
||||
for (const tc of chunk.choices[0].delta.tool_calls) {
|
||||
if (!toolCallAccumulators.has(tc.index)) {
|
||||
toolCallAccumulators.set(tc.index, {
|
||||
id: tc.id ?? `llamacpp_tc_${tc.index}`,
|
||||
name: tc.function?.name ?? '',
|
||||
arguments: '',
|
||||
});
|
||||
}
|
||||
const acc = toolCallAccumulators.get(tc.index)!;
|
||||
if (tc.function?.name) acc.name = tc.function.name;
|
||||
if (tc.function?.arguments) acc.arguments += tc.function.arguments;
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usage) {
|
||||
usage = {
|
||||
inputTokens: chunk.usage.prompt_tokens,
|
||||
@@ -166,6 +252,18 @@ export class LlamaCppClient implements ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
// Yield completed tool calls before the done event
|
||||
for (const [, acc] of toolCallAccumulators) {
|
||||
yield {
|
||||
type: 'tool_use',
|
||||
toolCall: {
|
||||
id: acc.id,
|
||||
name: acc.name,
|
||||
args: JSON.parse(acc.arguments),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
yield { type: 'done', usage };
|
||||
} catch (error) {
|
||||
yield {
|
||||
|
||||
@@ -1,23 +1,29 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { OllamaClient } from './ollama.js';
|
||||
|
||||
const mockChat = vi.fn();
|
||||
|
||||
vi.mock('ollama', () => ({
|
||||
Ollama: vi.fn().mockImplementation(() => ({
|
||||
chat: vi.fn().mockResolvedValue({
|
||||
message: { content: 'Hello from Ollama!' },
|
||||
done_reason: 'stop',
|
||||
prompt_eval_count: 10,
|
||||
eval_count: 5,
|
||||
}),
|
||||
chat: mockChat,
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('OllamaClient', () => {
|
||||
beforeEach(() => {
|
||||
mockChat.mockReset();
|
||||
});
|
||||
|
||||
it('sends messages and returns response', async () => {
|
||||
const client = new OllamaClient({
|
||||
model: 'llama3.2',
|
||||
mockChat.mockResolvedValue({
|
||||
message: { content: 'Hello from Ollama!' },
|
||||
done_reason: 'stop',
|
||||
prompt_eval_count: 10,
|
||||
eval_count: 5,
|
||||
});
|
||||
|
||||
const client = new OllamaClient({ model: 'llama3.2' });
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
});
|
||||
@@ -27,4 +33,205 @@ describe('OllamaClient', () => {
|
||||
expect(response.usage.inputTokens).toBe(10);
|
||||
expect(response.usage.outputTokens).toBe(5);
|
||||
});
|
||||
|
||||
it('passes tools to Ollama API in correct format', async () => {
|
||||
mockChat.mockResolvedValue({
|
||||
message: { content: 'I can help with that.' },
|
||||
done_reason: 'stop',
|
||||
prompt_eval_count: 15,
|
||||
eval_count: 8,
|
||||
});
|
||||
|
||||
const client = new OllamaClient({ model: 'llama3.2' });
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'List files' }],
|
||||
tools: [
|
||||
{
|
||||
name: 'shell.exec',
|
||||
description: 'Execute a shell command',
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
command: { type: 'string', description: 'The command to run' },
|
||||
},
|
||||
required: ['command'],
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
expect(mockChat).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tools: [
|
||||
{
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'shell.exec',
|
||||
description: 'Execute a shell command',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
required: ['command'],
|
||||
properties: {
|
||||
command: { type: 'string', description: 'The command to run' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('parses tool_calls from response', async () => {
|
||||
mockChat.mockResolvedValue({
|
||||
message: {
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{ function: { name: 'shell.exec', arguments: { command: 'ls' } } },
|
||||
],
|
||||
},
|
||||
done_reason: 'stop',
|
||||
prompt_eval_count: 12,
|
||||
eval_count: 6,
|
||||
});
|
||||
|
||||
const client = new OllamaClient({ model: 'llama3.2' });
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'List files' }],
|
||||
});
|
||||
|
||||
expect(response.stopReason).toBe('tool_use');
|
||||
expect(response.toolCalls).toHaveLength(1);
|
||||
expect(response.toolCalls![0]).toEqual({
|
||||
id: 'ollama_tc_0',
|
||||
name: 'shell.exec',
|
||||
args: { command: 'ls' },
|
||||
});
|
||||
});
|
||||
|
||||
it('handles thinking field from reasoning models', async () => {
|
||||
mockChat.mockResolvedValue({
|
||||
message: { content: '', thinking: 'Let me think...' },
|
||||
done_reason: 'stop',
|
||||
prompt_eval_count: 20,
|
||||
eval_count: 15,
|
||||
});
|
||||
|
||||
const client = new OllamaClient({ model: 'deepseek-r1' });
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Solve this problem' }],
|
||||
});
|
||||
|
||||
// When content is empty, thinking is used as fallback
|
||||
expect(response.content).toBe('Let me think...');
|
||||
expect(response.thinkingContent).toBe('Let me think...');
|
||||
});
|
||||
|
||||
it('thinking field does not override existing content', async () => {
|
||||
mockChat.mockResolvedValue({
|
||||
message: { content: 'Final answer', thinking: 'Reasoning...' },
|
||||
done_reason: 'stop',
|
||||
prompt_eval_count: 20,
|
||||
eval_count: 15,
|
||||
});
|
||||
|
||||
const client = new OllamaClient({ model: 'deepseek-r1' });
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Solve this problem' }],
|
||||
});
|
||||
|
||||
expect(response.content).toBe('Final answer');
|
||||
expect(response.thinkingContent).toBe('Reasoning...');
|
||||
});
|
||||
|
||||
it('does not send tools when none provided', async () => {
|
||||
mockChat.mockResolvedValue({
|
||||
message: { content: 'No tools needed.' },
|
||||
done_reason: 'stop',
|
||||
prompt_eval_count: 5,
|
||||
eval_count: 3,
|
||||
});
|
||||
|
||||
const client = new OllamaClient({ model: 'llama3.2' });
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
});
|
||||
|
||||
const callArgs = mockChat.mock.calls[0][0];
|
||||
expect(callArgs.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('streaming: yields content events', async () => {
|
||||
mockChat.mockResolvedValue(
|
||||
(async function* () {
|
||||
yield { message: { content: 'Hello' }, done: false };
|
||||
yield { message: { content: ' world' }, done: false };
|
||||
yield { message: { content: '' }, done: true, prompt_eval_count: 10, eval_count: 5 };
|
||||
})(),
|
||||
);
|
||||
|
||||
const client = new OllamaClient({ model: 'llama3.2' });
|
||||
|
||||
const events: Array<{ type: string; content?: string; usage?: { inputTokens: number; outputTokens: number } }> = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toHaveLength(3);
|
||||
expect(events[0]).toEqual({ type: 'content', content: 'Hello' });
|
||||
expect(events[1]).toEqual({ type: 'content', content: ' world' });
|
||||
expect(events[2]).toEqual({
|
||||
type: 'done',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
});
|
||||
});
|
||||
|
||||
it('streaming: yields tool_use events from final chunk', async () => {
|
||||
mockChat.mockResolvedValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
message: {
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{ function: { name: 'system.info', arguments: {} } },
|
||||
],
|
||||
},
|
||||
done: true,
|
||||
prompt_eval_count: 5,
|
||||
eval_count: 3,
|
||||
};
|
||||
})(),
|
||||
);
|
||||
|
||||
const client = new OllamaClient({ model: 'llama3.2' });
|
||||
|
||||
const events: Array<{ type: string; toolCall?: { id: string; name: string; args: unknown }; usage?: { inputTokens: number; outputTokens: number } }> = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Get system info' }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Should have tool_use event followed by done
|
||||
expect(events).toHaveLength(2);
|
||||
expect(events[0]).toEqual({
|
||||
type: 'tool_use',
|
||||
toolCall: {
|
||||
id: 'ollama_tc_0',
|
||||
name: 'system.info',
|
||||
args: {},
|
||||
},
|
||||
});
|
||||
expect(events[1]).toEqual({
|
||||
type: 'done',
|
||||
usage: { inputTokens: 5, outputTokens: 3 },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Ollama } from 'ollama';
|
||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js';
|
||||
import { Ollama, type Tool } from 'ollama';
|
||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ToolDefinition, ModelToolCall } from '../types.js';
|
||||
import { getMessageText } from '../media.js';
|
||||
|
||||
export interface OllamaClientConfig {
|
||||
@@ -21,6 +21,24 @@ export class OllamaClient implements ModelClient {
|
||||
this.numGpu = config.numGpu ?? -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert Flynn ToolDefinition[] to Ollama Tool[] format.
|
||||
*/
|
||||
private convertTools(tools: ToolDefinition[]): Tool[] {
|
||||
return tools.map(t => ({
|
||||
type: 'function',
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: {
|
||||
type: t.input_schema.type,
|
||||
required: t.input_schema.required,
|
||||
properties: t.input_schema.properties as Record<string, any>,
|
||||
},
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
async chat(request: ChatRequest): Promise<ChatResponse> {
|
||||
const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = [];
|
||||
|
||||
@@ -32,21 +50,51 @@ export class OllamaClient implements ModelClient {
|
||||
messages.push({ role: msg.role, content: getMessageText(msg) });
|
||||
}
|
||||
|
||||
const response = await this.client.chat({
|
||||
// Build the chat params, optionally including tools
|
||||
const chatParams: Parameters<typeof this.client.chat>[0] = {
|
||||
model: this.model,
|
||||
messages,
|
||||
options: {
|
||||
num_gpu: this.numGpu,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
if (request.tools && request.tools.length > 0) {
|
||||
chatParams.tools = this.convertTools(request.tools);
|
||||
}
|
||||
|
||||
const response = await this.client.chat(chatParams);
|
||||
|
||||
// Extract content, checking for thinking field from reasoning models
|
||||
let content = response.message.content;
|
||||
let thinkingContent: string | undefined;
|
||||
const thinking = (response.message as any).thinking;
|
||||
if (thinking && typeof thinking === 'string') {
|
||||
if (!content) {
|
||||
// If no regular content, use thinking as content
|
||||
content = thinking;
|
||||
}
|
||||
thinkingContent = thinking;
|
||||
}
|
||||
|
||||
// Parse tool_calls from the response
|
||||
const toolCalls: ModelToolCall[] = response.message.tool_calls?.map((tc, i) => ({
|
||||
id: `ollama_tc_${i}`,
|
||||
name: tc.function.name,
|
||||
args: tc.function.arguments,
|
||||
})) ?? [];
|
||||
|
||||
const hasToolCalls = toolCalls.length > 0;
|
||||
|
||||
return {
|
||||
content: response.message.content,
|
||||
stopReason: response.done_reason ?? 'stop',
|
||||
content,
|
||||
stopReason: hasToolCalls ? 'tool_use' : (response.done_reason ?? 'stop'),
|
||||
usage: {
|
||||
inputTokens: response.prompt_eval_count ?? 0,
|
||||
outputTokens: response.eval_count ?? 0,
|
||||
},
|
||||
...(hasToolCalls ? { toolCalls } : {}),
|
||||
...(thinkingContent ? { thinkingContent } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -62,6 +110,11 @@ export class OllamaClient implements ModelClient {
|
||||
}
|
||||
|
||||
try {
|
||||
// Build tools array if provided
|
||||
const tools = request.tools && request.tools.length > 0
|
||||
? this.convertTools(request.tools)
|
||||
: undefined;
|
||||
|
||||
const stream = await this.client.chat({
|
||||
model: this.model,
|
||||
messages,
|
||||
@@ -69,6 +122,7 @@ export class OllamaClient implements ModelClient {
|
||||
options: {
|
||||
num_gpu: this.numGpu,
|
||||
},
|
||||
...(tools ? { tools } : {}),
|
||||
});
|
||||
|
||||
let inputTokens = 0;
|
||||
@@ -79,6 +133,12 @@ export class OllamaClient implements ModelClient {
|
||||
yield { type: 'content', content: chunk.message.content };
|
||||
}
|
||||
|
||||
// Handle thinking field from reasoning models (e.g., deepseek-r1)
|
||||
const thinking = (chunk.message as any)?.thinking;
|
||||
if (thinking && typeof thinking === 'string') {
|
||||
yield { type: 'content', content: thinking };
|
||||
}
|
||||
|
||||
if (chunk.prompt_eval_count) {
|
||||
inputTokens = chunk.prompt_eval_count;
|
||||
}
|
||||
@@ -87,6 +147,22 @@ export class OllamaClient implements ModelClient {
|
||||
}
|
||||
|
||||
if (chunk.done) {
|
||||
// Handle tool_calls in the final chunk
|
||||
const toolCalls = (chunk.message as any)?.tool_calls;
|
||||
if (toolCalls && Array.isArray(toolCalls)) {
|
||||
for (let i = 0; i < toolCalls.length; i++) {
|
||||
const tc = toolCalls[i];
|
||||
yield {
|
||||
type: 'tool_use',
|
||||
toolCall: {
|
||||
id: `ollama_tc_${i}`,
|
||||
name: tc.function.name,
|
||||
args: tc.function.arguments,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: {
|
||||
|
||||
Reference in New Issue
Block a user