Improve in-flight cancel latency via run abort signal propagation
This commit is contained in:
@@ -24,10 +24,11 @@ describe('NativeAgent', () => {
|
||||
const response = await agent.process('Hi');
|
||||
|
||||
expect(response).toBe('Hello!');
|
||||
expect(mockClient.chat).toHaveBeenCalledWith({
|
||||
expect(mockClient.chat).toHaveBeenCalledWith(expect.objectContaining({
|
||||
messages: [{ role: 'user', content: 'Hi' }],
|
||||
system: 'You are helpful.',
|
||||
});
|
||||
signal: expect.any(AbortSignal),
|
||||
}));
|
||||
|
||||
const history = agent.getHistory();
|
||||
expect(history).toHaveLength(2);
|
||||
|
||||
@@ -83,6 +83,7 @@ export class NativeAgent {
|
||||
private _lastToolFingerprint?: string;
|
||||
private _cancelRequested = false;
|
||||
private _runInProgress = false;
|
||||
private _runAbortController?: AbortController;
|
||||
private modelTimeoutMs: number;
|
||||
|
||||
constructor(config: NativeAgentConfig) {
|
||||
@@ -106,6 +107,7 @@ export class NativeAgent {
|
||||
|
||||
async process(userMessage: string, attachments?: Attachment[]): Promise<string> {
|
||||
this._cancelRequested = false;
|
||||
this._runAbortController = new AbortController();
|
||||
if ('clearAbort' in this.modelClient && typeof this.modelClient.clearAbort === 'function') {
|
||||
this.modelClient.clearAbort();
|
||||
}
|
||||
@@ -144,6 +146,7 @@ export class NativeAgent {
|
||||
} finally {
|
||||
this._runInProgress = false;
|
||||
this._cancelRequested = false;
|
||||
this._runAbortController = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -353,7 +356,9 @@ export class NativeAgent {
|
||||
}
|
||||
: undefined;
|
||||
|
||||
const result = await toolExecutor.execute(internalName, tc.args, perCallContext);
|
||||
const result = await toolExecutor.execute(internalName, tc.args, perCallContext, {
|
||||
signal: this._runAbortController?.signal,
|
||||
});
|
||||
|
||||
this.onToolUse?.({ type: 'end', tool: internalName, result });
|
||||
|
||||
@@ -426,11 +431,22 @@ export class NativeAgent {
|
||||
}
|
||||
|
||||
private async chatWithRouter(request: ChatRequest): Promise<ChatResponse> {
|
||||
const runSignal = this._runAbortController?.signal;
|
||||
const requestSignal = request.signal;
|
||||
const signal = runSignal && requestSignal
|
||||
? AbortSignal.any([runSignal, requestSignal])
|
||||
: (runSignal ?? requestSignal);
|
||||
|
||||
const requestWithSignal = signal
|
||||
? { ...request, signal }
|
||||
: request;
|
||||
|
||||
const requestPromise = 'getClient' in this.modelClient
|
||||
? (this.modelClient as ModelRouter).chat(request, this.currentTier)
|
||||
: this.modelClient.chat(request);
|
||||
? (this.modelClient as ModelRouter).chat(requestWithSignal, this.currentTier)
|
||||
: this.modelClient.chat(requestWithSignal);
|
||||
|
||||
let timer: NodeJS.Timeout | undefined;
|
||||
let abortCleanup: (() => void) | undefined;
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
timer = setTimeout(() => {
|
||||
const error = new Error(`Model request timed out after ${this.modelTimeoutMs}ms`);
|
||||
@@ -439,13 +455,31 @@ export class NativeAgent {
|
||||
}, this.modelTimeoutMs);
|
||||
timer.unref?.();
|
||||
});
|
||||
const abortPromise = signal
|
||||
? new Promise<never>((_, reject) => {
|
||||
if (signal.aborted) {
|
||||
const error = new Error('Operation cancelled by user.');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
const onAbort = () => {
|
||||
const error = new Error('Operation cancelled by user.');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
};
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
abortCleanup = () => signal.removeEventListener('abort', onAbort);
|
||||
})
|
||||
: null;
|
||||
|
||||
try {
|
||||
return await Promise.race([requestPromise, timeoutPromise]);
|
||||
return await Promise.race([requestPromise, timeoutPromise, ...(abortPromise ? [abortPromise] : [])]);
|
||||
} finally {
|
||||
if (timer) {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
abortCleanup?.();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -544,6 +578,7 @@ export class NativeAgent {
|
||||
cancel(): void {
|
||||
if (this._runInProgress) {
|
||||
this._cancelRequested = true;
|
||||
this._runAbortController?.abort();
|
||||
if ('requestAbort' in this.modelClient && typeof this.modelClient.requestAbort === 'function') {
|
||||
this.modelClient.requestAbort();
|
||||
}
|
||||
@@ -555,7 +590,7 @@ export class NativeAgent {
|
||||
}
|
||||
|
||||
private throwIfCancelled(): void {
|
||||
if (!this._cancelRequested) {
|
||||
if (!this._cancelRequested && !this._runAbortController?.signal.aborted) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -100,7 +100,10 @@ export class AnthropicClient implements ModelClient {
|
||||
params.thinking = { type: 'enabled', budget_tokens: 4096 };
|
||||
}
|
||||
|
||||
const response = await this.client.messages.create(params) as AnthropicMessage;
|
||||
const response = await this.client.messages.create(
|
||||
params,
|
||||
request.signal ? { signal: request.signal } : undefined,
|
||||
) as AnthropicMessage;
|
||||
|
||||
const textContent = response.content.find((c) => c.type === 'text');
|
||||
const content = textContent?.type === 'text' ? textContent.text : '';
|
||||
|
||||
@@ -65,7 +65,10 @@ export class BedrockClient implements ModelClient {
|
||||
}
|
||||
|
||||
const command = new ConverseCommand(params);
|
||||
const response = await this.client.send(command);
|
||||
const response = await this.client.send(
|
||||
command,
|
||||
request.signal ? { abortSignal: request.signal } : undefined,
|
||||
);
|
||||
|
||||
// Extract text and tool_use content from the response
|
||||
const outputContent = response.output?.message?.content ?? [];
|
||||
@@ -126,7 +129,10 @@ export class BedrockClient implements ModelClient {
|
||||
|
||||
try {
|
||||
const command = new ConverseStreamCommand(params);
|
||||
const response = await this.client.send(command);
|
||||
const response = await this.client.send(
|
||||
command,
|
||||
request.signal ? { abortSignal: request.signal } : undefined,
|
||||
);
|
||||
|
||||
let inputTokens = 0;
|
||||
let outputTokens = 0;
|
||||
|
||||
@@ -163,7 +163,10 @@ export class GitHubModelsClient implements ModelClient {
|
||||
(params as OpenAI.ChatCompletionCreateParamsNonStreaming & { reasoning_effort?: 'low' | 'medium' | 'high' }).reasoning_effort = 'medium';
|
||||
}
|
||||
|
||||
const response = await this.client.chat.completions.create(params);
|
||||
const response = await this.client.chat.completions.create(
|
||||
params,
|
||||
request.signal ? { signal: request.signal } : undefined,
|
||||
);
|
||||
|
||||
const choice = response.choices[0];
|
||||
const content = choice?.message?.content ?? '';
|
||||
@@ -237,7 +240,10 @@ export class GitHubModelsClient implements ModelClient {
|
||||
}
|
||||
|
||||
try {
|
||||
const stream = await this.client.chat.completions.create(params);
|
||||
const stream = await this.client.chat.completions.create(
|
||||
params,
|
||||
request.signal ? { signal: request.signal } : undefined,
|
||||
);
|
||||
|
||||
let totalInputTokens = 0;
|
||||
let totalOutputTokens = 0;
|
||||
|
||||
@@ -247,13 +247,16 @@ export class LlamaCppClient implements ModelClient {
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
const signal = request.signal
|
||||
? AbortSignal.any([request.signal, controller.signal])
|
||||
: controller.signal;
|
||||
const timer = setTimeout(() => controller.abort(), this.requestTimeout);
|
||||
try {
|
||||
response = await fetch(`${this.endpoint}/v1/chat/completions`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
signal: controller.signal,
|
||||
signal,
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(timer);
|
||||
@@ -331,6 +334,7 @@ export class LlamaCppClient implements ModelClient {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
signal: request.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
||||
@@ -140,6 +140,7 @@ export class OpenAIClient implements ModelClient {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
signal: request.signal,
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
@@ -277,7 +278,10 @@ export class OpenAIClient implements ModelClient {
|
||||
|
||||
let response: OpenAI.ChatCompletion;
|
||||
try {
|
||||
response = await this.client.chat.completions.create(params);
|
||||
response = await this.client.chat.completions.create(
|
||||
params,
|
||||
request.signal ? { signal: request.signal } : undefined,
|
||||
);
|
||||
} catch (error) {
|
||||
const status = typeof (error as { status?: unknown })?.status === 'number'
|
||||
? (error as { status: number }).status
|
||||
|
||||
@@ -80,6 +80,8 @@ export interface ChatRequest {
|
||||
tools?: ToolDefinition[];
|
||||
/** Enable extended thinking/reasoning mode for this request. */
|
||||
thinking?: boolean;
|
||||
/** Optional abort signal for cancelling in-flight provider requests. */
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface ChatResponse {
|
||||
|
||||
+49
-6
@@ -24,6 +24,10 @@ export interface ToolExecutionObserverEvent {
|
||||
timestampSeconds: number;
|
||||
}
|
||||
|
||||
export interface ToolExecuteOptions {
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export class ToolExecutor {
|
||||
private registry: ToolRegistry;
|
||||
private hooks: HookEngine;
|
||||
@@ -64,7 +68,12 @@ export class ToolExecutor {
|
||||
return base;
|
||||
}
|
||||
|
||||
async execute(toolName: string, args: unknown, context?: ToolPolicyContext): Promise<ToolResult> {
|
||||
async execute(
|
||||
toolName: string,
|
||||
args: unknown,
|
||||
context?: ToolPolicyContext,
|
||||
options?: ToolExecuteOptions,
|
||||
): Promise<ToolResult> {
|
||||
const executionId = randomUUID();
|
||||
const executionEnvironment = this.resolveEffectiveExecutionEnvironment(toolName, context);
|
||||
const skillName = context?.skillName;
|
||||
@@ -279,31 +288,56 @@ export class ToolExecutor {
|
||||
});
|
||||
|
||||
let timeoutHandle: NodeJS.Timeout | undefined;
|
||||
const abortController = new AbortController();
|
||||
const timeoutAbortController = new AbortController();
|
||||
const externalSignal = options?.signal;
|
||||
const combinedSignal = externalSignal
|
||||
? AbortSignal.any([externalSignal, timeoutAbortController.signal])
|
||||
: timeoutAbortController.signal;
|
||||
let externalAbortCleanup: (() => void) | undefined;
|
||||
|
||||
try {
|
||||
const externalAbortPromise = externalSignal
|
||||
? new Promise<ToolResult>((_, reject) => {
|
||||
if (externalSignal.aborted) {
|
||||
const error = new Error('Operation cancelled by user.');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
const onAbort = () => {
|
||||
const error = new Error('Operation cancelled by user.');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
};
|
||||
externalSignal.addEventListener('abort', onAbort, { once: true });
|
||||
externalAbortCleanup = () => externalSignal.removeEventListener('abort', onAbort);
|
||||
})
|
||||
: null;
|
||||
|
||||
const result = await Promise.race([
|
||||
(async () => {
|
||||
if (executionEnvironment === 'sandbox' && this.sandboxManager) {
|
||||
const sandboxSessionId = context?.sessionId ?? `${context?.channel ?? 'unknown'}:${context?.sender ?? 'unknown'}`;
|
||||
const sandbox = await this.sandboxManager.getOrCreate(sandboxSessionId);
|
||||
if (toolName === 'shell.exec') {
|
||||
return createSandboxedShellTool(sandbox).execute(args, { signal: abortController.signal });
|
||||
return createSandboxedShellTool(sandbox).execute(args, { signal: combinedSignal });
|
||||
}
|
||||
if (toolName === 'process.start') {
|
||||
return createSandboxedProcessStartTool(sandbox).execute(args, { signal: abortController.signal });
|
||||
return createSandboxedProcessStartTool(sandbox).execute(args, { signal: combinedSignal });
|
||||
}
|
||||
}
|
||||
return tool.execute(args, { signal: abortController.signal });
|
||||
return tool.execute(args, { signal: combinedSignal });
|
||||
})(),
|
||||
new Promise<ToolResult>((_, reject) => {
|
||||
timeoutHandle = setTimeout(
|
||||
() => {
|
||||
abortController.abort();
|
||||
timeoutAbortController.abort();
|
||||
reject(new Error(`Tool '${toolName}' timed out after ${this.defaultTimeoutMs}ms`));
|
||||
},
|
||||
this.defaultTimeoutMs,
|
||||
);
|
||||
}),
|
||||
...(externalAbortPromise ? [externalAbortPromise] : []),
|
||||
]);
|
||||
|
||||
const duration = Date.now() - startTime;
|
||||
@@ -357,6 +391,10 @@ export class ToolExecutor {
|
||||
timestampSeconds: Math.floor(Date.now() / 1000),
|
||||
});
|
||||
|
||||
if (externalSignal?.aborted && this.isAbortError(error)) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
return {
|
||||
success: false,
|
||||
output: '',
|
||||
@@ -366,9 +404,14 @@ export class ToolExecutor {
|
||||
if (timeoutHandle) {
|
||||
clearTimeout(timeoutHandle);
|
||||
}
|
||||
externalAbortCleanup?.();
|
||||
}
|
||||
}
|
||||
|
||||
private isAbortError(error: unknown): boolean {
|
||||
return error instanceof Error && error.name === 'AbortError';
|
||||
}
|
||||
|
||||
private notifyExecutionObserver(event: ToolExecutionObserverEvent): void {
|
||||
if (!this.executionObserver) {
|
||||
return;
|
||||
|
||||
Reference in New Issue
Block a user