diff --git a/src/channels/index.ts b/src/channels/index.ts index 11e892b..a4c10db 100644 --- a/src/channels/index.ts +++ b/src/channels/index.ts @@ -15,3 +15,4 @@ export { WebChatAdapter, type WebChatAdapterConfig } from './webchat/index.js'; export { DiscordAdapter, type DiscordAdapterConfig } from './discord/index.js'; export { SlackAdapter, type SlackAdapterConfig } from './slack/index.js'; export { WhatsAppAdapter, type WhatsAppAdapterConfig } from './whatsapp/index.js'; +export { PairingManager, type PairingConfig } from './pairing.js'; diff --git a/src/channels/pairing.test.ts b/src/channels/pairing.test.ts new file mode 100644 index 0000000..095973f --- /dev/null +++ b/src/channels/pairing.test.ts @@ -0,0 +1,159 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { PairingManager } from './pairing.js'; + +describe('PairingManager', () => { + let manager: PairingManager; + + beforeEach(() => { + manager = new PairingManager({ + enabled: true, + codeTtl: 300_000, // 5 minutes + codeLength: 6, + }); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it('generateCode returns a code of the configured length', () => { + const code = manager.generateCode(); + expect(code).toHaveLength(6); + expect(code).toMatch(/^[0-9A-F]+$/); + }); + + it('generateCode with a label stores the label', () => { + const code = manager.generateCode('for alice'); + const pending = manager.listPendingCodes(); + expect(pending).toHaveLength(1); + expect(pending[0].code).toBe(code); + expect(pending[0].label).toBe('for alice'); + }); + + it('validateCode succeeds with a valid code and approves the sender', () => { + const code = manager.generateCode(); + const result = manager.validateCode('telegram', '12345', code); + expect(result).toBe(true); + expect(manager.isApproved('telegram', '12345')).toBe(true); + }); + + it('validateCode fails with an invalid code', () => { + manager.generateCode(); + const result = manager.validateCode('telegram', '12345', 'ZZZZZZ'); + expect(result).toBe(false); + expect(manager.isApproved('telegram', '12345')).toBe(false); + }); + + it('validateCode is case-insensitive', () => { + const code = manager.generateCode(); + const result = manager.validateCode('telegram', '12345', code.toLowerCase()); + expect(result).toBe(true); + expect(manager.isApproved('telegram', '12345')).toBe(true); + }); + + it('validateCode fails with an expired code', () => { + vi.useFakeTimers(); + const code = manager.generateCode(); + + // Advance time past the TTL + vi.advanceTimersByTime(300_001); + + const result = manager.validateCode('telegram', '12345', code); + expect(result).toBe(false); + expect(manager.isApproved('telegram', '12345')).toBe(false); + }); + + it('validateCode removes the code after successful use', () => { + const code = manager.generateCode(); + manager.validateCode('telegram', '12345', code); + + // Code should no longer be pending + expect(manager.listPendingCodes()).toHaveLength(0); + + // Second use of the same code should fail + const result = manager.validateCode('discord', '67890', code); + expect(result).toBe(false); + }); + + it('isApproved returns true after validation', () => { + const code = manager.generateCode(); + manager.validateCode('telegram', '12345', code); + expect(manager.isApproved('telegram', '12345')).toBe(true); + }); + + it('isApproved returns false for unapproved senders', () => { + expect(manager.isApproved('telegram', 'unknown')).toBe(false); + }); + + it('isApproved distinguishes between channels', () => { + const code = manager.generateCode(); + manager.validateCode('telegram', '12345', code); + expect(manager.isApproved('telegram', '12345')).toBe(true); + expect(manager.isApproved('discord', '12345')).toBe(false); + }); + + it('revokeApproval removes approval', () => { + const code = manager.generateCode(); + manager.validateCode('telegram', '12345', code); + expect(manager.isApproved('telegram', '12345')).toBe(true); + + const revoked = manager.revokeApproval('telegram', '12345'); + expect(revoked).toBe(true); + expect(manager.isApproved('telegram', '12345')).toBe(false); + }); + + it('revokeApproval returns false for non-existent sender', () => { + const revoked = manager.revokeApproval('telegram', 'nonexistent'); + expect(revoked).toBe(false); + }); + + it('listApproved returns all approved senders', () => { + const code1 = manager.generateCode(); + const code2 = manager.generateCode(); + manager.validateCode('telegram', '111', code1); + manager.validateCode('discord', '222', code2); + + const approved = manager.listApproved(); + expect(approved).toHaveLength(2); + expect(approved.map(a => a.senderId)).toContain('111'); + expect(approved.map(a => a.senderId)).toContain('222'); + }); + + it('listPendingCodes returns only non-expired codes', () => { + vi.useFakeTimers(); + const code1 = manager.generateCode('first'); + + // Advance time so the first code is almost expired + vi.advanceTimersByTime(200_000); + const code2 = manager.generateCode('second'); + + // Advance past first code's expiry + vi.advanceTimersByTime(100_001); + + const pending = manager.listPendingCodes(); + expect(pending).toHaveLength(1); + expect(pending[0].code).toBe(code2); + expect(pending[0].label).toBe('second'); + }); + + it('cleanup removes expired codes', () => { + vi.useFakeTimers(); + manager.generateCode(); + + vi.advanceTimersByTime(300_001); + manager.cleanup(); + + expect(manager.listPendingCodes()).toHaveLength(0); + }); + + it('enabled getter reflects config', () => { + expect(manager.enabled).toBe(true); + + const disabled = new PairingManager({ + enabled: false, + codeTtl: 300_000, + codeLength: 6, + }); + expect(disabled.enabled).toBe(false); + }); +}); diff --git a/src/channels/pairing.ts b/src/channels/pairing.ts new file mode 100644 index 0000000..5993ced --- /dev/null +++ b/src/channels/pairing.ts @@ -0,0 +1,133 @@ +import { randomBytes } from 'crypto'; + +export interface PairingConfig { + enabled: boolean; + codeTtl: number; // milliseconds + codeLength: number; // number of characters +} + +interface PendingCode { + code: string; + createdAt: number; + expiresAt: number; + /** Optional label for the code (e.g. "for alice"). */ + label?: string; +} + +interface ApprovedSender { + channel: string; + senderId: string; + approvedAt: number; + /** The code that was used. */ + codeUsed: string; +} + +/** + * Manages DM pairing codes for authenticating unknown senders. + * + * Flow: + * 1. Admin generates a pairing code via gateway API or TUI command. + * 2. Unknown sender DMs the bot with the code as their first message. + * 3. If the code is valid and not expired, the sender is approved. + * 4. Approved senders bypass the allowlist check for subsequent messages. + */ +export class PairingManager { + private config: PairingConfig; + private pendingCodes: Map = new Map(); + private approvedSenders: Map = new Map(); + + constructor(config: PairingConfig) { + this.config = config; + } + + /** Generate a new pairing code. Returns the code string. */ + generateCode(label?: string): string { + this.cleanup(); + const code = randomBytes(Math.ceil(this.config.codeLength / 2)) + .toString('hex') + .slice(0, this.config.codeLength) + .toUpperCase(); + + const now = Date.now(); + this.pendingCodes.set(code, { + code, + createdAt: now, + expiresAt: now + this.config.codeTtl, + label, + }); + + return code; + } + + /** + * Validate a code for a given channel+sender. + * If valid, adds the sender to the approved list and removes the code. + * Returns true if the code was valid. + */ + validateCode(channel: string, senderId: string, code: string): boolean { + this.cleanup(); + const normalizedCode = code.trim().toUpperCase(); + const pending = this.pendingCodes.get(normalizedCode); + + if (!pending) return false; + if (Date.now() > pending.expiresAt) { + this.pendingCodes.delete(normalizedCode); + return false; + } + + // Code is valid — approve the sender + const key = `${channel}:${senderId}`; + this.approvedSenders.set(key, { + channel, + senderId, + approvedAt: Date.now(), + codeUsed: normalizedCode, + }); + + // Remove the used code + this.pendingCodes.delete(normalizedCode); + return true; + } + + /** Check if a sender is already approved. */ + isApproved(channel: string, senderId: string): boolean { + const key = `${channel}:${senderId}`; + return this.approvedSenders.has(key); + } + + /** Revoke approval for a sender. Returns true if the sender was found and removed. */ + revokeApproval(channel: string, senderId: string): boolean { + const key = `${channel}:${senderId}`; + return this.approvedSenders.delete(key); + } + + /** List all currently approved senders. */ + listApproved(): ApprovedSender[] { + return Array.from(this.approvedSenders.values()); + } + + /** List all pending (non-expired) codes. */ + listPendingCodes(): Array<{ code: string; expiresAt: number; label?: string }> { + this.cleanup(); + return Array.from(this.pendingCodes.values()).map(p => ({ + code: p.code, + expiresAt: p.expiresAt, + label: p.label, + })); + } + + /** Remove expired codes. */ + cleanup(): void { + const now = Date.now(); + for (const [code, pending] of this.pendingCodes) { + if (now > pending.expiresAt) { + this.pendingCodes.delete(code); + } + } + } + + /** Whether pairing is enabled. */ + get enabled(): boolean { + return this.config.enabled; + } +} diff --git a/src/gateway/server.test.ts b/src/gateway/server.test.ts index 7bd37d1..0e76110 100644 --- a/src/gateway/server.test.ts +++ b/src/gateway/server.test.ts @@ -219,6 +219,102 @@ describe('GatewayServer integration', () => { }); }); +describe('GatewayServer lock mode', () => { + const LOCK_PORT = 18897; + let lockServer: GatewayServer; + + beforeAll(async () => { + lockServer = new GatewayServer({ + port: LOCK_PORT, + sessionManager: mockSessionManager as unknown as GatewayServerConfig['sessionManager'], + modelClient: mockModelClient, + systemPrompt: 'Test prompt', + toolRegistry: mockToolRegistry as unknown as GatewayServerConfig['toolRegistry'], + toolExecutor: mockToolExecutor as unknown as GatewayServerConfig['toolExecutor'], + lock: true, + uiDir: resolve(import.meta.dirname, 'ui'), + }); + await lockServer.start(); + }); + + afterAll(async () => { + await lockServer.stop(); + }); + + function createLockClient(): Promise { + return new Promise((resolve, reject) => { + const ws = new WebSocket(`ws://127.0.0.1:${LOCK_PORT}`); + ws.on('open', () => resolve(ws)); + ws.on('error', reject); + }); + } + + it('allows the first client to connect', async () => { + const ws = await createLockClient(); + try { + const result = await sendAndReceive(ws, { id: 1, method: 'system.health' }); + const response = result as GatewayResponse; + expect((response.result as any).status).toBe('ok'); + } finally { + ws.close(); + // Wait for the close to propagate so connectionMap is empty + await new Promise(r => setTimeout(r, 100)); + } + }); + + it('rejects second client with code 4003 when locked', async () => { + const ws1 = await createLockClient(); + try { + // Second client should be rejected + const closePromise = new Promise<{ code: number; reason: string }>((resolve) => { + const ws2 = new WebSocket(`ws://127.0.0.1:${LOCK_PORT}`); + ws2.on('close', (code, reason) => { + resolve({ code, reason: reason.toString() }); + }); + }); + + const { code, reason } = await closePromise; + expect(code).toBe(4003); + expect(reason).toContain('locked'); + } finally { + ws1.close(); + await new Promise(r => setTimeout(r, 100)); + } + }); + + it('allows a new client after the previous one disconnects', async () => { + const ws1 = await createLockClient(); + ws1.close(); + // Wait for the close to propagate + await new Promise(r => setTimeout(r, 100)); + + const ws2 = await createLockClient(); + try { + const result = await sendAndReceive(ws2, { id: 2, method: 'system.health' }); + const response = result as GatewayResponse; + expect((response.result as any).status).toBe('ok'); + } finally { + ws2.close(); + await new Promise(r => setTimeout(r, 100)); + } + }); + + it('system.lock handler returns lock status', async () => { + const ws = await createLockClient(); + try { + const result = await sendAndReceive(ws, { id: 3, method: 'system.lock' }); + const response = result as GatewayResponse; + const r = response.result as { locked: boolean; activeClients: number; maxClients: number | null }; + expect(r.locked).toBe(true); + expect(r.activeClients).toBe(1); + expect(r.maxClients).toBe(1); + } finally { + ws.close(); + await new Promise(r => setTimeout(r, 100)); + } + }); +}); + describe('GatewayServer HTTP auth', () => { const AUTH_PORT = 18898; let authServer: GatewayServer;