feat(gateway): add WebSocket gateway with JSON-RPC protocol and auth
Phase 2 of the Flynn roadmap. Adds a WebSocket gateway server that starts alongside the Telegram bot, providing real-time API access to the agent, sessions, and tools. Protocol: JSON-RPC-like (request/response/event) over WebSocket. 8 methods: agent.send, agent.cancel, sessions.list, sessions.history, sessions.create, tools.list, tools.invoke, system.health. Auth: Bearer token + Tailscale identity header support. Session bridge: per-connection agent instances with shared model router. New files: src/gateway/ (protocol, router, server, auth, session-bridge, handlers for agent/sessions/tools/system). 57 new tests (181 total), typecheck clean.
This commit is contained in:
@@ -0,0 +1,214 @@
|
||||
# Phase 2: WebSocket Gateway — Implementation Plan
|
||||
|
||||
## Goal
|
||||
|
||||
Add a WebSocket gateway to the Flynn daemon that allows real-time communication from web clients (and eventually any WS client). The gateway wraps existing daemon components (session manager, agent, tool registry) — no refactoring of existing code.
|
||||
|
||||
## Approach
|
||||
|
||||
Incremental, additive. Existing Telegram bot and TUI continue working unchanged. The gateway is a new module that starts alongside them.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
src/gateway/
|
||||
├── protocol.ts # JSON-RPC message types + event types
|
||||
├── server.ts # WebSocket server (ws library), connection lifecycle
|
||||
├── router.ts # Method routing → handler dispatch
|
||||
├── auth.ts # Token auth + Tailscale identity headers
|
||||
├── session-bridge.ts # Maps WS client connections to agent sessions
|
||||
└── handlers/
|
||||
├── agent.ts # agent.send (streaming), agent.cancel, agent.status
|
||||
├── sessions.ts # sessions.list, sessions.history, sessions.create
|
||||
├── tools.ts # tools.list, tools.invoke
|
||||
└── system.ts # system.health, system.info
|
||||
```
|
||||
|
||||
## Protocol
|
||||
|
||||
JSON-RPC-like over WebSocket (not full JSON-RPC 2.0 — simpler):
|
||||
|
||||
```typescript
|
||||
// Client → Server
|
||||
interface GatewayRequest {
|
||||
id: number; // Client-assigned request ID
|
||||
method: string; // e.g. "agent.send"
|
||||
params: object; // Method-specific parameters
|
||||
}
|
||||
|
||||
// Server → Client (success)
|
||||
interface GatewayResponse {
|
||||
id: number; // Matches request ID
|
||||
result: object; // Method-specific result
|
||||
}
|
||||
|
||||
// Server → Client (error)
|
||||
interface GatewayError {
|
||||
id: number; // Matches request ID
|
||||
error: {
|
||||
code: number; // Error code (negative = protocol, positive = app)
|
||||
message: string; // Human-readable description
|
||||
};
|
||||
}
|
||||
|
||||
// Server → Client (streaming event, multiple per request)
|
||||
interface GatewayEvent {
|
||||
id: number; // Matches originating request ID
|
||||
event: string; // Event type name
|
||||
data: object; // Event-specific payload
|
||||
}
|
||||
```
|
||||
|
||||
### Event Types (for agent.send streaming)
|
||||
|
||||
| Event | Data | When |
|
||||
|-------|------|------|
|
||||
| `content` | `{ text: string }` | Text chunk from model |
|
||||
| `tool_start` | `{ tool: string, args: object }` | Tool execution beginning |
|
||||
| `tool_end` | `{ tool: string, result: { success, output, error? } }` | Tool execution complete |
|
||||
| `done` | `{ content: string, usage: { inputTokens, outputTokens } }` | Final response |
|
||||
| `error` | `{ code: number, message: string }` | Error during processing |
|
||||
|
||||
### Error Codes
|
||||
|
||||
| Code | Meaning |
|
||||
|------|---------|
|
||||
| -1 | Parse error (invalid JSON) |
|
||||
| -2 | Invalid request (missing id/method) |
|
||||
| -3 | Method not found |
|
||||
| -4 | Authentication required |
|
||||
| -5 | Authentication failed |
|
||||
| 1 | Session not found |
|
||||
| 2 | Tool not found |
|
||||
| 3 | Agent busy (already processing) |
|
||||
| 4 | Request cancelled |
|
||||
|
||||
## Methods
|
||||
|
||||
### agent.send
|
||||
Send a message to the agent and receive streaming response.
|
||||
|
||||
```
|
||||
Params: { message: string, sessionId?: string }
|
||||
Events: content, tool_start, tool_end, done, error
|
||||
Response: (none — final state sent as "done" event)
|
||||
```
|
||||
|
||||
### agent.cancel
|
||||
Cancel the currently running agent request.
|
||||
|
||||
```
|
||||
Params: { sessionId?: string }
|
||||
Response: { cancelled: boolean }
|
||||
```
|
||||
|
||||
### sessions.list
|
||||
List all active sessions.
|
||||
|
||||
```
|
||||
Params: {}
|
||||
Response: { sessions: [{ id: string, messageCount: number }] }
|
||||
```
|
||||
|
||||
### sessions.history
|
||||
Get message history for a session.
|
||||
|
||||
```
|
||||
Params: { sessionId: string, limit?: number, offset?: number }
|
||||
Response: { messages: Message[], total: number }
|
||||
```
|
||||
|
||||
### sessions.create
|
||||
Create a new session.
|
||||
|
||||
```
|
||||
Params: { sessionId?: string }
|
||||
Response: { sessionId: string }
|
||||
```
|
||||
|
||||
### tools.list
|
||||
List all registered tools.
|
||||
|
||||
```
|
||||
Params: {}
|
||||
Response: { tools: [{ name, description, inputSchema }] }
|
||||
```
|
||||
|
||||
### tools.invoke
|
||||
Directly invoke a tool (bypasses agent).
|
||||
|
||||
```
|
||||
Params: { tool: string, args: object }
|
||||
Response: { success: boolean, output: string, error?: string }
|
||||
```
|
||||
|
||||
### system.health
|
||||
Health check.
|
||||
|
||||
```
|
||||
Params: {}
|
||||
Response: { status: "ok", uptime: number, version: string, sessions: number, tools: number }
|
||||
```
|
||||
|
||||
## Session Bridge
|
||||
|
||||
Each WebSocket connection is associated with a session:
|
||||
|
||||
1. Client connects → assigned default session ID `ws:{connectionId}`
|
||||
2. Client can specify `sessionId` in `agent.send` to use a named session
|
||||
3. Sessions are created on demand via SessionManager
|
||||
4. Multiple WS clients can share a session (e.g. multiple browser tabs)
|
||||
5. Disconnection does NOT destroy the session (persistence via SQLite)
|
||||
|
||||
## Auth (Phase 2b)
|
||||
|
||||
Two auth modes (checked in order):
|
||||
|
||||
1. **Token auth**: `Authorization: Bearer <token>` header on WS upgrade, or `{ method: "auth", params: { token: "..." } }` as first message
|
||||
2. **Tailscale identity**: `Tailscale-User-Login` header set by Tailscale Funnel/proxy
|
||||
|
||||
Config addition:
|
||||
```yaml
|
||||
server:
|
||||
port: 18800
|
||||
auth:
|
||||
token: "optional-static-token" # If set, required for all WS connections
|
||||
tailscale_identity: true # Trust Tailscale-User-Login header
|
||||
```
|
||||
|
||||
No auth initially (Phase 2a) — gateway only listens on localhost.
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. `protocol.ts` — Types only, no runtime code
|
||||
2. `router.ts` — Method dispatch (pure function, easy to test)
|
||||
3. `session-bridge.ts` — Client-to-session mapping
|
||||
4. `handlers/system.ts` — Simplest handler, proves the pattern
|
||||
5. `handlers/sessions.ts` — Session listing/history
|
||||
6. `handlers/tools.ts` — Tool listing/invocation
|
||||
7. `handlers/agent.ts` — The main handler (streaming, tool events)
|
||||
8. `server.ts` — WebSocket server, ties everything together
|
||||
9. Wire into `daemon/index.ts`
|
||||
10. Tests throughout
|
||||
|
||||
## Dependencies
|
||||
|
||||
Need `ws` package:
|
||||
```bash
|
||||
pnpm add ws
|
||||
pnpm add -D @types/ws
|
||||
```
|
||||
|
||||
## Test Strategy
|
||||
|
||||
- Unit tests for protocol message validation
|
||||
- Unit tests for router dispatch
|
||||
- Unit tests for each handler (mock session manager, agent, tool registry)
|
||||
- Integration test: WS client → server → handler → response
|
||||
- Test streaming events for agent.send
|
||||
|
||||
---
|
||||
|
||||
*Plan Version: 1.0*
|
||||
*Created: 2026-02-05*
|
||||
*Parent: docs/plans/2026-02-05-openclaw-parity-design.md Phase 2*
|
||||
@@ -28,6 +28,7 @@
|
||||
"@types/marked-terminal": "^6.1.1",
|
||||
"@types/node": "^22.0.0",
|
||||
"@types/react": "^19.0.0",
|
||||
"@types/ws": "^8.18.1",
|
||||
"eslint": "^9.0.0",
|
||||
"tsx": "^4.0.0",
|
||||
"typescript": "^5.7.0",
|
||||
@@ -45,6 +46,7 @@
|
||||
"ollama": "^0.5.0",
|
||||
"openai": "^4.0.0",
|
||||
"react": "^19.0.0",
|
||||
"ws": "^8.19.0",
|
||||
"yaml": "^2.7.0",
|
||||
"zod": "^3.24.0"
|
||||
},
|
||||
|
||||
Generated
+13
@@ -41,6 +41,9 @@ importers:
|
||||
react:
|
||||
specifier: ^19.0.0
|
||||
version: 19.2.4
|
||||
ws:
|
||||
specifier: ^8.19.0
|
||||
version: 8.19.0
|
||||
yaml:
|
||||
specifier: ^2.7.0
|
||||
version: 2.8.2
|
||||
@@ -60,6 +63,9 @@ importers:
|
||||
'@types/react':
|
||||
specifier: ^19.0.0
|
||||
version: 19.2.11
|
||||
'@types/ws':
|
||||
specifier: ^8.18.1
|
||||
version: 8.18.1
|
||||
eslint:
|
||||
specifier: ^9.0.0
|
||||
version: 9.39.2
|
||||
@@ -477,6 +483,9 @@ packages:
|
||||
'@types/react@19.2.11':
|
||||
resolution: {integrity: sha512-tORuanb01iEzWvMGVGv2ZDhYZVeRMrw453DCSAIn/5yvcSVnMoUMTyf33nQJLahYEnv9xqrTNbgz4qY5EfSh0g==}
|
||||
|
||||
'@types/ws@8.18.1':
|
||||
resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==}
|
||||
|
||||
'@vitest/expect@3.2.4':
|
||||
resolution: {integrity: sha512-Io0yyORnB6sikFlt8QW5K7slY4OjqNX9jmJQ02QDda8lyM6B5oNgVWoSoKPac8/kgnCUzuHQKrSLtu/uOqqrig==}
|
||||
|
||||
@@ -1867,6 +1876,10 @@ snapshots:
|
||||
dependencies:
|
||||
csstype: 3.2.3
|
||||
|
||||
'@types/ws@8.18.1':
|
||||
dependencies:
|
||||
'@types/node': 22.19.7
|
||||
|
||||
'@vitest/expect@3.2.4':
|
||||
dependencies:
|
||||
'@types/chai': 5.2.3
|
||||
|
||||
+26
-2
@@ -7,6 +7,7 @@ import { createTelegramBot } from '../frontends/telegram/index.js';
|
||||
import { SessionStore, SessionManager } from '../session/index.js';
|
||||
import { HookEngine } from '../hooks/index.js';
|
||||
import { ToolRegistry, ToolExecutor, allBuiltinTools } from '../tools/index.js';
|
||||
import { GatewayServer } from '../gateway/index.js';
|
||||
import { resolve } from 'path';
|
||||
import { homedir } from 'os';
|
||||
import { mkdirSync, readFileSync, existsSync } from 'fs';
|
||||
@@ -22,6 +23,7 @@ export interface DaemonContext {
|
||||
modelRouter: ModelRouter;
|
||||
toolRegistry: ToolRegistry;
|
||||
toolExecutor: ToolExecutor;
|
||||
gateway: GatewayServer;
|
||||
}
|
||||
|
||||
function loadSystemPrompt(): string {
|
||||
@@ -133,6 +135,9 @@ export async function startDaemon(config: Config): Promise<DaemonContext> {
|
||||
// Initialize model router
|
||||
const modelRouter = createModelRouter(config);
|
||||
|
||||
// Load system prompt once for reuse
|
||||
const systemPrompt = loadSystemPrompt();
|
||||
|
||||
// Get Telegram session
|
||||
const telegramUserId = String(config.telegram.allowed_chat_ids[0]);
|
||||
const session = sessionManager.getSession('telegram', telegramUserId);
|
||||
@@ -140,7 +145,7 @@ export async function startDaemon(config: Config): Promise<DaemonContext> {
|
||||
// Initialize native agent with session and tools
|
||||
const agent = new NativeAgent({
|
||||
modelClient: modelRouter,
|
||||
systemPrompt: loadSystemPrompt(),
|
||||
systemPrompt,
|
||||
session,
|
||||
toolRegistry,
|
||||
toolExecutor,
|
||||
@@ -153,6 +158,17 @@ export async function startDaemon(config: Config): Promise<DaemonContext> {
|
||||
hookEngine,
|
||||
});
|
||||
|
||||
// Initialize gateway WebSocket server
|
||||
const gateway = new GatewayServer({
|
||||
port: config.server.port,
|
||||
host: config.server.localhost ? '127.0.0.1' : '0.0.0.0',
|
||||
sessionManager,
|
||||
modelClient: modelRouter,
|
||||
systemPrompt,
|
||||
toolRegistry,
|
||||
toolExecutor,
|
||||
});
|
||||
|
||||
// Register signal handlers
|
||||
const signalHandler = () => {
|
||||
lifecycle.shutdown().then(() => process.exit(0));
|
||||
@@ -179,9 +195,17 @@ export async function startDaemon(config: Config): Promise<DaemonContext> {
|
||||
},
|
||||
});
|
||||
|
||||
// Start gateway
|
||||
lifecycle.onShutdown(async () => {
|
||||
await gateway.stop();
|
||||
console.log('Gateway server stopped');
|
||||
});
|
||||
|
||||
await gateway.start();
|
||||
|
||||
console.log('Flynn daemon started');
|
||||
|
||||
return { config, lifecycle, bot, agent, sessionStore, sessionManager, hookEngine, modelRouter, toolRegistry, toolExecutor };
|
||||
return { config, lifecycle, bot, agent, sessionStore, sessionManager, hookEngine, modelRouter, toolRegistry, toolExecutor, gateway };
|
||||
}
|
||||
|
||||
export { Lifecycle } from './lifecycle.js';
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { authenticateRequest } from './auth.js';
|
||||
import type { IncomingMessage } from 'http';
|
||||
|
||||
function mockRequest(headers: Record<string, string> = {}): IncomingMessage {
|
||||
return { headers } as unknown as IncomingMessage;
|
||||
}
|
||||
|
||||
describe('authenticateRequest', () => {
|
||||
describe('no auth configured', () => {
|
||||
it('allows all connections', () => {
|
||||
const result = authenticateRequest(mockRequest(), {});
|
||||
expect(result.authenticated).toBe(true);
|
||||
expect(result.identity).toBe('anonymous');
|
||||
});
|
||||
});
|
||||
|
||||
describe('token auth', () => {
|
||||
const config = { token: 'secret-token-123' };
|
||||
|
||||
it('accepts valid Bearer token', () => {
|
||||
const result = authenticateRequest(
|
||||
mockRequest({ authorization: 'Bearer secret-token-123' }),
|
||||
config,
|
||||
);
|
||||
expect(result.authenticated).toBe(true);
|
||||
expect(result.identity).toBe('token-user');
|
||||
});
|
||||
|
||||
it('rejects missing Authorization header', () => {
|
||||
const result = authenticateRequest(mockRequest(), config);
|
||||
expect(result.authenticated).toBe(false);
|
||||
expect(result.error).toContain('Authorization header required');
|
||||
});
|
||||
|
||||
it('rejects invalid token', () => {
|
||||
const result = authenticateRequest(
|
||||
mockRequest({ authorization: 'Bearer wrong-token' }),
|
||||
config,
|
||||
);
|
||||
expect(result.authenticated).toBe(false);
|
||||
expect(result.error).toContain('Invalid token');
|
||||
});
|
||||
|
||||
it('rejects non-Bearer format', () => {
|
||||
const result = authenticateRequest(
|
||||
mockRequest({ authorization: 'Basic dXNlcjpwYXNz' }),
|
||||
config,
|
||||
);
|
||||
expect(result.authenticated).toBe(false);
|
||||
expect(result.error).toContain('Invalid Authorization format');
|
||||
});
|
||||
|
||||
it('uses Tailscale identity when both token and tailscale are configured', () => {
|
||||
const result = authenticateRequest(
|
||||
mockRequest({
|
||||
authorization: 'Bearer secret-token-123',
|
||||
'tailscale-user-login': 'will@example.com',
|
||||
}),
|
||||
{ token: 'secret-token-123', tailscaleIdentity: true },
|
||||
);
|
||||
expect(result.authenticated).toBe(true);
|
||||
expect(result.identity).toBe('will@example.com');
|
||||
});
|
||||
});
|
||||
|
||||
describe('tailscale identity', () => {
|
||||
const config = { tailscaleIdentity: true };
|
||||
|
||||
it('extracts identity from Tailscale-User-Login header', () => {
|
||||
const result = authenticateRequest(
|
||||
mockRequest({ 'tailscale-user-login': 'will@example.com' }),
|
||||
config,
|
||||
);
|
||||
expect(result.authenticated).toBe(true);
|
||||
expect(result.identity).toBe('will@example.com');
|
||||
});
|
||||
|
||||
it('allows connections without Tailscale header (local access)', () => {
|
||||
const result = authenticateRequest(mockRequest(), config);
|
||||
expect(result.authenticated).toBe(true);
|
||||
expect(result.identity).toBe('anonymous');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,67 @@
|
||||
import type { IncomingMessage } from 'http';
|
||||
|
||||
export interface AuthConfig {
|
||||
/** Static bearer token. If set, all connections must provide it. */
|
||||
token?: string;
|
||||
/** Trust Tailscale-User-Login header for identity. */
|
||||
tailscaleIdentity?: boolean;
|
||||
}
|
||||
|
||||
export interface AuthResult {
|
||||
authenticated: boolean;
|
||||
identity?: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Authenticates a WebSocket upgrade request.
|
||||
*
|
||||
* Auth is checked in this order:
|
||||
* 1. If token is configured, validate Authorization header (Bearer token)
|
||||
* 2. If tailscaleIdentity is enabled, extract identity from Tailscale-User-Login header
|
||||
* 3. If no auth is configured, allow all connections
|
||||
*/
|
||||
export function authenticateRequest(req: IncomingMessage, config: AuthConfig): AuthResult {
|
||||
// If token auth is configured, it's required
|
||||
if (config.token) {
|
||||
const authHeader = req.headers['authorization'];
|
||||
if (!authHeader) {
|
||||
return { authenticated: false, error: 'Authorization header required' };
|
||||
}
|
||||
|
||||
const parts = authHeader.split(' ');
|
||||
if (parts.length !== 2 || parts[0] !== 'Bearer') {
|
||||
return { authenticated: false, error: 'Invalid Authorization format (expected: Bearer <token>)' };
|
||||
}
|
||||
|
||||
if (parts[1] !== config.token) {
|
||||
return { authenticated: false, error: 'Invalid token' };
|
||||
}
|
||||
|
||||
// Token is valid — check for Tailscale identity too
|
||||
const identity = extractTailscaleIdentity(req, config);
|
||||
return { authenticated: true, identity: identity ?? 'token-user' };
|
||||
}
|
||||
|
||||
// If Tailscale identity is configured (no token), use it
|
||||
if (config.tailscaleIdentity) {
|
||||
const identity = extractTailscaleIdentity(req, config);
|
||||
if (identity) {
|
||||
return { authenticated: true, identity };
|
||||
}
|
||||
// Tailscale identity configured but header not present — still allow (might be local)
|
||||
return { authenticated: true, identity: 'anonymous' };
|
||||
}
|
||||
|
||||
// No auth configured — allow all
|
||||
return { authenticated: true, identity: 'anonymous' };
|
||||
}
|
||||
|
||||
function extractTailscaleIdentity(req: IncomingMessage, config: AuthConfig): string | undefined {
|
||||
if (!config.tailscaleIdentity) return undefined;
|
||||
const header = req.headers['tailscale-user-login'];
|
||||
if (typeof header === 'string' && header.length > 0) {
|
||||
return header;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
import type { GatewayRequest, OutboundMessage } from '../protocol.js';
|
||||
import type { SendFn } from '../router.js';
|
||||
import { makeEvent, makeError, ErrorCode } from '../protocol.js';
|
||||
import type { SessionBridge } from '../session-bridge.js';
|
||||
|
||||
export interface AgentHandlerDeps {
|
||||
sessionBridge: SessionBridge;
|
||||
}
|
||||
|
||||
export function createAgentHandlers(deps: AgentHandlerDeps) {
|
||||
return {
|
||||
'agent.send': async (request: GatewayRequest, send: SendFn): Promise<OutboundMessage | void> => {
|
||||
const params = request.params as { message?: string; connectionId?: string } | undefined;
|
||||
if (!params?.message) {
|
||||
return makeError(request.id, ErrorCode.InvalidRequest, 'message is required');
|
||||
}
|
||||
|
||||
const connectionId = params.connectionId as string;
|
||||
if (!connectionId) {
|
||||
return makeError(request.id, ErrorCode.InvalidRequest, 'connectionId is required (set by server)');
|
||||
}
|
||||
|
||||
if (deps.sessionBridge.isBusy(connectionId)) {
|
||||
return makeError(request.id, ErrorCode.AgentBusy, 'Agent is already processing a request');
|
||||
}
|
||||
|
||||
const agent = deps.sessionBridge.getAgent(connectionId);
|
||||
if (!agent) {
|
||||
return makeError(request.id, ErrorCode.SessionNotFound, 'No agent for this connection');
|
||||
}
|
||||
|
||||
deps.sessionBridge.setBusy(connectionId, true);
|
||||
|
||||
// Set up tool use callback to emit streaming events
|
||||
deps.sessionBridge.setOnToolUse(connectionId, (event) => {
|
||||
if (event.type === 'start') {
|
||||
send(makeEvent(request.id, 'tool_start', { tool: event.tool, args: event.args }));
|
||||
} else if (event.type === 'end') {
|
||||
send(makeEvent(request.id, 'tool_end', {
|
||||
tool: event.tool,
|
||||
result: event.result ? {
|
||||
success: event.result.success,
|
||||
output: event.result.output,
|
||||
error: event.result.error,
|
||||
} : undefined,
|
||||
}));
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
const response = await agent.process(params.message);
|
||||
send(makeEvent(request.id, 'done', { content: response }));
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : 'Unknown error';
|
||||
send(makeEvent(request.id, 'error', { code: ErrorCode.InternalError, message }));
|
||||
} finally {
|
||||
deps.sessionBridge.setBusy(connectionId, false);
|
||||
deps.sessionBridge.setOnToolUse(connectionId, undefined);
|
||||
}
|
||||
},
|
||||
|
||||
'agent.cancel': async (request: GatewayRequest): Promise<OutboundMessage> => {
|
||||
// Cancel is a placeholder — proper cancellation requires abort controller support in NativeAgent.
|
||||
// For now, just report whether the agent was busy.
|
||||
const params = request.params as { connectionId?: string } | undefined;
|
||||
const connectionId = params?.connectionId as string;
|
||||
|
||||
if (!connectionId) {
|
||||
return makeError(request.id, ErrorCode.InvalidRequest, 'connectionId is required');
|
||||
}
|
||||
|
||||
const wasBusy = deps.sessionBridge.isBusy(connectionId);
|
||||
// TODO: Wire AbortController into NativeAgent for actual cancellation
|
||||
return { id: request.id, result: { cancelled: wasBusy } };
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { createSystemHandlers } from './system.js';
|
||||
import { createSessionHandlers } from './sessions.js';
|
||||
import { createToolHandlers } from './tools.js';
|
||||
import { createAgentHandlers } from './agent.js';
|
||||
import { ErrorCode } from '../protocol.js';
|
||||
import type { GatewayRequest, GatewayResponse, GatewayError, GatewayEvent, OutboundMessage } from '../protocol.js';
|
||||
|
||||
describe('system handlers', () => {
|
||||
const deps = {
|
||||
startTime: Date.now() - 60_000,
|
||||
version: '0.1.0',
|
||||
getSessionCount: () => 3,
|
||||
getToolCount: () => 6,
|
||||
getConnectionCount: () => 2,
|
||||
};
|
||||
const handlers = createSystemHandlers(deps);
|
||||
|
||||
it('system.health returns status info', async () => {
|
||||
const req: GatewayRequest = { id: 1, method: 'system.health' };
|
||||
const result = await handlers['system.health'](req) as GatewayResponse;
|
||||
|
||||
expect(result.id).toBe(1);
|
||||
const r = result.result as Record<string, unknown>;
|
||||
expect(r.status).toBe('ok');
|
||||
expect(r.version).toBe('0.1.0');
|
||||
expect(r.sessions).toBe(3);
|
||||
expect(r.tools).toBe(6);
|
||||
expect(r.connections).toBe(2);
|
||||
expect(typeof r.uptime).toBe('number');
|
||||
expect(r.uptime).toBeGreaterThanOrEqual(59);
|
||||
});
|
||||
});
|
||||
|
||||
describe('session handlers', () => {
|
||||
const mockHistory = [
|
||||
{ role: 'user' as const, content: 'hello' },
|
||||
{ role: 'assistant' as const, content: 'hi' },
|
||||
];
|
||||
|
||||
const mockSession = {
|
||||
id: 'ws:test',
|
||||
addMessage: vi.fn(),
|
||||
getHistory: vi.fn(() => mockHistory),
|
||||
clear: vi.fn(),
|
||||
};
|
||||
|
||||
const mockSessionManager = {
|
||||
listSessions: vi.fn(() => ['ws:test']),
|
||||
getSession: vi.fn(() => mockSession),
|
||||
transferSession: vi.fn(),
|
||||
closeSession: vi.fn(),
|
||||
};
|
||||
|
||||
const handlers = createSessionHandlers({
|
||||
sessionManager: mockSessionManager as any,
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockSessionManager.listSessions.mockReturnValue(['ws:test']);
|
||||
mockSessionManager.getSession.mockReturnValue(mockSession);
|
||||
mockSession.getHistory.mockReturnValue(mockHistory);
|
||||
});
|
||||
|
||||
it('sessions.list returns session list with message counts', async () => {
|
||||
const req: GatewayRequest = { id: 1, method: 'sessions.list' };
|
||||
const result = await handlers['sessions.list'](req) as GatewayResponse;
|
||||
|
||||
expect(result.id).toBe(1);
|
||||
const r = result.result as { sessions: Array<{ id: string; messageCount: number }> };
|
||||
expect(r.sessions).toHaveLength(1);
|
||||
expect(r.sessions[0].id).toBe('ws:test');
|
||||
expect(r.sessions[0].messageCount).toBe(2);
|
||||
});
|
||||
|
||||
it('sessions.history returns messages with pagination', async () => {
|
||||
const req: GatewayRequest = { id: 2, method: 'sessions.history', params: { sessionId: 'ws:test', limit: 1, offset: 0 } };
|
||||
const result = await handlers['sessions.history'](req) as GatewayResponse;
|
||||
|
||||
const r = result.result as { messages: unknown[]; total: number };
|
||||
expect(r.messages).toHaveLength(1);
|
||||
expect(r.total).toBe(2);
|
||||
});
|
||||
|
||||
it('sessions.history requires sessionId', async () => {
|
||||
const req: GatewayRequest = { id: 3, method: 'sessions.history', params: {} };
|
||||
const result = await handlers['sessions.history'](req) as GatewayError;
|
||||
|
||||
expect(result.error.code).toBe(ErrorCode.InvalidRequest);
|
||||
});
|
||||
|
||||
it('sessions.create creates a new session', async () => {
|
||||
const req: GatewayRequest = { id: 4, method: 'sessions.create', params: { sessionId: 'ws:new' } };
|
||||
const result = await handlers['sessions.create'](req) as GatewayResponse;
|
||||
|
||||
const r = result.result as { sessionId: string };
|
||||
expect(r.sessionId).toBe('ws:new');
|
||||
expect(mockSessionManager.getSession).toHaveBeenCalledWith('ws', 'new');
|
||||
});
|
||||
|
||||
it('sessions.create auto-generates session ID', async () => {
|
||||
const req: GatewayRequest = { id: 5, method: 'sessions.create' };
|
||||
const result = await handlers['sessions.create'](req) as GatewayResponse;
|
||||
|
||||
const r = result.result as { sessionId: string };
|
||||
expect(r.sessionId).toMatch(/^ws:\d+$/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('tool handlers', () => {
|
||||
const mockTool = {
|
||||
name: 'test.tool',
|
||||
description: 'A test tool',
|
||||
inputSchema: { type: 'object' as const, properties: {} },
|
||||
execute: vi.fn(),
|
||||
};
|
||||
|
||||
const mockRegistry = {
|
||||
list: vi.fn(() => [mockTool]),
|
||||
get: vi.fn((name: string) => (name === 'test.tool' ? mockTool : undefined)),
|
||||
register: vi.fn(),
|
||||
toAnthropicFormat: vi.fn(),
|
||||
toOpenAIFormat: vi.fn(),
|
||||
};
|
||||
|
||||
const mockExecutor = {
|
||||
execute: vi.fn(async () => ({ success: true, output: 'done' })),
|
||||
};
|
||||
|
||||
const handlers = createToolHandlers({
|
||||
toolRegistry: mockRegistry as any,
|
||||
toolExecutor: mockExecutor as any,
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockRegistry.list.mockReturnValue([mockTool]);
|
||||
mockRegistry.get.mockImplementation((name: string) => (name === 'test.tool' ? mockTool : undefined));
|
||||
mockExecutor.execute.mockResolvedValue({ success: true, output: 'done' });
|
||||
});
|
||||
|
||||
it('tools.list returns tool definitions', async () => {
|
||||
const req: GatewayRequest = { id: 1, method: 'tools.list' };
|
||||
const result = await handlers['tools.list'](req) as GatewayResponse;
|
||||
|
||||
const r = result.result as { tools: Array<{ name: string }> };
|
||||
expect(r.tools).toHaveLength(1);
|
||||
expect(r.tools[0].name).toBe('test.tool');
|
||||
});
|
||||
|
||||
it('tools.invoke executes a tool', async () => {
|
||||
const req: GatewayRequest = { id: 2, method: 'tools.invoke', params: { tool: 'test.tool', args: {} } };
|
||||
const result = await handlers['tools.invoke'](req) as GatewayResponse;
|
||||
|
||||
expect(result.result).toEqual({ success: true, output: 'done' });
|
||||
expect(mockExecutor.execute).toHaveBeenCalledWith('test.tool', {});
|
||||
});
|
||||
|
||||
it('tools.invoke errors on missing tool name', async () => {
|
||||
const req: GatewayRequest = { id: 3, method: 'tools.invoke', params: {} };
|
||||
const result = await handlers['tools.invoke'](req) as GatewayError;
|
||||
expect(result.error.code).toBe(ErrorCode.InvalidRequest);
|
||||
});
|
||||
|
||||
it('tools.invoke errors on unknown tool', async () => {
|
||||
const req: GatewayRequest = { id: 4, method: 'tools.invoke', params: { tool: 'unknown' } };
|
||||
const result = await handlers['tools.invoke'](req) as GatewayError;
|
||||
expect(result.error.code).toBe(ErrorCode.ToolNotFound);
|
||||
});
|
||||
});
|
||||
|
||||
describe('agent handlers', () => {
|
||||
const mockAgent = {
|
||||
process: vi.fn(async () => 'response text'),
|
||||
setOnToolUse: vi.fn(),
|
||||
};
|
||||
|
||||
const mockBridge = {
|
||||
getAgent: vi.fn(() => mockAgent),
|
||||
getSessionId: vi.fn(() => 'ws:conn-1'),
|
||||
isBusy: vi.fn(() => false),
|
||||
setBusy: vi.fn(),
|
||||
setOnToolUse: vi.fn(),
|
||||
};
|
||||
|
||||
const handlers = createAgentHandlers({
|
||||
sessionBridge: mockBridge as any,
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockBridge.isBusy.mockReturnValue(false);
|
||||
mockBridge.getAgent.mockReturnValue(mockAgent);
|
||||
mockAgent.process.mockResolvedValue('response text');
|
||||
});
|
||||
|
||||
it('agent.send processes message and sends done event', async () => {
|
||||
const req: GatewayRequest = { id: 1, method: 'agent.send', params: { message: 'hello', connectionId: 'conn-1' } };
|
||||
const sent: OutboundMessage[] = [];
|
||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||
|
||||
await handlers['agent.send'](req, send);
|
||||
|
||||
expect(mockAgent.process).toHaveBeenCalledWith('hello');
|
||||
expect(sent).toHaveLength(1);
|
||||
const doneEvent = sent[0] as GatewayEvent;
|
||||
expect(doneEvent.event).toBe('done');
|
||||
expect((doneEvent.data as any).content).toBe('response text');
|
||||
});
|
||||
|
||||
it('agent.send requires message', async () => {
|
||||
const req: GatewayRequest = { id: 2, method: 'agent.send', params: { connectionId: 'conn-1' } };
|
||||
const send = vi.fn();
|
||||
const result = await handlers['agent.send'](req, send) as GatewayError;
|
||||
|
||||
expect(result.error.code).toBe(ErrorCode.InvalidRequest);
|
||||
expect(result.error.message).toContain('message');
|
||||
});
|
||||
|
||||
it('agent.send rejects when busy', async () => {
|
||||
mockBridge.isBusy.mockReturnValue(true);
|
||||
const req: GatewayRequest = { id: 3, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } };
|
||||
const send = vi.fn();
|
||||
const result = await handlers['agent.send'](req, send) as GatewayError;
|
||||
|
||||
expect(result.error.code).toBe(ErrorCode.AgentBusy);
|
||||
});
|
||||
|
||||
it('agent.send handles errors gracefully', async () => {
|
||||
mockAgent.process.mockRejectedValue(new Error('model failed'));
|
||||
const req: GatewayRequest = { id: 4, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } };
|
||||
const sent: OutboundMessage[] = [];
|
||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||
|
||||
await handlers['agent.send'](req, send);
|
||||
|
||||
const errorEvent = sent[0] as GatewayEvent;
|
||||
expect(errorEvent.event).toBe('error');
|
||||
expect((errorEvent.data as any).message).toBe('model failed');
|
||||
});
|
||||
|
||||
it('agent.send sets and cleans up tool use callback', async () => {
|
||||
const req: GatewayRequest = { id: 5, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } };
|
||||
const send = vi.fn();
|
||||
|
||||
await handlers['agent.send'](req, send);
|
||||
|
||||
// setOnToolUse called twice: once to set callback, once to clear it
|
||||
expect(mockBridge.setOnToolUse).toHaveBeenCalledTimes(2);
|
||||
expect(mockBridge.setOnToolUse).toHaveBeenLastCalledWith('conn-1', undefined);
|
||||
});
|
||||
|
||||
it('agent.send sets busy state correctly', async () => {
|
||||
const req: GatewayRequest = { id: 6, method: 'agent.send', params: { message: 'hi', connectionId: 'conn-1' } };
|
||||
const send = vi.fn();
|
||||
|
||||
await handlers['agent.send'](req, send);
|
||||
|
||||
expect(mockBridge.setBusy).toHaveBeenCalledWith('conn-1', true);
|
||||
expect(mockBridge.setBusy).toHaveBeenCalledWith('conn-1', false);
|
||||
});
|
||||
|
||||
it('agent.cancel returns cancelled state', async () => {
|
||||
mockBridge.isBusy.mockReturnValue(true);
|
||||
const req: GatewayRequest = { id: 7, method: 'agent.cancel', params: { connectionId: 'conn-1' } };
|
||||
const result = await handlers['agent.cancel'](req) as GatewayResponse;
|
||||
|
||||
expect((result.result as any).cancelled).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,8 @@
|
||||
export { createSystemHandlers } from './system.js';
|
||||
export type { SystemHandlerDeps } from './system.js';
|
||||
export { createSessionHandlers } from './sessions.js';
|
||||
export type { SessionHandlerDeps } from './sessions.js';
|
||||
export { createToolHandlers } from './tools.js';
|
||||
export type { ToolHandlerDeps } from './tools.js';
|
||||
export { createAgentHandlers } from './agent.js';
|
||||
export type { AgentHandlerDeps } from './agent.js';
|
||||
@@ -0,0 +1,59 @@
|
||||
import type { GatewayRequest, OutboundMessage } from '../protocol.js';
|
||||
import { makeResponse, makeError, ErrorCode } from '../protocol.js';
|
||||
import type { SessionManager } from '../../session/manager.js';
|
||||
|
||||
export interface SessionHandlerDeps {
|
||||
sessionManager: SessionManager;
|
||||
}
|
||||
|
||||
export function createSessionHandlers(deps: SessionHandlerDeps) {
|
||||
return {
|
||||
'sessions.list': async (request: GatewayRequest): Promise<OutboundMessage> => {
|
||||
const sessionIds = deps.sessionManager.listSessions();
|
||||
const sessions = sessionIds.map(id => ({
|
||||
id,
|
||||
messageCount: deps.sessionManager.getSession(
|
||||
id.split(':')[0],
|
||||
id.split(':').slice(1).join(':')
|
||||
).getHistory().length,
|
||||
}));
|
||||
return makeResponse(request.id, { sessions });
|
||||
},
|
||||
|
||||
'sessions.history': async (request: GatewayRequest): Promise<OutboundMessage> => {
|
||||
const params = request.params as { sessionId?: string; limit?: number; offset?: number } | undefined;
|
||||
if (!params?.sessionId) {
|
||||
return makeError(request.id, ErrorCode.InvalidRequest, 'sessionId is required');
|
||||
}
|
||||
|
||||
const { sessionId, limit, offset } = params;
|
||||
const parts = sessionId.split(':');
|
||||
const frontend = parts[0];
|
||||
const userId = parts.slice(1).join(':');
|
||||
const session = deps.sessionManager.getSession(frontend, userId);
|
||||
const allMessages = session.getHistory();
|
||||
|
||||
const start = offset ?? 0;
|
||||
const end = limit ? start + limit : allMessages.length;
|
||||
const messages = allMessages.slice(start, end);
|
||||
|
||||
return makeResponse(request.id, {
|
||||
messages,
|
||||
total: allMessages.length,
|
||||
});
|
||||
},
|
||||
|
||||
'sessions.create': async (request: GatewayRequest): Promise<OutboundMessage> => {
|
||||
const params = request.params as { sessionId?: string } | undefined;
|
||||
const sessionId = params?.sessionId ?? `ws:${Date.now()}`;
|
||||
const parts = sessionId.split(':');
|
||||
const frontend = parts[0];
|
||||
const userId = parts.slice(1).join(':');
|
||||
|
||||
// Creating a session via getSession is idempotent
|
||||
deps.sessionManager.getSession(frontend, userId);
|
||||
|
||||
return makeResponse(request.id, { sessionId });
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
import type { GatewayRequest, OutboundMessage } from '../protocol.js';
|
||||
import { makeResponse } from '../protocol.js';
|
||||
|
||||
export interface SystemHandlerDeps {
|
||||
startTime: number;
|
||||
version: string;
|
||||
getSessionCount: () => number;
|
||||
getToolCount: () => number;
|
||||
getConnectionCount: () => number;
|
||||
}
|
||||
|
||||
export function createSystemHandlers(deps: SystemHandlerDeps) {
|
||||
return {
|
||||
'system.health': async (request: GatewayRequest): Promise<OutboundMessage> => {
|
||||
return makeResponse(request.id, {
|
||||
status: 'ok',
|
||||
uptime: Math.floor((Date.now() - deps.startTime) / 1000),
|
||||
version: deps.version,
|
||||
sessions: deps.getSessionCount(),
|
||||
tools: deps.getToolCount(),
|
||||
connections: deps.getConnectionCount(),
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
import type { GatewayRequest, OutboundMessage } from '../protocol.js';
|
||||
import { makeResponse, makeError, ErrorCode } from '../protocol.js';
|
||||
import type { ToolRegistry } from '../../tools/registry.js';
|
||||
import type { ToolExecutor } from '../../tools/executor.js';
|
||||
|
||||
export interface ToolHandlerDeps {
|
||||
toolRegistry: ToolRegistry;
|
||||
toolExecutor: ToolExecutor;
|
||||
}
|
||||
|
||||
export function createToolHandlers(deps: ToolHandlerDeps) {
|
||||
return {
|
||||
'tools.list': async (request: GatewayRequest): Promise<OutboundMessage> => {
|
||||
const tools = deps.toolRegistry.list().map(t => ({
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
inputSchema: t.inputSchema,
|
||||
}));
|
||||
return makeResponse(request.id, { tools });
|
||||
},
|
||||
|
||||
'tools.invoke': async (request: GatewayRequest): Promise<OutboundMessage> => {
|
||||
const params = request.params as { tool?: string; args?: Record<string, unknown> } | undefined;
|
||||
if (!params?.tool) {
|
||||
return makeError(request.id, ErrorCode.InvalidRequest, 'tool name is required');
|
||||
}
|
||||
|
||||
const tool = deps.toolRegistry.get(params.tool);
|
||||
if (!tool) {
|
||||
return makeError(request.id, ErrorCode.ToolNotFound, `Tool not found: ${params.tool}`);
|
||||
}
|
||||
|
||||
const result = await deps.toolExecutor.execute(params.tool, params.args ?? {});
|
||||
return makeResponse(request.id, result);
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
export { GatewayServer } from './server.js';
|
||||
export type { GatewayServerConfig } from './server.js';
|
||||
export { Router } from './router.js';
|
||||
export type { HandlerFn, SendFn } from './router.js';
|
||||
export { SessionBridge } from './session-bridge.js';
|
||||
export type { SessionBridgeConfig } from './session-bridge.js';
|
||||
export { authenticateRequest } from './auth.js';
|
||||
export type { AuthConfig, AuthResult } from './auth.js';
|
||||
export {
|
||||
ErrorCode,
|
||||
isValidRequest,
|
||||
parseMessage,
|
||||
makeResponse,
|
||||
makeError,
|
||||
makeEvent,
|
||||
} from './protocol.js';
|
||||
export type {
|
||||
GatewayRequest,
|
||||
GatewayResponse,
|
||||
GatewayError,
|
||||
GatewayEvent,
|
||||
OutboundMessage,
|
||||
EventType,
|
||||
ContentEventData,
|
||||
ToolStartEventData,
|
||||
ToolEndEventData,
|
||||
DoneEventData,
|
||||
ErrorEventData,
|
||||
} from './protocol.js';
|
||||
@@ -0,0 +1,90 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
isValidRequest,
|
||||
parseMessage,
|
||||
makeResponse,
|
||||
makeError,
|
||||
makeEvent,
|
||||
ErrorCode,
|
||||
} from './protocol.js';
|
||||
|
||||
describe('protocol', () => {
|
||||
describe('isValidRequest', () => {
|
||||
it('accepts valid request with params', () => {
|
||||
expect(isValidRequest({ id: 1, method: 'agent.send', params: { message: 'hello' } })).toBe(true);
|
||||
});
|
||||
|
||||
it('accepts valid request without params', () => {
|
||||
expect(isValidRequest({ id: 1, method: 'system.health' })).toBe(true);
|
||||
});
|
||||
|
||||
it('rejects missing id', () => {
|
||||
expect(isValidRequest({ method: 'test' })).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects missing method', () => {
|
||||
expect(isValidRequest({ id: 1 })).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects non-object', () => {
|
||||
expect(isValidRequest('not an object')).toBe(false);
|
||||
expect(isValidRequest(null)).toBe(false);
|
||||
expect(isValidRequest(42)).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects non-numeric id', () => {
|
||||
expect(isValidRequest({ id: 'abc', method: 'test' })).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects non-string method', () => {
|
||||
expect(isValidRequest({ id: 1, method: 42 })).toBe(false);
|
||||
});
|
||||
|
||||
it('rejects non-object params', () => {
|
||||
expect(isValidRequest({ id: 1, method: 'test', params: 'bad' })).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseMessage', () => {
|
||||
it('parses valid JSON into GatewayRequest', () => {
|
||||
const msg = parseMessage('{"id":1,"method":"agent.send","params":{"message":"hi"}}');
|
||||
expect(msg).toEqual({ id: 1, method: 'agent.send', params: { message: 'hi' } });
|
||||
});
|
||||
|
||||
it('returns null for invalid JSON', () => {
|
||||
expect(parseMessage('not json')).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null for valid JSON that is not a valid request', () => {
|
||||
expect(parseMessage('{"method":"test"}')).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('makeResponse', () => {
|
||||
it('creates a response message', () => {
|
||||
expect(makeResponse(1, { status: 'ok' })).toEqual({
|
||||
id: 1,
|
||||
result: { status: 'ok' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('makeError', () => {
|
||||
it('creates an error message', () => {
|
||||
expect(makeError(1, ErrorCode.MethodNotFound, 'Not found')).toEqual({
|
||||
id: 1,
|
||||
error: { code: -3, message: 'Not found' },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('makeEvent', () => {
|
||||
it('creates an event message', () => {
|
||||
expect(makeEvent(1, 'content', { text: 'hello' })).toEqual({
|
||||
id: 1,
|
||||
event: 'content',
|
||||
data: { text: 'hello' },
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,119 @@
|
||||
// Gateway protocol types — JSON-RPC-like messages over WebSocket.
|
||||
|
||||
// ── Client → Server ────────────────────────────────────────────
|
||||
|
||||
export interface GatewayRequest {
|
||||
id: number;
|
||||
method: string;
|
||||
params?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
// ── Server → Client ────────────────────────────────────────────
|
||||
|
||||
export interface GatewayResponse {
|
||||
id: number;
|
||||
result: unknown;
|
||||
}
|
||||
|
||||
export interface GatewayError {
|
||||
id: number;
|
||||
error: {
|
||||
code: ErrorCode;
|
||||
message: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface GatewayEvent {
|
||||
id: number;
|
||||
event: EventType;
|
||||
data: unknown;
|
||||
}
|
||||
|
||||
// ── Event types emitted during agent.send ──────────────────────
|
||||
|
||||
export type EventType =
|
||||
| 'content'
|
||||
| 'tool_start'
|
||||
| 'tool_end'
|
||||
| 'done'
|
||||
| 'error';
|
||||
|
||||
export interface ContentEventData {
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface ToolStartEventData {
|
||||
tool: string;
|
||||
args: unknown;
|
||||
}
|
||||
|
||||
export interface ToolEndEventData {
|
||||
tool: string;
|
||||
result: {
|
||||
success: boolean;
|
||||
output: string;
|
||||
error?: string;
|
||||
};
|
||||
}
|
||||
|
||||
export interface DoneEventData {
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface ErrorEventData {
|
||||
code: ErrorCode;
|
||||
message: string;
|
||||
}
|
||||
|
||||
// ── Error codes ────────────────────────────────────────────────
|
||||
|
||||
export enum ErrorCode {
|
||||
ParseError = -1,
|
||||
InvalidRequest = -2,
|
||||
MethodNotFound = -3,
|
||||
AuthRequired = -4,
|
||||
AuthFailed = -5,
|
||||
SessionNotFound = 1,
|
||||
ToolNotFound = 2,
|
||||
AgentBusy = 3,
|
||||
RequestCancelled = 4,
|
||||
InternalError = 5,
|
||||
}
|
||||
|
||||
// ── Outbound message (union of all server → client types) ──────
|
||||
|
||||
export type OutboundMessage = GatewayResponse | GatewayError | GatewayEvent;
|
||||
|
||||
// ── Validation helpers ─────────────────────────────────────────
|
||||
|
||||
export function isValidRequest(msg: unknown): msg is GatewayRequest {
|
||||
if (typeof msg !== 'object' || msg === null) return false;
|
||||
const obj = msg as Record<string, unknown>;
|
||||
return (
|
||||
typeof obj.id === 'number' &&
|
||||
typeof obj.method === 'string' &&
|
||||
(obj.params === undefined || (typeof obj.params === 'object' && obj.params !== null))
|
||||
);
|
||||
}
|
||||
|
||||
export function parseMessage(raw: string): GatewayRequest | null {
|
||||
try {
|
||||
const parsed = JSON.parse(raw);
|
||||
if (isValidRequest(parsed)) return parsed;
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function makeResponse(id: number, result: unknown): GatewayResponse {
|
||||
return { id, result };
|
||||
}
|
||||
|
||||
export function makeError(id: number, code: ErrorCode, message: string): GatewayError {
|
||||
return { id, error: { code, message } };
|
||||
}
|
||||
|
||||
export function makeEvent(id: number, event: EventType, data: unknown): GatewayEvent {
|
||||
return { id, event, data };
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { Router } from './router.js';
|
||||
import type { GatewayRequest, OutboundMessage } from './protocol.js';
|
||||
import { makeResponse, ErrorCode } from './protocol.js';
|
||||
|
||||
describe('Router', () => {
|
||||
it('dispatches to registered handler', async () => {
|
||||
const router = new Router();
|
||||
const handler = vi.fn(async (req: GatewayRequest) => makeResponse(req.id, { ok: true }));
|
||||
router.register('test.method', handler);
|
||||
|
||||
const request: GatewayRequest = { id: 1, method: 'test.method', params: {} };
|
||||
const send = vi.fn();
|
||||
const result = await router.dispatch(request, send);
|
||||
|
||||
expect(handler).toHaveBeenCalledWith(request, send);
|
||||
expect(result).toEqual({ id: 1, result: { ok: true } });
|
||||
});
|
||||
|
||||
it('returns MethodNotFound for unregistered method', async () => {
|
||||
const router = new Router();
|
||||
const request: GatewayRequest = { id: 1, method: 'unknown.method' };
|
||||
const send = vi.fn();
|
||||
const result = await router.dispatch(request, send);
|
||||
|
||||
expect(result).toEqual({
|
||||
id: 1,
|
||||
error: { code: ErrorCode.MethodNotFound, message: 'Unknown method: unknown.method' },
|
||||
});
|
||||
});
|
||||
|
||||
it('lists registered methods', () => {
|
||||
const router = new Router();
|
||||
router.register('a.method', async () => {});
|
||||
router.register('b.method', async () => {});
|
||||
|
||||
expect(router.listMethods()).toEqual(['a.method', 'b.method']);
|
||||
});
|
||||
|
||||
it('handler can send streaming events via send function', async () => {
|
||||
const router = new Router();
|
||||
router.register('stream.test', async (req, send) => {
|
||||
send({ id: req.id, event: 'content', data: { text: 'chunk1' } });
|
||||
send({ id: req.id, event: 'content', data: { text: 'chunk2' } });
|
||||
send({ id: req.id, event: 'done', data: { content: 'chunk1chunk2' } });
|
||||
});
|
||||
|
||||
const request: GatewayRequest = { id: 1, method: 'stream.test' };
|
||||
const sent: OutboundMessage[] = [];
|
||||
const send = vi.fn((msg: OutboundMessage) => sent.push(msg));
|
||||
|
||||
await router.dispatch(request, send);
|
||||
|
||||
expect(sent).toHaveLength(3);
|
||||
expect(sent[0]).toEqual({ id: 1, event: 'content', data: { text: 'chunk1' } });
|
||||
expect(sent[2]).toEqual({ id: 1, event: 'done', data: { content: 'chunk1chunk2' } });
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,27 @@
|
||||
import type { GatewayRequest, OutboundMessage } from './protocol.js';
|
||||
import { makeError, ErrorCode } from './protocol.js';
|
||||
|
||||
// A handler function receives a request and a send function for streaming events.
|
||||
// It returns a final response/error, or void if it already sent a done event.
|
||||
export type SendFn = (msg: OutboundMessage) => void;
|
||||
export type HandlerFn = (request: GatewayRequest, send: SendFn) => Promise<OutboundMessage | void>;
|
||||
|
||||
export class Router {
|
||||
private handlers: Map<string, HandlerFn> = new Map();
|
||||
|
||||
register(method: string, handler: HandlerFn): void {
|
||||
this.handlers.set(method, handler);
|
||||
}
|
||||
|
||||
async dispatch(request: GatewayRequest, send: SendFn): Promise<OutboundMessage | void> {
|
||||
const handler = this.handlers.get(request.method);
|
||||
if (!handler) {
|
||||
return makeError(request.id, ErrorCode.MethodNotFound, `Unknown method: ${request.method}`);
|
||||
}
|
||||
return handler(request, send);
|
||||
}
|
||||
|
||||
listMethods(): string[] {
|
||||
return Array.from(this.handlers.keys());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
import { describe, it, expect, beforeAll, afterAll, vi } from 'vitest';
|
||||
import { WebSocket } from 'ws';
|
||||
import { GatewayServer } from './server.js';
|
||||
import type { GatewayServerConfig } from './server.js';
|
||||
import type { GatewayResponse, GatewayError, GatewayEvent } from './protocol.js';
|
||||
import { ErrorCode } from './protocol.js';
|
||||
|
||||
// Minimal mocks for dependencies
|
||||
const mockSession = {
|
||||
id: 'test',
|
||||
addMessage: vi.fn(),
|
||||
getHistory: vi.fn(() => []),
|
||||
clear: vi.fn(),
|
||||
setHistory: vi.fn(),
|
||||
};
|
||||
|
||||
const mockSessionManager = {
|
||||
getSession: vi.fn(() => mockSession),
|
||||
listSessions: vi.fn(() => ['ws:test']),
|
||||
transferSession: vi.fn(),
|
||||
closeSession: vi.fn(),
|
||||
};
|
||||
|
||||
const mockModelClient = {
|
||||
chat: vi.fn(async () => ({
|
||||
content: 'Hello from Flynn!',
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
})),
|
||||
};
|
||||
|
||||
const mockToolRegistry = {
|
||||
register: vi.fn(),
|
||||
get: vi.fn((name: string) => (name === 'shell.exec' ? { name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } } : undefined)),
|
||||
list: vi.fn(() => [{ name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } }]),
|
||||
toAnthropicFormat: vi.fn(() => []),
|
||||
toOpenAIFormat: vi.fn(() => []),
|
||||
};
|
||||
|
||||
const mockToolExecutor = {
|
||||
execute: vi.fn(async () => ({ success: true, output: 'executed' })),
|
||||
};
|
||||
|
||||
const TEST_PORT = 18899;
|
||||
|
||||
let server: GatewayServer;
|
||||
|
||||
function createClient(): Promise<WebSocket> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const ws = new WebSocket(`ws://127.0.0.1:${TEST_PORT}`);
|
||||
ws.on('open', () => resolve(ws));
|
||||
ws.on('error', reject);
|
||||
});
|
||||
}
|
||||
|
||||
function sendAndReceive(ws: WebSocket, msg: object): Promise<GatewayResponse | GatewayError | GatewayEvent> {
|
||||
return new Promise((resolve) => {
|
||||
ws.once('message', (data) => {
|
||||
resolve(JSON.parse(data.toString()));
|
||||
});
|
||||
ws.send(JSON.stringify(msg));
|
||||
});
|
||||
}
|
||||
|
||||
function sendAndReceiveAll(ws: WebSocket, msg: object, count: number): Promise<Array<GatewayResponse | GatewayError | GatewayEvent>> {
|
||||
return new Promise((resolve) => {
|
||||
const messages: Array<GatewayResponse | GatewayError | GatewayEvent> = [];
|
||||
const handler = (data: Buffer) => {
|
||||
messages.push(JSON.parse(data.toString()));
|
||||
if (messages.length >= count) {
|
||||
ws.off('message', handler);
|
||||
resolve(messages);
|
||||
}
|
||||
};
|
||||
ws.on('message', handler);
|
||||
ws.send(JSON.stringify(msg));
|
||||
});
|
||||
}
|
||||
|
||||
describe('GatewayServer integration', () => {
|
||||
beforeAll(async () => {
|
||||
server = new GatewayServer({
|
||||
port: TEST_PORT,
|
||||
sessionManager: mockSessionManager as unknown as GatewayServerConfig['sessionManager'],
|
||||
modelClient: mockModelClient,
|
||||
systemPrompt: 'Test prompt',
|
||||
toolRegistry: mockToolRegistry as unknown as GatewayServerConfig['toolRegistry'],
|
||||
toolExecutor: mockToolExecutor as unknown as GatewayServerConfig['toolExecutor'],
|
||||
version: '0.1.0-test',
|
||||
});
|
||||
await server.start();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await server.stop();
|
||||
});
|
||||
|
||||
it('responds to system.health', async () => {
|
||||
const ws = await createClient();
|
||||
try {
|
||||
const result = await sendAndReceive(ws, { id: 1, method: 'system.health' });
|
||||
const response = result as GatewayResponse;
|
||||
expect(response.id).toBe(1);
|
||||
const r = response.result as Record<string, unknown>;
|
||||
expect(r.status).toBe('ok');
|
||||
expect(r.version).toBe('0.1.0-test');
|
||||
expect(typeof r.uptime).toBe('number');
|
||||
} finally {
|
||||
ws.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('returns MethodNotFound for unknown method', async () => {
|
||||
const ws = await createClient();
|
||||
try {
|
||||
const result = await sendAndReceive(ws, { id: 2, method: 'unknown.method' });
|
||||
const error = result as GatewayError;
|
||||
expect(error.error.code).toBe(ErrorCode.MethodNotFound);
|
||||
} finally {
|
||||
ws.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('returns ParseError for invalid JSON', async () => {
|
||||
const ws = await createClient();
|
||||
try {
|
||||
const result = await new Promise<GatewayError>((resolve) => {
|
||||
ws.once('message', (data) => resolve(JSON.parse(data.toString())));
|
||||
ws.send('not valid json');
|
||||
});
|
||||
expect(result.error.code).toBe(ErrorCode.ParseError);
|
||||
} finally {
|
||||
ws.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('lists tools via tools.list', async () => {
|
||||
const ws = await createClient();
|
||||
try {
|
||||
const result = await sendAndReceive(ws, { id: 3, method: 'tools.list' });
|
||||
const response = result as GatewayResponse;
|
||||
const r = response.result as { tools: Array<{ name: string }> };
|
||||
expect(r.tools.length).toBeGreaterThan(0);
|
||||
expect(r.tools[0].name).toBe('shell.exec');
|
||||
} finally {
|
||||
ws.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('sends agent message and receives done event', async () => {
|
||||
const ws = await createClient();
|
||||
try {
|
||||
// agent.send streams events — we expect a 'done' event
|
||||
const messages = await sendAndReceiveAll(ws, { id: 4, method: 'agent.send', params: { message: 'hi' } }, 1);
|
||||
const doneEvent = messages[0] as GatewayEvent;
|
||||
expect(doneEvent.id).toBe(4);
|
||||
expect(doneEvent.event).toBe('done');
|
||||
expect((doneEvent.data as any).content).toBe('Hello from Flynn!');
|
||||
} finally {
|
||||
ws.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('tracks connections correctly', async () => {
|
||||
const ws1 = await createClient();
|
||||
const ws2 = await createClient();
|
||||
try {
|
||||
const result = await sendAndReceive(ws1, { id: 5, method: 'system.health' });
|
||||
const r = (result as GatewayResponse).result as Record<string, unknown>;
|
||||
expect(r.connections).toBe(2);
|
||||
} finally {
|
||||
ws1.close();
|
||||
ws2.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('lists registered methods', () => {
|
||||
const methods = server.getMethods();
|
||||
expect(methods).toContain('system.health');
|
||||
expect(methods).toContain('agent.send');
|
||||
expect(methods).toContain('agent.cancel');
|
||||
expect(methods).toContain('sessions.list');
|
||||
expect(methods).toContain('sessions.history');
|
||||
expect(methods).toContain('sessions.create');
|
||||
expect(methods).toContain('tools.list');
|
||||
expect(methods).toContain('tools.invoke');
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,204 @@
|
||||
import { WebSocketServer, WebSocket } from 'ws';
|
||||
import { randomUUID } from 'crypto';
|
||||
import type { IncomingMessage } from 'http';
|
||||
import { Router } from './router.js';
|
||||
import { SessionBridge } from './session-bridge.js';
|
||||
import type { SessionBridgeConfig } from './session-bridge.js';
|
||||
import { authenticateRequest } from './auth.js';
|
||||
import type { AuthConfig } from './auth.js';
|
||||
import {
|
||||
parseMessage,
|
||||
makeError,
|
||||
ErrorCode,
|
||||
type OutboundMessage,
|
||||
} from './protocol.js';
|
||||
import {
|
||||
createSystemHandlers,
|
||||
createSessionHandlers,
|
||||
createToolHandlers,
|
||||
createAgentHandlers,
|
||||
} from './handlers/index.js';
|
||||
import type { SessionManager } from '../session/manager.js';
|
||||
import type { ToolRegistry } from '../tools/registry.js';
|
||||
import type { ToolExecutor } from '../tools/executor.js';
|
||||
|
||||
export interface GatewayServerConfig {
|
||||
port: number;
|
||||
host?: string;
|
||||
sessionManager: SessionManager;
|
||||
modelClient: SessionBridgeConfig['modelClient'];
|
||||
systemPrompt: string;
|
||||
toolRegistry: ToolRegistry;
|
||||
toolExecutor: ToolExecutor;
|
||||
version?: string;
|
||||
auth?: AuthConfig;
|
||||
}
|
||||
|
||||
export class GatewayServer {
|
||||
private wss: WebSocketServer | null = null;
|
||||
private router: Router;
|
||||
private sessionBridge: SessionBridge;
|
||||
private connectionMap: Map<WebSocket, string> = new Map();
|
||||
private config: GatewayServerConfig;
|
||||
private startTime: number = Date.now();
|
||||
|
||||
constructor(config: GatewayServerConfig) {
|
||||
this.config = config;
|
||||
|
||||
this.sessionBridge = new SessionBridge({
|
||||
sessionManager: config.sessionManager,
|
||||
modelClient: config.modelClient,
|
||||
systemPrompt: config.systemPrompt,
|
||||
toolRegistry: config.toolRegistry,
|
||||
toolExecutor: config.toolExecutor,
|
||||
});
|
||||
|
||||
this.router = new Router();
|
||||
this.registerHandlers();
|
||||
}
|
||||
|
||||
private registerHandlers(): void {
|
||||
const systemHandlers = createSystemHandlers({
|
||||
startTime: this.startTime,
|
||||
version: this.config.version ?? '0.1.0',
|
||||
getSessionCount: () => this.sessionBridge.listSessions().length,
|
||||
getToolCount: () => this.config.toolRegistry.list().length,
|
||||
getConnectionCount: () => this.sessionBridge.connectionCount,
|
||||
});
|
||||
|
||||
const sessionHandlers = createSessionHandlers({
|
||||
sessionManager: this.config.sessionManager,
|
||||
});
|
||||
|
||||
const toolHandlers = createToolHandlers({
|
||||
toolRegistry: this.config.toolRegistry,
|
||||
toolExecutor: this.config.toolExecutor,
|
||||
});
|
||||
|
||||
const agentHandlers = createAgentHandlers({
|
||||
sessionBridge: this.sessionBridge,
|
||||
});
|
||||
|
||||
// Register all methods
|
||||
for (const [method, handler] of Object.entries(systemHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
for (const [method, handler] of Object.entries(sessionHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
for (const [method, handler] of Object.entries(toolHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
for (const [method, handler] of Object.entries(agentHandlers)) {
|
||||
this.router.register(method, handler);
|
||||
}
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
this.wss = new WebSocketServer({
|
||||
port: this.config.port,
|
||||
host: this.config.host ?? '127.0.0.1',
|
||||
});
|
||||
|
||||
this.wss.on('connection', (ws: WebSocket, req: IncomingMessage) => {
|
||||
// Auth check on upgrade
|
||||
const authResult = authenticateRequest(req, this.config.auth ?? {});
|
||||
if (!authResult.authenticated) {
|
||||
ws.close(4001, authResult.error ?? 'Authentication failed');
|
||||
return;
|
||||
}
|
||||
this.handleConnection(ws, authResult.identity);
|
||||
});
|
||||
|
||||
this.wss.on('listening', () => {
|
||||
const addr = this.wss?.address();
|
||||
const portStr = typeof addr === 'object' && addr ? `:${addr.port}` : '';
|
||||
console.log(`Gateway WebSocket server listening on ${this.config.host ?? '127.0.0.1'}${portStr}`);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
if (!this.wss) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
// Close all connections
|
||||
for (const [ws, connectionId] of this.connectionMap) {
|
||||
this.sessionBridge.disconnect(connectionId);
|
||||
ws.close(1001, 'Server shutting down');
|
||||
}
|
||||
this.connectionMap.clear();
|
||||
|
||||
this.wss.close(() => {
|
||||
this.wss = null;
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private handleConnection(ws: WebSocket, identity?: string): void {
|
||||
const connectionId = randomUUID();
|
||||
this.sessionBridge.connect(connectionId);
|
||||
this.connectionMap.set(ws, connectionId);
|
||||
|
||||
ws.on('message', async (data) => {
|
||||
const raw = data.toString();
|
||||
await this.handleMessage(ws, connectionId, raw);
|
||||
});
|
||||
|
||||
ws.on('close', () => {
|
||||
this.sessionBridge.disconnect(connectionId);
|
||||
this.connectionMap.delete(ws);
|
||||
});
|
||||
|
||||
ws.on('error', (err) => {
|
||||
console.error(`WebSocket error (${connectionId}):`, err.message);
|
||||
});
|
||||
}
|
||||
|
||||
private async handleMessage(ws: WebSocket, connectionId: string, raw: string): Promise<void> {
|
||||
const request = parseMessage(raw);
|
||||
|
||||
if (!request) {
|
||||
this.send(ws, makeError(0, ErrorCode.ParseError, 'Invalid JSON or missing required fields'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Inject connectionId into params so handlers can identify the client
|
||||
if (!request.params) request.params = {};
|
||||
request.params.connectionId = connectionId;
|
||||
|
||||
const send = (msg: OutboundMessage) => this.send(ws, msg);
|
||||
const response = await this.router.dispatch(request, send);
|
||||
|
||||
if (response) {
|
||||
this.send(ws, response);
|
||||
}
|
||||
}
|
||||
|
||||
private send(ws: WebSocket, msg: OutboundMessage): void {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify(msg));
|
||||
}
|
||||
}
|
||||
|
||||
/** Get the underlying WebSocketServer (for testing). */
|
||||
getWss(): WebSocketServer | null {
|
||||
return this.wss;
|
||||
}
|
||||
|
||||
/** Get the session bridge (for testing/debugging). */
|
||||
getSessionBridge(): SessionBridge {
|
||||
return this.sessionBridge;
|
||||
}
|
||||
|
||||
/** Get list of registered methods. */
|
||||
getMethods(): string[] {
|
||||
return this.router.listMethods();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { SessionBridge } from './session-bridge.js';
|
||||
import type { SessionBridgeConfig } from './session-bridge.js';
|
||||
|
||||
// Minimal mocks
|
||||
const mockSession = {
|
||||
id: 'test',
|
||||
addMessage: vi.fn(),
|
||||
getHistory: vi.fn(() => []),
|
||||
clear: vi.fn(),
|
||||
};
|
||||
|
||||
const mockSessionManager = {
|
||||
getSession: vi.fn(() => mockSession),
|
||||
listSessions: vi.fn(() => []),
|
||||
transferSession: vi.fn(),
|
||||
closeSession: vi.fn(),
|
||||
};
|
||||
|
||||
const mockModelClient = {
|
||||
chat: vi.fn(async () => ({
|
||||
content: 'test',
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 0, outputTokens: 0 },
|
||||
})),
|
||||
};
|
||||
|
||||
const mockToolRegistry = {
|
||||
register: vi.fn(),
|
||||
get: vi.fn(),
|
||||
list: vi.fn(() => []),
|
||||
toAnthropicFormat: vi.fn(() => []),
|
||||
toOpenAIFormat: vi.fn(() => []),
|
||||
};
|
||||
|
||||
const mockToolExecutor = {
|
||||
execute: vi.fn(),
|
||||
};
|
||||
|
||||
function createBridge(): SessionBridge {
|
||||
return new SessionBridge({
|
||||
sessionManager: mockSessionManager as unknown as SessionBridgeConfig['sessionManager'],
|
||||
modelClient: mockModelClient,
|
||||
systemPrompt: 'test prompt',
|
||||
toolRegistry: mockToolRegistry as unknown as SessionBridgeConfig['toolRegistry'],
|
||||
toolExecutor: mockToolExecutor as unknown as SessionBridgeConfig['toolExecutor'],
|
||||
});
|
||||
}
|
||||
|
||||
describe('SessionBridge', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('connect assigns a connection ID', () => {
|
||||
const bridge = createBridge();
|
||||
const id = bridge.connect('conn-1');
|
||||
expect(id).toBe('conn-1');
|
||||
expect(bridge.connectionCount).toBe(1);
|
||||
});
|
||||
|
||||
it('connect auto-generates ID when not provided', () => {
|
||||
const bridge = createBridge();
|
||||
const id = bridge.connect();
|
||||
expect(typeof id).toBe('string');
|
||||
expect(id.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('getAgent returns agent for connected client', () => {
|
||||
const bridge = createBridge();
|
||||
bridge.connect('conn-1');
|
||||
const agent = bridge.getAgent('conn-1');
|
||||
expect(agent).toBeDefined();
|
||||
});
|
||||
|
||||
it('getSessionId returns session ID for connected client', () => {
|
||||
const bridge = createBridge();
|
||||
bridge.connect('conn-1');
|
||||
expect(bridge.getSessionId('conn-1')).toBe('ws:conn-1');
|
||||
});
|
||||
|
||||
it('disconnect removes client but preserves session', () => {
|
||||
const bridge = createBridge();
|
||||
bridge.connect('conn-1');
|
||||
bridge.disconnect('conn-1');
|
||||
expect(bridge.connectionCount).toBe(0);
|
||||
expect(bridge.getAgent('conn-1')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('tracks busy state', () => {
|
||||
const bridge = createBridge();
|
||||
bridge.connect('conn-1');
|
||||
expect(bridge.isBusy('conn-1')).toBe(false);
|
||||
bridge.setBusy('conn-1', true);
|
||||
expect(bridge.isBusy('conn-1')).toBe(true);
|
||||
bridge.setBusy('conn-1', false);
|
||||
expect(bridge.isBusy('conn-1')).toBe(false);
|
||||
});
|
||||
|
||||
it('switchSession changes session for a connection', () => {
|
||||
const bridge = createBridge();
|
||||
bridge.connect('conn-1');
|
||||
expect(bridge.getSessionId('conn-1')).toBe('ws:conn-1');
|
||||
|
||||
bridge.switchSession('conn-1', 'custom-session');
|
||||
expect(bridge.getSessionId('conn-1')).toBe('custom-session');
|
||||
});
|
||||
|
||||
it('switchSession throws when busy', () => {
|
||||
const bridge = createBridge();
|
||||
bridge.connect('conn-1');
|
||||
bridge.setBusy('conn-1', true);
|
||||
|
||||
expect(() => bridge.switchSession('conn-1', 'other')).toThrow('Cannot switch session while agent is busy');
|
||||
});
|
||||
|
||||
it('switchSession throws for unknown connection', () => {
|
||||
const bridge = createBridge();
|
||||
expect(() => bridge.switchSession('unknown', 'other')).toThrow('Unknown connection');
|
||||
});
|
||||
|
||||
it('listSessions groups connections by session', () => {
|
||||
const bridge = createBridge();
|
||||
bridge.connect('conn-1');
|
||||
bridge.connect('conn-2');
|
||||
// Switch conn-2 to share conn-1's session
|
||||
bridge.switchSession('conn-2', 'ws:conn-1');
|
||||
|
||||
const sessions = bridge.listSessions();
|
||||
expect(sessions).toEqual([{ sessionId: 'ws:conn-1', connections: 2 }]);
|
||||
});
|
||||
|
||||
it('shared sessions keep agent alive when one client disconnects', () => {
|
||||
const bridge = createBridge();
|
||||
bridge.connect('conn-1');
|
||||
bridge.connect('conn-2');
|
||||
bridge.switchSession('conn-2', 'ws:conn-1');
|
||||
|
||||
bridge.disconnect('conn-1');
|
||||
// conn-2 still has the session
|
||||
expect(bridge.getAgent('conn-2')).toBeDefined();
|
||||
expect(bridge.connectionCount).toBe(1);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,135 @@
|
||||
import { randomUUID } from 'crypto';
|
||||
import type { SessionManager } from '../session/manager.js';
|
||||
import type { Session } from '../session/manager.js';
|
||||
import type { ModelClient } from '../models/types.js';
|
||||
import type { ModelRouter } from '../models/router.js';
|
||||
import type { ToolRegistry } from '../tools/registry.js';
|
||||
import type { ToolExecutor } from '../tools/executor.js';
|
||||
import { NativeAgent } from '../backends/native/agent.js';
|
||||
import type { ToolUseEvent } from '../backends/native/agent.js';
|
||||
|
||||
export interface SessionBridgeConfig {
|
||||
sessionManager: SessionManager;
|
||||
modelClient: ModelClient | ModelRouter;
|
||||
systemPrompt: string;
|
||||
toolRegistry: ToolRegistry;
|
||||
toolExecutor: ToolExecutor;
|
||||
}
|
||||
|
||||
interface ClientEntry {
|
||||
connectionId: string;
|
||||
sessionId: string;
|
||||
agent: NativeAgent;
|
||||
busy: boolean;
|
||||
}
|
||||
|
||||
export class SessionBridge {
|
||||
private clients: Map<string, ClientEntry> = new Map();
|
||||
private agents: Map<string, NativeAgent> = new Map();
|
||||
private config: SessionBridgeConfig;
|
||||
|
||||
constructor(config: SessionBridgeConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
/** Register a new WS connection. Returns the assigned connection ID. */
|
||||
connect(connectionId?: string): string {
|
||||
const id = connectionId ?? randomUUID();
|
||||
const sessionId = `ws:${id}`;
|
||||
const agent = this.getOrCreateAgent(sessionId);
|
||||
|
||||
this.clients.set(id, {
|
||||
connectionId: id,
|
||||
sessionId,
|
||||
agent,
|
||||
busy: false,
|
||||
});
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
/** Remove a WS connection. Does NOT destroy the session (persists in SQLite). */
|
||||
disconnect(connectionId: string): void {
|
||||
const client = this.clients.get(connectionId);
|
||||
if (client) {
|
||||
// Only remove the agent if no other clients share the session
|
||||
const otherClients = Array.from(this.clients.values())
|
||||
.filter(c => c.sessionId === client.sessionId && c.connectionId !== connectionId);
|
||||
if (otherClients.length === 0) {
|
||||
this.agents.delete(client.sessionId);
|
||||
}
|
||||
this.clients.delete(connectionId);
|
||||
}
|
||||
}
|
||||
|
||||
/** Switch a connection to a different session (e.g. resuming an old session). */
|
||||
switchSession(connectionId: string, sessionId: string): void {
|
||||
const client = this.clients.get(connectionId);
|
||||
if (!client) throw new Error(`Unknown connection: ${connectionId}`);
|
||||
if (client.busy) throw new Error('Cannot switch session while agent is busy');
|
||||
|
||||
const agent = this.getOrCreateAgent(sessionId);
|
||||
client.sessionId = sessionId;
|
||||
client.agent = agent;
|
||||
}
|
||||
|
||||
/** Get the NativeAgent for a connection. */
|
||||
getAgent(connectionId: string): NativeAgent | undefined {
|
||||
return this.clients.get(connectionId)?.agent;
|
||||
}
|
||||
|
||||
/** Get the session ID for a connection. */
|
||||
getSessionId(connectionId: string): string | undefined {
|
||||
return this.clients.get(connectionId)?.sessionId;
|
||||
}
|
||||
|
||||
/** Check if a connection's agent is busy. */
|
||||
isBusy(connectionId: string): boolean {
|
||||
return this.clients.get(connectionId)?.busy ?? false;
|
||||
}
|
||||
|
||||
/** Mark a connection's agent as busy/idle. */
|
||||
setBusy(connectionId: string, busy: boolean): void {
|
||||
const client = this.clients.get(connectionId);
|
||||
if (client) client.busy = busy;
|
||||
}
|
||||
|
||||
/** Set onToolUse callback for a connection's agent. */
|
||||
setOnToolUse(connectionId: string, callback: ((event: ToolUseEvent) => void) | undefined): void {
|
||||
const client = this.clients.get(connectionId);
|
||||
if (client) client.agent.setOnToolUse(callback);
|
||||
}
|
||||
|
||||
/** List all active sessions with connection counts. */
|
||||
listSessions(): Array<{ sessionId: string; connections: number }> {
|
||||
const sessionMap = new Map<string, number>();
|
||||
for (const client of this.clients.values()) {
|
||||
sessionMap.set(client.sessionId, (sessionMap.get(client.sessionId) ?? 0) + 1);
|
||||
}
|
||||
return Array.from(sessionMap.entries()).map(([sessionId, connections]) => ({
|
||||
sessionId,
|
||||
connections,
|
||||
}));
|
||||
}
|
||||
|
||||
/** Get count of active connections. */
|
||||
get connectionCount(): number {
|
||||
return this.clients.size;
|
||||
}
|
||||
|
||||
private getOrCreateAgent(sessionId: string): NativeAgent {
|
||||
let agent = this.agents.get(sessionId);
|
||||
if (!agent) {
|
||||
const session = this.config.sessionManager.getSession('ws', sessionId);
|
||||
agent = new NativeAgent({
|
||||
modelClient: this.config.modelClient,
|
||||
systemPrompt: this.config.systemPrompt,
|
||||
session,
|
||||
toolRegistry: this.config.toolRegistry,
|
||||
toolExecutor: this.config.toolExecutor,
|
||||
});
|
||||
this.agents.set(sessionId, agent);
|
||||
}
|
||||
return agent;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user