refactor: share OpenAI tool schema normalization

This commit is contained in:
Peter Steinberger 2026-04-05 20:04:59 +01:00
parent 31016c5ed9
commit a9125ec0b0
No known key found for this signature in database
5 changed files with 265 additions and 145 deletions

View File

@ -0,0 +1,167 @@
import { normalizeToolParameterSchema } from "./pi-tools.schema.js";
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
type OpenAITransportKind = "stream" | "websocket";
type OpenAIStrictToolModel = {
provider?: unknown;
api?: unknown;
baseUrl?: unknown;
id?: unknown;
compat?: { supportsStore?: boolean };
};
type ToolWithParameters = {
parameters: unknown;
};
export function normalizeStrictOpenAIJsonSchema(schema: unknown): unknown {
return normalizeStrictOpenAIJsonSchemaRecursive(normalizeToolParameterSchema(schema ?? {}));
}
function normalizeStrictOpenAIJsonSchemaRecursive(schema: unknown): unknown {
if (Array.isArray(schema)) {
let changed = false;
const normalized = schema.map((entry) => {
const next = normalizeStrictOpenAIJsonSchemaRecursive(entry);
changed ||= next !== entry;
return next;
});
return changed ? normalized : schema;
}
if (!schema || typeof schema !== "object") {
return schema;
}
const record = schema as Record<string, unknown>;
let changed = false;
const normalized: Record<string, unknown> = {};
for (const [key, value] of Object.entries(record)) {
const next = normalizeStrictOpenAIJsonSchemaRecursive(value);
normalized[key] = next;
changed ||= next !== value;
}
if (normalized.type === "object") {
const properties =
normalized.properties &&
typeof normalized.properties === "object" &&
!Array.isArray(normalized.properties)
? (normalized.properties as Record<string, unknown>)
: undefined;
if (properties && Object.keys(properties).length === 0 && !Array.isArray(normalized.required)) {
normalized.required = [];
changed = true;
}
}
return changed ? normalized : schema;
}
export function normalizeOpenAIStrictToolParameters<T>(schema: T, strict: boolean): T {
if (!strict) {
return normalizeToolParameterSchema(schema ?? {}) as T;
}
return normalizeStrictOpenAIJsonSchema(schema) as T;
}
export function isStrictOpenAIJsonSchemaCompatible(schema: unknown): boolean {
return isStrictOpenAIJsonSchemaCompatibleRecursive(normalizeStrictOpenAIJsonSchema(schema));
}
function isStrictOpenAIJsonSchemaCompatibleRecursive(schema: unknown): boolean {
if (Array.isArray(schema)) {
return schema.every((entry) => isStrictOpenAIJsonSchemaCompatibleRecursive(entry));
}
if (!schema || typeof schema !== "object") {
return true;
}
const record = schema as Record<string, unknown>;
if ("anyOf" in record || "oneOf" in record || "allOf" in record) {
return false;
}
if (Array.isArray(record.type)) {
return false;
}
if (record.type === "object" && record.additionalProperties !== false) {
return false;
}
if (record.type === "object") {
const properties =
record.properties &&
typeof record.properties === "object" &&
!Array.isArray(record.properties)
? (record.properties as Record<string, unknown>)
: {};
const required = Array.isArray(record.required)
? record.required.filter((entry): entry is string => typeof entry === "string")
: undefined;
if (!required) {
return false;
}
const requiredSet = new Set(required);
if (Object.keys(properties).some((key) => !requiredSet.has(key))) {
return false;
}
}
return Object.entries(record).every(([key, entry]) => {
if (key === "properties" && entry && typeof entry === "object" && !Array.isArray(entry)) {
return Object.values(entry as Record<string, unknown>).every((value) =>
isStrictOpenAIJsonSchemaCompatibleRecursive(value),
);
}
return isStrictOpenAIJsonSchemaCompatibleRecursive(entry);
});
}
export function resolveOpenAIStrictToolFlagForInventory<T extends ToolWithParameters>(
tools: readonly T[],
strict: boolean | null | undefined,
): boolean | undefined {
if (strict !== true) {
return strict === false ? false : undefined;
}
return tools.every((tool) => isStrictOpenAIJsonSchemaCompatible(tool.parameters));
}
export function resolvesToNativeOpenAIStrictTools(
model: OpenAIStrictToolModel,
transport: OpenAITransportKind,
): boolean {
const capabilities = resolveProviderRequestCapabilities({
provider: model.provider,
api: model.api,
baseUrl: model.baseUrl,
capability: "llm",
transport,
modelId: model.id,
compat:
model.compat && typeof model.compat === "object"
? (model.compat as { supportsStore?: boolean })
: undefined,
});
if (!capabilities.usesKnownNativeOpenAIRoute) {
return false;
}
return (
capabilities.provider === "openai" ||
capabilities.provider === "openai-codex" ||
capabilities.provider === "azure-openai" ||
capabilities.provider === "azure-openai-responses"
);
}
export function resolveOpenAIStrictToolSetting(
model: OpenAIStrictToolModel,
options?: { transport?: OpenAITransportKind; supportsStrictMode?: boolean },
): boolean | undefined {
if (resolvesToNativeOpenAIStrictTools(model, options?.transport ?? "stream")) {
return true;
}
if (options?.supportsStrictMode) {
return false;
}
return undefined;
}

