feat(gateway): support full /model switching in webchat sessions
This commit is contained in:
@@ -5559,6 +5559,19 @@
|
|||||||
"docs/plans/state.json"
|
"docs/plans/state.json"
|
||||||
],
|
],
|
||||||
"test_status": "pnpm typecheck + pnpm test:run src/commands/registry.test.ts src/commands/builtin/index.test.ts passing"
|
"test_status": "pnpm typecheck + pnpm test:run src/commands/registry.test.ts src/commands/builtin/index.test.ts passing"
|
||||||
|
},
|
||||||
|
"webchat-gateway-model-switch-support": {
|
||||||
|
"status": "completed",
|
||||||
|
"date": "2026-02-19",
|
||||||
|
"updated": "2026-02-19",
|
||||||
|
"summary": "Extended gateway command fast-path to support full /model syntax in WebChat sessions: /model <tier>, /model <tier> <provider/model>, and /model <tier> reset. Gateway now wires ModelRouter/runtime config into agent handlers so provider/model switching and strict-mode toggles match channel behavior.",
|
||||||
|
"files_modified": [
|
||||||
|
"src/gateway/handlers/agent.ts",
|
||||||
|
"src/gateway/server.ts",
|
||||||
|
"src/gateway/handlers/agent.test.ts",
|
||||||
|
"docs/plans/state.json"
|
||||||
|
],
|
||||||
|
"test_status": "pnpm test:run src/gateway/handlers/agent.test.ts src/gateway/server.test.ts + pnpm typecheck passing"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"overall_progress": {
|
"overall_progress": {
|
||||||
|
|||||||
@@ -143,6 +143,82 @@ describe('createAgentHandlers command fast-path', () => {
|
|||||||
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Switched to model tier: fast');
|
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Switched to model tier: fast');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('handles /model <tier> <provider/model> in gateway sessions', async () => {
|
||||||
|
const sent: OutboundMessage[] = [];
|
||||||
|
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||||
|
const modelRouter = {
|
||||||
|
setClient: vi.fn(),
|
||||||
|
setTierStrict: vi.fn(),
|
||||||
|
};
|
||||||
|
const handlersWithRouter = createAgentHandlers({
|
||||||
|
sessionBridge: sessionBridge as unknown as AgentHandlerDeps['sessionBridge'],
|
||||||
|
laneQueue: new LaneQueue(),
|
||||||
|
sessionManager: sessionManager as unknown as AgentHandlerDeps['sessionManager'],
|
||||||
|
commandRegistry,
|
||||||
|
modelRouter: modelRouter as unknown as AgentHandlerDeps['modelRouter'],
|
||||||
|
runtimeConfig: {
|
||||||
|
models: {
|
||||||
|
default: { provider: 'anthropic', model: 'claude-sonnet-4' },
|
||||||
|
fallback_chain: ['anthropic'],
|
||||||
|
},
|
||||||
|
} as unknown as AgentHandlerDeps['runtimeConfig'],
|
||||||
|
});
|
||||||
|
const req: GatewayRequest = {
|
||||||
|
id: 9,
|
||||||
|
method: 'agent.send',
|
||||||
|
params: {
|
||||||
|
message: '/model default github/gpt-5-mini',
|
||||||
|
connectionId: 'conn-1',
|
||||||
|
metadata: { isCommand: true, command: 'model', commandArgs: 'default github/gpt-5-mini' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
await handlersWithRouter['agent.send'](req, send);
|
||||||
|
|
||||||
|
expect(modelRouter.setClient).toHaveBeenCalledWith('default', expect.anything(), 'github/gpt-5-mini');
|
||||||
|
expect(modelRouter.setTierStrict).toHaveBeenCalledWith('default', true);
|
||||||
|
expect(mockAgent.setModelTier).toHaveBeenCalledWith('default');
|
||||||
|
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Set default to: github/gpt-5-mini');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('handles /model <tier> reset in gateway sessions', async () => {
|
||||||
|
const sent: OutboundMessage[] = [];
|
||||||
|
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||||
|
const modelRouter = {
|
||||||
|
setClient: vi.fn(),
|
||||||
|
setTierStrict: vi.fn(),
|
||||||
|
};
|
||||||
|
const handlersWithRouter = createAgentHandlers({
|
||||||
|
sessionBridge: sessionBridge as unknown as AgentHandlerDeps['sessionBridge'],
|
||||||
|
laneQueue: new LaneQueue(),
|
||||||
|
sessionManager: sessionManager as unknown as AgentHandlerDeps['sessionManager'],
|
||||||
|
commandRegistry,
|
||||||
|
modelRouter: modelRouter as unknown as AgentHandlerDeps['modelRouter'],
|
||||||
|
runtimeConfig: {
|
||||||
|
models: {
|
||||||
|
default: { provider: 'anthropic', model: 'claude-sonnet-4' },
|
||||||
|
fallback_chain: ['anthropic'],
|
||||||
|
},
|
||||||
|
} as unknown as AgentHandlerDeps['runtimeConfig'],
|
||||||
|
});
|
||||||
|
const req: GatewayRequest = {
|
||||||
|
id: 10,
|
||||||
|
method: 'agent.send',
|
||||||
|
params: {
|
||||||
|
message: '/model default reset',
|
||||||
|
connectionId: 'conn-1',
|
||||||
|
metadata: { isCommand: true, command: 'model', commandArgs: 'default reset' },
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
await handlersWithRouter['agent.send'](req, send);
|
||||||
|
|
||||||
|
expect(modelRouter.setClient).toHaveBeenCalledWith('default', expect.anything(), 'anthropic/claude-sonnet-4');
|
||||||
|
expect(modelRouter.setTierStrict).toHaveBeenCalledWith('default', false);
|
||||||
|
expect(mockAgent.setModelTier).toHaveBeenCalledWith('default');
|
||||||
|
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Reset default to: anthropic/claude-sonnet-4');
|
||||||
|
});
|
||||||
|
|
||||||
it('falls through to agent.process for unknown commands', async () => {
|
it('falls through to agent.process for unknown commands', async () => {
|
||||||
const sent: OutboundMessage[] = [];
|
const sent: OutboundMessage[] = [];
|
||||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||||
|
|||||||
+115
-11
@@ -9,7 +9,11 @@ import type { MetricsCollector } from '../metrics.js';
|
|||||||
import type { Attachment } from '../../channels/types.js';
|
import type { Attachment } from '../../channels/types.js';
|
||||||
import type { SessionManager } from '../../session/manager.js';
|
import type { SessionManager } from '../../session/manager.js';
|
||||||
import type { ModelTier } from '../../models/router.js';
|
import type { ModelTier } from '../../models/router.js';
|
||||||
|
import type { ModelRouter } from '../../models/router.js';
|
||||||
import type { CommandRegistry } from '../../commands/index.js';
|
import type { CommandRegistry } from '../../commands/index.js';
|
||||||
|
import type { Config, ModelConfig, ModelProvider } from '../../config/index.js';
|
||||||
|
import { MODEL_PROVIDERS } from '../../config/index.js';
|
||||||
|
import { createClientFromConfig } from '../../daemon/models.js';
|
||||||
import { auditLogger } from '../../audit/index.js';
|
import { auditLogger } from '../../audit/index.js';
|
||||||
import { randomUUID } from 'crypto';
|
import { randomUUID } from 'crypto';
|
||||||
|
|
||||||
@@ -25,6 +29,28 @@ export interface AgentHandlerDeps {
|
|||||||
metrics?: MetricsCollector;
|
metrics?: MetricsCollector;
|
||||||
sessionManager?: SessionManager;
|
sessionManager?: SessionManager;
|
||||||
commandRegistry?: CommandRegistry;
|
commandRegistry?: CommandRegistry;
|
||||||
|
modelRouter?: ModelRouter;
|
||||||
|
runtimeConfig?: Config;
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildProviderConfigMap(config: Config): Partial<Record<ModelProvider, ModelConfig>> {
|
||||||
|
const providerConfigs: Partial<Record<ModelProvider, ModelConfig>> = {};
|
||||||
|
const modelConfigs: ModelConfig[] = [
|
||||||
|
config.models.default,
|
||||||
|
...(config.models.fast ? [config.models.fast] : []),
|
||||||
|
...(config.models.complex ? [config.models.complex] : []),
|
||||||
|
...(config.models.local ? [config.models.local] : []),
|
||||||
|
...Object.values(config.models.local_providers ?? {}),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (const modelConfig of modelConfigs) {
|
||||||
|
providerConfigs[modelConfig.provider] = modelConfig;
|
||||||
|
if (modelConfig.fallback) {
|
||||||
|
providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return providerConfigs;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function createAgentHandlers(deps: AgentHandlerDeps) {
|
export function createAgentHandlers(deps: AgentHandlerDeps) {
|
||||||
@@ -157,22 +183,100 @@ export function createAgentHandlers(deps: AgentHandlerDeps) {
|
|||||||
setModel: (input) => {
|
setModel: (input) => {
|
||||||
const raw = input.trim();
|
const raw = input.trim();
|
||||||
if (!raw) {
|
if (!raw) {
|
||||||
return 'Usage: /model <tier>';
|
return 'Usage: /model <tier> OR /model <tier> <provider/model> OR /model <tier> reset';
|
||||||
}
|
}
|
||||||
const [requestedTier, ...rest] = raw.split(/\s+/);
|
const parts = raw.split(/\s+/);
|
||||||
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
|
const requestedTier = parts[0];
|
||||||
const modelTier = requestedTier as ModelTier;
|
const validTiers: ModelTier[] = ['default', 'fast', 'complex', 'local'];
|
||||||
if (!validTiers.includes(modelTier)) {
|
if (!validTiers.includes(requestedTier as ModelTier)) {
|
||||||
return `Invalid tier: ${requestedTier}. Available: ${validTiers.join(', ')}`;
|
return `Invalid tier: ${requestedTier}. Available: ${validTiers.join(', ')}`;
|
||||||
}
|
}
|
||||||
if (rest.length > 0) {
|
const modelTier = requestedTier as ModelTier;
|
||||||
return `Switched to model tier: ${modelTier}\nNote: provider/model switching is not available via gateway (/model <tier> <provider/model>).`;
|
|
||||||
|
// /model <tier>
|
||||||
|
if (parts.length === 1) {
|
||||||
|
agent.setModelTier(modelTier);
|
||||||
|
if (sessionId && deps.sessionManager) {
|
||||||
|
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
||||||
|
}
|
||||||
|
return `Switched to model tier: ${modelTier}`;
|
||||||
}
|
}
|
||||||
agent.setModelTier(modelTier);
|
|
||||||
if (sessionId && deps.sessionManager) {
|
if (!deps.modelRouter || !deps.runtimeConfig) {
|
||||||
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
return 'Provider/model switching is unavailable in this gateway runtime.';
|
||||||
|
}
|
||||||
|
|
||||||
|
const arg2 = parts[1];
|
||||||
|
// /model <tier> reset
|
||||||
|
if (arg2.toLowerCase() === 'reset') {
|
||||||
|
const configured: ModelConfig | undefined = modelTier === 'default'
|
||||||
|
? deps.runtimeConfig.models.default
|
||||||
|
: modelTier === 'fast'
|
||||||
|
? deps.runtimeConfig.models.fast
|
||||||
|
: modelTier === 'complex'
|
||||||
|
? deps.runtimeConfig.models.complex
|
||||||
|
: modelTier === 'local'
|
||||||
|
? deps.runtimeConfig.models.local
|
||||||
|
: undefined;
|
||||||
|
if (!configured) {
|
||||||
|
return `No configured model for tier: ${modelTier}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const client = createClientFromConfig(configured);
|
||||||
|
const label = `${configured.provider}/${configured.model}`;
|
||||||
|
deps.modelRouter.setClient(modelTier, client, label);
|
||||||
|
deps.modelRouter.setTierStrict(modelTier, false);
|
||||||
|
agent.setModelTier(modelTier);
|
||||||
|
if (sessionId && deps.sessionManager) {
|
||||||
|
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
||||||
|
}
|
||||||
|
return `Reset ${modelTier} to: ${label}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// /model <tier> <provider/model>
|
||||||
|
const providerModel = arg2;
|
||||||
|
if (!providerModel.includes('/')) {
|
||||||
|
return 'Invalid format. Use: /model <tier> <provider/model> (e.g. /model default github/gpt-5-mini)';
|
||||||
|
}
|
||||||
|
|
||||||
|
const slashIdx = providerModel.indexOf('/');
|
||||||
|
const provider = providerModel.slice(0, slashIdx);
|
||||||
|
const model = providerModel.slice(slashIdx + 1);
|
||||||
|
|
||||||
|
if (!MODEL_PROVIDERS.includes(provider as ModelProvider)) {
|
||||||
|
return `Unknown provider "${provider}". Known providers: ${MODEL_PROVIDERS.join(', ')}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const providerType = provider as ModelProvider;
|
||||||
|
const providerConfigs = buildProviderConfigMap(deps.runtimeConfig);
|
||||||
|
const template = providerConfigs[providerType];
|
||||||
|
|
||||||
|
try {
|
||||||
|
const client = createClientFromConfig(
|
||||||
|
template
|
||||||
|
? { ...template, provider: providerType, model }
|
||||||
|
: { provider: providerType, model },
|
||||||
|
);
|
||||||
|
|
||||||
|
deps.modelRouter.setClient(modelTier, client, providerModel);
|
||||||
|
deps.modelRouter.setTierStrict(modelTier, true);
|
||||||
|
agent.setModelTier(modelTier);
|
||||||
|
if (sessionId && deps.sessionManager) {
|
||||||
|
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
||||||
|
}
|
||||||
|
|
||||||
|
const lines = [
|
||||||
|
`Set ${modelTier} to: ${providerModel}`,
|
||||||
|
`Fallbacks for ${modelTier} disabled (strict tier mode).`,
|
||||||
|
];
|
||||||
|
if (parts.length > 2) {
|
||||||
|
lines.push(`Note: ignored extra args: ${parts.slice(2).join(' ')}`);
|
||||||
|
}
|
||||||
|
return lines.join('\n');
|
||||||
|
} catch (error) {
|
||||||
|
const message = error instanceof Error ? error.message : String(error);
|
||||||
|
return `Failed to switch ${modelTier} to ${providerModel}: ${message}`;
|
||||||
}
|
}
|
||||||
return `Switched to model tier: ${modelTier}`;
|
|
||||||
},
|
},
|
||||||
compact: async () => {
|
compact: async () => {
|
||||||
const result = await agent.compact();
|
const result = await agent.compact();
|
||||||
|
|||||||
@@ -399,6 +399,8 @@ export class GatewayServer {
|
|||||||
metrics: this.metrics,
|
metrics: this.metrics,
|
||||||
sessionManager: this.config.sessionManager,
|
sessionManager: this.config.sessionManager,
|
||||||
commandRegistry: this.config.commandRegistry,
|
commandRegistry: this.config.commandRegistry,
|
||||||
|
modelRouter: 'setClient' in this.config.modelClient ? this.config.modelClient : undefined,
|
||||||
|
runtimeConfig: this.config.config,
|
||||||
});
|
});
|
||||||
|
|
||||||
const intentHandlers = createIntentHandlers({
|
const intentHandlers = createIntentHandlers({
|
||||||
|
|||||||
Reference in New Issue
Block a user