Files
flynn/src/models/gemini.ts
T

245 lines
7.6 KiB
TypeScript

import { GoogleGenerativeAI } from '@google/generative-ai';
import type { GenerativeModel, Content, Part, FunctionDeclaration, FunctionDeclarationSchema } from '@google/generative-ai';
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition, Message } from './types.js';
export interface GeminiClientConfig {
apiKey?: string;
model: string;
maxTokens?: number;
}
export class GeminiClient implements ModelClient {
private genAI: GoogleGenerativeAI;
private model: string;
private defaultMaxTokens: number;
constructor(config: GeminiClientConfig) {
const apiKey = config.apiKey ?? process.env.GOOGLE_API_KEY ?? '';
this.genAI = new GoogleGenerativeAI(apiKey);
this.model = config.model;
this.defaultMaxTokens = config.maxTokens ?? 8192;
}
private getModel(request: ChatRequest): GenerativeModel {
const tools = request.tools && request.tools.length > 0
? [{ functionDeclarations: request.tools.map(t => convertToolDefinition(t)) }]
: undefined;
const generationConfig: Record<string, unknown> = {
maxOutputTokens: request.maxTokens ?? this.defaultMaxTokens,
};
// Extended thinking mode
if (request.thinking) {
generationConfig.thinkingConfig = { thinkingBudget: 4096 };
}
return this.genAI.getGenerativeModel({
model: this.model,
systemInstruction: request.system || undefined,
tools,
generationConfig,
});
}
async chat(request: ChatRequest): Promise<ChatResponse> {
const model = this.getModel(request);
const contents = await convertMessages(request.messages);
const result = await model.generateContent({ contents });
const response = result.response;
const candidate = response.candidates?.[0];
// Extract text via the helper method
let content = '';
try {
content = response.text();
} catch {
// text() throws if blocked — fall back to manual extraction
const textParts = candidate?.content?.parts?.filter(p => 'text' in p && p.text !== undefined) ?? [];
content = textParts.map(p => (p as { text: string }).text).join('');
}
// Extract function calls via the helper method
const functionCalls = response.functionCalls();
const toolCalls: ModelToolCall[] = functionCalls
? functionCalls.map((fc, i) => ({
id: `gemini_${Date.now()}_${i}`,
name: fc.name,
args: fc.args,
}))
: [];
// Map finish reason
const finishReason = candidate?.finishReason;
let stopReason: string = 'end_turn';
if (toolCalls.length > 0) {
stopReason = 'tool_use';
} else if (finishReason === 'MAX_TOKENS') {
stopReason = 'max_tokens';
} else if (finishReason === 'STOP') {
stopReason = 'end_turn';
} else if (finishReason) {
stopReason = finishReason.toLowerCase();
}
// Extract usage
const usageMetadata = response.usageMetadata;
const usage = {
inputTokens: usageMetadata?.promptTokenCount ?? 0,
outputTokens: usageMetadata?.candidatesTokenCount ?? 0,
};
return {
content,
stopReason,
usage,
...(toolCalls.length > 0 ? { toolCalls } : {}),
};
}
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
const model = this.getModel(request);
const contents = await convertMessages(request.messages);
try {
const result = await model.generateContentStream({ contents });
let totalInputTokens = 0;
let totalOutputTokens = 0;
for await (const chunk of result.stream) {
// Use the text() helper to extract text content from this chunk
try {
const text = chunk.text();
if (text) {
yield { type: 'content', content: text };
}
} catch {
// text() throws if blocked — skip
}
// Check for function calls in streaming chunks
const calls = chunk.functionCalls();
if (calls) {
for (const fc of calls) {
yield {
type: 'tool_use',
toolCall: {
id: `gemini_${Date.now()}`,
name: fc.name,
args: fc.args,
},
};
}
}
// Track usage from chunks
if (chunk.usageMetadata) {
totalInputTokens = chunk.usageMetadata.promptTokenCount ?? totalInputTokens;
totalOutputTokens = chunk.usageMetadata.candidatesTokenCount ?? totalOutputTokens;
}
}
// Final aggregated response for usage
const aggregated = await result.response;
const usageMetadata = aggregated.usageMetadata;
yield {
type: 'done',
usage: {
inputTokens: usageMetadata?.promptTokenCount ?? totalInputTokens,
outputTokens: usageMetadata?.candidatesTokenCount ?? totalOutputTokens,
},
};
} catch (error) {
yield {
type: 'error',
error: error instanceof Error ? error : new Error(String(error)),
};
}
}
}
/** Convert Flynn's Message[] to Gemini Content[] format, including multimodal parts */
async function convertMessages(messages: Message[]): Promise<Content[]> {
return Promise.all(messages.map(async (m) => {
const role = m.role === 'assistant' ? 'model' : 'user';
if (typeof m.content === 'string') {
return { role, parts: [{ text: m.content }] };
}
// Multimodal content — convert each part
const parts = await Promise.all(m.content.map(async (part): Promise<Part> => {
if (part.type === 'text') {
return { text: part.text };
}
if (part.type === 'image') {
if (part.source.type === 'base64' && part.source.data) {
return {
inlineData: {
mimeType: part.source.media_type,
data: part.source.data,
},
};
}
if (part.source.type === 'url' && part.source.url) {
const inlineImage = await fetchImageAsInlineData(part.source.url, part.source.media_type);
if (inlineImage) {
return inlineImage;
}
}
return { text: `[Image: ${part.source.url ?? 'unavailable'}]` };
}
// Audio part — Gemini supports native audio via inlineData (same format as images)
if (part.type === 'audio') {
return {
inlineData: {
mimeType: part.source.media_type,
data: part.source.data,
},
};
}
return { text: JSON.stringify(part) };
}));
return { role, parts };
}));
}
async function fetchImageAsInlineData(url: string, fallbackMimeType: string): Promise<Part | null> {
try {
const response = await fetch(url);
if (!response.ok) {
return null;
}
const mimeTypeHeader = response.headers.get('content-type');
const mimeType = mimeTypeHeader ? mimeTypeHeader.split(';')[0].trim() : fallbackMimeType;
const data = Buffer.from(await response.arrayBuffer()).toString('base64');
if (!data) {
return null;
}
return {
inlineData: {
mimeType,
data,
},
};
} catch {
return null;
}
}
/** Convert Flynn's ToolDefinition to Gemini FunctionDeclaration format */
function convertToolDefinition(tool: ToolDefinition): FunctionDeclaration {
// The Gemini SDK's FunctionDeclarationSchema expects `type: SchemaType` (enum)
// but the actual wire format accepts string values. We pass the schema through
// as-is since the SDK serialises it to JSON for the API request.
return {
name: tool.name,
description: tool.description,
parameters: tool.input_schema as unknown as FunctionDeclarationSchema,
};
}