feat: add Gemini and Bedrock model providers
Add native GeminiClient using @google/generative-ai SDK and BedrockClient using @aws-sdk/client-bedrock-runtime. Replace the previous Gemini fallback (OpenAI-compatible shim) with the real implementation. Add OpenRouter as a provider option (OpenAI-compatible with custom baseURL). Update model costs, doctor CLI checks, and client factory tests.
This commit is contained in:
@@ -0,0 +1,332 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { GeminiClient } from './gemini.js';
|
||||
|
||||
// Shared mock functions
|
||||
const mockGenerateContent = vi.fn();
|
||||
const mockGenerateContentStream = vi.fn();
|
||||
|
||||
const mockGetGenerativeModel = vi.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
});
|
||||
|
||||
vi.mock('@google/generative-ai', () => ({
|
||||
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({
|
||||
getGenerativeModel: mockGetGenerativeModel,
|
||||
})),
|
||||
}));
|
||||
|
||||
function makeResponse(parts: unknown[], finishReason = 'STOP', usage = { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }) {
|
||||
return {
|
||||
response: {
|
||||
candidates: [{
|
||||
index: 0,
|
||||
content: { role: 'model', parts },
|
||||
finishReason,
|
||||
}],
|
||||
usageMetadata: usage,
|
||||
text: () => {
|
||||
const textParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'text' in p);
|
||||
if (textParts.length === 0) throw new Error('No text parts');
|
||||
return textParts.map((p: unknown) => (p as { text: string }).text).join('');
|
||||
},
|
||||
functionCalls: () => {
|
||||
const fcParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'functionCall' in p);
|
||||
if (fcParts.length === 0) return undefined;
|
||||
return fcParts.map((p: unknown) => (p as { functionCall: { name: string; args: object } }).functionCall);
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe('GeminiClient', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockGenerateContent.mockResolvedValue(
|
||||
makeResponse([{ text: 'Hello from Gemini!' }]),
|
||||
);
|
||||
mockGetGenerativeModel.mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
});
|
||||
});
|
||||
|
||||
it('sends messages and returns response', async () => {
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
});
|
||||
|
||||
expect(response.content).toBe('Hello from Gemini!');
|
||||
expect(response.stopReason).toBe('end_turn');
|
||||
expect(response.usage.inputTokens).toBe(10);
|
||||
expect(response.usage.outputTokens).toBe(5);
|
||||
});
|
||||
|
||||
it('passes system instruction to model', async () => {
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
system: 'You are a helpful assistant',
|
||||
});
|
||||
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
systemInstruction: 'You are a helpful assistant',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('converts assistant role to model role', async () => {
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{ role: 'assistant', content: 'Hi there!' },
|
||||
{ role: 'user', content: 'How are you?' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||
contents: [
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
{ role: 'model', parts: [{ text: 'Hi there!' }] },
|
||||
{ role: 'user', parts: [{ text: 'How are you?' }] },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('maps MAX_TOKENS finish reason', async () => {
|
||||
mockGenerateContent.mockResolvedValueOnce(
|
||||
makeResponse([{ text: 'Truncated...' }], 'MAX_TOKENS'),
|
||||
);
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Write a long story' }],
|
||||
});
|
||||
|
||||
expect(response.stopReason).toBe('max_tokens');
|
||||
});
|
||||
|
||||
it('uses environment variable for API key when not provided', () => {
|
||||
const originalEnv = process.env.GOOGLE_API_KEY;
|
||||
process.env.GOOGLE_API_KEY = 'env-key';
|
||||
|
||||
try {
|
||||
// Just construct — we verify it doesn't throw
|
||||
const _client = new GeminiClient({ model: 'gemini-2.0-flash' });
|
||||
expect(_client).toBeDefined();
|
||||
} finally {
|
||||
if (originalEnv !== undefined) {
|
||||
process.env.GOOGLE_API_KEY = originalEnv;
|
||||
} else {
|
||||
delete process.env.GOOGLE_API_KEY;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('GeminiClient streaming', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockGenerateContentStream.mockResolvedValue({
|
||||
stream: (async function* () {
|
||||
yield {
|
||||
text: () => 'Hello ',
|
||||
functionCalls: () => undefined,
|
||||
candidates: [{ content: { parts: [{ text: 'Hello ' }] } }],
|
||||
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 2, totalTokenCount: 12 },
|
||||
};
|
||||
yield {
|
||||
text: () => 'from Gemini!',
|
||||
functionCalls: () => undefined,
|
||||
candidates: [{ content: { parts: [{ text: 'from Gemini!' }] } }],
|
||||
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 },
|
||||
};
|
||||
})(),
|
||||
response: Promise.resolve({
|
||||
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 },
|
||||
}),
|
||||
});
|
||||
mockGetGenerativeModel.mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
});
|
||||
});
|
||||
|
||||
it('streams messages chunk by chunk', async () => {
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const chunks: string[] = [];
|
||||
let finalUsage: { inputTokens: number; outputTokens: number } | undefined;
|
||||
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
})) {
|
||||
if (event.type === 'content' && event.content) {
|
||||
chunks.push(event.content);
|
||||
}
|
||||
if (event.type === 'done' && event.usage) {
|
||||
finalUsage = event.usage;
|
||||
}
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
expect(chunks.join('')).toBe('Hello from Gemini!');
|
||||
expect(finalUsage).toEqual({ inputTokens: 10, outputTokens: 5 });
|
||||
});
|
||||
|
||||
it('yields error event on stream failure', async () => {
|
||||
mockGenerateContentStream.mockRejectedValueOnce(new Error('Network error'));
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const events: { type: string; error?: Error }[] = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toHaveLength(1);
|
||||
expect(events[0].type).toBe('error');
|
||||
expect(events[0].error?.message).toBe('Network error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('GeminiClient tool use', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockGetGenerativeModel.mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
});
|
||||
});
|
||||
|
||||
it('passes tools and parses function call response', async () => {
|
||||
mockGenerateContent.mockResolvedValueOnce(
|
||||
makeResponse([{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }]),
|
||||
);
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'list files' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
});
|
||||
|
||||
expect(response.stopReason).toBe('tool_use');
|
||||
expect(response.toolCalls).toHaveLength(1);
|
||||
expect(response.toolCalls![0].name).toBe('shell.exec');
|
||||
expect(response.toolCalls![0].args).toEqual({ command: 'ls' });
|
||||
|
||||
// Verify tools were passed to getGenerativeModel
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tools: [{
|
||||
functionDeclarations: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
parameters: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('handles mixed text and function call response', async () => {
|
||||
mockGenerateContent.mockResolvedValueOnce(
|
||||
makeResponse([
|
||||
{ text: 'Let me run that for you.' },
|
||||
{ functionCall: { name: 'shell.exec', args: { command: 'ls -la' } } },
|
||||
]),
|
||||
);
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'list files in detail' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
});
|
||||
|
||||
expect(response.content).toBe('Let me run that for you.');
|
||||
expect(response.stopReason).toBe('tool_use');
|
||||
expect(response.toolCalls).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('streams function calls', async () => {
|
||||
mockGenerateContentStream.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield {
|
||||
text: () => '',
|
||||
functionCalls: () => [{ name: 'shell.exec', args: { command: 'ls' } }],
|
||||
candidates: [{ content: { parts: [{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }] } }],
|
||||
usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 },
|
||||
};
|
||||
})(),
|
||||
response: Promise.resolve({
|
||||
usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 },
|
||||
}),
|
||||
});
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const toolCalls: { name: string; args: unknown }[] = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'list files' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
})) {
|
||||
if (event.type === 'tool_use' && event.toolCall) {
|
||||
toolCalls.push({ name: event.toolCall.name, args: event.toolCall.args });
|
||||
}
|
||||
}
|
||||
|
||||
expect(toolCalls).toHaveLength(1);
|
||||
expect(toolCalls[0].name).toBe('shell.exec');
|
||||
expect(toolCalls[0].args).toEqual({ command: 'ls' });
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user