From 5ac07b8ef0868c9fa25160acf9b5aa49cb3efe7a Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 5 Apr 2026 09:17:34 +0100 Subject: [PATCH] fix: normalize huggingface refs and discovery timeout --- extensions/huggingface/models.test.ts | 47 ++++++++++++++++- extensions/huggingface/models.ts | 8 ++- src/agents/model-ref-shared.ts | 12 +++++ src/agents/model-selection.test.ts | 6 +++ src/agents/pi-embedded-runner/model.test.ts | 28 +++++++++++ src/agents/pi-embedded-runner/model.ts | 56 ++++++++++++++------- 6 files changed, 135 insertions(+), 22 deletions(-) diff --git a/extensions/huggingface/models.test.ts b/extensions/huggingface/models.test.ts index 0f6bc708c55..ed982d519e3 100644 --- a/extensions/huggingface/models.test.ts +++ b/extensions/huggingface/models.test.ts @@ -1,10 +1,21 @@ -import { describe, expect, it } from "vitest"; +import { afterEach, describe, expect, it, vi } from "vitest"; import { buildHuggingfaceModelDefinition, discoverHuggingfaceModels, HUGGINGFACE_MODEL_CATALOG, isHuggingfacePolicyLocked, } from "./api.js"; +import { HUGGINGFACE_DISCOVERY_TIMEOUT_MS } from "./models.js"; + +const ORIGINAL_VITEST = process.env.VITEST; +const ORIGINAL_NODE_ENV = process.env.NODE_ENV; + +afterEach(() => { + process.env.VITEST = ORIGINAL_VITEST; + process.env.NODE_ENV = ORIGINAL_NODE_ENV; + vi.restoreAllMocks(); + vi.unstubAllGlobals(); +}); describe("huggingface models", () => { it("buildHuggingfaceModelDefinition returns config with required fields", () => { @@ -31,6 +42,40 @@ describe("huggingface models", () => { expect(models[0].id).toBe("deepseek-ai/DeepSeek-R1"); }); + it("uses the default discovery timeout for live Hugging Face fetches", async () => { + process.env.VITEST = "false"; + process.env.NODE_ENV = "development"; + const timeoutSpy = vi.spyOn(AbortSignal, "timeout"); + vi.stubGlobal( + "fetch", + vi.fn( + async () => + new Response("{}", { status: 500, headers: { "Content-Type": "application/json" } }), + ), + ); + + await discoverHuggingfaceModels("hf_test_token"); + + expect(timeoutSpy).toHaveBeenCalledWith(HUGGINGFACE_DISCOVERY_TIMEOUT_MS); + }); + + it("accepts a custom discovery timeout override", async () => { + process.env.VITEST = "false"; + process.env.NODE_ENV = "development"; + const timeoutSpy = vi.spyOn(AbortSignal, "timeout"); + vi.stubGlobal( + "fetch", + vi.fn( + async () => + new Response("{}", { status: 500, headers: { "Content-Type": "application/json" } }), + ), + ); + + await discoverHuggingfaceModels("hf_test_token", 25_000); + + expect(timeoutSpy).toHaveBeenCalledWith(25_000); + }); + describe("isHuggingfacePolicyLocked", () => { it("returns true for :cheapest and :fastest refs", () => { expect(isHuggingfacePolicyLocked("huggingface/deepseek-ai/DeepSeek-R1:cheapest")).toBe(true); diff --git a/extensions/huggingface/models.ts b/extensions/huggingface/models.ts index 4775151e1a9..9bad187bad8 100644 --- a/extensions/huggingface/models.ts +++ b/extensions/huggingface/models.ts @@ -2,6 +2,7 @@ import type { ModelDefinitionConfig } from "openclaw/plugin-sdk/provider-model-s export const HUGGINGFACE_BASE_URL = "https://router.huggingface.co/v1"; export const HUGGINGFACE_POLICY_SUFFIXES = ["cheapest", "fastest"] as const; +export const HUGGINGFACE_DISCOVERY_TIMEOUT_MS = 30_000; const HUGGINGFACE_DEFAULT_COST = { input: 0, @@ -123,7 +124,10 @@ function displayNameFromApiEntry(entry: HFModelEntry, inferredName: string): str return inferredName; } -export async function discoverHuggingfaceModels(apiKey: string): Promise { +export async function discoverHuggingfaceModels( + apiKey: string, + timeoutMs = HUGGINGFACE_DISCOVERY_TIMEOUT_MS, +): Promise { if (process.env.VITEST === "true" || process.env.NODE_ENV === "test") { return HUGGINGFACE_MODEL_CATALOG.map(buildHuggingfaceModelDefinition); } @@ -135,7 +139,7 @@ export async function discoverHuggingfaceModels(apiKey: string): Promise { defaultProvider: "openai", expected: { provider: "openrouter", model: "anthropic/claude-sonnet-4-6" }, }, + { + name: "strips duplicate Hugging Face provider prefixes", + variants: ["huggingface/deepseek-ai/DeepSeek-R1"], + defaultProvider: "huggingface", + expected: { provider: "huggingface", model: "deepseek-ai/DeepSeek-R1" }, + }, { name: "normalizes Vercel Claude shorthand to anthropic-prefixed model ids", variants: ["vercel-ai-gateway/claude-opus-4.6"], diff --git a/src/agents/pi-embedded-runner/model.test.ts b/src/agents/pi-embedded-runner/model.test.ts index 1c2cbf341f3..68d684cf5f6 100644 --- a/src/agents/pi-embedded-runner/model.test.ts +++ b/src/agents/pi-embedded-runner/model.test.ts @@ -674,6 +674,34 @@ describe("resolveModel", () => { }); }); + it("matches prefixed Hugging Face ids against discovered registry models", () => { + mockDiscoveredModel(discoverModels, { + provider: "huggingface", + modelId: "deepseek-ai/DeepSeek-R1", + templateModel: { + ...makeModel("deepseek-ai/DeepSeek-R1"), + provider: "huggingface", + baseUrl: "https://router.huggingface.co/v1", + reasoning: true, + input: ["text"], + }, + }); + + const result = resolveModelForTest( + "huggingface", + "huggingface/deepseek-ai/DeepSeek-R1", + "/tmp/agent", + ); + + expect(result.error).toBeUndefined(); + expect(result.model).toMatchObject({ + provider: "huggingface", + id: "deepseek-ai/DeepSeek-R1", + reasoning: true, + input: ["text"], + }); + }); + it("preloads OpenRouter capabilities before first async resolve of an unknown model", async () => { mockLoadOpenRouterModelCapabilities.mockImplementation(async (modelId) => { if (modelId === "google/gemini-3.1-flash-image-preview") { diff --git a/src/agents/pi-embedded-runner/model.ts b/src/agents/pi-embedded-runner/model.ts index 48f0689f4c0..f40af139692 100644 --- a/src/agents/pi-embedded-runner/model.ts +++ b/src/agents/pi-embedded-runner/model.ts @@ -17,6 +17,7 @@ import { resolveOpenClawAgentDir } from "../agent-paths.js"; import { DEFAULT_CONTEXT_TOKENS } from "../defaults.js"; import { buildModelAliasLines } from "../model-alias-lines.js"; import { isSecretRefHeaderValueMarker } from "../model-auth-markers.js"; +import { normalizeStaticProviderModelId } from "../model-ref-shared.js"; import { findNormalizedProviderValue, normalizeProviderId } from "../model-selection.js"; import { buildSuppressedBuiltInModelError, @@ -618,7 +619,16 @@ export function resolveModelWithRegistry(params: { agentDir?: string; runtimeHooks?: ProviderRuntimeHooks; }): Model | undefined { - const explicitModel = resolveExplicitModelWithRegistry(params); + const normalizedRef = { + provider: params.provider, + model: normalizeStaticProviderModelId(normalizeProviderId(params.provider), params.modelId), + }; + const normalizedParams = { + ...params, + provider: normalizedRef.provider, + modelId: normalizedRef.model, + }; + const explicitModel = resolveExplicitModelWithRegistry(normalizedParams); if (explicitModel?.kind === "suppressed") { return undefined; } @@ -626,12 +636,12 @@ export function resolveModelWithRegistry(params: { return explicitModel.model; } - const pluginDynamicModel = resolvePluginDynamicModelWithRegistry(params); + const pluginDynamicModel = resolvePluginDynamicModelWithRegistry(normalizedParams); if (pluginDynamicModel) { return pluginDynamicModel; } - return resolveConfiguredFallbackModel(params); + return resolveConfiguredFallbackModel(normalizedParams); } export function resolveModel( @@ -651,13 +661,17 @@ export function resolveModel( authStorage: AuthStorage; modelRegistry: ModelRegistry; } { + const normalizedRef = { + provider, + model: normalizeStaticProviderModelId(normalizeProviderId(provider), modelId), + }; const resolvedAgentDir = agentDir ?? resolveOpenClawAgentDir(); const authStorage = options?.authStorage ?? discoverAuthStorage(resolvedAgentDir); const modelRegistry = options?.modelRegistry ?? discoverModels(authStorage, resolvedAgentDir); const runtimeHooks = resolveRuntimeHooks(options); const model = resolveModelWithRegistry({ - provider, - modelId, + provider: normalizedRef.provider, + modelId: normalizedRef.model, modelRegistry, cfg, agentDir: resolvedAgentDir, @@ -669,8 +683,8 @@ export function resolveModel( return { error: buildUnknownModelError({ - provider, - modelId, + provider: normalizedRef.provider, + modelId: normalizedRef.model, cfg, agentDir: resolvedAgentDir, runtimeHooks, @@ -698,13 +712,17 @@ export async function resolveModelAsync( authStorage: AuthStorage; modelRegistry: ModelRegistry; }> { + const normalizedRef = { + provider, + model: normalizeStaticProviderModelId(normalizeProviderId(provider), modelId), + }; const resolvedAgentDir = agentDir ?? resolveOpenClawAgentDir(); const authStorage = options?.authStorage ?? discoverAuthStorage(resolvedAgentDir); const modelRegistry = options?.modelRegistry ?? discoverModels(authStorage, resolvedAgentDir); const runtimeHooks = resolveRuntimeHooks(options); const explicitModel = resolveExplicitModelWithRegistry({ - provider, - modelId, + provider: normalizedRef.provider, + modelId: normalizedRef.model, modelRegistry, cfg, agentDir: resolvedAgentDir, @@ -713,8 +731,8 @@ export async function resolveModelAsync( if (explicitModel?.kind === "suppressed") { return { error: buildUnknownModelError({ - provider, - modelId, + provider: normalizedRef.provider, + modelId: normalizedRef.model, cfg, agentDir: resolvedAgentDir, runtimeHooks, @@ -723,26 +741,26 @@ export async function resolveModelAsync( modelRegistry, }; } - const providerConfig = resolveConfiguredProviderConfig(cfg, provider); + const providerConfig = resolveConfiguredProviderConfig(cfg, normalizedRef.provider); const resolveDynamicAttempt = async (attemptOptions?: { clearHookCache?: boolean }) => { if (attemptOptions?.clearHookCache) { runtimeHooks.clearProviderRuntimeHookCache(); } await runtimeHooks.prepareProviderDynamicModel({ - provider, + provider: normalizedRef.provider, config: cfg, context: { config: cfg, agentDir: resolvedAgentDir, - provider, - modelId, + provider: normalizedRef.provider, + modelId: normalizedRef.model, modelRegistry, providerConfig, }, }); return resolveModelWithRegistry({ - provider, - modelId, + provider: normalizedRef.provider, + modelId: normalizedRef.model, modelRegistry, cfg, agentDir: resolvedAgentDir, @@ -763,8 +781,8 @@ export async function resolveModelAsync( return { error: buildUnknownModelError({ - provider, - modelId, + provider: normalizedRef.provider, + modelId: normalizedRef.model, cfg, agentDir: resolvedAgentDir, runtimeHooks,