diff --git a/src/models/openai.test.ts b/src/models/openai.test.ts index 1b37a0d..823f0cd 100644 --- a/src/models/openai.test.ts +++ b/src/models/openai.test.ts @@ -1,14 +1,17 @@ import { describe, it, expect, vi } from 'vitest'; import { OpenAIClient } from './openai.js'; +// Shared mock function so we can override per-test +const mockCreate = vi.fn().mockResolvedValue({ + choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }], + usage: { prompt_tokens: 10, completion_tokens: 5 }, +}); + vi.mock('openai', () => ({ default: vi.fn().mockImplementation(() => ({ chat: { completions: { - create: vi.fn().mockResolvedValue({ - choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }], - usage: { prompt_tokens: 10, completion_tokens: 5 }, - }), + create: mockCreate, }, }, })), @@ -31,3 +34,44 @@ describe('OpenAIClient', () => { expect(response.usage.outputTokens).toBe(5); }); }); + +describe('OpenAIClient tool use', () => { + it('passes tools to API and parses tool_calls response', async () => { + mockCreate.mockResolvedValueOnce({ + choices: [{ + message: { + content: null, + tool_calls: [{ + id: 'call_1', + type: 'function', + function: { name: 'shell.exec', arguments: '{"command":"ls"}' }, + }], + }, + finish_reason: 'tool_calls', + }], + usage: { prompt_tokens: 20, completion_tokens: 15 }, + }); + + const client = new OpenAIClient({ + apiKey: 'test-key', + model: 'gpt-4o', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'list files' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell command', + input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, + }], + }); + + expect(response.stopReason).toBe('tool_calls'); + expect(response.toolCalls).toHaveLength(1); + expect(response.toolCalls![0]).toEqual({ + id: 'call_1', + name: 'shell.exec', + args: { command: 'ls' }, + }); + }); +}); diff --git a/src/models/openai.ts b/src/models/openai.ts index 3b75ca0..51faeac 100644 --- a/src/models/openai.ts +++ b/src/models/openai.ts @@ -33,15 +33,36 @@ export class OpenAIClient implements ModelClient { messages.push({ role: msg.role, content: msg.content }); } - const response = await this.client.chat.completions.create({ + // Build params, conditionally including tools + const params: OpenAI.ChatCompletionCreateParamsNonStreaming = { model: this.model, max_tokens: request.maxTokens ?? this.defaultMaxTokens, messages, - }); + }; + + if (request.tools && request.tools.length > 0) { + params.tools = request.tools.map(t => ({ + type: 'function' as const, + function: { + name: t.name, + description: t.description, + parameters: t.input_schema as OpenAI.FunctionParameters, + }, + })); + } + + const response = await this.client.chat.completions.create(params); const choice = response.choices[0]; const content = choice?.message?.content ?? ''; + // Parse tool_calls from the response if present + const toolCalls = choice?.message?.tool_calls?.map((tc: OpenAI.ChatCompletionMessageToolCall) => ({ + id: tc.id, + name: tc.function.name, + args: JSON.parse(tc.function.arguments), + })) ?? []; + return { content, stopReason: choice?.finish_reason ?? 'stop', @@ -49,6 +70,7 @@ export class OpenAIClient implements ModelClient { inputTokens: response.usage?.prompt_tokens ?? 0, outputTokens: response.usage?.completion_tokens ?? 0, }, + ...(toolCalls.length > 0 ? { toolCalls } : {}), }; } }