feat(models): add tool use support to OpenAIClient
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { OpenAIClient } from './openai.js';
|
||||
|
||||
// Shared mock function so we can override per-test
|
||||
const mockCreate = vi.fn().mockResolvedValue({
|
||||
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
|
||||
usage: { prompt_tokens: 10, completion_tokens: 5 },
|
||||
});
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: vi.fn().mockResolvedValue({
|
||||
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
|
||||
usage: { prompt_tokens: 10, completion_tokens: 5 },
|
||||
}),
|
||||
create: mockCreate,
|
||||
},
|
||||
},
|
||||
})),
|
||||
@@ -31,3 +34,44 @@ describe('OpenAIClient', () => {
|
||||
expect(response.usage.outputTokens).toBe(5);
|
||||
});
|
||||
});
|
||||
|
||||
describe('OpenAIClient tool use', () => {
|
||||
it('passes tools to API and parses tool_calls response', async () => {
|
||||
mockCreate.mockResolvedValueOnce({
|
||||
choices: [{
|
||||
message: {
|
||||
content: null,
|
||||
tool_calls: [{
|
||||
id: 'call_1',
|
||||
type: 'function',
|
||||
function: { name: 'shell.exec', arguments: '{"command":"ls"}' },
|
||||
}],
|
||||
},
|
||||
finish_reason: 'tool_calls',
|
||||
}],
|
||||
usage: { prompt_tokens: 20, completion_tokens: 15 },
|
||||
});
|
||||
|
||||
const client = new OpenAIClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gpt-4o',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'list files' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
});
|
||||
|
||||
expect(response.stopReason).toBe('tool_calls');
|
||||
expect(response.toolCalls).toHaveLength(1);
|
||||
expect(response.toolCalls![0]).toEqual({
|
||||
id: 'call_1',
|
||||
name: 'shell.exec',
|
||||
args: { command: 'ls' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
+24
-2
@@ -33,15 +33,36 @@ export class OpenAIClient implements ModelClient {
|
||||
messages.push({ role: msg.role, content: msg.content });
|
||||
}
|
||||
|
||||
const response = await this.client.chat.completions.create({
|
||||
// Build params, conditionally including tools
|
||||
const params: OpenAI.ChatCompletionCreateParamsNonStreaming = {
|
||||
model: this.model,
|
||||
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
|
||||
messages,
|
||||
});
|
||||
};
|
||||
|
||||
if (request.tools && request.tools.length > 0) {
|
||||
params.tools = request.tools.map(t => ({
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema as OpenAI.FunctionParameters,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
const response = await this.client.chat.completions.create(params);
|
||||
|
||||
const choice = response.choices[0];
|
||||
const content = choice?.message?.content ?? '';
|
||||
|
||||
// Parse tool_calls from the response if present
|
||||
const toolCalls = choice?.message?.tool_calls?.map((tc: OpenAI.ChatCompletionMessageToolCall) => ({
|
||||
id: tc.id,
|
||||
name: tc.function.name,
|
||||
args: JSON.parse(tc.function.arguments),
|
||||
})) ?? [];
|
||||
|
||||
return {
|
||||
content,
|
||||
stopReason: choice?.finish_reason ?? 'stop',
|
||||
@@ -49,6 +70,7 @@ export class OpenAIClient implements ModelClient {
|
||||
inputTokens: response.usage?.prompt_tokens ?? 0,
|
||||
outputTokens: response.usage?.completion_tokens ?? 0,
|
||||
},
|
||||
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user