mirror of https://github.com/openclaw/openclaw.git
fix: apply Mistral compat across proxy transports
This commit is contained in:
parent
4a5885df3a
commit
c664b67796
|
|
@ -12,24 +12,32 @@ export {
|
|||
|
||||
const MISTRAL_MAX_TOKENS_FIELD = "max_tokens";
|
||||
|
||||
export const MISTRAL_MODEL_COMPAT_PATCH = {
|
||||
supportsStore: false,
|
||||
supportsReasoningEffort: false,
|
||||
maxTokensField: MISTRAL_MAX_TOKENS_FIELD,
|
||||
} as const satisfies {
|
||||
supportsStore: boolean;
|
||||
supportsReasoningEffort: boolean;
|
||||
maxTokensField: "max_tokens";
|
||||
};
|
||||
|
||||
export function applyMistralModelCompat<T extends { compat?: unknown }>(model: T): T {
|
||||
const patch = {
|
||||
supportsStore: false,
|
||||
supportsReasoningEffort: false,
|
||||
maxTokensField: MISTRAL_MAX_TOKENS_FIELD,
|
||||
} satisfies Record<string, unknown>;
|
||||
const compat =
|
||||
model.compat && typeof model.compat === "object"
|
||||
? (model.compat as Record<string, unknown>)
|
||||
: undefined;
|
||||
if (compat && Object.entries(patch).every(([key, value]) => compat[key] === value)) {
|
||||
if (
|
||||
compat &&
|
||||
Object.entries(MISTRAL_MODEL_COMPAT_PATCH).every(([key, value]) => compat[key] === value)
|
||||
) {
|
||||
return model;
|
||||
}
|
||||
return {
|
||||
...model,
|
||||
compat: {
|
||||
...compat,
|
||||
...patch,
|
||||
...MISTRAL_MODEL_COMPAT_PATCH,
|
||||
} as T extends { compat?: infer TCompat } ? TCompat : never,
|
||||
} as T;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,51 @@
|
|||
import { defineSingleProviderPluginEntry } from "openclaw/plugin-sdk/provider-entry";
|
||||
import { applyMistralModelCompat } from "./api.js";
|
||||
import { applyMistralModelCompat, MISTRAL_MODEL_COMPAT_PATCH } from "./api.js";
|
||||
import { mistralMediaUnderstandingProvider } from "./media-understanding-provider.js";
|
||||
import { applyMistralConfig, MISTRAL_DEFAULT_MODEL_REF } from "./onboard.js";
|
||||
import { buildMistralProvider } from "./provider-catalog.js";
|
||||
|
||||
const PROVIDER_ID = "mistral";
|
||||
const MISTRAL_MODEL_HINTS = [
|
||||
"mistral",
|
||||
"mistralai",
|
||||
"mixtral",
|
||||
"codestral",
|
||||
"pixtral",
|
||||
"devstral",
|
||||
"ministral",
|
||||
] as const;
|
||||
|
||||
function isMistralBaseUrl(baseUrl: unknown): boolean {
|
||||
if (typeof baseUrl !== "string" || !baseUrl.trim()) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
return new URL(baseUrl).hostname.toLowerCase() === "api.mistral.ai";
|
||||
} catch {
|
||||
return baseUrl.toLowerCase().includes("api.mistral.ai");
|
||||
}
|
||||
}
|
||||
|
||||
function isMistralModelHint(modelId: string): boolean {
|
||||
const normalized = modelId.trim().toLowerCase();
|
||||
return MISTRAL_MODEL_HINTS.some(
|
||||
(hint) =>
|
||||
normalized === hint ||
|
||||
normalized.startsWith(`${hint}/`) ||
|
||||
normalized.startsWith(`${hint}-`) ||
|
||||
normalized.startsWith(`${hint}:`),
|
||||
);
|
||||
}
|
||||
|
||||
function shouldContributeMistralCompat(params: {
|
||||
modelId: string;
|
||||
model: { api?: unknown; baseUrl?: unknown };
|
||||
}): boolean {
|
||||
if (params.model.api !== "openai-completions") {
|
||||
return false;
|
||||
}
|
||||
return isMistralBaseUrl(params.model.baseUrl) || isMistralModelHint(params.modelId);
|
||||
}
|
||||
|
||||
export default defineSingleProviderPluginEntry({
|
||||
id: PROVIDER_ID,
|
||||
|
|
@ -34,6 +75,8 @@ export default defineSingleProviderPluginEntry({
|
|||
allowExplicitBaseUrl: true,
|
||||
},
|
||||
normalizeResolvedModel: ({ model }) => applyMistralModelCompat(model),
|
||||
contributeResolvedModelCompat: ({ modelId, model }) =>
|
||||
shouldContributeMistralCompat({ modelId, model }) ? MISTRAL_MODEL_COMPAT_PATCH : undefined,
|
||||
capabilities: {
|
||||
transcriptToolCallIdMode: "strict9",
|
||||
transcriptToolCallIdModelHints: [
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ vi.mock("../../plugins/provider-runtime.js", async (importOriginal) => {
|
|||
const actual = await importOriginal<typeof import("../../plugins/provider-runtime.js")>();
|
||||
return {
|
||||
...actual,
|
||||
applyProviderResolvedModelCompatWithPlugins: () => undefined,
|
||||
clearProviderRuntimeHookCache: clearProviderRuntimeHookCacheMock,
|
||||
normalizeProviderResolvedModelWithPlugin: () => undefined,
|
||||
prepareProviderDynamicModel: (params: unknown) => prepareProviderDynamicModelMock(params),
|
||||
|
|
|
|||
|
|
@ -1044,6 +1044,7 @@ describe("resolveModel", () => {
|
|||
authStorage: { mocked: true } as never,
|
||||
modelRegistry: discoverModels({ mocked: true } as never, "/tmp/agent"),
|
||||
runtimeHooks: {
|
||||
applyProviderResolvedModelCompatWithPlugins: () => undefined,
|
||||
buildProviderUnknownModelHintWithPlugin: () => undefined,
|
||||
prepareProviderDynamicModel: async () => {},
|
||||
runProviderDynamicModel: () => undefined,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import type { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent";
|
|||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import type { ModelDefinitionConfig } from "../../config/types.js";
|
||||
import {
|
||||
applyProviderResolvedModelCompatWithPlugins,
|
||||
buildProviderUnknownModelHintWithPlugin,
|
||||
clearProviderRuntimeHookCache,
|
||||
normalizeProviderTransportWithPlugin,
|
||||
|
|
@ -36,6 +37,9 @@ type InlineProviderConfig = {
|
|||
};
|
||||
|
||||
type ProviderRuntimeHooks = {
|
||||
applyProviderResolvedModelCompatWithPlugins?: (
|
||||
params: Parameters<typeof applyProviderResolvedModelCompatWithPlugins>[0],
|
||||
) => unknown;
|
||||
buildProviderUnknownModelHintWithPlugin: (
|
||||
params: Parameters<typeof buildProviderUnknownModelHintWithPlugin>[0],
|
||||
) => string | undefined;
|
||||
|
|
@ -52,6 +56,7 @@ type ProviderRuntimeHooks = {
|
|||
};
|
||||
|
||||
const DEFAULT_PROVIDER_RUNTIME_HOOKS: ProviderRuntimeHooks = {
|
||||
applyProviderResolvedModelCompatWithPlugins,
|
||||
buildProviderUnknownModelHintWithPlugin,
|
||||
prepareProviderDynamicModel,
|
||||
runProviderDynamicModel,
|
||||
|
|
@ -121,9 +126,20 @@ function normalizeResolvedModel(params: {
|
|||
model: normalizedInputModel,
|
||||
},
|
||||
}) as Model<Api> | undefined;
|
||||
const compatNormalized = runtimeHooks.applyProviderResolvedModelCompatWithPlugins?.({
|
||||
provider: params.provider,
|
||||
config: params.cfg,
|
||||
context: {
|
||||
config: params.cfg,
|
||||
agentDir: params.agentDir,
|
||||
provider: params.provider,
|
||||
modelId: normalizedInputModel.id,
|
||||
model: (pluginNormalized ?? normalizedInputModel) as never,
|
||||
},
|
||||
}) as Model<Api> | undefined;
|
||||
return normalizeResolvedProviderModel({
|
||||
provider: params.provider,
|
||||
model: pluginNormalized ?? normalizedInputModel,
|
||||
model: compatNormalized ?? pluginNormalized ?? normalizedInputModel,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -273,4 +273,111 @@ describe("discoverAuthStorage", () => {
|
|||
}
|
||||
});
|
||||
});
|
||||
|
||||
it("normalizes discovered Mistral compat flags for custom Mistral-hosted providers", async () => {
|
||||
await withAgentDir(async (agentDir) => {
|
||||
saveAuthProfileStore(
|
||||
{
|
||||
version: 1,
|
||||
profiles: {
|
||||
"custom-api-mistral-ai:default": {
|
||||
type: "api_key",
|
||||
provider: "custom-api-mistral-ai",
|
||||
key: "mistral-custom-key",
|
||||
},
|
||||
},
|
||||
},
|
||||
agentDir,
|
||||
);
|
||||
await writeModelsJson(agentDir, {
|
||||
providers: {
|
||||
"custom-api-mistral-ai": {
|
||||
api: "openai-completions",
|
||||
baseUrl: "https://api.mistral.ai/v1",
|
||||
apiKey: "custom-api-mistral-ai",
|
||||
models: [
|
||||
{
|
||||
id: "mistral-small-latest",
|
||||
name: "Mistral Small",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 262144,
|
||||
maxTokens: 16384,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const authStorage = discoverAuthStorage(agentDir);
|
||||
const modelRegistry = discoverModels(authStorage, agentDir);
|
||||
const model = modelRegistry.find("custom-api-mistral-ai", "mistral-small-latest") as {
|
||||
compat?: {
|
||||
supportsStore?: boolean;
|
||||
supportsReasoningEffort?: boolean;
|
||||
maxTokensField?: string;
|
||||
};
|
||||
} | null;
|
||||
|
||||
expect(model?.compat?.supportsStore).toBe(false);
|
||||
expect(model?.compat?.supportsReasoningEffort).toBe(false);
|
||||
expect(model?.compat?.maxTokensField).toBe("max_tokens");
|
||||
});
|
||||
});
|
||||
|
||||
it("normalizes discovered Mistral compat flags for OpenRouter Mistral model ids", async () => {
|
||||
await withAgentDir(async (agentDir) => {
|
||||
saveAuthProfileStore(
|
||||
{
|
||||
version: 1,
|
||||
profiles: {
|
||||
"openrouter:default": {
|
||||
type: "api_key",
|
||||
provider: "openrouter",
|
||||
key: "sk-or-v1-runtime",
|
||||
},
|
||||
},
|
||||
},
|
||||
agentDir,
|
||||
);
|
||||
await writeModelsJson(agentDir, {
|
||||
providers: {
|
||||
openrouter: {
|
||||
api: "openai-completions",
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
apiKey: "OPENROUTER_API_KEY",
|
||||
models: [
|
||||
{
|
||||
id: "mistralai/mistral-small-3.2-24b-instruct",
|
||||
name: "Mistral Small via OpenRouter",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 262144,
|
||||
maxTokens: 16384,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const authStorage = discoverAuthStorage(agentDir);
|
||||
const modelRegistry = discoverModels(authStorage, agentDir);
|
||||
const model = modelRegistry.find(
|
||||
"openrouter",
|
||||
"mistralai/mistral-small-3.2-24b-instruct",
|
||||
) as {
|
||||
compat?: {
|
||||
supportsStore?: boolean;
|
||||
supportsReasoningEffort?: boolean;
|
||||
maxTokensField?: string;
|
||||
};
|
||||
} | null;
|
||||
|
||||
expect(model?.compat?.supportsStore).toBe(false);
|
||||
expect(model?.compat?.supportsReasoningEffort).toBe(false);
|
||||
expect(model?.compat?.maxTokensField).toBe("max_tokens");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@ import type {
|
|||
ModelRegistry as PiModelRegistry,
|
||||
} from "@mariozechner/pi-coding-agent";
|
||||
import { normalizeModelCompat } from "../plugins/provider-model-compat.js";
|
||||
import { normalizeProviderResolvedModelWithPlugin } from "../plugins/provider-runtime.js";
|
||||
import {
|
||||
applyProviderResolvedModelCompatWithPlugins,
|
||||
normalizeProviderResolvedModelWithPlugin,
|
||||
} from "../plugins/provider-runtime.js";
|
||||
import type { ProviderRuntimeModel } from "../plugins/types.js";
|
||||
import { ensureAuthProfileStore } from "./auth-profiles.js";
|
||||
import { PROVIDER_ENV_API_KEY_CANDIDATES } from "./model-auth-env-vars.js";
|
||||
|
|
@ -75,7 +78,17 @@ function normalizeRegistryModel<T>(value: T, agentDir: string): T {
|
|||
agentDir,
|
||||
},
|
||||
}) ?? model;
|
||||
return normalizeModelCompat(pluginNormalized as Model<Api>) as T;
|
||||
const compatNormalized =
|
||||
applyProviderResolvedModelCompatWithPlugins({
|
||||
provider: model.provider,
|
||||
context: {
|
||||
provider: model.provider,
|
||||
modelId: model.id,
|
||||
model: pluginNormalized,
|
||||
agentDir,
|
||||
},
|
||||
}) ?? pluginNormalized;
|
||||
return normalizeModelCompat(compatNormalized as Model<Api>) as T;
|
||||
}
|
||||
|
||||
class OpenClawModelRegistry extends PiModelRegistryClass {
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ let applyProviderNativeStreamingUsageCompatWithPlugin: typeof import("./provider
|
|||
let formatProviderAuthProfileApiKeyWithPlugin: typeof import("./provider-runtime.js").formatProviderAuthProfileApiKeyWithPlugin;
|
||||
let normalizeProviderConfigWithPlugin: typeof import("./provider-runtime.js").normalizeProviderConfigWithPlugin;
|
||||
let normalizeProviderModelIdWithPlugin: typeof import("./provider-runtime.js").normalizeProviderModelIdWithPlugin;
|
||||
let applyProviderResolvedModelCompatWithPlugins: typeof import("./provider-runtime.js").applyProviderResolvedModelCompatWithPlugins;
|
||||
let normalizeProviderTransportWithPlugin: typeof import("./provider-runtime.js").normalizeProviderTransportWithPlugin;
|
||||
let prepareProviderExtraParams: typeof import("./provider-runtime.js").prepareProviderExtraParams;
|
||||
let resolveProviderConfigApiKeyWithPlugin: typeof import("./provider-runtime.js").resolveProviderConfigApiKeyWithPlugin;
|
||||
|
|
@ -211,6 +212,7 @@ describe("provider-runtime", () => {
|
|||
buildProviderMissingAuthMessageWithPlugin,
|
||||
buildProviderUnknownModelHintWithPlugin,
|
||||
applyProviderNativeStreamingUsageCompatWithPlugin,
|
||||
applyProviderResolvedModelCompatWithPlugins,
|
||||
formatProviderAuthProfileApiKeyWithPlugin,
|
||||
normalizeProviderConfigWithPlugin,
|
||||
normalizeProviderModelIdWithPlugin,
|
||||
|
|
@ -727,6 +729,13 @@ describe("provider-runtime", () => {
|
|||
api: "openai-codex-responses",
|
||||
});
|
||||
|
||||
expect(
|
||||
applyProviderResolvedModelCompatWithPlugins({
|
||||
provider: DEMO_PROVIDER_ID,
|
||||
context: createDemoResolvedModelContext({}),
|
||||
}),
|
||||
).toBeUndefined();
|
||||
|
||||
expect(
|
||||
formatProviderAuthProfileApiKeyWithPlugin({
|
||||
provider: DEMO_PROVIDER_ID,
|
||||
|
|
@ -854,6 +863,53 @@ describe("provider-runtime", () => {
|
|||
);
|
||||
});
|
||||
|
||||
it("merges compat contributions from owner and foreign provider plugins", () => {
|
||||
resolveOwningPluginIdsForProviderMock.mockReturnValue(["openrouter"]);
|
||||
resolvePluginProvidersMock.mockImplementation((params) => {
|
||||
const onlyPluginIds = params.onlyPluginIds ?? [];
|
||||
const plugins: ProviderPlugin[] = [
|
||||
{
|
||||
id: "openrouter",
|
||||
label: "OpenRouter",
|
||||
auth: [],
|
||||
contributeResolvedModelCompat: () => ({ supportsStrictMode: true }),
|
||||
},
|
||||
{
|
||||
id: "mistral",
|
||||
label: "Mistral",
|
||||
auth: [],
|
||||
contributeResolvedModelCompat: ({ modelId }) =>
|
||||
modelId.startsWith("mistralai/") ? { supportsStore: false } : undefined,
|
||||
},
|
||||
];
|
||||
return onlyPluginIds.length > 0
|
||||
? plugins.filter((plugin) => onlyPluginIds.includes(plugin.id))
|
||||
: plugins;
|
||||
});
|
||||
|
||||
expect(
|
||||
applyProviderResolvedModelCompatWithPlugins({
|
||||
provider: "openrouter",
|
||||
context: createDemoResolvedModelContext({
|
||||
provider: "openrouter",
|
||||
modelId: "mistralai/mistral-small-3.2-24b-instruct",
|
||||
model: {
|
||||
...MODEL,
|
||||
provider: "openrouter",
|
||||
id: "mistralai/mistral-small-3.2-24b-instruct",
|
||||
compat: { supportsDeveloperRole: false },
|
||||
},
|
||||
}),
|
||||
}),
|
||||
).toMatchObject({
|
||||
compat: {
|
||||
supportsDeveloperRole: false,
|
||||
supportsStrictMode: true,
|
||||
supportsStore: false,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("resolves bundled catalog hooks through provider plugins", async () => {
|
||||
resolveCatalogHookProviderPluginIdsMock.mockReturnValue(["openai"]);
|
||||
resolvePluginProvidersMock.mockImplementation((params?: { onlyPluginIds?: string[] }) => {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import type {
|
|||
ProviderFetchUsageSnapshotContext,
|
||||
ProviderNormalizeConfigContext,
|
||||
ProviderNormalizeModelIdContext,
|
||||
ProviderNormalizeResolvedModelContext,
|
||||
ProviderNormalizeTransportContext,
|
||||
ProviderModernModelPolicyContext,
|
||||
ProviderPrepareExtraParamsContext,
|
||||
|
|
@ -224,6 +225,79 @@ export function normalizeProviderResolvedModelWithPlugin(params: {
|
|||
);
|
||||
}
|
||||
|
||||
function resolveProviderCompatHookPlugins(params: {
|
||||
provider: string;
|
||||
config?: OpenClawConfig;
|
||||
workspaceDir?: string;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
}): ProviderPlugin[] {
|
||||
const candidates = resolveProviderPluginsForHooks(params);
|
||||
const owner = resolveProviderRuntimePlugin(params);
|
||||
if (!owner) {
|
||||
return candidates;
|
||||
}
|
||||
|
||||
const ordered = [owner, ...candidates];
|
||||
const seen = new Set<string>();
|
||||
return ordered.filter((candidate) => {
|
||||
const key = `${candidate.pluginId ?? ""}:${candidate.id}`;
|
||||
if (seen.has(key)) {
|
||||
return false;
|
||||
}
|
||||
seen.add(key);
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
function applyCompatPatchToModel(
|
||||
model: ProviderRuntimeModel,
|
||||
patch: Record<string, unknown>,
|
||||
): ProviderRuntimeModel {
|
||||
const compat =
|
||||
model.compat && typeof model.compat === "object"
|
||||
? (model.compat as Record<string, unknown>)
|
||||
: undefined;
|
||||
if (Object.entries(patch).every(([key, value]) => compat?.[key] === value)) {
|
||||
return model;
|
||||
}
|
||||
return {
|
||||
...model,
|
||||
compat: {
|
||||
...compat,
|
||||
...patch,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function applyProviderResolvedModelCompatWithPlugins(params: {
|
||||
provider: string;
|
||||
config?: OpenClawConfig;
|
||||
workspaceDir?: string;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
context: ProviderNormalizeResolvedModelContext;
|
||||
}): ProviderRuntimeModel | undefined {
|
||||
let nextModel = params.context.model;
|
||||
let changed = false;
|
||||
|
||||
for (const plugin of resolveProviderCompatHookPlugins(params)) {
|
||||
const patch = plugin.contributeResolvedModelCompat?.({
|
||||
...params.context,
|
||||
model: nextModel,
|
||||
});
|
||||
if (!patch || typeof patch !== "object") {
|
||||
continue;
|
||||
}
|
||||
const patchedModel = applyCompatPatchToModel(nextModel, patch as Record<string, unknown>);
|
||||
if (patchedModel === nextModel) {
|
||||
continue;
|
||||
}
|
||||
nextModel = patchedModel;
|
||||
changed = true;
|
||||
}
|
||||
|
||||
return changed ? nextModel : undefined;
|
||||
}
|
||||
|
||||
function resolveProviderHookPlugin(params: {
|
||||
provider: string;
|
||||
config?: OpenClawConfig;
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import type {
|
|||
ModelProviderAuthMode,
|
||||
ModelProviderConfig,
|
||||
} from "../config/types.js";
|
||||
import type { ModelCompatConfig } from "../config/types.models.js";
|
||||
import type { OperatorScope } from "../gateway/method-scopes.js";
|
||||
import type { GatewayRequestHandler } from "../gateway/server-methods/types.js";
|
||||
import type { InternalHookHandler } from "../hooks/internal-hooks.js";
|
||||
|
|
@ -885,6 +886,17 @@ export type ProviderPlugin = {
|
|||
normalizeResolvedModel?: (
|
||||
ctx: ProviderNormalizeResolvedModelContext,
|
||||
) => ProviderRuntimeModel | null | undefined;
|
||||
/**
|
||||
* Provider-owned compat contribution for resolved models outside direct
|
||||
* provider ownership.
|
||||
*
|
||||
* Use this when a plugin can recognize its vendor's models behind another
|
||||
* OpenAI-compatible transport (for example OpenRouter or a custom base URL)
|
||||
* and needs to contribute compat flags without taking over the provider.
|
||||
*/
|
||||
contributeResolvedModelCompat?: (
|
||||
ctx: ProviderNormalizeResolvedModelContext,
|
||||
) => Partial<ModelCompatConfig> | null | undefined;
|
||||
/**
|
||||
* Provider-owned model-id normalization.
|
||||
*
|
||||
|
|
|
|||
Loading…
Reference in New Issue