diff --git a/docs/plans/state.json b/docs/plans/state.json index e473ef0..10f6cc3 100644 --- a/docs/plans/state.json +++ b/docs/plans/state.json @@ -5559,6 +5559,19 @@ "docs/plans/state.json" ], "test_status": "pnpm typecheck + pnpm test:run src/commands/registry.test.ts src/commands/builtin/index.test.ts passing" + }, + "webchat-gateway-model-switch-support": { + "status": "completed", + "date": "2026-02-19", + "updated": "2026-02-19", + "summary": "Extended gateway command fast-path to support full /model syntax in WebChat sessions: /model , /model , and /model reset. Gateway now wires ModelRouter/runtime config into agent handlers so provider/model switching and strict-mode toggles match channel behavior.", + "files_modified": [ + "src/gateway/handlers/agent.ts", + "src/gateway/server.ts", + "src/gateway/handlers/agent.test.ts", + "docs/plans/state.json" + ], + "test_status": "pnpm test:run src/gateway/handlers/agent.test.ts src/gateway/server.test.ts + pnpm typecheck passing" } }, "overall_progress": { diff --git a/src/gateway/handlers/agent.test.ts b/src/gateway/handlers/agent.test.ts index 4add825..f204f01 100644 --- a/src/gateway/handlers/agent.test.ts +++ b/src/gateway/handlers/agent.test.ts @@ -143,6 +143,82 @@ describe('createAgentHandlers command fast-path', () => { expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Switched to model tier: fast'); }); + it('handles /model in gateway sessions', async () => { + const sent: OutboundMessage[] = []; + const send = vi.fn((msg: OutboundMessage) => sent.push(msg)); + const modelRouter = { + setClient: vi.fn(), + setTierStrict: vi.fn(), + }; + const handlersWithRouter = createAgentHandlers({ + sessionBridge: sessionBridge as unknown as AgentHandlerDeps['sessionBridge'], + laneQueue: new LaneQueue(), + sessionManager: sessionManager as unknown as AgentHandlerDeps['sessionManager'], + commandRegistry, + modelRouter: modelRouter as unknown as AgentHandlerDeps['modelRouter'], + runtimeConfig: { + models: { + default: { provider: 'anthropic', model: 'claude-sonnet-4' }, + fallback_chain: ['anthropic'], + }, + } as unknown as AgentHandlerDeps['runtimeConfig'], + }); + const req: GatewayRequest = { + id: 9, + method: 'agent.send', + params: { + message: '/model default github/gpt-5-mini', + connectionId: 'conn-1', + metadata: { isCommand: true, command: 'model', commandArgs: 'default github/gpt-5-mini' }, + }, + }; + + await handlersWithRouter['agent.send'](req, send); + + expect(modelRouter.setClient).toHaveBeenCalledWith('default', expect.anything(), 'github/gpt-5-mini'); + expect(modelRouter.setTierStrict).toHaveBeenCalledWith('default', true); + expect(mockAgent.setModelTier).toHaveBeenCalledWith('default'); + expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Set default to: github/gpt-5-mini'); + }); + + it('handles /model reset in gateway sessions', async () => { + const sent: OutboundMessage[] = []; + const send = vi.fn((msg: OutboundMessage) => sent.push(msg)); + const modelRouter = { + setClient: vi.fn(), + setTierStrict: vi.fn(), + }; + const handlersWithRouter = createAgentHandlers({ + sessionBridge: sessionBridge as unknown as AgentHandlerDeps['sessionBridge'], + laneQueue: new LaneQueue(), + sessionManager: sessionManager as unknown as AgentHandlerDeps['sessionManager'], + commandRegistry, + modelRouter: modelRouter as unknown as AgentHandlerDeps['modelRouter'], + runtimeConfig: { + models: { + default: { provider: 'anthropic', model: 'claude-sonnet-4' }, + fallback_chain: ['anthropic'], + }, + } as unknown as AgentHandlerDeps['runtimeConfig'], + }); + const req: GatewayRequest = { + id: 10, + method: 'agent.send', + params: { + message: '/model default reset', + connectionId: 'conn-1', + metadata: { isCommand: true, command: 'model', commandArgs: 'default reset' }, + }, + }; + + await handlersWithRouter['agent.send'](req, send); + + expect(modelRouter.setClient).toHaveBeenCalledWith('default', expect.anything(), 'anthropic/claude-sonnet-4'); + expect(modelRouter.setTierStrict).toHaveBeenCalledWith('default', false); + expect(mockAgent.setModelTier).toHaveBeenCalledWith('default'); + expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Reset default to: anthropic/claude-sonnet-4'); + }); + it('falls through to agent.process for unknown commands', async () => { const sent: OutboundMessage[] = []; const send = vi.fn((msg: OutboundMessage) => sent.push(msg)); diff --git a/src/gateway/handlers/agent.ts b/src/gateway/handlers/agent.ts index a273fa3..2bea192 100644 --- a/src/gateway/handlers/agent.ts +++ b/src/gateway/handlers/agent.ts @@ -9,7 +9,11 @@ import type { MetricsCollector } from '../metrics.js'; import type { Attachment } from '../../channels/types.js'; import type { SessionManager } from '../../session/manager.js'; import type { ModelTier } from '../../models/router.js'; +import type { ModelRouter } from '../../models/router.js'; import type { CommandRegistry } from '../../commands/index.js'; +import type { Config, ModelConfig, ModelProvider } from '../../config/index.js'; +import { MODEL_PROVIDERS } from '../../config/index.js'; +import { createClientFromConfig } from '../../daemon/models.js'; import { auditLogger } from '../../audit/index.js'; import { randomUUID } from 'crypto'; @@ -25,6 +29,28 @@ export interface AgentHandlerDeps { metrics?: MetricsCollector; sessionManager?: SessionManager; commandRegistry?: CommandRegistry; + modelRouter?: ModelRouter; + runtimeConfig?: Config; +} + +function buildProviderConfigMap(config: Config): Partial> { + const providerConfigs: Partial> = {}; + const modelConfigs: ModelConfig[] = [ + config.models.default, + ...(config.models.fast ? [config.models.fast] : []), + ...(config.models.complex ? [config.models.complex] : []), + ...(config.models.local ? [config.models.local] : []), + ...Object.values(config.models.local_providers ?? {}), + ]; + + for (const modelConfig of modelConfigs) { + providerConfigs[modelConfig.provider] = modelConfig; + if (modelConfig.fallback) { + providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback; + } + } + + return providerConfigs; } export function createAgentHandlers(deps: AgentHandlerDeps) { @@ -157,22 +183,100 @@ export function createAgentHandlers(deps: AgentHandlerDeps) { setModel: (input) => { const raw = input.trim(); if (!raw) { - return 'Usage: /model '; + return 'Usage: /model OR /model OR /model reset'; } - const [requestedTier, ...rest] = raw.split(/\s+/); - const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local']; - const modelTier = requestedTier as ModelTier; - if (!validTiers.includes(modelTier)) { + const parts = raw.split(/\s+/); + const requestedTier = parts[0]; + const validTiers: ModelTier[] = ['default', 'fast', 'complex', 'local']; + if (!validTiers.includes(requestedTier as ModelTier)) { return `Invalid tier: ${requestedTier}. Available: ${validTiers.join(', ')}`; } - if (rest.length > 0) { - return `Switched to model tier: ${modelTier}\nNote: provider/model switching is not available via gateway (/model ).`; + const modelTier = requestedTier as ModelTier; + + // /model + if (parts.length === 1) { + agent.setModelTier(modelTier); + if (sessionId && deps.sessionManager) { + deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier); + } + return `Switched to model tier: ${modelTier}`; } - agent.setModelTier(modelTier); - if (sessionId && deps.sessionManager) { - deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier); + + if (!deps.modelRouter || !deps.runtimeConfig) { + return 'Provider/model switching is unavailable in this gateway runtime.'; + } + + const arg2 = parts[1]; + // /model reset + if (arg2.toLowerCase() === 'reset') { + const configured: ModelConfig | undefined = modelTier === 'default' + ? deps.runtimeConfig.models.default + : modelTier === 'fast' + ? deps.runtimeConfig.models.fast + : modelTier === 'complex' + ? deps.runtimeConfig.models.complex + : modelTier === 'local' + ? deps.runtimeConfig.models.local + : undefined; + if (!configured) { + return `No configured model for tier: ${modelTier}`; + } + + const client = createClientFromConfig(configured); + const label = `${configured.provider}/${configured.model}`; + deps.modelRouter.setClient(modelTier, client, label); + deps.modelRouter.setTierStrict(modelTier, false); + agent.setModelTier(modelTier); + if (sessionId && deps.sessionManager) { + deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier); + } + return `Reset ${modelTier} to: ${label}`; + } + + // /model + const providerModel = arg2; + if (!providerModel.includes('/')) { + return 'Invalid format. Use: /model (e.g. /model default github/gpt-5-mini)'; + } + + const slashIdx = providerModel.indexOf('/'); + const provider = providerModel.slice(0, slashIdx); + const model = providerModel.slice(slashIdx + 1); + + if (!MODEL_PROVIDERS.includes(provider as ModelProvider)) { + return `Unknown provider "${provider}". Known providers: ${MODEL_PROVIDERS.join(', ')}`; + } + + const providerType = provider as ModelProvider; + const providerConfigs = buildProviderConfigMap(deps.runtimeConfig); + const template = providerConfigs[providerType]; + + try { + const client = createClientFromConfig( + template + ? { ...template, provider: providerType, model } + : { provider: providerType, model }, + ); + + deps.modelRouter.setClient(modelTier, client, providerModel); + deps.modelRouter.setTierStrict(modelTier, true); + agent.setModelTier(modelTier); + if (sessionId && deps.sessionManager) { + deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier); + } + + const lines = [ + `Set ${modelTier} to: ${providerModel}`, + `Fallbacks for ${modelTier} disabled (strict tier mode).`, + ]; + if (parts.length > 2) { + lines.push(`Note: ignored extra args: ${parts.slice(2).join(' ')}`); + } + return lines.join('\n'); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + return `Failed to switch ${modelTier} to ${providerModel}: ${message}`; } - return `Switched to model tier: ${modelTier}`; }, compact: async () => { const result = await agent.compact(); diff --git a/src/gateway/server.ts b/src/gateway/server.ts index e3b7eea..f1944fd 100644 --- a/src/gateway/server.ts +++ b/src/gateway/server.ts @@ -399,6 +399,8 @@ export class GatewayServer { metrics: this.metrics, sessionManager: this.config.sessionManager, commandRegistry: this.config.commandRegistry, + modelRouter: 'setClient' in this.config.modelClient ? this.config.modelClient : undefined, + runtimeConfig: this.config.config, }); const intentHandlers = createIntentHandlers({