feat(gateway): add WebSocket gateway with JSON-RPC protocol and auth

Phase 2 of the Flynn roadmap. Adds a WebSocket gateway server that
starts alongside the Telegram bot, providing real-time API access to
the agent, sessions, and tools.

Protocol: JSON-RPC-like (request/response/event) over WebSocket.
8 methods: agent.send, agent.cancel, sessions.list, sessions.history,
sessions.create, tools.list, tools.invoke, system.health.

Auth: Bearer token + Tailscale identity header support.
Session bridge: per-connection agent instances with shared model router.

New files: src/gateway/ (protocol, router, server, auth, session-bridge,
handlers for agent/sessions/tools/system).
57 new tests (181 total), typecheck clean.
This commit is contained in:
William Valentin
2026-02-05 19:11:25 -08:00
parent ad7fc241f1
commit f30a8bc318
21 changed files with 1878 additions and 2 deletions
@@ -0,0 +1,214 @@
# Phase 2: WebSocket Gateway — Implementation Plan
## Goal
Add a WebSocket gateway to the Flynn daemon that allows real-time communication from web clients (and eventually any WS client). The gateway wraps existing daemon components (session manager, agent, tool registry) — no refactoring of existing code.
## Approach
Incremental, additive. Existing Telegram bot and TUI continue working unchanged. The gateway is a new module that starts alongside them.
## Architecture
```
src/gateway/
├── protocol.ts # JSON-RPC message types + event types
├── server.ts # WebSocket server (ws library), connection lifecycle
├── router.ts # Method routing → handler dispatch
├── auth.ts # Token auth + Tailscale identity headers
├── session-bridge.ts # Maps WS client connections to agent sessions
└── handlers/
├── agent.ts # agent.send (streaming), agent.cancel, agent.status
├── sessions.ts # sessions.list, sessions.history, sessions.create
├── tools.ts # tools.list, tools.invoke
└── system.ts # system.health, system.info
```
## Protocol
JSON-RPC-like over WebSocket (not full JSON-RPC 2.0 — simpler):
```typescript
// Client → Server
interface GatewayRequest {
id: number; // Client-assigned request ID
method: string; // e.g. "agent.send"
params: object; // Method-specific parameters
}
// Server → Client (success)
interface GatewayResponse {
id: number; // Matches request ID
result: object; // Method-specific result
}
// Server → Client (error)
interface GatewayError {
id: number; // Matches request ID
error: {
code: number; // Error code (negative = protocol, positive = app)
message: string; // Human-readable description
};
}
// Server → Client (streaming event, multiple per request)
interface GatewayEvent {
id: number; // Matches originating request ID
event: string; // Event type name
data: object; // Event-specific payload
}
```
### Event Types (for agent.send streaming)
| Event | Data | When |
|-------|------|------|
| `content` | `{ text: string }` | Text chunk from model |
| `tool_start` | `{ tool: string, args: object }` | Tool execution beginning |
| `tool_end` | `{ tool: string, result: { success, output, error? } }` | Tool execution complete |
| `done` | `{ content: string, usage: { inputTokens, outputTokens } }` | Final response |
| `error` | `{ code: number, message: string }` | Error during processing |
### Error Codes
| Code | Meaning |
|------|---------|
| -1 | Parse error (invalid JSON) |
| -2 | Invalid request (missing id/method) |
| -3 | Method not found |
| -4 | Authentication required |
| -5 | Authentication failed |
| 1 | Session not found |
| 2 | Tool not found |
| 3 | Agent busy (already processing) |
| 4 | Request cancelled |
## Methods
### agent.send
Send a message to the agent and receive streaming response.
```
Params: { message: string, sessionId?: string }
Events: content, tool_start, tool_end, done, error
Response: (none — final state sent as "done" event)
```
### agent.cancel
Cancel the currently running agent request.
```
Params: { sessionId?: string }
Response: { cancelled: boolean }
```
### sessions.list
List all active sessions.
```
Params: {}
Response: { sessions: [{ id: string, messageCount: number }] }
```
### sessions.history
Get message history for a session.
```
Params: { sessionId: string, limit?: number, offset?: number }
Response: { messages: Message[], total: number }
```
### sessions.create
Create a new session.
```
Params: { sessionId?: string }
Response: { sessionId: string }
```
### tools.list
List all registered tools.
```
Params: {}
Response: { tools: [{ name, description, inputSchema }] }
```
### tools.invoke
Directly invoke a tool (bypasses agent).
```
Params: { tool: string, args: object }
Response: { success: boolean, output: string, error?: string }
```
### system.health
Health check.
```
Params: {}
Response: { status: "ok", uptime: number, version: string, sessions: number, tools: number }
```
## Session Bridge
Each WebSocket connection is associated with a session:
1. Client connects → assigned default session ID `ws:{connectionId}`
2. Client can specify `sessionId` in `agent.send` to use a named session
3. Sessions are created on demand via SessionManager
4. Multiple WS clients can share a session (e.g. multiple browser tabs)
5. Disconnection does NOT destroy the session (persistence via SQLite)
## Auth (Phase 2b)
Two auth modes (checked in order):
1. **Token auth**: `Authorization: Bearer <token>` header on WS upgrade, or `{ method: "auth", params: { token: "..." } }` as first message
2. **Tailscale identity**: `Tailscale-User-Login` header set by Tailscale Funnel/proxy
Config addition:
```yaml
server:
port: 18800
auth:
token: "optional-static-token" # If set, required for all WS connections
tailscale_identity: true # Trust Tailscale-User-Login header
```
No auth initially (Phase 2a) — gateway only listens on localhost.
## Implementation Order
1. `protocol.ts` — Types only, no runtime code
2. `router.ts` — Method dispatch (pure function, easy to test)
3. `session-bridge.ts` — Client-to-session mapping
4. `handlers/system.ts` — Simplest handler, proves the pattern
5. `handlers/sessions.ts` — Session listing/history
6. `handlers/tools.ts` — Tool listing/invocation
7. `handlers/agent.ts` — The main handler (streaming, tool events)
8. `server.ts` — WebSocket server, ties everything together
9. Wire into `daemon/index.ts`
10. Tests throughout
## Dependencies
Need `ws` package:
```bash
pnpm add ws
pnpm add -D @types/ws
```
## Test Strategy
- Unit tests for protocol message validation
- Unit tests for router dispatch
- Unit tests for each handler (mock session manager, agent, tool registry)
- Integration test: WS client → server → handler → response
- Test streaming events for agent.send
---
*Plan Version: 1.0*
*Created: 2026-02-05*
*Parent: docs/plans/2026-02-05-openclaw-parity-design.md Phase 2*
+2
View File
@@ -28,6 +28,7 @@
"@types/marked-terminal": "^6.1.1", "@types/marked-terminal": "^6.1.1",
"@types/node": "^22.0.0", "@types/node": "^22.0.0",
"@types/react": "^19.0.0", "@types/react": "^19.0.0",
"@types/ws": "^8.18.1",
"eslint": "^9.0.0", "eslint": "^9.0.0",
"tsx": "^4.0.0", "tsx": "^4.0.0",
"typescript": "^5.7.0", "typescript": "^5.7.0",
@@ -45,6 +46,7 @@
"ollama": "^0.5.0", "ollama": "^0.5.0",
"openai": "^4.0.0", "openai": "^4.0.0",
"react": "^19.0.0", "react": "^19.0.0",
"ws": "^8.19.0",
"yaml": "^2.7.0", "yaml": "^2.7.0",
"zod": "^3.24.0" "zod": "^3.24.0"
}, },
+13
View File
@@ -41,6 +41,9 @@ importers:
react: react:
specifier: ^19.0.0 specifier: ^19.0.0
version: 19.2.4 version: 19.2.4
ws:
specifier: ^8.19.0
version: 8.19.0
yaml: yaml:
specifier: ^2.7.0 specifier: ^2.7.0
version: 2.8.2 version: 2.8.2
@@ -60,6 +63,9 @@ importers:
'@types/react': '@types/react':
specifier: ^19.0.0 specifier: ^19.0.0
version: 19.2.11 version: 19.2.11
'@types/ws':
specifier: ^8.18.1
version: 8.18.1
eslint: eslint:
specifier: ^9.0.0 specifier: ^9.0.0
version: 9.39.2 version: 9.39.2
@@ -477,6 +483,9 @@ packages:
'@types/react@19.2.11': '@types/react@19.2.11':
resolution: {integrity: sha512-tORuanb01iEzWvMGVGv2ZDhYZVeRMrw453DCSAIn/5yvcSVnMoUMTyf33nQJLahYEnv9xqrTNbgz4qY5EfSh0g==} resolution: {integrity: sha512-tORuanb01iEzWvMGVGv2ZDhYZVeRMrw453DCSAIn/5yvcSVnMoUMTyf33nQJLahYEnv9xqrTNbgz4qY5EfSh0g==}
'@types/ws@8.18.1':
resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==}
'@vitest/expect@3.2.4': '@vitest/expect@3.2.4':
resolution: {integrity: sha512-Io0yyORnB6sikFlt8QW5K7slY4OjqNX9jmJQ02QDda8lyM6B5oNgVWoSoKPac8/kgnCUzuHQKrSLtu/uOqqrig==} resolution: {integrity: sha512-Io0yyORnB6sikFlt8QW5K7slY4OjqNX9jmJQ02QDda8lyM6B5oNgVWoSoKPac8/kgnCUzuHQKrSLtu/uOqqrig==}
@@ -1867,6 +1876,10 @@ snapshots:
dependencies: dependencies:
csstype: 3.2.3 csstype: 3.2.3
'@types/ws@8.18.1':
dependencies:
'@types/node': 22.19.7
'@vitest/expect@3.2.4': '@vitest/expect@3.2.4':
dependencies: dependencies:
'@types/chai': 5.2.3 '@types/chai': 5.2.3
+26 -2
View File
@@ -7,6 +7,7 @@ import { createTelegramBot } from '../frontends/telegram/index.js';
import { SessionStore, SessionManager } from '../session/index.js'; import { SessionStore, SessionManager } from '../session/index.js';
import { HookEngine } from '../hooks/index.js'; import { HookEngine } from '../hooks/index.js';
import { ToolRegistry, ToolExecutor, allBuiltinTools } from '../tools/index.js'; import { ToolRegistry, ToolExecutor, allBuiltinTools } from '../tools/index.js';
import { GatewayServer } from '../gateway/index.js';
import { resolve } from 'path'; import { resolve } from 'path';
import { homedir } from 'os'; import { homedir } from 'os';
import { mkdirSync, readFileSync, existsSync } from 'fs'; import { mkdirSync, readFileSync, existsSync } from 'fs';
@@ -22,6 +23,7 @@ export interface DaemonContext {
modelRouter: ModelRouter; modelRouter: ModelRouter;
toolRegistry: ToolRegistry; toolRegistry: ToolRegistry;
toolExecutor: ToolExecutor; toolExecutor: ToolExecutor;
gateway: GatewayServer;
} }
function loadSystemPrompt(): string { function loadSystemPrompt(): string {
@@ -133,6 +135,9 @@ export async function startDaemon(config: Config): Promise<DaemonContext> {
// Initialize model router // Initialize model router
const modelRouter = createModelRouter(config); const modelRouter = createModelRouter(config);
// Load system prompt once for reuse
const systemPrompt = loadSystemPrompt();
// Get Telegram session // Get Telegram session
const telegramUserId = String(config.telegram.allowed_chat_ids[0]); const telegramUserId = String(config.telegram.allowed_chat_ids[0]);
const session = sessionManager.getSession('telegram', telegramUserId); const session = sessionManager.getSession('telegram', telegramUserId);
@@ -140,7 +145,7 @@ export async function startDaemon(config: Config): Promise<DaemonContext> {
// Initialize native agent with session and tools // Initialize native agent with session and tools
const agent = new NativeAgent({ const agent = new NativeAgent({
modelClient: modelRouter, modelClient: modelRouter,
systemPrompt: loadSystemPrompt(), systemPrompt,
session, session,
toolRegistry, toolRegistry,
toolExecutor, toolExecutor,
@@ -153,6 +158,17 @@ export async function startDaemon(config: Config): Promise<DaemonContext> {
hookEngine, hookEngine,
}); });
// Initialize gateway WebSocket server
const gateway = new GatewayServer({
port: config.server.port,
host: config.server.localhost ? '127.0.0.1' : '0.0.0.0',
sessionManager,
modelClient: modelRouter,
systemPrompt,
toolRegistry,
toolExecutor,
});
// Register signal handlers // Register signal handlers
const signalHandler = () => { const signalHandler = () => {
lifecycle.shutdown().then(() => process.exit(0)); lifecycle.shutdown().then(() => process.exit(0));
@@ -179,9 +195,17 @@ export async function startDaemon(config: Config): Promise<DaemonContext> {
}, },
}); });
// Start gateway
lifecycle.onShutdown(async () => {
await gateway.stop();
console.log('Gateway server stopped');
});
await gateway.start();
console.log('Flynn daemon started'); console.log('Flynn daemon started');
return { config, lifecycle, bot, agent, sessionStore, sessionManager, hookEngine, modelRouter, toolRegistry, toolExecutor }; return { config, lifecycle, bot, agent, sessionStore, sessionManager, hookEngine, modelRouter, toolRegistry, toolExecutor, gateway };
} }
export { Lifecycle } from './lifecycle.js'; export { Lifecycle } from './lifecycle.js';
+85
View File
@@ -0,0 +1,85 @@
import { describe, it, expect } from 'vitest';
import { authenticateRequest } from './auth.js';
import type { IncomingMessage } from 'http';
function mockRequest(headers: Record<string, string> = {}): IncomingMessage {
return { headers } as unknown as IncomingMessage;
}
describe('authenticateRequest', () => {
describe('no auth configured', () => {
it('allows all connections', () => {
const result = authenticateRequest(mockRequest(), {});
expect(result.authenticated).toBe(true);
expect(result.identity).toBe('anonymous');
});
});
describe('token auth', () => {
const config = { token: 'secret-token-123' };
it('accepts valid Bearer token', () => {
const result = authenticateRequest(
mockRequest({ authorization: 'Bearer secret-token-123' }),
config,
);
expect(result.authenticated).toBe(true);
expect(result.identity).toBe('token-user');
});
it('rejects missing Authorization header', () => {
const result = authenticateRequest(mockRequest(), config);
expect(result.authenticated).toBe(false);
expect(result.error).toContain('Authorization header required');
});
it('rejects invalid token', () => {
const result = authenticateRequest(
mockRequest({ authorization: 'Bearer wrong-token' }),
config,
);
expect(result.authenticated).toBe(false);
expect(result.error).toContain('Invalid token');
});
it('rejects non-Bearer format', () => {
const result = authenticateRequest(
mockRequest({ authorization: 'Basic dXNlcjpwYXNz' }),
config,
);
expect(result.authenticated).toBe(false);
expect(result.error).toContain('Invalid Authorization format');
});
it('uses Tailscale identity when both token and tailscale are configured', () => {
const result = authenticateRequest(
mockRequest({
authorization: 'Bearer secret-token-123',
'tailscale-user-login': 'will@example.com',
}),
{ token: 'secret-token-123', tailscaleIdentity: true },
);
expect(result.authenticated).toBe(true);
expect(result.identity).toBe('will@example.com');
});
});
describe('tailscale identity', () => {
const config = { tailscaleIdentity: true };
it('extracts identity from Tailscale-User-Login header', () => {
const result = authenticateRequest(
mockRequest({ 'tailscale-user-login': 'will@example.com' }),
config,
);
expect(result.authenticated).toBe(true);
expect(result.identity).toBe('will@example.com');
});
it('allows connections without Tailscale header (local access)', () => {
const result = authenticateRequest(mockRequest(), config);
expect(result.authenticated).toBe(true);
expect(result.identity).toBe('anonymous');
});
});
});
+67
View File
@@ -0,0 +1,67 @@
import type { IncomingMessage } from 'http';
export interface AuthConfig {
/** Static bearer token. If set, all connections must provide it. */
token?: string;
/** Trust Tailscale-User-Login header for identity. */
tailscaleIdentity?: boolean;
}
export interface AuthResult {
authenticated: boolean;
identity?: string;
error?: string;
}
/**
* Authenticates a WebSocket upgrade request.
*
* Auth is checked in this order:
* 1. If token is configured, validate Authorization header (Bearer token)
* 2. If tailscaleIdentity is enabled, extract identity from Tailscale-User-Login header
* 3. If no auth is configured, allow all connections
*/
export function authenticateRequest(req: IncomingMessage, config: AuthConfig): AuthResult {
// If token auth is configured, it's required
if (config.token) {
const authHeader = req.headers['authorization'];
if (!authHeader) {
return { authenticated: false, error: 'Authorization header required' };
}
const parts = authHeader.split(' ');
if (parts.length !== 2 || parts[0] !== 'Bearer') {
return { authenticated: false, error: 'Invalid Authorization format (expected: Bearer <token>)' };
}
if (parts[1] !== config.token) {
return { authenticated: false, error: 'Invalid token' };
}
// Token is valid — check for Tailscale identity too
const identity = extractTailscaleIdentity(req, config);
return { authenticated: true, identity: identity ?? 'token-user' };
}
// If Tailscale identity is configured (no token), use it
if (config.tailscaleIdentity) {
const identity = extractTailscaleIdentity(req, config);
if (identity) {
return { authenticated: true, identity };
}
// Tailscale identity configured but header not present — still allow (might be local)
return { authenticated: true, identity: 'anonymous' };
}
// No auth configured — allow all
return { authenticated: true, identity: 'anonymous' };
}
function extractTailscaleIdentity(req: IncomingMessage, config: AuthConfig): string | undefined {
if (!config.tailscaleIdentity) return undefined;
const header = req.headers['tailscale-user-login'];
if (typeof header === 'string' && header.length > 0) {
return header;
}
return undefined;
}
+77
View File
@@ -0,0 +1,77 @@
import type { GatewayRequest, OutboundMessage } from '../protocol.js';
import type { SendFn } from '../router.js';
import { makeEvent, makeError, ErrorCode } from '../protocol.js';
import type { SessionBridge } from '../session-bridge.js';
export interface AgentHandlerDeps {
sessionBridge: SessionBridge;
}
export function createAgentHandlers(deps: AgentHandlerDeps) {
return {
'agent.send': async (request: GatewayRequest, send: SendFn): Promise<OutboundMessage | void> => {
const params = request.params as { message?: string; connectionId?: string } | undefined;
if (!params?.message) {
return makeError(request.id, ErrorCode.InvalidRequest, 'message is required');
}
const connectionId = params.connectionId as string;
if (!connectionId) {
return makeError(request.id, ErrorCode.InvalidRequest, 'connectionId is required (set by server)');
}
if (deps.sessionBridge.isBusy(connectionId)) {
return makeError(request.id, ErrorCode.AgentBusy, 'Agent is already processing a request');
}
const agent = deps.sessionBridge.getAgent(connectionId);
if (!agent) {
return makeError(request.id, ErrorCode.SessionNotFound, 'No agent for this connection');
}
deps.sessionBridge.setBusy(connectionId, true);
// Set up tool use callback to emit streaming events
deps.sessionBridge.setOnToolUse(connectionId, (event) => {
if (event.type === 'start') {
send(makeEvent(request.id, 'tool_start', { tool: event.tool, args: event.args }));
} else if (event.type === 'end') {
send(makeEvent(request.id, 'tool_end', {
tool: event.tool,
result: event.result ? {
success: event.result.success,
output: event.result.output,
error: event.result.error,
} : undefined,
}));
}
});
try {
const response = await agent.process(params.message);
send(makeEvent(request.id, 'done', { content: response }));
} catch (err) {
const message = err instanceof Error ? err.message : 'Unknown error';
send(makeEvent(request.id, 'error', { code: ErrorCode.InternalError, message }));
} finally {
deps.sessionBridge.setBusy(connectionId, false);
deps.sessionBridge.setOnToolUse(connectionId, undefined);
}
},
'agent.cancel': async (request: GatewayRequest): Promise<OutboundMessage> => {
// Cancel is a placeholder — proper cancellation requires abort controller support in NativeAgent.
// For now, just report whether the agent was busy.
const params = request.params as { connectionId?: string } | undefined;
const connectionId = params?.connectionId as string;
if (!connectionId) {
return makeError(request.id, ErrorCode.InvalidRequest, 'connectionId is required');
}
const wasBusy = deps.sessionBridge.isBusy(connectionId);
// TODO: Wire AbortController into NativeAgent for actual cancellation
return { id: request.id, result: { cancelled: wasBusy } };
},
};
}
+271
View File
@@ -0,0 +1,271 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { createSystemHandlers } from './system.js';
import { createSessionHandlers } from './sessions.js';
import { createToolHandlers } from './tools.js';
import { createAgentHandlers } from './agent.js';
import { ErrorCode } from '../protocol.js';
import type { GatewayRequest, GatewayResponse, GatewayError, GatewayEvent, OutboundMessage } from '../protocol.js';
describe('system handlers', () => {
const deps = {
startTime: Date.now() - 60_000,
version: '0.1.0',
getSessionCount: () => 3,
getToolCount: () => 6,
getConnectionCount: () => 2,
};
const handlers = createSystemHandlers(deps);
it('system.health returns status info', async () => {
const req: GatewayRequest = { id: 1, method: 'system.health' };
const result = await handlers['system.health'](req) as GatewayResponse;
expect(result.id).toBe(1);
const r = result.result as Record<string, unknown>;
expect(r.status).toBe('ok');
expect(r.version).toBe('0.1.0');
expect(r.sessions).toBe(3);
expect(r.tools).toBe(6);
expect(r.connections).toBe(2);
expect(typeof r.uptime).toBe('number');
expect(r.uptime).toBeGreaterThanOrEqual(59);
});
});
describe('session handlers', () => {
const mockHistory = [
{ role: 'user' as const, content: 'hello' },
{ role: 'assistant' as const, content: 'hi' },
];
const mockSession = {
id: 'ws:test',
addMessage: vi.fn(),
getHistory: vi.fn(() => mockHistory),
clear: vi.fn(),
};
const mockSessionManager = {
listSessions: vi.fn(() => ['ws:test']),
getSession: vi.fn(() => mockSession),
transferSession: vi.fn(),
closeSession: vi.fn(),
};
const handlers = createSessionHandlers({
sessionManager: mockSessionManager as any,
});
beforeEach(() => {
vi.clearAllMocks();
mockSessionManager.listSessions.mockReturnValue(['ws:test']);
mockSessionManager.getSession.mockReturnValue(mockSession);
mockSession.getHistory.mockReturnValue(mockHistory);
});
it('sessions.list returns session list with message counts', async () => {
const req: GatewayRequest = { id: 1, method: 'sessions.list' };
const result = await handlers['sessions.list'](req) as GatewayResponse;
expect(result.id).toBe(1);
const r = result.result as { sessions: Array<{ id: string; messageCount: number }> };
expect(r.sessions).toHaveLength(1);
expect(r.sessions[0].id).toBe('ws:test');
expect(r.sessions[0].messageCount).toBe(2);
});
it('sessions.history returns messages with pagination', async () => {
const req: GatewayRequest = { id: 2, method: 'sessions.history', params: { sessionId: 'ws:test', limit: 1, offset: 0 } };
const result = await handlers['sessions.history'](req) as GatewayResponse;
const r = result.result as { messages: unknown[]; total: number };
expect(r.messages).toHaveLength(1);
expect(r.total).toBe(2);
});
it('sessions.history requires sessionId', async () => {
const req: GatewayRequest = { id: 3, method: 'sessions.history', params: {} };
const result = await handlers['sessions.history'](req) as GatewayError;
expect(result.error.code).toBe(ErrorCode.InvalidRequest);
});
it('sessions.create creates a new session', async () => {
const req: GatewayRequest = { id: 4, method: 'sessions.create', params: { sessionId: 'ws:new' } };
const result = await handlers['sessions.create'](req) as GatewayResponse;
const r = result.result as { sessionId: string };
expect(r.sessionId).toBe('ws:new');
expect(mockSessionManager.getSession).toHaveBeenCalledWith('ws', 'new');
});
it('sessions.create auto-generates session ID', async () => {
const req: GatewayRequest = { id: 5, method: 'sessions.create' };
const result = await handlers['sessions.create'](req) as GatewayResponse;
const r = result.result as { sessionId: string };
expect(r.sessionId).toMatch(/^ws:\d+$/);
});
});
describe('tool handlers', () => {
const mockTool = {
name: 'test.tool',
description: 'A test tool',
inputSchema: { type: 'object' as const, properties: {} },
execute: vi.fn(),
};
const mockRegistry = {
list: vi.fn(() => [mockTool]),
get: vi.fn((name: string) => (name === 'test.tool' ? mockTool : undefined)),
register: vi.fn(),
toAnthropicFormat: vi.fn(),
toOpenAIFormat: vi.fn(),
};
const mockExecutor = {
execute: vi.fn(async () => ({ success: true, output: 'done' })),
};
const handlers = createToolHandlers({
toolRegistry: mockRegistry as any,
toolExecutor: mockExecutor as any,
});
beforeEach(() => {
vi.clearAllMocks();
mockRegistry.list.mockReturnValue([mockTool]);
mockRegistry.get.mockImplementation((name: string) => (name === 'test.tool' ? mockTool : undefined));
mockExecutor.execute.mockResolvedValue({ success: true, output: 'done' });
});
it('tools.list returns tool definitions', async () => {
const req: GatewayRequest = { id: 1, method: 'tools.list' };
const result = await handlers['tools.list'](req) as GatewayResponse;
const r = result.result as { tools: Array<{ name: string }> };
expect(r.tools).toHaveLength(1);
expect(r.tools[0].name).toBe('test.tool');
});
it('tools.invoke executes a tool', async () => {
const req: GatewayRequest = { id: 2, method: 'tools.invoke', params: { tool: 'test.tool', args: {} } };
const result = await handlers['tools.invoke'](req) as GatewayResponse;
expect(result.result).toEqual({ success: true, output: 'done' });
expect(mockExecutor.execute).toHaveBeenCalledWith('test.tool', {});
});
it('tools.invoke errors on missing tool name', async () => {
const req: GatewayRequest = { id: 3, method: 'tools.invoke', params: {} };
const result = await handlers['tools.invoke'](req) as GatewayError;
expect(result.error.code).toBe(ErrorCode.InvalidRequest);
});
it('tools.invoke errors on unknown tool', async () => {
const req: GatewayRequest = { id: 4, method: 'tools.invoke', params: { tool: 'unknown' } };
const result = await handlers['tools.invoke'](req) as GatewayError;
expect(result.error.code).toBe(ErrorCode.ToolNotFound);
});
});
describe('agent handlers', () => {
const mockAgent = {
process: vi.fn(async () => 'response text'),
setOnToolUse: vi.fn(),
};
const mockBridge = {
getAgent: vi.fn(() => mockAgent),
getSessionId: vi.fn(() => 'ws:conn-1'),
isBusy: vi.fn(() => false),
setBusy: vi.fn(),
setOnToolUse: vi.fn(),
};
const handlers = createAgentHandlers({
sessionBridge: mockBridge as any,
});
beforeEach(() => {
vi.clearAllMocks();
mockBridge.isBusy.mockReturnValue(false);
mockBridge.getAgent.mockReturnValue(mockAgent);
mockAgent.process.mockResolvedValue('response text');
});
it('agent.send processes message and sends done event', async () => {
const req: GatewayRequest = { id: 1, method: 'agent.send', params: { message: 'hello', connectionId: 'conn-1' } };
const sent: OutboundMessage[] = [];
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
await handlers['agent.send'](req, send);
expect(mockAgent.process).toHaveBeenCalledWith('hello');
expect(sent).toHaveLength(1);
const doneEvent = sent[0] as GatewayEvent;
expect(doneEvent.event).toBe('done');
expect((doneEvent.data as any).content).toBe('response text');
});
it('agent.send requires message', async () => {
const req: GatewayRequest = { id: 2, method: 'agent.send', params: { connectionId: 'conn-1' } };
const send = vi.fn();
const result = await handlers['agent.send'](req, send) as GatewayError;
expect(result.error.code).toBe(ErrorCode.InvalidRequest);
expect(result.error.message).toContain('message');
});
it('agent.send rejects when busy', async () => {
mockBridge.isBusy.mockReturnValue(true);
const req: GatewayRequest = { id: 3, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } };
const send = vi.fn();
const result = await handlers['agent.send'](req, send) as GatewayError;
expect(result.error.code).toBe(ErrorCode.AgentBusy);
});
it('agent.send handles errors gracefully', async () => {
mockAgent.process.mockRejectedValue(new Error('model failed'));
const req: GatewayRequest = { id: 4, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } };
const sent: OutboundMessage[] = [];
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
await handlers['agent.send'](req, send);
const errorEvent = sent[0] as GatewayEvent;
expect(errorEvent.event).toBe('error');
expect((errorEvent.data as any).message).toBe('model failed');
});
it('agent.send sets and cleans up tool use callback', async () => {
const req: GatewayRequest = { id: 5, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } };
const send = vi.fn();
await handlers['agent.send'](req, send);
// setOnToolUse called twice: once to set callback, once to clear it
expect(mockBridge.setOnToolUse).toHaveBeenCalledTimes(2);
expect(mockBridge.setOnToolUse).toHaveBeenLastCalledWith('conn-1', undefined);
});
it('agent.send sets busy state correctly', async () => {
const req: GatewayRequest = { id: 6, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } };
const send = vi.fn();
await handlers['agent.send'](req, send);
expect(mockBridge.setBusy).toHaveBeenCalledWith('conn-1', true);
expect(mockBridge.setBusy).toHaveBeenCalledWith('conn-1', false);
});
it('agent.cancel returns cancelled state', async () => {
mockBridge.isBusy.mockReturnValue(true);
const req: GatewayRequest = { id: 7, method: 'agent.cancel', params: { connectionId: 'conn-1' } };
const result = await handlers['agent.cancel'](req) as GatewayResponse;
expect((result.result as any).cancelled).toBe(true);
});
});
+8
View File
@@ -0,0 +1,8 @@
export { createSystemHandlers } from './system.js';
export type { SystemHandlerDeps } from './system.js';
export { createSessionHandlers } from './sessions.js';
export type { SessionHandlerDeps } from './sessions.js';
export { createToolHandlers } from './tools.js';
export type { ToolHandlerDeps } from './tools.js';
export { createAgentHandlers } from './agent.js';
export type { AgentHandlerDeps } from './agent.js';
+59
View File
@@ -0,0 +1,59 @@
import type { GatewayRequest, OutboundMessage } from '../protocol.js';
import { makeResponse, makeError, ErrorCode } from '../protocol.js';
import type { SessionManager } from '../../session/manager.js';
export interface SessionHandlerDeps {
sessionManager: SessionManager;
}
export function createSessionHandlers(deps: SessionHandlerDeps) {
return {
'sessions.list': async (request: GatewayRequest): Promise<OutboundMessage> => {
const sessionIds = deps.sessionManager.listSessions();
const sessions = sessionIds.map(id => ({
id,
messageCount: deps.sessionManager.getSession(
id.split(':')[0],
id.split(':').slice(1).join(':')
).getHistory().length,
}));
return makeResponse(request.id, { sessions });
},
'sessions.history': async (request: GatewayRequest): Promise<OutboundMessage> => {
const params = request.params as { sessionId?: string; limit?: number; offset?: number } | undefined;
if (!params?.sessionId) {
return makeError(request.id, ErrorCode.InvalidRequest, 'sessionId is required');
}
const { sessionId, limit, offset } = params;
const parts = sessionId.split(':');
const frontend = parts[0];
const userId = parts.slice(1).join(':');
const session = deps.sessionManager.getSession(frontend, userId);
const allMessages = session.getHistory();
const start = offset ?? 0;
const end = limit ? start + limit : allMessages.length;
const messages = allMessages.slice(start, end);
return makeResponse(request.id, {
messages,
total: allMessages.length,
});
},
'sessions.create': async (request: GatewayRequest): Promise<OutboundMessage> => {
const params = request.params as { sessionId?: string } | undefined;
const sessionId = params?.sessionId ?? `ws:${Date.now()}`;
const parts = sessionId.split(':');
const frontend = parts[0];
const userId = parts.slice(1).join(':');
// Creating a session via getSession is idempotent
deps.sessionManager.getSession(frontend, userId);
return makeResponse(request.id, { sessionId });
},
};
}
+25
View File
@@ -0,0 +1,25 @@
import type { GatewayRequest, OutboundMessage } from '../protocol.js';
import { makeResponse } from '../protocol.js';
export interface SystemHandlerDeps {
startTime: number;
version: string;
getSessionCount: () => number;
getToolCount: () => number;
getConnectionCount: () => number;
}
export function createSystemHandlers(deps: SystemHandlerDeps) {
return {
'system.health': async (request: GatewayRequest): Promise<OutboundMessage> => {
return makeResponse(request.id, {
status: 'ok',
uptime: Math.floor((Date.now() - deps.startTime) / 1000),
version: deps.version,
sessions: deps.getSessionCount(),
tools: deps.getToolCount(),
connections: deps.getConnectionCount(),
});
},
};
}
+37
View File
@@ -0,0 +1,37 @@
import type { GatewayRequest, OutboundMessage } from '../protocol.js';
import { makeResponse, makeError, ErrorCode } from '../protocol.js';
import type { ToolRegistry } from '../../tools/registry.js';
import type { ToolExecutor } from '../../tools/executor.js';
export interface ToolHandlerDeps {
toolRegistry: ToolRegistry;
toolExecutor: ToolExecutor;
}
export function createToolHandlers(deps: ToolHandlerDeps) {
return {
'tools.list': async (request: GatewayRequest): Promise<OutboundMessage> => {
const tools = deps.toolRegistry.list().map(t => ({
name: t.name,
description: t.description,
inputSchema: t.inputSchema,
}));
return makeResponse(request.id, { tools });
},
'tools.invoke': async (request: GatewayRequest): Promise<OutboundMessage> => {
const params = request.params as { tool?: string; args?: Record<string, unknown> } | undefined;
if (!params?.tool) {
return makeError(request.id, ErrorCode.InvalidRequest, 'tool name is required');
}
const tool = deps.toolRegistry.get(params.tool);
if (!tool) {
return makeError(request.id, ErrorCode.ToolNotFound, `Tool not found: ${params.tool}`);
}
const result = await deps.toolExecutor.execute(params.tool, params.args ?? {});
return makeResponse(request.id, result);
},
};
}
+29
View File
@@ -0,0 +1,29 @@
export { GatewayServer } from './server.js';
export type { GatewayServerConfig } from './server.js';
export { Router } from './router.js';
export type { HandlerFn, SendFn } from './router.js';
export { SessionBridge } from './session-bridge.js';
export type { SessionBridgeConfig } from './session-bridge.js';
export { authenticateRequest } from './auth.js';
export type { AuthConfig, AuthResult } from './auth.js';
export {
ErrorCode,
isValidRequest,
parseMessage,
makeResponse,
makeError,
makeEvent,
} from './protocol.js';
export type {
GatewayRequest,
GatewayResponse,
GatewayError,
GatewayEvent,
OutboundMessage,
EventType,
ContentEventData,
ToolStartEventData,
ToolEndEventData,
DoneEventData,
ErrorEventData,
} from './protocol.js';
+90
View File
@@ -0,0 +1,90 @@
import { describe, it, expect } from 'vitest';
import {
isValidRequest,
parseMessage,
makeResponse,
makeError,
makeEvent,
ErrorCode,
} from './protocol.js';
describe('protocol', () => {
describe('isValidRequest', () => {
it('accepts valid request with params', () => {
expect(isValidRequest({ id: 1, method: 'agent.send', params: { message: 'hello' } })).toBe(true);
});
it('accepts valid request without params', () => {
expect(isValidRequest({ id: 1, method: 'system.health' })).toBe(true);
});
it('rejects missing id', () => {
expect(isValidRequest({ method: 'test' })).toBe(false);
});
it('rejects missing method', () => {
expect(isValidRequest({ id: 1 })).toBe(false);
});
it('rejects non-object', () => {
expect(isValidRequest('not an object')).toBe(false);
expect(isValidRequest(null)).toBe(false);
expect(isValidRequest(42)).toBe(false);
});
it('rejects non-numeric id', () => {
expect(isValidRequest({ id: 'abc', method: 'test' })).toBe(false);
});
it('rejects non-string method', () => {
expect(isValidRequest({ id: 1, method: 42 })).toBe(false);
});
it('rejects non-object params', () => {
expect(isValidRequest({ id: 1, method: 'test', params: 'bad' })).toBe(false);
});
});
describe('parseMessage', () => {
it('parses valid JSON into GatewayRequest', () => {
const msg = parseMessage('{"id":1,"method":"agent.send","params":{"message":"hi"}}');
expect(msg).toEqual({ id: 1, method: 'agent.send', params: { message: 'hi' } });
});
it('returns null for invalid JSON', () => {
expect(parseMessage('not json')).toBeNull();
});
it('returns null for valid JSON that is not a valid request', () => {
expect(parseMessage('{"method":"test"}')).toBeNull();
});
});
describe('makeResponse', () => {
it('creates a response message', () => {
expect(makeResponse(1, { status: 'ok' })).toEqual({
id: 1,
result: { status: 'ok' },
});
});
});
describe('makeError', () => {
it('creates an error message', () => {
expect(makeError(1, ErrorCode.MethodNotFound, 'Not found')).toEqual({
id: 1,
error: { code: -3, message: 'Not found' },
});
});
});
describe('makeEvent', () => {
it('creates an event message', () => {
expect(makeEvent(1, 'content', { text: 'hello' })).toEqual({
id: 1,
event: 'content',
data: { text: 'hello' },
});
});
});
});
+119
View File
@@ -0,0 +1,119 @@
// Gateway protocol types — JSON-RPC-like messages over WebSocket.
// ── Client → Server ────────────────────────────────────────────
export interface GatewayRequest {
id: number;
method: string;
params?: Record<string, unknown>;
}
// ── Server → Client ────────────────────────────────────────────
export interface GatewayResponse {
id: number;
result: unknown;
}
export interface GatewayError {
id: number;
error: {
code: ErrorCode;
message: string;
};
}
export interface GatewayEvent {
id: number;
event: EventType;
data: unknown;
}
// ── Event types emitted during agent.send ──────────────────────
export type EventType =
| 'content'
| 'tool_start'
| 'tool_end'
| 'done'
| 'error';
export interface ContentEventData {
text: string;
}
export interface ToolStartEventData {
tool: string;
args: unknown;
}
export interface ToolEndEventData {
tool: string;
result: {
success: boolean;
output: string;
error?: string;
};
}
export interface DoneEventData {
content: string;
}
export interface ErrorEventData {
code: ErrorCode;
message: string;
}
// ── Error codes ────────────────────────────────────────────────
export enum ErrorCode {
ParseError = -1,
InvalidRequest = -2,
MethodNotFound = -3,
AuthRequired = -4,
AuthFailed = -5,
SessionNotFound = 1,
ToolNotFound = 2,
AgentBusy = 3,
RequestCancelled = 4,
InternalError = 5,
}
// ── Outbound message (union of all server → client types) ──────
export type OutboundMessage = GatewayResponse | GatewayError | GatewayEvent;
// ── Validation helpers ─────────────────────────────────────────
export function isValidRequest(msg: unknown): msg is GatewayRequest {
if (typeof msg !== 'object' || msg === null) return false;
const obj = msg as Record<string, unknown>;
return (
typeof obj.id === 'number' &&
typeof obj.method === 'string' &&
(obj.params === undefined || (typeof obj.params === 'object' && obj.params !== null))
);
}
export function parseMessage(raw: string): GatewayRequest | null {
try {
const parsed = JSON.parse(raw);
if (isValidRequest(parsed)) return parsed;
return null;
} catch {
return null;
}
}
export function makeResponse(id: number, result: unknown): GatewayResponse {
return { id, result };
}
export function makeError(id: number, code: ErrorCode, message: string): GatewayError {
return { id, error: { code, message } };
}
export function makeEvent(id: number, event: EventType, data: unknown): GatewayEvent {
return { id, event, data };
}
+58
View File
@@ -0,0 +1,58 @@
import { describe, it, expect, vi } from 'vitest';
import { Router } from './router.js';
import type { GatewayRequest, OutboundMessage } from './protocol.js';
import { makeResponse, ErrorCode } from './protocol.js';
describe('Router', () => {
it('dispatches to registered handler', async () => {
const router = new Router();
const handler = vi.fn(async (req: GatewayRequest) => makeResponse(req.id, { ok: true }));
router.register('test.method', handler);
const request: GatewayRequest = { id: 1, method: 'test.method', params: {} };
const send = vi.fn();
const result = await router.dispatch(request, send);
expect(handler).toHaveBeenCalledWith(request, send);
expect(result).toEqual({ id: 1, result: { ok: true } });
});
it('returns MethodNotFound for unregistered method', async () => {
const router = new Router();
const request: GatewayRequest = { id: 1, method: 'unknown.method' };
const send = vi.fn();
const result = await router.dispatch(request, send);
expect(result).toEqual({
id: 1,
error: { code: ErrorCode.MethodNotFound, message: 'Unknown method: unknown.method' },
});
});
it('lists registered methods', () => {
const router = new Router();
router.register('a.method', async () => {});
router.register('b.method', async () => {});
expect(router.listMethods()).toEqual(['a.method', 'b.method']);
});
it('handler can send streaming events via send function', async () => {
const router = new Router();
router.register('stream.test', async (req, send) => {
send({ id: req.id, event: 'content', data: { text: 'chunk1' } });
send({ id: req.id, event: 'content', data: { text: 'chunk2' } });
send({ id: req.id, event: 'done', data: { content: 'chunk1chunk2' } });
});
const request: GatewayRequest = { id: 1, method: 'stream.test' };
const sent: OutboundMessage[] = [];
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
await router.dispatch(request, send);
expect(sent).toHaveLength(3);
expect(sent[0]).toEqual({ id: 1, event: 'content', data: { text: 'chunk1' } });
expect(sent[2]).toEqual({ id: 1, event: 'done', data: { content: 'chunk1chunk2' } });
});
});
+27
View File
@@ -0,0 +1,27 @@
import type { GatewayRequest, OutboundMessage } from './protocol.js';
import { makeError, ErrorCode } from './protocol.js';
// A handler function receives a request and a send function for streaming events.
// It returns a final response/error, or void if it already sent a done event.
export type SendFn = (msg: OutboundMessage) => void;
export type HandlerFn = (request: GatewayRequest, send: SendFn) => Promise<OutboundMessage | void>;
export class Router {
private handlers: Map<string, HandlerFn> = new Map();
register(method: string, handler: HandlerFn): void {
this.handlers.set(method, handler);
}
async dispatch(request: GatewayRequest, send: SendFn): Promise<OutboundMessage | void> {
const handler = this.handlers.get(request.method);
if (!handler) {
return makeError(request.id, ErrorCode.MethodNotFound, `Unknown method: ${request.method}`);
}
return handler(request, send);
}
listMethods(): string[] {
return Array.from(this.handlers.keys());
}
}
+188
View File
@@ -0,0 +1,188 @@
import { describe, it, expect, beforeAll, afterAll, vi } from 'vitest';
import { WebSocket } from 'ws';
import { GatewayServer } from './server.js';
import type { GatewayServerConfig } from './server.js';
import type { GatewayResponse, GatewayError, GatewayEvent } from './protocol.js';
import { ErrorCode } from './protocol.js';
// Minimal mocks for dependencies
const mockSession = {
id: 'test',
addMessage: vi.fn(),
getHistory: vi.fn(() => []),
clear: vi.fn(),
setHistory: vi.fn(),
};
const mockSessionManager = {
getSession: vi.fn(() => mockSession),
listSessions: vi.fn(() => ['ws:test']),
transferSession: vi.fn(),
closeSession: vi.fn(),
};
const mockModelClient = {
chat: vi.fn(async () => ({
content: 'Hello from Flynn!',
stopReason: 'end_turn',
usage: { inputTokens: 10, outputTokens: 5 },
})),
};
const mockToolRegistry = {
register: vi.fn(),
get: vi.fn((name: string) => (name === 'shell.exec' ? { name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } } : undefined)),
list: vi.fn(() => [{ name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } }]),
toAnthropicFormat: vi.fn(() => []),
toOpenAIFormat: vi.fn(() => []),
};
const mockToolExecutor = {
execute: vi.fn(async () => ({ success: true, output: 'executed' })),
};
const TEST_PORT = 18899;
let server: GatewayServer;
function createClient(): Promise<WebSocket> {
return new Promise((resolve, reject) => {
const ws = new WebSocket(`ws://127.0.0.1:${TEST_PORT}`);
ws.on('open', () => resolve(ws));
ws.on('error', reject);
});
}
function sendAndReceive(ws: WebSocket, msg: object): Promise<GatewayResponse | GatewayError | GatewayEvent> {
return new Promise((resolve) => {
ws.once('message', (data) => {
resolve(JSON.parse(data.toString()));
});
ws.send(JSON.stringify(msg));
});
}
function sendAndReceiveAll(ws: WebSocket, msg: object, count: number): Promise<Array<GatewayResponse | GatewayError | GatewayEvent>> {
return new Promise((resolve) => {
const messages: Array<GatewayResponse | GatewayError | GatewayEvent> = [];
const handler = (data: Buffer) => {
messages.push(JSON.parse(data.toString()));
if (messages.length >= count) {
ws.off('message', handler);
resolve(messages);
}
};
ws.on('message', handler);
ws.send(JSON.stringify(msg));
});
}
describe('GatewayServer integration', () => {
beforeAll(async () => {
server = new GatewayServer({
port: TEST_PORT,
sessionManager: mockSessionManager as unknown as GatewayServerConfig['sessionManager'],
modelClient: mockModelClient,
systemPrompt: 'Test prompt',
toolRegistry: mockToolRegistry as unknown as GatewayServerConfig['toolRegistry'],
toolExecutor: mockToolExecutor as unknown as GatewayServerConfig['toolExecutor'],
version: '0.1.0-test',
});
await server.start();
});
afterAll(async () => {
await server.stop();
});
it('responds to system.health', async () => {
const ws = await createClient();
try {
const result = await sendAndReceive(ws, { id: 1, method: 'system.health' });
const response = result as GatewayResponse;
expect(response.id).toBe(1);
const r = response.result as Record<string, unknown>;
expect(r.status).toBe('ok');
expect(r.version).toBe('0.1.0-test');
expect(typeof r.uptime).toBe('number');
} finally {
ws.close();
}
});
it('returns MethodNotFound for unknown method', async () => {
const ws = await createClient();
try {
const result = await sendAndReceive(ws, { id: 2, method: 'unknown.method' });
const error = result as GatewayError;
expect(error.error.code).toBe(ErrorCode.MethodNotFound);
} finally {
ws.close();
}
});
it('returns ParseError for invalid JSON', async () => {
const ws = await createClient();
try {
const result = await new Promise<GatewayError>((resolve) => {
ws.once('message', (data) => resolve(JSON.parse(data.toString())));
ws.send('not valid json');
});
expect(result.error.code).toBe(ErrorCode.ParseError);
} finally {
ws.close();
}
});
it('lists tools via tools.list', async () => {
const ws = await createClient();
try {
const result = await sendAndReceive(ws, { id: 3, method: 'tools.list' });
const response = result as GatewayResponse;
const r = response.result as { tools: Array<{ name: string }> };
expect(r.tools.length).toBeGreaterThan(0);
expect(r.tools[0].name).toBe('shell.exec');
} finally {
ws.close();
}
});
it('sends agent message and receives done event', async () => {
const ws = await createClient();
try {
// agent.send streams events — we expect a 'done' event
const messages = await sendAndReceiveAll(ws, { id: 4, method: 'agent.send', params: { message: 'hi' } }, 1);
const doneEvent = messages[0] as GatewayEvent;
expect(doneEvent.id).toBe(4);
expect(doneEvent.event).toBe('done');
expect((doneEvent.data as any).content).toBe('Hello from Flynn!');
} finally {
ws.close();
}
});
it('tracks connections correctly', async () => {
const ws1 = await createClient();
const ws2 = await createClient();
try {
const result = await sendAndReceive(ws1, { id: 5, method: 'system.health' });
const r = (result as GatewayResponse).result as Record<string, unknown>;
expect(r.connections).toBe(2);
} finally {
ws1.close();
ws2.close();
}
});
it('lists registered methods', () => {
const methods = server.getMethods();
expect(methods).toContain('system.health');
expect(methods).toContain('agent.send');
expect(methods).toContain('agent.cancel');
expect(methods).toContain('sessions.list');
expect(methods).toContain('sessions.history');
expect(methods).toContain('sessions.create');
expect(methods).toContain('tools.list');
expect(methods).toContain('tools.invoke');
});
});
+204
View File
@@ -0,0 +1,204 @@
import { WebSocketServer, WebSocket } from 'ws';
import { randomUUID } from 'crypto';
import type { IncomingMessage } from 'http';
import { Router } from './router.js';
import { SessionBridge } from './session-bridge.js';
import type { SessionBridgeConfig } from './session-bridge.js';
import { authenticateRequest } from './auth.js';
import type { AuthConfig } from './auth.js';
import {
parseMessage,
makeError,
ErrorCode,
type OutboundMessage,
} from './protocol.js';
import {
createSystemHandlers,
createSessionHandlers,
createToolHandlers,
createAgentHandlers,
} from './handlers/index.js';
import type { SessionManager } from '../session/manager.js';
import type { ToolRegistry } from '../tools/registry.js';
import type { ToolExecutor } from '../tools/executor.js';
export interface GatewayServerConfig {
port: number;
host?: string;
sessionManager: SessionManager;
modelClient: SessionBridgeConfig['modelClient'];
systemPrompt: string;
toolRegistry: ToolRegistry;
toolExecutor: ToolExecutor;
version?: string;
auth?: AuthConfig;
}
export class GatewayServer {
private wss: WebSocketServer | null = null;
private router: Router;
private sessionBridge: SessionBridge;
private connectionMap: Map<WebSocket, string> = new Map();
private config: GatewayServerConfig;
private startTime: number = Date.now();
constructor(config: GatewayServerConfig) {
this.config = config;
this.sessionBridge = new SessionBridge({
sessionManager: config.sessionManager,
modelClient: config.modelClient,
systemPrompt: config.systemPrompt,
toolRegistry: config.toolRegistry,
toolExecutor: config.toolExecutor,
});
this.router = new Router();
this.registerHandlers();
}
private registerHandlers(): void {
const systemHandlers = createSystemHandlers({
startTime: this.startTime,
version: this.config.version ?? '0.1.0',
getSessionCount: () => this.sessionBridge.listSessions().length,
getToolCount: () => this.config.toolRegistry.list().length,
getConnectionCount: () => this.sessionBridge.connectionCount,
});
const sessionHandlers = createSessionHandlers({
sessionManager: this.config.sessionManager,
});
const toolHandlers = createToolHandlers({
toolRegistry: this.config.toolRegistry,
toolExecutor: this.config.toolExecutor,
});
const agentHandlers = createAgentHandlers({
sessionBridge: this.sessionBridge,
});
// Register all methods
for (const [method, handler] of Object.entries(systemHandlers)) {
this.router.register(method, handler);
}
for (const [method, handler] of Object.entries(sessionHandlers)) {
this.router.register(method, handler);
}
for (const [method, handler] of Object.entries(toolHandlers)) {
this.router.register(method, handler);
}
for (const [method, handler] of Object.entries(agentHandlers)) {
this.router.register(method, handler);
}
}
async start(): Promise<void> {
return new Promise((resolve) => {
this.wss = new WebSocketServer({
port: this.config.port,
host: this.config.host ?? '127.0.0.1',
});
this.wss.on('connection', (ws: WebSocket, req: IncomingMessage) => {
// Auth check on upgrade
const authResult = authenticateRequest(req, this.config.auth ?? {});
if (!authResult.authenticated) {
ws.close(4001, authResult.error ?? 'Authentication failed');
return;
}
this.handleConnection(ws, authResult.identity);
});
this.wss.on('listening', () => {
const addr = this.wss?.address();
const portStr = typeof addr === 'object' && addr ? `:${addr.port}` : '';
console.log(`Gateway WebSocket server listening on ${this.config.host ?? '127.0.0.1'}${portStr}`);
resolve();
});
});
}
async stop(): Promise<void> {
return new Promise((resolve) => {
if (!this.wss) {
resolve();
return;
}
// Close all connections
for (const [ws, connectionId] of this.connectionMap) {
this.sessionBridge.disconnect(connectionId);
ws.close(1001, 'Server shutting down');
}
this.connectionMap.clear();
this.wss.close(() => {
this.wss = null;
resolve();
});
});
}
private handleConnection(ws: WebSocket, identity?: string): void {
const connectionId = randomUUID();
this.sessionBridge.connect(connectionId);
this.connectionMap.set(ws, connectionId);
ws.on('message', async (data) => {
const raw = data.toString();
await this.handleMessage(ws, connectionId, raw);
});
ws.on('close', () => {
this.sessionBridge.disconnect(connectionId);
this.connectionMap.delete(ws);
});
ws.on('error', (err) => {
console.error(`WebSocket error (${connectionId}):`, err.message);
});
}
private async handleMessage(ws: WebSocket, connectionId: string, raw: string): Promise<void> {
const request = parseMessage(raw);
if (!request) {
this.send(ws, makeError(0, ErrorCode.ParseError, 'Invalid JSON or missing required fields'));
return;
}
// Inject connectionId into params so handlers can identify the client
if (!request.params) request.params = {};
request.params.connectionId = connectionId;
const send = (msg: OutboundMessage) => this.send(ws, msg);
const response = await this.router.dispatch(request, send);
if (response) {
this.send(ws, response);
}
}
private send(ws: WebSocket, msg: OutboundMessage): void {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify(msg));
}
}
/** Get the underlying WebSocketServer (for testing). */
getWss(): WebSocketServer | null {
return this.wss;
}
/** Get the session bridge (for testing/debugging). */
getSessionBridge(): SessionBridge {
return this.sessionBridge;
}
/** Get list of registered methods. */
getMethods(): string[] {
return this.router.listMethods();
}
}
+144
View File
@@ -0,0 +1,144 @@
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { SessionBridge } from './session-bridge.js';
import type { SessionBridgeConfig } from './session-bridge.js';
// Minimal mocks
const mockSession = {
id: 'test',
addMessage: vi.fn(),
getHistory: vi.fn(() => []),
clear: vi.fn(),
};
const mockSessionManager = {
getSession: vi.fn(() => mockSession),
listSessions: vi.fn(() => []),
transferSession: vi.fn(),
closeSession: vi.fn(),
};
const mockModelClient = {
chat: vi.fn(async () => ({
content: 'test',
stopReason: 'end_turn',
usage: { inputTokens: 0, outputTokens: 0 },
})),
};
const mockToolRegistry = {
register: vi.fn(),
get: vi.fn(),
list: vi.fn(() => []),
toAnthropicFormat: vi.fn(() => []),
toOpenAIFormat: vi.fn(() => []),
};
const mockToolExecutor = {
execute: vi.fn(),
};
function createBridge(): SessionBridge {
return new SessionBridge({
sessionManager: mockSessionManager as unknown as SessionBridgeConfig['sessionManager'],
modelClient: mockModelClient,
systemPrompt: 'test prompt',
toolRegistry: mockToolRegistry as unknown as SessionBridgeConfig['toolRegistry'],
toolExecutor: mockToolExecutor as unknown as SessionBridgeConfig['toolExecutor'],
});
}
describe('SessionBridge', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('connect assigns a connection ID', () => {
const bridge = createBridge();
const id = bridge.connect('conn-1');
expect(id).toBe('conn-1');
expect(bridge.connectionCount).toBe(1);
});
it('connect auto-generates ID when not provided', () => {
const bridge = createBridge();
const id = bridge.connect();
expect(typeof id).toBe('string');
expect(id.length).toBeGreaterThan(0);
});
it('getAgent returns agent for connected client', () => {
const bridge = createBridge();
bridge.connect('conn-1');
const agent = bridge.getAgent('conn-1');
expect(agent).toBeDefined();
});
it('getSessionId returns session ID for connected client', () => {
const bridge = createBridge();
bridge.connect('conn-1');
expect(bridge.getSessionId('conn-1')).toBe('ws:conn-1');
});
it('disconnect removes client but preserves session', () => {
const bridge = createBridge();
bridge.connect('conn-1');
bridge.disconnect('conn-1');
expect(bridge.connectionCount).toBe(0);
expect(bridge.getAgent('conn-1')).toBeUndefined();
});
it('tracks busy state', () => {
const bridge = createBridge();
bridge.connect('conn-1');
expect(bridge.isBusy('conn-1')).toBe(false);
bridge.setBusy('conn-1', true);
expect(bridge.isBusy('conn-1')).toBe(true);
bridge.setBusy('conn-1', false);
expect(bridge.isBusy('conn-1')).toBe(false);
});
it('switchSession changes session for a connection', () => {
const bridge = createBridge();
bridge.connect('conn-1');
expect(bridge.getSessionId('conn-1')).toBe('ws:conn-1');
bridge.switchSession('conn-1', 'custom-session');
expect(bridge.getSessionId('conn-1')).toBe('custom-session');
});
it('switchSession throws when busy', () => {
const bridge = createBridge();
bridge.connect('conn-1');
bridge.setBusy('conn-1', true);
expect(() => bridge.switchSession('conn-1', 'other')).toThrow('Cannot switch session while agent is busy');
});
it('switchSession throws for unknown connection', () => {
const bridge = createBridge();
expect(() => bridge.switchSession('unknown', 'other')).toThrow('Unknown connection');
});
it('listSessions groups connections by session', () => {
const bridge = createBridge();
bridge.connect('conn-1');
bridge.connect('conn-2');
// Switch conn-2 to share conn-1's session
bridge.switchSession('conn-2', 'ws:conn-1');
const sessions = bridge.listSessions();
expect(sessions).toEqual([{ sessionId: 'ws:conn-1', connections: 2 }]);
});
it('shared sessions keep agent alive when one client disconnects', () => {
const bridge = createBridge();
bridge.connect('conn-1');
bridge.connect('conn-2');
bridge.switchSession('conn-2', 'ws:conn-1');
bridge.disconnect('conn-1');
// conn-2 still has the session
expect(bridge.getAgent('conn-2')).toBeDefined();
expect(bridge.connectionCount).toBe(1);
});
});
+135
View File
@@ -0,0 +1,135 @@
import { randomUUID } from 'crypto';
import type { SessionManager } from '../session/manager.js';
import type { Session } from '../session/manager.js';
import type { ModelClient } from '../models/types.js';
import type { ModelRouter } from '../models/router.js';
import type { ToolRegistry } from '../tools/registry.js';
import type { ToolExecutor } from '../tools/executor.js';
import { NativeAgent } from '../backends/native/agent.js';
import type { ToolUseEvent } from '../backends/native/agent.js';
export interface SessionBridgeConfig {
sessionManager: SessionManager;
modelClient: ModelClient | ModelRouter;
systemPrompt: string;
toolRegistry: ToolRegistry;
toolExecutor: ToolExecutor;
}
interface ClientEntry {
connectionId: string;
sessionId: string;
agent: NativeAgent;
busy: boolean;
}
export class SessionBridge {
private clients: Map<string, ClientEntry> = new Map();
private agents: Map<string, NativeAgent> = new Map();
private config: SessionBridgeConfig;
constructor(config: SessionBridgeConfig) {
this.config = config;
}
/** Register a new WS connection. Returns the assigned connection ID. */
connect(connectionId?: string): string {
const id = connectionId ?? randomUUID();
const sessionId = `ws:${id}`;
const agent = this.getOrCreateAgent(sessionId);
this.clients.set(id, {
connectionId: id,
sessionId,
agent,
busy: false,
});
return id;
}
/** Remove a WS connection. Does NOT destroy the session (persists in SQLite). */
disconnect(connectionId: string): void {
const client = this.clients.get(connectionId);
if (client) {
// Only remove the agent if no other clients share the session
const otherClients = Array.from(this.clients.values())
.filter(c => c.sessionId === client.sessionId && c.connectionId !== connectionId);
if (otherClients.length === 0) {
this.agents.delete(client.sessionId);
}
this.clients.delete(connectionId);
}
}
/** Switch a connection to a different session (e.g. resuming an old session). */
switchSession(connectionId: string, sessionId: string): void {
const client = this.clients.get(connectionId);
if (!client) throw new Error(`Unknown connection: ${connectionId}`);
if (client.busy) throw new Error('Cannot switch session while agent is busy');
const agent = this.getOrCreateAgent(sessionId);
client.sessionId = sessionId;
client.agent = agent;
}
/** Get the NativeAgent for a connection. */
getAgent(connectionId: string): NativeAgent | undefined {
return this.clients.get(connectionId)?.agent;
}
/** Get the session ID for a connection. */
getSessionId(connectionId: string): string | undefined {
return this.clients.get(connectionId)?.sessionId;
}
/** Check if a connection's agent is busy. */
isBusy(connectionId: string): boolean {
return this.clients.get(connectionId)?.busy ?? false;
}
/** Mark a connection's agent as busy/idle. */
setBusy(connectionId: string, busy: boolean): void {
const client = this.clients.get(connectionId);
if (client) client.busy = busy;
}
/** Set onToolUse callback for a connection's agent. */
setOnToolUse(connectionId: string, callback: ((event: ToolUseEvent) => void) | undefined): void {
const client = this.clients.get(connectionId);
if (client) client.agent.setOnToolUse(callback);
}
/** List all active sessions with connection counts. */
listSessions(): Array<{ sessionId: string; connections: number }> {
const sessionMap = new Map<string, number>();
for (const client of this.clients.values()) {
sessionMap.set(client.sessionId, (sessionMap.get(client.sessionId) ?? 0) + 1);
}
return Array.from(sessionMap.entries()).map(([sessionId, connections]) => ({
sessionId,
connections,
}));
}
/** Get count of active connections. */
get connectionCount(): number {
return this.clients.size;
}
private getOrCreateAgent(sessionId: string): NativeAgent {
let agent = this.agents.get(sessionId);
if (!agent) {
const session = this.config.sessionManager.getSession('ws', sessionId);
agent = new NativeAgent({
modelClient: this.config.modelClient,
systemPrompt: this.config.systemPrompt,
session,
toolRegistry: this.config.toolRegistry,
toolExecutor: this.config.toolExecutor,
});
this.agents.set(sessionId, agent);
}
return agent;
}
}