feat: add Gemini and Bedrock model providers
Add native GeminiClient using @google/generative-ai SDK and BedrockClient using @aws-sdk/client-bedrock-runtime. Replace the previous Gemini fallback (OpenAI-compatible shim) with the real implementation. Add OpenRouter as a provider option (OpenAI-compatible with custom baseURL). Update model costs, doctor CLI checks, and client factory tests.
This commit is contained in:
@@ -0,0 +1,175 @@
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai';
|
||||
import type { GenerativeModel, Content, FunctionDeclaration, FunctionDeclarationSchema } from '@google/generative-ai';
|
||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition } from './types.js';
|
||||
|
||||
export interface GeminiClientConfig {
|
||||
apiKey?: string;
|
||||
model: string;
|
||||
maxTokens?: number;
|
||||
}
|
||||
|
||||
export class GeminiClient implements ModelClient {
|
||||
private genAI: GoogleGenerativeAI;
|
||||
private model: string;
|
||||
private defaultMaxTokens: number;
|
||||
|
||||
constructor(config: GeminiClientConfig) {
|
||||
const apiKey = config.apiKey ?? process.env.GOOGLE_API_KEY ?? '';
|
||||
this.genAI = new GoogleGenerativeAI(apiKey);
|
||||
this.model = config.model;
|
||||
this.defaultMaxTokens = config.maxTokens ?? 8192;
|
||||
}
|
||||
|
||||
private getModel(request: ChatRequest): GenerativeModel {
|
||||
const tools = request.tools && request.tools.length > 0
|
||||
? [{ functionDeclarations: request.tools.map(t => convertToolDefinition(t)) }]
|
||||
: undefined;
|
||||
|
||||
return this.genAI.getGenerativeModel({
|
||||
model: this.model,
|
||||
systemInstruction: request.system || undefined,
|
||||
tools,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.maxTokens ?? this.defaultMaxTokens,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async chat(request: ChatRequest): Promise<ChatResponse> {
|
||||
const model = this.getModel(request);
|
||||
const contents = convertMessages(request.messages);
|
||||
|
||||
const result = await model.generateContent({ contents });
|
||||
const response = result.response;
|
||||
const candidate = response.candidates?.[0];
|
||||
|
||||
// Extract text via the helper method
|
||||
let content = '';
|
||||
try {
|
||||
content = response.text();
|
||||
} catch {
|
||||
// text() throws if blocked — fall back to manual extraction
|
||||
const textParts = candidate?.content?.parts?.filter(p => 'text' in p && p.text !== undefined) ?? [];
|
||||
content = textParts.map(p => (p as { text: string }).text).join('');
|
||||
}
|
||||
|
||||
// Extract function calls via the helper method
|
||||
const functionCalls = response.functionCalls();
|
||||
const toolCalls: ModelToolCall[] = functionCalls
|
||||
? functionCalls.map((fc, i) => ({
|
||||
id: `gemini_${Date.now()}_${i}`,
|
||||
name: fc.name,
|
||||
args: fc.args,
|
||||
}))
|
||||
: [];
|
||||
|
||||
// Map finish reason
|
||||
const finishReason = candidate?.finishReason;
|
||||
let stopReason: string = 'end_turn';
|
||||
if (toolCalls.length > 0) {
|
||||
stopReason = 'tool_use';
|
||||
} else if (finishReason === 'MAX_TOKENS') {
|
||||
stopReason = 'max_tokens';
|
||||
} else if (finishReason === 'STOP') {
|
||||
stopReason = 'end_turn';
|
||||
} else if (finishReason) {
|
||||
stopReason = finishReason.toLowerCase();
|
||||
}
|
||||
|
||||
// Extract usage
|
||||
const usageMetadata = response.usageMetadata;
|
||||
const usage = {
|
||||
inputTokens: usageMetadata?.promptTokenCount ?? 0,
|
||||
outputTokens: usageMetadata?.candidatesTokenCount ?? 0,
|
||||
};
|
||||
|
||||
return {
|
||||
content,
|
||||
stopReason,
|
||||
usage,
|
||||
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
|
||||
const model = this.getModel(request);
|
||||
const contents = convertMessages(request.messages);
|
||||
|
||||
try {
|
||||
const result = await model.generateContentStream({ contents });
|
||||
|
||||
let totalInputTokens = 0;
|
||||
let totalOutputTokens = 0;
|
||||
|
||||
for await (const chunk of result.stream) {
|
||||
// Use the text() helper to extract text content from this chunk
|
||||
try {
|
||||
const text = chunk.text();
|
||||
if (text) {
|
||||
yield { type: 'content', content: text };
|
||||
}
|
||||
} catch {
|
||||
// text() throws if blocked — skip
|
||||
}
|
||||
|
||||
// Check for function calls in streaming chunks
|
||||
const calls = chunk.functionCalls();
|
||||
if (calls) {
|
||||
for (const fc of calls) {
|
||||
yield {
|
||||
type: 'tool_use',
|
||||
toolCall: {
|
||||
id: `gemini_${Date.now()}`,
|
||||
name: fc.name,
|
||||
args: fc.args,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Track usage from chunks
|
||||
if (chunk.usageMetadata) {
|
||||
totalInputTokens = chunk.usageMetadata.promptTokenCount ?? totalInputTokens;
|
||||
totalOutputTokens = chunk.usageMetadata.candidatesTokenCount ?? totalOutputTokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Final aggregated response for usage
|
||||
const aggregated = await result.response;
|
||||
const usageMetadata = aggregated.usageMetadata;
|
||||
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: {
|
||||
inputTokens: usageMetadata?.promptTokenCount ?? totalInputTokens,
|
||||
outputTokens: usageMetadata?.candidatesTokenCount ?? totalOutputTokens,
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
yield {
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Convert Flynn's Message[] to Gemini Content[] format */
|
||||
function convertMessages(messages: { role: string; content: string }[]): Content[] {
|
||||
return messages.map(m => ({
|
||||
role: m.role === 'assistant' ? 'model' : 'user',
|
||||
parts: [{ text: m.content }],
|
||||
}));
|
||||
}
|
||||
|
||||
/** Convert Flynn's ToolDefinition to Gemini FunctionDeclaration format */
|
||||
function convertToolDefinition(tool: ToolDefinition): FunctionDeclaration {
|
||||
// The Gemini SDK's FunctionDeclarationSchema expects `type: SchemaType` (enum)
|
||||
// but the actual wire format accepts string values. We pass the schema through
|
||||
// as-is since the SDK serialises it to JSON for the API request.
|
||||
return {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema as unknown as FunctionDeclarationSchema,
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user