From e8204f5d42d061199b9e791b845dd4542ea9a05a Mon Sep 17 00:00:00 2001 From: William Valentin Date: Mon, 23 Feb 2026 21:59:51 -0800 Subject: [PATCH] feat(pi): support Agent runtime export in pi_embedded backend --- src/backends/piEmbedded.test.ts | 30 ++++- src/backends/piEmbedded.ts | 222 ++++++++++++++++++++++++++++---- 2 files changed, 225 insertions(+), 27 deletions(-) diff --git a/src/backends/piEmbedded.test.ts b/src/backends/piEmbedded.test.ts index 9fb7d5b..6b077dc 100644 --- a/src/backends/piEmbedded.test.ts +++ b/src/backends/piEmbedded.test.ts @@ -66,7 +66,35 @@ describe('PiEmbeddedBackend', () => { try { const backend = new PiEmbeddedBackend({ module: mod.moduleUrl, timeoutMs: 2000 }); await expect(backend.process({ prompt: 'hello', history: [] })) - .rejects.toThrow('supported session factory'); + .rejects.toThrow('supported runtime API'); + } finally { + mod.cleanup(); + } + }); + + it('uses Agent class runtime when session factory exports are absent', async () => { + const mod = createModule(` + export class Agent { + constructor() { + this.state = { messages: [] }; + } + replaceMessages(messages) { + this.state.messages = messages.slice(); + } + async prompt(input) { + this.state.messages.push({ role: "user", content: [{ type: "text", text: input }] }); + this.state.messages.push({ role: "assistant", content: [{ type: "text", text: "agent says: " + input }] }); + } + } + `); + + try { + const backend = new PiEmbeddedBackend({ module: mod.moduleUrl, timeoutMs: 2000 }); + const result = await backend.process({ + prompt: 'hello', + history: [{ role: 'assistant', content: 'previous answer' }], + }); + expect(result).toBe('agent says: hello'); } finally { mod.cleanup(); } diff --git a/src/backends/piEmbedded.ts b/src/backends/piEmbedded.ts index 745319d..4626e08 100644 --- a/src/backends/piEmbedded.ts +++ b/src/backends/piEmbedded.ts @@ -1,3 +1,5 @@ +import { createRequire } from 'module'; +import { pathToFileURL } from 'url'; import type { ExternalBackend, ExternalBackendRequest } from './external.js'; const DEFAULT_TIMEOUT_MS = 120_000; @@ -7,6 +9,12 @@ const DEFAULT_MODULE_CANDIDATES = [ '@openclaw/pi-agent-core', 'pi-agent-core', ] as const; +const PI_AI_MODULE_CANDIDATES = [ + '@mariozechner/pi-ai', + '@badlogic/pi-ai', + '@openclaw/pi-ai', + 'pi-ai', +] as const; type PiSystemPromptMode = 'flynn' | 'pi_default' | 'hybrid'; @@ -18,6 +26,10 @@ interface PiSessionLike { [key: string]: unknown; } +interface PiModelFactoryModuleLike extends PiModuleLike { + getModel?: (provider: string, modelId: string) => unknown; +} + export interface PiEmbeddedBackendOptions { timeoutMs?: number; model?: string; @@ -87,6 +99,45 @@ function getSessionFactory(moduleLike: PiModuleLike): ((args: Record) => unknown) | undefined { + const direct = moduleLike.Agent; + if (typeof direct === 'function') { + return direct as new (options?: Record) => unknown; + } + + const defaultExport = moduleLike.default; + if (defaultExport && typeof defaultExport === 'object') { + const nested = (defaultExport as PiModuleLike).Agent; + if (typeof nested === 'function') { + return nested as new (options?: Record) => unknown; + } + } + return undefined; +} + +function parseModelSpec(raw: string): { provider: string; modelId: string } { + const spec = raw.trim(); + const separator = spec.includes(':') ? ':' : spec.includes('/') ? '/' : undefined; + if (!separator) { + throw new Error( + `Invalid pi_embedded model "${raw}". Use ":" (for example "openai:gpt-5.2").`, + ); + } + const index = spec.indexOf(separator); + const provider = spec.slice(0, index).trim(); + const modelId = spec.slice(index + 1).trim(); + if (!provider || !modelId) { + throw new Error( + `Invalid pi_embedded model "${raw}". Use ":" (for example "openai:gpt-5.2").`, + ); + } + return { provider, modelId }; +} + +function isPackageSpecifier(moduleName: string): boolean { + return !(moduleName.startsWith('.') || moduleName.startsWith('/') || moduleName.startsWith('file:')); +} + function extractText(value: unknown): string | undefined { if (typeof value === 'string') { const trimmed = value.trim(); @@ -158,15 +209,8 @@ export class PiEmbeddedBackend implements ExternalBackend { async process(input: ExternalBackendRequest): Promise { const prompt = buildPrompt(input); - const moduleLike = await this.loadPiModule(); + const { moduleLike, moduleName } = await this.loadPiModule(); const factory = getSessionFactory(moduleLike); - if (!factory) { - throw new Error( - 'Loaded Pi module does not expose a supported session factory ' + - '(expected one of: createAgentSession, createSession, createPiSession, createAgent)', - ); - } - const requestPayload: Record = { prompt, input: input.prompt, @@ -179,29 +223,44 @@ export class PiEmbeddedBackend implements ExternalBackend { systemPromptMode: this.systemPromptMode, }; - const session = await withTimeout( - Promise.resolve(factory(requestPayload)), - this.timeoutMs, - 'Pi embedded session initialization', - ); - - try { - const response = await withTimeout( - this.invokeSession(session, requestPayload, prompt), + if (factory) { + const session = await withTimeout( + Promise.resolve(factory(requestPayload)), this.timeoutMs, - 'Pi embedded request', + 'Pi embedded session initialization', ); - const text = extractText(response); - if (!text) { - throw new Error('Pi embedded backend returned no text output'); + + try { + const response = await withTimeout( + this.invokeSession(session, requestPayload, prompt), + this.timeoutMs, + 'Pi embedded request', + ); + const text = extractText(response); + if (!text) { + throw new Error('Pi embedded backend returned no text output'); + } + return text; + } finally { + await maybeDisposeSession(session); } - return text; - } finally { - await maybeDisposeSession(session); } + + const AgentCtor = getAgentConstructor(moduleLike); + if (!AgentCtor) { + throw new Error( + 'Loaded Pi module does not expose a supported runtime API ' + + '(expected one of: createAgentSession/createSession/createPiSession/createAgent, or Agent class export)', + ); + } + return withTimeout( + this.invokeAgentRuntime(AgentCtor, input, moduleName), + this.timeoutMs, + 'Pi embedded request', + ); } - private async loadPiModule(): Promise { + private async loadPiModule(): Promise<{ moduleLike: PiModuleLike; moduleName: string }> { const candidates = this.moduleOverride ? [this.moduleOverride] : [...DEFAULT_MODULE_CANDIDATES]; @@ -210,7 +269,10 @@ export class PiEmbeddedBackend implements ExternalBackend { for (const moduleName of candidates) { try { const loaded = await import(moduleName); - return loaded as PiModuleLike; + return { + moduleLike: loaded as PiModuleLike, + moduleName, + }; } catch (error) { if (isModuleNotFound(error, moduleName)) { failures.push(`${moduleName}: not installed`); @@ -227,6 +289,114 @@ export class PiEmbeddedBackend implements ExternalBackend { ); } + private async invokeAgentRuntime( + AgentCtor: new (options?: Record) => unknown, + input: ExternalBackendRequest, + moduleName: string, + ): Promise { + const agent = new AgentCtor(); + if (!agent || typeof agent !== 'object') { + throw new Error('Pi Agent constructor returned an invalid object'); + } + const agentObj = agent as PiSessionLike; + + if (this.model) { + const model = await this.resolvePiModel(this.model, moduleName); + const setModel = agentObj.setModel; + if (typeof setModel === 'function') { + await Promise.resolve(setModel.call(agent, model)); + } + } + + const replaceMessages = agentObj.replaceMessages; + if (typeof replaceMessages === 'function') { + const historyMessages = input.history.map((entry, index) => ({ + role: entry.role, + content: [{ type: 'text', text: entry.content }], + timestamp: Date.now() + index, + })); + await Promise.resolve(replaceMessages.call(agent, historyMessages)); + } + + const promptMethod = agentObj.prompt; + if (typeof promptMethod !== 'function') { + throw new Error('Pi Agent runtime does not expose prompt()'); + } + await Promise.resolve(promptMethod.call(agent, input.prompt)); + + const state = agentObj.state; + if (!state || typeof state !== 'object') { + throw new Error('Pi Agent runtime returned no state after prompt()'); + } + const messages = (state as PiSessionLike).messages; + if (!Array.isArray(messages)) { + throw new Error('Pi Agent runtime state does not include messages'); + } + + for (let index = messages.length - 1; index >= 0; index -= 1) { + const message = messages[index]; + if (!message || typeof message !== 'object') { + continue; + } + const role = (message as PiSessionLike).role; + if (role !== 'assistant') { + continue; + } + const text = extractText(message); + if (text) { + return text; + } + } + + throw new Error('Pi Agent runtime produced no assistant text'); + } + + private async resolvePiModel(modelSpec: string, moduleName: string): Promise { + const { provider, modelId } = parseModelSpec(modelSpec); + const piAi = await this.loadPiAiModule(moduleName); + const getModel = piAi.getModel; + if (typeof getModel !== 'function') { + throw new Error('Pi AI module does not expose getModel() required for pi_embedded.model override'); + } + try { + return getModel(provider, modelId); + } catch (error) { + throw new Error( + `Failed to resolve pi_embedded.model "${modelSpec}": ${toErrorMessage(error)}`, + ); + } + } + + private async loadPiAiModule(moduleName: string): Promise { + const rootRequire = createRequire(import.meta.url); + let resolver = rootRequire; + + if (isPackageSpecifier(moduleName)) { + try { + const modulePackageJson = rootRequire.resolve(`${moduleName}/package.json`); + resolver = createRequire(modulePackageJson); + } catch { + resolver = rootRequire; + } + } + + const failures: string[] = []; + for (const candidate of PI_AI_MODULE_CANDIDATES) { + try { + const resolvedEntry = resolver.resolve(candidate); + const loaded = await import(pathToFileURL(resolvedEntry).href); + return loaded as PiModelFactoryModuleLike; + } catch (error) { + failures.push(`${candidate}: ${toErrorMessage(error)}`); + } + } + throw new Error( + 'Failed to load Pi AI model registry module. ' + + `Tried: ${PI_AI_MODULE_CANDIDATES.join(', ')}. ` + + `Details: ${failures.join(' | ')}`, + ); + } + private async invokeSession( session: unknown, requestPayload: Record,