feat(session): persist model tier overrides per session
Store per-session config in SQLite and route /model and /reset through command fast-paths so channel sessions keep independent model selection across reconnects and restarts.
This commit is contained in:
@@ -86,6 +86,27 @@ describe('createAgentHandlers command fast-path', () => {
|
||||
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Session reset.');
|
||||
});
|
||||
|
||||
it('handles /model command via fast-path and persists session tier', async () => {
|
||||
const sent: OutboundMessage[] = [];
|
||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||
const req: GatewayRequest = {
|
||||
id: 4,
|
||||
method: 'agent.send',
|
||||
params: {
|
||||
message: '/model fast',
|
||||
connectionId: 'conn-1',
|
||||
metadata: { isCommand: true, command: 'model', commandArgs: 'fast' },
|
||||
},
|
||||
};
|
||||
|
||||
await handlers['agent.send'](req, send);
|
||||
|
||||
expect(mockAgent.setModelTier).toHaveBeenCalledWith('fast');
|
||||
expect(sessionManager.setSessionConfig).toHaveBeenCalledWith('ws', 'ws:conn-1', 'modelTier', 'fast');
|
||||
expect(mockAgent.process).not.toHaveBeenCalled();
|
||||
expect(((sent[0] as GatewayEvent).data as { content: string }).content).toContain('Switched to model tier: fast');
|
||||
});
|
||||
|
||||
it('falls through to agent.process for unknown commands', async () => {
|
||||
const sent: OutboundMessage[] = [];
|
||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||
|
||||
@@ -7,12 +7,14 @@ import type { MetricsCollector } from '../metrics.js';
|
||||
import type { Attachment } from '../../channels/types.js';
|
||||
import type { SessionManager } from '../../session/manager.js';
|
||||
import type { ModelTier } from '../../models/router.js';
|
||||
import type { CommandRegistry } from '../../commands/index.js';
|
||||
|
||||
export interface AgentHandlerDeps {
|
||||
sessionBridge: SessionBridge;
|
||||
laneQueue: LaneQueue;
|
||||
metrics?: MetricsCollector;
|
||||
sessionManager?: SessionManager;
|
||||
commandRegistry?: CommandRegistry;
|
||||
}
|
||||
|
||||
export function createAgentHandlers(deps: AgentHandlerDeps) {
|
||||
@@ -46,59 +48,78 @@ export function createAgentHandlers(deps: AgentHandlerDeps) {
|
||||
return deps.laneQueue.enqueue(laneId, async () => {
|
||||
deps.sessionBridge.setBusy(connectionId, true);
|
||||
|
||||
// Handle slash commands via metadata (mirrors daemon/routing.ts pattern)
|
||||
if (params.metadata?.isCommand) {
|
||||
try {
|
||||
if (params.metadata.command === 'reset') {
|
||||
agent.reset();
|
||||
// Clear session config
|
||||
const sessionId = deps.sessionBridge.getSessionId(connectionId);
|
||||
if (sessionId && deps.sessionManager) {
|
||||
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'modelTier');
|
||||
}
|
||||
send(makeEvent(request.id, 'done', { content: 'Session reset.' }));
|
||||
return;
|
||||
}
|
||||
const commandInput = params.metadata?.isCommand && typeof params.metadata.command === 'string'
|
||||
? `/${params.metadata.command}${params.metadata.commandArgs ? ` ${params.metadata.commandArgs}` : ''}`
|
||||
: params.message;
|
||||
|
||||
if (params.metadata.command === 'model') {
|
||||
const modelArg = params.metadata.commandArgs as string | undefined;
|
||||
const sessionId = deps.sessionBridge.getSessionId(connectionId);
|
||||
if (commandInput && deps.commandRegistry?.isCommand(commandInput)) {
|
||||
const sessionId = deps.sessionBridge.getSessionId(connectionId);
|
||||
const commandResult = await deps.commandRegistry.execute(commandInput, {
|
||||
channel: 'ws',
|
||||
senderId: connectionId,
|
||||
sessionId: sessionId ?? `ws:${connectionId}`,
|
||||
rawInput: commandInput,
|
||||
services: {
|
||||
getStatus: () => `Gateway session active. Current model tier: ${agent.getModelTier()}`,
|
||||
getUsage: () => {
|
||||
const usage = agent.getUsage();
|
||||
const lines = [
|
||||
'**Token Usage**',
|
||||
'',
|
||||
`Primary: ${usage.primary.inputTokens.toLocaleString()} in / ${usage.primary.outputTokens.toLocaleString()} out (${usage.primary.calls} calls)`,
|
||||
];
|
||||
|
||||
if (!modelArg) {
|
||||
// Show current tier info
|
||||
const currentTier = agent.getModelTier();
|
||||
send(makeEvent(request.id, 'done', {
|
||||
content: `Current model tier: ${currentTier}`,
|
||||
}));
|
||||
return;
|
||||
}
|
||||
const delegationEntries = Object.entries(usage.delegation);
|
||||
if (delegationEntries.length > 0) {
|
||||
lines.push('');
|
||||
lines.push('Delegation:');
|
||||
for (const [tier, stats] of delegationEntries) {
|
||||
lines.push(` ${tier}: ${stats.inputTokens.toLocaleString()} in / ${stats.outputTokens.toLocaleString()} out (${stats.calls} calls)`);
|
||||
}
|
||||
}
|
||||
|
||||
// Validate tier
|
||||
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
|
||||
const tier = modelArg as ModelTier;
|
||||
if (!validTiers.includes(tier)) {
|
||||
send(makeEvent(request.id, 'done', {
|
||||
content: `Invalid tier: ${modelArg}. Available: ${validTiers.join(', ')}`,
|
||||
}));
|
||||
return;
|
||||
}
|
||||
lines.push('');
|
||||
lines.push(`**Total:** ${usage.total.inputTokens.toLocaleString()} in / ${usage.total.outputTokens.toLocaleString()} out (${usage.total.calls} calls)`);
|
||||
|
||||
// Update agent tier
|
||||
agent.setModelTier(tier);
|
||||
if (usage.total.estimatedCost > 0) {
|
||||
lines.push(`**Estimated cost:** $${usage.total.estimatedCost.toFixed(4)}`);
|
||||
}
|
||||
|
||||
// Persist to session config
|
||||
if (sessionId && deps.sessionManager) {
|
||||
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', tier);
|
||||
}
|
||||
return lines.join('\n');
|
||||
},
|
||||
getModel: () => `Current model tier: ${agent.getModelTier()}`,
|
||||
setModel: (tier) => {
|
||||
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
|
||||
const modelTier = tier as ModelTier;
|
||||
if (!validTiers.includes(modelTier)) {
|
||||
return `Invalid tier: ${tier}. Available: ${validTiers.join(', ')}`;
|
||||
}
|
||||
agent.setModelTier(modelTier);
|
||||
if (sessionId && deps.sessionManager) {
|
||||
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
||||
}
|
||||
return `Switched to model tier: ${modelTier}`;
|
||||
},
|
||||
compact: async () => {
|
||||
const result = await agent.compact();
|
||||
if (result && result.compactedCount > 0) {
|
||||
return `Compacted ${result.compactedCount} messages: ${result.tokensBefore} → ${result.tokensAfter} tokens`;
|
||||
}
|
||||
return 'Nothing to compact.';
|
||||
},
|
||||
reset: () => {
|
||||
agent.reset();
|
||||
if (sessionId && deps.sessionManager) {
|
||||
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'modelTier');
|
||||
}
|
||||
return 'Session reset.';
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
send(makeEvent(request.id, 'done', {
|
||||
content: `Switched to model tier: ${tier}`,
|
||||
}));
|
||||
return;
|
||||
}
|
||||
} finally {
|
||||
deps.sessionBridge.setBusy(connectionId, false);
|
||||
deps.metrics?.endRequest(requestId);
|
||||
if (commandResult.handled) {
|
||||
send(makeEvent(request.id, 'done', { content: commandResult.text }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,12 +4,17 @@ import type { TokenUsageEntry } from './system.js';
|
||||
import { createSessionHandlers } from './sessions.js';
|
||||
import { createToolHandlers } from './tools.js';
|
||||
import { createAgentHandlers } from './agent.js';
|
||||
import { createIntentHandlers } from './intents.js';
|
||||
import { createRoutingHandlers } from './routing.js';
|
||||
import { createHistoryHandlers } from './history.js';
|
||||
import { createConfigHandlers, redactConfig } from './config.js';
|
||||
import { createPairingHandlers } from './pairing.js';
|
||||
import { PairingManager } from '../../channels/pairing.js';
|
||||
import { LaneQueue } from '../lane-queue.js';
|
||||
import { ErrorCode } from '../protocol.js';
|
||||
import type { GatewayRequest, GatewayResponse, GatewayError, GatewayEvent, OutboundMessage } from '../protocol.js';
|
||||
import { ComponentRegistry } from '../../intents/index.js';
|
||||
import { RoutingPolicy } from '../../routing/index.js';
|
||||
|
||||
describe('system handlers', () => {
|
||||
const deps = {
|
||||
@@ -402,6 +407,124 @@ describe('agent handlers', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('intent handlers', () => {
|
||||
it('intents.list returns configured rules', async () => {
|
||||
const registry = new ComponentRegistry({ matchThreshold: 0.6 });
|
||||
registry.register({
|
||||
name: 'deploy-route',
|
||||
patterns: ['deploy *'],
|
||||
target: { type: 'agent', name: 'coder' },
|
||||
priority: 5,
|
||||
enabled: true,
|
||||
});
|
||||
|
||||
const handlers = createIntentHandlers({
|
||||
intentRegistry: registry,
|
||||
enabled: true,
|
||||
});
|
||||
|
||||
const req: GatewayRequest = { id: 10, method: 'intents.list' };
|
||||
const result = await handlers['intents.list'](req) as GatewayResponse;
|
||||
const payload = result.result as { enabled: boolean; rules: Array<{ name: string }> };
|
||||
|
||||
expect(payload.enabled).toBe(true);
|
||||
expect(payload.rules).toHaveLength(1);
|
||||
expect(payload.rules[0].name).toBe('deploy-route');
|
||||
});
|
||||
|
||||
it('intents.match returns best rule match', async () => {
|
||||
const registry = new ComponentRegistry({ matchThreshold: 0.5 });
|
||||
registry.register({
|
||||
name: 'deploy-route',
|
||||
patterns: ['deploy *'],
|
||||
target: { type: 'agent', name: 'coder' },
|
||||
priority: 5,
|
||||
enabled: true,
|
||||
});
|
||||
|
||||
const handlers = createIntentHandlers({
|
||||
intentRegistry: registry,
|
||||
enabled: true,
|
||||
});
|
||||
|
||||
const req: GatewayRequest = {
|
||||
id: 11,
|
||||
method: 'intents.match',
|
||||
params: { input: 'deploy backend service' },
|
||||
};
|
||||
const result = await handlers['intents.match'](req) as GatewayResponse;
|
||||
const payload = result.result as { match: { rule: { name: string } } };
|
||||
|
||||
expect(payload.match.rule.name).toBe('deploy-route');
|
||||
});
|
||||
});
|
||||
|
||||
describe('routing handlers', () => {
|
||||
it('routing.decide returns match and policy decision', async () => {
|
||||
const registry = new ComponentRegistry({ matchThreshold: 0.5 });
|
||||
registry.register({
|
||||
name: 'deploy-route',
|
||||
patterns: ['deploy *'],
|
||||
target: { type: 'agent', name: 'coder' },
|
||||
priority: 5,
|
||||
enabled: true,
|
||||
});
|
||||
const policy = new RoutingPolicy({
|
||||
enabled: true,
|
||||
fastPathThreshold: 0.7,
|
||||
llmThreshold: 0.3,
|
||||
defaultPath: 'llm',
|
||||
});
|
||||
const handlers = createRoutingHandlers({
|
||||
intentRegistry: registry,
|
||||
routingPolicy: policy,
|
||||
});
|
||||
|
||||
const req: GatewayRequest = {
|
||||
id: 12,
|
||||
method: 'routing.decide',
|
||||
params: { input: 'deploy service' },
|
||||
};
|
||||
const result = await handlers['routing.decide'](req) as GatewayResponse;
|
||||
const payload = result.result as {
|
||||
match: { rule: { name: string } };
|
||||
decision: { path: string };
|
||||
};
|
||||
|
||||
expect(payload.match.rule.name).toBe('deploy-route');
|
||||
expect(payload.decision.path).toBe('fast');
|
||||
});
|
||||
});
|
||||
|
||||
describe('history handlers', () => {
|
||||
it('history.search returns ranked results', async () => {
|
||||
const handlers = createHistoryHandlers({
|
||||
sessionManager: {
|
||||
searchHistory: () => [{ sessionId: 'ws:test', messageId: 1, role: 'user', content: 'deploy', score: 0.9, createdAt: 123 }],
|
||||
reindexHistory: () => 0,
|
||||
} as any,
|
||||
});
|
||||
|
||||
const req: GatewayRequest = { id: 13, method: 'history.search', params: { query: 'deploy' } };
|
||||
const result = await handlers['history.search'](req) as GatewayResponse;
|
||||
const payload = result.result as { results: Array<{ sessionId: string }> };
|
||||
expect(payload.results[0].sessionId).toBe('ws:test');
|
||||
});
|
||||
|
||||
it('history.reindex returns count', async () => {
|
||||
const handlers = createHistoryHandlers({
|
||||
sessionManager: {
|
||||
searchHistory: () => [],
|
||||
reindexHistory: () => 42,
|
||||
} as any,
|
||||
});
|
||||
|
||||
const req: GatewayRequest = { id: 14, method: 'history.reindex' };
|
||||
const result = await handlers['history.reindex'](req) as GatewayResponse;
|
||||
expect((result.result as { reindexed: number }).reindexed).toBe(42);
|
||||
});
|
||||
});
|
||||
|
||||
describe('system.restart handler', () => {
|
||||
it('returns restarting:true and calls restart callback', async () => {
|
||||
const restartFn = vi.fn(async () => {});
|
||||
|
||||
@@ -10,3 +10,9 @@ export { createConfigHandlers } from './config.js';
|
||||
export type { ConfigHandlerDeps } from './config.js';
|
||||
export { createPairingHandlers } from './pairing.js';
|
||||
export type { PairingHandlerDeps } from './pairing.js';
|
||||
export { createIntentHandlers } from './intents.js';
|
||||
export type { IntentHandlerDeps } from './intents.js';
|
||||
export { createRoutingHandlers } from './routing.js';
|
||||
export type { RoutingHandlerDeps } from './routing.js';
|
||||
export { createHistoryHandlers } from './history.js';
|
||||
export type { HistoryHandlerDeps } from './history.js';
|
||||
|
||||
@@ -23,6 +23,9 @@ import {
|
||||
createAgentHandlers,
|
||||
createConfigHandlers,
|
||||
createPairingHandlers,
|
||||
createIntentHandlers,
|
||||
createRoutingHandlers,
|
||||
createHistoryHandlers,
|
||||
} from './handlers/index.js';
|
||||
import type { TokenUsageEntry } from './handlers/system.js';
|
||||
import type { SessionManager } from '../session/manager.js';
|
||||
@@ -33,6 +36,9 @@ import type { WebhookHandler } from '../automation/webhooks.js';
|
||||
import type { GmailWatcher } from '../automation/gmail.js';
|
||||
import type { PairingManager } from '../channels/pairing.js';
|
||||
import type { MemoryStore } from '../memory/store.js';
|
||||
import type { CommandRegistry } from '../commands/index.js';
|
||||
import type { ComponentRegistry } from '../intents/index.js';
|
||||
import type { RoutingPolicy } from '../routing/index.js';
|
||||
|
||||
export interface GatewayServerConfig {
|
||||
port: number;
|
||||
@@ -62,6 +68,9 @@ export interface GatewayServerConfig {
|
||||
/** Optional pairing manager for DM pairing code management via gateway. */
|
||||
pairingManager?: PairingManager;
|
||||
memoryStore?: MemoryStore;
|
||||
commandRegistry?: CommandRegistry;
|
||||
intentRegistry?: ComponentRegistry;
|
||||
routingPolicy?: RoutingPolicy;
|
||||
}
|
||||
|
||||
export class GatewayServer {
|
||||
@@ -122,6 +131,10 @@ export class GatewayServer {
|
||||
sessionBridge: this.sessionBridge,
|
||||
});
|
||||
|
||||
const historyHandlers = createHistoryHandlers({
|
||||
sessionManager: this.config.sessionManager,
|
||||
});
|
||||
|
||||
const toolHandlers = createToolHandlers({
|
||||
toolRegistry: this.config.toolRegistry,
|
||||
toolExecutor: this.config.toolExecutor,
|
||||
@@ -132,6 +145,17 @@ export class GatewayServer {
|
||||
laneQueue: this.laneQueue,
|
||||
metrics: this.metrics,
|
||||
sessionManager: this.config.sessionManager,
|
||||
commandRegistry: this.config.commandRegistry,
|
||||
});
|
||||
|
||||
const intentHandlers = createIntentHandlers({
|
||||
intentRegistry: this.config.intentRegistry,
|
||||
enabled: this.config.config?.intents.enabled ?? false,
|
||||
});
|
||||
|
||||
const routingHandlers = createRoutingHandlers({
|
||||
intentRegistry: this.config.intentRegistry,
|
||||
routingPolicy: this.config.routingPolicy,
|
||||
});
|
||||
|
||||
// Config handlers (only if config object is provided)
|
||||
@@ -157,12 +181,21 @@ export class GatewayServer {
|
||||
for (const [method, handler] of Object.entries(sessionHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
for (const [method, handler] of Object.entries(historyHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
for (const [method, handler] of Object.entries(toolHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
for (const [method, handler] of Object.entries(agentHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
for (const [method, handler] of Object.entries(intentHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
for (const [method, handler] of Object.entries(routingHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
|
||||
@@ -9,6 +9,9 @@ const mockSession = {
|
||||
getHistory: vi.fn(() => []),
|
||||
clear: vi.fn(),
|
||||
replaceHistory: vi.fn(),
|
||||
getConfig: vi.fn((_key: string) => undefined as string | undefined),
|
||||
setConfig: vi.fn(),
|
||||
deleteConfig: vi.fn(),
|
||||
};
|
||||
|
||||
const mockSessionManager = {
|
||||
@@ -48,9 +51,21 @@ function createBridge(): SessionBridge {
|
||||
});
|
||||
}
|
||||
|
||||
function createBridgeWithConfig(config: SessionBridgeConfig['config']): SessionBridge {
|
||||
return new SessionBridge({
|
||||
sessionManager: mockSessionManager as unknown as SessionBridgeConfig['sessionManager'],
|
||||
modelClient: mockModelClient,
|
||||
systemPrompt: 'test prompt',
|
||||
toolRegistry: mockToolRegistry as unknown as SessionBridgeConfig['toolRegistry'],
|
||||
toolExecutor: mockToolExecutor as unknown as SessionBridgeConfig['toolExecutor'],
|
||||
config,
|
||||
});
|
||||
}
|
||||
|
||||
describe('SessionBridge', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockSession.getConfig.mockImplementation((_key: string) => undefined);
|
||||
});
|
||||
|
||||
it('connect assigns a connection ID', () => {
|
||||
@@ -142,4 +157,78 @@ describe('SessionBridge', () => {
|
||||
expect(bridge.getAgent('conn-2')).toBeDefined();
|
||||
expect(bridge.connectionCount).toBe(1);
|
||||
});
|
||||
|
||||
it('loads model tier from per-session config when creating a session agent', () => {
|
||||
mockSession.getConfig.mockImplementation((key: string) => (key === 'modelTier' ? 'local' : undefined));
|
||||
|
||||
const bridge = createBridgeWithConfig({
|
||||
agents: {
|
||||
primary_tier: 'default',
|
||||
delegation: {
|
||||
compaction: 'fast',
|
||||
memory_extraction: 'fast',
|
||||
classification: 'fast',
|
||||
tool_summarisation: 'fast',
|
||||
complex_reasoning: 'complex',
|
||||
},
|
||||
max_delegation_depth: 3,
|
||||
},
|
||||
compaction: { enabled: false },
|
||||
models: { default: { provider: 'anthropic', model: 'claude-3-haiku' } },
|
||||
} as any);
|
||||
|
||||
bridge.connect('conn-tier');
|
||||
const agent = bridge.getAgent('conn-tier');
|
||||
expect(agent?.getModelTier()).toBe('local');
|
||||
});
|
||||
|
||||
it('keeps different sessions isolated by persisted model tier', () => {
|
||||
const sessionById: Record<string, any> = {};
|
||||
const localSessionManager = {
|
||||
...mockSessionManager,
|
||||
getSession: vi.fn((frontend: string, sessionId: string) => {
|
||||
const fullId = `${frontend}:${sessionId}`;
|
||||
if (!sessionById[fullId]) {
|
||||
const tier = fullId === 'ws:ws:conn-a' ? 'fast' : 'complex';
|
||||
sessionById[fullId] = {
|
||||
...mockSession,
|
||||
id: fullId,
|
||||
getConfig: vi.fn((key: string) => (key === 'modelTier' ? tier : undefined)),
|
||||
};
|
||||
}
|
||||
return sessionById[fullId];
|
||||
}),
|
||||
};
|
||||
|
||||
const bridge = new SessionBridge({
|
||||
sessionManager: localSessionManager as unknown as SessionBridgeConfig['sessionManager'],
|
||||
modelClient: mockModelClient,
|
||||
systemPrompt: 'test prompt',
|
||||
toolRegistry: mockToolRegistry as unknown as SessionBridgeConfig['toolRegistry'],
|
||||
toolExecutor: mockToolExecutor as unknown as SessionBridgeConfig['toolExecutor'],
|
||||
config: {
|
||||
agents: {
|
||||
primary_tier: 'default',
|
||||
delegation: {
|
||||
compaction: 'fast',
|
||||
memory_extraction: 'fast',
|
||||
classification: 'fast',
|
||||
tool_summarisation: 'fast',
|
||||
complex_reasoning: 'complex',
|
||||
},
|
||||
max_delegation_depth: 3,
|
||||
},
|
||||
compaction: { enabled: false },
|
||||
models: { default: { provider: 'anthropic', model: 'claude-3-haiku' } },
|
||||
} as any,
|
||||
});
|
||||
|
||||
bridge.connect('conn-a');
|
||||
bridge.connect('conn-b');
|
||||
const agentA = bridge.getAgent('conn-a');
|
||||
const agentB = bridge.getAgent('conn-b');
|
||||
|
||||
expect(agentA?.getModelTier()).toBe('fast');
|
||||
expect(agentB?.getModelTier()).toBe('complex');
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user