mirror of https://github.com/openclaw/openclaw.git
Memory: add embedding fallback policy
This commit is contained in:
parent
3afa2508be
commit
fcb6dd911c
|
|
@ -0,0 +1,65 @@
|
|||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const createEmbeddingProvider = vi.hoisted(() => vi.fn());
|
||||
const resolveAgentDir = vi.hoisted(() => vi.fn(() => "/tmp/agent"));
|
||||
|
||||
vi.mock("./embedding-runtime.js", () => ({
|
||||
createEmbeddingProvider,
|
||||
}));
|
||||
|
||||
vi.mock("../agents/agent-scope.js", () => ({
|
||||
resolveAgentDir,
|
||||
}));
|
||||
|
||||
describe("embedding-manager-runtime", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("uses the shared fallback policy for manager fallback activation", async () => {
|
||||
createEmbeddingProvider.mockResolvedValue({
|
||||
provider: {
|
||||
id: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
embedQuery: vi.fn(),
|
||||
embedBatch: vi.fn(),
|
||||
},
|
||||
ollama: { kind: "ollama" },
|
||||
});
|
||||
|
||||
const { activateEmbeddingManagerFallbackProvider } =
|
||||
await import("./embedding-manager-runtime.js");
|
||||
const result = await activateEmbeddingManagerFallbackProvider({
|
||||
cfg: {} as never,
|
||||
agentId: "main",
|
||||
settings: {
|
||||
fallback: "ollama",
|
||||
model: "text-embedding-3-small",
|
||||
outputDimensionality: undefined,
|
||||
remote: undefined,
|
||||
local: undefined,
|
||||
},
|
||||
state: {
|
||||
provider: {
|
||||
id: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
embedQuery: vi.fn(),
|
||||
embedBatch: vi.fn(),
|
||||
},
|
||||
},
|
||||
reason: "forced fallback",
|
||||
});
|
||||
|
||||
expect(createEmbeddingProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: "ollama",
|
||||
model: "nomic-embed-text",
|
||||
fallback: "none",
|
||||
}),
|
||||
);
|
||||
expect(result).toMatchObject({
|
||||
fallbackFrom: "openai",
|
||||
fallbackReason: "forced fallback",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,11 +1,7 @@
|
|||
import { resolveAgentDir } from "../agents/agent-scope.js";
|
||||
import type { ResolvedMemorySearchConfig } from "../agents/memory-search.js";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "../memory/embeddings-gemini.js";
|
||||
import { DEFAULT_MISTRAL_EMBEDDING_MODEL } from "../memory/embeddings-mistral.js";
|
||||
import { DEFAULT_OLLAMA_EMBEDDING_MODEL } from "../memory/embeddings-ollama.js";
|
||||
import { DEFAULT_OPENAI_EMBEDDING_MODEL } from "../memory/embeddings-openai.js";
|
||||
import { DEFAULT_VOYAGE_EMBEDDING_MODEL } from "../memory/embeddings-voyage.js";
|
||||
import { resolveExtensionHostEmbeddingFallbackPolicy } from "./embedding-runtime-policy.js";
|
||||
import {
|
||||
createEmbeddingProvider,
|
||||
type EmbeddingProvider,
|
||||
|
|
@ -72,18 +68,25 @@ export async function activateEmbeddingManagerFallbackProvider(params: {
|
|||
state: EmbeddingManagerRuntimeState;
|
||||
reason: string;
|
||||
}): Promise<EmbeddingManagerFallbackActivation | null> {
|
||||
const fallback = params.settings.fallback;
|
||||
const { provider, fallbackFrom } = params.state;
|
||||
if (!fallback || fallback === "none" || !provider || fallback === provider.id || fallbackFrom) {
|
||||
if (!provider || fallbackFrom) {
|
||||
return null;
|
||||
}
|
||||
const fallbackPolicy = resolveExtensionHostEmbeddingFallbackPolicy({
|
||||
requestedProvider: provider.id as EmbeddingProviderId,
|
||||
fallback: params.settings.fallback,
|
||||
configuredModel: params.settings.model,
|
||||
});
|
||||
if (!fallbackPolicy) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const result = await createEmbeddingProvider({
|
||||
config: params.cfg,
|
||||
agentDir: resolveAgentDir(params.cfg, params.agentId),
|
||||
provider: fallback,
|
||||
provider: fallbackPolicy.provider,
|
||||
remote: params.settings.remote,
|
||||
model: resolveEmbeddingFallbackModel(fallback, params.settings.model),
|
||||
model: fallbackPolicy.model,
|
||||
outputDimensionality: params.settings.outputDimensionality,
|
||||
fallback: "none",
|
||||
local: params.settings.local,
|
||||
|
|
@ -100,23 +103,3 @@ export async function activateEmbeddingManagerFallbackProvider(params: {
|
|||
ollama: result.ollama,
|
||||
};
|
||||
}
|
||||
|
||||
function resolveEmbeddingFallbackModel(
|
||||
fallback: Exclude<ResolvedMemorySearchConfig["fallback"], undefined | "none">,
|
||||
configuredModel: string,
|
||||
): string {
|
||||
switch (fallback) {
|
||||
case "gemini":
|
||||
return DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
case "openai":
|
||||
return DEFAULT_OPENAI_EMBEDDING_MODEL;
|
||||
case "voyage":
|
||||
return DEFAULT_VOYAGE_EMBEDDING_MODEL;
|
||||
case "mistral":
|
||||
return DEFAULT_MISTRAL_EMBEDDING_MODEL;
|
||||
case "ollama":
|
||||
return DEFAULT_OLLAMA_EMBEDDING_MODEL;
|
||||
case "local":
|
||||
return configuredModel;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,11 @@ import type { EmbeddingProviderId } from "./embedding-runtime-types.js";
|
|||
|
||||
export const DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL =
|
||||
"hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf";
|
||||
export const DEFAULT_EXTENSION_HOST_OPENAI_EMBEDDING_MODEL = "text-embedding-3-small";
|
||||
export const DEFAULT_EXTENSION_HOST_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001";
|
||||
export const DEFAULT_EXTENSION_HOST_VOYAGE_EMBEDDING_MODEL = "voyage-4-large";
|
||||
export const DEFAULT_EXTENSION_HOST_MISTRAL_EMBEDDING_MODEL = "mistral-embed";
|
||||
export const DEFAULT_EXTENSION_HOST_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text";
|
||||
|
||||
export const EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS = [
|
||||
"openai",
|
||||
|
|
@ -21,3 +26,22 @@ export function isExtensionHostEmbeddingRuntimeBackendAutoSelectable(
|
|||
): boolean {
|
||||
return backendId === "local" || EXTENSION_HOST_REMOTE_EMBEDDING_PROVIDER_IDS.includes(backendId);
|
||||
}
|
||||
|
||||
export function resolveExtensionHostEmbeddingRuntimeDefaultModel(
|
||||
backendId: EmbeddingProviderId,
|
||||
): string {
|
||||
switch (backendId) {
|
||||
case "openai":
|
||||
return DEFAULT_EXTENSION_HOST_OPENAI_EMBEDDING_MODEL;
|
||||
case "gemini":
|
||||
return DEFAULT_EXTENSION_HOST_GEMINI_EMBEDDING_MODEL;
|
||||
case "voyage":
|
||||
return DEFAULT_EXTENSION_HOST_VOYAGE_EMBEDDING_MODEL;
|
||||
case "mistral":
|
||||
return DEFAULT_EXTENSION_HOST_MISTRAL_EMBEDDING_MODEL;
|
||||
case "ollama":
|
||||
return DEFAULT_EXTENSION_HOST_OLLAMA_EMBEDDING_MODEL;
|
||||
case "local":
|
||||
return DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
resolveExtensionHostEmbeddingFallbackModel,
|
||||
resolveExtensionHostEmbeddingFallbackPolicy,
|
||||
} from "./embedding-runtime-policy.js";
|
||||
|
||||
describe("embedding-runtime-policy", () => {
|
||||
it("returns null when fallback is disabled or would repeat the requested provider", () => {
|
||||
expect(
|
||||
resolveExtensionHostEmbeddingFallbackPolicy({
|
||||
requestedProvider: "openai",
|
||||
fallback: "none",
|
||||
configuredModel: "configured-local-model",
|
||||
}),
|
||||
).toBeNull();
|
||||
|
||||
expect(
|
||||
resolveExtensionHostEmbeddingFallbackPolicy({
|
||||
requestedProvider: "openai",
|
||||
fallback: "openai",
|
||||
configuredModel: "configured-local-model",
|
||||
}),
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
it("resolves host-owned fallback requests with provider-specific models", () => {
|
||||
expect(
|
||||
resolveExtensionHostEmbeddingFallbackPolicy({
|
||||
requestedProvider: "openai",
|
||||
fallback: "gemini",
|
||||
configuredModel: "configured-local-model",
|
||||
}),
|
||||
).toEqual({
|
||||
provider: "gemini",
|
||||
model: "gemini-embedding-001",
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps the configured model only for local fallback", () => {
|
||||
expect(resolveExtensionHostEmbeddingFallbackModel("local", "configured-local-model")).toBe(
|
||||
"configured-local-model",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
import { resolveExtensionHostEmbeddingRuntimeDefaultModel } from "./embedding-runtime-backends.js";
|
||||
import type {
|
||||
EmbeddingProviderFallback,
|
||||
EmbeddingProviderId,
|
||||
EmbeddingProviderRequest,
|
||||
} from "./embedding-runtime-types.js";
|
||||
|
||||
export function resolveExtensionHostEmbeddingFallbackPolicy(params: {
|
||||
requestedProvider: EmbeddingProviderRequest | EmbeddingProviderId;
|
||||
fallback: EmbeddingProviderFallback | undefined;
|
||||
configuredModel: string;
|
||||
}): {
|
||||
provider: EmbeddingProviderId;
|
||||
model: string;
|
||||
} | null {
|
||||
const fallback = params.fallback;
|
||||
if (!fallback || fallback === "none" || fallback === params.requestedProvider) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
provider: fallback,
|
||||
model: resolveExtensionHostEmbeddingFallbackModel(fallback, params.configuredModel),
|
||||
};
|
||||
}
|
||||
|
||||
export function resolveExtensionHostEmbeddingFallbackModel(
|
||||
fallback: Exclude<EmbeddingProviderFallback, "none">,
|
||||
configuredModel: string,
|
||||
): string {
|
||||
return fallback === "local"
|
||||
? configuredModel
|
||||
: resolveExtensionHostEmbeddingRuntimeDefaultModel(fallback);
|
||||
}
|
||||
|
|
@ -43,7 +43,7 @@ describe("extension host embedding runtime registry", () => {
|
|||
createGeminiEmbeddingProvider.mockResolvedValue({
|
||||
provider: {
|
||||
id: "gemini",
|
||||
model: "text-embedding-004",
|
||||
model: "gemini-embedding-001",
|
||||
embedQuery: vi.fn(),
|
||||
embedBatch: vi.fn(),
|
||||
},
|
||||
|
|
@ -55,7 +55,7 @@ describe("extension host embedding runtime registry", () => {
|
|||
const result = await createExtensionHostEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "auto",
|
||||
model: "text-embedding-004",
|
||||
model: "gemini-embedding-001",
|
||||
fallback: "none",
|
||||
});
|
||||
|
||||
|
|
@ -77,4 +77,36 @@ describe("extension host embedding runtime registry", () => {
|
|||
expect(message).toContain('agents.defaults.memorySearch.provider = "gemini"');
|
||||
expect(message).toContain('agents.defaults.memorySearch.provider = "openai"');
|
||||
});
|
||||
|
||||
it("uses the shared fallback policy for explicit provider fallback requests", async () => {
|
||||
createOpenAiEmbeddingProvider.mockRejectedValueOnce(new Error("openai failed"));
|
||||
createGeminiEmbeddingProvider.mockResolvedValueOnce({
|
||||
provider: {
|
||||
id: "gemini",
|
||||
model: "gemini-embedding-001",
|
||||
embedQuery: vi.fn(),
|
||||
embedBatch: vi.fn(),
|
||||
},
|
||||
client: { kind: "gemini" },
|
||||
});
|
||||
|
||||
const { createExtensionHostEmbeddingProvider } =
|
||||
await import("./embedding-runtime-registry.js");
|
||||
const result = await createExtensionHostEmbeddingProvider({
|
||||
config: {} as never,
|
||||
provider: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "gemini",
|
||||
});
|
||||
|
||||
expect(createGeminiEmbeddingProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
provider: "gemini",
|
||||
model: "gemini-embedding-001",
|
||||
fallback: "none",
|
||||
}),
|
||||
);
|
||||
expect(result.fallbackFrom).toBe("openai");
|
||||
expect(result.provider?.id).toBe("gemini");
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ import {
|
|||
import { importNodeLlamaCpp } from "../memory/node-llama.js";
|
||||
import { resolveUserPath } from "../utils.js";
|
||||
import { DEFAULT_EXTENSION_HOST_LOCAL_EMBEDDING_MODEL } from "./embedding-runtime-backends.js";
|
||||
import { resolveExtensionHostEmbeddingFallbackPolicy } from "./embedding-runtime-policy.js";
|
||||
import type {
|
||||
EmbeddingProvider,
|
||||
EmbeddingProviderId,
|
||||
|
|
@ -220,9 +221,22 @@ export async function createExtensionHostEmbeddingProvider(
|
|||
return { ...primary, requestedProvider };
|
||||
} catch (primaryErr) {
|
||||
const reason = formatExtensionHostPrimaryEmbeddingError(primaryErr, requestedProvider);
|
||||
if (fallback && fallback !== "none" && fallback !== requestedProvider) {
|
||||
const fallbackPolicy = resolveExtensionHostEmbeddingFallbackPolicy({
|
||||
requestedProvider,
|
||||
fallback,
|
||||
configuredModel: options.model,
|
||||
});
|
||||
if (fallbackPolicy) {
|
||||
try {
|
||||
const fallbackResult = await createExtensionHostEmbeddingProviderById(fallback, options);
|
||||
const fallbackResult = await createExtensionHostEmbeddingProviderById(
|
||||
fallbackPolicy.provider,
|
||||
{
|
||||
...options,
|
||||
provider: fallbackPolicy.provider,
|
||||
model: fallbackPolicy.model,
|
||||
fallback: "none",
|
||||
},
|
||||
);
|
||||
return {
|
||||
...fallbackResult,
|
||||
requestedProvider,
|
||||
|
|
@ -231,7 +245,7 @@ export async function createExtensionHostEmbeddingProvider(
|
|||
};
|
||||
} catch (fallbackErr) {
|
||||
const fallbackReason = formatErrorMessage(fallbackErr);
|
||||
const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`;
|
||||
const combinedReason = `${reason}\n\nFallback to ${fallbackPolicy.provider} failed: ${fallbackReason}`;
|
||||
if (
|
||||
isMissingExtensionHostEmbeddingApiKeyError(primaryErr) &&
|
||||
isMissingExtensionHostEmbeddingApiKeyError(fallbackErr)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ vi.mock("./embedding-runtime-backends.js", () => ({
|
|||
isExtensionHostEmbeddingRuntimeBackendAutoSelectable: vi.fn(
|
||||
(backendId: string) => backendId !== "ollama",
|
||||
),
|
||||
resolveExtensionHostEmbeddingRuntimeDefaultModel: vi.fn((backendId: string) =>
|
||||
backendId === "local" ? "local-model.gguf" : `${backendId}-default-model`,
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("./media-runtime-backends.js", () => ({
|
||||
|
|
@ -74,8 +77,14 @@ describe("runtime-backend-catalog", () => {
|
|||
).toBe(true);
|
||||
expect(entries.every((entry) => entry.subsystemId === "embedding")).toBe(true);
|
||||
expect(entries[0]?.capabilities).toContain("embed.query");
|
||||
expect(entries[0]?.metadata).toMatchObject({ autoSelectable: true });
|
||||
expect(entries.at(-1)?.metadata).toMatchObject({ autoSelectable: false });
|
||||
expect(entries[0]?.metadata).toMatchObject({
|
||||
autoSelectable: true,
|
||||
defaultModel: "local-model.gguf",
|
||||
});
|
||||
expect(entries.at(-1)?.metadata).toMatchObject({
|
||||
autoSelectable: false,
|
||||
defaultModel: "ollama-default-model",
|
||||
});
|
||||
});
|
||||
|
||||
it("splits media providers into subsystem-specific runtime-backend catalog entries", async () => {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import type { TtsProvider } from "../config/types.tts.js";
|
||||
import type { MediaUnderstandingCapability } from "../media-understanding/types.js";
|
||||
import {
|
||||
resolveExtensionHostEmbeddingRuntimeDefaultModel,
|
||||
EXTENSION_HOST_EMBEDDING_RUNTIME_BACKEND_IDS,
|
||||
isExtensionHostEmbeddingRuntimeBackendAutoSelectable,
|
||||
} from "./embedding-runtime-backends.js";
|
||||
|
|
@ -77,6 +78,7 @@ export function listExtensionHostEmbeddingRuntimeBackendCatalogEntries(): readon
|
|||
capabilities: ["embed.query", "embed.batch"],
|
||||
metadata: {
|
||||
autoSelectable: isExtensionHostEmbeddingRuntimeBackendAutoSelectable(backendId),
|
||||
defaultModel: resolveExtensionHostEmbeddingRuntimeDefaultModel(backendId),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue