diff --git a/src/backends/native/agent.ts b/src/backends/native/agent.ts index 952a33e..7c75766 100644 --- a/src/backends/native/agent.ts +++ b/src/backends/native/agent.ts @@ -4,6 +4,7 @@ import type { Session } from '../../session/index.js'; import type { ToolRegistry } from '../../tools/registry.js'; import type { ToolExecutor } from '../../tools/executor.js'; import type { ToolResult } from '../../tools/types.js'; +import type { ToolPolicyContext } from '../../tools/policy.js'; export interface ToolUseEvent { type: 'start' | 'end'; @@ -20,6 +21,8 @@ export interface NativeAgentConfig { toolExecutor?: ToolExecutor; maxIterations?: number; onToolUse?: (event: ToolUseEvent) => void; + /** Policy context for tool filtering (agent tier, provider). */ + toolPolicyContext?: ToolPolicyContext; } // Internal message type for the tool loop — supports both text and structured content blocks. @@ -41,6 +44,7 @@ export class NativeAgent { private onToolUse?: (event: ToolUseEvent) => void; private _totalUsage: TokenUsage = { inputTokens: 0, outputTokens: 0 }; private _callCount: number = 0; + private _toolPolicyContext?: ToolPolicyContext; constructor(config: NativeAgentConfig) { this.modelClient = config.modelClient; @@ -50,6 +54,7 @@ export class NativeAgent { this.toolExecutor = config.toolExecutor; this.maxIterations = config.maxIterations ?? 10; this.onToolUse = config.onToolUse; + this._toolPolicyContext = config.toolPolicyContext; } private get history(): Message[] { @@ -96,7 +101,7 @@ export class NativeAgent { } private async toolLoop(): Promise { - const tools = this.toolRegistry!.toAnthropicFormat(); + const tools = this.toolRegistry!.filteredToAnthropicFormat(this._toolPolicyContext); // Build the loop messages from existing history. // These are the messages sent to the model, including any structured tool blocks. @@ -151,7 +156,7 @@ export class NativeAgent { for (const tc of response.toolCalls) { this.onToolUse?.({ type: 'start', tool: tc.name, args: tc.args }); - const result = await this.toolExecutor!.execute(tc.name, tc.args); + const result = await this.toolExecutor!.execute(tc.name, tc.args, this._toolPolicyContext); this.onToolUse?.({ type: 'end', tool: tc.name, result }); @@ -226,4 +231,12 @@ export class NativeAgent { setOnToolUse(callback: ((event: ToolUseEvent) => void) | undefined): void { this.onToolUse = callback; } + + setToolPolicyContext(context: ToolPolicyContext | undefined): void { + this._toolPolicyContext = context; + } + + getToolPolicyContext(): ToolPolicyContext | undefined { + return this._toolPolicyContext; + } } diff --git a/src/backends/native/orchestrator.ts b/src/backends/native/orchestrator.ts index abf82ea..7392719 100644 --- a/src/backends/native/orchestrator.ts +++ b/src/backends/native/orchestrator.ts @@ -4,6 +4,7 @@ import type { Session } from '../../session/index.js'; import type { ToolRegistry } from '../../tools/registry.js'; import type { ToolExecutor } from '../../tools/executor.js'; import type { MemoryStore } from '../../memory/store.js'; +import type { ToolPolicyContext } from '../../tools/policy.js'; import { NativeAgent } from './agent.js'; import type { ToolUseEvent } from './agent.js'; import { shouldCompact } from '../../context/tokens.js'; @@ -87,6 +88,8 @@ export interface OrchestratorConfig { contextWindow?: number; /** Optional memory store for injecting persistent memory into the system prompt. */ memoryStore?: MemoryStore; + /** Policy context for tool filtering (agent tier, provider). */ + toolPolicyContext?: ToolPolicyContext; } // ── AgentOrchestrator ───────────────────────────────────────────────── @@ -134,6 +137,7 @@ export class AgentOrchestrator { toolExecutor: config.toolExecutor, maxIterations: config.maxIterations, onToolUse: config.onToolUse, + toolPolicyContext: config.toolPolicyContext, }); // Set the primary tier on the agent @@ -174,9 +178,10 @@ export class AgentOrchestrator { maxTokens: request.maxTokens, }; - // Optionally include tools from the registry + // Optionally include tools from the registry (filtered by policy) if (request.tools && this._toolRegistry) { - chatRequest.tools = this._toolRegistry.toAnthropicFormat(); + const policyContext = this._agent.getToolPolicyContext(); + chatRequest.tools = this._toolRegistry.filteredToAnthropicFormat(policyContext); } const response = await this._modelRouter.chat(chatRequest, tier); diff --git a/src/config/index.ts b/src/config/index.ts index fe7493e..45e4ce5 100644 --- a/src/config/index.ts +++ b/src/config/index.ts @@ -1,2 +1,2 @@ export { loadConfig } from './loader.js'; -export { configSchema, type Config, type TelegramConfig, type ModelConfig, type CronJobConfig, type AgentsConfig, type CompactionConfig } from './schema.js'; +export { configSchema, type Config, type TelegramConfig, type ModelConfig, type CronJobConfig, type AgentsConfig, type CompactionConfig, type ToolProfile, type ToolOverrideConfig, type ToolsConfig } from './schema.js'; diff --git a/src/config/schema.ts b/src/config/schema.ts index 88505b2..1dd90bb 100644 --- a/src/config/schema.ts +++ b/src/config/schema.ts @@ -161,6 +161,24 @@ const webSearchSchema = z.object({ max_results: z.number().min(1).max(20).default(5), }).default({}); +// ── Tool policy schemas ────────────────────────────────────────────── + +const toolProfileEnum = z.enum(['minimal', 'messaging', 'coding', 'full']); + +const toolOverrideSchema = z.object({ + profile: toolProfileEnum.optional(), + allow: z.array(z.string()).default([]), + deny: z.array(z.string()).default([]), +}).default({}); + +const toolsSchema = z.object({ + profile: toolProfileEnum.default('full'), + allow: z.array(z.string()).default([]), + deny: z.array(z.string()).default([]), + agents: z.record(z.string(), toolOverrideSchema).default({}), + providers: z.record(z.string(), toolOverrideSchema).default({}), +}).default({}); + const promptSchema = z.object({ /** Additional directories to search for prompt template files. */ search_dirs: z.array(z.string()).default([]), @@ -190,6 +208,7 @@ export const configSchema = z.object({ retry: retrySchema, web_search: webSearchSchema, prompt: promptSchema, + tools: toolsSchema, }); export type Config = z.infer; @@ -206,3 +225,6 @@ export type SlackConfig = z.infer; export type WhatsAppConfig = z.infer; export type RetryPolicyConfig = z.infer; export type PromptConfig = z.infer; +export type ToolProfile = z.infer; +export type ToolOverrideConfig = z.infer; +export type ToolsConfig = z.infer; diff --git a/src/daemon/index.ts b/src/daemon/index.ts index f7a2fb6..bc4dee6 100644 --- a/src/daemon/index.ts +++ b/src/daemon/index.ts @@ -5,7 +5,7 @@ import type { ModelClient, RetryConfig } from '../models/index.js'; import { AgentOrchestrator, type DelegationConfig } from '../backends/index.js'; import { SessionStore, SessionManager } from '../session/index.js'; import { HookEngine } from '../hooks/index.js'; -import { ToolRegistry, ToolExecutor, allBuiltinTools, createWebSearchTools, createProcessTools, ProcessManager } from '../tools/index.js'; +import { ToolRegistry, ToolExecutor, ToolPolicy, allBuiltinTools, createWebSearchTools, createProcessTools, ProcessManager } from '../tools/index.js'; import { MemoryStore } from '../memory/index.js'; import { createMemoryTools } from '../tools/builtin/index.js'; import { GatewayServer } from '../gateway/index.js'; @@ -197,6 +197,10 @@ function createMessageRouter(deps: { modelName: deps.config.models.default.model, contextWindow: deps.config.models.default.context_window, memoryStore: deps.memoryStore, + toolPolicyContext: { + agent: deps.config.agents.primary_tier ?? 'default', + provider: deps.config.models.default.provider, + }, }); agents.set(sessionId, agent); } @@ -338,6 +342,15 @@ export async function startDaemon(config: Config): Promise { const toolExecutor = new ToolExecutor(toolRegistry, hookEngine); + // Initialize tool policy from config + const toolPolicy = new ToolPolicy(config.tools); + toolRegistry.setPolicy(toolPolicy); + + const effectiveProfile = toolPolicy.getEffectiveProfile(); + if (effectiveProfile !== 'full') { + console.log(`Tool policy: profile=${effectiveProfile}, deny=[${config.tools.deny.join(', ')}]`); + } + // Initialize MCP manager and start configured servers const mcpManager = new McpManager(toolRegistry); diff --git a/src/gateway/handlers/handlers.test.ts b/src/gateway/handlers/handlers.test.ts index 4d4fd2c..4c4eac3 100644 --- a/src/gateway/handlers/handlers.test.ts +++ b/src/gateway/handlers/handlers.test.ts @@ -120,6 +120,7 @@ describe('tool handlers', () => { const mockRegistry = { list: vi.fn(() => [mockTool]), + filteredList: vi.fn(() => [mockTool]), get: vi.fn((name: string) => (name === 'test.tool' ? mockTool : undefined)), register: vi.fn(), toAnthropicFormat: vi.fn(), @@ -138,6 +139,7 @@ describe('tool handlers', () => { beforeEach(() => { vi.clearAllMocks(); mockRegistry.list.mockReturnValue([mockTool]); + mockRegistry.filteredList.mockReturnValue([mockTool]); mockRegistry.get.mockImplementation((name: string) => (name === 'test.tool' ? mockTool : undefined)); mockExecutor.execute.mockResolvedValue({ success: true, output: 'done' }); }); diff --git a/src/gateway/handlers/tools.ts b/src/gateway/handlers/tools.ts index f5e6e94..d8eb3e9 100644 --- a/src/gateway/handlers/tools.ts +++ b/src/gateway/handlers/tools.ts @@ -11,7 +11,8 @@ export interface ToolHandlerDeps { export function createToolHandlers(deps: ToolHandlerDeps) { return { 'tools.list': async (request: GatewayRequest): Promise => { - const tools = deps.toolRegistry.list().map(t => ({ + // Use filteredList to respect tool policy (gateway context has no agent/provider) + const tools = deps.toolRegistry.filteredList().map(t => ({ name: t.name, description: t.description, inputSchema: t.inputSchema, @@ -30,6 +31,7 @@ export function createToolHandlers(deps: ToolHandlerDeps) { return makeError(request.id, ErrorCode.ToolNotFound, `Tool not found: ${params.tool}`); } + // Pass no context — gateway uses global policy only const result = await deps.toolExecutor.execute(params.tool, params.args ?? {}); return makeResponse(request.id, result); }, diff --git a/src/gateway/server.test.ts b/src/gateway/server.test.ts index ef408c7..f83906d 100644 --- a/src/gateway/server.test.ts +++ b/src/gateway/server.test.ts @@ -35,8 +35,11 @@ const mockToolRegistry = { register: vi.fn(), get: vi.fn((name: string) => (name === 'shell.exec' ? { name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } } : undefined)), list: vi.fn(() => [{ name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } }]), + filteredList: vi.fn(() => [{ name: 'shell.exec', description: 'Run shell', inputSchema: { type: 'object', properties: {} } }]), toAnthropicFormat: vi.fn(() => []), toOpenAIFormat: vi.fn(() => []), + filteredToAnthropicFormat: vi.fn(() => []), + filteredToOpenAIFormat: vi.fn(() => []), }; const mockToolExecutor = { diff --git a/src/tools/executor.ts b/src/tools/executor.ts index 7a9501c..992498f 100644 --- a/src/tools/executor.ts +++ b/src/tools/executor.ts @@ -1,6 +1,7 @@ import type { ToolResult } from './types.js'; import type { ToolRegistry } from './registry.js'; import type { HookEngine } from '../hooks/engine.js'; +import type { ToolPolicyContext } from './policy.js'; export interface ToolExecutorConfig { defaultTimeoutMs?: number; @@ -20,12 +21,25 @@ export class ToolExecutor { this.maxOutputBytes = config?.maxOutputBytes ?? 51_200; } - async execute(toolName: string, args: unknown): Promise { + async execute(toolName: string, args: unknown, context?: ToolPolicyContext): Promise { const tool = this.registry.get(toolName); if (!tool) { return { success: false, output: '', error: `Tool '${toolName}' not found` }; } + // Policy check (defense in depth — tools should also be filtered at listing time) + const policy = this.registry.getPolicy(); + if (policy) { + const allNames = this.registry.list().map(t => t.name); + if (!policy.isAllowed(toolName, allNames, context)) { + return { + success: false, + output: '', + error: `Tool '${toolName}' is not allowed by tool policy`, + }; + } + } + // Check hooks const action = this.hooks.getAction(toolName); if (action === 'confirm') { diff --git a/src/tools/index.ts b/src/tools/index.ts index c22d8da..dbf7eb9 100644 --- a/src/tools/index.ts +++ b/src/tools/index.ts @@ -3,6 +3,8 @@ export { ToolRegistry } from './registry.js'; export type { AnthropicToolDef, OpenAIToolDef } from './registry.js'; export { ToolExecutor } from './executor.js'; export type { ToolExecutorConfig } from './executor.js'; +export { ToolPolicy } from './policy.js'; +export type { ToolPolicyContext } from './policy.js'; export { allBuiltinTools, createWebSearchTools, createProcessTools, ProcessManager } from './builtin/index.js'; export type { WebSearchConfig } from './builtin/web-search.js'; export type { ProcessManagerConfig } from './builtin/process/index.js'; diff --git a/src/tools/policy.test.ts b/src/tools/policy.test.ts new file mode 100644 index 0000000..002a97b --- /dev/null +++ b/src/tools/policy.test.ts @@ -0,0 +1,442 @@ +import { describe, it, expect } from 'vitest'; +import { ToolPolicy, PROFILE_TOOLS, matchesAnyPattern } from './policy.js'; +import type { ToolsConfig } from '../config/schema.js'; +import type { Tool } from './types.js'; + +// ── Helpers ───────────────────────────────────────────────────────── + +/** All tool names that would be in a fully loaded Flynn instance. */ +const ALL_TOOL_NAMES = [ + 'shell.exec', + 'file.read', + 'file.write', + 'file.edit', + 'file.list', + 'web.fetch', + 'web.search', + 'memory.read', + 'memory.write', + 'memory.search', + 'process.start', + 'process.status', + 'process.output', + 'process.kill', + 'process.list', + 'mcp:filesystem:read_file', + 'mcp:filesystem:write_file', +]; + +function makeTool(name: string): Tool { + return { + name, + description: `Mock ${name}`, + inputSchema: { type: 'object', properties: {} }, + execute: async () => ({ success: true, output: '' }), + }; +} + +const ALL_TOOLS = ALL_TOOL_NAMES.map(makeTool); + +function defaultConfig(overrides: Partial = {}): ToolsConfig { + return { + profile: 'full', + allow: [], + deny: [], + agents: {}, + providers: {}, + ...overrides, + }; +} + +// ── matchesAnyPattern ─────────────────────────────────────────────── + +describe('matchesAnyPattern', () => { + it('matches exact names', () => { + expect(matchesAnyPattern('shell.exec', ['shell.exec'])).toBe(true); + expect(matchesAnyPattern('shell.exec', ['file.read'])).toBe(false); + }); + + it('matches wildcard patterns', () => { + expect(matchesAnyPattern('file.read', ['file.*'])).toBe(true); + expect(matchesAnyPattern('file.write', ['file.*'])).toBe(true); + expect(matchesAnyPattern('shell.exec', ['file.*'])).toBe(false); + }); + + it('matches mcp wildcard patterns', () => { + expect(matchesAnyPattern('mcp:filesystem:read_file', ['mcp:*'])).toBe(true); + expect(matchesAnyPattern('mcp:filesystem:read_file', ['mcp:filesystem:*'])).toBe(true); + expect(matchesAnyPattern('shell.exec', ['mcp:*'])).toBe(false); + }); + + it('matches catch-all wildcard', () => { + expect(matchesAnyPattern('anything', ['*'])).toBe(true); + }); +}); + +// ── Profile definitions ───────────────────────────────────────────── + +describe('PROFILE_TOOLS', () => { + it('minimal contains only read-only tools', () => { + expect(PROFILE_TOOLS.minimal.has('file.read')).toBe(true); + expect(PROFILE_TOOLS.minimal.has('file.list')).toBe(true); + expect(PROFILE_TOOLS.minimal.has('web.fetch')).toBe(true); + expect(PROFILE_TOOLS.minimal.has('shell.exec')).toBe(false); + expect(PROFILE_TOOLS.minimal.has('file.write')).toBe(false); + }); + + it('messaging is a superset of minimal', () => { + for (const tool of PROFILE_TOOLS.minimal) { + expect(PROFILE_TOOLS.messaging.has(tool)).toBe(true); + } + expect(PROFILE_TOOLS.messaging.has('memory.read')).toBe(true); + expect(PROFILE_TOOLS.messaging.has('web.search')).toBe(true); + }); + + it('coding is a superset of messaging', () => { + for (const tool of PROFILE_TOOLS.messaging) { + expect(PROFILE_TOOLS.coding.has(tool)).toBe(true); + } + expect(PROFILE_TOOLS.coding.has('shell.exec')).toBe(true); + expect(PROFILE_TOOLS.coding.has('file.write')).toBe(true); + expect(PROFILE_TOOLS.coding.has('process.start')).toBe(true); + }); + + it('full is empty (special: matches everything)', () => { + expect(PROFILE_TOOLS.full.size).toBe(0); + }); +}); + +// ── ToolPolicy ────────────────────────────────────────────────────── + +describe('ToolPolicy', () => { + describe('default config (full profile)', () => { + it('allows all tools when profile is full', () => { + const policy = new ToolPolicy(defaultConfig()); + const result = policy.filterTools(ALL_TOOLS); + expect(result).toHaveLength(ALL_TOOLS.length); + }); + + it('allows all tool names when profile is full', () => { + const policy = new ToolPolicy(defaultConfig()); + const allowed = policy.resolveAllowedNames(ALL_TOOL_NAMES); + expect(allowed.size).toBe(ALL_TOOL_NAMES.length); + }); + }); + + describe('profile filtering', () => { + it('minimal profile only allows read-only tools', () => { + const policy = new ToolPolicy(defaultConfig({ profile: 'minimal' })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + + expect(names).toContain('file.read'); + expect(names).toContain('file.list'); + expect(names).toContain('web.fetch'); + expect(names).not.toContain('shell.exec'); + expect(names).not.toContain('file.write'); + expect(names).not.toContain('memory.read'); + expect(names).not.toContain('mcp:filesystem:read_file'); + }); + + it('messaging profile includes memory and web search', () => { + const policy = new ToolPolicy(defaultConfig({ profile: 'messaging' })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + + expect(names).toContain('memory.read'); + expect(names).toContain('memory.write'); + expect(names).toContain('web.search'); + expect(names).not.toContain('shell.exec'); + expect(names).not.toContain('file.write'); + }); + + it('coding profile includes file writes and shell', () => { + const policy = new ToolPolicy(defaultConfig({ profile: 'coding' })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + + expect(names).toContain('shell.exec'); + expect(names).toContain('file.write'); + expect(names).toContain('file.edit'); + expect(names).toContain('process.start'); + // MCP tools are not in the coding profile by default + expect(names).not.toContain('mcp:filesystem:read_file'); + }); + }); + + describe('global allow list', () => { + it('adds specific tools beyond profile', () => { + const policy = new ToolPolicy(defaultConfig({ + profile: 'minimal', + allow: ['shell.exec'], + })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + + expect(names).toContain('file.read'); + expect(names).toContain('shell.exec'); + expect(names).not.toContain('file.write'); + }); + + it('adds tools matching glob patterns', () => { + const policy = new ToolPolicy(defaultConfig({ + profile: 'minimal', + allow: ['mcp:*'], + })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + + expect(names).toContain('mcp:filesystem:read_file'); + expect(names).toContain('mcp:filesystem:write_file'); + expect(names).not.toContain('shell.exec'); + }); + }); + + describe('global deny list', () => { + it('removes tools from full profile', () => { + const policy = new ToolPolicy(defaultConfig({ + deny: ['shell.exec'], + })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + + expect(names).not.toContain('shell.exec'); + expect(names).toContain('file.read'); + }); + + it('removes tools matching glob patterns', () => { + const policy = new ToolPolicy(defaultConfig({ + deny: ['mcp:*'], + })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + + expect(names).not.toContain('mcp:filesystem:read_file'); + expect(names).not.toContain('mcp:filesystem:write_file'); + expect(names).toContain('shell.exec'); + }); + + it('deny wins over allow', () => { + const policy = new ToolPolicy(defaultConfig({ + profile: 'minimal', + allow: ['shell.exec'], + deny: ['shell.exec'], + })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + + expect(names).not.toContain('shell.exec'); + }); + }); + + describe('agent overrides', () => { + it('restricts tools for a specific agent tier', () => { + const policy = new ToolPolicy(defaultConfig({ + agents: { + fast: { profile: 'minimal', allow: [], deny: [] }, + }, + })); + + // Without agent context, full profile applies + const allResult = policy.filterTools(ALL_TOOLS); + expect(allResult).toHaveLength(ALL_TOOLS.length); + + // With agent context, minimal profile applies + const fastResult = policy.filterTools(ALL_TOOLS, { agent: 'fast' }); + const fastNames = fastResult.map(t => t.name); + expect(fastNames).toContain('file.read'); + expect(fastNames).not.toContain('shell.exec'); + }); + + it('agent deny removes tools from agent set', () => { + const policy = new ToolPolicy(defaultConfig({ + agents: { + fast: { profile: 'minimal', allow: [], deny: ['web.fetch'] }, + }, + })); + + const result = policy.filterTools(ALL_TOOLS, { agent: 'fast' }); + const names = result.map(t => t.name); + expect(names).toContain('file.read'); + expect(names).not.toContain('web.fetch'); + }); + + it('agent allow adds tools beyond agent profile', () => { + const policy = new ToolPolicy(defaultConfig({ + agents: { + complex: { profile: 'coding', allow: ['mcp:*'], deny: [] }, + }, + })); + + const result = policy.filterTools(ALL_TOOLS, { agent: 'complex' }); + const names = result.map(t => t.name); + expect(names).toContain('shell.exec'); + expect(names).toContain('mcp:filesystem:read_file'); + }); + + it('agent override intersects with global — cannot add what global denies', () => { + const policy = new ToolPolicy(defaultConfig({ + deny: ['shell.exec'], + agents: { + complex: { profile: 'coding', allow: ['shell.exec'], deny: [] }, + }, + })); + + const result = policy.filterTools(ALL_TOOLS, { agent: 'complex' }); + const names = result.map(t => t.name); + // Global deny of shell.exec overrides agent allow + expect(names).not.toContain('shell.exec'); + }); + + it('unknown agent tier has no effect', () => { + const policy = new ToolPolicy(defaultConfig()); + const result = policy.filterTools(ALL_TOOLS, { agent: 'nonexistent' }); + expect(result).toHaveLength(ALL_TOOLS.length); + }); + }); + + describe('provider overrides', () => { + it('restricts tools for a specific provider', () => { + const policy = new ToolPolicy(defaultConfig({ + providers: { + ollama: { profile: 'minimal', allow: [], deny: [] }, + }, + })); + + const result = policy.filterTools(ALL_TOOLS, { provider: 'ollama' }); + const names = result.map(t => t.name); + expect(names).toContain('file.read'); + expect(names).not.toContain('shell.exec'); + }); + + it('provider deny takes effect', () => { + const policy = new ToolPolicy(defaultConfig({ + providers: { + ollama: { profile: 'messaging', allow: [], deny: ['web.search'] }, + }, + })); + + const result = policy.filterTools(ALL_TOOLS, { provider: 'ollama' }); + const names = result.map(t => t.name); + expect(names).toContain('memory.read'); + expect(names).not.toContain('web.search'); + }); + }); + + describe('combined agent + provider', () => { + it('intersects agent and provider restrictions', () => { + const policy = new ToolPolicy(defaultConfig({ + agents: { + fast: { profile: 'messaging', allow: [], deny: [] }, + }, + providers: { + ollama: { profile: 'coding', allow: [], deny: [] }, + }, + })); + + // Fast agent has messaging tools, ollama provider has coding tools. + // Intersection = messaging tools (subset of coding). + const result = policy.filterTools(ALL_TOOLS, { agent: 'fast', provider: 'ollama' }); + const names = result.map(t => t.name); + expect(names).toContain('file.read'); + expect(names).toContain('memory.read'); + expect(names).not.toContain('shell.exec'); // in coding but not messaging + }); + }); + + describe('isAllowed', () => { + it('returns true for allowed tools', () => { + const policy = new ToolPolicy(defaultConfig()); + expect(policy.isAllowed('shell.exec', ALL_TOOL_NAMES)).toBe(true); + }); + + it('returns false for denied tools', () => { + const policy = new ToolPolicy(defaultConfig({ deny: ['shell.exec'] })); + expect(policy.isAllowed('shell.exec', ALL_TOOL_NAMES)).toBe(false); + }); + + it('respects context', () => { + const policy = new ToolPolicy(defaultConfig({ + agents: { fast: { profile: 'minimal', allow: [], deny: [] } }, + })); + expect(policy.isAllowed('shell.exec', ALL_TOOL_NAMES, { agent: 'fast' })).toBe(false); + expect(policy.isAllowed('file.read', ALL_TOOL_NAMES, { agent: 'fast' })).toBe(true); + }); + }); + + describe('getEffectiveProfile', () => { + it('returns global profile by default', () => { + const policy = new ToolPolicy(defaultConfig({ profile: 'coding' })); + expect(policy.getEffectiveProfile()).toBe('coding'); + }); + + it('returns agent profile override', () => { + const policy = new ToolPolicy(defaultConfig({ + profile: 'full', + agents: { fast: { profile: 'minimal', allow: [], deny: [] } }, + })); + expect(policy.getEffectiveProfile({ agent: 'fast' })).toBe('minimal'); + }); + + it('returns provider profile override', () => { + const policy = new ToolPolicy(defaultConfig({ + providers: { ollama: { profile: 'messaging', allow: [], deny: [] } }, + })); + expect(policy.getEffectiveProfile({ provider: 'ollama' })).toBe('messaging'); + }); + + it('agent override takes precedence over provider', () => { + const policy = new ToolPolicy(defaultConfig({ + agents: { fast: { profile: 'minimal', allow: [], deny: [] } }, + providers: { ollama: { profile: 'messaging', allow: [], deny: [] } }, + })); + expect(policy.getEffectiveProfile({ agent: 'fast', provider: 'ollama' })).toBe('minimal'); + }); + }); + + describe('backward compatibility', () => { + it('no tools config means full profile (all tools allowed)', () => { + // This simulates the default Zod output when no tools: section in yaml + const policy = new ToolPolicy({ + profile: 'full', + allow: [], + deny: [], + agents: {}, + providers: {}, + }); + const result = policy.filterTools(ALL_TOOLS); + expect(result).toHaveLength(ALL_TOOLS.length); + }); + }); + + describe('edge cases', () => { + it('handles empty tool list', () => { + const policy = new ToolPolicy(defaultConfig()); + const result = policy.filterTools([]); + expect(result).toHaveLength(0); + }); + + it('handles profile with unregistered tools', () => { + // If only some tools from the profile are registered + const fewTools = [makeTool('file.read'), makeTool('web.fetch')]; + const policy = new ToolPolicy(defaultConfig({ profile: 'coding' })); + const result = policy.filterTools(fewTools); + const names = result.map(t => t.name); + expect(names).toContain('file.read'); + expect(names).toContain('web.fetch'); + expect(names).toHaveLength(2); + }); + + it('deny pattern removes multiple tools', () => { + const policy = new ToolPolicy(defaultConfig({ deny: ['process.*'] })); + const result = policy.filterTools(ALL_TOOLS); + const names = result.map(t => t.name); + expect(names).not.toContain('process.start'); + expect(names).not.toContain('process.status'); + expect(names).not.toContain('process.output'); + expect(names).not.toContain('process.kill'); + expect(names).not.toContain('process.list'); + expect(names).toContain('shell.exec'); + }); + }); +}); diff --git a/src/tools/policy.ts b/src/tools/policy.ts new file mode 100644 index 0000000..c53e31c --- /dev/null +++ b/src/tools/policy.ts @@ -0,0 +1,229 @@ +import type { ToolsConfig, ToolProfile } from '../config/schema.js'; +import type { Tool } from './types.js'; + +// ── Profile definitions ───────────────────────────────────────────── + +/** Built-in tool name sets for each profile level. Profiles are cumulative. */ +const PROFILE_TOOLS: Record> = { + minimal: new Set([ + 'file.read', + 'file.list', + 'web.fetch', + ]), + messaging: new Set([ + 'file.read', + 'file.list', + 'web.fetch', + 'memory.read', + 'memory.write', + 'memory.search', + 'web.search', + ]), + coding: new Set([ + 'file.read', + 'file.list', + 'web.fetch', + 'memory.read', + 'memory.write', + 'memory.search', + 'web.search', + 'file.write', + 'file.edit', + 'shell.exec', + 'process.start', + 'process.status', + 'process.output', + 'process.kill', + 'process.list', + ]), + full: new Set(), // Special: matches everything +}; + +// ── Glob matching ─────────────────────────────────────────────────── + +/** + * Convert a simple glob pattern to a regex. + * Supports `*` (matches any characters) and `.` is escaped. + * Same algorithm as HookEngine.patternToRegex for consistency. + */ +function patternToRegex(pattern: string): RegExp { + const escaped = pattern + .replace(/[.+^${}()|[\]\\]/g, '\\$&') + .replace(/\*/g, '.*'); + return new RegExp(`^${escaped}$`); +} + +function matchesAnyPattern(toolName: string, patterns: string[]): boolean { + return patterns.some(p => patternToRegex(p).test(toolName)); +} + +// ── Policy context ────────────────────────────────────────────────── + +/** Identifies the runtime context for tool policy resolution. */ +export interface ToolPolicyContext { + /** Model tier name (e.g. 'fast', 'default', 'complex', 'local'). */ + agent?: string; + /** Provider name (e.g. 'ollama', 'anthropic'). */ + provider?: string; +} + +// ── ToolPolicy engine ─────────────────────────────────────────────── + +/** + * Resolves which tools are permitted for a given runtime context. + * + * Resolution order: + * 1. Start with profile's tool set (or all tools for 'full') + * 2. Apply global allow list (adds tools back in) + * 3. Apply global deny list (removes tools) + * 4. If agent/provider overrides exist, compute their resolved sets + * and intersect with the global set + * 5. Deny always wins over allow at every level + */ +export class ToolPolicy { + private config: ToolsConfig; + + constructor(config: ToolsConfig) { + this.config = config; + } + + /** + * Return the list of tools permitted for the given context. + * This is the primary API — filters an array of Tool objects. + */ + filterTools(tools: Tool[], context?: ToolPolicyContext): Tool[] { + const allowed = this.resolveAllowedNames( + tools.map(t => t.name), + context, + ); + return tools.filter(t => allowed.has(t.name)); + } + + /** + * Check whether a specific tool name is permitted in the given context. + * Used for runtime enforcement in the executor (defense in depth). + */ + isAllowed(toolName: string, allToolNames: string[], context?: ToolPolicyContext): boolean { + const allowed = this.resolveAllowedNames(allToolNames, context); + return allowed.has(toolName); + } + + /** + * Resolve the full set of allowed tool names given an array of all + * registered tool names and an optional context. + */ + resolveAllowedNames(allToolNames: string[], context?: ToolPolicyContext): Set { + // Step 1: Start from global profile + let allowed = this.applyProfile(this.config.profile, allToolNames); + + // Step 2: Apply global allow (adds tools) + if (this.config.allow.length > 0) { + for (const name of allToolNames) { + if (matchesAnyPattern(name, this.config.allow)) { + allowed.add(name); + } + } + } + + // Step 3: Apply global deny (removes tools) + if (this.config.deny.length > 0) { + allowed = new Set( + [...allowed].filter(name => !matchesAnyPattern(name, this.config.deny)), + ); + } + + // Step 4: Apply agent override if present + if (context?.agent && this.config.agents[context.agent]) { + const agentOverride = this.config.agents[context.agent]; + const agentAllowed = this.resolveOverride(agentOverride, allToolNames); + allowed = intersect(allowed, agentAllowed); + } + + // Step 5: Apply provider override if present + if (context?.provider && this.config.providers[context.provider]) { + const providerOverride = this.config.providers[context.provider]; + const providerAllowed = this.resolveOverride(providerOverride, allToolNames); + allowed = intersect(allowed, providerAllowed); + } + + return allowed; + } + + /** + * Get the effective profile for a given context. + * Used for informational/debugging purposes. + */ + getEffectiveProfile(context?: ToolPolicyContext): ToolProfile { + // Check agent override first, then provider, then global + if (context?.agent && this.config.agents[context.agent]?.profile) { + return this.config.agents[context.agent].profile!; + } + if (context?.provider && this.config.providers[context.provider]?.profile) { + return this.config.providers[context.provider].profile!; + } + return this.config.profile; + } + + // ── Private helpers ───────────────────────────────────────────────── + + /** + * Resolve the tool set for a profile. + * 'full' means all tools; other profiles return their defined set, + * filtered to only include actually-registered tools. + */ + private applyProfile(profile: ToolProfile, allToolNames: string[]): Set { + if (profile === 'full') { + return new Set(allToolNames); + } + const profileSet = PROFILE_TOOLS[profile]; + return new Set(allToolNames.filter(name => profileSet.has(name))); + } + + /** + * Resolve an override block (agent or provider) to a set of allowed names. + * An override inherits from the global profile if it doesn't specify its own. + */ + private resolveOverride( + override: { profile?: ToolProfile; allow: string[]; deny: string[] }, + allToolNames: string[], + ): Set { + // Start from the override's profile, or inherit global + const baseProfile = override.profile ?? this.config.profile; + let allowed = this.applyProfile(baseProfile, allToolNames); + + // Apply override allow + if (override.allow.length > 0) { + for (const name of allToolNames) { + if (matchesAnyPattern(name, override.allow)) { + allowed.add(name); + } + } + } + + // Apply override deny (deny always wins) + if (override.deny.length > 0) { + allowed = new Set( + [...allowed].filter(name => !matchesAnyPattern(name, override.deny)), + ); + } + + return allowed; + } +} + +// ── Utility ───────────────────────────────────────────────────────── + +function intersect(a: Set, b: Set): Set { + const result = new Set(); + for (const item of a) { + if (b.has(item)) { + result.add(item); + } + } + return result; +} + +/** + * Exported for testing and for use in HookEngine (DRY). + */ +export { patternToRegex, matchesAnyPattern, PROFILE_TOOLS }; diff --git a/src/tools/registry.ts b/src/tools/registry.ts index 541ae03..71f40b9 100644 --- a/src/tools/registry.ts +++ b/src/tools/registry.ts @@ -1,4 +1,5 @@ import type { Tool, ToolInputSchema } from './types.js'; +import type { ToolPolicy, ToolPolicyContext } from './policy.js'; export interface AnthropicToolDef { name: string; @@ -17,6 +18,7 @@ export interface OpenAIToolDef { export class ToolRegistry { private tools: Map = new Map(); + private _policy?: ToolPolicy; register(tool: Tool): void { if (this.tools.has(tool.name)) { @@ -37,6 +39,22 @@ export class ToolRegistry { return Array.from(this.tools.values()); } + /** Set the tool policy for filtering. */ + setPolicy(policy: ToolPolicy): void { + this._policy = policy; + } + + /** Get the tool policy (if set). */ + getPolicy(): ToolPolicy | undefined { + return this._policy; + } + + /** Return tools filtered by the policy for a given context. */ + filteredList(context?: ToolPolicyContext): Tool[] { + if (!this._policy) return this.list(); + return this._policy.filterTools(this.list(), context); + } + toAnthropicFormat(): AnthropicToolDef[] { return this.list().map(t => ({ name: t.name, @@ -45,6 +63,15 @@ export class ToolRegistry { })); } + /** Return Anthropic-format tools filtered by policy. */ + filteredToAnthropicFormat(context?: ToolPolicyContext): AnthropicToolDef[] { + return this.filteredList(context).map(t => ({ + name: t.name, + description: t.description, + input_schema: t.inputSchema, + })); + } + toOpenAIFormat(): OpenAIToolDef[] { return this.list().map(t => ({ type: 'function' as const, @@ -55,4 +82,16 @@ export class ToolRegistry { }, })); } + + /** Return OpenAI-format tools filtered by policy. */ + filteredToOpenAIFormat(context?: ToolPolicyContext): OpenAIToolDef[] { + return this.filteredList(context).map(t => ({ + type: 'function' as const, + function: { + name: t.name, + description: t.description, + parameters: t.inputSchema, + }, + })); + } }