From d4530a703447af08842ae423dea220066e7ccb24 Mon Sep 17 00:00:00 2001 From: William Valentin Date: Fri, 6 Feb 2026 23:42:14 -0800 Subject: [PATCH] feat: add runtime provider/model switching via /model - ModelRouter: add setClient(), labels map, getLabel(), getAllLabels() - TUI commands: parse /model syntax with autocompletion - TUI minimal: handle provider switching via createClientFromConfig factory - Daemon: wire initial labels into router config - Fix /model alias mappings (opus=complex, sonnet=default, haiku=fast) - Add design doc and update state.json with feature status --- ...6-02-06-provider-model-switching-design.md | 79 ++++++++++ docs/plans/state.json | 137 +++++++++++++++- src/daemon/index.ts | 6 + src/frontends/tui/commands.test.ts | 26 ++++ src/frontends/tui/commands.ts | 97 ++++++++---- src/frontends/tui/minimal.ts | 45 +++++- src/models/router.test.ts | 146 ++++++++++++++++++ src/models/router.ts | 28 ++++ 8 files changed, 527 insertions(+), 37 deletions(-) create mode 100644 docs/plans/2026-02-06-provider-model-switching-design.md diff --git a/docs/plans/2026-02-06-provider-model-switching-design.md b/docs/plans/2026-02-06-provider-model-switching-design.md new file mode 100644 index 0000000..7398086 --- /dev/null +++ b/docs/plans/2026-02-06-provider-model-switching-design.md @@ -0,0 +1,79 @@ +# Provider/Model Runtime Switching + +**Date:** 2026-02-06 +**Status:** In Progress + +## Goal + +Enable easy runtime switching of model providers per tier via the `/model` command, using `provider/model` syntax. + +## Commands + +``` +/model — Show all tiers with provider/model labels +/model fast — Switch active tier (existing behavior) +/model default github-copilot/claude-sonnet-4-5 — Change default tier's provider+model +/model complex anthropic/claude-opus-4 — Change complex tier's provider+model +/model fast github-copilot/gpt-4o-mini — Change fast tier's provider+model +``` + +## Design Decisions + +1. **No presets** — Direct `provider/model` targeting per tier. YAGNI. +2. **Full override** — When you set a tier, it fully replaces the previous client. +3. **Local tier excluded** — `/model local` continues to use `/backend` for switching. Local models are a different concern. +4. **Auth** — Config-based `api_key` for most providers. `/login` OAuth flow for GitHub Copilot (already implemented). +5. **Merged into /model** — No separate `/provider` command. Everything lives under `/model`. + +## Implementation + +### 1. ModelRouter (`src/models/router.ts`) + +- Add `setClient(tier: ModelTier, client: ModelClient, label: string)` — replaces a tier's client at runtime +- Add `getLabel(tier: ModelTier): string` — returns `provider/model` string for display +- Track labels in a `Map` populated at construction and updated by `setClient()` + +### 2. Client Factory (`src/daemon/index.ts`) + +Extract `createClientFromProvider(provider: string, model: string, opts?: { apiKey?: string; endpoint?: string }): ModelClient` factory function from the existing inline client creation logic. Used by both daemon startup and runtime `/model` switching. + +### 3. Command Parser (`src/frontends/tui/commands.ts`) + +- Extend `Command` type: `{ type: 'model'; name?: string; providerModel?: string }` +- Parse `/model ` — split on space to get tier + provider/model +- Parse provider/model string: split on first `/` to get provider and model name +- Update autocompletion to suggest available providers after tier name +- Update tooltips + +### 4. Daemon Wiring (`src/daemon/index.ts`) + +Handle the new command variant: +1. Receive `{ type: 'model', name: 'default', providerModel: 'github-copilot/claude-sonnet-4-5' }` +2. Parse provider and model from `providerModel` +3. Call `createClientFromProvider(provider, model)` to instantiate client +4. Call `router.setClient('default', client, 'github-copilot/claude-sonnet-4-5')` +5. Respond with confirmation message + +### 5. Provider Name Mapping + +Map short provider names to client constructors: + +| Provider Name | Client Class | +|---------------|-------------| +| `anthropic` | AnthropicClient | +| `openai` | OpenAIClient | +| `github` / `github-copilot` | GitHubModelsClient | +| `gemini` | GeminiClient | +| `bedrock` | BedrockClient | +| `ollama` | OllamaClient | +| `llamacpp` | LlamaCppClient | + +## Files Changed + +| File | Change | +|------|--------| +| `src/models/router.ts` | Add `setClient()`, `getLabel()`, label tracking | +| `src/models/router.test.ts` | Tests for new methods | +| `src/frontends/tui/commands.ts` | Extended parser, completions, tooltips | +| `src/frontends/tui/commands.test.ts` | Tests for new parsing | +| `src/daemon/index.ts` | Extract factory, wire new command handler | diff --git a/docs/plans/state.json b/docs/plans/state.json index c6ce465..a570971 100644 --- a/docs/plans/state.json +++ b/docs/plans/state.json @@ -336,6 +336,136 @@ } } }, + "p4-media-pipeline": { + "status": "completed", + "date": "2026-02-06", + "summary": "Multimodal media pipeline: receive images from channel adapters and pass through to vision-capable models (Anthropic, OpenAI, Gemini, Bedrock)", + "phases": { + "type_widening": { + "priority": "P4", + "status": "completed", + "description": "Widen Message.content from string to string | MessageContentPart[], add Attachment type to channel layer, add ImageSource/MessageContentPart types", + "files_created": [ + "src/models/media.ts", + "src/models/media.test.ts" + ], + "files_modified": [ + "src/models/types.ts", + "src/models/index.ts", + "src/channels/types.ts", + "src/channels/index.ts" + ], + "test_status": "25/25 passing" + }, + "model_client_multimodal": { + "priority": "P4", + "status": "completed", + "description": "Update all model clients to convert MessageContentPart[] to provider-specific image formats (Anthropic base64, OpenAI data URI, Gemini inlineData, Bedrock image bytes)", + "files_modified": [ + "src/models/anthropic.ts", + "src/models/openai.ts", + "src/models/gemini.ts", + "src/models/bedrock.ts", + "src/models/local/llamacpp.ts", + "src/models/local/ollama.ts" + ] + }, + "agent_attachment_passthrough": { + "priority": "P4", + "status": "completed", + "description": "Wire attachments through NativeAgent.process() and AgentOrchestrator.process() to daemon message handler", + "files_modified": [ + "src/backends/native/agent.ts", + "src/backends/native/orchestrator.ts", + "src/daemon/index.ts" + ] + }, + "downstream_type_fixes": { + "priority": "P4", + "status": "completed", + "description": "Fix all consumers of Message.content to use getMessageText() helper: token estimation, compaction, TUI rendering", + "files_modified": [ + "src/context/tokens.ts", + "src/context/compaction.ts", + "src/frontends/tui/components/MessageList.tsx" + ] + }, + "channel_adapter_extraction": { + "priority": "P4", + "status": "completed", + "description": "Extract images from platform messages in all channel adapters", + "sub_phases": { + "telegram": { + "status": "completed", + "description": "Handle message:photo (largest size, download via getFile API, base64) and image message:document events with caption text", + "files_modified": ["src/channels/telegram/adapter.ts"] + }, + "discord": { + "status": "completed", + "description": "Extract image attachments from message.attachments Collection, pass Discord CDN URLs directly", + "files_modified": ["src/channels/discord/adapter.ts"] + }, + "slack": { + "status": "completed", + "description": "Download image files via url_private_download with bot token auth, base64 encode", + "files_modified": ["src/channels/slack/adapter.ts"] + }, + "whatsapp": { + "status": "completed", + "description": "Use downloadMedia() from whatsapp-web.js (returns base64 natively)", + "files_modified": ["src/channels/whatsapp/adapter.ts"] + }, + "webchat": { + "status": "deferred", + "description": "Requires gateway protocol update for WebSocket attachment messages" + } + } + } + } + }, + "p5-github-copilot-provider": { + "status": "completed", + "date": "2026-02-06", + "summary": "GitHub Copilot as a model provider with OAuth device flow and auto-login on first use", + "phases": { + "copilot_client": { + "priority": "P5", + "status": "completed", + "description": "GitHubModelsClient using OpenAI SDK against api.githubcopilot.com with Copilot-specific headers, multimodal support, streaming, tool calls", + "files_created": [ + "src/models/github.ts", + "src/auth/github.ts", + "src/auth/index.ts" + ], + "files_modified": [ + "src/config/schema.ts", + "src/models/index.ts", + "src/models/costs.ts", + "src/daemon/index.ts", + "src/cli/tui.ts" + ] + }, + "oauth_device_flow": { + "priority": "P5", + "status": "completed", + "description": "Interactive OAuth device flow via /login github command, token stored at ~/.config/flynn/auth.json with chmod 0600", + "files_modified": [ + "src/frontends/tui/minimal.ts", + "src/frontends/tui/commands.ts" + ] + }, + "auto_login": { + "priority": "P5", + "status": "completed", + "description": "Lazy token resolution with onLoginRequired callback — triggers OAuth device flow automatically on first API call when no token is available", + "files_modified": [ + "src/models/github.ts", + "src/daemon/index.ts", + "src/cli/tui.ts" + ] + } + } + }, "earlier_plans": { "status": "completed", "summary": "Original design and implementation phases from 2026-02-02 to 2026-02-05", @@ -361,11 +491,14 @@ }, "overall_progress": { - "total_test_count": 655, + "total_test_count": 742, "all_tests_passing": true, "p0_completion": "3/3 (100%)", "p1_completion": "4/4 (100%)", "p2_completion": "7/7 (100%)", - "next_up": "p3_remaining (group chat support, gateway auth, gemini provider, browser control, additional model providers)" + "p3_completion": "completed (group chat, gateway auth, Gemini, OpenRouter, Bedrock, browser control)", + "p4_completion": "1/1 (100%) — multimodal media pipeline", + "p5_completion": "1/1 (100%) — GitHub Copilot provider with auto-login", + "next_up": "p6 (image.analyze tool, audio transcription, outbound attachments, gateway protocol attachments)" } } diff --git a/src/daemon/index.ts b/src/daemon/index.ts index b594100..4ce088b 100644 --- a/src/daemon/index.ts +++ b/src/daemon/index.ts @@ -179,6 +179,12 @@ function createModelRouter(config: Config): ModelRouter { local: localClient, fallbackChain, retryConfig, + labels: { + default: `${models.default.provider}/${models.default.model}`, + ...(models.fast ? { fast: `${models.fast.provider}/${models.fast.model}` } : {}), + ...(models.complex ? { complex: `${models.complex.provider}/${models.complex.model}` } : {}), + ...(models.local ? { local: `${models.local.provider}/${models.local.model}` } : {}), + }, }); } diff --git a/src/frontends/tui/commands.test.ts b/src/frontends/tui/commands.test.ts index 63b04ea..8ae9adf 100644 --- a/src/frontends/tui/commands.test.ts +++ b/src/frontends/tui/commands.test.ts @@ -43,6 +43,32 @@ describe('parseCommand', () => { expect(parseCommand('/model opus')).toEqual({ type: 'model', name: 'opus' }); }); + it('parses /model with provider/model', () => { + expect(parseCommand('/model default anthropic/claude-sonnet-4')).toEqual({ + type: 'model', + name: 'default', + providerModel: 'anthropic/claude-sonnet-4', + }); + expect(parseCommand('/model fast github-copilot/gpt-4o-mini')).toEqual({ + type: 'model', + name: 'fast', + providerModel: 'github-copilot/gpt-4o-mini', + }); + expect(parseCommand('/model complex openai/o3')).toEqual({ + type: 'model', + name: 'complex', + providerModel: 'openai/o3', + }); + }); + + it('still parses /model fast as tier switch (no providerModel)', () => { + expect(parseCommand('/model fast')).toEqual({ type: 'model', name: 'fast' }); + }); + + it('still parses /model as info (no args)', () => { + expect(parseCommand('/model')).toEqual({ type: 'model' }); + }); + it('parses /backend command without argument', () => { expect(parseCommand('/backend')).toEqual({ type: 'backend' }); }); diff --git a/src/frontends/tui/commands.ts b/src/frontends/tui/commands.ts index c368381..6c5edaa 100644 --- a/src/frontends/tui/commands.ts +++ b/src/frontends/tui/commands.ts @@ -6,7 +6,7 @@ export type Command = | { type: 'fullscreen' } | { type: 'compact' } | { type: 'usage' } - | { type: 'model'; name?: string } + | { type: 'model'; name?: string; providerModel?: string } | { type: 'backend'; provider?: string } | { type: 'login'; provider?: string } | { type: 'transfer'; target: string } @@ -56,7 +56,16 @@ export function parseCommand(input: string): Command | null { return { type: 'model' }; } if (trimmed.startsWith('/model ')) { - const name = trimmed.slice('/model '.length).trim(); + const args = trimmed.slice('/model '.length).trim(); + const parts = args.split(/\s+/); + + // /model - change tier's provider/model + if (parts.length === 2 && parts[1].includes('/')) { + return { type: 'model', name: parts[0], providerModel: parts[1] }; + } + + // /model - single word (backward compatibility) + const name = parts[0]; return { type: 'model', name }; } @@ -92,7 +101,8 @@ export function getHelpText(): string { return ` Commands: /help, /? Show this help - /model [name] Show or switch model (local, default, fast, complex) + /model [name] Show or switch model tier (local, default, fast, complex) + /model

