mirror of https://github.com/openclaw/openclaw.git
208 lines
5.8 KiB
TypeScript
208 lines
5.8 KiB
TypeScript
import type { OpenClawConfig } from "openclaw/plugin-sdk/provider-auth";
|
|
import { normalizeOptionalSecretInput } from "openclaw/plugin-sdk/provider-auth";
|
|
import { resolveEnvApiKey } from "openclaw/plugin-sdk/provider-auth-runtime";
|
|
import {
|
|
hasConfiguredSecretInput,
|
|
normalizeResolvedSecretInputString,
|
|
} from "openclaw/plugin-sdk/secret-input";
|
|
import {
|
|
fetchWithSsrFGuard,
|
|
formatErrorMessage,
|
|
type SsrFPolicy,
|
|
} from "openclaw/plugin-sdk/ssrf-runtime";
|
|
import { resolveOllamaApiBase } from "./provider-models.js";
|
|
|
|
export type OllamaEmbeddingProvider = {
|
|
id: string;
|
|
model: string;
|
|
maxInputTokens?: number;
|
|
embedQuery: (text: string) => Promise<number[]>;
|
|
embedBatch: (texts: string[]) => Promise<number[][]>;
|
|
};
|
|
|
|
type OllamaEmbeddingOptions = {
|
|
config: OpenClawConfig;
|
|
agentDir?: string;
|
|
provider?: string;
|
|
remote?: {
|
|
baseUrl?: string;
|
|
apiKey?: unknown;
|
|
headers?: Record<string, string>;
|
|
};
|
|
model: string;
|
|
fallback?: string;
|
|
local?: unknown;
|
|
outputDimensionality?: number;
|
|
taskType?: unknown;
|
|
};
|
|
|
|
export type OllamaEmbeddingClient = {
|
|
baseUrl: string;
|
|
headers: Record<string, string>;
|
|
ssrfPolicy?: SsrFPolicy;
|
|
model: string;
|
|
embedBatch: (texts: string[]) => Promise<number[][]>;
|
|
};
|
|
|
|
type OllamaEmbeddingClientConfig = Omit<OllamaEmbeddingClient, "embedBatch">;
|
|
|
|
export const DEFAULT_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text";
|
|
|
|
function sanitizeAndNormalizeEmbedding(vec: number[]): number[] {
|
|
const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0));
|
|
const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0));
|
|
if (magnitude < 1e-10) {
|
|
return sanitized;
|
|
}
|
|
return sanitized.map((value) => value / magnitude);
|
|
}
|
|
|
|
function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined {
|
|
const trimmed = baseUrl.trim();
|
|
if (!trimmed) {
|
|
return undefined;
|
|
}
|
|
try {
|
|
const parsed = new URL(trimmed);
|
|
if (parsed.protocol !== "http:" && parsed.protocol !== "https:") {
|
|
return undefined;
|
|
}
|
|
return { allowedHostnames: [parsed.hostname] };
|
|
} catch {
|
|
return undefined;
|
|
}
|
|
}
|
|
|
|
async function withRemoteHttpResponse<T>(params: {
|
|
url: string;
|
|
init?: RequestInit;
|
|
ssrfPolicy?: SsrFPolicy;
|
|
onResponse: (response: Response) => Promise<T>;
|
|
}): Promise<T> {
|
|
const { response, release } = await fetchWithSsrFGuard({
|
|
url: params.url,
|
|
init: params.init,
|
|
policy: params.ssrfPolicy,
|
|
auditContext: "memory-remote",
|
|
});
|
|
try {
|
|
return await params.onResponse(response);
|
|
} finally {
|
|
await release();
|
|
}
|
|
}
|
|
|
|
function normalizeEmbeddingModel(model: string): string {
|
|
const trimmed = model.trim();
|
|
if (!trimmed) {
|
|
return DEFAULT_OLLAMA_EMBEDDING_MODEL;
|
|
}
|
|
return trimmed.startsWith("ollama/") ? trimmed.slice("ollama/".length) : trimmed;
|
|
}
|
|
|
|
function resolveMemorySecretInputString(params: {
|
|
value: unknown;
|
|
path: string;
|
|
}): string | undefined {
|
|
if (!hasConfiguredSecretInput(params.value)) {
|
|
return undefined;
|
|
}
|
|
return normalizeResolvedSecretInputString({
|
|
value: params.value,
|
|
path: params.path,
|
|
});
|
|
}
|
|
|
|
function resolveOllamaApiKey(options: OllamaEmbeddingOptions): string | undefined {
|
|
const remoteApiKey = resolveMemorySecretInputString({
|
|
value: options.remote?.apiKey,
|
|
path: "agents.*.memorySearch.remote.apiKey",
|
|
});
|
|
if (remoteApiKey) {
|
|
return remoteApiKey;
|
|
}
|
|
const providerApiKey = normalizeOptionalSecretInput(
|
|
options.config.models?.providers?.ollama?.apiKey,
|
|
);
|
|
if (providerApiKey) {
|
|
return providerApiKey;
|
|
}
|
|
return resolveEnvApiKey("ollama")?.apiKey;
|
|
}
|
|
|
|
function resolveOllamaEmbeddingClient(
|
|
options: OllamaEmbeddingOptions,
|
|
): OllamaEmbeddingClientConfig {
|
|
const providerConfig = options.config.models?.providers?.ollama;
|
|
const rawBaseUrl = options.remote?.baseUrl?.trim() || providerConfig?.baseUrl?.trim();
|
|
const baseUrl = resolveOllamaApiBase(rawBaseUrl);
|
|
const model = normalizeEmbeddingModel(options.model);
|
|
const headerOverrides = Object.assign({}, providerConfig?.headers, options.remote?.headers);
|
|
const headers: Record<string, string> = {
|
|
"Content-Type": "application/json",
|
|
...headerOverrides,
|
|
};
|
|
const apiKey = resolveOllamaApiKey(options);
|
|
if (apiKey) {
|
|
headers.Authorization = `Bearer ${apiKey}`;
|
|
}
|
|
return {
|
|
baseUrl,
|
|
headers,
|
|
ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl),
|
|
model,
|
|
};
|
|
}
|
|
|
|
export async function createOllamaEmbeddingProvider(
|
|
options: OllamaEmbeddingOptions,
|
|
): Promise<{ provider: OllamaEmbeddingProvider; client: OllamaEmbeddingClient }> {
|
|
const client = resolveOllamaEmbeddingClient(options);
|
|
const embedUrl = `${client.baseUrl.replace(/\/$/, "")}/api/embeddings`;
|
|
|
|
const embedOne = async (text: string): Promise<number[]> => {
|
|
const json = await withRemoteHttpResponse({
|
|
url: embedUrl,
|
|
ssrfPolicy: client.ssrfPolicy,
|
|
init: {
|
|
method: "POST",
|
|
headers: client.headers,
|
|
body: JSON.stringify({ model: client.model, prompt: text }),
|
|
},
|
|
onResponse: async (response) => {
|
|
if (!response.ok) {
|
|
throw new Error(`Ollama embeddings HTTP ${response.status}: ${await response.text()}`);
|
|
}
|
|
return (await response.json()) as { embedding?: number[] };
|
|
},
|
|
});
|
|
if (!Array.isArray(json.embedding)) {
|
|
throw new Error("Ollama embeddings response missing embedding[]");
|
|
}
|
|
return sanitizeAndNormalizeEmbedding(json.embedding);
|
|
};
|
|
|
|
const provider: OllamaEmbeddingProvider = {
|
|
id: "ollama",
|
|
model: client.model,
|
|
embedQuery: embedOne,
|
|
embedBatch: async (texts) => {
|
|
return await Promise.all(texts.map(embedOne));
|
|
},
|
|
};
|
|
|
|
return {
|
|
provider,
|
|
client: {
|
|
...client,
|
|
embedBatch: async (texts) => {
|
|
try {
|
|
return await provider.embedBatch(texts);
|
|
} catch (err) {
|
|
throw new Error(formatErrorMessage(err), { cause: err });
|
|
}
|
|
},
|
|
},
|
|
};
|
|
}
|