feat: add setLocalClient and getLocalProviderName to ModelRouter
This commit is contained in:
@@ -141,3 +141,24 @@ describe('ModelRouter streaming', () => {
|
|||||||
expect(chunks).toEqual(['Fallback']);
|
expect(chunks).toEqual(['Fallback']);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('ModelRouter local client switching', () => {
|
||||||
|
it('allows setting a new local client', () => {
|
||||||
|
const mockDefault = { chat: vi.fn() } as unknown as ModelClient;
|
||||||
|
const mockLocal1 = { chat: vi.fn() } as unknown as ModelClient;
|
||||||
|
const mockLocal2 = { chat: vi.fn() } as unknown as ModelClient;
|
||||||
|
|
||||||
|
const router = new ModelRouter({
|
||||||
|
default: mockDefault,
|
||||||
|
local: mockLocal1,
|
||||||
|
fallbackChain: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(router.getLocalProviderName()).toBe(undefined);
|
||||||
|
|
||||||
|
router.setLocalClient(mockLocal2, 'llamacpp');
|
||||||
|
|
||||||
|
expect(router.getLocalProviderName()).toBe('llamacpp');
|
||||||
|
expect(router.getClient('local')).toBe(mockLocal2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ export class ModelRouter implements ModelClient {
|
|||||||
private defaultClient: ModelClient;
|
private defaultClient: ModelClient;
|
||||||
private fallbackChain: ModelClient[];
|
private fallbackChain: ModelClient[];
|
||||||
private currentTier: ModelTier = 'default';
|
private currentTier: ModelTier = 'default';
|
||||||
|
private localProviderName?: string;
|
||||||
|
|
||||||
constructor(config: ModelRouterConfig) {
|
constructor(config: ModelRouterConfig) {
|
||||||
this.clients = new Map();
|
this.clients = new Map();
|
||||||
@@ -111,4 +112,13 @@ export class ModelRouter implements ModelClient {
|
|||||||
getClient(tier: ModelTier): ModelClient | undefined {
|
getClient(tier: ModelTier): ModelClient | undefined {
|
||||||
return this.clients.get(tier);
|
return this.clients.get(tier);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setLocalClient(client: ModelClient, providerName: string): void {
|
||||||
|
this.clients.set('local', client);
|
||||||
|
this.localProviderName = providerName;
|
||||||
|
}
|
||||||
|
|
||||||
|
getLocalProviderName(): string | undefined {
|
||||||
|
return this.localProviderName;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user