From 0eb1f7a0738bb9a839b24c742418cd5f8c485578 Mon Sep 17 00:00:00 2001 From: William Valentin Date: Fri, 6 Feb 2026 16:51:32 -0800 Subject: [PATCH] feat: add Gemini and Bedrock model providers Add native GeminiClient using @google/generative-ai SDK and BedrockClient using @aws-sdk/client-bedrock-runtime. Replace the previous Gemini fallback (OpenAI-compatible shim) with the real implementation. Add OpenRouter as a provider option (OpenAI-compatible with custom baseURL). Update model costs, doctor CLI checks, and client factory tests. --- src/cli/doctor.ts | 9 +- src/daemon/clientFactory.test.ts | 24 ++- src/models/bedrock.test.ts | 180 +++++++++++++++++ src/models/bedrock.ts | 179 +++++++++++++++++ src/models/costs.ts | 12 ++ src/models/gemini.test.ts | 332 +++++++++++++++++++++++++++++++ src/models/gemini.ts | 175 ++++++++++++++++ src/models/index.ts | 2 + 8 files changed, 908 insertions(+), 5 deletions(-) create mode 100644 src/models/bedrock.test.ts create mode 100644 src/models/bedrock.ts create mode 100644 src/models/gemini.test.ts create mode 100644 src/models/gemini.ts diff --git a/src/cli/doctor.ts b/src/cli/doctor.ts index 45a45a1..84172d4 100644 --- a/src/cli/doctor.ts +++ b/src/cli/doctor.ts @@ -123,9 +123,14 @@ const checkModelConnectivity: Check = async (ctx) => { } // Check if API key is present for providers that need one - const needsKey = ['anthropic', 'openai', 'gemini']; + const needsKey = ['anthropic', 'openai', 'gemini', 'openrouter']; if (needsKey.includes(model.provider) && !model.api_key && !model.auth_token) { - const envVar = model.provider === 'anthropic' ? 'ANTHROPIC_API_KEY' : model.provider === 'openai' ? 'OPENAI_API_KEY' : undefined; + const envVarMap: Record = { + anthropic: 'ANTHROPIC_API_KEY', + openai: 'OPENAI_API_KEY', + openrouter: 'OPENROUTER_API_KEY', + }; + const envVar = envVarMap[model.provider]; const hasEnv = envVar && process.env[envVar]; if (!hasEnv) { return { status: 'warn', label: 'Model connectivity', detail: `${model.provider}/${model.model} — no API key or auth token found` }; diff --git a/src/daemon/clientFactory.test.ts b/src/daemon/clientFactory.test.ts index 640ef7f..428ab38 100644 --- a/src/daemon/clientFactory.test.ts +++ b/src/daemon/clientFactory.test.ts @@ -4,6 +4,8 @@ import { AnthropicClient } from '../models/anthropic.js'; import { OpenAIClient } from '../models/openai.js'; import { OllamaClient } from '../models/local/ollama.js'; import { LlamaCppClient } from '../models/local/llamacpp.js'; +import { GeminiClient } from '../models/gemini.js'; +import { BedrockClient } from '../models/bedrock.js'; describe('createClientFromConfig', () => { it('creates AnthropicClient for anthropic provider', () => { @@ -59,14 +61,13 @@ describe('createClientFromConfig', () => { expect(client).toBeInstanceOf(LlamaCppClient); }); - it('creates OpenAI-compatible client for gemini provider (with warning)', () => { + it('creates GeminiClient for gemini provider', () => { const client = createClientFromConfig({ provider: 'gemini', model: 'gemini-2.5-pro', api_key: 'test-key', }); - // Gemini falls back to OpenAI-compatible client - expect(client).toBeInstanceOf(OpenAIClient); + expect(client).toBeInstanceOf(GeminiClient); }); it('throws for unknown provider', () => { @@ -75,4 +76,21 @@ describe('createClientFromConfig', () => { model: 'test', })).toThrow('Unknown model provider: unknown'); }); + + it('creates OpenAIClient with OpenRouter baseURL for openrouter provider', () => { + const client = createClientFromConfig({ + provider: 'openrouter', + model: 'meta-llama/llama-3.1-70b', + api_key: 'test-key', + }); + expect(client).toBeInstanceOf(OpenAIClient); + }); + + it('creates BedrockClient for bedrock provider', () => { + const client = createClientFromConfig({ + provider: 'bedrock', + model: 'anthropic.claude-3-sonnet', + }); + expect(client).toBeInstanceOf(BedrockClient); + }); }); diff --git a/src/models/bedrock.test.ts b/src/models/bedrock.test.ts new file mode 100644 index 0000000..09f7d91 --- /dev/null +++ b/src/models/bedrock.test.ts @@ -0,0 +1,180 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { BedrockClient } from './bedrock.js'; +import type { ChatStreamEvent } from './types.js'; + +const mockSend = vi.fn().mockResolvedValue({ + output: { + message: { + content: [{ text: 'Hello from Bedrock!' }], + }, + }, + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5 }, +}); + +vi.mock('@aws-sdk/client-bedrock-runtime', () => ({ + BedrockRuntimeClient: vi.fn().mockImplementation(() => ({ + send: mockSend, + })), + ConverseCommand: vi.fn().mockImplementation((params) => params), + ConverseStreamCommand: vi.fn().mockImplementation((params) => params), +})); + +describe('BedrockClient', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockSend.mockResolvedValue({ + output: { + message: { + content: [{ text: 'Hello from Bedrock!' }], + }, + }, + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5 }, + }); + }); + + it('sends messages and returns response', async () => { + const client = new BedrockClient({ + model: 'anthropic.claude-3-sonnet', + region: 'us-east-1', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'Hello' }], + }); + + expect(response.content).toBe('Hello from Bedrock!'); + expect(response.stopReason).toBe('end_turn'); + expect(response.usage.inputTokens).toBe(10); + expect(response.usage.outputTokens).toBe(5); + }); + + it('parses tool use response', async () => { + mockSend.mockResolvedValueOnce({ + output: { + message: { + content: [{ + toolUse: { + toolUseId: 'tool_01', + name: 'shell.exec', + input: { command: 'ls' }, + }, + }], + }, + }, + stopReason: 'tool_use', + usage: { inputTokens: 20, outputTokens: 15 }, + }); + + const client = new BedrockClient({ + model: 'anthropic.claude-3-sonnet', + region: 'us-east-1', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'list files' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell command', + input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, + }], + }); + + expect(response.stopReason).toBe('tool_use'); + expect(response.toolCalls).toHaveLength(1); + expect(response.toolCalls![0].name).toBe('shell.exec'); + expect(response.toolCalls![0].args).toEqual({ command: 'ls' }); + }); + + it('uses default region when none provided', async () => { + const client = new BedrockClient({ + model: 'anthropic.claude-3-sonnet', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'Hello' }], + }); + + expect(response.content).toBe('Hello from Bedrock!'); + }); + + it('passes system prompt to API', async () => { + const client = new BedrockClient({ + model: 'anthropic.claude-3-sonnet', + region: 'us-east-1', + }); + + await client.chat({ + messages: [{ role: 'user', content: 'Hello' }], + system: 'You are a helpful assistant.', + }); + + expect(mockSend).toHaveBeenCalledTimes(1); + // ConverseCommand is called with params that include system + const { ConverseCommand } = await import('@aws-sdk/client-bedrock-runtime'); + expect(ConverseCommand).toHaveBeenCalledWith( + expect.objectContaining({ + system: [{ text: 'You are a helpful assistant.' }], + }), + ); + }); +}); + +describe('BedrockClient streaming', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('streams content events', async () => { + mockSend.mockResolvedValueOnce({ + stream: (async function* () { + yield { contentBlockDelta: { delta: { text: 'Hello ' } } }; + yield { contentBlockDelta: { delta: { text: 'from Bedrock!' } } }; + yield { metadata: { usage: { inputTokens: 10, outputTokens: 5 } } }; + })(), + }); + + const client = new BedrockClient({ + model: 'anthropic.claude-3-sonnet', + region: 'us-east-1', + }); + + const chunks: string[] = []; + let finalUsage: { inputTokens: number; outputTokens: number } | undefined; + + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Hello' }], + })) { + if (event.type === 'content' && event.content) { + chunks.push(event.content); + } + if (event.type === 'done' && event.usage) { + finalUsage = event.usage; + } + } + + expect(chunks.join('')).toBe('Hello from Bedrock!'); + expect(finalUsage).toEqual({ inputTokens: 10, outputTokens: 5 }); + }); + + it('yields error event on failure', async () => { + mockSend.mockRejectedValueOnce(new Error('Service unavailable')); + + const client = new BedrockClient({ + model: 'anthropic.claude-3-sonnet', + region: 'us-east-1', + }); + + const events: ChatStreamEvent[] = []; + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Hello' }], + })) { + events.push(event); + } + + expect(events).toHaveLength(1); + expect(events[0].type).toBe('error'); + expect(events[0].error?.message).toBe('Service unavailable'); + }); +}); diff --git a/src/models/bedrock.ts b/src/models/bedrock.ts new file mode 100644 index 0000000..9fd2cb8 --- /dev/null +++ b/src/models/bedrock.ts @@ -0,0 +1,179 @@ +import { + BedrockRuntimeClient, + ConverseCommand, + ConverseStreamCommand, +} from '@aws-sdk/client-bedrock-runtime'; +import type { + Message as BedrockMessage, + ContentBlock, + ToolConfiguration, + Tool as BedrockTool, + ConverseCommandInput, + ConverseStreamCommandInput, +} from '@aws-sdk/client-bedrock-runtime'; +import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition } from './types.js'; + +export interface BedrockClientConfig { + model: string; + region?: string; + maxTokens?: number; + /** AWS access key ID (if not using default credential chain). */ + accessKeyId?: string; + /** AWS secret access key (if not using default credential chain). */ + secretAccessKey?: string; +} + +export class BedrockClient implements ModelClient { + private client: BedrockRuntimeClient; + private model: string; + private defaultMaxTokens: number; + + constructor(config: BedrockClientConfig) { + const clientConfig: Record = { + region: config.region ?? process.env.AWS_REGION ?? 'us-east-1', + }; + + if (config.accessKeyId && config.secretAccessKey) { + clientConfig.credentials = { + accessKeyId: config.accessKeyId, + secretAccessKey: config.secretAccessKey, + }; + } + + this.client = new BedrockRuntimeClient(clientConfig); + this.model = config.model; + this.defaultMaxTokens = config.maxTokens ?? 4096; + } + + async chat(request: ChatRequest): Promise { + const messages = convertMessages(request.messages); + + const params: ConverseCommandInput = { + modelId: this.model, + messages, + inferenceConfig: { + maxTokens: request.maxTokens ?? this.defaultMaxTokens, + }, + }; + + if (request.system) { + params.system = [{ text: request.system }]; + } + + if (request.tools && request.tools.length > 0) { + params.toolConfig = convertTools(request.tools); + } + + const command = new ConverseCommand(params); + const response = await this.client.send(command); + + // Extract text and tool_use content from the response + const outputContent = response.output?.message?.content ?? []; + const textParts: string[] = []; + const toolCalls: ModelToolCall[] = []; + + for (const block of outputContent) { + if ('text' in block && block.text !== undefined) { + textParts.push(block.text); + } + if ('toolUse' in block && block.toolUse !== undefined) { + toolCalls.push({ + id: block.toolUse.toolUseId ?? `bedrock_${Date.now()}`, + name: block.toolUse.name ?? '', + args: block.toolUse.input as unknown, + }); + } + } + + const content = textParts.join(''); + + // Map stop reason + let stopReason: string = 'end_turn'; + if (response.stopReason === 'max_tokens') stopReason = 'max_tokens'; + else if (response.stopReason === 'tool_use') stopReason = 'tool_use'; + else if (response.stopReason === 'end_turn') stopReason = 'end_turn'; + else if (response.stopReason) stopReason = response.stopReason; + + return { + content, + stopReason, + usage: { + inputTokens: response.usage?.inputTokens ?? 0, + outputTokens: response.usage?.outputTokens ?? 0, + }, + ...(toolCalls.length > 0 ? { toolCalls } : {}), + }; + } + + async *chatStream(request: ChatRequest): AsyncIterable { + const messages = convertMessages(request.messages); + + const params: ConverseStreamCommandInput = { + modelId: this.model, + messages, + inferenceConfig: { + maxTokens: request.maxTokens ?? this.defaultMaxTokens, + }, + }; + + if (request.system) { + params.system = [{ text: request.system }]; + } + + if (request.tools && request.tools.length > 0) { + params.toolConfig = convertTools(request.tools); + } + + try { + const command = new ConverseStreamCommand(params); + const response = await this.client.send(command); + + let inputTokens = 0; + let outputTokens = 0; + + if (response.stream) { + for await (const event of response.stream) { + if (event.contentBlockDelta?.delta && 'text' in event.contentBlockDelta.delta && event.contentBlockDelta.delta.text) { + yield { type: 'content', content: event.contentBlockDelta.delta.text }; + } + + if (event.metadata?.usage) { + inputTokens = event.metadata.usage.inputTokens ?? inputTokens; + outputTokens = event.metadata.usage.outputTokens ?? outputTokens; + } + } + } + + yield { + type: 'done', + usage: { inputTokens, outputTokens }, + }; + } catch (error) { + yield { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)), + }; + } + } +} + +function convertMessages(messages: { role: string; content: string }[]): BedrockMessage[] { + return messages.map(m => ({ + role: m.role === 'assistant' ? 'assistant' as const : 'user' as const, + content: [{ text: m.content }] as ContentBlock[], + })); +} + +function convertTools(tools: ToolDefinition[]): ToolConfiguration { + return { + tools: tools.map(t => ({ + toolSpec: { + name: t.name, + description: t.description, + inputSchema: { + json: t.input_schema as Record, + }, + }, + } as BedrockTool)), + }; +} diff --git a/src/models/costs.ts b/src/models/costs.ts index ae758ea..d81d205 100644 --- a/src/models/costs.ts +++ b/src/models/costs.ts @@ -7,8 +7,20 @@ export const MODEL_COSTS_PER_MILLION: Record ({ + GoogleGenerativeAI: vi.fn().mockImplementation(() => ({ + getGenerativeModel: mockGetGenerativeModel, + })), +})); + +function makeResponse(parts: unknown[], finishReason = 'STOP', usage = { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }) { + return { + response: { + candidates: [{ + index: 0, + content: { role: 'model', parts }, + finishReason, + }], + usageMetadata: usage, + text: () => { + const textParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'text' in p); + if (textParts.length === 0) throw new Error('No text parts'); + return textParts.map((p: unknown) => (p as { text: string }).text).join(''); + }, + functionCalls: () => { + const fcParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'functionCall' in p); + if (fcParts.length === 0) return undefined; + return fcParts.map((p: unknown) => (p as { functionCall: { name: string; args: object } }).functionCall); + }, + }, + }; +} + +describe('GeminiClient', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockGenerateContent.mockResolvedValue( + makeResponse([{ text: 'Hello from Gemini!' }]), + ); + mockGetGenerativeModel.mockReturnValue({ + generateContent: mockGenerateContent, + generateContentStream: mockGenerateContentStream, + }); + }); + + it('sends messages and returns response', async () => { + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'Hello' }], + }); + + expect(response.content).toBe('Hello from Gemini!'); + expect(response.stopReason).toBe('end_turn'); + expect(response.usage.inputTokens).toBe(10); + expect(response.usage.outputTokens).toBe(5); + }); + + it('passes system instruction to model', async () => { + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + await client.chat({ + messages: [{ role: 'user', content: 'Hello' }], + system: 'You are a helpful assistant', + }); + + expect(mockGetGenerativeModel).toHaveBeenCalledWith( + expect.objectContaining({ + systemInstruction: 'You are a helpful assistant', + }), + ); + }); + + it('converts assistant role to model role', async () => { + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + await client.chat({ + messages: [ + { role: 'user', content: 'Hello' }, + { role: 'assistant', content: 'Hi there!' }, + { role: 'user', content: 'How are you?' }, + ], + }); + + expect(mockGenerateContent).toHaveBeenCalledWith({ + contents: [ + { role: 'user', parts: [{ text: 'Hello' }] }, + { role: 'model', parts: [{ text: 'Hi there!' }] }, + { role: 'user', parts: [{ text: 'How are you?' }] }, + ], + }); + }); + + it('maps MAX_TOKENS finish reason', async () => { + mockGenerateContent.mockResolvedValueOnce( + makeResponse([{ text: 'Truncated...' }], 'MAX_TOKENS'), + ); + + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'Write a long story' }], + }); + + expect(response.stopReason).toBe('max_tokens'); + }); + + it('uses environment variable for API key when not provided', () => { + const originalEnv = process.env.GOOGLE_API_KEY; + process.env.GOOGLE_API_KEY = 'env-key'; + + try { + // Just construct — we verify it doesn't throw + const _client = new GeminiClient({ model: 'gemini-2.0-flash' }); + expect(_client).toBeDefined(); + } finally { + if (originalEnv !== undefined) { + process.env.GOOGLE_API_KEY = originalEnv; + } else { + delete process.env.GOOGLE_API_KEY; + } + } + }); +}); + +describe('GeminiClient streaming', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockGenerateContentStream.mockResolvedValue({ + stream: (async function* () { + yield { + text: () => 'Hello ', + functionCalls: () => undefined, + candidates: [{ content: { parts: [{ text: 'Hello ' }] } }], + usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 2, totalTokenCount: 12 }, + }; + yield { + text: () => 'from Gemini!', + functionCalls: () => undefined, + candidates: [{ content: { parts: [{ text: 'from Gemini!' }] } }], + usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }, + }; + })(), + response: Promise.resolve({ + usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }, + }), + }); + mockGetGenerativeModel.mockReturnValue({ + generateContent: mockGenerateContent, + generateContentStream: mockGenerateContentStream, + }); + }); + + it('streams messages chunk by chunk', async () => { + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + const chunks: string[] = []; + let finalUsage: { inputTokens: number; outputTokens: number } | undefined; + + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Hello' }], + })) { + if (event.type === 'content' && event.content) { + chunks.push(event.content); + } + if (event.type === 'done' && event.usage) { + finalUsage = event.usage; + } + } + + expect(chunks.length).toBeGreaterThan(0); + expect(chunks.join('')).toBe('Hello from Gemini!'); + expect(finalUsage).toEqual({ inputTokens: 10, outputTokens: 5 }); + }); + + it('yields error event on stream failure', async () => { + mockGenerateContentStream.mockRejectedValueOnce(new Error('Network error')); + + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + const events: { type: string; error?: Error }[] = []; + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'Hello' }], + })) { + events.push(event); + } + + expect(events).toHaveLength(1); + expect(events[0].type).toBe('error'); + expect(events[0].error?.message).toBe('Network error'); + }); +}); + +describe('GeminiClient tool use', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockGetGenerativeModel.mockReturnValue({ + generateContent: mockGenerateContent, + generateContentStream: mockGenerateContentStream, + }); + }); + + it('passes tools and parses function call response', async () => { + mockGenerateContent.mockResolvedValueOnce( + makeResponse([{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }]), + ); + + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'list files' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell command', + input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, + }], + }); + + expect(response.stopReason).toBe('tool_use'); + expect(response.toolCalls).toHaveLength(1); + expect(response.toolCalls![0].name).toBe('shell.exec'); + expect(response.toolCalls![0].args).toEqual({ command: 'ls' }); + + // Verify tools were passed to getGenerativeModel + expect(mockGetGenerativeModel).toHaveBeenCalledWith( + expect.objectContaining({ + tools: [{ + functionDeclarations: [{ + name: 'shell.exec', + description: 'Run shell command', + parameters: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, + }], + }], + }), + ); + }); + + it('handles mixed text and function call response', async () => { + mockGenerateContent.mockResolvedValueOnce( + makeResponse([ + { text: 'Let me run that for you.' }, + { functionCall: { name: 'shell.exec', args: { command: 'ls -la' } } }, + ]), + ); + + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + const response = await client.chat({ + messages: [{ role: 'user', content: 'list files in detail' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell command', + input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, + }], + }); + + expect(response.content).toBe('Let me run that for you.'); + expect(response.stopReason).toBe('tool_use'); + expect(response.toolCalls).toHaveLength(1); + }); + + it('streams function calls', async () => { + mockGenerateContentStream.mockResolvedValueOnce({ + stream: (async function* () { + yield { + text: () => '', + functionCalls: () => [{ name: 'shell.exec', args: { command: 'ls' } }], + candidates: [{ content: { parts: [{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }] } }], + usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 }, + }; + })(), + response: Promise.resolve({ + usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 }, + }), + }); + + const client = new GeminiClient({ + apiKey: 'test-key', + model: 'gemini-2.0-flash', + }); + + const toolCalls: { name: string; args: unknown }[] = []; + for await (const event of client.chatStream({ + messages: [{ role: 'user', content: 'list files' }], + tools: [{ + name: 'shell.exec', + description: 'Run shell command', + input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, + }], + })) { + if (event.type === 'tool_use' && event.toolCall) { + toolCalls.push({ name: event.toolCall.name, args: event.toolCall.args }); + } + } + + expect(toolCalls).toHaveLength(1); + expect(toolCalls[0].name).toBe('shell.exec'); + expect(toolCalls[0].args).toEqual({ command: 'ls' }); + }); +}); diff --git a/src/models/gemini.ts b/src/models/gemini.ts new file mode 100644 index 0000000..ccb58ce --- /dev/null +++ b/src/models/gemini.ts @@ -0,0 +1,175 @@ +import { GoogleGenerativeAI } from '@google/generative-ai'; +import type { GenerativeModel, Content, FunctionDeclaration, FunctionDeclarationSchema } from '@google/generative-ai'; +import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition } from './types.js'; + +export interface GeminiClientConfig { + apiKey?: string; + model: string; + maxTokens?: number; +} + +export class GeminiClient implements ModelClient { + private genAI: GoogleGenerativeAI; + private model: string; + private defaultMaxTokens: number; + + constructor(config: GeminiClientConfig) { + const apiKey = config.apiKey ?? process.env.GOOGLE_API_KEY ?? ''; + this.genAI = new GoogleGenerativeAI(apiKey); + this.model = config.model; + this.defaultMaxTokens = config.maxTokens ?? 8192; + } + + private getModel(request: ChatRequest): GenerativeModel { + const tools = request.tools && request.tools.length > 0 + ? [{ functionDeclarations: request.tools.map(t => convertToolDefinition(t)) }] + : undefined; + + return this.genAI.getGenerativeModel({ + model: this.model, + systemInstruction: request.system || undefined, + tools, + generationConfig: { + maxOutputTokens: request.maxTokens ?? this.defaultMaxTokens, + }, + }); + } + + async chat(request: ChatRequest): Promise { + const model = this.getModel(request); + const contents = convertMessages(request.messages); + + const result = await model.generateContent({ contents }); + const response = result.response; + const candidate = response.candidates?.[0]; + + // Extract text via the helper method + let content = ''; + try { + content = response.text(); + } catch { + // text() throws if blocked — fall back to manual extraction + const textParts = candidate?.content?.parts?.filter(p => 'text' in p && p.text !== undefined) ?? []; + content = textParts.map(p => (p as { text: string }).text).join(''); + } + + // Extract function calls via the helper method + const functionCalls = response.functionCalls(); + const toolCalls: ModelToolCall[] = functionCalls + ? functionCalls.map((fc, i) => ({ + id: `gemini_${Date.now()}_${i}`, + name: fc.name, + args: fc.args, + })) + : []; + + // Map finish reason + const finishReason = candidate?.finishReason; + let stopReason: string = 'end_turn'; + if (toolCalls.length > 0) { + stopReason = 'tool_use'; + } else if (finishReason === 'MAX_TOKENS') { + stopReason = 'max_tokens'; + } else if (finishReason === 'STOP') { + stopReason = 'end_turn'; + } else if (finishReason) { + stopReason = finishReason.toLowerCase(); + } + + // Extract usage + const usageMetadata = response.usageMetadata; + const usage = { + inputTokens: usageMetadata?.promptTokenCount ?? 0, + outputTokens: usageMetadata?.candidatesTokenCount ?? 0, + }; + + return { + content, + stopReason, + usage, + ...(toolCalls.length > 0 ? { toolCalls } : {}), + }; + } + + async *chatStream(request: ChatRequest): AsyncIterable { + const model = this.getModel(request); + const contents = convertMessages(request.messages); + + try { + const result = await model.generateContentStream({ contents }); + + let totalInputTokens = 0; + let totalOutputTokens = 0; + + for await (const chunk of result.stream) { + // Use the text() helper to extract text content from this chunk + try { + const text = chunk.text(); + if (text) { + yield { type: 'content', content: text }; + } + } catch { + // text() throws if blocked — skip + } + + // Check for function calls in streaming chunks + const calls = chunk.functionCalls(); + if (calls) { + for (const fc of calls) { + yield { + type: 'tool_use', + toolCall: { + id: `gemini_${Date.now()}`, + name: fc.name, + args: fc.args, + }, + }; + } + } + + // Track usage from chunks + if (chunk.usageMetadata) { + totalInputTokens = chunk.usageMetadata.promptTokenCount ?? totalInputTokens; + totalOutputTokens = chunk.usageMetadata.candidatesTokenCount ?? totalOutputTokens; + } + } + + // Final aggregated response for usage + const aggregated = await result.response; + const usageMetadata = aggregated.usageMetadata; + + yield { + type: 'done', + usage: { + inputTokens: usageMetadata?.promptTokenCount ?? totalInputTokens, + outputTokens: usageMetadata?.candidatesTokenCount ?? totalOutputTokens, + }, + }; + } catch (error) { + yield { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)), + }; + } + } +} + +/** Convert Flynn's Message[] to Gemini Content[] format */ +function convertMessages(messages: { role: string; content: string }[]): Content[] { + return messages.map(m => ({ + role: m.role === 'assistant' ? 'model' : 'user', + parts: [{ text: m.content }], + })); +} + +/** Convert Flynn's ToolDefinition to Gemini FunctionDeclaration format */ +function convertToolDefinition(tool: ToolDefinition): FunctionDeclaration { + // The Gemini SDK's FunctionDeclarationSchema expects `type: SchemaType` (enum) + // but the actual wire format accepts string values. We pass the schema through + // as-is since the SDK serialises it to JSON for the API request. + return { + name: tool.name, + description: tool.description, + parameters: tool.input_schema as unknown as FunctionDeclarationSchema, + }; +} diff --git a/src/models/index.ts b/src/models/index.ts index 346c0b5..0dc8583 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -1,5 +1,7 @@ export { AnthropicClient, type AnthropicClientConfig } from './anthropic.js'; export { OpenAIClient, type OpenAIClientConfig } from './openai.js'; +export { GeminiClient, type GeminiClientConfig } from './gemini.js'; +export { BedrockClient, type BedrockClientConfig } from './bedrock.js'; export { OllamaClient, type OllamaClientConfig } from './local/index.js'; export { LlamaCppClient, type LlamaCppClientConfig } from './local/index.js'; export { ModelRouter, type ModelRouterConfig, type ModelTier } from './router.js';