feat: add Gemini and Bedrock model providers

Add native GeminiClient using @google/generative-ai SDK and BedrockClient
using @aws-sdk/client-bedrock-runtime. Replace the previous Gemini fallback
(OpenAI-compatible shim) with the real implementation. Add OpenRouter as a
provider option (OpenAI-compatible with custom baseURL). Update model costs,
doctor CLI checks, and client factory tests.
This commit is contained in:
William Valentin
2026-02-06 16:51:32 -08:00
parent e8e4fcd758
commit 0eb1f7a073
8 changed files with 908 additions and 5 deletions
+7 -2
View File
@@ -123,9 +123,14 @@ const checkModelConnectivity: Check = async (ctx) => {
}
// Check if API key is present for providers that need one
const needsKey = ['anthropic', 'openai', 'gemini'];
const needsKey = ['anthropic', 'openai', 'gemini', 'openrouter'];
if (needsKey.includes(model.provider) && !model.api_key && !model.auth_token) {
const envVar = model.provider === 'anthropic' ? 'ANTHROPIC_API_KEY' : model.provider === 'openai' ? 'OPENAI_API_KEY' : undefined;
const envVarMap: Record<string, string> = {
anthropic: 'ANTHROPIC_API_KEY',
openai: 'OPENAI_API_KEY',
openrouter: 'OPENROUTER_API_KEY',
};
const envVar = envVarMap[model.provider];
const hasEnv = envVar && process.env[envVar];
if (!hasEnv) {
return { status: 'warn', label: 'Model connectivity', detail: `${model.provider}/${model.model} — no API key or auth token found` };
+21 -3
View File
@@ -4,6 +4,8 @@ import { AnthropicClient } from '../models/anthropic.js';
import { OpenAIClient } from '../models/openai.js';
import { OllamaClient } from '../models/local/ollama.js';
import { LlamaCppClient } from '../models/local/llamacpp.js';
import { GeminiClient } from '../models/gemini.js';
import { BedrockClient } from '../models/bedrock.js';
describe('createClientFromConfig', () => {
it('creates AnthropicClient for anthropic provider', () => {
@@ -59,14 +61,13 @@ describe('createClientFromConfig', () => {
expect(client).toBeInstanceOf(LlamaCppClient);
});
it('creates OpenAI-compatible client for gemini provider (with warning)', () => {
it('creates GeminiClient for gemini provider', () => {
const client = createClientFromConfig({
provider: 'gemini',
model: 'gemini-2.5-pro',
api_key: 'test-key',
});
// Gemini falls back to OpenAI-compatible client
expect(client).toBeInstanceOf(OpenAIClient);
expect(client).toBeInstanceOf(GeminiClient);
});
it('throws for unknown provider', () => {
@@ -75,4 +76,21 @@ describe('createClientFromConfig', () => {
model: 'test',
})).toThrow('Unknown model provider: unknown');
});
it('creates OpenAIClient with OpenRouter baseURL for openrouter provider', () => {
const client = createClientFromConfig({
provider: 'openrouter',
model: 'meta-llama/llama-3.1-70b',
api_key: 'test-key',
});
expect(client).toBeInstanceOf(OpenAIClient);
});
it('creates BedrockClient for bedrock provider', () => {
const client = createClientFromConfig({
provider: 'bedrock',
model: 'anthropic.claude-3-sonnet',
});
expect(client).toBeInstanceOf(BedrockClient);
});
});
+180
View File
@@ -0,0 +1,180 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { BedrockClient } from './bedrock.js';
import type { ChatStreamEvent } from './types.js';
const mockSend = vi.fn().mockResolvedValue({
output: {
message: {
content: [{ text: 'Hello from Bedrock!' }],
},
},
stopReason: 'end_turn',
usage: { inputTokens: 10, outputTokens: 5 },
});
vi.mock('@aws-sdk/client-bedrock-runtime', () => ({
BedrockRuntimeClient: vi.fn().mockImplementation(() => ({
send: mockSend,
})),
ConverseCommand: vi.fn().mockImplementation((params) => params),
ConverseStreamCommand: vi.fn().mockImplementation((params) => params),
}));
describe('BedrockClient', () => {
beforeEach(() => {
vi.clearAllMocks();
mockSend.mockResolvedValue({
output: {
message: {
content: [{ text: 'Hello from Bedrock!' }],
},
},
stopReason: 'end_turn',
usage: { inputTokens: 10, outputTokens: 5 },
});
});
it('sends messages and returns response', async () => {
const client = new BedrockClient({
model: 'anthropic.claude-3-sonnet',
region: 'us-east-1',
});
const response = await client.chat({
messages: [{ role: 'user', content: 'Hello' }],
});
expect(response.content).toBe('Hello from Bedrock!');
expect(response.stopReason).toBe('end_turn');
expect(response.usage.inputTokens).toBe(10);
expect(response.usage.outputTokens).toBe(5);
});
it('parses tool use response', async () => {
mockSend.mockResolvedValueOnce({
output: {
message: {
content: [{
toolUse: {
toolUseId: 'tool_01',
name: 'shell.exec',
input: { command: 'ls' },
},
}],
},
},
stopReason: 'tool_use',
usage: { inputTokens: 20, outputTokens: 15 },
});
const client = new BedrockClient({
model: 'anthropic.claude-3-sonnet',
region: 'us-east-1',
});
const response = await client.chat({
messages: [{ role: 'user', content: 'list files' }],
tools: [{
name: 'shell.exec',
description: 'Run shell command',
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
}],
});
expect(response.stopReason).toBe('tool_use');
expect(response.toolCalls).toHaveLength(1);
expect(response.toolCalls![0].name).toBe('shell.exec');
expect(response.toolCalls![0].args).toEqual({ command: 'ls' });
});
it('uses default region when none provided', async () => {
const client = new BedrockClient({
model: 'anthropic.claude-3-sonnet',
});
const response = await client.chat({
messages: [{ role: 'user', content: 'Hello' }],
});
expect(response.content).toBe('Hello from Bedrock!');
});
it('passes system prompt to API', async () => {
const client = new BedrockClient({
model: 'anthropic.claude-3-sonnet',
region: 'us-east-1',
});
await client.chat({
messages: [{ role: 'user', content: 'Hello' }],
system: 'You are a helpful assistant.',
});
expect(mockSend).toHaveBeenCalledTimes(1);
// ConverseCommand is called with params that include system
const { ConverseCommand } = await import('@aws-sdk/client-bedrock-runtime');
expect(ConverseCommand).toHaveBeenCalledWith(
expect.objectContaining({
system: [{ text: 'You are a helpful assistant.' }],
}),
);
});
});
describe('BedrockClient streaming', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('streams content events', async () => {
mockSend.mockResolvedValueOnce({
stream: (async function* () {
yield { contentBlockDelta: { delta: { text: 'Hello ' } } };
yield { contentBlockDelta: { delta: { text: 'from Bedrock!' } } };
yield { metadata: { usage: { inputTokens: 10, outputTokens: 5 } } };
})(),
});
const client = new BedrockClient({
model: 'anthropic.claude-3-sonnet',
region: 'us-east-1',
});
const chunks: string[] = [];
let finalUsage: { inputTokens: number; outputTokens: number } | undefined;
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'Hello' }],
})) {
if (event.type === 'content' && event.content) {
chunks.push(event.content);
}
if (event.type === 'done' && event.usage) {
finalUsage = event.usage;
}
}
expect(chunks.join('')).toBe('Hello from Bedrock!');
expect(finalUsage).toEqual({ inputTokens: 10, outputTokens: 5 });
});
it('yields error event on failure', async () => {
mockSend.mockRejectedValueOnce(new Error('Service unavailable'));
const client = new BedrockClient({
model: 'anthropic.claude-3-sonnet',
region: 'us-east-1',
});
const events: ChatStreamEvent[] = [];
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'Hello' }],
})) {
events.push(event);
}
expect(events).toHaveLength(1);
expect(events[0].type).toBe('error');
expect(events[0].error?.message).toBe('Service unavailable');
});
});
+179
View File
@@ -0,0 +1,179 @@
import {
BedrockRuntimeClient,
ConverseCommand,
ConverseStreamCommand,
} from '@aws-sdk/client-bedrock-runtime';
import type {
Message as BedrockMessage,
ContentBlock,
ToolConfiguration,
Tool as BedrockTool,
ConverseCommandInput,
ConverseStreamCommandInput,
} from '@aws-sdk/client-bedrock-runtime';
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition } from './types.js';
export interface BedrockClientConfig {
model: string;
region?: string;
maxTokens?: number;
/** AWS access key ID (if not using default credential chain). */
accessKeyId?: string;
/** AWS secret access key (if not using default credential chain). */
secretAccessKey?: string;
}
export class BedrockClient implements ModelClient {
private client: BedrockRuntimeClient;
private model: string;
private defaultMaxTokens: number;
constructor(config: BedrockClientConfig) {
const clientConfig: Record<string, unknown> = {
region: config.region ?? process.env.AWS_REGION ?? 'us-east-1',
};
if (config.accessKeyId && config.secretAccessKey) {
clientConfig.credentials = {
accessKeyId: config.accessKeyId,
secretAccessKey: config.secretAccessKey,
};
}
this.client = new BedrockRuntimeClient(clientConfig);
this.model = config.model;
this.defaultMaxTokens = config.maxTokens ?? 4096;
}
async chat(request: ChatRequest): Promise<ChatResponse> {
const messages = convertMessages(request.messages);
const params: ConverseCommandInput = {
modelId: this.model,
messages,
inferenceConfig: {
maxTokens: request.maxTokens ?? this.defaultMaxTokens,
},
};
if (request.system) {
params.system = [{ text: request.system }];
}
if (request.tools && request.tools.length > 0) {
params.toolConfig = convertTools(request.tools);
}
const command = new ConverseCommand(params);
const response = await this.client.send(command);
// Extract text and tool_use content from the response
const outputContent = response.output?.message?.content ?? [];
const textParts: string[] = [];
const toolCalls: ModelToolCall[] = [];
for (const block of outputContent) {
if ('text' in block && block.text !== undefined) {
textParts.push(block.text);
}
if ('toolUse' in block && block.toolUse !== undefined) {
toolCalls.push({
id: block.toolUse.toolUseId ?? `bedrock_${Date.now()}`,
name: block.toolUse.name ?? '',
args: block.toolUse.input as unknown,
});
}
}
const content = textParts.join('');
// Map stop reason
let stopReason: string = 'end_turn';
if (response.stopReason === 'max_tokens') stopReason = 'max_tokens';
else if (response.stopReason === 'tool_use') stopReason = 'tool_use';
else if (response.stopReason === 'end_turn') stopReason = 'end_turn';
else if (response.stopReason) stopReason = response.stopReason;
return {
content,
stopReason,
usage: {
inputTokens: response.usage?.inputTokens ?? 0,
outputTokens: response.usage?.outputTokens ?? 0,
},
...(toolCalls.length > 0 ? { toolCalls } : {}),
};
}
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
const messages = convertMessages(request.messages);
const params: ConverseStreamCommandInput = {
modelId: this.model,
messages,
inferenceConfig: {
maxTokens: request.maxTokens ?? this.defaultMaxTokens,
},
};
if (request.system) {
params.system = [{ text: request.system }];
}
if (request.tools && request.tools.length > 0) {
params.toolConfig = convertTools(request.tools);
}
try {
const command = new ConverseStreamCommand(params);
const response = await this.client.send(command);
let inputTokens = 0;
let outputTokens = 0;
if (response.stream) {
for await (const event of response.stream) {
if (event.contentBlockDelta?.delta && 'text' in event.contentBlockDelta.delta && event.contentBlockDelta.delta.text) {
yield { type: 'content', content: event.contentBlockDelta.delta.text };
}
if (event.metadata?.usage) {
inputTokens = event.metadata.usage.inputTokens ?? inputTokens;
outputTokens = event.metadata.usage.outputTokens ?? outputTokens;
}
}
}
yield {
type: 'done',
usage: { inputTokens, outputTokens },
};
} catch (error) {
yield {
type: 'error',
error: error instanceof Error ? error : new Error(String(error)),
};
}
}
}
function convertMessages(messages: { role: string; content: string }[]): BedrockMessage[] {
return messages.map(m => ({
role: m.role === 'assistant' ? 'assistant' as const : 'user' as const,
content: [{ text: m.content }] as ContentBlock[],
}));
}
function convertTools(tools: ToolDefinition[]): ToolConfiguration {
return {
tools: tools.map(t => ({
toolSpec: {
name: t.name,
description: t.description,
inputSchema: {
json: t.input_schema as Record<string, unknown>,
},
},
} as BedrockTool)),
};
}
+12
View File
@@ -7,8 +7,20 @@ export const MODEL_COSTS_PER_MILLION: Record<string, { input: number; output: nu
// OpenAI
'gpt-4o': { input: 2.50, output: 10 },
'gpt-4o-mini': { input: 0.15, output: 0.60 },
// Gemini
'gemini-2.0-flash': { input: 0.10, output: 0.40 },
'gemini-2.0-flash-lite': { input: 0.025, output: 0.10 },
'gemini-2.5-pro': { input: 1.25, output: 10 },
'gemini-2.5-flash': { input: 0.15, output: 0.60 },
'gemini-1.5-pro': { input: 1.25, output: 5 },
'gemini-1.5-flash': { input: 0.075, output: 0.30 },
// Local / unknown models
'default': { input: 0, output: 0 },
// Bedrock (Meta Llama)
'meta.llama3-1-70b-instruct-v1:0': { input: 0.72, output: 0.72 },
'meta.llama3-1-8b-instruct-v1:0': { input: 0.22, output: 0.22 },
// Bedrock (Amazon Titan)
'amazon.titan-text-express-v1': { input: 0.20, output: 0.60 },
};
/**
+332
View File
@@ -0,0 +1,332 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { GeminiClient } from './gemini.js';
// Shared mock functions
const mockGenerateContent = vi.fn();
const mockGenerateContentStream = vi.fn();
const mockGetGenerativeModel = vi.fn().mockReturnValue({
generateContent: mockGenerateContent,
generateContentStream: mockGenerateContentStream,
});
vi.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({
getGenerativeModel: mockGetGenerativeModel,
})),
}));
function makeResponse(parts: unknown[], finishReason = 'STOP', usage = { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }) {
return {
response: {
candidates: [{
index: 0,
content: { role: 'model', parts },
finishReason,
}],
usageMetadata: usage,
text: () => {
const textParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'text' in p);
if (textParts.length === 0) throw new Error('No text parts');
return textParts.map((p: unknown) => (p as { text: string }).text).join('');
},
functionCalls: () => {
const fcParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'functionCall' in p);
if (fcParts.length === 0) return undefined;
return fcParts.map((p: unknown) => (p as { functionCall: { name: string; args: object } }).functionCall);
},
},
};
}
describe('GeminiClient', () => {
beforeEach(() => {
vi.clearAllMocks();
mockGenerateContent.mockResolvedValue(
makeResponse([{ text: 'Hello from Gemini!' }]),
);
mockGetGenerativeModel.mockReturnValue({
generateContent: mockGenerateContent,
generateContentStream: mockGenerateContentStream,
});
});
it('sends messages and returns response', async () => {
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
const response = await client.chat({
messages: [{ role: 'user', content: 'Hello' }],
});
expect(response.content).toBe('Hello from Gemini!');
expect(response.stopReason).toBe('end_turn');
expect(response.usage.inputTokens).toBe(10);
expect(response.usage.outputTokens).toBe(5);
});
it('passes system instruction to model', async () => {
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
await client.chat({
messages: [{ role: 'user', content: 'Hello' }],
system: 'You are a helpful assistant',
});
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
expect.objectContaining({
systemInstruction: 'You are a helpful assistant',
}),
);
});
it('converts assistant role to model role', async () => {
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
await client.chat({
messages: [
{ role: 'user', content: 'Hello' },
{ role: 'assistant', content: 'Hi there!' },
{ role: 'user', content: 'How are you?' },
],
});
expect(mockGenerateContent).toHaveBeenCalledWith({
contents: [
{ role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'model', parts: [{ text: 'Hi there!' }] },
{ role: 'user', parts: [{ text: 'How are you?' }] },
],
});
});
it('maps MAX_TOKENS finish reason', async () => {
mockGenerateContent.mockResolvedValueOnce(
makeResponse([{ text: 'Truncated...' }], 'MAX_TOKENS'),
);
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
const response = await client.chat({
messages: [{ role: 'user', content: 'Write a long story' }],
});
expect(response.stopReason).toBe('max_tokens');
});
it('uses environment variable for API key when not provided', () => {
const originalEnv = process.env.GOOGLE_API_KEY;
process.env.GOOGLE_API_KEY = 'env-key';
try {
// Just construct — we verify it doesn't throw
const _client = new GeminiClient({ model: 'gemini-2.0-flash' });
expect(_client).toBeDefined();
} finally {
if (originalEnv !== undefined) {
process.env.GOOGLE_API_KEY = originalEnv;
} else {
delete process.env.GOOGLE_API_KEY;
}
}
});
});
describe('GeminiClient streaming', () => {
beforeEach(() => {
vi.clearAllMocks();
mockGenerateContentStream.mockResolvedValue({
stream: (async function* () {
yield {
text: () => 'Hello ',
functionCalls: () => undefined,
candidates: [{ content: { parts: [{ text: 'Hello ' }] } }],
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 2, totalTokenCount: 12 },
};
yield {
text: () => 'from Gemini!',
functionCalls: () => undefined,
candidates: [{ content: { parts: [{ text: 'from Gemini!' }] } }],
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 },
};
})(),
response: Promise.resolve({
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 },
}),
});
mockGetGenerativeModel.mockReturnValue({
generateContent: mockGenerateContent,
generateContentStream: mockGenerateContentStream,
});
});
it('streams messages chunk by chunk', async () => {
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
const chunks: string[] = [];
let finalUsage: { inputTokens: number; outputTokens: number } | undefined;
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'Hello' }],
})) {
if (event.type === 'content' && event.content) {
chunks.push(event.content);
}
if (event.type === 'done' && event.usage) {
finalUsage = event.usage;
}
}
expect(chunks.length).toBeGreaterThan(0);
expect(chunks.join('')).toBe('Hello from Gemini!');
expect(finalUsage).toEqual({ inputTokens: 10, outputTokens: 5 });
});
it('yields error event on stream failure', async () => {
mockGenerateContentStream.mockRejectedValueOnce(new Error('Network error'));
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
const events: { type: string; error?: Error }[] = [];
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'Hello' }],
})) {
events.push(event);
}
expect(events).toHaveLength(1);
expect(events[0].type).toBe('error');
expect(events[0].error?.message).toBe('Network error');
});
});
describe('GeminiClient tool use', () => {
beforeEach(() => {
vi.clearAllMocks();
mockGetGenerativeModel.mockReturnValue({
generateContent: mockGenerateContent,
generateContentStream: mockGenerateContentStream,
});
});
it('passes tools and parses function call response', async () => {
mockGenerateContent.mockResolvedValueOnce(
makeResponse([{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }]),
);
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
const response = await client.chat({
messages: [{ role: 'user', content: 'list files' }],
tools: [{
name: 'shell.exec',
description: 'Run shell command',
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
}],
});
expect(response.stopReason).toBe('tool_use');
expect(response.toolCalls).toHaveLength(1);
expect(response.toolCalls![0].name).toBe('shell.exec');
expect(response.toolCalls![0].args).toEqual({ command: 'ls' });
// Verify tools were passed to getGenerativeModel
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
expect.objectContaining({
tools: [{
functionDeclarations: [{
name: 'shell.exec',
description: 'Run shell command',
parameters: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
}],
}],
}),
);
});
it('handles mixed text and function call response', async () => {
mockGenerateContent.mockResolvedValueOnce(
makeResponse([
{ text: 'Let me run that for you.' },
{ functionCall: { name: 'shell.exec', args: { command: 'ls -la' } } },
]),
);
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
const response = await client.chat({
messages: [{ role: 'user', content: 'list files in detail' }],
tools: [{
name: 'shell.exec',
description: 'Run shell command',
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
}],
});
expect(response.content).toBe('Let me run that for you.');
expect(response.stopReason).toBe('tool_use');
expect(response.toolCalls).toHaveLength(1);
});
it('streams function calls', async () => {
mockGenerateContentStream.mockResolvedValueOnce({
stream: (async function* () {
yield {
text: () => '',
functionCalls: () => [{ name: 'shell.exec', args: { command: 'ls' } }],
candidates: [{ content: { parts: [{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }] } }],
usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 },
};
})(),
response: Promise.resolve({
usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 },
}),
});
const client = new GeminiClient({
apiKey: 'test-key',
model: 'gemini-2.0-flash',
});
const toolCalls: { name: string; args: unknown }[] = [];
for await (const event of client.chatStream({
messages: [{ role: 'user', content: 'list files' }],
tools: [{
name: 'shell.exec',
description: 'Run shell command',
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
}],
})) {
if (event.type === 'tool_use' && event.toolCall) {
toolCalls.push({ name: event.toolCall.name, args: event.toolCall.args });
}
}
expect(toolCalls).toHaveLength(1);
expect(toolCalls[0].name).toBe('shell.exec');
expect(toolCalls[0].args).toEqual({ command: 'ls' });
});
});
+175
View File
@@ -0,0 +1,175 @@
import { GoogleGenerativeAI } from '@google/generative-ai';
import type { GenerativeModel, Content, FunctionDeclaration, FunctionDeclarationSchema } from '@google/generative-ai';
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition } from './types.js';
export interface GeminiClientConfig {
apiKey?: string;
model: string;
maxTokens?: number;
}
export class GeminiClient implements ModelClient {
private genAI: GoogleGenerativeAI;
private model: string;
private defaultMaxTokens: number;
constructor(config: GeminiClientConfig) {
const apiKey = config.apiKey ?? process.env.GOOGLE_API_KEY ?? '';
this.genAI = new GoogleGenerativeAI(apiKey);
this.model = config.model;
this.defaultMaxTokens = config.maxTokens ?? 8192;
}
private getModel(request: ChatRequest): GenerativeModel {
const tools = request.tools && request.tools.length > 0
? [{ functionDeclarations: request.tools.map(t => convertToolDefinition(t)) }]
: undefined;
return this.genAI.getGenerativeModel({
model: this.model,
systemInstruction: request.system || undefined,
tools,
generationConfig: {
maxOutputTokens: request.maxTokens ?? this.defaultMaxTokens,
},
});
}
async chat(request: ChatRequest): Promise<ChatResponse> {
const model = this.getModel(request);
const contents = convertMessages(request.messages);
const result = await model.generateContent({ contents });
const response = result.response;
const candidate = response.candidates?.[0];
// Extract text via the helper method
let content = '';
try {
content = response.text();
} catch {
// text() throws if blocked — fall back to manual extraction
const textParts = candidate?.content?.parts?.filter(p => 'text' in p && p.text !== undefined) ?? [];
content = textParts.map(p => (p as { text: string }).text).join('');
}
// Extract function calls via the helper method
const functionCalls = response.functionCalls();
const toolCalls: ModelToolCall[] = functionCalls
? functionCalls.map((fc, i) => ({
id: `gemini_${Date.now()}_${i}`,
name: fc.name,
args: fc.args,
}))
: [];
// Map finish reason
const finishReason = candidate?.finishReason;
let stopReason: string = 'end_turn';
if (toolCalls.length > 0) {
stopReason = 'tool_use';
} else if (finishReason === 'MAX_TOKENS') {
stopReason = 'max_tokens';
} else if (finishReason === 'STOP') {
stopReason = 'end_turn';
} else if (finishReason) {
stopReason = finishReason.toLowerCase();
}
// Extract usage
const usageMetadata = response.usageMetadata;
const usage = {
inputTokens: usageMetadata?.promptTokenCount ?? 0,
outputTokens: usageMetadata?.candidatesTokenCount ?? 0,
};
return {
content,
stopReason,
usage,
...(toolCalls.length > 0 ? { toolCalls } : {}),
};
}
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
const model = this.getModel(request);
const contents = convertMessages(request.messages);
try {
const result = await model.generateContentStream({ contents });
let totalInputTokens = 0;
let totalOutputTokens = 0;
for await (const chunk of result.stream) {
// Use the text() helper to extract text content from this chunk
try {
const text = chunk.text();
if (text) {
yield { type: 'content', content: text };
}
} catch {
// text() throws if blocked — skip
}
// Check for function calls in streaming chunks
const calls = chunk.functionCalls();
if (calls) {
for (const fc of calls) {
yield {
type: 'tool_use',
toolCall: {
id: `gemini_${Date.now()}`,
name: fc.name,
args: fc.args,
},
};
}
}
// Track usage from chunks
if (chunk.usageMetadata) {
totalInputTokens = chunk.usageMetadata.promptTokenCount ?? totalInputTokens;
totalOutputTokens = chunk.usageMetadata.candidatesTokenCount ?? totalOutputTokens;
}
}
// Final aggregated response for usage
const aggregated = await result.response;
const usageMetadata = aggregated.usageMetadata;
yield {
type: 'done',
usage: {
inputTokens: usageMetadata?.promptTokenCount ?? totalInputTokens,
outputTokens: usageMetadata?.candidatesTokenCount ?? totalOutputTokens,
},
};
} catch (error) {
yield {
type: 'error',
error: error instanceof Error ? error : new Error(String(error)),
};
}
}
}
/** Convert Flynn's Message[] to Gemini Content[] format */
function convertMessages(messages: { role: string; content: string }[]): Content[] {
return messages.map(m => ({
role: m.role === 'assistant' ? 'model' : 'user',
parts: [{ text: m.content }],
}));
}
/** Convert Flynn's ToolDefinition to Gemini FunctionDeclaration format */
function convertToolDefinition(tool: ToolDefinition): FunctionDeclaration {
// The Gemini SDK's FunctionDeclarationSchema expects `type: SchemaType` (enum)
// but the actual wire format accepts string values. We pass the schema through
// as-is since the SDK serialises it to JSON for the API request.
return {
name: tool.name,
description: tool.description,
parameters: tool.input_schema as unknown as FunctionDeclarationSchema,
};
}
+2
View File
@@ -1,5 +1,7 @@
export { AnthropicClient, type AnthropicClientConfig } from './anthropic.js';
export { OpenAIClient, type OpenAIClientConfig } from './openai.js';
export { GeminiClient, type GeminiClientConfig } from './gemini.js';
export { BedrockClient, type BedrockClientConfig } from './bedrock.js';
export { OllamaClient, type OllamaClientConfig } from './local/index.js';
export { LlamaCppClient, type LlamaCppClientConfig } from './local/index.js';
export { ModelRouter, type ModelRouterConfig, type ModelTier } from './router.js';