feat: add OpenAI OAuth, strict model overrides, and Gmail pull mode

This commit is contained in:
William Valentin
2026-02-13 14:55:40 -08:00
parent 8f644d5e25
commit 955b9e28e0
50 changed files with 5955 additions and 160 deletions
+13
View File
@@ -7,3 +7,16 @@ export {
loginGitHub,
type DeviceCodeResponse,
} from './github.js';
export {
loadStoredOpenAIAuth,
storeOpenAIAuth,
clearOpenAIAuth,
refreshOpenAIAuth,
ensureValidOpenAIAuth,
loginOpenAI,
parseJwtClaims,
extractAccountId,
type OpenAIOAuthInfo,
type IdTokenClaims,
} from './openai.js';
+43
View File
@@ -0,0 +1,43 @@
import { describe, it, expect } from 'vitest';
import { parseJwtClaims, extractAccountId } from './openai.js';
function base64UrlEncode(obj: unknown): string {
return Buffer.from(JSON.stringify(obj)).toString('base64url');
}
function makeJwt(payload: Record<string, unknown>): string {
const header = base64UrlEncode({ alg: 'none', typ: 'JWT' });
const body = base64UrlEncode(payload);
// Signature is ignored by parseJwtClaims.
return `${header}.${body}.sig`;
}
describe('OpenAI OAuth helpers', () => {
it('parseJwtClaims returns undefined for non-jwt strings', () => {
expect(parseJwtClaims('not-a-jwt')).toBeUndefined();
});
it('parseJwtClaims parses base64url payload', () => {
const token = makeJwt({ chatgpt_account_id: 'acct_123' });
const claims = parseJwtClaims(token);
expect(claims?.chatgpt_account_id).toBe('acct_123');
});
it('extractAccountId prefers chatgpt_account_id', () => {
const tokens = {
access_token: makeJwt({ chatgpt_account_id: 'acct_a' }),
refresh_token: 'rt',
id_token: makeJwt({ chatgpt_account_id: 'acct_b' }),
};
expect(extractAccountId(tokens)).toBe('acct_b');
});
it('extractAccountId falls back to organizations[0].id', () => {
const tokens = {
access_token: makeJwt({ organizations: [{ id: 'org_1' }] }),
refresh_token: 'rt',
};
expect(extractAccountId(tokens)).toBe('org_1');
});
});
+281
View File
@@ -0,0 +1,281 @@
import { readFileSync, writeFileSync, mkdirSync, chmodSync } from 'fs';
import { resolve } from 'path';
import { homedir } from 'os';
const ISSUER = 'https://auth.openai.com';
const CLIENT_ID = 'app_EMoamEEZ73f0CkXaXp7hrann';
const DEVICE_URL = `${ISSUER}/codex/device`;
const DEVICE_CODE_URL = `${ISSUER}/api/accounts/deviceauth/usercode`;
const DEVICE_TOKEN_URL = `${ISSUER}/api/accounts/deviceauth/token`;
const TOKEN_URL = `${ISSUER}/oauth/token`;
const POLLING_SAFETY_MARGIN_MS = 3000;
const REFRESH_SAFETY_MARGIN_MS = 30_000;
const AUTH_DIR = resolve(homedir(), '.config/flynn');
const AUTH_FILE = resolve(AUTH_DIR, 'auth.json');
export interface OpenAIOAuthInfo {
access_token: string;
refresh_token: string;
/** Epoch millis. */
expires_at: number;
/** Optional account/org id used for subscription routing. */
account_id?: string;
created_at: string;
}
interface AuthStore {
// Leave github entry untyped here so this module does not depend on github.ts.
github?: unknown;
openai?: OpenAIOAuthInfo;
}
interface DeviceAuthResponse {
device_auth_id: string;
user_code: string;
interval: string;
}
interface DeviceTokenResponse {
authorization_code: string;
code_verifier: string;
}
interface TokenResponse {
id_token?: string;
access_token: string;
refresh_token: string;
expires_in?: number;
}
export interface IdTokenClaims {
chatgpt_account_id?: string;
organizations?: Array<{ id: string }>;
'https://api.openai.com/auth'?: {
chatgpt_account_id?: string;
};
}
function safeJsonParse<T>(raw: string): T | null {
try {
return JSON.parse(raw) as T;
} catch {
return null;
}
}
function readAuthStore(): AuthStore {
try {
const raw = readFileSync(AUTH_FILE, 'utf-8');
const parsed = safeJsonParse<AuthStore>(raw);
return parsed ?? {};
} catch {
return {};
}
}
function writeAuthStore(store: AuthStore): void {
mkdirSync(AUTH_DIR, { recursive: true });
writeFileSync(AUTH_FILE, JSON.stringify(store, null, 2) + '\n', 'utf-8');
chmodSync(AUTH_FILE, 0o600);
}
export function loadStoredOpenAIAuth(): OpenAIOAuthInfo | null {
const store = readAuthStore();
return store.openai ?? null;
}
export function storeOpenAIAuth(info: OpenAIOAuthInfo): void {
const store = readAuthStore();
store.openai = info;
writeAuthStore(store);
}
export function clearOpenAIAuth(): void {
const store = readAuthStore();
delete store.openai;
writeAuthStore(store);
}
export function parseJwtClaims(token: string): IdTokenClaims | undefined {
const parts = token.split('.');
if (parts.length !== 3) {return undefined;}
try {
return JSON.parse(Buffer.from(parts[1], 'base64url').toString()) as IdTokenClaims;
} catch {
return undefined;
}
}
function extractAccountIdFromClaims(claims: IdTokenClaims): string | undefined {
return claims.chatgpt_account_id
?? claims['https://api.openai.com/auth']?.chatgpt_account_id
?? claims.organizations?.[0]?.id;
}
export function extractAccountId(tokens: TokenResponse): string | undefined {
const idToken = tokens.id_token;
if (idToken) {
const claims = parseJwtClaims(idToken);
const id = claims && extractAccountIdFromClaims(claims);
if (id) {return id;}
}
const accessToken = tokens.access_token;
if (accessToken) {
const claims = parseJwtClaims(accessToken);
return claims ? extractAccountIdFromClaims(claims) : undefined;
}
return undefined;
}
async function requestDeviceAuth(): Promise<DeviceAuthResponse> {
const response = await fetch(DEVICE_CODE_URL, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'User-Agent': 'flynn',
},
body: JSON.stringify({ client_id: CLIENT_ID }),
});
if (!response.ok) {
const body = await response.text();
throw new Error(`OpenAI device auth start failed (${response.status}): ${body}`);
}
return response.json() as Promise<DeviceAuthResponse>;
}
async function pollDeviceToken(deviceAuthId: string, userCode: string, intervalMs: number): Promise<DeviceTokenResponse> {
while (true) {
await new Promise(r => setTimeout(r, intervalMs + POLLING_SAFETY_MARGIN_MS));
const response = await fetch(DEVICE_TOKEN_URL, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'User-Agent': 'flynn',
},
body: JSON.stringify({ device_auth_id: deviceAuthId, user_code: userCode }),
});
if (response.ok) {
return response.json() as Promise<DeviceTokenResponse>;
}
// OpenCode treats 403/404 as "pending".
if (response.status === 403 || response.status === 404) {
continue;
}
const body = await response.text();
throw new Error(`OpenAI device auth token failed (${response.status}): ${body}`);
}
}
async function exchangeAuthorizationCode(authCode: string, codeVerifier: string): Promise<TokenResponse> {
const response = await fetch(TOKEN_URL, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
'User-Agent': 'flynn',
},
body: new URLSearchParams({
grant_type: 'authorization_code',
code: authCode,
redirect_uri: `${ISSUER}/deviceauth/callback`,
client_id: CLIENT_ID,
code_verifier: codeVerifier,
}).toString(),
});
if (!response.ok) {
const body = await response.text();
throw new Error(`OpenAI token exchange failed (${response.status}): ${body}`);
}
return response.json() as Promise<TokenResponse>;
}
export async function refreshOpenAIAuth(refreshToken: string): Promise<TokenResponse> {
const response = await fetch(TOKEN_URL, {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
'User-Agent': 'flynn',
},
body: new URLSearchParams({
grant_type: 'refresh_token',
refresh_token: refreshToken,
client_id: CLIENT_ID,
}).toString(),
});
if (!response.ok) {
const body = await response.text();
throw new Error(`OpenAI token refresh failed (${response.status}): ${body}`);
}
return response.json() as Promise<TokenResponse>;
}
/**
* Ensure we have a valid (non-expired) OpenAI OAuth access token.
* Refreshes and persists the token if needed.
*/
export async function ensureValidOpenAIAuth(): Promise<OpenAIOAuthInfo> {
const current = loadStoredOpenAIAuth();
if (!current) {
throw new Error('OpenAI OAuth is not configured. Run `flynn openai-auth` to authenticate.');
}
if (current.expires_at > Date.now() + REFRESH_SAFETY_MARGIN_MS) {
return current;
}
const refreshed = await refreshOpenAIAuth(current.refresh_token);
const expiresAt = Date.now() + (refreshed.expires_in ?? 3600) * 1000;
const accountId = extractAccountId(refreshed) ?? current.account_id;
const updated: OpenAIOAuthInfo = {
access_token: refreshed.access_token,
refresh_token: refreshed.refresh_token,
expires_at: expiresAt,
account_id: accountId,
created_at: current.created_at,
};
storeOpenAIAuth(updated);
return updated;
}
/**
* Run the OpenAI Codex device flow interactively.
* @param onPrompt Callback to display the user code and verification URL to the user.
*/
export async function loginOpenAI(
onPrompt: (userCode: string, verificationUri: string) => void,
): Promise<OpenAIOAuthInfo> {
const device = await requestDeviceAuth();
const intervalMs = Math.max(parseInt(device.interval) || 5, 1) * 1000;
onPrompt(device.user_code, DEVICE_URL);
const deviceToken = await pollDeviceToken(device.device_auth_id, device.user_code, intervalMs);
const tokens = await exchangeAuthorizationCode(deviceToken.authorization_code, deviceToken.code_verifier);
const expiresAt = Date.now() + (tokens.expires_in ?? 3600) * 1000;
const accountId = extractAccountId(tokens);
const info: OpenAIOAuthInfo = {
access_token: tokens.access_token,
refresh_token: tokens.refresh_token,
expires_at: expiresAt,
...(accountId ? { account_id: accountId } : {}),
created_at: new Date().toISOString(),
};
storeOpenAIAuth(info);
return info;
}
+154 -3
View File
@@ -1,6 +1,6 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { homedir } from 'os';
import { GmailWatcher } from './gmail.js';
import type { GmailWatcher as GmailWatcherType } from './gmail.js';
import type { OutboundMessage } from '../channels/types.js';
// Mock googleapis module
@@ -74,6 +74,23 @@ vi.mock('googleapis', () => {
};
});
vi.mock('@google-cloud/pubsub', () => {
const pull = vi.fn().mockResolvedValue([{ receivedMessages: [] }]);
const acknowledge = vi.fn().mockResolvedValue([{}]);
const close = vi.fn().mockResolvedValue(undefined);
class SubscriberClient {
pull = pull;
acknowledge = acknowledge;
close = close;
}
return {
v1: { SubscriberClient },
_mocks: { pull, acknowledge, close },
};
});
// Mock fs operations
vi.mock('fs', async () => {
const actual = await vi.importActual<typeof import('fs')>('fs');
@@ -86,6 +103,7 @@ vi.mock('fs', async () => {
installed: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
project_id: 'test-project',
redirect_uris: ['http://localhost'],
},
});
@@ -108,6 +126,9 @@ function createMockConfig(overrides = {}) {
enabled: true,
credentials_file: '~/.config/flynn/gmail-credentials.json',
token_file: '~/.config/flynn/gmail-token.json',
disable_push: false,
pubsub_pull_interval: '60s',
pubsub_max_messages: 10,
watch_labels: ['INBOX'],
poll_interval: '300s',
output: {
@@ -128,12 +149,16 @@ function createMockChannelLookup() {
}
describe('GmailWatcher', () => {
let watcher: GmailWatcher;
let GmailWatcher: typeof GmailWatcherType;
let watcher: GmailWatcherType;
let channelLookup: ReturnType<typeof createMockChannelLookup>;
beforeEach(() => {
beforeEach(async () => {
vi.useFakeTimers();
channelLookup = createMockChannelLookup();
// Import after mocks so ESM named imports (fs/googleapis) are properly mocked.
({ GmailWatcher } = await import('./gmail.js'));
});
afterEach(async () => {
@@ -154,6 +179,60 @@ describe('GmailWatcher', () => {
});
});
describe('push topic resolution', () => {
it('returns null when pubsub_topic is not set', () => {
const config = createMockConfig();
watcher = new GmailWatcher(config, channelLookup);
const topic = (watcher as unknown as { resolvePubSubTopicName: () => string | null }).resolvePubSubTopicName();
expect(topic).toBe(null);
});
it('expands shorthand topic id when project_id is known', () => {
const config = createMockConfig({ pubsub_topic: 'my-topic' });
watcher = new GmailWatcher(config, channelLookup);
(watcher as unknown as { googleProjectId: string }).googleProjectId = 'test-project';
const topic = (watcher as unknown as { resolvePubSubTopicName: () => string | null }).resolvePubSubTopicName();
expect(topic).toBe('projects/test-project/topics/my-topic');
});
it('rejects invalid pubsub_topic formats', () => {
const config = createMockConfig({ pubsub_topic: 'projects/test-project/topic/my-topic' });
watcher = new GmailWatcher(config, channelLookup);
expect(() => {
(watcher as unknown as { resolvePubSubTopicName: () => string | null }).resolvePubSubTopicName();
}).toThrow(/Invalid pubsub_topic/);
});
});
describe('pull subscription resolution', () => {
it('returns null when pubsub_subscription_id is not set', () => {
const config = createMockConfig();
watcher = new GmailWatcher(config, channelLookup);
const sub = (watcher as unknown as { resolvePubSubSubscriptionName: () => string | null }).resolvePubSubSubscriptionName();
expect(sub).toBe(null);
});
it('expands shorthand subscription id when project_id is known', () => {
const config = createMockConfig({ pubsub_subscription_id: 'my-sub' });
watcher = new GmailWatcher(config, channelLookup);
(watcher as unknown as { googleProjectId: string }).googleProjectId = 'test-project';
const sub = (watcher as unknown as { resolvePubSubSubscriptionName: () => string | null }).resolvePubSubSubscriptionName();
expect(sub).toBe('projects/test-project/subscriptions/my-sub');
});
it('rejects invalid pubsub_subscription_id formats', () => {
const config = createMockConfig({ pubsub_subscription_id: 'projects/test-project/subscription/my-sub' });
watcher = new GmailWatcher(config, channelLookup);
expect(() => {
(watcher as unknown as { resolvePubSubSubscriptionName: () => string | null }).resolvePubSubSubscriptionName();
}).toThrow(/Invalid pubsub_subscription_id/);
});
});
describe('connect() with missing credentials', () => {
it('logs warning and sets status to error when credentials_file is missing', async () => {
const config = createMockConfig({ credentials_file: undefined });
@@ -200,6 +279,78 @@ describe('GmailWatcher', () => {
});
});
describe('push disable flag', () => {
it('skips watch setup when disable_push is true', async () => {
const config = createMockConfig({ disable_push: true, pubsub_topic: 'projects/test-project/topics/gmail-push' });
watcher = new GmailWatcher(config, channelLookup);
const { existsSync, readFileSync } = await import('fs');
vi.mocked(existsSync).mockReturnValue(true);
vi.mocked(readFileSync).mockImplementation((path: unknown) => {
const p = String(path);
if (p.includes('credentials')) {
return JSON.stringify({
installed: {
client_id: 'test-client-id',
client_secret: 'test-client-secret',
project_id: 'test-project',
redirect_uris: ['http://localhost'],
},
});
}
return JSON.stringify({
access_token: 'test-access-token',
refresh_token: 'test-refresh-token',
expiry_date: Date.now() + 3600000,
});
});
const googleapis = await import('googleapis') as unknown as {
_mocks: {
mockWatch: ReturnType<typeof vi.fn>;
mockOAuth2: ReturnType<typeof vi.fn>;
};
};
googleapis._mocks.mockOAuth2.mockImplementation(() => ({
setCredentials: vi.fn(),
on: vi.fn(),
}));
const watchSpy = googleapis._mocks.mockWatch;
await watcher.connect();
expect(watchSpy).not.toHaveBeenCalled();
expect(watcher.status).toBe('connected');
});
});
describe('pullSubscriptionMessages', () => {
it('pulls messages and acknowledges successfully processed ones', async () => {
const config = createMockConfig({ pubsub_subscription_id: 'projects/test-project/subscriptions/gmail-pull' });
watcher = new GmailWatcher(config, channelLookup);
const { _mocks: pubsubMocks } = await import('@google-cloud/pubsub') as unknown as {
_mocks: { pull: ReturnType<typeof vi.fn>; acknowledge: ReturnType<typeof vi.fn> };
};
const payload = { emailAddress: 'bob@example.com', historyId: '200' };
pubsubMocks.pull.mockResolvedValueOnce([
{
receivedMessages: [
{ ackId: 'ack-1', message: { data: Buffer.from(JSON.stringify(payload)) } },
],
},
]);
await (watcher as unknown as { pullSubscriptionMessages: () => Promise<void> }).pullSubscriptionMessages();
expect((watcher as unknown as { lastHistoryId: string }).lastHistoryId).toBe('200');
expect(pubsubMocks.acknowledge).toHaveBeenCalledWith({
subscription: 'projects/test-project/subscriptions/gmail-pull',
ackIds: ['ack-1'],
});
});
});
describe('renderTemplate', () => {
it('replaces all placeholders correctly', () => {
const config = createMockConfig({
+222 -11
View File
@@ -2,6 +2,7 @@ import { google, type Auth } from 'googleapis';
import { readFileSync, writeFileSync, existsSync, mkdirSync, chmodSync } from 'fs';
import { dirname, resolve } from 'path';
import { homedir } from 'os';
import type { v1 } from '@google-cloud/pubsub';
import type { GmailConfig } from '../config/schema.js';
import type { ChannelAdapter, ChannelStatus, InboundMessage, OutboundMessage } from '../channels/types.js';
import { parseInterval } from './heartbeat.js';
@@ -30,9 +31,7 @@ interface PubSubNotification {
historyId: string;
}
// Google Cloud Pub/Sub topic for Gmail push notifications.
// This must be pre-configured in Google Cloud Console.
const GMAIL_PUBSUB_TOPIC = 'projects/flynn-agent/topics/gmail-push';
const DEFAULT_TOPIC_ID = 'gmail-push';
// Watch expires after ~7 days; renew at 6 days (in ms).
const WATCH_RENEWAL_MS = 6 * 24 * 60 * 60 * 1000;
@@ -56,7 +55,11 @@ export class GmailWatcher implements ChannelAdapter {
private lastHistoryId?: string;
private pollTimer?: ReturnType<typeof setInterval>;
private watchTimer?: ReturnType<typeof setInterval>;
private pullTimer?: ReturnType<typeof setInterval>;
private pubsubSubscriber?: v1.SubscriberClient;
private pullInFlight = false;
private readonly config: NonNullable<GmailConfig>;
private googleProjectId?: string;
constructor(
config: NonNullable<GmailConfig>,
@@ -82,12 +85,28 @@ export class GmailWatcher implements ChannelAdapter {
return;
}
// Set up Gmail push watch (Pub/Sub)
// Set up Gmail push watch (Pub/Sub). Polling is always enabled.
if (!this.config.disable_push) {
try {
await this.setupWatch();
} catch (error) {
const errMsg = error instanceof Error ? error.message : 'Unknown error';
const hint = this.buildWatchErrorHint(errMsg);
console.warn(`GmailWatcher: Watch setup failed (will use polling only) — ${errMsg}${hint}`);
}
} else {
const configured = (this.config.pubsub_topic ?? process.env.FLYNN_GMAIL_PUBSUB_TOPIC ?? '').trim();
if (configured) {
console.log('GmailWatcher: Push disabled (disable_push=true)');
}
}
// Set up Pub/Sub pull subscription (optional).
try {
await this.setupWatch();
await this.setupPullSubscription();
} catch (error) {
const errMsg = error instanceof Error ? error.message : 'Unknown error';
console.warn(`GmailWatcher: Watch setup failed (will use polling only) — ${errMsg}`);
console.warn(`GmailWatcher: Pull setup failed (will continue without pull) — ${errMsg}`);
}
// Start polling fallback
@@ -99,8 +118,23 @@ export class GmailWatcher implements ChannelAdapter {
}, pollMs);
this._status = 'connected';
console.log(`GmailWatcher: Connected (poll_interval=${this.config.poll_interval ?? '300s'})`);
auditLogger?.systemStart('GmailWatcher', { poll_interval: this.config.poll_interval });
const modes: string[] = [];
const pushConfigured = Boolean((this.config.pubsub_topic ?? process.env.FLYNN_GMAIL_PUBSUB_TOPIC ?? '').trim());
const pullConfigured = Boolean((this.config.pubsub_subscription_id ?? '').trim());
if (pushConfigured && !this.config.disable_push) {modes.push('push');}
if (pullConfigured) {modes.push('pull');}
modes.push('poll');
console.log(
`GmailWatcher: Connected (${modes.join('+')}, poll_interval=${this.config.poll_interval ?? '300s'}${pullConfigured ? `, pubsub_pull_interval=${this.config.pubsub_pull_interval ?? '60s'}` : ''})`,
);
auditLogger?.systemStart('GmailWatcher', {
modes: modes.join('+'),
poll_interval: this.config.poll_interval,
pubsub_topic: pushConfigured ? 'configured' : 'none',
pubsub_subscription_id: pullConfigured ? 'configured' : 'none',
});
}
async disconnect(): Promise<void> {
@@ -109,9 +143,21 @@ export class GmailWatcher implements ChannelAdapter {
this.pollTimer = undefined;
}
if (this.watchTimer) {
clearTimeout(this.watchTimer);
clearInterval(this.watchTimer);
this.watchTimer = undefined;
}
if (this.pullTimer) {
clearInterval(this.pullTimer);
this.pullTimer = undefined;
}
if (this.pubsubSubscriber) {
try {
await this.pubsubSubscriber.close();
} catch {
// Ignore shutdown errors
}
this.pubsubSubscriber = undefined;
}
this.oauth2Client = undefined;
this._status = 'disconnected';
auditLogger?.systemStop('GmailWatcher');
@@ -178,7 +224,10 @@ export class GmailWatcher implements ChannelAdapter {
}
const credentials = JSON.parse(readFileSync(expandedCredsPath, 'utf-8'));
const { client_id, client_secret, redirect_uris } = credentials.installed ?? credentials.web ?? {};
const { client_id, client_secret, redirect_uris, project_id } = credentials.installed ?? credentials.web ?? {};
if (project_id && typeof project_id === 'string') {
this.googleProjectId = project_id;
}
if (!client_id || !client_secret) {
throw new Error('Invalid credentials file — missing client_id or client_secret');
@@ -217,13 +266,24 @@ export class GmailWatcher implements ChannelAdapter {
private async setupWatch(): Promise<void> {
if (!this.oauth2Client) {return;}
if (this.watchTimer) {
clearInterval(this.watchTimer);
this.watchTimer = undefined;
}
const topicName = this.resolvePubSubTopicName();
if (!topicName) {
// Push notifications are optional; polling is always enabled.
return;
}
const gmail = google.gmail({ version: 'v1', auth: this.oauth2Client });
const watchResponse = await gmail.users.watch({
userId: 'me',
requestBody: {
labelIds: this.config.watch_labels ?? ['INBOX'],
topicName: GMAIL_PUBSUB_TOPIC,
topicName,
},
});
@@ -241,6 +301,157 @@ export class GmailWatcher implements ChannelAdapter {
}, WATCH_RENEWAL_MS);
}
private buildWatchErrorHint(errMsg: string): string {
const hints: string[] = [];
if (errMsg.includes('Invalid topicName')) {
hints.push(
`Tip: set automation.gmail.pubsub_topic to "projects/${this.googleProjectId ?? '<project-id>'}/topics/${DEFAULT_TOPIC_ID}"`,
);
}
if (/permission denied|PERMISSION_DENIED/i.test(errMsg)) {
hints.push('Tip: ensure Gmail has permission to publish to the Pub/Sub topic (IAM)');
}
hints.push('Tip: if Google cannot reach your gateway, set automation.gmail.pubsub_subscription_id for pull mode');
return hints.length > 0 ? `\n ${hints.join('\n ')}` : '';
}
/**
* Resolve the Pub/Sub topic resource name for Gmail push notifications.
*
* Priority:
* 1) automation.gmail.pubsub_topic
* 2) FLYNN_GMAIL_PUBSUB_TOPIC env var
* If neither is provided, push notifications are disabled.
*/
private resolvePubSubTopicName(): string | null {
const configured = this.config.pubsub_topic ?? process.env.FLYNN_GMAIL_PUBSUB_TOPIC;
let topic = (configured ?? '').trim();
if (!topic) {return null;}
// Allow shorthand: just the topic id (e.g. "gmail-push")
if (!topic.includes('/')) {
if (!this.googleProjectId) {
throw new Error(
`pubsub_topic '${topic}' must be fully qualified (projects/<project-id>/topics/<topic>) because project_id was not found in credentials`,
);
}
topic = `projects/${this.googleProjectId}/topics/${topic}`;
}
const isValid = /^projects\/[^/]+\/topics\/[^/]+$/.test(topic);
if (!isValid) {
throw new Error(
`Invalid pubsub_topic '${topic}'. Expected: projects/<project-id>/topics/<topic>`,
);
}
return topic;
}
private resolvePubSubSubscriptionName(): string | null {
let sub = (this.config.pubsub_subscription_id ?? '').trim();
if (!sub) {return null;}
// Allow shorthand: just the subscription id (e.g. "gmail-pull")
if (!sub.includes('/')) {
if (!this.googleProjectId) {
throw new Error(
`pubsub_subscription_id '${sub}' must be fully qualified (projects/<project-id>/subscriptions/<subscription>) because project_id was not found in credentials`,
);
}
sub = `projects/${this.googleProjectId}/subscriptions/${sub}`;
}
const isValid = /^projects\/[^/]+\/subscriptions\/[^/]+$/.test(sub);
if (!isValid) {
throw new Error(
`Invalid pubsub_subscription_id '${sub}'. Expected: projects/<project-id>/subscriptions/<subscription>`,
);
}
return sub;
}
private async setupPullSubscription(): Promise<void> {
const subscriptionName = this.resolvePubSubSubscriptionName();
if (!subscriptionName) {return;}
if (this.pullTimer) {
clearInterval(this.pullTimer);
this.pullTimer = undefined;
}
const pullMs = parseInterval(this.config.pubsub_pull_interval ?? '60s');
// Kick once immediately, then on interval.
await this.pullSubscriptionMessages().catch((err) => {
console.error('GmailWatcher: Pub/Sub pull error —', err instanceof Error ? err.message : err);
});
this.pullTimer = setInterval(() => {
this.pullSubscriptionMessages().catch((err) => {
console.error('GmailWatcher: Pub/Sub pull error —', err instanceof Error ? err.message : err);
});
}, pullMs);
console.log(
`GmailWatcher: Pull enabled (subscription=${subscriptionName}, interval=${this.config.pubsub_pull_interval ?? '60s'})`,
);
}
private async getSubscriberClient(): Promise<v1.SubscriberClient> {
if (this.pubsubSubscriber) {return this.pubsubSubscriber;}
const mod = await import('@google-cloud/pubsub');
this.pubsubSubscriber = new mod.v1.SubscriberClient();
return this.pubsubSubscriber;
}
private async pullSubscriptionMessages(): Promise<void> {
const subscription = this.resolvePubSubSubscriptionName();
if (!subscription) {return;}
if (this.pullInFlight) {return;}
this.pullInFlight = true;
try {
const client = await this.getSubscriberClient();
const maxMessages = this.config.pubsub_max_messages ?? 10;
const [response] = await client.pull({
subscription,
maxMessages,
});
const received = response.receivedMessages ?? [];
if (received.length === 0) {return;}
const ackIds: string[] = [];
for (const receivedMessage of received) {
const ackId = receivedMessage.ackId;
const data = receivedMessage.message?.data;
if (!ackId || !data) {continue;}
const base64 = Buffer.from(data as Uint8Array).toString('base64');
try {
await this.handlePushNotification(base64);
ackIds.push(ackId);
} catch {
// If processing fails, leave message unacked for retry.
}
}
if (ackIds.length > 0) {
await client.acknowledge({ subscription, ackIds });
}
} finally {
this.pullInFlight = false;
}
}
/**
* Poll Gmail History API for new messages since lastHistoryId.
* Fallback mechanism when Pub/Sub push is not available.
+29
View File
@@ -228,6 +228,35 @@ describe('TelegramAdapter', () => {
expect(msg.metadata).toEqual({ isCommand: true, command: 'reset' });
});
it('/model command strips @bot suffix in groups', async () => {
const handler = vi.fn();
adapter.onMessage(handler);
await adapter.connect();
// Find the /model command handler
const modelCall = mockCommand.mock.calls.find((call) => call[0] === 'model');
expect(modelCall).toBeDefined();
const modelHandler = modelCall![1];
const ctx = {
message: { message_id: 123, text: '/model@flynn_bot default github/gpt-5-mini' },
chat: { id: 100 },
from: { first_name: 'Will' },
};
await modelHandler(ctx);
expect(handler).toHaveBeenCalledTimes(1);
const msg: InboundMessage = handler.mock.calls[0][0];
expect(msg.text).toBe('/model default github/gpt-5-mini');
expect(msg.metadata).toEqual({
isCommand: true,
command: 'model',
commandArgs: 'default github/gpt-5-mini',
});
});
// ── Auth middleware ───────────────────────────────────────────
it('auth middleware blocks unauthorized chat IDs', async () => {
+39 -4
View File
@@ -166,7 +166,9 @@ export class TelegramAdapter implements ChannelAdapter {
this.bot.command('model', async (ctx) => {
if (!this.messageHandler) {return;}
const args = ctx.message?.text?.replace(/^\/model\s*/, '').trim() ?? '';
// Telegram can deliver group commands in the form: /model@bot_username ...
// Strip the optional @mention so args parsing is consistent across DMs/groups.
const args = ctx.message?.text?.replace(/^\/model(?:@\S+)?\s*/i, '').trim() ?? '';
this.messageHandler({
id: String(ctx.message?.message_id ?? Date.now()),
@@ -439,15 +441,48 @@ export class TelegramAdapter implements ChannelAdapter {
if (!this.bot) {throw new Error('Telegram adapter not connected');}
const chatId = Number(peerId);
const text = message.text;
const text = message.text ?? '';
// Telegram rejects empty text messages.
// If there is no text, skip straight to attachments.
if (!text.trim()) {
if (message.attachments && message.attachments.length > 0) {
for (const attachment of message.attachments) {
await this.sendAttachment(chatId, attachment);
}
}
return;
}
const sendChunk = async (chunk: string): Promise<void> => {
// We default to Markdown for nicer formatting, but Telegram's Markdown parsing
// is strict and can fail on unescaped characters. If Telegram rejects the
// message, retry once without parse_mode so users still get the content.
try {
await this.bot!.api.sendMessage(chatId, chunk, { parse_mode: 'Markdown' });
} catch (error) {
const description = error && typeof error === 'object' && 'description' in error
? String((error as { description?: unknown }).description)
: '';
const isParseError = description.includes("can't parse entities")
|| description.includes('message text is empty');
if (!isParseError) {
throw error;
}
await this.bot!.api.sendMessage(chatId, chunk);
}
};
// Telegram enforces a 4096-character limit per message
if (text.length <= 4096) {
await this.bot.api.sendMessage(chatId, text, { parse_mode: 'Markdown' });
await sendChunk(text);
} else {
const chunks = splitMessage(text, 4096);
for (const chunk of chunks) {
await this.bot.api.sendMessage(chatId, chunk, { parse_mode: 'Markdown' });
await sendChunk(chunk);
}
}
+69
View File
@@ -163,6 +163,75 @@ automation:
expect(gmailCheck?.detail).toContain('flynn gmail-auth');
});
it('reports PASS for Gmail when enabled (poll only)', async () => {
mkdirSync(testDir, { recursive: true });
const configPath = join(testDir, 'config.yaml');
const credsPath = join(testDir, 'gmail-creds.json');
const tokenPath = join(testDir, 'gmail-token.json');
writeFileSync(credsPath, JSON.stringify({ installed: { project_id: 'test-project' } }));
writeFileSync(tokenPath, JSON.stringify({ refresh_token: 'x' }));
writeFileSync(configPath, `
telegram:
bot_token: "test-token"
allowed_chat_ids: [123]
models:
default:
provider: anthropic
model: claude-sonnet
automation:
gmail:
enabled: true
credentials_file: "${credsPath}"
token_file: "${tokenPath}"
output:
channel: telegram
peer: "123"
`);
const ctx: DoctorContext = { configPath, dataDir: testDir };
const results = await runChecks(ctx);
const gmailCheck = results.find(r => r.label.includes('Gmail configured'));
expect(gmailCheck?.status).toBe('pass');
expect(gmailCheck?.detail).toContain('poll');
});
it('reports WARN for Gmail when pubsub_topic shorthand used without project_id', async () => {
mkdirSync(testDir, { recursive: true });
const configPath = join(testDir, 'config.yaml');
const credsPath = join(testDir, 'gmail-creds.json');
const tokenPath = join(testDir, 'gmail-token.json');
writeFileSync(credsPath, '{}');
writeFileSync(tokenPath, JSON.stringify({ refresh_token: 'x' }));
writeFileSync(configPath, `
telegram:
bot_token: "test-token"
allowed_chat_ids: [123]
models:
default:
provider: anthropic
model: claude-sonnet
automation:
gmail:
enabled: true
credentials_file: "${credsPath}"
token_file: "${tokenPath}"
pubsub_topic: gmail-push
output:
channel: telegram
peer: "123"
`);
const ctx: DoctorContext = { configPath, dataDir: testDir };
const results = await runChecks(ctx);
const gmailCheck = results.find(r => r.label.includes('Gmail configured'));
expect(gmailCheck?.status).toBe('warn');
expect(gmailCheck?.detail).toContain('pubsub_topic shorthand');
});
it('skips downstream checks when config is invalid', async () => {
const ctx: DoctorContext = { configPath: '/nonexistent/config.yaml', dataDir: testDir };
const results = await runChecks(ctx);
+55 -2
View File
@@ -137,7 +137,8 @@ const checkModelConnectivity: Check = async (ctx) => {
// Check if API key is present for providers that need one
const needsKey = ['anthropic', 'openai', 'gemini', 'openrouter'];
if (needsKey.includes(model.provider) && !model.api_key && !model.auth_token) {
const openaiUsingOAuth = model.provider === 'openai' && Boolean((model as unknown as { use_oauth?: boolean }).use_oauth);
if (needsKey.includes(model.provider) && !openaiUsingOAuth && !model.api_key && !model.auth_token) {
const envVarMap: Record<string, string> = {
anthropic: 'ANTHROPIC_API_KEY',
openai: 'OPENAI_API_KEY',
@@ -256,12 +257,64 @@ const checkGmail: Check = async (ctx) => {
return { status: 'fail', label: 'Gmail configured', detail: `credentials file not found: ${credentialsPath}` };
}
let googleProjectId: string | undefined;
try {
const creds = JSON.parse(readFileSync(credentialsPath, 'utf-8')) as Record<string, unknown>;
const installed = (creds.installed as Record<string, unknown> | undefined) ?? (creds.web as Record<string, unknown> | undefined);
const projectId = installed?.project_id;
if (typeof projectId === 'string' && projectId.trim()) {
googleProjectId = projectId.trim();
}
} catch {
// Ignore JSON parse errors; doctor will still validate token and output.
}
const tokenPath = expandPath(gmail.token_file ?? '~/.config/flynn/gmail-token.json');
if (!existsSync(tokenPath)) {
return { status: 'warn', label: 'Gmail configured', detail: 'run `flynn gmail-auth` to authenticate' };
}
return { status: 'pass', label: 'Gmail configured', detail: `(output: ${gmail.output.channel}/${gmail.output.peer})` };
const modes: string[] = [];
const warnings: string[] = [];
const topicRaw = (gmail.pubsub_topic ?? process.env.FLYNN_GMAIL_PUBSUB_TOPIC ?? '').trim();
const pushEnabled = Boolean(topicRaw) && !gmail.disable_push;
if (pushEnabled) {
modes.push('push');
if (topicRaw.includes('/')) {
const ok = /^projects\/[^/]+\/topics\/[^/]+$/.test(topicRaw);
if (!ok) {
warnings.push('pubsub_topic format invalid (expected projects/<project>/topics/<topic>)');
}
} else if (!googleProjectId) {
warnings.push('pubsub_topic shorthand requires project_id in Gmail credentials');
}
if (ctx.config.server?.tailscale?.serve) {
warnings.push('push requires a public HTTPS endpoint; Tailscale Serve is typically tailnet-only');
}
} else if (gmail.disable_push && topicRaw) {
warnings.push('push disabled (disable_push=true)');
}
const subRaw = (gmail.pubsub_subscription_id ?? '').trim();
if (subRaw) {
modes.push('pull');
if (subRaw.includes('/')) {
const ok = /^projects\/[^/]+\/subscriptions\/[^/]+$/.test(subRaw);
if (!ok) {
warnings.push('pubsub_subscription_id format invalid (expected projects/<project>/subscriptions/<sub>)');
}
} else if (!googleProjectId) {
warnings.push('pubsub_subscription_id shorthand requires project_id in Gmail credentials');
}
}
modes.push('poll');
const detail = `(${modes.join(' + ')} -> ${gmail.output.channel}/${gmail.output.peer})`;
const withWarnings = warnings.length > 0 ? `${detail}${warnings.join('; ')}` : detail;
return { status: warnings.length > 0 ? 'warn' : 'pass', label: 'Gmail configured', detail: withWarnings };
};
const allChecks: Check[] = [
+2
View File
@@ -18,6 +18,7 @@ import { registerGcalAuthCommand } from './gcal-auth.js';
import { registerGdocsAuthCommand } from './gdocs-auth.js';
import { registerGdriveAuthCommand } from './gdrive-auth.js';
import { registerGtasksAuthCommand } from './gtasks-auth.js';
import { registerOpenaiAuthCommand } from './openai-auth.js';
import { registerSkillsCommand } from './skills.js';
export function createProgram(): Command {
@@ -41,6 +42,7 @@ export function createProgram(): Command {
registerGdocsAuthCommand(program);
registerGdriveAuthCommand(program);
registerGtasksAuthCommand(program);
registerOpenaiAuthCommand(program);
registerSkillsCommand(program);
return program;
+35
View File
@@ -0,0 +1,35 @@
import type { Command } from 'commander';
import { loadStoredOpenAIAuth, loginOpenAI } from '../auth/index.js';
export function registerOpenaiAuthCommand(program: Command): void {
program
.command('openai-auth')
.description('Authenticate OpenAI (ChatGPT Plus/Pro) via OAuth device flow')
.action(async () => {
const existing = loadStoredOpenAIAuth();
if (existing) {
console.log('OpenAI OAuth token already exists.');
console.log('Delete ~/.config/flynn/auth.json openai entry if you want to re-authenticate.');
process.exit(0);
}
console.log('Starting OpenAI OAuth device flow...');
console.log('');
try {
await loginOpenAI((userCode, verificationUri) => {
console.log(`Please visit: ${verificationUri}`);
console.log(`Enter code: ${userCode}`);
console.log('');
console.log('Waiting for authorization...');
});
console.log('');
console.log('OpenAI authentication successful! Token stored.');
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
console.error(`OpenAI login failed: ${message}`);
process.exit(1);
}
});
}
+39
View File
@@ -53,6 +53,45 @@ models:
expect(result.config!.telegram?.bot_token).toBe('test-token');
});
it('loads env vars from FLYNN_ENV_FILE before parsing config', () => {
const prevEnvFile = process.env.FLYNN_ENV_FILE;
const prevToken = process.env.TEST_BOT_TOKEN;
delete process.env.TEST_BOT_TOKEN;
mkdirSync(testDir, { recursive: true });
const envPath = join(testDir, 'cloud.env');
const configPath = join(testDir, 'config.yaml');
writeFileSync(envPath, 'TEST_BOT_TOKEN=test-token\n');
process.env.FLYNN_ENV_FILE = envPath;
writeFileSync(configPath, `
telegram:
bot_token: \${TEST_BOT_TOKEN}
allowed_chat_ids: [123]
models:
default:
provider: anthropic
model: claude-sonnet
`);
const result = loadConfigSafe(configPath);
expect(result.config).toBeDefined();
expect(result.error).toBeUndefined();
expect(result.config!.telegram?.bot_token).toBe('test-token');
if (prevEnvFile !== undefined) {
process.env.FLYNN_ENV_FILE = prevEnvFile;
} else {
delete process.env.FLYNN_ENV_FILE;
}
if (prevToken !== undefined) {
process.env.TEST_BOT_TOKEN = prevToken;
} else {
delete process.env.TEST_BOT_TOKEN;
}
});
it('returns error when file not found', () => {
const result = loadConfigSafe('/nonexistent/config.yaml');
expect(result.config).toBeUndefined();
+33
View File
@@ -2,6 +2,38 @@ import { loadConfig } from '../config/index.js';
import type { Config } from '../config/index.js';
import { resolve, dirname, join } from 'path';
import { homedir } from 'os';
import { existsSync, readFileSync } from 'fs';
function loadEnvFileIfPresent(): void {
const envFile = process.env.FLYNN_ENV_FILE ?? resolve(homedir(), '.config/flynn/cloud.env');
if (!existsSync(envFile)) {
return;
}
const raw = readFileSync(envFile, 'utf-8');
for (const line of raw.split(/\r?\n/)) {
const trimmed = line.trim();
if (!trimmed || trimmed.startsWith('#')) {
continue;
}
const idx = trimmed.indexOf('=');
if (idx <= 0) {
continue;
}
const key = trimmed.slice(0, idx).trim();
const value = trimmed.slice(idx + 1);
if (!key) {
continue;
}
// Don't override existing env vars.
if (process.env[key] === undefined) {
process.env[key] = value;
}
}
}
/** Get the config file path from env or default location. */
export function getConfigPath(): string {
@@ -30,6 +62,7 @@ export function resolveOverlayPath(basePath: string): string | undefined {
export function loadConfigSafe(configPath?: string): { config?: Config; error?: string } {
const path = configPath ?? getConfigPath();
try {
loadEnvFileIfPresent();
const overlayPath = resolveOverlayPath(path);
const config = loadConfig(path, overlayPath);
return { config };
+49 -22
View File
@@ -1,5 +1,5 @@
import type { Command } from 'commander';
import type { Config } from '../config/index.js';
import type { Config, ModelConfig, ModelProvider } from '../config/index.js';
import { loadConfigSafe, getConfigPath } from './shared.js';
import { existsSync, mkdirSync, readFileSync } from 'fs';
import { resolve } from 'path';
@@ -58,6 +58,26 @@ function loadSystemPrompt(): string {
return 'You are Flynn, a helpful personal AI assistant. Be direct, concise, and helpful. Use markdown when it improves readability.';
}
function buildProviderConfigMap(config: Config): Partial<Record<ModelProvider, ModelConfig>> {
const providerConfigs: Partial<Record<ModelProvider, ModelConfig>> = {};
const modelConfigs: ModelConfig[] = [
config.models.default,
...(config.models.fast ? [config.models.fast] : []),
...(config.models.complex ? [config.models.complex] : []),
...(config.models.local ? [config.models.local] : []),
...Object.values(config.models.local_providers ?? {}),
];
for (const modelConfig of modelConfigs) {
providerConfigs[modelConfig.provider] = modelConfig;
if (modelConfig.fallback) {
providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback;
}
}
return providerConfigs;
}
export function registerTuiCommand(program: Command): void {
program
.command('tui')
@@ -179,6 +199,7 @@ export function registerTuiCommand(program: Command): void {
const toolExecutor = new ToolExecutor(toolRegistry, hookEngine);
const session = sessionManager.getSession('tui', 'local');
const modelProviderConfigs = buildProviderConfigMap(config);
const agent = new NativeAgent({
modelClient: modelRouter,
@@ -211,29 +232,33 @@ export function registerTuiCommand(program: Command): void {
process.exit(0);
});
if (opts.fullscreen) {
await startFullscreenTui({
session,
modelClient: modelRouter,
modelRouter,
systemPrompt,
model: config.models.default.model,
agent,
onExit: cleanup,
});
} else {
if (opts.fullscreen) {
await startFullscreenTui({
session,
modelClient: modelRouter,
modelRouter,
systemPrompt,
model: config.models.default.model,
agent,
hookEngine,
modelProviderConfigs,
onExit: cleanup,
});
} else {
let switchingToFullscreen = false;
const tui = new MinimalTui({
session,
modelClient: modelRouter,
modelRouter,
systemPrompt,
agent,
pairingManager,
localProviders: config.models.local_providers,
currentLocalProvider: config.models.local?.provider,
onTransfer: (target) => {
const tui = new MinimalTui({
session,
modelClient: modelRouter,
modelRouter,
systemPrompt,
agent,
hookEngine,
pairingManager,
localProviders: config.models.local_providers,
modelProviderConfigs,
currentLocalProvider: config.models.local?.provider,
onTransfer: (target) => {
if (target === 'telegram') {
if (config.telegram && config.telegram.allowed_chat_ids.length > 0) {
const telegramUserId = String(config.telegram.allowed_chat_ids[0]);
@@ -263,6 +288,8 @@ export function registerTuiCommand(program: Command): void {
systemPrompt,
model: config.models.default.model,
agent,
hookEngine,
modelProviderConfigs,
onExit: cleanup,
});
return;
+37
View File
@@ -0,0 +1,37 @@
import { describe, it, expect, vi } from 'vitest';
import { createModelCommand } from './index.js';
describe('builtin /model command', () => {
it('passes through the full argument string', async () => {
const cmd = createModelCommand();
const setModel = vi.fn(() => 'ok');
const result = await cmd.execute(['default', 'github/gpt-5-mini'], {
channel: 'test',
senderId: 'user',
sessionId: 's1',
rawInput: '/model default github/gpt-5-mini',
services: { setModel },
});
expect(setModel).toHaveBeenCalledWith('default github/gpt-5-mini');
expect(result).toEqual({ handled: true, text: 'ok' });
});
it('still works for single-argument tier switching', async () => {
const cmd = createModelCommand();
const setModel = vi.fn(() => 'switched');
const result = await cmd.execute(['fast'], {
channel: 'test',
senderId: 'user',
sessionId: 's1',
rawInput: '/model fast',
services: { setModel },
});
expect(setModel).toHaveBeenCalledWith('fast');
expect(result).toEqual({ handled: true, text: 'switched' });
});
});
+3 -1
View File
@@ -86,7 +86,9 @@ export function createModelCommand(): CommandDefinition {
return {
handled: true,
text: await ctx.services.setModel(args[0]),
// Pass through the full argument string so frontends can support
// richer syntax like: /model <tier> <provider/model>
text: await ctx.services.setModel(args.join(' ')),
};
},
};
+26
View File
@@ -49,6 +49,8 @@ const modelConfigBaseSchema = z.object({
endpoint: z.string().optional(),
api_key: z.string().optional(),
auth_token: z.string().optional(),
/** Use OAuth credential flow (provider-specific). */
use_oauth: z.boolean().optional(),
for: z.array(z.string()).optional(),
num_gpu: z.number().optional(),
context_window: z.number().optional(),
@@ -178,6 +180,30 @@ const gmailSchema = z.object({
enabled: z.boolean().default(false),
credentials_file: z.string().optional(),
token_file: z.string().default('~/.config/flynn/gmail-token.json'),
/**
* Optional Google Cloud Pub/Sub topic for Gmail push notifications.
* Format: projects/<project-id>/topics/<topic>
* If omitted, push notifications are disabled and Flynn will use polling.
*/
pubsub_topic: z.string().optional(),
/**
* Explicitly disable Gmail push watch registration even if pubsub_topic is set.
* Useful for environments where Google cannot reach the gateway (e.g. tailnet-only).
*/
disable_push: z.boolean().default(false),
/**
* Optional Pub/Sub subscription for pull-based delivery (no inbound webhook required).
* Format: projects/<project-id>/subscriptions/<subscription>
*/
pubsub_subscription_id: z.string().optional(),
/** How often to pull messages from pubsub_subscription_id (e.g. '60s'). */
pubsub_pull_interval: z.string().default('60s'),
/** Max messages to pull per cycle (1..100). */
pubsub_max_messages: z.number().min(1).max(100).default(10),
watch_labels: z.array(z.string()).default(['INBOX']),
poll_interval: z.string().default('300s'),
history_start: z.string().optional(), // ISO date string — only process emails after this date
+28
View File
@@ -96,6 +96,34 @@ describe('createClientFromConfig', () => {
expect(client).toBeInstanceOf(OpenAIClient);
});
it('creates OpenAIClient for zhipuai when using auth_token', () => {
const client = createClientFromConfig({
provider: 'zhipuai',
model: 'glm-4.5',
auth_token: 'oauth-access-token',
});
expect(client).toBeInstanceOf(OpenAIClient);
});
it('creates OpenAIClient for zhipuai using ZHIPUAI_AUTH_TOKEN env var', () => {
const prev = process.env.ZHIPUAI_AUTH_TOKEN;
process.env.ZHIPUAI_AUTH_TOKEN = 'oauth-access-token';
try {
const client = createClientFromConfig({
provider: 'zhipuai',
model: 'glm-4.5',
});
expect(client).toBeInstanceOf(OpenAIClient);
} finally {
if (prev === undefined) {
delete process.env.ZHIPUAI_AUTH_TOKEN;
} else {
process.env.ZHIPUAI_AUTH_TOKEN = prev;
}
}
});
it('creates BedrockClient for bedrock provider', () => {
const client = createClientFromConfig({
provider: 'bedrock',
+19 -1
View File
@@ -18,6 +18,23 @@ function requireApiKey(cfg: ModelConfig, envVar: string): string {
return key;
}
function resolveAuthCredential(cfg: ModelConfig, apiKeyEnvVar: string, authTokenEnvVar?: string): string {
const raw = cfg.api_key
?? cfg.auth_token
?? process.env[apiKeyEnvVar]
?? (authTokenEnvVar ? process.env[authTokenEnvVar] : undefined);
if (!raw) {
const envHint = authTokenEnvVar ? `${apiKeyEnvVar} or ${authTokenEnvVar}` : apiKeyEnvVar;
throw new Error(
`Credential required for ${cfg.provider}. ` +
`Set ${envHint} environment variable or provide api_key/auth_token in config.`,
);
}
return raw.startsWith('Bearer ') ? raw.slice('Bearer '.length) : raw;
}
/**
* Create a ModelClient from a provider config entry.
* Dispatches on the `provider` field so all tiers and fallback entries
@@ -35,6 +52,7 @@ export function createClientFromConfig(cfg: ModelConfig): ModelClient {
return new OpenAIClient({
model: cfg.model,
apiKey: cfg.api_key,
useOAuth: Boolean(cfg.use_oauth),
});
case 'ollama':
return new OllamaClient({
@@ -62,7 +80,7 @@ export function createClientFromConfig(cfg: ModelConfig): ModelClient {
case 'zhipuai':
return new OpenAIClient({
model: cfg.model,
apiKey: requireApiKey(cfg, 'ZHIPUAI_API_KEY'),
apiKey: resolveAuthCredential(cfg, 'ZHIPUAI_API_KEY', 'ZHIPUAI_AUTH_TOKEN'),
baseURL: cfg.endpoint ?? 'https://api.z.ai/api/paas/v4',
});
case 'xai':
+112 -8
View File
@@ -9,7 +9,7 @@ import { MemoryStore } from '../memory/index.js';
import type { Tool } from '../tools/types.js';
import { createMediaSendTool } from '../tools/index.js';
import { createSandboxedShellTool, createSandboxedProcessStartTool, SandboxManager } from '../sandbox/index.js';
import type { Config } from '../config/index.js';
import { MODEL_PROVIDERS, type Config, type ModelConfig, type ModelProvider } from '../config/index.js';
import { ModelRouter, type ModelTier } from '../models/index.js';
import { ToolRegistry, ToolExecutor } from '../tools/index.js';
import { SessionManager } from '../session/index.js';
@@ -17,6 +17,27 @@ import { AgentConfigRegistry, AgentRouter } from '../agents/index.js';
import type { CommandRegistry } from '../commands/index.js';
import type { ComponentRegistry } from '../intents/index.js';
import type { RoutingPolicy } from '../routing/index.js';
import { createClientFromConfig } from './models.js';
function buildProviderConfigMap(config: Config): Partial<Record<ModelProvider, ModelConfig>> {
const providerConfigs: Partial<Record<ModelProvider, ModelConfig>> = {};
const modelConfigs: ModelConfig[] = [
config.models.default,
...(config.models.fast ? [config.models.fast] : []),
...(config.models.complex ? [config.models.complex] : []),
...(config.models.local ? [config.models.local] : []),
...Object.values(config.models.local_providers ?? {}),
];
for (const modelConfig of modelConfigs) {
providerConfigs[modelConfig.provider] = modelConfig;
if (modelConfig.fallback) {
providerConfigs[modelConfig.fallback.provider] = modelConfig.fallback;
}
}
return providerConfigs;
}
/**
* Create the unified message handler for the channel registry.
@@ -263,14 +284,97 @@ export function createMessageRouter(deps: {
return lines.join('\n');
},
setModel: (tier) => {
const validTiers = deps.modelRouter.getAvailableTiers();
if (!validTiers.includes(tier as ModelTier)) {
return `Model tier not available: ${tier}`;
const raw = tier.trim();
if (!raw) {
return 'Usage: /model <tier> OR /model <tier> <provider/model> OR /model <tier> reset';
}
const parts = raw.split(/\s+/);
const requestedTier = parts[0];
const validTiers = deps.modelRouter.getAvailableTiers();
if (!validTiers.includes(requestedTier as ModelTier)) {
return `Model tier not available: ${requestedTier}`;
}
const modelTier = requestedTier as ModelTier;
// /model <tier>
if (parts.length === 1) {
session.setConfig('modelTier', modelTier);
agent.setModelTier(modelTier);
const label = deps.modelRouter.getLabel(modelTier);
return `Switched to model: ${modelTier} (${label})`;
}
const arg2 = parts[1];
// /model <tier> reset — restore configured provider/model and re-enable fallbacks
if (arg2.toLowerCase() === 'reset') {
const configured: ModelConfig | undefined = modelTier === 'default'
? deps.config.models.default
: modelTier === 'fast'
? deps.config.models.fast
: modelTier === 'complex'
? deps.config.models.complex
: modelTier === 'local'
? deps.config.models.local
: undefined;
if (!configured) {
return `No configured model for tier: ${modelTier}`;
}
const client = createClientFromConfig(configured);
const label = `${configured.provider}/${configured.model}`;
deps.modelRouter.setClient(modelTier, client, label);
deps.modelRouter.setTierStrict(modelTier, false);
session.setConfig('modelTier', modelTier);
agent.setModelTier(modelTier);
return `Reset ${modelTier} to: ${label}`;
}
// /model <tier> <provider/model>
const providerModel = arg2;
if (!providerModel.includes('/')) {
return 'Invalid format. Use: /model <tier> <provider/model> (e.g. /model default github/gpt-5-mini)';
}
const slashIdx = providerModel.indexOf('/');
const provider = providerModel.slice(0, slashIdx);
const model = providerModel.slice(slashIdx + 1);
if (!MODEL_PROVIDERS.includes(provider as ModelProvider)) {
return `Unknown provider "${provider}". Known providers: ${MODEL_PROVIDERS.join(', ')}`;
}
const providerType = provider as ModelProvider;
const providerConfigs = buildProviderConfigMap(deps.config);
const template = providerConfigs[providerType];
try {
const client = createClientFromConfig(
template
? { ...template, provider: providerType, model }
: { provider: providerType, model },
);
deps.modelRouter.setClient(modelTier, client, providerModel);
deps.modelRouter.setTierStrict(modelTier, true);
session.setConfig('modelTier', modelTier);
agent.setModelTier(modelTier);
const lines = [
`Set ${modelTier} to: ${providerModel}`,
`Fallbacks for ${modelTier} disabled (strict tier mode).`,
];
if (parts.length > 2) {
lines.push(`Note: ignored extra args: ${parts.slice(2).join(' ')}`);
}
return lines.join('\n');
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
return `Failed to switch ${modelTier} to ${providerModel}: ${message}`;
}
session.setConfig('modelTier', tier);
agent.setModelTier(tier as ModelTier);
const label = deps.modelRouter.getLabel(tier as ModelTier);
return `Switched to model: ${tier} (${label})`;
},
compact: async () => {
const result = await agent.compact();
+2 -2
View File
@@ -124,7 +124,7 @@ Commands:
/model [name] Show or switch model tier (local, default, fast, complex)
/model <tier> <p/m> Change tier's provider/model (e.g. /model default anthropic/claude-sonnet-4)
/backend [provider] Show or switch local backend (ollama, llamacpp)
/login [provider] Authenticate with GitHub
/login [provider] Authenticate with GitHub or OpenAI
/pair List pending pairing codes and approved senders
/pair generate [label] Generate a new DM pairing code
/pair revoke <ch> <id> Revoke an approved sender
@@ -178,7 +178,7 @@ export const COMMAND_TOOLTIPS: Record<string, string> = {
'/status': 'Show session info and token usage',
'/fullscreen': 'Switch to fullscreen mode',
'/fs': 'Switch to fullscreen mode',
'/login': 'Authenticate with GitHub (OAuth device flow)',
'/login': 'Authenticate with GitHub or OpenAI (OAuth device flow)',
'/pair': 'Generate/list/revoke DM pairing codes',
'/transfer': 'Transfer session to another frontend',
'/quit': 'Exit TUI',
+184 -59
View File
@@ -1,5 +1,5 @@
import React, { useState, useCallback, useRef, useEffect } from 'react';
import { Box, useApp, useInput } from 'ink';
import { Box, Text, useApp, useInput } from 'ink';
import { StatusBar } from './StatusBar.js';
import { MessageList } from './MessageList.js';
import { InputBar } from './InputBar.js';
@@ -7,8 +7,11 @@ import { parseCommand, getHelpText, resolveModelAlias, getCommandCompletions } f
import type { Message, ModelClient, TokenUsage } from '../../../models/types.js';
import type { ModelRouter } from '../../../models/router.js';
import type { ManagedSession } from '../../../session/index.js';
import type { NativeAgent } from '../../../backends/native/agent.js';
import type { ToolUseEvent } from '../../../backends/native/agent.js';
import type { NativeAgent, ToolUseEvent } from '../../../backends/native/agent.js';
import type { HookEngine, HookResult } from '../../../hooks/index.js';
import type { ModelConfig, ModelProvider } from '../../../config/schema.js';
import { MODEL_PROVIDERS } from '../../../config/schema.js';
import { createClientFromConfig } from '../../../daemon/index.js';
/** Format a tool name like "gmail.list" -> "Gmail: List" */
function formatToolName(name: string): string {
@@ -44,6 +47,8 @@ export interface AppProps {
systemPrompt: string;
model: string;
agent?: NativeAgent;
hookEngine?: HookEngine;
modelProviderConfigs?: Partial<Record<ModelProvider, ModelConfig>>;
onExit?: () => void;
}
@@ -54,6 +59,8 @@ export function App({
systemPrompt,
model,
agent,
hookEngine,
modelProviderConfigs,
onExit,
}: AppProps): React.ReactElement {
const { exit } = useApp();
@@ -63,13 +70,20 @@ export function App({
const [streamingContent, setStreamingContent] = useState('');
const [scrollOffset, setScrollOffset] = useState(0);
const [tokenUsage, setTokenUsage] = useState<TokenUsage>({ inputTokens: 0, outputTokens: 0 });
const [currentModel, setCurrentModel] = useState(model);
const [currentModel, setCurrentModel] = useState(() => {
if (!modelRouter) {return model;}
return modelRouter.getLabel(modelRouter.getTier());
});
const abortRef = useRef(false);
const toolLinesRef = useRef<string[]>([]);
const confirmResolveRef = useRef<((result: HookResult) => void) | null>(null);
const [confirmation, setConfirmation] = useState<{ tool: string; args: Record<string, unknown> } | null>(null);
// Set up an Ink-compatible onToolUse callback for the agent.
// This replaces the process.stdout.write callback (which corrupts Ink rendering)
// with one that updates React state to show tool activity in the streaming area.
// This replaces process.stdout writes (which corrupt Ink rendering)
// with one that updates React state to show tool activity.
useEffect(() => {
if (!agent) {return;}
@@ -79,7 +93,10 @@ export function App({
const argsStr = event.args ? ` (${formatToolArgs(event.args)})` : '';
toolLinesRef.current = [...toolLinesRef.current, `> ${label}${argsStr}`];
setStreamingContent(toolLinesRef.current.join('\n'));
} else if (event.type === 'end' && event.result) {
return;
}
if (event.type === 'end' && event.result) {
const icon = event.result.success ? 'done' : 'error';
const detail = event.result.success
? `(${event.result.output.split('\n').length} lines)`
@@ -95,7 +112,43 @@ export function App({
};
}, [agent]);
// Inline confirmations for dangerous tools (e.g. shell.exec) in fullscreen mode.
useEffect(() => {
if (!hookEngine) {return;}
hookEngine.setInteractiveConfirmer(async (pending) => {
return await new Promise<HookResult>((resolve) => {
confirmResolveRef.current = resolve;
setConfirmation({ tool: pending.tool, args: pending.args });
});
});
return () => {
hookEngine.setInteractiveConfirmer(undefined);
confirmResolveRef.current = null;
setConfirmation(null);
};
}, [hookEngine]);
useInput((inputChar, key) => {
// Confirmation prompt mode: capture y/n and ignore everything else.
if (confirmation && confirmResolveRef.current) {
const c = inputChar.toLowerCase();
if (c === 'y') {
confirmResolveRef.current({ approved: true });
confirmResolveRef.current = null;
setConfirmation(null);
return;
}
if (c === 'n') {
confirmResolveRef.current({ approved: false, reason: 'Denied by user' });
confirmResolveRef.current = null;
setConfirmation(null);
return;
}
return;
}
if (key.escape) {
if (isStreaming) {
abortRef.current = true;
@@ -120,7 +173,6 @@ export function App({
return;
}
// Scroll handling
if (key.upArrow && scrollOffset > 0) {
setScrollOffset(prev => Math.max(0, prev - 1));
}
@@ -136,12 +188,15 @@ export function App({
});
const handleSubmit = useCallback(async (value: string) => {
if (confirmation) {
return;
}
const command = parseCommand(value);
if (!command) {return;}
setInput('');
// Handle commands
switch (command.type) {
case 'quit':
onExit?.();
@@ -160,69 +215,124 @@ export function App({
return;
case 'help': {
// Show help as system message
const helpMsg: Message = { role: 'assistant', content: getHelpText() };
const helpWithTs = session.addMessage(helpMsg);
setMessages(prev => [...prev, helpWithTs]);
setMessages(prev => [...prev, session.addMessage(helpMsg)]);
return;
}
case 'status': {
const status = `Session: ${session.id}\nMessages: ${messages.length}\nTokens: ${tokenUsage.inputTokens} in / ${tokenUsage.outputTokens} out`;
const statusMsg: Message = { role: 'assistant', content: status };
const statusWithTs = session.addMessage(statusMsg);
setMessages(prev => [...prev, statusWithTs]);
setMessages(prev => [...prev, session.addMessage(statusMsg)]);
return;
}
case 'model': {
if (!modelRouter) {
const errMsg: Message = { role: 'assistant', content: 'Model switching not available.' };
const errWithTs = session.addMessage(errMsg);
setMessages(prev => [...prev, errWithTs]);
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: 'Model switching not available.' })]);
return;
}
// /model
if (!command.name) {
const info = `Current: ${modelRouter.getTier()}\nAvailable: ${modelRouter.getAvailableTiers().join(', ')}`;
const infoMsg: Message = { role: 'assistant', content: info };
const infoWithTs = session.addMessage(infoMsg);
setMessages(prev => [...prev, infoWithTs]);
const current = modelRouter.getTier();
const available = modelRouter.getAvailableTiers();
const labels = modelRouter.getAllLabels();
const lines: string[] = [];
lines.push(`Active tier: ${current}`);
for (const t of available) {
const label = labels[t] ?? 'unknown';
const strict = modelRouter.isTierStrict(t) ? ' (strict)' : '';
lines.push(` ${t}: ${label}${strict}${t === current ? ' ←' : ''}`);
}
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: lines.join('\n') })]);
return;
}
// /model <tier> <provider/model>
if (command.providerModel) {
const tier = resolveModelAlias(command.name);
const providerModel = command.providerModel;
const slashIdx = providerModel.indexOf('/');
if (slashIdx === -1) {
setMessages(prev => [...prev, session.addMessage({
role: 'assistant',
content: 'Invalid format. Use provider/model (e.g. anthropic/claude-sonnet-4)',
})]);
return;
}
const provider = providerModel.slice(0, slashIdx);
const modelName = providerModel.slice(slashIdx + 1);
if (!MODEL_PROVIDERS.includes(provider as ModelProvider)) {
setMessages(prev => [...prev, session.addMessage({
role: 'assistant',
content: `Unknown provider "${provider}". Known providers: ${MODEL_PROVIDERS.join(', ')}`,
})]);
return;
}
try {
const providerType = provider as ModelProvider;
const template = modelProviderConfigs?.[providerType];
const client = createClientFromConfig({
...(template ?? {}),
provider: providerType,
model: modelName,
});
modelRouter.setClient(tier, client, providerModel);
modelRouter.setTierStrict(tier, true);
if (agent && tier === modelRouter.getTier()) {
agent.setModelTier(tier);
setCurrentModel(modelRouter.getLabel(tier));
}
setMessages(prev => [...prev, session.addMessage({
role: 'assistant',
content: `Set ${tier} to: ${providerModel}\nFallbacks for ${tier} disabled (strict tier mode).`,
})]);
} catch (error) {
const msg = error instanceof Error ? error.message : String(error);
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: `Failed to create client: ${msg}` })]);
}
return;
}
// /model <tier>
const tier = resolveModelAlias(command.name);
if (modelRouter.setTier(tier)) {
// Also update the agent tier so chatWithRouter uses the correct client
if (agent) {
agent.setModelTier(tier);
}
setCurrentModel(tier);
const successMsg: Message = { role: 'assistant', content: `Switched to model: ${tier}` };
const successWithTs = session.addMessage(successMsg);
setMessages(prev => [...prev, successWithTs]);
setCurrentModel(modelRouter.getLabel(tier));
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: `Switched to model: ${tier}` })]);
} else {
const failMsg: Message = { role: 'assistant', content: `Model not available: ${command.name}` };
const failWithTs = session.addMessage(failMsg);
setMessages(prev => [...prev, failWithTs]);
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: `Model not available: ${command.name}` })]);
}
return;
}
case 'fullscreen':
// Already in fullscreen
return;
case 'transfer': {
const xferMsg: Message = { role: 'assistant', content: 'Transfer not supported in fullscreen mode.' };
const xferWithTs = session.addMessage(xferMsg);
setMessages(prev => [...prev, xferWithTs]);
case 'transfer':
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: 'Transfer not supported in fullscreen mode.' })]);
return;
}
case 'message':
break; // Continue to message handling
break;
}
if (command.type !== 'message' || isStreaming) {return;}
if (command.type !== 'message' || isStreaming) {
return;
}
// Add user message to UI (and session if no agent — agent adds it internally)
const userMessage: Message = { role: 'user', content: command.content };
@@ -232,9 +342,8 @@ export function App({
} else {
setMessages(prev => [...prev, { ...userMessage, timestamp: Date.now() }]);
}
setScrollOffset(0); // Auto-scroll to bottom
setScrollOffset(0);
// Process response
setIsStreaming(true);
setStreamingContent('');
toolLinesRef.current = [];
@@ -242,16 +351,11 @@ export function App({
try {
if (agent) {
// agent.process() handles session history internally
const response = await agent.process(command.content);
await agent.process(command.content);
const usage = agent.getUsage();
setTokenUsage({ inputTokens: usage.inputTokens, outputTokens: usage.outputTokens });
// Sync UI with session history (agent already added messages to session)
setMessages(session.getHistory());
} else if (modelClient.chatStream) {
// Fallback: direct streaming without tools
let fullContent = '';
for await (const event of modelClient.chatStream({
@@ -279,10 +383,8 @@ export function App({
}
const assistantMessage: Message = { role: 'assistant', content: fullContent };
const assistantWithTimestamp = session.addMessage(assistantMessage);
setMessages(prev => [...prev, assistantWithTimestamp]);
setMessages(prev => [...prev, session.addMessage(assistantMessage)]);
} else {
// Fallback: non-streaming without tools
const response = await modelClient.chat({
messages: session.getHistory(),
system: systemPrompt,
@@ -294,21 +396,30 @@ export function App({
}));
const assistantMessage: Message = { role: 'assistant', content: response.content };
const assistantWithTimestamp = session.addMessage(assistantMessage);
setMessages(prev => [...prev, assistantWithTimestamp]);
setMessages(prev => [...prev, session.addMessage(assistantMessage)]);
}
} catch (error) {
const errorMessage: Message = {
role: 'assistant',
content: `Error: ${error instanceof Error ? error.message : 'Unknown error'}`,
};
const errorWithTimestamp = session.addMessage(errorMessage);
setMessages(prev => [...prev, errorWithTimestamp]);
const msg = error instanceof Error ? error.message : 'Unknown error';
setMessages(prev => [...prev, session.addMessage({ role: 'assistant', content: `Error: ${msg}` })]);
} finally {
setIsStreaming(false);
setStreamingContent('');
}
}, [isStreaming, session, agent, modelClient, modelRouter, systemPrompt, exit, onExit, messages.length, tokenUsage]);
}, [
confirmation,
session,
agent,
modelClient,
modelRouter,
systemPrompt,
exit,
onExit,
isStreaming,
messages.length,
tokenUsage.inputTokens,
tokenUsage.outputTokens,
modelProviderConfigs,
]);
return (
<Box flexDirection="column" height="100%">
@@ -317,13 +428,27 @@ export function App({
scrollOffset={scrollOffset}
streamingContent={isStreaming ? streamingContent : undefined}
/>
{confirmation ? (
<Box paddingX={1} paddingY={0} borderStyle="round" borderColor="yellow">
<Text color="yellow">
Confirmation required: {confirmation.tool}{' '}
{Object.keys(confirmation.args).length > 0 ? JSON.stringify(confirmation.args) : ''}
</Text>
<Text color="yellow">Press y to approve, n to deny.</Text>
</Box>
) : null}
<InputBar
value={input}
onChange={setInput}
onSubmit={handleSubmit}
isLoading={isStreaming}
placeholder={isStreaming ? 'Flynn is typing... (Esc to cancel)' : 'Type a message... (Esc=exit, /help)'}
isLoading={isStreaming || !!confirmation}
placeholder={confirmation
? 'Confirmation required (press y/n)'
: (isStreaming ? 'Flynn is typing... (Esc to cancel)' : 'Type a message... (Esc=exit, /help)')}
/>
<StatusBar
sessionId={session.id}
messageCount={messages.length}
+10
View File
@@ -5,6 +5,8 @@ import type { ManagedSession } from '../../session/index.js';
import type { ModelClient } from '../../models/types.js';
import type { ModelRouter } from '../../models/router.js';
import type { NativeAgent } from '../../backends/native/agent.js';
import type { HookEngine } from '../../hooks/index.js';
import type { ModelConfig, ModelProvider } from '../../config/index.js';
export interface FullscreenTuiConfig {
session: ManagedSession;
@@ -13,6 +15,8 @@ export interface FullscreenTuiConfig {
systemPrompt: string;
model: string;
agent?: NativeAgent;
hookEngine?: HookEngine;
modelProviderConfigs?: Partial<Record<ModelProvider, ModelConfig>>;
onExit?: () => void;
}
@@ -22,6 +26,10 @@ export async function startFullscreenTui(config: FullscreenTuiConfig): Promise<v
process.stdin.resume();
}
if (config.agent && config.modelRouter) {
config.agent.setModelTier(config.modelRouter.getTier());
}
const { waitUntilExit } = render(
React.createElement(App, {
session: config.session,
@@ -30,6 +38,8 @@ export async function startFullscreenTui(config: FullscreenTuiConfig): Promise<v
systemPrompt: config.systemPrompt,
model: config.model,
agent: config.agent,
hookEngine: config.hookEngine,
modelProviderConfigs: config.modelProviderConfigs,
onExit: config.onExit,
}),
);
+53
View File
@@ -118,4 +118,57 @@ describe('MinimalTui backend command', () => {
expect(mockRouter.setTier).toHaveBeenCalledWith('local');
expect(mockAgent.setModelTier).toHaveBeenCalledWith('local');
});
it('reuses configured provider credentials for /model <tier> <provider/model>', () => {
const prevOpenRouterKey = process.env.OPENROUTER_API_KEY;
delete process.env.OPENROUTER_API_KEY;
try {
const mockSession = {
id: 'test',
getHistory: () => [],
addMessage: vi.fn(),
clear: vi.fn(),
replaceHistory: vi.fn(),
};
const mockRouter = {
getTier: () => 'default' as const,
getAvailableTiers: () => ['default', 'local'],
setTier: vi.fn(() => true),
getLocalProviderName: () => 'ollama',
setLocalClient: vi.fn(),
setClient: vi.fn(),
setTierStrict: vi.fn(),
chat: vi.fn(),
getClient: vi.fn(),
};
const tui = new MinimalTui({
session: mockSession as any,
modelClient: mockRouter as any,
modelRouter: mockRouter as any,
systemPrompt: 'test',
modelProviderConfigs: {
openrouter: {
provider: 'openrouter',
model: 'seed-model',
api_key: 'test-key',
endpoint: 'https://openrouter.ai/api/v1',
},
},
});
(tui as any).handleModelCommand('default', 'openrouter/deepseek/deepseek-chat');
expect(mockRouter.setClient).toHaveBeenCalledOnce();
expect(mockRouter.setTierStrict).toHaveBeenCalledWith('default', true);
} finally {
if (prevOpenRouterKey) {
process.env.OPENROUTER_API_KEY = prevOpenRouterKey;
} else {
delete process.env.OPENROUTER_API_KEY;
}
}
});
});
+80 -18
View File
@@ -9,9 +9,10 @@ import type { ModelConfig, ModelProvider } from '../../config/schema.js';
import { MODEL_PROVIDERS } from '../../config/schema.js';
import { OllamaClient, LlamaCppClient } from '../../models/index.js';
import { createClientFromConfig } from '../../daemon/index.js';
import { loginGitHub } from '../../auth/index.js';
import { loginGitHub, loginOpenAI } from '../../auth/index.js';
import type { PairingManager } from '../../channels/pairing.js';
import { getColoredBanner } from './banner.js';
import type { HookEngine } from '../../hooks/index.js';
export { parseCommand, type Command };
@@ -42,8 +43,10 @@ export interface MinimalTuiConfig {
onFullscreen?: () => void;
onTransfer?: (target: string) => void;
localProviders?: Record<string, ModelConfig>;
modelProviderConfigs?: Partial<Record<ModelProvider, ModelConfig>>;
currentLocalProvider?: string;
pairingManager?: PairingManager;
hookEngine?: HookEngine;
}
export class MinimalTui {
@@ -99,6 +102,10 @@ export class MinimalTui {
async start(): Promise<void> {
this.running = true;
if (this.config.agent && this.config.modelRouter) {
this.config.agent.setModelTier(this.config.modelRouter.getTier());
}
this.rl = readline.createInterface({
input: process.stdin,
output: process.stdout,
@@ -108,6 +115,26 @@ export class MinimalTui {
},
});
// In minimal TUI we can prompt inline for tool confirmations.
// This avoids deadlocks when hooks are configured to require confirmation
// (e.g. shell.exec) and the tool loop is awaiting a decision.
if (this.config.hookEngine) {
this.config.hookEngine.setInteractiveConfirmer(async (pending) => {
const tool = pending.tool;
const args = pending.args;
const argsStr = Object.keys(args).length > 0 ? ` ${JSON.stringify(args)}` : '';
console.log(`\n${colors.bold}Confirmation required${colors.reset}`);
console.log(`${colors.gray}${tool}${colors.reset}${argsStr}`);
const answer = (await this.prompt(`${colors.orange}${colors.bold}Approve?${colors.reset} ${colors.gray}(y/N)${colors.reset} `))
.trim()
.toLowerCase();
const approved = answer === 'y' || answer === 'yes';
console.log(approved ? `${colors.gray}Approved.${colors.reset}\n` : `${colors.gray}Denied.${colors.reset}\n`);
return approved ? { approved: true } : { approved: false, reason: 'Denied by user' };
});
}
// Listen for line changes to show hints
process.stdin.on('keypress', () => {
// Small delay to let readline update the line
@@ -239,9 +266,22 @@ export class MinimalTui {
}
try {
const client = createClientFromConfig({ provider: provider as ModelProvider, model });
const providerType = provider as ModelProvider;
const template = this.config.modelProviderConfigs?.[providerType];
const client = createClientFromConfig({
...(template ?? {}),
provider: providerType,
model,
});
router.setClient(tier, client, providerModel);
console.log(`${colors.gray}Set ${tier} to:${colors.reset} ${providerModel}\n`);
router.setTierStrict(tier, true);
if (this.config.agent && tier === router.getTier()) {
this.config.agent.setModelTier(tier);
}
console.log(`${colors.gray}Set ${tier} to:${colors.reset} ${providerModel}`);
console.log(`${colors.gray}Fallbacks for ${tier} disabled (strict tier mode).${colors.reset}\n`);
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
console.log(`${colors.gray}Failed to create client:${colors.reset} ${message}\n`);
@@ -383,27 +423,49 @@ export class MinimalTui {
private async handleLoginCommand(provider?: string): Promise<void> {
const target = provider ?? 'github';
if (target !== 'github') {
console.log(`${colors.gray}Unknown login provider:${colors.reset} ${target}. Only 'github' is supported.\n`);
if (target === 'github') {
console.log(`${colors.gray}Starting GitHub OAuth device flow...${colors.reset}`);
try {
await loginGitHub((userCode, verificationUri) => {
console.log('');
console.log(`${colors.gray}Please visit:${colors.reset} ${verificationUri}`);
console.log(`${colors.gray}and enter code:${colors.reset} ${userCode}`);
console.log('');
console.log(`${colors.gray}Waiting for authorization...${colors.reset}`);
});
console.log(`${colors.gray}GitHub authentication successful! Token stored.${colors.reset}\n`);
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
console.log(`${colors.gray}GitHub login failed:${colors.reset} ${message}\n`);
}
return;
}
console.log(`${colors.gray}Starting GitHub OAuth device flow...${colors.reset}`);
if (target === 'openai') {
console.log(`${colors.gray}Starting OpenAI OAuth device flow...${colors.reset}`);
try {
await loginGitHub((userCode, verificationUri) => {
console.log('');
console.log(`${colors.gray}Please visit:${colors.reset} ${verificationUri}`);
console.log(`${colors.gray}and enter code:${colors.reset} ${userCode}`);
console.log('');
console.log(`${colors.gray}Waiting for authorization...${colors.reset}`);
});
try {
await loginOpenAI((userCode, verificationUri) => {
console.log('');
console.log(`${colors.gray}Please visit:${colors.reset} ${verificationUri}`);
console.log(`${colors.gray}and enter code:${colors.reset} ${userCode}`);
console.log('');
console.log(`${colors.gray}Waiting for authorization...${colors.reset}`);
});
console.log(`${colors.gray}GitHub authentication successful! Token stored.${colors.reset}\n`);
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
console.log(`${colors.gray}GitHub login failed:${colors.reset} ${message}\n`);
console.log(`${colors.gray}OpenAI authentication successful! Token stored.${colors.reset}\n`);
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
console.log(`${colors.gray}OpenAI login failed:${colors.reset} ${message}\n`);
}
return;
}
console.log(`${colors.gray}Unknown login provider:${colors.reset} ${target}. Supported: github, openai\n`);
}
private handlePairCommand(action?: 'generate' | 'list' | 'revoke', args?: string): void {
+11 -3
View File
@@ -88,11 +88,19 @@ export function createAgentHandlers(deps: AgentHandlerDeps) {
return lines.join('\n');
},
getModel: () => `Current model tier: ${agent.getModelTier()}`,
setModel: (tier) => {
setModel: (input) => {
const raw = input.trim();
if (!raw) {
return 'Usage: /model <tier>';
}
const [requestedTier, ...rest] = raw.split(/\s+/);
const validTiers: ModelTier[] = ['fast', 'default', 'complex', 'local'];
const modelTier = tier as ModelTier;
const modelTier = requestedTier as ModelTier;
if (!validTiers.includes(modelTier)) {
return `Invalid tier: ${tier}. Available: ${validTiers.join(', ')}`;
return `Invalid tier: ${requestedTier}. Available: ${validTiers.join(', ')}`;
}
if (rest.length > 0) {
return `Switched to model tier: ${modelTier}\nNote: provider/model switching is not available via gateway (/model <tier> <provider/model>).`;
}
agent.setModelTier(modelTier);
if (sessionId && deps.sessionManager) {
+15
View File
@@ -72,4 +72,19 @@ describe('HookEngine', () => {
expect(result.approved).toBe(false);
expect(result.reason).toBe('Too dangerous');
});
it('uses interactive confirmer when set (no pending queue)', async () => {
const engine = new HookEngine({ confirm: ['shell.*'], log: [], silent: [] });
const confirmer = vi.fn(async () => ({ approved: true }));
engine.setInteractiveConfirmer(confirmer);
const result = await engine.requestConfirmation('shell.exec', { cmd: 'ls' });
expect(result.approved).toBe(true);
expect(engine.getPendingConfirmations()).toHaveLength(0);
expect(confirmer).toHaveBeenCalledOnce();
expect(confirmer).toHaveBeenCalledWith(expect.objectContaining({
tool: 'shell.exec',
args: { cmd: 'ls' },
}));
});
});
+20
View File
@@ -1,16 +1,32 @@
import { randomUUID } from 'crypto';
import type { HookAction, HookResult, PendingConfirmation, HookConfig } from './types.js';
export type InteractiveConfirmer = (pending: {
id: string;
tool: string;
args: Record<string, unknown>;
}) => Promise<HookResult>;
export class HookEngine {
private confirmPatterns: RegExp[];
private logPatterns: RegExp[];
private pendingConfirmations: Map<string, PendingConfirmation> = new Map();
private interactiveConfirmer?: InteractiveConfirmer;
constructor(config: HookConfig) {
this.confirmPatterns = config.confirm.map(p => this.patternToRegex(p));
this.logPatterns = config.log.map(p => this.patternToRegex(p));
}
/**
* Optional interactive confirmation handler.
* When set, confirmation requests are handled immediately (no pending queue).
* Useful for CLI/TUI environments where we can prompt the user inline.
*/
setInteractiveConfirmer(confirmer: InteractiveConfirmer | undefined): void {
this.interactiveConfirmer = confirmer;
}
private patternToRegex(pattern: string): RegExp {
const escaped = pattern
.replace(/[.+^${}()|[\]\\]/g, '\\$&')
@@ -31,6 +47,10 @@ export class HookEngine {
async requestConfirmation(tool: string, args: Record<string, unknown>): Promise<HookResult> {
const id = randomUUID();
if (this.interactiveConfirmer) {
return await this.interactiveConfirmer({ id, tool, args });
}
return new Promise((resolve) => {
const pending: PendingConfirmation = {
id,
+38
View File
@@ -140,6 +140,44 @@ describe('LlamaCppClient', () => {
}]);
});
it('sanitizes web_search tool schema for llama.cpp', async () => {
mockFetch.mockResolvedValue({
ok: true,
json: () => Promise.resolve({
choices: [{ message: { content: 'ok' } }],
usage: { prompt_tokens: 1, completion_tokens: 1 },
}),
});
const client = new LlamaCppClient({
endpoint: 'http://localhost:8080',
model: 'test-model',
});
await client.chat({
messages: [{ role: 'user', content: 'search' }],
tools: [{
name: 'web_search',
description: 'Search',
input_schema: {
type: 'object',
properties: {
query: { type: 'string' },
count: { type: 'number' },
},
required: ['query'],
},
}],
});
const requestBody = JSON.parse(mockFetch.mock.calls[0][1].body);
expect(requestBody.tools[0].function.parameters).toEqual({
type: 'object',
properties: { query: { type: 'string' } },
required: ['query'],
});
});
it('parses tool_calls from response', async () => {
mockFetch.mockResolvedValue({
ok: true,
+32 -2
View File
@@ -48,6 +48,36 @@ interface LlamaCppStreamChunk {
usage?: { prompt_tokens: number; completion_tokens: number };
}
function sanitizeToolParametersForLlamaCpp(toolName: string, parameters: unknown): unknown {
// llama.cpp is stricter than most tool-call APIs about JSON schema.
// In particular, some builds reject extra optional properties for common tools.
// Keep the full schema for most tools, but reduce known-problematic ones.
if (toolName !== 'web_search') {
return parameters;
}
if (!parameters || typeof parameters !== 'object') {
return parameters;
}
const schema = parameters as {
type?: unknown;
properties?: Record<string, unknown>;
required?: unknown;
};
const querySchema = schema.properties?.query;
if (!querySchema) {
return parameters;
}
return {
type: 'object',
properties: { query: querySchema },
required: ['query'],
};
}
/** Message format for OpenAI-compatible chat completions API. */
interface LlamaCppChatMessage {
role: 'system' | 'user' | 'assistant' | 'tool';
@@ -211,7 +241,7 @@ export class LlamaCppClient implements ModelClient {
function: {
name: t.name,
description: t.description,
parameters: t.input_schema,
parameters: sanitizeToolParametersForLlamaCpp(t.name, t.input_schema),
},
}));
}
@@ -292,7 +322,7 @@ export class LlamaCppClient implements ModelClient {
function: {
name: t.name,
description: t.description,
parameters: t.input_schema,
parameters: sanitizeToolParametersForLlamaCpp(t.name, t.input_schema),
},
}));
}
+68
View File
@@ -0,0 +1,68 @@
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { OpenAIClient } from './openai.js';
vi.mock('../auth/openai.js', () => ({
ensureValidOpenAIAuth: vi.fn(async () => ({
access_token: 'at',
refresh_token: 'rt',
expires_at: Date.now() + 60_000,
created_at: new Date().toISOString(),
account_id: 'acct',
})),
}));
function makeSse(events: Array<{ event: string; data: any }>): string {
return events
.map((e) => `event: ${e.event}\ndata: ${JSON.stringify(e.data)}\n\n`)
.join('');
}
describe('OpenAIClient OAuth (Codex)', () => {
const originalFetch = globalThis.fetch;
beforeEach(() => {
vi.useFakeTimers();
});
afterEach(() => {
globalThis.fetch = originalFetch;
vi.useRealTimers();
vi.restoreAllMocks();
});
it('streams SSE and accumulates output_text.delta', async () => {
const sse = makeSse([
{ event: 'response.created', data: { type: 'response.created', response: { id: 'r1' } } },
{ event: 'response.output_text.delta', data: { type: 'response.output_text.delta', delta: 'hel' } },
{ event: 'response.output_text.delta', data: { type: 'response.output_text.delta', delta: 'lo' } },
{ event: 'response.completed', data: { type: 'response.completed', response: { usage: { input_tokens: 2, output_tokens: 2 } } } },
]);
globalThis.fetch = vi.fn(async (_url: any, init?: any) => {
const parsed = JSON.parse(init.body);
expect(parsed.store).toBe(false);
expect(parsed.stream).toBe(true);
expect(typeof parsed.instructions).toBe('string');
expect(Array.isArray(parsed.input)).toBe(true);
const stream = new ReadableStream({
start(controller) {
controller.enqueue(new TextEncoder().encode(sse));
controller.close();
},
});
return new Response(stream, { status: 200 });
}) as any;
const client = new OpenAIClient({ model: 'gpt-5.3-codex', useOAuth: true });
const resp = await client.chat({
system: 'You are helpful.',
messages: [{ role: 'user', content: 'hi' }],
});
expect(resp.content).toBe('hello');
expect(resp.usage).toEqual({ inputTokens: 2, outputTokens: 2 });
});
});
+24 -9
View File
@@ -1,23 +1,38 @@
import { describe, it, expect, vi } from 'vitest';
import { OpenAIClient } from './openai.js';
// Shared mock function so we can override per-test
const mockCreate = vi.fn().mockResolvedValue({
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
usage: { prompt_tokens: 10, completion_tokens: 5 },
});
vi.mock('openai', () => ({
default: vi.fn().mockImplementation(() => ({
const { mockCreate, mockOpenAIConstructor } = vi.hoisted(() => {
const mockCreate = vi.fn().mockResolvedValue({
choices: [{ message: { content: 'Hello from GPT!' }, finish_reason: 'stop' }],
usage: { prompt_tokens: 10, completion_tokens: 5 },
});
const mockOpenAIConstructor = vi.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate,
},
},
})),
}));
return { mockCreate, mockOpenAIConstructor };
});
vi.mock('openai', () => ({
default: mockOpenAIConstructor,
}));
describe('OpenAIClient', () => {
it('sets request timeout and disables SDK retries', () => {
new OpenAIClient({
apiKey: 'test-key',
model: 'gpt-4o',
});
expect(mockOpenAIConstructor).toHaveBeenCalledWith(expect.objectContaining({
timeout: 20_000,
maxRetries: 0,
}));
});
it('sends messages and returns response', async () => {
const client = new OpenAIClient({
apiKey: 'test-key',
+151 -6
View File
@@ -1,11 +1,16 @@
import OpenAI from 'openai';
import type { ChatRequest, ChatResponse, ModelClient, MessageContentPart } from './types.js';
import type { ChatRequest, ChatResponse, ModelClient, MessageContentPart, TokenUsage } from './types.js';
import { getMessageTextWithTools } from './media.js';
import { ensureValidOpenAIAuth } from '../auth/openai.js';
export interface OpenAIClientConfig {
apiKey?: string;
model: string;
maxTokens?: number;
baseURL?: string;
timeoutMs?: number;
/** If true, use ChatGPT subscription OAuth via the Codex backend endpoint. */
useOAuth?: boolean;
}
/**
@@ -52,20 +57,160 @@ function toOpenAIContent(content: string | MessageContentPart[]): string | OpenA
}
export class OpenAIClient implements ModelClient {
private client: OpenAI;
private client?: OpenAI;
private model: string;
private defaultMaxTokens: number;
private useOAuth: boolean;
constructor(config: OpenAIClientConfig) {
this.client = new OpenAI({
apiKey: config.apiKey,
baseURL: config.baseURL,
});
const timeoutMs = config.timeoutMs ?? 20_000;
this.useOAuth = Boolean(config.useOAuth);
// OAuth mode uses a different backend (ChatGPT Codex) and a different API shape.
// Only initialize the OpenAI SDK for API-key providers.
if (!this.useOAuth) {
this.client = new OpenAI({
apiKey: config.apiKey,
baseURL: config.baseURL,
timeout: timeoutMs,
maxRetries: 0,
});
}
this.model = config.model;
this.defaultMaxTokens = config.maxTokens ?? 4096;
}
private async chatViaOAuthCodex(request: ChatRequest): Promise<ChatResponse> {
const CODEX_API_ENDPOINT = 'https://chatgpt.com/backend-api/codex/responses';
const auth = await ensureValidOpenAIAuth();
// Codex endpoint requires:
// - instructions (non-empty)
// - input must be a list
// - store must be false
// - stream must be true (SSE)
const instructions = (request.system ?? '').trim() || 'You are helpful.';
const input = request.messages
.map((m) => {
const text = getMessageTextWithTools(m);
if (!text) {return null;}
return {
role: m.role,
content: [{ type: 'input_text', text }],
};
})
.filter((x): x is NonNullable<typeof x> => Boolean(x));
const body = {
model: this.model,
instructions,
store: false,
stream: true,
input,
// Intentionally omit max_output_tokens: Codex endpoint rejects it.
// Also omit tools/tool_choice for now.
};
const headers: Record<string, string> = {
'content-type': 'application/json',
'authorization': `Bearer ${auth.access_token}`,
'originator': 'flynn',
'user-agent': 'flynn/0.1',
'session_id': `flynn-${Date.now()}`,
};
if (auth.account_id) {
headers['ChatGPT-Account-Id'] = auth.account_id;
}
const res = await fetch(CODEX_API_ENDPOINT, {
method: 'POST',
headers,
body: JSON.stringify(body),
});
if (!res.ok) {
const text = await res.text();
throw new Error(`${res.status} ${res.statusText}${text ? `: ${text}` : ''}`);
}
if (!res.body) {
throw new Error('OpenAI OAuth request failed: missing response body');
}
let buffer = '';
let outputText = '';
let usage: TokenUsage | undefined;
const reader = res.body.getReader();
const processBlock = (block: string): void => {
const lines = block.split('\n');
let data = '';
for (const line of lines) {
if (line.startsWith('data:')) {
data += line.slice('data:'.length).trim();
}
}
if (!data) {return;}
let obj: any;
try {
obj = JSON.parse(data);
} catch {
return;
}
if (obj.type === 'response.output_text.delta' && typeof obj.delta === 'string') {
outputText += obj.delta;
}
if (obj.type === 'response.completed') {
const u = obj.response?.usage;
if (u) {
usage = {
inputTokens: u.input_tokens ?? 0,
outputTokens: u.output_tokens ?? 0,
};
}
}
if (obj.type === 'response.failed') {
const detail = obj.response?.error?.message ?? 'OpenAI OAuth response failed';
throw new Error(detail);
}
};
while (true) {
const { value, done } = await reader.read();
if (done) {break;}
buffer += Buffer.from(value).toString('utf8');
while (true) {
const idx = buffer.indexOf('\n\n');
if (idx === -1) {break;}
const block = buffer.slice(0, idx);
buffer = buffer.slice(idx + 2);
processBlock(block);
}
}
return {
content: outputText,
stopReason: 'end_turn',
usage: usage ?? { inputTokens: 0, outputTokens: 0 },
};
}
async chat(request: ChatRequest): Promise<ChatResponse> {
if (this.useOAuth) {
return this.chatViaOAuthCodex(request);
}
if (!this.client) {
throw new Error('OpenAI client not initialized');
}
const messages: OpenAI.ChatCompletionMessageParam[] = [];
if (request.system) {
+8 -3
View File
@@ -4,10 +4,15 @@ import type { RetryConfig } from './retry.js';
describe('isRetryable', () => {
it('returns true for generic errors', () => {
const error = new Error('Connection timeout');
const error = new Error('Connection reset by peer');
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(true);
});
it('returns false for timeout errors', () => {
const error = new Error('Request timed out after 20000ms');
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(false);
});
it('returns false for authentication errors', () => {
const error = new Error('Invalid API key: authentication failed');
expect(isRetryable(error, DEFAULT_RETRY_CONFIG.nonRetryablePatterns)).toBe(false);
@@ -75,8 +80,8 @@ describe('withRetry', () => {
it('retries on transient failure then succeeds', async () => {
const fn = vi.fn()
.mockRejectedValueOnce(new Error('timeout'))
.mockRejectedValueOnce(new Error('timeout'))
.mockRejectedValueOnce(new Error('temporary network issue'))
.mockRejectedValueOnce(new Error('temporary network issue'))
.mockResolvedValueOnce('recovered');
const result = await withRetry(fn, fastConfig, 'test-op');
+3
View File
@@ -26,6 +26,9 @@ export const DEFAULT_RETRY_CONFIG: RetryConfig = {
'context_length_exceeded',
'content_policy',
'does not support',
'timeout',
'timed out',
'request aborted',
],
};
+25
View File
@@ -438,4 +438,29 @@ describe('setClient and labels', () => {
expect(newFastClient!.chat).toHaveBeenCalledTimes(1);
expect(initialFastClient!.chat).toHaveBeenCalledTimes(1);
});
it('strict tier mode disables fallback chain for that tier', async () => {
const failingDefault = {
chat: vi.fn().mockRejectedValue(new Error('primary failed')),
} as unknown as ModelClient;
const fallback = {
chat: vi.fn().mockResolvedValue({
content: 'fallback',
stopReason: 'end_turn',
usage: { inputTokens: 1, outputTokens: 1 },
}),
} as unknown as ModelClient;
const router = new ModelRouter({
default: failingDefault,
fallbackChain: [fallback],
});
router.setTierStrict('default', true);
await expect(router.chat({ messages: [{ role: 'user', content: 'Hi' }] }, 'default'))
.rejects.toThrow('primary failed');
expect(fallback.chat).not.toHaveBeenCalled();
expect(router.isTierStrict('default')).toBe(true);
});
});
+22
View File
@@ -27,6 +27,7 @@ export class ModelRouter implements ModelClient {
private localProviderName?: string;
private retryConfig?: RetryConfig;
private tierChangeListeners: Array<(tier: ModelTier) => void> = [];
private strictTiers: Set<ModelTier> = new Set();
constructor(config: ModelRouterConfig) {
this.clients = new Map();
@@ -97,6 +98,10 @@ export class ModelRouter implements ModelClient {
logger.debug(`Primary model failed: ${errors[0].message}`);
}
if (this.strictTiers.has(useTier)) {
throw errors[0];
}
// Try tier-specific fallbacks first
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
for (let i = 0; i < tierFallbackList.length; i++) {
@@ -150,6 +155,11 @@ export class ModelRouter implements ModelClient {
primaryError = 'Primary client does not support streaming';
}
if (this.strictTiers.has(useTier)) {
yield { type: 'error', error: new Error(primaryError ?? 'Primary model failed') };
return;
}
// Try tier-specific fallbacks first
const tierFallbackList = this.tierFallbacks.get(useTier) ?? [];
for (let i = 0; i < tierFallbackList.length; i++) {
@@ -216,6 +226,18 @@ export class ModelRouter implements ModelClient {
this.labels.set(tier, label);
}
setTierStrict(tier: ModelTier, strict: boolean): void {
if (strict) {
this.strictTiers.add(tier);
return;
}
this.strictTiers.delete(tier);
}
isTierStrict(tier: ModelTier): boolean {
return this.strictTiers.has(tier);
}
getLabel(tier: ModelTier): string {
return this.labels.get(tier) ?? 'unknown';
}
+3
View File
@@ -44,6 +44,9 @@ const testConfig: NonNullable<GmailConfig> = {
enabled: true,
credentials_file: '/tmp/test-creds.json',
token_file: '/tmp/test-token.json',
disable_push: false,
pubsub_pull_interval: '60s',
pubsub_max_messages: 10,
watch_labels: ['INBOX'],
poll_interval: '300s',
output: { channel: 'discord', peer: '123' },