feat(ollama): implement native /api/chat StreamFn (#11828)

This commit is contained in:
BrokenFinger98 2026-02-08 21:19:02 +09:00 committed by Peter Steinberger
parent 91dda25d3b
commit e9900993a2
2 changed files with 496 additions and 0 deletions

View File

@ -0,0 +1,121 @@
import { describe, expect, it } from "vitest";
import { convertToOllamaMessages, buildAssistantMessage } from "./ollama-stream.js";
describe("convertToOllamaMessages", () => {
it("converts user text messages", () => {
const messages = [{ role: "user", content: "hello" }];
const result = convertToOllamaMessages(messages);
expect(result).toEqual([{ role: "user", content: "hello" }]);
});
it("converts user messages with content parts", () => {
const messages = [
{
role: "user",
content: [
{ type: "text", text: "describe this" },
{ type: "image", data: "base64data" },
],
},
];
const result = convertToOllamaMessages(messages);
expect(result).toEqual([
{ role: "user", content: "describe this", images: ["base64data"] },
]);
});
it("prepends system message when provided", () => {
const messages = [{ role: "user", content: "hello" }];
const result = convertToOllamaMessages(messages, "You are helpful.");
expect(result[0]).toEqual({ role: "system", content: "You are helpful." });
expect(result[1]).toEqual({ role: "user", content: "hello" });
});
it("converts assistant messages with tool calls", () => {
const messages = [
{
role: "assistant",
content: [
{ type: "text", text: "Let me check." },
{ type: "tool_use", id: "call_1", name: "bash", input: { command: "ls" } },
],
},
];
const result = convertToOllamaMessages(messages);
expect(result[0].role).toBe("assistant");
expect(result[0].content).toBe("Let me check.");
expect(result[0].tool_calls).toEqual([
{ function: { name: "bash", arguments: { command: "ls" } } },
]);
});
it("converts tool result messages", () => {
const messages = [{ role: "tool", content: "file1.txt\nfile2.txt" }];
const result = convertToOllamaMessages(messages);
expect(result).toEqual([{ role: "tool", content: "file1.txt\nfile2.txt" }]);
});
it("handles empty messages array", () => {
const result = convertToOllamaMessages([]);
expect(result).toEqual([]);
});
});
describe("buildAssistantMessage", () => {
const modelInfo = { api: "ollama", provider: "ollama", id: "qwen3:32b" };
it("builds text-only response", () => {
const response = {
model: "qwen3:32b",
created_at: "2026-01-01T00:00:00Z",
message: { role: "assistant" as const, content: "Hello!" },
done: true,
prompt_eval_count: 10,
eval_count: 5,
};
const result = buildAssistantMessage(response, modelInfo);
expect(result.role).toBe("assistant");
expect(result.content).toEqual([{ type: "text", text: "Hello!" }]);
expect(result.stopReason).toBe("stop");
expect(result.usage.input).toBe(10);
expect(result.usage.output).toBe(5);
expect(result.usage.totalTokens).toBe(15);
});
it("builds response with tool calls", () => {
const response = {
model: "qwen3:32b",
created_at: "2026-01-01T00:00:00Z",
message: {
role: "assistant" as const,
content: "",
tool_calls: [
{ function: { name: "bash", arguments: { command: "ls -la" } } },
],
},
done: true,
prompt_eval_count: 20,
eval_count: 10,
};
const result = buildAssistantMessage(response, modelInfo);
expect(result.stopReason).toBe("end_turn");
expect(result.content.length).toBe(2); // empty text + tool_use
expect(result.content[1].type).toBe("tool_use");
const toolUse = result.content[1] as { type: "tool_use"; name: string; input: Record<string, unknown> };
expect(toolUse.name).toBe("bash");
expect(toolUse.input).toEqual({ command: "ls -la" });
});
it("sets all costs to zero for local models", () => {
const response = {
model: "qwen3:32b",
created_at: "2026-01-01T00:00:00Z",
message: { role: "assistant" as const, content: "ok" },
done: true,
};
const result = buildAssistantMessage(response, modelInfo);
expect(result.usage.cost).toEqual({
input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0,
});
});
});

375
src/agents/ollama-stream.ts Normal file
View File

@ -0,0 +1,375 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import type { SimpleStreamOptions } from "@mariozechner/pi-ai";
import { AssistantMessageEventStream } from "@mariozechner/pi-ai";
// ── Ollama /api/chat request types ──────────────────────────────────────────
interface OllamaChatRequest {
model: string;
messages: OllamaChatMessage[];
stream: boolean;
tools?: OllamaTool[];
options?: Record<string, unknown>;
}
interface OllamaChatMessage {
role: "system" | "user" | "assistant" | "tool";
content: string;
images?: string[];
tool_calls?: OllamaToolCall[];
}
interface OllamaTool {
type: "function";
function: {
name: string;
description: string;
parameters: Record<string, unknown>;
};
}
interface OllamaToolCall {
function: {
name: string;
arguments: Record<string, unknown>;
};
}
// ── Ollama /api/chat response types ─────────────────────────────────────────
interface OllamaChatResponse {
model: string;
created_at: string;
message: {
role: "assistant";
content: string;
tool_calls?: OllamaToolCall[];
};
done: boolean;
done_reason?: string;
total_duration?: number;
load_duration?: number;
prompt_eval_count?: number;
prompt_eval_duration?: number;
eval_count?: number;
eval_duration?: number;
}
// ── Message conversion ──────────────────────────────────────────────────────
type AgentMessage = {
role: string;
content: unknown;
[key: string]: unknown;
};
type ContentPart =
| { type: "text"; text: string }
| { type: "image"; data: string; mediaType?: string }
| { type: "tool_use"; id: string; name: string; input: Record<string, unknown> }
| { type: "tool_result"; tool_use_id: string; content: string };
function extractTextContent(content: unknown): string {
if (typeof content === "string") {
return content;
}
if (!Array.isArray(content)) {
return "";
}
return (content as ContentPart[])
.filter((part): part is { type: "text"; text: string } => part.type === "text")
.map((part) => part.text)
.join("");
}
function extractImages(content: unknown): string[] {
if (!Array.isArray(content)) {
return [];
}
return (content as ContentPart[])
.filter((part): part is { type: "image"; data: string } => part.type === "image")
.map((part) => part.data);
}
function extractToolCalls(content: unknown): OllamaToolCall[] {
if (!Array.isArray(content)) {
return [];
}
return (content as ContentPart[])
.filter(
(part): part is { type: "tool_use"; id: string; name: string; input: Record<string, unknown> } =>
part.type === "tool_use",
)
.map((part) => ({
function: {
name: part.name,
arguments: part.input,
},
}));
}
export function convertToOllamaMessages(
messages: AgentMessage[],
system?: string,
): OllamaChatMessage[] {
const result: OllamaChatMessage[] = [];
if (system) {
result.push({ role: "system", content: system });
}
for (const msg of messages) {
const role = msg.role as string;
if (role === "user") {
const text = extractTextContent(msg.content);
const images = extractImages(msg.content);
result.push({
role: "user",
content: text,
...(images.length > 0 ? { images } : {}),
});
} else if (role === "assistant") {
const text = extractTextContent(msg.content);
const toolCalls = extractToolCalls(msg.content);
result.push({
role: "assistant",
content: text,
...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {}),
});
} else if (role === "tool") {
const text = extractTextContent(msg.content);
result.push({ role: "tool", content: text });
}
}
return result;
}
// ── Response conversion ─────────────────────────────────────────────────────
interface AssistantMessageLike {
role: "assistant";
content: ContentPart[];
stopReason: string;
api: string;
provider: string;
model: string;
usage: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
totalTokens: number;
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
total: number;
};
};
timestamp: number;
}
let toolCallIdCounter = 0;
export function buildAssistantMessage(
response: OllamaChatResponse,
modelInfo: { api: string; provider: string; id: string },
): AssistantMessageLike {
const content: ContentPart[] = [];
if (response.message.content) {
content.push({ type: "text", text: response.message.content });
}
const toolCalls = response.message.tool_calls;
if (toolCalls && toolCalls.length > 0) {
for (const tc of toolCalls) {
toolCallIdCounter += 1;
content.push({
type: "tool_use",
id: `ollama_call_${toolCallIdCounter}_${Date.now()}`,
name: tc.function.name,
input: tc.function.arguments,
});
}
}
const hasToolCalls = toolCalls && toolCalls.length > 0;
const stopReason = hasToolCalls ? "end_turn" : "stop";
return {
role: "assistant",
content,
stopReason,
api: modelInfo.api,
provider: modelInfo.provider,
model: modelInfo.id,
usage: {
input: response.prompt_eval_count ?? 0,
output: response.eval_count ?? 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: (response.prompt_eval_count ?? 0) + (response.eval_count ?? 0),
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
};
}
// ── NDJSON streaming parser ─────────────────────────────────────────────────
export async function* parseNdjsonStream(
reader: ReadableStreamDefaultReader<Uint8Array>,
): AsyncGenerator<OllamaChatResponse> {
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() ?? "";
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) {
continue;
}
try {
yield JSON.parse(trimmed) as OllamaChatResponse;
} catch {
// Skip malformed lines
}
}
}
if (buffer.trim()) {
try {
yield JSON.parse(buffer.trim()) as OllamaChatResponse;
} catch {
// Skip malformed trailing data
}
}
}
// ── Main StreamFn factory ───────────────────────────────────────────────────
export function createOllamaStreamFn(baseUrl: string): StreamFn {
const chatUrl = `${baseUrl.replace(/\/+$/, "")}/api/chat`;
return (model, context, options) => {
const stream = new AssistantMessageEventStream();
const run = async () => {
try {
const ctx = context as { messages?: AgentMessage[]; system?: string };
const ollamaMessages = convertToOllamaMessages(ctx.messages ?? [], ctx.system as string);
const body: OllamaChatRequest = {
model: model.id,
messages: ollamaMessages,
stream: true,
...(typeof options?.temperature === "number"
? { options: { temperature: options.temperature } }
: {}),
};
const headers: Record<string, string> = {
"Content-Type": "application/json",
...(options?.headers ?? {}),
};
if (options?.apiKey) {
headers.Authorization = `Bearer ${options.apiKey}`;
}
const response = await fetch(chatUrl, {
method: "POST",
headers,
body: JSON.stringify(body),
});
if (!response.ok) {
const errorText = await response.text().catch(() => "unknown error");
throw new Error(`Ollama API error ${response.status}: ${errorText}`);
}
if (!response.body) {
throw new Error("Ollama API returned empty response body");
}
const reader = response.body.getReader();
let accumulatedContent = "";
let finalResponse: OllamaChatResponse | undefined;
for await (const chunk of parseNdjsonStream(reader)) {
if (chunk.done) {
finalResponse = chunk;
if (chunk.message?.content) {
accumulatedContent += chunk.message.content;
}
break;
}
if (chunk.message?.content) {
accumulatedContent += chunk.message.content;
}
}
if (!finalResponse) {
throw new Error("Ollama API stream ended without a final response");
}
finalResponse.message.content = accumulatedContent;
const assistantMessage = buildAssistantMessage(finalResponse, {
api: model.api as string,
provider: model.provider as string,
id: model.id,
});
stream.push({
type: "done",
reason: "stop",
message: assistantMessage,
});
} catch (err) {
const errorMessage = err instanceof Error ? err.message : String(err);
stream.push({
type: "done",
reason: "error",
message: {
role: "assistant" as const,
content: [],
stopReason: "error",
errorMessage,
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
},
});
} finally {
stream.end();
}
};
queueMicrotask(() => void run());
return stream;
};
}
export const OLLAMA_NATIVE_BASE_URL = "http://127.0.0.1:11434";