From ea4265a820636ceab3a13da3bd227dc87eda15f4 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Sat, 4 Apr 2026 02:56:28 +0900 Subject: [PATCH] feat(providers): add anthropic transport runtime --- src/agents/anthropic-transport-stream.test.ts | 267 +++++ src/agents/anthropic-transport-stream.ts | 918 ++++++++++++++++++ src/agents/copilot-dynamic-headers.ts | 29 + src/agents/openai-transport-stream.test.ts | 45 +- src/agents/openai-transport-stream.ts | 334 +------ src/agents/provider-stream.ts | 2 +- src/agents/provider-transport-fetch.ts | 100 ++ src/agents/provider-transport-stream.ts | 75 ++ .../simple-completion-transport.test.ts | 2 +- src/agents/simple-completion-transport.ts | 4 +- src/agents/transport-message-transform.ts | 131 +++ 11 files changed, 1575 insertions(+), 332 deletions(-) create mode 100644 src/agents/anthropic-transport-stream.test.ts create mode 100644 src/agents/anthropic-transport-stream.ts create mode 100644 src/agents/copilot-dynamic-headers.ts create mode 100644 src/agents/provider-transport-fetch.ts create mode 100644 src/agents/provider-transport-stream.ts create mode 100644 src/agents/transport-message-transform.ts diff --git a/src/agents/anthropic-transport-stream.test.ts b/src/agents/anthropic-transport-stream.test.ts new file mode 100644 index 00000000000..24455228b65 --- /dev/null +++ b/src/agents/anthropic-transport-stream.test.ts @@ -0,0 +1,267 @@ +import type { Model } from "@mariozechner/pi-ai"; +import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import { attachModelProviderRequestTransport } from "./provider-request-config.js"; + +const { + anthropicCtorMock, + anthropicMessagesStreamMock, + buildGuardedModelFetchMock, + guardedFetchMock, +} = vi.hoisted(() => ({ + anthropicCtorMock: vi.fn(), + anthropicMessagesStreamMock: vi.fn(), + buildGuardedModelFetchMock: vi.fn(), + guardedFetchMock: vi.fn(), +})); + +vi.mock("@anthropic-ai/sdk", () => ({ + default: anthropicCtorMock, +})); + +vi.mock("./provider-transport-fetch.js", () => ({ + buildGuardedModelFetch: buildGuardedModelFetchMock, +})); + +let createAnthropicMessagesTransportStreamFn: typeof import("./anthropic-transport-stream.js").createAnthropicMessagesTransportStreamFn; + +function emptyEventStream(): AsyncIterable> { + return (async function* () {})(); +} + +describe("anthropic transport stream", () => { + beforeAll(async () => { + ({ createAnthropicMessagesTransportStreamFn } = + await import("./anthropic-transport-stream.js")); + }); + + beforeEach(() => { + anthropicCtorMock.mockReset(); + anthropicMessagesStreamMock.mockReset(); + buildGuardedModelFetchMock.mockReset(); + guardedFetchMock.mockReset(); + buildGuardedModelFetchMock.mockReturnValue(guardedFetchMock); + anthropicMessagesStreamMock.mockReturnValue(emptyEventStream()); + anthropicCtorMock.mockImplementation(function mockAnthropicClient() { + return { + messages: { + stream: anthropicMessagesStreamMock, + }, + }; + }); + }); + + it("uses the guarded fetch transport for api-key Anthropic requests", async () => { + const model = attachModelProviderRequestTransport( + { + id: "claude-sonnet-4-6", + name: "Claude Sonnet 4.6", + api: "anthropic-messages", + provider: "anthropic", + baseUrl: "https://api.anthropic.com", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + headers: { "X-Provider": "anthropic" }, + } satisfies Model<"anthropic-messages">, + { + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", + }, + }, + ); + const streamFn = createAnthropicMessagesTransportStreamFn(); + + const stream = await Promise.resolve( + streamFn( + model, + { + messages: [{ role: "user", content: "hello" }], + } as Parameters[1], + { + apiKey: "sk-ant-api", + headers: { "X-Call": "1" }, + } as Parameters[2], + ), + ); + await stream.result(); + + expect(buildGuardedModelFetchMock).toHaveBeenCalledWith(model); + expect(anthropicCtorMock).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: "sk-ant-api", + baseURL: "https://api.anthropic.com", + fetch: guardedFetchMock, + defaultHeaders: expect.objectContaining({ + accept: "application/json", + "anthropic-dangerous-direct-browser-access": "true", + "X-Provider": "anthropic", + "X-Call": "1", + }), + }), + ); + expect(anthropicMessagesStreamMock).toHaveBeenCalledWith( + expect.objectContaining({ + model: "claude-sonnet-4-6", + stream: true, + }), + undefined, + ); + }); + + it("preserves Anthropic OAuth identity and tool-name remapping with transport overrides", async () => { + anthropicMessagesStreamMock.mockReturnValueOnce( + (async function* () { + yield { + type: "message_start", + message: { id: "msg_1", usage: { input_tokens: 10, output_tokens: 0 } }, + }; + yield { + type: "content_block_start", + index: 0, + content_block: { + type: "tool_use", + id: "tool_1", + name: "Read", + input: { path: "/tmp/a" }, + }, + }; + yield { + type: "content_block_stop", + index: 0, + }; + yield { + type: "message_delta", + delta: { stop_reason: "tool_use" }, + usage: { input_tokens: 10, output_tokens: 5 }, + }; + })(), + ); + const model = attachModelProviderRequestTransport( + { + id: "claude-sonnet-4-6", + name: "Claude Sonnet 4.6", + api: "anthropic-messages", + provider: "anthropic", + baseUrl: "https://api.anthropic.com", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + } satisfies Model<"anthropic-messages">, + { + tls: { + ca: "ca-pem", + }, + }, + ); + const streamFn = createAnthropicMessagesTransportStreamFn(); + const stream = await Promise.resolve( + streamFn( + model, + { + systemPrompt: "Follow policy.", + messages: [{ role: "user", content: "Read the file" }], + tools: [ + { + name: "read", + description: "Read a file", + parameters: { + type: "object", + properties: { + path: { type: "string" }, + }, + required: ["path"], + }, + }, + ], + } as unknown as Parameters[1], + { + apiKey: "sk-ant-oat-example", + } as Parameters[2], + ), + ); + const result = await stream.result(); + + expect(anthropicCtorMock).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: null, + authToken: "sk-ant-oat-example", + fetch: guardedFetchMock, + defaultHeaders: expect.objectContaining({ + "x-app": "cli", + "user-agent": expect.stringContaining("claude-cli/"), + }), + }), + ); + const firstCallParams = anthropicMessagesStreamMock.mock.calls[0]?.[0] as Record< + string, + unknown + >; + expect(firstCallParams.system).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + text: "You are Claude Code, Anthropic's official CLI for Claude.", + }), + expect.objectContaining({ + text: "Follow policy.", + }), + ]), + ); + expect(firstCallParams.tools).toEqual( + expect.arrayContaining([expect.objectContaining({ name: "Read" })]), + ); + expect(result.stopReason).toBe("toolUse"); + expect(result.content).toEqual( + expect.arrayContaining([expect.objectContaining({ type: "toolCall", name: "read" })]), + ); + }); + + it("maps adaptive thinking effort for Claude 4.6 transport runs", async () => { + const model = attachModelProviderRequestTransport( + { + id: "claude-opus-4-6", + name: "Claude Opus 4.6", + api: "anthropic-messages", + provider: "anthropic", + baseUrl: "https://api.anthropic.com", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + } satisfies Model<"anthropic-messages">, + { + proxy: { + mode: "env-proxy", + }, + }, + ); + const streamFn = createAnthropicMessagesTransportStreamFn(); + + const stream = await Promise.resolve( + streamFn( + model, + { + messages: [{ role: "user", content: "Think deeply." }], + } as Parameters[1], + { + apiKey: "sk-ant-api", + reasoning: "xhigh", + } as Parameters[2], + ), + ); + await stream.result(); + + expect(anthropicMessagesStreamMock).toHaveBeenCalledWith( + expect.objectContaining({ + thinking: { type: "adaptive" }, + output_config: { effort: "max" }, + }), + undefined, + ); + }); +}); diff --git a/src/agents/anthropic-transport-stream.ts b/src/agents/anthropic-transport-stream.ts new file mode 100644 index 00000000000..15ed98efb6d --- /dev/null +++ b/src/agents/anthropic-transport-stream.ts @@ -0,0 +1,918 @@ +import Anthropic from "@anthropic-ai/sdk"; +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import { + calculateCost, + createAssistantMessageEventStream, + getEnvApiKey, + parseStreamingJson, + type AnthropicOptions, + type Context, + type Model, + type SimpleStreamOptions, + type ThinkingLevel, +} from "@mariozechner/pi-ai"; +import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./copilot-dynamic-headers.js"; +import { sanitizeTransportPayloadText } from "./openai-transport-stream.js"; +import { buildGuardedModelFetch } from "./provider-transport-fetch.js"; +import { transformTransportMessages } from "./transport-message-transform.js"; + +const CLAUDE_CODE_VERSION = "2.1.75"; +const CLAUDE_CODE_TOOLS = [ + "Read", + "Write", + "Edit", + "Bash", + "Grep", + "Glob", + "AskUserQuestion", + "EnterPlanMode", + "ExitPlanMode", + "KillShell", + "NotebookEdit", + "Skill", + "Task", + "TaskOutput", + "TodoWrite", + "WebFetch", + "WebSearch", +] as const; +const CLAUDE_CODE_TOOL_LOOKUP = new Map( + CLAUDE_CODE_TOOLS.map((tool) => [tool.toLowerCase(), tool]), +); + +type AnthropicTransportModel = Model<"anthropic-messages"> & { + headers?: Record; + provider: string; +}; + +type AnthropicTransportOptions = AnthropicOptions & + Pick; + +type TransportContentBlock = + | { type: "text"; text: string; index?: number } + | { + type: "thinking"; + thinking: string; + thinkingSignature: string; + redacted?: boolean; + index?: number; + } + | { + type: "toolCall"; + id: string; + name: string; + arguments: unknown; + partialJson?: string; + index?: number; + }; + +type MutableAssistantOutput = { + role: "assistant"; + content: Array; + api: "anthropic-messages"; + provider: string; + model: string; + usage: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + totalTokens: number; + cost: { input: number; output: number; cacheRead: number; cacheWrite: number; total: number }; + }; + stopReason: string; + timestamp: number; + responseId?: string; + errorMessage?: string; +}; + +function sanitizeAnthropicText(text: string): string { + return sanitizeTransportPayloadText(text); +} + +function supportsAdaptiveThinking(modelId: string): boolean { + return ( + modelId.includes("opus-4-6") || + modelId.includes("opus-4.6") || + modelId.includes("sonnet-4-6") || + modelId.includes("sonnet-4.6") + ); +} + +function mapThinkingLevelToEffort( + level: ThinkingLevel, + modelId: string, +): NonNullable { + switch (level) { + case "minimal": + case "low": + return "low"; + case "medium": + return "medium"; + case "xhigh": + return modelId.includes("opus-4-6") || modelId.includes("opus-4.6") ? "max" : "high"; + default: + return "high"; + } +} + +function clampReasoningLevel(level: ThinkingLevel): "minimal" | "low" | "medium" | "high" { + return level === "xhigh" ? "high" : level; +} + +function adjustMaxTokensForThinking(params: { + baseMaxTokens: number; + modelMaxTokens: number; + reasoningLevel: ThinkingLevel; + customBudgets?: SimpleStreamOptions["thinkingBudgets"]; +}): { maxTokens: number; thinkingBudget: number } { + const budgets = { + minimal: 1024, + low: 2048, + medium: 8192, + high: 16384, + ...params.customBudgets, + }; + const minOutputTokens = 1024; + const level = clampReasoningLevel(params.reasoningLevel); + let thinkingBudget = budgets[level]; + const maxTokens = Math.min(params.baseMaxTokens + thinkingBudget, params.modelMaxTokens); + if (maxTokens <= thinkingBudget) { + thinkingBudget = Math.max(0, maxTokens - minOutputTokens); + } + return { maxTokens, thinkingBudget }; +} + +function mergeHeaders( + ...headerSources: Array | undefined> +): Record | undefined { + const merged: Record = {}; + for (const headers of headerSources) { + if (headers) { + Object.assign(merged, headers); + } + } + return Object.keys(merged).length > 0 ? merged : undefined; +} + +function isAnthropicOAuthToken(apiKey: string): boolean { + return apiKey.includes("sk-ant-oat"); +} + +function toClaudeCodeName(name: string): string { + return CLAUDE_CODE_TOOL_LOOKUP.get(name.toLowerCase()) ?? name; +} + +function fromClaudeCodeName(name: string, tools: Context["tools"] | undefined): string { + if (tools && tools.length > 0) { + const lowerName = name.toLowerCase(); + const matchedTool = tools.find((tool) => tool.name.toLowerCase() === lowerName); + if (matchedTool) { + return matchedTool.name; + } + } + return name; +} + +function resolveCacheControl( + baseUrl: string | undefined, + cacheRetention: AnthropicOptions["cacheRetention"], +): { type: "ephemeral"; ttl?: "1h" } | undefined { + const retention = + cacheRetention ?? (process.env.PI_CACHE_RETENTION === "long" ? "long" : "short"); + if (retention === "none") { + return undefined; + } + const ttl = + retention === "long" && typeof baseUrl === "string" && baseUrl.includes("api.anthropic.com") + ? "1h" + : undefined; + return { type: "ephemeral", ...(ttl ? { ttl } : {}) }; +} + +function convertContentBlocks( + content: Array< + { type: "text"; text: string } | { type: "image"; data: string; mimeType: string } + >, +) { + const hasImages = content.some((item) => item.type === "image"); + if (!hasImages) { + return sanitizeAnthropicText( + content.map((item) => ("text" in item ? item.text : "")).join("\n"), + ); + } + const blocks = content.map((block) => { + if (block.type === "text") { + return { + type: "text", + text: sanitizeAnthropicText(block.text), + }; + } + return { + type: "image", + source: { + type: "base64", + media_type: block.mimeType, + data: block.data, + }, + }; + }); + if (!blocks.some((block) => block.type === "text")) { + blocks.unshift({ + type: "text", + text: "(see attached image)", + }); + } + return blocks; +} + +function normalizeToolCallId(id: string): string { + return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64); +} + +function convertAnthropicMessages( + messages: Context["messages"], + model: AnthropicTransportModel, + isOAuthToken: boolean, + cacheControl: { type: "ephemeral"; ttl?: "1h" } | undefined, +) { + const params: Array> = []; + const transformedMessages = transformTransportMessages(messages, model, normalizeToolCallId); + for (let i = 0; i < transformedMessages.length; i += 1) { + const msg = transformedMessages[i]; + if (msg.role === "user") { + if (typeof msg.content === "string") { + if (msg.content.trim().length > 0) { + params.push({ + role: "user", + content: sanitizeAnthropicText(msg.content), + }); + } + continue; + } + const blocks: Array< + | { type: "text"; text: string } + | { + type: "image"; + source: { type: "base64"; media_type: string; data: string }; + } + > = msg.content.map((item) => + item.type === "text" + ? { + type: "text", + text: sanitizeAnthropicText(item.text), + } + : { + type: "image", + source: { + type: "base64", + media_type: item.mimeType, + data: item.data, + }, + }, + ); + let filteredBlocks = model.input.includes("image") + ? blocks + : blocks.filter((block) => block.type !== "image"); + filteredBlocks = filteredBlocks.filter( + (block) => block.type !== "text" || block.text.trim().length > 0, + ); + if (filteredBlocks.length === 0) { + continue; + } + params.push({ + role: "user", + content: filteredBlocks, + }); + continue; + } + if (msg.role === "assistant") { + const blocks: Array> = []; + for (const block of msg.content) { + if (block.type === "text") { + if (block.text.trim().length > 0) { + blocks.push({ + type: "text", + text: sanitizeAnthropicText(block.text), + }); + } + continue; + } + if (block.type === "thinking") { + if (block.redacted) { + blocks.push({ + type: "redacted_thinking", + data: block.thinkingSignature, + }); + continue; + } + if (block.thinking.trim().length === 0) { + continue; + } + if (!block.thinkingSignature || block.thinkingSignature.trim().length === 0) { + blocks.push({ + type: "text", + text: sanitizeAnthropicText(block.thinking), + }); + } else { + blocks.push({ + type: "thinking", + thinking: sanitizeAnthropicText(block.thinking), + signature: block.thinkingSignature, + }); + } + continue; + } + if (block.type === "toolCall") { + blocks.push({ + type: "tool_use", + id: block.id, + name: isOAuthToken ? toClaudeCodeName(block.name) : block.name, + input: block.arguments ?? {}, + }); + } + } + if (blocks.length > 0) { + params.push({ + role: "assistant", + content: blocks, + }); + } + continue; + } + if (msg.role === "toolResult") { + const toolResult = msg; + const toolResults: Array> = [ + { + type: "tool_result", + tool_use_id: toolResult.toolCallId, + content: convertContentBlocks(toolResult.content), + is_error: toolResult.isError, + }, + ]; + let j = i + 1; + while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") { + const nextMsg = transformedMessages[j] as Extract< + Context["messages"][number], + { role: "toolResult" } + >; + toolResults.push({ + type: "tool_result", + tool_use_id: nextMsg.toolCallId, + content: convertContentBlocks(nextMsg.content), + is_error: nextMsg.isError, + }); + j += 1; + } + i = j - 1; + params.push({ + role: "user", + content: toolResults, + }); + } + } + if (cacheControl && params.length > 0) { + const lastMessage = params[params.length - 1]; + if (lastMessage.role === "user") { + const content = lastMessage.content; + if (Array.isArray(content)) { + const lastBlock = content[content.length - 1]; + if ( + lastBlock && + typeof lastBlock === "object" && + "type" in lastBlock && + (lastBlock.type === "text" || + lastBlock.type === "image" || + lastBlock.type === "tool_result") + ) { + (lastBlock as Record).cache_control = cacheControl; + } + } else if (typeof content === "string") { + lastMessage.content = [ + { + type: "text", + text: content, + cache_control: cacheControl, + }, + ]; + } + } + } + return params; +} + +function convertAnthropicTools(tools: Context["tools"], isOAuthToken: boolean) { + if (!tools) { + return []; + } + return tools.map((tool) => ({ + name: isOAuthToken ? toClaudeCodeName(tool.name) : tool.name, + description: tool.description, + input_schema: { + type: "object", + properties: tool.parameters.properties || {}, + required: tool.parameters.required || [], + }, + })); +} + +function mapStopReason(reason: string | undefined): string { + switch (reason) { + case "end_turn": + return "stop"; + case "max_tokens": + return "length"; + case "tool_use": + return "toolUse"; + case "pause_turn": + return "stop"; + case "refusal": + case "sensitive": + return "error"; + case "stop_sequence": + return "stop"; + default: + throw new Error(`Unhandled stop reason: ${String(reason)}`); + } +} + +function createAnthropicTransportClient(params: { + model: AnthropicTransportModel; + context: Context; + apiKey: string; + options: AnthropicTransportOptions | undefined; +}) { + const { model, context, apiKey, options } = params; + const needsInterleavedBeta = + (options?.interleavedThinking ?? true) && !supportsAdaptiveThinking(model.id); + const fetch = buildGuardedModelFetch(model); + if (model.provider === "github-copilot") { + const betaFeatures = needsInterleavedBeta ? ["interleaved-thinking-2025-05-14"] : []; + return { + client: new Anthropic({ + apiKey: null, + authToken: apiKey, + baseURL: model.baseUrl, + dangerouslyAllowBrowser: true, + defaultHeaders: mergeHeaders( + { + accept: "application/json", + "anthropic-dangerous-direct-browser-access": "true", + ...(betaFeatures.length > 0 ? { "anthropic-beta": betaFeatures.join(",") } : {}), + }, + model.headers, + buildCopilotDynamicHeaders({ + messages: context.messages, + hasImages: hasCopilotVisionInput(context.messages), + }), + options?.headers, + ), + fetch, + }), + isOAuthToken: false, + }; + } + const betaFeatures = ["fine-grained-tool-streaming-2025-05-14"]; + if (needsInterleavedBeta) { + betaFeatures.push("interleaved-thinking-2025-05-14"); + } + if (isAnthropicOAuthToken(apiKey)) { + return { + client: new Anthropic({ + apiKey: null, + authToken: apiKey, + baseURL: model.baseUrl, + dangerouslyAllowBrowser: true, + defaultHeaders: mergeHeaders( + { + accept: "application/json", + "anthropic-dangerous-direct-browser-access": "true", + "anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`, + "user-agent": `claude-cli/${CLAUDE_CODE_VERSION}`, + "x-app": "cli", + }, + model.headers, + options?.headers, + ), + fetch, + }), + isOAuthToken: true, + }; + } + return { + client: new Anthropic({ + apiKey, + baseURL: model.baseUrl, + dangerouslyAllowBrowser: true, + defaultHeaders: mergeHeaders( + { + accept: "application/json", + "anthropic-dangerous-direct-browser-access": "true", + "anthropic-beta": betaFeatures.join(","), + }, + model.headers, + options?.headers, + ), + fetch, + }), + isOAuthToken: false, + }; +} + +function buildAnthropicParams( + model: AnthropicTransportModel, + context: Context, + isOAuthToken: boolean, + options: AnthropicTransportOptions | undefined, +) { + const cacheControl = resolveCacheControl(model.baseUrl, options?.cacheRetention); + const defaultMaxTokens = Math.min(model.maxTokens, 32_000); + const params: Record = { + model: model.id, + messages: convertAnthropicMessages(context.messages, model, isOAuthToken, cacheControl), + max_tokens: options?.maxTokens || defaultMaxTokens, + stream: true, + }; + if (isOAuthToken) { + params.system = [ + { + type: "text", + text: "You are Claude Code, Anthropic's official CLI for Claude.", + ...(cacheControl ? { cache_control: cacheControl } : {}), + }, + ...(context.systemPrompt + ? [ + { + type: "text", + text: sanitizeAnthropicText(context.systemPrompt), + ...(cacheControl ? { cache_control: cacheControl } : {}), + }, + ] + : []), + ]; + } else if (context.systemPrompt) { + params.system = [ + { + type: "text", + text: sanitizeAnthropicText(context.systemPrompt), + ...(cacheControl ? { cache_control: cacheControl } : {}), + }, + ]; + } + if (options?.temperature !== undefined && !options.thinkingEnabled) { + params.temperature = options.temperature; + } + if (context.tools) { + params.tools = convertAnthropicTools(context.tools, isOAuthToken); + } + if (model.reasoning) { + if (options?.thinkingEnabled) { + if (supportsAdaptiveThinking(model.id)) { + params.thinking = { type: "adaptive" }; + if (options.effort) { + params.output_config = { effort: options.effort }; + } + } else { + params.thinking = { + type: "enabled", + budget_tokens: options.thinkingBudgetTokens || 1024, + }; + } + } else if (options?.thinkingEnabled === false) { + params.thinking = { type: "disabled" }; + } + } + if (options?.metadata && typeof options.metadata.user_id === "string") { + params.metadata = { user_id: options.metadata.user_id }; + } + if (options?.toolChoice) { + params.tool_choice = + typeof options.toolChoice === "string" ? { type: options.toolChoice } : options.toolChoice; + } + return params; +} + +function resolveAnthropicTransportOptions( + model: AnthropicTransportModel, + options: AnthropicTransportOptions | undefined, + apiKey: string, +): AnthropicTransportOptions { + const baseMaxTokens = options?.maxTokens || Math.min(model.maxTokens, 32_000); + const resolved: AnthropicTransportOptions = { + temperature: options?.temperature, + maxTokens: baseMaxTokens, + signal: options?.signal, + apiKey, + cacheRetention: options?.cacheRetention, + sessionId: options?.sessionId, + headers: options?.headers, + onPayload: options?.onPayload, + maxRetryDelayMs: options?.maxRetryDelayMs, + metadata: options?.metadata, + interleavedThinking: options?.interleavedThinking, + toolChoice: options?.toolChoice, + thinkingBudgets: options?.thinkingBudgets, + reasoning: options?.reasoning, + }; + if (!options?.reasoning) { + resolved.thinkingEnabled = false; + return resolved; + } + if (supportsAdaptiveThinking(model.id)) { + resolved.thinkingEnabled = true; + resolved.effort = mapThinkingLevelToEffort(options.reasoning, model.id); + return resolved; + } + const adjusted = adjustMaxTokensForThinking({ + baseMaxTokens, + modelMaxTokens: model.maxTokens, + reasoningLevel: options.reasoning, + customBudgets: options.thinkingBudgets, + }); + resolved.maxTokens = adjusted.maxTokens; + resolved.thinkingEnabled = true; + resolved.thinkingBudgetTokens = adjusted.thinkingBudget; + return resolved; +} + +export function createAnthropicMessagesTransportStreamFn(): StreamFn { + return (rawModel, context, rawOptions) => { + const model = rawModel as AnthropicTransportModel; + const options = rawOptions as AnthropicTransportOptions | undefined; + const stream = createAssistantMessageEventStream(); + void (async () => { + const output: MutableAssistantOutput = { + role: "assistant", + content: [], + api: "anthropic-messages", + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }; + try { + const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? ""; + if (!apiKey) { + throw new Error(`No API key for provider: ${model.provider}`); + } + const transportOptions = resolveAnthropicTransportOptions(model, options, apiKey); + const { client, isOAuthToken } = createAnthropicTransportClient({ + model, + context, + apiKey, + options: transportOptions, + }); + let params = buildAnthropicParams(model, context, isOAuthToken, transportOptions); + const nextParams = await transportOptions.onPayload?.(params, model); + if (nextParams !== undefined) { + params = nextParams as Record; + } + const anthropicStream = client.messages.stream( + { ...params, stream: true } as never, + transportOptions.signal ? { signal: transportOptions.signal } : undefined, + ) as AsyncIterable>; + stream.push({ type: "start", partial: output as never }); + const blocks = output.content; + for await (const event of anthropicStream) { + if (event.type === "message_start") { + const message = event.message as + | { id?: string; usage?: Record } + | undefined; + const usage = message?.usage ?? {}; + output.responseId = typeof message?.id === "string" ? message.id : undefined; + output.usage.input = typeof usage.input_tokens === "number" ? usage.input_tokens : 0; + output.usage.output = typeof usage.output_tokens === "number" ? usage.output_tokens : 0; + output.usage.cacheRead = + typeof usage.cache_read_input_tokens === "number" ? usage.cache_read_input_tokens : 0; + output.usage.cacheWrite = + typeof usage.cache_creation_input_tokens === "number" + ? usage.cache_creation_input_tokens + : 0; + output.usage.totalTokens = + output.usage.input + + output.usage.output + + output.usage.cacheRead + + output.usage.cacheWrite; + calculateCost(model, output.usage); + continue; + } + if (event.type === "content_block_start") { + const contentBlock = event.content_block as Record | undefined; + const index = typeof event.index === "number" ? event.index : -1; + if (contentBlock?.type === "text") { + const block: TransportContentBlock = { type: "text", text: "", index }; + output.content.push(block); + stream.push({ + type: "text_start", + contentIndex: output.content.length - 1, + partial: output as never, + }); + continue; + } + if (contentBlock?.type === "thinking") { + const block: TransportContentBlock = { + type: "thinking", + thinking: "", + thinkingSignature: "", + index, + }; + output.content.push(block); + stream.push({ + type: "thinking_start", + contentIndex: output.content.length - 1, + partial: output as never, + }); + continue; + } + if (contentBlock?.type === "redacted_thinking") { + const block: TransportContentBlock = { + type: "thinking", + thinking: "[Reasoning redacted]", + thinkingSignature: typeof contentBlock.data === "string" ? contentBlock.data : "", + redacted: true, + index, + }; + output.content.push(block); + stream.push({ + type: "thinking_start", + contentIndex: output.content.length - 1, + partial: output as never, + }); + continue; + } + if (contentBlock?.type === "tool_use") { + const block: TransportContentBlock = { + type: "toolCall", + id: typeof contentBlock.id === "string" ? contentBlock.id : "", + name: + typeof contentBlock.name === "string" + ? isOAuthToken + ? fromClaudeCodeName(contentBlock.name, context.tools) + : contentBlock.name + : "", + arguments: + contentBlock.input && typeof contentBlock.input === "object" + ? (contentBlock.input as Record) + : {}, + partialJson: "", + index, + }; + output.content.push(block); + stream.push({ + type: "toolcall_start", + contentIndex: output.content.length - 1, + partial: output as never, + }); + } + continue; + } + if (event.type === "content_block_delta") { + const index = blocks.findIndex((block) => block.index === event.index); + const block = blocks[index]; + const delta = event.delta as Record | undefined; + if ( + block?.type === "text" && + delta?.type === "text_delta" && + typeof delta.text === "string" + ) { + block.text += delta.text; + stream.push({ + type: "text_delta", + contentIndex: index, + delta: delta.text, + partial: output as never, + }); + continue; + } + if ( + block?.type === "thinking" && + delta?.type === "thinking_delta" && + typeof delta.thinking === "string" + ) { + block.thinking += delta.thinking; + stream.push({ + type: "thinking_delta", + contentIndex: index, + delta: delta.thinking, + partial: output as never, + }); + continue; + } + if ( + block?.type === "toolCall" && + delta?.type === "input_json_delta" && + typeof delta.partial_json === "string" + ) { + block.partialJson += delta.partial_json; + block.arguments = parseStreamingJson(block.partialJson); + stream.push({ + type: "toolcall_delta", + contentIndex: index, + delta: delta.partial_json, + partial: output as never, + }); + continue; + } + if ( + block?.type === "thinking" && + delta?.type === "signature_delta" && + typeof delta.signature === "string" + ) { + block.thinkingSignature = `${String(block.thinkingSignature ?? "")}${delta.signature}`; + } + continue; + } + if (event.type === "content_block_stop") { + const index = blocks.findIndex((block) => block.index === event.index); + const block = blocks[index]; + if (!block) { + continue; + } + delete block.index; + if (block.type === "text") { + stream.push({ + type: "text_end", + contentIndex: index, + content: block.text, + partial: output as never, + }); + continue; + } + if (block.type === "thinking") { + stream.push({ + type: "thinking_end", + contentIndex: index, + content: block.thinking, + partial: output as never, + }); + continue; + } + if (block.type === "toolCall") { + if (typeof block.partialJson === "string" && block.partialJson.length > 0) { + block.arguments = parseStreamingJson(block.partialJson); + } + delete block.partialJson; + stream.push({ + type: "toolcall_end", + contentIndex: index, + toolCall: block as never, + partial: output as never, + }); + } + continue; + } + if (event.type === "message_delta") { + const delta = event.delta as { stop_reason?: string } | undefined; + const usage = event.usage as Record | undefined; + if (delta?.stop_reason) { + output.stopReason = mapStopReason(delta.stop_reason); + } + if (typeof usage?.input_tokens === "number") { + output.usage.input = usage.input_tokens; + } + if (typeof usage?.output_tokens === "number") { + output.usage.output = usage.output_tokens; + } + if (typeof usage?.cache_read_input_tokens === "number") { + output.usage.cacheRead = usage.cache_read_input_tokens; + } + if (typeof usage?.cache_creation_input_tokens === "number") { + output.usage.cacheWrite = usage.cache_creation_input_tokens; + } + output.usage.totalTokens = + output.usage.input + + output.usage.output + + output.usage.cacheRead + + output.usage.cacheWrite; + calculateCost(model, output.usage); + } + } + if (transportOptions.signal?.aborted) { + throw new Error("Request was aborted"); + } + if (output.stopReason === "aborted" || output.stopReason === "error") { + throw new Error("An unknown error occurred"); + } + stream.push({ type: "done", reason: output.stopReason as never, message: output as never }); + stream.end(); + } catch (error) { + for (const block of output.content) { + delete block.index; + } + output.stopReason = options?.signal?.aborted ? "aborted" : "error"; + output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error); + stream.push({ type: "error", reason: output.stopReason as never, error: output as never }); + stream.end(); + } + })(); + return stream as ReturnType; + }; +} diff --git a/src/agents/copilot-dynamic-headers.ts b/src/agents/copilot-dynamic-headers.ts new file mode 100644 index 00000000000..c6551ebb64a --- /dev/null +++ b/src/agents/copilot-dynamic-headers.ts @@ -0,0 +1,29 @@ +import type { Context } from "@mariozechner/pi-ai"; + +function inferCopilotInitiator(messages: Context["messages"]): "agent" | "user" { + const last = messages[messages.length - 1]; + return last && last.role !== "user" ? "agent" : "user"; +} + +export function hasCopilotVisionInput(messages: Context["messages"]): boolean { + return messages.some((message) => { + if (message.role === "user" && Array.isArray(message.content)) { + return message.content.some((item) => item.type === "image"); + } + if (message.role === "toolResult" && Array.isArray(message.content)) { + return message.content.some((item) => item.type === "image"); + } + return false; + }); +} + +export function buildCopilotDynamicHeaders(params: { + messages: Context["messages"]; + hasImages: boolean; +}): Record { + return { + "X-Initiator": inferCopilotInitiator(params.messages), + "Openai-Intent": "conversation-edits", + ...(params.hasImages ? { "Copilot-Vision-Request": "true" } : {}), + }; +} diff --git a/src/agents/openai-transport-stream.test.ts b/src/agents/openai-transport-stream.test.ts index 3952d197ad3..454cc1b53c8 100644 --- a/src/agents/openai-transport-stream.test.ts +++ b/src/agents/openai-transport-stream.test.ts @@ -3,22 +3,24 @@ import { describe, expect, it } from "vitest"; import { buildOpenAIResponsesParams, buildOpenAICompletionsParams, - buildTransportAwareSimpleStreamFn, - isTransportAwareApiSupported, parseTransportChunkUsage, - prepareTransportAwareSimpleModel, resolveAzureOpenAIApiVersion, - resolveTransportAwareSimpleApi, sanitizeTransportPayloadText, } from "./openai-transport-stream.js"; import { attachModelProviderRequestTransport } from "./provider-request-config.js"; +import { + buildTransportAwareSimpleStreamFn, + isTransportAwareApiSupported, + prepareTransportAwareSimpleModel, + resolveTransportAwareSimpleApi, +} from "./provider-transport-stream.js"; describe("openai transport stream", () => { it("reports the supported transport-aware APIs", () => { expect(isTransportAwareApiSupported("openai-responses")).toBe(true); expect(isTransportAwareApiSupported("openai-completions")).toBe(true); expect(isTransportAwareApiSupported("azure-openai-responses")).toBe(true); - expect(isTransportAwareApiSupported("anthropic-messages")).toBe(false); + expect(isTransportAwareApiSupported("anthropic-messages")).toBe(true); }); it("prepares a custom simple-completion api alias when transport overrides are attached", () => { @@ -54,6 +56,39 @@ describe("openai transport stream", () => { expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function"); }); + it("prepares an Anthropic simple-completion api alias when transport overrides are attached", () => { + const model = attachModelProviderRequestTransport( + { + id: "claude-sonnet-4-6", + name: "Claude Sonnet 4.6", + api: "anthropic-messages", + provider: "anthropic", + baseUrl: "https://api.anthropic.com", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + } satisfies Model<"anthropic-messages">, + { + proxy: { + mode: "explicit-proxy", + url: "http://proxy.internal:8443", + }, + }, + ); + + const prepared = prepareTransportAwareSimpleModel(model); + + expect(resolveTransportAwareSimpleApi(model.api)).toBe("openclaw-anthropic-messages-transport"); + expect(prepared).toMatchObject({ + api: "openclaw-anthropic-messages-transport", + provider: "anthropic", + id: "claude-sonnet-4-6", + }); + expect(buildTransportAwareSimpleStreamFn(model)).toBeTypeOf("function"); + }); + it("removes unpaired surrogate code units but preserves valid surrogate pairs", () => { const high = String.fromCharCode(0xd83d); const low = String.fromCharCode(0xdc00); diff --git a/src/agents/openai-transport-stream.ts b/src/agents/openai-transport-stream.ts index 3440f2ec5a4..99ce51ad3ed 100644 --- a/src/agents/openai-transport-stream.ts +++ b/src/agents/openai-transport-stream.ts @@ -18,29 +18,14 @@ import type { ResponseInput, ResponseInputMessageContentList, } from "openai/resources/responses/responses.js"; -import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js"; +import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./copilot-dynamic-headers.js"; import { resolveOpenAICompletionsCompatDefaultsFromCapabilities } from "./openai-completions-compat.js"; import { resolveProviderRequestCapabilities } from "./provider-attribution.js"; -import { - buildProviderRequestDispatcherPolicy, - getModelProviderRequestTransport, - resolveProviderRequestPolicyConfig, -} from "./provider-request-config.js"; +import { buildGuardedModelFetch } from "./provider-transport-fetch.js"; +import { transformTransportMessages } from "./transport-message-transform.js"; const DEFAULT_AZURE_OPENAI_API_VERSION = "2024-12-01-preview"; -const SUPPORTED_TRANSPORT_APIS = new Set([ - "openai-responses", - "openai-completions", - "azure-openai-responses", -]); - -const SIMPLE_TRANSPORT_API_ALIAS: Record = { - "openai-responses": "openclaw-openai-responses-transport", - "openai-completions": "openclaw-openai-completions-transport", - "azure-openai-responses": "openclaw-azure-openai-responses-transport", -}; - type BaseStreamOptions = { temperature?: number; maxTokens?: number; @@ -165,164 +150,6 @@ function shortHash(value: string): string { return Math.abs(hash).toString(36); } -function inferCopilotInitiator(messages: Context["messages"]): "agent" | "user" { - const last = messages[messages.length - 1]; - return last && last.role !== "user" ? "agent" : "user"; -} - -function hasCopilotVisionInput(messages: Context["messages"]): boolean { - return messages.some((message) => { - if (message.role === "user" && Array.isArray(message.content)) { - return message.content.some((item) => item.type === "image"); - } - if (message.role === "toolResult" && Array.isArray(message.content)) { - return message.content.some((item) => item.type === "image"); - } - return false; - }); -} - -function buildCopilotDynamicHeaders(params: { - messages: Context["messages"]; - hasImages: boolean; -}): Record { - return { - "X-Initiator": inferCopilotInitiator(params.messages), - "Openai-Intent": "conversation-edits", - ...(params.hasImages ? { "Copilot-Vision-Request": "true" } : {}), - }; -} - -function transformMessages( - messages: Context["messages"], - model: Model, - normalizeToolCallId?: ( - id: string, - targetModel: Model, - source: { provider: string; api: Api; model: string }, - ) => string, -): Context["messages"] { - const toolCallIdMap = new Map(); - const transformed = messages.map((msg) => { - if (msg.role === "user") { - return msg; - } - if (msg.role === "toolResult") { - const normalizedId = toolCallIdMap.get(msg.toolCallId); - return normalizedId && normalizedId !== msg.toolCallId - ? { ...msg, toolCallId: normalizedId } - : msg; - } - if (msg.role !== "assistant") { - return msg; - } - const isSameModel = - msg.provider === model.provider && msg.api === model.api && msg.model === model.id; - const content: typeof msg.content = []; - for (const block of msg.content) { - if (block.type === "thinking") { - if (block.redacted) { - if (isSameModel) { - content.push(block); - } - continue; - } - if (isSameModel && block.thinkingSignature) { - content.push(block); - continue; - } - if (!block.thinking.trim()) { - continue; - } - content.push(isSameModel ? block : { type: "text", text: block.thinking }); - continue; - } - if (block.type === "text") { - content.push(isSameModel ? block : { type: "text", text: block.text }); - continue; - } - if (block.type !== "toolCall") { - content.push(block); - continue; - } - let normalizedToolCall = block; - if (!isSameModel && block.thoughtSignature) { - normalizedToolCall = { ...normalizedToolCall }; - delete normalizedToolCall.thoughtSignature; - } - if (!isSameModel && normalizeToolCallId) { - const normalizedId = normalizeToolCallId(block.id, model, msg); - if (normalizedId !== block.id) { - toolCallIdMap.set(block.id, normalizedId); - normalizedToolCall = { ...normalizedToolCall, id: normalizedId }; - } - } - content.push(normalizedToolCall); - } - return { ...msg, content }; - }); - - const result: Context["messages"] = []; - let pendingToolCalls: Array<{ id: string; name: string }> = []; - let existingToolResultIds = new Set(); - for (const msg of transformed) { - if (msg.role === "assistant") { - if (pendingToolCalls.length > 0) { - for (const toolCall of pendingToolCalls) { - if (!existingToolResultIds.has(toolCall.id)) { - result.push({ - role: "toolResult", - toolCallId: toolCall.id, - toolName: toolCall.name, - content: [{ type: "text", text: "No result provided" }], - isError: true, - timestamp: Date.now(), - }); - } - } - pendingToolCalls = []; - existingToolResultIds = new Set(); - } - if (msg.stopReason === "error" || msg.stopReason === "aborted") { - continue; - } - const toolCalls = msg.content.filter( - (block): block is Extract<(typeof msg.content)[number], { type: "toolCall" }> => - block.type === "toolCall", - ); - if (toolCalls.length > 0) { - pendingToolCalls = toolCalls.map((block) => ({ id: block.id, name: block.name })); - existingToolResultIds = new Set(); - } - result.push(msg); - continue; - } - if (msg.role === "toolResult") { - existingToolResultIds.add(msg.toolCallId); - result.push(msg); - continue; - } - if (pendingToolCalls.length > 0) { - for (const toolCall of pendingToolCalls) { - if (!existingToolResultIds.has(toolCall.id)) { - result.push({ - role: "toolResult", - toolCallId: toolCall.id, - toolName: toolCall.name, - content: [{ type: "text", text: "No result provided" }], - isError: true, - timestamp: Date.now(), - }); - } - } - pendingToolCalls = []; - existingToolResultIds = new Set(); - } - result.push(msg); - } - return result; -} - function encodeTextSignatureV1(id: string, phase?: "commentary" | "final_answer"): string { return JSON.stringify({ v: 1, id, ...(phase ? { phase } : {}) }); } @@ -386,7 +213,11 @@ function convertResponsesMessages( } return `${normalizedCallId}|${normalizedItemId}`; }; - const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId); + const transformedMessages = transformTransportMessages( + context.messages, + model, + normalizeToolCallId, + ); const includeSystemPrompt = options?.includeSystemPrompt ?? true; if (includeSystemPrompt && context.systemPrompt) { messages.push({ @@ -718,133 +549,6 @@ function mapResponsesStopReason(status: string | undefined): string { } } -function hasTransportOverrides(model: Model): boolean { - const request = getModelProviderRequestTransport(model); - return Boolean(request?.proxy || request?.tls); -} - -export function isTransportAwareApiSupported(api: Api): boolean { - return SUPPORTED_TRANSPORT_APIS.has(api); -} - -export function resolveTransportAwareSimpleApi(api: Api): Api | undefined { - return SIMPLE_TRANSPORT_API_ALIAS[api]; -} - -export function createTransportAwareStreamFnForModel(model: Model): StreamFn | undefined { - if (!hasTransportOverrides(model)) { - return undefined; - } - if (!isTransportAwareApiSupported(model.api)) { - throw new Error( - `Model-provider request.proxy/request.tls is not yet supported for api "${model.api}"`, - ); - } - switch (model.api) { - case "openai-responses": - return createOpenAIResponsesTransportStreamFn(); - case "openai-completions": - return createOpenAICompletionsTransportStreamFn(); - case "azure-openai-responses": - return createAzureOpenAIResponsesTransportStreamFn(); - default: - return undefined; - } -} - -function resolveModelRequestPolicy(model: Model) { - return resolveProviderRequestPolicyConfig({ - provider: model.provider, - api: model.api, - baseUrl: model.baseUrl, - capability: "llm", - transport: "stream", - request: getModelProviderRequestTransport(model), - }); -} - -function buildManagedResponse(response: Response, release: () => Promise): Response { - if (!response.body) { - void release(); - return response; - } - const source = response.body; - let reader: ReadableStreamDefaultReader | undefined; - let released = false; - const finalize = async () => { - if (released) { - return; - } - released = true; - await release().catch(() => undefined); - }; - const wrappedBody = new ReadableStream({ - start() { - reader = source.getReader(); - }, - async pull(controller) { - try { - const chunk = await reader?.read(); - if (!chunk || chunk.done) { - controller.close(); - await finalize(); - return; - } - controller.enqueue(chunk.value); - } catch (error) { - controller.error(error); - await finalize(); - } - }, - async cancel(reason) { - try { - await reader?.cancel(reason); - } finally { - await finalize(); - } - }, - }); - return new Response(wrappedBody, { - status: response.status, - statusText: response.statusText, - headers: response.headers, - }); -} - -function buildGuardedModelFetch(model: Model): typeof fetch { - const requestConfig = resolveModelRequestPolicy(model); - const dispatcherPolicy = buildProviderRequestDispatcherPolicy(requestConfig); - return async (input, init) => { - const request = input instanceof Request ? new Request(input, init) : undefined; - const url = - request?.url ?? - (input instanceof URL - ? input.toString() - : typeof input === "string" - ? input - : (() => { - throw new Error("Unsupported fetch input for transport-aware model request"); - })()); - const requestInit = - request && - ({ - method: request.method, - headers: request.headers, - body: request.body ?? undefined, - redirect: request.redirect, - signal: request.signal, - ...(request.body ? ({ duplex: "half" } as const) : {}), - } satisfies RequestInit & { duplex?: "half" }); - const result = await fetchWithSsrFGuard({ - url, - init: requestInit ?? init, - dispatcherPolicy, - ...(requestConfig.allowPrivateNetwork ? { policy: { allowPrivateNetwork: true } } : {}), - }); - return buildManagedResponse(result.response, result.release); - }; -} - function buildOpenAIClientHeaders( model: Model, context: Context, @@ -881,7 +585,7 @@ function createOpenAIResponsesClient( }); } -function createOpenAIResponsesTransportStreamFn(): StreamFn { +export function createOpenAIResponsesTransportStreamFn(): StreamFn { return (model, context, options) => { const eventStream = createAssistantMessageEventStream(); const stream = eventStream as unknown as { push(event: unknown): void; end(): void }; @@ -1008,7 +712,7 @@ export function buildOpenAIResponsesParams( return params; } -function createAzureOpenAIResponsesTransportStreamFn(): StreamFn { +export function createAzureOpenAIResponsesTransportStreamFn(): StreamFn { return (model, context, options) => { const eventStream = createAssistantMessageEventStream(); const stream = eventStream as unknown as { push(event: unknown): void; end(): void }; @@ -1137,7 +841,7 @@ function createOpenAICompletionsClient( }); } -function createOpenAICompletionsTransportStreamFn(): StreamFn { +export function createOpenAICompletionsTransportStreamFn(): StreamFn { return (model, context, options) => { const eventStream = createAssistantMessageEventStream(); const stream = eventStream as unknown as { push(event: unknown): void; end(): void }; @@ -1553,19 +1257,3 @@ function mapStopReason(reason: string | null) { }; } } - -export function prepareTransportAwareSimpleModel(model: Model): Model { - const streamFn = createTransportAwareStreamFnForModel(model as Model); - const alias = resolveTransportAwareSimpleApi(model.api); - if (!streamFn || !alias) { - return model; - } - return { - ...model, - api: alias, - }; -} - -export function buildTransportAwareSimpleStreamFn(model: Model): StreamFn | undefined { - return createTransportAwareStreamFnForModel(model); -} diff --git a/src/agents/provider-stream.ts b/src/agents/provider-stream.ts index c8d3edb4557..420af2bff28 100644 --- a/src/agents/provider-stream.ts +++ b/src/agents/provider-stream.ts @@ -3,7 +3,7 @@ import type { Api, Model } from "@mariozechner/pi-ai"; import type { OpenClawConfig } from "../config/config.js"; import { resolveProviderStreamFn } from "../plugins/provider-runtime.js"; import { ensureCustomApiRegistered } from "./custom-api-registry.js"; -import { createTransportAwareStreamFnForModel } from "./openai-transport-stream.js"; +import { createTransportAwareStreamFnForModel } from "./provider-transport-stream.js"; export function registerProviderStreamForModel(params: { model: Model; diff --git a/src/agents/provider-transport-fetch.ts b/src/agents/provider-transport-fetch.ts new file mode 100644 index 00000000000..15d9bc42abd --- /dev/null +++ b/src/agents/provider-transport-fetch.ts @@ -0,0 +1,100 @@ +import type { Api, Model } from "@mariozechner/pi-ai"; +import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js"; +import { + buildProviderRequestDispatcherPolicy, + getModelProviderRequestTransport, + resolveProviderRequestPolicyConfig, +} from "./provider-request-config.js"; + +function buildManagedResponse(response: Response, release: () => Promise): Response { + if (!response.body) { + void release(); + return response; + } + const source = response.body; + let reader: ReadableStreamDefaultReader | undefined; + let released = false; + const finalize = async () => { + if (released) { + return; + } + released = true; + await release().catch(() => undefined); + }; + const wrappedBody = new ReadableStream({ + start() { + reader = source.getReader(); + }, + async pull(controller) { + try { + const chunk = await reader?.read(); + if (!chunk || chunk.done) { + controller.close(); + await finalize(); + return; + } + controller.enqueue(chunk.value); + } catch (error) { + controller.error(error); + await finalize(); + } + }, + async cancel(reason) { + try { + await reader?.cancel(reason); + } finally { + await finalize(); + } + }, + }); + return new Response(wrappedBody, { + status: response.status, + statusText: response.statusText, + headers: response.headers, + }); +} + +function resolveModelRequestPolicy(model: Model) { + return resolveProviderRequestPolicyConfig({ + provider: model.provider, + api: model.api, + baseUrl: model.baseUrl, + capability: "llm", + transport: "stream", + request: getModelProviderRequestTransport(model), + }); +} + +export function buildGuardedModelFetch(model: Model): typeof fetch { + const requestConfig = resolveModelRequestPolicy(model); + const dispatcherPolicy = buildProviderRequestDispatcherPolicy(requestConfig); + return async (input, init) => { + const request = input instanceof Request ? new Request(input, init) : undefined; + const url = + request?.url ?? + (input instanceof URL + ? input.toString() + : typeof input === "string" + ? input + : (() => { + throw new Error("Unsupported fetch input for transport-aware model request"); + })()); + const requestInit = + request && + ({ + method: request.method, + headers: request.headers, + body: request.body ?? undefined, + redirect: request.redirect, + signal: request.signal, + ...(request.body ? ({ duplex: "half" } as const) : {}), + } satisfies RequestInit & { duplex?: "half" }); + const result = await fetchWithSsrFGuard({ + url, + init: requestInit ?? init, + dispatcherPolicy, + ...(requestConfig.allowPrivateNetwork ? { policy: { allowPrivateNetwork: true } } : {}), + }); + return buildManagedResponse(result.response, result.release); + }; +} diff --git a/src/agents/provider-transport-stream.ts b/src/agents/provider-transport-stream.ts new file mode 100644 index 00000000000..3a1de3f7e2b --- /dev/null +++ b/src/agents/provider-transport-stream.ts @@ -0,0 +1,75 @@ +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import type { Api, Model } from "@mariozechner/pi-ai"; +import { createAnthropicMessagesTransportStreamFn } from "./anthropic-transport-stream.js"; +import { + createAzureOpenAIResponsesTransportStreamFn, + createOpenAICompletionsTransportStreamFn, + createOpenAIResponsesTransportStreamFn, +} from "./openai-transport-stream.js"; +import { getModelProviderRequestTransport } from "./provider-request-config.js"; + +const SUPPORTED_TRANSPORT_APIS = new Set([ + "openai-responses", + "openai-completions", + "azure-openai-responses", + "anthropic-messages", +]); + +const SIMPLE_TRANSPORT_API_ALIAS: Record = { + "openai-responses": "openclaw-openai-responses-transport", + "openai-completions": "openclaw-openai-completions-transport", + "azure-openai-responses": "openclaw-azure-openai-responses-transport", + "anthropic-messages": "openclaw-anthropic-messages-transport", +}; + +function hasTransportOverrides(model: Model): boolean { + const request = getModelProviderRequestTransport(model); + return Boolean(request?.proxy || request?.tls); +} + +export function isTransportAwareApiSupported(api: Api): boolean { + return SUPPORTED_TRANSPORT_APIS.has(api); +} + +export function resolveTransportAwareSimpleApi(api: Api): Api | undefined { + return SIMPLE_TRANSPORT_API_ALIAS[api]; +} + +export function createTransportAwareStreamFnForModel(model: Model): StreamFn | undefined { + if (!hasTransportOverrides(model)) { + return undefined; + } + if (!isTransportAwareApiSupported(model.api)) { + throw new Error( + `Model-provider request.proxy/request.tls is not yet supported for api "${model.api}"`, + ); + } + switch (model.api) { + case "openai-responses": + return createOpenAIResponsesTransportStreamFn(); + case "openai-completions": + return createOpenAICompletionsTransportStreamFn(); + case "azure-openai-responses": + return createAzureOpenAIResponsesTransportStreamFn(); + case "anthropic-messages": + return createAnthropicMessagesTransportStreamFn(); + default: + return undefined; + } +} + +export function prepareTransportAwareSimpleModel(model: Model): Model { + const streamFn = createTransportAwareStreamFnForModel(model as Model); + const alias = resolveTransportAwareSimpleApi(model.api); + if (!streamFn || !alias) { + return model; + } + return { + ...model, + api: alias, + }; +} + +export function buildTransportAwareSimpleStreamFn(model: Model): StreamFn | undefined { + return createTransportAwareStreamFnForModel(model); +} diff --git a/src/agents/simple-completion-transport.test.ts b/src/agents/simple-completion-transport.test.ts index 90cb9939a7d..4b3036e3d72 100644 --- a/src/agents/simple-completion-transport.test.ts +++ b/src/agents/simple-completion-transport.test.ts @@ -16,7 +16,7 @@ vi.mock("./custom-api-registry.js", () => ({ ensureCustomApiRegistered, })); -vi.mock("./openai-transport-stream.js", () => ({ +vi.mock("./provider-transport-stream.js", () => ({ buildTransportAwareSimpleStreamFn, prepareTransportAwareSimpleModel, })); diff --git a/src/agents/simple-completion-transport.ts b/src/agents/simple-completion-transport.ts index 3448a5f026f..f920a597831 100644 --- a/src/agents/simple-completion-transport.ts +++ b/src/agents/simple-completion-transport.ts @@ -2,11 +2,11 @@ import { getApiProvider, type Api, type Model } from "@mariozechner/pi-ai"; import type { OpenClawConfig } from "../config/config.js"; import { createAnthropicVertexStreamFnForModel } from "./anthropic-vertex-stream.js"; import { ensureCustomApiRegistered } from "./custom-api-registry.js"; +import { registerProviderStreamForModel } from "./provider-stream.js"; import { buildTransportAwareSimpleStreamFn, prepareTransportAwareSimpleModel, -} from "./openai-transport-stream.js"; -import { registerProviderStreamForModel } from "./provider-stream.js"; +} from "./provider-transport-stream.js"; function resolveAnthropicVertexSimpleApi(baseUrl?: string): Api { const suffix = baseUrl?.trim() ? encodeURIComponent(baseUrl.trim()) : "default"; diff --git a/src/agents/transport-message-transform.ts b/src/agents/transport-message-transform.ts new file mode 100644 index 00000000000..5760578d7cb --- /dev/null +++ b/src/agents/transport-message-transform.ts @@ -0,0 +1,131 @@ +import type { Api, Context, Model } from "@mariozechner/pi-ai"; + +export function transformTransportMessages( + messages: Context["messages"], + model: Model, + normalizeToolCallId?: ( + id: string, + targetModel: Model, + source: { provider: string; api: Api; model: string }, + ) => string, +): Context["messages"] { + const toolCallIdMap = new Map(); + const transformed = messages.map((msg) => { + if (msg.role === "user") { + return msg; + } + if (msg.role === "toolResult") { + const normalizedId = toolCallIdMap.get(msg.toolCallId); + return normalizedId && normalizedId !== msg.toolCallId + ? { ...msg, toolCallId: normalizedId } + : msg; + } + if (msg.role !== "assistant") { + return msg; + } + const isSameModel = + msg.provider === model.provider && msg.api === model.api && msg.model === model.id; + const content: typeof msg.content = []; + for (const block of msg.content) { + if (block.type === "thinking") { + if (block.redacted) { + if (isSameModel) { + content.push(block); + } + continue; + } + if (isSameModel && block.thinkingSignature) { + content.push(block); + continue; + } + if (!block.thinking.trim()) { + continue; + } + content.push(isSameModel ? block : { type: "text", text: block.thinking }); + continue; + } + if (block.type === "text") { + content.push(isSameModel ? block : { type: "text", text: block.text }); + continue; + } + if (block.type !== "toolCall") { + content.push(block); + continue; + } + let normalizedToolCall = block; + if (!isSameModel && block.thoughtSignature) { + normalizedToolCall = { ...normalizedToolCall }; + delete normalizedToolCall.thoughtSignature; + } + if (!isSameModel && normalizeToolCallId) { + const normalizedId = normalizeToolCallId(block.id, model, msg); + if (normalizedId !== block.id) { + toolCallIdMap.set(block.id, normalizedId); + normalizedToolCall = { ...normalizedToolCall, id: normalizedId }; + } + } + content.push(normalizedToolCall); + } + return { ...msg, content }; + }); + + const result: Context["messages"] = []; + let pendingToolCalls: Array<{ id: string; name: string }> = []; + let existingToolResultIds = new Set(); + for (const msg of transformed) { + if (msg.role === "assistant") { + if (pendingToolCalls.length > 0) { + for (const toolCall of pendingToolCalls) { + if (!existingToolResultIds.has(toolCall.id)) { + result.push({ + role: "toolResult", + toolCallId: toolCall.id, + toolName: toolCall.name, + content: [{ type: "text", text: "No result provided" }], + isError: true, + timestamp: Date.now(), + }); + } + } + pendingToolCalls = []; + existingToolResultIds = new Set(); + } + if (msg.stopReason === "error" || msg.stopReason === "aborted") { + continue; + } + const toolCalls = msg.content.filter( + (block): block is Extract<(typeof msg.content)[number], { type: "toolCall" }> => + block.type === "toolCall", + ); + if (toolCalls.length > 0) { + pendingToolCalls = toolCalls.map((block) => ({ id: block.id, name: block.name })); + existingToolResultIds = new Set(); + } + result.push(msg); + continue; + } + if (msg.role === "toolResult") { + existingToolResultIds.add(msg.toolCallId); + result.push(msg); + continue; + } + if (pendingToolCalls.length > 0) { + for (const toolCall of pendingToolCalls) { + if (!existingToolResultIds.has(toolCall.id)) { + result.push({ + role: "toolResult", + toolCallId: toolCall.id, + toolName: toolCall.name, + content: [{ type: "text", text: "No result provided" }], + isError: true, + timestamp: Date.now(), + }); + } + } + pendingToolCalls = []; + existingToolResultIds = new Set(); + } + result.push(msg); + } + return result; +}