Files
flynn/src/gateway/session-bridge.ts
T

417 lines
15 KiB
TypeScript

import { randomUUID } from 'crypto';
import type { SessionManager } from '../session/manager.js';
import type { ModelClient } from '../models/types.js';
import type { ModelRouter, ModelTier } from '../models/router.js';
import { createClientFromConfig } from '../daemon/models.js';
import type { Config, ModelConfig, ModelProvider } from '../config/index.js';
import type { ToolRegistry } from '../tools/registry.js';
import type { ToolExecutor } from '../tools/executor.js';
import { AgentOrchestrator, type DelegationConfig } from '../backends/native/orchestrator.js';
import type { ToolUseEvent } from '../backends/native/agent.js';
import type { MemoryStore } from '../memory/store.js';
import { summarizeSessionOnEnd, type SessionEndSummaryConfig } from '../session/endSummary.js';
export interface SessionBridgeConfig {
sessionManager: SessionManager;
modelClient: ModelClient | ModelRouter;
systemPrompt: string;
toolRegistry: ToolRegistry;
toolExecutor: ToolExecutor;
config?: Config;
memoryStore?: MemoryStore;
}
interface ClientEntry {
connectionId: string;
sessionId: string;
agent: AgentOrchestrator;
busy: boolean;
}
export class SessionBridge {
private clients: Map<string, ClientEntry> = new Map();
private agents: Map<string, AgentOrchestrator> = new Map();
private config: SessionBridgeConfig;
constructor(config: SessionBridgeConfig) {
this.config = config;
}
/** Register a new WS connection. Returns the assigned connection ID. */
connect(connectionId?: string): string {
const id = connectionId ?? randomUUID();
const sessionId = `ws:${id}`;
const agent = this.getOrCreateAgent(sessionId);
this.clients.set(id, {
connectionId: id,
sessionId,
agent,
busy: false,
});
return id;
}
/** Remove a WS connection. Does NOT destroy the session (persists in SQLite). */
async disconnect(connectionId: string): Promise<void> {
const client = this.clients.get(connectionId);
if (client) {
// Only remove the agent if no other clients share the session
const otherClients = Array.from(this.clients.values())
.filter(c => c.sessionId === client.sessionId && c.connectionId !== connectionId);
if (otherClients.length === 0) {
const agent = this.agents.get(client.sessionId);
const summaryConfig = this.config.config?.sessions?.end_summary;
if (agent && summaryConfig?.enabled) {
const history = agent.getHistory();
const mappedConfig: SessionEndSummaryConfig = {
enabled: summaryConfig.enabled,
tier: summaryConfig.tier,
maxMessages: summaryConfig.max_messages,
maxInputChars: summaryConfig.max_input_chars,
maxTokens: summaryConfig.max_tokens,
writeToMemory: summaryConfig.write_to_memory,
memoryNamespace: summaryConfig.memory_namespace,
};
try {
await summarizeSessionOnEnd({
agent,
sessionId: client.sessionId,
history,
config: mappedConfig,
memoryStore: this.config.memoryStore,
});
} catch (error) {
console.warn('Session end summary failed:', error);
}
}
this.agents.delete(client.sessionId);
}
this.clients.delete(connectionId);
}
}
/** Switch a connection to a different session (e.g. resuming an old session). */
switchSession(connectionId: string, sessionId: string): void {
const client = this.clients.get(connectionId);
if (!client) {throw new Error(`Unknown connection: ${connectionId}`);}
if (client.busy) {throw new Error('Cannot switch session while agent is busy');}
const agent = this.getOrCreateAgent(sessionId);
client.sessionId = sessionId;
client.agent = agent;
}
/** Get the AgentOrchestrator for a connection. */
getAgent(connectionId: string): AgentOrchestrator | undefined {
return this.clients.get(connectionId)?.agent;
}
/** Get the session ID for a connection. */
getSessionId(connectionId: string): string | undefined {
return this.clients.get(connectionId)?.sessionId;
}
/** Check if a connection's agent is busy. */
isBusy(connectionId: string): boolean {
return this.clients.get(connectionId)?.busy ?? false;
}
/** Mark a connection's agent as busy/idle. */
setBusy(connectionId: string, busy: boolean): void {
const client = this.clients.get(connectionId);
if (client) {client.busy = busy;}
}
/** Request cancellation for the current operation on a connection's agent. */
cancel(connectionId: string): boolean {
const client = this.clients.get(connectionId);
if (!client || !client.busy) {
return false;
}
client.agent.cancel();
return true;
}
/**
* Request cancellation for the active operation in a session.
* Returns true if at least one connection in the session is currently busy.
*/
cancelSession(sessionId: string): boolean {
const clients = Array.from(this.clients.values()).filter((client) => client.sessionId === sessionId);
if (clients.length === 0) {
return false;
}
const hasBusyClient = clients.some((client) => client.busy);
if (!hasBusyClient) {
return false;
}
const agent = this.agents.get(sessionId);
if (!agent) {
return false;
}
agent.cancel();
return true;
}
/** Set onToolUse callback for a connection's agent. */
setOnToolUse(connectionId: string, callback: ((event: ToolUseEvent) => void) | undefined): void {
const client = this.clients.get(connectionId);
if (client) {client.agent.setOnToolUse(callback);}
}
/** List all active sessions with connection counts. */
listSessions(): Array<{ sessionId: string; connections: number }> {
const sessionMap = new Map<string, number>();
for (const client of this.clients.values()) {
sessionMap.set(client.sessionId, (sessionMap.get(client.sessionId) ?? 0) + 1);
}
return Array.from(sessionMap.entries()).map(([sessionId, connections]) => ({
sessionId,
connections,
}));
}
/** Get count of active connections. */
get connectionCount(): number {
return this.clients.size;
}
/** Get usage stats for a specific connection's agent. */
getUsage(connectionId: string): { inputTokens: number; outputTokens: number; calls: number } | undefined {
const agent = this.clients.get(connectionId)?.agent;
if (!agent) {return undefined;}
const usage = agent.getUsage();
return {
inputTokens: usage.primary.inputTokens,
outputTokens: usage.primary.outputTokens,
calls: usage.primary.calls,
};
}
/** Get usage stats for all active sessions. Returns an array of per-session usage entries. */
getAllUsage(): Array<{
sessionId: string;
primary: { inputTokens: number; outputTokens: number; calls: number };
delegation: Record<string, { inputTokens: number; outputTokens: number; calls: number }>;
total: { inputTokens: number; outputTokens: number; calls: number; estimatedCost: number };
}> {
const results: Array<{
sessionId: string;
primary: { inputTokens: number; outputTokens: number; calls: number };
delegation: Record<string, { inputTokens: number; outputTokens: number; calls: number }>;
total: { inputTokens: number; outputTokens: number; calls: number; estimatedCost: number };
}> = [];
// De-duplicate by sessionId (multiple connections may share a session)
const seen = new Set<string>();
for (const client of this.clients.values()) {
if (seen.has(client.sessionId)) {continue;}
seen.add(client.sessionId);
const usage = client.agent.getUsage();
results.push({
sessionId: client.sessionId,
primary: usage.primary,
delegation: usage.delegation,
total: usage.total,
});
}
return results;
}
/** Get estimated context budget for all active sessions. */
getAllContextUsage(): Array<{
sessionId: string;
budget: {
estimatedTokens: number;
contextWindow: number;
remainingTokens: number;
usagePct: number;
thresholdPct: number;
thresholdTokens: number;
shouldCompact: boolean;
};
}> {
const results: Array<{
sessionId: string;
budget: {
estimatedTokens: number;
contextWindow: number;
remainingTokens: number;
usagePct: number;
thresholdPct: number;
thresholdTokens: number;
shouldCompact: boolean;
};
}> = [];
const seen = new Set<string>();
for (const client of this.clients.values()) {
if (seen.has(client.sessionId)) {
continue;
}
seen.add(client.sessionId);
results.push({
sessionId: client.sessionId,
budget: client.agent.getContextBudget(),
});
}
return results;
}
private getOrCreateAgent(sessionId: string): AgentOrchestrator {
let agent = this.agents.get(sessionId);
if (!agent) {
const session = this.config.sessionManager.getSession('ws', sessionId);
const config = this.config.config;
// Read per-session tier override from session config
const sessionTier = session.getConfig?.('modelTier') as ModelTier | undefined;
const primaryTier = sessionTier ?? config?.agents.primary_tier ?? 'default';
const delegationConfig: DelegationConfig = {
compaction: config?.agents.delegation.compaction ?? 'fast',
memory_extraction: config?.agents.delegation.memory_extraction ?? 'fast',
classification: config?.agents.delegation.classification ?? 'fast',
tool_summarisation: config?.agents.delegation.tool_summarisation ?? 'fast',
complex_reasoning: config?.agents.delegation.complex_reasoning ?? 'complex',
};
const backgroundModelOverrides = this.buildBackgroundModelOverrides();
agent = new AgentOrchestrator({
modelRouter: this.config.modelClient as ModelRouter,
systemPrompt: this.config.systemPrompt,
session,
toolRegistry: this.config.toolRegistry,
toolExecutor: this.config.toolExecutor,
primaryTier,
delegation: delegationConfig,
backgroundModelOverrides,
maxDelegationDepth: config?.agents.max_delegation_depth ?? 3,
maxIterations: config?.agents.max_iterations,
compaction: config?.compaction.enabled ? {
thresholdPct: config.compaction.threshold_pct,
keepTurns: config.compaction.keep_turns,
summaryMaxTokens: config.compaction.summary_max_tokens,
importanceThreshold: config.compaction.importance_threshold,
proactive: {
enabled: config.compaction.proactive.enabled,
warnPct: config.compaction.proactive.warn_pct,
checkpointPct: config.compaction.proactive.checkpoint_pct,
autoCompactPct: config.compaction.proactive.auto_compact_pct,
checkpointCooldownMs: config.compaction.proactive.checkpoint_cooldown_ms,
memoryNamespace: config.compaction.proactive.memory_namespace,
},
} : undefined,
modelName: config?.models.default.model,
contextWindow: config?.models.default.context_window,
memoryStore: this.config.memoryStore,
memoryAutoExtract: config?.memory?.auto_extract,
memoryInjectionStrategy: config?.memory?.injection_strategy,
memoryMaxInjectionTokens: config?.memory?.max_injection_tokens,
memoryProactiveExtractEnabled: config?.memory?.proactive_extract?.enabled,
memoryProactiveExtractMinToolCalls: config?.memory?.proactive_extract?.min_tool_calls,
memoryProactiveExtractNamespace: config?.memory?.proactive_extract?.namespace,
memoryDailyLogEnabled: config?.memory?.daily_log?.enabled,
memoryDailyLogNamespacePrefix: config?.memory?.daily_log?.namespace_prefix,
toolPolicyContext: {
agent: primaryTier,
provider: config?.models.default.provider,
autonomyLevel: config?.agents.autonomy_level ?? 'standard',
sensitiveMode: config?.agents.sensitive_mode ?? 'deny_without_elevation',
immutableDenylist: (config?.agents.immutable_denylist ?? []).map((rule) => ({
tool: rule.tool,
argsPattern: rule.args_pattern,
reason: rule.reason,
})),
},
});
this.agents.set(sessionId, agent);
}
return agent;
}
private buildBackgroundModelOverrides(): Partial<Record<keyof DelegationConfig, {
client: ModelClient;
label: string;
fallbackTier: ModelTier;
}>> {
const runtimeConfig = this.config.config;
const overrides: Partial<Record<keyof DelegationConfig, {
client: ModelClient;
label: string;
fallbackTier: ModelTier;
}>> = {};
if (!runtimeConfig) {
return overrides;
}
const providerConfigs = this.buildProviderConfigMap(runtimeConfig);
const configured = runtimeConfig.agents?.background_models ?? {};
const tasks: Array<keyof DelegationConfig> = [
'compaction',
'memory_extraction',
'classification',
'tool_summarisation',
'complex_reasoning',
];
for (const task of tasks) {
const entry = configured[task];
if (!entry || entry.enabled === false) {
continue;
}
const template = providerConfigs[entry.provider];
try {
const client = createClientFromConfig(
template
? { ...template, provider: entry.provider, model: entry.model }
: { provider: entry.provider, model: entry.model },
);
overrides[task] = {
client,
label: `${entry.provider}/${entry.model}`,
fallbackTier: entry.fallback_tier,
};
} catch (error) {
console.warn(
`[Flynn:gateway] Failed to initialize background model override for ${task} ` +
`(${entry.provider}/${entry.model}): ${error instanceof Error ? error.message : String(error)}`,
);
}
}
return overrides;
}
private buildProviderConfigMap(config: Config): Partial<Record<ModelProvider, ModelConfig>> {
const providerConfigs: Partial<Record<ModelProvider, ModelConfig>> = {};
const modelConfigs: ModelConfig[] = [
config.models.default,
...(config.models.fast ? [config.models.fast] : []),
...(config.models.complex ? [config.models.complex] : []),
...(config.models.local ? [config.models.local] : []),
...Object.values(config.models.local_providers ?? {}),
];
for (const modelConfig of modelConfigs) {
providerConfigs[modelConfig.provider] = modelConfig;
if (modelConfig.fallback) {
providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback;
}
}
return providerConfigs;
}
}