xAI: preserve session auth in embedded runs

This commit is contained in:
huntharo 2026-03-27 14:32:42 -04:00 committed by Peter Steinberger
parent 2765fdc2dd
commit 80a1ccc552
2 changed files with 96 additions and 24 deletions

View File

@ -27,6 +27,7 @@ import {
wrapStreamFnRepairMalformedToolCallArguments,
wrapStreamFnSanitizeMalformedToolCalls,
wrapStreamFnTrimToolCallNames,
resolveEmbeddedAgentStreamFn,
} from "./attempt.js";
import { shouldInjectHeartbeatPromptForTrigger } from "./trigger-policy.js";
@ -1810,6 +1811,50 @@ describe("shouldInjectOllamaCompatNumCtx", () => {
});
});
describe("resolveEmbeddedAgentStreamFn", () => {
it("keeps the session-managed HTTP stream when no override applies", () => {
const currentStreamFn = vi.fn();
const resolved = resolveEmbeddedAgentStreamFn({
currentStreamFn: currentStreamFn as never,
shouldUseWebSocketTransport: false,
sessionId: "session-1",
model: { provider: "xai" } as never,
});
expect(resolved).toBe(currentStreamFn);
});
it("keeps the session-managed HTTP stream when websocket auth is unavailable", () => {
const currentStreamFn = vi.fn();
const resolved = resolveEmbeddedAgentStreamFn({
currentStreamFn: currentStreamFn as never,
shouldUseWebSocketTransport: true,
wsApiKey: undefined,
sessionId: "session-1",
model: { provider: "xai" } as never,
});
expect(resolved).toBe(currentStreamFn);
});
it("prefers a provider-owned stream override when present", () => {
const currentStreamFn = vi.fn();
const providerStreamFn = vi.fn();
const resolved = resolveEmbeddedAgentStreamFn({
currentStreamFn: currentStreamFn as never,
providerStreamFn: providerStreamFn as never,
shouldUseWebSocketTransport: false,
sessionId: "session-1",
model: { provider: "xai" } as never,
});
expect(resolved).toBe(providerStreamFn);
});
});
describe("decodeHtmlEntitiesInObject", () => {
it("decodes HTML entities in string values", () => {
const result = decodeHtmlEntitiesInObject(

View File

@ -1,6 +1,7 @@
import fs from "node:fs/promises";
import os from "node:os";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import type { AgentMessage, StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import {
createAgentSession,
DefaultResourceLoader,
@ -225,6 +226,35 @@ function summarizeProviderAuthKey(apiKey: string | undefined): string {
return `${trimmed.slice(0, 4)}${trimmed.slice(-4)}`;
}
export function resolveEmbeddedAgentStreamFn(params: {
currentStreamFn: StreamFn | undefined;
providerStreamFn?: StreamFn;
shouldUseWebSocketTransport: boolean;
wsApiKey?: string;
sessionId: string;
signal?: AbortSignal;
model: EmbeddedRunAttemptParams["model"];
}): StreamFn {
if (params.providerStreamFn) {
return params.providerStreamFn;
}
const currentStreamFn = params.currentStreamFn ?? streamSimple;
if (params.shouldUseWebSocketTransport) {
return params.wsApiKey
? createOpenAIWebSocketStreamFn(params.wsApiKey, params.sessionId, {
signal: params.signal,
})
: currentStreamFn;
}
if (params.model.provider === "anthropic-vertex") {
return createAnthropicVertexStreamFnForModel(params.model);
}
return currentStreamFn;
}
function summarizeMessagePayload(msg: AgentMessage): { textChars: number; imageBlocks: number } {
const content = (msg as { content?: unknown }).content;
if (typeof content === "string") {
@ -868,30 +898,27 @@ export async function runEmbeddedAttempt(
`[xai-auth] pre-stream setup: modelApi=${params.model.api} baseUrl=${params.model.baseUrl ?? "default"} runtimeAuthKey=${summarizeProviderAuthKey(runtimeApiKey)} headersAuth=${params.model.headers?.Authorization ? "present" : "absent"} responsesAuthPath=apiKey-argument`,
);
}
if (providerStreamFn) {
activeSession.agent.streamFn = providerStreamFn;
} else if (
shouldUseOpenAIWebSocketTransport({
provider: params.provider,
modelApi: params.model.api,
})
) {
const wsApiKey = await params.authStorage.getApiKey(params.provider);
if (wsApiKey) {
activeSession.agent.streamFn = createOpenAIWebSocketStreamFn(wsApiKey, params.sessionId, {
signal: runAbortController.signal,
});
} else {
log.warn(`[ws-stream] no API key for provider=${params.provider}; using HTTP transport`);
activeSession.agent.streamFn = defaultSessionStreamFn;
}
} else if (params.model.provider === "anthropic-vertex") {
// Anthropic Vertex AI: inject AnthropicVertex client into pi-ai's
// streamAnthropic for GCP IAM auth instead of Anthropic API keys.
activeSession.agent.streamFn = createAnthropicVertexStreamFnForModel(params.model);
} else {
activeSession.agent.streamFn = defaultSessionStreamFn;
const shouldUseWebSocketTransport = shouldUseOpenAIWebSocketTransport({
provider: params.provider,
modelApi: params.model.api,
});
const wsApiKey = shouldUseWebSocketTransport
? await params.authStorage.getApiKey(params.provider)
: undefined;
if (shouldUseWebSocketTransport && !wsApiKey) {
log.warn(
`[ws-stream] no API key for provider=${params.provider}; keeping session-managed HTTP transport`,
);
}
activeSession.agent.streamFn = resolveEmbeddedAgentStreamFn({
currentStreamFn: defaultSessionStreamFn,
providerStreamFn,
shouldUseWebSocketTransport,
wsApiKey,
sessionId: params.sessionId,
signal: runAbortController.signal,
model: params.model,
});
const { effectiveExtraParams } = applyExtraParamsToAgent(
activeSession.agent,