import { describe, it, expect, vi, beforeEach } from 'vitest'; import { ModelRouter } from './router.js'; import type { ModelClient, ChatResponse, ChatStreamEvent } from './types.js'; describe('ModelRouter', () => { const createMockClient = (name: string, shouldFail = false): ModelClient => ({ chat: vi.fn().mockImplementation(async () => { if (shouldFail) { throw new Error(`${name} failed`); } return { content: `Response from ${name}`, stopReason: 'end_turn', usage: { inputTokens: 10, outputTokens: 5 }, } satisfies ChatResponse; }), }); it('uses default client when available', async () => { const defaultClient = createMockClient('default'); const router = new ModelRouter({ default: defaultClient, fallbackChain: [], }); const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }); expect(response.content).toBe('Response from default'); expect(defaultClient.chat).toHaveBeenCalled(); }); it('falls back to next provider on failure', async () => { const failingClient = createMockClient('primary', true); const fallbackClient = createMockClient('fallback'); const router = new ModelRouter({ default: failingClient, fallbackChain: [fallbackClient], }); const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }); expect(response.content).toBe('Response from fallback'); expect(failingClient.chat).toHaveBeenCalled(); expect(fallbackClient.chat).toHaveBeenCalled(); }); it('throws when all providers fail', async () => { const failing1 = createMockClient('primary', true); const failing2 = createMockClient('fallback', true); const router = new ModelRouter({ default: failing1, fallbackChain: [failing2], }); await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] })) .rejects.toThrow('All model providers failed'); }); it('uses tier-specific client when specified', async () => { const defaultClient = createMockClient('default'); const fastClient = createMockClient('fast'); const router = new ModelRouter({ default: defaultClient, fast: fastClient, fallbackChain: [], }); const response = await router.chat( { messages: [{ role: 'user', content: 'Hi' }] }, 'fast' ); expect(response.content).toBe('Response from fast'); expect(fastClient.chat).toHaveBeenCalled(); expect(defaultClient.chat).not.toHaveBeenCalled(); }); }); describe('ModelRouter streaming', () => { it('streams from primary client', async () => { const mockStream = async function* (): AsyncIterable { yield { type: 'content', content: 'Hello' }; yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; }; const mockClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(mockStream()), }; const router = new ModelRouter({ default: mockClient, fallbackChain: [], }); const chunks: string[] = []; for await (const event of router.chatStream({ messages: [] })) { if (event.type === 'content' && event.content) { chunks.push(event.content); } } expect(chunks).toEqual(['Hello']); }); it('falls back when primary stream fails', async () => { const failingStream = async function* (): AsyncIterable { yield { type: 'error', error: new Error('Primary failed') }; }; const fallbackStream = async function* (): AsyncIterable { yield { type: 'content', content: 'Fallback' }; yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; }; const primaryClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(failingStream()), }; const fallbackClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(fallbackStream()), }; const router = new ModelRouter({ default: primaryClient, fallbackChain: [fallbackClient], }); const chunks: string[] = []; for await (const event of router.chatStream({ messages: [] })) { if (event.type === 'content' && event.content) { chunks.push(event.content); } } expect(chunks).toEqual(['Fallback']); }); });