diff --git a/src/agents/ollama-stream.test.ts b/src/agents/ollama-stream.test.ts new file mode 100644 index 00000000000..08f269c2c93 --- /dev/null +++ b/src/agents/ollama-stream.test.ts @@ -0,0 +1,121 @@ +import { describe, expect, it } from "vitest"; +import { convertToOllamaMessages, buildAssistantMessage } from "./ollama-stream.js"; + +describe("convertToOllamaMessages", () => { + it("converts user text messages", () => { + const messages = [{ role: "user", content: "hello" }]; + const result = convertToOllamaMessages(messages); + expect(result).toEqual([{ role: "user", content: "hello" }]); + }); + + it("converts user messages with content parts", () => { + const messages = [ + { + role: "user", + content: [ + { type: "text", text: "describe this" }, + { type: "image", data: "base64data" }, + ], + }, + ]; + const result = convertToOllamaMessages(messages); + expect(result).toEqual([ + { role: "user", content: "describe this", images: ["base64data"] }, + ]); + }); + + it("prepends system message when provided", () => { + const messages = [{ role: "user", content: "hello" }]; + const result = convertToOllamaMessages(messages, "You are helpful."); + expect(result[0]).toEqual({ role: "system", content: "You are helpful." }); + expect(result[1]).toEqual({ role: "user", content: "hello" }); + }); + + it("converts assistant messages with tool calls", () => { + const messages = [ + { + role: "assistant", + content: [ + { type: "text", text: "Let me check." }, + { type: "tool_use", id: "call_1", name: "bash", input: { command: "ls" } }, + ], + }, + ]; + const result = convertToOllamaMessages(messages); + expect(result[0].role).toBe("assistant"); + expect(result[0].content).toBe("Let me check."); + expect(result[0].tool_calls).toEqual([ + { function: { name: "bash", arguments: { command: "ls" } } }, + ]); + }); + + it("converts tool result messages", () => { + const messages = [{ role: "tool", content: "file1.txt\nfile2.txt" }]; + const result = convertToOllamaMessages(messages); + expect(result).toEqual([{ role: "tool", content: "file1.txt\nfile2.txt" }]); + }); + + it("handles empty messages array", () => { + const result = convertToOllamaMessages([]); + expect(result).toEqual([]); + }); +}); + +describe("buildAssistantMessage", () => { + const modelInfo = { api: "ollama", provider: "ollama", id: "qwen3:32b" }; + + it("builds text-only response", () => { + const response = { + model: "qwen3:32b", + created_at: "2026-01-01T00:00:00Z", + message: { role: "assistant" as const, content: "Hello!" }, + done: true, + prompt_eval_count: 10, + eval_count: 5, + }; + const result = buildAssistantMessage(response, modelInfo); + expect(result.role).toBe("assistant"); + expect(result.content).toEqual([{ type: "text", text: "Hello!" }]); + expect(result.stopReason).toBe("stop"); + expect(result.usage.input).toBe(10); + expect(result.usage.output).toBe(5); + expect(result.usage.totalTokens).toBe(15); + }); + + it("builds response with tool calls", () => { + const response = { + model: "qwen3:32b", + created_at: "2026-01-01T00:00:00Z", + message: { + role: "assistant" as const, + content: "", + tool_calls: [ + { function: { name: "bash", arguments: { command: "ls -la" } } }, + ], + }, + done: true, + prompt_eval_count: 20, + eval_count: 10, + }; + const result = buildAssistantMessage(response, modelInfo); + expect(result.stopReason).toBe("end_turn"); + expect(result.content.length).toBe(2); // empty text + tool_use + expect(result.content[1].type).toBe("tool_use"); + const toolUse = result.content[1] as { type: "tool_use"; name: string; input: Record }; + expect(toolUse.name).toBe("bash"); + expect(toolUse.input).toEqual({ command: "ls -la" }); + }); + + it("sets all costs to zero for local models", () => { + const response = { + model: "qwen3:32b", + created_at: "2026-01-01T00:00:00Z", + message: { role: "assistant" as const, content: "ok" }, + done: true, + }; + const result = buildAssistantMessage(response, modelInfo); + expect(result.usage.cost).toEqual({ + input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0, + }); + }); +}); diff --git a/src/agents/ollama-stream.ts b/src/agents/ollama-stream.ts new file mode 100644 index 00000000000..8697de6b629 --- /dev/null +++ b/src/agents/ollama-stream.ts @@ -0,0 +1,375 @@ +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import type { SimpleStreamOptions } from "@mariozechner/pi-ai"; +import { AssistantMessageEventStream } from "@mariozechner/pi-ai"; + +// ── Ollama /api/chat request types ────────────────────────────────────────── + +interface OllamaChatRequest { + model: string; + messages: OllamaChatMessage[]; + stream: boolean; + tools?: OllamaTool[]; + options?: Record; +} + +interface OllamaChatMessage { + role: "system" | "user" | "assistant" | "tool"; + content: string; + images?: string[]; + tool_calls?: OllamaToolCall[]; +} + +interface OllamaTool { + type: "function"; + function: { + name: string; + description: string; + parameters: Record; + }; +} + +interface OllamaToolCall { + function: { + name: string; + arguments: Record; + }; +} + +// ── Ollama /api/chat response types ───────────────────────────────────────── + +interface OllamaChatResponse { + model: string; + created_at: string; + message: { + role: "assistant"; + content: string; + tool_calls?: OllamaToolCall[]; + }; + done: boolean; + done_reason?: string; + total_duration?: number; + load_duration?: number; + prompt_eval_count?: number; + prompt_eval_duration?: number; + eval_count?: number; + eval_duration?: number; +} + +// ── Message conversion ────────────────────────────────────────────────────── + +type AgentMessage = { + role: string; + content: unknown; + [key: string]: unknown; +}; + +type ContentPart = + | { type: "text"; text: string } + | { type: "image"; data: string; mediaType?: string } + | { type: "tool_use"; id: string; name: string; input: Record } + | { type: "tool_result"; tool_use_id: string; content: string }; + +function extractTextContent(content: unknown): string { + if (typeof content === "string") { + return content; + } + if (!Array.isArray(content)) { + return ""; + } + return (content as ContentPart[]) + .filter((part): part is { type: "text"; text: string } => part.type === "text") + .map((part) => part.text) + .join(""); +} + +function extractImages(content: unknown): string[] { + if (!Array.isArray(content)) { + return []; + } + return (content as ContentPart[]) + .filter((part): part is { type: "image"; data: string } => part.type === "image") + .map((part) => part.data); +} + +function extractToolCalls(content: unknown): OllamaToolCall[] { + if (!Array.isArray(content)) { + return []; + } + return (content as ContentPart[]) + .filter( + (part): part is { type: "tool_use"; id: string; name: string; input: Record } => + part.type === "tool_use", + ) + .map((part) => ({ + function: { + name: part.name, + arguments: part.input, + }, + })); +} + +export function convertToOllamaMessages( + messages: AgentMessage[], + system?: string, +): OllamaChatMessage[] { + const result: OllamaChatMessage[] = []; + + if (system) { + result.push({ role: "system", content: system }); + } + + for (const msg of messages) { + const role = msg.role as string; + + if (role === "user") { + const text = extractTextContent(msg.content); + const images = extractImages(msg.content); + result.push({ + role: "user", + content: text, + ...(images.length > 0 ? { images } : {}), + }); + } else if (role === "assistant") { + const text = extractTextContent(msg.content); + const toolCalls = extractToolCalls(msg.content); + result.push({ + role: "assistant", + content: text, + ...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {}), + }); + } else if (role === "tool") { + const text = extractTextContent(msg.content); + result.push({ role: "tool", content: text }); + } + } + + return result; +} + +// ── Response conversion ───────────────────────────────────────────────────── + +interface AssistantMessageLike { + role: "assistant"; + content: ContentPart[]; + stopReason: string; + api: string; + provider: string; + model: string; + usage: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + totalTokens: number; + cost: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + total: number; + }; + }; + timestamp: number; +} + +let toolCallIdCounter = 0; + +export function buildAssistantMessage( + response: OllamaChatResponse, + modelInfo: { api: string; provider: string; id: string }, +): AssistantMessageLike { + const content: ContentPart[] = []; + + if (response.message.content) { + content.push({ type: "text", text: response.message.content }); + } + + const toolCalls = response.message.tool_calls; + if (toolCalls && toolCalls.length > 0) { + for (const tc of toolCalls) { + toolCallIdCounter += 1; + content.push({ + type: "tool_use", + id: `ollama_call_${toolCallIdCounter}_${Date.now()}`, + name: tc.function.name, + input: tc.function.arguments, + }); + } + } + + const hasToolCalls = toolCalls && toolCalls.length > 0; + const stopReason = hasToolCalls ? "end_turn" : "stop"; + + return { + role: "assistant", + content, + stopReason, + api: modelInfo.api, + provider: modelInfo.provider, + model: modelInfo.id, + usage: { + input: response.prompt_eval_count ?? 0, + output: response.eval_count ?? 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: (response.prompt_eval_count ?? 0) + (response.eval_count ?? 0), + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + timestamp: Date.now(), + }; +} + +// ── NDJSON streaming parser ───────────────────────────────────────────────── + +export async function* parseNdjsonStream( + reader: ReadableStreamDefaultReader, +): AsyncGenerator { + const decoder = new TextDecoder(); + let buffer = ""; + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed) { + continue; + } + try { + yield JSON.parse(trimmed) as OllamaChatResponse; + } catch { + // Skip malformed lines + } + } + } + + if (buffer.trim()) { + try { + yield JSON.parse(buffer.trim()) as OllamaChatResponse; + } catch { + // Skip malformed trailing data + } + } +} + +// ── Main StreamFn factory ─────────────────────────────────────────────────── + +export function createOllamaStreamFn(baseUrl: string): StreamFn { + const chatUrl = `${baseUrl.replace(/\/+$/, "")}/api/chat`; + + return (model, context, options) => { + const stream = new AssistantMessageEventStream(); + + const run = async () => { + try { + const ctx = context as { messages?: AgentMessage[]; system?: string }; + const ollamaMessages = convertToOllamaMessages(ctx.messages ?? [], ctx.system as string); + + const body: OllamaChatRequest = { + model: model.id, + messages: ollamaMessages, + stream: true, + ...(typeof options?.temperature === "number" + ? { options: { temperature: options.temperature } } + : {}), + }; + + const headers: Record = { + "Content-Type": "application/json", + ...(options?.headers ?? {}), + }; + if (options?.apiKey) { + headers.Authorization = `Bearer ${options.apiKey}`; + } + + const response = await fetch(chatUrl, { + method: "POST", + headers, + body: JSON.stringify(body), + }); + + if (!response.ok) { + const errorText = await response.text().catch(() => "unknown error"); + throw new Error(`Ollama API error ${response.status}: ${errorText}`); + } + + if (!response.body) { + throw new Error("Ollama API returned empty response body"); + } + + const reader = response.body.getReader(); + let accumulatedContent = ""; + let finalResponse: OllamaChatResponse | undefined; + + for await (const chunk of parseNdjsonStream(reader)) { + if (chunk.done) { + finalResponse = chunk; + if (chunk.message?.content) { + accumulatedContent += chunk.message.content; + } + break; + } + + if (chunk.message?.content) { + accumulatedContent += chunk.message.content; + } + } + + if (!finalResponse) { + throw new Error("Ollama API stream ended without a final response"); + } + + finalResponse.message.content = accumulatedContent; + + const assistantMessage = buildAssistantMessage(finalResponse, { + api: model.api as string, + provider: model.provider as string, + id: model.id, + }); + + stream.push({ + type: "done", + reason: "stop", + message: assistantMessage, + }); + } catch (err) { + const errorMessage = err instanceof Error ? err.message : String(err); + stream.push({ + type: "done", + reason: "error", + message: { + role: "assistant" as const, + content: [], + stopReason: "error", + errorMessage, + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + timestamp: Date.now(), + }, + }); + } finally { + stream.end(); + } + }; + + queueMicrotask(() => void run()); + return stream; + }; +} + +export const OLLAMA_NATIVE_BASE_URL = "http://127.0.0.1:11434";