feat: add runtime provider/model switching via /model <tier> <provider/model>

- ModelRouter: add setClient(), labels map, getLabel(), getAllLabels()
- TUI commands: parse /model <tier> <provider/model> syntax with autocompletion
- TUI minimal: handle provider switching via createClientFromConfig factory
- Daemon: wire initial labels into router config
- Fix /model alias mappings (opus=complex, sonnet=default, haiku=fast)
- Add design doc and update state.json with feature status
This commit is contained in:
William Valentin
2026-02-06 23:42:14 -08:00
parent e92ce69067
commit d4530a7034
8 changed files with 527 additions and 37 deletions
+6
View File
@@ -179,6 +179,12 @@ function createModelRouter(config: Config): ModelRouter {
local: localClient,
fallbackChain,
retryConfig,
labels: {
default: `${models.default.provider}/${models.default.model}`,
...(models.fast ? { fast: `${models.fast.provider}/${models.fast.model}` } : {}),
...(models.complex ? { complex: `${models.complex.provider}/${models.complex.model}` } : {}),
...(models.local ? { local: `${models.local.provider}/${models.local.model}` } : {}),
},
});
}
+26
View File
@@ -43,6 +43,32 @@ describe('parseCommand', () => {
expect(parseCommand('/model opus')).toEqual({ type: 'model', name: 'opus' });
});
it('parses /model with provider/model', () => {
expect(parseCommand('/model default anthropic/claude-sonnet-4')).toEqual({
type: 'model',
name: 'default',
providerModel: 'anthropic/claude-sonnet-4',
});
expect(parseCommand('/model fast github-copilot/gpt-4o-mini')).toEqual({
type: 'model',
name: 'fast',
providerModel: 'github-copilot/gpt-4o-mini',
});
expect(parseCommand('/model complex openai/o3')).toEqual({
type: 'model',
name: 'complex',
providerModel: 'openai/o3',
});
});
it('still parses /model fast as tier switch (no providerModel)', () => {
expect(parseCommand('/model fast')).toEqual({ type: 'model', name: 'fast' });
});
it('still parses /model as info (no args)', () => {
expect(parseCommand('/model')).toEqual({ type: 'model' });
});
it('parses /backend command without argument', () => {
expect(parseCommand('/backend')).toEqual({ type: 'backend' });
});
+69 -28
View File
@@ -6,7 +6,7 @@ export type Command =
| { type: 'fullscreen' }
| { type: 'compact' }
| { type: 'usage' }
| { type: 'model'; name?: string }
| { type: 'model'; name?: string; providerModel?: string }
| { type: 'backend'; provider?: string }
| { type: 'login'; provider?: string }
| { type: 'transfer'; target: string }
@@ -56,7 +56,16 @@ export function parseCommand(input: string): Command | null {
return { type: 'model' };
}
if (trimmed.startsWith('/model ')) {
const name = trimmed.slice('/model '.length).trim();
const args = trimmed.slice('/model '.length).trim();
const parts = args.split(/\s+/);
// /model <tier> <provider/model> - change tier's provider/model
if (parts.length === 2 && parts[1].includes('/')) {
return { type: 'model', name: parts[0], providerModel: parts[1] };
}
// /model <name> - single word (backward compatibility)
const name = parts[0];
return { type: 'model', name };
}
@@ -92,7 +101,8 @@ export function getHelpText(): string {
return `
Commands:
/help, /? Show this help
/model [name] Show or switch model (local, default, fast, complex)
/model [name] Show or switch model tier (local, default, fast, complex)
/model <tier> <p/m> Change tier's provider/model (e.g. /model default anthropic/claude-sonnet-4)
/backend [provider] Show or switch local backend (ollama, llamacpp)
/login [provider] Authenticate with GitHub
/reset, /clear, /new Clear conversation history
@@ -105,7 +115,7 @@ Commands:
`.trim();
}
export type ModelAlias = 'local' | 'default' | 'fast' | 'complex' | 'opus' | 'sonnet' | 'ollama';
export type ModelAlias = 'local' | 'default' | 'fast' | 'complex' | 'opus' | 'sonnet' | 'haiku' | 'ollama';
// List of all slash commands for autocompletion
export const SLASH_COMMANDS = [
@@ -146,28 +156,44 @@ export const COMMAND_TOOLTIPS: Record<string, string> = {
};
// Model aliases for /model command autocompletion
export const MODEL_ALIASES = ['local', 'default', 'fast', 'complex', 'opus', 'sonnet', 'ollama'];
export const MODEL_ALIASES = ['local', 'default', 'fast', 'complex', 'opus', 'sonnet', 'haiku', 'ollama'];
// Provider names for /model <tier> <provider/model> syntax
export const PROVIDER_NAMES = ['anthropic', 'openai', 'github-copilot', 'gemini', 'bedrock', 'ollama', 'llamacpp'];
// Model alias descriptions
export const MODEL_TOOLTIPS: Record<string, string> = {
local: 'Local Ollama model',
default: 'Default model (Opus)',
fast: 'Fast model (Sonnet)',
complex: 'Complex reasoning model',
opus: 'Claude Opus',
sonnet: 'Claude Sonnet',
ollama: 'Local Ollama model',
local: 'Local model (Ollama/llama.cpp)',
default: 'Default model tier',
fast: 'Fast/lightweight model tier',
complex: 'Complex reasoning model tier',
opus: 'Alias for complex tier',
sonnet: 'Alias for default tier',
haiku: 'Alias for fast tier',
ollama: 'Alias for local tier',
};
export function getCommandCompletions(partial: string): string[] {
const trimmed = partial.trim();
// Complete /model arguments
// Complete /model <tier> <provider/model>
if (trimmed.startsWith('/model ')) {
const modelPartial = trimmed.slice('/model '.length).toLowerCase();
return MODEL_ALIASES
.filter(alias => alias.startsWith(modelPartial))
.map(alias => `/model ${alias}`);
const args = trimmed.slice('/model '.length).trim();
const parts = args.split(/\s+/);
if (parts.length === 1) {
// Single word - suggest model aliases
const modelPartial = parts[0].toLowerCase();
return MODEL_ALIASES
.filter(alias => alias.startsWith(modelPartial))
.map(alias => `/model ${alias}`);
} else if (parts.length === 2) {
// Two words - suggest provider prefixes
const providerPartial = parts[1].toLowerCase();
return PROVIDER_NAMES
.filter(provider => provider.startsWith(providerPartial))
.map(provider => `/model ${parts[0]} ${provider}`);
}
}
// Complete slash commands
@@ -183,16 +209,30 @@ export function getCommandTooltip(partial: string): string | null {
// Tooltip for /model arguments
if (trimmed.startsWith('/model ')) {
const modelArg = trimmed.slice('/model '.length).trim();
if (modelArg && MODEL_TOOLTIPS[modelArg]) {
return MODEL_TOOLTIPS[modelArg];
const args = trimmed.slice('/model '.length).trim();
const parts = args.split(/\s+/);
if (parts.length === 1) {
// Single word - model tier or provider
const modelArg = parts[0].toLowerCase();
if (modelArg && MODEL_TOOLTIPS[modelArg]) {
return MODEL_TOOLTIPS[modelArg];
}
// Show tooltip for partial match
const matches = MODEL_ALIASES.filter(a => a.startsWith(modelArg));
if (matches.length === 1 && MODEL_TOOLTIPS[matches[0]]) {
return MODEL_TOOLTIPS[matches[0]];
}
return 'Choose: local, default, fast, complex';
} else if (parts.length === 2) {
// Two words - tier + provider
const providerPartial = parts[1].toLowerCase();
const matches = PROVIDER_NAMES.filter(p => p.startsWith(providerPartial));
if (matches.length === 1) {
return `Enter provider/model (e.g. ${matches[0]}/...)`;
}
return `Enter provider/model (e.g. anthropic/claude-sonnet-4)`;
}
// Show tooltip for partial match
const matches = MODEL_ALIASES.filter(a => a.startsWith(modelArg));
if (matches.length === 1 && MODEL_TOOLTIPS[matches[0]]) {
return MODEL_TOOLTIPS[matches[0]];
}
return 'Choose: local, default, fast, complex';
}
// Exact match tooltip
@@ -216,10 +256,11 @@ export function resolveModelAlias(alias: string): 'local' | 'default' | 'fast' |
local: 'local',
ollama: 'local',
default: 'default',
opus: 'default',
sonnet: 'default',
fast: 'fast',
sonnet: 'fast',
haiku: 'fast',
complex: 'complex',
opus: 'complex',
};
return map[alias.toLowerCase()] ?? 'default';
}
+38 -7
View File
@@ -7,6 +7,7 @@ import { parseCommand, getHelpText, resolveModelAlias, getCommandCompletions, ge
import { renderMarkdown } from './markdown.js';
import type { ModelConfig } from '../../config/schema.js';
import { OllamaClient, LlamaCppClient } from '../../models/index.js';
import { createClientFromConfig } from '../../daemon/index.js';
import { loginGitHub } from '../../auth/index.js';
export { parseCommand, type Command };
@@ -180,7 +181,7 @@ export class MinimalTui {
break;
case 'model':
this.handleModelCommand(command.name);
this.handleModelCommand(command.name, command.providerModel);
break;
case 'backend':
@@ -201,21 +202,51 @@ export class MinimalTui {
}
}
private handleModelCommand(name?: string): void {
private handleModelCommand(name?: string, providerModel?: string): void {
const router = this.config.modelRouter;
if (!router) {
console.log(`${colors.gray}Model switching not available.${colors.reset}\n`);
return;
}
if (!name) {
const current = router.getTier();
const available = router.getAvailableTiers();
console.log(`${colors.gray}Current model:${colors.reset} ${current}`);
console.log(`${colors.gray}Available:${colors.reset} ${available.join(', ')}\n`);
// /model <tier> <provider/model> — change a tier's provider and model
if (name && providerModel) {
const tier = resolveModelAlias(name);
const slashIdx = providerModel.indexOf('/');
if (slashIdx === -1) {
console.log(`${colors.gray}Invalid format. Use provider/model (e.g. anthropic/claude-sonnet-4)${colors.reset}\n`);
return;
}
const provider = providerModel.slice(0, slashIdx);
const model = providerModel.slice(slashIdx + 1);
try {
const client = createClientFromConfig({ provider: provider as 'anthropic', model });
router.setClient(tier, client, providerModel);
console.log(`${colors.gray}Set ${tier} to:${colors.reset} ${providerModel}\n`);
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
console.log(`${colors.gray}Failed to create client:${colors.reset} ${message}\n`);
}
return;
}
// /model — show all tiers with labels
if (!name) {
const current = router.getTier();
const available = router.getAvailableTiers();
const labels = router.getAllLabels();
console.log(`${colors.gray}Active tier:${colors.reset} ${current}`);
for (const tier of available) {
const label = labels[tier] ?? 'unknown';
const marker = tier === current ? ' ←' : '';
console.log(` ${tier}: ${label}${marker}`);
}
console.log();
return;
}
// /model <tier> — switch active tier
const tier = resolveModelAlias(name);
if (router.setTier(tier)) {
// Also update the agent tier so chatWithRouter uses the correct client
+146
View File
@@ -169,3 +169,149 @@ describe('ModelRouter local client switching', () => {
expect(router.getClient('local')).toBe(mockLocal2);
});
});
describe('setClient and labels', () => {
it('setClient replaces an existing tier client', async () => {
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
fast: mockClient1,
fallbackChain: [],
});
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
expect(mockClient1.chat).toHaveBeenCalled();
expect(mockClient1.chat).toHaveBeenCalledTimes(1);
router.setClient('fast', mockClient2, 'fast-replaced');
const newFastClient = router.getClient('fast');
expect(newFastClient).toBeDefined();
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
expect(newFastClient!.chat).toHaveBeenCalled();
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
expect(mockClient1.chat).toHaveBeenCalledTimes(1);
});
it('setClient adds a new tier client', async () => {
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
const router = new ModelRouter({
default: mockClient1,
fallbackChain: [],
});
expect(router.getClient('complex')).toBeUndefined();
router.setClient('complex', mockClient2, 'complex-tier');
const newClient = router.getClient('complex');
expect(newClient).toBe(mockClient2);
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'complex');
expect(newClient!.chat).toHaveBeenCalled();
});
it('getLabel returns the label set by setClient', () => {
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
fallbackChain: [],
});
expect(router.getLabel('fast')).toBe('unknown');
router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier');
expect(router.getLabel('fast')).toBe('fast-tier');
});
it('getLabel returns "unknown" for unset tier', () => {
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
fallbackChain: [],
});
expect(router.getLabel('fast')).toBe('unknown');
expect(router.getLabel('complex')).toBe('unknown');
});
it('getAllLabels returns all tier labels', () => {
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
fallbackChain: [],
});
const labels = router.getAllLabels();
expect(labels).toEqual({});
router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier');
router.setClient('complex', { chat: vi.fn() } as unknown as ModelClient, 'complex-tier');
const allLabels = router.getAllLabels();
expect(allLabels).toEqual({
fast: 'fast-tier',
complex: 'complex-tier',
});
});
it('constructor accepts initial labels', async () => {
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
const router = new ModelRouter({
default: mockClient1,
fast: mockClient2,
fallbackChain: [],
labels: {
default: 'default-tier',
fast: 'fast-tier',
},
});
expect(router.getClient('default')).toBe(mockClient1);
expect(router.getClient('fast')).toBe(mockClient2);
expect(router.getLabel('default')).toBe('default-tier');
expect(router.getLabel('fast')).toBe('fast-tier');
expect(router.getLabel('complex')).toBe('unknown');
await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'fast');
expect(mockClient2.chat).toHaveBeenCalled();
});
it('chat uses the new client after setClient', async () => {
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
const router = new ModelRouter({
default: mockClient1,
fast: { chat: vi.fn() } as unknown as ModelClient,
fallbackChain: [],
labels: {
fast: 'original-fast',
},
});
const initialFastClient = router.getClient('fast');
expect(initialFastClient).toBeDefined();
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
expect(initialFastClient!.chat).toHaveBeenCalled();
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
router.setClient('fast', mockClient2, 'fast-replaced');
const newFastClient = router.getClient('fast');
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
expect(newFastClient!.chat).toHaveBeenCalled();
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
});
});
+28
View File
@@ -11,10 +11,12 @@ export interface ModelRouterConfig {
local?: ModelClient;
fallbackChain: ModelClient[];
retryConfig?: RetryConfig;
labels?: Partial<Record<ModelTier, string>>;
}
export class ModelRouter implements ModelClient {
private clients: Map<ModelTier, ModelClient>;
private labels: Map<ModelTier, string>;
private defaultClient: ModelClient;
private fallbackChain: ModelClient[];
private currentTier: ModelTier = 'default';
@@ -23,6 +25,7 @@ export class ModelRouter implements ModelClient {
constructor(config: ModelRouterConfig) {
this.clients = new Map();
this.labels = new Map();
this.defaultClient = config.default;
this.fallbackChain = config.fallbackChain;
this.retryConfig = config.retryConfig;
@@ -31,6 +34,14 @@ export class ModelRouter implements ModelClient {
if (config.fast) this.clients.set('fast', config.fast);
if (config.complex) this.clients.set('complex', config.complex);
if (config.local) this.clients.set('local', config.local);
if (config.labels) {
for (const tier of ['fast', 'default', 'complex', 'local'] as ModelTier[]) {
if (config.labels[tier]) {
this.labels.set(tier, config.labels[tier]);
}
}
}
}
setTier(tier: ModelTier): boolean {
@@ -141,4 +152,21 @@ export class ModelRouter implements ModelClient {
getLocalProviderName(): string | undefined {
return this.localProviderName;
}
setClient(tier: ModelTier, client: ModelClient, label: string): void {
this.clients.set(tier, client);
this.labels.set(tier, label);
}
getLabel(tier: ModelTier): string {
return this.labels.get(tier) ?? 'unknown';
}
getAllLabels(): Record<string, string> {
const result: Record<string, string> = {};
for (const tier of this.labels.keys()) {
result[tier] = this.labels.get(tier) ?? 'unknown';
}
return result;
}
}