Files
flynn/src/models/github.ts
T

317 lines
9.8 KiB
TypeScript

import OpenAI from 'openai';
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, MessageContentPart } from './types.js';
import { getGitHubToken } from '../auth/index.js';
export interface GitHubModelsClientConfig {
apiKey?: string; // GitHub PAT or gh auth token. Falls back to GITHUB_TOKEN env var
model: string; // e.g., 'gpt-4o' or 'claude-sonnet-4'
maxTokens?: number;
endpoint?: string; // Override base URL (default: https://api.githubcopilot.com)
/**
* Optional callback invoked when no token is available at API call time.
* Should return a valid token (e.g. by running the OAuth device flow).
* If not provided and no token is available, API calls will fail with auth errors.
*/
onLoginRequired?: () => Promise<string>;
}
const DEFAULT_ENDPOINT = 'https://api.githubcopilot.com';
/**
* Convert Flynn message content to OpenAI format.
* Reuses the same pattern as openai.ts since GitHub Models uses an OpenAI-compatible API.
*/
function toOpenAIContent(content: string | MessageContentPart[]): string | OpenAI.ChatCompletionContentPart[] {
if (typeof content === 'string') {
return content;
}
return content.map((part): OpenAI.ChatCompletionContentPart => {
if (part.type === 'text') {
return { type: 'text', text: part.text };
}
if (part.type === 'image') {
if (part.source.type === 'base64' && !part.source.data) {
return { type: 'text', text: '[Image omitted: missing base64 data]' };
}
if (part.source.type !== 'base64' && !part.source.url) {
return { type: 'text', text: '[Image omitted: missing URL]' };
}
const url = part.source.type === 'base64'
? `data:${part.source.media_type};base64,${part.source.data}`
: (part.source.url ?? '');
return { type: 'image_url', image_url: { url } };
}
if (part.type === 'audio') {
// GitHub Models uses OpenAI-compatible API — native audio via input_audio
const formatMap: Record<string, string> = {
'audio/wav': 'wav',
'audio/mpeg': 'mp3',
'audio/mp3': 'mp3',
'audio/ogg': 'ogg',
'audio/webm': 'webm',
'audio/mp4': 'mp4',
'audio/x-m4a': 'mp4',
};
const format = formatMap[part.source.media_type] ?? 'wav';
return {
type: 'input_audio',
input_audio: { data: part.source.data, format },
} as unknown as OpenAI.ChatCompletionContentPart;
}
// Fallback — shouldn't happen
return { type: 'text', text: JSON.stringify(part) };
});
}
export class GitHubModelsClient implements ModelClient {
private client: OpenAI;
private model: string;
private defaultMaxTokens: number;
private baseURL: string;
private onLoginRequired?: () => Promise<string>;
private tokenResolved = false;
constructor(config: GitHubModelsClientConfig) {
const apiKey = config.apiKey ?? getGitHubToken() ?? '';
this.baseURL = config.endpoint ?? DEFAULT_ENDPOINT;
this.onLoginRequired = config.onLoginRequired;
this.tokenResolved = !!apiKey;
this.client = new OpenAI({
apiKey: apiKey || 'placeholder',
baseURL: this.baseURL,
defaultHeaders: {
'Openai-Intent': 'conversation-edits',
},
});
this.model = config.model;
this.defaultMaxTokens = config.maxTokens ?? 4096;
}
/**
* Ensure we have a valid token before making an API call.
* If no token was resolved at construction time and an onLoginRequired
* callback is provided, invoke it to obtain a token (e.g. via OAuth device flow).
*/
private async ensureToken(): Promise<void> {
if (this.tokenResolved) {return;}
// Try resolving again (user might have logged in via /login since construction)
const token = getGitHubToken();
if (token) {
this.rebuildClient(token);
return;
}
// Trigger auto-login if callback provided
if (this.onLoginRequired) {
const newToken = await this.onLoginRequired();
this.rebuildClient(newToken);
}
// No token and no callback — the API call will fail with an auth error
}
/** Rebuild the OpenAI client with a new API key. */
private rebuildClient(apiKey: string): void {
this.client = new OpenAI({
apiKey,
baseURL: this.baseURL,
defaultHeaders: {
'Openai-Intent': 'conversation-edits',
},
});
this.tokenResolved = true;
}
async chat(request: ChatRequest): Promise<ChatResponse> {
await this.ensureToken();
const messages: OpenAI.ChatCompletionMessageParam[] = [];
if (request.system) {
messages.push({ role: 'system', content: request.system });
}
for (const msg of request.messages) {
messages.push({
role: msg.role,
content: toOpenAIContent(msg.content),
} as OpenAI.ChatCompletionMessageParam);
}
const params: OpenAI.ChatCompletionCreateParamsNonStreaming = {
model: this.model,
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
messages,
};
if (request.tools && request.tools.length > 0) {
params.tools = request.tools.map(t => ({
type: 'function' as const,
function: {
name: t.name,
description: t.description,
parameters: t.input_schema as OpenAI.FunctionParameters,
},
}));
}
// Extended thinking/reasoning mode
if (request.thinking) {
(params as OpenAI.ChatCompletionCreateParamsNonStreaming & { reasoning_effort?: 'low' | 'medium' | 'high' }).reasoning_effort = 'medium';
}
const response = await this.client.chat.completions.create(params);
const choice = response.choices[0];
const content = choice?.message?.content ?? '';
const toolCalls = choice?.message?.tool_calls?.map((tc: OpenAI.ChatCompletionMessageToolCall) => ({
id: tc.id,
name: tc.function.name,
args: JSON.parse(tc.function.arguments),
})) ?? [];
// Map OpenAI finish reasons to Flynn's stop reasons
let stopReason: string;
if (toolCalls.length > 0) {
stopReason = 'tool_use';
} else {
const reason = choice?.finish_reason;
if (reason === 'stop') {
stopReason = 'end_turn';
} else if (reason === 'length') {
stopReason = 'max_tokens';
} else if (reason === 'tool_calls') {
// Edge case: finish_reason says tool_calls but none were parsed
stopReason = 'end_turn';
} else {
stopReason = reason ?? 'end_turn';
}
}
return {
content,
stopReason,
usage: {
inputTokens: response.usage?.prompt_tokens ?? 0,
outputTokens: response.usage?.completion_tokens ?? 0,
},
...(toolCalls.length > 0 ? { toolCalls } : {}),
};
}
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
await this.ensureToken();
const messages: OpenAI.ChatCompletionMessageParam[] = [];
if (request.system) {
messages.push({ role: 'system', content: request.system });
}
for (const msg of request.messages) {
messages.push({
role: msg.role,
content: toOpenAIContent(msg.content),
} as OpenAI.ChatCompletionMessageParam);
}
const params: OpenAI.ChatCompletionCreateParamsStreaming = {
model: this.model,
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
messages,
stream: true,
};
if (request.tools && request.tools.length > 0) {
params.tools = request.tools.map(t => ({
type: 'function' as const,
function: {
name: t.name,
description: t.description,
parameters: t.input_schema as OpenAI.FunctionParameters,
},
}));
}
try {
const stream = await this.client.chat.completions.create(params);
let totalInputTokens = 0;
let totalOutputTokens = 0;
// Accumulate tool call deltas across chunks
const toolCallAccumulator = new Map<number, {
id: string;
name: string;
arguments: string;
}>();
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta;
const finishReason = chunk.choices[0]?.finish_reason;
// Emit text content deltas
if (delta?.content) {
yield { type: 'content', content: delta.content };
}
// Accumulate tool call deltas
if (delta?.tool_calls) {
for (const tc of delta.tool_calls) {
const existing = toolCallAccumulator.get(tc.index);
if (existing) {
// Append argument fragments
if (tc.function?.arguments) {
existing.arguments += tc.function.arguments;
}
} else {
toolCallAccumulator.set(tc.index, {
id: tc.id ?? '',
name: tc.function?.name ?? '',
arguments: tc.function?.arguments ?? '',
});
}
}
}
// Track usage from chunks (some providers include it in stream)
if (chunk.usage) {
totalInputTokens = chunk.usage.prompt_tokens ?? totalInputTokens;
totalOutputTokens = chunk.usage.completion_tokens ?? totalOutputTokens;
}
// On finish, emit accumulated tool calls
if (finishReason === 'tool_calls' || finishReason === 'stop') {
for (const [, tc] of toolCallAccumulator) {
yield {
type: 'tool_use',
toolCall: {
id: tc.id,
name: tc.name,
args: JSON.parse(tc.arguments || '{}'),
},
};
}
toolCallAccumulator.clear();
}
}
yield {
type: 'done',
usage: {
inputTokens: totalInputTokens,
outputTokens: totalOutputTokens,
},
};
} catch (error) {
yield {
type: 'error',
error: error instanceof Error ? error : new Error(String(error)),
};
}
}
}