Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 2.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import pickle
  2. import cv2
  3. from PIL import Image
  4. from numpy import asarray
  5. from data.data_loader import DatasetLoader
  6. class WeiboDatasetLoader(DatasetLoader):
  7. def __getitem__(self, idx):
  8. item = self.set_text(idx)
  9. if self.labels[idx] == 1:
  10. try:
  11. image = cv2.imread(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}")
  12. except:
  13. image = Image.open(f"{self.config.rumor_image_path}/{self.image_filenames[idx]}").convert('RGB')
  14. image = asarray(image)
  15. else:
  16. try:
  17. image = cv2.imread(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}")
  18. except:
  19. image = Image.open(f"{self.config.nonrumor_image_path}/{self.image_filenames[idx]}").convert('RGB')
  20. image = asarray(image)
  21. self.set_image(image, item)
  22. return item
  23. def __len__(self):
  24. return len(self.text)
  25. class WeiboEmbeddingDatasetLoader(DatasetLoader):
  26. def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
  27. super().__init__(config, image_filenames, texts, labels, tokenizer, transforms, mode)
  28. self.config = config
  29. self.image_filenames = image_filenames
  30. self.text = list(texts)
  31. self.encoded_text = tokenizer
  32. self.labels = labels
  33. self.transforms = transforms
  34. self.mode = mode
  35. if mode == 'train':
  36. with open(config.train_image_embedding_path, 'rb') as handle:
  37. self.image_embedding = pickle.load(handle)
  38. with open(config.train_text_embedding_path, 'rb') as handle:
  39. self.text_embedding = pickle.load(handle)
  40. else:
  41. with open(config.validation_image_embedding_path, 'rb') as handle:
  42. self.image_embedding = pickle.load(handle)
  43. with open(config.validation_text_embedding_path, 'rb') as handle:
  44. self.text_embedding = pickle.load(handle)
  45. def __getitem__(self, idx):
  46. item = dict()
  47. item['id'] = idx
  48. item['image_embedding'] = self.image_embedding[idx]
  49. item['text_embedding'] = self.text_embedding[idx]
  50. item['label'] = self.labels[idx]
  51. return item
  52. def __len__(self):
  53. return len(self.text)