diff --git a/src/gateway/server-methods/devices.test.ts b/src/gateway/server-methods/devices.test.ts new file mode 100644 index 00000000000..1d25a9cfd91 --- /dev/null +++ b/src/gateway/server-methods/devices.test.ts @@ -0,0 +1,120 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { deviceHandlers } from "./devices.js"; +import type { GatewayRequestHandlerOptions } from "./types.js"; + +const { removePairedDeviceMock, revokeDeviceTokenMock } = vi.hoisted(() => ({ + removePairedDeviceMock: vi.fn(), + revokeDeviceTokenMock: vi.fn(), +})); + +vi.mock("../../infra/device-pairing.js", async () => { + const actual = await vi.importActual( + "../../infra/device-pairing.js", + ); + return { + ...actual, + removePairedDevice: removePairedDeviceMock, + revokeDeviceToken: revokeDeviceTokenMock, + }; +}); + +function createOptions( + method: string, + params: Record, + overrides?: Partial, +): GatewayRequestHandlerOptions { + return { + req: { type: "req", id: "req-1", method, params }, + params, + client: null, + isWebchatConnect: () => false, + respond: vi.fn(), + context: { + disconnectClientsForDevice: vi.fn(), + logGateway: { + debug: vi.fn(), + error: vi.fn(), + info: vi.fn(), + warn: vi.fn(), + }, + }, + ...overrides, + } as unknown as GatewayRequestHandlerOptions; +} + +describe("deviceHandlers", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("disconnects active clients after removing a paired device", async () => { + removePairedDeviceMock.mockResolvedValue({ deviceId: "device-1", removedAtMs: 123 }); + const opts = createOptions("device.pair.remove", { deviceId: " device-1 " }); + + await deviceHandlers["device.pair.remove"](opts); + await Promise.resolve(); + + expect(removePairedDeviceMock).toHaveBeenCalledWith(" device-1 "); + expect(opts.context.disconnectClientsForDevice).toHaveBeenCalledWith("device-1"); + expect(opts.respond).toHaveBeenCalledWith( + true, + { deviceId: "device-1", removedAtMs: 123 }, + undefined, + ); + }); + + it("does not disconnect clients when device removal fails", async () => { + removePairedDeviceMock.mockResolvedValue(null); + const opts = createOptions("device.pair.remove", { deviceId: "device-1" }); + + await deviceHandlers["device.pair.remove"](opts); + + expect(opts.context.disconnectClientsForDevice).not.toHaveBeenCalled(); + expect(opts.respond).toHaveBeenCalledWith( + false, + undefined, + expect.objectContaining({ message: "unknown deviceId" }), + ); + }); + + it("disconnects active clients after revoking a device token", async () => { + revokeDeviceTokenMock.mockResolvedValue({ role: "operator", revokedAtMs: 456 }); + const opts = createOptions("device.token.revoke", { + deviceId: " device-1 ", + role: " operator ", + }); + + await deviceHandlers["device.token.revoke"](opts); + await Promise.resolve(); + + expect(revokeDeviceTokenMock).toHaveBeenCalledWith({ + deviceId: " device-1 ", + role: " operator ", + }); + expect(opts.context.disconnectClientsForDevice).toHaveBeenCalledWith("device-1", { + role: "operator", + }); + expect(opts.respond).toHaveBeenCalledWith( + true, + { deviceId: "device-1", role: "operator", revokedAtMs: 456 }, + undefined, + ); + }); + + it("does not disconnect clients when token revocation fails", async () => { + revokeDeviceTokenMock.mockResolvedValue(null); + const opts = createOptions("device.token.revoke", { + deviceId: "device-1", + role: "operator", + }); + + await deviceHandlers["device.token.revoke"](opts); + + expect(opts.context.disconnectClientsForDevice).not.toHaveBeenCalled(); + expect(opts.respond).toHaveBeenCalledWith( + false, + undefined, + expect.objectContaining({ message: "unknown deviceId/role" }), + ); + }); +}); diff --git a/src/gateway/server-methods/devices.ts b/src/gateway/server-methods/devices.ts index 3917f49d301..c27ce75311e 100644 --- a/src/gateway/server-methods/devices.ts +++ b/src/gateway/server-methods/devices.ts @@ -173,6 +173,9 @@ export const deviceHandlers: GatewayRequestHandlers = { } context.logGateway.info(`device pairing removed device=${removed.deviceId}`); respond(true, removed, undefined); + queueMicrotask(() => { + context.disconnectClientsForDevice?.(removed.deviceId); + }); }, "device.token.rotate": async ({ params, respond, context, client }) => { if (!validateDeviceTokenRotateParams(params)) { @@ -283,11 +286,19 @@ export const deviceHandlers: GatewayRequestHandlers = { respond(false, undefined, errorShape(ErrorCodes.INVALID_REQUEST, "unknown deviceId/role")); return; } - context.logGateway.info(`device token revoked device=${deviceId} role=${entry.role}`); + const normalizedDeviceId = deviceId.trim(); + context.logGateway.info(`device token revoked device=${normalizedDeviceId} role=${entry.role}`); respond( true, - { deviceId, role: entry.role, revokedAtMs: entry.revokedAtMs ?? Date.now() }, + { + deviceId: normalizedDeviceId, + role: entry.role, + revokedAtMs: entry.revokedAtMs ?? Date.now(), + }, undefined, ); + queueMicrotask(() => { + context.disconnectClientsForDevice?.(normalizedDeviceId, { role: entry.role }); + }); }, }; diff --git a/src/gateway/server-methods/types.ts b/src/gateway/server-methods/types.ts index ed778c52b3d..7028cea1904 100644 --- a/src/gateway/server-methods/types.ts +++ b/src/gateway/server-methods/types.ts @@ -57,6 +57,7 @@ export type GatewayRequestContext = { nodeUnsubscribeAll: (nodeId: string) => void; hasConnectedMobileNode: () => boolean; hasExecApprovalClients?: (excludeConnId?: string) => boolean; + disconnectClientsForDevice?: (deviceId: string, opts?: { role?: string }) => void; nodeRegistry: NodeRegistry; agentRunSeq: Map; chatAbortControllers: Map; diff --git a/src/gateway/server.impl.ts b/src/gateway/server.impl.ts index 33a82e560eb..4c29ce17071 100644 --- a/src/gateway/server.impl.ts +++ b/src/gateway/server.impl.ts @@ -1196,6 +1196,21 @@ export async function startGatewayServer( } return false; }, + disconnectClientsForDevice: (deviceId: string, opts?: { role?: string }) => { + for (const gatewayClient of clients) { + if (gatewayClient.connect.device?.id !== deviceId) { + continue; + } + if (opts?.role && gatewayClient.connect.role !== opts.role) { + continue; + } + try { + gatewayClient.socket.close(4001, "device removed"); + } catch { + /* ignore */ + } + } + }, nodeRegistry, agentRunSeq, chatAbortControllers,