Change tier's provider/model (e.g. /model default anthropic/claude-sonnet-4) /backend [provider] Show or switch local backend (ollama, llamacpp) /login [provider] Authenticate with GitHub /reset, /clear, /new Clear conversation history @@ -105,7 +115,7 @@ Commands: `.trim(); } -export type ModelAlias = 'local' | 'default' | 'fast' | 'complex' | 'opus' | 'sonnet' | 'ollama'; +export type ModelAlias = 'local' | 'default' | 'fast' | 'complex' | 'opus' | 'sonnet' | 'haiku' | 'ollama'; // List of all slash commands for autocompletion export const SLASH_COMMANDS = [ @@ -146,28 +156,44 @@ export const COMMAND_TOOLTIPS: Record = { }; // Model aliases for /model command autocompletion -export const MODEL_ALIASES = ['local', 'default', 'fast', 'complex', 'opus', 'sonnet', 'ollama']; +export const MODEL_ALIASES = ['local', 'default', 'fast', 'complex', 'opus', 'sonnet', 'haiku', 'ollama']; + +// Provider names for /model syntax +export const PROVIDER_NAMES = ['anthropic', 'openai', 'github-copilot', 'gemini', 'bedrock', 'ollama', 'llamacpp']; // Model alias descriptions export const MODEL_TOOLTIPS: Record = { - local: 'Local Ollama model', - default: 'Default model (Opus)', - fast: 'Fast model (Sonnet)', - complex: 'Complex reasoning model', - opus: 'Claude Opus', - sonnet: 'Claude Sonnet', - ollama: 'Local Ollama model', + local: 'Local model (Ollama/llama.cpp)', + default: 'Default model tier', + fast: 'Fast/lightweight model tier', + complex: 'Complex reasoning model tier', + opus: 'Alias for complex tier', + sonnet: 'Alias for default tier', + haiku: 'Alias for fast tier', + ollama: 'Alias for local tier', }; export function getCommandCompletions(partial: string): string[] { const trimmed = partial.trim(); - // Complete /model arguments + // Complete /model if (trimmed.startsWith('/model ')) { - const modelPartial = trimmed.slice('/model '.length).toLowerCase(); - return MODEL_ALIASES - .filter(alias => alias.startsWith(modelPartial)) - .map(alias => `/model ${alias}`); + const args = trimmed.slice('/model '.length).trim(); + const parts = args.split(/\s+/); + + if (parts.length === 1) { + // Single word - suggest model aliases + const modelPartial = parts[0].toLowerCase(); + return MODEL_ALIASES + .filter(alias => alias.startsWith(modelPartial)) + .map(alias => `/model ${alias}`); + } else if (parts.length === 2) { + // Two words - suggest provider prefixes + const providerPartial = parts[1].toLowerCase(); + return PROVIDER_NAMES + .filter(provider => provider.startsWith(providerPartial)) + .map(provider => `/model ${parts[0]} ${provider}`); + } } // Complete slash commands @@ -183,16 +209,30 @@ export function getCommandTooltip(partial: string): string | null { // Tooltip for /model arguments if (trimmed.startsWith('/model ')) { - const modelArg = trimmed.slice('/model '.length).trim(); - if (modelArg && MODEL_TOOLTIPS[modelArg]) { - return MODEL_TOOLTIPS[modelArg]; + const args = trimmed.slice('/model '.length).trim(); + const parts = args.split(/\s+/); + + if (parts.length === 1) { + // Single word - model tier or provider + const modelArg = parts[0].toLowerCase(); + if (modelArg && MODEL_TOOLTIPS[modelArg]) { + return MODEL_TOOLTIPS[modelArg]; + } + // Show tooltip for partial match + const matches = MODEL_ALIASES.filter(a => a.startsWith(modelArg)); + if (matches.length === 1 && MODEL_TOOLTIPS[matches[0]]) { + return MODEL_TOOLTIPS[matches[0]]; + } + return 'Choose: local, default, fast, complex'; + } else if (parts.length === 2) { + // Two words - tier + provider + const providerPartial = parts[1].toLowerCase(); + const matches = PROVIDER_NAMES.filter(p => p.startsWith(providerPartial)); + if (matches.length === 1) { + return `Enter provider/model (e.g. ${matches[0]}/...)`; + } + return `Enter provider/model (e.g. anthropic/claude-sonnet-4)`; } - // Show tooltip for partial match - const matches = MODEL_ALIASES.filter(a => a.startsWith(modelArg)); - if (matches.length === 1 && MODEL_TOOLTIPS[matches[0]]) { - return MODEL_TOOLTIPS[matches[0]]; - } - return 'Choose: local, default, fast, complex'; } // Exact match tooltip @@ -216,10 +256,11 @@ export function resolveModelAlias(alias: string): 'local' | 'default' | 'fast' | local: 'local', ollama: 'local', default: 'default', - opus: 'default', + sonnet: 'default', fast: 'fast', - sonnet: 'fast', + haiku: 'fast', complex: 'complex', + opus: 'complex', }; return map[alias.toLowerCase()] ?? 'default'; } diff --git a/src/frontends/tui/minimal.ts b/src/frontends/tui/minimal.ts index 6078eae..67f4cea 100644 --- a/src/frontends/tui/minimal.ts +++ b/src/frontends/tui/minimal.ts @@ -7,6 +7,7 @@ import { parseCommand, getHelpText, resolveModelAlias, getCommandCompletions, ge import { renderMarkdown } from './markdown.js'; import type { ModelConfig } from '../../config/schema.js'; import { OllamaClient, LlamaCppClient } from '../../models/index.js'; +import { createClientFromConfig } from '../../daemon/index.js'; import { loginGitHub } from '../../auth/index.js'; export { parseCommand, type Command }; @@ -180,7 +181,7 @@ export class MinimalTui { break; case 'model': - this.handleModelCommand(command.name); + this.handleModelCommand(command.name, command.providerModel); break; case 'backend': @@ -201,21 +202,51 @@ export class MinimalTui { } } - private handleModelCommand(name?: string): void { + private handleModelCommand(name?: string, providerModel?: string): void { const router = this.config.modelRouter; if (!router) { console.log(`${colors.gray}Model switching not available.${colors.reset}\n`); return; } - if (!name) { - const current = router.getTier(); - const available = router.getAvailableTiers(); - console.log(`${colors.gray}Current model:${colors.reset} ${current}`); - console.log(`${colors.gray}Available:${colors.reset} ${available.join(', ')}\n`); + // /model — change a tier's provider and model + if (name && providerModel) { + const tier = resolveModelAlias(name); + const slashIdx = providerModel.indexOf('/'); + if (slashIdx === -1) { + console.log(`${colors.gray}Invalid format. Use provider/model (e.g. anthropic/claude-sonnet-4)${colors.reset}\n`); + return; + } + const provider = providerModel.slice(0, slashIdx); + const model = providerModel.slice(slashIdx + 1); + + try { + const client = createClientFromConfig({ provider: provider as 'anthropic', model }); + router.setClient(tier, client, providerModel); + console.log(`${colors.gray}Set ${tier} to:${colors.reset} ${providerModel}\n`); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + console.log(`${colors.gray}Failed to create client:${colors.reset} ${message}\n`); + } return; } + // /model — show all tiers with labels + if (!name) { + const current = router.getTier(); + const available = router.getAvailableTiers(); + const labels = router.getAllLabels(); + console.log(`${colors.gray}Active tier:${colors.reset} ${current}`); + for (const tier of available) { + const label = labels[tier] ?? 'unknown'; + const marker = tier === current ? ' ←' : ''; + console.log(` ${tier}: ${label}${marker}`); + } + console.log(); + return; + } + + // /model — switch active tier const tier = resolveModelAlias(name); if (router.setTier(tier)) { // Also update the agent tier so chatWithRouter uses the correct client diff --git a/src/models/router.test.ts b/src/models/router.test.ts index d6697c1..013e10e 100644 --- a/src/models/router.test.ts +++ b/src/models/router.test.ts @@ -169,3 +169,149 @@ describe('ModelRouter local client switching', () => { expect(router.getClient('local')).toBe(mockLocal2); }); }); + +describe('setClient and labels', () => { + it('setClient replaces an existing tier client', async () => { + const mockClient1 = { chat: vi.fn() } as unknown as ModelClient; + const mockClient2 = { chat: vi.fn() } as unknown as ModelClient; + + const router = new ModelRouter({ + default: { chat: vi.fn() } as unknown as ModelClient, + fast: mockClient1, + fallbackChain: [], + }); + + await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast'); + + expect(mockClient1.chat).toHaveBeenCalled(); + expect(mockClient1.chat).toHaveBeenCalledTimes(1); + + router.setClient('fast', mockClient2, 'fast-replaced'); + + const newFastClient = router.getClient('fast'); + expect(newFastClient).toBeDefined(); + await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast'); + + expect(newFastClient!.chat).toHaveBeenCalled(); + expect(newFastClient!.chat).toHaveBeenCalledTimes(1); + expect(mockClient1.chat).toHaveBeenCalledTimes(1); + }); + + it('setClient adds a new tier client', async () => { + const mockClient1 = { chat: vi.fn() } as unknown as ModelClient; + const mockClient2 = { chat: vi.fn() } as unknown as ModelClient; + + const router = new ModelRouter({ + default: mockClient1, + fallbackChain: [], + }); + + expect(router.getClient('complex')).toBeUndefined(); + + router.setClient('complex', mockClient2, 'complex-tier'); + + const newClient = router.getClient('complex'); + expect(newClient).toBe(mockClient2); + + await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'complex'); + + expect(newClient!.chat).toHaveBeenCalled(); + }); + + it('getLabel returns the label set by setClient', () => { + const router = new ModelRouter({ + default: { chat: vi.fn() } as unknown as ModelClient, + fallbackChain: [], + }); + + expect(router.getLabel('fast')).toBe('unknown'); + + router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier'); + + expect(router.getLabel('fast')).toBe('fast-tier'); + }); + + it('getLabel returns "unknown" for unset tier', () => { + const router = new ModelRouter({ + default: { chat: vi.fn() } as unknown as ModelClient, + fallbackChain: [], + }); + + expect(router.getLabel('fast')).toBe('unknown'); + expect(router.getLabel('complex')).toBe('unknown'); + }); + + it('getAllLabels returns all tier labels', () => { + const router = new ModelRouter({ + default: { chat: vi.fn() } as unknown as ModelClient, + fallbackChain: [], + }); + + const labels = router.getAllLabels(); + expect(labels).toEqual({}); + + router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier'); + router.setClient('complex', { chat: vi.fn() } as unknown as ModelClient, 'complex-tier'); + + const allLabels = router.getAllLabels(); + expect(allLabels).toEqual({ + fast: 'fast-tier', + complex: 'complex-tier', + }); + }); + + it('constructor accepts initial labels', async () => { + const mockClient1 = { chat: vi.fn() } as unknown as ModelClient; + const mockClient2 = { chat: vi.fn() } as unknown as ModelClient; + + const router = new ModelRouter({ + default: mockClient1, + fast: mockClient2, + fallbackChain: [], + labels: { + default: 'default-tier', + fast: 'fast-tier', + }, + }); + + expect(router.getClient('default')).toBe(mockClient1); + expect(router.getClient('fast')).toBe(mockClient2); + expect(router.getLabel('default')).toBe('default-tier'); + expect(router.getLabel('fast')).toBe('fast-tier'); + expect(router.getLabel('complex')).toBe('unknown'); + + await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'fast'); + + expect(mockClient2.chat).toHaveBeenCalled(); + }); + + it('chat uses the new client after setClient', async () => { + const mockClient1 = { chat: vi.fn() } as unknown as ModelClient; + const mockClient2 = { chat: vi.fn() } as unknown as ModelClient; + + const router = new ModelRouter({ + default: mockClient1, + fast: { chat: vi.fn() } as unknown as ModelClient, + fallbackChain: [], + labels: { + fast: 'original-fast', + }, + }); + + const initialFastClient = router.getClient('fast'); + expect(initialFastClient).toBeDefined(); + await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast'); + + expect(initialFastClient!.chat).toHaveBeenCalled(); + expect(initialFastClient!.chat).toHaveBeenCalledTimes(1); + + router.setClient('fast', mockClient2, 'fast-replaced'); + + const newFastClient = router.getClient('fast'); + await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast'); + + expect(newFastClient!.chat).toHaveBeenCalled(); + expect(newFastClient!.chat).toHaveBeenCalledTimes(1); + expect(initialFastClient!.chat).toHaveBeenCalledTimes(1); + }); +}); diff --git a/src/models/router.ts b/src/models/router.ts index 38d4a46..d33cbd9 100644 --- a/src/models/router.ts +++ b/src/models/router.ts @@ -11,10 +11,12 @@ export interface ModelRouterConfig { local?: ModelClient; fallbackChain: ModelClient[]; retryConfig?: RetryConfig; + labels?: Partial>; } export class ModelRouter implements ModelClient { private clients: Map; + private labels: Map; private defaultClient: ModelClient; private fallbackChain: ModelClient[]; private currentTier: ModelTier = 'default'; @@ -23,6 +25,7 @@ export class ModelRouter implements ModelClient { constructor(config: ModelRouterConfig) { this.clients = new Map(); + this.labels = new Map(); this.defaultClient = config.default; this.fallbackChain = config.fallbackChain; this.retryConfig = config.retryConfig; @@ -31,6 +34,14 @@ export class ModelRouter implements ModelClient { if (config.fast) this.clients.set('fast', config.fast); if (config.complex) this.clients.set('complex', config.complex); if (config.local) this.clients.set('local', config.local); + + if (config.labels) { + for (const tier of ['fast', 'default', 'complex', 'local'] as ModelTier[]) { + if (config.labels[tier]) { + this.labels.set(tier, config.labels[tier]); + } + } + } } setTier(tier: ModelTier): boolean { @@ -141,4 +152,21 @@ export class ModelRouter implements ModelClient { getLocalProviderName(): string | undefined { return this.localProviderName; } + + setClient(tier: ModelTier, client: ModelClient, label: string): void { + this.clients.set(tier, client); + this.labels.set(tier, label); + } + + getLabel(tier: ModelTier): string { + return this.labels.get(tier) ?? 'unknown'; + } + + getAllLabels(): Record { + const result: Record = {}; + for (const tier of this.labels.keys()) { + result[tier] = this.labels.get(tier) ?? 'unknown'; + } + return result; + } }