feat: add per-tier fallback support to ModelRouter
The router now accepts a tierFallbacks map so each model tier can have its own fallback providers. Tier fallbacks are tried before the global fallback chain in both chat() and chatStream(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -81,6 +81,83 @@ describe('ModelRouter', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('ModelRouter per-tier fallbacks', () => {
|
||||
const createMockClient = (name: string, shouldFail = false): ModelClient => ({
|
||||
chat: vi.fn().mockImplementation(async () => {
|
||||
if (shouldFail) {
|
||||
throw new Error(`${name} failed`);
|
||||
}
|
||||
return {
|
||||
content: `Response from ${name}`,
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
} satisfies ChatResponse;
|
||||
}),
|
||||
});
|
||||
|
||||
it('uses tier-specific fallback when tier client fails', async () => {
|
||||
const defaultClient = createMockClient('default');
|
||||
const failingFast = createMockClient('fast', true);
|
||||
const fastFallback = createMockClient('fast-fallback');
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: defaultClient,
|
||||
fast: failingFast,
|
||||
fallbackChain: [],
|
||||
tierFallbacks: new Map([['fast', [fastFallback]]]),
|
||||
});
|
||||
|
||||
const response = await router.chat(
|
||||
{ messages: [{ role: 'user', content: 'Hi' }] },
|
||||
'fast',
|
||||
);
|
||||
|
||||
expect(response.content).toBe('Response from fast-fallback');
|
||||
expect(response.fallback).toBe(true);
|
||||
expect(defaultClient.chat).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('falls through to global chain when tier fallback also fails', async () => {
|
||||
const failingFast = createMockClient('fast', true);
|
||||
const failingTierFallback = createMockClient('tier-fallback', true);
|
||||
const globalFallback = createMockClient('global-fallback');
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: createMockClient('default'),
|
||||
fast: failingFast,
|
||||
fallbackChain: [globalFallback],
|
||||
tierFallbacks: new Map([['fast', [failingTierFallback]]]),
|
||||
});
|
||||
|
||||
const response = await router.chat(
|
||||
{ messages: [{ role: 'user', content: 'Hi' }] },
|
||||
'fast',
|
||||
);
|
||||
|
||||
expect(response.content).toBe('Response from global-fallback');
|
||||
expect(response.fallback).toBe(true);
|
||||
});
|
||||
|
||||
it('skips tier fallbacks when none configured for that tier', async () => {
|
||||
const failingComplex = createMockClient('complex', true);
|
||||
const globalFallback = createMockClient('global-fallback');
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: createMockClient('default'),
|
||||
complex: failingComplex,
|
||||
fallbackChain: [globalFallback],
|
||||
tierFallbacks: new Map([['fast', [createMockClient('fast-fb')]]]),
|
||||
});
|
||||
|
||||
const response = await router.chat(
|
||||
{ messages: [{ role: 'user', content: 'Hi' }] },
|
||||
'complex',
|
||||
);
|
||||
|
||||
expect(response.content).toBe('Response from global-fallback');
|
||||
});
|
||||
});
|
||||
|
||||
describe('ModelRouter streaming', () => {
|
||||
it('streams from primary client', async () => {
|
||||
const mockStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||
@@ -147,6 +224,53 @@ describe('ModelRouter streaming', () => {
|
||||
expect(chunks).toEqual(['Fallback']);
|
||||
expect(fallbackWarning).toMatch(/Primary model failed/);
|
||||
});
|
||||
|
||||
it('uses tier fallback for streaming before global chain', async () => {
|
||||
const failingStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||
yield { type: 'error', error: new Error('Primary failed') };
|
||||
};
|
||||
|
||||
const tierFallbackStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||
yield { type: 'content', content: 'TierFallback' };
|
||||
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
||||
};
|
||||
|
||||
const globalStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||
yield { type: 'content', content: 'Global' };
|
||||
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
||||
};
|
||||
|
||||
const primaryClient = {
|
||||
chat: vi.fn(),
|
||||
chatStream: vi.fn().mockReturnValue(failingStream()),
|
||||
};
|
||||
|
||||
const tierFallbackClient = {
|
||||
chat: vi.fn(),
|
||||
chatStream: vi.fn().mockReturnValue(tierFallbackStream()),
|
||||
};
|
||||
|
||||
const globalClient = {
|
||||
chat: vi.fn(),
|
||||
chatStream: vi.fn().mockReturnValue(globalStream()),
|
||||
};
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: primaryClient,
|
||||
fallbackChain: [globalClient],
|
||||
tierFallbacks: new Map([['default', [tierFallbackClient]]]),
|
||||
});
|
||||
|
||||
const chunks: string[] = [];
|
||||
for await (const event of router.chatStream({ messages: [] })) {
|
||||
if (event.type === 'content' && event.content) {
|
||||
chunks.push(event.content);
|
||||
}
|
||||
}
|
||||
|
||||
expect(chunks).toEqual(['TierFallback']);
|
||||
expect(globalClient.chatStream).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('ModelRouter local client switching', () => {
|
||||
|
||||
Reference in New Issue
Block a user