openclaw/extensions/fal/image-generation-provider.ts

303 lines
8.6 KiB
TypeScript

import type {
GeneratedImageAsset,
ImageGenerationProvider,
} from "openclaw/plugin-sdk/image-generation";
import { resolveApiKeyForProvider } from "openclaw/plugin-sdk/provider-auth";
const DEFAULT_FAL_BASE_URL = "https://fal.run";
const DEFAULT_FAL_IMAGE_MODEL = "fal-ai/flux/dev";
const DEFAULT_FAL_EDIT_SUBPATH = "image-to-image";
const DEFAULT_OUTPUT_FORMAT = "png";
const FAL_SUPPORTED_SIZES = [
"1024x1024",
"1024x1536",
"1536x1024",
"1024x1792",
"1792x1024",
] as const;
const FAL_SUPPORTED_ASPECT_RATIOS = ["1:1", "4:3", "3:4", "16:9", "9:16"] as const;
type FalGeneratedImage = {
url?: string;
content_type?: string;
};
type FalImageGenerationResponse = {
images?: FalGeneratedImage[];
prompt?: string;
};
type FalImageSize = string | { width: number; height: number };
function resolveFalBaseUrl(cfg: Parameters<typeof resolveApiKeyForProvider>[0]["cfg"]): string {
const direct = cfg?.models?.providers?.fal?.baseUrl?.trim();
return (direct || DEFAULT_FAL_BASE_URL).replace(/\/+$/u, "");
}
function ensureFalModelPath(model: string | undefined, hasInputImages: boolean): string {
const trimmed = model?.trim() || DEFAULT_FAL_IMAGE_MODEL;
if (!hasInputImages) {
return trimmed;
}
if (
trimmed.endsWith(`/${DEFAULT_FAL_EDIT_SUBPATH}`) ||
trimmed.endsWith("/edit") ||
trimmed.includes("/image-to-image/")
) {
return trimmed;
}
return `${trimmed}/${DEFAULT_FAL_EDIT_SUBPATH}`;
}
function parseSize(raw: string | undefined): { width: number; height: number } | null {
const trimmed = raw?.trim();
if (!trimmed) {
return null;
}
const match = /^(\d{2,5})x(\d{2,5})$/iu.exec(trimmed);
if (!match) {
return null;
}
const width = Number.parseInt(match[1] ?? "", 10);
const height = Number.parseInt(match[2] ?? "", 10);
if (!Number.isFinite(width) || !Number.isFinite(height) || width <= 0 || height <= 0) {
return null;
}
return { width, height };
}
function mapResolutionToEdge(resolution: "1K" | "2K" | "4K" | undefined): number | undefined {
if (!resolution) {
return undefined;
}
return resolution === "4K" ? 4096 : resolution === "2K" ? 2048 : 1024;
}
function aspectRatioToEnum(aspectRatio: string | undefined): string | undefined {
const normalized = aspectRatio?.trim();
if (!normalized) {
return undefined;
}
if (normalized === "1:1") {
return "square_hd";
}
if (normalized === "4:3") {
return "landscape_4_3";
}
if (normalized === "3:4") {
return "portrait_4_3";
}
if (normalized === "16:9") {
return "landscape_16_9";
}
if (normalized === "9:16") {
return "portrait_16_9";
}
return undefined;
}
function aspectRatioToDimensions(
aspectRatio: string,
edge: number,
): { width: number; height: number } {
const match = /^(\d+):(\d+)$/u.exec(aspectRatio.trim());
if (!match) {
throw new Error(`Invalid fal aspect ratio: ${aspectRatio}`);
}
const widthRatio = Number.parseInt(match[1] ?? "", 10);
const heightRatio = Number.parseInt(match[2] ?? "", 10);
if (
!Number.isFinite(widthRatio) ||
!Number.isFinite(heightRatio) ||
widthRatio <= 0 ||
heightRatio <= 0
) {
throw new Error(`Invalid fal aspect ratio: ${aspectRatio}`);
}
if (widthRatio >= heightRatio) {
return {
width: edge,
height: Math.max(1, Math.round((edge * heightRatio) / widthRatio)),
};
}
return {
width: Math.max(1, Math.round((edge * widthRatio) / heightRatio)),
height: edge,
};
}
function resolveFalImageSize(params: {
size?: string;
resolution?: "1K" | "2K" | "4K";
aspectRatio?: string;
hasInputImages: boolean;
}): FalImageSize | undefined {
const parsed = parseSize(params.size);
if (parsed) {
return parsed;
}
const normalizedAspectRatio = params.aspectRatio?.trim();
if (normalizedAspectRatio && params.hasInputImages) {
throw new Error("fal image edit endpoint does not support aspectRatio overrides");
}
const edge = mapResolutionToEdge(params.resolution);
if (normalizedAspectRatio && edge) {
return aspectRatioToDimensions(normalizedAspectRatio, edge);
}
if (edge) {
return { width: edge, height: edge };
}
if (normalizedAspectRatio) {
return (
aspectRatioToEnum(normalizedAspectRatio) ??
aspectRatioToDimensions(normalizedAspectRatio, 1024)
);
}
return undefined;
}
function toDataUri(buffer: Buffer, mimeType: string): string {
return `data:${mimeType};base64,${buffer.toString("base64")}`;
}
function fileExtensionForMimeType(mimeType: string | undefined): string {
const normalized = mimeType?.toLowerCase().trim();
if (!normalized) {
return "png";
}
if (normalized.includes("jpeg")) {
return "jpg";
}
const slashIndex = normalized.indexOf("/");
return slashIndex >= 0 ? normalized.slice(slashIndex + 1) || "png" : "png";
}
async function fetchImageBuffer(url: string): Promise<{ buffer: Buffer; mimeType: string }> {
const response = await fetch(url);
if (!response.ok) {
const text = await response.text().catch(() => "");
throw new Error(
`fal image download failed (${response.status}): ${text || response.statusText}`,
);
}
const mimeType = response.headers.get("content-type")?.trim() || "image/png";
const arrayBuffer = await response.arrayBuffer();
return { buffer: Buffer.from(arrayBuffer), mimeType };
}
export function buildFalImageGenerationProvider(): ImageGenerationProvider {
return {
id: "fal",
label: "fal",
defaultModel: DEFAULT_FAL_IMAGE_MODEL,
models: [DEFAULT_FAL_IMAGE_MODEL, `${DEFAULT_FAL_IMAGE_MODEL}/${DEFAULT_FAL_EDIT_SUBPATH}`],
capabilities: {
generate: {
maxCount: 4,
supportsSize: true,
supportsAspectRatio: true,
supportsResolution: true,
},
edit: {
enabled: true,
maxCount: 4,
maxInputImages: 1,
supportsSize: true,
supportsAspectRatio: false,
supportsResolution: true,
},
geometry: {
sizes: [...FAL_SUPPORTED_SIZES],
aspectRatios: [...FAL_SUPPORTED_ASPECT_RATIOS],
resolutions: ["1K", "2K", "4K"],
},
},
async generateImage(req) {
const auth = await resolveApiKeyForProvider({
provider: "fal",
cfg: req.cfg,
agentDir: req.agentDir,
store: req.authStore,
});
if (!auth.apiKey) {
throw new Error("fal API key missing");
}
if ((req.inputImages?.length ?? 0) > 1) {
throw new Error("fal image generation currently supports at most one reference image");
}
const hasInputImages = (req.inputImages?.length ?? 0) > 0;
const imageSize = resolveFalImageSize({
size: req.size,
resolution: req.resolution,
aspectRatio: req.aspectRatio,
hasInputImages,
});
const model = ensureFalModelPath(req.model, hasInputImages);
const requestBody: Record<string, unknown> = {
prompt: req.prompt,
num_images: req.count ?? 1,
output_format: DEFAULT_OUTPUT_FORMAT,
};
if (imageSize !== undefined) {
requestBody.image_size = imageSize;
}
if (hasInputImages) {
const [input] = req.inputImages ?? [];
if (!input) {
throw new Error("fal image edit request missing reference image");
}
requestBody.image_url = toDataUri(input.buffer, input.mimeType);
}
const response = await fetch(`${resolveFalBaseUrl(req.cfg)}/${model}`, {
method: "POST",
headers: {
Authorization: `Key ${auth.apiKey}`,
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
});
if (!response.ok) {
const text = await response.text().catch(() => "");
throw new Error(
`fal image generation failed (${response.status}): ${text || response.statusText}`,
);
}
const payload = (await response.json()) as FalImageGenerationResponse;
const images: GeneratedImageAsset[] = [];
let imageIndex = 0;
for (const entry of payload.images ?? []) {
const url = entry.url?.trim();
if (!url) {
continue;
}
const downloaded = await fetchImageBuffer(url);
imageIndex += 1;
images.push({
buffer: downloaded.buffer,
mimeType: downloaded.mimeType,
fileName: `image-${imageIndex}.${fileExtensionForMimeType(
downloaded.mimeType || entry.content_type,
)}`,
});
}
if (images.length === 0) {
throw new Error("fal image generation response missing image data");
}
return {
images,
model,
metadata: payload.prompt ? { prompt: payload.prompt } : undefined,
};
},
};
}