fix: abort model retries immediately on user cancellation

This commit is contained in:
William Valentin
2026-02-18 11:21:57 -08:00
parent a76c5ae346
commit 55cde541ea
6 changed files with 177 additions and 3 deletions
+16 -1
View File
@@ -5408,10 +5408,25 @@
"docs/plans/state.json" "docs/plans/state.json"
], ],
"test_status": "pnpm test:run src/config/schema.test.ts src/cli/setup/config.test.ts src/cli/setup/sections.test.ts + pnpm typecheck passing" "test_status": "pnpm test:run src/config/schema.test.ts src/cli/setup/config.test.ts src/cli/setup/sections.test.ts + pnpm typecheck passing"
},
"model-retry-cancel-abort-immediacy": {
"status": "completed",
"date": "2026-02-18",
"updated": "2026-02-18",
"summary": "Fixed cancellation responsiveness for model retries: `/cancel` now aborts model-router retry backoff loops immediately via abort-aware retry execution and router abort signaling, preventing long waits through remaining retries/fallback steps.",
"files_modified": [
"src/models/retry.ts",
"src/models/retry.test.ts",
"src/models/router.ts",
"src/models/router.test.ts",
"src/backends/native/agent.ts",
"docs/plans/state.json"
],
"test_status": "pnpm test:run src/models/retry.test.ts src/models/router.test.ts src/backends/native/agent.test.ts + pnpm typecheck passing"
} }
}, },
"overall_progress": { "overall_progress": {
"total_test_count": 1927, "total_test_count": 1930,
"all_tests_passing": true, "all_tests_passing": true,
"p0_completion": "3/3 (100%)", "p0_completion": "3/3 (100%)",
"p1_completion": "4/4 (100%)", "p1_completion": "4/4 (100%)",
+6
View File
@@ -100,6 +100,9 @@ export class NativeAgent {
async process(userMessage: string, attachments?: Attachment[]): Promise<string> { async process(userMessage: string, attachments?: Attachment[]): Promise<string> {
this._cancelRequested = false; this._cancelRequested = false;
if ('clearAbort' in this.modelClient && typeof this.modelClient.clearAbort === 'function') {
this.modelClient.clearAbort();
}
this._runInProgress = true; this._runInProgress = true;
// Detect and strip !!think prefix for per-message thinking mode // Detect and strip !!think prefix for per-message thinking mode
@@ -541,6 +544,9 @@ export class NativeAgent {
cancel(): void { cancel(): void {
if (this._runInProgress) { if (this._runInProgress) {
this._cancelRequested = true; this._cancelRequested = true;
if ('requestAbort' in this.modelClient && typeof this.modelClient.requestAbort === 'function') {
this.modelClient.requestAbort();
}
} }
} }
+28
View File
@@ -159,6 +159,34 @@ describe('withRetry', () => {
expect(fn).toHaveBeenCalledTimes(1); expect(fn).toHaveBeenCalledTimes(1);
}); });
it('aborts before first attempt when shouldAbort is true', async () => {
const fn = vi.fn().mockResolvedValue('never');
await expect(
withRetry(fn, fastConfig, 'abort-test', { shouldAbort: () => true }),
).rejects.toMatchObject({ name: 'AbortError' });
expect(fn).not.toHaveBeenCalled();
});
it('aborts during backoff sleep when shouldAbort flips true', async () => {
let abort = false;
const fn = vi.fn().mockRejectedValue(new Error('temporary failure'));
const run = withRetry(
fn,
{ ...fastConfig, maxRetries: 3, initialDelayMs: 80, backoffMultiplier: 1, maxDelayMs: 80 },
'abort-backoff-test',
{ shouldAbort: () => abort },
);
setTimeout(() => {
abort = true;
}, 10);
await expect(run).rejects.toMatchObject({ name: 'AbortError' });
expect(fn).toHaveBeenCalledTimes(1);
});
it('increases delay exponentially between retries', async () => { it('increases delay exponentially between retries', async () => {
const timestamps: number[] = []; const timestamps: number[] = [];
const config: RetryConfig = { const config: RetryConfig = {
+71 -1
View File
@@ -13,6 +13,11 @@ export interface RetryConfig {
nonRetryablePatterns: string[]; nonRetryablePatterns: string[];
} }
export interface RetryExecutionOptions {
/** Abort retry loop before next attempt and during backoff sleeps. */
shouldAbort?: () => boolean;
}
export const DEFAULT_RETRY_CONFIG: RetryConfig = { export const DEFAULT_RETRY_CONFIG: RetryConfig = {
maxRetries: 3, maxRetries: 3,
initialDelayMs: 1000, initialDelayMs: 1000,
@@ -38,15 +43,34 @@ export async function withRetry<T>(
fn: () => Promise<T>, fn: () => Promise<T>,
config: RetryConfig = DEFAULT_RETRY_CONFIG, config: RetryConfig = DEFAULT_RETRY_CONFIG,
label?: string, label?: string,
options?: RetryExecutionOptions,
): Promise<T> { ): Promise<T> {
let lastError: Error | undefined; let lastError: Error | undefined;
const throwAbort = (): never => {
const error = new Error('Operation cancelled by user.');
error.name = 'AbortError';
throw error;
};
for (let attempt = 0; attempt <= config.maxRetries; attempt++) { for (let attempt = 0; attempt <= config.maxRetries; attempt++) {
if (options?.shouldAbort?.()) {
throwAbort();
}
try { try {
return await fn(); return await fn();
} catch (error) { } catch (error) {
lastError = error instanceof Error ? error : new Error(String(error)); lastError = error instanceof Error ? error : new Error(String(error));
if (lastError.name === 'AbortError') {
throw lastError;
}
if (options?.shouldAbort?.()) {
throwAbort();
}
// Don't retry non-retryable errors // Don't retry non-retryable errors
if (!isRetryable(lastError, config.nonRetryablePatterns)) { if (!isRetryable(lastError, config.nonRetryablePatterns)) {
throw lastError; throw lastError;
@@ -66,9 +90,55 @@ export async function withRetry<T>(
`[retry] ${label ?? 'operation'} attempt ${attempt + 1}/${config.maxRetries} failed: ${lastError.message}. Retrying in ${Math.round(jitter)}ms...`, `[retry] ${label ?? 'operation'} attempt ${attempt + 1}/${config.maxRetries} failed: ${lastError.message}. Retrying in ${Math.round(jitter)}ms...`,
); );
await new Promise(resolve => setTimeout(resolve, jitter)); await sleepWithAbort(jitter, options?.shouldAbort);
} }
} }
throw lastError ?? new Error('Retry failed with no error'); throw lastError ?? new Error('Retry failed with no error');
} }
async function sleepWithAbort(delayMs: number, shouldAbort?: () => boolean): Promise<void> {
if (!shouldAbort) {
await new Promise<void>((resolve) => setTimeout(resolve, delayMs));
return;
}
await new Promise<void>((resolve, reject) => {
const endAt = Date.now() + delayMs;
const checkIntervalMs = Math.min(100, Math.max(20, Math.floor(delayMs / 5)));
let timeout: NodeJS.Timeout | undefined;
let interval: NodeJS.Timeout | undefined;
const cleanup = () => {
if (timeout) {
clearTimeout(timeout);
}
if (interval) {
clearInterval(interval);
}
};
const rejectAbort = () => {
cleanup();
const error = new Error('Operation cancelled by user.');
error.name = 'AbortError';
reject(error);
};
if (shouldAbort()) {
rejectAbort();
return;
}
timeout = setTimeout(() => {
cleanup();
resolve();
}, Math.max(0, endAt - Date.now()));
interval = setInterval(() => {
if (shouldAbort()) {
rejectAbort();
}
}, checkIntervalMs);
});
}
+32
View File
@@ -476,6 +476,38 @@ describe('setClient and labels', () => {
expect(router.isTierStrict('default')).toBe(true); expect(router.isTierStrict('default')).toBe(true);
}); });
it('requestAbort interrupts retry loop before fallback chain', async () => {
const primary = {
chat: vi.fn().mockRejectedValue(new Error('temporary failure')),
} as unknown as ModelClient;
const fallback = {
chat: vi.fn().mockResolvedValue({
content: 'fallback',
stopReason: 'end_turn',
usage: { inputTokens: 1, outputTokens: 1 },
}),
} as unknown as ModelClient;
const router = new ModelRouter({
default: primary,
fallbackChain: [fallback],
retryConfig: {
maxRetries: 3,
initialDelayMs: 80,
backoffMultiplier: 1,
maxDelayMs: 80,
nonRetryablePatterns: [],
},
});
const run = router.chat({ messages: [{ role: 'user', content: 'hi' }] });
setTimeout(() => router.requestAbort(), 10);
await expect(run).rejects.toMatchObject({ name: 'AbortError' });
expect(primary.chat).toHaveBeenCalledTimes(1);
expect(fallback.chat).not.toHaveBeenCalled();
});
it('setOnTierChange does not replace existing listeners', () => { it('setOnTierChange does not replace existing listeners', () => {
const router = new ModelRouter({ const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient, default: { chat: vi.fn() } as unknown as ModelClient,
+24 -1
View File
@@ -28,6 +28,7 @@ export class ModelRouter implements ModelClient {
private retryConfig?: RetryConfig; private retryConfig?: RetryConfig;
private tierChangeListeners: Array<(tier: ModelTier) => void> = []; private tierChangeListeners: Array<(tier: ModelTier) => void> = [];
private strictTiers: Set<ModelTier> = new Set(); private strictTiers: Set<ModelTier> = new Set();
private abortRequested = false;
constructor(config: ModelRouterConfig) { constructor(config: ModelRouterConfig) {
this.clients = new Map(); this.clients = new Map();
@@ -89,8 +90,11 @@ export class ModelRouter implements ModelClient {
// Try primary client (with retry if configured) // Try primary client (with retry if configured)
try { try {
this.throwIfAborted();
if (this.retryConfig) { if (this.retryConfig) {
return await withRetry(() => primaryClient.chat(request), this.retryConfig, 'primary model'); return await withRetry(() => primaryClient.chat(request), this.retryConfig, 'primary model', {
shouldAbort: () => this.abortRequested,
});
} }
return await primaryClient.chat(request); return await primaryClient.chat(request);
} catch (error) { } catch (error) {
@@ -105,6 +109,7 @@ export class ModelRouter implements ModelClient {
// Try tier-specific fallbacks first // Try tier-specific fallbacks first
const tierFallbackList = this.tierFallbacks.get(useTier) ?? []; const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
for (let i = 0; i < tierFallbackList.length; i++) { for (let i = 0; i < tierFallbackList.length; i++) {
this.throwIfAborted();
try { try {
const reason = `Primary model failed (${errors[0].message}), using tier fallback #${i + 1}`; const reason = `Primary model failed (${errors[0].message}), using tier fallback #${i + 1}`;
logger.debug(reason); logger.debug(reason);
@@ -118,6 +123,7 @@ export class ModelRouter implements ModelClient {
// Then try global fallback chain // Then try global fallback chain
for (let i = 0; i < this.fallbackChain.length; i++) { for (let i = 0; i < this.fallbackChain.length; i++) {
this.throwIfAborted();
const fallbackClient = this.fallbackChain[i]; const fallbackClient = this.fallbackChain[i];
try { try {
const reason = `Primary model failed (${errors[0].message}), using global fallback #${i + 1}`; const reason = `Primary model failed (${errors[0].message}), using global fallback #${i + 1}`;
@@ -238,6 +244,14 @@ export class ModelRouter implements ModelClient {
return this.strictTiers.has(tier); return this.strictTiers.has(tier);
} }
requestAbort(): void {
this.abortRequested = true;
}
clearAbort(): void {
this.abortRequested = false;
}
getLabel(tier: ModelTier): string { getLabel(tier: ModelTier): string {
return this.labels.get(tier) ?? 'unknown'; return this.labels.get(tier) ?? 'unknown';
} }
@@ -249,4 +263,13 @@ export class ModelRouter implements ModelClient {
} }
return result; return result;
} }
private throwIfAborted(): void {
if (!this.abortRequested) {
return;
}
const error = new Error('Operation cancelled by user.');
error.name = 'AbortError';
throw error;
}
} }