feat(pi): support Agent runtime export in pi_embedded backend

This commit is contained in:
William Valentin
2026-02-23 21:59:51 -08:00
parent 559fe61168
commit e8204f5d42
2 changed files with 225 additions and 27 deletions
+29 -1
View File
@@ -66,7 +66,35 @@ describe('PiEmbeddedBackend', () => {
try { try {
const backend = new PiEmbeddedBackend({ module: mod.moduleUrl, timeoutMs: 2000 }); const backend = new PiEmbeddedBackend({ module: mod.moduleUrl, timeoutMs: 2000 });
await expect(backend.process({ prompt: 'hello', history: [] })) await expect(backend.process({ prompt: 'hello', history: [] }))
.rejects.toThrow('supported session factory'); .rejects.toThrow('supported runtime API');
} finally {
mod.cleanup();
}
});
it('uses Agent class runtime when session factory exports are absent', async () => {
const mod = createModule(`
export class Agent {
constructor() {
this.state = { messages: [] };
}
replaceMessages(messages) {
this.state.messages = messages.slice();
}
async prompt(input) {
this.state.messages.push({ role: "user", content: [{ type: "text", text: input }] });
this.state.messages.push({ role: "assistant", content: [{ type: "text", text: "agent says: " + input }] });
}
}
`);
try {
const backend = new PiEmbeddedBackend({ module: mod.moduleUrl, timeoutMs: 2000 });
const result = await backend.process({
prompt: 'hello',
history: [{ role: 'assistant', content: 'previous answer' }],
});
expect(result).toBe('agent says: hello');
} finally { } finally {
mod.cleanup(); mod.cleanup();
} }
+196 -26
View File
@@ -1,3 +1,5 @@
import { createRequire } from 'module';
import { pathToFileURL } from 'url';
import type { ExternalBackend, ExternalBackendRequest } from './external.js'; import type { ExternalBackend, ExternalBackendRequest } from './external.js';
const DEFAULT_TIMEOUT_MS = 120_000; const DEFAULT_TIMEOUT_MS = 120_000;
@@ -7,6 +9,12 @@ const DEFAULT_MODULE_CANDIDATES = [
'@openclaw/pi-agent-core', '@openclaw/pi-agent-core',
'pi-agent-core', 'pi-agent-core',
] as const; ] as const;
const PI_AI_MODULE_CANDIDATES = [
'@mariozechner/pi-ai',
'@badlogic/pi-ai',
'@openclaw/pi-ai',
'pi-ai',
] as const;
type PiSystemPromptMode = 'flynn' | 'pi_default' | 'hybrid'; type PiSystemPromptMode = 'flynn' | 'pi_default' | 'hybrid';
@@ -18,6 +26,10 @@ interface PiSessionLike {
[key: string]: unknown; [key: string]: unknown;
} }
interface PiModelFactoryModuleLike extends PiModuleLike {
getModel?: (provider: string, modelId: string) => unknown;
}
export interface PiEmbeddedBackendOptions { export interface PiEmbeddedBackendOptions {
timeoutMs?: number; timeoutMs?: number;
model?: string; model?: string;
@@ -87,6 +99,45 @@ function getSessionFactory(moduleLike: PiModuleLike): ((args: Record<string, unk
return undefined; return undefined;
} }
function getAgentConstructor(moduleLike: PiModuleLike): (new (options?: Record<string, unknown>) => unknown) | undefined {
const direct = moduleLike.Agent;
if (typeof direct === 'function') {
return direct as new (options?: Record<string, unknown>) => unknown;
}
const defaultExport = moduleLike.default;
if (defaultExport && typeof defaultExport === 'object') {
const nested = (defaultExport as PiModuleLike).Agent;
if (typeof nested === 'function') {
return nested as new (options?: Record<string, unknown>) => unknown;
}
}
return undefined;
}
function parseModelSpec(raw: string): { provider: string; modelId: string } {
const spec = raw.trim();
const separator = spec.includes(':') ? ':' : spec.includes('/') ? '/' : undefined;
if (!separator) {
throw new Error(
`Invalid pi_embedded model "${raw}". Use "<provider>:<model>" (for example "openai:gpt-5.2").`,
);
}
const index = spec.indexOf(separator);
const provider = spec.slice(0, index).trim();
const modelId = spec.slice(index + 1).trim();
if (!provider || !modelId) {
throw new Error(
`Invalid pi_embedded model "${raw}". Use "<provider>:<model>" (for example "openai:gpt-5.2").`,
);
}
return { provider, modelId };
}
function isPackageSpecifier(moduleName: string): boolean {
return !(moduleName.startsWith('.') || moduleName.startsWith('/') || moduleName.startsWith('file:'));
}
function extractText(value: unknown): string | undefined { function extractText(value: unknown): string | undefined {
if (typeof value === 'string') { if (typeof value === 'string') {
const trimmed = value.trim(); const trimmed = value.trim();
@@ -158,15 +209,8 @@ export class PiEmbeddedBackend implements ExternalBackend {
async process(input: ExternalBackendRequest): Promise<string> { async process(input: ExternalBackendRequest): Promise<string> {
const prompt = buildPrompt(input); const prompt = buildPrompt(input);
const moduleLike = await this.loadPiModule(); const { moduleLike, moduleName } = await this.loadPiModule();
const factory = getSessionFactory(moduleLike); const factory = getSessionFactory(moduleLike);
if (!factory) {
throw new Error(
'Loaded Pi module does not expose a supported session factory ' +
'(expected one of: createAgentSession, createSession, createPiSession, createAgent)',
);
}
const requestPayload: Record<string, unknown> = { const requestPayload: Record<string, unknown> = {
prompt, prompt,
input: input.prompt, input: input.prompt,
@@ -179,29 +223,44 @@ export class PiEmbeddedBackend implements ExternalBackend {
systemPromptMode: this.systemPromptMode, systemPromptMode: this.systemPromptMode,
}; };
const session = await withTimeout( if (factory) {
Promise.resolve(factory(requestPayload)), const session = await withTimeout(
this.timeoutMs, Promise.resolve(factory(requestPayload)),
'Pi embedded session initialization',
);
try {
const response = await withTimeout(
this.invokeSession(session, requestPayload, prompt),
this.timeoutMs, this.timeoutMs,
'Pi embedded request', 'Pi embedded session initialization',
); );
const text = extractText(response);
if (!text) { try {
throw new Error('Pi embedded backend returned no text output'); const response = await withTimeout(
this.invokeSession(session, requestPayload, prompt),
this.timeoutMs,
'Pi embedded request',
);
const text = extractText(response);
if (!text) {
throw new Error('Pi embedded backend returned no text output');
}
return text;
} finally {
await maybeDisposeSession(session);
} }
return text;
} finally {
await maybeDisposeSession(session);
} }
const AgentCtor = getAgentConstructor(moduleLike);
if (!AgentCtor) {
throw new Error(
'Loaded Pi module does not expose a supported runtime API ' +
'(expected one of: createAgentSession/createSession/createPiSession/createAgent, or Agent class export)',
);
}
return withTimeout(
this.invokeAgentRuntime(AgentCtor, input, moduleName),
this.timeoutMs,
'Pi embedded request',
);
} }
private async loadPiModule(): Promise<PiModuleLike> { private async loadPiModule(): Promise<{ moduleLike: PiModuleLike; moduleName: string }> {
const candidates = this.moduleOverride const candidates = this.moduleOverride
? [this.moduleOverride] ? [this.moduleOverride]
: [...DEFAULT_MODULE_CANDIDATES]; : [...DEFAULT_MODULE_CANDIDATES];
@@ -210,7 +269,10 @@ export class PiEmbeddedBackend implements ExternalBackend {
for (const moduleName of candidates) { for (const moduleName of candidates) {
try { try {
const loaded = await import(moduleName); const loaded = await import(moduleName);
return loaded as PiModuleLike; return {
moduleLike: loaded as PiModuleLike,
moduleName,
};
} catch (error) { } catch (error) {
if (isModuleNotFound(error, moduleName)) { if (isModuleNotFound(error, moduleName)) {
failures.push(`${moduleName}: not installed`); failures.push(`${moduleName}: not installed`);
@@ -227,6 +289,114 @@ export class PiEmbeddedBackend implements ExternalBackend {
); );
} }
private async invokeAgentRuntime(
AgentCtor: new (options?: Record<string, unknown>) => unknown,
input: ExternalBackendRequest,
moduleName: string,
): Promise<string> {
const agent = new AgentCtor();
if (!agent || typeof agent !== 'object') {
throw new Error('Pi Agent constructor returned an invalid object');
}
const agentObj = agent as PiSessionLike;
if (this.model) {
const model = await this.resolvePiModel(this.model, moduleName);
const setModel = agentObj.setModel;
if (typeof setModel === 'function') {
await Promise.resolve(setModel.call(agent, model));
}
}
const replaceMessages = agentObj.replaceMessages;
if (typeof replaceMessages === 'function') {
const historyMessages = input.history.map((entry, index) => ({
role: entry.role,
content: [{ type: 'text', text: entry.content }],
timestamp: Date.now() + index,
}));
await Promise.resolve(replaceMessages.call(agent, historyMessages));
}
const promptMethod = agentObj.prompt;
if (typeof promptMethod !== 'function') {
throw new Error('Pi Agent runtime does not expose prompt()');
}
await Promise.resolve(promptMethod.call(agent, input.prompt));
const state = agentObj.state;
if (!state || typeof state !== 'object') {
throw new Error('Pi Agent runtime returned no state after prompt()');
}
const messages = (state as PiSessionLike).messages;
if (!Array.isArray(messages)) {
throw new Error('Pi Agent runtime state does not include messages');
}
for (let index = messages.length - 1; index >= 0; index -= 1) {
const message = messages[index];
if (!message || typeof message !== 'object') {
continue;
}
const role = (message as PiSessionLike).role;
if (role !== 'assistant') {
continue;
}
const text = extractText(message);
if (text) {
return text;
}
}
throw new Error('Pi Agent runtime produced no assistant text');
}
private async resolvePiModel(modelSpec: string, moduleName: string): Promise<unknown> {
const { provider, modelId } = parseModelSpec(modelSpec);
const piAi = await this.loadPiAiModule(moduleName);
const getModel = piAi.getModel;
if (typeof getModel !== 'function') {
throw new Error('Pi AI module does not expose getModel() required for pi_embedded.model override');
}
try {
return getModel(provider, modelId);
} catch (error) {
throw new Error(
`Failed to resolve pi_embedded.model "${modelSpec}": ${toErrorMessage(error)}`,
);
}
}
private async loadPiAiModule(moduleName: string): Promise<PiModelFactoryModuleLike> {
const rootRequire = createRequire(import.meta.url);
let resolver = rootRequire;
if (isPackageSpecifier(moduleName)) {
try {
const modulePackageJson = rootRequire.resolve(`${moduleName}/package.json`);
resolver = createRequire(modulePackageJson);
} catch {
resolver = rootRequire;
}
}
const failures: string[] = [];
for (const candidate of PI_AI_MODULE_CANDIDATES) {
try {
const resolvedEntry = resolver.resolve(candidate);
const loaded = await import(pathToFileURL(resolvedEntry).href);
return loaded as PiModelFactoryModuleLike;
} catch (error) {
failures.push(`${candidate}: ${toErrorMessage(error)}`);
}
}
throw new Error(
'Failed to load Pi AI model registry module. ' +
`Tried: ${PI_AI_MODULE_CANDIDATES.join(', ')}. ` +
`Details: ${failures.join(' | ')}`,
);
}
private async invokeSession( private async invokeSession(
session: unknown, session: unknown,
requestPayload: Record<string, unknown>, requestPayload: Record<string, unknown>,