From 26bd6ce65da13d835099d256cf0313fa88c59f7d Mon Sep 17 00:00:00 2001 From: William Valentin Date: Tue, 3 Feb 2026 00:29:52 -0800 Subject: [PATCH] feat: add model router with fallback chain support Co-Authored-By: Claude Opus 4.5 --- src/models/index.ts | 1 + src/models/router.test.ts | 80 +++++++++++++++++++++++++++++++++++++++ src/models/router.ts | 58 ++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+) create mode 100644 src/models/router.test.ts create mode 100644 src/models/router.ts diff --git a/src/models/index.ts b/src/models/index.ts index be1ada3..3ec116f 100644 --- a/src/models/index.ts +++ b/src/models/index.ts @@ -1,4 +1,5 @@ export { AnthropicClient, type AnthropicClientConfig } from './anthropic.js'; export { OpenAIClient, type OpenAIClientConfig } from './openai.js'; export { OllamaClient, type OllamaClientConfig } from './local/index.js'; +export { ModelRouter, type ModelRouterConfig, type ModelTier } from './router.js'; export type { Message, ChatRequest, ChatResponse, ModelClient } from './types.js'; diff --git a/src/models/router.test.ts b/src/models/router.test.ts new file mode 100644 index 0000000..00fce36 --- /dev/null +++ b/src/models/router.test.ts @@ -0,0 +1,80 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ModelRouter } from './router.js'; +import type { ModelClient, ChatResponse } from './types.js'; + +describe('ModelRouter', () => { + const createMockClient = (name: string, shouldFail = false): ModelClient => ({ + chat: vi.fn().mockImplementation(async () => { + if (shouldFail) { + throw new Error(`${name} failed`); + } + return { + content: `Response from ${name}`, + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5 }, + } satisfies ChatResponse; + }), + }); + + it('uses default client when available', async () => { + const defaultClient = createMockClient('default'); + const router = new ModelRouter({ + default: defaultClient, + fallbackChain: [], + }); + + const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }); + + expect(response.content).toBe('Response from default'); + expect(defaultClient.chat).toHaveBeenCalled(); + }); + + it('falls back to next provider on failure', async () => { + const failingClient = createMockClient('primary', true); + const fallbackClient = createMockClient('fallback'); + + const router = new ModelRouter({ + default: failingClient, + fallbackChain: [fallbackClient], + }); + + const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }); + + expect(response.content).toBe('Response from fallback'); + expect(failingClient.chat).toHaveBeenCalled(); + expect(fallbackClient.chat).toHaveBeenCalled(); + }); + + it('throws when all providers fail', async () => { + const failing1 = createMockClient('primary', true); + const failing2 = createMockClient('fallback', true); + + const router = new ModelRouter({ + default: failing1, + fallbackChain: [failing2], + }); + + await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] })) + .rejects.toThrow('All model providers failed'); + }); + + it('uses tier-specific client when specified', async () => { + const defaultClient = createMockClient('default'); + const fastClient = createMockClient('fast'); + + const router = new ModelRouter({ + default: defaultClient, + fast: fastClient, + fallbackChain: [], + }); + + const response = await router.chat( + { messages: [{ role: 'user', content: 'Hi' }] }, + 'fast' + ); + + expect(response.content).toBe('Response from fast'); + expect(fastClient.chat).toHaveBeenCalled(); + expect(defaultClient.chat).not.toHaveBeenCalled(); + }); +}); diff --git a/src/models/router.ts b/src/models/router.ts new file mode 100644 index 0000000..58c13f8 --- /dev/null +++ b/src/models/router.ts @@ -0,0 +1,58 @@ +import type { ChatRequest, ChatResponse, ModelClient } from './types.js'; + +export type ModelTier = 'fast' | 'default' | 'complex' | 'local'; + +export interface ModelRouterConfig { + default: ModelClient; + fast?: ModelClient; + complex?: ModelClient; + local?: ModelClient; + fallbackChain: ModelClient[]; +} + +export class ModelRouter implements ModelClient { + private clients: Map; + private defaultClient: ModelClient; + private fallbackChain: ModelClient[]; + + constructor(config: ModelRouterConfig) { + this.clients = new Map(); + this.defaultClient = config.default; + this.fallbackChain = config.fallbackChain; + + this.clients.set('default', config.default); + if (config.fast) this.clients.set('fast', config.fast); + if (config.complex) this.clients.set('complex', config.complex); + if (config.local) this.clients.set('local', config.local); + } + + async chat(request: ChatRequest, tier?: ModelTier): Promise { + const primaryClient = tier ? this.clients.get(tier) ?? this.defaultClient : this.defaultClient; + const errors: Error[] = []; + + // Try primary client + try { + return await primaryClient.chat(request); + } catch (error) { + errors.push(error instanceof Error ? error : new Error(String(error))); + console.warn(`Primary model failed: ${errors[0].message}`); + } + + // Try fallback chain + for (const fallbackClient of this.fallbackChain) { + try { + console.log('Trying fallback model...'); + return await fallbackClient.chat(request); + } catch (error) { + errors.push(error instanceof Error ? error : new Error(String(error))); + console.warn(`Fallback model failed: ${errors[errors.length - 1].message}`); + } + } + + throw new Error(`All model providers failed: ${errors.map(e => e.message).join(', ')}`); + } + + getClient(tier: ModelTier): ModelClient | undefined { + return this.clients.get(tier); + } +}