mirror of https://github.com/openclaw/openclaw.git
refactor: share OpenAI tool schema normalization
This commit is contained in:
parent
31016c5ed9
commit
a9125ec0b0
|
|
@ -0,0 +1,167 @@
|
|||
import { normalizeToolParameterSchema } from "./pi-tools.schema.js";
|
||||
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
|
||||
|
||||
type OpenAITransportKind = "stream" | "websocket";
|
||||
|
||||
type OpenAIStrictToolModel = {
|
||||
provider?: unknown;
|
||||
api?: unknown;
|
||||
baseUrl?: unknown;
|
||||
id?: unknown;
|
||||
compat?: { supportsStore?: boolean };
|
||||
};
|
||||
|
||||
type ToolWithParameters = {
|
||||
parameters: unknown;
|
||||
};
|
||||
|
||||
export function normalizeStrictOpenAIJsonSchema(schema: unknown): unknown {
|
||||
return normalizeStrictOpenAIJsonSchemaRecursive(normalizeToolParameterSchema(schema ?? {}));
|
||||
}
|
||||
|
||||
function normalizeStrictOpenAIJsonSchemaRecursive(schema: unknown): unknown {
|
||||
if (Array.isArray(schema)) {
|
||||
let changed = false;
|
||||
const normalized = schema.map((entry) => {
|
||||
const next = normalizeStrictOpenAIJsonSchemaRecursive(entry);
|
||||
changed ||= next !== entry;
|
||||
return next;
|
||||
});
|
||||
return changed ? normalized : schema;
|
||||
}
|
||||
if (!schema || typeof schema !== "object") {
|
||||
return schema;
|
||||
}
|
||||
|
||||
const record = schema as Record<string, unknown>;
|
||||
let changed = false;
|
||||
const normalized: Record<string, unknown> = {};
|
||||
for (const [key, value] of Object.entries(record)) {
|
||||
const next = normalizeStrictOpenAIJsonSchemaRecursive(value);
|
||||
normalized[key] = next;
|
||||
changed ||= next !== value;
|
||||
}
|
||||
|
||||
if (normalized.type === "object") {
|
||||
const properties =
|
||||
normalized.properties &&
|
||||
typeof normalized.properties === "object" &&
|
||||
!Array.isArray(normalized.properties)
|
||||
? (normalized.properties as Record<string, unknown>)
|
||||
: undefined;
|
||||
if (properties && Object.keys(properties).length === 0 && !Array.isArray(normalized.required)) {
|
||||
normalized.required = [];
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
return changed ? normalized : schema;
|
||||
}
|
||||
|
||||
export function normalizeOpenAIStrictToolParameters<T>(schema: T, strict: boolean): T {
|
||||
if (!strict) {
|
||||
return normalizeToolParameterSchema(schema ?? {}) as T;
|
||||
}
|
||||
return normalizeStrictOpenAIJsonSchema(schema) as T;
|
||||
}
|
||||
|
||||
export function isStrictOpenAIJsonSchemaCompatible(schema: unknown): boolean {
|
||||
return isStrictOpenAIJsonSchemaCompatibleRecursive(normalizeStrictOpenAIJsonSchema(schema));
|
||||
}
|
||||
|
||||
function isStrictOpenAIJsonSchemaCompatibleRecursive(schema: unknown): boolean {
|
||||
if (Array.isArray(schema)) {
|
||||
return schema.every((entry) => isStrictOpenAIJsonSchemaCompatibleRecursive(entry));
|
||||
}
|
||||
if (!schema || typeof schema !== "object") {
|
||||
return true;
|
||||
}
|
||||
|
||||
const record = schema as Record<string, unknown>;
|
||||
if ("anyOf" in record || "oneOf" in record || "allOf" in record) {
|
||||
return false;
|
||||
}
|
||||
if (Array.isArray(record.type)) {
|
||||
return false;
|
||||
}
|
||||
if (record.type === "object" && record.additionalProperties !== false) {
|
||||
return false;
|
||||
}
|
||||
if (record.type === "object") {
|
||||
const properties =
|
||||
record.properties &&
|
||||
typeof record.properties === "object" &&
|
||||
!Array.isArray(record.properties)
|
||||
? (record.properties as Record<string, unknown>)
|
||||
: {};
|
||||
const required = Array.isArray(record.required)
|
||||
? record.required.filter((entry): entry is string => typeof entry === "string")
|
||||
: undefined;
|
||||
if (!required) {
|
||||
return false;
|
||||
}
|
||||
const requiredSet = new Set(required);
|
||||
if (Object.keys(properties).some((key) => !requiredSet.has(key))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return Object.entries(record).every(([key, entry]) => {
|
||||
if (key === "properties" && entry && typeof entry === "object" && !Array.isArray(entry)) {
|
||||
return Object.values(entry as Record<string, unknown>).every((value) =>
|
||||
isStrictOpenAIJsonSchemaCompatibleRecursive(value),
|
||||
);
|
||||
}
|
||||
return isStrictOpenAIJsonSchemaCompatibleRecursive(entry);
|
||||
});
|
||||
}
|
||||
|
||||
export function resolveOpenAIStrictToolFlagForInventory<T extends ToolWithParameters>(
|
||||
tools: readonly T[],
|
||||
strict: boolean | null | undefined,
|
||||
): boolean | undefined {
|
||||
if (strict !== true) {
|
||||
return strict === false ? false : undefined;
|
||||
}
|
||||
return tools.every((tool) => isStrictOpenAIJsonSchemaCompatible(tool.parameters));
|
||||
}
|
||||
|
||||
export function resolvesToNativeOpenAIStrictTools(
|
||||
model: OpenAIStrictToolModel,
|
||||
transport: OpenAITransportKind,
|
||||
): boolean {
|
||||
const capabilities = resolveProviderRequestCapabilities({
|
||||
provider: model.provider,
|
||||
api: model.api,
|
||||
baseUrl: model.baseUrl,
|
||||
capability: "llm",
|
||||
transport,
|
||||
modelId: model.id,
|
||||
compat:
|
||||
model.compat && typeof model.compat === "object"
|
||||
? (model.compat as { supportsStore?: boolean })
|
||||
: undefined,
|
||||
});
|
||||
if (!capabilities.usesKnownNativeOpenAIRoute) {
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
capabilities.provider === "openai" ||
|
||||
capabilities.provider === "openai-codex" ||
|
||||
capabilities.provider === "azure-openai" ||
|
||||
capabilities.provider === "azure-openai-responses"
|
||||
);
|
||||
}
|
||||
|
||||
export function resolveOpenAIStrictToolSetting(
|
||||
model: OpenAIStrictToolModel,
|
||||
options?: { transport?: OpenAITransportKind; supportsStrictMode?: boolean },
|
||||
): boolean | undefined {
|
||||
if (resolvesToNativeOpenAIStrictTools(model, options?.transport ?? "stream")) {
|
||||
return true;
|
||||
}
|
||||
if (options?.supportsStrictMode) {
|
||||
return false;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
|
@ -27,7 +27,11 @@ import {
|
|||
applyOpenAIResponsesPayloadPolicy,
|
||||
resolveOpenAIResponsesPayloadPolicy,
|
||||
} from "./openai-responses-payload-policy.js";
|
||||
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
|
||||
import {
|
||||
normalizeOpenAIStrictToolParameters,
|
||||
resolveOpenAIStrictToolFlagForInventory,
|
||||
resolveOpenAIStrictToolSetting,
|
||||
} from "./openai-tool-schema.js";
|
||||
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
|
||||
import { stripSystemPromptCacheBoundary } from "./system-prompt-cache-boundary.js";
|
||||
import { transformTransportMessages } from "./transport-message-transform.js";
|
||||
|
|
@ -332,7 +336,7 @@ function convertResponsesTools(
|
|||
tools: NonNullable<Context["tools"]>,
|
||||
options?: { strict?: boolean | null },
|
||||
): FunctionTool[] {
|
||||
const strict = resolveStrictToolFlagForInventory(tools, options?.strict);
|
||||
const strict = resolveOpenAIStrictToolFlagForInventory(tools, options?.strict);
|
||||
if (strict === undefined) {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
|
|
@ -350,104 +354,6 @@ function convertResponsesTools(
|
|||
}));
|
||||
}
|
||||
|
||||
function normalizeOpenAIStrictToolParameters<T>(schema: T, strict: boolean): T {
|
||||
if (!strict) {
|
||||
return schema;
|
||||
}
|
||||
return normalizeStrictOpenAIJsonSchema(schema) as T;
|
||||
}
|
||||
|
||||
function normalizeStrictOpenAIJsonSchema(schema: unknown): unknown {
|
||||
if (Array.isArray(schema)) {
|
||||
let changed = false;
|
||||
const normalized = schema.map((entry) => {
|
||||
const next = normalizeStrictOpenAIJsonSchema(entry);
|
||||
changed ||= next !== entry;
|
||||
return next;
|
||||
});
|
||||
return changed ? normalized : schema;
|
||||
}
|
||||
if (!schema || typeof schema !== "object") {
|
||||
return schema;
|
||||
}
|
||||
|
||||
const record = schema as Record<string, unknown>;
|
||||
let changed = false;
|
||||
const normalized: Record<string, unknown> = {};
|
||||
for (const [key, value] of Object.entries(record)) {
|
||||
const next = normalizeStrictOpenAIJsonSchema(value);
|
||||
normalized[key] = next;
|
||||
changed ||= next !== value;
|
||||
}
|
||||
|
||||
if (normalized.type === "object") {
|
||||
const properties =
|
||||
normalized.properties &&
|
||||
typeof normalized.properties === "object" &&
|
||||
!Array.isArray(normalized.properties)
|
||||
? (normalized.properties as Record<string, unknown>)
|
||||
: undefined;
|
||||
if (properties && Object.keys(properties).length === 0 && !Array.isArray(normalized.required)) {
|
||||
normalized.required = [];
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
return changed ? normalized : schema;
|
||||
}
|
||||
|
||||
function isStrictOpenAIJsonSchemaCompatible(schema: unknown): boolean {
|
||||
if (Array.isArray(schema)) {
|
||||
return schema.every((entry) => isStrictOpenAIJsonSchemaCompatible(entry));
|
||||
}
|
||||
if (!schema || typeof schema !== "object") {
|
||||
return true;
|
||||
}
|
||||
|
||||
const record = schema as Record<string, unknown>;
|
||||
if ("anyOf" in record || "oneOf" in record || "allOf" in record) {
|
||||
return false;
|
||||
}
|
||||
if (Array.isArray(record.type)) {
|
||||
return false;
|
||||
}
|
||||
if (record.type === "object" && record.additionalProperties !== false) {
|
||||
return false;
|
||||
}
|
||||
if (record.type === "object") {
|
||||
const properties =
|
||||
record.properties &&
|
||||
typeof record.properties === "object" &&
|
||||
!Array.isArray(record.properties)
|
||||
? (record.properties as Record<string, unknown>)
|
||||
: {};
|
||||
const required = Array.isArray(record.required)
|
||||
? record.required.filter((entry): entry is string => typeof entry === "string")
|
||||
: undefined;
|
||||
if (!required) {
|
||||
return false;
|
||||
}
|
||||
const requiredSet = new Set(required);
|
||||
if (Object.keys(properties).some((key) => !requiredSet.has(key))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return Object.values(record).every((entry) => isStrictOpenAIJsonSchemaCompatible(entry));
|
||||
}
|
||||
|
||||
function resolveStrictToolFlagForInventory(
|
||||
tools: NonNullable<Context["tools"]>,
|
||||
strict: boolean | null | undefined,
|
||||
): boolean | undefined {
|
||||
if (strict !== true) {
|
||||
return strict === false ? false : undefined;
|
||||
}
|
||||
return tools.every((tool) =>
|
||||
isStrictOpenAIJsonSchemaCompatible(normalizeStrictOpenAIJsonSchema(tool.parameters)),
|
||||
);
|
||||
}
|
||||
|
||||
async function processResponsesStream(
|
||||
openaiStream: AsyncIterable<unknown>,
|
||||
output: MutableAssistantOutput,
|
||||
|
|
@ -857,7 +763,9 @@ export function buildOpenAIResponsesParams(
|
|||
}
|
||||
if (context.tools) {
|
||||
params.tools = convertResponsesTools(context.tools, {
|
||||
strict: resolveOpenAIStrictToolSetting(model as OpenAIModeModel),
|
||||
strict: resolveOpenAIStrictToolSetting(model as OpenAIModeModel, {
|
||||
transport: "stream",
|
||||
}),
|
||||
});
|
||||
}
|
||||
if (model.reasoning) {
|
||||
|
|
@ -1318,51 +1226,17 @@ function mapReasoningEffort(effort: string, reasoningEffortMap: Record<string, s
|
|||
return reasoningEffortMap[effort] ?? effort;
|
||||
}
|
||||
|
||||
function resolvesToNativeOpenAIStrictTools(model: OpenAIModeModel): boolean {
|
||||
const capabilities = resolveProviderRequestCapabilities({
|
||||
provider: model.provider,
|
||||
api: model.api,
|
||||
baseUrl: model.baseUrl,
|
||||
capability: "llm",
|
||||
transport: "stream",
|
||||
modelId: model.id,
|
||||
compat:
|
||||
model.compat && typeof model.compat === "object"
|
||||
? (model.compat as { supportsStore?: boolean })
|
||||
: undefined,
|
||||
});
|
||||
if (!capabilities.usesKnownNativeOpenAIRoute) {
|
||||
return false;
|
||||
}
|
||||
return (
|
||||
capabilities.provider === "openai" ||
|
||||
capabilities.provider === "openai-codex" ||
|
||||
capabilities.provider === "azure-openai" ||
|
||||
capabilities.provider === "azure-openai-responses"
|
||||
);
|
||||
}
|
||||
|
||||
function resolveOpenAIStrictToolSetting(
|
||||
model: OpenAIModeModel,
|
||||
compat?: ReturnType<typeof getCompat>,
|
||||
): boolean | undefined {
|
||||
if (resolvesToNativeOpenAIStrictTools(model)) {
|
||||
return true;
|
||||
}
|
||||
if (compat?.supportsStrictMode) {
|
||||
return false;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function convertTools(
|
||||
tools: NonNullable<Context["tools"]>,
|
||||
compat: ReturnType<typeof getCompat>,
|
||||
model: OpenAIModeModel,
|
||||
) {
|
||||
const strict = resolveStrictToolFlagForInventory(
|
||||
const strict = resolveOpenAIStrictToolFlagForInventory(
|
||||
tools,
|
||||
resolveOpenAIStrictToolSetting(model, compat),
|
||||
resolveOpenAIStrictToolSetting(model, {
|
||||
transport: "stream",
|
||||
supportsStrictMode: compat?.supportsStrictMode,
|
||||
}),
|
||||
);
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
import { randomUUID } from "node:crypto";
|
||||
import type { Context, Message, StopReason } from "@mariozechner/pi-ai";
|
||||
import type { AssistantMessage } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
normalizeOpenAIStrictToolParameters,
|
||||
resolveOpenAIStrictToolFlagForInventory,
|
||||
} from "./openai-tool-schema.js";
|
||||
import type {
|
||||
ContentPart,
|
||||
FunctionToolDefinition,
|
||||
|
|
@ -8,7 +12,6 @@ import type {
|
|||
OpenAIResponsesAssistantPhase,
|
||||
ResponseObject,
|
||||
} from "./openai-ws-connection.js";
|
||||
import { normalizeToolParameterSchema } from "./pi-tools.schema.js";
|
||||
import { buildAssistantMessage, buildUsageWithNoCost } from "./stream-message-shared.js";
|
||||
import { normalizeUsage } from "./usage.js";
|
||||
|
||||
|
|
@ -274,16 +277,24 @@ function extractResponseReasoningText(item: unknown): string {
|
|||
return typeof record.content === "string" ? record.content.trim() : "";
|
||||
}
|
||||
|
||||
export function convertTools(tools: Context["tools"]): FunctionToolDefinition[] {
|
||||
export function convertTools(
|
||||
tools: Context["tools"],
|
||||
options?: { strict?: boolean | null },
|
||||
): FunctionToolDefinition[] {
|
||||
if (!tools || tools.length === 0) {
|
||||
return [];
|
||||
}
|
||||
const strict = resolveOpenAIStrictToolFlagForInventory(tools, options?.strict);
|
||||
return tools.map((tool) => {
|
||||
return {
|
||||
type: "function" as const,
|
||||
name: tool.name,
|
||||
description: typeof tool.description === "string" ? tool.description : undefined,
|
||||
parameters: normalizeToolParameterSchema(tool.parameters ?? {}) as Record<string, unknown>,
|
||||
parameters: normalizeOpenAIStrictToolParameters(
|
||||
tool.parameters ?? {},
|
||||
strict === true,
|
||||
) as Record<string, unknown>,
|
||||
...(strict === undefined ? {} : { strict }),
|
||||
};
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -503,6 +503,57 @@ describe("convertTools", () => {
|
|||
properties: { cmd: { type: "string" } },
|
||||
});
|
||||
});
|
||||
|
||||
it("adds strict:true and required:[] for native strict-compatible no-param tools", () => {
|
||||
const tools = [
|
||||
{
|
||||
name: "ping",
|
||||
description: "No params",
|
||||
parameters: { type: "object", properties: {}, additionalProperties: false },
|
||||
},
|
||||
];
|
||||
const result = convertTools(tools as Parameters<typeof convertTools>[0], { strict: true });
|
||||
|
||||
expect(result[0]).toEqual({
|
||||
type: "function",
|
||||
name: "ping",
|
||||
description: "No params",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: {},
|
||||
additionalProperties: false,
|
||||
required: [],
|
||||
},
|
||||
strict: true,
|
||||
});
|
||||
});
|
||||
|
||||
it("falls back to strict:false for native tools with non-strict-compatible schemas", () => {
|
||||
const tools = [
|
||||
{
|
||||
name: "read",
|
||||
description: "Read file",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: { path: { type: "string" } },
|
||||
additionalProperties: false,
|
||||
},
|
||||
},
|
||||
];
|
||||
const result = convertTools(tools as Parameters<typeof convertTools>[0], { strict: true });
|
||||
|
||||
expect(result[0]).toEqual({
|
||||
type: "function",
|
||||
name: "read",
|
||||
description: "Read file",
|
||||
parameters: {
|
||||
type: "object",
|
||||
properties: { path: { type: "string" } },
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ import {
|
|||
resolveProviderWebSocketSessionPolicyWithPlugin,
|
||||
} from "../plugins/provider-runtime.js";
|
||||
import type { ProviderRuntimeModel, ProviderTransportTurnState } from "../plugins/types.js";
|
||||
import { resolveOpenAIStrictToolSetting } from "./openai-tool-schema.js";
|
||||
import {
|
||||
getOpenAIWebSocketErrorDetails,
|
||||
OpenAIWebSocketManager,
|
||||
|
|
@ -78,6 +79,18 @@ interface WsSession {
|
|||
degradeCooldownMs: number;
|
||||
}
|
||||
|
||||
function resolveOpenAIWebSocketStrictToolSetting(
|
||||
model: Parameters<StreamFn>[0],
|
||||
): boolean | undefined {
|
||||
return resolveOpenAIStrictToolSetting(model, {
|
||||
transport: "websocket",
|
||||
supportsStrictMode:
|
||||
model.compat && typeof model.compat === "object"
|
||||
? (model.compat as { supportsStrictMode?: boolean }).supportsStrictMode
|
||||
: undefined,
|
||||
});
|
||||
}
|
||||
|
||||
/** Module-level registry: sessionId → WsSession */
|
||||
const wsRegistry = new Map<string, WsSession>();
|
||||
|
||||
|
|
@ -747,7 +760,9 @@ export function createOpenAIWebSocketStreamFn(
|
|||
await runWarmUp({
|
||||
manager: session.manager,
|
||||
modelId: model.id,
|
||||
tools: convertTools(context.tools),
|
||||
tools: convertTools(context.tools, {
|
||||
strict: resolveOpenAIWebSocketStrictToolSetting(model),
|
||||
}),
|
||||
instructions: context.systemPrompt
|
||||
? stripSystemPromptCacheBoundary(context.systemPrompt)
|
||||
: undefined,
|
||||
|
|
@ -842,7 +857,9 @@ export function createOpenAIWebSocketStreamFn(
|
|||
context,
|
||||
options: options as WsOptions | undefined,
|
||||
turnInput,
|
||||
tools: convertTools(context.tools),
|
||||
tools: convertTools(context.tools, {
|
||||
strict: resolveOpenAIWebSocketStrictToolSetting(model),
|
||||
}),
|
||||
metadata: turnState?.metadata,
|
||||
}) as Record<string, unknown>;
|
||||
const nextPayload = await options?.onPayload?.(payload, model);
|
||||
|
|
|
|||
Loading…
Reference in New Issue