openclaw/extensions/ollama/src/setup.ts

522 lines
16 KiB
TypeScript

import type { OpenClawConfig } from "openclaw/plugin-sdk/provider-auth";
import { upsertAuthProfileWithLock } from "openclaw/plugin-sdk/provider-auth";
import { applyAgentDefaultModelPrimary } from "openclaw/plugin-sdk/provider-onboard";
import type { RuntimeEnv } from "openclaw/plugin-sdk/runtime";
import { WizardCancelledError, type WizardPrompter } from "openclaw/plugin-sdk/setup";
import { fetchWithSsrFGuard } from "openclaw/plugin-sdk/ssrf-runtime";
import { OLLAMA_DEFAULT_BASE_URL, OLLAMA_DEFAULT_MODEL } from "./defaults.js";
import {
buildOllamaBaseUrlSsrFPolicy,
buildOllamaModelDefinition,
enrichOllamaModelsWithContext,
fetchOllamaModels,
resolveOllamaApiBase,
type OllamaModelWithContext,
} from "./provider-models.js";
const OLLAMA_SUGGESTED_MODELS_LOCAL = [OLLAMA_DEFAULT_MODEL];
const OLLAMA_SUGGESTED_MODELS_CLOUD = ["kimi-k2.5:cloud", "minimax-m2.5:cloud", "glm-5:cloud"];
const OLLAMA_CONTEXT_ENRICH_LIMIT = 200;
type OllamaMode = "remote" | "local";
type OllamaSetupOptions = {
customBaseUrl?: string;
customModelId?: string;
};
type OllamaCloudAuthResult = {
signedIn: boolean;
signinUrl?: string;
};
type ProviderConfig = {
baseUrl: string;
api: "ollama";
models: ReturnType<typeof buildOllamaModelDefinition>[];
};
function normalizeOllamaModelName(value: string | undefined): string | undefined {
const trimmed = value?.trim();
if (!trimmed) {
return undefined;
}
if (trimmed.toLowerCase().startsWith("ollama/")) {
const normalized = trimmed.slice("ollama/".length).trim();
return normalized || undefined;
}
return trimmed;
}
function isOllamaCloudModel(modelName: string | undefined): boolean {
return Boolean(modelName?.trim().toLowerCase().endsWith(":cloud"));
}
function formatOllamaPullStatus(status: string): { text: string; hidePercent: boolean } {
const trimmed = status.trim();
const partStatusMatch = trimmed.match(/^([a-z-]+)\s+(?:sha256:)?[a-f0-9]{8,}$/i);
if (partStatusMatch) {
return { text: `${partStatusMatch[1]} part`, hidePercent: false };
}
if (/^verifying\b.*\bdigest\b/i.test(trimmed)) {
return { text: "verifying digest", hidePercent: true };
}
return { text: trimmed, hidePercent: false };
}
async function checkOllamaCloudAuth(baseUrl: string): Promise<OllamaCloudAuthResult> {
try {
const apiBase = resolveOllamaApiBase(baseUrl);
const { response, release } = await fetchWithSsrFGuard({
url: `${apiBase}/api/me`,
init: {
method: "POST",
signal: AbortSignal.timeout(5000),
},
policy: buildOllamaBaseUrlSsrFPolicy(apiBase),
auditContext: "ollama-setup.me",
});
try {
if (response.status === 401) {
const data = (await response.json()) as { signin_url?: string };
return { signedIn: false, signinUrl: data.signin_url };
}
if (!response.ok) {
return { signedIn: false };
}
return { signedIn: true };
} finally {
await release();
}
} catch {
return { signedIn: false };
}
}
type OllamaPullChunk = {
status?: string;
total?: number;
completed?: number;
error?: string;
};
type OllamaPullResult = { ok: true } | { ok: false; message: string };
async function pullOllamaModelCore(params: {
baseUrl: string;
modelName: string;
onStatus?: (status: string, percent: number | null) => void;
}): Promise<OllamaPullResult> {
const baseUrl = resolveOllamaApiBase(params.baseUrl);
const modelName = normalizeOllamaModelName(params.modelName) ?? params.modelName.trim();
try {
const { response, release } = await fetchWithSsrFGuard({
url: `${baseUrl}/api/pull`,
init: {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ name: modelName }),
},
policy: buildOllamaBaseUrlSsrFPolicy(baseUrl),
auditContext: "ollama-setup.pull",
});
try {
if (!response.ok) {
return { ok: false, message: `Failed to download ${modelName} (HTTP ${response.status})` };
}
if (!response.body) {
return { ok: false, message: `Failed to download ${modelName} (no response body)` };
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
const layers = new Map<string, { total: number; completed: number }>();
const parseLine = (line: string): OllamaPullResult => {
const trimmed = line.trim();
if (!trimmed) {
return { ok: true };
}
try {
const chunk = JSON.parse(trimmed) as OllamaPullChunk;
if (chunk.error) {
return { ok: false, message: `Download failed: ${chunk.error}` };
}
if (!chunk.status) {
return { ok: true };
}
if (chunk.total && chunk.completed !== undefined) {
layers.set(chunk.status, { total: chunk.total, completed: chunk.completed });
let totalSum = 0;
let completedSum = 0;
for (const layer of layers.values()) {
totalSum += layer.total;
completedSum += layer.completed;
}
params.onStatus?.(
chunk.status,
totalSum > 0 ? Math.round((completedSum / totalSum) * 100) : null,
);
} else {
params.onStatus?.(chunk.status, null);
}
} catch {
// Ignore malformed streaming lines from Ollama.
}
return { ok: true };
};
for (;;) {
const { done, value } = await reader.read();
if (done) {
break;
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() ?? "";
for (const line of lines) {
const parsed = parseLine(line);
if (!parsed.ok) {
return parsed;
}
}
}
const trailing = buffer.trim();
if (trailing) {
const parsed = parseLine(trailing);
if (!parsed.ok) {
return parsed;
}
}
return { ok: true };
} finally {
await release();
}
} catch (err) {
const reason = err instanceof Error ? err.message : String(err);
return { ok: false, message: `Failed to download ${modelName}: ${reason}` };
}
}
async function pullOllamaModel(
baseUrl: string,
modelName: string,
prompter: WizardPrompter,
): Promise<boolean> {
const spinner = prompter.progress(`Downloading ${modelName}...`);
const result = await pullOllamaModelCore({
baseUrl,
modelName,
onStatus: (status, percent) => {
const displayStatus = formatOllamaPullStatus(status);
if (displayStatus.hidePercent) {
spinner.update(`Downloading ${modelName} - ${displayStatus.text}`);
} else {
spinner.update(`Downloading ${modelName} - ${displayStatus.text} - ${percent ?? 0}%`);
}
},
});
if (!result.ok) {
spinner.stop(result.message);
return false;
}
spinner.stop(`Downloaded ${modelName}`);
return true;
}
async function pullOllamaModelNonInteractive(
baseUrl: string,
modelName: string,
runtime: RuntimeEnv,
): Promise<boolean> {
runtime.log(`Downloading ${modelName}...`);
const result = await pullOllamaModelCore({ baseUrl, modelName });
if (!result.ok) {
runtime.error(result.message);
return false;
}
runtime.log(`Downloaded ${modelName}`);
return true;
}
function buildOllamaModelsConfig(
modelNames: string[],
discoveredModelsByName?: Map<string, OllamaModelWithContext>,
) {
return modelNames.map((name) =>
buildOllamaModelDefinition(name, discoveredModelsByName?.get(name)?.contextWindow),
);
}
function applyOllamaProviderConfig(
cfg: OpenClawConfig,
baseUrl: string,
modelNames: string[],
discoveredModelsByName?: Map<string, OllamaModelWithContext>,
): OpenClawConfig {
return {
...cfg,
models: {
...cfg.models,
mode: cfg.models?.mode ?? "merge",
providers: {
...cfg.models?.providers,
ollama: {
baseUrl,
api: "ollama",
// pragma: allowlist secret
apiKey: "OLLAMA_API_KEY",
models: buildOllamaModelsConfig(modelNames, discoveredModelsByName),
},
},
},
};
}
async function storeOllamaCredential(agentDir?: string): Promise<void> {
await upsertAuthProfileWithLock({
profileId: "ollama:default",
credential: { type: "api_key", provider: "ollama", key: "ollama-local" },
agentDir,
});
}
export async function buildOllamaProvider(
configuredBaseUrl?: string,
opts?: { quiet?: boolean },
): Promise<ProviderConfig> {
const apiBase = resolveOllamaApiBase(configuredBaseUrl);
const { reachable, models } = await fetchOllamaModels(apiBase);
if (!reachable && !opts?.quiet) {
console.warn(`Ollama could not be reached at ${apiBase}.`);
}
const discovered = await enrichOllamaModelsWithContext(
apiBase,
models.slice(0, OLLAMA_CONTEXT_ENRICH_LIMIT),
);
return {
baseUrl: apiBase,
api: "ollama",
models: discovered.map((model) => buildOllamaModelDefinition(model.name, model.contextWindow)),
};
}
export async function promptAndConfigureOllama(params: {
cfg: OpenClawConfig;
prompter: WizardPrompter;
isRemote: boolean;
openUrl: (url: string) => Promise<void>;
}): Promise<{ config: OpenClawConfig }> {
const baseUrlRaw = await params.prompter.text({
message: "Ollama base URL",
initialValue: OLLAMA_DEFAULT_BASE_URL,
placeholder: OLLAMA_DEFAULT_BASE_URL,
validate: (value) => (value?.trim() ? undefined : "Required"),
});
const baseUrl = resolveOllamaApiBase(
String(baseUrlRaw ?? "")
.trim()
.replace(/\/+$/, ""),
);
const { reachable, models } = await fetchOllamaModels(baseUrl);
if (!reachable) {
await params.prompter.note(
[
`Ollama could not be reached at ${baseUrl}.`,
"Download it at https://ollama.com/download",
"",
"Start Ollama and re-run setup.",
].join("\n"),
"Ollama",
);
throw new WizardCancelledError("Ollama not reachable");
}
const enrichedModels = await enrichOllamaModelsWithContext(
baseUrl,
models.slice(0, OLLAMA_CONTEXT_ENRICH_LIMIT),
);
const discoveredModelsByName = new Map(enrichedModels.map((model) => [model.name, model]));
const modelNames = models.map((model) => model.name);
const mode = (await params.prompter.select({
message: "Ollama mode",
options: [
{ value: "remote", label: "Cloud + Local", hint: "Ollama cloud models + local models" },
{ value: "local", label: "Local", hint: "Local models only" },
],
})) as OllamaMode;
let cloudAuthVerified = false;
if (mode === "remote") {
const authResult = await checkOllamaCloudAuth(baseUrl);
if (!authResult.signedIn) {
if (authResult.signinUrl) {
if (!params.isRemote) {
await params.openUrl(authResult.signinUrl);
}
await params.prompter.note(
["Sign in to Ollama Cloud:", authResult.signinUrl].join("\n"),
"Ollama Cloud",
);
const confirmed = await params.prompter.confirm({ message: "Have you signed in?" });
if (!confirmed) {
throw new WizardCancelledError("Ollama cloud sign-in cancelled");
}
if (!(await checkOllamaCloudAuth(baseUrl)).signedIn) {
throw new WizardCancelledError("Ollama cloud sign-in required");
}
cloudAuthVerified = true;
} else {
await params.prompter.note(
[
"Could not verify Ollama Cloud authentication.",
"Cloud models may not work until you sign in at https://ollama.com.",
].join("\n"),
"Ollama Cloud",
);
if (!(await params.prompter.confirm({ message: "Continue without cloud auth?" }))) {
throw new WizardCancelledError("Ollama cloud auth could not be verified");
}
}
} else {
cloudAuthVerified = true;
}
}
const suggestedModels =
mode === "local" || !cloudAuthVerified
? OLLAMA_SUGGESTED_MODELS_LOCAL
: OLLAMA_SUGGESTED_MODELS_CLOUD;
const orderedModelNames = [
...suggestedModels,
...modelNames.filter((name) => !suggestedModels.includes(name)),
];
return {
config: applyOllamaProviderConfig(
params.cfg,
baseUrl,
orderedModelNames,
discoveredModelsByName,
),
};
}
export async function configureOllamaNonInteractive(params: {
nextConfig: OpenClawConfig;
opts: OllamaSetupOptions;
runtime: RuntimeEnv;
agentDir?: string;
}): Promise<OpenClawConfig> {
const baseUrl = resolveOllamaApiBase(
(params.opts.customBaseUrl?.trim() || OLLAMA_DEFAULT_BASE_URL).replace(/\/+$/, ""),
);
const { reachable, models } = await fetchOllamaModels(baseUrl);
const explicitModel = normalizeOllamaModelName(params.opts.customModelId);
if (!reachable) {
params.runtime.error(
[
`Ollama could not be reached at ${baseUrl}.`,
"Download it at https://ollama.com/download",
].join("\n"),
);
params.runtime.exit(1);
return params.nextConfig;
}
await storeOllamaCredential(params.agentDir);
const enrichedModels = await enrichOllamaModelsWithContext(
baseUrl,
models.slice(0, OLLAMA_CONTEXT_ENRICH_LIMIT),
);
const discoveredModelsByName = new Map(enrichedModels.map((model) => [model.name, model]));
const modelNames = models.map((model) => model.name);
const orderedModelNames = [
...OLLAMA_SUGGESTED_MODELS_LOCAL,
...modelNames.filter((name) => !OLLAMA_SUGGESTED_MODELS_LOCAL.includes(name)),
];
const requestedDefaultModelId = explicitModel ?? OLLAMA_SUGGESTED_MODELS_LOCAL[0];
const availableModelNames = new Set(modelNames);
const requestedCloudModel = isOllamaCloudModel(requestedDefaultModelId);
let pulledRequestedModel = false;
if (requestedCloudModel) {
availableModelNames.add(requestedDefaultModelId);
} else if (!modelNames.includes(requestedDefaultModelId)) {
pulledRequestedModel = await pullOllamaModelNonInteractive(
baseUrl,
requestedDefaultModelId,
params.runtime,
);
if (pulledRequestedModel) {
availableModelNames.add(requestedDefaultModelId);
}
}
let allModelNames = orderedModelNames;
let defaultModelId = requestedDefaultModelId;
if (
(pulledRequestedModel || requestedCloudModel) &&
!allModelNames.includes(requestedDefaultModelId)
) {
allModelNames = [...allModelNames, requestedDefaultModelId];
}
if (!availableModelNames.has(requestedDefaultModelId)) {
if (availableModelNames.size === 0) {
params.runtime.error(
[
`No Ollama models are available at ${baseUrl}.`,
"Pull a model first, then re-run setup.",
].join("\n"),
);
params.runtime.exit(1);
return params.nextConfig;
}
defaultModelId =
allModelNames.find((name) => availableModelNames.has(name)) ??
Array.from(availableModelNames)[0]!;
params.runtime.log(
`Ollama model ${requestedDefaultModelId} was not available; using ${defaultModelId} instead.`,
);
}
const config = applyOllamaProviderConfig(
params.nextConfig,
baseUrl,
allModelNames,
discoveredModelsByName,
);
params.runtime.log(`Default Ollama model: ${defaultModelId}`);
return applyAgentDefaultModelPrimary(config, `ollama/${defaultModelId}`);
}
export async function ensureOllamaModelPulled(params: {
config: OpenClawConfig;
model: string;
prompter: WizardPrompter;
}): Promise<void> {
if (!params.model.startsWith("ollama/")) {
return;
}
const baseUrl = params.config.models?.providers?.ollama?.baseUrl ?? OLLAMA_DEFAULT_BASE_URL;
const modelName = params.model.slice("ollama/".length);
if (isOllamaCloudModel(modelName)) {
return;
}
const { models } = await fetchOllamaModels(baseUrl);
if (models.some((model) => model.name === modelName)) {
return;
}
if (!(await pullOllamaModel(baseUrl, modelName, params.prompter))) {
throw new WizardCancelledError("Failed to download selected Ollama model");
}
}