317 lines
9.8 KiB
TypeScript
317 lines
9.8 KiB
TypeScript
import OpenAI from 'openai';
|
|
import type { ChatRequest, ChatResponse, ChatStreamEvent, ModelClient, MessageContentPart } from './types.js';
|
|
import { getGitHubToken } from '../auth/index.js';
|
|
|
|
export interface GitHubModelsClientConfig {
|
|
apiKey?: string; // GitHub PAT or gh auth token. Falls back to GITHUB_TOKEN env var
|
|
model: string; // e.g., 'gpt-4o' or 'claude-sonnet-4'
|
|
maxTokens?: number;
|
|
endpoint?: string; // Override base URL (default: https://api.githubcopilot.com)
|
|
/**
|
|
* Optional callback invoked when no token is available at API call time.
|
|
* Should return a valid token (e.g. by running the OAuth device flow).
|
|
* If not provided and no token is available, API calls will fail with auth errors.
|
|
*/
|
|
onLoginRequired?: () => Promise<string>;
|
|
}
|
|
|
|
const DEFAULT_ENDPOINT = 'https://api.githubcopilot.com';
|
|
|
|
/**
|
|
* Convert Flynn message content to OpenAI format.
|
|
* Reuses the same pattern as openai.ts since GitHub Models uses an OpenAI-compatible API.
|
|
*/
|
|
function toOpenAIContent(content: string | MessageContentPart[]): string | OpenAI.ChatCompletionContentPart[] {
|
|
if (typeof content === 'string') {
|
|
return content;
|
|
}
|
|
|
|
return content.map((part): OpenAI.ChatCompletionContentPart => {
|
|
if (part.type === 'text') {
|
|
return { type: 'text', text: part.text };
|
|
}
|
|
if (part.type === 'image') {
|
|
if (part.source.type === 'base64' && !part.source.data) {
|
|
return { type: 'text', text: '[Image omitted: missing base64 data]' };
|
|
}
|
|
if (part.source.type !== 'base64' && !part.source.url) {
|
|
return { type: 'text', text: '[Image omitted: missing URL]' };
|
|
}
|
|
const url = part.source.type === 'base64'
|
|
? `data:${part.source.media_type};base64,${part.source.data}`
|
|
: (part.source.url ?? '');
|
|
return { type: 'image_url', image_url: { url } };
|
|
}
|
|
if (part.type === 'audio') {
|
|
// GitHub Models uses OpenAI-compatible API — native audio via input_audio
|
|
const formatMap: Record<string, string> = {
|
|
'audio/wav': 'wav',
|
|
'audio/mpeg': 'mp3',
|
|
'audio/mp3': 'mp3',
|
|
'audio/ogg': 'ogg',
|
|
'audio/webm': 'webm',
|
|
'audio/mp4': 'mp4',
|
|
'audio/x-m4a': 'mp4',
|
|
};
|
|
const format = formatMap[part.source.media_type] ?? 'wav';
|
|
return {
|
|
type: 'input_audio',
|
|
input_audio: { data: part.source.data, format },
|
|
} as unknown as OpenAI.ChatCompletionContentPart;
|
|
}
|
|
// Fallback — shouldn't happen
|
|
return { type: 'text', text: JSON.stringify(part) };
|
|
});
|
|
}
|
|
|
|
export class GitHubModelsClient implements ModelClient {
|
|
private client: OpenAI;
|
|
private model: string;
|
|
private defaultMaxTokens: number;
|
|
private baseURL: string;
|
|
private onLoginRequired?: () => Promise<string>;
|
|
private tokenResolved = false;
|
|
|
|
constructor(config: GitHubModelsClientConfig) {
|
|
const apiKey = config.apiKey ?? getGitHubToken() ?? '';
|
|
this.baseURL = config.endpoint ?? DEFAULT_ENDPOINT;
|
|
this.onLoginRequired = config.onLoginRequired;
|
|
this.tokenResolved = !!apiKey;
|
|
|
|
this.client = new OpenAI({
|
|
apiKey: apiKey || 'placeholder',
|
|
baseURL: this.baseURL,
|
|
defaultHeaders: {
|
|
'Openai-Intent': 'conversation-edits',
|
|
},
|
|
});
|
|
this.model = config.model;
|
|
this.defaultMaxTokens = config.maxTokens ?? 4096;
|
|
}
|
|
|
|
/**
|
|
* Ensure we have a valid token before making an API call.
|
|
* If no token was resolved at construction time and an onLoginRequired
|
|
* callback is provided, invoke it to obtain a token (e.g. via OAuth device flow).
|
|
*/
|
|
private async ensureToken(): Promise<void> {
|
|
if (this.tokenResolved) {return;}
|
|
|
|
// Try resolving again (user might have logged in via /login since construction)
|
|
const token = getGitHubToken();
|
|
if (token) {
|
|
this.rebuildClient(token);
|
|
return;
|
|
}
|
|
|
|
// Trigger auto-login if callback provided
|
|
if (this.onLoginRequired) {
|
|
const newToken = await this.onLoginRequired();
|
|
this.rebuildClient(newToken);
|
|
|
|
}
|
|
|
|
// No token and no callback — the API call will fail with an auth error
|
|
}
|
|
|
|
/** Rebuild the OpenAI client with a new API key. */
|
|
private rebuildClient(apiKey: string): void {
|
|
this.client = new OpenAI({
|
|
apiKey,
|
|
baseURL: this.baseURL,
|
|
defaultHeaders: {
|
|
'Openai-Intent': 'conversation-edits',
|
|
},
|
|
});
|
|
this.tokenResolved = true;
|
|
}
|
|
|
|
async chat(request: ChatRequest): Promise<ChatResponse> {
|
|
await this.ensureToken();
|
|
const messages: OpenAI.ChatCompletionMessageParam[] = [];
|
|
|
|
if (request.system) {
|
|
messages.push({ role: 'system', content: request.system });
|
|
}
|
|
|
|
for (const msg of request.messages) {
|
|
messages.push({
|
|
role: msg.role,
|
|
content: toOpenAIContent(msg.content),
|
|
} as OpenAI.ChatCompletionMessageParam);
|
|
}
|
|
|
|
const params: OpenAI.ChatCompletionCreateParamsNonStreaming = {
|
|
model: this.model,
|
|
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
|
|
messages,
|
|
};
|
|
|
|
if (request.tools && request.tools.length > 0) {
|
|
params.tools = request.tools.map(t => ({
|
|
type: 'function' as const,
|
|
function: {
|
|
name: t.name,
|
|
description: t.description,
|
|
parameters: t.input_schema as OpenAI.FunctionParameters,
|
|
},
|
|
}));
|
|
}
|
|
|
|
// Extended thinking/reasoning mode
|
|
if (request.thinking) {
|
|
(params as OpenAI.ChatCompletionCreateParamsNonStreaming & { reasoning_effort?: 'low' | 'medium' | 'high' }).reasoning_effort = 'medium';
|
|
}
|
|
|
|
const response = await this.client.chat.completions.create(params);
|
|
|
|
const choice = response.choices[0];
|
|
const content = choice?.message?.content ?? '';
|
|
|
|
const toolCalls = choice?.message?.tool_calls?.map((tc: OpenAI.ChatCompletionMessageToolCall) => ({
|
|
id: tc.id,
|
|
name: tc.function.name,
|
|
args: JSON.parse(tc.function.arguments),
|
|
})) ?? [];
|
|
|
|
// Map OpenAI finish reasons to Flynn's stop reasons
|
|
let stopReason: string;
|
|
if (toolCalls.length > 0) {
|
|
stopReason = 'tool_use';
|
|
} else {
|
|
const reason = choice?.finish_reason;
|
|
if (reason === 'stop') {
|
|
stopReason = 'end_turn';
|
|
} else if (reason === 'length') {
|
|
stopReason = 'max_tokens';
|
|
} else if (reason === 'tool_calls') {
|
|
// Edge case: finish_reason says tool_calls but none were parsed
|
|
stopReason = 'end_turn';
|
|
} else {
|
|
stopReason = reason ?? 'end_turn';
|
|
}
|
|
}
|
|
|
|
return {
|
|
content,
|
|
stopReason,
|
|
usage: {
|
|
inputTokens: response.usage?.prompt_tokens ?? 0,
|
|
outputTokens: response.usage?.completion_tokens ?? 0,
|
|
},
|
|
...(toolCalls.length > 0 ? { toolCalls } : {}),
|
|
};
|
|
}
|
|
|
|
async *chatStream(request: ChatRequest): AsyncIterable<ChatStreamEvent> {
|
|
await this.ensureToken();
|
|
const messages: OpenAI.ChatCompletionMessageParam[] = [];
|
|
|
|
if (request.system) {
|
|
messages.push({ role: 'system', content: request.system });
|
|
}
|
|
|
|
for (const msg of request.messages) {
|
|
messages.push({
|
|
role: msg.role,
|
|
content: toOpenAIContent(msg.content),
|
|
} as OpenAI.ChatCompletionMessageParam);
|
|
}
|
|
|
|
const params: OpenAI.ChatCompletionCreateParamsStreaming = {
|
|
model: this.model,
|
|
max_tokens: request.maxTokens ?? this.defaultMaxTokens,
|
|
messages,
|
|
stream: true,
|
|
};
|
|
|
|
if (request.tools && request.tools.length > 0) {
|
|
params.tools = request.tools.map(t => ({
|
|
type: 'function' as const,
|
|
function: {
|
|
name: t.name,
|
|
description: t.description,
|
|
parameters: t.input_schema as OpenAI.FunctionParameters,
|
|
},
|
|
}));
|
|
}
|
|
|
|
try {
|
|
const stream = await this.client.chat.completions.create(params);
|
|
|
|
let totalInputTokens = 0;
|
|
let totalOutputTokens = 0;
|
|
|
|
// Accumulate tool call deltas across chunks
|
|
const toolCallAccumulator = new Map<number, {
|
|
id: string;
|
|
name: string;
|
|
arguments: string;
|
|
}>();
|
|
|
|
for await (const chunk of stream) {
|
|
const delta = chunk.choices[0]?.delta;
|
|
const finishReason = chunk.choices[0]?.finish_reason;
|
|
|
|
// Emit text content deltas
|
|
if (delta?.content) {
|
|
yield { type: 'content', content: delta.content };
|
|
}
|
|
|
|
// Accumulate tool call deltas
|
|
if (delta?.tool_calls) {
|
|
for (const tc of delta.tool_calls) {
|
|
const existing = toolCallAccumulator.get(tc.index);
|
|
if (existing) {
|
|
// Append argument fragments
|
|
if (tc.function?.arguments) {
|
|
existing.arguments += tc.function.arguments;
|
|
}
|
|
} else {
|
|
toolCallAccumulator.set(tc.index, {
|
|
id: tc.id ?? '',
|
|
name: tc.function?.name ?? '',
|
|
arguments: tc.function?.arguments ?? '',
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Track usage from chunks (some providers include it in stream)
|
|
if (chunk.usage) {
|
|
totalInputTokens = chunk.usage.prompt_tokens ?? totalInputTokens;
|
|
totalOutputTokens = chunk.usage.completion_tokens ?? totalOutputTokens;
|
|
}
|
|
|
|
// On finish, emit accumulated tool calls
|
|
if (finishReason === 'tool_calls' || finishReason === 'stop') {
|
|
for (const [, tc] of toolCallAccumulator) {
|
|
yield {
|
|
type: 'tool_use',
|
|
toolCall: {
|
|
id: tc.id,
|
|
name: tc.name,
|
|
args: JSON.parse(tc.arguments || '{}'),
|
|
},
|
|
};
|
|
}
|
|
toolCallAccumulator.clear();
|
|
}
|
|
}
|
|
|
|
yield {
|
|
type: 'done',
|
|
usage: {
|
|
inputTokens: totalInputTokens,
|
|
outputTokens: totalOutputTokens,
|
|
},
|
|
};
|
|
} catch (error) {
|
|
yield {
|
|
type: 'error',
|
|
error: error instanceof Error ? error : new Error(String(error)),
|
|
};
|
|
}
|
|
}
|
|
}
|