feat: add OpenAI OAuth, strict model overrides, and Gmail pull mode
This commit is contained in:
@@ -140,6 +140,44 @@ describe('LlamaCppClient', () => {
|
||||
}]);
|
||||
});
|
||||
|
||||
it('sanitizes web_search tool schema for llama.cpp', async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({
|
||||
choices: [{ message: { content: 'ok' } }],
|
||||
usage: { prompt_tokens: 1, completion_tokens: 1 },
|
||||
}),
|
||||
});
|
||||
|
||||
const client = new LlamaCppClient({
|
||||
endpoint: 'http://localhost:8080',
|
||||
model: 'test-model',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'search' }],
|
||||
tools: [{
|
||||
name: 'web_search',
|
||||
description: 'Search',
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string' },
|
||||
count: { type: 'number' },
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
|
||||
expect(requestBody.tools[0].function.parameters).toEqual({
|
||||
type: 'object',
|
||||
properties: { query: { type: 'string' } },
|
||||
required: ['query'],
|
||||
});
|
||||
});
|
||||
|
||||
it('parses tool_calls from response', async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
|
||||
@@ -48,6 +48,36 @@ interface LlamaCppStreamChunk {
|
||||
usage?: { prompt_tokens: number; completion_tokens: number };
|
||||
}
|
||||
|
||||
function sanitizeToolParametersForLlamaCpp(toolName: string, parameters: unknown): unknown {
|
||||
// llama.cpp is stricter than most tool-call APIs about JSON schema.
|
||||
// In particular, some builds reject extra optional properties for common tools.
|
||||
// Keep the full schema for most tools, but reduce known-problematic ones.
|
||||
if (toolName !== 'web_search') {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
if (!parameters || typeof parameters !== 'object') {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
const schema = parameters as {
|
||||
type?: unknown;
|
||||
properties?: Record<string, unknown>;
|
||||
required?: unknown;
|
||||
};
|
||||
|
||||
const querySchema = schema.properties?.query;
|
||||
if (!querySchema) {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'object',
|
||||
properties: { query: querySchema },
|
||||
required: ['query'],
|
||||
};
|
||||
}
|
||||
|
||||
/** Message format for OpenAI-compatible chat completions API. */
|
||||
interface LlamaCppChatMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool';
|
||||
@@ -211,7 +241,7 @@ export class LlamaCppClient implements ModelClient {
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema,
|
||||
parameters: sanitizeToolParametersForLlamaCpp(t.name, t.input_schema),
|
||||
},
|
||||
}));
|
||||
}
|
||||
@@ -292,7 +322,7 @@ export class LlamaCppClient implements ModelClient {
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema,
|
||||
parameters: sanitizeToolParametersForLlamaCpp(t.name, t.input_schema),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
import { OpenAIClient } from './openai.js';
|
||||
|
||||
vi.mock('../auth/openai.js', () => ({
|
||||
ensureValidOpenAIAuth: vi.fn(async () => ({
|
||||
access_token: 'at',
|
||||
refresh_token: 'rt',
|
||||
expires_at: Date.now() + 60_000,
|
||||
created_at: new Date().toISOString(),
|
||||
account_id: 'acct',
|
||||
})),
|
||||
}));
|
||||
|
||||
function makeSse(events: Array<{ event: string; data: any }>): string {
|
||||
return events
|
||||
.map((e) => `event: ${e.event}\ndata: ${JSON.stringify(e.data)}\n\n`)
|
||||
.join('');
|
||||
}
|
||||
|
||||
describe('OpenAIClient OAuth (Codex)', () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
vi.useRealTimers();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('streams SSE and accumulates output_text.delta', async () => {
|
||||
const sse = makeSse([
|
||||
{ event: 'response.created', data: { type: 'response.created', response: { id: 'r1' } } },
|
||||
{ event: 'response.output_text.delta', data: { type: 'response.output_text.delta', delta: 'hel' } },
|
||||
{ event: 'response.output_text.delta', data: { type: 'response.output_text.delta', delta: 'lo' } },
|
||||
{ event: 'response.completed', data: { type: 'response.completed', response: { usage: { input_tokens: 2, output_tokens: 2 } } } },
|
||||
]);
|
||||
|
||||
globalThis.fetch = vi.fn(async (_url: any, init?: any) => {
|
||||
const parsed = JSON.parse(init.body);
|
||||
expect(parsed.store).toBe(false);
|
||||
expect(parsed.stream).toBe(true);
|
||||
expect(typeof parsed.instructions).toBe('string');
|
||||
expect(Array.isArray(parsed.input)).toBe(true);
|
||||
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
controller.enqueue(new TextEncoder().encode(sse));
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
|
||||
return new Response(stream, { status: 200 });
|
||||
}) as any;
|
||||
|
||||
const client = new OpenAIClient({ model: 'gpt-5.3-codex', useOAuth: true });
|
||||
const resp = await client.chat({
|
||||
system: 'You are helpful.',
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
});
|
||||
|
||||
expect(resp.content).toBe('hello');
|
||||
expect(resp.usage).toEqual({ inputTokens: 2, outputTokens: 2 });
|
||||
});
|
||||
});
|
||||
@@ -1,23 +1,38 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { OpenAIClient } from './openai.js';
|
||||
|
||||
// Shared mock function so we can override per-test
|
||||
const mockCreate = vi.fn().mockResolvedValue({
|
||||
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
|
||||
usage: { prompt_tokens: 10, completion_tokens: 5 },
|
||||
});
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({
|
||||
const { mockCreate, mockOpenAIConstructor } = vi.hoisted(() => {
|
||||
const mockCreate = vi.fn().mockResolvedValue({
|
||||
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
|
||||
usage: { prompt_tokens: 10, completion_tokens: 5 },
|
||||
});
|
||||
const mockOpenAIConstructor = vi.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate,
|
||||
},
|
||||
},
|
||||
})),
|
||||
}));
|
||||
return { mockCreate, mockOpenAIConstructor };
|
||||
});
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
default: mockOpenAIConstructor,
|
||||
}));
|
||||
|
||||
describe('OpenAIClient', () => {
|
||||
it('sets request timeout and disables SDK retries', () => {
|
||||
new OpenAIClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gpt-4o',
|
||||
});
|
||||
|
||||
expect(mockOpenAIConstructor).toHaveBeenCalledWith(expect.objectContaining({
|
||||
timeout: 20_000,
|
||||
maxRetries: 0,
|
||||
}));
|
||||
});
|
||||
|
||||
it('sends messages and returns response', async () => {
|
||||
const client = new OpenAIClient({
|
||||
apiKey: 'test-key',
|
||||
|
||||
+151
-6
@@ -1,11 +1,16 @@
|
||||
import OpenAI from 'openai';
|
||||
import type { ChatRequest, ChatResponse, ModelClient, MessageContentPart } from './types.js';
|
||||
import type { ChatRequest, ChatResponse, ModelClient, MessageContentPart, TokenUsage } from './types.js';
|
||||
import { getMessageTextWithTools } from './media.js';
|
||||
import { ensureValidOpenAIAuth } from '../auth/openai.js';
|
||||
|
||||
export interface OpenAIClientConfig {
|
||||
apiKey?: string;
|
||||
model: string;
|
||||
maxTokens?: number;
|
||||
baseURL?: string;
|
||||
timeoutMs?: number;
|
||||
/** If true, use ChatGPT subscription OAuth via the Codex backend endpoint. */
|
||||
useOAuth?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -52,20 +57,160 @@ function toOpenAIContent(content: string | MessageContentPart[]): string | OpenA
|
||||
}
|
||||
|
||||
export class OpenAIClient implements ModelClient {
|
||||
private client: OpenAI;
|
||||
private client?: OpenAI;
|
||||
private model: string;
|
||||
private defaultMaxTokens: number;
|
||||
private useOAuth: boolean;
|
||||
|
||||
constructor(config: OpenAIClientConfig) {
|
||||
this.client = new OpenAI({
|
||||
apiKey: config.apiKey,
|
||||
baseURL: config.baseURL,
|
||||
});
|
||||
const timeoutMs = config.timeoutMs ?? 20_000;
|
||||
this.useOAuth = Boolean(config.useOAuth);
|
||||
|
||||
// OAuth mode uses a different backend (ChatGPT Codex) and a different API shape.
|
||||
// Only initialize the OpenAI SDK for API-key providers.
|
||||
if (!this.useOAuth) {
|
||||
this.client = new OpenAI({
|
||||
apiKey: config.apiKey,
|
||||
baseURL: config.baseURL,
|
||||
timeout: timeoutMs,
|
||||
maxRetries: 0,
|
||||
});
|
||||
}
|
||||
this.model = config.model;
|
||||
this.defaultMaxTokens = config.maxTokens ?? 4096;
|
||||
}
|
||||
|
||||
private async chatViaOAuthCodex(request: ChatRequest): Promise<ChatResponse> {
|
||||
const CODEX_API_ENDPOINT = 'https://chatgpt.com/backend-api/codex/responses';
|
||||
|
||||
const auth = await ensureValidOpenAIAuth();
|
||||
|
||||
// Codex endpoint requires:
|
||||
// - instructions (non-empty)
|
||||
// - input must be a list
|
||||
// - store must be false
|
||||
// - stream must be true (SSE)
|
||||
const instructions = (request.system ?? '').trim() || 'You are helpful.';
|
||||
|
||||
const input = request.messages
|
||||
.map((m) => {
|
||||
const text = getMessageTextWithTools(m);
|
||||
if (!text) {return null;}
|
||||
return {
|
||||
role: m.role,
|
||||
content: [{ type: 'input_text', text }],
|
||||
};
|
||||
})
|
||||
.filter((x): x is NonNullable<typeof x> => Boolean(x));
|
||||
|
||||
const body = {
|
||||
model: this.model,
|
||||
instructions,
|
||||
store: false,
|
||||
stream: true,
|
||||
input,
|
||||
// Intentionally omit max_output_tokens: Codex endpoint rejects it.
|
||||
// Also omit tools/tool_choice for now.
|
||||
};
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'content-type': 'application/json',
|
||||
'authorization': `Bearer ${auth.access_token}`,
|
||||
'originator': 'flynn',
|
||||
'user-agent': 'flynn/0.1',
|
||||
'session_id': `flynn-${Date.now()}`,
|
||||
};
|
||||
if (auth.account_id) {
|
||||
headers['ChatGPT-Account-Id'] = auth.account_id;
|
||||
}
|
||||
|
||||
const res = await fetch(CODEX_API_ENDPOINT, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`${res.status} ${res.statusText}${text ? `: ${text}` : ''}`);
|
||||
}
|
||||
|
||||
if (!res.body) {
|
||||
throw new Error('OpenAI OAuth request failed: missing response body');
|
||||
}
|
||||
|
||||
let buffer = '';
|
||||
let outputText = '';
|
||||
let usage: TokenUsage | undefined;
|
||||
|
||||
const reader = res.body.getReader();
|
||||
|
||||
const processBlock = (block: string): void => {
|
||||
const lines = block.split('\n');
|
||||
let data = '';
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data:')) {
|
||||
data += line.slice('data:'.length).trim();
|
||||
}
|
||||
}
|
||||
if (!data) {return;}
|
||||
let obj: any;
|
||||
try {
|
||||
obj = JSON.parse(data);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
if (obj.type === 'response.output_text.delta' && typeof obj.delta === 'string') {
|
||||
outputText += obj.delta;
|
||||
}
|
||||
|
||||
if (obj.type === 'response.completed') {
|
||||
const u = obj.response?.usage;
|
||||
if (u) {
|
||||
usage = {
|
||||
inputTokens: u.input_tokens ?? 0,
|
||||
outputTokens: u.output_tokens ?? 0,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (obj.type === 'response.failed') {
|
||||
const detail = obj.response?.error?.message ?? 'OpenAI OAuth response failed';
|
||||
throw new Error(detail);
|
||||
}
|
||||
};
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {break;}
|
||||
buffer += Buffer.from(value).toString('utf8');
|
||||
|
||||
while (true) {
|
||||
const idx = buffer.indexOf('\n\n');
|
||||
if (idx === -1) {break;}
|
||||
const block = buffer.slice(0, idx);
|
||||
buffer = buffer.slice(idx + 2);
|
||||
processBlock(block);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: outputText,
|
||||
stopReason: 'end_turn',
|
||||
usage: usage ?? { inputTokens: 0, outputTokens: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
async chat(request: ChatRequest): Promise<ChatResponse> {
|
||||
if (this.useOAuth) {
|
||||
return this.chatViaOAuthCodex(request);
|
||||
}
|
||||
|
||||
if (!this.client) {
|
||||
throw new Error('OpenAI client not initialized');
|
||||
}
|
||||
|
||||
const messages: OpenAI.ChatCompletionMessageParam[] = [];
|
||||
|
||||
if (request.system) {
|
||||
|
||||
@@ -4,10 +4,15 @@ import type { RetryConfig } from './retry.js';
|
||||
|
||||
describe('isRetryable', () => {
|
||||
it('returns true for generic errors', () => {
|
||||
const error = new Error('Connection timeout');
|
||||
const error = new Error('Connection reset by peer');
|
||||
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for timeout errors', () => {
|
||||
const error = new Error('Request timed out after 20000ms');
|
||||
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for authentication errors', () => {
|
||||
const error = new Error('Invalid API key: authentication failed');
|
||||
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(false);
|
||||
@@ -75,8 +80,8 @@ describe('withRetry', () => {
|
||||
|
||||
it('retries on transient failure then succeeds', async () => {
|
||||
const fn = vi.fn()
|
||||
.mockRejectedValueOnce(new Error('timeout'))
|
||||
.mockRejectedValueOnce(new Error('timeout'))
|
||||
.mockRejectedValueOnce(new Error('temporary network issue'))
|
||||
.mockRejectedValueOnce(new Error('temporary network issue'))
|
||||
.mockResolvedValueOnce('recovered');
|
||||
|
||||
const result = await withRetry(fn, fastConfig, 'test-op');
|
||||
|
||||
@@ -26,6 +26,9 @@ export const DEFAULT_RETRY_CONFIG: RetryConfig = {
|
||||
'context_length_exceeded',
|
||||
'content_policy',
|
||||
'does not support',
|
||||
'timeout',
|
||||
'timed out',
|
||||
'request aborted',
|
||||
],
|
||||
};
|
||||
|
||||
|
||||
@@ -438,4 +438,29 @@ describe('setClient and labels', () => {
|
||||
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
|
||||
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('strict tier mode disables fallback chain for that tier', async () => {
|
||||
const failingDefault = {
|
||||
chat: vi.fn().mockRejectedValue(new Error('primary failed')),
|
||||
} as unknown as ModelClient;
|
||||
const fallback = {
|
||||
chat: vi.fn().mockResolvedValue({
|
||||
content: 'fallback',
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 1, outputTokens: 1 },
|
||||
}),
|
||||
} as unknown as ModelClient;
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: failingDefault,
|
||||
fallbackChain: [fallback],
|
||||
});
|
||||
|
||||
router.setTierStrict('default', true);
|
||||
|
||||
await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'default'))
|
||||
.rejects.toThrow('primary failed');
|
||||
expect(fallback.chat).not.toHaveBeenCalled();
|
||||
expect(router.isTierStrict('default')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -27,6 +27,7 @@ export class ModelRouter implements ModelClient {
|
||||
private localProviderName?: string;
|
||||
private retryConfig?: RetryConfig;
|
||||
private tierChangeListeners: Array<(tier: ModelTier) => void> = [];
|
||||
private strictTiers: Set<ModelTier> = new Set();
|
||||
|
||||
constructor(config: ModelRouterConfig) {
|
||||
this.clients = new Map();
|
||||
@@ -97,6 +98,10 @@ export class ModelRouter implements ModelClient {
|
||||
logger.debug(`Primary model failed: ${errors[0].message}`);
|
||||
}
|
||||
|
||||
if (this.strictTiers.has(useTier)) {
|
||||
throw errors[0];
|
||||
}
|
||||
|
||||
// Try tier-specific fallbacks first
|
||||
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
|
||||
for (let i = 0; i < tierFallbackList.length; i++) {
|
||||
@@ -150,6 +155,11 @@ export class ModelRouter implements ModelClient {
|
||||
primaryError = 'Primary client does not support streaming';
|
||||
}
|
||||
|
||||
if (this.strictTiers.has(useTier)) {
|
||||
yield { type: 'error', error: new Error(primaryError ?? 'Primary model failed') };
|
||||
return;
|
||||
}
|
||||
|
||||
// Try tier-specific fallbacks first
|
||||
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
|
||||
for (let i = 0; i < tierFallbackList.length; i++) {
|
||||
@@ -216,6 +226,18 @@ export class ModelRouter implements ModelClient {
|
||||
this.labels.set(tier, label);
|
||||
}
|
||||
|
||||
setTierStrict(tier: ModelTier, strict: boolean): void {
|
||||
if (strict) {
|
||||
this.strictTiers.add(tier);
|
||||
return;
|
||||
}
|
||||
this.strictTiers.delete(tier);
|
||||
}
|
||||
|
||||
isTierStrict(tier: ModelTier): boolean {
|
||||
return this.strictTiers.has(tier);
|
||||
}
|
||||
|
||||
getLabel(tier: ModelTier): string {
|
||||
return this.labels.get(tier) ?? 'unknown';
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user