Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: #2373 Fix image description #2375

Merged
merged 9 commits into from
Jan 17, 2025
3 changes: 1 addition & 2 deletions packages/plugin-node/package.json
Original file line number Diff line number Diff line change
@@ -23,13 +23,13 @@
"tsup.config.ts"
],
"dependencies": {
"@elizaos/core": "workspace:*",
"@aws-sdk/client-s3": "^3.705.0",
"@aws-sdk/s3-request-presigner": "^3.705.0",
"@cliqz/adblocker-playwright": "1.34.0",
"@echogarden/espeak-ng-emscripten": "0.3.3",
"@echogarden/kissfft-wasm": "0.2.0",
"@echogarden/speex-resampler-wasm": "0.2.1",
"@elizaos/core": "workspace:*",
"@huggingface/transformers": "3.0.2",
"@opendocsg/pdf2md": "0.1.32",
"@types/uuid": "10.0.0",
@@ -46,7 +46,6 @@
"formdata-node": "6.0.3",
"fs-extra": "11.2.0",
"gaxios": "6.7.1",
"gif-frames": "0.4.1",
"glob": "11.0.0",
"graceful-fs": "4.2.11",
"html-escaper": "3.0.3",
10 changes: 8 additions & 2 deletions packages/plugin-node/src/actions/describe-image.ts
Original file line number Diff line number Diff line change
@@ -43,12 +43,18 @@ export const describeImage: Action = {
stop: ["\n"],
});

