fix: provider-aware model routing with fallback visibility
- Extract createClientFromConfig() to dispatch on provider field instead of hardcoding all tiers as AnthropicClient - Add fallback/fallbackReason metadata to ChatResponse and ChatStreamEvent so callers know when a fallback model was used - Enhance doctor check to report full model stack and warn on missing API keys for cloud providers - Log fallback warnings in NativeAgent and display them in TUI - Support tier names and local_providers entries in fallback_chain - Add 8 tests for createClientFromConfig covering all provider types
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { createClientFromConfig } from './index.js';
|
||||
import { AnthropicClient } from '../models/anthropic.js';
|
||||
import { OpenAIClient } from '../models/openai.js';
|
||||
import { OllamaClient } from '../models/local/ollama.js';
|
||||
import { LlamaCppClient } from '../models/local/llamacpp.js';
|
||||
|
||||
describe('createClientFromConfig', () => {
|
||||
it('creates AnthropicClient for anthropic provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'anthropic',
|
||||
model: 'claude-sonnet-4-5-20250514',
|
||||
api_key: 'sk-ant-test',
|
||||
});
|
||||
expect(client).toBeInstanceOf(AnthropicClient);
|
||||
});
|
||||
|
||||
it('creates OpenAIClient for openai provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'openai',
|
||||
model: 'gpt-4o',
|
||||
api_key: 'sk-test',
|
||||
});
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
});
|
||||
|
||||
it('creates OllamaClient for ollama provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'ollama',
|
||||
model: 'llama3.2:1b',
|
||||
endpoint: 'http://localhost:11434',
|
||||
});
|
||||
expect(client).toBeInstanceOf(OllamaClient);
|
||||
});
|
||||
|
||||
it('creates OllamaClient with num_gpu option', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'ollama',
|
||||
model: 'llama3.2:1b',
|
||||
num_gpu: 0,
|
||||
});
|
||||
expect(client).toBeInstanceOf(OllamaClient);
|
||||
});
|
||||
|
||||
it('creates LlamaCppClient for llamacpp provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'llamacpp',
|
||||
model: 'ministral-reasoning',
|
||||
endpoint: 'http://localhost:8080',
|
||||
});
|
||||
expect(client).toBeInstanceOf(LlamaCppClient);
|
||||
});
|
||||
|
||||
it('defaults llamacpp endpoint to localhost:8080', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'llamacpp',
|
||||
model: 'test-model',
|
||||
});
|
||||
expect(client).toBeInstanceOf(LlamaCppClient);
|
||||
});
|
||||
|
||||
it('creates OpenAI-compatible client for gemini provider (with warning)', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'gemini',
|
||||
model: 'gemini-2.5-pro',
|
||||
api_key: 'test-key',
|
||||
});
|
||||
// Gemini falls back to OpenAI-compatible client
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
});
|
||||
|
||||
it('throws for unknown provider', () => {
|
||||
expect(() => createClientFromConfig({
|
||||
provider: 'unknown' as 'anthropic',
|
||||
model: 'test',
|
||||
})).toThrow('Unknown model provider: unknown');
|
||||
});
|
||||
});
|
||||
+66
-45
@@ -1,6 +1,7 @@
|
||||
import { Lifecycle } from './lifecycle.js';
|
||||
import type { Config } from '../config/index.js';
|
||||
import type { Config, ModelConfig } from '../config/index.js';
|
||||
import { AnthropicClient, OpenAIClient, OllamaClient, LlamaCppClient, ModelRouter } from '../models/index.js';
|
||||
import type { ModelClient } from '../models/index.js';
|
||||
import { NativeAgent } from '../backends/index.js';
|
||||
import { SessionStore, SessionManager } from '../session/index.js';
|
||||
import { HookEngine } from '../hooks/index.js';
|
||||
@@ -48,60 +49,80 @@ function loadSystemPrompt(): string {
|
||||
return 'You are Flynn, a helpful personal AI assistant. Be direct, concise, and helpful. Use markdown when it improves readability.';
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a ModelClient from a provider config entry.
|
||||
* Dispatches on the `provider` field so all tiers and fallback entries
|
||||
* use the correct client implementation.
|
||||
*/
|
||||
export function createClientFromConfig(cfg: ModelConfig): ModelClient {
|
||||
switch (cfg.provider) {
|
||||
case 'anthropic':
|
||||
return new AnthropicClient({
|
||||
model: cfg.model,
|
||||
apiKey: cfg.api_key,
|
||||
authToken: cfg.auth_token,
|
||||
});
|
||||
case 'openai':
|
||||
return new OpenAIClient({
|
||||
model: cfg.model,
|
||||
apiKey: cfg.api_key,
|
||||
});
|
||||
case 'ollama':
|
||||
return new OllamaClient({
|
||||
model: cfg.model,
|
||||
host: cfg.endpoint,
|
||||
numGpu: cfg.num_gpu,
|
||||
});
|
||||
case 'llamacpp':
|
||||
return new LlamaCppClient({
|
||||
endpoint: cfg.endpoint ?? 'http://localhost:8080',
|
||||
model: cfg.model,
|
||||
authToken: cfg.auth_token,
|
||||
});
|
||||
case 'gemini':
|
||||
// Gemini support not yet implemented — fall back to OpenAI-compatible client
|
||||
console.warn(`Gemini provider not yet implemented for model "${cfg.model}", using OpenAI-compatible client`);
|
||||
return new OpenAIClient({
|
||||
model: cfg.model,
|
||||
apiKey: cfg.api_key,
|
||||
});
|
||||
default:
|
||||
throw new Error(`Unknown model provider: ${(cfg as Record<string, unknown>).provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
function createModelRouter(config: Config): ModelRouter {
|
||||
const models = config.models;
|
||||
|
||||
const defaultClient = new AnthropicClient({
|
||||
model: models.default.model,
|
||||
apiKey: models.default.api_key,
|
||||
authToken: models.default.auth_token,
|
||||
});
|
||||
const defaultClient = createClientFromConfig(models.default);
|
||||
|
||||
let fastClient;
|
||||
let complexClient;
|
||||
let localClient;
|
||||
const fastClient = models.fast ? createClientFromConfig(models.fast) : undefined;
|
||||
const complexClient = models.complex ? createClientFromConfig(models.complex) : undefined;
|
||||
const localClient = models.local ? createClientFromConfig(models.local) : undefined;
|
||||
|
||||
if (models.fast) {
|
||||
fastClient = new AnthropicClient({
|
||||
model: models.fast.model,
|
||||
apiKey: models.fast.api_key,
|
||||
authToken: models.fast.auth_token,
|
||||
});
|
||||
}
|
||||
|
||||
if (models.complex) {
|
||||
complexClient = new AnthropicClient({
|
||||
model: models.complex.model,
|
||||
apiKey: models.complex.api_key,
|
||||
authToken: models.complex.auth_token,
|
||||
});
|
||||
}
|
||||
|
||||
if (models.local) {
|
||||
if (models.local.provider === 'ollama') {
|
||||
localClient = new OllamaClient({
|
||||
model: models.local.model,
|
||||
host: models.local.endpoint,
|
||||
numGpu: models.local.num_gpu,
|
||||
});
|
||||
} else if (models.local.provider === 'llamacpp') {
|
||||
localClient = new LlamaCppClient({
|
||||
endpoint: models.local.endpoint ?? 'http://localhost:8080',
|
||||
model: models.local.model,
|
||||
authToken: models.local.auth_token,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const fallbackChain = [];
|
||||
// Build fallback chain — each entry references a tier name or 'local'
|
||||
const fallbackChain: ModelClient[] = [];
|
||||
for (const providerName of models.fallback_chain) {
|
||||
if (providerName === 'openai') {
|
||||
fallbackChain.push(new OpenAIClient({ model: 'gpt-4o' }));
|
||||
} else if (providerName === 'local' && localClient) {
|
||||
if (providerName === 'local' && localClient) {
|
||||
fallbackChain.push(localClient);
|
||||
} else if (providerName === 'default') {
|
||||
// Allows re-trying the default provider in the chain
|
||||
fallbackChain.push(defaultClient);
|
||||
} else if (providerName === 'fast' && fastClient) {
|
||||
fallbackChain.push(fastClient);
|
||||
} else if (providerName === 'complex' && complexClient) {
|
||||
fallbackChain.push(complexClient);
|
||||
} else if (models.local_providers?.[providerName]) {
|
||||
// Named provider from local_providers map
|
||||
fallbackChain.push(createClientFromConfig(models.local_providers[providerName]));
|
||||
} else {
|
||||
console.warn(`Fallback chain entry "${providerName}" not found — skipping`);
|
||||
}
|
||||
}
|
||||
|
||||
console.log(`Model router: default=${models.default.provider}/${models.default.model}, ` +
|
||||
`fallback=[${models.fallback_chain.join(', ')}]`);
|
||||
|
||||
return new ModelRouter({
|
||||
default: defaultClient,
|
||||
fast: fastClient,
|
||||
|
||||
Reference in New Issue
Block a user