import { describe, it, expect, vi, beforeEach } from 'vitest'; import { GeminiClient } from './gemini.js'; // Shared mock functions const mockGenerateContent = vi.fn(); const mockGenerateContentStream = vi.fn(); const mockGetGenerativeModel = vi.fn().mockReturnValue({ generateContent: mockGenerateContent, generateContentStream: mockGenerateContentStream, }); vi.mock('@google/generative-ai', () => ({ GoogleGenerativeAI: vi.fn().mockImplementation(() => ({ getGenerativeModel: mockGetGenerativeModel, })), })); function makeResponse(parts: unknown[], finishReason = 'STOP', usage = { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }) { return { response: { candidates: [{ index: 0, content: { role: 'model', parts }, finishReason, }], usageMetadata: usage, text: () => { const textParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'text' in p); if (textParts.length === 0) {throw new Error('No text parts');} return textParts.map((p: unknown) => (p as { text: string }).text).join(''); }, functionCalls: () => { const fcParts = parts.filter((p: unknown) => typeof p === 'object' && p !== null && 'functionCall' in p); if (fcParts.length === 0) {return undefined;} return fcParts.map((p: unknown) => (p as { functionCall: { name: string; args: object } }).functionCall); }, }, }; } describe('GeminiClient', () => { beforeEach(() => { vi.clearAllMocks(); mockGenerateContent.mockResolvedValue( makeResponse([{ text: 'Hello from Gemini!' }]), ); mockGetGenerativeModel.mockReturnValue({ generateContent: mockGenerateContent, generateContentStream: mockGenerateContentStream, }); }); it('sends messages and returns response', async () => { const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); const response = await client.chat({ messages: [{ role: 'user', content: 'Hello' }], }); expect(response.content).toBe('Hello from Gemini!'); expect(response.stopReason).toBe('end_turn'); expect(response.usage.inputTokens).toBe(10); expect(response.usage.outputTokens).toBe(5); }); it('passes system instruction to model', async () => { const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); await client.chat({ messages: [{ role: 'user', content: 'Hello' }], system: 'You are a helpful assistant', }); expect(mockGetGenerativeModel).toHaveBeenCalledWith( expect.objectContaining({ systemInstruction: 'You are a helpful assistant', }), ); }); it('converts assistant role to model role', async () => { const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); await client.chat({ messages: [ { role: 'user', content: 'Hello' }, { role: 'assistant', content: 'Hi there!' }, { role: 'user', content: 'How are you?' }, ], }); expect(mockGenerateContent).toHaveBeenCalledWith({ contents: [ { role: 'user', parts: [{ text: 'Hello' }] }, { role: 'model', parts: [{ text: 'Hi there!' }] }, { role: 'user', parts: [{ text: 'How are you?' }] }, ], }); }); it('maps MAX_TOKENS finish reason', async () => { mockGenerateContent.mockResolvedValueOnce( makeResponse([{ text: 'Truncated...' }], 'MAX_TOKENS'), ); const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); const response = await client.chat({ messages: [{ role: 'user', content: 'Write a long story' }], }); expect(response.stopReason).toBe('max_tokens'); }); it('uses environment variable for API key when not provided', () => { const originalEnv = process.env.GOOGLE_API_KEY; process.env.GOOGLE_API_KEY = 'env-key'; try { // Just construct — we verify it doesn't throw const _client = new GeminiClient({ model: 'gemini-2.0-flash' }); expect(_client).toBeDefined(); } finally { if (originalEnv !== undefined) { process.env.GOOGLE_API_KEY = originalEnv; } else { delete process.env.GOOGLE_API_KEY; } } }); }); describe('GeminiClient streaming', () => { beforeEach(() => { vi.clearAllMocks(); mockGenerateContentStream.mockResolvedValue({ stream: (async function* () { yield { text: () => 'Hello ', functionCalls: () => undefined, candidates: [{ content: { parts: [{ text: 'Hello ' }] } }], usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 2, totalTokenCount: 12 }, }; yield { text: () => 'from Gemini!', functionCalls: () => undefined, candidates: [{ content: { parts: [{ text: 'from Gemini!' }] } }], usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }, }; })(), response: Promise.resolve({ usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5, totalTokenCount: 15 }, }), }); mockGetGenerativeModel.mockReturnValue({ generateContent: mockGenerateContent, generateContentStream: mockGenerateContentStream, }); }); it('streams messages chunk by chunk', async () => { const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); const chunks: string[] = []; let finalUsage: { inputTokens: number; outputTokens: number } | undefined; for await (const event of client.chatStream({ messages: [{ role: 'user', content: 'Hello' }], })) { if (event.type === 'content' && event.content) { chunks.push(event.content); } if (event.type === 'done' && event.usage) { finalUsage = event.usage; } } expect(chunks.length).toBeGreaterThan(0); expect(chunks.join('')).toBe('Hello from Gemini!'); expect(finalUsage).toEqual({ inputTokens: 10, outputTokens: 5 }); }); it('yields error event on stream failure', async () => { mockGenerateContentStream.mockRejectedValueOnce(new Error('Network error')); const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); const events: { type: string; error?: Error }[] = []; for await (const event of client.chatStream({ messages: [{ role: 'user', content: 'Hello' }], })) { events.push(event); } expect(events).toHaveLength(1); expect(events[0].type).toBe('error'); expect(events[0].error?.message).toBe('Network error'); }); }); describe('GeminiClient tool use', () => { beforeEach(() => { vi.clearAllMocks(); mockGetGenerativeModel.mockReturnValue({ generateContent: mockGenerateContent, generateContentStream: mockGenerateContentStream, }); }); it('passes tools and parses function call response', async () => { mockGenerateContent.mockResolvedValueOnce( makeResponse([{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }]), ); const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); const response = await client.chat({ messages: [{ role: 'user', content: 'list files' }], tools: [{ name: 'shell.exec', description: 'Run shell command', input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, }], }); expect(response.stopReason).toBe('tool_use'); expect(response.toolCalls).toHaveLength(1); expect(response.toolCalls![0].name).toBe('shell.exec'); expect(response.toolCalls![0].args).toEqual({ command: 'ls' }); // Verify tools were passed to getGenerativeModel expect(mockGetGenerativeModel).toHaveBeenCalledWith( expect.objectContaining({ tools: [{ functionDeclarations: [{ name: 'shell.exec', description: 'Run shell command', parameters: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, }], }], }), ); }); it('handles mixed text and function call response', async () => { mockGenerateContent.mockResolvedValueOnce( makeResponse([ { text: 'Let me run that for you.' }, { functionCall: { name: 'shell.exec', args: { command: 'ls -la' } } }, ]), ); const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); const response = await client.chat({ messages: [{ role: 'user', content: 'list files in detail' }], tools: [{ name: 'shell.exec', description: 'Run shell command', input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, }], }); expect(response.content).toBe('Let me run that for you.'); expect(response.stopReason).toBe('tool_use'); expect(response.toolCalls).toHaveLength(1); }); it('streams function calls', async () => { mockGenerateContentStream.mockResolvedValueOnce({ stream: (async function* () { yield { text: () => '', functionCalls: () => [{ name: 'shell.exec', args: { command: 'ls' } }], candidates: [{ content: { parts: [{ functionCall: { name: 'shell.exec', args: { command: 'ls' } } }] } }], usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 }, }; })(), response: Promise.resolve({ usageMetadata: { promptTokenCount: 15, candidatesTokenCount: 10, totalTokenCount: 25 }, }), }); const client = new GeminiClient({ apiKey: 'test-key', model: 'gemini-2.0-flash', }); const toolCalls: { name: string; args: unknown }[] = []; for await (const event of client.chatStream({ messages: [{ role: 'user', content: 'list files' }], tools: [{ name: 'shell.exec', description: 'Run shell command', input_schema: { type: 'object', properties: { command: { type: 'string' } }, required: ['command'] }, }], })) { if (event.type === 'tool_use' && event.toolCall) { toolCalls.push({ name: event.toolCall.name, args: event.toolCall.args }); } } expect(toolCalls).toHaveLength(1); expect(toolCalls[0].name).toBe('shell.exec'); expect(toolCalls[0].args).toEqual({ command: 'ls' }); }); });