feat: add GitHub Copilot model provider with OAuth device flow
Add a new 'github' model provider backed by the Copilot API (api.githubcopilot.com), with OAuth device flow for authentication. - New src/auth/github.ts: device flow login, token storage at ~/.config/flynn/auth.json with 0600 permissions - New src/models/github.ts: OpenAI-compatible client with streaming, tool calling, and Copilot-specific headers - Add 'github' to provider enum in config schema - Register provider in daemon factory and TUI client factory - Refactor TUI to use provider-agnostic client factory (was hardcoded to AnthropicClient for all tiers) - Add /login command to TUI for interactive OAuth authorization - Add Copilot model cost tracking entries
This commit is contained in:
@@ -0,0 +1,164 @@
|
||||
import { readFileSync, writeFileSync, mkdirSync, chmodSync } from 'fs';
|
||||
import { resolve } from 'path';
|
||||
import { homedir } from 'os';
|
||||
|
||||
const COPILOT_CLIENT_ID = 'Ov23li8tweQw6odWQebz';
|
||||
const DEVICE_CODE_URL = 'https://github.com/login/device/code';
|
||||
const TOKEN_URL = 'https://github.com/login/oauth/access_token';
|
||||
const POLLING_SAFETY_MARGIN_MS = 3000;
|
||||
|
||||
const AUTH_DIR = resolve(homedir(), '.config/flynn');
|
||||
const AUTH_FILE = resolve(AUTH_DIR, 'auth.json');
|
||||
|
||||
export interface DeviceCodeResponse {
|
||||
device_code: string;
|
||||
user_code: string;
|
||||
verification_uri: string;
|
||||
expires_in: number;
|
||||
interval: number;
|
||||
}
|
||||
|
||||
interface AuthStore {
|
||||
github?: {
|
||||
access_token: string;
|
||||
created_at: string;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Request a device code from GitHub to start the OAuth device flow.
|
||||
*/
|
||||
export async function requestDeviceCode(): Promise<DeviceCodeResponse> {
|
||||
const response = await fetch(DEVICE_CODE_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ client_id: COPILOT_CLIENT_ID }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to request device code: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
|
||||
return response.json() as Promise<DeviceCodeResponse>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Poll GitHub for an access token after the user has entered the device code.
|
||||
* Blocks until the user authorizes or the code expires.
|
||||
*/
|
||||
export async function pollForToken(deviceCode: string, interval: number): Promise<string> {
|
||||
let currentInterval = interval;
|
||||
|
||||
while (true) {
|
||||
await new Promise(r => setTimeout(r, currentInterval * 1000 + POLLING_SAFETY_MARGIN_MS));
|
||||
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
client_id: COPILOT_CLIENT_ID,
|
||||
device_code: deviceCode,
|
||||
grant_type: 'urn:ietf:params:oauth:grant-type:device_code',
|
||||
}),
|
||||
});
|
||||
|
||||
const data = await response.json() as Record<string, unknown>;
|
||||
|
||||
if (data.access_token) {
|
||||
return data.access_token as string;
|
||||
}
|
||||
|
||||
if (data.error === 'authorization_pending') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (data.error === 'slow_down') {
|
||||
// Add 5 seconds as per GitHub spec
|
||||
currentInterval = (data.interval as number) ?? currentInterval + 5;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (data.error === 'expired_token') {
|
||||
throw new Error('Device code expired. Please try again.');
|
||||
}
|
||||
|
||||
if (data.error === 'access_denied') {
|
||||
throw new Error('Authorization was denied by the user.');
|
||||
}
|
||||
|
||||
throw new Error(`OAuth error: ${data.error ?? 'unknown'} - ${data.error_description ?? ''}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a previously stored GitHub OAuth token from disk.
|
||||
* Returns null if no token is stored or the file doesn't exist.
|
||||
*/
|
||||
export function loadStoredToken(): string | null {
|
||||
try {
|
||||
const raw = readFileSync(AUTH_FILE, 'utf-8');
|
||||
const store = JSON.parse(raw) as AuthStore;
|
||||
return store.github?.access_token ?? null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Store a GitHub OAuth token to disk with secure permissions.
|
||||
*/
|
||||
export function storeToken(token: string): void {
|
||||
mkdirSync(AUTH_DIR, { recursive: true });
|
||||
|
||||
let store: AuthStore = {};
|
||||
try {
|
||||
const raw = readFileSync(AUTH_FILE, 'utf-8');
|
||||
store = JSON.parse(raw) as AuthStore;
|
||||
} catch {
|
||||
// File doesn't exist yet — start fresh
|
||||
}
|
||||
|
||||
store.github = {
|
||||
access_token: token,
|
||||
created_at: new Date().toISOString(),
|
||||
};
|
||||
|
||||
writeFileSync(AUTH_FILE, JSON.stringify(store, null, 2) + '\n', 'utf-8');
|
||||
chmodSync(AUTH_FILE, 0o600);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a GitHub token from any available source.
|
||||
* Priority: GITHUB_TOKEN env var → stored OAuth token → null
|
||||
*/
|
||||
export function getGitHubToken(): string | null {
|
||||
// 1. Environment variable
|
||||
const envToken = process.env.GITHUB_TOKEN;
|
||||
if (envToken) return envToken;
|
||||
|
||||
// 2. Stored OAuth token
|
||||
return loadStoredToken();
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the full GitHub OAuth device flow interactively.
|
||||
* @param onPrompt Callback to display the user code and verification URL to the user.
|
||||
* @returns The access token.
|
||||
*/
|
||||
export async function loginGitHub(
|
||||
onPrompt: (userCode: string, verificationUri: string) => void,
|
||||
): Promise<string> {
|
||||
const deviceCode = await requestDeviceCode();
|
||||
|
||||
onPrompt(deviceCode.user_code, deviceCode.verification_uri);
|
||||
|
||||
const token = await pollForToken(deviceCode.device_code, deviceCode.interval);
|
||||
storeToken(token);
|
||||
return token;
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
export {
|
||||
requestDeviceCode,
|
||||
pollForToken,
|
||||
loadStoredToken,
|
||||
storeToken,
|
||||
getGitHubToken,
|
||||
loginGitHub,
|
||||
type DeviceCodeResponse,
|
||||
} from './github.js';
|
||||
+27
-41
@@ -43,7 +43,7 @@ export function registerTuiCommand(program: Command): void {
|
||||
|
||||
// Dynamic imports to keep CLI startup fast
|
||||
const { SessionStore, SessionManager } = await import('../session/index.js');
|
||||
const { AnthropicClient, OpenAIClient, OllamaClient, LlamaCppClient, ModelRouter } = await import('../models/index.js');
|
||||
const { AnthropicClient, OpenAIClient, OllamaClient, LlamaCppClient, GitHubModelsClient, GeminiClient, BedrockClient, ModelRouter } = await import('../models/index.js');
|
||||
const { MinimalTui, startFullscreenTui } = await import('../frontends/tui/index.js');
|
||||
const { NativeAgent } = await import('../backends/index.js');
|
||||
const { ToolRegistry, ToolExecutor, allBuiltinTools } = await import('../tools/index.js');
|
||||
@@ -56,49 +56,35 @@ export function registerTuiCommand(program: Command): void {
|
||||
const sessionManager = new SessionManager(sessionStore);
|
||||
const models = config.models;
|
||||
|
||||
// Build model router
|
||||
const defaultClient = new AnthropicClient({
|
||||
model: models.default.model,
|
||||
apiKey: models.default.api_key,
|
||||
authToken: models.default.auth_token,
|
||||
});
|
||||
|
||||
let fastClient;
|
||||
let complexClient;
|
||||
let localClient;
|
||||
|
||||
if (models.fast) {
|
||||
fastClient = new AnthropicClient({
|
||||
model: models.fast.model,
|
||||
apiKey: models.fast.api_key,
|
||||
authToken: models.fast.auth_token,
|
||||
});
|
||||
}
|
||||
|
||||
if (models.complex) {
|
||||
complexClient = new AnthropicClient({
|
||||
model: models.complex.model,
|
||||
apiKey: models.complex.api_key,
|
||||
authToken: models.complex.auth_token,
|
||||
});
|
||||
}
|
||||
|
||||
if (models.local) {
|
||||
if (models.local.provider === 'ollama') {
|
||||
localClient = new OllamaClient({
|
||||
model: models.local.model,
|
||||
host: models.local.endpoint,
|
||||
numGpu: models.local.num_gpu,
|
||||
});
|
||||
} else if (models.local.provider === 'llamacpp') {
|
||||
localClient = new LlamaCppClient({
|
||||
endpoint: models.local.endpoint ?? 'http://localhost:8080',
|
||||
model: models.local.model,
|
||||
authToken: models.local.auth_token,
|
||||
});
|
||||
// Provider-agnostic client factory for TUI
|
||||
function createClient(cfg: typeof models.default) {
|
||||
switch (cfg.provider) {
|
||||
case 'anthropic':
|
||||
return new AnthropicClient({ model: cfg.model, apiKey: cfg.api_key, authToken: cfg.auth_token });
|
||||
case 'openai':
|
||||
return new OpenAIClient({ model: cfg.model, apiKey: cfg.api_key });
|
||||
case 'gemini':
|
||||
return new GeminiClient({ model: cfg.model, apiKey: cfg.api_key });
|
||||
case 'ollama':
|
||||
return new OllamaClient({ model: cfg.model, host: cfg.endpoint, numGpu: cfg.num_gpu });
|
||||
case 'llamacpp':
|
||||
return new LlamaCppClient({ endpoint: cfg.endpoint ?? 'http://localhost:8080', model: cfg.model, authToken: cfg.auth_token });
|
||||
case 'openrouter':
|
||||
return new OpenAIClient({ model: cfg.model, apiKey: cfg.api_key ?? process.env.OPENROUTER_API_KEY, baseURL: cfg.endpoint ?? 'https://openrouter.ai/api/v1' });
|
||||
case 'bedrock':
|
||||
return new BedrockClient({ model: cfg.model, region: cfg.endpoint, accessKeyId: cfg.api_key, secretAccessKey: cfg.auth_token });
|
||||
case 'github':
|
||||
return new GitHubModelsClient({ model: cfg.model, apiKey: cfg.api_key, endpoint: cfg.endpoint });
|
||||
default:
|
||||
throw new Error(`Unknown provider: ${cfg.provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
const defaultClient = createClient(models.default);
|
||||
const fastClient = models.fast ? createClient(models.fast) : undefined;
|
||||
const complexClient = models.complex ? createClient(models.complex) : undefined;
|
||||
const localClient = models.local ? createClient(models.local) : undefined;
|
||||
|
||||
const fallbackChain = [];
|
||||
for (const providerName of models.fallback_chain) {
|
||||
if (providerName === 'openai') {
|
||||
|
||||
@@ -19,7 +19,7 @@ const serverSchema = z.object({
|
||||
});
|
||||
|
||||
const modelConfigSchema = z.object({
|
||||
provider: z.enum(['anthropic', 'openai', 'gemini', 'ollama', 'llamacpp', 'openrouter', 'bedrock']),
|
||||
provider: z.enum(['anthropic', 'openai', 'gemini', 'ollama', 'llamacpp', 'openrouter', 'bedrock', 'github']),
|
||||
model: z.string(),
|
||||
endpoint: z.string().optional(),
|
||||
api_key: z.string().optional(),
|
||||
|
||||
+7
-1
@@ -1,6 +1,6 @@
|
||||
import { Lifecycle } from './lifecycle.js';
|
||||
import type { Config, ModelConfig } from '../config/index.js';
|
||||
import { AnthropicClient, OpenAIClient, OllamaClient, LlamaCppClient, GeminiClient, BedrockClient, ModelRouter, DEFAULT_RETRY_CONFIG } from '../models/index.js';
|
||||
import { AnthropicClient, OpenAIClient, OllamaClient, LlamaCppClient, GeminiClient, BedrockClient, GitHubModelsClient, ModelRouter, DEFAULT_RETRY_CONFIG } from '../models/index.js';
|
||||
import type { ModelClient, RetryConfig } from '../models/index.js';
|
||||
import { AgentOrchestrator, type DelegationConfig } from '../backends/index.js';
|
||||
import { SessionStore, SessionManager } from '../session/index.js';
|
||||
@@ -109,6 +109,12 @@ export function createClientFromConfig(cfg: ModelConfig): ModelClient {
|
||||
accessKeyId: cfg.api_key,
|
||||
secretAccessKey: cfg.auth_token,
|
||||
});
|
||||
case 'github':
|
||||
return new GitHubModelsClient({
|
||||
model: cfg.model,
|
||||
apiKey: cfg.api_key,
|
||||
endpoint: cfg.endpoint,
|
||||
});
|
||||
default:
|
||||
throw new Error(`Unknown model provider: ${(cfg as Record<string, unknown>).provider}`);
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ export type Command =
|
||||
| { type: 'usage' }
|
||||
| { type: 'model'; name?: string }
|
||||
| { type: 'backend'; provider?: string }
|
||||
| { type: 'login'; provider?: string }
|
||||
| { type: 'transfer'; target: string }
|
||||
| { type: 'message'; content: string };
|
||||
|
||||
@@ -74,6 +75,15 @@ export function parseCommand(input: string): Command | null {
|
||||
return { type: 'transfer', target };
|
||||
}
|
||||
|
||||
// Login
|
||||
if (trimmed === '/login') {
|
||||
return { type: 'login' };
|
||||
}
|
||||
if (trimmed.startsWith('/login ')) {
|
||||
const provider = trimmed.slice('/login '.length).trim();
|
||||
return { type: 'login', provider: provider || undefined };
|
||||
}
|
||||
|
||||
// Regular message
|
||||
return { type: 'message', content: trimmed };
|
||||
}
|
||||
@@ -84,6 +94,7 @@ Commands:
|
||||
/help, /? Show this help
|
||||
/model [name] Show or switch model (local, default, fast, complex)
|
||||
/backend [provider] Show or switch local backend (ollama, llamacpp)
|
||||
/login [provider] Authenticate with GitHub
|
||||
/reset, /clear, /new Clear conversation history
|
||||
/compact Compact conversation history
|
||||
/usage Show token usage and estimated cost
|
||||
@@ -109,6 +120,7 @@ export const SLASH_COMMANDS = [
|
||||
'/status',
|
||||
'/fullscreen',
|
||||
'/fs',
|
||||
'/login',
|
||||
'/transfer',
|
||||
'/quit',
|
||||
'/exit',
|
||||
@@ -127,6 +139,7 @@ export const COMMAND_TOOLTIPS: Record<string, string> = {
|
||||
'/status': 'Show session info and token usage',
|
||||
'/fullscreen': 'Switch to fullscreen mode',
|
||||
'/fs': 'Switch to fullscreen mode',
|
||||
'/login': 'Authenticate with GitHub (OAuth device flow)',
|
||||
'/transfer': 'Transfer session to another frontend',
|
||||
'/quit': 'Exit TUI',
|
||||
'/exit': 'Exit TUI',
|
||||
|
||||
@@ -7,6 +7,7 @@ import { parseCommand, getHelpText, resolveModelAlias, getCommandCompletions, ge
|
||||
import { renderMarkdown } from './markdown.js';
|
||||
import type { ModelConfig } from '../../config/schema.js';
|
||||
import { OllamaClient, LlamaCppClient } from '../../models/index.js';
|
||||
import { loginGitHub } from '../../auth/index.js';
|
||||
|
||||
export { parseCommand, type Command };
|
||||
|
||||
@@ -186,6 +187,10 @@ export class MinimalTui {
|
||||
this.handleBackendCommand(command.provider);
|
||||
break;
|
||||
|
||||
case 'login':
|
||||
await this.handleLoginCommand(command.provider);
|
||||
break;
|
||||
|
||||
case 'transfer':
|
||||
this.config.onTransfer?.(command.target);
|
||||
break;
|
||||
@@ -256,6 +261,31 @@ export class MinimalTui {
|
||||
console.log(`Switched to backend: ${provider}\n`);
|
||||
}
|
||||
|
||||
private async handleLoginCommand(provider?: string): Promise<void> {
|
||||
const target = provider ?? 'github';
|
||||
if (target !== 'github') {
|
||||
console.log(`${colors.gray}Unknown login provider:${colors.reset} ${target}. Only 'github' is supported.\n`);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`${colors.gray}Starting GitHub OAuth device flow...${colors.reset}`);
|
||||
|
||||
try {
|
||||
await loginGitHub((userCode, verificationUri) => {
|
||||
console.log('');
|
||||
console.log(`${colors.gray}Please visit:${colors.reset} ${verificationUri}`);
|
||||
console.log(`${colors.gray}and enter code:${colors.reset} ${userCode}`);
|
||||
console.log('');
|
||||
console.log(`${colors.gray}Waiting for authorization...${colors.reset}`);
|
||||
});
|
||||
|
||||
console.log(`${colors.gray}GitHub authentication successful! Token stored.${colors.reset}\n`);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.log(`${colors.gray}GitHub login failed:${colors.reset} ${message}\n`);
|
||||
}
|
||||
}
|
||||
|
||||
private getAvailableBackends(): string[] {
|
||||
const backends: string[] = [];
|
||||
if (this.config.currentLocalProvider) {
|
||||
|
||||
@@ -14,6 +14,11 @@ export const MODEL_COSTS_PER_MILLION: Record<string, { input: number; output: nu
|
||||
'gemini-2.5-flash': { input: 0.15, output: 0.60 },
|
||||
'gemini-1.5-pro': { input: 1.25, output: 5 },
|
||||
'gemini-1.5-flash': { input: 0.075, output: 0.30 },
|
||||
// GitHub Copilot (included in subscription, tracked at $0)
|
||||
'gpt-4.1': { input: 0, output: 0 },
|
||||
'gpt-4.1-mini': { input: 0, output: 0 },
|
||||
'claude-sonnet-4': { input: 0, output: 0 },
|
||||
'claude-haiku-4': { input: 0, output: 0 },
|
||||
// Local / unknown models
|
||||
'default': { input: 0, output: 0 },
|
||||
// Bedrock (Meta Llama)
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
import OpenAI from 'openai';
|
||||
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, MessageContentPart } from './types.js';
|
||||
import { getGitHubToken } from '../auth/index.js';
|
||||
|
||||
export interface GitHubModelsClientConfig {
|
||||
apiKey?: string; // GitHub PAT or gh auth token. Falls back to GITHUB_TOKEN env var
|
||||
model: string; // e.g., 'gpt-4o' or 'claude-sonnet-4'
|
||||
maxTokens?: number;
|
||||
endpoint?: string; // Override base URL (default: https://models.github.ai/inference)
|
||||
}
|
||||
|
||||
const DEFAULT_ENDPOINT = 'https://api.githubcopilot.com';
|
||||
|
||||
/**
|
||||
* Convert Flynn message content to OpenAI format.
|
||||
* Reuses the same pattern as openai.ts since GitHub Models uses an OpenAI-compatible API.
|
||||
*/
|
||||
function toOpenAIContent(content: string | MessageContentPart[]): string | OpenAI.ChatCompletionContentPart[] {
|
||||
if (typeof content === 'string') {
|
||||
return content;
|
||||
}
|
||||
|
||||
return content.map((part): OpenAI.ChatCompletionContentPart => {
|
||||
if (part.type === 'text') {
|
||||
return { type: 'text', text: part.text };
|
||||
}
|
||||
if (part.type === 'image') {
|
||||
const url = part.source.type === 'base64'
|
||||
? `data:${part.source.media_type};base64,${part.source.data!}`
|
||||
: part.source.url!;
|
||||
return { type: 'image_url', image_url: { url } };
|
||||
}
|
||||
// Fallback — shouldn't happen
|
||||
return { type: 'text', text: JSON.stringify(part) };
|
||||
});
|
||||
}
|
||||
|
||||
export class GitHubModelsClient implements ModelClient {
|
||||
private client: OpenAI;
|
||||
private model: string;
|
||||
private defaultMaxTokens: number;
|
||||
|
||||
constructor(config: GitHubModelsClientConfig) {
|
||||
const apiKey = config.apiKey ?? getGitHubToken() ?? '';
|
||||
const baseURL = config.endpoint ?? DEFAULT_ENDPOINT;
|
||||
|
||||
this.client = new OpenAI({
|
||||
apiKey,
|
||||
baseURL,
|
||||
defaultHeaders: {
|
||||
'X-GitHub-Api-Version': '2022-11-28',
|
||||
'Openai-Intent': 'conversation-edits',
|
||||
},
|
||||
});
|
||||
this.model = config.model;
|
||||
this.defaultMaxTokens = config.maxTokens ?? 4096;
|
||||
}
|
||||
|
||||
async chat(request: ChatRequest): Promise<ChatResponse> {
|
||||
const messages: OpenAI.ChatCompletionMessageParam[] = [];
|
||||
|
||||
if (request.system) {
|
||||
messages.push({ role: 'system', content: request.system });
|
||||
}
|
||||
|
||||
for (const msg of request.messages) {
|
||||
messages.push({
|
||||
role: msg.role,
|
||||
content: toOpenAIContent(msg.content),
|
||||
} as OpenAI.ChatCompletionMessageParam);
|
||||
}
|
||||
|
||||
const params: OpenAI.ChatCompletionCreateParamsNonStreaming = {
|
||||
model: this.model,
|
||||
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
|
||||
messages,
|
||||
};
|
||||
|
||||
if (request.tools && request.tools.length > 0) {
|
||||
params.tools = request.tools.map(t => ({
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema as OpenAI.FunctionParameters,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
const response = await this.client.chat.completions.create(params);
|
||||
|
||||
const choice = response.choices[0];
|
||||
const content = choice?.message?.content ?? '';
|
||||
|
||||
const toolCalls = choice?.message?.tool_calls?.map((tc: OpenAI.ChatCompletionMessageToolCall) => ({
|
||||
id: tc.id,
|
||||
name: tc.function.name,
|
||||
args: JSON.parse(tc.function.arguments),
|
||||
})) ?? [];
|
||||
|
||||
// Map OpenAI finish reasons to Flynn's stop reasons
|
||||
let stopReason: string;
|
||||
if (toolCalls.length > 0) {
|
||||
stopReason = 'tool_use';
|
||||
} else {
|
||||
const reason = choice?.finish_reason;
|
||||
if (reason === 'stop') {
|
||||
stopReason = 'end_turn';
|
||||
} else if (reason === 'length') {
|
||||
stopReason = 'max_tokens';
|
||||
} else {
|
||||
stopReason = reason ?? 'end_turn';
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content,
|
||||
stopReason,
|
||||
usage: {
|
||||
inputTokens: response.usage?.prompt_tokens ?? 0,
|
||||
outputTokens: response.usage?.completion_tokens ?? 0,
|
||||
},
|
||||
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
|
||||
const messages: OpenAI.ChatCompletionMessageParam[] = [];
|
||||
|
||||
if (request.system) {
|
||||
messages.push({ role: 'system', content: request.system });
|
||||
}
|
||||
|
||||
for (const msg of request.messages) {
|
||||
messages.push({
|
||||
role: msg.role,
|
||||
content: toOpenAIContent(msg.content),
|
||||
} as OpenAI.ChatCompletionMessageParam);
|
||||
}
|
||||
|
||||
const params: OpenAI.ChatCompletionCreateParamsStreaming = {
|
||||
model: this.model,
|
||||
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (request.tools && request.tools.length > 0) {
|
||||
params.tools = request.tools.map(t => ({
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description,
|
||||
parameters: t.input_schema as OpenAI.FunctionParameters,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
try {
|
||||
const stream = await this.client.chat.completions.create(params);
|
||||
|
||||
let totalInputTokens = 0;
|
||||
let totalOutputTokens = 0;
|
||||
|
||||
// Accumulate tool call deltas across chunks
|
||||
const toolCallAccumulator = new Map<number, {
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
}>();
|
||||
|
||||
for await (const chunk of stream) {
|
||||
const delta = chunk.choices[0]?.delta;
|
||||
const finishReason = chunk.choices[0]?.finish_reason;
|
||||
|
||||
// Emit text content deltas
|
||||
if (delta?.content) {
|
||||
yield { type: 'content', content: delta.content };
|
||||
}
|
||||
|
||||
// Accumulate tool call deltas
|
||||
if (delta?.tool_calls) {
|
||||
for (const tc of delta.tool_calls) {
|
||||
const existing = toolCallAccumulator.get(tc.index);
|
||||
if (existing) {
|
||||
// Append argument fragments
|
||||
if (tc.function?.arguments) {
|
||||
existing.arguments += tc.function.arguments;
|
||||
}
|
||||
} else {
|
||||
toolCallAccumulator.set(tc.index, {
|
||||
id: tc.id ?? '',
|
||||
name: tc.function?.name ?? '',
|
||||
arguments: tc.function?.arguments ?? '',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track usage from chunks (some providers include it in stream)
|
||||
if (chunk.usage) {
|
||||
totalInputTokens = chunk.usage.prompt_tokens ?? totalInputTokens;
|
||||
totalOutputTokens = chunk.usage.completion_tokens ?? totalOutputTokens;
|
||||
}
|
||||
|
||||
// On finish, emit accumulated tool calls
|
||||
if (finishReason === 'tool_calls' || finishReason === 'stop') {
|
||||
for (const [, tc] of toolCallAccumulator) {
|
||||
yield {
|
||||
type: 'tool_use',
|
||||
toolCall: {
|
||||
id: tc.id,
|
||||
name: tc.name,
|
||||
args: JSON.parse(tc.arguments || '{}'),
|
||||
},
|
||||
};
|
||||
}
|
||||
toolCallAccumulator.clear();
|
||||
}
|
||||
}
|
||||
|
||||
yield {
|
||||
type: 'done',
|
||||
usage: {
|
||||
inputTokens: totalInputTokens,
|
||||
outputTokens: totalOutputTokens,
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
yield {
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ export { AnthropicClient, type AnthropicClientConfig } from './anthropic.js';
|
||||
export { OpenAIClient, type OpenAIClientConfig } from './openai.js';
|
||||
export { GeminiClient, type GeminiClientConfig } from './gemini.js';
|
||||
export { BedrockClient, type BedrockClientConfig } from './bedrock.js';
|
||||
export { GitHubModelsClient, type GitHubModelsClientConfig } from './github.js';
|
||||
export { OllamaClient, type OllamaClientConfig } from './local/index.js';
|
||||
export { LlamaCppClient, type LlamaCppClientConfig } from './local/index.js';
|
||||
export { ModelRouter, type ModelRouterConfig, type ModelTier } from './router.js';
|
||||
|
||||
Reference in New Issue
Block a user