import { describe, it, expect, vi } from 'vitest'; import { ModelRouter } from './router.js'; import type { ModelClient, ChatResponse, ChatStreamEvent } from './types.js'; describe('ModelRouter', () => { const createMockClient = (name: string, shouldFail = false): ModelClient => ({ chat: vi.fn().mockImplementation(async () => { if (shouldFail) { throw new Error(`${name} failed`); } return { content: `Response from ${name}`, stopReason: 'end_turn', usage: { inputTokens: 10, outputTokens: 5 }, } satisfies ChatResponse; }), }); it('uses default client when available', async () => { const defaultClient = createMockClient('default'); const router = new ModelRouter({ default: defaultClient, fallbackChain: [], }); const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }); expect(response.content).toBe('Response from default'); expect(defaultClient.chat).toHaveBeenCalled(); }); it('falls back to next provider on failure', async () => { const failingClient = createMockClient('primary', true); const fallbackClient = createMockClient('fallback'); const router = new ModelRouter({ default: failingClient, fallbackChain: [fallbackClient], }); const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }); expect(response.content).toBe('Response from fallback'); expect(response.fallback).toBe(true); expect(response.fallbackReason).toMatch(/Primary model failed/); expect(failingClient.chat).toHaveBeenCalled(); expect(fallbackClient.chat).toHaveBeenCalled(); }); it('skips duplicate fallback clients that already failed as primary', async () => { const failingPrimary = createMockClient('primary', true); const fallbackClient = createMockClient('fallback'); const router = new ModelRouter({ default: failingPrimary, fallbackChain: [failingPrimary, fallbackClient], }); const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }); expect(response.content).toBe('Response from fallback'); expect(failingPrimary.chat).toHaveBeenCalledTimes(1); expect(fallbackClient.chat).toHaveBeenCalledTimes(1); }); it('applies retry policy to fallback clients', async () => { const failingPrimary = createMockClient('primary', true); let attempts = 0; const flakyFallback: ModelClient = { chat: vi.fn().mockImplementation(async () => { attempts += 1; if (attempts === 1) { throw new Error('transient'); } return { content: 'Recovered fallback', stopReason: 'end_turn', usage: { inputTokens: 1, outputTokens: 1 }, } satisfies ChatResponse; }), }; const router = new ModelRouter({ default: failingPrimary, fallbackChain: [flakyFallback], retryConfig: { maxRetries: 1, initialDelayMs: 1, backoffMultiplier: 1, maxDelayMs: 1, nonRetryablePatterns: [], }, }); const response = await router.chat({ messages: [{ role: 'user', content: 'retry fallback' }] }); expect(response.content).toBe('Recovered fallback'); expect(flakyFallback.chat).toHaveBeenCalledTimes(2); }); it('throws when all providers fail', async () => { const failing1 = createMockClient('primary', true); const failing2 = createMockClient('fallback', true); const router = new ModelRouter({ default: failing1, fallbackChain: [failing2], }); await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] })) .rejects.toThrow('All model providers failed'); }); it('uses tier-specific client when specified', async () => { const defaultClient = createMockClient('default'); const fastClient = createMockClient('fast'); const router = new ModelRouter({ default: defaultClient, fast: fastClient, fallbackChain: [], }); const response = await router.chat( { messages: [{ role: 'user', content: 'Hi' }] }, 'fast', ); expect(response.content).toBe('Response from fast'); expect(fastClient.chat).toHaveBeenCalled(); expect(defaultClient.chat).not.toHaveBeenCalled(); }); }); describe('ModelRouter per-tier fallbacks', () => { const createMockClient = (name: string, shouldFail = false): ModelClient => ({ chat: vi.fn().mockImplementation(async () => { if (shouldFail) { throw new Error(`${name} failed`); } return { content: `Response from ${name}`, stopReason: 'end_turn', usage: { inputTokens: 10, outputTokens: 5 }, } satisfies ChatResponse; }), }); it('uses tier-specific fallback when tier client fails', async () => { const defaultClient = createMockClient('default'); const failingFast = createMockClient('fast', true); const fastFallback = createMockClient('fast-fallback'); const router = new ModelRouter({ default: defaultClient, fast: failingFast, fallbackChain: [], tierFallbacks: new Map([['fast', [fastFallback]]]), }); const response = await router.chat( { messages: [{ role: 'user', content: 'Hi' }] }, 'fast', ); expect(response.content).toBe('Response from fast-fallback'); expect(response.fallback).toBe(true); expect(defaultClient.chat).not.toHaveBeenCalled(); }); it('falls through to global chain when tier fallback also fails', async () => { const failingFast = createMockClient('fast', true); const failingTierFallback = createMockClient('tier-fallback', true); const globalFallback = createMockClient('global-fallback'); const router = new ModelRouter({ default: createMockClient('default'), fast: failingFast, fallbackChain: [globalFallback], tierFallbacks: new Map([['fast', [failingTierFallback]]]), }); const response = await router.chat( { messages: [{ role: 'user', content: 'Hi' }] }, 'fast', ); expect(response.content).toBe('Response from global-fallback'); expect(response.fallback).toBe(true); }); it('skips tier fallbacks when none configured for that tier', async () => { const failingComplex = createMockClient('complex', true); const globalFallback = createMockClient('global-fallback'); const router = new ModelRouter({ default: createMockClient('default'), complex: failingComplex, fallbackChain: [globalFallback], tierFallbacks: new Map([['fast', [createMockClient('fast-fb')]]]), }); const response = await router.chat( { messages: [{ role: 'user', content: 'Hi' }] }, 'complex', ); expect(response.content).toBe('Response from global-fallback'); }); }); describe('ModelRouter streaming', () => { it('streams from primary client', async () => { const mockStream = async function* (): AsyncIterable { yield { type: 'content', content: 'Hello' }; yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; }; const mockClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(mockStream()), }; const router = new ModelRouter({ default: mockClient, fallbackChain: [], }); const chunks: string[] = []; for await (const event of router.chatStream({ messages: [] })) { if (event.type === 'content' && event.content) { chunks.push(event.content); } } expect(chunks).toEqual(['Hello']); }); it('falls back when primary stream fails', async () => { const failingStream = async function* (): AsyncIterable { yield { type: 'error', error: new Error('Primary failed') }; }; const fallbackStream = async function* (): AsyncIterable { yield { type: 'content', content: 'Fallback' }; yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; }; const primaryClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(failingStream()), }; const fallbackClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(fallbackStream()), }; const router = new ModelRouter({ default: primaryClient, fallbackChain: [fallbackClient], }); const chunks: string[] = []; let fallbackWarning: string | undefined; for await (const event of router.chatStream({ messages: [] })) { if (event.type === 'content' && event.content) { chunks.push(event.content); } if (event.type === 'fallback_warning') { fallbackWarning = event.fallbackReason; } } expect(chunks).toEqual(['Fallback']); expect(fallbackWarning).toMatch(/Primary model failed/); }); it('uses tier fallback for streaming before global chain', async () => { const failingStream = async function* (): AsyncIterable { yield { type: 'error', error: new Error('Primary failed') }; }; const tierFallbackStream = async function* (): AsyncIterable { yield { type: 'content', content: 'TierFallback' }; yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; }; const globalStream = async function* (): AsyncIterable { yield { type: 'content', content: 'Global' }; yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; }; const primaryClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(failingStream()), }; const tierFallbackClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(tierFallbackStream()), }; const globalClient = { chat: vi.fn(), chatStream: vi.fn().mockReturnValue(globalStream()), }; const router = new ModelRouter({ default: primaryClient, fallbackChain: [globalClient], tierFallbacks: new Map([['default', [tierFallbackClient]]]), }); const chunks: string[] = []; for await (const event of router.chatStream({ messages: [] })) { if (event.type === 'content' && event.content) { chunks.push(event.content); } } expect(chunks).toEqual(['TierFallback']); expect(globalClient.chatStream).not.toHaveBeenCalled(); }); }); describe('ModelRouter local client switching', () => { it('allows setting a new local client', () => { const mockDefault = { chat: vi.fn() } as unknown as ModelClient; const mockLocal1 = { chat: vi.fn() } as unknown as ModelClient; const mockLocal2 = { chat: vi.fn() } as unknown as ModelClient; const router = new ModelRouter({ default: mockDefault, local: mockLocal1, fallbackChain: [], }); expect(router.getLocalProviderName()).toBe(undefined); router.setLocalClient(mockLocal2, 'llamacpp'); expect(router.getLocalProviderName()).toBe('llamacpp'); expect(router.getClient('local')).toBe(mockLocal2); }); }); describe('setClient and labels', () => { it('setClient replaces an existing tier client', async () => { const mockClient1 = { chat: vi.fn() } as unknown as ModelClient; const mockClient2 = { chat: vi.fn() } as unknown as ModelClient; const router = new ModelRouter({ default: { chat: vi.fn() } as unknown as ModelClient, fast: mockClient1, fallbackChain: [], }); await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast'); expect(mockClient1.chat).toHaveBeenCalled(); expect(mockClient1.chat).toHaveBeenCalledTimes(1); router.setClient('fast', mockClient2, 'fast-replaced'); const newFastClient = router.getClient('fast'); expect(newFastClient).toBeDefined(); if (!newFastClient) { throw new Error('Expected fast client to be set'); } await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast'); expect(newFastClient.chat).toHaveBeenCalled(); expect(newFastClient.chat).toHaveBeenCalledTimes(1); expect(mockClient1.chat).toHaveBeenCalledTimes(1); }); it('setClient adds a new tier client', async () => { const mockClient1 = { chat: vi.fn() } as unknown as ModelClient; const mockClient2 = { chat: vi.fn() } as unknown as ModelClient; const router = new ModelRouter({ default: mockClient1, fallbackChain: [], }); expect(router.getClient('complex')).toBeUndefined(); router.setClient('complex', mockClient2, 'complex-tier'); const newClient = router.getClient('complex'); expect(newClient).toBe(mockClient2); if (!newClient) { throw new Error('Expected complex client to be set'); } await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'complex'); expect(newClient.chat).toHaveBeenCalled(); }); it('getLabel returns the label set by setClient', () => { const router = new ModelRouter({ default: { chat: vi.fn() } as unknown as ModelClient, fallbackChain: [], }); expect(router.getLabel('fast')).toBe('unknown'); router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier'); expect(router.getLabel('fast')).toBe('fast-tier'); }); it('getLabel returns "unknown" for unset tier', () => { const router = new ModelRouter({ default: { chat: vi.fn() } as unknown as ModelClient, fallbackChain: [], }); expect(router.getLabel('fast')).toBe('unknown'); expect(router.getLabel('complex')).toBe('unknown'); }); it('getAllLabels returns all tier labels', () => { const router = new ModelRouter({ default: { chat: vi.fn() } as unknown as ModelClient, fallbackChain: [], }); const labels = router.getAllLabels(); expect(labels).toEqual({}); router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier'); router.setClient('complex', { chat: vi.fn() } as unknown as ModelClient, 'complex-tier'); const allLabels = router.getAllLabels(); expect(allLabels).toEqual({ fast: 'fast-tier', complex: 'complex-tier', }); }); it('constructor accepts initial labels', async () => { const mockClient1 = { chat: vi.fn() } as unknown as ModelClient; const mockClient2 = { chat: vi.fn() } as unknown as ModelClient; const router = new ModelRouter({ default: mockClient1, fast: mockClient2, fallbackChain: [], labels: { default: 'default-tier', fast: 'fast-tier', }, }); expect(router.getClient('default')).toBe(mockClient1); expect(router.getClient('fast')).toBe(mockClient2); expect(router.getLabel('default')).toBe('default-tier'); expect(router.getLabel('fast')).toBe('fast-tier'); expect(router.getLabel('complex')).toBe('unknown'); await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'fast'); expect(mockClient2.chat).toHaveBeenCalled(); }); it('chat uses the new client after setClient', async () => { const mockClient1 = { chat: vi.fn() } as unknown as ModelClient; const mockClient2 = { chat: vi.fn() } as unknown as ModelClient; const router = new ModelRouter({ default: mockClient1, fast: { chat: vi.fn() } as unknown as ModelClient, fallbackChain: [], labels: { fast: 'original-fast', }, }); const initialFastClient = router.getClient('fast'); expect(initialFastClient).toBeDefined(); if (!initialFastClient) { throw new Error('Expected initial fast client to exist'); } await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast'); expect(initialFastClient.chat).toHaveBeenCalled(); expect(initialFastClient.chat).toHaveBeenCalledTimes(1); router.setClient('fast', mockClient2, 'fast-replaced'); const newFastClient = router.getClient('fast'); if (!newFastClient) { throw new Error('Expected replaced fast client to exist'); } await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast'); expect(newFastClient.chat).toHaveBeenCalled(); expect(newFastClient.chat).toHaveBeenCalledTimes(1); expect(initialFastClient.chat).toHaveBeenCalledTimes(1); }); it('strict tier mode disables fallback chain for that tier', async () => { const failingDefault = { chat: vi.fn().mockRejectedValue(new Error('primary failed')), } as unknown as ModelClient; const fallback = { chat: vi.fn().mockResolvedValue({ content: 'fallback', stopReason: 'end_turn', usage: { inputTokens: 1, outputTokens: 1 }, }), } as unknown as ModelClient; const router = new ModelRouter({ default: failingDefault, fallbackChain: [fallback], }); router.setTierStrict('default', true); await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'default')) .rejects.toThrow('primary failed'); expect(fallback.chat).not.toHaveBeenCalled(); expect(router.isTierStrict('default')).toBe(true); }); it('requestAbort interrupts retry loop before fallback chain', async () => { const primary = { chat: vi.fn().mockRejectedValue(new Error('temporary failure')), } as unknown as ModelClient; const fallback = { chat: vi.fn().mockResolvedValue({ content: 'fallback', stopReason: 'end_turn', usage: { inputTokens: 1, outputTokens: 1 }, }), } as unknown as ModelClient; const router = new ModelRouter({ default: primary, fallbackChain: [fallback], retryConfig: { maxRetries: 3, initialDelayMs: 80, backoffMultiplier: 1, maxDelayMs: 80, nonRetryablePatterns: [], }, }); const run = router.chat({ messages: [{ role: 'user', content: 'hi' }] }); setTimeout(() => router.requestAbort(), 10); await expect(run).rejects.toMatchObject({ name: 'AbortError' }); expect(primary.chat).toHaveBeenCalledTimes(1); expect(fallback.chat).not.toHaveBeenCalled(); }); it('setOnTierChange does not replace existing listeners', () => { const router = new ModelRouter({ default: { chat: vi.fn() } as unknown as ModelClient, fast: { chat: vi.fn() } as unknown as ModelClient, fallbackChain: [], }); const first = vi.fn(); const second = vi.fn(); router.addOnTierChange(first); router.setOnTierChange(second); router.setTier('fast'); expect(first).toHaveBeenCalledWith('fast'); expect(second).toHaveBeenCalledWith('fast'); }); });