if (!isFileLocationResult(fileLocationResultObject?.object)) {
if (
!isFileLocationResult(
fileLocationResultObject?.object ?? fileLocationResultObject
)
) {
elizaLogger.error("Failed to generate file location");
return false;
}

const { fileLocation } = fileLocationResultObject.object;
let fileLocation = (fileLocationResultObject?.object as any)
?.fileLocation;
fileLocation ??= fileLocationResultObject;

const { description } = await runtime
.getService<IImageDescriptionService>(ServiceType.IMAGE_DESCRIPTION)
129 changes: 72 additions & 57 deletions packages/plugin-node/src/services/image.ts
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@ import {
RawImage,
type Tensor,
} from "@huggingface/transformers";
import sharp, { AvailableFormatInfo, FormatEnum } from "sharp";
import fs from "fs";
import gifFrames from "gif-frames";
import os from "os";
import path from "path";

@@ -111,15 +111,14 @@ class LocalImageProvider implements ImageProvider {
}

async describeImage(
imageData: Buffer
imageData: Buffer,
mimeType: string
): Promise<{ title: string; description: string }> {
if (!this.model || !this.processor || !this.tokenizer) {
throw new Error("Model components not initialized");
}

const base64Data = imageData.toString("base64");
const dataUrl = `data:image/jpeg;base64,${base64Data}`;
const image = await RawImage.fromURL(dataUrl);
const blob = new Blob([imageData], { type: mimeType });
const image = await RawImage.fromBlob(blob);
const visionInputs = await this.processor(image);
const prompts = this.processor.construct_prompts("<DETAILED_CAPTION>");
const textInputs = this.tokenizer(prompts);
@@ -314,10 +313,12 @@ export class ImageDescriptionService
if (this.runtime.imageVisionModelProvider) {
if (
this.runtime.imageVisionModelProvider ===
ModelProviderName.LLAMALOCAL
ModelProviderName.LLAMALOCAL ||
this.runtime.imageVisionModelProvider ===
ModelProviderName.OLLAMA
) {
this.provider = new LocalImageProvider();
elizaLogger.debug("Using llama local for vision model");
elizaLogger.debug("Using local provider for vision model");
} else if (
this.runtime.imageVisionModelProvider ===
ModelProviderName.GOOGLE
@@ -343,9 +344,12 @@ export class ImageDescriptionService
);
return false;
}
} else if (model === models[ModelProviderName.LLAMALOCAL]) {
} else if (
model === models[ModelProviderName.LLAMALOCAL] ||
model === models[ModelProviderName.OLLAMA]
) {
this.provider = new LocalImageProvider();
elizaLogger.debug("Using llama local for vision model");
elizaLogger.debug("Using local provider for vision model");
} else if (model === models[ModelProviderName.GOOGLE]) {
this.provider = new GoogleImageProvider(this.runtime);
elizaLogger.debug("Using google for vision model");
@@ -369,74 +373,85 @@ export class ImageDescriptionService
}

private async loadImageData(
imageUrl: string
imageUrlOrPath: string
): Promise<{ data: Buffer; mimeType: string }> {
const isGif = imageUrl.toLowerCase().endsWith(".gif");
let imageData: Buffer;
let mimeType: string;

if (isGif) {
const { filePath } = await this.extractFirstFrameFromGif(imageUrl);
imageData = fs.readFileSync(filePath);
mimeType = "image/png";
fs.unlinkSync(filePath); // Clean up temp file
let loadedImageData: Buffer;
let loadedMimeType: string;
const { imageData, mimeType } = await this.fetchImage(imageUrlOrPath);
const skipConversion =
mimeType === "image/jpeg" ||
mimeType === "image/jpg" ||
mimeType === "image/png";
if (skipConversion) {
loadedImageData = imageData;
loadedMimeType = mimeType;
} else {
if (fs.existsSync(imageUrl)) {
imageData = fs.readFileSync(imageUrl);
const ext = path.extname(imageUrl).slice(1);
mimeType = ext ? `image/${ext}` : "image/jpeg";
} else {
const response = await fetch(imageUrl);
if (!response.ok) {
throw new Error(
`Failed to fetch image: ${response.statusText}`
);
}
imageData = Buffer.from(await response.arrayBuffer());
mimeType = response.headers.get("content-type") || "image/jpeg";
}
const converted = await this.convertImageDataToFormat(
imageData,
"png"
);
loadedImageData = converted.imageData;
loadedMimeType = converted.mimeType;
}

if (!imageData || imageData.length === 0) {
if (!loadedImageData || loadedImageData.length === 0) {
throw new Error("Failed to fetch image data");
}

return { data: imageData, mimeType };
return { data: loadedImageData, mimeType: loadedMimeType };
}

private async extractFirstFrameFromGif(
gifUrl: string
): Promise<{ filePath: string }> {
const frameData = await gifFrames({
url: gifUrl,
frames: 1,
outputType: "png",
});

private async convertImageDataToFormat(
data: Buffer,
format: keyof FormatEnum | AvailableFormatInfo = "png"
): Promise<{ imageData: Buffer; mimeType: string }> {
const tempFilePath = path.join(
os.tmpdir(),
`gif_frame_${Date.now()}.png`
`tmp_img_${Date.now()}.${format}`
);
try {
await sharp(data).toFormat(format).toFile(tempFilePath);
const { imageData, mimeType } = await this.fetchImage(tempFilePath);
return {
imageData,
mimeType,
};
} finally {
fs.unlinkSync(tempFilePath); // Clean up temp file
}
}

return new Promise((resolve, reject) => {
const writeStream = fs.createWriteStream(tempFilePath);
frameData[0].getImage().pipe(writeStream);
writeStream.on("finish", () => resolve({ filePath: tempFilePath }));
writeStream.on("error", reject);
});
private async fetchImage(
imageUrlOrPath: string
): Promise<{ imageData: Buffer; mimeType: string }> {
let imageData: Buffer;
let mimeType: string;
if (fs.existsSync(imageUrlOrPath)) {
imageData = fs.readFileSync(imageUrlOrPath);
const ext = path.extname(imageUrlOrPath).slice(1).toLowerCase();
mimeType = ext ? `image/${ext}` : "image/jpeg";
} else {
const response = await fetch(imageUrlOrPath);
if (!response.ok) {
throw new Error(
`Failed to fetch image: ${response.statusText}`
);
}
imageData = Buffer.from(await response.arrayBuffer());
mimeType = response.headers.get("content-type") || "image/jpeg";
}
return { imageData, mimeType };
}

async describeImage(
imageUrl: string
imageUrlOrPath: string
): Promise<{ title: string; description: string }> {
if (!this.initialized) {
this.initialized = await this.initializeProvider();
}

if (this.initialized) {
try {
const { data, mimeType } = await this.loadImageData(imageUrl);
return await this.provider!.describeImage(data, mimeType);
const { data, mimeType } = await this.loadImageData(imageUrlOrPath);
return await this.provider.describeImage(data, mimeType);
} catch (error) {
elizaLogger.error("Error in describeImage:", error);
throw error;
Loading