View File

@ -27,7 +27,11 @@ import {
applyOpenAIResponsesPayloadPolicy,
resolveOpenAIResponsesPayloadPolicy,
} from "./openai-responses-payload-policy.js";
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
import {
normalizeOpenAIStrictToolParameters,
resolveOpenAIStrictToolFlagForInventory,
resolveOpenAIStrictToolSetting,
} from "./openai-tool-schema.js";
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
import { stripSystemPromptCacheBoundary } from "./system-prompt-cache-boundary.js";
import { transformTransportMessages } from "./transport-message-transform.js";
@ -332,7 +336,7 @@ function convertResponsesTools(
tools: NonNullable<Context["tools"]>,
options?: { strict?: boolean | null },
): FunctionTool[] {
const strict = resolveStrictToolFlagForInventory(tools, options?.strict);
const strict = resolveOpenAIStrictToolFlagForInventory(tools, options?.strict);
if (strict === undefined) {
return tools.map((tool) => ({
type: "function",
@ -350,104 +354,6 @@ function convertResponsesTools(
}));
}
function normalizeOpenAIStrictToolParameters<T>(schema: T, strict: boolean): T {
if (!strict) {
return schema;
}
return normalizeStrictOpenAIJsonSchema(schema) as T;
}
function normalizeStrictOpenAIJsonSchema(schema: unknown): unknown {
if (Array.isArray(schema)) {
let changed = false;
const normalized = schema.map((entry) => {
const next = normalizeStrictOpenAIJsonSchema(entry);
changed ||= next !== entry;
return next;
});
return changed ? normalized : schema;
}
if (!schema || typeof schema !== "object") {
return schema;
}
const record = schema as Record<string, unknown>;
let changed = false;
const normalized: Record<string, unknown> = {};
for (const [key, value] of Object.entries(record)) {
const next = normalizeStrictOpenAIJsonSchema(value);
normalized[key] = next;
changed ||= next !== value;
}
if (normalized.type === "object") {
const properties =
normalized.properties &&
typeof normalized.properties === "object" &&
!Array.isArray(normalized.properties)
? (normalized.properties as Record<string, unknown>)
: undefined;
if (properties && Object.keys(properties).length === 0 && !Array.isArray(normalized.required)) {
normalized.required = [];
changed = true;
}
}
return changed ? normalized : schema;
}
function isStrictOpenAIJsonSchemaCompatible(schema: unknown): boolean {
if (Array.isArray(schema)) {
return schema.every((entry) => isStrictOpenAIJsonSchemaCompatible(entry));
}
if (!schema || typeof schema !== "object") {
return true;
}
const record = schema as Record<string, unknown>;
if ("anyOf" in record || "oneOf" in record || "allOf" in record) {
return false;
}
if (Array.isArray(record.type)) {
return false;
}
if (record.type === "object" && record.additionalProperties !== false) {
return false;
}
if (record.type === "object") {
const properties =
record.properties &&
typeof record.properties === "object" &&
!Array.isArray(record.properties)
? (record.properties as Record<string, unknown>)
: {};
const required = Array.isArray(record.required)
? record.required.filter((entry): entry is string => typeof entry === "string")
: undefined;
if (!required) {
return false;
}
const requiredSet = new Set(required);
if (Object.keys(properties).some((key) => !requiredSet.has(key))) {
return false;
}
}
return Object.values(record).every((entry) => isStrictOpenAIJsonSchemaCompatible(entry));
}
function resolveStrictToolFlagForInventory(
tools: NonNullable<Context["tools"]>,
strict: boolean | null | undefined,
): boolean | undefined {
if (strict !== true) {
return strict === false ? false : undefined;
}
return tools.every((tool) =>
isStrictOpenAIJsonSchemaCompatible(normalizeStrictOpenAIJsonSchema(tool.parameters)),
);
}
async function processResponsesStream(
openaiStream: AsyncIterable<unknown>,
output: MutableAssistantOutput,
@ -857,7 +763,9 @@ export function buildOpenAIResponsesParams(
}
if (context.tools) {
params.tools = convertResponsesTools(context.tools, {
strict: resolveOpenAIStrictToolSetting(model as OpenAIModeModel),
strict: resolveOpenAIStrictToolSetting(model as OpenAIModeModel, {
transport: "stream",
}),
});
}
if (model.reasoning) {
@ -1318,51 +1226,17 @@ function mapReasoningEffort(effort: string, reasoningEffortMap: Record<string, s
return reasoningEffortMap[effort] ?? effort;
}
function resolvesToNativeOpenAIStrictTools(model: OpenAIModeModel): boolean {
const capabilities = resolveProviderRequestCapabilities({
provider: model.provider,
api: model.api,
baseUrl: model.baseUrl,
capability: "llm",
transport: "stream",
modelId: model.id,
compat:
model.compat && typeof model.compat === "object"
? (model.compat as { supportsStore?: boolean })
: undefined,
});
if (!capabilities.usesKnownNativeOpenAIRoute) {
return false;
}
return (
capabilities.provider === "openai" ||
capabilities.provider === "openai-codex" ||
capabilities.provider === "azure-openai" ||
capabilities.provider === "azure-openai-responses"
);
}
function resolveOpenAIStrictToolSetting(
model: OpenAIModeModel,
compat?: ReturnType<typeof getCompat>,
): boolean | undefined {
if (resolvesToNativeOpenAIStrictTools(model)) {
return true;
}
if (compat?.supportsStrictMode) {
return false;
}
return undefined;
}
function convertTools(
tools: NonNullable<Context["tools"]>,
compat: ReturnType<typeof getCompat>,
model: OpenAIModeModel,
) {
const strict = resolveStrictToolFlagForInventory(
const strict = resolveOpenAIStrictToolFlagForInventory(
tools,
resolveOpenAIStrictToolSetting(model, compat),
resolveOpenAIStrictToolSetting(model, {
transport: "stream",
supportsStrictMode: compat?.supportsStrictMode,
}),
);
return tools.map((tool) => ({
type: "function",

View File

@ -1,6 +1,10 @@
import { randomUUID } from "node:crypto";
import type { Context, Message, StopReason } from "@mariozechner/pi-ai";
import type { AssistantMessage } from "@mariozechner/pi-ai";
import {
normalizeOpenAIStrictToolParameters,
resolveOpenAIStrictToolFlagForInventory,
} from "./openai-tool-schema.js";
import type {
ContentPart,
FunctionToolDefinition,
@ -8,7 +12,6 @@ import type {
OpenAIResponsesAssistantPhase,
ResponseObject,
} from "./openai-ws-connection.js";
import { normalizeToolParameterSchema } from "./pi-tools.schema.js";
import { buildAssistantMessage, buildUsageWithNoCost } from "./stream-message-shared.js";
import { normalizeUsage } from "./usage.js";
@ -274,16 +277,24 @@ function extractResponseReasoningText(item: unknown): string {
return typeof record.content === "string" ? record.content.trim() : "";
}
export function convertTools(tools: Context["tools"]): FunctionToolDefinition[] {
export function convertTools(
tools: Context["tools"],
options?: { strict?: boolean | null },
): FunctionToolDefinition[] {
if (!tools || tools.length === 0) {
return [];
}
const strict = resolveOpenAIStrictToolFlagForInventory(tools, options?.strict);
return tools.map((tool) => {
return {
type: "function" as const,
name: tool.name,
description: typeof tool.description === "string" ? tool.description : undefined,
parameters: normalizeToolParameterSchema(tool.parameters ?? {}) as Record<string, unknown>,
parameters: normalizeOpenAIStrictToolParameters(
tool.parameters ?? {},
strict === true,
) as Record<string, unknown>,
...(strict === undefined ? {} : { strict }),
};
});
}

View File

@ -503,6 +503,57 @@ describe("convertTools", () => {
properties: { cmd: { type: "string" } },
});
});
it("adds strict:true and required:[] for native strict-compatible no-param tools", () => {
const tools = [
{
name: "ping",
description: "No params",
parameters: { type: "object", properties: {}, additionalProperties: false },
},
];
const result = convertTools(tools as Parameters<typeof convertTools>[0], { strict: true });
expect(result[0]).toEqual({
type: "function",
name: "ping",
description: "No params",
parameters: {
type: "object",
properties: {},
additionalProperties: false,
required: [],
},
strict: true,
});
});
it("falls back to strict:false for native tools with non-strict-compatible schemas", () => {
const tools = [
{
name: "read",
description: "Read file",
parameters: {
type: "object",
properties: { path: { type: "string" } },
additionalProperties: false,
},
},
];
const result = convertTools(tools as Parameters<typeof convertTools>[0], { strict: true });
expect(result[0]).toEqual({
type: "function",
name: "read",
description: "Read file",
parameters: {
type: "object",
properties: { path: { type: "string" } },
additionalProperties: false,
},
strict: false,
});
});
});
// ─────────────────────────────────────────────────────────────────────────────

View File

@ -35,6 +35,7 @@ import {
resolveProviderWebSocketSessionPolicyWithPlugin,
} from "../plugins/provider-runtime.js";
import type { ProviderRuntimeModel, ProviderTransportTurnState } from "../plugins/types.js";
import { resolveOpenAIStrictToolSetting } from "./openai-tool-schema.js";
import {
getOpenAIWebSocketErrorDetails,
OpenAIWebSocketManager,
@ -78,6 +79,18 @@ interface WsSession {
degradeCooldownMs: number;
}
function resolveOpenAIWebSocketStrictToolSetting(
model: Parameters<StreamFn>[0],
): boolean | undefined {
return resolveOpenAIStrictToolSetting(model, {
transport: "websocket",
supportsStrictMode:
model.compat && typeof model.compat === "object"
? (model.compat as { supportsStrictMode?: boolean }).supportsStrictMode
: undefined,
});
}
/** Module-level registry: sessionId → WsSession */
const wsRegistry = new Map<string, WsSession>();
@ -747,7 +760,9 @@ export function createOpenAIWebSocketStreamFn(
await runWarmUp({
manager: session.manager,
modelId: model.id,
tools: convertTools(context.tools),
tools: convertTools(context.tools, {
strict: resolveOpenAIWebSocketStrictToolSetting(model),
}),
instructions: context.systemPrompt
? stripSystemPromptCacheBoundary(context.systemPrompt)
: undefined,
@ -842,7 +857,9 @@ export function createOpenAIWebSocketStreamFn(
context,
options: options as WsOptions | undefined,
turnInput,
tools: convertTools(context.tools),
tools: convertTools(context.tools, {
strict: resolveOpenAIWebSocketStrictToolSetting(model),
}),
metadata: turnState?.metadata,
}) as Record<string, unknown>;
const nextPayload = await options?.onPayload?.(payload, model);