Files
flynn/src/daemon/models.ts
T
2026-02-21 11:39:23 -08:00

499 lines
17 KiB
TypeScript

import type { Config, ModelConfig } from '../config/index.js';
import { AnthropicClient, OpenAIClient, OllamaClient, LlamaCppClient, GeminiClient, BedrockClient, GitHubModelsClient, SyntheticClient, RotatingModelClient, ModelRouter, DEFAULT_RETRY_CONFIG } from '../models/index.js';
import type { ModelClient, RetryConfig, ModelTier } from '../models/index.js';
import { logger } from '../logger.js';
import { getZaiApiKey } from '../auth/zai.js';
import { getAnthropicApiKey, getAnthropicAuthToken } from '../auth/anthropic.js';
import { getGeminiApiKey } from '../auth/gemini.js';
import { getOpenAIApiKey, loadStoredOpenAIAuth } from '../auth/openai.js';
type AuthMode = 'auto' | 'api_key' | 'oauth';
function getEffectiveAuthMode(cfg: ModelConfig): AuthMode {
if (cfg.auth_mode) {
return cfg.auth_mode;
}
if (cfg.use_oauth) {
return 'oauth';
}
return 'auto';
}
/**
* Resolve an API key from config or environment variable.
* Throws a clear error naming the expected env var if neither source provides a key.
*/
function requireApiKey(cfg: ModelConfig, envVar: string): string {
const key = cfg.api_key ?? process.env[envVar];
if (!key) {
throw new Error(
`API key required for ${cfg.provider}. ` +
`Set ${envVar} environment variable or provide api_key in config.`,
);
}
return key;
}
function resolveApiKeyPool(cfg: ModelConfig, envVar?: string): string[] {
const configured = (cfg.api_keys ?? []).map((key) => key.trim()).filter(Boolean);
if (configured.length > 0) {
return configured;
}
if (cfg.api_key?.trim()) {
return [cfg.api_key.trim()];
}
if (envVar && process.env[envVar]?.trim()) {
return [process.env[envVar]!.trim()];
}
return [];
}
function createApiKeyClient(
keys: string[],
build: (apiKey: string) => ModelClient,
options?: { cooldownMs?: number },
): ModelClient {
if (keys.length === 1) {
return build(keys[0]);
}
return new RotatingModelClient(
keys.map((key) => build(key)),
{ cooldownMs: options?.cooldownMs ?? 0 },
);
}
function resolveZaiCredential(cfg: ModelConfig): string {
const raw = cfg.api_key
?? cfg.auth_token
?? getZaiApiKey();
if (!raw) {
throw new Error(
'Z.AI credential not configured. ' +
'Run `flynn zai-auth` or set ZAI_API_KEY / ZHIPUAI_API_KEY / ZHIPUAI_AUTH_TOKEN, ' +
'or provide api_key/auth_token in config.',
);
}
return raw.startsWith('Bearer ') ? raw.slice('Bearer '.length) : raw;
}
/**
* Create a ModelClient from a provider config entry.
* Dispatches on the `provider` field so all tiers and fallback entries
* use the correct client implementation.
*/
export function createClientFromConfig(cfg: ModelConfig): ModelClient {
switch (cfg.provider) {
case 'anthropic':
{
const authMode = getEffectiveAuthMode(cfg);
if (authMode === 'oauth') {
const token = cfg.auth_token ?? getAnthropicAuthToken();
if (!token) {
throw new Error(
'Anthropic auth token not configured (auth_mode: oauth). ' +
'Set ANTHROPIC_AUTH_TOKEN, run `flynn anthropic-auth --token`, or provide auth_token in config.',
);
}
return new AnthropicClient({
model: cfg.model,
authToken: token,
});
}
if (authMode === 'api_key') {
const keys = resolveApiKeyPool(cfg);
const envKey = getAnthropicApiKey();
const allKeys = keys.length > 0
? keys
: (envKey ? [envKey] : []);
if (allKeys.length === 0) {
throw new Error(
'Anthropic API key not configured (auth_mode: api_key). ' +
'Set ANTHROPIC_API_KEY, run `flynn anthropic-auth`, or provide api_key/api_keys in config.',
);
}
return createApiKeyClient(allKeys, (apiKey) => new AnthropicClient({
model: cfg.model,
apiKey,
}), { cooldownMs: cfg.auth_profile_cooldown_ms });
}
// auto: prefer API keys, then token
const configuredKeys = resolveApiKeyPool(cfg);
const envKey = getAnthropicApiKey();
const allKeys = configuredKeys.length > 0
? configuredKeys
: (envKey ? [envKey] : []);
if (allKeys.length > 0) {
return createApiKeyClient(allKeys, (apiKey) => new AnthropicClient({
model: cfg.model,
apiKey,
}), { cooldownMs: cfg.auth_profile_cooldown_ms });
}
const token = cfg.auth_token ?? getAnthropicAuthToken();
if (token) {
return new AnthropicClient({
model: cfg.model,
authToken: token,
});
}
throw new Error(
'Anthropic credentials not configured (auth_mode: auto). ' +
'Set ANTHROPIC_API_KEY (or run `flynn anthropic-auth`), ' +
'or set ANTHROPIC_AUTH_TOKEN (or run `flynn anthropic-auth --token`).',
);
}
case 'openai':
{
const authMode = getEffectiveAuthMode(cfg);
if (authMode === 'oauth') {
const existing = loadStoredOpenAIAuth();
if (!existing) {
throw new Error(
'OpenAI OAuth is not configured (auth_mode: oauth). ' +
'Run `flynn openai-auth` to authenticate.',
);
}
return new OpenAIClient({
model: cfg.model,
useOAuth: true,
});
}
if (authMode === 'api_key') {
const keys = resolveApiKeyPool(cfg);
const envKey = getOpenAIApiKey();
const allKeys = keys.length > 0
? keys
: (envKey ? [envKey] : []);
if (allKeys.length === 0) {
throw new Error(
'OpenAI API key not configured (auth_mode: api_key). ' +
'Set OPENAI_API_KEY, run `flynn openai-key`, or provide api_key/api_keys in config.',
);
}
return createApiKeyClient(allKeys, (apiKey) => new OpenAIClient({
model: cfg.model,
apiKey,
}), { cooldownMs: cfg.auth_profile_cooldown_ms });
}
// auto: prefer API keys, then OAuth
const configuredKeys = resolveApiKeyPool(cfg);
const envKey = getOpenAIApiKey();
const allKeys = configuredKeys.length > 0
? configuredKeys
: (envKey ? [envKey] : []);
if (allKeys.length > 0) {
return createApiKeyClient(allKeys, (apiKey) => new OpenAIClient({
model: cfg.model,
apiKey,
}), { cooldownMs: cfg.auth_profile_cooldown_ms });
}
const existing = loadStoredOpenAIAuth();
if (existing) {
return new OpenAIClient({
model: cfg.model,
useOAuth: true,
});
}
throw new Error(
'OpenAI credentials not configured (auth_mode: auto). ' +
'Set OPENAI_API_KEY (or run `flynn openai-key`), ' +
'or run `flynn openai-auth` for OAuth.',
);
}
case 'ollama':
return new OllamaClient({
model: cfg.model,
host: cfg.endpoint,
numGpu: cfg.num_gpu,
});
case 'llamacpp':
return new LlamaCppClient({
endpoint: cfg.endpoint ?? 'http://localhost:8080',
model: cfg.model,
authToken: cfg.auth_token,
});
case 'gemini':
{
const apiKey = cfg.api_key ?? getGeminiApiKey() ?? undefined;
return new GeminiClient({
model: cfg.model,
apiKey,
});
}
case 'openrouter':
{
const keys = resolveApiKeyPool(cfg, 'OPENROUTER_API_KEY');
if (keys.length === 0) {
throw new Error(
'API key required for openrouter. Set OPENROUTER_API_KEY or provide api_key/api_keys in config.',
);
}
return createApiKeyClient(keys, (apiKey) => new OpenAIClient({
model: cfg.model,
apiKey,
baseURL: cfg.endpoint ?? 'https://openrouter.ai/api/v1',
}), { cooldownMs: cfg.auth_profile_cooldown_ms });
}
case 'vercel':
return new OpenAIClient({
model: cfg.model,
apiKey: requireApiKey(cfg, 'AI_GATEWAY_API_KEY'),
baseURL: cfg.endpoint ?? 'https://ai-gateway.vercel.sh/v1',
});
case 'zhipuai':
return new OpenAIClient({
model: cfg.model,
apiKey: resolveZaiCredential(cfg),
baseURL: cfg.endpoint ?? 'https://api.z.ai/api/paas/v4',
});
case 'xai':
{
const keys = resolveApiKeyPool(cfg, 'XAI_API_KEY');
if (keys.length === 0) {
throw new Error(
'API key required for xai. Set XAI_API_KEY or provide api_key/api_keys in config.',
);
}
return createApiKeyClient(keys, (apiKey) => new OpenAIClient({
model: cfg.model,
apiKey,
baseURL: cfg.endpoint ?? 'https://api.x.ai/v1',
}), { cooldownMs: cfg.auth_profile_cooldown_ms });
}
case 'minimax':
{
const keys = resolveApiKeyPool(cfg, 'MINIMAX_API_KEY');
if (keys.length === 0) {
throw new Error(
'API key required for minimax. Set MINIMAX_API_KEY or provide api_key/api_keys in config.',
);
}
return createApiKeyClient(keys, (apiKey) => new OpenAIClient({
model: cfg.model,
apiKey,
baseURL: cfg.endpoint ?? 'https://api.minimax.io/v1',
}), { cooldownMs: cfg.auth_profile_cooldown_ms });
}
case 'moonshot':
{
const keys = resolveApiKeyPool(cfg, 'MOONSHOT_API_KEY');
if (keys.length === 0) {
throw new Error(
'API key required for moonshot. Set MOONSHOT_API_KEY or provide api_key/api_keys in config.',
);
}
return createApiKeyClient(keys, (apiKey) => new OpenAIClient({
model: cfg.model,
apiKey,
baseURL: cfg.endpoint ?? 'https://api.moonshot.cn/v1',
}), { cooldownMs: cfg.auth_profile_cooldown_ms });
}
case 'bedrock':
return new BedrockClient({
model: cfg.model,
region: cfg.endpoint,
accessKeyId: cfg.api_key,
secretAccessKey: cfg.auth_token,
});
case 'github':
return new GitHubModelsClient({
model: anthropicToGitHubModel(cfg.model) ?? cfg.model,
apiKey: cfg.api_key,
endpoint: cfg.endpoint,
onLoginRequired: async () => {
const { loginGitHub } = await import('../auth/index.js');
return loginGitHub((userCode, verificationUri) => {
console.log(`GitHub login required. Visit: ${verificationUri}`);
console.log(`Enter code: ${userCode}`);
});
},
});
case 'synthetic':
return new SyntheticClient({ model: cfg.model });
default:
throw new Error(`Unknown model provider: ${(cfg as Record<string, unknown>).provider}`);
}
}
/**
* Map an Anthropic model identifier to its GitHub Models equivalent.
* Returns undefined if no mapping is known.
*
* Anthropic uses hyphens and date suffixes: claude-sonnet-4-5-20250929
* GitHub Copilot uses dots, no dates: claude-sonnet-4.5
*/
export function anthropicToGitHubModel(anthropicModel: string): string | undefined {
// Explicit mappings for known models
const MAPPINGS: Record<string, string> = {
// Sonnet family
'claude-sonnet-4-20250514': 'claude-sonnet-4',
'claude-sonnet-4-5-20250929': 'claude-sonnet-4.5',
'claude-sonnet-4-6-20260217': 'claude-sonnet-4.6',
// Opus family
'claude-opus-4-20250514': 'claude-opus-4',
'claude-opus-4-5-20250918': 'claude-opus-4.5',
'claude-opus-4-6-20250715': 'claude-opus-4.6',
// Haiku family
'claude-3-5-haiku-20241022': 'claude-haiku-4.5',
'claude-haiku-4-5-20251001': 'claude-haiku-4.5',
};
if (MAPPINGS[anthropicModel]) {return MAPPINGS[anthropicModel];}
// Generic fallback: strip date suffix, then convert trailing -N to .N
// only when preceded by another digit (i.e. "4-5" → "4.5", not "sonnet-5" → "sonnet.5")
// e.g. "claude-sonnet-4-7-20260301" → "claude-sonnet-4-7" → "claude-sonnet-4.7"
const dateMatch = anthropicModel.match(/^(.+)-\d{8}$/);
if (dateMatch) {
const base = dateMatch[1];
// Convert "claude-sonnet-4-5" → "claude-sonnet-4.5" (digit-hyphen-digit at end)
const dotted = base.replace(/(\d)-(\d+)$/, '$1.$2');
return dotted;
}
return undefined;
}
/**
* For a given tier config using the Anthropic provider, create a GitHub Models
* client with the equivalent model as an automatic same-model fallback.
* Returns undefined if no mapping exists or the tier isn't Anthropic.
*/
export function createAutoFallbackClient(tierConfig: { provider: string; model: string }): ModelClient | undefined {
if (tierConfig.provider !== 'anthropic') {return undefined;}
const githubModel = anthropicToGitHubModel(tierConfig.model);
if (!githubModel) {return undefined;}
return new GitHubModelsClient({
model: githubModel,
onLoginRequired: async () => {
const { loginGitHub } = await import('../auth/index.js');
return loginGitHub((userCode, verificationUri) => {
console.log(`GitHub login required. Visit: ${verificationUri}`);
console.log(`Enter code: ${userCode}`);
});
},
});
}
export function createModelRouter(config: Config): ModelRouter {
const models = config.models;
const defaultClient = createClientFromConfig(models.default);
const fastClient = models.fast ? createClientFromConfig(models.fast) : undefined;
const complexClient = models.complex ? createClientFromConfig(models.complex) : undefined;
const localClient = models.local ? createClientFromConfig(models.local) : undefined;
// Build fallback chain — each entry references a tier name or 'local'
const fallbackChain: ModelClient[] = [];
for (const providerName of models.fallback_chain) {
if (providerName === 'local' && localClient) {
fallbackChain.push(localClient);
} else if (providerName === 'default') {
// Allows re-trying the default provider in the chain
fallbackChain.push(defaultClient);
} else if (providerName === 'fast' && fastClient) {
fallbackChain.push(fastClient);
} else if (providerName === 'complex' && complexClient) {
fallbackChain.push(complexClient);
} else if (models.local_providers?.[providerName]) {
// Named provider from local_providers map
fallbackChain.push(createClientFromConfig(models.local_providers[providerName]));
} else {
logger.warn(`Fallback chain entry "${providerName}" not found — skipping`);
}
}
// Build per-tier fallbacks from inline fallback config + auto same-model fallbacks.
// Auto-fallback: when a tier uses Anthropic, automatically insert a GitHub Models
// client for the same model *before* any user-configured inline fallbacks.
// This ensures the same model is tried via an alternative provider before
// degrading to the global fallback chain (which may be a much weaker local model).
const tierFallbacks = new Map<ModelTier, ModelClient[]>();
const tierConfigs: { tier: ModelTier; cfg: typeof models.default | undefined }[] = [
{ tier: 'default', cfg: models.default },
{ tier: 'fast', cfg: models.fast },
{ tier: 'complex', cfg: models.complex },
{ tier: 'local', cfg: models.local },
];
const autoFallbackTiers: string[] = [];
for (const { tier, cfg } of tierConfigs) {
if (!cfg) {continue;}
const fallbackList: ModelClient[] = [];
// Auto same-model fallback (only when user hasn't configured an inline fallback)
if (!cfg.fallback) {
const autoClient = createAutoFallbackClient(cfg);
if (autoClient) {
fallbackList.push(autoClient);
autoFallbackTiers.push(tier);
}
}
// User-configured inline fallback
if (cfg.fallback) {
fallbackList.push(createClientFromConfig(cfg.fallback));
}
if (fallbackList.length > 0) {
tierFallbacks.set(tier, fallbackList);
}
}
if (tierFallbacks.size > 0) {
const tierNames = Array.from(tierFallbacks.keys()).join(', ');
logger.info(`Per-tier fallbacks configured for: ${tierNames}`);
}
if (autoFallbackTiers.length > 0) {
logger.info(`Auto same-model fallback (via GitHub Models) for: ${autoFallbackTiers.join(', ')}`);
}
logger.info(`Model router: default=${models.default.provider}/${models.default.model}, ` +
`fallback=[${models.fallback_chain.join(', ')}]`);
// Build retry config if enabled
const retryConfig: RetryConfig | undefined = config.retry.enabled ? {
maxRetries: config.retry.max_retries,
initialDelayMs: config.retry.initial_delay_ms,
backoffMultiplier: config.retry.backoff_multiplier,
maxDelayMs: config.retry.max_delay_ms,
nonRetryablePatterns: DEFAULT_RETRY_CONFIG.nonRetryablePatterns,
} : undefined;
if (retryConfig) {
logger.info(`Retry policy: max_retries=${retryConfig.maxRetries}, initial_delay=${retryConfig.initialDelayMs}ms`);
}
return new ModelRouter({
default: defaultClient,
fast: fastClient,
complex: complexClient,
local: localClient,
fallbackChain,
tierFallbacks,
retryConfig,
labels: {
default: `${models.default.provider}/${models.default.model}`,
...(models.fast ? { fast: `${models.fast.provider}/${models.fast.model}` } : {}),
...(models.complex ? { complex: `${models.complex.provider}/${models.complex.model}` } : {}),
...(models.local ? { local: `${models.local.provider}/${models.local.model}` } : {}),
},
});
}