345 lines
16 KiB
TypeScript
345 lines
16 KiB
TypeScript
import type { GatewayRequest, GatewayAttachment, OutboundMessage } from '../protocol.js';
|
|
import type { SendFn } from '../router.js';
|
|
import { makeEvent, makeError, ErrorCode } from '../protocol.js';
|
|
import type { SessionBridge } from '../session-bridge.js';
|
|
import type { LaneQueue } from '../lane-queue.js';
|
|
import type { MetricsCollector } from '../metrics.js';
|
|
import type { Attachment } from '../../channels/types.js';
|
|
import type { SessionManager } from '../../session/manager.js';
|
|
import type { ModelTier } from '../../models/router.js';
|
|
import type { CommandRegistry } from '../../commands/index.js';
|
|
import { auditLogger } from '../../audit/index.js';
|
|
import { randomUUID } from 'crypto';
|
|
|
|
export interface AgentHandlerDeps {
|
|
sessionBridge: SessionBridge;
|
|
laneQueue: LaneQueue;
|
|
metrics?: MetricsCollector;
|
|
sessionManager?: SessionManager;
|
|
commandRegistry?: CommandRegistry;
|
|
}
|
|
|
|
export function createAgentHandlers(deps: AgentHandlerDeps) {
|
|
return {
|
|
'agent.send': async (request: GatewayRequest, send: SendFn): Promise<OutboundMessage | void> => {
|
|
const params = request.params as { message?: string; connectionId?: string; attachments?: GatewayAttachment[]; metadata?: { isCommand?: boolean; command?: string; commandArgs?: string } } | undefined;
|
|
if (!params) {
|
|
return makeError(request.id, ErrorCode.InvalidRequest, 'params are required');
|
|
}
|
|
|
|
const safeParams = params;
|
|
const hasMessage = Boolean(safeParams.message && safeParams.message.trim());
|
|
const hasAttachments = Boolean(safeParams.attachments && safeParams.attachments.length > 0);
|
|
if (!hasMessage && !hasAttachments && !safeParams.metadata?.isCommand) {
|
|
return makeError(request.id, ErrorCode.InvalidRequest, 'message or attachments are required');
|
|
}
|
|
|
|
const connectionId = safeParams.connectionId as string;
|
|
if (!connectionId) {
|
|
return makeError(request.id, ErrorCode.InvalidRequest, 'connectionId is required (set by server)');
|
|
}
|
|
|
|
const agent = deps.sessionBridge.getAgent(connectionId);
|
|
if (!agent) {
|
|
return makeError(request.id, ErrorCode.SessionNotFound, 'No agent for this connection');
|
|
}
|
|
|
|
// Queue by session ID so multiple connections sharing a session are serialised.
|
|
// Falls back to connectionId if session lookup fails (shouldn't happen).
|
|
const sessionId = deps.sessionBridge.getSessionId(connectionId);
|
|
const laneId = sessionId ?? connectionId;
|
|
|
|
// Enqueue the work — if the lane is idle it runs immediately,
|
|
// otherwise it waits for earlier requests on the same session to finish.
|
|
const requestId = request.id.toString();
|
|
deps.metrics?.startRequest(requestId, { sessionId: laneId, channel: 'ws' });
|
|
|
|
return deps.laneQueue.enqueue(laneId, async () => {
|
|
deps.sessionBridge.setBusy(connectionId, true);
|
|
|
|
const commandInput = safeParams.metadata?.isCommand && typeof safeParams.metadata.command === 'string'
|
|
? `/${safeParams.metadata.command}${safeParams.metadata.commandArgs ? ` ${safeParams.metadata.commandArgs}` : ''}`
|
|
: (safeParams.message ?? '');
|
|
|
|
if (commandInput && deps.commandRegistry?.isCommand(commandInput)) {
|
|
const sessionId = deps.sessionBridge.getSessionId(connectionId);
|
|
const commandResult = await deps.commandRegistry.execute(commandInput, {
|
|
channel: 'ws',
|
|
senderId: connectionId,
|
|
sessionId: sessionId ?? `ws:${connectionId}`,
|
|
rawInput: commandInput,
|
|
services: {
|
|
getStatus: () => `Gateway session active. Current model tier: ${agent.getModelTier()}`,
|
|
getUsage: () => {
|
|
const usage = agent.getUsage();
|
|
const lines = [
|
|
'**Token Usage**',
|
|
'',
|
|
`Primary: ${usage.primary.inputTokens.toLocaleString()} in / ${usage.primary.outputTokens.toLocaleString()} out (${usage.primary.calls} calls)`,
|
|
];
|
|
|
|
const delegationEntries = Object.entries(usage.delegation);
|
|
if (delegationEntries.length > 0) {
|
|
lines.push('');
|
|
lines.push('Delegation:');
|
|
for (const [tier, stats] of delegationEntries) {
|
|
lines.push(` ${tier}: ${stats.inputTokens.toLocaleString()} in / ${stats.outputTokens.toLocaleString()} out (${stats.calls} calls)`);
|
|
}
|
|
}
|
|
|
|
lines.push('');
|
|
lines.push(`**Total:** ${usage.total.inputTokens.toLocaleString()} in / ${usage.total.outputTokens.toLocaleString()} out (${usage.total.calls} calls)`);
|
|
|
|
if (usage.total.estimatedCost > 0) {
|
|
lines.push(`**Estimated cost:** $${usage.total.estimatedCost.toFixed(4)}`);
|
|
}
|
|
|
|
return lines.join('\n');
|
|
},
|
|
getModel: () => `Current model tier: ${agent.getModelTier()}`,
|
|
setModel: (input) => {
|
|
const raw = input.trim();
|
|
if (!raw) {
|
|
return 'Usage: /model <tier>';
|
|
}
|
|
const [requestedTier, ...rest] = raw.split(/\s+/);
|
|
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
|
|
const modelTier = requestedTier as ModelTier;
|
|
if (!validTiers.includes(modelTier)) {
|
|
return `Invalid tier: ${requestedTier}. Available: ${validTiers.join(', ')}`;
|
|
}
|
|
if (rest.length > 0) {
|
|
return `Switched to model tier: ${modelTier}\nNote: provider/model switching is not available via gateway (/model <tier> <provider/model>).`;
|
|
}
|
|
agent.setModelTier(modelTier);
|
|
if (sessionId && deps.sessionManager) {
|
|
deps.sessionManager.setSessionConfig('ws', sessionId, 'modelTier', modelTier);
|
|
}
|
|
return `Switched to model tier: ${modelTier}`;
|
|
},
|
|
compact: async () => {
|
|
const result = await agent.compact();
|
|
if (result && result.compactedCount > 0) {
|
|
return `Compacted ${result.compactedCount} messages: ${result.tokensBefore} → ${result.tokensAfter} tokens`;
|
|
}
|
|
return 'Nothing to compact.';
|
|
},
|
|
reset: () => {
|
|
agent.reset();
|
|
if (sessionId && deps.sessionManager) {
|
|
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'modelTier');
|
|
}
|
|
return 'Session reset.';
|
|
},
|
|
|
|
getElevation: () => {
|
|
if (!sessionId || !deps.sessionManager) {
|
|
return 'Elevated mode: off';
|
|
}
|
|
const untilRaw = deps.sessionManager.getSessionConfig('ws', sessionId, 'elevation.until_ms');
|
|
const reason = deps.sessionManager.getSessionConfig('ws', sessionId, 'elevation.reason') ?? '';
|
|
const id = deps.sessionManager.getSessionConfig('ws', sessionId, 'elevation.id') ?? '';
|
|
if (!untilRaw || !id) {
|
|
return 'Elevated mode: off';
|
|
}
|
|
const untilMs = Number.parseInt(untilRaw, 10);
|
|
if (!Number.isFinite(untilMs)) {
|
|
return 'Elevated mode: off';
|
|
}
|
|
const now = Date.now();
|
|
if (untilMs <= now) {
|
|
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'elevation.until_ms');
|
|
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'elevation.reason');
|
|
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'elevation.id');
|
|
auditLogger?.securityElevationExpired({
|
|
session_id: `ws:${sessionId}`,
|
|
channel: 'ws',
|
|
sender: connectionId,
|
|
elevation_id: id,
|
|
until_ms: untilMs,
|
|
reason: reason || undefined,
|
|
});
|
|
return 'Elevated mode: off (expired)';
|
|
}
|
|
const remainingMs = untilMs - now;
|
|
const remainingSec = Math.ceil(remainingMs / 1000);
|
|
return `Elevated mode: on (${remainingSec}s remaining)${reason ? ` — ${reason}` : ''}`;
|
|
},
|
|
|
|
setElevation: (input: string) => {
|
|
if (!sessionId || !deps.sessionManager) {
|
|
return 'Elevate command is not available in this session.';
|
|
}
|
|
const raw = input.trim();
|
|
const parts = raw.split(/\s+/);
|
|
const hasYes = parts.includes('--yes') || parts.includes('--confirm');
|
|
const filtered = parts.filter(p => p !== '--yes' && p !== '--confirm');
|
|
|
|
if (filtered.length === 0) {
|
|
return 'Usage: /elevate <duration> <reason...> --yes | /elevate off --yes';
|
|
}
|
|
|
|
if (filtered[0] === 'off') {
|
|
if (!hasYes) {
|
|
return 'Refusing to disable elevation without explicit confirmation. Use: /elevate off --yes';
|
|
}
|
|
const existingId = deps.sessionManager.getSessionConfig('ws', sessionId, 'elevation.id') ?? randomUUID();
|
|
const existingUntil = deps.sessionManager.getSessionConfig('ws', sessionId, 'elevation.until_ms');
|
|
const existingReason = deps.sessionManager.getSessionConfig('ws', sessionId, 'elevation.reason') ?? '';
|
|
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'elevation.until_ms');
|
|
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'elevation.reason');
|
|
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'elevation.id');
|
|
auditLogger?.securityElevationDisabled({
|
|
session_id: `ws:${sessionId}`,
|
|
channel: 'ws',
|
|
sender: connectionId,
|
|
elevation_id: existingId,
|
|
until_ms: existingUntil ? Number.parseInt(existingUntil, 10) : undefined,
|
|
reason: existingReason || undefined,
|
|
});
|
|
return 'Elevated mode: off';
|
|
}
|
|
|
|
if (!hasYes) {
|
|
return 'Refusing to enable elevation without explicit confirmation. Use: /elevate <duration> <reason...> --yes';
|
|
}
|
|
|
|
const dur = filtered[0];
|
|
const reason = filtered.slice(1).join(' ').trim();
|
|
const ttlMs = (() => {
|
|
const m = dur.match(/^(\d+)([smhd])$/i);
|
|
if (!m) {
|
|
return null;
|
|
}
|
|
const n = Number.parseInt(m[1], 10);
|
|
if (!Number.isFinite(n) || n <= 0) {
|
|
return null;
|
|
}
|
|
const unit = m[2].toLowerCase();
|
|
if (unit === 's') {return n * 1000;}
|
|
if (unit === 'm') {return n * 60_000;}
|
|
if (unit === 'h') {return n * 3_600_000;}
|
|
if (unit === 'd') {return n * 86_400_000;}
|
|
return null;
|
|
})();
|
|
if (!ttlMs) {
|
|
return 'Invalid duration. Use one of: 30s, 10m, 1h, 1d';
|
|
}
|
|
|
|
const untilMs = Date.now() + ttlMs;
|
|
const id = randomUUID();
|
|
deps.sessionManager.setSessionConfig('ws', sessionId, 'elevation.until_ms', String(untilMs));
|
|
deps.sessionManager.setSessionConfig('ws', sessionId, 'elevation.id', id);
|
|
if (reason) {
|
|
deps.sessionManager.setSessionConfig('ws', sessionId, 'elevation.reason', reason);
|
|
} else {
|
|
deps.sessionManager.deleteSessionConfig('ws', sessionId, 'elevation.reason');
|
|
}
|
|
|
|
auditLogger?.securityElevationEnabled({
|
|
session_id: `ws:${sessionId}`,
|
|
channel: 'ws',
|
|
sender: connectionId,
|
|
elevation_id: id,
|
|
until_ms: untilMs,
|
|
ttl_ms: ttlMs,
|
|
reason: reason || undefined,
|
|
});
|
|
|
|
return `Elevated mode: on until ${new Date(untilMs).toISOString()}`;
|
|
},
|
|
},
|
|
});
|
|
|
|
if (commandResult.handled) {
|
|
send(makeEvent(request.id, 'done', { content: commandResult.text }));
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Set up tool use callback to emit streaming events
|
|
deps.sessionBridge.setOnToolUse(connectionId, (event) => {
|
|
if (event.type === 'start') {
|
|
send(makeEvent(request.id, 'tool_start', { tool: event.tool, args: event.args }));
|
|
} else if (event.type === 'end') {
|
|
send(makeEvent(request.id, 'tool_end', {
|
|
tool: event.tool,
|
|
result: event.result ? {
|
|
success: event.result.success,
|
|
output: event.result.output,
|
|
error: event.result.error,
|
|
} : undefined,
|
|
}));
|
|
// Record tool failures as error events
|
|
if (event.result && !event.result.success) {
|
|
deps.metrics?.incrementErrors();
|
|
deps.metrics?.recordEvent({
|
|
timestamp: Date.now(),
|
|
level: 'error',
|
|
source: 'tool',
|
|
message: `Tool '${event.tool}' failed: ${event.result.error ?? 'unknown error'}`,
|
|
context: { sessionId: laneId, tool: event.tool },
|
|
});
|
|
}
|
|
}
|
|
});
|
|
|
|
try {
|
|
// Convert gateway attachments to channel attachments
|
|
const attachments: Attachment[] | undefined = safeParams.attachments?.map(a => ({
|
|
mimeType: a.mimeType,
|
|
data: a.data,
|
|
url: a.url,
|
|
filename: a.filename,
|
|
}));
|
|
|
|
const response = await agent.process(safeParams.message ?? '', attachments);
|
|
deps.metrics?.incrementMessages();
|
|
send(makeEvent(request.id, 'done', { content: response }));
|
|
} catch (err) {
|
|
const message = err instanceof Error ? err.message : 'Unknown error';
|
|
deps.metrics?.incrementErrors();
|
|
deps.metrics?.recordEvent({
|
|
timestamp: Date.now(),
|
|
level: 'error',
|
|
source: 'agent.send',
|
|
message,
|
|
context: { sessionId: laneId },
|
|
});
|
|
send(makeEvent(request.id, 'error', { code: ErrorCode.InternalError, message }));
|
|
} finally {
|
|
deps.sessionBridge.setBusy(connectionId, false);
|
|
deps.sessionBridge.setOnToolUse(connectionId, undefined);
|
|
deps.metrics?.endRequest(requestId);
|
|
}
|
|
});
|
|
},
|
|
|
|
'agent.cancel': async (request: GatewayRequest): Promise<OutboundMessage> => {
|
|
const params = request.params as { connectionId?: string } | undefined;
|
|
const connectionId = params?.connectionId as string;
|
|
|
|
if (!connectionId) {
|
|
return makeError(request.id, ErrorCode.InvalidRequest, 'connectionId is required');
|
|
}
|
|
|
|
const sessionId = deps.sessionBridge.getSessionId(connectionId);
|
|
const laneId = sessionId ?? connectionId;
|
|
|
|
// Clear any queued (not-yet-started) work first.
|
|
deps.laneQueue.cancel(laneId);
|
|
|
|
const cancelled = deps.sessionBridge.cancel(connectionId);
|
|
return {
|
|
id: request.id,
|
|
result: {
|
|
cancelled,
|
|
message: cancelled
|
|
? 'Cancellation requested. The active operation will stop at the next safe point.'
|
|
: 'No active operation to cancel.',
|
|
},
|
|
};
|
|
},
|
|
};
|
|
}
|