feat: add OpenAI OAuth, strict model overrides, and Gmail pull mode
This commit is contained in:
@@ -96,6 +96,34 @@ describe('createClientFromConfig', () => {
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
});
|
||||
|
||||
it('creates OpenAIClient for zhipuai when using auth_token', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'zhipuai',
|
||||
model: 'glm-4.5',
|
||||
auth_token: 'oauth-access-token',
|
||||
});
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
});
|
||||
|
||||
it('creates OpenAIClient for zhipuai using ZHIPUAI_AUTH_TOKEN env var', () => {
|
||||
const prev = process.env.ZHIPUAI_AUTH_TOKEN;
|
||||
process.env.ZHIPUAI_AUTH_TOKEN = 'oauth-access-token';
|
||||
|
||||
try {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'zhipuai',
|
||||
model: 'glm-4.5',
|
||||
});
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
} finally {
|
||||
if (prev === undefined) {
|
||||
delete process.env.ZHIPUAI_AUTH_TOKEN;
|
||||
} else {
|
||||
process.env.ZHIPUAI_AUTH_TOKEN = prev;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('creates BedrockClient for bedrock provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'bedrock',
|
||||
|
||||
+19
-1
@@ -18,6 +18,23 @@ function requireApiKey(cfg: ModelConfig, envVar: string): string {
|
||||
return key;
|
||||
}
|
||||
|
||||
function resolveAuthCredential(cfg: ModelConfig, apiKeyEnvVar: string, authTokenEnvVar?: string): string {
|
||||
const raw = cfg.api_key
|
||||
?? cfg.auth_token
|
||||
?? process.env[apiKeyEnvVar]
|
||||
?? (authTokenEnvVar ? process.env[authTokenEnvVar] : undefined);
|
||||
|
||||
if (!raw) {
|
||||
const envHint = authTokenEnvVar ? `${apiKeyEnvVar} or ${authTokenEnvVar}` : apiKeyEnvVar;
|
||||
throw new Error(
|
||||
`Credential required for ${cfg.provider}. ` +
|
||||
`Set ${envHint} environment variable or provide api_key/auth_token in config.`,
|
||||
);
|
||||
}
|
||||
|
||||
return raw.startsWith('Bearer ') ? raw.slice('Bearer '.length) : raw;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a ModelClient from a provider config entry.
|
||||
* Dispatches on the `provider` field so all tiers and fallback entries
|
||||
@@ -35,6 +52,7 @@ export function createClientFromConfig(cfg: ModelConfig): ModelClient {
|
||||
return new OpenAIClient({
|
||||
model: cfg.model,
|
||||
apiKey: cfg.api_key,
|
||||
useOAuth: Boolean(cfg.use_oauth),
|
||||
});
|
||||
case 'ollama':
|
||||
return new OllamaClient({
|
||||
@@ -62,7 +80,7 @@ export function createClientFromConfig(cfg: ModelConfig): ModelClient {
|
||||
case 'zhipuai':
|
||||
return new OpenAIClient({
|
||||
model: cfg.model,
|
||||
apiKey: requireApiKey(cfg, 'ZHIPUAI_API_KEY'),
|
||||
apiKey: resolveAuthCredential(cfg, 'ZHIPUAI_API_KEY', 'ZHIPUAI_AUTH_TOKEN'),
|
||||
baseURL: cfg.endpoint ?? 'https://api.z.ai/api/paas/v4',
|
||||
});
|
||||
case 'xai':
|
||||
|
||||
+112
-8
@@ -9,7 +9,7 @@ import { MemoryStore } from '../memory/index.js';
|
||||
import type { Tool } from '../tools/types.js';
|
||||
import { createMediaSendTool } from '../tools/index.js';
|
||||
import { createSandboxedShellTool, createSandboxedProcessStartTool, SandboxManager } from '../sandbox/index.js';
|
||||
import type { Config } from '../config/index.js';
|
||||
import { MODEL_PROVIDERS, type Config, type ModelConfig, type ModelProvider } from '../config/index.js';
|
||||
import { ModelRouter, type ModelTier } from '../models/index.js';
|
||||
import { ToolRegistry, ToolExecutor } from '../tools/index.js';
|
||||
import { SessionManager } from '../session/index.js';
|
||||
@@ -17,6 +17,27 @@ import { AgentConfigRegistry, AgentRouter } from '../agents/index.js';
|
||||
import type { CommandRegistry } from '../commands/index.js';
|
||||
import type { ComponentRegistry } from '../intents/index.js';
|
||||
import type { RoutingPolicy } from '../routing/index.js';
|
||||
import { createClientFromConfig } from './models.js';
|
||||
|
||||
function buildProviderConfigMap(config: Config): Partial<Record<ModelProvider, ModelConfig>> {
|
||||
const providerConfigs: Partial<Record<ModelProvider, ModelConfig>> = {};
|
||||
const modelConfigs: ModelConfig[] = [
|
||||
config.models.default,
|
||||
...(config.models.fast ? [config.models.fast] : []),
|
||||
...(config.models.complex ? [config.models.complex] : []),
|
||||
...(config.models.local ? [config.models.local] : []),
|
||||
...Object.values(config.models.local_providers ?? {}),
|
||||
];
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
providerConfigs[modelConfig.provider] = modelConfig;
|
||||
if (modelConfig.fallback) {
|
||||
providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback;
|
||||
}
|
||||
}
|
||||
|
||||
return providerConfigs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the unified message handler for the channel registry.
|
||||
@@ -263,14 +284,97 @@ export function createMessageRouter(deps: {
|
||||
return lines.join('\n');
|
||||
},
|
||||
setModel: (tier) => {
|
||||
const validTiers = deps.modelRouter.getAvailableTiers();
|
||||
if (!validTiers.includes(tier as ModelTier)) {
|
||||
return `Model tier not available: ${tier}`;
|
||||
const raw = tier.trim();
|
||||
if (!raw) {
|
||||
return 'Usage: /model <tier> OR /model <tier> <provider/model> OR /model <tier> reset';
|
||||
}
|
||||
|
||||
const parts = raw.split(/\s+/);
|
||||
const requestedTier = parts[0];
|
||||
|
||||
const validTiers = deps.modelRouter.getAvailableTiers();
|
||||
if (!validTiers.includes(requestedTier as ModelTier)) {
|
||||
return `Model tier not available: ${requestedTier}`;
|
||||
}
|
||||
|
||||
const modelTier = requestedTier as ModelTier;
|
||||
|
||||
// /model <tier>
|
||||
if (parts.length === 1) {
|
||||
session.setConfig('modelTier', modelTier);
|
||||
agent.setModelTier(modelTier);
|
||||
const label = deps.modelRouter.getLabel(modelTier);
|
||||
return `Switched to model: ${modelTier} (${label})`;
|
||||
}
|
||||
|
||||
const arg2 = parts[1];
|
||||
|
||||
// /model <tier> reset — restore configured provider/model and re-enable fallbacks
|
||||
if (arg2.toLowerCase() === 'reset') {
|
||||
const configured: ModelConfig | undefined = modelTier === 'default'
|
||||
? deps.config.models.default
|
||||
: modelTier === 'fast'
|
||||
? deps.config.models.fast
|
||||
: modelTier === 'complex'
|
||||
? deps.config.models.complex
|
||||
: modelTier === 'local'
|
||||
? deps.config.models.local
|
||||
: undefined;
|
||||
if (!configured) {
|
||||
return `No configured model for tier: ${modelTier}`;
|
||||
}
|
||||
|
||||
const client = createClientFromConfig(configured);
|
||||
const label = `${configured.provider}/${configured.model}`;
|
||||
deps.modelRouter.setClient(modelTier, client, label);
|
||||
deps.modelRouter.setTierStrict(modelTier, false);
|
||||
session.setConfig('modelTier', modelTier);
|
||||
agent.setModelTier(modelTier);
|
||||
return `Reset ${modelTier} to: ${label}`;
|
||||
}
|
||||
|
||||
// /model <tier> <provider/model>
|
||||
const providerModel = arg2;
|
||||
if (!providerModel.includes('/')) {
|
||||
return 'Invalid format. Use: /model <tier> <provider/model> (e.g. /model default github/gpt-5-mini)';
|
||||
}
|
||||
|
||||
const slashIdx = providerModel.indexOf('/');
|
||||
const provider = providerModel.slice(0, slashIdx);
|
||||
const model = providerModel.slice(slashIdx + 1);
|
||||
|
||||
if (!MODEL_PROVIDERS.includes(provider as ModelProvider)) {
|
||||
return `Unknown provider "${provider}". Known providers: ${MODEL_PROVIDERS.join(', ')}`;
|
||||
}
|
||||
|
||||
const providerType = provider as ModelProvider;
|
||||
const providerConfigs = buildProviderConfigMap(deps.config);
|
||||
const template = providerConfigs[providerType];
|
||||
|
||||
try {
|
||||
const client = createClientFromConfig(
|
||||
template
|
||||
? { ...template, provider: providerType, model }
|
||||
: { provider: providerType, model },
|
||||
);
|
||||
|
||||
deps.modelRouter.setClient(modelTier, client, providerModel);
|
||||
deps.modelRouter.setTierStrict(modelTier, true);
|
||||
session.setConfig('modelTier', modelTier);
|
||||
agent.setModelTier(modelTier);
|
||||
|
||||
const lines = [
|
||||
`Set ${modelTier} to: ${providerModel}`,
|
||||
`Fallbacks for ${modelTier} disabled (strict tier mode).`,
|
||||
];
|
||||
if (parts.length > 2) {
|
||||
lines.push(`Note: ignored extra args: ${parts.slice(2).join(' ')}`);
|
||||
}
|
||||
return lines.join('\n');
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
return `Failed to switch ${modelTier} to ${providerModel}: ${message}`;
|
||||
}
|
||||
session.setConfig('modelTier', tier);
|
||||
agent.setModelTier(tier as ModelTier);
|
||||
const label = deps.modelRouter.getLabel(tier as ModelTier);
|
||||
return `Switched to model: ${tier} (${label})`;
|
||||
},
|
||||
compact: async () => {
|
||||
const result = await agent.compact();
|
||||
|
||||
Reference in New Issue
Block a user