144 lines
4.3 KiB
TypeScript
144 lines
4.3 KiB
TypeScript
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
|
import { ModelRouter } from './router.js';
|
|
import type { ModelClient, ChatResponse, ChatStreamEvent } from './types.js';
|
|
|
|
describe('ModelRouter', () => {
|
|
const createMockClient = (name: string, shouldFail = false): ModelClient => ({
|
|
chat: vi.fn().mockImplementation(async () => {
|
|
if (shouldFail) {
|
|
throw new Error(`${name} failed`);
|
|
}
|
|
return {
|
|
content: `Response from ${name}`,
|
|
stopReason: 'end_turn',
|
|
usage: { inputTokens: 10, outputTokens: 5 },
|
|
} satisfies ChatResponse;
|
|
}),
|
|
});
|
|
|
|
it('uses default client when available', async () => {
|
|
const defaultClient = createMockClient('default');
|
|
const router = new ModelRouter({
|
|
default: defaultClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
|
|
|
expect(response.content).toBe('Response from default');
|
|
expect(defaultClient.chat).toHaveBeenCalled();
|
|
});
|
|
|
|
it('falls back to next provider on failure', async () => {
|
|
const failingClient = createMockClient('primary', true);
|
|
const fallbackClient = createMockClient('fallback');
|
|
|
|
const router = new ModelRouter({
|
|
default: failingClient,
|
|
fallbackChain: [fallbackClient],
|
|
});
|
|
|
|
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
|
|
|
expect(response.content).toBe('Response from fallback');
|
|
expect(failingClient.chat).toHaveBeenCalled();
|
|
expect(fallbackClient.chat).toHaveBeenCalled();
|
|
});
|
|
|
|
it('throws when all providers fail', async () => {
|
|
const failing1 = createMockClient('primary', true);
|
|
const failing2 = createMockClient('fallback', true);
|
|
|
|
const router = new ModelRouter({
|
|
default: failing1,
|
|
fallbackChain: [failing2],
|
|
});
|
|
|
|
await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] }))
|
|
.rejects.toThrow('All model providers failed');
|
|
});
|
|
|
|
it('uses tier-specific client when specified', async () => {
|
|
const defaultClient = createMockClient('default');
|
|
const fastClient = createMockClient('fast');
|
|
|
|
const router = new ModelRouter({
|
|
default: defaultClient,
|
|
fast: fastClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
const response = await router.chat(
|
|
{ messages: [{ role: 'user', content: 'Hi' }] },
|
|
'fast'
|
|
);
|
|
|
|
expect(response.content).toBe('Response from fast');
|
|
expect(fastClient.chat).toHaveBeenCalled();
|
|
expect(defaultClient.chat).not.toHaveBeenCalled();
|
|
});
|
|
});
|
|
|
|
describe('ModelRouter streaming', () => {
|
|
it('streams from primary client', async () => {
|
|
const mockStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'content', content: 'Hello' };
|
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
|
};
|
|
|
|
const mockClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(mockStream()),
|
|
};
|
|
|
|
const router = new ModelRouter({
|
|
default: mockClient,
|
|
fallbackChain: [],
|
|
});
|
|
|
|
const chunks: string[] = [];
|
|
for await (const event of router.chatStream({ messages: [] })) {
|
|
if (event.type === 'content' && event.content) {
|
|
chunks.push(event.content);
|
|
}
|
|
}
|
|
|
|
expect(chunks).toEqual(['Hello']);
|
|
});
|
|
|
|
it('falls back when primary stream fails', async () => {
|
|
const failingStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'error', error: new Error('Primary failed') };
|
|
};
|
|
|
|
const fallbackStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
|
yield { type: 'content', content: 'Fallback' };
|
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
|
};
|
|
|
|
const primaryClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(failingStream()),
|
|
};
|
|
|
|
const fallbackClient = {
|
|
chat: vi.fn(),
|
|
chatStream: vi.fn().mockReturnValue(fallbackStream()),
|
|
};
|
|
|
|
const router = new ModelRouter({
|
|
default: primaryClient,
|
|
fallbackChain: [fallbackClient],
|
|
});
|
|
|
|
const chunks: string[] = [];
|
|
for await (const event of router.chatStream({ messages: [] })) {
|
|
if (event.type === 'content' && event.content) {
|
|
chunks.push(event.content);
|
|
}
|
|
}
|
|
|
|
expect(chunks).toEqual(['Fallback']);
|
|
});
|
|
});
|