diff --git a/extensions/ollama/src/stream.ts b/extensions/ollama/src/stream.ts index 2bbc0609010..133b15d5eee 100644 --- a/extensions/ollama/src/stream.ts +++ b/extensions/ollama/src/stream.ts @@ -689,12 +689,58 @@ export function createOllamaStreamFn( let accumulatedContent = ""; const accumulatedToolCalls: OllamaToolCall[] = []; let finalResponse: OllamaChatResponse | undefined; + const modelInfo = { api: model.api, provider: model.provider, id: model.id }; + let streamStarted = false; + let textBlockClosed = false; + + const closeTextBlock = () => { + if (!streamStarted || textBlockClosed) { + return; + } + textBlockClosed = true; + const partial = buildStreamAssistantMessage({ + model: modelInfo, + content: [{ type: "text", text: accumulatedContent }], + stopReason: "stop", + usage: buildUsageWithNoCost({}), + }); + stream.push({ + type: "text_end", + contentIndex: 0, + content: accumulatedContent, + partial, + }); + }; for await (const chunk of parseNdjsonStream(reader)) { if (chunk.message?.content) { - accumulatedContent += chunk.message.content; + const delta = chunk.message.content; + + if (!streamStarted) { + streamStarted = true; + // Emit start/text_start with an empty partial before accumulating + // the first delta, matching the Anthropic/OpenAI provider contract. + const emptyPartial = buildStreamAssistantMessage({ + model: modelInfo, + content: [], + stopReason: "stop", + usage: buildUsageWithNoCost({}), + }); + stream.push({ type: "start", partial: emptyPartial }); + stream.push({ type: "text_start", contentIndex: 0, partial: emptyPartial }); + } + + accumulatedContent += delta; + const partial = buildStreamAssistantMessage({ + model: modelInfo, + content: [{ type: "text", text: accumulatedContent }], + stopReason: "stop", + usage: buildUsageWithNoCost({}), + }); + stream.push({ type: "text_delta", contentIndex: 0, delta, partial }); } if (chunk.message?.tool_calls) { + closeTextBlock(); accumulatedToolCalls.push(...chunk.message.tool_calls); } if (chunk.done) { @@ -712,11 +758,11 @@ export function createOllamaStreamFn( finalResponse.message.tool_calls = accumulatedToolCalls; } - const assistantMessage = buildAssistantMessage(finalResponse, { - api: model.api, - provider: model.provider, - id: model.id, - }); + const assistantMessage = buildAssistantMessage(finalResponse, modelInfo); + + // Close the text block if we emitted any text_delta events. + closeTextBlock(); + stream.push({ type: "done", reason: assistantMessage.stopReason === "toolUse" ? "toolUse" : "stop", diff --git a/src/agents/ollama-stream.test.ts b/src/agents/ollama-stream.test.ts index 59cbc57a353..a81fb327859 100644 --- a/src/agents/ollama-stream.test.ts +++ b/src/agents/ollama-stream.test.ts @@ -332,6 +332,40 @@ async function withMockNdjsonFetch( } } +function createControlledNdjsonFetch(): { + fetchMock: ReturnType; + pushLine: (line: string) => void; + close: () => void; +} { + const encoder = new TextEncoder(); + let controller: ReadableStreamDefaultController | undefined; + const body = new ReadableStream({ + start(streamController) { + controller = streamController; + }, + }); + return { + fetchMock: vi.fn(async () => { + return new Response(body, { + status: 200, + headers: { "Content-Type": "application/x-ndjson" }, + }); + }), + pushLine(line: string) { + if (!controller) { + throw new Error("NDJSON controller not initialized"); + } + controller.enqueue(encoder.encode(`${line}\n`)); + }, + close() { + if (!controller) { + throw new Error("NDJSON controller not initialized"); + } + controller.close(); + }, + }; +} + async function createOllamaTestStream(params: { baseUrl: string; defaultHeaders?: Record; @@ -365,6 +399,219 @@ async function collectStreamEvents(stream: AsyncIterable): Promise { return events; } +async function nextEventWithin( + iterator: AsyncIterator, + timeoutMs = 100, +): Promise | "timeout"> { + return await Promise.race([ + iterator.next(), + new Promise<"timeout">((resolve) => { + setTimeout(() => resolve("timeout"), timeoutMs); + }), + ]); +} + +describe("createOllamaStreamFn streaming events", () => { + it("emits start, text_start, text_delta, text_end, done for text responses", async () => { + await withMockNdjsonFetch( + [ + '{"model":"m","created_at":"t","message":{"role":"assistant","content":"Hello"},"done":false}', + '{"model":"m","created_at":"t","message":{"role":"assistant","content":" world"},"done":false}', + '{"model":"m","created_at":"t","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":5,"eval_count":2}', + ], + async () => { + const stream = await createOllamaTestStream({ baseUrl: "http://ollama-host:11434" }); + const events = await collectStreamEvents(stream); + + const types = events.map((e) => e.type); + expect(types).toEqual([ + "start", + "text_start", + "text_delta", + "text_delta", + "text_end", + "done", + ]); + + // text_delta events carry incremental deltas + const deltas = events.filter((e) => e.type === "text_delta"); + expect(deltas[0]).toMatchObject({ contentIndex: 0, delta: "Hello" }); + expect(deltas[1]).toMatchObject({ contentIndex: 0, delta: " world" }); + + // text_end carries the full accumulated content + const textEnd = events.find((e) => e.type === "text_end"); + expect(textEnd).toMatchObject({ contentIndex: 0, content: "Hello world" }); + + // start/text_start carry empty partials (before any content accumulates) + const startEvent = events.find((e) => e.type === "start"); + expect(startEvent?.partial.content).toEqual([]); + const textStartEvent = events.find((e) => e.type === "text_start"); + expect(textStartEvent?.partial.content).toEqual([]); + + // text_delta partials accumulate content progressively + expect(deltas[0].partial.content).toEqual([{ type: "text", text: "Hello" }]); + expect(deltas[1].partial.content).toEqual([{ type: "text", text: "Hello world" }]); + + // done event contains the final message + const doneEvent = events.at(-1); + expect(doneEvent?.type).toBe("done"); + if (doneEvent?.type === "done") { + expect(doneEvent.message.content).toEqual([{ type: "text", text: "Hello world" }]); + } + }, + ); + }); + + it("emits only done for tool-call-only responses (no text content)", async () => { + await withMockNdjsonFetch( + [ + '{"model":"m","created_at":"t","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"bash","arguments":{"command":"ls"}}}]},"done":false}', + '{"model":"m","created_at":"t","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":10,"eval_count":5}', + ], + async () => { + const stream = await createOllamaTestStream({ baseUrl: "http://ollama-host:11434" }); + const events = await collectStreamEvents(stream); + + // No text content means no start/text_start/text_delta/text_end events + const types = events.map((e) => e.type); + expect(types).toEqual(["done"]); + const doneEvent = events[0]; + if (doneEvent.type === "done") { + expect(doneEvent.reason).toBe("toolUse"); + } + }, + ); + }); + + it("emits text streaming events before done for mixed text + tool responses", async () => { + await withMockNdjsonFetch( + [ + '{"model":"m","created_at":"t","message":{"role":"assistant","content":"Let me check."},"done":false}', + '{"model":"m","created_at":"t","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"bash","arguments":{"command":"ls"}}}]},"done":false}', + '{"model":"m","created_at":"t","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":10,"eval_count":5}', + ], + async () => { + const stream = await createOllamaTestStream({ baseUrl: "http://ollama-host:11434" }); + const events = await collectStreamEvents(stream); + + const types = events.map((e) => e.type); + expect(types).toEqual(["start", "text_start", "text_delta", "text_end", "done"]); + const doneEvent = events.at(-1); + if (doneEvent?.type === "done") { + expect(doneEvent.reason).toBe("toolUse"); + } + }, + ); + }); + + it("emits text_end as soon as Ollama switches from text to tool calls", async () => { + const originalFetch = globalThis.fetch; + const controlledFetch = createControlledNdjsonFetch(); + globalThis.fetch = controlledFetch.fetchMock as unknown as typeof fetch; + + try { + const stream = await createOllamaTestStream({ baseUrl: "http://ollama-host:11434" }); + const iterator = stream[Symbol.asyncIterator](); + + controlledFetch.pushLine( + '{"model":"m","created_at":"t","message":{"role":"assistant","content":"Let me check."},"done":false}', + ); + + const startEvent = await nextEventWithin(iterator); + const textStartEvent = await nextEventWithin(iterator); + const textDeltaEvent = await nextEventWithin(iterator); + + expect(startEvent).not.toBe("timeout"); + expect(textStartEvent).not.toBe("timeout"); + expect(textDeltaEvent).not.toBe("timeout"); + expect(startEvent).toMatchObject({ value: { type: "start" }, done: false }); + expect(textStartEvent).toMatchObject({ value: { type: "text_start" }, done: false }); + expect(textDeltaEvent).toMatchObject({ + value: { type: "text_delta", delta: "Let me check." }, + done: false, + }); + + controlledFetch.pushLine( + '{"model":"m","created_at":"t","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"bash","arguments":{"command":"ls"}}}]},"done":false}', + ); + + const textEndEvent = await nextEventWithin(iterator); + expect(textEndEvent).not.toBe("timeout"); + expect(textEndEvent).toMatchObject({ + value: { + type: "text_end", + contentIndex: 0, + content: "Let me check.", + partial: { + content: [{ type: "text", text: "Let me check." }], + }, + }, + done: false, + }); + + const nextBeforeDone = await nextEventWithin(iterator, 25); + expect(nextBeforeDone).toBe("timeout"); + + controlledFetch.pushLine( + '{"model":"m","created_at":"t","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":10,"eval_count":5}', + ); + controlledFetch.close(); + + const doneEvent = await nextEventWithin(iterator); + expect(doneEvent).not.toBe("timeout"); + expect(doneEvent).toMatchObject({ + value: { type: "done", reason: "toolUse" }, + done: false, + }); + + const streamEnd = await nextEventWithin(iterator); + expect(streamEnd).not.toBe("timeout"); + expect(streamEnd).toMatchObject({ value: undefined, done: true }); + } finally { + globalThis.fetch = originalFetch; + } + }); + + it("emits error without text_end when stream fails mid-response", async () => { + // Simulate a stream that sends one content chunk then ends without done:true. + // The stream function throws "Ollama API stream ended without a final response". + await withMockNdjsonFetch( + [ + '{"model":"m","created_at":"t","message":{"role":"assistant","content":"partial"},"done":false}', + ], + async () => { + const stream = await createOllamaTestStream({ baseUrl: "http://ollama-host:11434" }); + const events = await collectStreamEvents(stream); + + const types = events.map((e) => e.type); + // Should have streaming events for the partial content, then error (no text_end). + expect(types).toEqual(["start", "text_start", "text_delta", "error"]); + const errorEvent = events.at(-1); + expect(errorEvent?.type).toBe("error"); + }, + ); + }); + + it("emits a single text_delta for single-chunk responses", async () => { + await withMockNdjsonFetch( + [ + '{"model":"m","created_at":"t","message":{"role":"assistant","content":"one shot"},"done":false}', + '{"model":"m","created_at":"t","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":1,"eval_count":1}', + ], + async () => { + const stream = await createOllamaTestStream({ baseUrl: "http://ollama-host:11434" }); + const events = await collectStreamEvents(stream); + + const types = events.map((e) => e.type); + expect(types).toEqual(["start", "text_start", "text_delta", "text_end", "done"]); + + const delta = events.find((e) => e.type === "text_delta"); + expect(delta).toMatchObject({ delta: "one shot" }); + }, + ); + }); +}); + describe("createOllamaStreamFn", () => { it("normalizes /v1 baseUrl and maps maxTokens + signal", async () => { await withMockNdjsonFetch(