1234567891011121314151617181920212223242526272829303132 |
-
- import cv2
- import pickle
- from PIL import Image
- from numpy import asarray
-
- from data.data_loader import DatasetLoader
-
-
- class TwitterDatasetLoader(DatasetLoader):
- def __getitem__(self, idx):
- item = self.set_text(idx)
- item.update(self.set_img(idx))
- return item
-
- def set_img(self, idx):
- try:
- if self.mode == 'train':
- image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}")
- elif self.mode == 'validation':
- image = cv2.imread(f"{self.config.validation_image_path}/{self.image_filenames[idx]}")
- else:
- image = cv2.imread(f"{self.config.test_image_path}/{self.image_filenames[idx]}")
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- except:
- image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB')
- image = asarray(image)
- return self.set_image(image)
-
- def __len__(self):
- return len(self.text)
|