Files
flynn/src/models/gemini.ts
T
William Valentin 0eb1f7a073 feat: add Gemini and Bedrock model providers
Add native GeminiClient using @google/generative-ai SDK and BedrockClient
using @aws-sdk/client-bedrock-runtime. Replace the previous Gemini fallback
(OpenAI-compatible shim) with the real implementation. Add OpenRouter as a
provider option (OpenAI-compatible with custom baseURL). Update model costs,
doctor CLI checks, and client factory tests.
2026-02-06 16:51:32 -08:00

176 lines
5.6 KiB
TypeScript

import { GoogleGenerativeAI } from '@google/generative-ai';
import type { GenerativeModel, Content, FunctionDeclaration, FunctionDeclarationSchema } from '@google/generative-ai';
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition } from './types.js';
export interface GeminiClientConfig {
apiKey?: string;
model: string;
maxTokens?: number;
}
export class GeminiClient implements ModelClient {
private genAI: GoogleGenerativeAI;
private model: string;
private defaultMaxTokens: number;
constructor(config: GeminiClientConfig) {
const apiKey = config.apiKey ?? process.env.GOOGLE_API_KEY ?? '';
this.genAI = new GoogleGenerativeAI(apiKey);
this.model = config.model;
this.defaultMaxTokens = config.maxTokens ?? 8192;
}
private getModel(request: ChatRequest): GenerativeModel {
const tools = request.tools && request.tools.length > 0
? [{ functionDeclarations: request.tools.map(t => convertToolDefinition(t)) }]
: undefined;
return this.genAI.getGenerativeModel({
model: this.model,
systemInstruction: request.system || undefined,
tools,
generationConfig: {
maxOutputTokens: request.maxTokens ?? this.defaultMaxTokens,
},
});
}
async chat(request: ChatRequest): Promise<ChatResponse> {
const model = this.getModel(request);
const contents = convertMessages(request.messages);
const result = await model.generateContent({ contents });
const response = result.response;
const candidate = response.candidates?.[0];
// Extract text via the helper method
let content = '';
try {
content = response.text();
} catch {
// text() throws if blocked — fall back to manual extraction
const textParts = candidate?.content?.parts?.filter(p => 'text' in p && p.text !== undefined) ?? [];
content = textParts.map(p => (p as { text: string }).text).join('');
}
// Extract function calls via the helper method
const functionCalls = response.functionCalls();
const toolCalls: ModelToolCall[] = functionCalls
? functionCalls.map((fc, i) => ({
id: `gemini_${Date.now()}_${i}`,
name: fc.name,
args: fc.args,
}))
: [];
// Map finish reason
const finishReason = candidate?.finishReason;
let stopReason: string = 'end_turn';
if (toolCalls.length > 0) {
stopReason = 'tool_use';
} else if (finishReason === 'MAX_TOKENS') {
stopReason = 'max_tokens';
} else if (finishReason === 'STOP') {
stopReason = 'end_turn';
} else if (finishReason) {
stopReason = finishReason.toLowerCase();
}
// Extract usage
const usageMetadata = response.usageMetadata;
const usage = {
inputTokens: usageMetadata?.promptTokenCount ?? 0,
outputTokens: usageMetadata?.candidatesTokenCount ?? 0,
};
return {
content,
stopReason,
usage,
...(toolCalls.length > 0 ? { toolCalls } : {}),
};
}
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
const model = this.getModel(request);
const contents = convertMessages(request.messages);
try {
const result = await model.generateContentStream({ contents });
let totalInputTokens = 0;
let totalOutputTokens = 0;
for await (const chunk of result.stream) {
// Use the text() helper to extract text content from this chunk
try {
const text = chunk.text();
if (text) {
yield { type: 'content', content: text };
}
} catch {
// text() throws if blocked — skip
}
// Check for function calls in streaming chunks
const calls = chunk.functionCalls();
if (calls) {
for (const fc of calls) {
yield {
type: 'tool_use',
toolCall: {
id: `gemini_${Date.now()}`,
name: fc.name,
args: fc.args,
},
};
}
}
// Track usage from chunks
if (chunk.usageMetadata) {
totalInputTokens = chunk.usageMetadata.promptTokenCount ?? totalInputTokens;
totalOutputTokens = chunk.usageMetadata.candidatesTokenCount ?? totalOutputTokens;
}
}
// Final aggregated response for usage
const aggregated = await result.response;
const usageMetadata = aggregated.usageMetadata;
yield {
type: 'done',
usage: {
inputTokens: usageMetadata?.promptTokenCount ?? totalInputTokens,
outputTokens: usageMetadata?.candidatesTokenCount ?? totalOutputTokens,
},
};
} catch (error) {
yield {
type: 'error',
error: error instanceof Error ? error : new Error(String(error)),
};
}
}
}
/** Convert Flynn's Message[] to Gemini Content[] format */
function convertMessages(messages: { role: string; content: string }[]): Content[] {
return messages.map(m => ({
role: m.role === 'assistant' ? 'model' : 'user',
parts: [{ text: m.content }],
}));
}
/** Convert Flynn's ToolDefinition to Gemini FunctionDeclaration format */
function convertToolDefinition(tool: ToolDefinition): FunctionDeclaration {
// The Gemini SDK's FunctionDeclarationSchema expects `type: SchemaType` (enum)
// but the actual wire format accepts string values. We pass the schema through
// as-is since the SDK serialises it to JSON for the API request.
return {
name: tool.name,
description: tool.description,
parameters: tool.input_schema as unknown as FunctionDeclarationSchema,
};
}