diff --git a/src/agents/pi-embedded-runner/run/attempt.test.ts b/src/agents/pi-embedded-runner/run/attempt.test.ts index bc6cddfb5d6..27982edcf05 100644 --- a/src/agents/pi-embedded-runner/run/attempt.test.ts +++ b/src/agents/pi-embedded-runner/run/attempt.test.ts @@ -8,6 +8,7 @@ import { resolvePromptBuildHookResult, resolvePromptModeForSession, shouldInjectOllamaCompatNumCtx, + decodeHtmlEntitiesInObject, wrapOllamaCompatNumCtx, wrapStreamFnTrimToolCallNames, } from "./attempt.js"; @@ -453,3 +454,42 @@ describe("shouldInjectOllamaCompatNumCtx", () => { ).toBe(false); }); }); + +describe("decodeHtmlEntitiesInObject", () => { + it("decodes HTML entities in string values", () => { + const result = decodeHtmlEntitiesInObject( + "source .env && psql "$DB" -c <query>", + ); + expect(result).toBe('source .env && psql "$DB" -c '); + }); + + it("recursively decodes nested objects", () => { + const input = { + command: "cd ~/dev && npm run build", + args: ["--flag="value"", "<input>"], + nested: { deep: "a & b" }, + }; + const result = decodeHtmlEntitiesInObject(input) as Record; + expect(result.command).toBe("cd ~/dev && npm run build"); + expect((result.args as string[])[0]).toBe('--flag="value"'); + expect((result.args as string[])[1]).toBe(""); + expect((result.nested as Record).deep).toBe("a & b"); + }); + + it("passes through non-string primitives unchanged", () => { + expect(decodeHtmlEntitiesInObject(42)).toBe(42); + expect(decodeHtmlEntitiesInObject(null)).toBe(null); + expect(decodeHtmlEntitiesInObject(true)).toBe(true); + expect(decodeHtmlEntitiesInObject(undefined)).toBe(undefined); + }); + + it("returns strings without entities unchanged", () => { + const input = "plain string with no entities"; + expect(decodeHtmlEntitiesInObject(input)).toBe(input); + }); + + it("decodes numeric character references", () => { + expect(decodeHtmlEntitiesInObject("'hello'")).toBe("'hello'"); + expect(decodeHtmlEntitiesInObject("'world'")).toBe("'world'"); + }); +}); diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index c34043a5351..627fb017953 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -102,6 +102,7 @@ import { type EmbeddedPiQueueHandle, setActiveEmbeddedRun, } from "../runs.js"; +import { isXaiProvider } from "../../schema/clean-for-xai.js"; import { buildEmbeddedSandboxInfo } from "../sandbox-info.js"; import { prewarmSessionFile, trackSessionManagerAccess } from "../session-manager-cache.js"; import { prepareSessionManagerForRun } from "../session-manager-init.js"; @@ -421,6 +422,110 @@ export function wrapStreamFnTrimToolCallNames( }; } +// --------------------------------------------------------------------------- +// xAI / Grok: decode HTML entities in tool call arguments +// --------------------------------------------------------------------------- + +const HTML_ENTITY_RE = /&(?:amp|lt|gt|quot|apos|#39|#x[0-9a-f]+|#\d+);/i; + +function decodeHtmlEntities(value: string): string { + return value + .replace(/&/gi, "&") + .replace(/"/gi, '"') + .replace(/'/gi, "'") + .replace(/'/gi, "'") + .replace(/</gi, "<") + .replace(/>/gi, ">") + .replace(/&#x([0-9a-f]+);/gi, (_, hex) => String.fromCodePoint(Number.parseInt(hex, 16))) + .replace(/&#(\d+);/gi, (_, dec) => String.fromCodePoint(Number.parseInt(dec, 10))); +} + +export function decodeHtmlEntitiesInObject(obj: unknown): unknown { + if (typeof obj === "string") { + return HTML_ENTITY_RE.test(obj) ? decodeHtmlEntities(obj) : obj; + } + if (Array.isArray(obj)) { + return obj.map(decodeHtmlEntitiesInObject); + } + if (obj && typeof obj === "object") { + const result: Record = {}; + for (const [key, val] of Object.entries(obj as Record)) { + result[key] = decodeHtmlEntitiesInObject(val); + } + return result; + } + return obj; +} + +function decodeXaiToolCallArgumentsInMessage(message: unknown): void { + if (!message || typeof message !== "object") { + return; + } + const content = (message as { content?: unknown }).content; + if (!Array.isArray(content)) { + return; + } + for (const block of content) { + if (!block || typeof block !== "object") { + continue; + } + const typedBlock = block as { type?: unknown; arguments?: unknown }; + if (typedBlock.type !== "toolCall" || !typedBlock.arguments) { + continue; + } + if (typeof typedBlock.arguments === "object") { + typedBlock.arguments = decodeHtmlEntitiesInObject(typedBlock.arguments); + } + } +} + +function wrapStreamDecodeXaiToolCallArguments( + stream: ReturnType, +): ReturnType { + const originalResult = stream.result.bind(stream); + stream.result = async () => { + const message = await originalResult(); + decodeXaiToolCallArgumentsInMessage(message); + return message; + }; + + const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream); + (stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] = + function () { + const iterator = originalAsyncIterator(); + return { + async next() { + const result = await iterator.next(); + if (!result.done && result.value && typeof result.value === "object") { + const event = result.value as { partial?: unknown; message?: unknown }; + decodeXaiToolCallArgumentsInMessage(event.partial); + decodeXaiToolCallArgumentsInMessage(event.message); + } + return result; + }, + async return(value?: unknown) { + return iterator.return?.(value) ?? { done: true as const, value: undefined }; + }, + async throw(error?: unknown) { + return iterator.throw?.(error) ?? { done: true as const, value: undefined }; + }, + }; + }; + return stream; +} + +function wrapStreamFnDecodeXaiToolCallArguments(baseFn: StreamFn): StreamFn { + return (model, context, options) => { + const maybeStream = baseFn(model, context, options); + if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) { + return Promise.resolve(maybeStream).then((stream) => + wrapStreamDecodeXaiToolCallArguments(stream), + ); + } + return wrapStreamDecodeXaiToolCallArguments(maybeStream); + }; +} + export async function resolvePromptBuildHookResult(params: { prompt: string; messages: unknown[]; @@ -1158,6 +1263,12 @@ export async function runEmbeddedAttempt( allowedToolNames, ); + if (isXaiProvider(params.provider, params.modelId)) { + activeSession.agent.streamFn = wrapStreamFnDecodeXaiToolCallArguments( + activeSession.agent.streamFn, + ); + } + if (anthropicPayloadLogger) { activeSession.agent.streamFn = anthropicPayloadLogger.wrapStreamFn( activeSession.agent.streamFn,