fix: normalize huggingface refs and discovery timeout

This commit is contained in:
Peter Steinberger 2026-04-05 09:17:34 +01:00
parent b5f8cd4fcf
commit 5ac07b8ef0
No known key found for this signature in database
6 changed files with 135 additions and 22 deletions

View File

@ -1,10 +1,21 @@
import { describe, expect, it } from "vitest";
import { afterEach, describe, expect, it, vi } from "vitest";
import {
buildHuggingfaceModelDefinition,
discoverHuggingfaceModels,
HUGGINGFACE_MODEL_CATALOG,
isHuggingfacePolicyLocked,
} from "./api.js";
import { HUGGINGFACE_DISCOVERY_TIMEOUT_MS } from "./models.js";
const ORIGINAL_VITEST = process.env.VITEST;
const ORIGINAL_NODE_ENV = process.env.NODE_ENV;
afterEach(() => {
process.env.VITEST = ORIGINAL_VITEST;
process.env.NODE_ENV = ORIGINAL_NODE_ENV;
vi.restoreAllMocks();
vi.unstubAllGlobals();
});
describe("huggingface models", () => {
it("buildHuggingfaceModelDefinition returns config with required fields", () => {
@ -31,6 +42,40 @@ describe("huggingface models", () => {
expect(models[0].id).toBe("deepseek-ai/DeepSeek-R1");
});
it("uses the default discovery timeout for live Hugging Face fetches", async () => {
process.env.VITEST = "false";
process.env.NODE_ENV = "development";
const timeoutSpy = vi.spyOn(AbortSignal, "timeout");
vi.stubGlobal(
"fetch",
vi.fn(
async () =>
new Response("{}", { status: 500, headers: { "Content-Type": "application/json" } }),
),
);
await discoverHuggingfaceModels("hf_test_token");
expect(timeoutSpy).toHaveBeenCalledWith(HUGGINGFACE_DISCOVERY_TIMEOUT_MS);
});
it("accepts a custom discovery timeout override", async () => {
process.env.VITEST = "false";
process.env.NODE_ENV = "development";
const timeoutSpy = vi.spyOn(AbortSignal, "timeout");
vi.stubGlobal(
"fetch",
vi.fn(
async () =>
new Response("{}", { status: 500, headers: { "Content-Type": "application/json" } }),
),
);
await discoverHuggingfaceModels("hf_test_token", 25_000);
expect(timeoutSpy).toHaveBeenCalledWith(25_000);
});
describe("isHuggingfacePolicyLocked", () => {
it("returns true for :cheapest and :fastest refs", () => {
expect(isHuggingfacePolicyLocked("huggingface/deepseek-ai/DeepSeek-R1:cheapest")).toBe(true);

View File

@ -2,6 +2,7 @@ import type { ModelDefinitionConfig } from "openclaw/plugin-sdk/provider-model-s
export const HUGGINGFACE_BASE_URL = "https://router.huggingface.co/v1";
export const HUGGINGFACE_POLICY_SUFFIXES = ["cheapest", "fastest"] as const;
export const HUGGINGFACE_DISCOVERY_TIMEOUT_MS = 30_000;
const HUGGINGFACE_DEFAULT_COST = {
input: 0,
@ -123,7 +124,10 @@ function displayNameFromApiEntry(entry: HFModelEntry, inferredName: string): str
return inferredName;
}
export async function discoverHuggingfaceModels(apiKey: string): Promise<ModelDefinitionConfig[]> {
export async function discoverHuggingfaceModels(
apiKey: string,
timeoutMs = HUGGINGFACE_DISCOVERY_TIMEOUT_MS,
): Promise<ModelDefinitionConfig[]> {
if (process.env.VITEST === "true" || process.env.NODE_ENV === "test") {
return HUGGINGFACE_MODEL_CATALOG.map(buildHuggingfaceModelDefinition);
}
@ -135,7 +139,7 @@ export async function discoverHuggingfaceModels(apiKey: string): Promise<ModelDe
try {
const response = await fetch(`${HUGGINGFACE_BASE_URL}/models`, {
signal: AbortSignal.timeout(10_000),
signal: AbortSignal.timeout(timeoutMs),
headers: {
Authorization: `Bearer ${trimmedKey}`,
"Content-Type": "application/json",

View File

@ -42,10 +42,22 @@ export function normalizeAnthropicModelId(model: string): string {
}
}
function normalizeHuggingfaceModelId(model: string): string {
const trimmed = model.trim();
if (!trimmed) {
return trimmed;
}
const prefix = "huggingface/";
return trimmed.toLowerCase().startsWith(prefix) ? trimmed.slice(prefix.length) : trimmed;
}
export function normalizeStaticProviderModelId(provider: string, model: string): string {
if (provider === "anthropic") {
return normalizeAnthropicModelId(model);
}
if (provider === "huggingface") {
return normalizeHuggingfaceModelId(model);
}
if (provider === "google" || provider === "google-vertex") {
return normalizeGooglePreviewModelId(model);
}

View File

@ -223,6 +223,12 @@ describe("model-selection", () => {
defaultProvider: "openai",
expected: { provider: "openrouter", model: "anthropic/claude-sonnet-4-6" },
},
{
name: "strips duplicate Hugging Face provider prefixes",
variants: ["huggingface/deepseek-ai/DeepSeek-R1"],
defaultProvider: "huggingface",
expected: { provider: "huggingface", model: "deepseek-ai/DeepSeek-R1" },
},
{
name: "normalizes Vercel Claude shorthand to anthropic-prefixed model ids",
variants: ["vercel-ai-gateway/claude-opus-4.6"],

View File

@ -674,6 +674,34 @@ describe("resolveModel", () => {
});
});
it("matches prefixed Hugging Face ids against discovered registry models", () => {
mockDiscoveredModel(discoverModels, {
provider: "huggingface",
modelId: "deepseek-ai/DeepSeek-R1",
templateModel: {
...makeModel("deepseek-ai/DeepSeek-R1"),
provider: "huggingface",
baseUrl: "https://router.huggingface.co/v1",
reasoning: true,
input: ["text"],
},
});
const result = resolveModelForTest(
"huggingface",
"huggingface/deepseek-ai/DeepSeek-R1",
"/tmp/agent",
);
expect(result.error).toBeUndefined();
expect(result.model).toMatchObject({
provider: "huggingface",
id: "deepseek-ai/DeepSeek-R1",
reasoning: true,
input: ["text"],
});
});
it("preloads OpenRouter capabilities before first async resolve of an unknown model", async () => {
mockLoadOpenRouterModelCapabilities.mockImplementation(async (modelId) => {
if (modelId === "google/gemini-3.1-flash-image-preview") {

View File

@ -17,6 +17,7 @@ import { resolveOpenClawAgentDir } from "../agent-paths.js";
import { DEFAULT_CONTEXT_TOKENS } from "../defaults.js";
import { buildModelAliasLines } from "../model-alias-lines.js";
import { isSecretRefHeaderValueMarker } from "../model-auth-markers.js";
import { normalizeStaticProviderModelId } from "../model-ref-shared.js";
import { findNormalizedProviderValue, normalizeProviderId } from "../model-selection.js";
import {
buildSuppressedBuiltInModelError,
@ -618,7 +619,16 @@ export function resolveModelWithRegistry(params: {
agentDir?: string;
runtimeHooks?: ProviderRuntimeHooks;
}): Model<Api> | undefined {
const explicitModel = resolveExplicitModelWithRegistry(params);
const normalizedRef = {
provider: params.provider,
model: normalizeStaticProviderModelId(normalizeProviderId(params.provider), params.modelId),
};
const normalizedParams = {
...params,
provider: normalizedRef.provider,
modelId: normalizedRef.model,
};
const explicitModel = resolveExplicitModelWithRegistry(normalizedParams);
if (explicitModel?.kind === "suppressed") {
return undefined;
}
@ -626,12 +636,12 @@ export function resolveModelWithRegistry(params: {
return explicitModel.model;
}
const pluginDynamicModel = resolvePluginDynamicModelWithRegistry(params);
const pluginDynamicModel = resolvePluginDynamicModelWithRegistry(normalizedParams);
if (pluginDynamicModel) {
return pluginDynamicModel;
}
return resolveConfiguredFallbackModel(params);
return resolveConfiguredFallbackModel(normalizedParams);
}
export function resolveModel(
@ -651,13 +661,17 @@ export function resolveModel(
authStorage: AuthStorage;
modelRegistry: ModelRegistry;
} {
const normalizedRef = {
provider,
model: normalizeStaticProviderModelId(normalizeProviderId(provider), modelId),
};
const resolvedAgentDir = agentDir ?? resolveOpenClawAgentDir();
const authStorage = options?.authStorage ?? discoverAuthStorage(resolvedAgentDir);
const modelRegistry = options?.modelRegistry ?? discoverModels(authStorage, resolvedAgentDir);
const runtimeHooks = resolveRuntimeHooks(options);
const model = resolveModelWithRegistry({
provider,
modelId,
provider: normalizedRef.provider,
modelId: normalizedRef.model,
modelRegistry,
cfg,
agentDir: resolvedAgentDir,
@ -669,8 +683,8 @@ export function resolveModel(
return {
error: buildUnknownModelError({
provider,
modelId,
provider: normalizedRef.provider,
modelId: normalizedRef.model,
cfg,
agentDir: resolvedAgentDir,
runtimeHooks,
@ -698,13 +712,17 @@ export async function resolveModelAsync(
authStorage: AuthStorage;
modelRegistry: ModelRegistry;
}> {
const normalizedRef = {
provider,
model: normalizeStaticProviderModelId(normalizeProviderId(provider), modelId),
};
const resolvedAgentDir = agentDir ?? resolveOpenClawAgentDir();
const authStorage = options?.authStorage ?? discoverAuthStorage(resolvedAgentDir);
const modelRegistry = options?.modelRegistry ?? discoverModels(authStorage, resolvedAgentDir);
const runtimeHooks = resolveRuntimeHooks(options);
const explicitModel = resolveExplicitModelWithRegistry({
provider,
modelId,
provider: normalizedRef.provider,
modelId: normalizedRef.model,
modelRegistry,
cfg,
agentDir: resolvedAgentDir,
@ -713,8 +731,8 @@ export async function resolveModelAsync(
if (explicitModel?.kind === "suppressed") {
return {
error: buildUnknownModelError({
provider,
modelId,
provider: normalizedRef.provider,
modelId: normalizedRef.model,
cfg,
agentDir: resolvedAgentDir,
runtimeHooks,
@ -723,26 +741,26 @@ export async function resolveModelAsync(
modelRegistry,
};
}
const providerConfig = resolveConfiguredProviderConfig(cfg, provider);
const providerConfig = resolveConfiguredProviderConfig(cfg, normalizedRef.provider);
const resolveDynamicAttempt = async (attemptOptions?: { clearHookCache?: boolean }) => {
if (attemptOptions?.clearHookCache) {
runtimeHooks.clearProviderRuntimeHookCache();
}
await runtimeHooks.prepareProviderDynamicModel({
provider,
provider: normalizedRef.provider,
config: cfg,
context: {
config: cfg,
agentDir: resolvedAgentDir,
provider,
modelId,
provider: normalizedRef.provider,
modelId: normalizedRef.model,
modelRegistry,
providerConfig,
},
});
return resolveModelWithRegistry({
provider,
modelId,
provider: normalizedRef.provider,
modelId: normalizedRef.model,
modelRegistry,
cfg,
agentDir: resolvedAgentDir,
@ -763,8 +781,8 @@ export async function resolveModelAsync(
return {
error: buildUnknownModelError({
provider,
modelId,
provider: normalizedRef.provider,
modelId: normalizedRef.model,
cfg,
agentDir: resolvedAgentDir,
runtimeHooks,