fix: abort model retries immediately on user cancellation

This commit is contained in:
William Valentin
2026-02-18 11:21:57 -08:00
parent a76c5ae346
commit 55cde541ea
6 changed files with 177 additions and 3 deletions
+28
View File
@@ -159,6 +159,34 @@ describe('withRetry', () => {
expect(fn).toHaveBeenCalledTimes(1);
});
it('aborts before first attempt when shouldAbort is true', async () => {
const fn = vi.fn().mockResolvedValue('never');
await expect(
withRetry(fn, fastConfig, 'abort-test', { shouldAbort: () => true }),
).rejects.toMatchObject({ name: 'AbortError' });
expect(fn).not.toHaveBeenCalled();
});
it('aborts during backoff sleep when shouldAbort flips true', async () => {
let abort = false;
const fn = vi.fn().mockRejectedValue(new Error('temporary failure'));
const run = withRetry(
fn,
{ ...fastConfig, maxRetries: 3, initialDelayMs: 80, backoffMultiplier: 1, maxDelayMs: 80 },
'abort-backoff-test',
{ shouldAbort: () => abort },
);
setTimeout(() => {
abort = true;
}, 10);
await expect(run).rejects.toMatchObject({ name: 'AbortError' });
expect(fn).toHaveBeenCalledTimes(1);
});
it('increases delay exponentially between retries', async () => {
const timestamps: number[] = [];
const config: RetryConfig = {
+71 -1
View File
@@ -13,6 +13,11 @@ export interface RetryConfig {
nonRetryablePatterns: string[];
}
export interface RetryExecutionOptions {
/** Abort retry loop before next attempt and during backoff sleeps. */
shouldAbort?: () => boolean;
}
export const DEFAULT_RETRY_CONFIG: RetryConfig = {
maxRetries: 3,
initialDelayMs: 1000,
@@ -38,15 +43,34 @@ export async function withRetry<T>(
fn: () => Promise<T>,
config: RetryConfig = DEFAULT_RETRY_CONFIG,
label?: string,
options?: RetryExecutionOptions,
): Promise<T> {
let lastError: Error | undefined;
const throwAbort = (): never => {
const error = new Error('Operation cancelled by user.');
error.name = 'AbortError';
throw error;
};
for (let attempt = 0; attempt <= config.maxRetries; attempt++) {
if (options?.shouldAbort?.()) {
throwAbort();
}
try {
return await fn();
} catch (error) {
lastError = error instanceof Error ? error : new Error(String(error));
if (lastError.name === 'AbortError') {
throw lastError;
}
if (options?.shouldAbort?.()) {
throwAbort();
}
// Don't retry non-retryable errors
if (!isRetryable(lastError, config.nonRetryablePatterns)) {
throw lastError;
@@ -66,9 +90,55 @@ export async function withRetry<T>(
`[retry] ${label ?? 'operation'} attempt ${attempt + 1}/${config.maxRetries} failed: ${lastError.message}. Retrying in ${Math.round(jitter)}ms...`,
);
await new Promise(resolve => setTimeout(resolve, jitter));
await sleepWithAbort(jitter, options?.shouldAbort);
}
}
throw lastError ?? new Error('Retry failed with no error');
}
async function sleepWithAbort(delayMs: number, shouldAbort?: () => boolean): Promise<void> {
if (!shouldAbort) {
await new Promise<void>((resolve) => setTimeout(resolve, delayMs));
return;
}
await new Promise<void>((resolve, reject) => {
const endAt = Date.now() + delayMs;
const checkIntervalMs = Math.min(100, Math.max(20, Math.floor(delayMs / 5)));
let timeout: NodeJS.Timeout | undefined;
let interval: NodeJS.Timeout | undefined;
const cleanup = () => {
if (timeout) {
clearTimeout(timeout);
}
if (interval) {
clearInterval(interval);
}
};
const rejectAbort = () => {
cleanup();
const error = new Error('Operation cancelled by user.');
error.name = 'AbortError';
reject(error);
};
if (shouldAbort()) {
rejectAbort();
return;
}
timeout = setTimeout(() => {
cleanup();
resolve();
}, Math.max(0, endAt - Date.now()));
interval = setInterval(() => {
if (shouldAbort()) {
rejectAbort();
}
}, checkIntervalMs);
});
}
+32
View File
@@ -476,6 +476,38 @@ describe('setClient and labels', () => {
expect(router.isTierStrict('default')).toBe(true);
});
it('requestAbort interrupts retry loop before fallback chain', async () => {
const primary = {
chat: vi.fn().mockRejectedValue(new Error('temporary failure')),
} as unknown as ModelClient;
const fallback = {
chat: vi.fn().mockResolvedValue({
content: 'fallback',
stopReason: 'end_turn',
usage: { inputTokens: 1, outputTokens: 1 },
}),
} as unknown as ModelClient;
const router = new ModelRouter({
default: primary,
fallbackChain: [fallback],
retryConfig: {
maxRetries: 3,
initialDelayMs: 80,
backoffMultiplier: 1,
maxDelayMs: 80,
nonRetryablePatterns: [],
},
});
const run = router.chat({ messages: [{ role: 'user', content: 'hi' }] });
setTimeout(() => router.requestAbort(), 10);
await expect(run).rejects.toMatchObject({ name: 'AbortError' });
expect(primary.chat).toHaveBeenCalledTimes(1);
expect(fallback.chat).not.toHaveBeenCalled();
});
it('setOnTierChange does not replace existing listeners', () => {
const router = new ModelRouter({
default: { chat: vi.fn() } as unknown as ModelClient,
+24 -1
View File
@@ -28,6 +28,7 @@ export class ModelRouter implements ModelClient {
private retryConfig?: RetryConfig;
private tierChangeListeners: Array<(tier: ModelTier) => void> = [];
private strictTiers: Set<ModelTier> = new Set();
private abortRequested = false;
constructor(config: ModelRouterConfig) {
this.clients = new Map();
@@ -89,8 +90,11 @@ export class ModelRouter implements ModelClient {
// Try primary client (with retry if configured)
try {
this.throwIfAborted();
if (this.retryConfig) {
return await withRetry(() => primaryClient.chat(request), this.retryConfig, 'primary model');
return await withRetry(() => primaryClient.chat(request), this.retryConfig, 'primary model', {
shouldAbort: () => this.abortRequested,
});
}
return await primaryClient.chat(request);
} catch (error) {
@@ -105,6 +109,7 @@ export class ModelRouter implements ModelClient {
// Try tier-specific fallbacks first
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
for (let i = 0; i < tierFallbackList.length; i++) {
this.throwIfAborted();
try {
const reason = `Primary model failed (${errors[0].message}), using tier fallback #${i + 1}`;
logger.debug(reason);
@@ -118,6 +123,7 @@ export class ModelRouter implements ModelClient {
// Then try global fallback chain
for (let i = 0; i < this.fallbackChain.length; i++) {
this.throwIfAborted();
const fallbackClient = this.fallbackChain[i];
try {
const reason = `Primary model failed (${errors[0].message}), using global fallback #${i + 1}`;
@@ -238,6 +244,14 @@ export class ModelRouter implements ModelClient {
return this.strictTiers.has(tier);
}
requestAbort(): void {
this.abortRequested = true;
}
clearAbort(): void {
this.abortRequested = false;
}
getLabel(tier: ModelTier): string {
return this.labels.get(tier) ?? 'unknown';
}
@@ -249,4 +263,13 @@ export class ModelRouter implements ModelClient {
}
return result;
}
private throwIfAborted(): void {
if (!this.abortRequested) {
return;
}
const error = new Error('Operation cancelled by user.');
error.name = 'AbortError';
throw error;
}
}