diff --git a/src/backends/native/agent.test.ts b/src/backends/native/agent.test.ts index d17da0b..42d5b78 100644 --- a/src/backends/native/agent.test.ts +++ b/src/backends/native/agent.test.ts @@ -1,6 +1,9 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { NativeAgent } from './agent.js'; import type { ModelClient, ChatResponse } from '../../models/types.js'; +import { ToolRegistry, ToolExecutor } from '../../tools/index.js'; +import { HookEngine } from '../../hooks/index.js'; +import type { Tool, ToolResult } from '../../tools/index.js'; describe('NativeAgent', () => { const createMockClient = (): ModelClient => ({ @@ -67,3 +70,191 @@ describe('NativeAgent', () => { expect(mockSession.addMessage).toHaveBeenNthCalledWith(2, { role: 'assistant', content: 'Hello!' }); }); }); + +// Simple test tool +const echoTool: Tool = { + name: 'test.echo', + description: 'Echo', + inputSchema: { type: 'object', properties: { text: { type: 'string' } }, required: ['text'] }, + execute: async (args) => ({ success: true, output: (args as { text: string }).text }), +}; + +describe('NativeAgent tool loop', () => { + it('executes tool calls and feeds results back', async () => { + let callCount = 0; + const mockClient: ModelClient = { + chat: vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + // First call: model requests tool use + return { + content: '', + stopReason: 'tool_use', + usage: { inputTokens: 10, outputTokens: 5 }, + toolCalls: [{ id: 'call_1', name: 'test.echo', args: { text: 'hello' } }], + }; + } + // Second call: model gives final text response + return { + content: 'The tool returned: hello', + stopReason: 'end_turn', + usage: { inputTokens: 15, outputTokens: 10 }, + }; + }), + }; + + const registry = new ToolRegistry(); + registry.register(echoTool); + const hooks = new HookEngine({ confirm: [], log: [], silent: [] }); + const executor = new ToolExecutor(registry, hooks); + + const agent = new NativeAgent({ + modelClient: mockClient, + systemPrompt: 'You are helpful.', + toolRegistry: registry, + toolExecutor: executor, + }); + + const response = await agent.process('echo hello'); + expect(response).toBe('The tool returned: hello'); + expect(mockClient.chat).toHaveBeenCalledTimes(2); + }); + + it('respects max iterations', async () => { + // Model always returns tool_use + const mockClient: ModelClient = { + chat: vi.fn().mockResolvedValue({ + content: '', + stopReason: 'tool_use', + usage: { inputTokens: 10, outputTokens: 5 }, + toolCalls: [{ id: 'call_1', name: 'test.echo', args: { text: 'loop' } }], + }), + }; + + const registry = new ToolRegistry(); + registry.register(echoTool); + const hooks = new HookEngine({ confirm: [], log: [], silent: [] }); + const executor = new ToolExecutor(registry, hooks); + + const agent = new NativeAgent({ + modelClient: mockClient, + systemPrompt: 'You are helpful.', + toolRegistry: registry, + toolExecutor: executor, + maxIterations: 3, + }); + + const response = await agent.process('loop forever'); + expect(response).toContain('max iterations'); + expect(mockClient.chat).toHaveBeenCalledTimes(3); + }); + + it('works without tools (backward compatible)', async () => { + const mockClient: ModelClient = { + chat: vi.fn().mockResolvedValue({ + content: 'Hello!', + stopReason: 'end_turn', + usage: { inputTokens: 10, outputTokens: 5 }, + }), + }; + + const agent = new NativeAgent({ + modelClient: mockClient, + systemPrompt: 'You are helpful.', + }); + + const response = await agent.process('Hi'); + expect(response).toBe('Hello!'); + }); + + it('calls onToolUse callback on start and end', async () => { + let callCount = 0; + const mockClient: ModelClient = { + chat: vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return { + content: '', + stopReason: 'tool_use', + usage: { inputTokens: 10, outputTokens: 5 }, + toolCalls: [{ id: 'call_1', name: 'test.echo', args: { text: 'hi' } }], + }; + } + return { + content: 'Done', + stopReason: 'end_turn', + usage: { inputTokens: 15, outputTokens: 10 }, + }; + }), + }; + + const registry = new ToolRegistry(); + registry.register(echoTool); + const hooks = new HookEngine({ confirm: [], log: [], silent: [] }); + const executor = new ToolExecutor(registry, hooks); + const onToolUse = vi.fn(); + + const agent = new NativeAgent({ + modelClient: mockClient, + systemPrompt: 'You are helpful.', + toolRegistry: registry, + toolExecutor: executor, + onToolUse, + }); + + await agent.process('echo hi'); + + expect(onToolUse).toHaveBeenCalledTimes(2); + expect(onToolUse).toHaveBeenNthCalledWith(1, expect.objectContaining({ + type: 'start', + tool: 'test.echo', + args: { text: 'hi' }, + })); + expect(onToolUse).toHaveBeenNthCalledWith(2, expect.objectContaining({ + type: 'end', + tool: 'test.echo', + result: expect.objectContaining({ success: true, output: 'hi' }), + })); + }); + + it('handles multiple tool calls in single response', async () => { + let callCount = 0; + const mockClient: ModelClient = { + chat: vi.fn().mockImplementation(() => { + callCount++; + if (callCount === 1) { + return { + content: '', + stopReason: 'tool_use', + usage: { inputTokens: 10, outputTokens: 5 }, + toolCalls: [ + { id: 'call_1', name: 'test.echo', args: { text: 'first' } }, + { id: 'call_2', name: 'test.echo', args: { text: 'second' } }, + ], + }; + } + return { + content: 'Got both results', + stopReason: 'end_turn', + usage: { inputTokens: 15, outputTokens: 10 }, + }; + }), + }; + + const registry = new ToolRegistry(); + registry.register(echoTool); + const hooks = new HookEngine({ confirm: [], log: [], silent: [] }); + const executor = new ToolExecutor(registry, hooks); + + const agent = new NativeAgent({ + modelClient: mockClient, + systemPrompt: 'You are helpful.', + toolRegistry: registry, + toolExecutor: executor, + }); + + const response = await agent.process('echo both'); + expect(response).toBe('Got both results'); + expect(mockClient.chat).toHaveBeenCalledTimes(2); + }); +}); diff --git a/src/backends/native/agent.ts b/src/backends/native/agent.ts index cf42654..c65e065 100644 --- a/src/backends/native/agent.ts +++ b/src/backends/native/agent.ts @@ -1,11 +1,32 @@ -import type { ModelClient, Message } from '../../models/types.js'; +import type { ModelClient, Message, ChatRequest, ChatResponse, ModelToolCall } from '../../models/types.js'; import type { ModelRouter, ModelTier } from '../../models/router.js'; import type { Session } from '../../session/index.js'; +import type { ToolRegistry } from '../../tools/registry.js'; +import type { ToolExecutor } from '../../tools/executor.js'; +import type { ToolResult } from '../../tools/types.js'; + +export interface ToolUseEvent { + type: 'start' | 'end'; + tool: string; + args?: unknown; + result?: ToolResult; +} export interface NativeAgentConfig { modelClient: ModelClient | ModelRouter; systemPrompt: string; session?: Session; + toolRegistry?: ToolRegistry; + toolExecutor?: ToolExecutor; + maxIterations?: number; + onToolUse?: (event: ToolUseEvent) => void; +} + +// Internal message type for the tool loop — supports both text and structured content blocks. +// This is broader than Message to accommodate Anthropic's tool_use/tool_result block format. +interface LoopMessage { + role: 'user' | 'assistant'; + content: string | unknown[]; } export class NativeAgent { @@ -14,11 +35,19 @@ export class NativeAgent { private session?: Session; private inMemoryHistory: Message[] = []; private currentTier: ModelTier = 'default'; + private toolRegistry?: ToolRegistry; + private toolExecutor?: ToolExecutor; + private maxIterations: number; + private onToolUse?: (event: ToolUseEvent) => void; constructor(config: NativeAgentConfig) { this.modelClient = config.modelClient; this.systemPrompt = config.systemPrompt; this.session = config.session; + this.toolRegistry = config.toolRegistry; + this.toolExecutor = config.toolExecutor; + this.maxIterations = config.maxIterations ?? 10; + this.onToolUse = config.onToolUse; } private get history(): Message[] { @@ -34,27 +63,114 @@ export class NativeAgent { this.inMemoryHistory.push(userMsg); } - const request = { + // If no tools configured, use the simple single-turn path + if (!this.toolRegistry || !this.toolExecutor) { + return this.singleTurn(); + } + + return this.toolLoop(); + } + + private async singleTurn(): Promise { + const request: ChatRequest = { messages: this.history, system: this.systemPrompt, }; - // Use tier if modelClient is a ModelRouter - const response = 'getClient' in this.modelClient - ? await (this.modelClient as ModelRouter).chat(request, this.currentTier) - : await this.modelClient.chat(request); + const response = await this.chatWithRouter(request); const assistantMsg: Message = { role: 'assistant', content: response.content }; - - if (this.session) { - this.session.addMessage(assistantMsg); - } else { - this.inMemoryHistory.push(assistantMsg); - } + this.addToHistory(assistantMsg); return response.content; } + private async toolLoop(): Promise { + const tools = this.toolRegistry!.toAnthropicFormat(); + + // Build the loop messages from existing history. + // These are the messages sent to the model, including any structured tool blocks. + const loopMessages: LoopMessage[] = this.history.map(m => ({ + role: m.role, + content: m.content, + })); + + for (let iteration = 0; iteration < this.maxIterations; iteration++) { + // Build request — cast loopMessages to Message[] because the underlying + // model client will pass them through to the API which accepts structured content. + const request = { + messages: loopMessages as unknown as Message[], + system: this.systemPrompt, + tools, + }; + + const response = await this.chatWithRouter(request); + + // If the model didn't request tool use, we're done + if (response.stopReason !== 'tool_use' || !response.toolCalls?.length) { + const assistantMsg: Message = { role: 'assistant', content: response.content }; + this.addToHistory(assistantMsg); + return response.content; + } + + // Build the assistant message with tool_use content blocks + const assistantContent: unknown[] = []; + if (response.content) { + assistantContent.push({ type: 'text', text: response.content }); + } + for (const tc of response.toolCalls) { + assistantContent.push({ + type: 'tool_use', + id: tc.id, + name: tc.name, + input: tc.args, + }); + } + loopMessages.push({ role: 'assistant', content: assistantContent }); + + // Execute each tool call and collect results + const toolResultBlocks: unknown[] = []; + for (const tc of response.toolCalls) { + this.onToolUse?.({ type: 'start', tool: tc.name, args: tc.args }); + + const result = await this.toolExecutor!.execute(tc.name, tc.args); + + this.onToolUse?.({ type: 'end', tool: tc.name, result }); + + toolResultBlocks.push({ + type: 'tool_result', + tool_use_id: tc.id, + content: result.success ? result.output : (result.error ?? 'Unknown error'), + is_error: !result.success, + }); + } + + // Add tool results as a user message + loopMessages.push({ role: 'user', content: toolResultBlocks }); + } + + // Max iterations reached + const warningMsg = `Stopped after reaching max iterations (${this.maxIterations}). The task may be incomplete.`; + const assistantMsg: Message = { role: 'assistant', content: warningMsg }; + this.addToHistory(assistantMsg); + return warningMsg; + } + + private async chatWithRouter(request: ChatRequest): Promise { + if ('getClient' in this.modelClient) { + return (this.modelClient as ModelRouter).chat(request, this.currentTier); + } + return this.modelClient.chat(request); + } + + private addToHistory(msg: Message): void { + if (this.session) { + this.session.addMessage(msg); + } else { + this.inMemoryHistory.push(msg); + } + } + reset(): void { if (this.session) { this.session.clear();