580 lines
18 KiB
TypeScript
580 lines
18 KiB
TypeScript
import { describe, it, expect, vi } from 'vitest';
|
|
import { ModelRouter } from './router.js';
|
|
import type { ModelClient, ChatResponse, ChatStreamEvent } from './types.js';
|
|
|
|
describe('ModelRouter', () => {
|
|
const createMockClient = (name: string, shouldFail = false): ModelClient => ({
|
|
chat: vi.fn().mockImplementation(async () => {
|
|
if (shouldFail) {
|
|
throw new Error(`${name} failed`);
|
|
}
|
|
return {
|
|
content: `Response from ${name}`,
|
|
stopReason: 'end_turn',
|
|
usage: { inputTokens: 10, outputTokens: 5 },
|
|
} satisfies ChatResponse;
|
|
}),
|
|
});
|
|
|
|
it('uses default client when available', async () => {
|
|
const defaultClient = createMockClient('default');
|
|
const router = new ModelRouter({
|
|
default: defaultClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
|
|
|
expect(response.content).toBe('Response from default');
|
|
expect(defaultClient.chat).toHaveBeenCalled();
|
|
});
|
|
|
|
it('falls back to next provider on failure', async () => {
|
|
const failingClient = createMockClient('primary', true);
|
|
const fallbackClient = createMockClient('fallback');
|
|
|
|
const router = new ModelRouter({
|
|
default: failingClient,
|
|
fallbackChain: [fallbackClient],
|
|
});
|
|
|
|
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
|
|
|
expect(response.content).toBe('Response from fallback');
|
|
expect(response.fallback).toBe(true);
|
|
expect(response.fallbackReason).toMatch(/Primary model failed/);
|
|
expect(failingClient.chat).toHaveBeenCalled();
|
|
expect(fallbackClient.chat).toHaveBeenCalled();
|
|
});
|
|
|
|
it('skips duplicate fallback clients that already failed as primary', async () => {
|
|
const failingPrimary = createMockClient('primary', true);
|
|
const fallbackClient = createMockClient('fallback');
|
|
|
|
const router = new ModelRouter({
|
|
default: failingPrimary,
|
|
fallbackChain: [failingPrimary, fallbackClient],
|
|
});
|
|
|
|
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
|
|
|
expect(response.content).toBe('Response from fallback');
|
|
expect(failingPrimary.chat).toHaveBeenCalledTimes(1);
|
|
expect(fallbackClient.chat).toHaveBeenCalledTimes(1);
|
|
});
|
|
|
|
it('applies retry policy to fallback clients', async () => {
|
|
const failingPrimary = createMockClient('primary', true);
|
|
let attempts = 0;
|
|
const flakyFallback: ModelClient = {
|
|
chat: vi.fn().mockImplementation(async () => {
|
|
attempts += 1;
|
|
if (attempts === 1) {
|
|
throw new Error('transient');
|
|
}
|
|
return {
|
|
content: 'Recovered fallback',
|
|
stopReason: 'end_turn',
|
|
usage: { inputTokens: 1, outputTokens: 1 },
|
|
} satisfies ChatResponse;
|
|
}),
|
|
};
|
|
|
|
const router = new ModelRouter({
|
|
default: failingPrimary,
|
|
fallbackChain: [flakyFallback],
|
|
retryConfig: {
|
|
maxRetries: 1,
|
|
initialDelayMs: 1,
|
|
backoffMultiplier: 1,
|
|
maxDelayMs: 1,
|
|
nonRetryablePatterns: [],
|
|
},
|
|
});
|
|
|
|
const response = await router.chat({ messages: [{ role: 'user', content: 'retry fallback' }] });
|
|
|
|
expect(response.content).toBe('Recovered fallback');
|
|
expect(flakyFallback.chat).toHaveBeenCalledTimes(2);
|
|
});
|
|
|
|
it('throws when all providers fail', async () => {
|
|
const failing1 = createMockClient('primary', true);
|
|
const failing2 = createMockClient('fallback', true);
|
|
|
|
const router = new ModelRouter({
|
|
default: failing1,
|
|
fallbackChain: [failing2],
|
|
});
|
|
|
|
await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] }))
|
|
.rejects.toThrow('All model providers failed');
|
|
});
|
|
|
|
it('uses tier-specific client when specified', async () => {
|
|
const defaultClient = createMockClient('default');
|
|
const fastClient = createMockClient('fast');
|
|
|
|
const router = new ModelRouter({
|
|
default: defaultClient,
|
|
fast: fastClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
const response = await router.chat(
|
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
|
'fast',
|
|
);
|
|
|
|
expect(response.content).toBe('Response from fast');
|
|
expect(fastClient.chat).toHaveBeenCalled();
|
|
expect(defaultClient.chat).not.toHaveBeenCalled();
|
|
});
|
|
});
|
|
|
|
describe('ModelRouter per-tier fallbacks', () => {
|
|
const createMockClient = (name: string, shouldFail = false): ModelClient => ({
|
|
chat: vi.fn().mockImplementation(async () => {
|
|
if (shouldFail) {
|
|
throw new Error(`${name} failed`);
|
|
}
|
|
return {
|
|
content: `Response from ${name}`,
|
|
stopReason: 'end_turn',
|
|
usage: { inputTokens: 10, outputTokens: 5 },
|
|
} satisfies ChatResponse;
|
|
}),
|
|
});
|
|
|
|
it('uses tier-specific fallback when tier client fails', async () => {
|
|
const defaultClient = createMockClient('default');
|
|
const failingFast = createMockClient('fast', true);
|
|
const fastFallback = createMockClient('fast-fallback');
|
|
|
|
const router = new ModelRouter({
|
|
default: defaultClient,
|
|
fast: failingFast,
|
|
fallbackChain: [],
|
|
tierFallbacks: new Map([['fast', [fastFallback]]]),
|
|
});
|
|
|
|
const response = await router.chat(
|
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
|
'fast',
|
|
);
|
|
|
|
expect(response.content).toBe('Response from fast-fallback');
|
|
expect(response.fallback).toBe(true);
|
|
expect(defaultClient.chat).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('falls through to global chain when tier fallback also fails', async () => {
|
|
const failingFast = createMockClient('fast', true);
|
|
const failingTierFallback = createMockClient('tier-fallback', true);
|
|
const globalFallback = createMockClient('global-fallback');
|
|
|
|
const router = new ModelRouter({
|
|
default: createMockClient('default'),
|
|
fast: failingFast,
|
|
fallbackChain: [globalFallback],
|
|
tierFallbacks: new Map([['fast', [failingTierFallback]]]),
|
|
});
|
|
|
|
const response = await router.chat(
|
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
|
'fast',
|
|
);
|
|
|
|
expect(response.content).toBe('Response from global-fallback');
|
|
expect(response.fallback).toBe(true);
|
|
});
|
|
|
|
it('skips tier fallbacks when none configured for that tier', async () => {
|
|
const failingComplex = createMockClient('complex', true);
|
|
const globalFallback = createMockClient('global-fallback');
|
|
|
|
const router = new ModelRouter({
|
|
default: createMockClient('default'),
|
|
complex: failingComplex,
|
|
fallbackChain: [globalFallback],
|
|
tierFallbacks: new Map([['fast', [createMockClient('fast-fb')]]]),
|
|
});
|
|
|
|
const response = await router.chat(
|
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
|
'complex',
|
|
);
|
|
|
|
expect(response.content).toBe('Response from global-fallback');
|
|
});
|
|
});
|
|
|
|
describe('ModelRouter streaming', () => {
|
|
it('streams from primary client', async () => {
|
|
const mockStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'content', content: 'Hello' };
|
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
|
};
|
|
|
|
const mockClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(mockStream()),
|
|
};
|
|
|
|
const router = new ModelRouter({
|
|
default: mockClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
const chunks: string[] = [];
|
|
for await (const event of router.chatStream({ messages: [] })) {
|
|
if (event.type === 'content' && event.content) {
|
|
chunks.push(event.content);
|
|
}
|
|
}
|
|
|
|
expect(chunks).toEqual(['Hello']);
|
|
});
|
|
|
|
it('falls back when primary stream fails', async () => {
|
|
const failingStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'error', error: new Error('Primary failed') };
|
|
};
|
|
|
|
const fallbackStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'content', content: 'Fallback' };
|
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
|
};
|
|
|
|
const primaryClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(failingStream()),
|
|
};
|
|
|
|
const fallbackClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(fallbackStream()),
|
|
};
|
|
|
|
const router = new ModelRouter({
|
|
default: primaryClient,
|
|
fallbackChain: [fallbackClient],
|
|
});
|
|
|
|
const chunks: string[] = [];
|
|
let fallbackWarning: string | undefined;
|
|
for await (const event of router.chatStream({ messages: [] })) {
|
|
if (event.type === 'content' && event.content) {
|
|
chunks.push(event.content);
|
|
}
|
|
if (event.type === 'fallback_warning') {
|
|
fallbackWarning = event.fallbackReason;
|
|
}
|
|
}
|
|
|
|
expect(chunks).toEqual(['Fallback']);
|
|
expect(fallbackWarning).toMatch(/Primary model failed/);
|
|
});
|
|
|
|
it('uses tier fallback for streaming before global chain', async () => {
|
|
const failingStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'error', error: new Error('Primary failed') };
|
|
};
|
|
|
|
const tierFallbackStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'content', content: 'TierFallback' };
|
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
|
};
|
|
|
|
const globalStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'content', content: 'Global' };
|
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
|
};
|
|
|
|
const primaryClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(failingStream()),
|
|
};
|
|
|
|
const tierFallbackClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(tierFallbackStream()),
|
|
};
|
|
|
|
const globalClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(globalStream()),
|
|
};
|
|
|
|
const router = new ModelRouter({
|
|
default: primaryClient,
|
|
fallbackChain: [globalClient],
|
|
tierFallbacks: new Map([['default', [tierFallbackClient]]]),
|
|
});
|
|
|
|
const chunks: string[] = [];
|
|
for await (const event of router.chatStream({ messages: [] })) {
|
|
if (event.type === 'content' && event.content) {
|
|
chunks.push(event.content);
|
|
}
|
|
}
|
|
|
|
expect(chunks).toEqual(['TierFallback']);
|
|
expect(globalClient.chatStream).not.toHaveBeenCalled();
|
|
});
|
|
});
|
|
|
|
describe('ModelRouter local client switching', () => {
|
|
it('allows setting a new local client', () => {
|
|
const mockDefault = { chat: vi.fn() } as unknown as ModelClient;
|
|
const mockLocal1 = { chat: vi.fn() } as unknown as ModelClient;
|
|
const mockLocal2 = { chat: vi.fn() } as unknown as ModelClient;
|
|
|
|
const router = new ModelRouter({
|
|
default: mockDefault,
|
|
local: mockLocal1,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
expect(router.getLocalProviderName()).toBe(undefined);
|
|
|
|
router.setLocalClient(mockLocal2, 'llamacpp');
|
|
|
|
expect(router.getLocalProviderName()).toBe('llamacpp');
|
|
expect(router.getClient('local')).toBe(mockLocal2);
|
|
});
|
|
});
|
|
|
|
describe('setClient and labels', () => {
|
|
it('setClient replaces an existing tier client', async () => {
|
|
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
|
|
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
|
|
|
|
const router = new ModelRouter({
|
|
default: { chat: vi.fn() } as unknown as ModelClient,
|
|
fast: mockClient1,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
|
|
|
|
expect(mockClient1.chat).toHaveBeenCalled();
|
|
expect(mockClient1.chat).toHaveBeenCalledTimes(1);
|
|
|
|
router.setClient('fast', mockClient2, 'fast-replaced');
|
|
|
|
const newFastClient = router.getClient('fast');
|
|
expect(newFastClient).toBeDefined();
|
|
if (!newFastClient) {
|
|
throw new Error('Expected fast client to be set');
|
|
}
|
|
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
|
|
|
|
expect(newFastClient.chat).toHaveBeenCalled();
|
|
expect(newFastClient.chat).toHaveBeenCalledTimes(1);
|
|
expect(mockClient1.chat).toHaveBeenCalledTimes(1);
|
|
});
|
|
|
|
it('setClient adds a new tier client', async () => {
|
|
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
|
|
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
|
|
|
|
const router = new ModelRouter({
|
|
default: mockClient1,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
expect(router.getClient('complex')).toBeUndefined();
|
|
|
|
router.setClient('complex', mockClient2, 'complex-tier');
|
|
|
|
const newClient = router.getClient('complex');
|
|
expect(newClient).toBe(mockClient2);
|
|
if (!newClient) {
|
|
throw new Error('Expected complex client to be set');
|
|
}
|
|
|
|
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'complex');
|
|
|
|
expect(newClient.chat).toHaveBeenCalled();
|
|
});
|
|
|
|
it('getLabel returns the label set by setClient', () => {
|
|
const router = new ModelRouter({
|
|
default: { chat: vi.fn() } as unknown as ModelClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
expect(router.getLabel('fast')).toBe('unknown');
|
|
|
|
router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier');
|
|
|
|
expect(router.getLabel('fast')).toBe('fast-tier');
|
|
});
|
|
|
|
it('getLabel returns "unknown" for unset tier', () => {
|
|
const router = new ModelRouter({
|
|
default: { chat: vi.fn() } as unknown as ModelClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
expect(router.getLabel('fast')).toBe('unknown');
|
|
expect(router.getLabel('complex')).toBe('unknown');
|
|
});
|
|
|
|
it('getAllLabels returns all tier labels', () => {
|
|
const router = new ModelRouter({
|
|
default: { chat: vi.fn() } as unknown as ModelClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
const labels = router.getAllLabels();
|
|
expect(labels).toEqual({});
|
|
|
|
router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier');
|
|
router.setClient('complex', { chat: vi.fn() } as unknown as ModelClient, 'complex-tier');
|
|
|
|
const allLabels = router.getAllLabels();
|
|
expect(allLabels).toEqual({
|
|
fast: 'fast-tier',
|
|
complex: 'complex-tier',
|
|
});
|
|
});
|
|
|
|
it('constructor accepts initial labels', async () => {
|
|
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
|
|
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
|
|
|
|
const router = new ModelRouter({
|
|
default: mockClient1,
|
|
fast: mockClient2,
|
|
fallbackChain: [],
|
|
labels: {
|
|
default: 'default-tier',
|
|
fast: 'fast-tier',
|
|
},
|
|
});
|
|
|
|
expect(router.getClient('default')).toBe(mockClient1);
|
|
expect(router.getClient('fast')).toBe(mockClient2);
|
|
expect(router.getLabel('default')).toBe('default-tier');
|
|
expect(router.getLabel('fast')).toBe('fast-tier');
|
|
expect(router.getLabel('complex')).toBe('unknown');
|
|
|
|
await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'fast');
|
|
|
|
expect(mockClient2.chat).toHaveBeenCalled();
|
|
});
|
|
|
|
it('chat uses the new client after setClient', async () => {
|
|
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
|
|
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
|
|
|
|
const router = new ModelRouter({
|
|
default: mockClient1,
|
|
fast: { chat: vi.fn() } as unknown as ModelClient,
|
|
fallbackChain: [],
|
|
labels: {
|
|
fast: 'original-fast',
|
|
},
|
|
});
|
|
|
|
const initialFastClient = router.getClient('fast');
|
|
expect(initialFastClient).toBeDefined();
|
|
if (!initialFastClient) {
|
|
throw new Error('Expected initial fast client to exist');
|
|
}
|
|
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
|
|
|
|
expect(initialFastClient.chat).toHaveBeenCalled();
|
|
expect(initialFastClient.chat).toHaveBeenCalledTimes(1);
|
|
|
|
router.setClient('fast', mockClient2, 'fast-replaced');
|
|
|
|
const newFastClient = router.getClient('fast');
|
|
if (!newFastClient) {
|
|
throw new Error('Expected replaced fast client to exist');
|
|
}
|
|
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
|
|
|
|
expect(newFastClient.chat).toHaveBeenCalled();
|
|
expect(newFastClient.chat).toHaveBeenCalledTimes(1);
|
|
expect(initialFastClient.chat).toHaveBeenCalledTimes(1);
|
|
});
|
|
|
|
it('strict tier mode disables fallback chain for that tier', async () => {
|
|
const failingDefault = {
|
|
chat: vi.fn().mockRejectedValue(new Error('primary failed')),
|
|
} as unknown as ModelClient;
|
|
const fallback = {
|
|
chat: vi.fn().mockResolvedValue({
|
|
content: 'fallback',
|
|
stopReason: 'end_turn',
|
|
usage: { inputTokens: 1, outputTokens: 1 },
|
|
}),
|
|
} as unknown as ModelClient;
|
|
|
|
const router = new ModelRouter({
|
|
default: failingDefault,
|
|
fallbackChain: [fallback],
|
|
});
|
|
|
|
router.setTierStrict('default', true);
|
|
|
|
await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'default'))
|
|
.rejects.toThrow('primary failed');
|
|
expect(fallback.chat).not.toHaveBeenCalled();
|
|
expect(router.isTierStrict('default')).toBe(true);
|
|
});
|
|
|
|
it('requestAbort interrupts retry loop before fallback chain', async () => {
|
|
const primary = {
|
|
chat: vi.fn().mockRejectedValue(new Error('temporary failure')),
|
|
} as unknown as ModelClient;
|
|
const fallback = {
|
|
chat: vi.fn().mockResolvedValue({
|
|
content: 'fallback',
|
|
stopReason: 'end_turn',
|
|
usage: { inputTokens: 1, outputTokens: 1 },
|
|
}),
|
|
} as unknown as ModelClient;
|
|
|
|
const router = new ModelRouter({
|
|
default: primary,
|
|
fallbackChain: [fallback],
|
|
retryConfig: {
|
|
maxRetries: 3,
|
|
initialDelayMs: 80,
|
|
backoffMultiplier: 1,
|
|
maxDelayMs: 80,
|
|
nonRetryablePatterns: [],
|
|
},
|
|
});
|
|
|
|
const run = router.chat({ messages: [{ role: 'user', content: 'hi' }] });
|
|
setTimeout(() => router.requestAbort(), 10);
|
|
|
|
await expect(run).rejects.toMatchObject({ name: 'AbortError' });
|
|
expect(primary.chat).toHaveBeenCalledTimes(1);
|
|
expect(fallback.chat).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('setOnTierChange does not replace existing listeners', () => {
|
|
const router = new ModelRouter({
|
|
default: { chat: vi.fn() } as unknown as ModelClient,
|
|
fast: { chat: vi.fn() } as unknown as ModelClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
const first = vi.fn();
|
|
const second = vi.fn();
|
|
router.addOnTierChange(first);
|
|
router.setOnTierChange(second);
|
|
|
|
router.setTier('fast');
|
|
|
|
expect(first).toHaveBeenCalledWith('fast');
|
|
expect(second).toHaveBeenCalledWith('fast');
|
|
});
|
|
});
|