diff --git a/docs/plans/2026-02-05-phase2-websocket-gateway.md b/docs/plans/2026-02-05-phase2-websocket-gateway.md new file mode 100644 index 0000000..774d9b0 --- /dev/null +++ b/docs/plans/2026-02-05-phase2-websocket-gateway.md @@ -0,0 +1,214 @@ +# Phase 2: WebSocket Gateway — Implementation Plan + +## Goal + +Add a WebSocket gateway to the Flynn daemon that allows real-time communication from web clients (and eventually any WS client). The gateway wraps existing daemon components (session manager, agent, tool registry) — no refactoring of existing code. + +## Approach + +Incremental, additive. Existing Telegram bot and TUI continue working unchanged. The gateway is a new module that starts alongside them. + +## Architecture + +``` +src/gateway/ +├── protocol.ts # JSON-RPC message types + event types +├── server.ts # WebSocket server (ws library), connection lifecycle +├── router.ts # Method routing → handler dispatch +├── auth.ts # Token auth + Tailscale identity headers +├── session-bridge.ts # Maps WS client connections to agent sessions +└── handlers/ + ├── agent.ts # agent.send (streaming), agent.cancel, agent.status + ├── sessions.ts # sessions.list, sessions.history, sessions.create + ├── tools.ts # tools.list, tools.invoke + └── system.ts # system.health, system.info +``` + +## Protocol + +JSON-RPC-like over WebSocket (not full JSON-RPC 2.0 — simpler): + +```typescript +// Client → Server +interface GatewayRequest { + id: number; // Client-assigned request ID + method: string; // e.g. "agent.send" + params: object; // Method-specific parameters +} + +// Server → Client (success) +interface GatewayResponse { + id: number; // Matches request ID + result: object; // Method-specific result +} + +// Server → Client (error) +interface GatewayError { + id: number; // Matches request ID + error: { + code: number; // Error code (negative = protocol, positive = app) + message: string; // Human-readable description + }; +} + +// Server → Client (streaming event, multiple per request) +interface GatewayEvent { + id: number; // Matches originating request ID + event: string; // Event type name + data: object; // Event-specific payload +} +``` + +### Event Types (for agent.send streaming) + +| Event | Data | When | +|-------|------|------| +| `content` | `{ text: string }` | Text chunk from model | +| `tool_start` | `{ tool: string, args: object }` | Tool execution beginning | +| `tool_end` | `{ tool: string, result: { success, output, error? } }` | Tool execution complete | +| `done` | `{ content: string, usage: { inputTokens, outputTokens } }` | Final response | +| `error` | `{ code: number, message: string }` | Error during processing | + +### Error Codes + +| Code | Meaning | +|------|---------| +| -1 | Parse error (invalid JSON) | +| -2 | Invalid request (missing id/method) | +| -3 | Method not found | +| -4 | Authentication required | +| -5 | Authentication failed | +| 1 | Session not found | +| 2 | Tool not found | +| 3 | Agent busy (already processing) | +| 4 | Request cancelled | + +## Methods + +### agent.send +Send a message to the agent and receive streaming response. + +``` +Params: { message: string, sessionId?: string } +Events: content, tool_start, tool_end, done, error +Response: (none — final state sent as "done" event) +``` + +### agent.cancel +Cancel the currently running agent request. + +``` +Params: { sessionId?: string } +Response: { cancelled: boolean } +``` + +### sessions.list +List all active sessions. + +``` +Params: {} +Response: { sessions: [{ id: string, messageCount: number }] } +``` + +### sessions.history +Get message history for a session. + +``` +Params: { sessionId: string, limit?: number, offset?: number } +Response: { messages: Message[], total: number } +``` + +### sessions.create +Create a new session. + +``` +Params: { sessionId?: string } +Response: { sessionId: string } +``` + +### tools.list +List all registered tools. + +``` +Params: {} +Response: { tools: [{ name, description, inputSchema }] } +``` + +### tools.invoke +Directly invoke a tool (bypasses agent). + +``` +Params: { tool: string, args: object } +Response: { success: boolean, output: string, error?: string } +``` + +### system.health +Health check. + +``` +Params: {} +Response: { status: "ok", uptime: number, version: string, sessions: number, tools: number } +``` + +## Session Bridge + +Each WebSocket connection is associated with a session: + +1. Client connects → assigned default session ID `ws:{connectionId}` +2. Client can specify `sessionId` in `agent.send` to use a named session +3. Sessions are created on demand via SessionManager +4. Multiple WS clients can share a session (e.g. multiple browser tabs) +5. Disconnection does NOT destroy the session (persistence via SQLite) + +## Auth (Phase 2b) + +Two auth modes (checked in order): + +1. **Token auth**: `Authorization: Bearer ` header on WS upgrade, or `{ method: "auth", params: { token: "..." } }` as first message +2. **Tailscale identity**: `Tailscale-User-Login` header set by Tailscale Funnel/proxy + +Config addition: +```yaml +server: + port: 18800 + auth: + token: "optional-static-token" # If set, required for all WS connections + tailscale_identity: true # Trust Tailscale-User-Login header +``` + +No auth initially (Phase 2a) — gateway only listens on localhost. + +## Implementation Order + +1. `protocol.ts` — Types only, no runtime code +2. `router.ts` — Method dispatch (pure function, easy to test) +3. `session-bridge.ts` — Client-to-session mapping +4. `handlers/system.ts` — Simplest handler, proves the pattern +5. `handlers/sessions.ts` — Session listing/history +6. `handlers/tools.ts` — Tool listing/invocation +7. `handlers/agent.ts` — The main handler (streaming, tool events) +8. `server.ts` — WebSocket server, ties everything together +9. Wire into `daemon/index.ts` +10. Tests throughout + +## Dependencies + +Need `ws` package: +```bash +pnpm add ws +pnpm add -D @types/ws +``` + +## Test Strategy + +- Unit tests for protocol message validation +- Unit tests for router dispatch +- Unit tests for each handler (mock session manager, agent, tool registry) +- Integration test: WS client → server → handler → response +- Test streaming events for agent.send + +--- + +*Plan Version: 1.0* +*Created: 2026-02-05* +*Parent: docs/plans/2026-02-05-openclaw-parity-design.md Phase 2* diff --git a/package.json b/package.json index 07ea637..6d0c12d 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "@types/marked-terminal": "^6.1.1", "@types/node": "^22.0.0", "@types/react": "^19.0.0", + "@types/ws": "^8.18.1", "eslint": "^9.0.0", "tsx": "^4.0.0", "typescript": "^5.7.0", @@ -45,6 +46,7 @@ "ollama": "^0.5.0", "openai": "^4.0.0", "react": "^19.0.0", + "ws": "^8.19.0", "yaml": "^2.7.0", "zod": "^3.24.0" }, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index cf9e81a..3a84b78 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -41,6 +41,9 @@ importers: react: specifier: ^19.0.0 version: 19.2.4 + ws: + specifier: ^8.19.0 + version: 8.19.0 yaml: specifier: ^2.7.0 version: 2.8.2 @@ -60,6 +63,9 @@ importers: '@types/react': specifier: ^19.0.0 version: 19.2.11 + '@types/ws': + specifier: ^8.18.1 + version: 8.18.1 eslint: specifier: ^9.0.0 version: 9.39.2 @@ -477,6 +483,9 @@ packages: '@types/react@19.2.11': resolution: {integrity: sha512-tORuanb01iEzWvMGVGv2ZDhYZVeRMrw453DCSAIn/5yvcSVnMoUMTyf33nQJLahYEnv9xqrTNbgz4qY5EfSh0g==} + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + '@vitest/expect@3.2.4': resolution: {integrity: sha512-Io0yyORnB6sikFlt8QW5K7slY4OjqNX9jmJQ02QDda8lyM6B5oNgVWoSoKPac8/kgnCUzuHQKrSLtu/uOqqrig==} @@ -1867,6 +1876,10 @@ snapshots: dependencies: csstype: 3.2.3 + '@types/ws@8.18.1': + dependencies: + '@types/node': 22.19.7 + '@vitest/expect@3.2.4': dependencies: '@types/chai': 5.2.3 diff --git a/src/daemon/index.ts b/src/daemon/index.ts index 6056c13..290978e 100644 --- a/src/daemon/index.ts +++ b/src/daemon/index.ts @@ -7,6 +7,7 @@ import { createTelegramBot } from '../frontends/telegram/index.js'; import { SessionStore, SessionManager } from '../session/index.js'; import { HookEngine } from '../hooks/index.js'; import { ToolRegistry, ToolExecutor, allBuiltinTools } from '../tools/index.js'; +import { GatewayServer } from '../gateway/index.js'; import { resolve } from 'path'; import { homedir } from 'os'; import { mkdirSync, readFileSync, existsSync } from 'fs'; @@ -22,6 +23,7 @@ export interface DaemonContext { modelRouter: ModelRouter; toolRegistry: ToolRegistry; toolExecutor: ToolExecutor; + gateway: GatewayServer; } function loadSystemPrompt(): string { @@ -133,6 +135,9 @@ export async function startDaemon(config: Config): Promise { // Initialize model router const modelRouter = createModelRouter(config); + // Load system prompt once for reuse + const systemPrompt = loadSystemPrompt(); + // Get Telegram session const telegramUserId = String(config.telegram.allowed_chat_ids[0]); const session = sessionManager.getSession('telegram', telegramUserId); @@ -140,7 +145,7 @@ export async function startDaemon(config: Config): Promise { // Initialize native agent with session and tools const agent = new NativeAgent({ modelClient: modelRouter, - systemPrompt: loadSystemPrompt(), + systemPrompt, session, toolRegistry, toolExecutor, @@ -153,6 +158,17 @@ export async function startDaemon(config: Config): Promise { hookEngine, }); + // Initialize gateway WebSocket server + const gateway = new GatewayServer({ + port: config.server.port, + host: config.server.localhost ? '127.0.0.1' : '0.0.0.0', + sessionManager, + modelClient: modelRouter, + systemPrompt, + toolRegistry, + toolExecutor, + }); + // Register signal handlers const signalHandler = () => { lifecycle.shutdown().then(() => process.exit(0)); @@ -179,9 +195,17 @@ export async function startDaemon(config: Config): Promise { }, }); + // Start gateway + lifecycle.onShutdown(async () => { + await gateway.stop(); + console.log('Gateway server stopped'); + }); + + await gateway.start(); + console.log('Flynn daemon started'); - return { config, lifecycle, bot, agent, sessionStore, sessionManager, hookEngine, modelRouter, toolRegistry, toolExecutor }; + return { config, lifecycle, bot, agent, sessionStore, sessionManager, hookEngine, modelRouter, toolRegistry, toolExecutor, gateway }; } export { Lifecycle } from './lifecycle.js'; diff --git a/src/gateway/auth.test.ts b/src/gateway/auth.test.ts new file mode 100644 index 0000000..fcb998a --- /dev/null +++ b/src/gateway/auth.test.ts @@ -0,0 +1,85 @@ +import { describe, it, expect } from 'vitest'; +import { authenticateRequest } from './auth.js'; +import type { IncomingMessage } from 'http'; + +function mockRequest(headers: Record = {}): IncomingMessage { + return { headers } as unknown as IncomingMessage; +} + +describe('authenticateRequest', () => { + describe('no auth configured', () => { + it('allows all connections', () => { + const result = authenticateRequest(mockRequest(), {}); + expect(result.authenticated).toBe(true); + expect(result.identity).toBe('anonymous'); + }); + }); + + describe('token auth', () => { + const config = { token: 'secret-token-123' }; + + it('accepts valid Bearer token', () => { + const result = authenticateRequest( + mockRequest({ authorization: 'Bearer secret-token-123' }), + config, + ); + expect(result.authenticated).toBe(true); + expect(result.identity).toBe('token-user'); + }); + + it('rejects missing Authorization header', () => { + const result = authenticateRequest(mockRequest(), config); + expect(result.authenticated).toBe(false); + expect(result.error).toContain('Authorization header required'); + }); + + it('rejects invalid token', () => { + const result = authenticateRequest( + mockRequest({ authorization: 'Bearer wrong-token' }), + config, + ); + expect(result.authenticated).toBe(false); + expect(result.error).toContain('Invalid token'); + }); + + it('rejects non-Bearer format', () => { + const result = authenticateRequest( + mockRequest({ authorization: 'Basic dXNlcjpwYXNz' }), + config, + ); + expect(result.authenticated).toBe(false); + expect(result.error).toContain('Invalid Authorization format'); + }); + + it('uses Tailscale identity when both token and tailscale are configured', () => { + const result = authenticateRequest( + mockRequest({ + authorization: 'Bearer secret-token-123', + 'tailscale-user-login': 'will@example.com', + }), + { token: 'secret-token-123', tailscaleIdentity: true }, + ); + expect(result.authenticated).toBe(true); + expect(result.identity).toBe('will@example.com'); + }); + }); + + describe('tailscale identity', () => { + const config = { tailscaleIdentity: true }; + + it('extracts identity from Tailscale-User-Login header', () => { + const result = authenticateRequest( + mockRequest({ 'tailscale-user-login': 'will@example.com' }), + config, + ); + expect(result.authenticated).toBe(true); + expect(result.identity).toBe('will@example.com'); + }); + + it('allows connections without Tailscale header (local access)', () => { + const result = authenticateRequest(mockRequest(), config); + expect(result.authenticated).toBe(true); + expect(result.identity).toBe('anonymous'); + }); + }); +}); diff --git a/src/gateway/auth.ts b/src/gateway/auth.ts new file mode 100644 index 0000000..edc7466 --- /dev/null +++ b/src/gateway/auth.ts @@ -0,0 +1,67 @@ +import type { IncomingMessage } from 'http'; + +export interface AuthConfig { + /** Static bearer token. If set, all connections must provide it. */ + token?: string; + /** Trust Tailscale-User-Login header for identity. */ + tailscaleIdentity?: boolean; +} + +export interface AuthResult { + authenticated: boolean; + identity?: string; + error?: string; +} + +/** + * Authenticates a WebSocket upgrade request. + * + * Auth is checked in this order: + * 1. If token is configured, validate Authorization header (Bearer token) + * 2. If tailscaleIdentity is enabled, extract identity from Tailscale-User-Login header + * 3. If no auth is configured, allow all connections + */ +export function authenticateRequest(req: IncomingMessage, config: AuthConfig): AuthResult { + // If token auth is configured, it's required + if (config.token) { + const authHeader = req.headers['authorization']; + if (!authHeader) { + return { authenticated: false, error: 'Authorization header required' }; + } + + const parts = authHeader.split(' '); + if (parts.length !== 2 || parts[0] !== 'Bearer') { + return { authenticated: false, error: 'Invalid Authorization format (expected: Bearer )' }; + } + + if (parts[1] !== config.token) { + return { authenticated: false, error: 'Invalid token' }; + } + + // Token is valid — check for Tailscale identity too + const identity = extractTailscaleIdentity(req, config); + return { authenticated: true, identity: identity ?? 'token-user' }; + } + + // If Tailscale identity is configured (no token), use it + if (config.tailscaleIdentity) { + const identity = extractTailscaleIdentity(req, config); + if (identity) { + return { authenticated: true, identity }; + } + // Tailscale identity configured but header not present — still allow (might be local) + return { authenticated: true, identity: 'anonymous' }; + } + + // No auth configured — allow all + return { authenticated: true, identity: 'anonymous' }; +} + +function extractTailscaleIdentity(req: IncomingMessage, config: AuthConfig): string | undefined { + if (!config.tailscaleIdentity) return undefined; + const header = req.headers['tailscale-user-login']; + if (typeof header === 'string' && header.length > 0) { + return header; + } + return undefined; +} diff --git a/src/gateway/handlers/agent.ts b/src/gateway/handlers/agent.ts new file mode 100644 index 0000000..4255a1b --- /dev/null +++ b/src/gateway/handlers/agent.ts @@ -0,0 +1,77 @@ +import type { GatewayRequest, OutboundMessage } from '../protocol.js'; +import type { SendFn } from '../router.js'; +import { makeEvent, makeError, ErrorCode } from '../protocol.js'; +import type { SessionBridge } from '../session-bridge.js'; + +export interface AgentHandlerDeps { + sessionBridge: SessionBridge; +} + +export function createAgentHandlers(deps: AgentHandlerDeps) { + return { + 'agent.send': async (request: GatewayRequest, send: SendFn): Promise => { + const params = request.params as { message?: string; connectionId?: string } | undefined; + if (!params?.message) { + return makeError(request.id, ErrorCode.InvalidRequest, 'message is required'); + } + + const connectionId = params.connectionId as string; + if (!connectionId) { + return makeError(request.id, ErrorCode.InvalidRequest, 'connectionId is required (set by server)'); + } + + if (deps.sessionBridge.isBusy(connectionId)) { + return makeError(request.id, ErrorCode.AgentBusy, 'Agent is already processing a request'); + } + + const agent = deps.sessionBridge.getAgent(connectionId); + if (!agent) { + return makeError(request.id, ErrorCode.SessionNotFound, 'No agent for this connection'); + } + + deps.sessionBridge.setBusy(connectionId, true); + + // Set up tool use callback to emit streaming events + deps.sessionBridge.setOnToolUse(connectionId, (event) => { + if (event.type === 'start') { + send(makeEvent(request.id, 'tool_start', { tool: event.tool, args: event.args })); + } else if (event.type === 'end') { + send(makeEvent(request.id, 'tool_end', { + tool: event.tool, + result: event.result ? { + success: event.result.success, + output: event.result.output, + error: event.result.error, + } : undefined, + })); + } + }); + + try { + const response = await agent.process(params.message); + send(makeEvent(request.id, 'done', { content: response })); + } catch (err) { + const message = err instanceof Error ? err.message : 'Unknown error'; + send(makeEvent(request.id, 'error', { code: ErrorCode.InternalError, message })); + } finally { + deps.sessionBridge.setBusy(connectionId, false); + deps.sessionBridge.setOnToolUse(connectionId, undefined); + } + }, + + 'agent.cancel': async (request: GatewayRequest): Promise => { + // Cancel is a placeholder — proper cancellation requires abort controller support in NativeAgent. + // For now, just report whether the agent was busy. + const params = request.params as { connectionId?: string } | undefined; + const connectionId = params?.connectionId as string; + + if (!connectionId) { + return makeError(request.id, ErrorCode.InvalidRequest, 'connectionId is required'); + } + + const wasBusy = deps.sessionBridge.isBusy(connectionId); + // TODO: Wire AbortController into NativeAgent for actual cancellation + return { id: request.id, result: { cancelled: wasBusy } }; + }, + }; +} diff --git a/src/gateway/handlers/handlers.test.ts b/src/gateway/handlers/handlers.test.ts new file mode 100644 index 0000000..686f1d2 --- /dev/null +++ b/src/gateway/handlers/handlers.test.ts @@ -0,0 +1,271 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { createSystemHandlers } from './system.js'; +import { createSessionHandlers } from './sessions.js'; +import { createToolHandlers } from './tools.js'; +import { createAgentHandlers } from './agent.js'; +import { ErrorCode } from '../protocol.js'; +import type { GatewayRequest, GatewayResponse, GatewayError, GatewayEvent, OutboundMessage } from '../protocol.js'; + +describe('system handlers', () => { + const deps = { + startTime: Date.now() - 60_000, + version: '0.1.0', + getSessionCount: () => 3, + getToolCount: () => 6, + getConnectionCount: () => 2, + }; + const handlers = createSystemHandlers(deps); + + it('system.health returns status info', async () => { + const req: GatewayRequest = { id: 1, method: 'system.health' }; + const result = await handlers['system.health'](req) as GatewayResponse; + + expect(result.id).toBe(1); + const r = result.result as Record; + expect(r.status).toBe('ok'); + expect(r.version).toBe('0.1.0'); + expect(r.sessions).toBe(3); + expect(r.tools).toBe(6); + expect(r.connections).toBe(2); + expect(typeof r.uptime).toBe('number'); + expect(r.uptime).toBeGreaterThanOrEqual(59); + }); +}); + +describe('session handlers', () => { + const mockHistory = [ + { role: 'user' as const, content: 'hello' }, + { role: 'assistant' as const, content: 'hi' }, + ]; + + const mockSession = { + id: 'ws:test', + addMessage: vi.fn(), + getHistory: vi.fn(() => mockHistory), + clear: vi.fn(), + }; + + const mockSessionManager = { + listSessions: vi.fn(() => ['ws:test']), + getSession: vi.fn(() => mockSession), + transferSession: vi.fn(), + closeSession: vi.fn(), + }; + + const handlers = createSessionHandlers({ + sessionManager: mockSessionManager as any, + }); + + beforeEach(() => { + vi.clearAllMocks(); + mockSessionManager.listSessions.mockReturnValue(['ws:test']); + mockSessionManager.getSession.mockReturnValue(mockSession); + mockSession.getHistory.mockReturnValue(mockHistory); + }); + + it('sessions.list returns session list with message counts', async () => { + const req: GatewayRequest = { id: 1, method: 'sessions.list' }; + const result = await handlers['sessions.list'](req) as GatewayResponse; + + expect(result.id).toBe(1); + const r = result.result as { sessions: Array<{ id: string; messageCount: number }> }; + expect(r.sessions).toHaveLength(1); + expect(r.sessions[0].id).toBe('ws:test'); + expect(r.sessions[0].messageCount).toBe(2); + }); + + it('sessions.history returns messages with pagination', async () => { + const req: GatewayRequest = { id: 2, method: 'sessions.history', params: { sessionId: 'ws:test', limit: 1, offset: 0 } }; + const result = await handlers['sessions.history'](req) as GatewayResponse; + + const r = result.result as { messages: unknown[]; total: number }; + expect(r.messages).toHaveLength(1); + expect(r.total).toBe(2); + }); + + it('sessions.history requires sessionId', async () => { + const req: GatewayRequest = { id: 3, method: 'sessions.history', params: {} }; + const result = await handlers['sessions.history'](req) as GatewayError; + + expect(result.error.code).toBe(ErrorCode.InvalidRequest); + }); + + it('sessions.create creates a new session', async () => { + const req: GatewayRequest = { id: 4, method: 'sessions.create', params: { sessionId: 'ws:new' } }; + const result = await handlers['sessions.create'](req) as GatewayResponse; + + const r = result.result as { sessionId: string }; + expect(r.sessionId).toBe('ws:new'); + expect(mockSessionManager.getSession).toHaveBeenCalledWith('ws', 'new'); + }); + + it('sessions.create auto-generates session ID', async () => { + const req: GatewayRequest = { id: 5, method: 'sessions.create' }; + const result = await handlers['sessions.create'](req) as GatewayResponse; + + const r = result.result as { sessionId: string }; + expect(r.sessionId).toMatch(/^ws:\d+$/); + }); +}); + +describe('tool handlers', () => { + const mockTool = { + name: 'test.tool', + description: 'A test tool', + inputSchema: { type: 'object' as const, properties: {} }, + execute: vi.fn(), + }; + + const mockRegistry = { + list: vi.fn(() => [mockTool]), + get: vi.fn((name: string) => (name === 'test.tool' ? mockTool : undefined)), + register: vi.fn(), + toAnthropicFormat: vi.fn(), + toOpenAIFormat: vi.fn(), + }; + + const mockExecutor = { + execute: vi.fn(async () => ({ success: true, output: 'done' })), + }; + + const handlers = createToolHandlers({ + toolRegistry: mockRegistry as any, + toolExecutor: mockExecutor as any, + }); + + beforeEach(() => { + vi.clearAllMocks(); + mockRegistry.list.mockReturnValue([mockTool]); + mockRegistry.get.mockImplementation((name: string) => (name === 'test.tool' ? mockTool : undefined)); + mockExecutor.execute.mockResolvedValue({ success: true, output: 'done' }); + }); + + it('tools.list returns tool definitions', async () => { + const req: GatewayRequest = { id: 1, method: 'tools.list' }; + const result = await handlers['tools.list'](req) as GatewayResponse; + + const r = result.result as { tools: Array<{ name: string }> }; + expect(r.tools).toHaveLength(1); + expect(r.tools[0].name).toBe('test.tool'); + }); + + it('tools.invoke executes a tool', async () => { + const req: GatewayRequest = { id: 2, method: 'tools.invoke', params: { tool: 'test.tool', args: {} } }; + const result = await handlers['tools.invoke'](req) as GatewayResponse; + + expect(result.result).toEqual({ success: true, output: 'done' }); + expect(mockExecutor.execute).toHaveBeenCalledWith('test.tool', {}); + }); + + it('tools.invoke errors on missing tool name', async () => { + const req: GatewayRequest = { id: 3, method: 'tools.invoke', params: {} }; + const result = await handlers['tools.invoke'](req) as GatewayError; + expect(result.error.code).toBe(ErrorCode.InvalidRequest); + }); + + it('tools.invoke errors on unknown tool', async () => { + const req: GatewayRequest = { id: 4, method: 'tools.invoke', params: { tool: 'unknown' } }; + const result = await handlers['tools.invoke'](req) as GatewayError; + expect(result.error.code).toBe(ErrorCode.ToolNotFound); + }); +}); + +describe('agent handlers', () => { + const mockAgent = { + process: vi.fn(async () => 'response text'), + setOnToolUse: vi.fn(), + }; + + const mockBridge = { + getAgent: vi.fn(() => mockAgent), + getSessionId: vi.fn(() => 'ws:conn-1'), + isBusy: vi.fn(() => false), + setBusy: vi.fn(), + setOnToolUse: vi.fn(), + }; + + const handlers = createAgentHandlers({ + sessionBridge: mockBridge as any, + }); + + beforeEach(() => { + vi.clearAllMocks(); + mockBridge.isBusy.mockReturnValue(false); + mockBridge.getAgent.mockReturnValue(mockAgent); + mockAgent.process.mockResolvedValue('response text'); + }); + + it('agent.send processes message and sends done event', async () => { + const req: GatewayRequest = { id: 1, method: 'agent.send', params: { message: 'hello', connectionId: 'conn-1' } }; + const sent: OutboundMessage[] = []; + const send = vi.fn((msg: OutboundMessage) => sent.push(msg)); + + await handlers['agent.send'](req, send); + + expect(mockAgent.process).toHaveBeenCalledWith('hello'); + expect(sent).toHaveLength(1); + const doneEvent = sent[0] as GatewayEvent; + expect(doneEvent.event).toBe('done'); + expect((doneEvent.data as any).content).toBe('response text'); + }); + + it('agent.send requires message', async () => { + const req: GatewayRequest = { id: 2, method: 'agent.send', params: { connectionId: 'conn-1' } }; + const send = vi.fn(); + const result = await handlers['agent.send'](req, send) as GatewayError; + + expect(result.error.code).toBe(ErrorCode.InvalidRequest); + expect(result.error.message).toContain('message'); + }); + + it('agent.send rejects when busy', async () => { + mockBridge.isBusy.mockReturnValue(true); + const req: GatewayRequest = { id: 3, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } }; + const send = vi.fn(); + const result = await handlers['agent.send'](req, send) as GatewayError; + + expect(result.error.code).toBe(ErrorCode.AgentBusy); + }); + + it('agent.send handles errors gracefully', async () => { + mockAgent.process.mockRejectedValue(new Error('model failed')); + const req: GatewayRequest = { id: 4, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } }; + const sent: OutboundMessage[] = []; + const send = vi.fn((msg: OutboundMessage) => sent.push(msg)); + + await handlers['agent.send'](req, send); + + const errorEvent = sent[0] as GatewayEvent; + expect(errorEvent.event).toBe('error'); + expect((errorEvent.data as any).message).toBe('model failed'); + }); + + it('agent.send sets and cleans up tool use callback', async () => { + const req: GatewayRequest = { id: 5, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } }; + const send = vi.fn(); + + await handlers['agent.send'](req, send); + + // setOnToolUse called twice: once to set callback, once to clear it + expect(mockBridge.setOnToolUse).toHaveBeenCalledTimes(2); + expect(mockBridge.setOnToolUse).toHaveBeenLastCalledWith('conn-1', undefined); + }); + + it('agent.send sets busy state correctly', async () => { + const req: GatewayRequest = { id: 6, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } }; + const send = vi.fn(); + + await handlers['agent.send'](req, send); + + expect(mockBridge.setBusy).toHaveBeenCalledWith('conn-1', true); + expect(mockBridge.setBusy).toHaveBeenCalledWith('conn-1', false); + }); + + it('agent.cancel returns cancelled state', async () => { + mockBridge.isBusy.mockReturnValue(true); + const req: GatewayRequest = { id: 7, method: 'agent.cancel', params: { connectionId: 'conn-1' } }; + const result = await handlers['agent.cancel'](req) as GatewayResponse; + + expect((result.result as any).cancelled).toBe(true); + }); +}); diff --git a/src/gateway/handlers/index.ts b/src/gateway/handlers/index.ts new file mode 100644 index 0000000..6affefd --- /dev/null +++ b/src/gateway/handlers/index.ts @@ -0,0 +1,8 @@ +export { createSystemHandlers } from './system.js'; +export type { SystemHandlerDeps } from './system.js'; +export { createSessionHandlers } from './sessions.js'; +export type { SessionHandlerDeps } from './sessions.js'; +export { createToolHandlers } from './tools.js'; +export type { ToolHandlerDeps } from './tools.js'; +export { createAgentHandlers } from './agent.js'; +export type { AgentHandlerDeps } from './agent.js'; diff --git a/src/gateway/handlers/sessions.ts b/src/gateway/handlers/sessions.ts new file mode 100644 index 0000000..3e10df2 --- /dev/null +++ b/src/gateway/handlers/sessions.ts @@ -0,0 +1,59 @@ +import type { GatewayRequest, OutboundMessage } from '../protocol.js'; +import { makeResponse, makeError, ErrorCode } from '../protocol.js'; +import type { SessionManager } from '../../session/manager.js'; + +export interface SessionHandlerDeps { + sessionManager: SessionManager; +} + +export function createSessionHandlers(deps: SessionHandlerDeps) { + return { + 'sessions.list': async (request: GatewayRequest): Promise => { + const sessionIds = deps.sessionManager.listSessions(); + const sessions = sessionIds.map(id => ({ + id, + messageCount: deps.sessionManager.getSession( + id.split(':')[0], + id.split(':').slice(1).join(':') + ).getHistory().length, + })); + return makeResponse(request.id, { sessions }); + }, + + 'sessions.history': async (request: GatewayRequest): Promise => { + const params = request.params as { sessionId?: string; limit?: number; offset?: number } | undefined; + if (!params?.sessionId) { + return makeError(request.id, ErrorCode.InvalidRequest, 'sessionId is required'); + } + + const { sessionId, limit, offset } = params; + const parts = sessionId.split(':'); + const frontend = parts[0]; + const userId = parts.slice(1).join(':'); + const session = deps.sessionManager.getSession(frontend, userId); + const allMessages = session.getHistory(); + + const start = offset ?? 0; + const end = limit ? start + limit : allMessages.length; + const messages = allMessages.slice(start, end); + + return makeResponse(request.id, { + messages, + total: allMessages.length, + }); + }, + + 'sessions.create': async (request: GatewayRequest): Promise => { + const params = request.params as { sessionId?: string } | undefined; + const sessionId = params?.sessionId ?? `ws:${Date.now()}`; + const parts = sessionId.split(':'); + const frontend = parts[0]; + const userId = parts.slice(1).join(':'); + + // Creating a session via getSession is idempotent + deps.sessionManager.getSession(frontend, userId); + + return makeResponse(request.id, { sessionId }); + }, + }; +} diff --git a/src/gateway/handlers/system.ts b/src/gateway/handlers/system.ts new file mode 100644 index 0000000..b1ae11f --- /dev/null +++ b/src/gateway/handlers/system.ts @@ -0,0 +1,25 @@ +import type { GatewayRequest, OutboundMessage } from '../protocol.js'; +import { makeResponse } from '../protocol.js'; + +export interface SystemHandlerDeps { + startTime: number; + version: string; + getSessionCount: () => number; + getToolCount: () => number; + getConnectionCount: () => number; +} + +export function createSystemHandlers(deps: SystemHandlerDeps) { + return { + 'system.health': async (request: GatewayRequest): Promise => { + return makeResponse(request.id, { + status: 'ok', + uptime: Math.floor((Date.now() - deps.startTime) / 1000), + version: deps.version, + sessions: deps.getSessionCount(), + tools: deps.getToolCount(), + connections: deps.getConnectionCount(), + }); + }, + }; +} diff --git a/src/gateway/handlers/tools.ts b/src/gateway/handlers/tools.ts new file mode 100644 index 0000000..f5e6e94 --- /dev/null +++ b/src/gateway/handlers/tools.ts @@ -0,0 +1,37 @@ +import type { GatewayRequest, OutboundMessage } from '../protocol.js'; +import { makeResponse, makeError, ErrorCode } from '../protocol.js'; +import type { ToolRegistry } from '../../tools/registry.js'; +import type { ToolExecutor } from '../../tools/executor.js'; + +export interface ToolHandlerDeps { + toolRegistry: ToolRegistry; + toolExecutor: ToolExecutor; +} + +export function createToolHandlers(deps: ToolHandlerDeps) { + return { + 'tools.list': async (request: GatewayRequest): Promise => { + const tools = deps.toolRegistry.list().map(t => ({ + name: t.name, + description: t.description, + inputSchema: t.inputSchema, + })); + return makeResponse(request.id, { tools }); + }, + + 'tools.invoke': async (request: GatewayRequest): Promise => { + const params = request.params as { tool?: string; args?: Record } | undefined; + if (!params?.tool) { + return makeError(request.id, ErrorCode.InvalidRequest, 'tool name is required'); + } + + const tool = deps.toolRegistry.get(params.tool); + if (!tool) { + return makeError(request.id, ErrorCode.ToolNotFound, `Tool not found: ${params.tool}`); + } + + const result = await deps.toolExecutor.execute(params.tool, params.args ?? {}); + return makeResponse(request.id, result); + }, + }; +} diff --git a/src/gateway/index.ts b/src/gateway/index.ts new file mode 100644 index 0000000..6aa5c24 --- /dev/null +++ b/src/gateway/index.ts @@ -0,0 +1,29 @@ +export { GatewayServer } from './server.js'; +export type { GatewayServerConfig } from './server.js'; +export { Router } from './router.js'; +export type { HandlerFn, SendFn } from './router.js'; +export { SessionBridge } from './session-bridge.js'; +export type { SessionBridgeConfig } from './session-bridge.js'; +export { authenticateRequest } from './auth.js'; +export type { AuthConfig, AuthResult } from './auth.js'; +export { + ErrorCode, + isValidRequest, + parseMessage, + makeResponse, + makeError, + makeEvent, +} from './protocol.js'; +export type { + GatewayRequest, + GatewayResponse, + GatewayError, + GatewayEvent, + OutboundMessage, + EventType, + ContentEventData, + ToolStartEventData, + ToolEndEventData, + DoneEventData, + ErrorEventData, +} from './protocol.js'; diff --git a/src/gateway/protocol.test.ts b/src/gateway/protocol.test.ts new file mode 100644 index 0000000..4206d04 --- /dev/null +++ b/src/gateway/protocol.test.ts @@ -0,0 +1,90 @@ +import { describe, it, expect } from 'vitest'; +import { + isValidRequest, + parseMessage, + makeResponse, + makeError, + makeEvent, + ErrorCode, +} from './protocol.js'; + +describe('protocol', () => { + describe('isValidRequest', () => { + it('accepts valid request with params', () => { + expect(isValidRequest({ id: 1, method: 'agent.send', params: { message: 'hello' } })).toBe(true); + }); + + it('accepts valid request without params', () => { + expect(isValidRequest({ id: 1, method: 'system.health' })).toBe(true); + }); + + it('rejects missing id', () => { + expect(isValidRequest({ method: 'test' })).toBe(false); + }); + + it('rejects missing method', () => { + expect(isValidRequest({ id: 1 })).toBe(false); + }); + + it('rejects non-object', () => { + expect(isValidRequest('not an object')).toBe(false); + expect(isValidRequest(null)).toBe(false); + expect(isValidRequest(42)).toBe(false); + }); + + it('rejects non-numeric id', () => { + expect(isValidRequest({ id: 'abc', method: 'test' })).toBe(false); + }); + + it('rejects non-string method', () => { + expect(isValidRequest({ id: 1, method: 42 })).toBe(false); + }); + + it('rejects non-object params', () => { + expect(isValidRequest({ id: 1, method: 'test', params: 'bad' })).toBe(false); + }); + }); + + describe('parseMessage', () => { + it('parses valid JSON into GatewayRequest', () => { + const msg = parseMessage('{"id":1,"method":"agent.send","params":{"message":"hi"}}'); + expect(msg).toEqual({ id: 1, method: 'agent.send', params: { message: 'hi' } }); + }); + + it('returns null for invalid JSON', () => { + expect(parseMessage('not json')).toBeNull(); + }); + + it('returns null for valid JSON that is not a valid request', () => { + expect(parseMessage('{"method":"test"}')).toBeNull(); + }); + }); + + describe('makeResponse', () => { + it('creates a response message', () => { + expect(makeResponse(1, { status: 'ok' })).toEqual({ + id: 1, + result: { status: 'ok' }, + }); + }); + }); + + describe('makeError', () => { + it('creates an error message', () => { + expect(makeError(1, ErrorCode.MethodNotFound, 'Not found')).toEqual({ + id: 1, + error: { code: -3, message: 'Not found' }, + }); + }); + }); + + describe('makeEvent', () => { + it('creates an event message', () => { + expect(makeEvent(1, 'content', { text: 'hello' })).toEqual({ + id: 1, + event: 'content', + data: { text: 'hello' }, + }); + }); + }); +}); diff --git a/src/gateway/protocol.ts b/src/gateway/protocol.ts new file mode 100644 index 0000000..40374fa --- /dev/null +++ b/src/gateway/protocol.ts @@ -0,0 +1,119 @@ +// Gateway protocol types — JSON-RPC-like messages over WebSocket. + +// ── Client → Server ──────────────────────────────────────────── + +export interface GatewayRequest { + id: number; + method: string; + params?: Record; +} + +// ── Server → Client ──────────────────────────────────────────── + +export interface GatewayResponse { + id: number; + result: unknown; +} + +export interface GatewayError { + id: number; + error: { + code: ErrorCode; + message: string; + }; +} + +export interface GatewayEvent { + id: number; + event: EventType; + data: unknown; +} + +// ── Event types emitted during agent.send ────────────────────── + +export type EventType = + | 'content' + | 'tool_start' + | 'tool_end' + | 'done' + | 'error'; + +export interface ContentEventData { + text: string; +} + +export interface ToolStartEventData { + tool: string; + args: unknown; +} + +export interface ToolEndEventData { + tool: string; + result: { + success: boolean; + output: string; + error?: string; + }; +} + +export interface DoneEventData { + content: string; +} + +export interface ErrorEventData { + code: ErrorCode; + message: string; +} + +// ── Error codes ──────────────────────────────────────────────── + +export enum ErrorCode { + ParseError = -1, + InvalidRequest = -2, + MethodNotFound = -3, + AuthRequired = -4, + AuthFailed = -5, + SessionNotFound = 1, + ToolNotFound = 2, + AgentBusy = 3, + RequestCancelled = 4, + InternalError = 5, +} + +// ── Outbound message (union of all server → client types) ────── + +export type OutboundMessage = GatewayResponse | GatewayError | GatewayEvent; + +// ── Validation helpers ───────────────────────────────────────── + +export function isValidRequest(msg: unknown): msg is GatewayRequest { + if (typeof msg !== 'object' || msg === null) return false; + const obj = msg as Record; + return ( + typeof obj.id === 'number' && + typeof obj.method === 'string' && + (obj.params === undefined || (typeof obj.params === 'object' && obj.params !== null)) + ); +} + +export function parseMessage(raw: string): GatewayRequest | null { + try { + const parsed = JSON.parse(raw); + if (isValidRequest(parsed)) return parsed; + return null; + } catch { + return null; + } +} + +export function makeResponse(id: number, result: unknown): GatewayResponse { + return { id, result }; +} + +export function makeError(id: number, code: ErrorCode, message: string): GatewayError { + return { id, error: { code, message } }; +} + +export function makeEvent(id: number, event: EventType, data: unknown): GatewayEvent { + return { id, event, data }; +} diff --git a/src/gateway/router.test.ts b/src/gateway/router.test.ts new file mode 100644 index 0000000..da898ef --- /dev/null +++ b/src/gateway/router.test.ts @@ -0,0 +1,58 @@ +import { describe, it, expect, vi } from 'vitest'; +import { Router } from './router.js'; +import type { GatewayRequest, OutboundMessage } from './protocol.js'; +import { makeResponse, ErrorCode } from './protocol.js'; + +describe('Router', () => { + it('dispatches to registered handler', async () => { + const router = new Router(); + const handler = vi.fn(async (req: GatewayRequest) => makeResponse(req.id, { ok: true })); + router.register('test.method', handler); + + const request: GatewayRequest = { id: 1, method: 'test.method', params: {} }; + const send = vi.fn(); + const result = await router.dispatch(request, send); + + expect(handler).toHaveBeenCalledWith(request, send); + expect(result).toEqual({ id: 1, result: { ok: true } }); + }); + + it('returns MethodNotFound for unregistered method', async () => { + const router = new Router(); + const request: GatewayRequest = { id: 1, method: 'unknown.method' }; + const send = vi.fn(); + const result = await router.dispatch(request, send); + + expect(result).toEqual({ + id: 1, + error: { code: ErrorCode.MethodNotFound, message: 'Unknown method: unknown.method' }, + }); + }); + + it('lists registered methods', () => { + const router = new Router(); + router.register('a.method', async () => {}); + router.register('b.method', async () => {}); + + expect(router.listMethods()).toEqual(['a.method', 'b.method']); + }); + + it('handler can send streaming events via send function', async () => { + const router = new Router(); + router.register('stream.test', async (req, send) => { + send({ id: req.id, event: 'content', data: { text: 'chunk1' } }); + send({ id: req.id, event: 'content', data: { text: 'chunk2' } }); + send({ id: req.id, event: 'done', data: { content: 'chunk1chunk2' } }); + }); + + const request: GatewayRequest = { id: 1, method: 'stream.test' }; + const sent: OutboundMessage[] = []; + const send = vi.fn((msg: OutboundMessage) => sent.push(msg)); + + await router.dispatch(request, send); + + expect(sent).toHaveLength(3); + expect(sent[0]).toEqual({ id: 1, event: 'content', data: { text: 'chunk1' } }); + expect(sent[2]).toEqual({ id: 1, event: 'done', data: { content: 'chunk1chunk2' } }); + }); +}); diff --git a/src/gateway/router.ts b/src/gateway/router.ts new file mode 100644 index 0000000..fb81f55 --- /dev/null +++ b/src/gateway/router.ts @@ -0,0 +1,27 @@ +import type { GatewayRequest, OutboundMessage } from './protocol.js'; +import { makeError, ErrorCode } from './protocol.js'; + +// A handler function receives a request and a send function for streaming events. +// It returns a final response/error, or void if it already sent a done event. +export type SendFn = (msg: OutboundMessage) => void; +export type HandlerFn = (request: GatewayRequest, send: SendFn) => Promise; + +export class Router { + private handlers: Map = new Map(); + + register(method: string, handler: HandlerFn): void { + this.handlers.set(method, handler); + } + + async dispatch(request: GatewayRequest, send: SendFn): Promise { + const handler = this.handlers.get(request.method); + if (!handler) { + return makeError(request.id, ErrorCode.MethodNotFound, `Unknown method: ${request.method}`); + } + return handler(request, send); + } + + listMethods(): string[] { + return Array.from(this.handlers.keys()); + } +} diff --git a/src/gateway/server.test.ts b/src/gateway/server.test.ts new file mode 100644 index 0000000..c40cccf --- /dev/null +++ b/src/gateway/server.test.ts @@ -0,0 +1,188 @@ +import { describe, it, expect, beforeAll, afterAll, vi } from 'vitest'; +import { WebSocket } from 'ws'; +import { GatewayServer } from './server.js'; +import type { GatewayServerConfig } from './server.js'; +import type { GatewayResponse, GatewayError, GatewayEvent } from './protocol.js'; +import { ErrorCode } from './protocol.js'; + +// Minimal mocks for dependencies +const mockSession = { + id: 'test', + addMessage: vi.fn(), + getHistory: vi.fn(() => []), + clear: vi.fn(), + setHistory: vi.fn(), +}; + +const mockSessionManager = { + getSession: vi.fn(() => mockSession), + listSessions: vi.fn(() => ['ws:test']), + transferSession: vi.fn(), + closeSession: vi.fn(), +}; + +const mockModelClient = { + chat: vi.fn(async () => ({ + content: 'Hello from Flynn!', + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5 }, + })), +}; + +const mockToolRegistry = { + register: vi.fn(), + get: vi.fn((name: string) => (name === 'shell.exec' ? { name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } } : undefined)), + list: vi.fn(() => [{ name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } }]), + toAnthropicFormat: vi.fn(() => []), + toOpenAIFormat: vi.fn(() => []), +}; + +const mockToolExecutor = { + execute: vi.fn(async () => ({ success: true, output: 'executed' })), +}; + +const TEST_PORT = 18899; + +let server: GatewayServer; + +function createClient(): Promise { + return new Promise((resolve, reject) => { + const ws = new WebSocket(`ws://127.0.0.1:${TEST_PORT}`); + ws.on('open', () => resolve(ws)); + ws.on('error', reject); + }); +} + +function sendAndReceive(ws: WebSocket, msg: object): Promise { + return new Promise((resolve) => { + ws.once('message', (data) => { + resolve(JSON.parse(data.toString())); + }); + ws.send(JSON.stringify(msg)); + }); +} + +function sendAndReceiveAll(ws: WebSocket, msg: object, count: number): Promise> { + return new Promise((resolve) => { + const messages: Array = []; + const handler = (data: Buffer) => { + messages.push(JSON.parse(data.toString())); + if (messages.length >= count) { + ws.off('message', handler); + resolve(messages); + } + }; + ws.on('message', handler); + ws.send(JSON.stringify(msg)); + }); +} + +describe('GatewayServer integration', () => { + beforeAll(async () => { + server = new GatewayServer({ + port: TEST_PORT, + sessionManager: mockSessionManager as unknown as GatewayServerConfig['sessionManager'], + modelClient: mockModelClient, + systemPrompt: 'Test prompt', + toolRegistry: mockToolRegistry as unknown as GatewayServerConfig['toolRegistry'], + toolExecutor: mockToolExecutor as unknown as GatewayServerConfig['toolExecutor'], + version: '0.1.0-test', + }); + await server.start(); + }); + + afterAll(async () => { + await server.stop(); + }); + + it('responds to system.health', async () => { + const ws = await createClient(); + try { + const result = await sendAndReceive(ws, { id: 1, method: 'system.health' }); + const response = result as GatewayResponse; + expect(response.id).toBe(1); + const r = response.result as Record; + expect(r.status).toBe('ok'); + expect(r.version).toBe('0.1.0-test'); + expect(typeof r.uptime).toBe('number'); + } finally { + ws.close(); + } + }); + + it('returns MethodNotFound for unknown method', async () => { + const ws = await createClient(); + try { + const result = await sendAndReceive(ws, { id: 2, method: 'unknown.method' }); + const error = result as GatewayError; + expect(error.error.code).toBe(ErrorCode.MethodNotFound); + } finally { + ws.close(); + } + }); + + it('returns ParseError for invalid JSON', async () => { + const ws = await createClient(); + try { + const result = await new Promise((resolve) => { + ws.once('message', (data) => resolve(JSON.parse(data.toString()))); + ws.send('not valid json'); + }); + expect(result.error.code).toBe(ErrorCode.ParseError); + } finally { + ws.close(); + } + }); + + it('lists tools via tools.list', async () => { + const ws = await createClient(); + try { + const result = await sendAndReceive(ws, { id: 3, method: 'tools.list' }); + const response = result as GatewayResponse; + const r = response.result as { tools: Array<{ name: string }> }; + expect(r.tools.length).toBeGreaterThan(0); + expect(r.tools[0].name).toBe('shell.exec'); + } finally { + ws.close(); + } + }); + + it('sends agent message and receives done event', async () => { + const ws = await createClient(); + try { + // agent.send streams events — we expect a 'done' event + const messages = await sendAndReceiveAll(ws, { id: 4, method: 'agent.send', params: { message: 'hi' } }, 1); + const doneEvent = messages[0] as GatewayEvent; + expect(doneEvent.id).toBe(4); + expect(doneEvent.event).toBe('done'); + expect((doneEvent.data as any).content).toBe('Hello from Flynn!'); + } finally { + ws.close(); + } + }); + + it('tracks connections correctly', async () => { + const ws1 = await createClient(); + const ws2 = await createClient(); + try { + const result = await sendAndReceive(ws1, { id: 5, method: 'system.health' }); + const r = (result as GatewayResponse).result as Record; + expect(r.connections).toBe(2); + } finally { + ws1.close(); + ws2.close(); + } + }); + + it('lists registered methods', () => { + const methods = server.getMethods(); + expect(methods).toContain('system.health'); + expect(methods).toContain('agent.send'); + expect(methods).toContain('agent.cancel'); + expect(methods).toContain('sessions.list'); + expect(methods).toContain('sessions.history'); + expect(methods).toContain('sessions.create'); + expect(methods).toContain('tools.list'); + expect(methods).toContain('tools.invoke'); + }); +}); diff --git a/src/gateway/server.ts b/src/gateway/server.ts new file mode 100644 index 0000000..f76d64a --- /dev/null +++ b/src/gateway/server.ts @@ -0,0 +1,204 @@ +import { WebSocketServer, WebSocket } from 'ws'; +import { randomUUID } from 'crypto'; +import type { IncomingMessage } from 'http'; +import { Router } from './router.js'; +import { SessionBridge } from './session-bridge.js'; +import type { SessionBridgeConfig } from './session-bridge.js'; +import { authenticateRequest } from './auth.js'; +import type { AuthConfig } from './auth.js'; +import { + parseMessage, + makeError, + ErrorCode, + type OutboundMessage, +} from './protocol.js'; +import { + createSystemHandlers, + createSessionHandlers, + createToolHandlers, + createAgentHandlers, +} from './handlers/index.js'; +import type { SessionManager } from '../session/manager.js'; +import type { ToolRegistry } from '../tools/registry.js'; +import type { ToolExecutor } from '../tools/executor.js'; + +export interface GatewayServerConfig { + port: number; + host?: string; + sessionManager: SessionManager; + modelClient: SessionBridgeConfig['modelClient']; + systemPrompt: string; + toolRegistry: ToolRegistry; + toolExecutor: ToolExecutor; + version?: string; + auth?: AuthConfig; +} + +export class GatewayServer { + private wss: WebSocketServer | null = null; + private router: Router; + private sessionBridge: SessionBridge; + private connectionMap: Map = new Map(); + private config: GatewayServerConfig; + private startTime: number = Date.now(); + + constructor(config: GatewayServerConfig) { + this.config = config; + + this.sessionBridge = new SessionBridge({ + sessionManager: config.sessionManager, + modelClient: config.modelClient, + systemPrompt: config.systemPrompt, + toolRegistry: config.toolRegistry, + toolExecutor: config.toolExecutor, + }); + + this.router = new Router(); + this.registerHandlers(); + } + + private registerHandlers(): void { + const systemHandlers = createSystemHandlers({ + startTime: this.startTime, + version: this.config.version ?? '0.1.0', + getSessionCount: () => this.sessionBridge.listSessions().length, + getToolCount: () => this.config.toolRegistry.list().length, + getConnectionCount: () => this.sessionBridge.connectionCount, + }); + + const sessionHandlers = createSessionHandlers({ + sessionManager: this.config.sessionManager, + }); + + const toolHandlers = createToolHandlers({ + toolRegistry: this.config.toolRegistry, + toolExecutor: this.config.toolExecutor, + }); + + const agentHandlers = createAgentHandlers({ + sessionBridge: this.sessionBridge, + }); + + // Register all methods + for (const [method, handler] of Object.entries(systemHandlers)) { + this.router.register(method, handler); + } + for (const [method, handler] of Object.entries(sessionHandlers)) { + this.router.register(method, handler); + } + for (const [method, handler] of Object.entries(toolHandlers)) { + this.router.register(method, handler); + } + for (const [method, handler] of Object.entries(agentHandlers)) { + this.router.register(method, handler); + } + } + + async start(): Promise { + return new Promise((resolve) => { + this.wss = new WebSocketServer({ + port: this.config.port, + host: this.config.host ?? '127.0.0.1', + }); + + this.wss.on('connection', (ws: WebSocket, req: IncomingMessage) => { + // Auth check on upgrade + const authResult = authenticateRequest(req, this.config.auth ?? {}); + if (!authResult.authenticated) { + ws.close(4001, authResult.error ?? 'Authentication failed'); + return; + } + this.handleConnection(ws, authResult.identity); + }); + + this.wss.on('listening', () => { + const addr = this.wss?.address(); + const portStr = typeof addr === 'object' && addr ? `:${addr.port}` : ''; + console.log(`Gateway WebSocket server listening on ${this.config.host ?? '127.0.0.1'}${portStr}`); + resolve(); + }); + }); + } + + async stop(): Promise { + return new Promise((resolve) => { + if (!this.wss) { + resolve(); + return; + } + + // Close all connections + for (const [ws, connectionId] of this.connectionMap) { + this.sessionBridge.disconnect(connectionId); + ws.close(1001, 'Server shutting down'); + } + this.connectionMap.clear(); + + this.wss.close(() => { + this.wss = null; + resolve(); + }); + }); + } + + private handleConnection(ws: WebSocket, identity?: string): void { + const connectionId = randomUUID(); + this.sessionBridge.connect(connectionId); + this.connectionMap.set(ws, connectionId); + + ws.on('message', async (data) => { + const raw = data.toString(); + await this.handleMessage(ws, connectionId, raw); + }); + + ws.on('close', () => { + this.sessionBridge.disconnect(connectionId); + this.connectionMap.delete(ws); + }); + + ws.on('error', (err) => { + console.error(`WebSocket error (${connectionId}):`, err.message); + }); + } + + private async handleMessage(ws: WebSocket, connectionId: string, raw: string): Promise { + const request = parseMessage(raw); + + if (!request) { + this.send(ws, makeError(0, ErrorCode.ParseError, 'Invalid JSON or missing required fields')); + return; + } + + // Inject connectionId into params so handlers can identify the client + if (!request.params) request.params = {}; + request.params.connectionId = connectionId; + + const send = (msg: OutboundMessage) => this.send(ws, msg); + const response = await this.router.dispatch(request, send); + + if (response) { + this.send(ws, response); + } + } + + private send(ws: WebSocket, msg: OutboundMessage): void { + if (ws.readyState === WebSocket.OPEN) { + ws.send(JSON.stringify(msg)); + } + } + + /** Get the underlying WebSocketServer (for testing). */ + getWss(): WebSocketServer | null { + return this.wss; + } + + /** Get the session bridge (for testing/debugging). */ + getSessionBridge(): SessionBridge { + return this.sessionBridge; + } + + /** Get list of registered methods. */ + getMethods(): string[] { + return this.router.listMethods(); + } +} diff --git a/src/gateway/session-bridge.test.ts b/src/gateway/session-bridge.test.ts new file mode 100644 index 0000000..2602925 --- /dev/null +++ b/src/gateway/session-bridge.test.ts @@ -0,0 +1,144 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { SessionBridge } from './session-bridge.js'; +import type { SessionBridgeConfig } from './session-bridge.js'; + +// Minimal mocks +const mockSession = { + id: 'test', + addMessage: vi.fn(), + getHistory: vi.fn(() => []), + clear: vi.fn(), +}; + +const mockSessionManager = { + getSession: vi.fn(() => mockSession), + listSessions: vi.fn(() => []), + transferSession: vi.fn(), + closeSession: vi.fn(), +}; + +const mockModelClient = { + chat: vi.fn(async () => ({ + content: 'test', + stopReason: 'end_turn', + usage: { inputTokens: 0, outputTokens: 0 }, + })), +}; + +const mockToolRegistry = { + register: vi.fn(), + get: vi.fn(), + list: vi.fn(() => []), + toAnthropicFormat: vi.fn(() => []), + toOpenAIFormat: vi.fn(() => []), +}; + +const mockToolExecutor = { + execute: vi.fn(), +}; + +function createBridge(): SessionBridge { + return new SessionBridge({ + sessionManager: mockSessionManager as unknown as SessionBridgeConfig['sessionManager'], + modelClient: mockModelClient, + systemPrompt: 'test prompt', + toolRegistry: mockToolRegistry as unknown as SessionBridgeConfig['toolRegistry'], + toolExecutor: mockToolExecutor as unknown as SessionBridgeConfig['toolExecutor'], + }); +} + +describe('SessionBridge', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('connect assigns a connection ID', () => { + const bridge = createBridge(); + const id = bridge.connect('conn-1'); + expect(id).toBe('conn-1'); + expect(bridge.connectionCount).toBe(1); + }); + + it('connect auto-generates ID when not provided', () => { + const bridge = createBridge(); + const id = bridge.connect(); + expect(typeof id).toBe('string'); + expect(id.length).toBeGreaterThan(0); + }); + + it('getAgent returns agent for connected client', () => { + const bridge = createBridge(); + bridge.connect('conn-1'); + const agent = bridge.getAgent('conn-1'); + expect(agent).toBeDefined(); + }); + + it('getSessionId returns session ID for connected client', () => { + const bridge = createBridge(); + bridge.connect('conn-1'); + expect(bridge.getSessionId('conn-1')).toBe('ws:conn-1'); + }); + + it('disconnect removes client but preserves session', () => { + const bridge = createBridge(); + bridge.connect('conn-1'); + bridge.disconnect('conn-1'); + expect(bridge.connectionCount).toBe(0); + expect(bridge.getAgent('conn-1')).toBeUndefined(); + }); + + it('tracks busy state', () => { + const bridge = createBridge(); + bridge.connect('conn-1'); + expect(bridge.isBusy('conn-1')).toBe(false); + bridge.setBusy('conn-1', true); + expect(bridge.isBusy('conn-1')).toBe(true); + bridge.setBusy('conn-1', false); + expect(bridge.isBusy('conn-1')).toBe(false); + }); + + it('switchSession changes session for a connection', () => { + const bridge = createBridge(); + bridge.connect('conn-1'); + expect(bridge.getSessionId('conn-1')).toBe('ws:conn-1'); + + bridge.switchSession('conn-1', 'custom-session'); + expect(bridge.getSessionId('conn-1')).toBe('custom-session'); + }); + + it('switchSession throws when busy', () => { + const bridge = createBridge(); + bridge.connect('conn-1'); + bridge.setBusy('conn-1', true); + + expect(() => bridge.switchSession('conn-1', 'other')).toThrow('Cannot switch session while agent is busy'); + }); + + it('switchSession throws for unknown connection', () => { + const bridge = createBridge(); + expect(() => bridge.switchSession('unknown', 'other')).toThrow('Unknown connection'); + }); + + it('listSessions groups connections by session', () => { + const bridge = createBridge(); + bridge.connect('conn-1'); + bridge.connect('conn-2'); + // Switch conn-2 to share conn-1's session + bridge.switchSession('conn-2', 'ws:conn-1'); + + const sessions = bridge.listSessions(); + expect(sessions).toEqual([{ sessionId: 'ws:conn-1', connections: 2 }]); + }); + + it('shared sessions keep agent alive when one client disconnects', () => { + const bridge = createBridge(); + bridge.connect('conn-1'); + bridge.connect('conn-2'); + bridge.switchSession('conn-2', 'ws:conn-1'); + + bridge.disconnect('conn-1'); + // conn-2 still has the session + expect(bridge.getAgent('conn-2')).toBeDefined(); + expect(bridge.connectionCount).toBe(1); + }); +}); diff --git a/src/gateway/session-bridge.ts b/src/gateway/session-bridge.ts new file mode 100644 index 0000000..0e3b74b --- /dev/null +++ b/src/gateway/session-bridge.ts @@ -0,0 +1,135 @@ +import { randomUUID } from 'crypto'; +import type { SessionManager } from '../session/manager.js'; +import type { Session } from '../session/manager.js'; +import type { ModelClient } from '../models/types.js'; +import type { ModelRouter } from '../models/router.js'; +import type { ToolRegistry } from '../tools/registry.js'; +import type { ToolExecutor } from '../tools/executor.js'; +import { NativeAgent } from '../backends/native/agent.js'; +import type { ToolUseEvent } from '../backends/native/agent.js'; + +export interface SessionBridgeConfig { + sessionManager: SessionManager; + modelClient: ModelClient | ModelRouter; + systemPrompt: string; + toolRegistry: ToolRegistry; + toolExecutor: ToolExecutor; +} + +interface ClientEntry { + connectionId: string; + sessionId: string; + agent: NativeAgent; + busy: boolean; +} + +export class SessionBridge { + private clients: Map = new Map(); + private agents: Map = new Map(); + private config: SessionBridgeConfig; + + constructor(config: SessionBridgeConfig) { + this.config = config; + } + + /** Register a new WS connection. Returns the assigned connection ID. */ + connect(connectionId?: string): string { + const id = connectionId ?? randomUUID(); + const sessionId = `ws:${id}`; + const agent = this.getOrCreateAgent(sessionId); + + this.clients.set(id, { + connectionId: id, + sessionId, + agent, + busy: false, + }); + + return id; + } + + /** Remove a WS connection. Does NOT destroy the session (persists in SQLite). */ + disconnect(connectionId: string): void { + const client = this.clients.get(connectionId); + if (client) { + // Only remove the agent if no other clients share the session + const otherClients = Array.from(this.clients.values()) + .filter(c => c.sessionId === client.sessionId && c.connectionId !== connectionId); + if (otherClients.length === 0) { + this.agents.delete(client.sessionId); + } + this.clients.delete(connectionId); + } + } + + /** Switch a connection to a different session (e.g. resuming an old session). */ + switchSession(connectionId: string, sessionId: string): void { + const client = this.clients.get(connectionId); + if (!client) throw new Error(`Unknown connection: ${connectionId}`); + if (client.busy) throw new Error('Cannot switch session while agent is busy'); + + const agent = this.getOrCreateAgent(sessionId); + client.sessionId = sessionId; + client.agent = agent; + } + + /** Get the NativeAgent for a connection. */ + getAgent(connectionId: string): NativeAgent | undefined { + return this.clients.get(connectionId)?.agent; + } + + /** Get the session ID for a connection. */ + getSessionId(connectionId: string): string | undefined { + return this.clients.get(connectionId)?.sessionId; + } + + /** Check if a connection's agent is busy. */ + isBusy(connectionId: string): boolean { + return this.clients.get(connectionId)?.busy ?? false; + } + + /** Mark a connection's agent as busy/idle. */ + setBusy(connectionId: string, busy: boolean): void { + const client = this.clients.get(connectionId); + if (client) client.busy = busy; + } + + /** Set onToolUse callback for a connection's agent. */ + setOnToolUse(connectionId: string, callback: ((event: ToolUseEvent) => void) | undefined): void { + const client = this.clients.get(connectionId); + if (client) client.agent.setOnToolUse(callback); + } + + /** List all active sessions with connection counts. */ + listSessions(): Array<{ sessionId: string; connections: number }> { + const sessionMap = new Map(); + for (const client of this.clients.values()) { + sessionMap.set(client.sessionId, (sessionMap.get(client.sessionId) ?? 0) + 1); + } + return Array.from(sessionMap.entries()).map(([sessionId, connections]) => ({ + sessionId, + connections, + })); + } + + /** Get count of active connections. */ + get connectionCount(): number { + return this.clients.size; + } + + private getOrCreateAgent(sessionId: string): NativeAgent { + let agent = this.agents.get(sessionId); + if (!agent) { + const session = this.config.sessionManager.getSession('ws', sessionId); + agent = new NativeAgent({ + modelClient: this.config.modelClient, + systemPrompt: this.config.systemPrompt, + session, + toolRegistry: this.config.toolRegistry, + toolExecutor: this.config.toolExecutor, + }); + this.agents.set(sessionId, agent); + } + return agent; + } +}