feat(agent): add iterative tool use loop with max iterations
Rewrites NativeAgent.process() from single-turn to an iterative tool loop. When toolRegistry and toolExecutor are provided, the agent calls the model, executes any requested tool calls, feeds results back, and loops until the model returns a text response or max iterations hit. - Backward compatible: works exactly as before without tools - Supports onToolUse callback for frontend status display - Max iterations (default 10) prevents infinite loops - Handles multiple tool calls per model response - 5 new tests (8 total)
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { NativeAgent } from './agent.js';
|
||||
import type { ModelClient, ChatResponse } from '../../models/types.js';
|
||||
import { ToolRegistry, ToolExecutor } from '../../tools/index.js';
|
||||
import { HookEngine } from '../../hooks/index.js';
|
||||
import type { Tool, ToolResult } from '../../tools/index.js';
|
||||
|
||||
describe('NativeAgent', () => {
|
||||
const createMockClient = (): ModelClient => ({
|
||||
@@ -67,3 +70,191 @@ describe('NativeAgent', () => {
|
||||
expect(mockSession.addMessage).toHaveBeenNthCalledWith(2, { role: 'assistant', content: 'Hello!' });
|
||||
});
|
||||
});
|
||||
|
||||
// Simple test tool
|
||||
const echoTool: Tool = {
|
||||
name: 'test.echo',
|
||||
description: 'Echo',
|
||||
inputSchema: { type: 'object', properties: { text: { type: 'string' } }, required: ['text'] },
|
||||
execute: async (args) => ({ success: true, output: (args as { text: string }).text }),
|
||||
};
|
||||
|
||||
describe('NativeAgent tool loop', () => {
|
||||
it('executes tool calls and feeds results back', async () => {
|
||||
let callCount = 0;
|
||||
const mockClient: ModelClient = {
|
||||
chat: vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
// First call: model requests tool use
|
||||
return {
|
||||
content: '',
|
||||
stopReason: 'tool_use',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
toolCalls: [{ id: 'call_1', name: 'test.echo', args: { text: 'hello' } }],
|
||||
};
|
||||
}
|
||||
// Second call: model gives final text response
|
||||
return {
|
||||
content: 'The tool returned: hello',
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 15, outputTokens: 10 },
|
||||
};
|
||||
}),
|
||||
};
|
||||
|
||||
const registry = new ToolRegistry();
|
||||
registry.register(echoTool);
|
||||
const hooks = new HookEngine({ confirm: [], log: [], silent: [] });
|
||||
const executor = new ToolExecutor(registry, hooks);
|
||||
|
||||
const agent = new NativeAgent({
|
||||
modelClient: mockClient,
|
||||
systemPrompt: 'You are helpful.',
|
||||
toolRegistry: registry,
|
||||
toolExecutor: executor,
|
||||
});
|
||||
|
||||
const response = await agent.process('echo hello');
|
||||
expect(response).toBe('The tool returned: hello');
|
||||
expect(mockClient.chat).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('respects max iterations', async () => {
|
||||
// Model always returns tool_use
|
||||
const mockClient: ModelClient = {
|
||||
chat: vi.fn().mockResolvedValue({
|
||||
content: '',
|
||||
stopReason: 'tool_use',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
toolCalls: [{ id: 'call_1', name: 'test.echo', args: { text: 'loop' } }],
|
||||
}),
|
||||
};
|
||||
|
||||
const registry = new ToolRegistry();
|
||||
registry.register(echoTool);
|
||||
const hooks = new HookEngine({ confirm: [], log: [], silent: [] });
|
||||
const executor = new ToolExecutor(registry, hooks);
|
||||
|
||||
const agent = new NativeAgent({
|
||||
modelClient: mockClient,
|
||||
systemPrompt: 'You are helpful.',
|
||||
toolRegistry: registry,
|
||||
toolExecutor: executor,
|
||||
maxIterations: 3,
|
||||
});
|
||||
|
||||
const response = await agent.process('loop forever');
|
||||
expect(response).toContain('max iterations');
|
||||
expect(mockClient.chat).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('works without tools (backward compatible)', async () => {
|
||||
const mockClient: ModelClient = {
|
||||
chat: vi.fn().mockResolvedValue({
|
||||
content: 'Hello!',
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
}),
|
||||
};
|
||||
|
||||
const agent = new NativeAgent({
|
||||
modelClient: mockClient,
|
||||
systemPrompt: 'You are helpful.',
|
||||
});
|
||||
|
||||
const response = await agent.process('Hi');
|
||||
expect(response).toBe('Hello!');
|
||||
});
|
||||
|
||||
it('calls onToolUse callback on start and end', async () => {
|
||||
let callCount = 0;
|
||||
const mockClient: ModelClient = {
|
||||
chat: vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return {
|
||||
content: '',
|
||||
stopReason: 'tool_use',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
toolCalls: [{ id: 'call_1', name: 'test.echo', args: { text: 'hi' } }],
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: 'Done',
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 15, outputTokens: 10 },
|
||||
};
|
||||
}),
|
||||
};
|
||||
|
||||
const registry = new ToolRegistry();
|
||||
registry.register(echoTool);
|
||||
const hooks = new HookEngine({ confirm: [], log: [], silent: [] });
|
||||
const executor = new ToolExecutor(registry, hooks);
|
||||
const onToolUse = vi.fn();
|
||||
|
||||
const agent = new NativeAgent({
|
||||
modelClient: mockClient,
|
||||
systemPrompt: 'You are helpful.',
|
||||
toolRegistry: registry,
|
||||
toolExecutor: executor,
|
||||
onToolUse,
|
||||
});
|
||||
|
||||
await agent.process('echo hi');
|
||||
|
||||
expect(onToolUse).toHaveBeenCalledTimes(2);
|
||||
expect(onToolUse).toHaveBeenNthCalledWith(1, expect.objectContaining({
|
||||
type: 'start',
|
||||
tool: 'test.echo',
|
||||
args: { text: 'hi' },
|
||||
}));
|
||||
expect(onToolUse).toHaveBeenNthCalledWith(2, expect.objectContaining({
|
||||
type: 'end',
|
||||
tool: 'test.echo',
|
||||
result: expect.objectContaining({ success: true, output: 'hi' }),
|
||||
}));
|
||||
});
|
||||
|
||||
it('handles multiple tool calls in single response', async () => {
|
||||
let callCount = 0;
|
||||
const mockClient: ModelClient = {
|
||||
chat: vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return {
|
||||
content: '',
|
||||
stopReason: 'tool_use',
|
||||
usage: { inputTokens: 10, outputTokens: 5 },
|
||||
toolCalls: [
|
||||
{ id: 'call_1', name: 'test.echo', args: { text: 'first' } },
|
||||
{ id: 'call_2', name: 'test.echo', args: { text: 'second' } },
|
||||
],
|
||||
};
|
||||
}
|
||||
return {
|
||||
content: 'Got both results',
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 15, outputTokens: 10 },
|
||||
};
|
||||
}),
|
||||
};
|
||||
|
||||
const registry = new ToolRegistry();
|
||||
registry.register(echoTool);
|
||||
const hooks = new HookEngine({ confirm: [], log: [], silent: [] });
|
||||
const executor = new ToolExecutor(registry, hooks);
|
||||
|
||||
const agent = new NativeAgent({
|
||||
modelClient: mockClient,
|
||||
systemPrompt: 'You are helpful.',
|
||||
toolRegistry: registry,
|
||||
toolExecutor: executor,
|
||||
});
|
||||
|
||||
const response = await agent.process('echo both');
|
||||
expect(response).toBe('Got both results');
|
||||
expect(mockClient.chat).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
+128
-12
@@ -1,11 +1,32 @@
|
||||
import type { ModelClient, Message } from '../../models/types.js';
|
||||
import type { ModelClient, Message, ChatRequest, ChatResponse, ModelToolCall } from '../../models/types.js';
|
||||
import type { ModelRouter, ModelTier } from '../../models/router.js';
|
||||
import type { Session } from '../../session/index.js';
|
||||
import type { ToolRegistry } from '../../tools/registry.js';
|
||||
import type { ToolExecutor } from '../../tools/executor.js';
|
||||
import type { ToolResult } from '../../tools/types.js';
|
||||
|
||||
export interface ToolUseEvent {
|
||||
type: 'start' | 'end';
|
||||
tool: string;
|
||||
args?: unknown;
|
||||
result?: ToolResult;
|
||||
}
|
||||
|
||||
export interface NativeAgentConfig {
|
||||
modelClient: ModelClient | ModelRouter;
|
||||
systemPrompt: string;
|
||||
session?: Session;
|
||||
toolRegistry?: ToolRegistry;
|
||||
toolExecutor?: ToolExecutor;
|
||||
maxIterations?: number;
|
||||
onToolUse?: (event: ToolUseEvent) => void;
|
||||
}
|
||||
|
||||
// Internal message type for the tool loop — supports both text and structured content blocks.
|
||||
// This is broader than Message to accommodate Anthropic's tool_use/tool_result block format.
|
||||
interface LoopMessage {
|
||||
role: 'user' | 'assistant';
|
||||
content: string | unknown[];
|
||||
}
|
||||
|
||||
export class NativeAgent {
|
||||
@@ -14,11 +35,19 @@ export class NativeAgent {
|
||||
private session?: Session;
|
||||
private inMemoryHistory: Message[] = [];
|
||||
private currentTier: ModelTier = 'default';
|
||||
private toolRegistry?: ToolRegistry;
|
||||
private toolExecutor?: ToolExecutor;
|
||||
private maxIterations: number;
|
||||
private onToolUse?: (event: ToolUseEvent) => void;
|
||||
|
||||
constructor(config: NativeAgentConfig) {
|
||||
this.modelClient = config.modelClient;
|
||||
this.systemPrompt = config.systemPrompt;
|
||||
this.session = config.session;
|
||||
this.toolRegistry = config.toolRegistry;
|
||||
this.toolExecutor = config.toolExecutor;
|
||||
this.maxIterations = config.maxIterations ?? 10;
|
||||
this.onToolUse = config.onToolUse;
|
||||
}
|
||||
|
||||
private get history(): Message[] {
|
||||
@@ -34,27 +63,114 @@ export class NativeAgent {
|
||||
this.inMemoryHistory.push(userMsg);
|
||||
}
|
||||
|
||||
const request = {
|
||||
// If no tools configured, use the simple single-turn path
|
||||
if (!this.toolRegistry || !this.toolExecutor) {
|
||||
return this.singleTurn();
|
||||
}
|
||||
|
||||
return this.toolLoop();
|
||||
}
|
||||
|
||||
private async singleTurn(): Promise<string> {
|
||||
const request: ChatRequest = {
|
||||
messages: this.history,
|
||||
system: this.systemPrompt,
|
||||
};
|
||||
|
||||
// Use tier if modelClient is a ModelRouter
|
||||
const response = 'getClient' in this.modelClient
|
||||
? await (this.modelClient as ModelRouter).chat(request, this.currentTier)
|
||||
: await this.modelClient.chat(request);
|
||||
const response = await this.chatWithRouter(request);
|
||||
|
||||
const assistantMsg: Message = { role: 'assistant', content: response.content };
|
||||
|
||||
if (this.session) {
|
||||
this.session.addMessage(assistantMsg);
|
||||
} else {
|
||||
this.inMemoryHistory.push(assistantMsg);
|
||||
}
|
||||
this.addToHistory(assistantMsg);
|
||||
|
||||
return response.content;
|
||||
}
|
||||
|
||||
private async toolLoop(): Promise<string> {
|
||||
const tools = this.toolRegistry!.toAnthropicFormat();
|
||||
|
||||
// Build the loop messages from existing history.
|
||||
// These are the messages sent to the model, including any structured tool blocks.
|
||||
const loopMessages: LoopMessage[] = this.history.map(m => ({
|
||||
role: m.role,
|
||||
content: m.content,
|
||||
}));
|
||||
|
||||
for (let iteration = 0; iteration < this.maxIterations; iteration++) {
|
||||
// Build request — cast loopMessages to Message[] because the underlying
|
||||
// model client will pass them through to the API which accepts structured content.
|
||||
const request = {
|
||||
messages: loopMessages as unknown as Message[],
|
||||
system: this.systemPrompt,
|
||||
tools,
|
||||
};
|
||||
|
||||
const response = await this.chatWithRouter(request);
|
||||
|
||||
// If the model didn't request tool use, we're done
|
||||
if (response.stopReason !== 'tool_use' || !response.toolCalls?.length) {
|
||||
const assistantMsg: Message = { role: 'assistant', content: response.content };
|
||||
this.addToHistory(assistantMsg);
|
||||
return response.content;
|
||||
}
|
||||
|
||||
// Build the assistant message with tool_use content blocks
|
||||
const assistantContent: unknown[] = [];
|
||||
if (response.content) {
|
||||
assistantContent.push({ type: 'text', text: response.content });
|
||||
}
|
||||
for (const tc of response.toolCalls) {
|
||||
assistantContent.push({
|
||||
type: 'tool_use',
|
||||
id: tc.id,
|
||||
name: tc.name,
|
||||
input: tc.args,
|
||||
});
|
||||
}
|
||||
loopMessages.push({ role: 'assistant', content: assistantContent });
|
||||
|
||||
// Execute each tool call and collect results
|
||||
const toolResultBlocks: unknown[] = [];
|
||||
for (const tc of response.toolCalls) {
|
||||
this.onToolUse?.({ type: 'start', tool: tc.name, args: tc.args });
|
||||
|
||||
const result = await this.toolExecutor!.execute(tc.name, tc.args);
|
||||
|
||||
this.onToolUse?.({ type: 'end', tool: tc.name, result });
|
||||
|
||||
toolResultBlocks.push({
|
||||
type: 'tool_result',
|
||||
tool_use_id: tc.id,
|
||||
content: result.success ? result.output : (result.error ?? 'Unknown error'),
|
||||
is_error: !result.success,
|
||||
});
|
||||
}
|
||||
|
||||
// Add tool results as a user message
|
||||
loopMessages.push({ role: 'user', content: toolResultBlocks });
|
||||
}
|
||||
|
||||
// Max iterations reached
|
||||
const warningMsg = `Stopped after reaching max iterations (${this.maxIterations}). The task may be incomplete.`;
|
||||
const assistantMsg: Message = { role: 'assistant', content: warningMsg };
|
||||
this.addToHistory(assistantMsg);
|
||||
return warningMsg;
|
||||
}
|
||||
|
||||
private async chatWithRouter(request: ChatRequest): Promise<ChatResponse> {
|
||||
if ('getClient' in this.modelClient) {
|
||||
return (this.modelClient as ModelRouter).chat(request, this.currentTier);
|
||||
}
|
||||
return this.modelClient.chat(request);
|
||||
}
|
||||
|
||||
private addToHistory(msg: Message): void {
|
||||
if (this.session) {
|
||||
this.session.addMessage(msg);
|
||||
} else {
|
||||
this.inMemoryHistory.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
reset(): void {
|
||||
if (this.session) {
|
||||
this.session.clear();
|
||||
|
||||
Reference in New Issue
Block a user