feat: add OpenAI OAuth, strict model overrides, and Gmail pull mode
This commit is contained in:
@@ -7,3 +7,16 @@ export {
|
||||
loginGitHub,
|
||||
type DeviceCodeResponse,
|
||||
} from './github.js';
|
||||
|
||||
export {
|
||||
loadStoredOpenAIAuth,
|
||||
storeOpenAIAuth,
|
||||
clearOpenAIAuth,
|
||||
refreshOpenAIAuth,
|
||||
ensureValidOpenAIAuth,
|
||||
loginOpenAI,
|
||||
parseJwtClaims,
|
||||
extractAccountId,
|
||||
type OpenAIOAuthInfo,
|
||||
type IdTokenClaims,
|
||||
} from './openai.js';
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
|
||||
import { parseJwtClaims, extractAccountId } from './openai.js';
|
||||
|
||||
function base64UrlEncode(obj: unknown): string {
|
||||
return Buffer.from(JSON.stringify(obj)).toString('base64url');
|
||||
}
|
||||
|
||||
function makeJwt(payload: Record<string, unknown>): string {
|
||||
const header = base64UrlEncode({ alg: 'none', typ: 'JWT' });
|
||||
const body = base64UrlEncode(payload);
|
||||
// Signature is ignored by parseJwtClaims.
|
||||
return `${header}.${body}.sig`;
|
||||
}
|
||||
|
||||
describe('OpenAI OAuth helpers', () => {
|
||||
it('parseJwtClaims returns undefined for non-jwt strings', () => {
|
||||
expect(parseJwtClaims('not-a-jwt')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('parseJwtClaims parses base64url payload', () => {
|
||||
const token = makeJwt({ chatgpt_account_id: 'acct_123' });
|
||||
const claims = parseJwtClaims(token);
|
||||
expect(claims?.chatgpt_account_id).toBe('acct_123');
|
||||
});
|
||||
|
||||
it('extractAccountId prefers chatgpt_account_id', () => {
|
||||
const tokens = {
|
||||
access_token: makeJwt({ chatgpt_account_id: 'acct_a' }),
|
||||
refresh_token: 'rt',
|
||||
id_token: makeJwt({ chatgpt_account_id: 'acct_b' }),
|
||||
};
|
||||
expect(extractAccountId(tokens)).toBe('acct_b');
|
||||
});
|
||||
|
||||
it('extractAccountId falls back to organizations[0].id', () => {
|
||||
const tokens = {
|
||||
access_token: makeJwt({ organizations: [{ id: 'org_1' }] }),
|
||||
refresh_token: 'rt',
|
||||
};
|
||||
expect(extractAccountId(tokens)).toBe('org_1');
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,281 @@
|
||||
import { readFileSync, writeFileSync, mkdirSync, chmodSync } from 'fs';
|
||||
import { resolve } from 'path';
|
||||
import { homedir } from 'os';
|
||||
|
||||
const ISSUER = 'https://auth.openai.com';
|
||||
const CLIENT_ID = 'app_EMoamEEZ73f0CkXaXp7hrann';
|
||||
const DEVICE_URL = `${ISSUER}/codex/device`;
|
||||
const DEVICE_CODE_URL = `${ISSUER}/api/accounts/deviceauth/usercode`;
|
||||
const DEVICE_TOKEN_URL = `${ISSUER}/api/accounts/deviceauth/token`;
|
||||
const TOKEN_URL = `${ISSUER}/oauth/token`;
|
||||
|
||||
const POLLING_SAFETY_MARGIN_MS = 3000;
|
||||
const REFRESH_SAFETY_MARGIN_MS = 30_000;
|
||||
|
||||
const AUTH_DIR = resolve(homedir(), '.config/flynn');
|
||||
const AUTH_FILE = resolve(AUTH_DIR, 'auth.json');
|
||||
|
||||
export interface OpenAIOAuthInfo {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
/** Epoch millis. */
|
||||
expires_at: number;
|
||||
/** Optional account/org id used for subscription routing. */
|
||||
account_id?: string;
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
interface AuthStore {
|
||||
// Leave github entry untyped here so this module does not depend on github.ts.
|
||||
github?: unknown;
|
||||
openai?: OpenAIOAuthInfo;
|
||||
}
|
||||
|
||||
interface DeviceAuthResponse {
|
||||
device_auth_id: string;
|
||||
user_code: string;
|
||||
interval: string;
|
||||
}
|
||||
|
||||
interface DeviceTokenResponse {
|
||||
authorization_code: string;
|
||||
code_verifier: string;
|
||||
}
|
||||
|
||||
interface TokenResponse {
|
||||
id_token?: string;
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in?: number;
|
||||
}
|
||||
|
||||
export interface IdTokenClaims {
|
||||
chatgpt_account_id?: string;
|
||||
organizations?: Array<{ id: string }>;
|
||||
'https://api.openai.com/auth'?: {
|
||||
chatgpt_account_id?: string;
|
||||
};
|
||||
}
|
||||
|
||||
function safeJsonParse<T>(raw: string): T | null {
|
||||
try {
|
||||
return JSON.parse(raw) as T;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function readAuthStore(): AuthStore {
|
||||
try {
|
||||
const raw = readFileSync(AUTH_FILE, 'utf-8');
|
||||
const parsed = safeJsonParse<AuthStore>(raw);
|
||||
return parsed ?? {};
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function writeAuthStore(store: AuthStore): void {
|
||||
mkdirSync(AUTH_DIR, { recursive: true });
|
||||
writeFileSync(AUTH_FILE, JSON.stringify(store, null, 2) + '\n', 'utf-8');
|
||||
chmodSync(AUTH_FILE, 0o600);
|
||||
}
|
||||
|
||||
export function loadStoredOpenAIAuth(): OpenAIOAuthInfo | null {
|
||||
const store = readAuthStore();
|
||||
return store.openai ?? null;
|
||||
}
|
||||
|
||||
export function storeOpenAIAuth(info: OpenAIOAuthInfo): void {
|
||||
const store = readAuthStore();
|
||||
store.openai = info;
|
||||
writeAuthStore(store);
|
||||
}
|
||||
|
||||
export function clearOpenAIAuth(): void {
|
||||
const store = readAuthStore();
|
||||
delete store.openai;
|
||||
writeAuthStore(store);
|
||||
}
|
||||
|
||||
export function parseJwtClaims(token: string): IdTokenClaims | undefined {
|
||||
const parts = token.split('.');
|
||||
if (parts.length !== 3) {return undefined;}
|
||||
try {
|
||||
return JSON.parse(Buffer.from(parts[1], 'base64url').toString()) as IdTokenClaims;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
function extractAccountIdFromClaims(claims: IdTokenClaims): string | undefined {
|
||||
return claims.chatgpt_account_id
|
||||
?? claims['https://api.openai.com/auth']?.chatgpt_account_id
|
||||
?? claims.organizations?.[0]?.id;
|
||||
}
|
||||
|
||||
export function extractAccountId(tokens: TokenResponse): string | undefined {
|
||||
const idToken = tokens.id_token;
|
||||
if (idToken) {
|
||||
const claims = parseJwtClaims(idToken);
|
||||
const id = claims && extractAccountIdFromClaims(claims);
|
||||
if (id) {return id;}
|
||||
}
|
||||
const accessToken = tokens.access_token;
|
||||
if (accessToken) {
|
||||
const claims = parseJwtClaims(accessToken);
|
||||
return claims ? extractAccountIdFromClaims(claims) : undefined;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async function requestDeviceAuth(): Promise<DeviceAuthResponse> {
|
||||
const response = await fetch(DEVICE_CODE_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'flynn',
|
||||
},
|
||||
body: JSON.stringify({ client_id: CLIENT_ID }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const body = await response.text();
|
||||
throw new Error(`OpenAI device auth start failed (${response.status}): ${body}`);
|
||||
}
|
||||
|
||||
return response.json() as Promise<DeviceAuthResponse>;
|
||||
}
|
||||
|
||||
async function pollDeviceToken(deviceAuthId: string, userCode: string, intervalMs: number): Promise<DeviceTokenResponse> {
|
||||
while (true) {
|
||||
await new Promise(r => setTimeout(r, intervalMs + POLLING_SAFETY_MARGIN_MS));
|
||||
|
||||
const response = await fetch(DEVICE_TOKEN_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'flynn',
|
||||
},
|
||||
body: JSON.stringify({ device_auth_id: deviceAuthId, user_code: userCode }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return response.json() as Promise<DeviceTokenResponse>;
|
||||
}
|
||||
|
||||
// OpenCode treats 403/404 as "pending".
|
||||
if (response.status === 403 || response.status === 404) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const body = await response.text();
|
||||
throw new Error(`OpenAI device auth token failed (${response.status}): ${body}`);
|
||||
}
|
||||
}
|
||||
|
||||
async function exchangeAuthorizationCode(authCode: string, codeVerifier: string): Promise<TokenResponse> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'User-Agent': 'flynn',
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'authorization_code',
|
||||
code: authCode,
|
||||
redirect_uri: `${ISSUER}/deviceauth/callback`,
|
||||
client_id: CLIENT_ID,
|
||||
code_verifier: codeVerifier,
|
||||
}).toString(),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const body = await response.text();
|
||||
throw new Error(`OpenAI token exchange failed (${response.status}): ${body}`);
|
||||
}
|
||||
|
||||
return response.json() as Promise<TokenResponse>;
|
||||
}
|
||||
|
||||
export async function refreshOpenAIAuth(refreshToken: string): Promise<TokenResponse> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'User-Agent': 'flynn',
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
grant_type: 'refresh_token',
|
||||
refresh_token: refreshToken,
|
||||
client_id: CLIENT_ID,
|
||||
}).toString(),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const body = await response.text();
|
||||
throw new Error(`OpenAI token refresh failed (${response.status}): ${body}`);
|
||||
}
|
||||
|
||||
return response.json() as Promise<TokenResponse>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure we have a valid (non-expired) OpenAI OAuth access token.
|
||||
* Refreshes and persists the token if needed.
|
||||
*/
|
||||
export async function ensureValidOpenAIAuth(): Promise<OpenAIOAuthInfo> {
|
||||
const current = loadStoredOpenAIAuth();
|
||||
if (!current) {
|
||||
throw new Error('OpenAI OAuth is not configured. Run `flynn openai-auth` to authenticate.');
|
||||
}
|
||||
|
||||
if (current.expires_at > Date.now() + REFRESH_SAFETY_MARGIN_MS) {
|
||||
return current;
|
||||
}
|
||||
|
||||
const refreshed = await refreshOpenAIAuth(current.refresh_token);
|
||||
const expiresAt = Date.now() + (refreshed.expires_in ?? 3600) * 1000;
|
||||
const accountId = extractAccountId(refreshed) ?? current.account_id;
|
||||
|
||||
const updated: OpenAIOAuthInfo = {
|
||||
access_token: refreshed.access_token,
|
||||
refresh_token: refreshed.refresh_token,
|
||||
expires_at: expiresAt,
|
||||
account_id: accountId,
|
||||
created_at: current.created_at,
|
||||
};
|
||||
|
||||
storeOpenAIAuth(updated);
|
||||
return updated;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the OpenAI Codex device flow interactively.
|
||||
* @param onPrompt Callback to display the user code and verification URL to the user.
|
||||
*/
|
||||
export async function loginOpenAI(
|
||||
onPrompt: (userCode: string, verificationUri: string) => void,
|
||||
): Promise<OpenAIOAuthInfo> {
|
||||
const device = await requestDeviceAuth();
|
||||
const intervalMs = Math.max(parseInt(device.interval) || 5, 1) * 1000;
|
||||
|
||||
onPrompt(device.user_code, DEVICE_URL);
|
||||
|
||||
const deviceToken = await pollDeviceToken(device.device_auth_id, device.user_code, intervalMs);
|
||||
const tokens = await exchangeAuthorizationCode(deviceToken.authorization_code, deviceToken.code_verifier);
|
||||
|
||||
const expiresAt = Date.now() + (tokens.expires_in ?? 3600) * 1000;
|
||||
const accountId = extractAccountId(tokens);
|
||||
|
||||
const info: OpenAIOAuthInfo = {
|
||||
access_token: tokens.access_token,
|
||||
refresh_token: tokens.refresh_token,
|
||||
expires_at: expiresAt,
|
||||
...(accountId ? { account_id: accountId } : {}),
|
||||
created_at: new Date().toISOString(),
|
||||
};
|
||||
|
||||
storeOpenAIAuth(info);
|
||||
return info;
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { homedir } from 'os';
|
||||
import { GmailWatcher } from './gmail.js';
|
||||
import type { GmailWatcher as GmailWatcherType } from './gmail.js';
|
||||
import type { OutboundMessage } from '../channels/types.js';
|
||||
|
||||
// Mock googleapis module
|
||||
@@ -74,6 +74,23 @@ vi.mock('googleapis', () => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('@google-cloud/pubsub', () => {
|
||||
const pull = vi.fn().mockResolvedValue([{ receivedMessages: [] }]);
|
||||
const acknowledge = vi.fn().mockResolvedValue([{}]);
|
||||
const close = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
class SubscriberClient {
|
||||
pull = pull;
|
||||
acknowledge = acknowledge;
|
||||
close = close;
|
||||
}
|
||||
|
||||
return {
|
||||
v1: { SubscriberClient },
|
||||
_mocks: { pull, acknowledge, close },
|
||||
};
|
||||
});
|
||||
|
||||
// Mock fs operations
|
||||
vi.mock('fs', async () => {
|
||||
const actual = await vi.importActual<typeof import('fs')>('fs');
|
||||
@@ -86,6 +103,7 @@ vi.mock('fs', async () => {
|
||||
installed: {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-client-secret',
|
||||
project_id: 'test-project',
|
||||
redirect_uris: ['http://localhost'],
|
||||
},
|
||||
});
|
||||
@@ -108,6 +126,9 @@ function createMockConfig(overrides = {}) {
|
||||
enabled: true,
|
||||
credentials_file: '~/.config/flynn/gmail-credentials.json',
|
||||
token_file: '~/.config/flynn/gmail-token.json',
|
||||
disable_push: false,
|
||||
pubsub_pull_interval: '60s',
|
||||
pubsub_max_messages: 10,
|
||||
watch_labels: ['INBOX'],
|
||||
poll_interval: '300s',
|
||||
output: {
|
||||
@@ -128,12 +149,16 @@ function createMockChannelLookup() {
|
||||
}
|
||||
|
||||
describe('GmailWatcher', () => {
|
||||
let watcher: GmailWatcher;
|
||||
let GmailWatcher: typeof GmailWatcherType;
|
||||
let watcher: GmailWatcherType;
|
||||
let channelLookup: ReturnType<typeof createMockChannelLookup>;
|
||||
|
||||
beforeEach(() => {
|
||||
beforeEach(async () => {
|
||||
vi.useFakeTimers();
|
||||
channelLookup = createMockChannelLookup();
|
||||
|
||||
// Import after mocks so ESM named imports (fs/googleapis) are properly mocked.
|
||||
({ GmailWatcher } = await import('./gmail.js'));
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
@@ -154,6 +179,60 @@ describe('GmailWatcher', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('push topic resolution', () => {
|
||||
it('returns null when pubsub_topic is not set', () => {
|
||||
const config = createMockConfig();
|
||||
watcher = new GmailWatcher(config, channelLookup);
|
||||
const topic = (watcher as unknown as { resolvePubSubTopicName: () => string | null }).resolvePubSubTopicName();
|
||||
expect(topic).toBe(null);
|
||||
});
|
||||
|
||||
it('expands shorthand topic id when project_id is known', () => {
|
||||
const config = createMockConfig({ pubsub_topic: 'my-topic' });
|
||||
watcher = new GmailWatcher(config, channelLookup);
|
||||
(watcher as unknown as { googleProjectId: string }).googleProjectId = 'test-project';
|
||||
|
||||
const topic = (watcher as unknown as { resolvePubSubTopicName: () => string | null }).resolvePubSubTopicName();
|
||||
expect(topic).toBe('projects/test-project/topics/my-topic');
|
||||
});
|
||||
|
||||
it('rejects invalid pubsub_topic formats', () => {
|
||||
const config = createMockConfig({ pubsub_topic: 'projects/test-project/topic/my-topic' });
|
||||
watcher = new GmailWatcher(config, channelLookup);
|
||||
|
||||
expect(() => {
|
||||
(watcher as unknown as { resolvePubSubTopicName: () => string | null }).resolvePubSubTopicName();
|
||||
}).toThrow(/Invalid pubsub_topic/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('pull subscription resolution', () => {
|
||||
it('returns null when pubsub_subscription_id is not set', () => {
|
||||
const config = createMockConfig();
|
||||
watcher = new GmailWatcher(config, channelLookup);
|
||||
const sub = (watcher as unknown as { resolvePubSubSubscriptionName: () => string | null }).resolvePubSubSubscriptionName();
|
||||
expect(sub).toBe(null);
|
||||
});
|
||||
|
||||
it('expands shorthand subscription id when project_id is known', () => {
|
||||
const config = createMockConfig({ pubsub_subscription_id: 'my-sub' });
|
||||
watcher = new GmailWatcher(config, channelLookup);
|
||||
(watcher as unknown as { googleProjectId: string }).googleProjectId = 'test-project';
|
||||
|
||||
const sub = (watcher as unknown as { resolvePubSubSubscriptionName: () => string | null }).resolvePubSubSubscriptionName();
|
||||
expect(sub).toBe('projects/test-project/subscriptions/my-sub');
|
||||
});
|
||||
|
||||
it('rejects invalid pubsub_subscription_id formats', () => {
|
||||
const config = createMockConfig({ pubsub_subscription_id: 'projects/test-project/subscription/my-sub' });
|
||||
watcher = new GmailWatcher(config, channelLookup);
|
||||
|
||||
expect(() => {
|
||||
(watcher as unknown as { resolvePubSubSubscriptionName: () => string | null }).resolvePubSubSubscriptionName();
|
||||
}).toThrow(/Invalid pubsub_subscription_id/);
|
||||
});
|
||||
});
|
||||
|
||||
describe('connect() with missing credentials', () => {
|
||||
it('logs warning and sets status to error when credentials_file is missing', async () => {
|
||||
const config = createMockConfig({ credentials_file: undefined });
|
||||
@@ -200,6 +279,78 @@ describe('GmailWatcher', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('push disable flag', () => {
|
||||
it('skips watch setup when disable_push is true', async () => {
|
||||
const config = createMockConfig({ disable_push: true, pubsub_topic: 'projects/test-project/topics/gmail-push' });
|
||||
watcher = new GmailWatcher(config, channelLookup);
|
||||
|
||||
const { existsSync, readFileSync } = await import('fs');
|
||||
vi.mocked(existsSync).mockReturnValue(true);
|
||||
vi.mocked(readFileSync).mockImplementation((path: unknown) => {
|
||||
const p = String(path);
|
||||
if (p.includes('credentials')) {
|
||||
return JSON.stringify({
|
||||
installed: {
|
||||
client_id: 'test-client-id',
|
||||
client_secret: 'test-client-secret',
|
||||
project_id: 'test-project',
|
||||
redirect_uris: ['http://localhost'],
|
||||
},
|
||||
});
|
||||
}
|
||||
return JSON.stringify({
|
||||
access_token: 'test-access-token',
|
||||
refresh_token: 'test-refresh-token',
|
||||
expiry_date: Date.now() + 3600000,
|
||||
});
|
||||
});
|
||||
|
||||
const googleapis = await import('googleapis') as unknown as {
|
||||
_mocks: {
|
||||
mockWatch: ReturnType<typeof vi.fn>;
|
||||
mockOAuth2: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
};
|
||||
googleapis._mocks.mockOAuth2.mockImplementation(() => ({
|
||||
setCredentials: vi.fn(),
|
||||
on: vi.fn(),
|
||||
}));
|
||||
const watchSpy = googleapis._mocks.mockWatch;
|
||||
|
||||
await watcher.connect();
|
||||
|
||||
expect(watchSpy).not.toHaveBeenCalled();
|
||||
expect(watcher.status).toBe('connected');
|
||||
});
|
||||
});
|
||||
|
||||
describe('pullSubscriptionMessages', () => {
|
||||
it('pulls messages and acknowledges successfully processed ones', async () => {
|
||||
const config = createMockConfig({ pubsub_subscription_id: 'projects/test-project/subscriptions/gmail-pull' });
|
||||
watcher = new GmailWatcher(config, channelLookup);
|
||||
|
||||
const { _mocks: pubsubMocks } = await import('@google-cloud/pubsub') as unknown as {
|
||||
_mocks: { pull: ReturnType<typeof vi.fn>; acknowledge: ReturnType<typeof vi.fn> };
|
||||
};
|
||||
const payload = { emailAddress: 'bob@example.com', historyId: '200' };
|
||||
pubsubMocks.pull.mockResolvedValueOnce([
|
||||
{
|
||||
receivedMessages: [
|
||||
{ ackId: 'ack-1', message: { data: Buffer.from(JSON.stringify(payload)) } },
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
await (watcher as unknown as { pullSubscriptionMessages: () => Promise<void> }).pullSubscriptionMessages();
|
||||
|
||||
expect((watcher as unknown as { lastHistoryId: string }).lastHistoryId).toBe('200');
|
||||
expect(pubsubMocks.acknowledge).toHaveBeenCalledWith({
|
||||
subscription: 'projects/test-project/subscriptions/gmail-pull',
|
||||
ackIds: ['ack-1'],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('renderTemplate', () => {
|
||||
it('replaces all placeholders correctly', () => {
|
||||
const config = createMockConfig({
|
||||
|
||||
+222
-11
@@ -2,6 +2,7 @@ import { google, type Auth } from 'googleapis';
|
||||
import { readFileSync, writeFileSync, existsSync, mkdirSync, chmodSync } from 'fs';
|
||||
import { dirname, resolve } from 'path';
|
||||
import { homedir } from 'os';
|
||||
import type { v1 } from '@google-cloud/pubsub';
|
||||
import type { GmailConfig } from '../config/schema.js';
|
||||
import type { ChannelAdapter, ChannelStatus, InboundMessage, OutboundMessage } from '../channels/types.js';
|
||||
import { parseInterval } from './heartbeat.js';
|
||||
@@ -30,9 +31,7 @@ interface PubSubNotification {
|
||||
historyId: string;
|
||||
}
|
||||
|
||||
// Google Cloud Pub/Sub topic for Gmail push notifications.
|
||||
// This must be pre-configured in Google Cloud Console.
|
||||
const GMAIL_PUBSUB_TOPIC = 'projects/flynn-agent/topics/gmail-push';
|
||||
const DEFAULT_TOPIC_ID = 'gmail-push';
|
||||
|
||||
// Watch expires after ~7 days; renew at 6 days (in ms).
|
||||
const WATCH_RENEWAL_MS = 6 * 24 * 60 * 60 * 1000;
|
||||
@@ -56,7 +55,11 @@ export class GmailWatcher implements ChannelAdapter {
|
||||
private lastHistoryId?: string;
|
||||
private pollTimer?: ReturnType<typeof setInterval>;
|
||||
private watchTimer?: ReturnType<typeof setInterval>;
|
||||
private pullTimer?: ReturnType<typeof setInterval>;
|
||||
private pubsubSubscriber?: v1.SubscriberClient;
|
||||
private pullInFlight = false;
|
||||
private readonly config: NonNullable<GmailConfig>;
|
||||
private googleProjectId?: string;
|
||||
|
||||
constructor(
|
||||
config: NonNullable<GmailConfig>,
|
||||
@@ -82,12 +85,28 @@ export class GmailWatcher implements ChannelAdapter {
|
||||
return;
|
||||
}
|
||||
|
||||
// Set up Gmail push watch (Pub/Sub)
|
||||
// Set up Gmail push watch (Pub/Sub). Polling is always enabled.
|
||||
if (!this.config.disable_push) {
|
||||
try {
|
||||
await this.setupWatch();
|
||||
} catch (error) {
|
||||
const errMsg = error instanceof Error ? error.message : 'Unknown error';
|
||||
const hint = this.buildWatchErrorHint(errMsg);
|
||||
console.warn(`GmailWatcher: Watch setup failed (will use polling only) — ${errMsg}${hint}`);
|
||||
}
|
||||
} else {
|
||||
const configured = (this.config.pubsub_topic ?? process.env.FLYNN_GMAIL_PUBSUB_TOPIC ?? '').trim();
|
||||
if (configured) {
|
||||
console.log('GmailWatcher: Push disabled (disable_push=true)');
|
||||
}
|
||||
}
|
||||
|
||||
// Set up Pub/Sub pull subscription (optional).
|
||||
try {
|
||||
await this.setupWatch();
|
||||
await this.setupPullSubscription();
|
||||
} catch (error) {
|
||||
const errMsg = error instanceof Error ? error.message : 'Unknown error';
|
||||
console.warn(`GmailWatcher: Watch setup failed (will use polling only) — ${errMsg}`);
|
||||
console.warn(`GmailWatcher: Pull setup failed (will continue without pull) — ${errMsg}`);
|
||||
}
|
||||
|
||||
// Start polling fallback
|
||||
@@ -99,8 +118,23 @@ export class GmailWatcher implements ChannelAdapter {
|
||||
}, pollMs);
|
||||
|
||||
this._status = 'connected';
|
||||
console.log(`GmailWatcher: Connected (poll_interval=${this.config.poll_interval ?? '300s'})`);
|
||||
auditLogger?.systemStart('GmailWatcher', { poll_interval: this.config.poll_interval });
|
||||
|
||||
const modes: string[] = [];
|
||||
const pushConfigured = Boolean((this.config.pubsub_topic ?? process.env.FLYNN_GMAIL_PUBSUB_TOPIC ?? '').trim());
|
||||
const pullConfigured = Boolean((this.config.pubsub_subscription_id ?? '').trim());
|
||||
if (pushConfigured && !this.config.disable_push) {modes.push('push');}
|
||||
if (pullConfigured) {modes.push('pull');}
|
||||
modes.push('poll');
|
||||
|
||||
console.log(
|
||||
`GmailWatcher: Connected (${modes.join('+')}, poll_interval=${this.config.poll_interval ?? '300s'}${pullConfigured ? `, pubsub_pull_interval=${this.config.pubsub_pull_interval ?? '60s'}` : ''})`,
|
||||
);
|
||||
auditLogger?.systemStart('GmailWatcher', {
|
||||
modes: modes.join('+'),
|
||||
poll_interval: this.config.poll_interval,
|
||||
pubsub_topic: pushConfigured ? 'configured' : 'none',
|
||||
pubsub_subscription_id: pullConfigured ? 'configured' : 'none',
|
||||
});
|
||||
}
|
||||
|
||||
async disconnect(): Promise<void> {
|
||||
@@ -109,9 +143,21 @@ export class GmailWatcher implements ChannelAdapter {
|
||||
this.pollTimer = undefined;
|
||||
}
|
||||
if (this.watchTimer) {
|
||||
clearTimeout(this.watchTimer);
|
||||
clearInterval(this.watchTimer);
|
||||
this.watchTimer = undefined;
|
||||
}
|
||||
if (this.pullTimer) {
|
||||
clearInterval(this.pullTimer);
|
||||
this.pullTimer = undefined;
|
||||
}
|
||||
if (this.pubsubSubscriber) {
|
||||
try {
|
||||
await this.pubsubSubscriber.close();
|
||||
} catch {
|
||||
// Ignore shutdown errors
|
||||
}
|
||||
this.pubsubSubscriber = undefined;
|
||||
}
|
||||
this.oauth2Client = undefined;
|
||||
this._status = 'disconnected';
|
||||
auditLogger?.systemStop('GmailWatcher');
|
||||
@@ -178,7 +224,10 @@ export class GmailWatcher implements ChannelAdapter {
|
||||
}
|
||||
|
||||
const credentials = JSON.parse(readFileSync(expandedCredsPath, 'utf-8'));
|
||||
const { client_id, client_secret, redirect_uris } = credentials.installed ?? credentials.web ?? {};
|
||||
const { client_id, client_secret, redirect_uris, project_id } = credentials.installed ?? credentials.web ?? {};
|
||||
if (project_id && typeof project_id === 'string') {
|
||||
this.googleProjectId = project_id;
|
||||
}
|
||||
|
||||
if (!client_id || !client_secret) {
|
||||
throw new Error('Invalid credentials file — missing client_id or client_secret');
|
||||
@@ -217,13 +266,24 @@ export class GmailWatcher implements ChannelAdapter {
|
||||
private async setupWatch(): Promise<void> {
|
||||
if (!this.oauth2Client) {return;}
|
||||
|
||||
if (this.watchTimer) {
|
||||
clearInterval(this.watchTimer);
|
||||
this.watchTimer = undefined;
|
||||
}
|
||||
|
||||
const topicName = this.resolvePubSubTopicName();
|
||||
if (!topicName) {
|
||||
// Push notifications are optional; polling is always enabled.
|
||||
return;
|
||||
}
|
||||
|
||||
const gmail = google.gmail({ version: 'v1', auth: this.oauth2Client });
|
||||
|
||||
const watchResponse = await gmail.users.watch({
|
||||
userId: 'me',
|
||||
requestBody: {
|
||||
labelIds: this.config.watch_labels ?? ['INBOX'],
|
||||
topicName: GMAIL_PUBSUB_TOPIC,
|
||||
topicName,
|
||||
},
|
||||
});
|
||||
|
||||
@@ -241,6 +301,157 @@ export class GmailWatcher implements ChannelAdapter {
|
||||
}, WATCH_RENEWAL_MS);
|
||||
}
|
||||
|
||||
private buildWatchErrorHint(errMsg: string): string {
|
||||
const hints: string[] = [];
|
||||
|
||||
if (errMsg.includes('Invalid topicName')) {
|
||||
hints.push(
|
||||
`Tip: set automation.gmail.pubsub_topic to "projects/${this.googleProjectId ?? '<project-id>'}/topics/${DEFAULT_TOPIC_ID}"`,
|
||||
);
|
||||
}
|
||||
|
||||
if (/permission denied|PERMISSION_DENIED/i.test(errMsg)) {
|
||||
hints.push('Tip: ensure Gmail has permission to publish to the Pub/Sub topic (IAM)');
|
||||
}
|
||||
|
||||
hints.push('Tip: if Google cannot reach your gateway, set automation.gmail.pubsub_subscription_id for pull mode');
|
||||
|
||||
return hints.length > 0 ? `\n ${hints.join('\n ')}` : '';
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the Pub/Sub topic resource name for Gmail push notifications.
|
||||
*
|
||||
* Priority:
|
||||
* 1) automation.gmail.pubsub_topic
|
||||
* 2) FLYNN_GMAIL_PUBSUB_TOPIC env var
|
||||
* If neither is provided, push notifications are disabled.
|
||||
*/
|
||||
private resolvePubSubTopicName(): string | null {
|
||||
const configured = this.config.pubsub_topic ?? process.env.FLYNN_GMAIL_PUBSUB_TOPIC;
|
||||
let topic = (configured ?? '').trim();
|
||||
|
||||
if (!topic) {return null;}
|
||||
|
||||
// Allow shorthand: just the topic id (e.g. "gmail-push")
|
||||
if (!topic.includes('/')) {
|
||||
if (!this.googleProjectId) {
|
||||
throw new Error(
|
||||
`pubsub_topic '${topic}' must be fully qualified (projects/<project-id>/topics/<topic>) because project_id was not found in credentials`,
|
||||
);
|
||||
}
|
||||
topic = `projects/${this.googleProjectId}/topics/${topic}`;
|
||||
}
|
||||
|
||||
const isValid = /^projects\/[^/]+\/topics\/[^/]+$/.test(topic);
|
||||
if (!isValid) {
|
||||
throw new Error(
|
||||
`Invalid pubsub_topic '${topic}'. Expected: projects/<project-id>/topics/<topic>`,
|
||||
);
|
||||
}
|
||||
|
||||
return topic;
|
||||
}
|
||||
|
||||
private resolvePubSubSubscriptionName(): string | null {
|
||||
let sub = (this.config.pubsub_subscription_id ?? '').trim();
|
||||
if (!sub) {return null;}
|
||||
|
||||
// Allow shorthand: just the subscription id (e.g. "gmail-pull")
|
||||
if (!sub.includes('/')) {
|
||||
if (!this.googleProjectId) {
|
||||
throw new Error(
|
||||
`pubsub_subscription_id '${sub}' must be fully qualified (projects/<project-id>/subscriptions/<subscription>) because project_id was not found in credentials`,
|
||||
);
|
||||
}
|
||||
sub = `projects/${this.googleProjectId}/subscriptions/${sub}`;
|
||||
}
|
||||
|
||||
const isValid = /^projects\/[^/]+\/subscriptions\/[^/]+$/.test(sub);
|
||||
if (!isValid) {
|
||||
throw new Error(
|
||||
`Invalid pubsub_subscription_id '${sub}'. Expected: projects/<project-id>/subscriptions/<subscription>`,
|
||||
);
|
||||
}
|
||||
|
||||
return sub;
|
||||
}
|
||||
|
||||
private async setupPullSubscription(): Promise<void> {
|
||||
const subscriptionName = this.resolvePubSubSubscriptionName();
|
||||
if (!subscriptionName) {return;}
|
||||
|
||||
if (this.pullTimer) {
|
||||
clearInterval(this.pullTimer);
|
||||
this.pullTimer = undefined;
|
||||
}
|
||||
|
||||
const pullMs = parseInterval(this.config.pubsub_pull_interval ?? '60s');
|
||||
|
||||
// Kick once immediately, then on interval.
|
||||
await this.pullSubscriptionMessages().catch((err) => {
|
||||
console.error('GmailWatcher: Pub/Sub pull error —', err instanceof Error ? err.message : err);
|
||||
});
|
||||
|
||||
this.pullTimer = setInterval(() => {
|
||||
this.pullSubscriptionMessages().catch((err) => {
|
||||
console.error('GmailWatcher: Pub/Sub pull error —', err instanceof Error ? err.message : err);
|
||||
});
|
||||
}, pullMs);
|
||||
|
||||
console.log(
|
||||
`GmailWatcher: Pull enabled (subscription=${subscriptionName}, interval=${this.config.pubsub_pull_interval ?? '60s'})`,
|
||||
);
|
||||
}
|
||||
|
||||
private async getSubscriberClient(): Promise<v1.SubscriberClient> {
|
||||
if (this.pubsubSubscriber) {return this.pubsubSubscriber;}
|
||||
const mod = await import('@google-cloud/pubsub');
|
||||
this.pubsubSubscriber = new mod.v1.SubscriberClient();
|
||||
return this.pubsubSubscriber;
|
||||
}
|
||||
|
||||
private async pullSubscriptionMessages(): Promise<void> {
|
||||
const subscription = this.resolvePubSubSubscriptionName();
|
||||
if (!subscription) {return;}
|
||||
if (this.pullInFlight) {return;}
|
||||
this.pullInFlight = true;
|
||||
|
||||
try {
|
||||
const client = await this.getSubscriberClient();
|
||||
const maxMessages = this.config.pubsub_max_messages ?? 10;
|
||||
|
||||
const [response] = await client.pull({
|
||||
subscription,
|
||||
maxMessages,
|
||||
});
|
||||
|
||||
const received = response.receivedMessages ?? [];
|
||||
if (received.length === 0) {return;}
|
||||
|
||||
const ackIds: string[] = [];
|
||||
for (const receivedMessage of received) {
|
||||
const ackId = receivedMessage.ackId;
|
||||
const data = receivedMessage.message?.data;
|
||||
if (!ackId || !data) {continue;}
|
||||
|
||||
const base64 = Buffer.from(data as Uint8Array).toString('base64');
|
||||
try {
|
||||
await this.handlePushNotification(base64);
|
||||
ackIds.push(ackId);
|
||||
} catch {
|
||||
// If processing fails, leave message unacked for retry.
|
||||
}
|
||||
}
|
||||
|
||||
if (ackIds.length > 0) {
|
||||
await client.acknowledge({ subscription, ackIds });
|
||||
}
|
||||
} finally {
|
||||
this.pullInFlight = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Poll Gmail History API for new messages since lastHistoryId.
|
||||
* Fallback mechanism when Pub/Sub push is not available.
|
||||
|
||||
@@ -228,6 +228,35 @@ describe('TelegramAdapter', () => {
|
||||
expect(msg.metadata).toEqual({ isCommand: true, command: 'reset' });
|
||||
});
|
||||
|
||||
it('/model command strips @bot suffix in groups', async () => {
|
||||
const handler = vi.fn();
|
||||
adapter.onMessage(handler);
|
||||
|
||||
await adapter.connect();
|
||||
|
||||
// Find the /model command handler
|
||||
const modelCall = mockCommand.mock.calls.find((call) => call[0] === 'model');
|
||||
expect(modelCall).toBeDefined();
|
||||
const modelHandler = modelCall![1];
|
||||
|
||||
const ctx = {
|
||||
message: { message_id: 123, text: '/model@flynn_bot default github/gpt-5-mini' },
|
||||
chat: { id: 100 },
|
||||
from: { first_name: 'Will' },
|
||||
};
|
||||
|
||||
await modelHandler(ctx);
|
||||
|
||||
expect(handler).toHaveBeenCalledTimes(1);
|
||||
const msg: InboundMessage = handler.mock.calls[0][0];
|
||||
expect(msg.text).toBe('/model default github/gpt-5-mini');
|
||||
expect(msg.metadata).toEqual({
|
||||
isCommand: true,
|
||||
command: 'model',
|
||||
commandArgs: 'default github/gpt-5-mini',
|
||||
});
|
||||
});
|
||||
|
||||
// ── Auth middleware ───────────────────────────────────────────
|
||||
|
||||
it('auth middleware blocks unauthorized chat IDs', async () => {
|
||||
|
||||
@@ -166,7 +166,9 @@ export class TelegramAdapter implements ChannelAdapter {
|
||||
this.bot.command('model', async (ctx) => {
|
||||
if (!this.messageHandler) {return;}
|
||||
|
||||
const args = ctx.message?.text?.replace(/^\/model\s*/, '').trim() ?? '';
|
||||
// Telegram can deliver group commands in the form: /model@bot_username ...
|
||||
// Strip the optional @mention so args parsing is consistent across DMs/groups.
|
||||
const args = ctx.message?.text?.replace(/^\/model(?:@\S+)?\s*/i, '').trim() ?? '';
|
||||
|
||||
this.messageHandler({
|
||||
id: String(ctx.message?.message_id ?? Date.now()),
|
||||
@@ -439,15 +441,48 @@ export class TelegramAdapter implements ChannelAdapter {
|
||||
if (!this.bot) {throw new Error('Telegram adapter not connected');}
|
||||
|
||||
const chatId = Number(peerId);
|
||||
const text = message.text;
|
||||
const text = message.text ?? '';
|
||||
|
||||
// Telegram rejects empty text messages.
|
||||
// If there is no text, skip straight to attachments.
|
||||
if (!text.trim()) {
|
||||
if (message.attachments && message.attachments.length > 0) {
|
||||
for (const attachment of message.attachments) {
|
||||
await this.sendAttachment(chatId, attachment);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const sendChunk = async (chunk: string): Promise<void> => {
|
||||
// We default to Markdown for nicer formatting, but Telegram's Markdown parsing
|
||||
// is strict and can fail on unescaped characters. If Telegram rejects the
|
||||
// message, retry once without parse_mode so users still get the content.
|
||||
try {
|
||||
await this.bot!.api.sendMessage(chatId, chunk, { parse_mode: 'Markdown' });
|
||||
} catch (error) {
|
||||
const description = error && typeof error === 'object' && 'description' in error
|
||||
? String((error as { description?: unknown }).description)
|
||||
: '';
|
||||
|
||||
const isParseError = description.includes("can't parse entities")
|
||||
|| description.includes('message text is empty');
|
||||
|
||||
if (!isParseError) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
await this.bot!.api.sendMessage(chatId, chunk);
|
||||
}
|
||||
};
|
||||
|
||||
// Telegram enforces a 4096-character limit per message
|
||||
if (text.length <= 4096) {
|
||||
await this.bot.api.sendMessage(chatId, text, { parse_mode: 'Markdown' });
|
||||
await sendChunk(text);
|
||||
} else {
|
||||
const chunks = splitMessage(text, 4096);
|
||||
for (const chunk of chunks) {
|
||||
await this.bot.api.sendMessage(chatId, chunk, { parse_mode: 'Markdown' });
|
||||
await sendChunk(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -163,6 +163,75 @@ automation:
|
||||
expect(gmailCheck?.detail).toContain('flynn gmail-auth');
|
||||
});
|
||||
|
||||
it('reports PASS for Gmail when enabled (poll only)', async () => {
|
||||
mkdirSync(testDir, { recursive: true });
|
||||
const configPath = join(testDir, 'config.yaml');
|
||||
const credsPath = join(testDir, 'gmail-creds.json');
|
||||
const tokenPath = join(testDir, 'gmail-token.json');
|
||||
writeFileSync(credsPath, JSON.stringify({ installed: { project_id: 'test-project' } }));
|
||||
writeFileSync(tokenPath, JSON.stringify({ refresh_token: 'x' }));
|
||||
|
||||
writeFileSync(configPath, `
|
||||
telegram:
|
||||
bot_token: "test-token"
|
||||
allowed_chat_ids: [123]
|
||||
models:
|
||||
default:
|
||||
provider: anthropic
|
||||
model: claude-sonnet
|
||||
automation:
|
||||
gmail:
|
||||
enabled: true
|
||||
credentials_file: "${credsPath}"
|
||||
token_file: "${tokenPath}"
|
||||
output:
|
||||
channel: telegram
|
||||
peer: "123"
|
||||
`);
|
||||
|
||||
const ctx: DoctorContext = { configPath, dataDir: testDir };
|
||||
const results = await runChecks(ctx);
|
||||
|
||||
const gmailCheck = results.find(r => r.label.includes('Gmail configured'));
|
||||
expect(gmailCheck?.status).toBe('pass');
|
||||
expect(gmailCheck?.detail).toContain('poll');
|
||||
});
|
||||
|
||||
it('reports WARN for Gmail when pubsub_topic shorthand used without project_id', async () => {
|
||||
mkdirSync(testDir, { recursive: true });
|
||||
const configPath = join(testDir, 'config.yaml');
|
||||
const credsPath = join(testDir, 'gmail-creds.json');
|
||||
const tokenPath = join(testDir, 'gmail-token.json');
|
||||
writeFileSync(credsPath, '{}');
|
||||
writeFileSync(tokenPath, JSON.stringify({ refresh_token: 'x' }));
|
||||
|
||||
writeFileSync(configPath, `
|
||||
telegram:
|
||||
bot_token: "test-token"
|
||||
allowed_chat_ids: [123]
|
||||
models:
|
||||
default:
|
||||
provider: anthropic
|
||||
model: claude-sonnet
|
||||
automation:
|
||||
gmail:
|
||||
enabled: true
|
||||
credentials_file: "${credsPath}"
|
||||
token_file: "${tokenPath}"
|
||||
pubsub_topic: gmail-push
|
||||
output:
|
||||
channel: telegram
|
||||
peer: "123"
|
||||
`);
|
||||
|
||||
const ctx: DoctorContext = { configPath, dataDir: testDir };
|
||||
const results = await runChecks(ctx);
|
||||
|
||||
const gmailCheck = results.find(r => r.label.includes('Gmail configured'));
|
||||
expect(gmailCheck?.status).toBe('warn');
|
||||
expect(gmailCheck?.detail).toContain('pubsub_topic shorthand');
|
||||
});
|
||||
|
||||
it('skips downstream checks when config is invalid', async () => {
|
||||
const ctx: DoctorContext = { configPath: '/nonexistent/config.yaml', dataDir: testDir };
|
||||
const results = await runChecks(ctx);
|
||||
|
||||
+55
-2
@@ -137,7 +137,8 @@ const checkModelConnectivity: Check = async (ctx) => {
|
||||
|
||||
// Check if API key is present for providers that need one
|
||||
const needsKey = ['anthropic', 'openai', 'gemini', 'openrouter'];
|
||||
if (needsKey.includes(model.provider) && !model.api_key && !model.auth_token) {
|
||||
const openaiUsingOAuth = model.provider === 'openai' && Boolean((model as unknown as { use_oauth?: boolean }).use_oauth);
|
||||
if (needsKey.includes(model.provider) && !openaiUsingOAuth && !model.api_key && !model.auth_token) {
|
||||
const envVarMap: Record<string, string> = {
|
||||
anthropic: 'ANTHROPIC_API_KEY',
|
||||
openai: 'OPENAI_API_KEY',
|
||||
@@ -256,12 +257,64 @@ const checkGmail: Check = async (ctx) => {
|
||||
return { status: 'fail', label: 'Gmail configured', detail: `credentials file not found: ${credentialsPath}` };
|
||||
}
|
||||
|
||||
let googleProjectId: string | undefined;
|
||||
try {
|
||||
const creds = JSON.parse(readFileSync(credentialsPath, 'utf-8')) as Record<string, unknown>;
|
||||
const installed = (creds.installed as Record<string, unknown> | undefined) ?? (creds.web as Record<string, unknown> | undefined);
|
||||
const projectId = installed?.project_id;
|
||||
if (typeof projectId === 'string' && projectId.trim()) {
|
||||
googleProjectId = projectId.trim();
|
||||
}
|
||||
} catch {
|
||||
// Ignore JSON parse errors; doctor will still validate token and output.
|
||||
}
|
||||
|
||||
const tokenPath = expandPath(gmail.token_file ?? '~/.config/flynn/gmail-token.json');
|
||||
if (!existsSync(tokenPath)) {
|
||||
return { status: 'warn', label: 'Gmail configured', detail: 'run `flynn gmail-auth` to authenticate' };
|
||||
}
|
||||
|
||||
return { status: 'pass', label: 'Gmail configured', detail: `(output: ${gmail.output.channel}/${gmail.output.peer})` };
|
||||
const modes: string[] = [];
|
||||
const warnings: string[] = [];
|
||||
|
||||
const topicRaw = (gmail.pubsub_topic ?? process.env.FLYNN_GMAIL_PUBSUB_TOPIC ?? '').trim();
|
||||
const pushEnabled = Boolean(topicRaw) && !gmail.disable_push;
|
||||
if (pushEnabled) {
|
||||
modes.push('push');
|
||||
if (topicRaw.includes('/')) {
|
||||
const ok = /^projects\/[^/]+\/topics\/[^/]+$/.test(topicRaw);
|
||||
if (!ok) {
|
||||
warnings.push('pubsub_topic format invalid (expected projects/<project>/topics/<topic>)');
|
||||
}
|
||||
} else if (!googleProjectId) {
|
||||
warnings.push('pubsub_topic shorthand requires project_id in Gmail credentials');
|
||||
}
|
||||
|
||||
if (ctx.config.server?.tailscale?.serve) {
|
||||
warnings.push('push requires a public HTTPS endpoint; Tailscale Serve is typically tailnet-only');
|
||||
}
|
||||
} else if (gmail.disable_push && topicRaw) {
|
||||
warnings.push('push disabled (disable_push=true)');
|
||||
}
|
||||
|
||||
const subRaw = (gmail.pubsub_subscription_id ?? '').trim();
|
||||
if (subRaw) {
|
||||
modes.push('pull');
|
||||
if (subRaw.includes('/')) {
|
||||
const ok = /^projects\/[^/]+\/subscriptions\/[^/]+$/.test(subRaw);
|
||||
if (!ok) {
|
||||
warnings.push('pubsub_subscription_id format invalid (expected projects/<project>/subscriptions/<sub>)');
|
||||
}
|
||||
} else if (!googleProjectId) {
|
||||
warnings.push('pubsub_subscription_id shorthand requires project_id in Gmail credentials');
|
||||
}
|
||||
}
|
||||
|
||||
modes.push('poll');
|
||||
const detail = `(${modes.join(' + ')} -> ${gmail.output.channel}/${gmail.output.peer})`;
|
||||
const withWarnings = warnings.length > 0 ? `${detail} — ${warnings.join('; ')}` : detail;
|
||||
|
||||
return { status: warnings.length > 0 ? 'warn' : 'pass', label: 'Gmail configured', detail: withWarnings };
|
||||
};
|
||||
|
||||
const allChecks: Check[] = [
|
||||
|
||||
@@ -18,6 +18,7 @@ import { registerGcalAuthCommand } from './gcal-auth.js';
|
||||
import { registerGdocsAuthCommand } from './gdocs-auth.js';
|
||||
import { registerGdriveAuthCommand } from './gdrive-auth.js';
|
||||
import { registerGtasksAuthCommand } from './gtasks-auth.js';
|
||||
import { registerOpenaiAuthCommand } from './openai-auth.js';
|
||||
import { registerSkillsCommand } from './skills.js';
|
||||
|
||||
export function createProgram(): Command {
|
||||
@@ -41,6 +42,7 @@ export function createProgram(): Command {
|
||||
registerGdocsAuthCommand(program);
|
||||
registerGdriveAuthCommand(program);
|
||||
registerGtasksAuthCommand(program);
|
||||
registerOpenaiAuthCommand(program);
|
||||
registerSkillsCommand(program);
|
||||
|
||||
return program;
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import type { Command } from 'commander';
|
||||
import { loadStoredOpenAIAuth, loginOpenAI } from '../auth/index.js';
|
||||
|
||||
export function registerOpenaiAuthCommand(program: Command): void {
|
||||
program
|
||||
.command('openai-auth')
|
||||
.description('Authenticate OpenAI (ChatGPT Plus/Pro) via OAuth device flow')
|
||||
.action(async () => {
|
||||
const existing = loadStoredOpenAIAuth();
|
||||
if (existing) {
|
||||
console.log('OpenAI OAuth token already exists.');
|
||||
console.log('Delete ~/.config/flynn/auth.json openai entry if you want to re-authenticate.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
console.log('Starting OpenAI OAuth device flow...');
|
||||
console.log('');
|
||||
|
||||
try {
|
||||
await loginOpenAI((userCode, verificationUri) => {
|
||||
console.log(`Please visit: ${verificationUri}`);
|
||||
console.log(`Enter code: ${userCode}`);
|
||||
console.log('');
|
||||
console.log('Waiting for authorization...');
|
||||
});
|
||||
|
||||
console.log('');
|
||||
console.log('OpenAI authentication successful! Token stored.');
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.error(`OpenAI login failed: ${message}`);
|
||||
process.exit(1);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -53,6 +53,45 @@ models:
|
||||
expect(result.config!.telegram?.bot_token).toBe('test-token');
|
||||
});
|
||||
|
||||
it('loads env vars from FLYNN_ENV_FILE before parsing config', () => {
|
||||
const prevEnvFile = process.env.FLYNN_ENV_FILE;
|
||||
const prevToken = process.env.TEST_BOT_TOKEN;
|
||||
delete process.env.TEST_BOT_TOKEN;
|
||||
|
||||
mkdirSync(testDir, { recursive: true });
|
||||
const envPath = join(testDir, 'cloud.env');
|
||||
const configPath = join(testDir, 'config.yaml');
|
||||
|
||||
writeFileSync(envPath, 'TEST_BOT_TOKEN=test-token\n');
|
||||
process.env.FLYNN_ENV_FILE = envPath;
|
||||
|
||||
writeFileSync(configPath, `
|
||||
telegram:
|
||||
bot_token: \${TEST_BOT_TOKEN}
|
||||
allowed_chat_ids: [123]
|
||||
models:
|
||||
default:
|
||||
provider: anthropic
|
||||
model: claude-sonnet
|
||||
`);
|
||||
|
||||
const result = loadConfigSafe(configPath);
|
||||
expect(result.config).toBeDefined();
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.config!.telegram?.bot_token).toBe('test-token');
|
||||
|
||||
if (prevEnvFile !== undefined) {
|
||||
process.env.FLYNN_ENV_FILE = prevEnvFile;
|
||||
} else {
|
||||
delete process.env.FLYNN_ENV_FILE;
|
||||
}
|
||||
if (prevToken !== undefined) {
|
||||
process.env.TEST_BOT_TOKEN = prevToken;
|
||||
} else {
|
||||
delete process.env.TEST_BOT_TOKEN;
|
||||
}
|
||||
});
|
||||
|
||||
it('returns error when file not found', () => {
|
||||
const result = loadConfigSafe('/nonexistent/config.yaml');
|
||||
expect(result.config).toBeUndefined();
|
||||
|
||||
@@ -2,6 +2,38 @@ import { loadConfig } from '../config/index.js';
|
||||
import type { Config } from '../config/index.js';
|
||||
import { resolve, dirname, join } from 'path';
|
||||
import { homedir } from 'os';
|
||||
import { existsSync, readFileSync } from 'fs';
|
||||
|
||||
function loadEnvFileIfPresent(): void {
|
||||
const envFile = process.env.FLYNN_ENV_FILE ?? resolve(homedir(), '.config/flynn/cloud.env');
|
||||
if (!existsSync(envFile)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const raw = readFileSync(envFile, 'utf-8');
|
||||
for (const line of raw.split(/\r?\n/)) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || trimmed.startsWith('#')) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const idx = trimmed.indexOf('=');
|
||||
if (idx <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const key = trimmed.slice(0, idx).trim();
|
||||
const value = trimmed.slice(idx + 1);
|
||||
if (!key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Don't override existing env vars.
|
||||
if (process.env[key] === undefined) {
|
||||
process.env[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Get the config file path from env or default location. */
|
||||
export function getConfigPath(): string {
|
||||
@@ -30,6 +62,7 @@ export function resolveOverlayPath(basePath: string): string | undefined {
|
||||
export function loadConfigSafe(configPath?: string): { config?: Config; error?: string } {
|
||||
const path = configPath ?? getConfigPath();
|
||||
try {
|
||||
loadEnvFileIfPresent();
|
||||
const overlayPath = resolveOverlayPath(path);
|
||||
const config = loadConfig(path, overlayPath);
|
||||
return { config };
|
||||
|
||||
+49
-22
@@ -1,5 +1,5 @@
|
||||
import type { Command } from 'commander';
|
||||
import type { Config } from '../config/index.js';
|
||||
import type { Config, ModelConfig, ModelProvider } from '../config/index.js';
|
||||
import { loadConfigSafe, getConfigPath } from './shared.js';
|
||||
import { existsSync, mkdirSync, readFileSync } from 'fs';
|
||||
import { resolve } from 'path';
|
||||
@@ -58,6 +58,26 @@ function loadSystemPrompt(): string {
|
||||
return 'You are Flynn, a helpful personal AI assistant. Be direct, concise, and helpful. Use markdown when it improves readability.';
|
||||
}
|
||||
|
||||
function buildProviderConfigMap(config: Config): Partial<Record<ModelProvider, ModelConfig>> {
|
||||
const providerConfigs: Partial<Record<ModelProvider, ModelConfig>> = {};
|
||||
const modelConfigs: ModelConfig[] = [
|
||||
config.models.default,
|
||||
...(config.models.fast ? [config.models.fast] : []),
|
||||
...(config.models.complex ? [config.models.complex] : []),
|
||||
...(config.models.local ? [config.models.local] : []),
|
||||
...Object.values(config.models.local_providers ?? {}),
|
||||
];
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
providerConfigs[modelConfig.provider] = modelConfig;
|
||||
if (modelConfig.fallback) {
|
||||
providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback;
|
||||
}
|
||||
}
|
||||
|
||||
return providerConfigs;
|
||||
}
|
||||
|
||||
export function registerTuiCommand(program: Command): void {
|
||||
program
|
||||
.command('tui')
|
||||
@@ -179,6 +199,7 @@ export function registerTuiCommand(program: Command): void {
|
||||
const toolExecutor = new ToolExecutor(toolRegistry, hookEngine);
|
||||
|
||||
const session = sessionManager.getSession('tui', 'local');
|
||||
const modelProviderConfigs = buildProviderConfigMap(config);
|
||||
|
||||
const agent = new NativeAgent({
|
||||
modelClient: modelRouter,
|
||||
@@ -211,29 +232,33 @@ export function registerTuiCommand(program: Command): void {
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
if (opts.fullscreen) {
|
||||
await startFullscreenTui({
|
||||
session,
|
||||
modelClient: modelRouter,
|
||||
modelRouter,
|
||||
systemPrompt,
|
||||
model: config.models.default.model,
|
||||
agent,
|
||||
onExit: cleanup,
|
||||
});
|
||||
} else {
|
||||
if (opts.fullscreen) {
|
||||
await startFullscreenTui({
|
||||
session,
|
||||
modelClient: modelRouter,
|
||||
modelRouter,
|
||||
systemPrompt,
|
||||
model: config.models.default.model,
|
||||
agent,
|
||||
hookEngine,
|
||||
modelProviderConfigs,
|
||||
onExit: cleanup,
|
||||
});
|
||||
} else {
|
||||
let switchingToFullscreen = false;
|
||||
|
||||
const tui = new MinimalTui({
|
||||
session,
|
||||
modelClient: modelRouter,
|
||||
modelRouter,
|
||||
systemPrompt,
|
||||
agent,
|
||||
pairingManager,
|
||||
localProviders: config.models.local_providers,
|
||||
currentLocalProvider: config.models.local?.provider,
|
||||
onTransfer: (target) => {
|
||||
const tui = new MinimalTui({
|
||||
session,
|
||||
modelClient: modelRouter,
|
||||
modelRouter,
|
||||
systemPrompt,
|
||||
agent,
|
||||
hookEngine,
|
||||
pairingManager,
|
||||
localProviders: config.models.local_providers,
|
||||
modelProviderConfigs,
|
||||
currentLocalProvider: config.models.local?.provider,
|
||||
onTransfer: (target) => {
|
||||
if (target === 'telegram') {
|
||||
if (config.telegram && config.telegram.allowed_chat_ids.length > 0) {
|
||||
const telegramUserId = String(config.telegram.allowed_chat_ids[0]);
|
||||
@@ -263,6 +288,8 @@ export function registerTuiCommand(program: Command): void {
|
||||
systemPrompt,
|
||||
model: config.models.default.model,
|
||||
agent,
|
||||
hookEngine,
|
||||
modelProviderConfigs,
|
||||
onExit: cleanup,
|
||||
});
|
||||
return;
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
|
||||
import { createModelCommand } from './index.js';
|
||||
|
||||
describe('builtin /model command', () => {
|
||||
it('passes through the full argument string', async () => {
|
||||
const cmd = createModelCommand();
|
||||
const setModel = vi.fn(() => 'ok');
|
||||
|
||||
const result = await cmd.execute(['default', 'github/gpt-5-mini'], {
|
||||
channel: 'test',
|
||||
senderId: 'user',
|
||||
sessionId: 's1',
|
||||
rawInput: '/model default github/gpt-5-mini',
|
||||
services: { setModel },
|
||||
});
|
||||
|
||||
expect(setModel).toHaveBeenCalledWith('default github/gpt-5-mini');
|
||||
expect(result).toEqual({ handled: true, text: 'ok' });
|
||||
});
|
||||
|
||||
it('still works for single-argument tier switching', async () => {
|
||||
const cmd = createModelCommand();
|
||||
const setModel = vi.fn(() => 'switched');
|
||||
|
||||
const result = await cmd.execute(['fast'], {
|
||||
channel: 'test',
|
||||
senderId: 'user',
|
||||
sessionId: 's1',
|
||||
rawInput: '/model fast',
|
||||
services: { setModel },
|
||||
});
|
||||
|
||||
expect(setModel).toHaveBeenCalledWith('fast');
|
||||
expect(result).toEqual({ handled: true, text: 'switched' });
|
||||
});
|
||||
});
|
||||
@@ -86,7 +86,9 @@ export function createModelCommand(): CommandDefinition {
|
||||
|
||||
return {
|
||||
handled: true,
|
||||
text: await ctx.services.setModel(args[0]),
|
||||
// Pass through the full argument string so frontends can support
|
||||
// richer syntax like: /model <tier> <provider/model>
|
||||
text: await ctx.services.setModel(args.join(' ')),
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
@@ -49,6 +49,8 @@ const modelConfigBaseSchema = z.object({
|
||||
endpoint: z.string().optional(),
|
||||
api_key: z.string().optional(),
|
||||
auth_token: z.string().optional(),
|
||||
/** Use OAuth credential flow (provider-specific). */
|
||||
use_oauth: z.boolean().optional(),
|
||||
for: z.array(z.string()).optional(),
|
||||
num_gpu: z.number().optional(),
|
||||
context_window: z.number().optional(),
|
||||
@@ -178,6 +180,30 @@ const gmailSchema = z.object({
|
||||
enabled: z.boolean().default(false),
|
||||
credentials_file: z.string().optional(),
|
||||
token_file: z.string().default('~/.config/flynn/gmail-token.json'),
|
||||
/**
|
||||
* Optional Google Cloud Pub/Sub topic for Gmail push notifications.
|
||||
* Format: projects/<project-id>/topics/<topic>
|
||||
* If omitted, push notifications are disabled and Flynn will use polling.
|
||||
*/
|
||||
pubsub_topic: z.string().optional(),
|
||||
|
||||
/**
|
||||
* Explicitly disable Gmail push watch registration even if pubsub_topic is set.
|
||||
* Useful for environments where Google cannot reach the gateway (e.g. tailnet-only).
|
||||
*/
|
||||
disable_push: z.boolean().default(false),
|
||||
|
||||
/**
|
||||
* Optional Pub/Sub subscription for pull-based delivery (no inbound webhook required).
|
||||
* Format: projects/<project-id>/subscriptions/<subscription>
|
||||
*/
|
||||
pubsub_subscription_id: z.string().optional(),
|
||||
|
||||
/** How often to pull messages from pubsub_subscription_id (e.g. '60s'). */
|
||||
pubsub_pull_interval: z.string().default('60s'),
|
||||
|
||||
/** Max messages to pull per cycle (1..100). */
|
||||
pubsub_max_messages: z.number().min(1).max(100).default(10),
|
||||
watch_labels: z.array(z.string()).default(['INBOX']),
|
||||
poll_interval: z.string().default('300s'),
|
||||
history_start: z.string().optional(), // ISO date string — only process emails after this date
|
||||
|
||||
@@ -96,6 +96,34 @@ describe('createClientFromConfig', () => {
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
});
|
||||
|
||||
it('creates OpenAIClient for zhipuai when using auth_token', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'zhipuai',
|
||||
model: 'glm-4.5',
|
||||
auth_token: 'oauth-access-token',
|
||||
});
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
});
|
||||
|
||||
it('creates OpenAIClient for zhipuai using ZHIPUAI_AUTH_TOKEN env var', () => {
|
||||
const prev = process.env.ZHIPUAI_AUTH_TOKEN;
|
||||
process.env.ZHIPUAI_AUTH_TOKEN = 'oauth-access-token';
|
||||
|
||||
try {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'zhipuai',
|
||||
model: 'glm-4.5',
|
||||
});
|
||||
expect(client).toBeInstanceOf(OpenAIClient);
|
||||
} finally {
|
||||
if (prev === undefined) {
|
||||
delete process.env.ZHIPUAI_AUTH_TOKEN;
|
||||
} else {
|
||||
process.env.ZHIPUAI_AUTH_TOKEN = prev;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
it('creates BedrockClient for bedrock provider', () => {
|
||||
const client = createClientFromConfig({
|
||||
provider: 'bedrock',
|
||||
|
||||
+19
-1
@@ -18,6 +18,23 @@ function requireApiKey(cfg: ModelConfig, envVar: string): string {
|
||||
return key;
|
||||
}
|
||||
|
||||
function resolveAuthCredential(cfg: ModelConfig, apiKeyEnvVar: string, authTokenEnvVar?: string): string {
|
||||
const raw = cfg.api_key
|
||||
?? cfg.auth_token
|
||||
?? process.env[apiKeyEnvVar]
|
||||
?? (authTokenEnvVar ? process.env[authTokenEnvVar] : undefined);
|
||||
|
||||
if (!raw) {
|
||||
const envHint = authTokenEnvVar ? `${apiKeyEnvVar} or ${authTokenEnvVar}` : apiKeyEnvVar;
|
||||
throw new Error(
|
||||
`Credential required for ${cfg.provider}. ` +
|
||||
`Set ${envHint} environment variable or provide api_key/auth_token in config.`,
|
||||
);
|
||||
}
|
||||
|
||||
return raw.startsWith('Bearer ') ? raw.slice('Bearer '.length) : raw;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a ModelClient from a provider config entry.
|
||||
* Dispatches on the `provider` field so all tiers and fallback entries
|
||||
@@ -35,6 +52,7 @@ export function createClientFromConfig(cfg: ModelConfig): ModelClient {
|
||||
return new OpenAIClient({
|
||||
model: cfg.model,
|
||||
apiKey: cfg.api_key,
|
||||
useOAuth: Boolean(cfg.use_oauth),
|
||||
});
|
||||
case 'ollama':
|
||||
return new OllamaClient({
|
||||
@@ -62,7 +80,7 @@ export function createClientFromConfig(cfg: ModelConfig): ModelClient {
|
||||
case 'zhipuai':
|
||||
return new OpenAIClient({
|
||||
model: cfg.model,
|
||||
apiKey: requireApiKey(cfg, 'ZHIPUAI_API_KEY'),
|
||||
apiKey: resolveAuthCredential(cfg, 'ZHIPUAI_API_KEY', 'ZHIPUAI_AUTH_TOKEN'),
|
||||
baseURL: cfg.endpoint ?? 'https://api.z.ai/api/paas/v4',
|
||||
});
|
||||
case 'xai':
|
||||
|
||||
+112
-8
@@ -9,7 +9,7 @@ import { MemoryStore } from '../memory/index.js';
|
||||
import type { Tool } from '../tools/types.js';
|
||||
import { createMediaSendTool } from '../tools/index.js';
|
||||
import { createSandboxedShellTool, createSandboxedProcessStartTool, SandboxManager } from '../sandbox/index.js';
|
||||
import type { Config } from '../config/index.js';
|
||||
import { MODEL_PROVIDERS, type Config, type ModelConfig, type ModelProvider } from '../config/index.js';
|
||||
import { ModelRouter, type ModelTier } from '../models/index.js';
|
||||
import { ToolRegistry, ToolExecutor } from '../tools/index.js';
|
||||
import { SessionManager } from '../session/index.js';
|
||||
@@ -17,6 +17,27 @@ import { AgentConfigRegistry, AgentRouter } from '../agents/index.js';
|
||||
import type { CommandRegistry } from '../commands/index.js';
|
||||
import type { ComponentRegistry } from '../intents/index.js';
|
||||
import type { RoutingPolicy } from '../routing/index.js';
|
||||
import { createClientFromConfig } from './models.js';
|
||||
|
||||
function buildProviderConfigMap(config: Config): Partial<Record<ModelProvider, ModelConfig>> {
|
||||
const providerConfigs: Partial<Record<ModelProvider, ModelConfig>> = {};
|
||||
const modelConfigs: ModelConfig[] = [
|
||||
config.models.default,
|
||||
...(config.models.fast ? [config.models.fast] : []),
|
||||
...(config.models.complex ? [config.models.complex] : []),
|
||||
...(config.models.local ? [config.models.local] : []),
|
||||
...Object.values(config.models.local_providers ?? {}),
|
||||
];
|
||||
|
||||
for (const modelConfig of modelConfigs) {
|
||||
providerConfigs[modelConfig.provider] = modelConfig;
|
||||
if (modelConfig.fallback) {
|
||||
providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback;
|
||||
}
|
||||
}
|
||||
|
||||
return providerConfigs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the unified message handler for the channel registry.
|
||||
@@ -263,14 +284,97 @@ export function createMessageRouter(deps: {
|
||||
return lines.join('\n');
|
||||
},
|
||||
setModel: (tier) => {
|
||||
const validTiers = deps.modelRouter.getAvailableTiers();
|
||||
if (!validTiers.includes(tier as ModelTier)) {
|
||||
return `Model tier not available: ${tier}`;
|
||||
const raw = tier.trim();
|
||||
if (!raw) {
|
||||
return 'Usage: /model <tier> OR /model <tier> <provider/model> OR /model <tier> reset';
|
||||
}
|
||||
|
||||
const parts = raw.split(/\s+/);
|
||||
const requestedTier = parts[0];
|
||||
|
||||
const validTiers = deps.modelRouter.getAvailableTiers();
|
||||
if (!validTiers.includes(requestedTier as ModelTier)) {
|
||||
return `Model tier not available: ${requestedTier}`;
|
||||
}
|
||||
|
||||
const modelTier = requestedTier as ModelTier;
|
||||
|
||||
// /model <tier>
|
||||
if (parts.length === 1) {
|
||||
session.setConfig('modelTier', modelTier);
|
||||
agent.setModelTier(modelTier);
|
||||
const label = deps.modelRouter.getLabel(modelTier);
|
||||
return `Switched to model: ${modelTier} (${label})`;
|
||||
}
|
||||
|
||||
const arg2 = parts[1];
|
||||
|
||||
// /model <tier> reset — restore configured provider/model and re-enable fallbacks
|
||||
if (arg2.toLowerCase() === 'reset') {
|
||||
const configured: ModelConfig | undefined = modelTier === 'default'
|
||||
? deps.config.models.default
|
||||
: modelTier === 'fast'
|
||||
? deps.config.models.fast
|
||||
: modelTier === 'complex'
|
||||
? deps.config.models.complex
|
||||
: modelTier === 'local'
|
||||
? deps.config.models.local
|
||||
: undefined;
|
||||
if (!configured) {
|
||||
return `No configured model for tier: ${modelTier}`;
|
||||
}
|
||||
|
||||
const client = createClientFromConfig(configured);
|
||||
const label = `${configured.provider}/${configured.model}`;
|
||||
deps.modelRouter.setClient(modelTier, client, label);
|
||||
deps.modelRouter.setTierStrict(modelTier, false);
|
||||
session.setConfig('modelTier', modelTier);
|
||||
agent.setModelTier(modelTier);
|
||||
return `Reset ${modelTier} to: ${label}`;
|
||||
}
|
||||
|
||||
// /model <tier> <provider/model>
|
||||
const providerModel = arg2;
|
||||
if (!providerModel.includes('/')) {
|
||||
return 'Invalid format. Use: /model <tier> <provider/model> (e.g. /model default github/gpt-5-mini)';
|
||||
}
|
||||
|
||||
const slashIdx = providerModel.indexOf('/');
|
||||
const provider = providerModel.slice(0, slashIdx);
|
||||
const model = providerModel.slice(slashIdx + 1);
|
||||
|
||||
if (!MODEL_PROVIDERS.includes(provider as ModelProvider)) {
|
||||
return `Unknown provider "${provider}". Known providers: ${MODEL_PROVIDERS.join(', ')}`;
|
||||
}
|
||||
|
||||
const providerType = provider as ModelProvider;
|
||||
const providerConfigs = buildProviderConfigMap(deps.config);
|
||||
const template = providerConfigs[providerType];
|
||||
|
||||
try {
|
||||
const client = createClientFromConfig(
|
||||
template
|
||||
? { ...template, provider: providerType, model }
|
||||
: { provider: providerType, model },
|
||||
);
|
||||
|
||||
deps.modelRouter.setClient(modelTier, client, providerModel);
|
||||
deps.modelRouter.setTierStrict(modelTier, true);
|
||||
session.setConfig('modelTier', modelTier);
|
||||
agent.setModelTier(modelTier);
|
||||
|
||||
const lines = [
|
||||
`Set ${modelTier} to: ${providerModel}`,
|
||||
`Fallbacks for ${modelTier} disabled (strict tier mode).`,
|
||||
];
|
||||
if (parts.length > 2) {
|
||||
lines.push(`Note: ignored extra args: ${parts.slice(2).join(' ')}`);
|
||||
}
|
||||
return lines.join('\n');
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
return `Failed to switch ${modelTier} to ${providerModel}: ${message}`;
|
||||
}
|
||||
session.setConfig('modelTier', tier);
|
||||
agent.setModelTier(tier as ModelTier);
|
||||
const label = deps.modelRouter.getLabel(tier as ModelTier);
|
||||
return `Switched to model: ${tier} (${label})`;
|
||||
},
|
||||
compact: async () => {
|
||||
const result = await agent.compact();
|
||||
|
||||
@@ -124,7 +124,7 @@ Commands:
|
||||
/model [name] Show or switch model tier (local, default, fast, complex)
|
||||
/model <tier> <p/m> Change tier's provider/model (e.g. /model default anthropic/claude-sonnet-4)
|
||||
/backend [provider] Show or switch local backend (ollama, llamacpp)
|
||||
/login [provider] Authenticate with GitHub
|
||||
/login [provider] Authenticate with GitHub or OpenAI
|
||||
/pair List pending pairing codes and approved senders
|
||||
/pair generate [label] Generate a new DM pairing code
|
||||
/pair revoke <ch> <id> Revoke an approved sender
|
||||
@@ -178,7 +178,7 @@ export const COMMAND_TOOLTIPS: Record<string, string> = {
|
||||
'/status': 'Show session info and token usage',
|
||||
'/fullscreen': 'Switch to fullscreen mode',
|
||||
'/fs': 'Switch to fullscreen mode',
|
||||
'/login': 'Authenticate with GitHub (OAuth device flow)',
|
||||
'/login': 'Authenticate with GitHub or OpenAI (OAuth device flow)',
|
||||
'/pair': 'Generate/list/revoke DM pairing codes',
|
||||
'/transfer': 'Transfer session to another frontend',
|
||||
'/quit': 'Exit TUI',
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React, { useState, useCallback, useRef, useEffect } from 'react';
|
||||
import { Box, useApp, useInput } from 'ink';
|
||||
import { Box, Text, useApp, useInput } from 'ink';
|
||||
import { StatusBar } from './StatusBar.js';
|
||||
import { MessageList } from './MessageList.js';
|
||||
import { InputBar } from './InputBar.js';
|
||||
@@ -7,8 +7,11 @@ import { parseCommand, getHelpText, resolveModelAlias, getCommandCompletions } f
|
||||
import type { Message, ModelClient, TokenUsage } from '../../../models/types.js';
|
||||
import type { ModelRouter } from '../../../models/router.js';
|
||||
import type { ManagedSession } from '../../../session/index.js';
|
||||
import type { NativeAgent } from '../../../backends/native/agent.js';
|
||||
import type { ToolUseEvent } from '../../../backends/native/agent.js';
|
||||
import type { NativeAgent, ToolUseEvent } from '../../../backends/native/agent.js';
|
||||
import type { HookEngine, HookResult } from '../../../hooks/index.js';
|
||||
import type { ModelConfig, ModelProvider } from '../../../config/schema.js';
|
||||
import { MODEL_PROVIDERS } from '../../../config/schema.js';
|
||||
import { createClientFromConfig } from '../../../daemon/index.js';
|
||||
|
||||
/** Format a tool name like "gmail.list" -> "Gmail: List" */
|
||||
function formatToolName(name: string): string {
|
||||
@@ -44,6 +47,8 @@ export interface AppProps {
|
||||
systemPrompt: string;
|
||||
model: string;
|
||||
agent?: NativeAgent;
|
||||
hookEngine?: HookEngine;
|
||||
modelProviderConfigs?: Partial<Record<ModelProvider, ModelConfig>>;
|
||||
onExit?: () => void;
|
||||
}
|
||||
|
||||
@@ -54,6 +59,8 @@ export function App({
|
||||
systemPrompt,
|
||||
model,
|
||||
agent,
|
||||
hookEngine,
|
||||
modelProviderConfigs,
|
||||
onExit,
|
||||
}: AppProps): React.ReactElement {
|
||||
const { exit } = useApp();
|
||||
@@ -63,13 +70,20 @@ export function App({
|
||||
const [streamingContent, setStreamingContent] = useState('');
|
||||
const [scrollOffset, setScrollOffset] = useState(0);
|
||||
const [tokenUsage, setTokenUsage] = useState<TokenUsage>({ inputTokens: 0, outputTokens: 0 });
|
||||
const [currentModel, setCurrentModel] = useState(model);
|
||||
const [currentModel, setCurrentModel] = useState(() => {
|
||||
if (!modelRouter) {return model;}
|
||||
return modelRouter.getLabel(modelRouter.getTier());
|
||||
});
|
||||
|
||||
const abortRef = useRef(false);
|
||||
const toolLinesRef = useRef<string[]>([]);
|
||||
|
||||
const confirmResolveRef = useRef<((result: HookResult) => void) | null>(null);
|
||||
const [confirmation, setConfirmation] = useState<{ tool: string; args: Record<string, unknown> } | null>(null);
|
||||
|
||||
// Set up an Ink-compatible onToolUse callback for the agent.
|
||||
// This replaces the process.stdout.write callback (which corrupts Ink rendering)
|
||||
// with one that updates React state to show tool activity in the streaming area.
|
||||
// This replaces process.stdout writes (which corrupt Ink rendering)
|
||||
// with one that updates React state to show tool activity.
|
||||
useEffect(() => {
|
||||
if (!agent) {return;}
|
||||
|
||||
@@ -79,7 +93,10 @@ export function App({
|
||||
const argsStr = event.args ? ` (${formatToolArgs(event.args)})` : '';
|
||||
toolLinesRef.current = [...toolLinesRef.current, `> ${label}${argsStr}`];
|
||||
setStreamingContent(toolLinesRef.current.join('\n'));
|
||||
} else if (event.type === 'end' && event.result) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === 'end' && event.result) {
|
||||
const icon = event.result.success ? 'done' : 'error';
|
||||
const detail = event.result.success
|
||||
? `(${event.result.output.split('\n').length} lines)`
|
||||
@@ -95,7 +112,43 @@ export function App({
|
||||
};
|
||||
}, [agent]);
|
||||
|
||||
// Inline confirmations for dangerous tools (e.g. shell.exec) in fullscreen mode.
|
||||
useEffect(() => {
|
||||
if (!hookEngine) {return;}
|
||||
|
||||
hookEngine.setInteractiveConfirmer(async (pending) => {
|
||||
return await new Promise<HookResult>((resolve) => {
|
||||
confirmResolveRef.current = resolve;
|
||||
setConfirmation({ tool: pending.tool, args: pending.args });
|
||||
});
|
||||
});
|
||||
|
||||
return () => {
|
||||
hookEngine.setInteractiveConfirmer(undefined);
|
||||
confirmResolveRef.current = null;
|
||||
setConfirmation(null);
|
||||
};
|
||||
}, [hookEngine]);
|
||||
|
||||
useInput((inputChar, key) => {
|
||||
// Confirmation prompt mode: capture y/n and ignore everything else.
|
||||
if (confirmation && confirmResolveRef.current) {
|
||||
const c = inputChar.toLowerCase();
|
||||
if (c === 'y') {
|
||||
confirmResolveRef.current({ approved: true });
|
||||
confirmResolveRef.current = null;
|
||||
setConfirmation(null);
|
||||
return;
|
||||
}
|
||||
if (c === 'n') {
|
||||
confirmResolveRef.current({ approved: false, reason: 'Denied by user' });
|
||||
confirmResolveRef.current = null;
|
||||
setConfirmation(null);
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (key.escape) {
|
||||
if (isStreaming) {
|
||||
abortRef.current = true;
|
||||
@@ -120,7 +173,6 @@ export function App({
|
||||
return;
|
||||
}
|
||||
|
||||
// Scroll handling
|
||||
if (key.upArrow && scrollOffset > 0) {
|
||||
setScrollOffset(prev => Math.max(0, prev - 1));
|
||||
}
|
||||
@@ -136,12 +188,15 @@ export function App({
|
||||
});
|
||||
|
||||
const handleSubmit = useCallback(async (value: string) => {
|
||||
if (confirmation) {
|
||||
return;
|
||||
}
|
||||
|
||||
const command = parseCommand(value);
|
||||
if (!command) {return;}
|
||||
|
||||
setInput('');
|
||||
|
||||
// Handle commands
|
||||
switch (command.type) {
|
||||
case 'quit':
|
||||
onExit?.();
|
||||
@@ -160,69 +215,124 @@ export function App({
|
||||
return;
|
||||
|
||||
case 'help': {
|
||||
// Show help as system message
|
||||
const helpMsg: Message = { role: 'assistant', content: getHelpText() };
|
||||
const helpWithTs = session.addMessage(helpMsg);
|
||||
setMessages(prev => [...prev, helpWithTs]);
|
||||
setMessages(prev => [...prev, session.addMessage(helpMsg)]);
|
||||
return;
|
||||
}
|
||||
|
||||
case 'status': {
|
||||
const status = `Session: ${session.id}\nMessages: ${messages.length}\nTokens: ${tokenUsage.inputTokens} in / ${tokenUsage.outputTokens} out`;
|
||||
const statusMsg: Message = { role: 'assistant', content: status };
|
||||
const statusWithTs = session.addMessage(statusMsg);
|
||||
setMessages(prev => [...prev, statusWithTs]);
|
||||
setMessages(prev => [...prev, session.addMessage(statusMsg)]);
|
||||
return;
|
||||
}
|
||||
|
||||
case 'model': {
|
||||
if (!modelRouter) {
|
||||
const errMsg: Message = { role: 'assistant', content: 'Model switching not available.' };
|
||||
const errWithTs = session.addMessage(errMsg);
|
||||
setMessages(prev => [...prev, errWithTs]);
|
||||
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: 'Model switching not available.' })]);
|
||||
return;
|
||||
}
|
||||
|
||||
// /model
|
||||
if (!command.name) {
|
||||
const info = `Current: ${modelRouter.getTier()}\nAvailable: ${modelRouter.getAvailableTiers().join(', ')}`;
|
||||
const infoMsg: Message = { role: 'assistant', content: info };
|
||||
const infoWithTs = session.addMessage(infoMsg);
|
||||
setMessages(prev => [...prev, infoWithTs]);
|
||||
const current = modelRouter.getTier();
|
||||
const available = modelRouter.getAvailableTiers();
|
||||
const labels = modelRouter.getAllLabels();
|
||||
|
||||
const lines: string[] = [];
|
||||
lines.push(`Active tier: ${current}`);
|
||||
for (const t of available) {
|
||||
const label = labels[t] ?? 'unknown';
|
||||
const strict = modelRouter.isTierStrict(t) ? ' (strict)' : '';
|
||||
lines.push(` ${t}: ${label}${strict}${t === current ? ' ←' : ''}`);
|
||||
}
|
||||
|
||||
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: lines.join('\n') })]);
|
||||
return;
|
||||
}
|
||||
|
||||
// /model <tier> <provider/model>
|
||||
if (command.providerModel) {
|
||||
const tier = resolveModelAlias(command.name);
|
||||
const providerModel = command.providerModel;
|
||||
|
||||
const slashIdx = providerModel.indexOf('/');
|
||||
if (slashIdx === -1) {
|
||||
setMessages(prev => [...prev, session.addMessage({
|
||||
role: 'assistant',
|
||||
content: 'Invalid format. Use provider/model (e.g. anthropic/claude-sonnet-4)',
|
||||
})]);
|
||||
return;
|
||||
}
|
||||
|
||||
const provider = providerModel.slice(0, slashIdx);
|
||||
const modelName = providerModel.slice(slashIdx + 1);
|
||||
|
||||
if (!MODEL_PROVIDERS.includes(provider as ModelProvider)) {
|
||||
setMessages(prev => [...prev, session.addMessage({
|
||||
role: 'assistant',
|
||||
content: `Unknown provider "${provider}". Known providers: ${MODEL_PROVIDERS.join(', ')}`,
|
||||
})]);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const providerType = provider as ModelProvider;
|
||||
const template = modelProviderConfigs?.[providerType];
|
||||
const client = createClientFromConfig({
|
||||
...(template ?? {}),
|
||||
provider: providerType,
|
||||
model: modelName,
|
||||
});
|
||||
|
||||
modelRouter.setClient(tier, client, providerModel);
|
||||
modelRouter.setTierStrict(tier, true);
|
||||
|
||||
if (agent && tier === modelRouter.getTier()) {
|
||||
agent.setModelTier(tier);
|
||||
setCurrentModel(modelRouter.getLabel(tier));
|
||||
}
|
||||
|
||||
setMessages(prev => [...prev, session.addMessage({
|
||||
role: 'assistant',
|
||||
content: `Set ${tier} to: ${providerModel}\nFallbacks for ${tier} disabled (strict tier mode).`,
|
||||
})]);
|
||||
} catch (error) {
|
||||
const msg = error instanceof Error ? error.message : String(error);
|
||||
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: `Failed to create client: ${msg}` })]);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// /model <tier>
|
||||
const tier = resolveModelAlias(command.name);
|
||||
if (modelRouter.setTier(tier)) {
|
||||
// Also update the agent tier so chatWithRouter uses the correct client
|
||||
if (agent) {
|
||||
agent.setModelTier(tier);
|
||||
}
|
||||
setCurrentModel(tier);
|
||||
const successMsg: Message = { role: 'assistant', content: `Switched to model: ${tier}` };
|
||||
const successWithTs = session.addMessage(successMsg);
|
||||
setMessages(prev => [...prev, successWithTs]);
|
||||
setCurrentModel(modelRouter.getLabel(tier));
|
||||
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: `Switched to model: ${tier}` })]);
|
||||
} else {
|
||||
const failMsg: Message = { role: 'assistant', content: `Model not available: ${command.name}` };
|
||||
const failWithTs = session.addMessage(failMsg);
|
||||
setMessages(prev => [...prev, failWithTs]);
|
||||
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: `Model not available: ${command.name}` })]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
case 'fullscreen':
|
||||
// Already in fullscreen
|
||||
return;
|
||||
|
||||
case 'transfer': {
|
||||
const xferMsg: Message = { role: 'assistant', content: 'Transfer not supported in fullscreen mode.' };
|
||||
const xferWithTs = session.addMessage(xferMsg);
|
||||
setMessages(prev => [...prev, xferWithTs]);
|
||||
case 'transfer':
|
||||
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: 'Transfer not supported in fullscreen mode.' })]);
|
||||
return;
|
||||
}
|
||||
|
||||
case 'message':
|
||||
break; // Continue to message handling
|
||||
break;
|
||||
}
|
||||
|
||||
if (command.type !== 'message' || isStreaming) {return;}
|
||||
if (command.type !== 'message' || isStreaming) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Add user message to UI (and session if no agent — agent adds it internally)
|
||||
const userMessage: Message = { role: 'user', content: command.content };
|
||||
@@ -232,9 +342,8 @@ export function App({
|
||||
} else {
|
||||
setMessages(prev => [...prev, { ...userMessage, timestamp: Date.now() }]);
|
||||
}
|
||||
setScrollOffset(0); // Auto-scroll to bottom
|
||||
setScrollOffset(0);
|
||||
|
||||
// Process response
|
||||
setIsStreaming(true);
|
||||
setStreamingContent('');
|
||||
toolLinesRef.current = [];
|
||||
@@ -242,16 +351,11 @@ export function App({
|
||||
|
||||
try {
|
||||
if (agent) {
|
||||
// agent.process() handles session history internally
|
||||
const response = await agent.process(command.content);
|
||||
|
||||
await agent.process(command.content);
|
||||
const usage = agent.getUsage();
|
||||
setTokenUsage({ inputTokens: usage.inputTokens, outputTokens: usage.outputTokens });
|
||||
|
||||
// Sync UI with session history (agent already added messages to session)
|
||||
setMessages(session.getHistory());
|
||||
} else if (modelClient.chatStream) {
|
||||
// Fallback: direct streaming without tools
|
||||
let fullContent = '';
|
||||
|
||||
for await (const event of modelClient.chatStream({
|
||||
@@ -279,10 +383,8 @@ export function App({
|
||||
}
|
||||
|
||||
const assistantMessage: Message = { role: 'assistant', content: fullContent };
|
||||
const assistantWithTimestamp = session.addMessage(assistantMessage);
|
||||
setMessages(prev => [...prev, assistantWithTimestamp]);
|
||||
setMessages(prev => [...prev, session.addMessage(assistantMessage)]);
|
||||
} else {
|
||||
// Fallback: non-streaming without tools
|
||||
const response = await modelClient.chat({
|
||||
messages: session.getHistory(),
|
||||
system: systemPrompt,
|
||||
@@ -294,21 +396,30 @@ export function App({
|
||||
}));
|
||||
|
||||
const assistantMessage: Message = { role: 'assistant', content: response.content };
|
||||
const assistantWithTimestamp = session.addMessage(assistantMessage);
|
||||
setMessages(prev => [...prev, assistantWithTimestamp]);
|
||||
setMessages(prev => [...prev, session.addMessage(assistantMessage)]);
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage: Message = {
|
||||
role: 'assistant',
|
||||
content: `Error: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
};
|
||||
const errorWithTimestamp = session.addMessage(errorMessage);
|
||||
setMessages(prev => [...prev, errorWithTimestamp]);
|
||||
const msg = error instanceof Error ? error.message : 'Unknown error';
|
||||
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: `Error: ${msg}` })]);
|
||||
} finally {
|
||||
setIsStreaming(false);
|
||||
setStreamingContent('');
|
||||
}
|
||||
}, [isStreaming, session, agent, modelClient, modelRouter, systemPrompt, exit, onExit, messages.length, tokenUsage]);
|
||||
}, [
|
||||
confirmation,
|
||||
session,
|
||||
agent,
|
||||
modelClient,
|
||||
modelRouter,
|
||||
systemPrompt,
|
||||
exit,
|
||||
onExit,
|
||||
isStreaming,
|
||||
messages.length,
|
||||
tokenUsage.inputTokens,
|
||||
tokenUsage.outputTokens,
|
||||
modelProviderConfigs,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Box flexDirection="column" height="100%">
|
||||
@@ -317,13 +428,27 @@ export function App({
|
||||
scrollOffset={scrollOffset}
|
||||
streamingContent={isStreaming ? streamingContent : undefined}
|
||||
/>
|
||||
|
||||
{confirmation ? (
|
||||
<Box paddingX={1} paddingY={0} borderStyle="round" borderColor="yellow">
|
||||
<Text color="yellow">
|
||||
Confirmation required: {confirmation.tool}{' '}
|
||||
{Object.keys(confirmation.args).length > 0 ? JSON.stringify(confirmation.args) : ''}
|
||||
</Text>
|
||||
<Text color="yellow">Press y to approve, n to deny.</Text>
|
||||
</Box>
|
||||
) : null}
|
||||
|
||||
<InputBar
|
||||
value={input}
|
||||
onChange={setInput}
|
||||
onSubmit={handleSubmit}
|
||||
isLoading={isStreaming}
|
||||
placeholder={isStreaming ? 'Flynn is typing... (Esc to cancel)' : 'Type a message... (Esc=exit, /help)'}
|
||||
isLoading={isStreaming || !!confirmation}
|
||||
placeholder={confirmation
|
||||
? 'Confirmation required (press y/n)'
|
||||
: (isStreaming ? 'Flynn is typing... (Esc to cancel)' : 'Type a message... (Esc=exit, /help)')}
|
||||
/>
|
||||
|
||||
<StatusBar
|
||||
sessionId={session.id}
|
||||
messageCount={messages.length}
|
||||
|
||||
@@ -5,6 +5,8 @@ import type { ManagedSession } from '../../session/index.js';
|
||||
import type { ModelClient } from '../../models/types.js';
|
||||
import type { ModelRouter } from '../../models/router.js';
|
||||
import type { NativeAgent } from '../../backends/native/agent.js';
|
||||
import type { HookEngine } from '../../hooks/index.js';
|
||||
import type { ModelConfig, ModelProvider } from '../../config/index.js';
|
||||
|
||||
export interface FullscreenTuiConfig {
|
||||
session: ManagedSession;
|
||||
@@ -13,6 +15,8 @@ export interface FullscreenTuiConfig {
|
||||
systemPrompt: string;
|
||||
model: string;
|
||||
agent?: NativeAgent;
|
||||
hookEngine?: HookEngine;
|
||||
modelProviderConfigs?: Partial<Record<ModelProvider, ModelConfig>>;
|
||||
onExit?: () => void;
|
||||
}
|
||||
|
||||
@@ -22,6 +26,10 @@ export async function startFullscreenTui(config: FullscreenTuiConfig): Promise<v
|
||||
process.stdin.resume();
|
||||
}
|
||||
|
||||
if (config.agent && config.modelRouter) {
|
||||
config.agent.setModelTier(config.modelRouter.getTier());
|
||||
}
|
||||
|
||||
const { waitUntilExit } = render(
|
||||
React.createElement(App, {
|
||||
session: config.session,
|
||||
@@ -30,6 +38,8 @@ export async function startFullscreenTui(config: FullscreenTuiConfig): Promise<v
|
||||
systemPrompt: config.systemPrompt,
|
||||
model: config.model,
|
||||
agent: config.agent,
|
||||
hookEngine: config.hookEngine,
|
||||
modelProviderConfigs: config.modelProviderConfigs,
|
||||
onExit: config.onExit,
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -118,4 +118,57 @@ describe('MinimalTui backend command', () => {
|
||||
expect(mockRouter.setTier).toHaveBeenCalledWith('local');
|
||||
expect(mockAgent.setModelTier).toHaveBeenCalledWith('local');
|
||||
});
|
||||
|
||||
it('reuses configured provider credentials for /model <tier> <provider/model>', () => {
|
||||
const prevOpenRouterKey = process.env.OPENROUTER_API_KEY;
|
||||
delete process.env.OPENROUTER_API_KEY;
|
||||
|
||||
try {
|
||||
const mockSession = {
|
||||
id: 'test',
|
||||
getHistory: () => [],
|
||||
addMessage: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
replaceHistory: vi.fn(),
|
||||
};
|
||||
|
||||
const mockRouter = {
|
||||
getTier: () => 'default' as const,
|
||||
getAvailableTiers: () => ['default', 'local'],
|
||||
setTier: vi.fn(() => true),
|
||||
getLocalProviderName: () => 'ollama',
|
||||
setLocalClient: vi.fn(),
|
||||
setClient: vi.fn(),
|
||||
setTierStrict: vi.fn(),
|
||||
chat: vi.fn(),
|
||||
getClient: vi.fn(),
|
||||
};
|
||||
|
||||
const tui = new MinimalTui({
|
||||
session: mockSession as any,
|
||||
modelClient: mockRouter as any,
|
||||
modelRouter: mockRouter as any,
|
||||
systemPrompt: 'test',
|
||||
modelProviderConfigs: {
|
||||
openrouter: {
|
||||
provider: 'openrouter',
|
||||
model: 'seed-model',
|
||||
api_key: 'test-key',
|
||||
endpoint: 'https://openrouter.ai/api/v1',
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
(tui as any).handleModelCommand('default', 'openrouter/deepseek/deepseek-chat');
|
||||
|
||||
expect(mockRouter.setClient).toHaveBeenCalledOnce();
|
||||
expect(mockRouter.setTierStrict).toHaveBeenCalledWith('default', true);
|
||||
} finally {
|
||||
if (prevOpenRouterKey) {
|
||||
process.env.OPENROUTER_API_KEY = prevOpenRouterKey;
|
||||
} else {
|
||||
delete process.env.OPENROUTER_API_KEY;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,9 +9,10 @@ import type { ModelConfig, ModelProvider } from '../../config/schema.js';
|
||||
import { MODEL_PROVIDERS } from '../../config/schema.js';
|
||||
import { OllamaClient, LlamaCppClient } from '../../models/index.js';
|
||||
import { createClientFromConfig } from '../../daemon/index.js';
|
||||
import { loginGitHub } from '../../auth/index.js';
|
||||
import { loginGitHub, loginOpenAI } from '../../auth/index.js';
|
||||
import type { PairingManager } from '../../channels/pairing.js';
|
||||
import { getColoredBanner } from './banner.js';
|
||||
import type { HookEngine } from '../../hooks/index.js';
|
||||
|
||||
export { parseCommand, type Command };
|
||||
|
||||
@@ -42,8 +43,10 @@ export interface MinimalTuiConfig {
|
||||
onFullscreen?: () => void;
|
||||
onTransfer?: (target: string) => void;
|
||||
localProviders?: Record<string, ModelConfig>;
|
||||
modelProviderConfigs?: Partial<Record<ModelProvider, ModelConfig>>;
|
||||
currentLocalProvider?: string;
|
||||
pairingManager?: PairingManager;
|
||||
hookEngine?: HookEngine;
|
||||
}
|
||||
|
||||
export class MinimalTui {
|
||||
@@ -99,6 +102,10 @@ export class MinimalTui {
|
||||
async start(): Promise<void> {
|
||||
this.running = true;
|
||||
|
||||
if (this.config.agent && this.config.modelRouter) {
|
||||
this.config.agent.setModelTier(this.config.modelRouter.getTier());
|
||||
}
|
||||
|
||||
this.rl = readline.createInterface({
|
||||
input: process.stdin,
|
||||
output: process.stdout,
|
||||
@@ -108,6 +115,26 @@ export class MinimalTui {
|
||||
},
|
||||
});
|
||||
|
||||
// In minimal TUI we can prompt inline for tool confirmations.
|
||||
// This avoids deadlocks when hooks are configured to require confirmation
|
||||
// (e.g. shell.exec) and the tool loop is awaiting a decision.
|
||||
if (this.config.hookEngine) {
|
||||
this.config.hookEngine.setInteractiveConfirmer(async (pending) => {
|
||||
const tool = pending.tool;
|
||||
const args = pending.args;
|
||||
const argsStr = Object.keys(args).length > 0 ? ` ${JSON.stringify(args)}` : '';
|
||||
console.log(`\n${colors.bold}Confirmation required${colors.reset}`);
|
||||
console.log(`${colors.gray}${tool}${colors.reset}${argsStr}`);
|
||||
|
||||
const answer = (await this.prompt(`${colors.orange}${colors.bold}Approve?${colors.reset} ${colors.gray}(y/N)${colors.reset} `))
|
||||
.trim()
|
||||
.toLowerCase();
|
||||
const approved = answer === 'y' || answer === 'yes';
|
||||
console.log(approved ? `${colors.gray}Approved.${colors.reset}\n` : `${colors.gray}Denied.${colors.reset}\n`);
|
||||
return approved ? { approved: true } : { approved: false, reason: 'Denied by user' };
|
||||
});
|
||||
}
|
||||
|
||||
// Listen for line changes to show hints
|
||||
process.stdin.on('keypress', () => {
|
||||
// Small delay to let readline update the line
|
||||
@@ -239,9 +266,22 @@ export class MinimalTui {
|
||||
}
|
||||
|
||||
try {
|
||||
const client = createClientFromConfig({ provider: provider as ModelProvider, model });
|
||||
const providerType = provider as ModelProvider;
|
||||
const template = this.config.modelProviderConfigs?.[providerType];
|
||||
const client = createClientFromConfig({
|
||||
...(template ?? {}),
|
||||
provider: providerType,
|
||||
model,
|
||||
});
|
||||
router.setClient(tier, client, providerModel);
|
||||
console.log(`${colors.gray}Set ${tier} to:${colors.reset} ${providerModel}\n`);
|
||||
router.setTierStrict(tier, true);
|
||||
|
||||
if (this.config.agent && tier === router.getTier()) {
|
||||
this.config.agent.setModelTier(tier);
|
||||
}
|
||||
|
||||
console.log(`${colors.gray}Set ${tier} to:${colors.reset} ${providerModel}`);
|
||||
console.log(`${colors.gray}Fallbacks for ${tier} disabled (strict tier mode).${colors.reset}\n`);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.log(`${colors.gray}Failed to create client:${colors.reset} ${message}\n`);
|
||||
@@ -383,27 +423,49 @@ export class MinimalTui {
|
||||
|
||||
private async handleLoginCommand(provider?: string): Promise<void> {
|
||||
const target = provider ?? 'github';
|
||||
if (target !== 'github') {
|
||||
console.log(`${colors.gray}Unknown login provider:${colors.reset} ${target}. Only 'github' is supported.\n`);
|
||||
if (target === 'github') {
|
||||
console.log(`${colors.gray}Starting GitHub OAuth device flow...${colors.reset}`);
|
||||
|
||||
try {
|
||||
await loginGitHub((userCode, verificationUri) => {
|
||||
console.log('');
|
||||
console.log(`${colors.gray}Please visit:${colors.reset} ${verificationUri}`);
|
||||
console.log(`${colors.gray}and enter code:${colors.reset} ${userCode}`);
|
||||
console.log('');
|
||||
console.log(`${colors.gray}Waiting for authorization...${colors.reset}`);
|
||||
});
|
||||
|
||||
console.log(`${colors.gray}GitHub authentication successful! Token stored.${colors.reset}\n`);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.log(`${colors.gray}GitHub login failed:${colors.reset} ${message}\n`);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`${colors.gray}Starting GitHub OAuth device flow...${colors.reset}`);
|
||||
if (target === 'openai') {
|
||||
console.log(`${colors.gray}Starting OpenAI OAuth device flow...${colors.reset}`);
|
||||
|
||||
try {
|
||||
await loginGitHub((userCode, verificationUri) => {
|
||||
console.log('');
|
||||
console.log(`${colors.gray}Please visit:${colors.reset} ${verificationUri}`);
|
||||
console.log(`${colors.gray}and enter code:${colors.reset} ${userCode}`);
|
||||
console.log('');
|
||||
console.log(`${colors.gray}Waiting for authorization...${colors.reset}`);
|
||||
});
|
||||
try {
|
||||
await loginOpenAI((userCode, verificationUri) => {
|
||||
console.log('');
|
||||
console.log(`${colors.gray}Please visit:${colors.reset} ${verificationUri}`);
|
||||
console.log(`${colors.gray}and enter code:${colors.reset} ${userCode}`);
|
||||
console.log('');
|
||||
console.log(`${colors.gray}Waiting for authorization...${colors.reset}`);
|
||||
});
|
||||
|
||||
console.log(`${colors.gray}GitHub authentication successful! Token stored.${colors.reset}\n`);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.log(`${colors.gray}GitHub login failed:${colors.reset} ${message}\n`);
|
||||
console.log(`${colors.gray}OpenAI authentication successful! Token stored.${colors.reset}\n`);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.log(`${colors.gray}OpenAI login failed:${colors.reset} ${message}\n`);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`${colors.gray}Unknown login provider:${colors.reset} ${target}. Supported: github, openai\n`);
|
||||
}
|
||||
|
||||
private handlePairCommand(action?: 'generate' | 'list' | 'revoke', args?: string): void {
|
||||
|
||||
@@ -88,11 +88,19 @@ export function createAgentHandlers(deps: AgentHandlerDeps) {
|
||||
return lines.join('\n');
|
||||
},
|
||||
getModel: () => `Current model tier: ${agent.getModelTier()}`,
|
||||
setModel: (tier) => {
|
||||
setModel: (input) => {
|
||||
const raw = input.trim();
|
||||
if (!raw) {
|
||||
return 'Usage: /model <tier>';
|
||||
}
|
||||
const [requestedTier, ...rest] = raw.split(/\s+/);
|
||||
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
|
||||
const modelTier = tier as ModelTier;
|
||||
const modelTier = requestedTier as ModelTier;
|
||||
if (!validTiers.includes(modelTier)) {
|
||||
return `Invalid tier: ${tier}. Available: ${validTiers.join(', ')}`;
|
||||
return `Invalid tier: ${requestedTier}. Available: ${validTiers.join(', ')}`;
|
||||
}
|
||||
if (rest.length > 0) {
|
||||
return `Switched to model tier: ${modelTier}\nNote: provider/model switching is not available via gateway (/model <tier> <provider/model>).`;
|
||||
}
|
||||
agent.setModelTier(modelTier);
|
||||
if (sessionId && deps.sessionManager) {
|
||||
|
||||
@@ -72,4 +72,19 @@ describe('HookEngine', () => {
|
||||
expect(result.approved).toBe(false);
|
||||
expect(result.reason).toBe('Too dangerous');
|
||||
});
|
||||
|
||||
it('uses interactive confirmer when set (no pending queue)', async () => {
|
||||
const engine = new HookEngine({ confirm: ['shell.*'], log: [], silent: [] });
|
||||
const confirmer = vi.fn(async () => ({ approved: true }));
|
||||
engine.setInteractiveConfirmer(confirmer);
|
||||
|
||||
const result = await engine.requestConfirmation('shell.exec', { cmd: 'ls' });
|
||||
expect(result.approved).toBe(true);
|
||||
expect(engine.getPendingConfirmations()).toHaveLength(0);
|
||||
expect(confirmer).toHaveBeenCalledOnce();
|
||||
expect(confirmer).toHaveBeenCalledWith(expect.objectContaining({
|
||||
tool: 'shell.exec',
|
||||
args: { cmd: 'ls' },
|
||||
}));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,16 +1,32 @@
|
||||
import { randomUUID } from 'crypto';
|
||||
import type { HookAction, HookResult, PendingConfirmation, HookConfig } from './types.js';
|
||||
|
||||
export type InteractiveConfirmer = (pending: {
|
||||
id: string;
|
||||
tool: string;
|
||||
args: Record<string, unknown>;
|
||||
}) => Promise<HookResult>;
|
||||
|
||||
export class HookEngine {
|
||||
private confirmPatterns: RegExp[];
|
||||
private logPatterns: RegExp[];
|
||||
private pendingConfirmations: Map<string, PendingConfirmation> = new Map();
|
||||
private interactiveConfirmer?: InteractiveConfirmer;
|
||||
|
||||
constructor(config: HookConfig) {
|
||||
this.confirmPatterns = config.confirm.map(p => this.patternToRegex(p));
|
||||
this.logPatterns = config.log.map(p => this.patternToRegex(p));
|
||||
}
|
||||
|
||||
/**
|
||||
* Optional interactive confirmation handler.
|
||||
* When set, confirmation requests are handled immediately (no pending queue).
|
||||
* Useful for CLI/TUI environments where we can prompt the user inline.
|
||||
*/
|
||||
setInteractiveConfirmer(confirmer: InteractiveConfirmer | undefined): void {
|
||||
this.interactiveConfirmer = confirmer;
|
||||
}
|
||||
|
||||
private patternToRegex(pattern: string): RegExp {
|
||||
const escaped = pattern
|
||||
.replace(/[.+^${}()|[\]\\]/g, '\\$&')
|
||||
@@ -31,6 +47,10 @@ export class HookEngine {
|
||||
async requestConfirmation(tool: string, args: Record<string, unknown>): Promise<HookResult> {
|
||||
const id = randomUUID();
|
||||
|
||||
if (this.interactiveConfirmer) {
|
||||
return await this.interactiveConfirmer({ id, tool, args });
|
||||
}
|
||||
|
||||
return new Promise((resolve) => {
|
||||
const pending: PendingConfirmation = {
|
||||
id,
|
||||
|
||||
@@ -140,6 +140,44 @@ describe('LlamaCppClient', () => {
|
||||
}]);
|
||||
});
|
||||
|
||||
it('sanitizes web_search tool schema for llama.cpp', async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({
|
||||
choices: [{ message: { content: 'ok' } }],
|
||||
usage: { prompt_tokens: 1, completion_tokens: 1 },
|
||||
}),
|
||||
});
|
||||
|
||||
const client = new LlamaCppClient({
|
||||
endpoint: 'http://localhost:8080',
|
||||
model: 'test-model',
|
||||
});
|
||||
|
||||
await client.chat({
|
||||
messages: [{ role: 'user', content: 'search' }],
|
||||
tools: [{
|
||||
name: 'web_search',
|
||||
description: 'Search',
|
||||
input_schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string' },
|
||||
count: { type: 'number' },
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
|
||||
expect(requestBody.tools[0].function.parameters).toEqual({
|
||||
type: 'object',
|
||||
properties: { query: { type: 'string' } },
|
||||
required: ['query'],
|
||||
});
|
||||
});
|
||||
|
||||
it('parses tool_calls from response', async () => {
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
|
||||
@@ -48,6 +48,36 @@ interface LlamaCppStreamChunk {
|
||||
usage?: { prompt_tokens: number; completion_tokens: number };
|
||||
}
|
||||
|
||||
function sanitizeToolParametersForLlamaCpp(toolName: string, parameters: unknown): unknown {
|
||||
// llama.cpp is stricter than most tool-call APIs about JSON schema.
|
||||
// In particular, some builds reject extra optional properties for common tools.
|
||||
// Keep the full schema for most tools, but reduce known-problematic ones.
|
||||
if (toolName !== 'web_search') {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
if (!parameters || typeof parameters !== 'object') {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
const schema = parameters as {
|
||||
type?: unknown;
|
||||
properties?: Record<string, unknown>;
|
||||
required?: unknown;
|
||||
};
|
||||
|
||||
const querySchema = schema.properties?.query;
|
||||
if (!querySchema) {
|
||||
return parameters;
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'object',
|
||||
properties: { query: querySchema },
|
||||
required: ['query'],
|
||||
};
|
||||
}
|
||||
|
||||
/** Message format for OpenAI-compatible chat completions API. */
|
||||
interface LlamaCppChatMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool';
|
||||
@@ -211,7 +241,7 @@ export class LlamaCppClient implements ModelClient {
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema,
|
||||
parameters: sanitizeToolParametersForLlamaCpp(t.name, t.input_schema),
|
||||
},
|
||||
}));
|
||||
}
|
||||
@@ -292,7 +322,7 @@ export class LlamaCppClient implements ModelClient {
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema,
|
||||
parameters: sanitizeToolParametersForLlamaCpp(t.name, t.input_schema),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
|
||||
import { OpenAIClient } from './openai.js';
|
||||
|
||||
vi.mock('../auth/openai.js', () => ({
|
||||
ensureValidOpenAIAuth: vi.fn(async () => ({
|
||||
access_token: 'at',
|
||||
refresh_token: 'rt',
|
||||
expires_at: Date.now() + 60_000,
|
||||
created_at: new Date().toISOString(),
|
||||
account_id: 'acct',
|
||||
})),
|
||||
}));
|
||||
|
||||
function makeSse(events: Array<{ event: string; data: any }>): string {
|
||||
return events
|
||||
.map((e) => `event: ${e.event}\ndata: ${JSON.stringify(e.data)}\n\n`)
|
||||
.join('');
|
||||
}
|
||||
|
||||
describe('OpenAIClient OAuth (Codex)', () => {
|
||||
const originalFetch = globalThis.fetch;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch;
|
||||
vi.useRealTimers();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('streams SSE and accumulates output_text.delta', async () => {
|
||||
const sse = makeSse([
|
||||
{ event: 'response.created', data: { type: 'response.created', response: { id: 'r1' } } },
|
||||
{ event: 'response.output_text.delta', data: { type: 'response.output_text.delta', delta: 'hel' } },
|
||||
{ event: 'response.output_text.delta', data: { type: 'response.output_text.delta', delta: 'lo' } },
|
||||
{ event: 'response.completed', data: { type: 'response.completed', response: { usage: { input_tokens: 2, output_tokens: 2 } } } },
|
||||
]);
|
||||
|
||||
globalThis.fetch = vi.fn(async (_url: any, init?: any) => {
|
||||
const parsed = JSON.parse(init.body);
|
||||
expect(parsed.store).toBe(false);
|
||||
expect(parsed.stream).toBe(true);
|
||||
expect(typeof parsed.instructions).toBe('string');
|
||||
expect(Array.isArray(parsed.input)).toBe(true);
|
||||
|
||||
const stream = new ReadableStream({
|
||||
start(controller) {
|
||||
controller.enqueue(new TextEncoder().encode(sse));
|
||||
controller.close();
|
||||
},
|
||||
});
|
||||
|
||||
return new Response(stream, { status: 200 });
|
||||
}) as any;
|
||||
|
||||
const client = new OpenAIClient({ model: 'gpt-5.3-codex', useOAuth: true });
|
||||
const resp = await client.chat({
|
||||
system: 'You are helpful.',
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
});
|
||||
|
||||
expect(resp.content).toBe('hello');
|
||||
expect(resp.usage).toEqual({ inputTokens: 2, outputTokens: 2 });
|
||||
});
|
||||
});
|
||||
@@ -1,23 +1,38 @@
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { OpenAIClient } from './openai.js';
|
||||
|
||||
// Shared mock function so we can override per-test
|
||||
const mockCreate = vi.fn().mockResolvedValue({
|
||||
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
|
||||
usage: { prompt_tokens: 10, completion_tokens: 5 },
|
||||
});
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({
|
||||
const { mockCreate, mockOpenAIConstructor } = vi.hoisted(() => {
|
||||
const mockCreate = vi.fn().mockResolvedValue({
|
||||
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
|
||||
usage: { prompt_tokens: 10, completion_tokens: 5 },
|
||||
});
|
||||
const mockOpenAIConstructor = vi.fn().mockImplementation(() => ({
|
||||
chat: {
|
||||
completions: {
|
||||
create: mockCreate,
|
||||
},
|
||||
},
|
||||
})),
|
||||
}));
|
||||
return { mockCreate, mockOpenAIConstructor };
|
||||
});
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
default: mockOpenAIConstructor,
|
||||
}));
|
||||
|
||||
describe('OpenAIClient', () => {
|
||||
it('sets request timeout and disables SDK retries', () => {
|
||||
new OpenAIClient({
|
||||
apiKey: 'test-key',
|
||||
model: 'gpt-4o',
|
||||
});
|
||||
|
||||
expect(mockOpenAIConstructor).toHaveBeenCalledWith(expect.objectContaining({
|
||||
timeout: 20_000,
|
||||
maxRetries: 0,
|
||||
}));
|
||||
});
|
||||
|
||||
it('sends messages and returns response', async () => {
|
||||
const client = new OpenAIClient({
|
||||
apiKey: 'test-key',
|
||||
|
||||
+151
-6
@@ -1,11 +1,16 @@
|
||||
import OpenAI from 'openai';
|
||||
import type { ChatRequest, ChatResponse, ModelClient, MessageContentPart } from './types.js';
|
||||
import type { ChatRequest, ChatResponse, ModelClient, MessageContentPart, TokenUsage } from './types.js';
|
||||
import { getMessageTextWithTools } from './media.js';
|
||||
import { ensureValidOpenAIAuth } from '../auth/openai.js';
|
||||
|
||||
export interface OpenAIClientConfig {
|
||||
apiKey?: string;
|
||||
model: string;
|
||||
maxTokens?: number;
|
||||
baseURL?: string;
|
||||
timeoutMs?: number;
|
||||
/** If true, use ChatGPT subscription OAuth via the Codex backend endpoint. */
|
||||
useOAuth?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -52,20 +57,160 @@ function toOpenAIContent(content: string | MessageContentPart[]): string | OpenA
|
||||
}
|
||||
|
||||
export class OpenAIClient implements ModelClient {
|
||||
private client: OpenAI;
|
||||
private client?: OpenAI;
|
||||
private model: string;
|
||||
private defaultMaxTokens: number;
|
||||
private useOAuth: boolean;
|
||||
|
||||
constructor(config: OpenAIClientConfig) {
|
||||
this.client = new OpenAI({
|
||||
apiKey: config.apiKey,
|
||||
baseURL: config.baseURL,
|
||||
});
|
||||
const timeoutMs = config.timeoutMs ?? 20_000;
|
||||
this.useOAuth = Boolean(config.useOAuth);
|
||||
|
||||
// OAuth mode uses a different backend (ChatGPT Codex) and a different API shape.
|
||||
// Only initialize the OpenAI SDK for API-key providers.
|
||||
if (!this.useOAuth) {
|
||||
this.client = new OpenAI({
|
||||
apiKey: config.apiKey,
|
||||
baseURL: config.baseURL,
|
||||
timeout: timeoutMs,
|
||||
maxRetries: 0,
|
||||
});
|
||||
}
|
||||
this.model = config.model;
|
||||
this.defaultMaxTokens = config.maxTokens ?? 4096;
|
||||
}
|
||||
|
||||
private async chatViaOAuthCodex(request: ChatRequest): Promise<ChatResponse> {
|
||||
const CODEX_API_ENDPOINT = 'https://chatgpt.com/backend-api/codex/responses';
|
||||
|
||||
const auth = await ensureValidOpenAIAuth();
|
||||
|
||||
// Codex endpoint requires:
|
||||
// - instructions (non-empty)
|
||||
// - input must be a list
|
||||
// - store must be false
|
||||
// - stream must be true (SSE)
|
||||
const instructions = (request.system ?? '').trim() || 'You are helpful.';
|
||||
|
||||
const input = request.messages
|
||||
.map((m) => {
|
||||
const text = getMessageTextWithTools(m);
|
||||
if (!text) {return null;}
|
||||
return {
|
||||
role: m.role,
|
||||
content: [{ type: 'input_text', text }],
|
||||
};
|
||||
})
|
||||
.filter((x): x is NonNullable<typeof x> => Boolean(x));
|
||||
|
||||
const body = {
|
||||
model: this.model,
|
||||
instructions,
|
||||
store: false,
|
||||
stream: true,
|
||||
input,
|
||||
// Intentionally omit max_output_tokens: Codex endpoint rejects it.
|
||||
// Also omit tools/tool_choice for now.
|
||||
};
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
'content-type': 'application/json',
|
||||
'authorization': `Bearer ${auth.access_token}`,
|
||||
'originator': 'flynn',
|
||||
'user-agent': 'flynn/0.1',
|
||||
'session_id': `flynn-${Date.now()}`,
|
||||
};
|
||||
if (auth.account_id) {
|
||||
headers['ChatGPT-Account-Id'] = auth.account_id;
|
||||
}
|
||||
|
||||
const res = await fetch(CODEX_API_ENDPOINT, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
const text = await res.text();
|
||||
throw new Error(`${res.status} ${res.statusText}${text ? `: ${text}` : ''}`);
|
||||
}
|
||||
|
||||
if (!res.body) {
|
||||
throw new Error('OpenAI OAuth request failed: missing response body');
|
||||
}
|
||||
|
||||
let buffer = '';
|
||||
let outputText = '';
|
||||
let usage: TokenUsage | undefined;
|
||||
|
||||
const reader = res.body.getReader();
|
||||
|
||||
const processBlock = (block: string): void => {
|
||||
const lines = block.split('\n');
|
||||
let data = '';
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data:')) {
|
||||
data += line.slice('data:'.length).trim();
|
||||
}
|
||||
}
|
||||
if (!data) {return;}
|
||||
let obj: any;
|
||||
try {
|
||||
obj = JSON.parse(data);
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
if (obj.type === 'response.output_text.delta' && typeof obj.delta === 'string') {
|
||||
outputText += obj.delta;
|
||||
}
|
||||
|
||||
if (obj.type === 'response.completed') {
|
||||
const u = obj.response?.usage;
|
||||
if (u) {
|
||||
usage = {
|
||||
inputTokens: u.input_tokens ?? 0,
|
||||
outputTokens: u.output_tokens ?? 0,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (obj.type === 'response.failed') {
|
||||
const detail = obj.response?.error?.message ?? 'OpenAI OAuth response failed';
|
||||
throw new Error(detail);
|
||||
}
|
||||
};
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {break;}
|
||||
buffer += Buffer.from(value).toString('utf8');
|
||||
|
||||
while (true) {
|
||||
const idx = buffer.indexOf('\n\n');
|
||||
if (idx === -1) {break;}
|
||||
const block = buffer.slice(0, idx);
|
||||
buffer = buffer.slice(idx + 2);
|
||||
processBlock(block);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: outputText,
|
||||
stopReason: 'end_turn',
|
||||
usage: usage ?? { inputTokens: 0, outputTokens: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
async chat(request: ChatRequest): Promise<ChatResponse> {
|
||||
if (this.useOAuth) {
|
||||
return this.chatViaOAuthCodex(request);
|
||||
}
|
||||
|
||||
if (!this.client) {
|
||||
throw new Error('OpenAI client not initialized');
|
||||
}
|
||||
|
||||
const messages: OpenAI.ChatCompletionMessageParam[] = [];
|
||||
|
||||
if (request.system) {
|
||||
|
||||
@@ -4,10 +4,15 @@ import type { RetryConfig } from './retry.js';
|
||||
|
||||
describe('isRetryable', () => {
|
||||
it('returns true for generic errors', () => {
|
||||
const error = new Error('Connection timeout');
|
||||
const error = new Error('Connection reset by peer');
|
||||
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false for timeout errors', () => {
|
||||
const error = new Error('Request timed out after 20000ms');
|
||||
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false for authentication errors', () => {
|
||||
const error = new Error('Invalid API key: authentication failed');
|
||||
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(false);
|
||||
@@ -75,8 +80,8 @@ describe('withRetry', () => {
|
||||
|
||||
it('retries on transient failure then succeeds', async () => {
|
||||
const fn = vi.fn()
|
||||
.mockRejectedValueOnce(new Error('timeout'))
|
||||
.mockRejectedValueOnce(new Error('timeout'))
|
||||
.mockRejectedValueOnce(new Error('temporary network issue'))
|
||||
.mockRejectedValueOnce(new Error('temporary network issue'))
|
||||
.mockResolvedValueOnce('recovered');
|
||||
|
||||
const result = await withRetry(fn, fastConfig, 'test-op');
|
||||
|
||||
@@ -26,6 +26,9 @@ export const DEFAULT_RETRY_CONFIG: RetryConfig = {
|
||||
'context_length_exceeded',
|
||||
'content_policy',
|
||||
'does not support',
|
||||
'timeout',
|
||||
'timed out',
|
||||
'request aborted',
|
||||
],
|
||||
};
|
||||
|
||||
|
||||
@@ -438,4 +438,29 @@ describe('setClient and labels', () => {
|
||||
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
|
||||
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('strict tier mode disables fallback chain for that tier', async () => {
|
||||
const failingDefault = {
|
||||
chat: vi.fn().mockRejectedValue(new Error('primary failed')),
|
||||
} as unknown as ModelClient;
|
||||
const fallback = {
|
||||
chat: vi.fn().mockResolvedValue({
|
||||
content: 'fallback',
|
||||
stopReason: 'end_turn',
|
||||
usage: { inputTokens: 1, outputTokens: 1 },
|
||||
}),
|
||||
} as unknown as ModelClient;
|
||||
|
||||
const router = new ModelRouter({
|
||||
default: failingDefault,
|
||||
fallbackChain: [fallback],
|
||||
});
|
||||
|
||||
router.setTierStrict('default', true);
|
||||
|
||||
await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'default'))
|
||||
.rejects.toThrow('primary failed');
|
||||
expect(fallback.chat).not.toHaveBeenCalled();
|
||||
expect(router.isTierStrict('default')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -27,6 +27,7 @@ export class ModelRouter implements ModelClient {
|
||||
private localProviderName?: string;
|
||||
private retryConfig?: RetryConfig;
|
||||
private tierChangeListeners: Array<(tier: ModelTier) => void> = [];
|
||||
private strictTiers: Set<ModelTier> = new Set();
|
||||
|
||||
constructor(config: ModelRouterConfig) {
|
||||
this.clients = new Map();
|
||||
@@ -97,6 +98,10 @@ export class ModelRouter implements ModelClient {
|
||||
logger.debug(`Primary model failed: ${errors[0].message}`);
|
||||
}
|
||||
|
||||
if (this.strictTiers.has(useTier)) {
|
||||
throw errors[0];
|
||||
}
|
||||
|
||||
// Try tier-specific fallbacks first
|
||||
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
|
||||
for (let i = 0; i < tierFallbackList.length; i++) {
|
||||
@@ -150,6 +155,11 @@ export class ModelRouter implements ModelClient {
|
||||
primaryError = 'Primary client does not support streaming';
|
||||
}
|
||||
|
||||
if (this.strictTiers.has(useTier)) {
|
||||
yield { type: 'error', error: new Error(primaryError ?? 'Primary model failed') };
|
||||
return;
|
||||
}
|
||||
|
||||
// Try tier-specific fallbacks first
|
||||
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
|
||||
for (let i = 0; i < tierFallbackList.length; i++) {
|
||||
@@ -216,6 +226,18 @@ export class ModelRouter implements ModelClient {
|
||||
this.labels.set(tier, label);
|
||||
}
|
||||
|
||||
setTierStrict(tier: ModelTier, strict: boolean): void {
|
||||
if (strict) {
|
||||
this.strictTiers.add(tier);
|
||||
return;
|
||||
}
|
||||
this.strictTiers.delete(tier);
|
||||
}
|
||||
|
||||
isTierStrict(tier: ModelTier): boolean {
|
||||
return this.strictTiers.has(tier);
|
||||
}
|
||||
|
||||
getLabel(tier: ModelTier): string {
|
||||
return this.labels.get(tier) ?? 'unknown';
|
||||
}
|
||||
|
||||
@@ -44,6 +44,9 @@ const testConfig: NonNullable<GmailConfig> = {
|
||||
enabled: true,
|
||||
credentials_file: '/tmp/test-creds.json',
|
||||
token_file: '/tmp/test-token.json',
|
||||
disable_push: false,
|
||||
pubsub_pull_interval: '60s',
|
||||
pubsub_max_messages: 10,
|
||||
watch_labels: ['INBOX'],
|
||||
poll_interval: '300s',
|
||||
output: { channel: 'discord', peer: '123' },
|
||||
|
||||
Reference in New Issue
Block a user