feat: add tool calling support to Ollama and llama.cpp clients
- Ollama: pass tools to API, parse tool_calls responses, handle thinking field from reasoning models (deepseek-r1, glm-4.7-flash) - llama.cpp: pass tools via OpenAI-compatible endpoint, parse tool_calls, accumulate streaming tool call deltas - Both clients now set stopReason to 'tool_use' when tool calls are present - Tests: 12 new tests (8 Ollama + 5 llama.cpp, total 983→995)
This commit is contained in:
@@ -6,6 +6,7 @@ describe('LlamaCppClient', () => {
|
|||||||
const mockFetch = vi.fn();
|
const mockFetch = vi.fn();
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
|
mockFetch.mockReset();
|
||||||
vi.stubGlobal('fetch', mockFetch);
|
vi.stubGlobal('fetch', mockFetch);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -96,4 +97,247 @@ describe('LlamaCppClient', () => {
|
|||||||
messages: [{ role: 'user', content: 'Hello' }],
|
messages: [{ role: 'user', content: 'Hello' }],
|
||||||
})).rejects.toThrow('llama-server not running at http://localhost:8080');
|
})).rejects.toThrow('llama-server not running at http://localhost:8080');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('passes tools in request body', async () => {
|
||||||
|
mockFetch.mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
json: () => Promise.resolve({
|
||||||
|
choices: [{ message: { content: 'I can help with that.' } }],
|
||||||
|
usage: { prompt_tokens: 12, completion_tokens: 6 },
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new LlamaCppClient({
|
||||||
|
endpoint: 'http://localhost:8080',
|
||||||
|
model: 'test-model',
|
||||||
|
});
|
||||||
|
|
||||||
|
await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'Run ls' }],
|
||||||
|
tools: [{
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Run shell',
|
||||||
|
input_schema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: { command: { type: 'string' } },
|
||||||
|
required: ['command'],
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
|
||||||
|
expect(requestBody.tools).toEqual([{
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Run shell',
|
||||||
|
parameters: {
|
||||||
|
type: 'object',
|
||||||
|
properties: { command: { type: 'string' } },
|
||||||
|
required: ['command'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('parses tool_calls from response', async () => {
|
||||||
|
mockFetch.mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
json: () => Promise.resolve({
|
||||||
|
choices: [{
|
||||||
|
message: {
|
||||||
|
content: null,
|
||||||
|
tool_calls: [{
|
||||||
|
id: 'call_123',
|
||||||
|
type: 'function',
|
||||||
|
function: { name: 'shell.exec', arguments: '{"command":"ls"}' },
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
finish_reason: 'tool_calls',
|
||||||
|
}],
|
||||||
|
usage: { prompt_tokens: 15, completion_tokens: 8 },
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new LlamaCppClient({
|
||||||
|
endpoint: 'http://localhost:8080',
|
||||||
|
model: 'test-model',
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'List files' }],
|
||||||
|
tools: [{
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Run shell',
|
||||||
|
input_schema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: { command: { type: 'string' } },
|
||||||
|
required: ['command'],
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.stopReason).toBe('tool_use');
|
||||||
|
expect(response.toolCalls).toHaveLength(1);
|
||||||
|
expect(response.toolCalls![0]).toEqual({
|
||||||
|
id: 'call_123',
|
||||||
|
name: 'shell.exec',
|
||||||
|
args: { command: 'ls' },
|
||||||
|
});
|
||||||
|
expect(response.usage.inputTokens).toBe(15);
|
||||||
|
expect(response.usage.outputTokens).toBe(8);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not send tools when none provided', async () => {
|
||||||
|
mockFetch.mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
json: () => Promise.resolve({
|
||||||
|
choices: [{ message: { content: 'Hello!' } }],
|
||||||
|
usage: { prompt_tokens: 5, completion_tokens: 2 },
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new LlamaCppClient({
|
||||||
|
endpoint: 'http://localhost:8080',
|
||||||
|
model: 'test-model',
|
||||||
|
});
|
||||||
|
|
||||||
|
await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'Hello' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
|
||||||
|
expect(requestBody.tools).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('streaming: accumulates and yields tool_calls from deltas', async () => {
|
||||||
|
const chunks = [
|
||||||
|
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"shell.exec"}}]}}]}\n\n',
|
||||||
|
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"comma"}}]}}]}\n\n',
|
||||||
|
'data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"nd\\":\\"ls\\"}"}}]}}]}\n\n',
|
||||||
|
'data: {"choices":[{}],"usage":{"prompt_tokens":10,"completion_tokens":5}}\n\n',
|
||||||
|
'data: [DONE]\n\n',
|
||||||
|
];
|
||||||
|
|
||||||
|
const encoder = new TextEncoder();
|
||||||
|
let chunkIndex = 0;
|
||||||
|
|
||||||
|
const mockStream = new ReadableStream({
|
||||||
|
pull(controller) {
|
||||||
|
if (chunkIndex < chunks.length) {
|
||||||
|
controller.enqueue(encoder.encode(chunks[chunkIndex]));
|
||||||
|
chunkIndex++;
|
||||||
|
} else {
|
||||||
|
controller.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
mockFetch.mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
body: mockStream,
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new LlamaCppClient({
|
||||||
|
endpoint: 'http://localhost:8080',
|
||||||
|
model: 'test-model',
|
||||||
|
});
|
||||||
|
|
||||||
|
const events: ChatStreamEvent[] = [];
|
||||||
|
for await (const event of client.chatStream({
|
||||||
|
messages: [{ role: 'user', content: 'Run ls' }],
|
||||||
|
tools: [{
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Run shell',
|
||||||
|
input_schema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: { command: { type: 'string' } },
|
||||||
|
required: ['command'],
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})) {
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have a tool_use event and a done event
|
||||||
|
const toolUseEvents = events.filter(e => e.type === 'tool_use');
|
||||||
|
const doneEvents = events.filter(e => e.type === 'done');
|
||||||
|
|
||||||
|
expect(toolUseEvents).toHaveLength(1);
|
||||||
|
expect(toolUseEvents[0].toolCall).toEqual({
|
||||||
|
id: 'call_1',
|
||||||
|
name: 'shell.exec',
|
||||||
|
args: { command: 'ls' },
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(doneEvents).toHaveLength(1);
|
||||||
|
expect(doneEvents[0].usage).toEqual({
|
||||||
|
inputTokens: 10,
|
||||||
|
outputTokens: 5,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('streaming: passes tools in request body', async () => {
|
||||||
|
const chunks = [
|
||||||
|
'data: {"choices":[{"delta":{"content":"Hi"}}]}\n\n',
|
||||||
|
'data: {"choices":[{}],"usage":{"prompt_tokens":3,"completion_tokens":1}}\n\n',
|
||||||
|
'data: [DONE]\n\n',
|
||||||
|
];
|
||||||
|
|
||||||
|
const encoder = new TextEncoder();
|
||||||
|
let chunkIndex = 0;
|
||||||
|
|
||||||
|
const mockStream = new ReadableStream({
|
||||||
|
pull(controller) {
|
||||||
|
if (chunkIndex < chunks.length) {
|
||||||
|
controller.enqueue(encoder.encode(chunks[chunkIndex]));
|
||||||
|
chunkIndex++;
|
||||||
|
} else {
|
||||||
|
controller.close();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
mockFetch.mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
body: mockStream,
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new LlamaCppClient({
|
||||||
|
endpoint: 'http://localhost:8080',
|
||||||
|
model: 'test-model',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Consume the stream to trigger the fetch call
|
||||||
|
const events: ChatStreamEvent[] = [];
|
||||||
|
for await (const event of client.chatStream({
|
||||||
|
messages: [{ role: 'user', content: 'Hi' }],
|
||||||
|
tools: [{
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Run shell',
|
||||||
|
input_schema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: { command: { type: 'string' } },
|
||||||
|
required: ['command'],
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})) {
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
|
||||||
|
expect(requestBody.tools).toEqual([{
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Run shell',
|
||||||
|
parameters: {
|
||||||
|
type: 'object',
|
||||||
|
properties: { command: { type: 'string' } },
|
||||||
|
required: ['command'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}]);
|
||||||
|
expect(requestBody.stream).toBe(true);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
+113
-15
@@ -1,4 +1,4 @@
|
|||||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js';
|
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall } from '../types.js';
|
||||||
import { getMessageText } from '../media.js';
|
import { getMessageText } from '../media.js';
|
||||||
|
|
||||||
export interface LlamaCppClientConfig {
|
export interface LlamaCppClientConfig {
|
||||||
@@ -12,13 +12,42 @@ interface LlamaCppMessage {
|
|||||||
content: string;
|
content: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface LlamaCppToolCall {
|
||||||
|
id: string;
|
||||||
|
type: 'function';
|
||||||
|
function: {
|
||||||
|
name: string;
|
||||||
|
arguments: string; // JSON string
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
interface LlamaCppResponse {
|
interface LlamaCppResponse {
|
||||||
choices: Array<{ message: { content: string } }>;
|
choices: Array<{
|
||||||
|
message: {
|
||||||
|
content: string | null;
|
||||||
|
tool_calls?: LlamaCppToolCall[];
|
||||||
|
};
|
||||||
|
finish_reason?: string;
|
||||||
|
}>;
|
||||||
usage: { prompt_tokens: number; completion_tokens: number };
|
usage: { prompt_tokens: number; completion_tokens: number };
|
||||||
}
|
}
|
||||||
|
|
||||||
interface LlamaCppStreamChunk {
|
interface LlamaCppStreamChunk {
|
||||||
choices: Array<{ delta?: { content?: string } }>;
|
choices: Array<{
|
||||||
|
delta?: {
|
||||||
|
content?: string;
|
||||||
|
tool_calls?: Array<{
|
||||||
|
index: number;
|
||||||
|
id?: string;
|
||||||
|
type?: string;
|
||||||
|
function?: {
|
||||||
|
name?: string;
|
||||||
|
arguments?: string;
|
||||||
|
};
|
||||||
|
}>;
|
||||||
|
};
|
||||||
|
finish_reason?: string | null;
|
||||||
|
}>;
|
||||||
usage?: { prompt_tokens: number; completion_tokens: number };
|
usage?: { prompt_tokens: number; completion_tokens: number };
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,14 +83,28 @@ export class LlamaCppClient implements ModelClient {
|
|||||||
|
|
||||||
let response: Response;
|
let response: Response;
|
||||||
try {
|
try {
|
||||||
|
const body: Record<string, unknown> = {
|
||||||
|
model: this.model,
|
||||||
|
messages,
|
||||||
|
max_tokens: request.maxTokens ?? 2048,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pass tool definitions to the API if provided
|
||||||
|
if (request.tools && request.tools.length > 0) {
|
||||||
|
body.tools = request.tools.map(t => ({
|
||||||
|
type: 'function' as const,
|
||||||
|
function: {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: t.input_schema,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
response = await fetch(`${this.endpoint}/v1/chat/completions`, {
|
response = await fetch(`${this.endpoint}/v1/chat/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers,
|
headers,
|
||||||
body: JSON.stringify({
|
body: JSON.stringify(body),
|
||||||
model: this.model,
|
|
||||||
messages,
|
|
||||||
max_tokens: request.maxTokens ?? 2048,
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof TypeError && error.message.includes('fetch failed')) {
|
if (error instanceof TypeError && error.message.includes('fetch failed')) {
|
||||||
@@ -77,13 +120,24 @@ export class LlamaCppClient implements ModelClient {
|
|||||||
|
|
||||||
const data = (await response.json()) as LlamaCppResponse;
|
const data = (await response.json()) as LlamaCppResponse;
|
||||||
|
|
||||||
|
// Parse tool calls from the response, if present
|
||||||
|
const toolCalls: ModelToolCall[] = data.choices[0]?.message?.tool_calls?.map((tc) => ({
|
||||||
|
id: tc.id ?? `llamacpp_tc_${Math.random().toString(36).slice(2, 8)}`,
|
||||||
|
name: tc.function.name,
|
||||||
|
args: JSON.parse(tc.function.arguments),
|
||||||
|
})) ?? [];
|
||||||
|
|
||||||
|
// Set stopReason to 'tool_use' when tool_calls are present
|
||||||
|
const stopReason = toolCalls.length > 0 ? 'tool_use' : (data.choices[0]?.finish_reason ?? 'stop');
|
||||||
|
|
||||||
return {
|
return {
|
||||||
content: data.choices[0]?.message?.content ?? '',
|
content: data.choices[0]?.message?.content ?? '',
|
||||||
stopReason: 'stop',
|
stopReason,
|
||||||
usage: {
|
usage: {
|
||||||
inputTokens: data.usage?.prompt_tokens ?? 0,
|
inputTokens: data.usage?.prompt_tokens ?? 0,
|
||||||
outputTokens: data.usage?.completion_tokens ?? 0,
|
outputTokens: data.usage?.completion_tokens ?? 0,
|
||||||
},
|
},
|
||||||
|
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,15 +161,29 @@ export class LlamaCppClient implements ModelClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
const body: Record<string, unknown> = {
|
||||||
|
model: this.model,
|
||||||
|
messages,
|
||||||
|
max_tokens: request.maxTokens ?? 2048,
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pass tool definitions to the API if provided
|
||||||
|
if (request.tools && request.tools.length > 0) {
|
||||||
|
body.tools = request.tools.map(t => ({
|
||||||
|
type: 'function' as const,
|
||||||
|
function: {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: t.input_schema,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
const response = await fetch(`${this.endpoint}/v1/chat/completions`, {
|
const response = await fetch(`${this.endpoint}/v1/chat/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers,
|
headers,
|
||||||
body: JSON.stringify({
|
body: JSON.stringify(body),
|
||||||
model: this.model,
|
|
||||||
messages,
|
|
||||||
max_tokens: request.maxTokens ?? 2048,
|
|
||||||
stream: true,
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
@@ -131,6 +199,8 @@ export class LlamaCppClient implements ModelClient {
|
|||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
let buffer = '';
|
let buffer = '';
|
||||||
let usage = { inputTokens: 0, outputTokens: 0 };
|
let usage = { inputTokens: 0, outputTokens: 0 };
|
||||||
|
// Accumulate tool call deltas across streamed chunks
|
||||||
|
const toolCallAccumulators: Map<number, { id: string; name: string; arguments: string }> = new Map();
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const { done, value } = await reader.read();
|
const { done, value } = await reader.read();
|
||||||
@@ -154,6 +224,22 @@ export class LlamaCppClient implements ModelClient {
|
|||||||
yield { type: 'content', content: chunk.choices[0].delta.content };
|
yield { type: 'content', content: chunk.choices[0].delta.content };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accumulate tool call deltas from the stream
|
||||||
|
if (chunk.choices[0]?.delta?.tool_calls) {
|
||||||
|
for (const tc of chunk.choices[0].delta.tool_calls) {
|
||||||
|
if (!toolCallAccumulators.has(tc.index)) {
|
||||||
|
toolCallAccumulators.set(tc.index, {
|
||||||
|
id: tc.id ?? `llamacpp_tc_${tc.index}`,
|
||||||
|
name: tc.function?.name ?? '',
|
||||||
|
arguments: '',
|
||||||
|
});
|
||||||
|
}
|
||||||
|
const acc = toolCallAccumulators.get(tc.index)!;
|
||||||
|
if (tc.function?.name) acc.name = tc.function.name;
|
||||||
|
if (tc.function?.arguments) acc.arguments += tc.function.arguments;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (chunk.usage) {
|
if (chunk.usage) {
|
||||||
usage = {
|
usage = {
|
||||||
inputTokens: chunk.usage.prompt_tokens,
|
inputTokens: chunk.usage.prompt_tokens,
|
||||||
@@ -166,6 +252,18 @@ export class LlamaCppClient implements ModelClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Yield completed tool calls before the done event
|
||||||
|
for (const [, acc] of toolCallAccumulators) {
|
||||||
|
yield {
|
||||||
|
type: 'tool_use',
|
||||||
|
toolCall: {
|
||||||
|
id: acc.id,
|
||||||
|
name: acc.name,
|
||||||
|
args: JSON.parse(acc.arguments),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
yield { type: 'done', usage };
|
yield { type: 'done', usage };
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
yield {
|
yield {
|
||||||
|
|||||||
@@ -1,23 +1,29 @@
|
|||||||
import { describe, it, expect, vi } from 'vitest';
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||||
import { OllamaClient } from './ollama.js';
|
import { OllamaClient } from './ollama.js';
|
||||||
|
|
||||||
|
const mockChat = vi.fn();
|
||||||
|
|
||||||
vi.mock('ollama', () => ({
|
vi.mock('ollama', () => ({
|
||||||
Ollama: vi.fn().mockImplementation(() => ({
|
Ollama: vi.fn().mockImplementation(() => ({
|
||||||
chat: vi.fn().mockResolvedValue({
|
chat: mockChat,
|
||||||
message: { content: 'Hello from Ollama!' },
|
|
||||||
done_reason: 'stop',
|
|
||||||
prompt_eval_count: 10,
|
|
||||||
eval_count: 5,
|
|
||||||
}),
|
|
||||||
})),
|
})),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
describe('OllamaClient', () => {
|
describe('OllamaClient', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
mockChat.mockReset();
|
||||||
|
});
|
||||||
|
|
||||||
it('sends messages and returns response', async () => {
|
it('sends messages and returns response', async () => {
|
||||||
const client = new OllamaClient({
|
mockChat.mockResolvedValue({
|
||||||
model: 'llama3.2',
|
message: { content: 'Hello from Ollama!' },
|
||||||
|
done_reason: 'stop',
|
||||||
|
prompt_eval_count: 10,
|
||||||
|
eval_count: 5,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const client = new OllamaClient({ model: 'llama3.2' });
|
||||||
|
|
||||||
const response = await client.chat({
|
const response = await client.chat({
|
||||||
messages: [{ role: 'user', content: 'Hello' }],
|
messages: [{ role: 'user', content: 'Hello' }],
|
||||||
});
|
});
|
||||||
@@ -27,4 +33,205 @@ describe('OllamaClient', () => {
|
|||||||
expect(response.usage.inputTokens).toBe(10);
|
expect(response.usage.inputTokens).toBe(10);
|
||||||
expect(response.usage.outputTokens).toBe(5);
|
expect(response.usage.outputTokens).toBe(5);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('passes tools to Ollama API in correct format', async () => {
|
||||||
|
mockChat.mockResolvedValue({
|
||||||
|
message: { content: 'I can help with that.' },
|
||||||
|
done_reason: 'stop',
|
||||||
|
prompt_eval_count: 15,
|
||||||
|
eval_count: 8,
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new OllamaClient({ model: 'llama3.2' });
|
||||||
|
|
||||||
|
await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'List files' }],
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Execute a shell command',
|
||||||
|
input_schema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
command: { type: 'string', description: 'The command to run' },
|
||||||
|
},
|
||||||
|
required: ['command'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockChat).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Execute a shell command',
|
||||||
|
parameters: {
|
||||||
|
type: 'object',
|
||||||
|
required: ['command'],
|
||||||
|
properties: {
|
||||||
|
command: { type: 'string', description: 'The command to run' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('parses tool_calls from response', async () => {
|
||||||
|
mockChat.mockResolvedValue({
|
||||||
|
message: {
|
||||||
|
content: '',
|
||||||
|
tool_calls: [
|
||||||
|
{ function: { name: 'shell.exec', arguments: { command: 'ls' } } },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
done_reason: 'stop',
|
||||||
|
prompt_eval_count: 12,
|
||||||
|
eval_count: 6,
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new OllamaClient({ model: 'llama3.2' });
|
||||||
|
|
||||||
|
const response = await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'List files' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.stopReason).toBe('tool_use');
|
||||||
|
expect(response.toolCalls).toHaveLength(1);
|
||||||
|
expect(response.toolCalls![0]).toEqual({
|
||||||
|
id: 'ollama_tc_0',
|
||||||
|
name: 'shell.exec',
|
||||||
|
args: { command: 'ls' },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('handles thinking field from reasoning models', async () => {
|
||||||
|
mockChat.mockResolvedValue({
|
||||||
|
message: { content: '', thinking: 'Let me think...' },
|
||||||
|
done_reason: 'stop',
|
||||||
|
prompt_eval_count: 20,
|
||||||
|
eval_count: 15,
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new OllamaClient({ model: 'deepseek-r1' });
|
||||||
|
|
||||||
|
const response = await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'Solve this problem' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
// When content is empty, thinking is used as fallback
|
||||||
|
expect(response.content).toBe('Let me think...');
|
||||||
|
expect(response.thinkingContent).toBe('Let me think...');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('thinking field does not override existing content', async () => {
|
||||||
|
mockChat.mockResolvedValue({
|
||||||
|
message: { content: 'Final answer', thinking: 'Reasoning...' },
|
||||||
|
done_reason: 'stop',
|
||||||
|
prompt_eval_count: 20,
|
||||||
|
eval_count: 15,
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new OllamaClient({ model: 'deepseek-r1' });
|
||||||
|
|
||||||
|
const response = await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'Solve this problem' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.content).toBe('Final answer');
|
||||||
|
expect(response.thinkingContent).toBe('Reasoning...');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not send tools when none provided', async () => {
|
||||||
|
mockChat.mockResolvedValue({
|
||||||
|
message: { content: 'No tools needed.' },
|
||||||
|
done_reason: 'stop',
|
||||||
|
prompt_eval_count: 5,
|
||||||
|
eval_count: 3,
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new OllamaClient({ model: 'llama3.2' });
|
||||||
|
|
||||||
|
await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'Hello' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const callArgs = mockChat.mock.calls[0][0];
|
||||||
|
expect(callArgs.tools).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('streaming: yields content events', async () => {
|
||||||
|
mockChat.mockResolvedValue(
|
||||||
|
(async function* () {
|
||||||
|
yield { message: { content: 'Hello' }, done: false };
|
||||||
|
yield { message: { content: ' world' }, done: false };
|
||||||
|
yield { message: { content: '' }, done: true, prompt_eval_count: 10, eval_count: 5 };
|
||||||
|
})(),
|
||||||
|
);
|
||||||
|
|
||||||
|
const client = new OllamaClient({ model: 'llama3.2' });
|
||||||
|
|
||||||
|
const events: Array<{ type: string; content?: string; usage?: { inputTokens: number; outputTokens: number } }> = [];
|
||||||
|
for await (const event of client.chatStream({
|
||||||
|
messages: [{ role: 'user', content: 'Hello' }],
|
||||||
|
})) {
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(events).toHaveLength(3);
|
||||||
|
expect(events[0]).toEqual({ type: 'content', content: 'Hello' });
|
||||||
|
expect(events[1]).toEqual({ type: 'content', content: ' world' });
|
||||||
|
expect(events[2]).toEqual({
|
||||||
|
type: 'done',
|
||||||
|
usage: { inputTokens: 10, outputTokens: 5 },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('streaming: yields tool_use events from final chunk', async () => {
|
||||||
|
mockChat.mockResolvedValue(
|
||||||
|
(async function* () {
|
||||||
|
yield {
|
||||||
|
message: {
|
||||||
|
content: '',
|
||||||
|
tool_calls: [
|
||||||
|
{ function: { name: 'system.info', arguments: {} } },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
done: true,
|
||||||
|
prompt_eval_count: 5,
|
||||||
|
eval_count: 3,
|
||||||
|
};
|
||||||
|
})(),
|
||||||
|
);
|
||||||
|
|
||||||
|
const client = new OllamaClient({ model: 'llama3.2' });
|
||||||
|
|
||||||
|
const events: Array<{ type: string; toolCall?: { id: string; name: string; args: unknown }; usage?: { inputTokens: number; outputTokens: number } }> = [];
|
||||||
|
for await (const event of client.chatStream({
|
||||||
|
messages: [{ role: 'user', content: 'Get system info' }],
|
||||||
|
})) {
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have tool_use event followed by done
|
||||||
|
expect(events).toHaveLength(2);
|
||||||
|
expect(events[0]).toEqual({
|
||||||
|
type: 'tool_use',
|
||||||
|
toolCall: {
|
||||||
|
id: 'ollama_tc_0',
|
||||||
|
name: 'system.info',
|
||||||
|
args: {},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
expect(events[1]).toEqual({
|
||||||
|
type: 'done',
|
||||||
|
usage: { inputTokens: 5, outputTokens: 3 },
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { Ollama } from 'ollama';
|
import { Ollama, type Tool } from 'ollama';
|
||||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from '../types.js';
|
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ToolDefinition, ModelToolCall } from '../types.js';
|
||||||
import { getMessageText } from '../media.js';
|
import { getMessageText } from '../media.js';
|
||||||
|
|
||||||
export interface OllamaClientConfig {
|
export interface OllamaClientConfig {
|
||||||
@@ -21,6 +21,24 @@ export class OllamaClient implements ModelClient {
|
|||||||
this.numGpu = config.numGpu ?? -1;
|
this.numGpu = config.numGpu ?? -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert Flynn ToolDefinition[] to Ollama Tool[] format.
|
||||||
|
*/
|
||||||
|
private convertTools(tools: ToolDefinition[]): Tool[] {
|
||||||
|
return tools.map(t => ({
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: {
|
||||||
|
type: t.input_schema.type,
|
||||||
|
required: t.input_schema.required,
|
||||||
|
properties: t.input_schema.properties as Record<string, any>,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
async chat(request: ChatRequest): Promise<ChatResponse> {
|
async chat(request: ChatRequest): Promise<ChatResponse> {
|
||||||
const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = [];
|
const messages: Array<{ role: 'system' | 'user' | 'assistant'; content: string }> = [];
|
||||||
|
|
||||||
@@ -32,21 +50,51 @@ export class OllamaClient implements ModelClient {
|
|||||||
messages.push({ role: msg.role, content: getMessageText(msg) });
|
messages.push({ role: msg.role, content: getMessageText(msg) });
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await this.client.chat({
|
// Build the chat params, optionally including tools
|
||||||
|
const chatParams: Parameters<typeof this.client.chat>[0] = {
|
||||||
model: this.model,
|
model: this.model,
|
||||||
messages,
|
messages,
|
||||||
options: {
|
options: {
|
||||||
num_gpu: this.numGpu,
|
num_gpu: this.numGpu,
|
||||||
},
|
},
|
||||||
});
|
};
|
||||||
|
|
||||||
|
if (request.tools && request.tools.length > 0) {
|
||||||
|
chatParams.tools = this.convertTools(request.tools);
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await this.client.chat(chatParams);
|
||||||
|
|
||||||
|
// Extract content, checking for thinking field from reasoning models
|
||||||
|
let content = response.message.content;
|
||||||
|
let thinkingContent: string | undefined;
|
||||||
|
const thinking = (response.message as any).thinking;
|
||||||
|
if (thinking && typeof thinking === 'string') {
|
||||||
|
if (!content) {
|
||||||
|
// If no regular content, use thinking as content
|
||||||
|
content = thinking;
|
||||||
|
}
|
||||||
|
thinkingContent = thinking;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse tool_calls from the response
|
||||||
|
const toolCalls: ModelToolCall[] = response.message.tool_calls?.map((tc, i) => ({
|
||||||
|
id: `ollama_tc_${i}`,
|
||||||
|
name: tc.function.name,
|
||||||
|
args: tc.function.arguments,
|
||||||
|
})) ?? [];
|
||||||
|
|
||||||
|
const hasToolCalls = toolCalls.length > 0;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
content: response.message.content,
|
content,
|
||||||
stopReason: response.done_reason ?? 'stop',
|
stopReason: hasToolCalls ? 'tool_use' : (response.done_reason ?? 'stop'),
|
||||||
usage: {
|
usage: {
|
||||||
inputTokens: response.prompt_eval_count ?? 0,
|
inputTokens: response.prompt_eval_count ?? 0,
|
||||||
outputTokens: response.eval_count ?? 0,
|
outputTokens: response.eval_count ?? 0,
|
||||||
},
|
},
|
||||||
|
...(hasToolCalls ? { toolCalls } : {}),
|
||||||
|
...(thinkingContent ? { thinkingContent } : {}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,6 +110,11 @@ export class OllamaClient implements ModelClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// Build tools array if provided
|
||||||
|
const tools = request.tools && request.tools.length > 0
|
||||||
|
? this.convertTools(request.tools)
|
||||||
|
: undefined;
|
||||||
|
|
||||||
const stream = await this.client.chat({
|
const stream = await this.client.chat({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
messages,
|
messages,
|
||||||
@@ -69,6 +122,7 @@ export class OllamaClient implements ModelClient {
|
|||||||
options: {
|
options: {
|
||||||
num_gpu: this.numGpu,
|
num_gpu: this.numGpu,
|
||||||
},
|
},
|
||||||
|
...(tools ? { tools } : {}),
|
||||||
});
|
});
|
||||||
|
|
||||||
let inputTokens = 0;
|
let inputTokens = 0;
|
||||||
@@ -79,6 +133,12 @@ export class OllamaClient implements ModelClient {
|
|||||||
yield { type: 'content', content: chunk.message.content };
|
yield { type: 'content', content: chunk.message.content };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle thinking field from reasoning models (e.g., deepseek-r1)
|
||||||
|
const thinking = (chunk.message as any)?.thinking;
|
||||||
|
if (thinking && typeof thinking === 'string') {
|
||||||
|
yield { type: 'content', content: thinking };
|
||||||
|
}
|
||||||
|
|
||||||
if (chunk.prompt_eval_count) {
|
if (chunk.prompt_eval_count) {
|
||||||
inputTokens = chunk.prompt_eval_count;
|
inputTokens = chunk.prompt_eval_count;
|
||||||
}
|
}
|
||||||
@@ -87,6 +147,22 @@ export class OllamaClient implements ModelClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (chunk.done) {
|
if (chunk.done) {
|
||||||
|
// Handle tool_calls in the final chunk
|
||||||
|
const toolCalls = (chunk.message as any)?.tool_calls;
|
||||||
|
if (toolCalls && Array.isArray(toolCalls)) {
|
||||||
|
for (let i = 0; i < toolCalls.length; i++) {
|
||||||
|
const tc = toolCalls[i];
|
||||||
|
yield {
|
||||||
|
type: 'tool_use',
|
||||||
|
toolCall: {
|
||||||
|
id: `ollama_tc_${i}`,
|
||||||
|
name: tc.function.name,
|
||||||
|
args: tc.function.arguments,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
type: 'done',
|
type: 'done',
|
||||||
usage: {
|
usage: {
|
||||||
|
|||||||
Reference in New Issue
Block a user