Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 53 additions & 48 deletions react_retrieval/react_retrieval/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,16 @@
from PIL import Image
import json
import logging

from torch.utils import data
import clip
import numpy as np

from .tsv import TSVFile

import os
from zipfile import ZipFile, BadZipFile


class ICinWJsonDataset(data.Dataset):
def __init__(self, data_root, infolist, transform=None):
super().__init__()

logging.info(f'Initializing ICinW JSON dataset with {infolist}')
with open(infolist, 'r') as fp:
self.infolist = json.load(fp)
Expand All @@ -28,57 +23,58 @@ def __init__(self, data_root, infolist, transform=None):
def __len__(self):
return len(self.infolist)

def load_zipfile(self, zipfile):
zipfile = os.path.join(self.data_root, zipfile)
if zipfile not in self.zipfiles:
self.zipfiles[zipfile] = ZipFile(zipfile)
return self.zipfiles[zipfile]
def load_zipfile(self, zipfile_name):
full_path = os.path.join(self.data_root, zipfile_name) # 수정됨: 명확한 경로 처리
if full_path not in self.zipfiles:
try:
self.zipfiles[full_path] = ZipFile(full_path)
except BadZipFile as e:
logging.error(f"Failed to open zipfile: {full_path}") # 개선됨: 로깅 추가
raise e
return self.zipfiles[full_path]

def read_image(self, index):
img_info = self.infolist[index]
zipfile, imagefile = img_info['img_path'].split('@')
zipfile = self.load_zipfile(zipfile)
zipfile_name, imagefile = img_info['img_path'].split('@')
zipfile = self.load_zipfile(zipfile_name)

try:
image = Image.open(BytesIO(zipfile.read(imagefile))).convert('RGB')
with zipfile.open(imagefile) as img_file:
image = Image.open(img_file).convert('RGB') # 개선됨: with-context 사용
except KeyError:
logging.error(f"Image file {imagefile} not found in zip archive {zipfile_name}") # ✅ 개선됨
raise
except BadZipFile:
assert False, f"bad zip file in reading {img_info['img_path']}"
raise RuntimeError(f"Bad zip file in reading {img_info['img_path']}") # 개선됨

return image

def __getitem__(self, index):
image = self.read_image(index)
if self.transform is not None:
return self.transform(image)
return image

return self.transform(image) if self.transform else image # 개선됨: 간결화

class TSVDataset(data.Dataset):
def __init__(self, file_name, transform=None):
super().__init__()

self.tsv_file = TSVFile(file_name)
self.transform = transform

def __len__(self):
return len(self.tsv_file)

def __getitem__(self, index):
item = self.tsv_file[index]
if self.transform is not None:
return self.transform(item)
return item

return self.transform(item) if self.transform else item # 개선됨

class PairsDataset(data.Dataset):
def __init__(self, image_file_name, text_file_name, image_transform=None, text_transform=None):
super().__init__()

self.image_dataset = TSVDataset(image_file_name, image_transform)
self.text_dataset = TSVDataset(text_file_name, text_transform)

assert len(self.image_dataset) == len(self.text_dataset)

if len(self.image_dataset) != len(self.text_dataset): # 개선됨: 명확한 에러 처리
raise ValueError("Image and text datasets must be of equal length.")

def __len__(self):
return len(self.image_dataset)

Expand All @@ -87,47 +83,56 @@ def get_image(self, index):
return Image.open(BytesIO(base64.b64decode(raw_image_data[1]))).convert('RGB')

def get_image_raw(self, index):
raw_image_data = self.image_dataset.tsv_file[index]
return raw_image_data[1]
return self.image_dataset.tsv_file[index][1]

def get_text(self, index):
raw_text_data = self.text_dataset.tsv_file[index]
return json.loads(raw_text_data[1])['captions'][0]
captions = json.loads(raw_text_data[1]).get('captions', [""]) # 개선됨: get 사용
return captions[0] if captions else ""

def __getitem__(self, index):
image_filename, image = self.image_dataset[index]
text_filename, text = self.text_dataset[index]

assert image_filename == text_filename
if image_filename != text_filename: # 개선됨: assert 대신 명시적 예외 처리
raise ValueError(f"Filename mismatch: {image_filename} != {text_filename}")

return image, text, {
'index': index,
'filename': image_filename,
}


def decode_image(image_item, fn):
return image_item[0], fn(Image.open(BytesIO(base64.b64decode(image_item[1]))).convert('RGB'))

try:
return image_item[0], fn(Image.open(BytesIO(base64.b64decode(image_item[1]))).convert('RGB')) # ✅ 개선됨: try-catch
except Exception as e:
logging.error(f"Failed to decode image: {e}") # 개선됨
raise

def decode_text(text_item):
text_captions_first = json.loads(text_item[1])['captions'][0]
if text_captions_first is None:
text_captions_first = ""
print(f'Found null caption in file {text_item[0]}, using empty string.')
texts = clip.tokenize([text_captions_first], context_length=77, truncate=True)
return text_item[0], texts.squeeze()

try:
captions = json.loads(text_item[1]).get('captions', [""]) # 개선됨
text_captions_first = captions[0] if captions and captions[0] else ""
if not text_captions_first:
logging.warning(f"Found null or empty caption in file {text_item[0]}, using empty string.") # ✅ 개선됨
texts = clip.tokenize([text_captions_first], context_length=77, truncate=True)
return text_item[0], texts.squeeze()
except Exception as e:
logging.error(f"Failed to decode text: {e}") # 개선됨
raise

def encode_as_string(arr):
if type(arr) != np.ndarray:
if not isinstance(arr, np.ndarray): # 개선됨
arr = arr.data.cpu().numpy()
return base64.b64encode(arr.tobytes()).decode('utf-8')


def decode_pairs_feature(item):
index, filename, image_feature, text_feature = item
index = int(index)
image_feature = np.frombuffer(base64.b64decode(image_feature), dtype='float16')
text_feature = np.frombuffer(base64.b64decode(text_feature), dtype='float16')
return index, filename, image_feature, text_feature
try:
index, filename, image_feature, text_feature = item
index = int(index)
image_feature = np.frombuffer(base64.b64decode(image_feature), dtype='float16')
text_feature = np.frombuffer(base64.b64decode(text_feature), dtype='float16')
return index, filename, image_feature, text_feature
except Exception as e:
logging.error(f"Failed to decode pair features: {e}") # 개선됨
raise