From a9125ec0b07b61a5f646287cc8cf3cd89d2cd5f5 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 5 Apr 2026 20:04:59 +0100 Subject: [PATCH] refactor: share OpenAI tool schema normalization --- src/agents/openai-tool-schema.ts | 167 +++++++++++++++++++++ src/agents/openai-transport-stream.ts | 154 ++----------------- src/agents/openai-ws-message-conversion.ts | 17 ++- src/agents/openai-ws-stream.test.ts | 51 +++++++ src/agents/openai-ws-stream.ts | 21 ++- 5 files changed, 265 insertions(+), 145 deletions(-) create mode 100644 src/agents/openai-tool-schema.ts diff --git a/src/agents/openai-tool-schema.ts b/src/agents/openai-tool-schema.ts new file mode 100644 index 00000000000..d3ae15315af --- /dev/null +++ b/src/agents/openai-tool-schema.ts @@ -0,0 +1,167 @@ +import { normalizeToolParameterSchema } from "./pi-tools.schema.js"; +import { resolveProviderRequestCapabilities } from "./provider-attribution.js"; + +type OpenAITransportKind = "stream" | "websocket"; + +type OpenAIStrictToolModel = { + provider?: unknown; + api?: unknown; + baseUrl?: unknown; + id?: unknown; + compat?: { supportsStore?: boolean }; +}; + +type ToolWithParameters = { + parameters: unknown; +}; + +export function normalizeStrictOpenAIJsonSchema(schema: unknown): unknown { + return normalizeStrictOpenAIJsonSchemaRecursive(normalizeToolParameterSchema(schema ?? {})); +} + +function normalizeStrictOpenAIJsonSchemaRecursive(schema: unknown): unknown { + if (Array.isArray(schema)) { + let changed = false; + const normalized = schema.map((entry) => { + const next = normalizeStrictOpenAIJsonSchemaRecursive(entry); + changed ||= next !== entry; + return next; + }); + return changed ? normalized : schema; + } + if (!schema || typeof schema !== "object") { + return schema; + } + + const record = schema as Record; + let changed = false; + const normalized: Record = {}; + for (const [key, value] of Object.entries(record)) { + const next = normalizeStrictOpenAIJsonSchemaRecursive(value); + normalized[key] = next; + changed ||= next !== value; + } + + if (normalized.type === "object") { + const properties = + normalized.properties && + typeof normalized.properties === "object" && + !Array.isArray(normalized.properties) + ? (normalized.properties as Record) + : undefined; + if (properties && Object.keys(properties).length === 0 && !Array.isArray(normalized.required)) { + normalized.required = []; + changed = true; + } + } + + return changed ? normalized : schema; +} + +export function normalizeOpenAIStrictToolParameters(schema: T, strict: boolean): T { + if (!strict) { + return normalizeToolParameterSchema(schema ?? {}) as T; + } + return normalizeStrictOpenAIJsonSchema(schema) as T; +} + +export function isStrictOpenAIJsonSchemaCompatible(schema: unknown): boolean { + return isStrictOpenAIJsonSchemaCompatibleRecursive(normalizeStrictOpenAIJsonSchema(schema)); +} + +function isStrictOpenAIJsonSchemaCompatibleRecursive(schema: unknown): boolean { + if (Array.isArray(schema)) { + return schema.every((entry) => isStrictOpenAIJsonSchemaCompatibleRecursive(entry)); + } + if (!schema || typeof schema !== "object") { + return true; + } + + const record = schema as Record; + if ("anyOf" in record || "oneOf" in record || "allOf" in record) { + return false; + } + if (Array.isArray(record.type)) { + return false; + } + if (record.type === "object" && record.additionalProperties !== false) { + return false; + } + if (record.type === "object") { + const properties = + record.properties && + typeof record.properties === "object" && + !Array.isArray(record.properties) + ? (record.properties as Record) + : {}; + const required = Array.isArray(record.required) + ? record.required.filter((entry): entry is string => typeof entry === "string") + : undefined; + if (!required) { + return false; + } + const requiredSet = new Set(required); + if (Object.keys(properties).some((key) => !requiredSet.has(key))) { + return false; + } + } + + return Object.entries(record).every(([key, entry]) => { + if (key === "properties" && entry && typeof entry === "object" && !Array.isArray(entry)) { + return Object.values(entry as Record).every((value) => + isStrictOpenAIJsonSchemaCompatibleRecursive(value), + ); + } + return isStrictOpenAIJsonSchemaCompatibleRecursive(entry); + }); +} + +export function resolveOpenAIStrictToolFlagForInventory( + tools: readonly T[], + strict: boolean | null | undefined, +): boolean | undefined { + if (strict !== true) { + return strict === false ? false : undefined; + } + return tools.every((tool) => isStrictOpenAIJsonSchemaCompatible(tool.parameters)); +} + +export function resolvesToNativeOpenAIStrictTools( + model: OpenAIStrictToolModel, + transport: OpenAITransportKind, +): boolean { + const capabilities = resolveProviderRequestCapabilities({ + provider: model.provider, + api: model.api, + baseUrl: model.baseUrl, + capability: "llm", + transport, + modelId: model.id, + compat: + model.compat && typeof model.compat === "object" + ? (model.compat as { supportsStore?: boolean }) + : undefined, + }); + if (!capabilities.usesKnownNativeOpenAIRoute) { + return false; + } + return ( + capabilities.provider === "openai" || + capabilities.provider === "openai-codex" || + capabilities.provider === "azure-openai" || + capabilities.provider === "azure-openai-responses" + ); +} + +export function resolveOpenAIStrictToolSetting( + model: OpenAIStrictToolModel, + options?: { transport?: OpenAITransportKind; supportsStrictMode?: boolean }, +): boolean | undefined { + if (resolvesToNativeOpenAIStrictTools(model, options?.transport ?? "stream")) { + return true; + } + if (options?.supportsStrictMode) { + return false; + } + return undefined; +} diff --git a/src/agents/openai-transport-stream.ts b/src/agents/openai-transport-stream.ts index 0ee3617fff9..90937082c7f 100644 --- a/src/agents/openai-transport-stream.ts +++ b/src/agents/openai-transport-stream.ts @@ -27,7 +27,11 @@ import { applyOpenAIResponsesPayloadPolicy, resolveOpenAIResponsesPayloadPolicy, } from "./openai-responses-payload-policy.js"; -import { resolveProviderRequestCapabilities } from "./provider-attribution.js"; +import { + normalizeOpenAIStrictToolParameters, + resolveOpenAIStrictToolFlagForInventory, + resolveOpenAIStrictToolSetting, +} from "./openai-tool-schema.js"; import { buildGuardedModelFetch } from "./provider-transport-fetch.js"; import { stripSystemPromptCacheBoundary } from "./system-prompt-cache-boundary.js"; import { transformTransportMessages } from "./transport-message-transform.js"; @@ -332,7 +336,7 @@ function convertResponsesTools( tools: NonNullable, options?: { strict?: boolean | null }, ): FunctionTool[] { - const strict = resolveStrictToolFlagForInventory(tools, options?.strict); + const strict = resolveOpenAIStrictToolFlagForInventory(tools, options?.strict); if (strict === undefined) { return tools.map((tool) => ({ type: "function", @@ -350,104 +354,6 @@ function convertResponsesTools( })); } -function normalizeOpenAIStrictToolParameters(schema: T, strict: boolean): T { - if (!strict) { - return schema; - } - return normalizeStrictOpenAIJsonSchema(schema) as T; -} - -function normalizeStrictOpenAIJsonSchema(schema: unknown): unknown { - if (Array.isArray(schema)) { - let changed = false; - const normalized = schema.map((entry) => { - const next = normalizeStrictOpenAIJsonSchema(entry); - changed ||= next !== entry; - return next; - }); - return changed ? normalized : schema; - } - if (!schema || typeof schema !== "object") { - return schema; - } - - const record = schema as Record; - let changed = false; - const normalized: Record = {}; - for (const [key, value] of Object.entries(record)) { - const next = normalizeStrictOpenAIJsonSchema(value); - normalized[key] = next; - changed ||= next !== value; - } - - if (normalized.type === "object") { - const properties = - normalized.properties && - typeof normalized.properties === "object" && - !Array.isArray(normalized.properties) - ? (normalized.properties as Record) - : undefined; - if (properties && Object.keys(properties).length === 0 && !Array.isArray(normalized.required)) { - normalized.required = []; - changed = true; - } - } - - return changed ? normalized : schema; -} - -function isStrictOpenAIJsonSchemaCompatible(schema: unknown): boolean { - if (Array.isArray(schema)) { - return schema.every((entry) => isStrictOpenAIJsonSchemaCompatible(entry)); - } - if (!schema || typeof schema !== "object") { - return true; - } - - const record = schema as Record; - if ("anyOf" in record || "oneOf" in record || "allOf" in record) { - return false; - } - if (Array.isArray(record.type)) { - return false; - } - if (record.type === "object" && record.additionalProperties !== false) { - return false; - } - if (record.type === "object") { - const properties = - record.properties && - typeof record.properties === "object" && - !Array.isArray(record.properties) - ? (record.properties as Record) - : {}; - const required = Array.isArray(record.required) - ? record.required.filter((entry): entry is string => typeof entry === "string") - : undefined; - if (!required) { - return false; - } - const requiredSet = new Set(required); - if (Object.keys(properties).some((key) => !requiredSet.has(key))) { - return false; - } - } - - return Object.values(record).every((entry) => isStrictOpenAIJsonSchemaCompatible(entry)); -} - -function resolveStrictToolFlagForInventory( - tools: NonNullable, - strict: boolean | null | undefined, -): boolean | undefined { - if (strict !== true) { - return strict === false ? false : undefined; - } - return tools.every((tool) => - isStrictOpenAIJsonSchemaCompatible(normalizeStrictOpenAIJsonSchema(tool.parameters)), - ); -} - async function processResponsesStream( openaiStream: AsyncIterable, output: MutableAssistantOutput, @@ -857,7 +763,9 @@ export function buildOpenAIResponsesParams( } if (context.tools) { params.tools = convertResponsesTools(context.tools, { - strict: resolveOpenAIStrictToolSetting(model as OpenAIModeModel), + strict: resolveOpenAIStrictToolSetting(model as OpenAIModeModel, { + transport: "stream", + }), }); } if (model.reasoning) { @@ -1318,51 +1226,17 @@ function mapReasoningEffort(effort: string, reasoningEffortMap: Record, -): boolean | undefined { - if (resolvesToNativeOpenAIStrictTools(model)) { - return true; - } - if (compat?.supportsStrictMode) { - return false; - } - return undefined; -} - function convertTools( tools: NonNullable, compat: ReturnType, model: OpenAIModeModel, ) { - const strict = resolveStrictToolFlagForInventory( + const strict = resolveOpenAIStrictToolFlagForInventory( tools, - resolveOpenAIStrictToolSetting(model, compat), + resolveOpenAIStrictToolSetting(model, { + transport: "stream", + supportsStrictMode: compat?.supportsStrictMode, + }), ); return tools.map((tool) => ({ type: "function", diff --git a/src/agents/openai-ws-message-conversion.ts b/src/agents/openai-ws-message-conversion.ts index fa2f7961102..9a2e997a715 100644 --- a/src/agents/openai-ws-message-conversion.ts +++ b/src/agents/openai-ws-message-conversion.ts @@ -1,6 +1,10 @@ import { randomUUID } from "node:crypto"; import type { Context, Message, StopReason } from "@mariozechner/pi-ai"; import type { AssistantMessage } from "@mariozechner/pi-ai"; +import { + normalizeOpenAIStrictToolParameters, + resolveOpenAIStrictToolFlagForInventory, +} from "./openai-tool-schema.js"; import type { ContentPart, FunctionToolDefinition, @@ -8,7 +12,6 @@ import type { OpenAIResponsesAssistantPhase, ResponseObject, } from "./openai-ws-connection.js"; -import { normalizeToolParameterSchema } from "./pi-tools.schema.js"; import { buildAssistantMessage, buildUsageWithNoCost } from "./stream-message-shared.js"; import { normalizeUsage } from "./usage.js"; @@ -274,16 +277,24 @@ function extractResponseReasoningText(item: unknown): string { return typeof record.content === "string" ? record.content.trim() : ""; } -export function convertTools(tools: Context["tools"]): FunctionToolDefinition[] { +export function convertTools( + tools: Context["tools"], + options?: { strict?: boolean | null }, +): FunctionToolDefinition[] { if (!tools || tools.length === 0) { return []; } + const strict = resolveOpenAIStrictToolFlagForInventory(tools, options?.strict); return tools.map((tool) => { return { type: "function" as const, name: tool.name, description: typeof tool.description === "string" ? tool.description : undefined, - parameters: normalizeToolParameterSchema(tool.parameters ?? {}) as Record, + parameters: normalizeOpenAIStrictToolParameters( + tool.parameters ?? {}, + strict === true, + ) as Record, + ...(strict === undefined ? {} : { strict }), }; }); } diff --git a/src/agents/openai-ws-stream.test.ts b/src/agents/openai-ws-stream.test.ts index b9852e7f5f0..14a05d02e19 100644 --- a/src/agents/openai-ws-stream.test.ts +++ b/src/agents/openai-ws-stream.test.ts @@ -503,6 +503,57 @@ describe("convertTools", () => { properties: { cmd: { type: "string" } }, }); }); + + it("adds strict:true and required:[] for native strict-compatible no-param tools", () => { + const tools = [ + { + name: "ping", + description: "No params", + parameters: { type: "object", properties: {}, additionalProperties: false }, + }, + ]; + const result = convertTools(tools as Parameters[0], { strict: true }); + + expect(result[0]).toEqual({ + type: "function", + name: "ping", + description: "No params", + parameters: { + type: "object", + properties: {}, + additionalProperties: false, + required: [], + }, + strict: true, + }); + }); + + it("falls back to strict:false for native tools with non-strict-compatible schemas", () => { + const tools = [ + { + name: "read", + description: "Read file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + additionalProperties: false, + }, + }, + ]; + const result = convertTools(tools as Parameters[0], { strict: true }); + + expect(result[0]).toEqual({ + type: "function", + name: "read", + description: "Read file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + additionalProperties: false, + }, + strict: false, + }); + }); }); // ───────────────────────────────────────────────────────────────────────────── diff --git a/src/agents/openai-ws-stream.ts b/src/agents/openai-ws-stream.ts index c9e5fd80f3b..4e453fdeef1 100644 --- a/src/agents/openai-ws-stream.ts +++ b/src/agents/openai-ws-stream.ts @@ -35,6 +35,7 @@ import { resolveProviderWebSocketSessionPolicyWithPlugin, } from "../plugins/provider-runtime.js"; import type { ProviderRuntimeModel, ProviderTransportTurnState } from "../plugins/types.js"; +import { resolveOpenAIStrictToolSetting } from "./openai-tool-schema.js"; import { getOpenAIWebSocketErrorDetails, OpenAIWebSocketManager, @@ -78,6 +79,18 @@ interface WsSession { degradeCooldownMs: number; } +function resolveOpenAIWebSocketStrictToolSetting( + model: Parameters[0], +): boolean | undefined { + return resolveOpenAIStrictToolSetting(model, { + transport: "websocket", + supportsStrictMode: + model.compat && typeof model.compat === "object" + ? (model.compat as { supportsStrictMode?: boolean }).supportsStrictMode + : undefined, + }); +} + /** Module-level registry: sessionId → WsSession */ const wsRegistry = new Map(); @@ -747,7 +760,9 @@ export function createOpenAIWebSocketStreamFn( await runWarmUp({ manager: session.manager, modelId: model.id, - tools: convertTools(context.tools), + tools: convertTools(context.tools, { + strict: resolveOpenAIWebSocketStrictToolSetting(model), + }), instructions: context.systemPrompt ? stripSystemPromptCacheBoundary(context.systemPrompt) : undefined, @@ -842,7 +857,9 @@ export function createOpenAIWebSocketStreamFn( context, options: options as WsOptions | undefined, turnInput, - tools: convertTools(context.tools), + tools: convertTools(context.tools, { + strict: resolveOpenAIWebSocketStrictToolSetting(model), + }), metadata: turnState?.metadata, }) as Record; const nextPayload = await options?.onPayload?.(payload, model);