diff --git a/src/agents/pi-embedded-runner/moonshot-stream-wrappers.ts b/src/agents/pi-embedded-runner/moonshot-stream-wrappers.ts index c066a168a0f..f9ce52eb234 100644 --- a/src/agents/pi-embedded-runner/moonshot-stream-wrappers.ts +++ b/src/agents/pi-embedded-runner/moonshot-stream-wrappers.ts @@ -1,6 +1,7 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import { streamSimple } from "@mariozechner/pi-ai"; import type { ThinkLevel } from "../../auto-reply/thinking.js"; +import { usesMoonshotThinkingPayloadCompat } from "../provider-capabilities.js"; type MoonshotThinkingType = "enabled" | "disabled"; @@ -62,7 +63,7 @@ export function shouldApplyMoonshotPayloadCompat(params: { const normalizedProvider = params.provider.trim().toLowerCase(); const normalizedModelId = params.modelId.trim().toLowerCase(); - if (normalizedProvider === "moonshot") { + if (usesMoonshotThinkingPayloadCompat(normalizedProvider)) { return true; } diff --git a/src/agents/pi-embedded-runner/xai-stream-wrappers.test.ts b/src/agents/pi-embedded-runner/xai-stream-wrappers.test.ts new file mode 100644 index 00000000000..a005f2c4721 --- /dev/null +++ b/src/agents/pi-embedded-runner/xai-stream-wrappers.test.ts @@ -0,0 +1,39 @@ +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import type { Context, Model } from "@mariozechner/pi-ai"; +import { describe, expect, it } from "vitest"; +import { createXaiFastModeWrapper } from "./xai-stream-wrappers.js"; + +function captureWrappedModelId(params: { modelId: string; fastMode: boolean }): string { + let capturedModelId = ""; + const baseStreamFn: StreamFn = (model) => { + capturedModelId = model.id; + return {} as ReturnType; + }; + + const wrapped = createXaiFastModeWrapper(baseStreamFn, params.fastMode); + void wrapped( + { + api: "openai-completions", + provider: "xai", + id: params.modelId, + } as Model<"openai-completions">, + { messages: [] } as Context, + {}, + ); + + return capturedModelId; +} + +describe("xai fast mode wrapper", () => { + it("rewrites Grok 3 models to fast variants", () => { + expect(captureWrappedModelId({ modelId: "grok-3", fastMode: true })).toBe("grok-3-fast"); + expect(captureWrappedModelId({ modelId: "grok-3-mini", fastMode: true })).toBe( + "grok-3-mini-fast", + ); + }); + + it("leaves unsupported or disabled models unchanged", () => { + expect(captureWrappedModelId({ modelId: "grok-3-fast", fastMode: true })).toBe("grok-3-fast"); + expect(captureWrappedModelId({ modelId: "grok-3", fastMode: false })).toBe("grok-3"); + }); +}); diff --git a/src/agents/pi-embedded-runner/xai-stream-wrappers.ts b/src/agents/pi-embedded-runner/xai-stream-wrappers.ts index 747c3fb4158..ceef1dcb36d 100644 --- a/src/agents/pi-embedded-runner/xai-stream-wrappers.ts +++ b/src/agents/pi-embedded-runner/xai-stream-wrappers.ts @@ -2,6 +2,8 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import { streamSimple } from "@mariozechner/pi-ai"; const XAI_FAST_MODEL_IDS = new Map([ + ["grok-3", "grok-3-fast"], + ["grok-3-mini", "grok-3-mini-fast"], ["grok-4", "grok-4-fast"], ["grok-4-0709", "grok-4-fast"], ]); diff --git a/src/agents/provider-capabilities.test.ts b/src/agents/provider-capabilities.test.ts index 09f19468776..5dee5ea2113 100644 --- a/src/agents/provider-capabilities.test.ts +++ b/src/agents/provider-capabilities.test.ts @@ -53,6 +53,7 @@ import { shouldDropThinkingBlocksForModel, shouldSanitizeGeminiThoughtSignaturesForModel, supportsOpenAiCompatTurnValidation, + usesMoonshotThinkingPayloadCompat, } from "./provider-capabilities.js"; describe("resolveProviderCapabilities", () => { @@ -60,6 +61,7 @@ describe("resolveProviderCapabilities", () => { expect(resolveProviderCapabilities("anthropic")).toEqual({ anthropicToolSchemaMode: "native", anthropicToolChoiceMode: "native", + openAiPayloadNormalizationMode: "default", providerFamily: "anthropic", preserveAnthropicThinkingSignatures: true, openAiCompatTurnValidation: true, @@ -72,6 +74,7 @@ describe("resolveProviderCapabilities", () => { expect(resolveProviderCapabilities("anthropic-vertex")).toEqual({ anthropicToolSchemaMode: "native", anthropicToolChoiceMode: "native", + openAiPayloadNormalizationMode: "default", providerFamily: "anthropic", preserveAnthropicThinkingSignatures: true, openAiCompatTurnValidation: true, @@ -84,6 +87,7 @@ describe("resolveProviderCapabilities", () => { expect(resolveProviderCapabilities("amazon-bedrock")).toEqual({ anthropicToolSchemaMode: "native", anthropicToolChoiceMode: "native", + openAiPayloadNormalizationMode: "default", providerFamily: "anthropic", preserveAnthropicThinkingSignatures: true, openAiCompatTurnValidation: true, @@ -100,6 +104,7 @@ describe("resolveProviderCapabilities", () => { expect(resolveProviderCapabilities("kimi-code")).toEqual({ anthropicToolSchemaMode: "native", anthropicToolChoiceMode: "native", + openAiPayloadNormalizationMode: "default", providerFamily: "default", preserveAnthropicThinkingSignatures: false, openAiCompatTurnValidation: true, @@ -118,6 +123,11 @@ describe("resolveProviderCapabilities", () => { expect(supportsOpenAiCompatTurnValidation("moonshot")).toBe(true); }); + it("routes moonshot payload compatibility through the capability registry", () => { + expect(usesMoonshotThinkingPayloadCompat("moonshot")).toBe(true); + expect(usesMoonshotThinkingPayloadCompat("openai")).toBe(false); + }); + it("resolves transcript thought-signature and tool-call quirks through the registry", () => { expect( shouldSanitizeGeminiThoughtSignaturesForModel({ diff --git a/src/agents/provider-capabilities.ts b/src/agents/provider-capabilities.ts index c52be686387..01ec62f55f8 100644 --- a/src/agents/provider-capabilities.ts +++ b/src/agents/provider-capabilities.ts @@ -5,6 +5,7 @@ import { normalizeProviderId } from "./model-selection.js"; export type ProviderCapabilities = { anthropicToolSchemaMode: "native" | "openai-functions"; anthropicToolChoiceMode: "native" | "openai-string-modes"; + openAiPayloadNormalizationMode: "default" | "moonshot-thinking"; providerFamily: "default" | "openai" | "anthropic"; preserveAnthropicThinkingSignatures: boolean; openAiCompatTurnValidation: boolean; @@ -24,6 +25,7 @@ export type ProviderCapabilityLookupOptions = { const DEFAULT_PROVIDER_CAPABILITIES: ProviderCapabilities = { anthropicToolSchemaMode: "native", anthropicToolChoiceMode: "native", + openAiPayloadNormalizationMode: "default", providerFamily: "default", preserveAnthropicThinkingSignatures: true, openAiCompatTurnValidation: true, @@ -62,6 +64,9 @@ const PLUGIN_CAPABILITIES_FALLBACKS: Record