feat(gateway): support full /model switching in webchat sessions
This commit is contained in:
@@ -5559,6 +5559,19 @@
|
||||
"docs/plans/state.json"
|
||||
],
|
||||
"test_status": "pnpm typecheck + pnpm test:run src/commands/registry.test.ts src/commands/builtin/index.test.ts passing"
|
||||
},
|
||||
"webchat-gateway-model-switch-support": {
|
||||
"status": "completed",
|
||||
"date": "2026-02-19",
|
||||
"updated": "2026-02-19",
|
||||
"summary": "Extended gateway command fast-path to support full /model syntax in WebChat sessions: /model <tier>, /model <tier> <provider/model>, and /model <tier> reset. Gateway now wires ModelRouter/runtime config into agent handlers so provider/model switching and strict-mode toggles match channel behavior.",
|
||||
"files_modified": [
|
||||
"src/gateway/handlers/agent.ts",
|
||||
"src/gateway/server.ts",
|
||||
"src/gateway/handlers/agent.test.ts",
|
||||
"docs/plans/state.json"
|
||||
],
|
||||
"test_status": "pnpm test:run src/gateway/handlers/agent.test.ts src/gateway/server.test.ts + pnpm typecheck passing"
|
||||
}
|
||||
},
|
||||
"overall_progress": {
|
||||
|
||||
@@ -143,6 +143,82 @@ describe('createAgentHandlers command fast-path', () => {
|
||||
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Switched to model tier: fast');
|
||||
});
|
||||
|
||||
it('handles /model <tier> <provider/model> in gateway sessions', async () => {
|
||||
const sent: OutboundMessage[] = [];
|
||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||
const modelRouter = {
|
||||
setClient: vi.fn(),
|
||||
setTierStrict: vi.fn(),
|
||||
};
|
||||
const handlersWithRouter = createAgentHandlers({
|
||||
sessionBridge: sessionBridge as unknown as AgentHandlerDeps['sessionBridge'],
|
||||
laneQueue: new LaneQueue(),
|
||||
sessionManager: sessionManager as unknown as AgentHandlerDeps['sessionManager'],
|
||||
commandRegistry,
|
||||
modelRouter: modelRouter as unknown as AgentHandlerDeps['modelRouter'],
|
||||
runtimeConfig: {
|
||||
models: {
|
||||
default: { provider: 'anthropic', model: 'claude-sonnet-4' },
|
||||
fallback_chain: ['anthropic'],
|
||||
},
|
||||
} as unknown as AgentHandlerDeps['runtimeConfig'],
|
||||
});
|
||||
const req: GatewayRequest = {
|
||||
id: 9,
|
||||
method: 'agent.send',
|
||||
params: {
|
||||
message: '/model default github/gpt-5-mini',
|
||||
connectionId: 'conn-1',
|
||||
metadata: { isCommand: true, command: 'model', commandArgs: 'default github/gpt-5-mini' },
|
||||
},
|
||||
};
|
||||
|
||||
await handlersWithRouter['agent.send'](req, send);
|
||||
|
||||
expect(modelRouter.setClient).toHaveBeenCalledWith('default', expect.anything(), 'github/gpt-5-mini');
|
||||
expect(modelRouter.setTierStrict).toHaveBeenCalledWith('default', true);
|
||||
expect(mockAgent.setModelTier).toHaveBeenCalledWith('default');
|
||||
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Set default to: github/gpt-5-mini');
|
||||
});
|
||||
|
||||
it('handles /model <tier> reset in gateway sessions', async () => {
|
||||
const sent: OutboundMessage[] = [];
|
||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||
const modelRouter = {
|
||||
setClient: vi.fn(),
|
||||
setTierStrict: vi.fn(),
|
||||
};
|
||||
const handlersWithRouter = createAgentHandlers({
|
||||
sessionBridge: sessionBridge as unknown as AgentHandlerDeps['sessionBridge'],
|
||||
laneQueue: new LaneQueue(),
|
||||
sessionManager: sessionManager as unknown as AgentHandlerDeps['sessionManager'],
|
||||
commandRegistry,
|
||||
modelRouter: modelRouter as unknown as AgentHandlerDeps['modelRouter'],
|
||||
runtimeConfig: {
|
||||
models: {
|
||||
default: { provider: 'anthropic', model: 'claude-sonnet-4' },
|
||||
fallback_chain: ['anthropic'],
|
||||
},
|
||||
} as unknown as AgentHandlerDeps['runtimeConfig'],
|
||||
});
|
||||
const req: GatewayRequest = {
|
||||
id: 10,
|
||||
method: 'agent.send',
|
||||
params: {
|
||||
message: '/model default reset',
|
||||
connectionId: 'conn-1',
|
||||
metadata: { isCommand: true, command: 'model', commandArgs: 'default reset' },
|
||||
},
|
||||
};
|
||||
|
||||
await handlersWithRouter['agent.send'](req, send);
|
||||
|
||||
expect(modelRouter.setClient).toHaveBeenCalledWith('default', expect.anything(), 'anthropic/claude-sonnet-4');
|
||||
expect(modelRouter.setTierStrict).toHaveBeenCalledWith('default', false);
|
||||
expect(mockAgent.setModelTier).toHaveBeenCalledWith('default');
|
||||
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Reset default to: anthropic/claude-sonnet-4');
|
||||
});
|
||||
|
||||
it('falls through to agent.process for unknown commands', async () => {
|
||||
const sent: OutboundMessage[] = [];
|
||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||
|
||||
+115
-11
@@ -9,7 +9,11 @@ import type { MetricsCollector } from '../metrics.js';
|
||||
import type { Attachment } from '../../channels/types.js';
|
||||
import type { SessionManager } from '../../session/manager.js';
|
||||
import type { ModelTier } from '../../models/router.js';
|
||||
import type { ModelRouter } from '../../models/router.js';
|
||||
import type { CommandRegistry } from '../../commands/index.js';
|
||||
import type { Config, ModelConfig, ModelProvider } from '../../config/index.js';
|
||||
import { MODEL_PROVIDERS } from '../../config/index.js';
|
||||
import { createClientFromConfig } from '../../daemon/models.js';
|
||||
import { auditLogger } from '../../audit/index.js';
|
||||
import { randomUUID } from 'crypto';
|
||||
|
||||
@@ -25,6 +29,28 @@ export interface AgentHandlerDeps {
|
||||
metrics?: MetricsCollector;
|
||||
sessionManager?: SessionManager;
|
||||
commandRegistry?: CommandRegistry;
|
||||
modelRouter?: ModelRouter;
|
||||
runtimeConfig?: Config;
|
||||
}
|
||||
|
||||
function buildProviderConfigMap(config: Config): Partial<Record<ModelProvider, ModelConfig>> {
|
||||
const providerConfigs: Partial<Record<ModelProvider, ModelConfig>> = {};
|
||||
const modelConfigs: ModelConfig[] = [
|
||||
config.models.default,
|
||||
...(config.models.fast ? [config.models.fast] : []),
|
||||
...(config.models.complex ? [config.models.complex] : []),
|
||||
...(config.models.local ? [config.models.local] : []),
|
||||
...Object.values(config.models.local_providers ?? {}),
|
||||
];
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
providerConfigs[modelConfig.provider] = modelConfig;
|
||||
if (modelConfig.fallback) {
|
||||
providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback;
|
||||
}
|
||||
}
|
||||
|
||||
return providerConfigs;
|
||||
}
|
||||
|
||||
export function createAgentHandlers(deps: AgentHandlerDeps) {
|
||||
@@ -157,22 +183,100 @@ export function createAgentHandlers(deps: AgentHandlerDeps) {
|
||||
setModel: (input) => {
|
||||
const raw = input.trim();
|
||||
if (!raw) {
|
||||
return 'Usage: /model <tier>';
|
||||
return 'Usage: /model <tier> OR /model <tier> <provider/model> OR /model <tier> reset';
|
||||
}
|
||||
const [requestedTier, ...rest] = raw.split(/\s+/);
|
||||
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
|
||||
const modelTier = requestedTier as ModelTier;
|
||||
if (!validTiers.includes(modelTier)) {
|
||||
const parts = raw.split(/\s+/);
|
||||
const requestedTier = parts[0];
|
||||
const validTiers: ModelTier[] = ['default', 'fast', 'complex', 'local'];
|
||||
if (!validTiers.includes(requestedTier as ModelTier)) {
|
||||
return `Invalid tier: ${requestedTier}. Available: ${validTiers.join(', ')}`;
|
||||
}
|
||||
if (rest.length > 0) {
|
||||
return `Switched to model tier: ${modelTier}\nNote: provider/model switching is not available via gateway (/model <tier> <provider/model>).`;
|
||||
const modelTier = requestedTier as ModelTier;
|
||||
|
||||
// /model <tier>
|
||||
if (parts.length === 1) {
|
||||
agent.setModelTier(modelTier);
|
||||
if (sessionId && deps.sessionManager) {
|
||||
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
||||
}
|
||||
return `Switched to model tier: ${modelTier}`;
|
||||
}
|
||||
agent.setModelTier(modelTier);
|
||||
if (sessionId && deps.sessionManager) {
|
||||
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
||||
|
||||
if (!deps.modelRouter || !deps.runtimeConfig) {
|
||||
return 'Provider/model switching is unavailable in this gateway runtime.';
|
||||
}
|
||||
|
||||
const arg2 = parts[1];
|
||||
// /model <tier> reset
|
||||
if (arg2.toLowerCase() === 'reset') {
|
||||
const configured: ModelConfig | undefined = modelTier === 'default'
|
||||
? deps.runtimeConfig.models.default
|
||||
: modelTier === 'fast'
|
||||
? deps.runtimeConfig.models.fast
|
||||
: modelTier === 'complex'
|
||||
? deps.runtimeConfig.models.complex
|
||||
: modelTier === 'local'
|
||||
? deps.runtimeConfig.models.local
|
||||
: undefined;
|
||||
if (!configured) {
|
||||
return `No configured model for tier: ${modelTier}`;
|
||||
}
|
||||
|
||||
const client = createClientFromConfig(configured);
|
||||
const label = `${configured.provider}/${configured.model}`;
|
||||
deps.modelRouter.setClient(modelTier, client, label);
|
||||
deps.modelRouter.setTierStrict(modelTier, false);
|
||||
agent.setModelTier(modelTier);
|
||||
if (sessionId && deps.sessionManager) {
|
||||
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
||||
}
|
||||
return `Reset ${modelTier} to: ${label}`;
|
||||
}
|
||||
|
||||
// /model <tier> <provider/model>
|
||||
const providerModel = arg2;
|
||||
if (!providerModel.includes('/')) {
|
||||
return 'Invalid format. Use: /model <tier> <provider/model> (e.g. /model default github/gpt-5-mini)';
|
||||
}
|
||||
|
||||
const slashIdx = providerModel.indexOf('/');
|
||||
const provider = providerModel.slice(0, slashIdx);
|
||||
const model = providerModel.slice(slashIdx + 1);
|
||||
|
||||
if (!MODEL_PROVIDERS.includes(provider as ModelProvider)) {
|
||||
return `Unknown provider "${provider}". Known providers: ${MODEL_PROVIDERS.join(', ')}`;
|
||||
}
|
||||
|
||||
const providerType = provider as ModelProvider;
|
||||
const providerConfigs = buildProviderConfigMap(deps.runtimeConfig);
|
||||
const template = providerConfigs[providerType];
|
||||
|
||||
try {
|
||||
const client = createClientFromConfig(
|
||||
template
|
||||
? { ...template, provider: providerType, model }
|
||||
: { provider: providerType, model },
|
||||
);
|
||||
|
||||
deps.modelRouter.setClient(modelTier, client, providerModel);
|
||||
deps.modelRouter.setTierStrict(modelTier, true);
|
||||
agent.setModelTier(modelTier);
|
||||
if (sessionId && deps.sessionManager) {
|
||||
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
||||
}
|
||||
|
||||
const lines = [
|
||||
`Set ${modelTier} to: ${providerModel}`,
|
||||
`Fallbacks for ${modelTier} disabled (strict tier mode).`,
|
||||
];
|
||||
if (parts.length > 2) {
|
||||
lines.push(`Note: ignored extra args: ${parts.slice(2).join(' ')}`);
|
||||
}
|
||||
return lines.join('\n');
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
return `Failed to switch ${modelTier} to ${providerModel}: ${message}`;
|
||||
}
|
||||
return `Switched to model tier: ${modelTier}`;
|
||||
},
|
||||
compact: async () => {
|
||||
const result = await agent.compact();
|
||||
|
||||
@@ -399,6 +399,8 @@ export class GatewayServer {
|
||||
metrics: this.metrics,
|
||||
sessionManager: this.config.sessionManager,
|
||||
commandRegistry: this.config.commandRegistry,
|
||||
modelRouter: 'setClient' in this.config.modelClient ? this.config.modelClient : undefined,
|
||||
runtimeConfig: this.config.config,
|
||||
});
|
||||
|
||||
const intentHandlers = createIntentHandlers({
|
||||
|
||||
Reference in New Issue
Block a user