fix: provider-aware model routing with fallback visibility
- Extract createClientFromConfig() to dispatch on provider field instead of hardcoding all tiers as AnthropicClient - Add fallback/fallbackReason metadata to ChatResponse and ChatStreamEvent so callers know when a fallback model was used - Enhance doctor check to report full model stack and warn on missing API keys for cloud providers - Log fallback warnings in NativeAgent and display them in TUI - Support tier names and local_providers entries in fallback_chain - Add 8 tests for createClientFromConfig covering all provider types
This commit is contained in:
@@ -79,6 +79,10 @@ export class NativeAgent {
|
|||||||
|
|
||||||
const response = await this.chatWithRouter(request);
|
const response = await this.chatWithRouter(request);
|
||||||
|
|
||||||
|
if (response.fallback) {
|
||||||
|
console.warn(`[Flynn] ${response.fallbackReason}`);
|
||||||
|
}
|
||||||
|
|
||||||
const assistantMsg: Message = { role: 'assistant', content: response.content };
|
const assistantMsg: Message = { role: 'assistant', content: response.content };
|
||||||
this.addToHistory(assistantMsg);
|
this.addToHistory(assistantMsg);
|
||||||
|
|
||||||
@@ -106,6 +110,10 @@ export class NativeAgent {
|
|||||||
|
|
||||||
const response = await this.chatWithRouter(request);
|
const response = await this.chatWithRouter(request);
|
||||||
|
|
||||||
|
if (response.fallback) {
|
||||||
|
console.warn(`[Flynn] ${response.fallbackReason}`);
|
||||||
|
}
|
||||||
|
|
||||||
// If the model didn't request tool use, we're done
|
// If the model didn't request tool use, we're done
|
||||||
if (response.stopReason !== 'tool_use' || !response.toolCalls?.length) {
|
if (response.stopReason !== 'tool_use' || !response.toolCalls?.length) {
|
||||||
const assistantMsg: Message = { role: 'assistant', content: response.content };
|
const assistantMsg: Message = { role: 'assistant', content: response.content };
|
||||||
|
|||||||
@@ -0,0 +1,125 @@
|
|||||||
|
import { describe, it, expect, afterEach } from 'vitest';
|
||||||
|
import { runChecks, type CheckResult, type DoctorContext } from './doctor.js';
|
||||||
|
import { writeFileSync, mkdirSync, rmSync } from 'fs';
|
||||||
|
import { join } from 'path';
|
||||||
|
import { tmpdir } from 'os';
|
||||||
|
|
||||||
|
describe('doctor checks', () => {
|
||||||
|
const testDir = join(tmpdir(), 'flynn-test-doctor');
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
try { rmSync(testDir, { recursive: true }); } catch {}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reports PASS when config file exists and is valid', async () => {
|
||||||
|
mkdirSync(testDir, { recursive: true });
|
||||||
|
const configPath = join(testDir, 'config.yaml');
|
||||||
|
writeFileSync(configPath, `
|
||||||
|
telegram:
|
||||||
|
bot_token: "test-token"
|
||||||
|
allowed_chat_ids: [123]
|
||||||
|
models:
|
||||||
|
default:
|
||||||
|
provider: anthropic
|
||||||
|
model: claude-sonnet
|
||||||
|
`);
|
||||||
|
|
||||||
|
const ctx: DoctorContext = { configPath, dataDir: testDir };
|
||||||
|
const results = await runChecks(ctx);
|
||||||
|
|
||||||
|
const configExists = results.find(r => r.label.includes('Config file'));
|
||||||
|
expect(configExists?.status).toBe('pass');
|
||||||
|
|
||||||
|
const configParses = results.find(r => r.label.includes('parses'));
|
||||||
|
expect(configParses?.status).toBe('pass');
|
||||||
|
|
||||||
|
const configValidates = results.find(r => r.label.includes('validates'));
|
||||||
|
expect(configValidates?.status).toBe('pass');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reports FAIL when config file does not exist', async () => {
|
||||||
|
const ctx: DoctorContext = { configPath: '/nonexistent/config.yaml', dataDir: testDir };
|
||||||
|
const results = await runChecks(ctx);
|
||||||
|
|
||||||
|
const configExists = results.find(r => r.label.includes('Config file'));
|
||||||
|
expect(configExists?.status).toBe('fail');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reports FAIL on invalid YAML', async () => {
|
||||||
|
mkdirSync(testDir, { recursive: true });
|
||||||
|
const configPath = join(testDir, 'bad.yaml');
|
||||||
|
writeFileSync(configPath, '{{{{bad yaml');
|
||||||
|
|
||||||
|
const ctx: DoctorContext = { configPath, dataDir: testDir };
|
||||||
|
const results = await runChecks(ctx);
|
||||||
|
|
||||||
|
const configParses = results.find(r => r.label.includes('parses'));
|
||||||
|
expect(configParses?.status).toBe('fail');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reports FAIL on schema validation failure', async () => {
|
||||||
|
mkdirSync(testDir, { recursive: true });
|
||||||
|
const configPath = join(testDir, 'invalid.yaml');
|
||||||
|
writeFileSync(configPath, `
|
||||||
|
telegram:
|
||||||
|
bot_token: ""
|
||||||
|
`);
|
||||||
|
|
||||||
|
const ctx: DoctorContext = { configPath, dataDir: testDir };
|
||||||
|
const results = await runChecks(ctx);
|
||||||
|
|
||||||
|
const configValidates = results.find(r => r.label.includes('validates'));
|
||||||
|
expect(configValidates?.status).toBe('fail');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reports PASS for writable data directory', async () => {
|
||||||
|
mkdirSync(testDir, { recursive: true });
|
||||||
|
const configPath = join(testDir, 'config.yaml');
|
||||||
|
writeFileSync(configPath, `
|
||||||
|
telegram:
|
||||||
|
bot_token: "test-token"
|
||||||
|
allowed_chat_ids: [123]
|
||||||
|
models:
|
||||||
|
default:
|
||||||
|
provider: anthropic
|
||||||
|
model: claude-sonnet
|
||||||
|
`);
|
||||||
|
|
||||||
|
const ctx: DoctorContext = { configPath, dataDir: testDir };
|
||||||
|
const results = await runChecks(ctx);
|
||||||
|
|
||||||
|
const dataDir = results.find(r => r.label.includes('Data directory'));
|
||||||
|
expect(dataDir?.status).toBe('pass');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('reports PASS for accessible session DB', async () => {
|
||||||
|
mkdirSync(testDir, { recursive: true });
|
||||||
|
const configPath = join(testDir, 'config.yaml');
|
||||||
|
writeFileSync(configPath, `
|
||||||
|
telegram:
|
||||||
|
bot_token: "test-token"
|
||||||
|
allowed_chat_ids: [123]
|
||||||
|
models:
|
||||||
|
default:
|
||||||
|
provider: anthropic
|
||||||
|
model: claude-sonnet
|
||||||
|
`);
|
||||||
|
|
||||||
|
const ctx: DoctorContext = { configPath, dataDir: testDir };
|
||||||
|
const results = await runChecks(ctx);
|
||||||
|
|
||||||
|
const sessionDb = results.find(r => r.label.includes('Session DB'));
|
||||||
|
expect(sessionDb?.status).toBe('pass');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('skips downstream checks when config is invalid', async () => {
|
||||||
|
const ctx: DoctorContext = { configPath: '/nonexistent/config.yaml', dataDir: testDir };
|
||||||
|
const results = await runChecks(ctx);
|
||||||
|
|
||||||
|
const modelCheck = results.find(r => r.label.includes('Model connectivity'));
|
||||||
|
expect(modelCheck?.status).toBe('skip');
|
||||||
|
|
||||||
|
const telegramCheck = results.find(r => r.label.includes('Telegram'));
|
||||||
|
expect(telegramCheck?.status).toBe('skip');
|
||||||
|
});
|
||||||
|
});
|
||||||
+21
-3
@@ -116,12 +116,30 @@ const checkModelConnectivity: Check = async (ctx) => {
|
|||||||
if (!ctx.config) {
|
if (!ctx.config) {
|
||||||
return { status: 'skip', label: 'Model connectivity', detail: '(config invalid)' };
|
return { status: 'skip', label: 'Model connectivity', detail: '(config invalid)' };
|
||||||
}
|
}
|
||||||
// Skip actual API call in doctor — just verify config looks complete
|
const models = ctx.config.models;
|
||||||
const model = ctx.config.models.default;
|
const model = models.default;
|
||||||
if (!model.model) {
|
if (!model.model) {
|
||||||
return { status: 'fail', label: 'Model connectivity', detail: 'no default model configured' };
|
return { status: 'fail', label: 'Model connectivity', detail: 'no default model configured' };
|
||||||
}
|
}
|
||||||
return { status: 'pass', label: 'Model connectivity', detail: `(${model.provider}: ${model.model})` };
|
|
||||||
|
// Check if API key is present for providers that need one
|
||||||
|
const needsKey = ['anthropic', 'openai', 'gemini'];
|
||||||
|
if (needsKey.includes(model.provider) && !model.api_key && !model.auth_token) {
|
||||||
|
const envVar = model.provider === 'anthropic' ? 'ANTHROPIC_API_KEY' : model.provider === 'openai' ? 'OPENAI_API_KEY' : undefined;
|
||||||
|
const hasEnv = envVar && process.env[envVar];
|
||||||
|
if (!hasEnv) {
|
||||||
|
return { status: 'warn', label: 'Model connectivity', detail: `${model.provider}/${model.model} — no API key or auth token found` };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a summary of the model stack
|
||||||
|
const parts = [`default: ${model.provider}/${model.model}`];
|
||||||
|
if (models.fast) parts.push(`fast: ${models.fast.provider}/${models.fast.model}`);
|
||||||
|
if (models.complex) parts.push(`complex: ${models.complex.provider}/${models.complex.model}`);
|
||||||
|
if (models.local) parts.push(`local: ${models.local.provider}/${models.local.model}`);
|
||||||
|
parts.push(`fallback: [${models.fallback_chain.join(', ')}]`);
|
||||||
|
|
||||||
|
return { status: 'pass', label: 'Model connectivity', detail: parts.join(', ') };
|
||||||
};
|
};
|
||||||
|
|
||||||
const checkTelegram: Check = async (ctx) => {
|
const checkTelegram: Check = async (ctx) => {
|
||||||
|
|||||||
@@ -0,0 +1,78 @@
|
|||||||
|
import { describe, it, expect } from 'vitest';
|
||||||
|
import { createClientFromConfig } from './index.js';
|
||||||
|
import { AnthropicClient } from '../models/anthropic.js';
|
||||||
|
import { OpenAIClient } from '../models/openai.js';
|
||||||
|
import { OllamaClient } from '../models/local/ollama.js';
|
||||||
|
import { LlamaCppClient } from '../models/local/llamacpp.js';
|
||||||
|
|
||||||
|
describe('createClientFromConfig', () => {
|
||||||
|
it('creates AnthropicClient for anthropic provider', () => {
|
||||||
|
const client = createClientFromConfig({
|
||||||
|
provider: 'anthropic',
|
||||||
|
model: 'claude-sonnet-4-5-20250514',
|
||||||
|
api_key: 'sk-ant-test',
|
||||||
|
});
|
||||||
|
expect(client).toBeInstanceOf(AnthropicClient);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('creates OpenAIClient for openai provider', () => {
|
||||||
|
const client = createClientFromConfig({
|
||||||
|
provider: 'openai',
|
||||||
|
model: 'gpt-4o',
|
||||||
|
api_key: 'sk-test',
|
||||||
|
});
|
||||||
|
expect(client).toBeInstanceOf(OpenAIClient);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('creates OllamaClient for ollama provider', () => {
|
||||||
|
const client = createClientFromConfig({
|
||||||
|
provider: 'ollama',
|
||||||
|
model: 'llama3.2:1b',
|
||||||
|
endpoint: 'http://localhost:11434',
|
||||||
|
});
|
||||||
|
expect(client).toBeInstanceOf(OllamaClient);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('creates OllamaClient with num_gpu option', () => {
|
||||||
|
const client = createClientFromConfig({
|
||||||
|
provider: 'ollama',
|
||||||
|
model: 'llama3.2:1b',
|
||||||
|
num_gpu: 0,
|
||||||
|
});
|
||||||
|
expect(client).toBeInstanceOf(OllamaClient);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('creates LlamaCppClient for llamacpp provider', () => {
|
||||||
|
const client = createClientFromConfig({
|
||||||
|
provider: 'llamacpp',
|
||||||
|
model: 'ministral-reasoning',
|
||||||
|
endpoint: 'http://localhost:8080',
|
||||||
|
});
|
||||||
|
expect(client).toBeInstanceOf(LlamaCppClient);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('defaults llamacpp endpoint to localhost:8080', () => {
|
||||||
|
const client = createClientFromConfig({
|
||||||
|
provider: 'llamacpp',
|
||||||
|
model: 'test-model',
|
||||||
|
});
|
||||||
|
expect(client).toBeInstanceOf(LlamaCppClient);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('creates OpenAI-compatible client for gemini provider (with warning)', () => {
|
||||||
|
const client = createClientFromConfig({
|
||||||
|
provider: 'gemini',
|
||||||
|
model: 'gemini-2.5-pro',
|
||||||
|
api_key: 'test-key',
|
||||||
|
});
|
||||||
|
// Gemini falls back to OpenAI-compatible client
|
||||||
|
expect(client).toBeInstanceOf(OpenAIClient);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('throws for unknown provider', () => {
|
||||||
|
expect(() => createClientFromConfig({
|
||||||
|
provider: 'unknown' as 'anthropic',
|
||||||
|
model: 'test',
|
||||||
|
})).toThrow('Unknown model provider: unknown');
|
||||||
|
});
|
||||||
|
});
|
||||||
+66
-45
@@ -1,6 +1,7 @@
|
|||||||
import { Lifecycle } from './lifecycle.js';
|
import { Lifecycle } from './lifecycle.js';
|
||||||
import type { Config } from '../config/index.js';
|
import type { Config, ModelConfig } from '../config/index.js';
|
||||||
import { AnthropicClient, OpenAIClient, OllamaClient, LlamaCppClient, ModelRouter } from '../models/index.js';
|
import { AnthropicClient, OpenAIClient, OllamaClient, LlamaCppClient, ModelRouter } from '../models/index.js';
|
||||||
|
import type { ModelClient } from '../models/index.js';
|
||||||
import { NativeAgent } from '../backends/index.js';
|
import { NativeAgent } from '../backends/index.js';
|
||||||
import { SessionStore, SessionManager } from '../session/index.js';
|
import { SessionStore, SessionManager } from '../session/index.js';
|
||||||
import { HookEngine } from '../hooks/index.js';
|
import { HookEngine } from '../hooks/index.js';
|
||||||
@@ -48,60 +49,80 @@ function loadSystemPrompt(): string {
|
|||||||
return 'You are Flynn, a helpful personal AI assistant. Be direct, concise, and helpful. Use markdown when it improves readability.';
|
return 'You are Flynn, a helpful personal AI assistant. Be direct, concise, and helpful. Use markdown when it improves readability.';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a ModelClient from a provider config entry.
|
||||||
|
* Dispatches on the `provider` field so all tiers and fallback entries
|
||||||
|
* use the correct client implementation.
|
||||||
|
*/
|
||||||
|
export function createClientFromConfig(cfg: ModelConfig): ModelClient {
|
||||||
|
switch (cfg.provider) {
|
||||||
|
case 'anthropic':
|
||||||
|
return new AnthropicClient({
|
||||||
|
model: cfg.model,
|
||||||
|
apiKey: cfg.api_key,
|
||||||
|
authToken: cfg.auth_token,
|
||||||
|
});
|
||||||
|
case 'openai':
|
||||||
|
return new OpenAIClient({
|
||||||
|
model: cfg.model,
|
||||||
|
apiKey: cfg.api_key,
|
||||||
|
});
|
||||||
|
case 'ollama':
|
||||||
|
return new OllamaClient({
|
||||||
|
model: cfg.model,
|
||||||
|
host: cfg.endpoint,
|
||||||
|
numGpu: cfg.num_gpu,
|
||||||
|
});
|
||||||
|
case 'llamacpp':
|
||||||
|
return new LlamaCppClient({
|
||||||
|
endpoint: cfg.endpoint ?? 'http://localhost:8080',
|
||||||
|
model: cfg.model,
|
||||||
|
authToken: cfg.auth_token,
|
||||||
|
});
|
||||||
|
case 'gemini':
|
||||||
|
// Gemini support not yet implemented — fall back to OpenAI-compatible client
|
||||||
|
console.warn(`Gemini provider not yet implemented for model "${cfg.model}", using OpenAI-compatible client`);
|
||||||
|
return new OpenAIClient({
|
||||||
|
model: cfg.model,
|
||||||
|
apiKey: cfg.api_key,
|
||||||
|
});
|
||||||
|
default:
|
||||||
|
throw new Error(`Unknown model provider: ${(cfg as Record<string, unknown>).provider}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function createModelRouter(config: Config): ModelRouter {
|
function createModelRouter(config: Config): ModelRouter {
|
||||||
const models = config.models;
|
const models = config.models;
|
||||||
|
|
||||||
const defaultClient = new AnthropicClient({
|
const defaultClient = createClientFromConfig(models.default);
|
||||||
model: models.default.model,
|
|
||||||
apiKey: models.default.api_key,
|
|
||||||
authToken: models.default.auth_token,
|
|
||||||
});
|
|
||||||
|
|
||||||
let fastClient;
|
const fastClient = models.fast ? createClientFromConfig(models.fast) : undefined;
|
||||||
let complexClient;
|
const complexClient = models.complex ? createClientFromConfig(models.complex) : undefined;
|
||||||
let localClient;
|
const localClient = models.local ? createClientFromConfig(models.local) : undefined;
|
||||||
|
|
||||||
if (models.fast) {
|
// Build fallback chain — each entry references a tier name or 'local'
|
||||||
fastClient = new AnthropicClient({
|
const fallbackChain: ModelClient[] = [];
|
||||||
model: models.fast.model,
|
|
||||||
apiKey: models.fast.api_key,
|
|
||||||
authToken: models.fast.auth_token,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (models.complex) {
|
|
||||||
complexClient = new AnthropicClient({
|
|
||||||
model: models.complex.model,
|
|
||||||
apiKey: models.complex.api_key,
|
|
||||||
authToken: models.complex.auth_token,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (models.local) {
|
|
||||||
if (models.local.provider === 'ollama') {
|
|
||||||
localClient = new OllamaClient({
|
|
||||||
model: models.local.model,
|
|
||||||
host: models.local.endpoint,
|
|
||||||
numGpu: models.local.num_gpu,
|
|
||||||
});
|
|
||||||
} else if (models.local.provider === 'llamacpp') {
|
|
||||||
localClient = new LlamaCppClient({
|
|
||||||
endpoint: models.local.endpoint ?? 'http://localhost:8080',
|
|
||||||
model: models.local.model,
|
|
||||||
authToken: models.local.auth_token,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const fallbackChain = [];
|
|
||||||
for (const providerName of models.fallback_chain) {
|
for (const providerName of models.fallback_chain) {
|
||||||
if (providerName === 'openai') {
|
if (providerName === 'local' && localClient) {
|
||||||
fallbackChain.push(new OpenAIClient({ model: 'gpt-4o' }));
|
|
||||||
} else if (providerName === 'local' && localClient) {
|
|
||||||
fallbackChain.push(localClient);
|
fallbackChain.push(localClient);
|
||||||
|
} else if (providerName === 'default') {
|
||||||
|
// Allows re-trying the default provider in the chain
|
||||||
|
fallbackChain.push(defaultClient);
|
||||||
|
} else if (providerName === 'fast' && fastClient) {
|
||||||
|
fallbackChain.push(fastClient);
|
||||||
|
} else if (providerName === 'complex' && complexClient) {
|
||||||
|
fallbackChain.push(complexClient);
|
||||||
|
} else if (models.local_providers?.[providerName]) {
|
||||||
|
// Named provider from local_providers map
|
||||||
|
fallbackChain.push(createClientFromConfig(models.local_providers[providerName]));
|
||||||
|
} else {
|
||||||
|
console.warn(`Fallback chain entry "${providerName}" not found — skipping`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.log(`Model router: default=${models.default.provider}/${models.default.model}, ` +
|
||||||
|
`fallback=[${models.fallback_chain.join(', ')}]`);
|
||||||
|
|
||||||
return new ModelRouter({
|
return new ModelRouter({
|
||||||
default: defaultClient,
|
default: defaultClient,
|
||||||
fast: fastClient,
|
fast: fastClient,
|
||||||
|
|||||||
@@ -315,6 +315,9 @@ export class MinimalTui {
|
|||||||
process.stdout.write(event.content);
|
process.stdout.write(event.content);
|
||||||
fullContent += event.content;
|
fullContent += event.content;
|
||||||
}
|
}
|
||||||
|
if (event.type === 'fallback_warning' && event.fallbackReason) {
|
||||||
|
console.warn(`\n⚠ ${event.fallbackReason}`);
|
||||||
|
}
|
||||||
if (event.type === 'done' && event.usage) {
|
if (event.type === 'done' && event.usage) {
|
||||||
this.totalUsage.inputTokens += event.usage.inputTokens;
|
this.totalUsage.inputTokens += event.usage.inputTokens;
|
||||||
this.totalUsage.outputTokens += event.usage.outputTokens;
|
this.totalUsage.outputTokens += event.usage.outputTokens;
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ describe('ModelRouter', () => {
|
|||||||
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
const response = await router.chat({ messages: [{ role: 'user', content: 'Hi' }] });
|
||||||
|
|
||||||
expect(response.content).toBe('Response from fallback');
|
expect(response.content).toBe('Response from fallback');
|
||||||
|
expect(response.fallback).toBe(true);
|
||||||
|
expect(response.fallbackReason).toMatch(/Primary model failed/);
|
||||||
expect(failingClient.chat).toHaveBeenCalled();
|
expect(failingClient.chat).toHaveBeenCalled();
|
||||||
expect(fallbackClient.chat).toHaveBeenCalled();
|
expect(fallbackClient.chat).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
@@ -132,13 +134,18 @@ describe('ModelRouter streaming', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const chunks: string[] = [];
|
const chunks: string[] = [];
|
||||||
|
let fallbackWarning: string | undefined;
|
||||||
for await (const event of router.chatStream({ messages: [] })) {
|
for await (const event of router.chatStream({ messages: [] })) {
|
||||||
if (event.type === 'content' && event.content) {
|
if (event.type === 'content' && event.content) {
|
||||||
chunks.push(event.content);
|
chunks.push(event.content);
|
||||||
}
|
}
|
||||||
|
if (event.type === 'fallback_warning') {
|
||||||
|
fallbackWarning = event.fallbackReason;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
expect(chunks).toEqual(['Fallback']);
|
expect(chunks).toEqual(['Fallback']);
|
||||||
|
expect(fallbackWarning).toMatch(/Primary model failed/);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
+19
-7
@@ -58,13 +58,16 @@ export class ModelRouter implements ModelClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try fallback chain
|
// Try fallback chain
|
||||||
for (const fallbackClient of this.fallbackChain) {
|
for (let i = 0; i < this.fallbackChain.length; i++) {
|
||||||
|
const fallbackClient = this.fallbackChain[i];
|
||||||
try {
|
try {
|
||||||
console.log('Trying fallback model...');
|
const reason = `Primary model failed (${errors[0].message}), using fallback #${i + 1}`;
|
||||||
return await fallbackClient.chat(request);
|
console.warn(reason);
|
||||||
|
const response = await fallbackClient.chat(request);
|
||||||
|
return { ...response, fallback: true, fallbackReason: reason };
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
errors.push(error instanceof Error ? error : new Error(String(error)));
|
errors.push(error instanceof Error ? error : new Error(String(error)));
|
||||||
console.warn(`Fallback model failed: ${errors[errors.length - 1].message}`);
|
console.warn(`Fallback model #${i + 1} failed: ${errors[errors.length - 1].message}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,30 +77,39 @@ export class ModelRouter implements ModelClient {
|
|||||||
async *chatStream(request: ChatRequest, tier?: ModelTier): AsyncIterable<ChatStreamEvent> {
|
async *chatStream(request: ChatRequest, tier?: ModelTier): AsyncIterable<ChatStreamEvent> {
|
||||||
const useTier = tier ?? this.currentTier;
|
const useTier = tier ?? this.currentTier;
|
||||||
const primaryClient = this.clients.get(useTier) ?? this.defaultClient;
|
const primaryClient = this.clients.get(useTier) ?? this.defaultClient;
|
||||||
|
let primaryError: string | undefined;
|
||||||
|
|
||||||
if (primaryClient.chatStream) {
|
if (primaryClient.chatStream) {
|
||||||
let hasError = false;
|
let hasError = false;
|
||||||
for await (const event of primaryClient.chatStream(request)) {
|
for await (const event of primaryClient.chatStream(request)) {
|
||||||
if (event.type === 'error') {
|
if (event.type === 'error') {
|
||||||
hasError = true;
|
hasError = true;
|
||||||
console.warn(`Primary stream failed: ${event.error?.message}`);
|
primaryError = event.error?.message ?? 'Unknown error';
|
||||||
|
console.warn(`Primary stream failed: ${primaryError}`);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
yield event;
|
yield event;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!hasError) return;
|
if (!hasError) return;
|
||||||
|
} else {
|
||||||
|
primaryError = 'Primary client does not support streaming';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try fallback chain
|
// Try fallback chain
|
||||||
for (const fallbackClient of this.fallbackChain) {
|
for (let i = 0; i < this.fallbackChain.length; i++) {
|
||||||
|
const fallbackClient = this.fallbackChain[i];
|
||||||
if (!fallbackClient.chatStream) continue;
|
if (!fallbackClient.chatStream) continue;
|
||||||
|
|
||||||
|
const reason = `Primary model failed (${primaryError}), using fallback #${i + 1}`;
|
||||||
|
console.warn(reason);
|
||||||
|
yield { type: 'fallback_warning', fallbackReason: reason };
|
||||||
|
|
||||||
let hasError = false;
|
let hasError = false;
|
||||||
for await (const event of fallbackClient.chatStream(request)) {
|
for await (const event of fallbackClient.chatStream(request)) {
|
||||||
if (event.type === 'error') {
|
if (event.type === 'error') {
|
||||||
hasError = true;
|
hasError = true;
|
||||||
console.warn(`Fallback stream failed: ${event.error?.message}`);
|
console.warn(`Fallback stream #${i + 1} failed: ${event.error?.message}`);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
yield event;
|
yield event;
|
||||||
|
|||||||
+7
-1
@@ -55,6 +55,10 @@ export interface ChatResponse {
|
|||||||
stopReason: 'end_turn' | 'max_tokens' | 'stop_sequence' | 'tool_use' | string;
|
stopReason: 'end_turn' | 'max_tokens' | 'stop_sequence' | 'tool_use' | string;
|
||||||
usage: TokenUsage;
|
usage: TokenUsage;
|
||||||
toolCalls?: ModelToolCall[];
|
toolCalls?: ModelToolCall[];
|
||||||
|
/** Set when the response came from a fallback model, not the primary. */
|
||||||
|
fallback?: boolean;
|
||||||
|
/** Human-readable reason for the fallback. */
|
||||||
|
fallbackReason?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TokenUsage {
|
export interface TokenUsage {
|
||||||
@@ -63,11 +67,13 @@ export interface TokenUsage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatStreamEvent {
|
export interface ChatStreamEvent {
|
||||||
type: 'content' | 'done' | 'error' | 'tool_use';
|
type: 'content' | 'done' | 'error' | 'tool_use' | 'fallback_warning';
|
||||||
content?: string;
|
content?: string;
|
||||||
usage?: TokenUsage;
|
usage?: TokenUsage;
|
||||||
error?: Error;
|
error?: Error;
|
||||||
toolCall?: ModelToolCall;
|
toolCall?: ModelToolCall;
|
||||||
|
/** Human-readable message when primary model failed and fallback is being used. */
|
||||||
|
fallbackReason?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface StreamingModelClient {
|
export interface StreamingModelClient {
|
||||||
|
|||||||
Reference in New Issue
Block a user