feat: add Gemini and Bedrock model providers
Add native GeminiClient using @google/generative-ai SDK and BedrockClient using @aws-sdk/client-bedrock-runtime. Replace the previous Gemini fallback (OpenAI-compatible shim) with the real implementation. Add OpenRouter as a provider option (OpenAI-compatible with custom baseURL). Update model costs, doctor CLI checks, and client factory tests.
This commit is contained in:
+7
-2
@@ -123,9 +123,14 @@ const checkModelConnectivity: Check = async (ctx) => {
|
||||
}
|
||||
|
||||
// Check if API key is present for providers that need one
|
||||
const needsKey = ['anthropic', 'openai', 'gemini'];
|
||||
const needsKey = ['anthropic', 'openai', 'gemini', 'openrouter'];
|
||||
if (needsKey.includes(model.provider) && !model.api_key && !model.auth_token) {
|
||||
const envVar = model.provider === 'anthropic' ? 'ANTHROPIC_API_KEY' : model.provider === 'openai' ? 'OPENAI_API_KEY' : undefined;
|
||||
const envVarMap: Record<string, string> = {
|
||||
anthropic: 'ANTHROPIC_API_KEY',
|
||||
openai: 'OPENAI_API_KEY',
|
||||
openrouter: 'OPENROUTER_API_KEY',
|
||||
};
|
||||
const envVar = envVarMap[model.provider];
|
||||
const hasEnv = envVar && process.env[envVar];
|
||||
if (!hasEnv) {
|
||||
return { status: 'warn', label: 'Model connectivity', detail: `${model.provider}/${model.model} — no API key or auth token found` };
|
||||
|
||||
@@ -4,6 +4,8 @@ import { AnthropicClient } from '../models/anthropic.js';
|
||||
import { OpenAIClient } from '../models/openai.js';
|
||||
import { OllamaClient } from '../models/local/ollama.js';
|
||||
import { LlamaCppClient } from '../models/local/llamacpp.js';
|
||||
import { GeminiClient } from '../models/gemini.js';
|
||||
import { BedrockClient } from '../models/bedrock.js';
|
||||
|
||||
describe('createClientFromConfig', () => {
|
||||
it('creates AnthropicClient for anthropic provider', () => {
|
||||
@@ -59,14 +61,13 @@ describe('createClientFromConfig', () => {
|
||||
expect(client).toBeInstanceOf(LlamaCppClient);
|
||||
});
|
||||
|
||||
it('creates OpenAI-compatible client for gemini provider (with warning)', () => {
|
||||
it('creates GeminiClient for gemini provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'gemini',
|
||||
model: 'gemini-2.5-pro',
|
||||
api_key: 'test-key',
|
||||
});
|
||||
// Gemini falls back to OpenAI-compatible client
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
expect(client).toBeInstanceOf(GeminiClient);
|
||||
});
|
||||
|
||||
it('throws for unknown provider', () => {
|
||||
@@ -75,4 +76,21 @@ describe('createClientFromConfig', () => {
|
||||
model: 'test',
|
||||
})).toThrow('Unknown model provider: unknown');
|
||||
});
|
||||
|
||||
it('creates OpenAIClient with OpenRouter baseURL for openrouter provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'openrouter',
|
||||
model: 'meta-llama/llama-3.1-70b',
|
||||
api_key: 'test-key',
|
||||
});
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
});
|
||||
|
||||
it('creates BedrockClient for bedrock provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'bedrock',
|
||||
model: 'anthropic.claude-3-sonnet',
|
||||
});
|
||||
expect(client).toBeInstanceOf(BedrockClient);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { BedrockClient } from './bedrock.js';
|
||||
import type { ChatStreamEvent } from './types.js';
|
||||
|
||||
const mockSend = vi.fn().mockResolvedValue({
|
||||
output: {
|
||||
message: {
|
||||
content: [{ text: 'Hello from Bedrock!' }],
|
||||
},
|
||||
},
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
});
|
||||
|
||||
vi.mock('@aws-sdk/client-bedrock-runtime', () => ({
|
||||
BedrockRuntimeClient: vi.fn().mockImplementation(() => ({
|
||||
send: mockSend,
|
||||
})),
|
||||
ConverseCommand: vi.fn().mockImplementation((params) => params),
|
||||
ConverseStreamCommand: vi.fn().mockImplementation((params) => params),
|
||||
}));
|
||||
|
||||
describe('BedrockClient', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockSend.mockResolvedValue({
|
||||
output: {
|
||||
message: {
|
||||
content: [{ text: 'Hello from Bedrock!' }],
|
||||
},
|
||||
},
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
});
|
||||
});
|
||||
|
||||
it('sends messages and returns response', async () => {
|
||||
const client = new BedrockClient({
|
||||
model: 'anthropic.claude-3-sonnet',
|
||||
region: 'us-east-1',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
});
|
||||
|
||||
expect(response.content).toBe('Hello from Bedrock!');
|
||||
expect(response.stopReason).toBe('end_turn');
|
||||
expect(response.usage.inputTokens).toBe(10);
|
||||
expect(response.usage.outputTokens).toBe(5);
|
||||
});
|
||||
|
||||
it('parses tool use response', async () => {
|
||||
mockSend.mockResolvedValueOnce({
|
||||
output: {
|
||||
message: {
|
||||
content: [{
|
||||
toolUse: {
|
||||
toolUseId: 'tool_01',
|
||||
name: 'shell.exec',
|
||||
input: { command: 'ls' },
|
||||
},
|
||||
}],
|
||||
},
|
||||
},
|
||||
stopReason: 'tool_use',
|
||||
usage: { inputTokens: 20, outputTokens: 15 },
|
||||
});
|
||||
|
||||
const client = new BedrockClient({
|
||||
model: 'anthropic.claude-3-sonnet',
|
||||
region: 'us-east-1',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'list files' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
});
|
||||
|
||||
expect(response.stopReason).toBe('tool_use');
|
||||
expect(response.toolCalls).toHaveLength(1);
|
||||
expect(response.toolCalls![0].name).toBe('shell.exec');
|
||||
expect(response.toolCalls![0].args).toEqual({ command: 'ls' });
|
||||
});
|
||||
|
||||
it('uses default region when none provided', async () => {
|
||||
const client = new BedrockClient({
|
||||
model: 'anthropic.claude-3-sonnet',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
});
|
||||
|
||||
expect(response.content).toBe('Hello from Bedrock!');
|
||||
});
|
||||
|
||||
it('passes system prompt to API', async () => {
|
||||
const client = new BedrockClient({
|
||||
model: 'anthropic.claude-3-sonnet',
|
||||
region: 'us-east-1',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
system: 'You are a helpful assistant.',
|
||||
});
|
||||
|
||||
expect(mockSend).toHaveBeenCalledTimes(1);
|
||||
// ConverseCommand is called with params that include system
|
||||
const { ConverseCommand } = await import('@aws-sdk/client-bedrock-runtime');
|
||||
expect(ConverseCommand).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
system: [{ text: 'You are a helpful assistant.' }],
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('BedrockClient streaming', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('streams content events', async () => {
|
||||
mockSend.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield { contentBlockDelta: { delta: { text: 'Hello ' } } };
|
||||
yield { contentBlockDelta: { delta: { text: 'from Bedrock!' } } };
|
||||
yield { metadata: { usage: { inputTokens: 10, outputTokens: 5 } } };
|
||||
})(),
|
||||
});
|
||||
|
||||
const client = new BedrockClient({
|
||||
model: 'anthropic.claude-3-sonnet',
|
||||
region: 'us-east-1',
|
||||
});
|
||||
|
||||
const chunks: string[] = [];
|
||||
let finalUsage: { inputTokens: number; outputTokens: number } | undefined;
|
||||
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
})) {
|
||||
if (event.type === 'content' && event.content) {
|
||||
chunks.push(event.content);
|
||||
}
|
||||
if (event.type === 'done' && event.usage) {
|
||||
finalUsage = event.usage;
|
||||
}
|
||||
}
|
||||
|
||||
expect(chunks.join('')).toBe('Hello from Bedrock!');
|
||||
expect(finalUsage).toEqual({ inputTokens: 10, outputTokens: 5 });
|
||||
});
|
||||
|
||||
it('yields error event on failure', async () => {
|
||||
mockSend.mockRejectedValueOnce(new Error('Service unavailable'));
|
||||
|
||||
const client = new BedrockClient({
|
||||
model: 'anthropic.claude-3-sonnet',
|
||||
region: 'us-east-1',
|
||||
});
|
||||
|
||||
const events: ChatStreamEvent[] = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toHaveLength(1);
|
||||
expect(events[0].type).toBe('error');
|
||||
expect(events[0].error?.message).toBe('Service unavailable');
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,179 @@
|
||||
import {
|
||||
BedrockRuntimeClient,
|
||||
ConverseCommand,
|
||||
ConverseStreamCommand,
|
||||
} from '@aws-sdk/client-bedrock-runtime';
|
||||
import type {
|
||||
Message as BedrockMessage,
|
||||
ContentBlock,
|
||||
ToolConfiguration,
|
||||
Tool as BedrockTool,
|
||||
ConverseCommandInput,
|
||||
ConverseStreamCommandInput,
|
||||
} from '@aws-sdk/client-bedrock-runtime';
|
||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition } from './types.js';
|
||||
|
||||
export interface BedrockClientConfig {
|
||||
model: string;
|
||||
region?: string;
|
||||
maxTokens?: number;
|
||||
/** AWS access key ID (if not using default credential chain). */
|
||||
accessKeyId?: string;
|
||||
/** AWS secret access key (if not using default credential chain). */
|
||||
secretAccessKey?: string;
|
||||
}
|
||||
|
||||
export class BedrockClient implements ModelClient {
|
||||
private client: BedrockRuntimeClient;
|
||||
private model: string;
|
||||
private defaultMaxTokens: number;
|
||||
|
||||
constructor(config: BedrockClientConfig) {
|
||||
const clientConfig: Record<string, unknown> = {
|
||||
region: config.region ?? process.env.AWS_REGION ?? 'us-east-1',
|
||||
};
|
||||
|
||||
if (config.accessKeyId && config.secretAccessKey) {
|
||||
clientConfig.credentials = {
|
||||
accessKeyId: config.accessKeyId,
|
||||
secretAccessKey: config.secretAccessKey,
|
||||
};
|
||||
}
|
||||
|
||||
this.client = new BedrockRuntimeClient(clientConfig);
|
||||
this.model = config.model;
|
||||
this.defaultMaxTokens = config.maxTokens ?? 4096;
|
||||
}
|
||||
|
||||
async chat(request: ChatRequest): Promise<ChatResponse> {
|
||||
const messages = convertMessages(request.messages);
|
||||
|
||||
const params: ConverseCommandInput = {
|
||||
modelId: this.model,
|
||||
messages,
|
||||
inferenceConfig: {
|
||||
maxTokens: request.maxTokens ?? this.defaultMaxTokens,
|
||||
},
|
||||
};
|
||||
|
||||
if (request.system) {
|
||||
params.system = [{ text: request.system }];
|
||||
}
|
||||
|
||||
if (request.tools && request.tools.length > 0) {
|
||||
params.toolConfig = convertTools(request.tools);
|
||||
}
|
||||
|
||||
const command = new ConverseCommand(params);
|
||||
const response = await this.client.send(command);
|
||||
|
||||
// Extract text and tool_use content from the response
|
||||
const outputContent = response.output?.message?.content ?? [];
|
||||
const textParts: string[] = [];
|
||||
const toolCalls: ModelToolCall[] = [];
|
||||
|
||||
for (const block of outputContent) {
|
||||
if ('text' in block && block.text !== undefined) {
|
||||
textParts.push(block.text);
|
||||
}
|
||||
if ('toolUse' in block && block.toolUse !== undefined) {
|
||||
toolCalls.push({
|
||||
id: block.toolUse.toolUseId ?? `bedrock_${Date.now()}`,
|
||||
name: block.toolUse.name ?? '',
|
||||
args: block.toolUse.input as unknown,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const content = textParts.join('');
|
||||
|
||||
// Map stop reason
|
||||
let stopReason: string = 'end_turn';
|
||||
if (response.stopReason === 'max_tokens') stopReason = 'max_tokens';
|
||||
else if (response.stopReason === 'tool_use') stopReason = 'tool_use';
|
||||
else if (response.stopReason === 'end_turn') stopReason = 'end_turn';
|
||||
else if (response.stopReason) stopReason = response.stopReason;
|
||||
|
||||
return {
|
||||
content,
|
||||
stopReason,
|
||||
usage: {
|
||||
inputTokens: response.usage?.inputTokens ?? 0,
|
||||
outputTokens: response.usage?.outputTokens ?? 0,
|
||||
},
|
||||
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
|
||||
const messages = convertMessages(request.messages);
|
||||
|
||||
const params: ConverseStreamCommandInput = {
|
||||
modelId: this.model,
|
||||
messages,
|
||||
inferenceConfig: {
|
||||
maxTokens: request.maxTokens ?? this.defaultMaxTokens,
|
||||
},
|
||||
};
|
||||
|
||||
if (request.system) {
|
||||
params.system = [{ text: request.system }];
|
||||
}
|
||||
|
||||
if (request.tools && request.tools.length > 0) {
|
||||
params.toolConfig = convertTools(request.tools);
|
||||
}
|
||||
|
||||
try {
|
||||
const command = new ConverseStreamCommand(params);
|
||||
const response = await this.client.send(command);
|
||||
|
||||
let inputTokens = 0;
|
||||
let outputTokens = 0;
|
||||
|
||||
if (response.stream) {
|
||||
for await (const event of response.stream) {
|
||||
if (event.contentBlockDelta?.delta && 'text' in event.contentBlockDelta.delta && event.contentBlockDelta.delta.text) {
|
||||
yield { type: 'content', content: event.contentBlockDelta.delta.text };
|
||||
}
|
||||
|
||||
if (event.metadata?.usage) {
|
||||
inputTokens = event.metadata.usage.inputTokens ?? inputTokens;
|
||||
outputTokens = event.metadata.usage.outputTokens ?? outputTokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: { inputTokens, outputTokens },
|
||||
};
|
||||
} catch (error) {
|
||||
yield {
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function convertMessages(messages: { role: string; content: string }[]): BedrockMessage[] {
|
||||
return messages.map(m => ({
|
||||
role: m.role === 'assistant' ? 'assistant' as const : 'user' as const,
|
||||
content: [{ text: m.content }] as ContentBlock[],
|
||||
}));
|
||||
}
|
||||
|
||||
function convertTools(tools: ToolDefinition[]): ToolConfiguration {
|
||||
return {
|
||||
tools: tools.map(t => ({
|
||||
toolSpec: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
inputSchema: {
|
||||
json: t.input_schema as Record<string, unknown>,
|
||||
},
|
||||
},
|
||||
} as BedrockTool)),
|
||||
};
|
||||
}
|
||||
@@ -7,8 +7,20 @@ export const MODEL_COSTS_PER_MILLION: Record<string, { input: number; output: nu
|
||||
// OpenAI
|
||||
'gpt-4o': { input: 2.50, output: 10 },
|
||||
'gpt-4o-mini': { input: 0.15, output: 0.60 },
|
||||
// Gemini
|
||||
'gemini-2.0-flash': { input: 0.10, output: 0.40 },
|
||||
'gemini-2.0-flash-lite': { input: 0.025, output: 0.10 },
|
||||
'gemini-2.5-pro': { input: 1.25, output: 10 },
|
||||
'gemini-2.5-flash': { input: 0.15, output: 0.60 },
|
||||
'gemini-1.5-pro': { input: 1.25, output: 5 },
|
||||
'gemini-1.5-flash': { input: 0.075, output: 0.30 },
|
||||
// Local / unknown models
|
||||
'default': { input: 0, output: 0 },
|
||||
// Bedrock (Meta Llama)
|
||||
'meta.llama3-1-70b-instruct-v1:0': { input: 0.72, output: 0.72 },
|
||||
'meta.llama3-1-8b-instruct-v1:0': { input: 0.22, output: 0.22 },
|
||||
// Bedrock (Amazon Titan)
|
||||
'amazon.titan-text-express-v1': { input: 0.20, output: 0.60 },
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,332 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { GeminiClient } from './gemini.js';
|
||||
|
||||
// Shared mock functions
|
||||
const mockGenerateContent = vi.fn();
|
||||
const mockGenerateContentStream = vi.fn();
|
||||
|
||||
const mockGetGenerativeModel = vi.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
});
|
||||
|
||||
vi.mock('@google/generative-ai', () => ({
|
||||
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({
|
||||
getGenerativeModel: mockGetGenerativeModel,
|
||||
})),
|
||||
}));
|
||||
|
||||
function makeResponse(parts: unknown[], finishReason = 'STOP', usage = { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }) {
|
||||
return {
|
||||
response: {
|
||||
candidates: [{
|
||||
index: 0,
|
||||
content: { role: 'model', parts },
|
||||
finishReason,
|
||||
}],
|
||||
usageMetadata: usage,
|
||||
text: () => {
|
||||
const textParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'text' in p);
|
||||
if (textParts.length === 0) throw new Error('No text parts');
|
||||
return textParts.map((p: unknown) => (p as { text: string }).text).join('');
|
||||
},
|
||||
functionCalls: () => {
|
||||
const fcParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'functionCall' in p);
|
||||
if (fcParts.length === 0) return undefined;
|
||||
return fcParts.map((p: unknown) => (p as { functionCall: { name: string; args: object } }).functionCall);
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe('GeminiClient', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockGenerateContent.mockResolvedValue(
|
||||
makeResponse([{ text: 'Hello from Gemini!' }]),
|
||||
);
|
||||
mockGetGenerativeModel.mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
});
|
||||
});
|
||||
|
||||
it('sends messages and returns response', async () => {
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
});
|
||||
|
||||
expect(response.content).toBe('Hello from Gemini!');
|
||||
expect(response.stopReason).toBe('end_turn');
|
||||
expect(response.usage.inputTokens).toBe(10);
|
||||
expect(response.usage.outputTokens).toBe(5);
|
||||
});
|
||||
|
||||
it('passes system instruction to model', async () => {
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
system: 'You are a helpful assistant',
|
||||
});
|
||||
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
systemInstruction: 'You are a helpful assistant',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('converts assistant role to model role', async () => {
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [
|
||||
{ role: 'user', content: 'Hello' },
|
||||
{ role: 'assistant', content: 'Hi there!' },
|
||||
{ role: 'user', content: 'How are you?' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith({
|
||||
contents: [
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
{ role: 'model', parts: [{ text: 'Hi there!' }] },
|
||||
{ role: 'user', parts: [{ text: 'How are you?' }] },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('maps MAX_TOKENS finish reason', async () => {
|
||||
mockGenerateContent.mockResolvedValueOnce(
|
||||
makeResponse([{ text: 'Truncated...' }], 'MAX_TOKENS'),
|
||||
);
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'Write a long story' }],
|
||||
});
|
||||
|
||||
expect(response.stopReason).toBe('max_tokens');
|
||||
});
|
||||
|
||||
it('uses environment variable for API key when not provided', () => {
|
||||
const originalEnv = process.env.GOOGLE_API_KEY;
|
||||
process.env.GOOGLE_API_KEY = 'env-key';
|
||||
|
||||
try {
|
||||
// Just construct — we verify it doesn't throw
|
||||
const _client = new GeminiClient({ model: 'gemini-2.0-flash' });
|
||||
expect(_client).toBeDefined();
|
||||
} finally {
|
||||
if (originalEnv !== undefined) {
|
||||
process.env.GOOGLE_API_KEY = originalEnv;
|
||||
} else {
|
||||
delete process.env.GOOGLE_API_KEY;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('GeminiClient streaming', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockGenerateContentStream.mockResolvedValue({
|
||||
stream: (async function* () {
|
||||
yield {
|
||||
text: () => 'Hello ',
|
||||
functionCalls: () => undefined,
|
||||
candidates: [{ content: { parts: [{ text: 'Hello ' }] } }],
|
||||
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 2, totalTokenCount: 12 },
|
||||
};
|
||||
yield {
|
||||
text: () => 'from Gemini!',
|
||||
functionCalls: () => undefined,
|
||||
candidates: [{ content: { parts: [{ text: 'from Gemini!' }] } }],
|
||||
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 },
|
||||
};
|
||||
})(),
|
||||
response: Promise.resolve({
|
||||
usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 },
|
||||
}),
|
||||
});
|
||||
mockGetGenerativeModel.mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
});
|
||||
});
|
||||
|
||||
it('streams messages chunk by chunk', async () => {
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const chunks: string[] = [];
|
||||
let finalUsage: { inputTokens: number; outputTokens: number } | undefined;
|
||||
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
})) {
|
||||
if (event.type === 'content' && event.content) {
|
||||
chunks.push(event.content);
|
||||
}
|
||||
if (event.type === 'done' && event.usage) {
|
||||
finalUsage = event.usage;
|
||||
}
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0);
|
||||
expect(chunks.join('')).toBe('Hello from Gemini!');
|
||||
expect(finalUsage).toEqual({ inputTokens: 10, outputTokens: 5 });
|
||||
});
|
||||
|
||||
it('yields error event on stream failure', async () => {
|
||||
mockGenerateContentStream.mockRejectedValueOnce(new Error('Network error'));
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const events: { type: string; error?: Error }[] = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'Hello' }],
|
||||
})) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toHaveLength(1);
|
||||
expect(events[0].type).toBe('error');
|
||||
expect(events[0].error?.message).toBe('Network error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('GeminiClient tool use', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockGetGenerativeModel.mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
generateContentStream: mockGenerateContentStream,
|
||||
});
|
||||
});
|
||||
|
||||
it('passes tools and parses function call response', async () => {
|
||||
mockGenerateContent.mockResolvedValueOnce(
|
||||
makeResponse([{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }]),
|
||||
);
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'list files' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
});
|
||||
|
||||
expect(response.stopReason).toBe('tool_use');
|
||||
expect(response.toolCalls).toHaveLength(1);
|
||||
expect(response.toolCalls![0].name).toBe('shell.exec');
|
||||
expect(response.toolCalls![0].args).toEqual({ command: 'ls' });
|
||||
|
||||
// Verify tools were passed to getGenerativeModel
|
||||
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tools: [{
|
||||
functionDeclarations: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
parameters: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
}],
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('handles mixed text and function call response', async () => {
|
||||
mockGenerateContent.mockResolvedValueOnce(
|
||||
makeResponse([
|
||||
{ text: 'Let me run that for you.' },
|
||||
{ functionCall: { name: 'shell.exec', args: { command: 'ls -la' } } },
|
||||
]),
|
||||
);
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const response = await client.chat({
|
||||
messages: [{ role: 'user', content: 'list files in detail' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
});
|
||||
|
||||
expect(response.content).toBe('Let me run that for you.');
|
||||
expect(response.stopReason).toBe('tool_use');
|
||||
expect(response.toolCalls).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('streams function calls', async () => {
|
||||
mockGenerateContentStream.mockResolvedValueOnce({
|
||||
stream: (async function* () {
|
||||
yield {
|
||||
text: () => '',
|
||||
functionCalls: () => [{ name: 'shell.exec', args: { command: 'ls' } }],
|
||||
candidates: [{ content: { parts: [{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }] } }],
|
||||
usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 },
|
||||
};
|
||||
})(),
|
||||
response: Promise.resolve({
|
||||
usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 },
|
||||
}),
|
||||
});
|
||||
|
||||
const client = new GeminiClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gemini-2.0-flash',
|
||||
});
|
||||
|
||||
const toolCalls: { name: string; args: unknown }[] = [];
|
||||
for await (const event of client.chatStream({
|
||||
messages: [{ role: 'user', content: 'list files' }],
|
||||
tools: [{
|
||||
name: 'shell.exec',
|
||||
description: 'Run shell command',
|
||||
input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] },
|
||||
}],
|
||||
})) {
|
||||
if (event.type === 'tool_use' && event.toolCall) {
|
||||
toolCalls.push({ name: event.toolCall.name, args: event.toolCall.args });
|
||||
}
|
||||
}
|
||||
|
||||
expect(toolCalls).toHaveLength(1);
|
||||
expect(toolCalls[0].name).toBe('shell.exec');
|
||||
expect(toolCalls[0].args).toEqual({ command: 'ls' });
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,175 @@
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai';
|
||||
import type { GenerativeModel, Content, FunctionDeclaration, FunctionDeclarationSchema } from '@google/generative-ai';
|
||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, ModelToolCall, ToolDefinition } from './types.js';
|
||||
|
||||
export interface GeminiClientConfig {
|
||||
apiKey?: string;
|
||||
model: string;
|
||||
maxTokens?: number;
|
||||
}
|
||||
|
||||
export class GeminiClient implements ModelClient {
|
||||
private genAI: GoogleGenerativeAI;
|
||||
private model: string;
|
||||
private defaultMaxTokens: number;
|
||||
|
||||
constructor(config: GeminiClientConfig) {
|
||||
const apiKey = config.apiKey ?? process.env.GOOGLE_API_KEY ?? '';
|
||||
this.genAI = new GoogleGenerativeAI(apiKey);
|
||||
this.model = config.model;
|
||||
this.defaultMaxTokens = config.maxTokens ?? 8192;
|
||||
}
|
||||
|
||||
private getModel(request: ChatRequest): GenerativeModel {
|
||||
const tools = request.tools && request.tools.length > 0
|
||||
? [{ functionDeclarations: request.tools.map(t => convertToolDefinition(t)) }]
|
||||
: undefined;
|
||||
|
||||
return this.genAI.getGenerativeModel({
|
||||
model: this.model,
|
||||
systemInstruction: request.system || undefined,
|
||||
tools,
|
||||
generationConfig: {
|
||||
maxOutputTokens: request.maxTokens ?? this.defaultMaxTokens,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async chat(request: ChatRequest): Promise<ChatResponse> {
|
||||
const model = this.getModel(request);
|
||||
const contents = convertMessages(request.messages);
|
||||
|
||||
const result = await model.generateContent({ contents });
|
||||
const response = result.response;
|
||||
const candidate = response.candidates?.[0];
|
||||
|
||||
// Extract text via the helper method
|
||||
let content = '';
|
||||
try {
|
||||
content = response.text();
|
||||
} catch {
|
||||
// text() throws if blocked — fall back to manual extraction
|
||||
const textParts = candidate?.content?.parts?.filter(p => 'text' in p && p.text !== undefined) ?? [];
|
||||
content = textParts.map(p => (p as { text: string }).text).join('');
|
||||
}
|
||||
|
||||
// Extract function calls via the helper method
|
||||
const functionCalls = response.functionCalls();
|
||||
const toolCalls: ModelToolCall[] = functionCalls
|
||||
? functionCalls.map((fc, i) => ({
|
||||
id: `gemini_${Date.now()}_${i}`,
|
||||
name: fc.name,
|
||||
args: fc.args,
|
||||
}))
|
||||
: [];
|
||||
|
||||
// Map finish reason
|
||||
const finishReason = candidate?.finishReason;
|
||||
let stopReason: string = 'end_turn';
|
||||
if (toolCalls.length > 0) {
|
||||
stopReason = 'tool_use';
|
||||
} else if (finishReason === 'MAX_TOKENS') {
|
||||
stopReason = 'max_tokens';
|
||||
} else if (finishReason === 'STOP') {
|
||||
stopReason = 'end_turn';
|
||||
} else if (finishReason) {
|
||||
stopReason = finishReason.toLowerCase();
|
||||
}
|
||||
|
||||
// Extract usage
|
||||
const usageMetadata = response.usageMetadata;
|
||||
const usage = {
|
||||
inputTokens: usageMetadata?.promptTokenCount ?? 0,
|
||||
outputTokens: usageMetadata?.candidatesTokenCount ?? 0,
|
||||
};
|
||||
|
||||
return {
|
||||
content,
|
||||
stopReason,
|
||||
usage,
|
||||
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
|
||||
const model = this.getModel(request);
|
||||
const contents = convertMessages(request.messages);
|
||||
|
||||
try {
|
||||
const result = await model.generateContentStream({ contents });
|
||||
|
||||
let totalInputTokens = 0;
|
||||
let totalOutputTokens = 0;
|
||||
|
||||
for await (const chunk of result.stream) {
|
||||
// Use the text() helper to extract text content from this chunk
|
||||
try {
|
||||
const text = chunk.text();
|
||||
if (text) {
|
||||
yield { type: 'content', content: text };
|
||||
}
|
||||
} catch {
|
||||
// text() throws if blocked — skip
|
||||
}
|
||||
|
||||
// Check for function calls in streaming chunks
|
||||
const calls = chunk.functionCalls();
|
||||
if (calls) {
|
||||
for (const fc of calls) {
|
||||
yield {
|
||||
type: 'tool_use',
|
||||
toolCall: {
|
||||
id: `gemini_${Date.now()}`,
|
||||
name: fc.name,
|
||||
args: fc.args,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Track usage from chunks
|
||||
if (chunk.usageMetadata) {
|
||||
totalInputTokens = chunk.usageMetadata.promptTokenCount ?? totalInputTokens;
|
||||
totalOutputTokens = chunk.usageMetadata.candidatesTokenCount ?? totalOutputTokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Final aggregated response for usage
|
||||
const aggregated = await result.response;
|
||||
const usageMetadata = aggregated.usageMetadata;
|
||||
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: {
|
||||
inputTokens: usageMetadata?.promptTokenCount ?? totalInputTokens,
|
||||
outputTokens: usageMetadata?.candidatesTokenCount ?? totalOutputTokens,
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
yield {
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Convert Flynn's Message[] to Gemini Content[] format */
|
||||
function convertMessages(messages: { role: string; content: string }[]): Content[] {
|
||||
return messages.map(m => ({
|
||||
role: m.role === 'assistant' ? 'model' : 'user',
|
||||
parts: [{ text: m.content }],
|
||||
}));
|
||||
}
|
||||
|
||||
/** Convert Flynn's ToolDefinition to Gemini FunctionDeclaration format */
|
||||
function convertToolDefinition(tool: ToolDefinition): FunctionDeclaration {
|
||||
// The Gemini SDK's FunctionDeclarationSchema expects `type: SchemaType` (enum)
|
||||
// but the actual wire format accepts string values. We pass the schema through
|
||||
// as-is since the SDK serialises it to JSON for the API request.
|
||||
return {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.input_schema as unknown as FunctionDeclarationSchema,
|
||||
};
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
export { AnthropicClient, type AnthropicClientConfig } from './anthropic.js';
|
||||
export { OpenAIClient, type OpenAIClientConfig } from './openai.js';
|
||||
export { GeminiClient, type GeminiClientConfig } from './gemini.js';
|
||||
export { BedrockClient, type BedrockClientConfig } from './bedrock.js';
|
||||
export { OllamaClient, type OllamaClientConfig } from './local/index.js';
|
||||
export { LlamaCppClient, type LlamaCppClientConfig } from './local/index.js';
|
||||
export { ModelRouter, type ModelRouterConfig, type ModelTier } from './router.js';
|
||||
|
||||
Reference in New Issue
Block a user