Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: #2373 Fix image description #2375

Merged
merged 9 commits into from
Jan 17, 2025
3 changes: 1 addition & 2 deletions packages/plugin-node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
"tsup.config.ts"
],
"dependencies": {
"@elizaos/core": "workspace:*",
"@aws-sdk/client-s3": "^3.705.0",
"@aws-sdk/s3-request-presigner": "^3.705.0",
"@cliqz/adblocker-playwright": "1.34.0",
"@echogarden/espeak-ng-emscripten": "0.3.3",
"@echogarden/kissfft-wasm": "0.2.0",
"@echogarden/speex-resampler-wasm": "0.2.1",
"@elizaos/core": "workspace:*",
"@huggingface/transformers": "3.0.2",
"@opendocsg/pdf2md": "0.1.32",
"@types/uuid": "10.0.0",
Expand All @@ -46,7 +46,6 @@
"formdata-node": "6.0.3",
"fs-extra": "11.2.0",
"gaxios": "6.7.1",
"gif-frames": "0.4.1",
"glob": "11.0.0",
"graceful-fs": "4.2.11",
"html-escaper": "3.0.3",
Expand Down
10 changes: 8 additions & 2 deletions packages/plugin-node/src/actions/describe-image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,18 @@ export const describeImage: Action = {
stop: ["\n"],
});

