feat: add per-tier fallback support to ModelRouter
The router now accepts a tierFallbacks map so each model tier can have its own fallback providers. Tier fallbacks are tried before the global fallback chain in both chat() and chatStream(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -81,6 +81,83 @@ describe('ModelRouter', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('ModelRouter per-tier fallbacks', () => {
|
||||||
|
const createMockClient = (name: string, shouldFail = false): ModelClient => ({
|
||||||
|
chat: vi.fn().mockImplementation(async () => {
|
||||||
|
if (shouldFail) {
|
||||||
|
throw new Error(`${name} failed`);
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
content: `Response from ${name}`,
|
||||||
|
stopReason: 'end_turn',
|
||||||
|
usage: { inputTokens: 10, outputTokens: 5 },
|
||||||
|
} satisfies ChatResponse;
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
it('uses tier-specific fallback when tier client fails', async () => {
|
||||||
|
const defaultClient = createMockClient('default');
|
||||||
|
const failingFast = createMockClient('fast', true);
|
||||||
|
const fastFallback = createMockClient('fast-fallback');
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: defaultClient,
|
||||||
|
fast: failingFast,
|
||||||
|
fallbackChain: [],
|
||||||
|
tierFallbacks: new Map([['fast', [fastFallback]]]),
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await router.chat(
|
||||||
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
||||||
|
'fast',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(response.content).toBe('Response from fast-fallback');
|
||||||
|
expect(response.fallback).toBe(true);
|
||||||
|
expect(defaultClient.chat).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('falls through to global chain when tier fallback also fails', async () => {
|
||||||
|
const failingFast = createMockClient('fast', true);
|
||||||
|
const failingTierFallback = createMockClient('tier-fallback', true);
|
||||||
|
const globalFallback = createMockClient('global-fallback');
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: createMockClient('default'),
|
||||||
|
fast: failingFast,
|
||||||
|
fallbackChain: [globalFallback],
|
||||||
|
tierFallbacks: new Map([['fast', [failingTierFallback]]]),
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await router.chat(
|
||||||
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
||||||
|
'fast',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(response.content).toBe('Response from global-fallback');
|
||||||
|
expect(response.fallback).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('skips tier fallbacks when none configured for that tier', async () => {
|
||||||
|
const failingComplex = createMockClient('complex', true);
|
||||||
|
const globalFallback = createMockClient('global-fallback');
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: createMockClient('default'),
|
||||||
|
complex: failingComplex,
|
||||||
|
fallbackChain: [globalFallback],
|
||||||
|
tierFallbacks: new Map([['fast', [createMockClient('fast-fb')]]]),
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await router.chat(
|
||||||
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
||||||
|
'complex',
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(response.content).toBe('Response from global-fallback');
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('ModelRouter streaming', () => {
|
describe('ModelRouter streaming', () => {
|
||||||
it('streams from primary client', async () => {
|
it('streams from primary client', async () => {
|
||||||
const mockStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
const mockStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||||
@@ -147,6 +224,53 @@ describe('ModelRouter streaming', () => {
|
|||||||
expect(chunks).toEqual(['Fallback']);
|
expect(chunks).toEqual(['Fallback']);
|
||||||
expect(fallbackWarning).toMatch(/Primary model failed/);
|
expect(fallbackWarning).toMatch(/Primary model failed/);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('uses tier fallback for streaming before global chain', async () => {
|
||||||
|
const failingStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||||
|
yield { type: 'error', error: new Error('Primary failed') };
|
||||||
|
};
|
||||||
|
|
||||||
|
const tierFallbackStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||||
|
yield { type: 'content', content: 'TierFallback' };
|
||||||
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
||||||
|
};
|
||||||
|
|
||||||
|
const globalStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||||
|
yield { type: 'content', content: 'Global' };
|
||||||
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
||||||
|
};
|
||||||
|
|
||||||
|
const primaryClient = {
|
||||||
|
chat: vi.fn(),
|
||||||
|
chatStream: vi.fn().mockReturnValue(failingStream()),
|
||||||
|
};
|
||||||
|
|
||||||
|
const tierFallbackClient = {
|
||||||
|
chat: vi.fn(),
|
||||||
|
chatStream: vi.fn().mockReturnValue(tierFallbackStream()),
|
||||||
|
};
|
||||||
|
|
||||||
|
const globalClient = {
|
||||||
|
chat: vi.fn(),
|
||||||
|
chatStream: vi.fn().mockReturnValue(globalStream()),
|
||||||
|
};
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: primaryClient,
|
||||||
|
fallbackChain: [globalClient],
|
||||||
|
tierFallbacks: new Map([['default', [tierFallbackClient]]]),
|
||||||
|
});
|
||||||
|
|
||||||
|
const chunks: string[] = [];
|
||||||
|
for await (const event of router.chatStream({ messages: [] })) {
|
||||||
|
if (event.type === 'content' && event.content) {
|
||||||
|
chunks.push(event.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(chunks).toEqual(['TierFallback']);
|
||||||
|
expect(globalClient.chatStream).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('ModelRouter local client switching', () => {
|
describe('ModelRouter local client switching', () => {
|
||||||
|
|||||||
+48
-8
@@ -10,6 +10,7 @@ export interface ModelRouterConfig {
|
|||||||
complex?: ModelClient;
|
complex?: ModelClient;
|
||||||
local?: ModelClient;
|
local?: ModelClient;
|
||||||
fallbackChain: ModelClient[];
|
fallbackChain: ModelClient[];
|
||||||
|
tierFallbacks?: Map<ModelTier, ModelClient[]>;
|
||||||
retryConfig?: RetryConfig;
|
retryConfig?: RetryConfig;
|
||||||
labels?: Partial<Record<ModelTier, string>>;
|
labels?: Partial<Record<ModelTier, string>>;
|
||||||
}
|
}
|
||||||
@@ -19,6 +20,7 @@ export class ModelRouter implements ModelClient {
|
|||||||
private labels: Map<ModelTier, string>;
|
private labels: Map<ModelTier, string>;
|
||||||
private defaultClient: ModelClient;
|
private defaultClient: ModelClient;
|
||||||
private fallbackChain: ModelClient[];
|
private fallbackChain: ModelClient[];
|
||||||
|
private tierFallbacks: Map<ModelTier, ModelClient[]>;
|
||||||
private currentTier: ModelTier = 'default';
|
private currentTier: ModelTier = 'default';
|
||||||
private localProviderName?: string;
|
private localProviderName?: string;
|
||||||
private retryConfig?: RetryConfig;
|
private retryConfig?: RetryConfig;
|
||||||
@@ -28,6 +30,7 @@ export class ModelRouter implements ModelClient {
|
|||||||
this.labels = new Map();
|
this.labels = new Map();
|
||||||
this.defaultClient = config.default;
|
this.defaultClient = config.default;
|
||||||
this.fallbackChain = config.fallbackChain;
|
this.fallbackChain = config.fallbackChain;
|
||||||
|
this.tierFallbacks = config.tierFallbacks ?? new Map();
|
||||||
this.retryConfig = config.retryConfig;
|
this.retryConfig = config.retryConfig;
|
||||||
|
|
||||||
this.clients.set('default', config.default);
|
this.clients.set('default', config.default);
|
||||||
@@ -76,17 +79,31 @@ export class ModelRouter implements ModelClient {
|
|||||||
console.warn(`Primary model failed: ${errors[0].message}`);
|
console.warn(`Primary model failed: ${errors[0].message}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try fallback chain
|
// Try tier-specific fallbacks first
|
||||||
|
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
|
||||||
|
for (let i = 0; i < tierFallbackList.length; i++) {
|
||||||
|
try {
|
||||||
|
const reason = `Primary model failed (${errors[0].message}), using tier fallback #${i + 1}`;
|
||||||
|
console.warn(reason);
|
||||||
|
const response = await tierFallbackList[i].chat(request);
|
||||||
|
return { ...response, fallback: true, fallbackReason: reason };
|
||||||
|
} catch (error) {
|
||||||
|
errors.push(error instanceof Error ? error : new Error(String(error)));
|
||||||
|
console.warn(`Tier fallback #${i + 1} failed: ${errors[errors.length - 1].message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then try global fallback chain
|
||||||
for (let i = 0; i < this.fallbackChain.length; i++) {
|
for (let i = 0; i < this.fallbackChain.length; i++) {
|
||||||
const fallbackClient = this.fallbackChain[i];
|
const fallbackClient = this.fallbackChain[i];
|
||||||
try {
|
try {
|
||||||
const reason = `Primary model failed (${errors[0].message}), using fallback #${i + 1}`;
|
const reason = `Primary model failed (${errors[0].message}), using global fallback #${i + 1}`;
|
||||||
console.warn(reason);
|
console.warn(reason);
|
||||||
const response = await fallbackClient.chat(request);
|
const response = await fallbackClient.chat(request);
|
||||||
return { ...response, fallback: true, fallbackReason: reason };
|
return { ...response, fallback: true, fallbackReason: reason };
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
errors.push(error instanceof Error ? error : new Error(String(error)));
|
errors.push(error instanceof Error ? error : new Error(String(error)));
|
||||||
console.warn(`Fallback model #${i + 1} failed: ${errors[errors.length - 1].message}`);
|
console.warn(`Global fallback #${i + 1} failed: ${errors[errors.length - 1].message}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,12 +132,13 @@ export class ModelRouter implements ModelClient {
|
|||||||
primaryError = 'Primary client does not support streaming';
|
primaryError = 'Primary client does not support streaming';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try fallback chain
|
// Try tier-specific fallbacks first
|
||||||
for (let i = 0; i < this.fallbackChain.length; i++) {
|
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
|
||||||
const fallbackClient = this.fallbackChain[i];
|
for (let i = 0; i < tierFallbackList.length; i++) {
|
||||||
|
const fallbackClient = tierFallbackList[i];
|
||||||
if (!fallbackClient.chatStream) continue;
|
if (!fallbackClient.chatStream) continue;
|
||||||
|
|
||||||
const reason = `Primary model failed (${primaryError}), using fallback #${i + 1}`;
|
const reason = `Primary model failed (${primaryError}), using tier fallback #${i + 1}`;
|
||||||
console.warn(reason);
|
console.warn(reason);
|
||||||
yield { type: 'fallback_warning', fallbackReason: reason };
|
yield { type: 'fallback_warning', fallbackReason: reason };
|
||||||
|
|
||||||
@@ -128,7 +146,29 @@ export class ModelRouter implements ModelClient {
|
|||||||
for await (const event of fallbackClient.chatStream(request)) {
|
for await (const event of fallbackClient.chatStream(request)) {
|
||||||
if (event.type === 'error') {
|
if (event.type === 'error') {
|
||||||
hasError = true;
|
hasError = true;
|
||||||
console.warn(`Fallback stream #${i + 1} failed: ${event.error?.message}`);
|
console.warn(`Tier fallback stream #${i + 1} failed: ${event.error?.message}`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasError) return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then try global fallback chain
|
||||||
|
for (let i = 0; i < this.fallbackChain.length; i++) {
|
||||||
|
const fallbackClient = this.fallbackChain[i];
|
||||||
|
if (!fallbackClient.chatStream) continue;
|
||||||
|
|
||||||
|
const reason = `Primary model failed (${primaryError}), using global fallback #${i + 1}`;
|
||||||
|
console.warn(reason);
|
||||||
|
yield { type: 'fallback_warning', fallbackReason: reason };
|
||||||
|
|
||||||
|
let hasError = false;
|
||||||
|
for await (const event of fallbackClient.chatStream(request)) {
|
||||||
|
if (event.type === 'error') {
|
||||||
|
hasError = true;
|
||||||
|
console.warn(`Global fallback stream #${i + 1} failed: ${event.error?.message}`);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
yield event;
|
yield event;
|
||||||
|
|||||||
Reference in New Issue
Block a user