diff --git a/src/agents/openai-transport-stream.test.ts b/src/agents/openai-transport-stream.test.ts new file mode 100644 index 00000000000..bf335445c01 --- /dev/null +++ b/src/agents/openai-transport-stream.test.ts @@ -0,0 +1,120 @@ +import type { Model } from "@mariozechner/pi-ai"; +import { describe, expect, it } from "vitest"; +import { + buildTransportAwareSimpleStreamFn, + isTransportAwareApiSupported, + parseTransportChunkUsage, + prepareTransportAwareSimpleModel, + resolveAzureOpenAIApiVersion, + resolveTransportAwareSimpleApi, + sanitizeTransportPayloadText, +} from "./openai-transport-stream.js"; +import { attachModelProviderRequestTransport } from "./provider-request-config.js"; + +describe("openai transport stream", () => { + it("reports the supported transport-aware APIs", () => { + expect(isTransportAwareApiSupported("openai-responses")).toBe(true); + expect(isTransportAwareApiSupported("openai-completions")).toBe(true); + expect(isTransportAwareApiSupported("azure-openai-responses")).toBe(true); + expect(isTransportAwareApiSupported("anthropic-messages")).toBe(false); + }); + + it("prepares a custom simple-completion api alias when transport overrides are attached", () => { + const model = attachModelProviderRequestTransport( + { + id: "gpt-5", + name: "GPT-5", + api: "openai-responses", + provider: "openai", + baseUrl: "https://api.openai.com/v1", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + } satisfies Model<"openai-responses">, + { + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", + }, + }, + ); + + const prepared = prepareTransportAwareSimpleModel(model); + + expect(resolveTransportAwareSimpleApi(model.api)).toBe("openclaw-openai-responses-transport"); + expect(prepared).toMatchObject({ + api: "openclaw-openai-responses-transport", + provider: "openai", + id: "gpt-5", + }); + expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function"); + }); + + it("removes unpaired surrogate code units but preserves valid surrogate pairs", () => { + const high = String.fromCharCode(0xd83d); + const low = String.fromCharCode(0xdc00); + + expect(sanitizeTransportPayloadText(`left${high}right`)).toBe("leftright"); + expect(sanitizeTransportPayloadText(`left${low}right`)).toBe("leftright"); + expect(sanitizeTransportPayloadText("emoji 🙈 ok")).toBe("emoji 🙈 ok"); + }); + + it("uses a valid Azure API version default when the environment is unset", () => { + expect(resolveAzureOpenAIApiVersion({})).toBe("2024-12-01-preview"); + expect(resolveAzureOpenAIApiVersion({ AZURE_OPENAI_API_VERSION: "2025-01-01-preview" })).toBe( + "2025-01-01-preview", + ); + }); + + it("does not double-count reasoning tokens and clamps uncached prompt usage at zero", () => { + const model = { + id: "gpt-5", + name: "GPT-5", + api: "openai-completions", + provider: "openai", + baseUrl: "https://api.openai.com/v1", + reasoning: true, + input: ["text"], + cost: { input: 1, output: 2, cacheRead: 0.5, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + } satisfies Model<"openai-completions">; + + expect( + parseTransportChunkUsage( + { + prompt_tokens: 10, + completion_tokens: 20, + total_tokens: 30, + prompt_tokens_details: { cached_tokens: 3 }, + completion_tokens_details: { reasoning_tokens: 7 }, + }, + model, + ), + ).toMatchObject({ + input: 7, + output: 20, + cacheRead: 3, + totalTokens: 30, + }); + + expect( + parseTransportChunkUsage( + { + prompt_tokens: 2, + completion_tokens: 5, + total_tokens: 7, + prompt_tokens_details: { cached_tokens: 4 }, + }, + model, + ), + ).toMatchObject({ + input: 0, + output: 5, + cacheRead: 4, + totalTokens: 9, + }); + }); +}); diff --git a/src/agents/openai-transport-stream.ts b/src/agents/openai-transport-stream.ts new file mode 100644 index 00000000000..775c45e7329 --- /dev/null +++ b/src/agents/openai-transport-stream.ts @@ -0,0 +1,1518 @@ +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import { + calculateCost, + createAssistantMessageEventStream, + getEnvApiKey, + parseStreamingJson, + type Api, + type Context, + type Model, +} from "@mariozechner/pi-ai"; +import { convertMessages } from "@mariozechner/pi-ai/openai-completions"; +import OpenAI, { AzureOpenAI } from "openai"; +import type { ChatCompletionChunk } from "openai/resources/chat/completions.js"; +import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js"; +import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js"; +import { + buildProviderRequestDispatcherPolicy, + getModelProviderRequestTransport, + resolveProviderRequestPolicyConfig, +} from "./provider-request-config.js"; + +const DEFAULT_AZURE_OPENAI_API_VERSION = "2024-12-01-preview"; + +const SUPPORTED_TRANSPORT_APIS = new Set([ + "openai-responses", + "openai-completions", + "azure-openai-responses", +]); + +const SIMPLE_TRANSPORT_API_ALIAS: Record = { + "openai-responses": "openclaw-openai-responses-transport", + "openai-completions": "openclaw-openai-completions-transport", + "azure-openai-responses": "openclaw-azure-openai-responses-transport", +}; + +type BaseStreamOptions = { + temperature?: number; + maxTokens?: number; + signal?: AbortSignal; + apiKey?: string; + cacheRetention?: "none" | "short" | "long"; + sessionId?: string; + onPayload?: (payload: unknown, model: Model) => unknown; + headers?: Record; +}; + +type OpenAIResponsesOptions = BaseStreamOptions & { + reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh"; + reasoningSummary?: "auto" | "detailed" | "concise" | null; + serviceTier?: ResponseCreateParamsStreaming["service_tier"]; +}; + +type OpenAICompletionsOptions = BaseStreamOptions & { + toolChoice?: + | "auto" + | "none" + | "required" + | { + type: "function"; + function: { + name: string; + }; + }; + reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh"; +}; + +type OpenAIModeModel = Model & { + compat?: Record; +}; + +type MutableAssistantOutput = { + role: "assistant"; + content: Array>; + api: Api; + provider: string; + model: string; + usage: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + totalTokens: number; + cost: { input: number; output: number; cacheRead: number; cacheWrite: number; total: number }; + }; + stopReason: string; + timestamp: number; + responseId?: string; + errorMessage?: string; +}; + +export function sanitizeTransportPayloadText(text: string): string { + return text.replace( + /[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(? { + if (message.role === "user" && Array.isArray(message.content)) { + return message.content.some((item) => item.type === "image"); + } + if (message.role === "toolResult" && Array.isArray(message.content)) { + return message.content.some((item) => item.type === "image"); + } + return false; + }); +} + +function buildCopilotDynamicHeaders(params: { + messages: Context["messages"]; + hasImages: boolean; +}): Record { + return { + "X-Initiator": inferCopilotInitiator(params.messages), + "Openai-Intent": "conversation-edits", + ...(params.hasImages ? { "Copilot-Vision-Request": "true" } : {}), + }; +} + +function transformMessages( + messages: Context["messages"], + model: Model, + normalizeToolCallId?: ( + id: string, + targetModel: Model, + source: { provider: string; api: Api; model: string }, + ) => string, +): Context["messages"] { + const toolCallIdMap = new Map(); + const transformed = messages.map((msg) => { + if (msg.role === "user") { + return msg; + } + if (msg.role === "toolResult") { + const normalizedId = toolCallIdMap.get(msg.toolCallId); + return normalizedId && normalizedId !== msg.toolCallId + ? { ...msg, toolCallId: normalizedId } + : msg; + } + if (msg.role !== "assistant") { + return msg; + } + const isSameModel = + msg.provider === model.provider && msg.api === model.api && msg.model === model.id; + const content: typeof msg.content = []; + for (const block of msg.content) { + if (block.type === "thinking") { + if (block.redacted) { + if (isSameModel) { + content.push(block); + } + continue; + } + if (isSameModel && block.thinkingSignature) { + content.push(block); + continue; + } + if (!block.thinking.trim()) { + continue; + } + content.push(isSameModel ? block : { type: "text", text: block.thinking }); + continue; + } + if (block.type === "text") { + content.push(isSameModel ? block : { type: "text", text: block.text }); + continue; + } + if (block.type !== "toolCall") { + content.push(block); + continue; + } + let normalizedToolCall = block; + if (!isSameModel && block.thoughtSignature) { + normalizedToolCall = { ...normalizedToolCall }; + delete normalizedToolCall.thoughtSignature; + } + if (!isSameModel && normalizeToolCallId) { + const normalizedId = normalizeToolCallId(block.id, model, msg); + if (normalizedId !== block.id) { + toolCallIdMap.set(block.id, normalizedId); + normalizedToolCall = { ...normalizedToolCall, id: normalizedId }; + } + } + content.push(normalizedToolCall); + } + return { ...msg, content }; + }); + + const result: Context["messages"] = []; + let pendingToolCalls: Array<{ id: string; name: string }> = []; + let existingToolResultIds = new Set(); + for (const msg of transformed) { + if (msg.role === "assistant") { + if (pendingToolCalls.length > 0) { + for (const toolCall of pendingToolCalls) { + if (!existingToolResultIds.has(toolCall.id)) { + result.push({ + role: "toolResult", + toolCallId: toolCall.id, + toolName: toolCall.name, + content: [{ type: "text", text: "No result provided" }], + isError: true, + timestamp: Date.now(), + }); + } + } + pendingToolCalls = []; + existingToolResultIds = new Set(); + } + if (msg.stopReason === "error" || msg.stopReason === "aborted") { + continue; + } + const toolCalls = msg.content.filter( + (block): block is Extract<(typeof msg.content)[number], { type: "toolCall" }> => + block.type === "toolCall", + ); + if (toolCalls.length > 0) { + pendingToolCalls = toolCalls.map((block) => ({ id: block.id, name: block.name })); + existingToolResultIds = new Set(); + } + result.push(msg); + continue; + } + if (msg.role === "toolResult") { + existingToolResultIds.add(msg.toolCallId); + result.push(msg); + continue; + } + if (pendingToolCalls.length > 0) { + for (const toolCall of pendingToolCalls) { + if (!existingToolResultIds.has(toolCall.id)) { + result.push({ + role: "toolResult", + toolCallId: toolCall.id, + toolName: toolCall.name, + content: [{ type: "text", text: "No result provided" }], + isError: true, + timestamp: Date.now(), + }); + } + } + pendingToolCalls = []; + existingToolResultIds = new Set(); + } + result.push(msg); + } + return result; +} + +function encodeTextSignatureV1(id: string, phase?: "commentary" | "final_answer"): string { + return JSON.stringify({ v: 1, id, ...(phase ? { phase } : {}) }); +} + +function parseTextSignature( + signature: string | undefined, +): { id: string; phase?: "commentary" | "final_answer" } | undefined { + if (!signature) { + return undefined; + } + if (signature.startsWith("{")) { + try { + const parsed = JSON.parse(signature) as { v?: unknown; id?: unknown; phase?: unknown }; + if (parsed.v === 1 && typeof parsed.id === "string") { + return parsed.phase === "commentary" || parsed.phase === "final_answer" + ? { id: parsed.id, phase: parsed.phase } + : { id: parsed.id }; + } + } catch { + // Keep legacy plain-string behavior below. + } + } + return { id: signature }; +} + +function convertResponsesMessages( + model: Model, + context: Context, + allowedToolCallProviders: Set, + options?: { includeSystemPrompt?: boolean }, +) { + const messages: unknown[] = []; + const normalizeIdPart = (part: string) => { + const sanitized = part.replace(/[^a-zA-Z0-9_-]/g, "_"); + const normalized = sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized; + return normalized.replace(/_+$/, ""); + }; + const buildForeignResponsesItemId = (itemId: string) => { + const normalized = `fc_${shortHash(itemId)}`; + return normalized.length > 64 ? normalized.slice(0, 64) : normalized; + }; + const normalizeToolCallId = ( + id: string, + _targetModel: Model, + source: { provider: string; api: Api }, + ) => { + if (!allowedToolCallProviders.has(model.provider)) { + return normalizeIdPart(id); + } + if (!id.includes("|")) { + return normalizeIdPart(id); + } + const [callId, itemId] = id.split("|"); + const normalizedCallId = normalizeIdPart(callId); + const isForeignToolCall = source.provider !== model.provider || source.api !== model.api; + let normalizedItemId = isForeignToolCall + ? buildForeignResponsesItemId(itemId) + : normalizeIdPart(itemId); + if (!normalizedItemId.startsWith("fc_")) { + normalizedItemId = normalizeIdPart(`fc_${normalizedItemId}`); + } + return `${normalizedCallId}|${normalizedItemId}`; + }; + const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId); + const includeSystemPrompt = options?.includeSystemPrompt ?? true; + if (includeSystemPrompt && context.systemPrompt) { + messages.push({ + role: model.reasoning ? "developer" : "system", + content: sanitizeTransportPayloadText(context.systemPrompt), + }); + } + let msgIndex = 0; + for (const msg of transformedMessages) { + if (msg.role === "user") { + if (typeof msg.content === "string") { + messages.push({ + role: "user", + content: [{ type: "input_text", text: sanitizeTransportPayloadText(msg.content) }], + }); + } else { + const content = msg.content + .map((item) => + item.type === "text" + ? { type: "input_text", text: sanitizeTransportPayloadText(item.text) } + : { + type: "input_image", + detail: "auto", + image_url: `data:${item.mimeType};base64,${item.data}`, + }, + ) + .filter((item) => model.input.includes("image") || item.type !== "input_image"); + if (content.length > 0) { + messages.push({ role: "user", content }); + } + } + } else if (msg.role === "assistant") { + const output: unknown[] = []; + const isDifferentModel = + msg.model !== model.id && msg.provider === model.provider && msg.api === model.api; + for (const block of msg.content) { + if (block.type === "thinking") { + if (block.thinkingSignature) { + output.push(JSON.parse(block.thinkingSignature)); + } + } else if (block.type === "text") { + let msgId = parseTextSignature(block.textSignature)?.id ?? `msg_${msgIndex}`; + if (msgId.length > 64) { + msgId = `msg_${shortHash(msgId)}`; + } + output.push({ + type: "message", + role: "assistant", + content: [ + { + type: "output_text", + text: sanitizeTransportPayloadText(block.text), + annotations: [], + }, + ], + status: "completed", + id: msgId, + phase: parseTextSignature(block.textSignature)?.phase, + }); + } else if (block.type === "toolCall") { + const [callId, itemIdRaw] = block.id.split("|"); + const itemId = isDifferentModel && itemIdRaw?.startsWith("fc_") ? undefined : itemIdRaw; + output.push({ + type: "function_call", + id: itemId, + call_id: callId, + name: block.name, + arguments: JSON.stringify(block.arguments), + }); + } + } + if (output.length > 0) { + messages.push(...output); + } + } else if (msg.role === "toolResult") { + const textResult = msg.content + .filter((item) => item.type === "text") + .map((item) => item.text) + .join("\n"); + const hasImages = msg.content.some((item) => item.type === "image"); + const [callId] = msg.toolCallId.split("|"); + messages.push({ + type: "function_call_output", + call_id: callId, + output: + hasImages && model.input.includes("image") + ? [ + ...(textResult + ? [{ type: "input_text", text: sanitizeTransportPayloadText(textResult) }] + : []), + ...msg.content + .filter((item) => item.type === "image") + .map((item) => ({ + type: "input_image", + detail: "auto", + image_url: `data:${item.mimeType};base64,${item.data}`, + })), + ] + : sanitizeTransportPayloadText(textResult || "(see attached image)"), + }); + } + msgIndex += 1; + } + return messages; +} + +function convertResponsesTools( + tools: NonNullable, + options?: { strict?: boolean | null }, +) { + const strict = options?.strict === undefined ? false : options.strict; + return tools.map((tool) => ({ + type: "function", + name: tool.name, + description: tool.description, + parameters: tool.parameters, + strict, + })); +} + +async function processResponsesStream( + openaiStream: AsyncIterable, + output: MutableAssistantOutput, + stream: { push(event: unknown): void }, + model: Model, + options?: { + serviceTier?: ResponseCreateParamsStreaming["service_tier"]; + applyServiceTierPricing?: ( + usage: MutableAssistantOutput["usage"], + serviceTier?: ResponseCreateParamsStreaming["service_tier"], + ) => void; + }, +) { + let currentItem: Record | null = null; + let currentBlock: Record | null = null; + const blockIndex = () => output.content.length - 1; + for await (const rawEvent of openaiStream) { + const event = rawEvent as Record; + const type = stringifyUnknown(event.type); + if (type === "response.created") { + output.responseId = stringifyUnknown((event.response as { id?: string } | undefined)?.id); + } else if (type === "response.output_item.added") { + const item = event.item as Record; + if (item.type === "reasoning") { + currentItem = item; + currentBlock = { type: "thinking", thinking: "" }; + output.content.push(currentBlock); + stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output }); + } else if (item.type === "message") { + currentItem = item; + currentBlock = { type: "text", text: "" }; + output.content.push(currentBlock); + stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output }); + } else if (item.type === "function_call") { + currentItem = item; + currentBlock = { + type: "toolCall", + id: `${stringifyUnknown(item.call_id)}|${stringifyUnknown(item.id)}`, + name: stringifyUnknown(item.name), + arguments: {}, + partialJson: stringifyJsonLike(item.arguments), + }; + output.content.push(currentBlock); + stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output }); + } + } else if (type === "response.reasoning_summary_text.delta") { + if (currentItem?.type === "reasoning" && currentBlock?.type === "thinking") { + currentBlock.thinking = `${stringifyUnknown(currentBlock.thinking)}${stringifyUnknown(event.delta)}`; + stream.push({ + type: "thinking_delta", + contentIndex: blockIndex(), + delta: stringifyUnknown(event.delta), + partial: output, + }); + } + } else if (type === "response.output_text.delta" || type === "response.refusal.delta") { + if (currentItem?.type === "message" && currentBlock?.type === "text") { + currentBlock.text = `${stringifyUnknown(currentBlock.text)}${stringifyUnknown(event.delta)}`; + stream.push({ + type: "text_delta", + contentIndex: blockIndex(), + delta: stringifyUnknown(event.delta), + partial: output, + }); + } + } else if (type === "response.function_call_arguments.delta") { + if (currentItem?.type === "function_call" && currentBlock?.type === "toolCall") { + currentBlock.partialJson = `${stringifyJsonLike(currentBlock.partialJson)}${stringifyJsonLike(event.delta)}`; + currentBlock.arguments = parseStreamingJson(stringifyJsonLike(currentBlock.partialJson)); + stream.push({ + type: "toolcall_delta", + contentIndex: blockIndex(), + delta: stringifyJsonLike(event.delta), + partial: output, + }); + } + } else if (type === "response.output_item.done") { + const item = event.item as Record; + if (item.type === "reasoning" && currentBlock?.type === "thinking") { + const summary = Array.isArray(item.summary) + ? item.summary.map((part) => String((part as { text?: string }).text ?? "")).join("\n\n") + : ""; + currentBlock.thinking = summary; + currentBlock.thinkingSignature = JSON.stringify(item); + stream.push({ + type: "thinking_end", + contentIndex: blockIndex(), + content: stringifyUnknown(currentBlock.thinking), + partial: output, + }); + currentBlock = null; + } else if (item.type === "message" && currentBlock?.type === "text") { + const content = Array.isArray(item.content) ? item.content : []; + currentBlock.text = content + .map((part) => + (part as { type?: string; text?: string; refusal?: string }).type === "output_text" + ? String((part as { text?: string }).text ?? "") + : String((part as { refusal?: string }).refusal ?? ""), + ) + .join(""); + currentBlock.textSignature = encodeTextSignatureV1( + stringifyUnknown(item.id), + (item.phase as "commentary" | "final_answer" | undefined) ?? undefined, + ); + stream.push({ + type: "text_end", + contentIndex: blockIndex(), + content: stringifyUnknown(currentBlock.text), + partial: output, + }); + currentBlock = null; + } else if (item.type === "function_call") { + const args = + currentBlock?.type === "toolCall" && currentBlock.partialJson + ? parseStreamingJson(stringifyJsonLike(currentBlock.partialJson, "{}")) + : parseStreamingJson(stringifyJsonLike(item.arguments, "{}")); + stream.push({ + type: "toolcall_end", + contentIndex: blockIndex(), + toolCall: { + type: "toolCall", + id: `${stringifyUnknown(item.call_id)}|${stringifyUnknown(item.id)}`, + name: stringifyUnknown(item.name), + arguments: args, + }, + partial: output, + }); + currentBlock = null; + } + } else if (type === "response.completed") { + const response = event.response as Record | undefined; + if (typeof response?.id === "string") { + output.responseId = response.id; + } + const usage = response?.usage as + | { + input_tokens?: number; + output_tokens?: number; + total_tokens?: number; + input_tokens_details?: { cached_tokens?: number }; + service_tier?: ResponseCreateParamsStreaming["service_tier"]; + status?: string; + } + | undefined; + if (usage) { + const cachedTokens = usage.input_tokens_details?.cached_tokens || 0; + output.usage = { + input: (usage.input_tokens || 0) - cachedTokens, + output: usage.output_tokens || 0, + cacheRead: cachedTokens, + cacheWrite: 0, + totalTokens: usage.total_tokens || 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }; + } + calculateCost(model as never, output.usage as never); + if (options?.applyServiceTierPricing) { + options.applyServiceTierPricing( + output.usage, + (response?.service_tier as ResponseCreateParamsStreaming["service_tier"] | undefined) ?? + options.serviceTier, + ); + } + output.stopReason = mapResponsesStopReason(response?.status as string | undefined); + if ( + output.content.some((block) => block.type === "toolCall") && + output.stopReason === "stop" + ) { + output.stopReason = "toolUse"; + } + } else if (type === "error") { + throw new Error( + `Error Code ${stringifyUnknown(event.code, "unknown")}: ${stringifyUnknown(event.message, "Unknown error")}`, + ); + } else if (type === "response.failed") { + const response = event.response as + | { + error?: { code?: string; message?: string }; + incomplete_details?: { reason?: string }; + } + | undefined; + const msg = response?.error + ? `${response.error.code || "unknown"}: ${response.error.message || "no message"}` + : response?.incomplete_details?.reason + ? `incomplete: ${response.incomplete_details.reason}` + : "Unknown error (no error details in response)"; + throw new Error(msg); + } + } +} + +function mapResponsesStopReason(status: string | undefined): string { + if (!status) { + return "stop"; + } + switch (status) { + case "completed": + return "stop"; + case "incomplete": + return "length"; + case "failed": + case "cancelled": + return "error"; + case "in_progress": + case "queued": + return "stop"; + default: + throw new Error(`Unhandled stop reason: ${status}`); + } +} + +function hasTransportOverrides(model: Model): boolean { + const request = getModelProviderRequestTransport(model); + return Boolean(request?.proxy || request?.tls); +} + +export function isTransportAwareApiSupported(api: Api): boolean { + return SUPPORTED_TRANSPORT_APIS.has(api); +} + +export function resolveTransportAwareSimpleApi(api: Api): Api | undefined { + return SIMPLE_TRANSPORT_API_ALIAS[api]; +} + +export function createTransportAwareStreamFnForModel(model: Model): StreamFn | undefined { + if (!hasTransportOverrides(model)) { + return undefined; + } + if (!isTransportAwareApiSupported(model.api)) { + throw new Error( + `Model-provider request.proxy/request.tls is not yet supported for api "${model.api}"`, + ); + } + switch (model.api) { + case "openai-responses": + return createOpenAIResponsesTransportStreamFn(); + case "openai-completions": + return createOpenAICompletionsTransportStreamFn(); + case "azure-openai-responses": + return createAzureOpenAIResponsesTransportStreamFn(); + default: + return undefined; + } +} + +function resolveModelRequestPolicy(model: Model) { + return resolveProviderRequestPolicyConfig({ + provider: model.provider, + api: model.api, + baseUrl: model.baseUrl, + capability: "llm", + transport: "stream", + request: getModelProviderRequestTransport(model), + }); +} + +function buildManagedResponse(response: Response, release: () => Promise): Response { + if (!response.body) { + void release(); + return response; + } + const source = response.body; + let reader: ReadableStreamDefaultReader | undefined; + let released = false; + const finalize = async () => { + if (released) { + return; + } + released = true; + await release().catch(() => undefined); + }; + const wrappedBody = new ReadableStream({ + start() { + reader = source.getReader(); + }, + async pull(controller) { + try { + const chunk = await reader?.read(); + if (!chunk || chunk.done) { + controller.close(); + await finalize(); + return; + } + controller.enqueue(chunk.value); + } catch (error) { + controller.error(error); + await finalize(); + } + }, + async cancel(reason) { + try { + await reader?.cancel(reason); + } finally { + await finalize(); + } + }, + }); + return new Response(wrappedBody, { + status: response.status, + statusText: response.statusText, + headers: response.headers, + }); +} + +function buildGuardedModelFetch(model: Model): typeof fetch { + const requestConfig = resolveModelRequestPolicy(model); + const dispatcherPolicy = buildProviderRequestDispatcherPolicy(requestConfig); + return async (input, init) => { + const request = input instanceof Request ? new Request(input, init) : undefined; + const url = + request?.url ?? + (input instanceof URL + ? input.toString() + : typeof input === "string" + ? input + : (() => { + throw new Error("Unsupported fetch input for transport-aware model request"); + })()); + const requestInit = + request && + ({ + method: request.method, + headers: request.headers, + body: request.body ?? undefined, + redirect: request.redirect, + signal: request.signal, + ...(request.body ? ({ duplex: "half" } as const) : {}), + } satisfies RequestInit & { duplex?: "half" }); + const result = await fetchWithSsrFGuard({ + url, + init: requestInit ?? init, + dispatcherPolicy, + ...(requestConfig.allowPrivateNetwork ? { policy: { allowPrivateNetwork: true } } : {}), + }); + return buildManagedResponse(result.response, result.release); + }; +} + +function buildOpenAIClientHeaders( + model: Model, + context: Context, + optionHeaders?: Record, +): Record { + const headers = { ...model.headers }; + if (model.provider === "github-copilot") { + Object.assign( + headers, + buildCopilotDynamicHeaders({ + messages: context.messages, + hasImages: hasCopilotVisionInput(context.messages), + }), + ); + } + if (optionHeaders) { + Object.assign(headers, optionHeaders); + } + return headers; +} + +function createOpenAIResponsesClient( + model: Model, + context: Context, + apiKey: string, + optionHeaders?: Record, +) { + return new OpenAI({ + apiKey, + baseURL: model.baseUrl, + dangerouslyAllowBrowser: true, + defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders), + fetch: buildGuardedModelFetch(model), + }); +} + +function createOpenAIResponsesTransportStreamFn(): StreamFn { + return (model, context, options) => { + const eventStream = createAssistantMessageEventStream(); + const stream = eventStream as unknown as { push(event: unknown): void; end(): void }; + void (async () => { + const output: MutableAssistantOutput = { + role: "assistant" as const, + content: [], + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; + try { + const apiKey = options?.apiKey || getEnvApiKey(model.provider) || ""; + const client = createOpenAIResponsesClient(model, context, apiKey, options?.headers); + let params = buildOpenAIResponsesParams(model, context, options as OpenAIResponsesOptions); + const nextParams = await options?.onPayload?.(params, model); + if (nextParams !== undefined) { + params = nextParams as typeof params; + } + const responseStream = (await client.responses.create( + params as never, + options?.signal ? { signal: options.signal } : undefined, + )) as unknown as AsyncIterable; + stream.push({ type: "start", partial: output as never }); + await processResponsesStream(responseStream, output, stream, model, { + serviceTier: (options as OpenAIResponsesOptions | undefined)?.serviceTier, + applyServiceTierPricing, + }); + if (options?.signal?.aborted) { + throw new Error("Request was aborted"); + } + if (output.stopReason === "aborted" || output.stopReason === "error") { + throw new Error("An unknown error occurred"); + } + stream.push({ type: "done", reason: output.stopReason as never, message: output as never }); + stream.end(); + } catch (error) { + output.stopReason = options?.signal?.aborted ? "aborted" : "error"; + output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); + stream.push({ type: "error", reason: output.stopReason as never, error: output as never }); + stream.end(); + } + })(); + return eventStream as unknown as ReturnType; + }; +} + +function resolveCacheRetention(cacheRetention: string | undefined): "short" | "long" | "none" { + if (cacheRetention === "short" || cacheRetention === "long" || cacheRetention === "none") { + return cacheRetention; + } + if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") { + return "long"; + } + return "short"; +} + +function getPromptCacheRetention( + baseUrl: string | undefined, + cacheRetention: "short" | "long" | "none", +) { + if (cacheRetention !== "long") { + return undefined; + } + return baseUrl?.includes("api.openai.com") ? "24h" : undefined; +} + +function buildOpenAIResponsesParams( + model: Model, + context: Context, + options: OpenAIResponsesOptions | undefined, +) { + const messages = convertResponsesMessages( + model, + context, + new Set(["openai", "openai-codex", "opencode", "azure-openai-responses"]), + ); + const cacheRetention = resolveCacheRetention(options?.cacheRetention); + const params: Record = { + model: model.id, + input: messages, + stream: true, + prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId, + prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention), + store: false, + }; + if (options?.maxTokens) { + params.max_output_tokens = options.maxTokens; + } + if (options?.temperature !== undefined) { + params.temperature = options.temperature; + } + if (options?.serviceTier !== undefined) { + params.service_tier = options.serviceTier; + } + if (context.tools) { + params.tools = convertResponsesTools(context.tools); + } + if (model.reasoning) { + if (options?.reasoningEffort || options?.reasoningSummary) { + params.reasoning = { + effort: options?.reasoningEffort || "medium", + summary: options?.reasoningSummary || "auto", + }; + params.include = ["reasoning.encrypted_content"]; + } else if (model.provider !== "github-copilot") { + params.reasoning = { effort: "none" }; + } + } + return params; +} + +function createAzureOpenAIResponsesTransportStreamFn(): StreamFn { + return (model, context, options) => { + const eventStream = createAssistantMessageEventStream(); + const stream = eventStream as unknown as { push(event: unknown): void; end(): void }; + void (async () => { + const output: MutableAssistantOutput = { + role: "assistant" as const, + content: [], + api: "azure-openai-responses", + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; + try { + const apiKey = options?.apiKey || getEnvApiKey(model.provider) || ""; + const client = createAzureOpenAIClient(model, context, apiKey, options?.headers); + const deploymentName = resolveAzureDeploymentName(model); + let params = buildAzureOpenAIResponsesParams( + model, + context, + options as OpenAIResponsesOptions | undefined, + deploymentName, + ); + const nextParams = await options?.onPayload?.(params, model); + if (nextParams !== undefined) { + params = nextParams as typeof params; + } + const responseStream = (await client.responses.create( + params as never, + options?.signal ? { signal: options.signal } : undefined, + )) as unknown as AsyncIterable; + stream.push({ type: "start", partial: output as never }); + await processResponsesStream(responseStream, output, stream, model); + if (options?.signal?.aborted) { + throw new Error("Request was aborted"); + } + if (output.stopReason === "aborted" || output.stopReason === "error") { + throw new Error("An unknown error occurred"); + } + stream.push({ type: "done", reason: output.stopReason as never, message: output as never }); + stream.end(); + } catch (error) { + output.stopReason = options?.signal?.aborted ? "aborted" : "error"; + output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); + stream.push({ type: "error", reason: output.stopReason as never, error: output as never }); + stream.end(); + } + })(); + return eventStream as unknown as ReturnType; + }; +} + +function normalizeAzureBaseUrl(baseUrl: string): string { + return baseUrl.replace(/\/+$/, ""); +} + +function resolveAzureDeploymentName(model: Model): string { + const deploymentMap = process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP; + if (deploymentMap) { + for (const entry of deploymentMap.split(",")) { + const [modelId, deploymentName] = entry.split("=", 2).map((value) => value?.trim()); + if (modelId === model.id && deploymentName) { + return deploymentName; + } + } + } + return model.id; +} + +function createAzureOpenAIClient( + model: Model, + context: Context, + apiKey: string, + optionHeaders?: Record, +) { + return new AzureOpenAI({ + apiKey, + apiVersion: resolveAzureOpenAIApiVersion(), + dangerouslyAllowBrowser: true, + defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders), + baseURL: normalizeAzureBaseUrl(model.baseUrl), + fetch: buildGuardedModelFetch(model), + }); +} + +function buildAzureOpenAIResponsesParams( + model: Model, + context: Context, + options: OpenAIResponsesOptions | undefined, + deploymentName: string, +) { + const params = buildOpenAIResponsesParams(model, context, options); + params.model = deploymentName; + delete params.store; + return params; +} + +function hasToolHistory(messages: Context["messages"]): boolean { + return messages.some( + (message) => + message.role === "toolResult" || + (message.role === "assistant" && message.content.some((block) => block.type === "toolCall")), + ); +} + +function createOpenAICompletionsClient( + model: Model, + context: Context, + apiKey: string, + optionHeaders?: Record, +) { + return new OpenAI({ + apiKey, + baseURL: model.baseUrl, + dangerouslyAllowBrowser: true, + defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders), + fetch: buildGuardedModelFetch(model), + }); +} + +function createOpenAICompletionsTransportStreamFn(): StreamFn { + return (model, context, options) => { + const eventStream = createAssistantMessageEventStream(); + const stream = eventStream as unknown as { push(event: unknown): void; end(): void }; + void (async () => { + const output: MutableAssistantOutput = { + role: "assistant" as const, + content: [], + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; + try { + const apiKey = options?.apiKey || getEnvApiKey(model.provider) || ""; + const client = createOpenAICompletionsClient(model, context, apiKey, options?.headers); + let params = buildOpenAICompletionsParams( + model as OpenAIModeModel, + context, + options as OpenAICompletionsOptions | undefined, + ); + const nextParams = await options?.onPayload?.(params, model); + if (nextParams !== undefined) { + params = nextParams as typeof params; + } + const responseStream = (await client.chat.completions.create(params as never, { + signal: options?.signal, + })) as unknown as AsyncIterable; + stream.push({ type: "start", partial: output as never }); + await processOpenAICompletionsStream(responseStream, output, model, stream); + if (options?.signal?.aborted) { + throw new Error("Request was aborted"); + } + stream.push({ type: "done", reason: output.stopReason as never, message: output as never }); + stream.end(); + } catch (error) { + output.stopReason = options?.signal?.aborted ? "aborted" : "error"; + output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); + stream.push({ type: "error", reason: output.stopReason as never, error: output as never }); + stream.end(); + } + })(); + return eventStream as unknown as ReturnType; + }; +} + +async function processOpenAICompletionsStream( + responseStream: AsyncIterable, + output: MutableAssistantOutput, + model: Model, + stream: { push(event: unknown): void }, +) { + let currentBlock: + | { type: "text"; text: string } + | { type: "thinking"; thinking: string; thinkingSignature?: string } + | { + type: "toolCall"; + id: string; + name: string; + arguments: Record; + partialArgs: string; + } + | null = null; + const blockIndex = () => output.content.length - 1; + const finishCurrentBlock = () => { + if (!currentBlock) { + return; + } + if (currentBlock.type === "toolCall") { + currentBlock.arguments = parseStreamingJson(currentBlock.partialArgs); + const completed = { + ...currentBlock, + arguments: parseStreamingJson(currentBlock.partialArgs), + }; + output.content[blockIndex()] = completed; + } + }; + for await (const chunk of responseStream) { + output.responseId ||= chunk.id; + if (chunk.usage) { + output.usage = parseTransportChunkUsage(chunk.usage, model); + } + const choice = Array.isArray(chunk.choices) ? chunk.choices[0] : undefined; + if (!choice) { + continue; + } + const choiceUsage = (choice as unknown as { usage?: ChatCompletionChunk["usage"] }).usage; + if (!chunk.usage && choiceUsage) { + output.usage = parseTransportChunkUsage(choiceUsage, model); + } + if (choice.finish_reason) { + const finishReasonResult = mapStopReason(choice.finish_reason); + output.stopReason = finishReasonResult.stopReason; + if (finishReasonResult.errorMessage) { + output.errorMessage = finishReasonResult.errorMessage; + } + } + if (!choice.delta) { + continue; + } + if (choice.delta.content) { + if (!currentBlock || currentBlock.type !== "text") { + finishCurrentBlock(); + currentBlock = { type: "text", text: "" }; + output.content.push(currentBlock); + stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output }); + } + currentBlock.text += choice.delta.content; + stream.push({ + type: "text_delta", + contentIndex: blockIndex(), + delta: choice.delta.content, + partial: output, + }); + continue; + } + const reasoningFields = ["reasoning_content", "reasoning", "reasoning_text"] as const; + const reasoningField = reasoningFields.find((field) => { + const value = (choice.delta as Record)[field]; + return typeof value === "string" && value.length > 0; + }); + if (reasoningField) { + if (!currentBlock || currentBlock.type !== "thinking") { + finishCurrentBlock(); + currentBlock = { type: "thinking", thinking: "", thinkingSignature: reasoningField }; + output.content.push(currentBlock); + stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output }); + } + currentBlock.thinking += String((choice.delta as Record)[reasoningField]); + stream.push({ + type: "thinking_delta", + contentIndex: blockIndex(), + delta: String((choice.delta as Record)[reasoningField]), + partial: output, + }); + continue; + } + if (choice.delta.tool_calls) { + for (const toolCall of choice.delta.tool_calls) { + if ( + !currentBlock || + currentBlock.type !== "toolCall" || + (toolCall.id && currentBlock.id !== toolCall.id) + ) { + finishCurrentBlock(); + currentBlock = { + type: "toolCall", + id: toolCall.id || "", + name: toolCall.function?.name || "", + arguments: {}, + partialArgs: "", + }; + output.content.push(currentBlock); + stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output }); + } + if (currentBlock.type !== "toolCall") { + continue; + } + if (toolCall.id) { + currentBlock.id = toolCall.id; + } + if (toolCall.function?.name) { + currentBlock.name = toolCall.function.name; + } + if (toolCall.function?.arguments) { + currentBlock.partialArgs += toolCall.function.arguments; + currentBlock.arguments = parseStreamingJson(currentBlock.partialArgs); + stream.push({ + type: "toolcall_delta", + contentIndex: blockIndex(), + delta: toolCall.function.arguments, + partial: output, + }); + } + } + } + } + finishCurrentBlock(); +} + +function detectCompat(model: OpenAIModeModel) { + const provider = model.provider; + const baseUrl = model.baseUrl ?? ""; + const isZai = provider === "zai" || baseUrl.includes("api.z.ai"); + const isNonStandard = + provider === "cerebras" || + baseUrl.includes("cerebras.ai") || + provider === "xai" || + baseUrl.includes("api.x.ai") || + baseUrl.includes("chutes.ai") || + baseUrl.includes("deepseek.com") || + isZai || + provider === "opencode" || + baseUrl.includes("opencode.ai"); + const useMaxTokens = baseUrl.includes("chutes.ai"); + const isGrok = provider === "xai" || baseUrl.includes("api.x.ai"); + const isGroq = provider === "groq" || baseUrl.includes("groq.com"); + const reasoningEffortMap: Record = + isGroq && model.id === "qwen/qwen3-32b" + ? { + minimal: "default", + low: "default", + medium: "default", + high: "default", + xhigh: "default", + } + : {}; + return { + supportsStore: !isNonStandard, + supportsDeveloperRole: !isNonStandard, + supportsReasoningEffort: !isGrok && !isZai, + reasoningEffortMap, + supportsUsageInStreaming: true, + maxTokensField: useMaxTokens ? "max_tokens" : "max_completion_tokens", + requiresToolResultName: false, + requiresAssistantAfterToolResult: false, + requiresThinkingAsText: false, + thinkingFormat: isZai + ? "zai" + : provider === "openrouter" || baseUrl.includes("openrouter.ai") + ? "openrouter" + : "openai", + openRouterRouting: {}, + vercelGatewayRouting: {}, + supportsStrictMode: true, + }; +} + +function getCompat(model: OpenAIModeModel) { + const detected = detectCompat(model); + const compat = model.compat ?? {}; + return { + supportsStore: compat.supportsStore ?? detected.supportsStore, + supportsDeveloperRole: compat.supportsDeveloperRole ?? detected.supportsDeveloperRole, + supportsReasoningEffort: compat.supportsReasoningEffort ?? detected.supportsReasoningEffort, + reasoningEffortMap: + (compat.reasoningEffortMap as Record | undefined) ?? + detected.reasoningEffortMap, + supportsUsageInStreaming: + (compat.supportsUsageInStreaming as boolean | undefined) ?? detected.supportsUsageInStreaming, + maxTokensField: (compat.maxTokensField as string | undefined) ?? detected.maxTokensField, + requiresToolResultName: + (compat.requiresToolResultName as boolean | undefined) ?? detected.requiresToolResultName, + requiresAssistantAfterToolResult: + (compat.requiresAssistantAfterToolResult as boolean | undefined) ?? + detected.requiresAssistantAfterToolResult, + requiresThinkingAsText: + (compat.requiresThinkingAsText as boolean | undefined) ?? detected.requiresThinkingAsText, + thinkingFormat: (compat.thinkingFormat as string | undefined) ?? detected.thinkingFormat, + openRouterRouting: (compat.openRouterRouting as Record | undefined) ?? {}, + vercelGatewayRouting: + (compat.vercelGatewayRouting as Record | undefined) ?? + detected.vercelGatewayRouting, + supportsStrictMode: + (compat.supportsStrictMode as boolean | undefined) ?? detected.supportsStrictMode, + }; +} + +function mapReasoningEffort(effort: string, reasoningEffortMap: Record): string { + return reasoningEffortMap[effort] ?? effort; +} + +function convertTools(tools: NonNullable, compat: ReturnType) { + return tools.map((tool) => ({ + type: "function", + function: { + name: tool.name, + description: tool.description, + parameters: tool.parameters, + ...(compat.supportsStrictMode ? { strict: false } : {}), + }, + })); +} + +function buildOpenAICompletionsParams( + model: OpenAIModeModel, + context: Context, + options: OpenAICompletionsOptions | undefined, +) { + const compat = getCompat(model); + const params: Record = { + model: model.id, + messages: convertMessages(model as never, context, compat as never), + stream: true, + }; + if (compat.supportsUsageInStreaming) { + params.stream_options = { include_usage: true }; + } + if (compat.supportsStore) { + params.store = false; + } + if (options?.maxTokens) { + if (compat.maxTokensField === "max_tokens") { + params.max_tokens = options.maxTokens; + } else { + params.max_completion_tokens = options.maxTokens; + } + } + if (options?.temperature !== undefined) { + params.temperature = options.temperature; + } + if (context.tools) { + params.tools = convertTools(context.tools, compat); + } else if (hasToolHistory(context.messages)) { + params.tools = []; + } + if (options?.toolChoice) { + params.tool_choice = options.toolChoice; + } + if (compat.thinkingFormat === "openrouter" && model.reasoning && options?.reasoningEffort) { + params.reasoning = { + effort: mapReasoningEffort(options.reasoningEffort, compat.reasoningEffortMap), + }; + } else if (options?.reasoningEffort && model.reasoning && compat.supportsReasoningEffort) { + params.reasoning_effort = mapReasoningEffort( + options.reasoningEffort, + compat.reasoningEffortMap, + ); + } + return params; +} + +export function parseTransportChunkUsage( + rawUsage: NonNullable, + model: Model, +) { + const cachedTokens = rawUsage.prompt_tokens_details?.cached_tokens || 0; + const promptTokens = rawUsage.prompt_tokens || 0; + const input = Math.max(0, promptTokens - cachedTokens); + const outputTokens = rawUsage.completion_tokens || 0; + const usage = { + input, + output: outputTokens, + cacheRead: cachedTokens, + cacheWrite: 0, + totalTokens: input + outputTokens + cachedTokens, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }; + calculateCost(model as never, usage as never); + return usage; +} + +function mapStopReason(reason: string | null) { + if (reason === null) { + return { stopReason: "stop" }; + } + switch (reason) { + case "stop": + case "end": + return { stopReason: "stop" }; + case "length": + return { stopReason: "length" }; + case "function_call": + case "tool_calls": + return { stopReason: "toolUse" }; + case "content_filter": + return { stopReason: "error", errorMessage: "Provider finish_reason: content_filter" }; + case "network_error": + return { stopReason: "error", errorMessage: "Provider finish_reason: network_error" }; + default: + return { + stopReason: "error", + errorMessage: `Provider finish_reason: ${reason}`, + }; + } +} + +export function prepareTransportAwareSimpleModel(model: Model): Model { + const streamFn = createTransportAwareStreamFnForModel(model as Model); + const alias = resolveTransportAwareSimpleApi(model.api); + if (!streamFn || !alias) { + return model; + } + return { + ...model, + api: alias, + }; +} + +export function buildTransportAwareSimpleStreamFn(model: Model): StreamFn | undefined { + return createTransportAwareStreamFnForModel(model); +} diff --git a/src/agents/pi-embedded-runner/model.test.ts b/src/agents/pi-embedded-runner/model.test.ts index 1e8009bd408..7809f418b08 100644 --- a/src/agents/pi-embedded-runner/model.test.ts +++ b/src/agents/pi-embedded-runner/model.test.ts @@ -232,22 +232,27 @@ describe("buildInlineProviderModels", () => { expect(result[0].headers).toEqual({ "X-Tenant": "acme" }); }); - it("rejects inline provider transport overrides that the llm model path cannot carry", () => { - expect(() => - buildInlineProviderModels({ - proxy: { - baseUrl: "https://proxy.example.com/v1", - api: "openai-completions", - request: { - proxy: { - mode: "explicit-proxy", - url: "http://proxy.internal:8443", - }, + it("keeps inline provider transport overrides once the llm transport adapter is available", () => { + const result = buildInlineProviderModels({ + proxy: { + baseUrl: "https://proxy.example.com/v1", + api: "openai-completions", + request: { + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", }, - models: [makeModel("proxy-model")], }, - } as unknown as Parameters[0]), - ).toThrow(/models\.providers\.\*\.request only supports headers and auth overrides/i); + models: [makeModel("proxy-model")], + }, + } as unknown as Parameters[0]); + + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + provider: "proxy", + api: "openai-completions", + baseUrl: "https://proxy.example.com/v1", + }); }); it("omits headers when neither provider nor model specifies them", () => { diff --git a/src/agents/pi-embedded-runner/model.ts b/src/agents/pi-embedded-runner/model.ts index e55bc4ef428..60b71e8e82f 100644 --- a/src/agents/pi-embedded-runner/model.ts +++ b/src/agents/pi-embedded-runner/model.ts @@ -23,6 +23,7 @@ import { } from "../model-suppression.js"; import { discoverAuthStorage, discoverModels } from "../pi-model-discovery.js"; import { + attachModelProviderRequestTransport, resolveProviderRequestConfig, sanitizeConfiguredModelProviderRequest, } from "../provider-request-config.js"; @@ -355,18 +356,21 @@ function applyConfiguredProviderOverrides(params: { capability: "llm", transport: "stream", }); - return { - ...discoveredModel, - api: requestConfig.api ?? "openai-responses", - baseUrl: requestConfig.baseUrl ?? discoveredModel.baseUrl, - reasoning: configuredModel?.reasoning ?? discoveredModel.reasoning, - input: normalizedInput, - cost: configuredModel?.cost ?? discoveredModel.cost, - contextWindow: configuredModel?.contextWindow ?? discoveredModel.contextWindow, - maxTokens: configuredModel?.maxTokens ?? discoveredModel.maxTokens, - headers: requestConfig.headers, - compat: configuredModel?.compat ?? discoveredModel.compat, - }; + return attachModelProviderRequestTransport( + { + ...discoveredModel, + api: requestConfig.api ?? "openai-responses", + baseUrl: requestConfig.baseUrl ?? discoveredModel.baseUrl, + reasoning: configuredModel?.reasoning ?? discoveredModel.reasoning, + input: normalizedInput, + cost: configuredModel?.cost ?? discoveredModel.cost, + contextWindow: configuredModel?.contextWindow ?? discoveredModel.contextWindow, + maxTokens: configuredModel?.maxTokens ?? discoveredModel.maxTokens, + headers: requestConfig.headers, + compat: configuredModel?.compat ?? discoveredModel.compat, + }, + providerRequest, + ); } export function buildInlineProviderModels( @@ -401,13 +405,16 @@ export function buildInlineProviderModels( capability: "llm", transport: "stream", }); - return { - ...model, - provider: trimmed, - baseUrl: requestConfig.baseUrl, - api: requestConfig.api ?? model.api, - headers: requestConfig.headers, - }; + return attachModelProviderRequestTransport( + { + ...model, + provider: trimmed, + baseUrl: requestConfig.baseUrl ?? transport.baseUrl, + api: requestConfig.api ?? model.api, + headers: requestConfig.headers, + }, + providerRequest, + ); }); }); } @@ -571,25 +578,28 @@ function resolveConfiguredFallbackModel(params: { provider, cfg, agentDir, - model: { - id: modelId, - name: modelId, - api: requestConfig.api ?? "openai-responses", - provider, - baseUrl: requestConfig.baseUrl, - reasoning: configuredModel?.reasoning ?? false, - input: ["text"], - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, - contextWindow: - configuredModel?.contextWindow ?? - providerConfig?.models?.[0]?.contextWindow ?? - DEFAULT_CONTEXT_TOKENS, - maxTokens: - configuredModel?.maxTokens ?? - providerConfig?.models?.[0]?.maxTokens ?? - DEFAULT_CONTEXT_TOKENS, - headers: requestConfig.headers, - } as Model, + model: attachModelProviderRequestTransport( + { + id: modelId, + name: modelId, + api: requestConfig.api ?? "openai-responses", + provider, + baseUrl: requestConfig.baseUrl, + reasoning: configuredModel?.reasoning ?? false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: + configuredModel?.contextWindow ?? + providerConfig?.models?.[0]?.contextWindow ?? + DEFAULT_CONTEXT_TOKENS, + maxTokens: + configuredModel?.maxTokens ?? + providerConfig?.models?.[0]?.maxTokens ?? + DEFAULT_CONTEXT_TOKENS, + headers: requestConfig.headers, + } as Model, + providerRequest, + ), runtimeHooks, }); } diff --git a/src/agents/provider-request-config.test.ts b/src/agents/provider-request-config.test.ts index 92d5fe883f0..df58847d189 100644 --- a/src/agents/provider-request-config.test.ts +++ b/src/agents/provider-request-config.test.ts @@ -310,8 +310,8 @@ describe("provider request config", () => { ).toThrow(/request\.(headers\.X-Tenant|auth\.token|tls\.cert): unresolved SecretRef/i); }); - it("rejects model-provider transport overrides that the llm path cannot carry", () => { - expect(() => + it("keeps model-provider transport overrides once the llm path can carry them", () => { + expect( sanitizeConfiguredModelProviderRequest({ headers: { "X-Tenant": "acme", @@ -321,7 +321,15 @@ describe("provider request config", () => { url: "http://proxy.internal:8443", }, }), - ).toThrow(/models\.providers\.\*\.request only supports headers and auth overrides/i); + ).toEqual({ + headers: { + "X-Tenant": "acme", + }, + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", + }, + }); }); it("merges configured request overrides with later entries winning", () => { diff --git a/src/agents/provider-request-config.ts b/src/agents/provider-request-config.ts index d008f0bc61f..b41a3c13231 100644 --- a/src/agents/provider-request-config.ts +++ b/src/agents/provider-request-config.ts @@ -300,23 +300,10 @@ export function sanitizeConfiguredProviderRequest( }; } -const MODEL_PROVIDER_REQUEST_TRANSPORT_MESSAGE = - "models.providers.*.request only supports headers and auth overrides; proxy and TLS transport settings are not wired for model-provider requests"; - export function sanitizeConfiguredModelProviderRequest( request: ConfiguredModelProviderRequest | ConfiguredProviderRequest | undefined, ): ProviderRequestTransportOverrides | undefined { - const sanitized = sanitizeConfiguredProviderRequest(request); - if (!sanitized) { - return undefined; - } - if (sanitized.proxy || sanitized.tls) { - throw new Error(MODEL_PROVIDER_REQUEST_TRANSPORT_MESSAGE); - } - return { - ...(sanitized.headers ? { headers: sanitized.headers } : {}), - ...(sanitized.auth ? { auth: sanitized.auth } : {}), - }; + return sanitizeConfiguredProviderRequest(request); } export function mergeProviderRequestOverrides( @@ -700,3 +687,29 @@ export function resolveProviderRequestHeaders(params: { request: params.request, }).headers; } + +const MODEL_PROVIDER_REQUEST_TRANSPORT_SYMBOL = Symbol.for( + "openclaw.modelProviderRequestTransport", +); + +type ModelWithProviderRequestTransport = { + [MODEL_PROVIDER_REQUEST_TRANSPORT_SYMBOL]?: ProviderRequestTransportOverrides; +}; + +export function attachModelProviderRequestTransport( + model: TModel, + request: ProviderRequestTransportOverrides | undefined, +): TModel { + if (!request) { + return model; + } + const next = { ...model } as TModel & ModelWithProviderRequestTransport; + next[MODEL_PROVIDER_REQUEST_TRANSPORT_SYMBOL] = request; + return next; +} + +export function getModelProviderRequestTransport( + model: object, +): ProviderRequestTransportOverrides | undefined { + return (model as ModelWithProviderRequestTransport)[MODEL_PROVIDER_REQUEST_TRANSPORT_SYMBOL]; +} diff --git a/src/agents/provider-stream.ts b/src/agents/provider-stream.ts index 9eab449c8d5..c8d3edb4557 100644 --- a/src/agents/provider-stream.ts +++ b/src/agents/provider-stream.ts @@ -3,6 +3,7 @@ import type { Api, Model } from "@mariozechner/pi-ai"; import type { OpenClawConfig } from "../config/config.js"; import { resolveProviderStreamFn } from "../plugins/provider-runtime.js"; import { ensureCustomApiRegistered } from "./custom-api-registry.js"; +import { createTransportAwareStreamFnForModel } from "./openai-transport-stream.js"; export function registerProviderStreamForModel(params: { model: Model; @@ -11,20 +12,21 @@ export function registerProviderStreamForModel(params: { workspaceDir?: string; env?: NodeJS.ProcessEnv; }): StreamFn | undefined { - const streamFn = resolveProviderStreamFn({ - provider: params.model.provider, - config: params.cfg, - workspaceDir: params.workspaceDir, - env: params.env, - context: { - config: params.cfg, - agentDir: params.agentDir, - workspaceDir: params.workspaceDir, + const streamFn = + resolveProviderStreamFn({ provider: params.model.provider, - modelId: params.model.id, - model: params.model, - }, - }); + config: params.cfg, + workspaceDir: params.workspaceDir, + env: params.env, + context: { + config: params.cfg, + agentDir: params.agentDir, + workspaceDir: params.workspaceDir, + provider: params.model.provider, + modelId: params.model.id, + model: params.model, + }, + }) ?? createTransportAwareStreamFnForModel(params.model); if (!streamFn) { return undefined; } diff --git a/src/agents/simple-completion-transport.test.ts b/src/agents/simple-completion-transport.test.ts index 085c25780a2..90cb9939a7d 100644 --- a/src/agents/simple-completion-transport.test.ts +++ b/src/agents/simple-completion-transport.test.ts @@ -5,6 +5,8 @@ import type { OpenClawConfig } from "../config/config.js"; const createAnthropicVertexStreamFnForModel = vi.fn(); const ensureCustomApiRegistered = vi.fn(); const resolveProviderStreamFn = vi.fn(); +const buildTransportAwareSimpleStreamFn = vi.fn(); +const prepareTransportAwareSimpleModel = vi.fn(); vi.mock("./anthropic-vertex-stream.js", () => ({ createAnthropicVertexStreamFnForModel, @@ -14,6 +16,11 @@ vi.mock("./custom-api-registry.js", () => ({ ensureCustomApiRegistered, })); +vi.mock("./openai-transport-stream.js", () => ({ + buildTransportAwareSimpleStreamFn, + prepareTransportAwareSimpleModel, +})); + vi.mock("../plugins/provider-runtime.js", () => ({ resolveProviderStreamFn, })); @@ -29,8 +36,12 @@ describe("prepareModelForSimpleCompletion", () => { createAnthropicVertexStreamFnForModel.mockReset(); ensureCustomApiRegistered.mockReset(); resolveProviderStreamFn.mockReset(); + buildTransportAwareSimpleStreamFn.mockReset(); + prepareTransportAwareSimpleModel.mockReset(); createAnthropicVertexStreamFnForModel.mockReturnValue("vertex-stream"); resolveProviderStreamFn.mockReturnValue("ollama-stream"); + buildTransportAwareSimpleStreamFn.mockReturnValue(undefined); + prepareTransportAwareSimpleModel.mockImplementation((model) => model); }); it("registers the configured Ollama transport and keeps the original api", () => { @@ -106,4 +117,39 @@ describe("prepareModelForSimpleCompletion", () => { api: "openclaw-anthropic-vertex-simple:https%3A%2F%2Fus-central1-aiplatform.googleapis.com", }); }); + + it("uses a transport-aware custom api alias when llm request transport overrides are present", () => { + const model: Model<"openai-responses"> = { + id: "gpt-5", + name: "GPT-5", + api: "openai-responses", + provider: "openai", + baseUrl: "https://api.openai.com/v1", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + }; + + resolveProviderStreamFn.mockReturnValueOnce(undefined); + buildTransportAwareSimpleStreamFn.mockReturnValueOnce("transport-stream"); + prepareTransportAwareSimpleModel.mockReturnValueOnce({ + ...model, + api: "openclaw-openai-responses-transport", + }); + + const result = prepareModelForSimpleCompletion({ model }); + + expect(prepareTransportAwareSimpleModel).toHaveBeenCalledWith(model); + expect(buildTransportAwareSimpleStreamFn).toHaveBeenCalledWith(model); + expect(ensureCustomApiRegistered).toHaveBeenCalledWith( + "openclaw-openai-responses-transport", + "transport-stream", + ); + expect(result).toEqual({ + ...model, + api: "openclaw-openai-responses-transport", + }); + }); }); diff --git a/src/agents/simple-completion-transport.ts b/src/agents/simple-completion-transport.ts index cb9d5cd67b5..3448a5f026f 100644 --- a/src/agents/simple-completion-transport.ts +++ b/src/agents/simple-completion-transport.ts @@ -2,6 +2,10 @@ import { getApiProvider, type Api, type Model } from "@mariozechner/pi-ai"; import type { OpenClawConfig } from "../config/config.js"; import { createAnthropicVertexStreamFnForModel } from "./anthropic-vertex-stream.js"; import { ensureCustomApiRegistered } from "./custom-api-registry.js"; +import { + buildTransportAwareSimpleStreamFn, + prepareTransportAwareSimpleModel, +} from "./openai-transport-stream.js"; import { registerProviderStreamForModel } from "./provider-stream.js"; function resolveAnthropicVertexSimpleApi(baseUrl?: string): Api { @@ -19,6 +23,15 @@ export function prepareModelForSimpleCompletion(params: { return model; } + const transportAwareModel = prepareTransportAwareSimpleModel(model); + if (transportAwareModel !== model) { + const streamFn = buildTransportAwareSimpleStreamFn(model); + if (streamFn) { + ensureCustomApiRegistered(transportAwareModel.api, streamFn); + return transportAwareModel; + } + } + if (model.provider === "anthropic-vertex") { const api = resolveAnthropicVertexSimpleApi(model.baseUrl); ensureCustomApiRegistered(api, createAnthropicVertexStreamFnForModel(model));