From afb6e4b1858f76ec1447b958df53aa8d547a3f3f Mon Sep 17 00:00:00 2001 From: Gustavo Madeira Santana Date: Sun, 15 Mar 2026 17:47:31 +0000 Subject: [PATCH] Plugins: extract provider auth and wizard flows --- src/agents/google-model-id.test.ts | 11 + src/agents/google-model-id.ts | 21 ++ src/agents/model-ref.test.ts | 38 ++++ src/agents/model-ref.ts | 94 ++++++++ src/agents/model-selection.ts | 108 +--------- src/agents/models-config.providers.ts | 23 +- src/commands/provider-auth-helpers.ts | 72 +------ src/extension-host/provider-auth.test.ts | 106 +++++++++ src/extension-host/provider-auth.ts | 82 +++++++ src/extension-host/provider-wizard.test.ts | 83 ++++++++ src/extension-host/provider-wizard.ts | 201 ++++++++++++++++++ .../providers/google/inline-data.ts | 2 +- src/plugins/provider-wizard.ts | 178 ++-------------- 13 files changed, 670 insertions(+), 349 deletions(-) create mode 100644 src/agents/google-model-id.test.ts create mode 100644 src/agents/google-model-id.ts create mode 100644 src/agents/model-ref.test.ts create mode 100644 src/agents/model-ref.ts create mode 100644 src/extension-host/provider-auth.test.ts create mode 100644 src/extension-host/provider-auth.ts create mode 100644 src/extension-host/provider-wizard.test.ts create mode 100644 src/extension-host/provider-wizard.ts diff --git a/src/agents/google-model-id.test.ts b/src/agents/google-model-id.test.ts new file mode 100644 index 00000000000..bae8a44a241 --- /dev/null +++ b/src/agents/google-model-id.test.ts @@ -0,0 +1,11 @@ +import { describe, expect, it } from "vitest"; +import { normalizeGoogleModelId } from "./google-model-id.js"; + +describe("normalizeGoogleModelId", () => { + it("preserves compatibility with legacy Gemini aliases", () => { + expect(normalizeGoogleModelId("gemini-3.1-flash")).toBe("gemini-3-flash-preview"); + expect(normalizeGoogleModelId("gemini-3.1-flash-preview")).toBe("gemini-3-flash-preview"); + expect(normalizeGoogleModelId("gemini-3.1-flash-lite")).toBe("gemini-3.1-flash-lite-preview"); + expect(normalizeGoogleModelId("gemini-3-pro")).toBe("gemini-3-pro-preview"); + }); +}); diff --git a/src/agents/google-model-id.ts b/src/agents/google-model-id.ts new file mode 100644 index 00000000000..c7cfac6f891 --- /dev/null +++ b/src/agents/google-model-id.ts @@ -0,0 +1,21 @@ +export function normalizeGoogleModelId(id: string): string { + if (id === "gemini-3-pro") { + return "gemini-3-pro-preview"; + } + if (id === "gemini-3-flash") { + return "gemini-3-flash-preview"; + } + if (id === "gemini-3.1-pro") { + return "gemini-3.1-pro-preview"; + } + if (id === "gemini-3.1-flash-lite") { + return "gemini-3.1-flash-lite-preview"; + } + // Preserve compatibility with earlier OpenClaw docs/config that pointed at a + // non-existent Gemini Flash preview ID. Google's current Flash text model is + // `gemini-3-flash-preview`. + if (id === "gemini-3.1-flash" || id === "gemini-3.1-flash-preview") { + return "gemini-3-flash-preview"; + } + return id; +} diff --git a/src/agents/model-ref.test.ts b/src/agents/model-ref.test.ts new file mode 100644 index 00000000000..b63e3f9ecd8 --- /dev/null +++ b/src/agents/model-ref.test.ts @@ -0,0 +1,38 @@ +import { describe, expect, it } from "vitest"; +import { modelKey, parseModelRef } from "./model-ref.js"; + +describe("modelKey", () => { + it("keeps canonical OpenRouter native ids without duplicating the provider", () => { + expect(modelKey("openrouter", "openrouter/hunter-alpha")).toBe("openrouter/hunter-alpha"); + }); +}); + +describe("parseModelRef", () => { + it("uses the default provider when omitted", () => { + expect(parseModelRef("claude-3-5-sonnet", "anthropic")).toEqual({ + provider: "anthropic", + model: "claude-3-5-sonnet", + }); + }); + + it("normalizes anthropic shorthand aliases", () => { + expect(parseModelRef("anthropic/opus-4.6", "openai")).toEqual({ + provider: "anthropic", + model: "claude-opus-4-6", + }); + }); + + it("preserves nested model ids after the provider prefix", () => { + expect(parseModelRef("nvidia/moonshotai/kimi-k2.5", "anthropic")).toEqual({ + provider: "nvidia", + model: "moonshotai/kimi-k2.5", + }); + }); + + it("normalizes OpenRouter-native model refs without duplicating the provider", () => { + expect(parseModelRef("openrouter/hunter-alpha", "anthropic")).toEqual({ + provider: "openrouter", + model: "openrouter/hunter-alpha", + }); + }); +}); diff --git a/src/agents/model-ref.ts b/src/agents/model-ref.ts new file mode 100644 index 00000000000..3803e6c0798 --- /dev/null +++ b/src/agents/model-ref.ts @@ -0,0 +1,94 @@ +import { normalizeGoogleModelId } from "./google-model-id.js"; +import { normalizeProviderId } from "./provider-id.js"; + +export type ModelRef = { + provider: string; + model: string; +}; + +export function modelKey(provider: string, model: string) { + const providerId = provider.trim(); + const modelId = model.trim(); + if (!providerId) { + return modelId; + } + if (!modelId) { + return providerId; + } + return modelId.toLowerCase().startsWith(`${providerId.toLowerCase()}/`) + ? modelId + : `${providerId}/${modelId}`; +} + +export function legacyModelKey(provider: string, model: string): string | null { + const providerId = provider.trim(); + const modelId = model.trim(); + if (!providerId || !modelId) { + return null; + } + const rawKey = `${providerId}/${modelId}`; + const canonicalKey = modelKey(providerId, modelId); + return rawKey === canonicalKey ? null : rawKey; +} + +function normalizeAnthropicModelId(model: string): string { + const trimmed = model.trim(); + if (!trimmed) { + return trimmed; + } + const lower = trimmed.toLowerCase(); + switch (lower) { + case "opus-4.6": + return "claude-opus-4-6"; + case "opus-4.5": + return "claude-opus-4-5"; + case "sonnet-4.6": + return "claude-sonnet-4-6"; + case "sonnet-4.5": + return "claude-sonnet-4-5"; + default: + return trimmed; + } +} + +function normalizeProviderModelId(provider: string, model: string): string { + if (provider === "anthropic") { + return normalizeAnthropicModelId(model); + } + if (provider === "vercel-ai-gateway" && !model.includes("/")) { + const normalizedAnthropicModel = normalizeAnthropicModelId(model); + if (normalizedAnthropicModel.startsWith("claude-")) { + return `anthropic/${normalizedAnthropicModel}`; + } + } + if (provider === "google" || provider === "google-vertex") { + return normalizeGoogleModelId(model); + } + if (provider === "openrouter" && !model.includes("/")) { + return `openrouter/${model}`; + } + return model; +} + +export function normalizeModelRef(provider: string, model: string): ModelRef { + const normalizedProvider = normalizeProviderId(provider); + const normalizedModel = normalizeProviderModelId(normalizedProvider, model.trim()); + return { provider: normalizedProvider, model: normalizedModel }; +} + +export function parseModelRef(raw: string, defaultProvider: string): ModelRef | null { + const trimmed = raw.trim(); + if (!trimmed) { + return null; + } + const slash = trimmed.indexOf("/"); + if (slash === -1) { + return normalizeModelRef(defaultProvider, trimmed); + } + const providerRaw = trimmed.slice(0, slash).trim(); + const model = trimmed.slice(slash + 1).trim(); + if (!providerRaw || !model) { + return null; + } + return normalizeModelRef(providerRaw, model); +} diff --git a/src/agents/model-selection.ts b/src/agents/model-selection.ts index d301cddde67..0237508c4d2 100644 --- a/src/agents/model-selection.ts +++ b/src/agents/model-selection.ts @@ -15,17 +15,19 @@ import { import { DEFAULT_MODEL, DEFAULT_PROVIDER } from "./defaults.js"; import type { ModelCatalogEntry } from "./model-catalog.js"; import { splitTrailingAuthProfile } from "./model-ref-profile.js"; -import { normalizeGoogleModelId } from "./models-config.providers.js"; +import { + legacyModelKey, + modelKey, + normalizeModelRef, + parseModelRef, + type ModelRef, +} from "./model-ref.js"; import { normalizeProviderId } from "./provider-id.js"; export { normalizeProviderId, normalizeProviderIdForAuth } from "./provider-id.js"; +export { legacyModelKey, modelKey, normalizeModelRef, parseModelRef } from "./model-ref.js"; const log = createSubsystemLogger("model-selection"); -export type ModelRef = { - provider: string; - model: string; -}; - export type ThinkLevel = "off" | "minimal" | "low" | "medium" | "high" | "xhigh" | "adaptive"; export type ModelAliasIndex = { @@ -37,31 +39,6 @@ function normalizeAliasKey(value: string): string { return value.trim().toLowerCase(); } -export function modelKey(provider: string, model: string) { - const providerId = provider.trim(); - const modelId = model.trim(); - if (!providerId) { - return modelId; - } - if (!modelId) { - return providerId; - } - return modelId.toLowerCase().startsWith(`${providerId.toLowerCase()}/`) - ? modelId - : `${providerId}/${modelId}`; -} - -export function legacyModelKey(provider: string, model: string): string | null { - const providerId = provider.trim(); - const modelId = model.trim(); - if (!providerId || !modelId) { - return null; - } - const rawKey = `${providerId}/${modelId}`; - const canonicalKey = modelKey(providerId, modelId); - return rawKey === canonicalKey ? null : rawKey; -} - export function findNormalizedProviderValue( entries: Record | undefined, provider: string, @@ -101,75 +78,6 @@ export function isCliProvider(provider: string, cfg?: OpenClawConfig): boolean { return Object.keys(backends).some((key) => normalizeProviderId(key) === normalized); } -function normalizeAnthropicModelId(model: string): string { - const trimmed = model.trim(); - if (!trimmed) { - return trimmed; - } - const lower = trimmed.toLowerCase(); - // Keep alias resolution local so bundled startup paths cannot trip a TDZ on - // a module-level alias table while config parsing is still initializing. - switch (lower) { - case "opus-4.6": - return "claude-opus-4-6"; - case "opus-4.5": - return "claude-opus-4-5"; - case "sonnet-4.6": - return "claude-sonnet-4-6"; - case "sonnet-4.5": - return "claude-sonnet-4-5"; - default: - return trimmed; - } -} - -function normalizeProviderModelId(provider: string, model: string): string { - if (provider === "anthropic") { - return normalizeAnthropicModelId(model); - } - if (provider === "vercel-ai-gateway" && !model.includes("/")) { - // Allow Vercel-specific Claude refs without an upstream prefix. - const normalizedAnthropicModel = normalizeAnthropicModelId(model); - if (normalizedAnthropicModel.startsWith("claude-")) { - return `anthropic/${normalizedAnthropicModel}`; - } - } - if (provider === "google" || provider === "google-vertex") { - return normalizeGoogleModelId(model); - } - // OpenRouter-native models (e.g. "openrouter/aurora-alpha") need the full - // "openrouter/" as the model ID sent to the API. Models from external - // providers already contain a slash (e.g. "anthropic/claude-sonnet-4-5") and - // are passed through as-is (#12924). - if (provider === "openrouter" && !model.includes("/")) { - return `openrouter/${model}`; - } - return model; -} - -export function normalizeModelRef(provider: string, model: string): ModelRef { - const normalizedProvider = normalizeProviderId(provider); - const normalizedModel = normalizeProviderModelId(normalizedProvider, model.trim()); - return { provider: normalizedProvider, model: normalizedModel }; -} - -export function parseModelRef(raw: string, defaultProvider: string): ModelRef | null { - const trimmed = raw.trim(); - if (!trimmed) { - return null; - } - const slash = trimmed.indexOf("/"); - if (slash === -1) { - return normalizeModelRef(defaultProvider, trimmed); - } - const providerRaw = trimmed.slice(0, slash).trim(); - const model = trimmed.slice(slash + 1).trim(); - if (!providerRaw || !model) { - return null; - } - return normalizeModelRef(providerRaw, model); -} - export function inferUniqueProviderFromConfiguredModels(params: { cfg: OpenClawConfig; model: string; diff --git a/src/agents/models-config.providers.ts b/src/agents/models-config.providers.ts index b4ef8f4b0b1..053357da925 100644 --- a/src/agents/models-config.providers.ts +++ b/src/agents/models-config.providers.ts @@ -12,6 +12,7 @@ import { buildCloudflareAiGatewayModelDefinition, resolveCloudflareAiGatewayBaseUrl, } from "./cloudflare-ai-gateway.js"; +import { normalizeGoogleModelId } from "./google-model-id.js"; import { buildHuggingfaceProvider, buildKilocodeProviderWithDiscovery, @@ -223,27 +224,7 @@ function resolveApiKeyFromProfiles(params: { return undefined; } -export function normalizeGoogleModelId(id: string): string { - if (id === "gemini-3-pro") { - return "gemini-3-pro-preview"; - } - if (id === "gemini-3-flash") { - return "gemini-3-flash-preview"; - } - if (id === "gemini-3.1-pro") { - return "gemini-3.1-pro-preview"; - } - if (id === "gemini-3.1-flash-lite") { - return "gemini-3.1-flash-lite-preview"; - } - // Preserve compatibility with earlier OpenClaw docs/config that pointed at a - // non-existent Gemini Flash preview ID. Google's current Flash text model is - // `gemini-3-flash-preview`. - if (id === "gemini-3.1-flash" || id === "gemini-3.1-flash-preview") { - return "gemini-3-flash-preview"; - } - return id; -} +export { normalizeGoogleModelId } from "./google-model-id.js"; const ANTIGRAVITY_BARE_PRO_IDS = new Set(["gemini-3-pro", "gemini-3.1-pro", "gemini-3-1-pro"]); diff --git a/src/commands/provider-auth-helpers.ts b/src/commands/provider-auth-helpers.ts index f36c1c3de73..23ba7da64fc 100644 --- a/src/commands/provider-auth-helpers.ts +++ b/src/commands/provider-auth-helpers.ts @@ -1,82 +1,30 @@ -import { normalizeProviderId } from "../agents/model-selection.js"; import type { OpenClawConfig } from "../config/config.js"; +import { + applyExtensionHostDefaultModel, + mergeExtensionHostConfigPatch, + pickExtensionHostAuthMethod, + resolveExtensionHostProviderMatch, +} from "../extension-host/provider-auth.js"; import type { ProviderAuthMethod, ProviderPlugin } from "../plugins/types.js"; export function resolveProviderMatch( providers: ProviderPlugin[], rawProvider?: string, ): ProviderPlugin | null { - const raw = rawProvider?.trim(); - if (!raw) { - return null; - } - const normalized = normalizeProviderId(raw); - return ( - providers.find((provider) => normalizeProviderId(provider.id) === normalized) ?? - providers.find( - (provider) => - provider.aliases?.some((alias) => normalizeProviderId(alias) === normalized) ?? false, - ) ?? - null - ); + return resolveExtensionHostProviderMatch(providers, rawProvider); } export function pickAuthMethod( provider: ProviderPlugin, rawMethod?: string, ): ProviderAuthMethod | null { - const raw = rawMethod?.trim(); - if (!raw) { - return null; - } - const normalized = raw.toLowerCase(); - return ( - provider.auth.find((method) => method.id.toLowerCase() === normalized) ?? - provider.auth.find((method) => method.label.toLowerCase() === normalized) ?? - null - ); -} - -function isPlainRecord(value: unknown): value is Record { - return Boolean(value && typeof value === "object" && !Array.isArray(value)); + return pickExtensionHostAuthMethod(provider, rawMethod); } export function mergeConfigPatch(base: T, patch: unknown): T { - if (!isPlainRecord(base) || !isPlainRecord(patch)) { - return patch as T; - } - - const next: Record = { ...base }; - for (const [key, value] of Object.entries(patch)) { - const existing = next[key]; - if (isPlainRecord(existing) && isPlainRecord(value)) { - next[key] = mergeConfigPatch(existing, value); - } else { - next[key] = value; - } - } - return next as T; + return mergeExtensionHostConfigPatch(base, patch); } export function applyDefaultModel(cfg: OpenClawConfig, model: string): OpenClawConfig { - const models = { ...cfg.agents?.defaults?.models }; - models[model] = models[model] ?? {}; - - const existingModel = cfg.agents?.defaults?.model; - return { - ...cfg, - agents: { - ...cfg.agents, - defaults: { - ...cfg.agents?.defaults, - models, - model: { - ...(existingModel && typeof existingModel === "object" && "fallbacks" in existingModel - ? { fallbacks: (existingModel as { fallbacks?: string[] }).fallbacks } - : undefined), - primary: model, - }, - }, - }, - }; + return applyExtensionHostDefaultModel(cfg, model); } diff --git a/src/extension-host/provider-auth.test.ts b/src/extension-host/provider-auth.test.ts new file mode 100644 index 00000000000..20844620745 --- /dev/null +++ b/src/extension-host/provider-auth.test.ts @@ -0,0 +1,106 @@ +import { describe, expect, it, vi } from "vitest"; +import type { ProviderPlugin } from "../plugins/types.js"; +import { + applyExtensionHostDefaultModel, + mergeExtensionHostConfigPatch, + pickExtensionHostAuthMethod, + resolveExtensionHostProviderMatch, +} from "./provider-auth.js"; + +function makeProvider(overrides: Partial & Pick) { + return { + auth: [], + ...overrides, + } satisfies ProviderPlugin; +} + +describe("resolveExtensionHostProviderMatch", () => { + it("matches providers by normalized id and aliases", () => { + const providers = [ + makeProvider({ + id: "openrouter", + label: "OpenRouter", + aliases: ["Open Router"], + }), + ]; + + expect(resolveExtensionHostProviderMatch(providers, "openrouter")?.id).toBe("openrouter"); + expect(resolveExtensionHostProviderMatch(providers, " Open Router ")?.id).toBe("openrouter"); + expect(resolveExtensionHostProviderMatch(providers, "missing")).toBeNull(); + }); +}); + +describe("pickExtensionHostAuthMethod", () => { + it("matches auth methods by id or label", () => { + const provider = makeProvider({ + id: "ollama", + label: "Ollama", + auth: [ + { id: "local", label: "Local", kind: "custom", run: vi.fn() }, + { id: "cloud", label: "Cloud", kind: "custom", run: vi.fn() }, + ], + }); + + expect(pickExtensionHostAuthMethod(provider, "local")?.id).toBe("local"); + expect(pickExtensionHostAuthMethod(provider, "cloud")?.id).toBe("cloud"); + expect(pickExtensionHostAuthMethod(provider, "Cloud")?.id).toBe("cloud"); + expect(pickExtensionHostAuthMethod(provider, "missing")).toBeNull(); + }); +}); + +describe("mergeExtensionHostConfigPatch", () => { + it("deep-merges plain record config patches", () => { + expect( + mergeExtensionHostConfigPatch( + { + models: { providers: { ollama: { baseUrl: "http://127.0.0.1:11434" } } }, + auth: { profiles: { existing: { provider: "anthropic" } } }, + }, + { + models: { providers: { ollama: { api: "ollama" } } }, + auth: { profiles: { fresh: { provider: "ollama" } } }, + }, + ), + ).toEqual({ + models: { providers: { ollama: { baseUrl: "http://127.0.0.1:11434", api: "ollama" } } }, + auth: { + profiles: { + existing: { provider: "anthropic" }, + fresh: { provider: "ollama" }, + }, + }, + }); + }); +}); + +describe("applyExtensionHostDefaultModel", () => { + it("sets the primary model while preserving fallback config", () => { + expect( + applyExtensionHostDefaultModel( + { + agents: { + defaults: { + model: { + primary: "anthropic/claude-sonnet-4-5", + fallbacks: ["openai/gpt-5"], + }, + }, + }, + }, + "ollama/qwen3:4b", + ), + ).toEqual({ + agents: { + defaults: { + models: { + "ollama/qwen3:4b": {}, + }, + model: { + primary: "ollama/qwen3:4b", + fallbacks: ["openai/gpt-5"], + }, + }, + }, + }); + }); +}); diff --git a/src/extension-host/provider-auth.ts b/src/extension-host/provider-auth.ts new file mode 100644 index 00000000000..1d9a926d365 --- /dev/null +++ b/src/extension-host/provider-auth.ts @@ -0,0 +1,82 @@ +import { normalizeProviderId } from "../agents/provider-id.js"; +import type { OpenClawConfig } from "../config/config.js"; +import type { ProviderAuthMethod, ProviderPlugin } from "../plugins/types.js"; + +export function resolveExtensionHostProviderMatch( + providers: ProviderPlugin[], + rawProvider?: string, +): ProviderPlugin | null { + const raw = rawProvider?.trim(); + if (!raw) { + return null; + } + const normalized = normalizeProviderId(raw); + return ( + providers.find((provider) => normalizeProviderId(provider.id) === normalized) ?? + providers.find( + (provider) => + provider.aliases?.some((alias) => normalizeProviderId(alias) === normalized) ?? false, + ) ?? + null + ); +} + +export function pickExtensionHostAuthMethod( + provider: ProviderPlugin, + rawMethod?: string, +): ProviderAuthMethod | null { + const raw = rawMethod?.trim(); + if (!raw) { + return null; + } + const normalized = raw.toLowerCase(); + return ( + provider.auth.find((method) => method.id.toLowerCase() === normalized) ?? + provider.auth.find((method) => method.label.toLowerCase() === normalized) ?? + null + ); +} + +function isPlainRecord(value: unknown): value is Record { + return Boolean(value && typeof value === "object" && !Array.isArray(value)); +} + +export function mergeExtensionHostConfigPatch(base: T, patch: unknown): T { + if (!isPlainRecord(base) || !isPlainRecord(patch)) { + return patch as T; + } + + const next: Record = { ...base }; + for (const [key, value] of Object.entries(patch)) { + const existing = next[key]; + if (isPlainRecord(existing) && isPlainRecord(value)) { + next[key] = mergeExtensionHostConfigPatch(existing, value); + } else { + next[key] = value; + } + } + return next as T; +} + +export function applyExtensionHostDefaultModel(cfg: OpenClawConfig, model: string): OpenClawConfig { + const models = { ...cfg.agents?.defaults?.models }; + models[model] = models[model] ?? {}; + + const existingModel = cfg.agents?.defaults?.model; + return { + ...cfg, + agents: { + ...cfg.agents, + defaults: { + ...cfg.agents?.defaults, + models, + model: { + ...(existingModel && typeof existingModel === "object" && "fallbacks" in existingModel + ? { fallbacks: (existingModel as { fallbacks?: string[] }).fallbacks } + : undefined), + primary: model, + }, + }, + }, + }; +} diff --git a/src/extension-host/provider-wizard.test.ts b/src/extension-host/provider-wizard.test.ts new file mode 100644 index 00000000000..000e4b6af8a --- /dev/null +++ b/src/extension-host/provider-wizard.test.ts @@ -0,0 +1,83 @@ +import { describe, expect, it, vi } from "vitest"; +import type { ProviderPlugin } from "../plugins/types.js"; +import { + buildExtensionHostProviderMethodChoice, + resolveExtensionHostProviderChoice, + resolveExtensionHostProviderModelPickerEntries, + resolveExtensionHostProviderWizardOptions, +} from "./provider-wizard.js"; + +function makeProvider(overrides: Partial & Pick) { + return { + auth: [], + ...overrides, + } satisfies ProviderPlugin; +} + +describe("resolveExtensionHostProviderWizardOptions", () => { + it("uses explicit onboarding choice ids and bound method ids", () => { + const provider = makeProvider({ + id: "vllm", + label: "vLLM", + auth: [ + { id: "local", label: "Local", kind: "custom", run: vi.fn() }, + { id: "cloud", label: "Cloud", kind: "custom", run: vi.fn() }, + ], + wizard: { + onboarding: { + choiceId: "self-hosted-vllm", + methodId: "local", + choiceLabel: "vLLM local", + groupId: "local-runtimes", + groupLabel: "Local runtimes", + }, + }, + }); + + expect(resolveExtensionHostProviderWizardOptions([provider])).toEqual([ + { + value: "self-hosted-vllm", + label: "vLLM local", + groupId: "local-runtimes", + groupLabel: "Local runtimes", + }, + ]); + expect( + resolveExtensionHostProviderChoice({ + providers: [provider], + choice: "self-hosted-vllm", + }), + ).toEqual({ + provider, + method: provider.auth[0], + }); + }); +}); + +describe("resolveExtensionHostProviderModelPickerEntries", () => { + it("builds model-picker entries from provider metadata", () => { + const provider = makeProvider({ + id: "sglang", + label: "SGLang", + auth: [ + { id: "server", label: "Server", kind: "custom", run: vi.fn() }, + { id: "cloud", label: "Cloud", kind: "custom", run: vi.fn() }, + ], + wizard: { + modelPicker: { + label: "SGLang server", + hint: "OpenAI-compatible local runtime", + methodId: "server", + }, + }, + }); + + expect(resolveExtensionHostProviderModelPickerEntries([provider])).toEqual([ + { + value: buildExtensionHostProviderMethodChoice("sglang", "server"), + label: "SGLang server", + hint: "OpenAI-compatible local runtime", + }, + ]); + }); +}); diff --git a/src/extension-host/provider-wizard.ts b/src/extension-host/provider-wizard.ts new file mode 100644 index 00000000000..4fc87576038 --- /dev/null +++ b/src/extension-host/provider-wizard.ts @@ -0,0 +1,201 @@ +import { normalizeProviderId } from "../agents/provider-id.js"; +import type { + ProviderAuthMethod, + ProviderPlugin, + ProviderPluginWizardModelPicker, + ProviderPluginWizardOnboarding, +} from "../plugins/types.js"; + +export const EXTENSION_HOST_PROVIDER_CHOICE_PREFIX = "provider-plugin:"; + +export type ExtensionHostProviderWizardOption = { + value: string; + label: string; + hint?: string; + groupId: string; + groupLabel: string; + groupHint?: string; +}; + +export type ExtensionHostProviderModelPickerEntry = { + value: string; + label: string; + hint?: string; +}; + +function normalizeChoiceId(choiceId: string): string { + return choiceId.trim(); +} + +function resolveWizardOnboardingChoiceId( + provider: ProviderPlugin, + wizard: ProviderPluginWizardOnboarding, +): string { + const explicit = wizard.choiceId?.trim(); + if (explicit) { + return explicit; + } + const explicitMethodId = wizard.methodId?.trim(); + if (explicitMethodId) { + return buildExtensionHostProviderMethodChoice(provider.id, explicitMethodId); + } + if (provider.auth.length === 1) { + return provider.id; + } + return buildExtensionHostProviderMethodChoice(provider.id, provider.auth[0]?.id ?? "default"); +} + +function resolveMethodById( + provider: ProviderPlugin, + methodId?: string, +): ProviderAuthMethod | undefined { + const normalizedMethodId = methodId?.trim().toLowerCase(); + if (!normalizedMethodId) { + return provider.auth[0]; + } + return provider.auth.find((method) => method.id.trim().toLowerCase() === normalizedMethodId); +} + +function buildOnboardingOptionForMethod(params: { + provider: ProviderPlugin; + wizard: ProviderPluginWizardOnboarding; + method: ProviderAuthMethod; + value: string; +}): ExtensionHostProviderWizardOption { + const normalizedGroupId = params.wizard.groupId?.trim() || params.provider.id; + return { + value: normalizeChoiceId(params.value), + label: + params.wizard.choiceLabel?.trim() || + (params.provider.auth.length === 1 ? params.provider.label : params.method.label), + hint: params.wizard.choiceHint?.trim() || params.method.hint, + groupId: normalizedGroupId, + groupLabel: params.wizard.groupLabel?.trim() || params.provider.label, + groupHint: params.wizard.groupHint?.trim(), + }; +} + +function resolveModelPickerChoiceValue( + provider: ProviderPlugin, + modelPicker: ProviderPluginWizardModelPicker, +): string { + const explicitMethodId = modelPicker.methodId?.trim(); + if (explicitMethodId) { + return buildExtensionHostProviderMethodChoice(provider.id, explicitMethodId); + } + if (provider.auth.length === 1) { + return provider.id; + } + return buildExtensionHostProviderMethodChoice(provider.id, provider.auth[0]?.id ?? "default"); +} + +export function buildExtensionHostProviderMethodChoice( + providerId: string, + methodId: string, +): string { + return `${EXTENSION_HOST_PROVIDER_CHOICE_PREFIX}${providerId.trim()}:${methodId.trim()}`; +} + +export function resolveExtensionHostProviderWizardOptions( + providers: ProviderPlugin[], +): ExtensionHostProviderWizardOption[] { + const options: ExtensionHostProviderWizardOption[] = []; + + for (const provider of providers) { + const wizard = provider.wizard?.onboarding; + if (!wizard) { + continue; + } + const explicitMethod = resolveMethodById(provider, wizard.methodId); + if (explicitMethod) { + options.push( + buildOnboardingOptionForMethod({ + provider, + wizard, + method: explicitMethod, + value: resolveWizardOnboardingChoiceId(provider, wizard), + }), + ); + continue; + } + + for (const method of provider.auth) { + options.push( + buildOnboardingOptionForMethod({ + provider, + wizard, + method, + value: buildExtensionHostProviderMethodChoice(provider.id, method.id), + }), + ); + } + } + + return options; +} + +export function resolveExtensionHostProviderModelPickerEntries( + providers: ProviderPlugin[], +): ExtensionHostProviderModelPickerEntry[] { + const entries: ExtensionHostProviderModelPickerEntry[] = []; + + for (const provider of providers) { + const modelPicker = provider.wizard?.modelPicker; + if (!modelPicker) { + continue; + } + entries.push({ + value: resolveModelPickerChoiceValue(provider, modelPicker), + label: modelPicker.label?.trim() || `${provider.label} (custom)`, + hint: modelPicker.hint?.trim(), + }); + } + + return entries; +} + +export function resolveExtensionHostProviderChoice(params: { + providers: ProviderPlugin[]; + choice: string; +}): { provider: ProviderPlugin; method: ProviderAuthMethod } | null { + const choice = params.choice.trim(); + if (!choice) { + return null; + } + + if (choice.startsWith(EXTENSION_HOST_PROVIDER_CHOICE_PREFIX)) { + const payload = choice.slice(EXTENSION_HOST_PROVIDER_CHOICE_PREFIX.length); + const separator = payload.indexOf(":"); + const providerId = separator >= 0 ? payload.slice(0, separator) : payload; + const methodId = separator >= 0 ? payload.slice(separator + 1) : undefined; + const provider = params.providers.find( + (entry) => normalizeProviderId(entry.id) === normalizeProviderId(providerId), + ); + if (!provider) { + return null; + } + const method = resolveMethodById(provider, methodId); + return method ? { provider, method } : null; + } + + for (const provider of params.providers) { + const onboarding = provider.wizard?.onboarding; + if (onboarding) { + const onboardingChoiceId = resolveWizardOnboardingChoiceId(provider, onboarding); + if (normalizeChoiceId(onboardingChoiceId) === choice) { + const method = resolveMethodById(provider, onboarding.methodId); + if (method) { + return { provider, method }; + } + } + } + if ( + normalizeProviderId(provider.id) === normalizeProviderId(choice) && + provider.auth.length > 0 + ) { + return { provider, method: provider.auth[0] }; + } + } + + return null; +} diff --git a/src/media-understanding/providers/google/inline-data.ts b/src/media-understanding/providers/google/inline-data.ts index 69fd41871e8..3b76822ce04 100644 --- a/src/media-understanding/providers/google/inline-data.ts +++ b/src/media-understanding/providers/google/inline-data.ts @@ -1,4 +1,4 @@ -import { normalizeGoogleModelId } from "../../../agents/models-config.providers.js"; +import { normalizeGoogleModelId } from "../../../agents/google-model-id.js"; import { parseGeminiAuth } from "../../../infra/gemini-auth.js"; import { assertOkOrThrowHttpError, normalizeBaseUrl, postJsonRequest } from "../shared.js"; diff --git a/src/plugins/provider-wizard.ts b/src/plugins/provider-wizard.ts index 4b02fcd3cf7..ac5ab29e2f1 100644 --- a/src/plugins/provider-wizard.ts +++ b/src/plugins/provider-wizard.ts @@ -1,15 +1,16 @@ import { DEFAULT_PROVIDER } from "../agents/defaults.js"; -import { parseModelRef } from "../agents/model-selection.js"; -import { normalizeProviderId } from "../agents/model-selection.js"; +import { parseModelRef } from "../agents/model-ref.js"; +import { normalizeProviderId } from "../agents/provider-id.js"; import type { OpenClawConfig } from "../config/config.js"; +import { + buildExtensionHostProviderMethodChoice, + resolveExtensionHostProviderChoice, + resolveExtensionHostProviderModelPickerEntries, + resolveExtensionHostProviderWizardOptions, +} from "../extension-host/provider-wizard.js"; import type { WizardPrompter } from "../wizard/prompts.js"; import { resolvePluginProviders } from "./providers.js"; -import type { - ProviderAuthMethod, - ProviderPlugin, - ProviderPluginWizardModelPicker, - ProviderPluginWizardOnboarding, -} from "./types.js"; +import type { ProviderAuthMethod, ProviderPlugin } from "./types.js"; export const PROVIDER_PLUGIN_CHOICE_PREFIX = "provider-plugin:"; @@ -28,60 +29,8 @@ export type ProviderModelPickerEntry = { hint?: string; }; -function normalizeChoiceId(choiceId: string): string { - return choiceId.trim(); -} - -function resolveWizardOnboardingChoiceId( - provider: ProviderPlugin, - wizard: ProviderPluginWizardOnboarding, -): string { - const explicit = wizard.choiceId?.trim(); - if (explicit) { - return explicit; - } - const explicitMethodId = wizard.methodId?.trim(); - if (explicitMethodId) { - return buildProviderPluginMethodChoice(provider.id, explicitMethodId); - } - if (provider.auth.length === 1) { - return provider.id; - } - return buildProviderPluginMethodChoice(provider.id, provider.auth[0]?.id ?? "default"); -} - -function resolveMethodById( - provider: ProviderPlugin, - methodId?: string, -): ProviderAuthMethod | undefined { - const normalizedMethodId = methodId?.trim().toLowerCase(); - if (!normalizedMethodId) { - return provider.auth[0]; - } - return provider.auth.find((method) => method.id.trim().toLowerCase() === normalizedMethodId); -} - -function buildOnboardingOptionForMethod(params: { - provider: ProviderPlugin; - wizard: ProviderPluginWizardOnboarding; - method: ProviderAuthMethod; - value: string; -}): ProviderWizardOption { - const normalizedGroupId = params.wizard.groupId?.trim() || params.provider.id; - return { - value: normalizeChoiceId(params.value), - label: - params.wizard.choiceLabel?.trim() || - (params.provider.auth.length === 1 ? params.provider.label : params.method.label), - hint: params.wizard.choiceHint?.trim() || params.method.hint, - groupId: normalizedGroupId, - groupLabel: params.wizard.groupLabel?.trim() || params.provider.label, - groupHint: params.wizard.groupHint?.trim(), - }; -} - export function buildProviderPluginMethodChoice(providerId: string, methodId: string): string { - return `${PROVIDER_PLUGIN_CHOICE_PREFIX}${providerId.trim()}:${methodId.trim()}`; + return buildExtensionHostProviderMethodChoice(providerId, methodId); } export function resolveProviderWizardOptions(params: { @@ -89,54 +38,7 @@ export function resolveProviderWizardOptions(params: { workspaceDir?: string; env?: NodeJS.ProcessEnv; }): ProviderWizardOption[] { - const providers = resolvePluginProviders(params); - const options: ProviderWizardOption[] = []; - - for (const provider of providers) { - const wizard = provider.wizard?.onboarding; - if (!wizard) { - continue; - } - const explicitMethod = resolveMethodById(provider, wizard.methodId); - if (explicitMethod) { - options.push( - buildOnboardingOptionForMethod({ - provider, - wizard, - method: explicitMethod, - value: resolveWizardOnboardingChoiceId(provider, wizard), - }), - ); - continue; - } - - for (const method of provider.auth) { - options.push( - buildOnboardingOptionForMethod({ - provider, - wizard, - method, - value: buildProviderPluginMethodChoice(provider.id, method.id), - }), - ); - } - } - - return options; -} - -function resolveModelPickerChoiceValue( - provider: ProviderPlugin, - modelPicker: ProviderPluginWizardModelPicker, -): string { - const explicitMethodId = modelPicker.methodId?.trim(); - if (explicitMethodId) { - return buildProviderPluginMethodChoice(provider.id, explicitMethodId); - } - if (provider.auth.length === 1) { - return provider.id; - } - return buildProviderPluginMethodChoice(provider.id, provider.auth[0]?.id ?? "default"); + return resolveExtensionHostProviderWizardOptions(resolvePluginProviders(params)); } export function resolveProviderModelPickerEntries(params: { @@ -144,68 +46,14 @@ export function resolveProviderModelPickerEntries(params: { workspaceDir?: string; env?: NodeJS.ProcessEnv; }): ProviderModelPickerEntry[] { - const providers = resolvePluginProviders(params); - const entries: ProviderModelPickerEntry[] = []; - - for (const provider of providers) { - const modelPicker = provider.wizard?.modelPicker; - if (!modelPicker) { - continue; - } - entries.push({ - value: resolveModelPickerChoiceValue(provider, modelPicker), - label: modelPicker.label?.trim() || `${provider.label} (custom)`, - hint: modelPicker.hint?.trim(), - }); - } - - return entries; + return resolveExtensionHostProviderModelPickerEntries(resolvePluginProviders(params)); } export function resolveProviderPluginChoice(params: { providers: ProviderPlugin[]; choice: string; }): { provider: ProviderPlugin; method: ProviderAuthMethod } | null { - const choice = params.choice.trim(); - if (!choice) { - return null; - } - - if (choice.startsWith(PROVIDER_PLUGIN_CHOICE_PREFIX)) { - const payload = choice.slice(PROVIDER_PLUGIN_CHOICE_PREFIX.length); - const separator = payload.indexOf(":"); - const providerId = separator >= 0 ? payload.slice(0, separator) : payload; - const methodId = separator >= 0 ? payload.slice(separator + 1) : undefined; - const provider = params.providers.find( - (entry) => normalizeProviderId(entry.id) === normalizeProviderId(providerId), - ); - if (!provider) { - return null; - } - const method = resolveMethodById(provider, methodId); - return method ? { provider, method } : null; - } - - for (const provider of params.providers) { - const onboarding = provider.wizard?.onboarding; - if (onboarding) { - const onboardingChoiceId = resolveWizardOnboardingChoiceId(provider, onboarding); - if (normalizeChoiceId(onboardingChoiceId) === choice) { - const method = resolveMethodById(provider, onboarding.methodId); - if (method) { - return { provider, method }; - } - } - } - if ( - normalizeProviderId(provider.id) === normalizeProviderId(choice) && - provider.auth.length > 0 - ) { - return { provider, method: provider.auth[0] }; - } - } - - return null; + return resolveExtensionHostProviderChoice(params); } export async function runProviderModelSelectedHook(params: {