if (!isFileLocationResult(fileLocationResultObject?.object)) {
if (
!isFileLocationResult(
fileLocationResultObject?.object ?? fileLocationResultObject
)
) {
elizaLogger.error("Failed to generate file location");
return false;
}

const { fileLocation } = fileLocationResultObject.object;
let fileLocation = (fileLocationResultObject?.object as any)
?.fileLocation;
fileLocation ??= fileLocationResultObject;

const { description } = await runtime
.getService<IImageDescriptionService>(ServiceType.IMAGE_DESCRIPTION)
Expand Down
128 changes: 72 additions & 56 deletions packages/plugin-node/src/services/image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import {
RawImage,
type Tensor,
} from "@huggingface/transformers";
import sharp, { AvailableFormatInfo, FormatEnum } from "sharp";
import fs from "fs";
import gifFrames from "gif-frames";
import os from "os";
import path from "path";

Expand Down Expand Up @@ -111,15 +111,14 @@ class LocalImageProvider implements ImageProvider {
}

async describeImage(
imageData: Buffer
imageData: Buffer,
mimeType: string
): Promise<{ title: string; description: string }> {
if (!this.model || !this.processor || !this.tokenizer) {
throw new Error("Model components not initialized");
}

const base64Data = imageData.toString("base64");
const dataUrl = `data:image/jpeg;base64,${base64Data}`;
const image = await RawImage.fromURL(dataUrl);
const blob = new Blob([imageData], { type: mimeType });
const image = await RawImage.fromBlob(blob);
const visionInputs = await this.processor(image);
const prompts = this.processor.construct_prompts("<DETAILED_CAPTION>");
const textInputs = this.tokenizer(prompts);
Expand Down Expand Up @@ -264,10 +263,12 @@ export class ImageDescriptionService
if (this.runtime.imageVisionModelProvider) {
if (
this.runtime.imageVisionModelProvider ===
ModelProviderName.LLAMALOCAL
ModelProviderName.LLAMALOCAL ||
this.runtime.imageVisionModelProvider ===
ModelProviderName.OLLAMA
) {
this.provider = new LocalImageProvider();
elizaLogger.debug("Using llama local for vision model");
elizaLogger.debug("Using local provider for vision model");
} else if (
this.runtime.imageVisionModelProvider ===
ModelProviderName.GOOGLE
Expand All @@ -285,9 +286,12 @@ export class ImageDescriptionService
`Unsupported image vision model provider: ${this.runtime.imageVisionModelProvider}`
);
}
} else if (model === models[ModelProviderName.LLAMALOCAL]) {
} else if (
model === models[ModelProviderName.LLAMALOCAL] ||
model === models[ModelProviderName.OLLAMA]
) {
this.provider = new LocalImageProvider();
elizaLogger.debug("Using llama local for vision model");
elizaLogger.debug("Using local provider for vision model");
} else if (model === models[ModelProviderName.GOOGLE]) {
this.provider = new GoogleImageProvider(this.runtime);
elizaLogger.debug("Using google for vision model");
Expand All @@ -300,73 +304,85 @@ export class ImageDescriptionService
this.initialized = true;
}

/**
* @param imageUrlOrPath URL or file system path to an image. Any format that the sharp library supports as input should work. Images that aren't JPEG or PNG will be converted to PNG.
* @returns Image data buffer and mime type.
*/
private async loadImageData(
imageUrl: string
imageUrlOrPath: string
): Promise<{ data: Buffer; mimeType: string }> {
const isGif = imageUrl.toLowerCase().endsWith(".gif");
let imageData: Buffer;
let mimeType: string;

if (isGif) {
const { filePath } = await this.extractFirstFrameFromGif(imageUrl);
imageData = fs.readFileSync(filePath);
mimeType = "image/png";
fs.unlinkSync(filePath); // Clean up temp file
let loadedImageData: Buffer;
let loadedMimeType: string;
const { imageData, mimeType } = await this.fetchImage(imageUrlOrPath);
const skipConversion =
mimeType === "image/jpeg" ||
mimeType === "image/jpg" ||
mimeType === "image/png";
if (skipConversion) {
loadedImageData = imageData;
loadedMimeType = mimeType;
} else {
if (fs.existsSync(imageUrl)) {
imageData = fs.readFileSync(imageUrl);
const ext = path.extname(imageUrl).slice(1);
mimeType = ext ? `image/${ext}` : "image/jpeg";
} else {
const response = await fetch(imageUrl);
if (!response.ok) {
throw new Error(
`Failed to fetch image: ${response.statusText}`
);
}
imageData = Buffer.from(await response.arrayBuffer());
mimeType = response.headers.get("content-type") || "image/jpeg";
}
const converted = await this.convertImageDataToFormat(
imageData,
"png"
);
loadedImageData = converted.imageData;
loadedMimeType = converted.mimeType;
}

if (!imageData || imageData.length === 0) {
if (!loadedImageData || loadedImageData.length === 0) {
throw new Error("Failed to fetch image data");
}

return { data: imageData, mimeType };
return { data: loadedImageData, mimeType: loadedMimeType };
}

private async extractFirstFrameFromGif(
gifUrl: string
): Promise<{ filePath: string }> {
const frameData = await gifFrames({
url: gifUrl,
frames: 1,
outputType: "png",
});

private async convertImageDataToFormat(
data: Buffer,
format: keyof FormatEnum | AvailableFormatInfo = "png"
): Promise<{ imageData: Buffer; mimeType: string }> {
const tempFilePath = path.join(
os.tmpdir(),
`gif_frame_${Date.now()}.png`
`tmp_img_${Date.now()}.${format}`
);
await sharp(data).toFormat(format).toFile(tempFilePath);
const { imageData, mimeType } = await this.fetchImage(tempFilePath);
fs.unlinkSync(tempFilePath); // Clean up temp file
return {
imageData,
mimeType,
};
}

return new Promise((resolve, reject) => {
const writeStream = fs.createWriteStream(tempFilePath);
frameData[0].getImage().pipe(writeStream);
writeStream.on("finish", () => resolve({ filePath: tempFilePath }));
writeStream.on("error", reject);
});
private async fetchImage(
imageUrlOrPath: string
): Promise<{ imageData: Buffer; mimeType: string }> {
let imageData: Buffer;
let mimeType: string;
if (fs.existsSync(imageUrlOrPath)) {
imageData = fs.readFileSync(imageUrlOrPath);
const ext = path.extname(imageUrlOrPath).slice(1).toLowerCase();
mimeType = ext ? `image/${ext}` : "image/jpeg";
} else {
const response = await fetch(imageUrlOrPath);
if (!response.ok) {
throw new Error(
`Failed to fetch image: ${response.statusText}`
);
}
imageData = Buffer.from(await response.arrayBuffer());
mimeType = response.headers.get("content-type") || "image/jpeg";
}
return { imageData, mimeType };
}

async describeImage(
imageUrl: string
imageUrlOrPath: string
): Promise<{ title: string; description: string }> {
if (!this.initialized) {
await this.initializeProvider();
}

try {
const { data, mimeType } = await this.loadImageData(imageUrl);
const { data, mimeType } = await this.loadImageData(imageUrlOrPath);
return await this.provider!.describeImage(data, mimeType);
} catch (error) {
elizaLogger.error("Error in describeImage:", error);
Expand Down
Loading
Loading