From 9a48c39b072c86c75cec61306222ac83be679002 Mon Sep 17 00:00:00 2001 From: William Valentin Date: Thu, 5 Feb 2026 10:48:41 -0800 Subject: [PATCH] feat(models): add streaming and tier switching to ModelRouter --- src/models/router.test.ts | 65 ++++++++++++++++++++++++++++++++++++++- src/models/router.ts | 60 ++++++++++++++++++++++++++++++++++-- 2 files changed, 122 insertions(+), 3 deletions(-) diff --git a/src/models/router.test.ts b/src/models/router.test.ts index 00fce36..6c08179 100644 --- a/src/models/router.test.ts +++ b/src/models/router.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { ModelRouter } from './router.js'; -import type { ModelClient, ChatResponse } from './types.js'; +import type { ModelClient, ChatResponse, ChatStreamEvent } from './types.js'; describe('ModelRouter', () => { const createMockClient = (name: string, shouldFail = false): ModelClient => ({ @@ -78,3 +78,66 @@ describe('ModelRouter', () => { expect(defaultClient.chat).not.toHaveBeenCalled(); }); }); + +describe('ModelRouter streaming', () => { + it('streams from primary client', async () => { + const mockStream = async function* (): AsyncIterable { + yield { type: 'content', content: 'Hello' }; + yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; + }; + + const mockClient = { + chat: vi.fn(), + chatStream: vi.fn().mockReturnValue(mockStream()), + }; + + const router = new ModelRouter({ + default: mockClient, + fallbackChain: [], + }); + + const chunks: string[] = []; + for await (const event of router.chatStream({ messages: [] })) { + if (event.type === 'content' && event.content) { + chunks.push(event.content); + } + } + + expect(chunks).toEqual(['Hello']); + }); + + it('falls back when primary stream fails', async () => { + const failingStream = async function* (): AsyncIterable { + yield { type: 'error', error: new Error('Primary failed') }; + }; + + const fallbackStream = async function* (): AsyncIterable { + yield { type: 'content', content: 'Fallback' }; + yield { type: 'done', usage: { inputTokens: 5, outputTokens: 3 } }; + }; + + const primaryClient = { + chat: vi.fn(), + chatStream: vi.fn().mockReturnValue(failingStream()), + }; + + const fallbackClient = { + chat: vi.fn(), + chatStream: vi.fn().mockReturnValue(fallbackStream()), + }; + + const router = new ModelRouter({ + default: primaryClient, + fallbackChain: [fallbackClient], + }); + + const chunks: string[] = []; + for await (const event of router.chatStream({ messages: [] })) { + if (event.type === 'content' && event.content) { + chunks.push(event.content); + } + } + + expect(chunks).toEqual(['Fallback']); + }); +}); diff --git a/src/models/router.ts b/src/models/router.ts index 58c13f8..3b0704e 100644 --- a/src/models/router.ts +++ b/src/models/router.ts @@ -1,4 +1,4 @@ -import type { ChatRequest, ChatResponse, ModelClient } from './types.js'; +import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient } from './types.js'; export type ModelTier = 'fast' | 'default' | 'complex' | 'local'; @@ -14,6 +14,7 @@ export class ModelRouter implements ModelClient { private clients: Map; private defaultClient: ModelClient; private fallbackChain: ModelClient[]; + private currentTier: ModelTier = 'default'; constructor(config: ModelRouterConfig) { this.clients = new Map(); @@ -26,8 +27,25 @@ export class ModelRouter implements ModelClient { if (config.local) this.clients.set('local', config.local); } + setTier(tier: ModelTier): boolean { + if (this.clients.has(tier)) { + this.currentTier = tier; + return true; + } + return false; + } + + getTier(): ModelTier { + return this.currentTier; + } + + getAvailableTiers(): ModelTier[] { + return Array.from(this.clients.keys()); + } + async chat(request: ChatRequest, tier?: ModelTier): Promise { - const primaryClient = tier ? this.clients.get(tier) ?? this.defaultClient : this.defaultClient; + const useTier = tier ?? this.currentTier; + const primaryClient = this.clients.get(useTier) ?? this.defaultClient; const errors: Error[] = []; // Try primary client @@ -52,6 +70,44 @@ export class ModelRouter implements ModelClient { throw new Error(`All model providers failed: ${errors.map(e => e.message).join(', ')}`); } + async *chatStream(request: ChatRequest, tier?: ModelTier): AsyncIterable { + const useTier = tier ?? this.currentTier; + const primaryClient = this.clients.get(useTier) ?? this.defaultClient; + + if (primaryClient.chatStream) { + let hasError = false; + for await (const event of primaryClient.chatStream(request)) { + if (event.type === 'error') { + hasError = true; + console.warn(`Primary stream failed: ${event.error?.message}`); + break; + } + yield event; + } + + if (!hasError) return; + } + + // Try fallback chain + for (const fallbackClient of this.fallbackChain) { + if (!fallbackClient.chatStream) continue; + + let hasError = false; + for await (const event of fallbackClient.chatStream(request)) { + if (event.type === 'error') { + hasError = true; + console.warn(`Fallback stream failed: ${event.error?.message}`); + break; + } + yield event; + } + + if (!hasError) return; + } + + yield { type: 'error', error: new Error('All streaming providers failed') }; + } + getClient(tier: ModelTier): ModelClient | undefined { return this.clients.get(tier); }