feat: add setLocalClient and getLocalProviderName to ModelRouter
This commit is contained in:
@@ -141,3 +141,24 @@ describe('ModelRouter streaming', () => {
|
||||
expect(chunks).toEqual(['Fallback']);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ModelRouter local client switching', () => {
|
||||
it('allows setting a new local client', () => {
|
||||
const mockDefault = { chat: vi.fn() } as unknown as ModelClient;
|
||||
const mockLocal1 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
const mockLocal2 = { chat: vi.fn() } as unknown as ModelClient;
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: mockDefault,
|
||||
local: mockLocal1,
|
||||
fallbackChain: [],
|
||||
});
|
||||
|
||||
expect(router.getLocalProviderName()).toBe(undefined);
|
||||
|
||||
router.setLocalClient(mockLocal2, 'llamacpp');
|
||||
|
||||
expect(router.getLocalProviderName()).toBe('llamacpp');
|
||||
expect(router.getClient('local')).toBe(mockLocal2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -15,6 +15,7 @@ export class ModelRouter implements ModelClient {
|
||||
private defaultClient: ModelClient;
|
||||
private fallbackChain: ModelClient[];
|
||||
private currentTier: ModelTier = 'default';
|
||||
private localProviderName?: string;
|
||||
|
||||
constructor(config: ModelRouterConfig) {
|
||||
this.clients = new Map();
|
||||
@@ -111,4 +112,13 @@ export class ModelRouter implements ModelClient {
|
||||
getClient(tier: ModelTier): ModelClient | undefined {
|
||||
return this.clients.get(tier);
|
||||
}
|
||||
|
||||
setLocalClient(client: ModelClient, providerName: string): void {
|
||||
this.clients.set('local', client);
|
||||
this.localProviderName = providerName;
|
||||
}
|
||||
|
||||
getLocalProviderName(): string | undefined {
|
||||
return this.localProviderName;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user