diff --git a/src/models/router.test.ts b/src/models/router.test.ts index 013e10e..4afa1dd 100644 --- a/src/models/router.test.ts +++ b/src/models/router.test.ts @@ -81,6 +81,83 @@ describe('ModelRouter', () => { }); }); +describe('ModelRouter per-tier fallbacks', () => { + const createMockClient = (name: string, shouldFail = false): ModelClient => ({ + chat: vi.fn().mockImplementation(async () => { + if (shouldFail) { + throw new Error(`${name} failed`); + } + return { + content: `Response from ${name}`, + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5 }, + } satisfies ChatResponse; + }), + }); + + it('uses tier-specific fallback when tier client fails', async () => { + const defaultClient = createMockClient('default'); + const failingFast = createMockClient('fast', true); + const fastFallback = createMockClient('fast-fallback'); + + const router = new ModelRouter({ + default: defaultClient, + fast: failingFast, + fallbackChain: [], + tierFallbacks: new Map([['fast', [fastFallback]]]), + }); + + const response = await router.chat( + { messages: [{ role: 'user', content: 'Hi' }] }, + 'fast', + ); + + expect(response.content).toBe('Response from fast-fallback'); + expect(response.fallback).toBe(true); + expect(defaultClient.chat).not.toHaveBeenCalled(); + }); + + it('falls through to global chain when tier fallback also fails', async () => { + const failingFast = createMockClient('fast', true); + const failingTierFallback = createMockClient('tier-fallback', true); + const globalFallback = createMockClient('global-fallback'); + + const router = new ModelRouter({ + default: createMockClient('default'), + fast: failingFast, + fallbackChain: [globalFallback], + tierFallbacks: new Map([['fast', [failingTierFallback]]]), + }); + + const response = await router.chat( + { messages: [{ role: 'user', content: 'Hi' }] }, + 'fast', + ); + + expect(response.content).toBe('Response from global-fallback'); + expect(response.fallback).toBe(true); + }); + + it('skips tier fallbacks when none configured for that tier', async () => { + const failingComplex = createMockClient('complex', true); + const globalFallback = createMockClient('global-fallback'); + + const router = new ModelRouter({ + default: createMockClient('default'), + complex: failingComplex, + fallbackChain: [globalFallback], + tierFallbacks: new Map([['fast', [createMockClient('fast-fb')]]]), + }); + + const response = await router.chat( + { messages: [{ role: 'user', content: 'Hi' }] }, + 'complex', + ); + + expect(response.content).toBe('Response from global-fallback'); + }); +}); + describe('ModelRouter streaming', () => { it('streams from primary client', async () => { const mockStream = async function* (): AsyncIterable { @@ -147,6 +224,53 @@ describe('ModelRouter streaming', () => { expect(chunks).toEqual(['Fallback']); expect(fallbackWarning).toMatch(/Primary model failed/); }); + + it('uses tier fallback for streaming before global chain', async () => { + const failingStream = async function* (): AsyncIterable { + yield { type: 'error', error: new Error('Primary failed') }; + }; + + const tierFallbackStream = async function* (): AsyncIterable { + yield { type: 'content', content: 'TierFallback' }; + yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; + }; + + const globalStream = async function* (): AsyncIterable { + yield { type: 'content', content: 'Global' }; + yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; + }; + + const primaryClient = { + chat: vi.fn(), + chatStream: vi.fn().mockReturnValue(failingStream()), + }; + + const tierFallbackClient = { + chat: vi.fn(), + chatStream: vi.fn().mockReturnValue(tierFallbackStream()), + }; + + const globalClient = { + chat: vi.fn(), + chatStream: vi.fn().mockReturnValue(globalStream()), + }; + + const router = new ModelRouter({ + default: primaryClient, + fallbackChain: [globalClient], + tierFallbacks: new Map([['default', [tierFallbackClient]]]), + }); + + const chunks: string[] = []; + for await (const event of router.chatStream({ messages: [] })) { + if (event.type === 'content' && event.content) { + chunks.push(event.content); + } + } + + expect(chunks).toEqual(['TierFallback']); + expect(globalClient.chatStream).not.toHaveBeenCalled(); + }); }); describe('ModelRouter local client switching', () => { diff --git a/src/models/router.ts b/src/models/router.ts index d33cbd9..774c6ad 100644 --- a/src/models/router.ts +++ b/src/models/router.ts @@ -10,6 +10,7 @@ export interface ModelRouterConfig { complex?: ModelClient; local?: ModelClient; fallbackChain: ModelClient[]; + tierFallbacks?: Map; retryConfig?: RetryConfig; labels?: Partial>; } @@ -19,6 +20,7 @@ export class ModelRouter implements ModelClient { private labels: Map; private defaultClient: ModelClient; private fallbackChain: ModelClient[]; + private tierFallbacks: Map; private currentTier: ModelTier = 'default'; private localProviderName?: string; private retryConfig?: RetryConfig; @@ -28,6 +30,7 @@ export class ModelRouter implements ModelClient { this.labels = new Map(); this.defaultClient = config.default; this.fallbackChain = config.fallbackChain; + this.tierFallbacks = config.tierFallbacks ?? new Map(); this.retryConfig = config.retryConfig; this.clients.set('default', config.default); @@ -76,17 +79,31 @@ export class ModelRouter implements ModelClient { console.warn(`Primary model failed: ${errors[0].message}`); } - // Try fallback chain + // Try tier-specific fallbacks first + const tierFallbackList = this.tierFallbacks.get(useTier) ?? []; + for (let i = 0; i < tierFallbackList.length; i++) { + try { + const reason = `Primary model failed (${errors[0].message}), using tier fallback #${i + 1}`; + console.warn(reason); + const response = await tierFallbackList[i].chat(request); + return { ...response, fallback: true, fallbackReason: reason }; + } catch (error) { + errors.push(error instanceof Error ? error : new Error(String(error))); + console.warn(`Tier fallback #${i + 1} failed: ${errors[errors.length - 1].message}`); + } + } + + // Then try global fallback chain for (let i = 0; i < this.fallbackChain.length; i++) { const fallbackClient = this.fallbackChain[i]; try { - const reason = `Primary model failed (${errors[0].message}), using fallback #${i + 1}`; + const reason = `Primary model failed (${errors[0].message}), using global fallback #${i + 1}`; console.warn(reason); const response = await fallbackClient.chat(request); return { ...response, fallback: true, fallbackReason: reason }; } catch (error) { errors.push(error instanceof Error ? error : new Error(String(error))); - console.warn(`Fallback model #${i + 1} failed: ${errors[errors.length - 1].message}`); + console.warn(`Global fallback #${i + 1} failed: ${errors[errors.length - 1].message}`); } } @@ -115,12 +132,13 @@ export class ModelRouter implements ModelClient { primaryError = 'Primary client does not support streaming'; } - // Try fallback chain - for (let i = 0; i < this.fallbackChain.length; i++) { - const fallbackClient = this.fallbackChain[i]; + // Try tier-specific fallbacks first + const tierFallbackList = this.tierFallbacks.get(useTier) ?? []; + for (let i = 0; i < tierFallbackList.length; i++) { + const fallbackClient = tierFallbackList[i]; if (!fallbackClient.chatStream) continue; - const reason = `Primary model failed (${primaryError}), using fallback #${i + 1}`; + const reason = `Primary model failed (${primaryError}), using tier fallback #${i + 1}`; console.warn(reason); yield { type: 'fallback_warning', fallbackReason: reason }; @@ -128,7 +146,29 @@ export class ModelRouter implements ModelClient { for await (const event of fallbackClient.chatStream(request)) { if (event.type === 'error') { hasError = true; - console.warn(`Fallback stream #${i + 1} failed: ${event.error?.message}`); + console.warn(`Tier fallback stream #${i + 1} failed: ${event.error?.message}`); + break; + } + yield event; + } + + if (!hasError) return; + } + + // Then try global fallback chain + for (let i = 0; i < this.fallbackChain.length; i++) { + const fallbackClient = this.fallbackChain[i]; + if (!fallbackClient.chatStream) continue; + + const reason = `Primary model failed (${primaryError}), using global fallback #${i + 1}`; + console.warn(reason); + yield { type: 'fallback_warning', fallbackReason: reason }; + + let hasError = false; + for await (const event of fallbackClient.chatStream(request)) { + if (event.type === 'error') { + hasError = true; + console.warn(`Global fallback stream #${i + 1} failed: ${event.error?.message}`); break; } yield event;