From 80a1ccc55209042d4e2c04af7a9ce90ad7332b9a Mon Sep 17 00:00:00 2001 From: huntharo Date: Fri, 27 Mar 2026 14:32:42 -0400 Subject: [PATCH] xAI: preserve session auth in embedded runs --- .../pi-embedded-runner/run/attempt.test.ts | 45 +++++++++++ src/agents/pi-embedded-runner/run/attempt.ts | 75 +++++++++++++------ 2 files changed, 96 insertions(+), 24 deletions(-) diff --git a/src/agents/pi-embedded-runner/run/attempt.test.ts b/src/agents/pi-embedded-runner/run/attempt.test.ts index e80513d1853..89426348788 100644 --- a/src/agents/pi-embedded-runner/run/attempt.test.ts +++ b/src/agents/pi-embedded-runner/run/attempt.test.ts @@ -27,6 +27,7 @@ import { wrapStreamFnRepairMalformedToolCallArguments, wrapStreamFnSanitizeMalformedToolCalls, wrapStreamFnTrimToolCallNames, + resolveEmbeddedAgentStreamFn, } from "./attempt.js"; import { shouldInjectHeartbeatPromptForTrigger } from "./trigger-policy.js"; @@ -1810,6 +1811,50 @@ describe("shouldInjectOllamaCompatNumCtx", () => { }); }); +describe("resolveEmbeddedAgentStreamFn", () => { + it("keeps the session-managed HTTP stream when no override applies", () => { + const currentStreamFn = vi.fn(); + + const resolved = resolveEmbeddedAgentStreamFn({ + currentStreamFn: currentStreamFn as never, + shouldUseWebSocketTransport: false, + sessionId: "session-1", + model: { provider: "xai" } as never, + }); + + expect(resolved).toBe(currentStreamFn); + }); + + it("keeps the session-managed HTTP stream when websocket auth is unavailable", () => { + const currentStreamFn = vi.fn(); + + const resolved = resolveEmbeddedAgentStreamFn({ + currentStreamFn: currentStreamFn as never, + shouldUseWebSocketTransport: true, + wsApiKey: undefined, + sessionId: "session-1", + model: { provider: "xai" } as never, + }); + + expect(resolved).toBe(currentStreamFn); + }); + + it("prefers a provider-owned stream override when present", () => { + const currentStreamFn = vi.fn(); + const providerStreamFn = vi.fn(); + + const resolved = resolveEmbeddedAgentStreamFn({ + currentStreamFn: currentStreamFn as never, + providerStreamFn: providerStreamFn as never, + shouldUseWebSocketTransport: false, + sessionId: "session-1", + model: { provider: "xai" } as never, + }); + + expect(resolved).toBe(providerStreamFn); + }); +}); + describe("decodeHtmlEntitiesInObject", () => { it("decodes HTML entities in string values", () => { const result = decodeHtmlEntitiesInObject( diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index b88a9dd1c38..ac2ae8fa0e9 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -1,6 +1,7 @@ import fs from "node:fs/promises"; import os from "node:os"; -import type { AgentMessage } from "@mariozechner/pi-agent-core"; +import type { AgentMessage, StreamFn } from "@mariozechner/pi-agent-core"; +import { streamSimple } from "@mariozechner/pi-ai"; import { createAgentSession, DefaultResourceLoader, @@ -225,6 +226,35 @@ function summarizeProviderAuthKey(apiKey: string | undefined): string { return `${trimmed.slice(0, 4)}…${trimmed.slice(-4)}`; } +export function resolveEmbeddedAgentStreamFn(params: { + currentStreamFn: StreamFn | undefined; + providerStreamFn?: StreamFn; + shouldUseWebSocketTransport: boolean; + wsApiKey?: string; + sessionId: string; + signal?: AbortSignal; + model: EmbeddedRunAttemptParams["model"]; +}): StreamFn { + if (params.providerStreamFn) { + return params.providerStreamFn; + } + + const currentStreamFn = params.currentStreamFn ?? streamSimple; + if (params.shouldUseWebSocketTransport) { + return params.wsApiKey + ? createOpenAIWebSocketStreamFn(params.wsApiKey, params.sessionId, { + signal: params.signal, + }) + : currentStreamFn; + } + + if (params.model.provider === "anthropic-vertex") { + return createAnthropicVertexStreamFnForModel(params.model); + } + + return currentStreamFn; +} + function summarizeMessagePayload(msg: AgentMessage): { textChars: number; imageBlocks: number } { const content = (msg as { content?: unknown }).content; if (typeof content === "string") { @@ -868,30 +898,27 @@ export async function runEmbeddedAttempt( `[xai-auth] pre-stream setup: modelApi=${params.model.api} baseUrl=${params.model.baseUrl ?? "default"} runtimeAuthKey=${summarizeProviderAuthKey(runtimeApiKey)} headersAuth=${params.model.headers?.Authorization ? "present" : "absent"} responsesAuthPath=apiKey-argument`, ); } - if (providerStreamFn) { - activeSession.agent.streamFn = providerStreamFn; - } else if ( - shouldUseOpenAIWebSocketTransport({ - provider: params.provider, - modelApi: params.model.api, - }) - ) { - const wsApiKey = await params.authStorage.getApiKey(params.provider); - if (wsApiKey) { - activeSession.agent.streamFn = createOpenAIWebSocketStreamFn(wsApiKey, params.sessionId, { - signal: runAbortController.signal, - }); - } else { - log.warn(`[ws-stream] no API key for provider=${params.provider}; using HTTP transport`); - activeSession.agent.streamFn = defaultSessionStreamFn; - } - } else if (params.model.provider === "anthropic-vertex") { - // Anthropic Vertex AI: inject AnthropicVertex client into pi-ai's - // streamAnthropic for GCP IAM auth instead of Anthropic API keys. - activeSession.agent.streamFn = createAnthropicVertexStreamFnForModel(params.model); - } else { - activeSession.agent.streamFn = defaultSessionStreamFn; + const shouldUseWebSocketTransport = shouldUseOpenAIWebSocketTransport({ + provider: params.provider, + modelApi: params.model.api, + }); + const wsApiKey = shouldUseWebSocketTransport + ? await params.authStorage.getApiKey(params.provider) + : undefined; + if (shouldUseWebSocketTransport && !wsApiKey) { + log.warn( + `[ws-stream] no API key for provider=${params.provider}; keeping session-managed HTTP transport`, + ); } + activeSession.agent.streamFn = resolveEmbeddedAgentStreamFn({ + currentStreamFn: defaultSessionStreamFn, + providerStreamFn, + shouldUseWebSocketTransport, + wsApiKey, + sessionId: params.sessionId, + signal: runAbortController.signal, + model: params.model, + }); const { effectiveExtraParams } = applyExtraParamsToAgent( activeSession.agent,