Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 2.3KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import cv2
  2. import pickle
  3. from PIL import Image
  4. from numpy import asarray
  5. from data.data_loader import DatasetLoader
  6. class TwitterDatasetLoader(DatasetLoader):
  7. def __getitem__(self, idx):
  8. item = self.set_text(idx)
  9. try:
  10. if self.mode == 'train':
  11. image = cv2.imread(f"{self.config.train_image_path}/{self.image_filenames[idx]}")
  12. elif self.mode == 'validation':
  13. image = cv2.imread(f"{self.config.validation_image_path}/{self.image_filenames[idx]}")
  14. else:
  15. image = cv2.imread(f"{self.config.test_image_path}/{self.image_filenames[idx]}")
  16. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  17. except:
  18. image = Image.open(f"{self.config.train_image_path}/{self.image_filenames[idx]}").convert('RGB')
  19. image = asarray(image)
  20. self.set_image(image, item)
  21. return item
  22. def __len__(self):
  23. return len(self.text)
  24. class TwitterEmbeddingDatasetLoader(DatasetLoader):
  25. def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
  26. super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
  27. self.config = config
  28. self.image_filenames = image_filenames
  29. self.text = list(texts)
  30. self.encoded_text = tokenizer
  31. self.labels = labels
  32. self.transforms = transforms
  33. self.mode = mode
  34. if mode == 'train':
  35. with open(config.train_image_embedding_path, 'rb') as handle:
  36. self.image_embedding = pickle.load(handle)
  37. with open(config.train_text_embedding_path, 'rb') as handle:
  38. self.text_embedding = pickle.load(handle)
  39. else:
  40. with open(config.validation_image_embedding_path, 'rb') as handle:
  41. self.image_embedding = pickle.load(handle)
  42. with open(config.validation_text_embedding_path, 'rb') as handle:
  43. self.text_embedding = pickle.load(handle)
  44. def __getitem__(self, idx):
  45. item = dict()
  46. item['id'] = idx
  47. item['image_embedding'] = self.image_embedding[idx]
  48. item['text_embedding'] = self.text_embedding[idx]
  49. item['label'] = self.labels[idx]
  50. return item
  51. def __len__(self):
  52. return len(self.text)