import { collectProviderApiKeysForExecution, executeWithApiKeyRotation, } from "../agents/api-key-rotation.js"; import { requireApiKey, resolveApiKeyForProvider } from "../agents/model-auth.js"; import { parseGeminiAuth } from "../infra/gemini-auth.js"; import type { SsrFPolicy } from "../infra/net/ssrf.js"; import type { EmbeddingInput } from "./embedding-inputs.js"; import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; import { debugEmbeddingsLog } from "./embeddings-debug.js"; import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js"; import { resolveMemorySecretInputString } from "./secret-input.js"; export type GeminiEmbeddingClient = { baseUrl: string; headers: Record; ssrfPolicy?: SsrFPolicy; model: string; modelPath: string; apiKeys: string[]; outputDimensionality?: number; }; const DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"; export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001"; const GEMINI_MAX_INPUT_TOKENS: Record = { "text-embedding-004": 2048, }; // --- gemini-embedding-2-preview support --- export const GEMINI_EMBEDDING_2_MODELS = new Set([ "gemini-embedding-2-preview", // Add the GA model name here once released. ]); const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072; const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const; export type GeminiTaskType = | "RETRIEVAL_QUERY" | "RETRIEVAL_DOCUMENT" | "SEMANTIC_SIMILARITY" | "CLASSIFICATION" | "CLUSTERING" | "QUESTION_ANSWERING" | "FACT_VERIFICATION"; export type GeminiTextPart = { text: string }; export type GeminiInlinePart = { inlineData: { mimeType: string; data: string }; }; export type GeminiPart = GeminiTextPart | GeminiInlinePart; export type GeminiEmbeddingRequest = { content: { parts: GeminiPart[] }; taskType: GeminiTaskType; outputDimensionality?: number; model?: string; }; export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest; /** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */ export function buildGeminiTextEmbeddingRequest(params: { text: string; taskType: GeminiTaskType; outputDimensionality?: number; modelPath?: string; }): GeminiTextEmbeddingRequest { return buildGeminiEmbeddingRequest({ input: { text: params.text }, taskType: params.taskType, outputDimensionality: params.outputDimensionality, modelPath: params.modelPath, }); } export function buildGeminiEmbeddingRequest(params: { input: EmbeddingInput; taskType: GeminiTaskType; outputDimensionality?: number; modelPath?: string; }): GeminiEmbeddingRequest { const request: GeminiEmbeddingRequest = { content: { parts: params.input.parts?.map((part) => part.type === "text" ? ({ text: part.text } satisfies GeminiTextPart) : ({ inlineData: { mimeType: part.mimeType, data: part.data }, } satisfies GeminiInlinePart), ) ?? [{ text: params.input.text }], }, taskType: params.taskType, }; if (params.modelPath) { request.model = params.modelPath; } if (params.outputDimensionality != null) { request.outputDimensionality = params.outputDimensionality; } return request; } /** * Returns true if the given model name is a gemini-embedding-2 variant that * supports `outputDimensionality` and extended task types. */ export function isGeminiEmbedding2Model(model: string): boolean { return GEMINI_EMBEDDING_2_MODELS.has(model); } /** * Validate and return the `outputDimensionality` for gemini-embedding-2 models. * Returns `undefined` for older models (they don't support the param). */ export function resolveGeminiOutputDimensionality( model: string, requested?: number, ): number | undefined { if (!isGeminiEmbedding2Model(model)) { return undefined; } if (requested == null) { return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS; } const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS; if (!valid.includes(requested)) { throw new Error( `Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`, ); } return requested; } function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined { const trimmed = resolveMemorySecretInputString({ value: remoteApiKey, path: "agents.*.memorySearch.remote.apiKey", }); if (!trimmed) { return undefined; } if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") { return process.env[trimmed]?.trim(); } return trimmed; } export function normalizeGeminiModel(model: string): string { const trimmed = model.trim(); if (!trimmed) { return DEFAULT_GEMINI_EMBEDDING_MODEL; } const withoutPrefix = trimmed.replace(/^models\//, ""); if (withoutPrefix.startsWith("gemini/")) { return withoutPrefix.slice("gemini/".length); } if (withoutPrefix.startsWith("google/")) { return withoutPrefix.slice("google/".length); } return withoutPrefix; } async function fetchGeminiEmbeddingPayload(params: { client: GeminiEmbeddingClient; endpoint: string; body: unknown; }): Promise<{ embedding?: { values?: number[] }; embeddings?: Array<{ values?: number[] }>; }> { return await executeWithApiKeyRotation({ provider: "google", apiKeys: params.client.apiKeys, execute: async (apiKey) => { const authHeaders = parseGeminiAuth(apiKey); const headers = { ...authHeaders.headers, ...params.client.headers, }; return await withRemoteHttpResponse({ url: params.endpoint, ssrfPolicy: params.client.ssrfPolicy, init: { method: "POST", headers, body: JSON.stringify(params.body), }, onResponse: async (res) => { if (!res.ok) { const text = await res.text(); throw new Error(`gemini embeddings failed: ${res.status} ${text}`); } return (await res.json()) as { embedding?: { values?: number[] }; embeddings?: Array<{ values?: number[] }>; }; }, }); }, }); } function normalizeGeminiBaseUrl(raw: string): string { const trimmed = raw.replace(/\/+$/, ""); const openAiIndex = trimmed.indexOf("/openai"); if (openAiIndex > -1) { return trimmed.slice(0, openAiIndex); } return trimmed; } function buildGeminiModelPath(model: string): string { return model.startsWith("models/") ? model : `models/${model}`; } export async function createGeminiEmbeddingProvider( options: EmbeddingProviderOptions, ): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> { const client = await resolveGeminiEmbeddingClient(options); const baseUrl = client.baseUrl.replace(/\/$/, ""); const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`; const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`; const isV2 = isGeminiEmbedding2Model(client.model); const outputDimensionality = client.outputDimensionality; const embedQuery = async (text: string): Promise => { if (!text.trim()) { return []; } const payload = await fetchGeminiEmbeddingPayload({ client, endpoint: embedUrl, body: buildGeminiTextEmbeddingRequest({ text, taskType: options.taskType ?? "RETRIEVAL_QUERY", outputDimensionality: isV2 ? outputDimensionality : undefined, }), }); return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []); }; const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise => { if (inputs.length === 0) { return []; } const payload = await fetchGeminiEmbeddingPayload({ client, endpoint: batchUrl, body: { requests: inputs.map((input) => buildGeminiEmbeddingRequest({ input, modelPath: client.modelPath, taskType: options.taskType ?? "RETRIEVAL_DOCUMENT", outputDimensionality: isV2 ? outputDimensionality : undefined, }), ), }, }); const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : []; return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? [])); }; const embedBatch = async (texts: string[]): Promise => { return await embedBatchInputs( texts.map((text) => ({ text, })), ); }; return { provider: { id: "gemini", model: client.model, maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model], embedQuery, embedBatch, embedBatchInputs, }, client, }; } export async function resolveGeminiEmbeddingClient( options: EmbeddingProviderOptions, ): Promise { const remote = options.remote; const remoteApiKey = resolveRemoteApiKey(remote?.apiKey); const remoteBaseUrl = remote?.baseUrl?.trim(); const apiKey = remoteApiKey ? remoteApiKey : requireApiKey( await resolveApiKeyForProvider({ provider: "google", cfg: options.config, agentDir: options.agentDir, }), "google", ); const providerConfig = options.config.models?.providers?.google; const rawBaseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GEMINI_BASE_URL; const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl); const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl); const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); const headers: Record = { ...headerOverrides, }; const apiKeys = collectProviderApiKeysForExecution({ provider: "google", primaryApiKey: apiKey, }); const model = normalizeGeminiModel(options.model); const modelPath = buildGeminiModelPath(model); const outputDimensionality = resolveGeminiOutputDimensionality( model, options.outputDimensionality, ); debugEmbeddingsLog("memory embeddings: gemini client", { rawBaseUrl, baseUrl, model, modelPath, outputDimensionality, embedEndpoint: `${baseUrl}/${modelPath}:embedContent`, batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`, }); return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality }; }