From fcb6dd911cffd873a8da5d8626c143f2dae29d2f Mon Sep 17 00:00:00 2001 From: Gustavo Madeira Santana Date: Sun, 15 Mar 2026 21:40:25 +0000 Subject: [PATCH] Memory: add embedding fallback policy --- .../embedding-manager-runtime.test.ts | 65 +++++++++++++++++++ .../embedding-manager-runtime.ts | 41 ++++-------- .../embedding-runtime-backends.ts | 24 +++++++ .../embedding-runtime-policy.test.ts | 44 +++++++++++++ .../embedding-runtime-policy.ts | 33 ++++++++++ .../embedding-runtime-registry.test.ts | 36 +++++++++- .../embedding-runtime-registry.ts | 20 +++++- .../runtime-backend-catalog.test.ts | 13 +++- src/extension-host/runtime-backend-catalog.ts | 2 + 9 files changed, 242 insertions(+), 36 deletions(-) create mode 100644 src/extension-host/embedding-manager-runtime.test.ts create mode 100644 src/extension-host/embedding-runtime-policy.test.ts create mode 100644 src/extension-host/embedding-runtime-policy.ts diff --git a/src/extension-host/embedding-manager-runtime.test.ts b/src/extension-host/embedding-manager-runtime.test.ts new file mode 100644 index 00000000000..95af8df8936 --- /dev/null +++ b/src/extension-host/embedding-manager-runtime.test.ts @@ -0,0 +1,65 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const createEmbeddingProvider = vi.hoisted(() => vi.fn()); +const resolveAgentDir = vi.hoisted(() => vi.fn(() => "/tmp/agent")); + +vi.mock("./embedding-runtime.js", () => ({ + createEmbeddingProvider, +})); + +vi.mock("../agents/agent-scope.js", () => ({ + resolveAgentDir, +})); + +describe("embedding-manager-runtime", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("uses the shared fallback policy for manager fallback activation", async () => { + createEmbeddingProvider.mockResolvedValue({ + provider: { + id: "ollama", + model: "nomic-embed-text", + embedQuery: vi.fn(), + embedBatch: vi.fn(), + }, + ollama: { kind: "ollama" }, + }); + + const { activateEmbeddingManagerFallbackProvider } = + await import("./embedding-manager-runtime.js"); + const result = await activateEmbeddingManagerFallbackProvider({ + cfg: {} as never, + agentId: "main", + settings: { + fallback: "ollama", + model: "text-embedding-3-small", + outputDimensionality: undefined, + remote: undefined, + local: undefined, + }, + state: { + provider: { + id: "openai", + model: "text-embedding-3-small", + embedQuery: vi.fn(), + embedBatch: vi.fn(), + }, + }, + reason: "forced fallback", + }); + + expect(createEmbeddingProvider).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "ollama", + model: "nomic-embed-text", + fallback: "none", + }), + ); + expect(result).toMatchObject({ + fallbackFrom: "openai", + fallbackReason: "forced fallback", + }); + }); +}); diff --git a/src/extension-host/embedding-manager-runtime.ts b/src/extension-host/embedding-manager-runtime.ts index deff3f18ddb..2f75caa05e0 100644 --- a/src/extension-host/embedding-manager-runtime.ts +++ b/src/extension-host/embedding-manager-runtime.ts @@ -1,11 +1,7 @@ import { resolveAgentDir } from "../agents/agent-scope.js"; import type { ResolvedMemorySearchConfig } from "../agents/memory-search.js"; import type { OpenClawConfig } from "../config/config.js"; -import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "../memory/embeddings-gemini.js"; -import { DEFAULT_MISTRAL_EMBEDDING_MODEL } from "../memory/embeddings-mistral.js"; -import { DEFAULT_OLLAMA_EMBEDDING_MODEL } from "../memory/embeddings-ollama.js"; -import { DEFAULT_OPENAI_EMBEDDING_MODEL } from "../memory/embeddings-openai.js"; -import { DEFAULT_VOYAGE_EMBEDDING_MODEL } from "../memory/embeddings-voyage.js"; +import { resolveExtensionHostEmbeddingFallbackPolicy } from "./embedding-runtime-policy.js"; import { createEmbeddingProvider, type EmbeddingProvider, @@ -72,18 +68,25 @@ export async function activateEmbeddingManagerFallbackProvider(params: { state: EmbeddingManagerRuntimeState; reason: string; }): Promise { - const fallback = params.settings.fallback; const { provider, fallbackFrom } = params.state; - if (!fallback || fallback === "none" || !provider || fallback === provider.id || fallbackFrom) { + if (!provider || fallbackFrom) { + return null; + } + const fallbackPolicy = resolveExtensionHostEmbeddingFallbackPolicy({ + requestedProvider: provider.id as EmbeddingProviderId, + fallback: params.settings.fallback, + configuredModel: params.settings.model, + }); + if (!fallbackPolicy) { return null; } const result = await createEmbeddingProvider({ config: params.cfg, agentDir: resolveAgentDir(params.cfg, params.agentId), - provider: fallback, + provider: fallbackPolicy.provider, remote: params.settings.remote, - model: resolveEmbeddingFallbackModel(fallback, params.settings.model), + model: fallbackPolicy.model, outputDimensionality: params.settings.outputDimensionality, fallback: "none", local: params.settings.local, @@ -100,23 +103,3 @@ export async function activateEmbeddingManagerFallbackProvider(params: { ollama: result.ollama, }; } - -function resolveEmbeddingFallbackModel( - fallback: Exclude, - configuredModel: string, -): string { - switch (fallback) { - case "gemini": - return DEFAULT_GEMINI_EMBEDDING_MODEL; - case "openai": - return DEFAULT_OPENAI_EMBEDDING_MODEL; - case "voyage": - return DEFAULT_VOYAGE_EMBEDDING_MODEL; - case "mistral": - return DEFAULT_MISTRAL_EMBEDDING_MODEL; - case "ollama": - return DEFAULT_OLLAMA_EMBEDDING_MODEL; - case "local": - return configuredModel; - } -} diff --git a/src/extension-host/embedding-runtime-backends.ts b/src/extension-host/embedding-runtime-backends.ts index 7de07359def..a0cdc8eb328 100644 --- a/src/extension-host/embedding-runtime-backends.ts +++ b/src/extension-host/embedding-runtime-backends.ts @@ -2,6 +2,11 @@ import type { EmbeddingProviderId } from "./embedding-runtime-types.js"; export const DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL = "hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf"; +export const DEFAULT_EXTENSION_HOST_OPENAI_EMBEDDING_MODEL = "text-embedding-3-small"; +export const DEFAULT_EXTENSION_HOST_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001"; +export const DEFAULT_EXTENSION_HOST_VOYAGE_EMBEDDING_MODEL = "voyage-4-large"; +export const DEFAULT_EXTENSION_HOST_MISTRAL_EMBEDDING_MODEL = "mistral-embed"; +export const DEFAULT_EXTENSION_HOST_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text"; export const EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS = [ "openai", @@ -21,3 +26,22 @@ export function isExtensionHostEmbeddingRuntimeBackendAutoSelectable( ): boolean { return backendId === "local" || EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS.includes(backendId); } + +export function resolveExtensionHostEmbeddingRuntimeDefaultModel( + backendId: EmbeddingProviderId, +): string { + switch (backendId) { + case "openai": + return DEFAULT_EXTENSION_HOST_OPENAI_EMBEDDING_MODEL; + case "gemini": + return DEFAULT_EXTENSION_HOST_GEMINI_EMBEDDING_MODEL; + case "voyage": + return DEFAULT_EXTENSION_HOST_VOYAGE_EMBEDDING_MODEL; + case "mistral": + return DEFAULT_EXTENSION_HOST_MISTRAL_EMBEDDING_MODEL; + case "ollama": + return DEFAULT_EXTENSION_HOST_OLLAMA_EMBEDDING_MODEL; + case "local": + return DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL; + } +} diff --git a/src/extension-host/embedding-runtime-policy.test.ts b/src/extension-host/embedding-runtime-policy.test.ts new file mode 100644 index 00000000000..7d79d8e6102 --- /dev/null +++ b/src/extension-host/embedding-runtime-policy.test.ts @@ -0,0 +1,44 @@ +import { describe, expect, it } from "vitest"; +import { + resolveExtensionHostEmbeddingFallbackModel, + resolveExtensionHostEmbeddingFallbackPolicy, +} from "./embedding-runtime-policy.js"; + +describe("embedding-runtime-policy", () => { + it("returns null when fallback is disabled or would repeat the requested provider", () => { + expect( + resolveExtensionHostEmbeddingFallbackPolicy({ + requestedProvider: "openai", + fallback: "none", + configuredModel: "configured-local-model", + }), + ).toBeNull(); + + expect( + resolveExtensionHostEmbeddingFallbackPolicy({ + requestedProvider: "openai", + fallback: "openai", + configuredModel: "configured-local-model", + }), + ).toBeNull(); + }); + + it("resolves host-owned fallback requests with provider-specific models", () => { + expect( + resolveExtensionHostEmbeddingFallbackPolicy({ + requestedProvider: "openai", + fallback: "gemini", + configuredModel: "configured-local-model", + }), + ).toEqual({ + provider: "gemini", + model: "gemini-embedding-001", + }); + }); + + it("keeps the configured model only for local fallback", () => { + expect(resolveExtensionHostEmbeddingFallbackModel("local", "configured-local-model")).toBe( + "configured-local-model", + ); + }); +}); diff --git a/src/extension-host/embedding-runtime-policy.ts b/src/extension-host/embedding-runtime-policy.ts new file mode 100644 index 00000000000..a6d46d9341b --- /dev/null +++ b/src/extension-host/embedding-runtime-policy.ts @@ -0,0 +1,33 @@ +import { resolveExtensionHostEmbeddingRuntimeDefaultModel } from "./embedding-runtime-backends.js"; +import type { + EmbeddingProviderFallback, + EmbeddingProviderId, + EmbeddingProviderRequest, +} from "./embedding-runtime-types.js"; + +export function resolveExtensionHostEmbeddingFallbackPolicy(params: { + requestedProvider: EmbeddingProviderRequest | EmbeddingProviderId; + fallback: EmbeddingProviderFallback | undefined; + configuredModel: string; +}): { + provider: EmbeddingProviderId; + model: string; +} | null { + const fallback = params.fallback; + if (!fallback || fallback === "none" || fallback === params.requestedProvider) { + return null; + } + return { + provider: fallback, + model: resolveExtensionHostEmbeddingFallbackModel(fallback, params.configuredModel), + }; +} + +export function resolveExtensionHostEmbeddingFallbackModel( + fallback: Exclude, + configuredModel: string, +): string { + return fallback === "local" + ? configuredModel + : resolveExtensionHostEmbeddingRuntimeDefaultModel(fallback); +} diff --git a/src/extension-host/embedding-runtime-registry.test.ts b/src/extension-host/embedding-runtime-registry.test.ts index 9a6fcf4cc62..aae75dace70 100644 --- a/src/extension-host/embedding-runtime-registry.test.ts +++ b/src/extension-host/embedding-runtime-registry.test.ts @@ -43,7 +43,7 @@ describe("extension host embedding runtime registry", () => { createGeminiEmbeddingProvider.mockResolvedValue({ provider: { id: "gemini", - model: "text-embedding-004", + model: "gemini-embedding-001", embedQuery: vi.fn(), embedBatch: vi.fn(), }, @@ -55,7 +55,7 @@ describe("extension host embedding runtime registry", () => { const result = await createExtensionHostEmbeddingProvider({ config: {} as never, provider: "auto", - model: "text-embedding-004", + model: "gemini-embedding-001", fallback: "none", }); @@ -77,4 +77,36 @@ describe("extension host embedding runtime registry", () => { expect(message).toContain('agents.defaults.memorySearch.provider = "gemini"'); expect(message).toContain('agents.defaults.memorySearch.provider = "openai"'); }); + + it("uses the shared fallback policy for explicit provider fallback requests", async () => { + createOpenAiEmbeddingProvider.mockRejectedValueOnce(new Error("openai failed")); + createGeminiEmbeddingProvider.mockResolvedValueOnce({ + provider: { + id: "gemini", + model: "gemini-embedding-001", + embedQuery: vi.fn(), + embedBatch: vi.fn(), + }, + client: { kind: "gemini" }, + }); + + const { createExtensionHostEmbeddingProvider } = + await import("./embedding-runtime-registry.js"); + const result = await createExtensionHostEmbeddingProvider({ + config: {} as never, + provider: "openai", + model: "text-embedding-3-small", + fallback: "gemini", + }); + + expect(createGeminiEmbeddingProvider).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "gemini", + model: "gemini-embedding-001", + fallback: "none", + }), + ); + expect(result.fallbackFrom).toBe("openai"); + expect(result.provider?.id).toBe("gemini"); + }); }); diff --git a/src/extension-host/embedding-runtime-registry.ts b/src/extension-host/embedding-runtime-registry.ts index 0b7ffa5735c..096f5989bb9 100644 --- a/src/extension-host/embedding-runtime-registry.ts +++ b/src/extension-host/embedding-runtime-registry.ts @@ -26,6 +26,7 @@ import { import { importNodeLlamaCpp } from "../memory/node-llama.js"; import { resolveUserPath } from "../utils.js"; import { DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL } from "./embedding-runtime-backends.js"; +import { resolveExtensionHostEmbeddingFallbackPolicy } from "./embedding-runtime-policy.js"; import type { EmbeddingProvider, EmbeddingProviderId, @@ -220,9 +221,22 @@ export async function createExtensionHostEmbeddingProvider( return { ...primary, requestedProvider }; } catch (primaryErr) { const reason = formatExtensionHostPrimaryEmbeddingError(primaryErr, requestedProvider); - if (fallback && fallback !== "none" && fallback !== requestedProvider) { + const fallbackPolicy = resolveExtensionHostEmbeddingFallbackPolicy({ + requestedProvider, + fallback, + configuredModel: options.model, + }); + if (fallbackPolicy) { try { - const fallbackResult = await createExtensionHostEmbeddingProviderById(fallback, options); + const fallbackResult = await createExtensionHostEmbeddingProviderById( + fallbackPolicy.provider, + { + ...options, + provider: fallbackPolicy.provider, + model: fallbackPolicy.model, + fallback: "none", + }, + ); return { ...fallbackResult, requestedProvider, @@ -231,7 +245,7 @@ export async function createExtensionHostEmbeddingProvider( }; } catch (fallbackErr) { const fallbackReason = formatErrorMessage(fallbackErr); - const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`; + const combinedReason = `${reason}\n\nFallback to ${fallbackPolicy.provider} failed: ${fallbackReason}`; if ( isMissingExtensionHostEmbeddingApiKeyError(primaryErr) && isMissingExtensionHostEmbeddingApiKeyError(fallbackErr) diff --git a/src/extension-host/runtime-backend-catalog.test.ts b/src/extension-host/runtime-backend-catalog.test.ts index ce45d62291a..82316524a10 100644 --- a/src/extension-host/runtime-backend-catalog.test.ts +++ b/src/extension-host/runtime-backend-catalog.test.ts @@ -13,6 +13,9 @@ vi.mock("./embedding-runtime-backends.js", () => ({ isExtensionHostEmbeddingRuntimeBackendAutoSelectable: vi.fn( (backendId: string) => backendId !== "ollama", ), + resolveExtensionHostEmbeddingRuntimeDefaultModel: vi.fn((backendId: string) => + backendId === "local" ? "local-model.gguf" : `${backendId}-default-model`, + ), })); vi.mock("./media-runtime-backends.js", () => ({ @@ -74,8 +77,14 @@ describe("runtime-backend-catalog", () => { ).toBe(true); expect(entries.every((entry) => entry.subsystemId === "embedding")).toBe(true); expect(entries[0]?.capabilities).toContain("embed.query"); - expect(entries[0]?.metadata).toMatchObject({ autoSelectable: true }); - expect(entries.at(-1)?.metadata).toMatchObject({ autoSelectable: false }); + expect(entries[0]?.metadata).toMatchObject({ + autoSelectable: true, + defaultModel: "local-model.gguf", + }); + expect(entries.at(-1)?.metadata).toMatchObject({ + autoSelectable: false, + defaultModel: "ollama-default-model", + }); }); it("splits media providers into subsystem-specific runtime-backend catalog entries", async () => { diff --git a/src/extension-host/runtime-backend-catalog.ts b/src/extension-host/runtime-backend-catalog.ts index 4416cd0500c..488cd160bf3 100644 --- a/src/extension-host/runtime-backend-catalog.ts +++ b/src/extension-host/runtime-backend-catalog.ts @@ -1,6 +1,7 @@ import type { TtsProvider } from "../config/types.tts.js"; import type { MediaUnderstandingCapability } from "../media-understanding/types.js"; import { + resolveExtensionHostEmbeddingRuntimeDefaultModel, EXTENSION_HOST_EMBEDDING_RUNTIME_BACKEND_IDS, isExtensionHostEmbeddingRuntimeBackendAutoSelectable, } from "./embedding-runtime-backends.js"; @@ -77,6 +78,7 @@ export function listExtensionHostEmbeddingRuntimeBackendCatalogEntries(): readon capabilities: ["embed.query", "embed.batch"], metadata: { autoSelectable: isExtensionHostEmbeddingRuntimeBackendAutoSelectable(backendId), + defaultModel: resolveExtensionHostEmbeddingRuntimeDefaultModel(backendId), }, })); }