import cv2 import pickle from PIL import Image from numpy import asarray from data.data_loader import DatasetLoader class TwitterDatasetLoader(DatasetLoader): def __getitem__(self, idx): item = self.set_text(idx) try: if self.mode == 'train': image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}") elif self.mode == 'validation': image = cv2.imread(f"{self.config.validation_image_path}/{self.image_filenames[idx]}") else: image = cv2.imread(f"{self.config.test_image_path}/{self.image_filenames[idx]}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) except: image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB') image = asarray(image) self.set_image(image, item) return item def __len__(self): return len(self.text) class TwitterEmbeddingDatasetLoader(DatasetLoader): def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode): super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode) self.config = config self.image_filenames = image_filenames self.text = list(texts) self.encoded_text = tokenizer self.labels = labels self.transforms = transforms self.mode = mode if mode == 'train': with open(config.train_image_embedding_path, 'rb') as handle: self.image_embedding = pickle.load(handle) with open(config.train_text_embedding_path, 'rb') as handle: self.text_embedding = pickle.load(handle) else: with open(config.validation_image_embedding_path, 'rb') as handle: self.image_embedding = pickle.load(handle) with open(config.validation_text_embedding_path, 'rb') as handle: self.text_embedding = pickle.load(handle) def __getitem__(self, idx): item = dict() item['id'] = idx item['image_embedding'] = self.image_embedding[idx] item['text_embedding'] = self.text_embedding[idx] item['label'] = self.labels[idx] return item def __len__(self): return len(self.text)