feat(models): add tool use support to OpenAIClient
This commit is contained in:
@@ -1,14 +1,17 @@
|
|||||||
import { describe, it, expect, vi } from 'vitest';
|
import { describe, it, expect, vi } from 'vitest';
|
||||||
import { OpenAIClient } from './openai.js';
|
import { OpenAIClient } from './openai.js';
|
||||||
|
|
||||||
|
// Shared mock function so we can override per-test
|
||||||
|
const mockCreate = vi.fn().mockResolvedValue({
|
||||||
|
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
|
||||||
|
usage: { prompt_tokens: 10, completion_tokens: 5 },
|
||||||
|
});
|
||||||
|
|
||||||
vi.mock('openai', () => ({
|
vi.mock('openai', () => ({
|
||||||
default: vi.fn().mockImplementation(() => ({
|
default: vi.fn().mockImplementation(() => ({
|
||||||
chat: {
|
chat: {
|
||||||
completions: {
|
completions: {
|
||||||
create: vi.fn().mockResolvedValue({
|
create: mockCreate,
|
||||||
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
|
|
||||||
usage: { prompt_tokens: 10, completion_tokens: 5 },
|
|
||||||
}),
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})),
|
})),
|
||||||
@@ -31,3 +34,44 @@ describe('OpenAIClient', () => {
|
|||||||
expect(response.usage.outputTokens).toBe(5);
|
expect(response.usage.outputTokens).toBe(5);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('OpenAIClient tool use', () => {
|
||||||
|
it('passes tools to API and parses tool_calls response', async () => {
|
||||||
|
mockCreate.mockResolvedValueOnce({
|
||||||
|
choices: [{
|
||||||
|
message: {
|
||||||
|
content: null,
|
||||||
|
tool_calls: [{
|
||||||
|
id: 'call_1',
|
||||||
|
type: 'function',
|
||||||
|
function: { name: 'shell.exec', arguments: '{"command":"ls"}' },
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
finish_reason: 'tool_calls',
|
||||||
|
}],
|
||||||
|
usage: { prompt_tokens: 20, completion_tokens: 15 },
|
||||||
|
});
|
||||||
|
|
||||||
|
const client = new OpenAIClient({
|
||||||
|
apiKey: 'test-key',
|
||||||
|
model: 'gpt-4o',
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await client.chat({
|
||||||
|
messages: [{ role: 'user', content: 'list files' }],
|
||||||
|
tools: [{
|
||||||
|
name: 'shell.exec',
|
||||||
|
description: 'Run shell command',
|
||||||
|
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(response.stopReason).toBe('tool_calls');
|
||||||
|
expect(response.toolCalls).toHaveLength(1);
|
||||||
|
expect(response.toolCalls![0]).toEqual({
|
||||||
|
id: 'call_1',
|
||||||
|
name: 'shell.exec',
|
||||||
|
args: { command: 'ls' },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
+24
-2
@@ -33,15 +33,36 @@ export class OpenAIClient implements ModelClient {
|
|||||||
messages.push({ role: msg.role, content: msg.content });
|
messages.push({ role: msg.role, content: msg.content });
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await this.client.chat.completions.create({
|
// Build params, conditionally including tools
|
||||||
|
const params: OpenAI.ChatCompletionCreateParamsNonStreaming = {
|
||||||
model: this.model,
|
model: this.model,
|
||||||
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
|
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
|
||||||
messages,
|
messages,
|
||||||
});
|
};
|
||||||
|
|
||||||
|
if (request.tools && request.tools.length > 0) {
|
||||||
|
params.tools = request.tools.map(t => ({
|
||||||
|
type: 'function' as const,
|
||||||
|
function: {
|
||||||
|
name: t.name,
|
||||||
|
description: t.description,
|
||||||
|
parameters: t.input_schema as OpenAI.FunctionParameters,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await this.client.chat.completions.create(params);
|
||||||
|
|
||||||
const choice = response.choices[0];
|
const choice = response.choices[0];
|
||||||
const content = choice?.message?.content ?? '';
|
const content = choice?.message?.content ?? '';
|
||||||
|
|
||||||
|
// Parse tool_calls from the response if present
|
||||||
|
const toolCalls = choice?.message?.tool_calls?.map((tc: OpenAI.ChatCompletionMessageToolCall) => ({
|
||||||
|
id: tc.id,
|
||||||
|
name: tc.function.name,
|
||||||
|
args: JSON.parse(tc.function.arguments),
|
||||||
|
})) ?? [];
|
||||||
|
|
||||||
return {
|
return {
|
||||||
content,
|
content,
|
||||||
stopReason: choice?.finish_reason ?? 'stop',
|
stopReason: choice?.finish_reason ?? 'stop',
|
||||||
@@ -49,6 +70,7 @@ export class OpenAIClient implements ModelClient {
|
|||||||
inputTokens: response.usage?.prompt_tokens ?? 0,
|
inputTokens: response.usage?.prompt_tokens ?? 0,
|
||||||
outputTokens: response.usage?.completion_tokens ?? 0,
|
outputTokens: response.usage?.completion_tokens ?? 0,
|
||||||
},
|
},
|
||||||
|
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user