feat: add runtime provider/model switching via /model <tier> <provider/model>

- ModelRouter: add setClient(), labels map, getLabel(), getAllLabels()
- TUI commands: parse /model <tier> <provider/model> syntax with autocompletion
- TUI minimal: handle provider switching via createClientFromConfig factory
- Daemon: wire initial labels into router config
- Fix /model alias mappings (opus=complex, sonnet=default, haiku=fast)
- Add design doc and update state.json with feature status
This commit is contained in:
William Valentin
2026-02-06 23:42:14 -08:00
parent e92ce69067
commit d4530a7034
8 changed files with 527 additions and 37 deletions
+146
View File
@@ -169,3 +169,149 @@ describe('ModelRouter local client switching', () => {
expect(router.getClient('local')).toBe(mockLocal2);
});
});
describe('setClient and labels', () => {
it('setClient replaces an existing tier client', async () => {
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
fast: mockClient1,
fallbackChain: [],
});
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
expect(mockClient1.chat).toHaveBeenCalled();
expect(mockClient1.chat).toHaveBeenCalledTimes(1);
router.setClient('fast', mockClient2, 'fast-replaced');
const newFastClient = router.getClient('fast');
expect(newFastClient).toBeDefined();
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
expect(newFastClient!.chat).toHaveBeenCalled();
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
expect(mockClient1.chat).toHaveBeenCalledTimes(1);
});
it('setClient adds a new tier client', async () => {
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
const router = new ModelRouter({
default: mockClient1,
fallbackChain: [],
});
expect(router.getClient('complex')).toBeUndefined();
router.setClient('complex', mockClient2, 'complex-tier');
const newClient = router.getClient('complex');
expect(newClient).toBe(mockClient2);
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'complex');
expect(newClient!.chat).toHaveBeenCalled();
});
it('getLabel returns the label set by setClient', () => {
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
fallbackChain: [],
});
expect(router.getLabel('fast')).toBe('unknown');
router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier');
expect(router.getLabel('fast')).toBe('fast-tier');
});
it('getLabel returns "unknown" for unset tier', () => {
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
fallbackChain: [],
});
expect(router.getLabel('fast')).toBe('unknown');
expect(router.getLabel('complex')).toBe('unknown');
});
it('getAllLabels returns all tier labels', () => {
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
fallbackChain: [],
});
const labels = router.getAllLabels();
expect(labels).toEqual({});
router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier');
router.setClient('complex', { chat: vi.fn() } as unknown as ModelClient, 'complex-tier');
const allLabels = router.getAllLabels();
expect(allLabels).toEqual({
fast: 'fast-tier',
complex: 'complex-tier',
});
});
it('constructor accepts initial labels', async () => {
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
const router = new ModelRouter({
default: mockClient1,
fast: mockClient2,
fallbackChain: [],
labels: {
default: 'default-tier',
fast: 'fast-tier',
},
});
expect(router.getClient('default')).toBe(mockClient1);
expect(router.getClient('fast')).toBe(mockClient2);
expect(router.getLabel('default')).toBe('default-tier');
expect(router.getLabel('fast')).toBe('fast-tier');
expect(router.getLabel('complex')).toBe('unknown');
await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'fast');
expect(mockClient2.chat).toHaveBeenCalled();
});
it('chat uses the new client after setClient', async () => {
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
const router = new ModelRouter({
default: mockClient1,
fast: { chat: vi.fn() } as unknown as ModelClient,
fallbackChain: [],
labels: {
fast: 'original-fast',
},
});
const initialFastClient = router.getClient('fast');
expect(initialFastClient).toBeDefined();
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
expect(initialFastClient!.chat).toHaveBeenCalled();
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
router.setClient('fast', mockClient2, 'fast-replaced');
const newFastClient = router.getClient('fast');
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
expect(newFastClient!.chat).toHaveBeenCalled();
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
});
});
+28
View File
@@ -11,10 +11,12 @@ export interface ModelRouterConfig {
local?: ModelClient;
fallbackChain: ModelClient[];
retryConfig?: RetryConfig;
labels?: Partial<Record<ModelTier, string>>;
}
export class ModelRouter implements ModelClient {
private clients: Map<ModelTier, ModelClient>;
private labels: Map<ModelTier, string>;
private defaultClient: ModelClient;
private fallbackChain: ModelClient[];
private currentTier: ModelTier = 'default';
@@ -23,6 +25,7 @@ export class ModelRouter implements ModelClient {
constructor(config: ModelRouterConfig) {
this.clients = new Map();
this.labels = new Map();
this.defaultClient = config.default;
this.fallbackChain = config.fallbackChain;
this.retryConfig = config.retryConfig;
@@ -31,6 +34,14 @@ export class ModelRouter implements ModelClient {
if (config.fast) this.clients.set('fast', config.fast);
if (config.complex) this.clients.set('complex', config.complex);
if (config.local) this.clients.set('local', config.local);
if (config.labels) {
for (const tier of ['fast', 'default', 'complex', 'local'] as ModelTier[]) {
if (config.labels[tier]) {
this.labels.set(tier, config.labels[tier]);
}
}
}
}
setTier(tier: ModelTier): boolean {
@@ -141,4 +152,21 @@ export class ModelRouter implements ModelClient {
getLocalProviderName(): string | undefined {
return this.localProviderName;
}
setClient(tier: ModelTier, client: ModelClient, label: string): void {
this.clients.set(tier, client);
this.labels.set(tier, label);
}
getLabel(tier: ModelTier): string {
return this.labels.get(tier) ?? 'unknown';
}
getAllLabels(): Record<string, string> {
const result: Record<string, string> = {};
for (const tier of this.labels.keys()) {
result[tier] = this.labels.get(tier) ?? 'unknown';
}
return result;
}
}