feat: add model router with fallback chain support
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
export { AnthropicClient, type AnthropicClientConfig } from './anthropic.js';
|
export { AnthropicClient, type AnthropicClientConfig } from './anthropic.js';
|
||||||
export { OpenAIClient, type OpenAIClientConfig } from './openai.js';
|
export { OpenAIClient, type OpenAIClientConfig } from './openai.js';
|
||||||
export { OllamaClient, type OllamaClientConfig } from './local/index.js';
|
export { OllamaClient, type OllamaClientConfig } from './local/index.js';
|
||||||
|
export { ModelRouter, type ModelRouterConfig, type ModelTier } from './router.js';
|
||||||
export type { Message, ChatRequest, ChatResponse, ModelClient } from './types.js';
|
export type { Message, ChatRequest, ChatResponse, ModelClient } from './types.js';
|
||||||
|
|||||||
@@ -0,0 +1,80 @@
|
|||||||
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||||
|
import { ModelRouter } from './router.js';
|
||||||
|
import type { ModelClient, ChatResponse } from './types.js';
|
||||||
|
|
||||||
|
describe('ModelRouter', () => {
|
||||||
|
const createMockClient = (name: string, shouldFail = false): ModelClient => ({
|
||||||
|
chat: vi.fn().mockImplementation(async () => {
|
||||||
|
if (shouldFail) {
|
||||||
|
throw new Error(`${name} failed`);
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
content: `Response from ${name}`,
|
||||||
|
stopReason: 'end_turn',
|
||||||
|
usage: { inputTokens: 10, outputTokens: 5 },
|
||||||
|
} satisfies ChatResponse;
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses default client when available', async () => {
|
||||||
|
const defaultClient = createMockClient('default');
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: defaultClient,
|
||||||
|
fallbackChain: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
||||||
|
|
||||||
|
expect(response.content).toBe('Response from default');
|
||||||
|
expect(defaultClient.chat).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('falls back to next provider on failure', async () => {
|
||||||
|
const failingClient = createMockClient('primary', true);
|
||||||
|
const fallbackClient = createMockClient('fallback');
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: failingClient,
|
||||||
|
fallbackChain: [fallbackClient],
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
||||||
|
|
||||||
|
expect(response.content).toBe('Response from fallback');
|
||||||
|
expect(failingClient.chat).toHaveBeenCalled();
|
||||||
|
expect(fallbackClient.chat).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('throws when all providers fail', async () => {
|
||||||
|
const failing1 = createMockClient('primary', true);
|
||||||
|
const failing2 = createMockClient('fallback', true);
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: failing1,
|
||||||
|
fallbackChain: [failing2],
|
||||||
|
});
|
||||||
|
|
||||||
|
await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] }))
|
||||||
|
.rejects.toThrow('All model providers failed');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses tier-specific client when specified', async () => {
|
||||||
|
const defaultClient = createMockClient('default');
|
||||||
|
const fastClient = createMockClient('fast');
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: defaultClient,
|
||||||
|
fast: fastClient,
|
||||||
|
fallbackChain: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await router.chat(
|
||||||
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
||||||
|
'fast'
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(response.content).toBe('Response from fast');
|
||||||
|
expect(fastClient.chat).toHaveBeenCalled();
|
||||||
|
expect(defaultClient.chat).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
import type { ChatRequest, ChatResponse, ModelClient } from './types.js';
|
||||||
|
|
||||||
|
export type ModelTier = 'fast' | 'default' | 'complex' | 'local';
|
||||||
|
|
||||||
|
export interface ModelRouterConfig {
|
||||||
|
default: ModelClient;
|
||||||
|
fast?: ModelClient;
|
||||||
|
complex?: ModelClient;
|
||||||
|
local?: ModelClient;
|
||||||
|
fallbackChain: ModelClient[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ModelRouter implements ModelClient {
|
||||||
|
private clients: Map<ModelTier, ModelClient>;
|
||||||
|
private defaultClient: ModelClient;
|
||||||
|
private fallbackChain: ModelClient[];
|
||||||
|
|
||||||
|
constructor(config: ModelRouterConfig) {
|
||||||
|
this.clients = new Map();
|
||||||
|
this.defaultClient = config.default;
|
||||||
|
this.fallbackChain = config.fallbackChain;
|
||||||
|
|
||||||
|
this.clients.set('default', config.default);
|
||||||
|
if (config.fast) this.clients.set('fast', config.fast);
|
||||||
|
if (config.complex) this.clients.set('complex', config.complex);
|
||||||
|
if (config.local) this.clients.set('local', config.local);
|
||||||
|
}
|
||||||
|
|
||||||
|
async chat(request: ChatRequest, tier?: ModelTier): Promise<ChatResponse> {
|
||||||
|
const primaryClient = tier ? this.clients.get(tier) ?? this.defaultClient : this.defaultClient;
|
||||||
|
const errors: Error[] = [];
|
||||||
|
|
||||||
|
// Try primary client
|
||||||
|
try {
|
||||||
|
return await primaryClient.chat(request);
|
||||||
|
} catch (error) {
|
||||||
|
errors.push(error instanceof Error ? error : new Error(String(error)));
|
||||||
|
console.warn(`Primary model failed: ${errors[0].message}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try fallback chain
|
||||||
|
for (const fallbackClient of this.fallbackChain) {
|
||||||
|
try {
|
||||||
|
console.log('Trying fallback model...');
|
||||||
|
return await fallbackClient.chat(request);
|
||||||
|
} catch (error) {
|
||||||
|
errors.push(error instanceof Error ? error : new Error(String(error)));
|
||||||
|
console.warn(`Fallback model failed: ${errors[errors.length - 1].message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(`All model providers failed: ${errors.map(e => e.message).join(', ')}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
getClient(tier: ModelTier): ModelClient | undefined {
|
||||||
|
return this.clients.get(tier);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user