feat: add runtime provider/model switching via /model <tier> <provider/model>
- ModelRouter: add setClient(), labels map, getLabel(), getAllLabels() - TUI commands: parse /model <tier> <provider/model> syntax with autocompletion - TUI minimal: handle provider switching via createClientFromConfig factory - Daemon: wire initial labels into router config - Fix /model alias mappings (opus=complex, sonnet=default, haiku=fast) - Add design doc and update state.json with feature status
This commit is contained in:
@@ -169,3 +169,149 @@ describe('ModelRouter local client switching', () => {
|
||||
expect(router.getClient('local')).toBe(mockLocal2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setClient and labels', () => {
|
||||
it('setClient replaces an existing tier client', async () => {
|
||||
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: { chat: vi.fn() } as unknown as ModelClient,
|
||||
fast: mockClient1,
|
||||
fallbackChain: [],
|
||||
});
|
||||
|
||||
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
|
||||
|
||||
expect(mockClient1.chat).toHaveBeenCalled();
|
||||
expect(mockClient1.chat).toHaveBeenCalledTimes(1);
|
||||
|
||||
router.setClient('fast', mockClient2, 'fast-replaced');
|
||||
|
||||
const newFastClient = router.getClient('fast');
|
||||
expect(newFastClient).toBeDefined();
|
||||
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
|
||||
|
||||
expect(newFastClient!.chat).toHaveBeenCalled();
|
||||
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
|
||||
expect(mockClient1.chat).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('setClient adds a new tier client', async () => {
|
||||
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: mockClient1,
|
||||
fallbackChain: [],
|
||||
});
|
||||
|
||||
expect(router.getClient('complex')).toBeUndefined();
|
||||
|
||||
router.setClient('complex', mockClient2, 'complex-tier');
|
||||
|
||||
const newClient = router.getClient('complex');
|
||||
expect(newClient).toBe(mockClient2);
|
||||
|
||||
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'complex');
|
||||
|
||||
expect(newClient!.chat).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('getLabel returns the label set by setClient', () => {
|
||||
const router = new ModelRouter({
|
||||
default: { chat: vi.fn() } as unknown as ModelClient,
|
||||
fallbackChain: [],
|
||||
});
|
||||
|
||||
expect(router.getLabel('fast')).toBe('unknown');
|
||||
|
||||
router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier');
|
||||
|
||||
expect(router.getLabel('fast')).toBe('fast-tier');
|
||||
});
|
||||
|
||||
it('getLabel returns "unknown" for unset tier', () => {
|
||||
const router = new ModelRouter({
|
||||
default: { chat: vi.fn() } as unknown as ModelClient,
|
||||
fallbackChain: [],
|
||||
});
|
||||
|
||||
expect(router.getLabel('fast')).toBe('unknown');
|
||||
expect(router.getLabel('complex')).toBe('unknown');
|
||||
});
|
||||
|
||||
it('getAllLabels returns all tier labels', () => {
|
||||
const router = new ModelRouter({
|
||||
default: { chat: vi.fn() } as unknown as ModelClient,
|
||||
fallbackChain: [],
|
||||
});
|
||||
|
||||
const labels = router.getAllLabels();
|
||||
expect(labels).toEqual({});
|
||||
|
||||
router.setClient('fast', { chat: vi.fn() } as unknown as ModelClient, 'fast-tier');
|
||||
router.setClient('complex', { chat: vi.fn() } as unknown as ModelClient, 'complex-tier');
|
||||
|
||||
const allLabels = router.getAllLabels();
|
||||
expect(allLabels).toEqual({
|
||||
fast: 'fast-tier',
|
||||
complex: 'complex-tier',
|
||||
});
|
||||
});
|
||||
|
||||
it('constructor accepts initial labels', async () => {
|
||||
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: mockClient1,
|
||||
fast: mockClient2,
|
||||
fallbackChain: [],
|
||||
labels: {
|
||||
default: 'default-tier',
|
||||
fast: 'fast-tier',
|
||||
},
|
||||
});
|
||||
|
||||
expect(router.getClient('default')).toBe(mockClient1);
|
||||
expect(router.getClient('fast')).toBe(mockClient2);
|
||||
expect(router.getLabel('default')).toBe('default-tier');
|
||||
expect(router.getLabel('fast')).toBe('fast-tier');
|
||||
expect(router.getLabel('complex')).toBe('unknown');
|
||||
|
||||
await router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'fast');
|
||||
|
||||
expect(mockClient2.chat).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('chat uses the new client after setClient', async () => {
|
||||
const mockClient1 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
const mockClient2 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: mockClient1,
|
||||
fast: { chat: vi.fn() } as unknown as ModelClient,
|
||||
fallbackChain: [],
|
||||
labels: {
|
||||
fast: 'original-fast',
|
||||
},
|
||||
});
|
||||
|
||||
const initialFastClient = router.getClient('fast');
|
||||
expect(initialFastClient).toBeDefined();
|
||||
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
|
||||
|
||||
expect(initialFastClient!.chat).toHaveBeenCalled();
|
||||
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
|
||||
|
||||
router.setClient('fast', mockClient2, 'fast-replaced');
|
||||
|
||||
const newFastClient = router.getClient('fast');
|
||||
await router.chat({ messages: [{ role: 'user', content: 'Test' }] }, 'fast');
|
||||
|
||||
expect(newFastClient!.chat).toHaveBeenCalled();
|
||||
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
|
||||
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,10 +11,12 @@ export interface ModelRouterConfig {
|
||||
local?: ModelClient;
|
||||
fallbackChain: ModelClient[];
|
||||
retryConfig?: RetryConfig;
|
||||
labels?: Partial<Record<ModelTier, string>>;
|
||||
}
|
||||
|
||||
export class ModelRouter implements ModelClient {
|
||||
private clients: Map<ModelTier, ModelClient>;
|
||||
private labels: Map<ModelTier, string>;
|
||||
private defaultClient: ModelClient;
|
||||
private fallbackChain: ModelClient[];
|
||||
private currentTier: ModelTier = 'default';
|
||||
@@ -23,6 +25,7 @@ export class ModelRouter implements ModelClient {
|
||||
|
||||
constructor(config: ModelRouterConfig) {
|
||||
this.clients = new Map();
|
||||
this.labels = new Map();
|
||||
this.defaultClient = config.default;
|
||||
this.fallbackChain = config.fallbackChain;
|
||||
this.retryConfig = config.retryConfig;
|
||||
@@ -31,6 +34,14 @@ export class ModelRouter implements ModelClient {
|
||||
if (config.fast) this.clients.set('fast', config.fast);
|
||||
if (config.complex) this.clients.set('complex', config.complex);
|
||||
if (config.local) this.clients.set('local', config.local);
|
||||
|
||||
if (config.labels) {
|
||||
for (const tier of ['fast', 'default', 'complex', 'local'] as ModelTier[]) {
|
||||
if (config.labels[tier]) {
|
||||
this.labels.set(tier, config.labels[tier]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setTier(tier: ModelTier): boolean {
|
||||
@@ -141,4 +152,21 @@ export class ModelRouter implements ModelClient {
|
||||
getLocalProviderName(): string | undefined {
|
||||
return this.localProviderName;
|
||||
}
|
||||
|
||||
setClient(tier: ModelTier, client: ModelClient, label: string): void {
|
||||
this.clients.set(tier, client);
|
||||
this.labels.set(tier, label);
|
||||
}
|
||||
|
||||
getLabel(tier: ModelTier): string {
|
||||
return this.labels.get(tier) ?? 'unknown';
|
||||
}
|
||||
|
||||
getAllLabels(): Record<string, string> {
|
||||
const result: Record<string, string> = {};
|
||||
for (const tier of this.labels.keys()) {
|
||||
result[tier] = this.labels.get(tier) ?? 'unknown';
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user