feat(models): add streaming and tier switching to ModelRouter
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||||
import { ModelRouter } from './router.js';
|
import { ModelRouter } from './router.js';
|
||||||
import type { ModelClient, ChatResponse } from './types.js';
|
import type { ModelClient, ChatResponse, ChatStreamEvent } from './types.js';
|
||||||
|
|
||||||
describe('ModelRouter', () => {
|
describe('ModelRouter', () => {
|
||||||
const createMockClient = (name: string, shouldFail = false): ModelClient => ({
|
const createMockClient = (name: string, shouldFail = false): ModelClient => ({
|
||||||
@@ -78,3 +78,66 @@ describe('ModelRouter', () => {
|
|||||||
expect(defaultClient.chat).not.toHaveBeenCalled();
|
expect(defaultClient.chat).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('ModelRouter streaming', () => {
|
||||||
|
it('streams from primary client', async () => {
|
||||||
|
const mockStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||||
|
yield { type: 'content', content: 'Hello' };
|
||||||
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockClient = {
|
||||||
|
chat: vi.fn(),
|
||||||
|
chatStream: vi.fn().mockReturnValue(mockStream()),
|
||||||
|
};
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: mockClient,
|
||||||
|
fallbackChain: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const chunks: string[] = [];
|
||||||
|
for await (const event of router.chatStream({ messages: [] })) {
|
||||||
|
if (event.type === 'content' && event.content) {
|
||||||
|
chunks.push(event.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(chunks).toEqual(['Hello']);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('falls back when primary stream fails', async () => {
|
||||||
|
const failingStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||||
|
yield { type: 'error', error: new Error('Primary failed') };
|
||||||
|
};
|
||||||
|
|
||||||
|
const fallbackStream = async function* (): AsyncIterable<ChatStreamEvent> {
|
||||||
|
yield { type: 'content', content: 'Fallback' };
|
||||||
|
yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } };
|
||||||
|
};
|
||||||
|
|
||||||
|
const primaryClient = {
|
||||||
|
chat: vi.fn(),
|
||||||
|
chatStream: vi.fn().mockReturnValue(failingStream()),
|
||||||
|
};
|
||||||
|
|
||||||
|
const fallbackClient = {
|
||||||
|
chat: vi.fn(),
|
||||||
|
chatStream: vi.fn().mockReturnValue(fallbackStream()),
|
||||||
|
};
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: primaryClient,
|
||||||
|
fallbackChain: [fallbackClient],
|
||||||
|
});
|
||||||
|
|
||||||
|
const chunks: string[] = [];
|
||||||
|
for await (const event of router.chatStream({ messages: [] })) {
|
||||||
|
if (event.type === 'content' && event.content) {
|
||||||
|
chunks.push(event.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(chunks).toEqual(['Fallback']);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
+58
-2
@@ -1,4 +1,4 @@
|
|||||||
import type { ChatRequest, ChatResponse, ModelClient } from './types.js';
|
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from './types.js';
|
||||||
|
|
||||||
export type ModelTier = 'fast' | 'default' | 'complex' | 'local';
|
export type ModelTier = 'fast' | 'default' | 'complex' | 'local';
|
||||||
|
|
||||||
@@ -14,6 +14,7 @@ export class ModelRouter implements ModelClient {
|
|||||||
private clients: Map<ModelTier, ModelClient>;
|
private clients: Map<ModelTier, ModelClient>;
|
||||||
private defaultClient: ModelClient;
|
private defaultClient: ModelClient;
|
||||||
private fallbackChain: ModelClient[];
|
private fallbackChain: ModelClient[];
|
||||||
|
private currentTier: ModelTier = 'default';
|
||||||
|
|
||||||
constructor(config: ModelRouterConfig) {
|
constructor(config: ModelRouterConfig) {
|
||||||
this.clients = new Map();
|
this.clients = new Map();
|
||||||
@@ -26,8 +27,25 @@ export class ModelRouter implements ModelClient {
|
|||||||
if (config.local) this.clients.set('local', config.local);
|
if (config.local) this.clients.set('local', config.local);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setTier(tier: ModelTier): boolean {
|
||||||
|
if (this.clients.has(tier)) {
|
||||||
|
this.currentTier = tier;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
getTier(): ModelTier {
|
||||||
|
return this.currentTier;
|
||||||
|
}
|
||||||
|
|
||||||
|
getAvailableTiers(): ModelTier[] {
|
||||||
|
return Array.from(this.clients.keys());
|
||||||
|
}
|
||||||
|
|
||||||
async chat(request: ChatRequest, tier?: ModelTier): Promise<ChatResponse> {
|
async chat(request: ChatRequest, tier?: ModelTier): Promise<ChatResponse> {
|
||||||
const primaryClient = tier ? this.clients.get(tier) ?? this.defaultClient : this.defaultClient;
|
const useTier = tier ?? this.currentTier;
|
||||||
|
const primaryClient = this.clients.get(useTier) ?? this.defaultClient;
|
||||||
const errors: Error[] = [];
|
const errors: Error[] = [];
|
||||||
|
|
||||||
// Try primary client
|
// Try primary client
|
||||||
@@ -52,6 +70,44 @@ export class ModelRouter implements ModelClient {
|
|||||||
throw new Error(`All model providers failed: ${errors.map(e => e.message).join(', ')}`);
|
throw new Error(`All model providers failed: ${errors.map(e => e.message).join(', ')}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async *chatStream(request: ChatRequest, tier?: ModelTier): AsyncIterable<ChatStreamEvent> {
|
||||||
|
const useTier = tier ?? this.currentTier;
|
||||||
|
const primaryClient = this.clients.get(useTier) ?? this.defaultClient;
|
||||||
|
|
||||||
|
if (primaryClient.chatStream) {
|
||||||
|
let hasError = false;
|
||||||
|
for await (const event of primaryClient.chatStream(request)) {
|
||||||
|
if (event.type === 'error') {
|
||||||
|
hasError = true;
|
||||||
|
console.warn(`Primary stream failed: ${event.error?.message}`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasError) return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try fallback chain
|
||||||
|
for (const fallbackClient of this.fallbackChain) {
|
||||||
|
if (!fallbackClient.chatStream) continue;
|
||||||
|
|
||||||
|
let hasError = false;
|
||||||
|
for await (const event of fallbackClient.chatStream(request)) {
|
||||||
|
if (event.type === 'error') {
|
||||||
|
hasError = true;
|
||||||
|
console.warn(`Fallback stream failed: ${event.error?.message}`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
yield event;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!hasError) return;
|
||||||
|
}
|
||||||
|
|
||||||
|
yield { type: 'error', error: new Error('All streaming providers failed') };
|
||||||
|
}
|
||||||
|
|
||||||
getClient(tier: ModelTier): ModelClient | undefined {
|
getClient(tier: ModelTier): ModelClient | undefined {
|
||||||
return this.clients.get(tier);
|
return this.clients.get(tier);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user