feat(session): persist model tier overrides per session

Store per-session config in SQLite and route /model and /reset through command fast-paths so channel sessions keep independent model selection across reconnects and restarts.
This commit is contained in:
William Valentin
2026-02-13 01:04:26 -08:00
parent 3472a0b926
commit 9f81c01603
35 changed files with 1438 additions and 144 deletions
+21
View File
@@ -86,6 +86,27 @@ describe('createAgentHandlers command fast-path', () => {
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Session reset.');
});
it('handles /model command via fast-path and persists session tier', async () => {
const sent: OutboundMessage[] = [];
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
const req: GatewayRequest = {
id: 4,
method: 'agent.send',
params: {
message: '/model fast',
connectionId: 'conn-1',
metadata: { isCommand: true, command: 'model', commandArgs: 'fast' },
},
};
await handlers['agent.send'](req, send);
expect(mockAgent.setModelTier).toHaveBeenCalledWith('fast');
expect(sessionManager.setSessionConfig).toHaveBeenCalledWith('ws', 'ws:conn-1', 'modelTier', 'fast');
expect(mockAgent.process).not.toHaveBeenCalled();
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Switched to model tier: fast');
});
it('falls through to agent.process for unknown commands', async () => {
const sent: OutboundMessage[] = [];
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
+68 -47
View File
@@ -7,12 +7,14 @@ import type { MetricsCollector } from '../metrics.js';
import type { Attachment } from '../../channels/types.js';
import type { SessionManager } from '../../session/manager.js';
import type { ModelTier } from '../../models/router.js';
import type { CommandRegistry } from '../../commands/index.js';
export interface AgentHandlerDeps {
sessionBridge: SessionBridge;
laneQueue: LaneQueue;
metrics?: MetricsCollector;
sessionManager?: SessionManager;
commandRegistry?: CommandRegistry;
}
export function createAgentHandlers(deps: AgentHandlerDeps) {
@@ -46,59 +48,78 @@ export function createAgentHandlers(deps: AgentHandlerDeps) {
return deps.laneQueue.enqueue(laneId, async () => {
deps.sessionBridge.setBusy(connectionId, true);
// Handle slash commands via metadata (mirrors daemon/routing.ts pattern)
if (params.metadata?.isCommand) {
try {
if (params.metadata.command === 'reset') {
agent.reset();
// Clear session config
const sessionId = deps.sessionBridge.getSessionId(connectionId);
if (sessionId && deps.sessionManager) {
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'modelTier');
}
send(makeEvent(request.id, 'done', { content: 'Session reset.' }));
return;
}
const commandInput = params.metadata?.isCommand && typeof params.metadata.command === 'string'
? `/${params.metadata.command}${params.metadata.commandArgs ? ` ${params.metadata.commandArgs}` : ''}`
: params.message;
if (params.metadata.command === 'model') {
const modelArg = params.metadata.commandArgs as string | undefined;
const sessionId = deps.sessionBridge.getSessionId(connectionId);
if (commandInput && deps.commandRegistry?.isCommand(commandInput)) {
const sessionId = deps.sessionBridge.getSessionId(connectionId);
const commandResult = await deps.commandRegistry.execute(commandInput, {
channel: 'ws',
senderId: connectionId,
sessionId: sessionId ?? `ws:${connectionId}`,
rawInput: commandInput,
services: {
getStatus: () => `Gateway session active. Current model tier: ${agent.getModelTier()}`,
getUsage: () => {
const usage = agent.getUsage();
const lines = [
'**Token Usage**',
'',
`Primary: ${usage.primary.inputTokens.toLocaleString()} in / ${usage.primary.outputTokens.toLocaleString()} out (${usage.primary.calls} calls)`,
];
if (!modelArg) {
// Show current tier info
const currentTier = agent.getModelTier();
send(makeEvent(request.id, 'done', {
content: `Current model tier: ${currentTier}`,
}));
return;
}
const delegationEntries = Object.entries(usage.delegation);
if (delegationEntries.length > 0) {
lines.push('');
lines.push('Delegation:');
for (const [tier, stats] of delegationEntries) {
lines.push(` ${tier}: ${stats.inputTokens.toLocaleString()} in / ${stats.outputTokens.toLocaleString()} out (${stats.calls} calls)`);
}
}
// Validate tier
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
const tier = modelArg as ModelTier;
if (!validTiers.includes(tier)) {
send(makeEvent(request.id, 'done', {
content: `Invalid tier: ${modelArg}. Available: ${validTiers.join(', ')}`,
}));
return;
}
lines.push('');
lines.push(`**Total:** ${usage.total.inputTokens.toLocaleString()} in / ${usage.total.outputTokens.toLocaleString()} out (${usage.total.calls} calls)`);
// Update agent tier
agent.setModelTier(tier);
if (usage.total.estimatedCost > 0) {
lines.push(`**Estimated cost:** $${usage.total.estimatedCost.toFixed(4)}`);
}
// Persist to session config
if (sessionId && deps.sessionManager) {
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', tier);
}
return lines.join('\n');
},
getModel: () => `Current model tier: ${agent.getModelTier()}`,
setModel: (tier) => {
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
const modelTier = tier as ModelTier;
if (!validTiers.includes(modelTier)) {
return `Invalid tier: ${tier}. Available: ${validTiers.join(', ')}`;
}
agent.setModelTier(modelTier);
if (sessionId && deps.sessionManager) {
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
}
return `Switched to model tier: ${modelTier}`;
},
compact: async () => {
const result = await agent.compact();
if (result && result.compactedCount > 0) {
return `Compacted ${result.compactedCount} messages: ${result.tokensBefore}${result.tokensAfter} tokens`;
}
return 'Nothing to compact.';
},
reset: () => {
agent.reset();
if (sessionId && deps.sessionManager) {
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'modelTier');
}
return 'Session reset.';
},
},
});
send(makeEvent(request.id, 'done', {
content: `Switched to model tier: ${tier}`,
}));
return;
}
} finally {
deps.sessionBridge.setBusy(connectionId, false);
deps.metrics?.endRequest(requestId);
if (commandResult.handled) {
send(makeEvent(request.id, 'done', { content: commandResult.text }));
return;
}
}
+123
View File
@@ -4,12 +4,17 @@ import type { TokenUsageEntry } from './system.js';
import { createSessionHandlers } from './sessions.js';
import { createToolHandlers } from './tools.js';
import { createAgentHandlers } from './agent.js';
import { createIntentHandlers } from './intents.js';
import { createRoutingHandlers } from './routing.js';
import { createHistoryHandlers } from './history.js';
import { createConfigHandlers, redactConfig } from './config.js';
import { createPairingHandlers } from './pairing.js';
import { PairingManager } from '../../channels/pairing.js';
import { LaneQueue } from '../lane-queue.js';
import { ErrorCode } from '../protocol.js';
import type { GatewayRequest, GatewayResponse, GatewayError, GatewayEvent, OutboundMessage } from '../protocol.js';
import { ComponentRegistry } from '../../intents/index.js';
import { RoutingPolicy } from '../../routing/index.js';
describe('system handlers', () => {
const deps = {
@@ -402,6 +407,124 @@ describe('agent handlers', () => {
});
});
describe('intent handlers', () => {
it('intents.list returns configured rules', async () => {
const registry = new ComponentRegistry({ matchThreshold: 0.6 });
registry.register({
name: 'deploy-route',
patterns: ['deploy *'],
target: { type: 'agent', name: 'coder' },
priority: 5,
enabled: true,
});
const handlers = createIntentHandlers({
intentRegistry: registry,
enabled: true,
});
const req: GatewayRequest = { id: 10, method: 'intents.list' };
const result = await handlers['intents.list'](req) as GatewayResponse;
const payload = result.result as { enabled: boolean; rules: Array<{ name: string }> };
expect(payload.enabled).toBe(true);
expect(payload.rules).toHaveLength(1);
expect(payload.rules[0].name).toBe('deploy-route');
});
it('intents.match returns best rule match', async () => {
const registry = new ComponentRegistry({ matchThreshold: 0.5 });
registry.register({
name: 'deploy-route',
patterns: ['deploy *'],
target: { type: 'agent', name: 'coder' },
priority: 5,
enabled: true,
});
const handlers = createIntentHandlers({
intentRegistry: registry,
enabled: true,
});
const req: GatewayRequest = {
id: 11,
method: 'intents.match',
params: { input: 'deploy backend service' },
};
const result = await handlers['intents.match'](req) as GatewayResponse;
const payload = result.result as { match: { rule: { name: string } } };
expect(payload.match.rule.name).toBe('deploy-route');
});
});
describe('routing handlers', () => {
it('routing.decide returns match and policy decision', async () => {
const registry = new ComponentRegistry({ matchThreshold: 0.5 });
registry.register({
name: 'deploy-route',
patterns: ['deploy *'],
target: { type: 'agent', name: 'coder' },
priority: 5,
enabled: true,
});
const policy = new RoutingPolicy({
enabled: true,
fastPathThreshold: 0.7,
llmThreshold: 0.3,
defaultPath: 'llm',
});
const handlers = createRoutingHandlers({
intentRegistry: registry,
routingPolicy: policy,
});
const req: GatewayRequest = {
id: 12,
method: 'routing.decide',
params: { input: 'deploy service' },
};
const result = await handlers['routing.decide'](req) as GatewayResponse;
const payload = result.result as {
match: { rule: { name: string } };
decision: { path: string };
};
expect(payload.match.rule.name).toBe('deploy-route');
expect(payload.decision.path).toBe('fast');
});
});
describe('history handlers', () => {
it('history.search returns ranked results', async () => {
const handlers = createHistoryHandlers({
sessionManager: {
searchHistory: () => [{ sessionId: 'ws:test', messageId: 1, role: 'user', content: 'deploy', score: 0.9, createdAt: 123 }],
reindexHistory: () => 0,
} as any,
});
const req: GatewayRequest = { id: 13, method: 'history.search', params: { query: 'deploy' } };
const result = await handlers['history.search'](req) as GatewayResponse;
const payload = result.result as { results: Array<{ sessionId: string }> };
expect(payload.results[0].sessionId).toBe('ws:test');
});
it('history.reindex returns count', async () => {
const handlers = createHistoryHandlers({
sessionManager: {
searchHistory: () => [],
reindexHistory: () => 42,
} as any,
});
const req: GatewayRequest = { id: 14, method: 'history.reindex' };
const result = await handlers['history.reindex'](req) as GatewayResponse;
expect((result.result as { reindexed: number }).reindexed).toBe(42);
});
});
describe('system.restart handler', () => {
it('returns restarting:true and calls restart callback', async () => {
const restartFn = vi.fn(async () => {});
+6
View File
@@ -10,3 +10,9 @@ export { createConfigHandlers } from './config.js';
export type { ConfigHandlerDeps } from './config.js';
export { createPairingHandlers } from './pairing.js';
export type { PairingHandlerDeps } from './pairing.js';
export { createIntentHandlers } from './intents.js';
export type { IntentHandlerDeps } from './intents.js';
export { createRoutingHandlers } from './routing.js';
export type { RoutingHandlerDeps } from './routing.js';
export { createHistoryHandlers } from './history.js';
export type { HistoryHandlerDeps } from './history.js';
+33
View File
@@ -23,6 +23,9 @@ import {
createAgentHandlers,
createConfigHandlers,
createPairingHandlers,
createIntentHandlers,
createRoutingHandlers,
createHistoryHandlers,
} from './handlers/index.js';
import type { TokenUsageEntry } from './handlers/system.js';
import type { SessionManager } from '../session/manager.js';
@@ -33,6 +36,9 @@ import type { WebhookHandler } from '../automation/webhooks.js';
import type { GmailWatcher } from '../automation/gmail.js';
import type { PairingManager } from '../channels/pairing.js';
import type { MemoryStore } from '../memory/store.js';
import type { CommandRegistry } from '../commands/index.js';
import type { ComponentRegistry } from '../intents/index.js';
import type { RoutingPolicy } from '../routing/index.js';
export interface GatewayServerConfig {
port: number;
@@ -62,6 +68,9 @@ export interface GatewayServerConfig {
/** Optional pairing manager for DM pairing code management via gateway. */
pairingManager?: PairingManager;
memoryStore?: MemoryStore;
commandRegistry?: CommandRegistry;
intentRegistry?: ComponentRegistry;
routingPolicy?: RoutingPolicy;
}
export class GatewayServer {
@@ -122,6 +131,10 @@ export class GatewayServer {
sessionBridge: this.sessionBridge,
});
const historyHandlers = createHistoryHandlers({
sessionManager: this.config.sessionManager,
});
const toolHandlers = createToolHandlers({
toolRegistry: this.config.toolRegistry,
toolExecutor: this.config.toolExecutor,
@@ -132,6 +145,17 @@ export class GatewayServer {
laneQueue: this.laneQueue,
metrics: this.metrics,
sessionManager: this.config.sessionManager,
commandRegistry: this.config.commandRegistry,
});
const intentHandlers = createIntentHandlers({
intentRegistry: this.config.intentRegistry,
enabled: this.config.config?.intents.enabled ?? false,
});
const routingHandlers = createRoutingHandlers({
intentRegistry: this.config.intentRegistry,
routingPolicy: this.config.routingPolicy,
});
// Config handlers (only if config object is provided)
@@ -157,12 +181,21 @@ export class GatewayServer {
for (const [method, handler] of Object.entries(sessionHandlers)) {
this.router.register(method, handler);
}
for (const [method, handler] of Object.entries(historyHandlers)) {
this.router.register(method, handler);
}
for (const [method, handler] of Object.entries(toolHandlers)) {
this.router.register(method, handler);
}
for (const [method, handler] of Object.entries(agentHandlers)) {
this.router.register(method, handler);
}
for (const [method, handler] of Object.entries(intentHandlers)) {
this.router.register(method, handler);
}
for (const [method, handler] of Object.entries(routingHandlers)) {
this.router.register(method, handler);
}
}
async start(): Promise<void> {
+89
View File
@@ -9,6 +9,9 @@ const mockSession = {
getHistory: vi.fn(() => []),
clear: vi.fn(),
replaceHistory: vi.fn(),
getConfig: vi.fn((_key: string) => undefined as string | undefined),
setConfig: vi.fn(),
deleteConfig: vi.fn(),
};
const mockSessionManager = {
@@ -48,9 +51,21 @@ function createBridge(): SessionBridge {
});
}
function createBridgeWithConfig(config: SessionBridgeConfig['config']): SessionBridge {
return new SessionBridge({
sessionManager: mockSessionManager as unknown as SessionBridgeConfig['sessionManager'],
modelClient: mockModelClient,
systemPrompt: 'test prompt',
toolRegistry: mockToolRegistry as unknown as SessionBridgeConfig['toolRegistry'],
toolExecutor: mockToolExecutor as unknown as SessionBridgeConfig['toolExecutor'],
config,
});
}
describe('SessionBridge', () => {
beforeEach(() => {
vi.clearAllMocks();
mockSession.getConfig.mockImplementation((_key: string) => undefined);
});
it('connect assigns a connection ID', () => {
@@ -142,4 +157,78 @@ describe('SessionBridge', () => {
expect(bridge.getAgent('conn-2')).toBeDefined();
expect(bridge.connectionCount).toBe(1);
});
it('loads model tier from per-session config when creating a session agent', () => {
mockSession.getConfig.mockImplementation((key: string) => (key === 'modelTier' ? 'local' : undefined));
const bridge = createBridgeWithConfig({
agents: {
primary_tier: 'default',
delegation: {
compaction: 'fast',
memory_extraction: 'fast',
classification: 'fast',
tool_summarisation: 'fast',
complex_reasoning: 'complex',
},
max_delegation_depth: 3,
},
compaction: { enabled: false },
models: { default: { provider: 'anthropic', model: 'claude-3-haiku' } },
} as any);
bridge.connect('conn-tier');
const agent = bridge.getAgent('conn-tier');
expect(agent?.getModelTier()).toBe('local');
});
it('keeps different sessions isolated by persisted model tier', () => {
const sessionById: Record<string, any> = {};
const localSessionManager = {
...mockSessionManager,
getSession: vi.fn((frontend: string, sessionId: string) => {
const fullId = `${frontend}:${sessionId}`;
if (!sessionById[fullId]) {
const tier = fullId === 'ws:ws:conn-a' ? 'fast' : 'complex';
sessionById[fullId] = {
...mockSession,
id: fullId,
getConfig: vi.fn((key: string) => (key === 'modelTier' ? tier : undefined)),
};
}
return sessionById[fullId];
}),
};
const bridge = new SessionBridge({
sessionManager: localSessionManager as unknown as SessionBridgeConfig['sessionManager'],
modelClient: mockModelClient,
systemPrompt: 'test prompt',
toolRegistry: mockToolRegistry as unknown as SessionBridgeConfig['toolRegistry'],
toolExecutor: mockToolExecutor as unknown as SessionBridgeConfig['toolExecutor'],
config: {
agents: {
primary_tier: 'default',
delegation: {
compaction: 'fast',
memory_extraction: 'fast',
classification: 'fast',
tool_summarisation: 'fast',
complex_reasoning: 'complex',
},
max_delegation_depth: 3,
},
compaction: { enabled: false },
models: { default: { provider: 'anthropic', model: 'claude-3-haiku' } },
} as any,
});
bridge.connect('conn-a');
bridge.connect('conn-b');
const agentA = bridge.getAgent('conn-a');
const agentB = bridge.getAgent('conn-b');
expect(agentA?.getModelTier()).toBe('fast');
expect(agentB?.getModelTier()).toBe('complex');
});
});