mirror of https://github.com/openclaw/openclaw.git
test: inject thread-safe base seams
This commit is contained in:
parent
8fd2fa13c6
commit
47db5abece
|
|
@ -1,9 +1,22 @@
|
|||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import type { Context, Model, SimpleStreamOptions } from "@mariozechner/pi-ai";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { __testing as extraParamsTesting } from "./pi-embedded-runner/extra-params.js";
|
||||
import {
|
||||
createOpenRouterSystemCacheWrapper,
|
||||
createOpenRouterWrapper,
|
||||
isProxyReasoningUnsupported,
|
||||
} from "./pi-embedded-runner/proxy-stream-wrappers.js";
|
||||
import type { ProviderCapabilities } from "./provider-capabilities.js";
|
||||
import { __testing as providerCapabilitiesTesting } from "./provider-capabilities.js";
|
||||
|
||||
const resolveProviderCapabilitiesWithPluginMock = vi.fn(
|
||||
(params: { provider: string; workspaceDir?: string }) => {
|
||||
(params: {
|
||||
provider: string;
|
||||
config?: import("../config/config.js").OpenClawConfig;
|
||||
workspaceDir?: string;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
}): Partial<ProviderCapabilities> | undefined => {
|
||||
if (
|
||||
params.provider === "workspace-anthropic-proxy" &&
|
||||
params.workspaceDir === "/tmp/workspace-capabilities"
|
||||
|
|
@ -17,20 +30,12 @@ const resolveProviderCapabilitiesWithPluginMock = vi.fn(
|
|||
},
|
||||
);
|
||||
|
||||
vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("../plugins/provider-runtime.js")>();
|
||||
const {
|
||||
createOpenRouterSystemCacheWrapper,
|
||||
createOpenRouterWrapper,
|
||||
isProxyReasoningUnsupported,
|
||||
} = await import("./pi-embedded-runner/proxy-stream-wrappers.js");
|
||||
import { applyExtraParamsToAgent, resolveExtraParams } from "./pi-embedded-runner.js";
|
||||
import { log } from "./pi-embedded-runner/logger.js";
|
||||
|
||||
return {
|
||||
...actual,
|
||||
prepareProviderExtraParams: (params: {
|
||||
provider: string;
|
||||
context: { extraParams?: Record<string, unknown> };
|
||||
}) => {
|
||||
beforeEach(() => {
|
||||
extraParamsTesting.setProviderRuntimeDepsForTest({
|
||||
prepareProviderExtraParams: (params) => {
|
||||
if (params.provider !== "openai-codex") {
|
||||
return undefined;
|
||||
}
|
||||
|
|
@ -43,15 +48,7 @@ vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
|
|||
transport: "auto",
|
||||
};
|
||||
},
|
||||
wrapProviderStreamFn: (params: {
|
||||
provider: string;
|
||||
context: {
|
||||
modelId: string;
|
||||
thinkingLevel?: import("../auto-reply/thinking.js").ThinkLevel;
|
||||
extraParams?: Record<string, unknown>;
|
||||
streamFn?: StreamFn;
|
||||
};
|
||||
}) => {
|
||||
wrapProviderStreamFn: (params) => {
|
||||
if (params.provider !== "openrouter") {
|
||||
return params.context.streamFn;
|
||||
}
|
||||
|
|
@ -80,13 +77,17 @@ vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
|
|||
const thinkingLevel = skipReasoningInjection ? undefined : params.context.thinkingLevel;
|
||||
return createOpenRouterSystemCacheWrapper(createOpenRouterWrapper(streamFn, thinkingLevel));
|
||||
},
|
||||
resolveProviderCapabilitiesWithPlugin: (params: { provider: string; workspaceDir?: string }) =>
|
||||
resolveProviderCapabilitiesWithPluginMock(params),
|
||||
};
|
||||
});
|
||||
providerCapabilitiesTesting.setResolveProviderCapabilitiesWithPluginForTest(
|
||||
resolveProviderCapabilitiesWithPluginMock,
|
||||
);
|
||||
resolveProviderCapabilitiesWithPluginMock.mockClear();
|
||||
});
|
||||
|
||||
import { applyExtraParamsToAgent, resolveExtraParams } from "./pi-embedded-runner.js";
|
||||
import { log } from "./pi-embedded-runner/logger.js";
|
||||
afterEach(() => {
|
||||
extraParamsTesting.resetProviderRuntimeDepsForTest();
|
||||
providerCapabilitiesTesting.resetDepsForTests();
|
||||
});
|
||||
|
||||
describe("resolveExtraParams", () => {
|
||||
it("returns undefined with no model config", () => {
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import { streamSimple } from "@mariozechner/pi-ai";
|
|||
import type { ThinkLevel } from "../../auto-reply/thinking.js";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import {
|
||||
prepareProviderExtraParams,
|
||||
wrapProviderStreamFn,
|
||||
prepareProviderExtraParams as prepareProviderExtraParamsRuntime,
|
||||
wrapProviderStreamFn as wrapProviderStreamFnRuntime,
|
||||
} from "../../plugins/provider-runtime.js";
|
||||
import {
|
||||
createAnthropicBetaHeadersWrapper,
|
||||
|
|
@ -38,6 +38,31 @@ import {
|
|||
} from "./openai-stream-wrappers.js";
|
||||
import { createXaiFastModeWrapper } from "./xai-stream-wrappers.js";
|
||||
|
||||
const defaultProviderRuntimeDeps = {
|
||||
prepareProviderExtraParams: prepareProviderExtraParamsRuntime,
|
||||
wrapProviderStreamFn: wrapProviderStreamFnRuntime,
|
||||
};
|
||||
|
||||
const providerRuntimeDeps = {
|
||||
...defaultProviderRuntimeDeps,
|
||||
};
|
||||
|
||||
export const __testing = {
|
||||
setProviderRuntimeDepsForTest(
|
||||
deps: Partial<typeof defaultProviderRuntimeDeps> | undefined,
|
||||
): void {
|
||||
providerRuntimeDeps.prepareProviderExtraParams =
|
||||
deps?.prepareProviderExtraParams ?? defaultProviderRuntimeDeps.prepareProviderExtraParams;
|
||||
providerRuntimeDeps.wrapProviderStreamFn =
|
||||
deps?.wrapProviderStreamFn ?? defaultProviderRuntimeDeps.wrapProviderStreamFn;
|
||||
},
|
||||
resetProviderRuntimeDepsForTest(): void {
|
||||
providerRuntimeDeps.prepareProviderExtraParams =
|
||||
defaultProviderRuntimeDeps.prepareProviderExtraParams;
|
||||
providerRuntimeDeps.wrapProviderStreamFn = defaultProviderRuntimeDeps.wrapProviderStreamFn;
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Resolve provider-specific extra params from model config.
|
||||
* Used to pass through stream params like temperature/maxTokens.
|
||||
|
|
@ -206,7 +231,7 @@ export function applyExtraParamsToAgent(
|
|||
: undefined;
|
||||
const merged = Object.assign({}, resolvedExtraParams, override);
|
||||
const effectiveExtraParams =
|
||||
prepareProviderExtraParams({
|
||||
providerRuntimeDeps.prepareProviderExtraParams({
|
||||
provider,
|
||||
config: cfg,
|
||||
context: {
|
||||
|
|
@ -257,7 +282,7 @@ export function applyExtraParamsToAgent(
|
|||
workspaceDir,
|
||||
});
|
||||
const providerStreamBase = agent.streamFn;
|
||||
const pluginWrappedStreamFn = wrapProviderStreamFn({
|
||||
const pluginWrappedStreamFn = providerRuntimeDeps.wrapProviderStreamFn({
|
||||
provider,
|
||||
config: cfg,
|
||||
context: {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import { resolveProviderCapabilitiesWithPlugin } from "../plugins/provider-runtime.js";
|
||||
import { resolveProviderCapabilitiesWithPlugin as resolveProviderCapabilitiesWithPluginRuntime } from "../plugins/provider-runtime.js";
|
||||
import { normalizeProviderId } from "./model-selection.js";
|
||||
|
||||
export type ProviderCapabilities = {
|
||||
|
|
@ -82,13 +82,31 @@ const PLUGIN_CAPABILITIES_FALLBACKS: Record<string, Partial<ProviderCapabilities
|
|||
},
|
||||
};
|
||||
|
||||
const defaultResolveProviderCapabilitiesWithPlugin = resolveProviderCapabilitiesWithPluginRuntime;
|
||||
const providerCapabilityDeps = {
|
||||
resolveProviderCapabilitiesWithPlugin: defaultResolveProviderCapabilitiesWithPlugin,
|
||||
};
|
||||
|
||||
export const __testing = {
|
||||
setResolveProviderCapabilitiesWithPluginForTest(
|
||||
resolveProviderCapabilitiesWithPlugin?: typeof defaultResolveProviderCapabilitiesWithPlugin,
|
||||
): void {
|
||||
providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin =
|
||||
resolveProviderCapabilitiesWithPlugin ?? defaultResolveProviderCapabilitiesWithPlugin;
|
||||
},
|
||||
resetDepsForTests(): void {
|
||||
providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin =
|
||||
defaultResolveProviderCapabilitiesWithPlugin;
|
||||
},
|
||||
};
|
||||
|
||||
export function resolveProviderCapabilities(
|
||||
provider?: string | null,
|
||||
options?: ProviderCapabilityLookupOptions,
|
||||
): ProviderCapabilities {
|
||||
const normalized = normalizeProviderId(provider ?? "");
|
||||
const pluginCapabilities = normalized
|
||||
? resolveProviderCapabilitiesWithPlugin({
|
||||
? providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin({
|
||||
provider: normalized,
|
||||
config: options?.config,
|
||||
workspaceDir: options?.workspaceDir,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { SubagentRunRecord } from "../../agents/subagent-registry.js";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import {
|
||||
__testing as abortTesting,
|
||||
getAbortMemory,
|
||||
getAbortMemorySizeForTest,
|
||||
isAbortRequestText,
|
||||
|
|
@ -17,6 +18,7 @@ import {
|
|||
tryFastAbortFromMessage,
|
||||
} from "./abort.js";
|
||||
import { enqueueFollowupRun, getFollowupQueueDepth, type FollowupRun } from "./queue.js";
|
||||
import { __testing as queueCleanupTesting } from "./queue/cleanup.js";
|
||||
import { initSessionState } from "./session.js";
|
||||
import { buildTestCtx } from "./test-ctx.js";
|
||||
|
||||
|
|
@ -26,7 +28,7 @@ vi.mock("../../agents/pi-embedded.js", () => ({
|
|||
}));
|
||||
|
||||
const commandQueueMocks = vi.hoisted(() => ({
|
||||
clearCommandLane: vi.fn(),
|
||||
clearCommandLane: vi.fn(() => 1),
|
||||
}));
|
||||
|
||||
vi.mock("../../process/command-queue.js", () => commandQueueMocks);
|
||||
|
|
@ -162,8 +164,29 @@ describe("abort detection", () => {
|
|||
expect(commandQueueMocks.clearCommandLane).toHaveBeenCalledWith(`session:${sessionKey}`);
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
abortTesting.setDepsForTests({
|
||||
getAcpSessionManager: (() =>
|
||||
({
|
||||
resolveSession: acpManagerMocks.resolveSession,
|
||||
cancelSession: acpManagerMocks.cancelSession,
|
||||
}) as never) as never,
|
||||
abortEmbeddedPiRun: () => true,
|
||||
listSubagentRunsForController: subagentRegistryMocks.listSubagentRunsForRequester,
|
||||
markSubagentRunTerminated: subagentRegistryMocks.markSubagentRunTerminated,
|
||||
});
|
||||
queueCleanupTesting.setDepsForTests({
|
||||
resolveEmbeddedSessionLane: (key) => `session:${key.trim() || "main"}`,
|
||||
clearCommandLane: commandQueueMocks.clearCommandLane,
|
||||
});
|
||||
commandQueueMocks.clearCommandLane.mockClear().mockReturnValue(1);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
resetAbortMemoryForTest();
|
||||
abortTesting.resetDepsForTests();
|
||||
queueCleanupTesting.resetDepsForTests();
|
||||
commandQueueMocks.clearCommandLane.mockClear().mockReturnValue(1);
|
||||
acpManagerMocks.resolveSession.mockReset().mockReturnValue({ kind: "none" });
|
||||
acpManagerMocks.cancelSession.mockReset().mockResolvedValue(undefined);
|
||||
});
|
||||
|
|
|
|||
|
|
@ -47,6 +47,35 @@ export {
|
|||
setAbortMemory,
|
||||
};
|
||||
|
||||
const defaultAbortDeps = {
|
||||
getAcpSessionManager,
|
||||
abortEmbeddedPiRun,
|
||||
listSubagentRunsForController,
|
||||
markSubagentRunTerminated,
|
||||
};
|
||||
|
||||
const abortDeps = {
|
||||
...defaultAbortDeps,
|
||||
};
|
||||
|
||||
export const __testing = {
|
||||
setDepsForTests(deps: Partial<typeof defaultAbortDeps> | undefined): void {
|
||||
abortDeps.getAcpSessionManager =
|
||||
deps?.getAcpSessionManager ?? defaultAbortDeps.getAcpSessionManager;
|
||||
abortDeps.abortEmbeddedPiRun = deps?.abortEmbeddedPiRun ?? defaultAbortDeps.abortEmbeddedPiRun;
|
||||
abortDeps.listSubagentRunsForController =
|
||||
deps?.listSubagentRunsForController ?? defaultAbortDeps.listSubagentRunsForController;
|
||||
abortDeps.markSubagentRunTerminated =
|
||||
deps?.markSubagentRunTerminated ?? defaultAbortDeps.markSubagentRunTerminated;
|
||||
},
|
||||
resetDepsForTests(): void {
|
||||
abortDeps.getAcpSessionManager = defaultAbortDeps.getAcpSessionManager;
|
||||
abortDeps.abortEmbeddedPiRun = defaultAbortDeps.abortEmbeddedPiRun;
|
||||
abortDeps.listSubagentRunsForController = defaultAbortDeps.listSubagentRunsForController;
|
||||
abortDeps.markSubagentRunTerminated = defaultAbortDeps.markSubagentRunTerminated;
|
||||
},
|
||||
};
|
||||
|
||||
export function formatAbortReplyText(stoppedSubagents?: number): string {
|
||||
if (typeof stoppedSubagents !== "number" || stoppedSubagents <= 0) {
|
||||
return "⚙️ Agent was aborted.";
|
||||
|
|
@ -107,7 +136,7 @@ export function stopSubagentsForRequester(params: {
|
|||
if (!requesterKey) {
|
||||
return { stopped: 0 };
|
||||
}
|
||||
const runs = listSubagentRunsForController(requesterKey);
|
||||
const runs = abortDeps.listSubagentRunsForController(requesterKey);
|
||||
if (runs.length === 0) {
|
||||
return { stopped: 0 };
|
||||
}
|
||||
|
|
@ -134,9 +163,9 @@ export function stopSubagentsForRequester(params: {
|
|||
}
|
||||
const entry = store[childKey];
|
||||
const sessionId = entry?.sessionId;
|
||||
const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false;
|
||||
const aborted = sessionId ? abortDeps.abortEmbeddedPiRun(sessionId) : false;
|
||||
const markedTerminated =
|
||||
markSubagentRunTerminated({
|
||||
abortDeps.markSubagentRunTerminated({
|
||||
runId: run.runId,
|
||||
childSessionKey: childKey,
|
||||
reason: "killed",
|
||||
|
|
@ -198,7 +227,7 @@ export async function tryFastAbortFromMessage(params: {
|
|||
const store = loadSessionStore(storePath);
|
||||
const { entry, key, legacyKeys } = resolveSessionEntryForKey(store, targetKey);
|
||||
const resolvedTargetKey = key ?? targetKey;
|
||||
const acpManager = getAcpSessionManager();
|
||||
const acpManager = abortDeps.getAcpSessionManager();
|
||||
const acpResolution = acpManager.resolveSession({
|
||||
cfg,
|
||||
sessionKey: resolvedTargetKey,
|
||||
|
|
@ -217,7 +246,7 @@ export async function tryFastAbortFromMessage(params: {
|
|||
}
|
||||
}
|
||||
const sessionId = entry?.sessionId;
|
||||
const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false;
|
||||
const aborted = sessionId ? abortDeps.abortEmbeddedPiRun(sessionId) : false;
|
||||
const cleared = clearSessionQueues([resolvedTargetKey, sessionId]);
|
||||
if (cleared.followupCleared > 0 || cleared.laneCleared > 0) {
|
||||
logVerbose(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,29 @@ export type ClearSessionQueueResult = {
|
|||
keys: string[];
|
||||
};
|
||||
|
||||
const defaultQueueCleanupDeps = {
|
||||
resolveEmbeddedSessionLane,
|
||||
clearCommandLane,
|
||||
};
|
||||
|
||||
const queueCleanupDeps = {
|
||||
...defaultQueueCleanupDeps,
|
||||
};
|
||||
|
||||
export const __testing = {
|
||||
setDepsForTests(deps: Partial<typeof defaultQueueCleanupDeps> | undefined): void {
|
||||
queueCleanupDeps.resolveEmbeddedSessionLane =
|
||||
deps?.resolveEmbeddedSessionLane ?? defaultQueueCleanupDeps.resolveEmbeddedSessionLane;
|
||||
queueCleanupDeps.clearCommandLane =
|
||||
deps?.clearCommandLane ?? defaultQueueCleanupDeps.clearCommandLane;
|
||||
},
|
||||
resetDepsForTests(): void {
|
||||
queueCleanupDeps.resolveEmbeddedSessionLane =
|
||||
defaultQueueCleanupDeps.resolveEmbeddedSessionLane;
|
||||
queueCleanupDeps.clearCommandLane = defaultQueueCleanupDeps.clearCommandLane;
|
||||
},
|
||||
};
|
||||
|
||||
export function clearSessionQueues(keys: Array<string | undefined>): ClearSessionQueueResult {
|
||||
const seen = new Set<string>();
|
||||
let followupCleared = 0;
|
||||
|
|
@ -24,7 +47,9 @@ export function clearSessionQueues(keys: Array<string | undefined>): ClearSessio
|
|||
clearedKeys.push(cleaned);
|
||||
followupCleared += clearFollowupQueue(cleaned);
|
||||
clearFollowupDrainCallback(cleaned);
|
||||
laneCleared += clearCommandLane(resolveEmbeddedSessionLane(cleaned));
|
||||
laneCleared += queueCleanupDeps.clearCommandLane(
|
||||
queueCleanupDeps.resolveEmbeddedSessionLane(cleaned),
|
||||
);
|
||||
}
|
||||
|
||||
return { followupCleared, laneCleared, keys: clearedKeys };
|
||||
|
|
|
|||
|
|
@ -28,54 +28,42 @@ let startMode: StartMode = "hello";
|
|||
let closeCode = 1006;
|
||||
let closeReason = "";
|
||||
let helloMethods: string[] | undefined = ["health", "secrets.resolve"];
|
||||
|
||||
vi.mock("./client.js", () => ({
|
||||
describeGatewayCloseCode: (code: number) => {
|
||||
if (code === 1000) {
|
||||
return "normal closure";
|
||||
}
|
||||
if (code === 1006) {
|
||||
return "abnormal closure (no close frame)";
|
||||
}
|
||||
return undefined;
|
||||
},
|
||||
GatewayClient: class {
|
||||
constructor(opts: {
|
||||
url?: string;
|
||||
token?: string;
|
||||
password?: string;
|
||||
scopes?: string[];
|
||||
onHelloOk?: (hello: { features?: { methods?: string[] } }) => void | Promise<void>;
|
||||
onClose?: (code: number, reason: string) => void;
|
||||
}) {
|
||||
lastClientOptions = opts;
|
||||
}
|
||||
async request(
|
||||
method: string,
|
||||
params: unknown,
|
||||
opts?: { expectFinal?: boolean; timeoutMs?: number | null },
|
||||
) {
|
||||
lastRequestOptions = { method, params, opts };
|
||||
return { ok: true };
|
||||
}
|
||||
start() {
|
||||
if (startMode === "hello") {
|
||||
void lastClientOptions?.onHelloOk?.({
|
||||
features: {
|
||||
methods: helloMethods,
|
||||
},
|
||||
});
|
||||
} else if (startMode === "close") {
|
||||
lastClientOptions?.onClose?.(closeCode, closeReason);
|
||||
}
|
||||
}
|
||||
stop() {}
|
||||
},
|
||||
}));
|
||||
|
||||
const { buildGatewayConnectionDetails, callGateway, callGatewayCli, callGatewayScoped } =
|
||||
const { __testing, buildGatewayConnectionDetails, callGateway, callGatewayCli, callGatewayScoped } =
|
||||
await import("./call.js");
|
||||
|
||||
class StubGatewayClient {
|
||||
constructor(opts: {
|
||||
url?: string;
|
||||
token?: string;
|
||||
password?: string;
|
||||
scopes?: string[];
|
||||
onHelloOk?: (hello: { features?: { methods?: string[] } }) => void | Promise<void>;
|
||||
onClose?: (code: number, reason: string) => void;
|
||||
}) {
|
||||
lastClientOptions = opts;
|
||||
}
|
||||
async request(
|
||||
method: string,
|
||||
params: unknown,
|
||||
opts?: { expectFinal?: boolean; timeoutMs?: number | null },
|
||||
) {
|
||||
lastRequestOptions = { method, params, opts };
|
||||
return { ok: true };
|
||||
}
|
||||
start() {
|
||||
if (startMode === "hello") {
|
||||
void lastClientOptions?.onHelloOk?.({
|
||||
features: {
|
||||
methods: helloMethods,
|
||||
},
|
||||
});
|
||||
} else if (startMode === "close") {
|
||||
lastClientOptions?.onClose?.(closeCode, closeReason);
|
||||
}
|
||||
}
|
||||
stop() {}
|
||||
}
|
||||
|
||||
function resetGatewayCallMocks() {
|
||||
loadConfig.mockClear();
|
||||
resolveGatewayPort.mockClear();
|
||||
|
|
@ -87,6 +75,17 @@ function resetGatewayCallMocks() {
|
|||
closeCode = 1006;
|
||||
closeReason = "";
|
||||
helloMethods = ["health", "secrets.resolve"];
|
||||
const loadConfigForTests = loadConfig as unknown as () => OpenClawConfig;
|
||||
const resolveGatewayPortForTests = resolveGatewayPort as unknown as (
|
||||
cfg?: OpenClawConfig,
|
||||
env?: NodeJS.ProcessEnv,
|
||||
) => number;
|
||||
__testing.setDepsForTests({
|
||||
createGatewayClient: (opts) =>
|
||||
new StubGatewayClient(opts as ConstructorParameters<typeof StubGatewayClient>[0]) as never,
|
||||
loadConfig: loadConfigForTests,
|
||||
resolveGatewayPort: resolveGatewayPortForTests,
|
||||
});
|
||||
}
|
||||
|
||||
function setGatewayNetworkDefaults(port = 18789) {
|
||||
|
|
@ -126,6 +125,7 @@ describe("callGateway url resolution", () => {
|
|||
|
||||
afterEach(() => {
|
||||
envSnapshot.restore();
|
||||
__testing.resetDepsForTests();
|
||||
});
|
||||
|
||||
it.each([
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import {
|
|||
type GatewayClientName,
|
||||
} from "../utils/message-channel.js";
|
||||
import { VERSION } from "../version.js";
|
||||
import { GatewayClient } from "./client.js";
|
||||
import { GatewayClient, type GatewayClientOptions } from "./client.js";
|
||||
import {
|
||||
GatewaySecretRefUnavailableError,
|
||||
resolveGatewayCredentialsFromConfig,
|
||||
|
|
@ -81,6 +81,47 @@ export type GatewayConnectionDetails = {
|
|||
message: string;
|
||||
};
|
||||
|
||||
const defaultCreateGatewayClient = (opts: GatewayClientOptions) => new GatewayClient(opts);
|
||||
const defaultGatewayCallDeps = {
|
||||
createGatewayClient: defaultCreateGatewayClient,
|
||||
loadConfig,
|
||||
resolveGatewayPort,
|
||||
resolveConfigPath,
|
||||
resolveStateDir,
|
||||
loadGatewayTlsRuntime,
|
||||
};
|
||||
const gatewayCallDeps = {
|
||||
...defaultGatewayCallDeps,
|
||||
};
|
||||
|
||||
export const __testing = {
|
||||
setDepsForTests(deps: Partial<typeof defaultGatewayCallDeps> | undefined): void {
|
||||
gatewayCallDeps.createGatewayClient =
|
||||
deps?.createGatewayClient ?? defaultGatewayCallDeps.createGatewayClient;
|
||||
gatewayCallDeps.loadConfig = deps?.loadConfig ?? defaultGatewayCallDeps.loadConfig;
|
||||
gatewayCallDeps.resolveGatewayPort =
|
||||
deps?.resolveGatewayPort ?? defaultGatewayCallDeps.resolveGatewayPort;
|
||||
gatewayCallDeps.resolveConfigPath =
|
||||
deps?.resolveConfigPath ?? defaultGatewayCallDeps.resolveConfigPath;
|
||||
gatewayCallDeps.resolveStateDir =
|
||||
deps?.resolveStateDir ?? defaultGatewayCallDeps.resolveStateDir;
|
||||
gatewayCallDeps.loadGatewayTlsRuntime =
|
||||
deps?.loadGatewayTlsRuntime ?? defaultGatewayCallDeps.loadGatewayTlsRuntime;
|
||||
},
|
||||
setCreateGatewayClientForTests(createGatewayClient?: typeof defaultCreateGatewayClient): void {
|
||||
gatewayCallDeps.createGatewayClient =
|
||||
createGatewayClient ?? defaultGatewayCallDeps.createGatewayClient;
|
||||
},
|
||||
resetDepsForTests(): void {
|
||||
gatewayCallDeps.createGatewayClient = defaultGatewayCallDeps.createGatewayClient;
|
||||
gatewayCallDeps.loadConfig = defaultGatewayCallDeps.loadConfig;
|
||||
gatewayCallDeps.resolveGatewayPort = defaultGatewayCallDeps.resolveGatewayPort;
|
||||
gatewayCallDeps.resolveConfigPath = defaultGatewayCallDeps.resolveConfigPath;
|
||||
gatewayCallDeps.resolveStateDir = defaultGatewayCallDeps.resolveStateDir;
|
||||
gatewayCallDeps.loadGatewayTlsRuntime = defaultGatewayCallDeps.loadGatewayTlsRuntime;
|
||||
},
|
||||
};
|
||||
|
||||
function shouldAttachDeviceIdentityForGatewayCall(params: {
|
||||
url: string;
|
||||
token?: string;
|
||||
|
|
@ -155,13 +196,14 @@ export function buildGatewayConnectionDetails(
|
|||
urlSource?: "cli" | "env";
|
||||
} = {},
|
||||
): GatewayConnectionDetails {
|
||||
const config = options.config ?? loadConfig();
|
||||
const config = options.config ?? gatewayCallDeps.loadConfig();
|
||||
const configPath =
|
||||
options.configPath ?? resolveConfigPath(process.env, resolveStateDir(process.env));
|
||||
options.configPath ??
|
||||
gatewayCallDeps.resolveConfigPath(process.env, gatewayCallDeps.resolveStateDir(process.env));
|
||||
const isRemoteMode = config.gateway?.mode === "remote";
|
||||
const remote = isRemoteMode ? config.gateway?.remote : undefined;
|
||||
const tlsEnabled = config.gateway?.tls?.enabled === true;
|
||||
const localPort = resolveGatewayPort(config);
|
||||
const localPort = gatewayCallDeps.resolveGatewayPort(config);
|
||||
const bindMode = config.gateway?.bind ?? "loopback";
|
||||
const scheme = tlsEnabled ? "wss" : "ws";
|
||||
// Self-connections should always target loopback; bind mode only controls listener exposure.
|
||||
|
|
@ -273,9 +315,10 @@ function resolveGatewayCallTimeout(timeoutValue: unknown): {
|
|||
}
|
||||
|
||||
function resolveGatewayCallContext(opts: CallGatewayBaseOptions): ResolvedGatewayCallContext {
|
||||
const config = opts.config ?? loadConfig();
|
||||
const config = opts.config ?? gatewayCallDeps.loadConfig();
|
||||
const configPath =
|
||||
opts.configPath ?? resolveConfigPath(process.env, resolveStateDir(process.env));
|
||||
opts.configPath ??
|
||||
gatewayCallDeps.resolveConfigPath(process.env, gatewayCallDeps.resolveStateDir(process.env));
|
||||
const isRemoteMode = config.gateway?.mode === "remote";
|
||||
const remote = isRemoteMode
|
||||
? (config.gateway?.remote as GatewayRemoteSettings | undefined)
|
||||
|
|
@ -683,7 +726,10 @@ export async function resolveGatewayCredentialsWithSecretInputs(params: {
|
|||
: undefined;
|
||||
const context: ResolvedGatewayCallContext = {
|
||||
config: params.config,
|
||||
configPath: resolveConfigPath(process.env, resolveStateDir(process.env)),
|
||||
configPath: gatewayCallDeps.resolveConfigPath(
|
||||
process.env,
|
||||
gatewayCallDeps.resolveStateDir(process.env),
|
||||
),
|
||||
isRemoteMode,
|
||||
remote: remoteFromOverride ?? remoteFromConfig,
|
||||
urlOverride: trimToUndefined(params.urlOverride),
|
||||
|
|
@ -715,7 +761,7 @@ async function resolveGatewayTlsFingerprint(params: {
|
|||
!context.remoteUrl &&
|
||||
url.startsWith("wss://");
|
||||
const tlsRuntime = useLocalTls
|
||||
? await loadGatewayTlsRuntime(context.config.gateway?.tls)
|
||||
? await gatewayCallDeps.loadGatewayTlsRuntime(context.config.gateway?.tls)
|
||||
: undefined;
|
||||
const overrideTlsFingerprint = trimToUndefined(opts.tlsFingerprint);
|
||||
const remoteTlsFingerprint =
|
||||
|
|
@ -809,7 +855,7 @@ async function executeGatewayRequestWithScopes<T>(params: {
|
|||
}
|
||||
};
|
||||
|
||||
const client = new GatewayClient({
|
||||
const client = gatewayCallDeps.createGatewayClient({
|
||||
url,
|
||||
token,
|
||||
password,
|
||||
|
|
|
|||
Loading…
Reference in New Issue