Files
flynn/src/models/tts.ts
T

260 lines
7.1 KiB
TypeScript

import type { OutboundAttachment } from '../channels/types.js';
export type TtsOutputFormat = 'mp3' | 'wav' | 'opus';
export interface TtsSynthesisConfig {
endpoint?: string;
apiKey?: string;
model?: string;
voice?: string;
format?: TtsOutputFormat;
/** Optional provider identity used by fallback/health tracking. */
id?: string;
/** Optional provider type hint for endpoint defaults. */
type?: 'openai' | 'custom';
}
export interface TtsFallbackConfig {
/** Maximum number of providers to try per reply. */
maxAttempts?: number;
/** Cooldown window before retrying a provider that failed. */
failureCooldownMs?: number;
}
export interface TtsProviderHealth {
consecutiveFailures: number;
cooldownUntil: number;
lastFailureAt?: number;
lastError?: string;
}
export interface TtsFallbackResult {
attachment: OutboundAttachment | null;
providerId?: string;
attemptedProviders: string[];
skippedProviders: string[];
lastError?: string;
}
export class TtsHealthTracker {
private readonly states = new Map<string, TtsProviderHealth>();
isHealthy(providerId: string, now = Date.now()): boolean {
const state = this.states.get(providerId);
if (!state) {
return true;
}
return state.cooldownUntil <= now;
}
markSuccess(providerId: string): void {
this.states.delete(providerId);
}
markFailure(providerId: string, error: unknown, now = Date.now(), failureCooldownMs = 60_000): void {
const previous = this.states.get(providerId);
const consecutiveFailures = (previous?.consecutiveFailures ?? 0) + 1;
const message = error instanceof Error ? error.message : String(error);
this.states.set(providerId, {
consecutiveFailures,
cooldownUntil: now + Math.max(1_000, failureCooldownMs),
lastFailureAt: now,
lastError: message,
});
}
getState(providerId: string): TtsProviderHealth | undefined {
return this.states.get(providerId);
}
}
function outputFormatToMimeType(format: TtsOutputFormat): string {
switch (format) {
case 'wav':
return 'audio/wav';
case 'opus':
return 'audio/ogg';
case 'mp3':
default:
return 'audio/mpeg';
}
}
function outputFormatToExtension(format: TtsOutputFormat): string {
switch (format) {
case 'wav':
return 'wav';
case 'opus':
return 'ogg';
case 'mp3':
default:
return 'mp3';
}
}
function resolveProviderEndpoint(config: TtsSynthesisConfig): string | undefined {
if (config.endpoint && config.endpoint.trim().length > 0) {
return config.endpoint;
}
if (config.type === 'openai') {
return 'https://api.openai.com/v1/audio/speech';
}
return undefined;
}
function normalizeMaxAttempts(maxAttempts: number | undefined, providerCount: number): number {
if (providerCount <= 0) {
return 0;
}
if (!Number.isFinite(maxAttempts)) {
return providerCount;
}
return Math.min(providerCount, Math.max(1, Math.floor(maxAttempts as number)));
}
/** Synthesize speech via an OpenAI-compatible /v1/audio/speech endpoint. */
export async function synthesizeSpeechAttachment(
text: string,
config: TtsSynthesisConfig,
): Promise<OutboundAttachment | null> {
const trimmed = text.trim();
if (!trimmed) {
return null;
}
if (!config.endpoint) {
return null;
}
const format = config.format ?? 'mp3';
const model = config.model ?? 'gpt-4o-mini-tts';
const voice = config.voice ?? 'alloy';
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};
if (config.apiKey) {
headers.Authorization = `Bearer ${config.apiKey}`;
}
const response = await fetch(config.endpoint, {
method: 'POST',
headers,
body: JSON.stringify({
model,
voice,
input: trimmed,
response_format: format,
}),
});
if (!response.ok) {
const detail = await response.text().catch(() => '');
throw new Error(
`TTS request failed: ${response.status} ${response.statusText}${detail ? ` - ${detail.slice(0, 200)}` : ''}`,
);
}
const audioBytes = await response.arrayBuffer();
const data = Buffer.from(audioBytes).toString('base64');
const extension = outputFormatToExtension(format);
return {
mimeType: outputFormatToMimeType(format),
data,
filename: `flynn-reply-${Date.now()}.${extension}`,
};
}
/**
* Attempt TTS synthesis against an ordered provider chain with health-aware fallback.
* When every provider fails, returns text fallback metadata instead of throwing.
*/
export async function synthesizeSpeechWithFallback(
text: string,
input: {
providers: TtsSynthesisConfig[];
fallback?: TtsFallbackConfig;
healthTracker?: TtsHealthTracker;
},
): Promise<TtsFallbackResult> {
const trimmed = text.trim();
if (!trimmed) {
return {
attachment: null,
attemptedProviders: [],
skippedProviders: [],
};
}
const providers = input.providers
.map((provider, index) => {
const endpoint = resolveProviderEndpoint(provider);
return {
...provider,
endpoint,
id: provider.id ?? `tts-provider-${index + 1}`,
};
})
.filter((provider) => typeof provider.endpoint === 'string' && provider.endpoint.length > 0);
if (providers.length === 0) {
return {
attachment: null,
attemptedProviders: [],
skippedProviders: [],
};
}
const healthTracker = input.healthTracker;
const now = Date.now();
const healthyProviders: Array<TtsSynthesisConfig & { id: string; endpoint: string }> = [];
const unhealthyProviders: Array<TtsSynthesisConfig & { id: string; endpoint: string }> = [];
for (const provider of providers) {
const typedProvider = provider as TtsSynthesisConfig & { id: string; endpoint: string };
if (!healthTracker || healthTracker.isHealthy(typedProvider.id, now)) {
healthyProviders.push(typedProvider);
} else {
unhealthyProviders.push(typedProvider);
}
}
const orderedProviders = healthyProviders.length > 0
? [...healthyProviders, ...unhealthyProviders]
: unhealthyProviders;
const maxAttempts = normalizeMaxAttempts(input.fallback?.maxAttempts, orderedProviders.length);
const attempts = orderedProviders.slice(0, maxAttempts);
const skippedProviders = orderedProviders.slice(maxAttempts).map((provider) => provider.id);
const attemptedProviders: string[] = [];
let lastError: string | undefined;
const failureCooldownMs = input.fallback?.failureCooldownMs ?? 60_000;
for (const provider of attempts) {
attemptedProviders.push(provider.id);
try {
const attachment = await synthesizeSpeechAttachment(trimmed, provider);
if (attachment) {
healthTracker?.markSuccess(provider.id);
return {
attachment,
providerId: provider.id,
attemptedProviders,
skippedProviders,
};
}
} catch (error) {
healthTracker?.markFailure(provider.id, error, Date.now(), failureCooldownMs);
lastError = error instanceof Error ? error.message : String(error);
}
}
return {
attachment: null,
attemptedProviders,
skippedProviders,
lastError,
};
}