From 47db5abece27d3c9c7edbbd68401f29cb71058ef Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Mon, 23 Mar 2026 04:58:18 -0700 Subject: [PATCH] test: inject thread-safe base seams --- .../pi-embedded-runner-extraparams.test.ts | 59 ++++++------ src/agents/pi-embedded-runner/extra-params.ts | 33 ++++++- src/agents/provider-capabilities.ts | 22 ++++- src/auto-reply/reply/abort.test.ts | 27 +++++- src/auto-reply/reply/abort.ts | 39 +++++++- src/auto-reply/reply/queue/cleanup.ts | 27 +++++- src/gateway/call.test.ts | 92 +++++++++---------- src/gateway/call.ts | 64 +++++++++++-- 8 files changed, 265 insertions(+), 98 deletions(-) diff --git a/src/agents/pi-embedded-runner-extraparams.test.ts b/src/agents/pi-embedded-runner-extraparams.test.ts index dafb8b59319..87bfc7992bd 100644 --- a/src/agents/pi-embedded-runner-extraparams.test.ts +++ b/src/agents/pi-embedded-runner-extraparams.test.ts @@ -1,9 +1,22 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import type { Context, Model, SimpleStreamOptions } from "@mariozechner/pi-ai"; -import { describe, expect, it, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import { __testing as extraParamsTesting } from "./pi-embedded-runner/extra-params.js"; +import { + createOpenRouterSystemCacheWrapper, + createOpenRouterWrapper, + isProxyReasoningUnsupported, +} from "./pi-embedded-runner/proxy-stream-wrappers.js"; +import type { ProviderCapabilities } from "./provider-capabilities.js"; +import { __testing as providerCapabilitiesTesting } from "./provider-capabilities.js"; const resolveProviderCapabilitiesWithPluginMock = vi.fn( - (params: { provider: string; workspaceDir?: string }) => { + (params: { + provider: string; + config?: import("../config/config.js").OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; + }): Partial | undefined => { if ( params.provider === "workspace-anthropic-proxy" && params.workspaceDir === "/tmp/workspace-capabilities" @@ -17,20 +30,12 @@ const resolveProviderCapabilitiesWithPluginMock = vi.fn( }, ); -vi.mock("../plugins/provider-runtime.js", async (importOriginal) => { - const actual = await importOriginal(); - const { - createOpenRouterSystemCacheWrapper, - createOpenRouterWrapper, - isProxyReasoningUnsupported, - } = await import("./pi-embedded-runner/proxy-stream-wrappers.js"); +import { applyExtraParamsToAgent, resolveExtraParams } from "./pi-embedded-runner.js"; +import { log } from "./pi-embedded-runner/logger.js"; - return { - ...actual, - prepareProviderExtraParams: (params: { - provider: string; - context: { extraParams?: Record }; - }) => { +beforeEach(() => { + extraParamsTesting.setProviderRuntimeDepsForTest({ + prepareProviderExtraParams: (params) => { if (params.provider !== "openai-codex") { return undefined; } @@ -43,15 +48,7 @@ vi.mock("../plugins/provider-runtime.js", async (importOriginal) => { transport: "auto", }; }, - wrapProviderStreamFn: (params: { - provider: string; - context: { - modelId: string; - thinkingLevel?: import("../auto-reply/thinking.js").ThinkLevel; - extraParams?: Record; - streamFn?: StreamFn; - }; - }) => { + wrapProviderStreamFn: (params) => { if (params.provider !== "openrouter") { return params.context.streamFn; } @@ -80,13 +77,17 @@ vi.mock("../plugins/provider-runtime.js", async (importOriginal) => { const thinkingLevel = skipReasoningInjection ? undefined : params.context.thinkingLevel; return createOpenRouterSystemCacheWrapper(createOpenRouterWrapper(streamFn, thinkingLevel)); }, - resolveProviderCapabilitiesWithPlugin: (params: { provider: string; workspaceDir?: string }) => - resolveProviderCapabilitiesWithPluginMock(params), - }; + }); + providerCapabilitiesTesting.setResolveProviderCapabilitiesWithPluginForTest( + resolveProviderCapabilitiesWithPluginMock, + ); + resolveProviderCapabilitiesWithPluginMock.mockClear(); }); -import { applyExtraParamsToAgent, resolveExtraParams } from "./pi-embedded-runner.js"; -import { log } from "./pi-embedded-runner/logger.js"; +afterEach(() => { + extraParamsTesting.resetProviderRuntimeDepsForTest(); + providerCapabilitiesTesting.resetDepsForTests(); +}); describe("resolveExtraParams", () => { it("returns undefined with no model config", () => { diff --git a/src/agents/pi-embedded-runner/extra-params.ts b/src/agents/pi-embedded-runner/extra-params.ts index 7978c506237..27dbeede797 100644 --- a/src/agents/pi-embedded-runner/extra-params.ts +++ b/src/agents/pi-embedded-runner/extra-params.ts @@ -4,8 +4,8 @@ import { streamSimple } from "@mariozechner/pi-ai"; import type { ThinkLevel } from "../../auto-reply/thinking.js"; import type { OpenClawConfig } from "../../config/config.js"; import { - prepareProviderExtraParams, - wrapProviderStreamFn, + prepareProviderExtraParams as prepareProviderExtraParamsRuntime, + wrapProviderStreamFn as wrapProviderStreamFnRuntime, } from "../../plugins/provider-runtime.js"; import { createAnthropicBetaHeadersWrapper, @@ -38,6 +38,31 @@ import { } from "./openai-stream-wrappers.js"; import { createXaiFastModeWrapper } from "./xai-stream-wrappers.js"; +const defaultProviderRuntimeDeps = { + prepareProviderExtraParams: prepareProviderExtraParamsRuntime, + wrapProviderStreamFn: wrapProviderStreamFnRuntime, +}; + +const providerRuntimeDeps = { + ...defaultProviderRuntimeDeps, +}; + +export const __testing = { + setProviderRuntimeDepsForTest( + deps: Partial | undefined, + ): void { + providerRuntimeDeps.prepareProviderExtraParams = + deps?.prepareProviderExtraParams ?? defaultProviderRuntimeDeps.prepareProviderExtraParams; + providerRuntimeDeps.wrapProviderStreamFn = + deps?.wrapProviderStreamFn ?? defaultProviderRuntimeDeps.wrapProviderStreamFn; + }, + resetProviderRuntimeDepsForTest(): void { + providerRuntimeDeps.prepareProviderExtraParams = + defaultProviderRuntimeDeps.prepareProviderExtraParams; + providerRuntimeDeps.wrapProviderStreamFn = defaultProviderRuntimeDeps.wrapProviderStreamFn; + }, +}; + /** * Resolve provider-specific extra params from model config. * Used to pass through stream params like temperature/maxTokens. @@ -206,7 +231,7 @@ export function applyExtraParamsToAgent( : undefined; const merged = Object.assign({}, resolvedExtraParams, override); const effectiveExtraParams = - prepareProviderExtraParams({ + providerRuntimeDeps.prepareProviderExtraParams({ provider, config: cfg, context: { @@ -257,7 +282,7 @@ export function applyExtraParamsToAgent( workspaceDir, }); const providerStreamBase = agent.streamFn; - const pluginWrappedStreamFn = wrapProviderStreamFn({ + const pluginWrappedStreamFn = providerRuntimeDeps.wrapProviderStreamFn({ provider, config: cfg, context: { diff --git a/src/agents/provider-capabilities.ts b/src/agents/provider-capabilities.ts index 01ec62f55f8..9842ac40290 100644 --- a/src/agents/provider-capabilities.ts +++ b/src/agents/provider-capabilities.ts @@ -1,5 +1,5 @@ import type { OpenClawConfig } from "../config/config.js"; -import { resolveProviderCapabilitiesWithPlugin } from "../plugins/provider-runtime.js"; +import { resolveProviderCapabilitiesWithPlugin as resolveProviderCapabilitiesWithPluginRuntime } from "../plugins/provider-runtime.js"; import { normalizeProviderId } from "./model-selection.js"; export type ProviderCapabilities = { @@ -82,13 +82,31 @@ const PLUGIN_CAPABILITIES_FALLBACKS: Record ({ })); const commandQueueMocks = vi.hoisted(() => ({ - clearCommandLane: vi.fn(), + clearCommandLane: vi.fn(() => 1), })); vi.mock("../../process/command-queue.js", () => commandQueueMocks); @@ -162,8 +164,29 @@ describe("abort detection", () => { expect(commandQueueMocks.clearCommandLane).toHaveBeenCalledWith(`session:${sessionKey}`); } + beforeEach(() => { + abortTesting.setDepsForTests({ + getAcpSessionManager: (() => + ({ + resolveSession: acpManagerMocks.resolveSession, + cancelSession: acpManagerMocks.cancelSession, + }) as never) as never, + abortEmbeddedPiRun: () => true, + listSubagentRunsForController: subagentRegistryMocks.listSubagentRunsForRequester, + markSubagentRunTerminated: subagentRegistryMocks.markSubagentRunTerminated, + }); + queueCleanupTesting.setDepsForTests({ + resolveEmbeddedSessionLane: (key) => `session:${key.trim() || "main"}`, + clearCommandLane: commandQueueMocks.clearCommandLane, + }); + commandQueueMocks.clearCommandLane.mockClear().mockReturnValue(1); + }); + afterEach(() => { resetAbortMemoryForTest(); + abortTesting.resetDepsForTests(); + queueCleanupTesting.resetDepsForTests(); + commandQueueMocks.clearCommandLane.mockClear().mockReturnValue(1); acpManagerMocks.resolveSession.mockReset().mockReturnValue({ kind: "none" }); acpManagerMocks.cancelSession.mockReset().mockResolvedValue(undefined); }); diff --git a/src/auto-reply/reply/abort.ts b/src/auto-reply/reply/abort.ts index 327c2c74334..3d51db17ddb 100644 --- a/src/auto-reply/reply/abort.ts +++ b/src/auto-reply/reply/abort.ts @@ -47,6 +47,35 @@ export { setAbortMemory, }; +const defaultAbortDeps = { + getAcpSessionManager, + abortEmbeddedPiRun, + listSubagentRunsForController, + markSubagentRunTerminated, +}; + +const abortDeps = { + ...defaultAbortDeps, +}; + +export const __testing = { + setDepsForTests(deps: Partial | undefined): void { + abortDeps.getAcpSessionManager = + deps?.getAcpSessionManager ?? defaultAbortDeps.getAcpSessionManager; + abortDeps.abortEmbeddedPiRun = deps?.abortEmbeddedPiRun ?? defaultAbortDeps.abortEmbeddedPiRun; + abortDeps.listSubagentRunsForController = + deps?.listSubagentRunsForController ?? defaultAbortDeps.listSubagentRunsForController; + abortDeps.markSubagentRunTerminated = + deps?.markSubagentRunTerminated ?? defaultAbortDeps.markSubagentRunTerminated; + }, + resetDepsForTests(): void { + abortDeps.getAcpSessionManager = defaultAbortDeps.getAcpSessionManager; + abortDeps.abortEmbeddedPiRun = defaultAbortDeps.abortEmbeddedPiRun; + abortDeps.listSubagentRunsForController = defaultAbortDeps.listSubagentRunsForController; + abortDeps.markSubagentRunTerminated = defaultAbortDeps.markSubagentRunTerminated; + }, +}; + export function formatAbortReplyText(stoppedSubagents?: number): string { if (typeof stoppedSubagents !== "number" || stoppedSubagents <= 0) { return "⚙️ Agent was aborted."; @@ -107,7 +136,7 @@ export function stopSubagentsForRequester(params: { if (!requesterKey) { return { stopped: 0 }; } - const runs = listSubagentRunsForController(requesterKey); + const runs = abortDeps.listSubagentRunsForController(requesterKey); if (runs.length === 0) { return { stopped: 0 }; } @@ -134,9 +163,9 @@ export function stopSubagentsForRequester(params: { } const entry = store[childKey]; const sessionId = entry?.sessionId; - const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false; + const aborted = sessionId ? abortDeps.abortEmbeddedPiRun(sessionId) : false; const markedTerminated = - markSubagentRunTerminated({ + abortDeps.markSubagentRunTerminated({ runId: run.runId, childSessionKey: childKey, reason: "killed", @@ -198,7 +227,7 @@ export async function tryFastAbortFromMessage(params: { const store = loadSessionStore(storePath); const { entry, key, legacyKeys } = resolveSessionEntryForKey(store, targetKey); const resolvedTargetKey = key ?? targetKey; - const acpManager = getAcpSessionManager(); + const acpManager = abortDeps.getAcpSessionManager(); const acpResolution = acpManager.resolveSession({ cfg, sessionKey: resolvedTargetKey, @@ -217,7 +246,7 @@ export async function tryFastAbortFromMessage(params: { } } const sessionId = entry?.sessionId; - const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false; + const aborted = sessionId ? abortDeps.abortEmbeddedPiRun(sessionId) : false; const cleared = clearSessionQueues([resolvedTargetKey, sessionId]); if (cleared.followupCleared > 0 || cleared.laneCleared > 0) { logVerbose( diff --git a/src/auto-reply/reply/queue/cleanup.ts b/src/auto-reply/reply/queue/cleanup.ts index 77b623455bf..0a33337362b 100644 --- a/src/auto-reply/reply/queue/cleanup.ts +++ b/src/auto-reply/reply/queue/cleanup.ts @@ -9,6 +9,29 @@ export type ClearSessionQueueResult = { keys: string[]; }; +const defaultQueueCleanupDeps = { + resolveEmbeddedSessionLane, + clearCommandLane, +}; + +const queueCleanupDeps = { + ...defaultQueueCleanupDeps, +}; + +export const __testing = { + setDepsForTests(deps: Partial | undefined): void { + queueCleanupDeps.resolveEmbeddedSessionLane = + deps?.resolveEmbeddedSessionLane ?? defaultQueueCleanupDeps.resolveEmbeddedSessionLane; + queueCleanupDeps.clearCommandLane = + deps?.clearCommandLane ?? defaultQueueCleanupDeps.clearCommandLane; + }, + resetDepsForTests(): void { + queueCleanupDeps.resolveEmbeddedSessionLane = + defaultQueueCleanupDeps.resolveEmbeddedSessionLane; + queueCleanupDeps.clearCommandLane = defaultQueueCleanupDeps.clearCommandLane; + }, +}; + export function clearSessionQueues(keys: Array): ClearSessionQueueResult { const seen = new Set(); let followupCleared = 0; @@ -24,7 +47,9 @@ export function clearSessionQueues(keys: Array): ClearSessio clearedKeys.push(cleaned); followupCleared += clearFollowupQueue(cleaned); clearFollowupDrainCallback(cleaned); - laneCleared += clearCommandLane(resolveEmbeddedSessionLane(cleaned)); + laneCleared += queueCleanupDeps.clearCommandLane( + queueCleanupDeps.resolveEmbeddedSessionLane(cleaned), + ); } return { followupCleared, laneCleared, keys: clearedKeys }; diff --git a/src/gateway/call.test.ts b/src/gateway/call.test.ts index 1504ec35b9e..615463295c0 100644 --- a/src/gateway/call.test.ts +++ b/src/gateway/call.test.ts @@ -28,54 +28,42 @@ let startMode: StartMode = "hello"; let closeCode = 1006; let closeReason = ""; let helloMethods: string[] | undefined = ["health", "secrets.resolve"]; - -vi.mock("./client.js", () => ({ - describeGatewayCloseCode: (code: number) => { - if (code === 1000) { - return "normal closure"; - } - if (code === 1006) { - return "abnormal closure (no close frame)"; - } - return undefined; - }, - GatewayClient: class { - constructor(opts: { - url?: string; - token?: string; - password?: string; - scopes?: string[]; - onHelloOk?: (hello: { features?: { methods?: string[] } }) => void | Promise; - onClose?: (code: number, reason: string) => void; - }) { - lastClientOptions = opts; - } - async request( - method: string, - params: unknown, - opts?: { expectFinal?: boolean; timeoutMs?: number | null }, - ) { - lastRequestOptions = { method, params, opts }; - return { ok: true }; - } - start() { - if (startMode === "hello") { - void lastClientOptions?.onHelloOk?.({ - features: { - methods: helloMethods, - }, - }); - } else if (startMode === "close") { - lastClientOptions?.onClose?.(closeCode, closeReason); - } - } - stop() {} - }, -})); - -const { buildGatewayConnectionDetails, callGateway, callGatewayCli, callGatewayScoped } = +const { __testing, buildGatewayConnectionDetails, callGateway, callGatewayCli, callGatewayScoped } = await import("./call.js"); +class StubGatewayClient { + constructor(opts: { + url?: string; + token?: string; + password?: string; + scopes?: string[]; + onHelloOk?: (hello: { features?: { methods?: string[] } }) => void | Promise; + onClose?: (code: number, reason: string) => void; + }) { + lastClientOptions = opts; + } + async request( + method: string, + params: unknown, + opts?: { expectFinal?: boolean; timeoutMs?: number | null }, + ) { + lastRequestOptions = { method, params, opts }; + return { ok: true }; + } + start() { + if (startMode === "hello") { + void lastClientOptions?.onHelloOk?.({ + features: { + methods: helloMethods, + }, + }); + } else if (startMode === "close") { + lastClientOptions?.onClose?.(closeCode, closeReason); + } + } + stop() {} +} + function resetGatewayCallMocks() { loadConfig.mockClear(); resolveGatewayPort.mockClear(); @@ -87,6 +75,17 @@ function resetGatewayCallMocks() { closeCode = 1006; closeReason = ""; helloMethods = ["health", "secrets.resolve"]; + const loadConfigForTests = loadConfig as unknown as () => OpenClawConfig; + const resolveGatewayPortForTests = resolveGatewayPort as unknown as ( + cfg?: OpenClawConfig, + env?: NodeJS.ProcessEnv, + ) => number; + __testing.setDepsForTests({ + createGatewayClient: (opts) => + new StubGatewayClient(opts as ConstructorParameters[0]) as never, + loadConfig: loadConfigForTests, + resolveGatewayPort: resolveGatewayPortForTests, + }); } function setGatewayNetworkDefaults(port = 18789) { @@ -126,6 +125,7 @@ describe("callGateway url resolution", () => { afterEach(() => { envSnapshot.restore(); + __testing.resetDepsForTests(); }); it.each([ diff --git a/src/gateway/call.ts b/src/gateway/call.ts index 8e948c7cb30..bb52fa119a2 100644 --- a/src/gateway/call.ts +++ b/src/gateway/call.ts @@ -17,7 +17,7 @@ import { type GatewayClientName, } from "../utils/message-channel.js"; import { VERSION } from "../version.js"; -import { GatewayClient } from "./client.js"; +import { GatewayClient, type GatewayClientOptions } from "./client.js"; import { GatewaySecretRefUnavailableError, resolveGatewayCredentialsFromConfig, @@ -81,6 +81,47 @@ export type GatewayConnectionDetails = { message: string; }; +const defaultCreateGatewayClient = (opts: GatewayClientOptions) => new GatewayClient(opts); +const defaultGatewayCallDeps = { + createGatewayClient: defaultCreateGatewayClient, + loadConfig, + resolveGatewayPort, + resolveConfigPath, + resolveStateDir, + loadGatewayTlsRuntime, +}; +const gatewayCallDeps = { + ...defaultGatewayCallDeps, +}; + +export const __testing = { + setDepsForTests(deps: Partial | undefined): void { + gatewayCallDeps.createGatewayClient = + deps?.createGatewayClient ?? defaultGatewayCallDeps.createGatewayClient; + gatewayCallDeps.loadConfig = deps?.loadConfig ?? defaultGatewayCallDeps.loadConfig; + gatewayCallDeps.resolveGatewayPort = + deps?.resolveGatewayPort ?? defaultGatewayCallDeps.resolveGatewayPort; + gatewayCallDeps.resolveConfigPath = + deps?.resolveConfigPath ?? defaultGatewayCallDeps.resolveConfigPath; + gatewayCallDeps.resolveStateDir = + deps?.resolveStateDir ?? defaultGatewayCallDeps.resolveStateDir; + gatewayCallDeps.loadGatewayTlsRuntime = + deps?.loadGatewayTlsRuntime ?? defaultGatewayCallDeps.loadGatewayTlsRuntime; + }, + setCreateGatewayClientForTests(createGatewayClient?: typeof defaultCreateGatewayClient): void { + gatewayCallDeps.createGatewayClient = + createGatewayClient ?? defaultGatewayCallDeps.createGatewayClient; + }, + resetDepsForTests(): void { + gatewayCallDeps.createGatewayClient = defaultGatewayCallDeps.createGatewayClient; + gatewayCallDeps.loadConfig = defaultGatewayCallDeps.loadConfig; + gatewayCallDeps.resolveGatewayPort = defaultGatewayCallDeps.resolveGatewayPort; + gatewayCallDeps.resolveConfigPath = defaultGatewayCallDeps.resolveConfigPath; + gatewayCallDeps.resolveStateDir = defaultGatewayCallDeps.resolveStateDir; + gatewayCallDeps.loadGatewayTlsRuntime = defaultGatewayCallDeps.loadGatewayTlsRuntime; + }, +}; + function shouldAttachDeviceIdentityForGatewayCall(params: { url: string; token?: string; @@ -155,13 +196,14 @@ export function buildGatewayConnectionDetails( urlSource?: "cli" | "env"; } = {}, ): GatewayConnectionDetails { - const config = options.config ?? loadConfig(); + const config = options.config ?? gatewayCallDeps.loadConfig(); const configPath = - options.configPath ?? resolveConfigPath(process.env, resolveStateDir(process.env)); + options.configPath ?? + gatewayCallDeps.resolveConfigPath(process.env, gatewayCallDeps.resolveStateDir(process.env)); const isRemoteMode = config.gateway?.mode === "remote"; const remote = isRemoteMode ? config.gateway?.remote : undefined; const tlsEnabled = config.gateway?.tls?.enabled === true; - const localPort = resolveGatewayPort(config); + const localPort = gatewayCallDeps.resolveGatewayPort(config); const bindMode = config.gateway?.bind ?? "loopback"; const scheme = tlsEnabled ? "wss" : "ws"; // Self-connections should always target loopback; bind mode only controls listener exposure. @@ -273,9 +315,10 @@ function resolveGatewayCallTimeout(timeoutValue: unknown): { } function resolveGatewayCallContext(opts: CallGatewayBaseOptions): ResolvedGatewayCallContext { - const config = opts.config ?? loadConfig(); + const config = opts.config ?? gatewayCallDeps.loadConfig(); const configPath = - opts.configPath ?? resolveConfigPath(process.env, resolveStateDir(process.env)); + opts.configPath ?? + gatewayCallDeps.resolveConfigPath(process.env, gatewayCallDeps.resolveStateDir(process.env)); const isRemoteMode = config.gateway?.mode === "remote"; const remote = isRemoteMode ? (config.gateway?.remote as GatewayRemoteSettings | undefined) @@ -683,7 +726,10 @@ export async function resolveGatewayCredentialsWithSecretInputs(params: { : undefined; const context: ResolvedGatewayCallContext = { config: params.config, - configPath: resolveConfigPath(process.env, resolveStateDir(process.env)), + configPath: gatewayCallDeps.resolveConfigPath( + process.env, + gatewayCallDeps.resolveStateDir(process.env), + ), isRemoteMode, remote: remoteFromOverride ?? remoteFromConfig, urlOverride: trimToUndefined(params.urlOverride), @@ -715,7 +761,7 @@ async function resolveGatewayTlsFingerprint(params: { !context.remoteUrl && url.startsWith("wss://"); const tlsRuntime = useLocalTls - ? await loadGatewayTlsRuntime(context.config.gateway?.tls) + ? await gatewayCallDeps.loadGatewayTlsRuntime(context.config.gateway?.tls) : undefined; const overrideTlsFingerprint = trimToUndefined(opts.tlsFingerprint); const remoteTlsFingerprint = @@ -809,7 +855,7 @@ async function executeGatewayRequestWithScopes(params: { } }; - const client = new GatewayClient({ + const client = gatewayCallDeps.createGatewayClient({ url, token, password,