Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 2.3KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import torch
  2. from torch.utils.data import Dataset
  3. import albumentations as A
  4. from transformers import ViTFeatureExtractor
  5. from transformers import BertTokenizer, BigBirdTokenizer, XLNetTokenizer
  6. def get_transforms(config):
  7. return A.Compose(
  8. [
  9. A.Resize(config.size, config.size, always_apply=True),
  10. A.Normalize(max_pixel_value=255.0, always_apply=True),
  11. ]
  12. )
  13. def get_tokenizer(config):
  14. if 'bigbird' in config.text_encoder_model:
  15. tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
  16. elif 'xlnet' in config.text_encoder_model:
  17. tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
  18. else:
  19. tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer)
  20. return tokenizer
  21. class DatasetLoader(Dataset):
  22. def __init__(self, config, dataframe, mode):
  23. self.config = config
  24. self.mode = mode
  25. if mode == 'lime':
  26. self.image_filenames = [dataframe["image"],]
  27. self.text = [dataframe["text"],]
  28. self.labels = [dataframe["label"],]
  29. else:
  30. self.image_filenames = dataframe["image"].values
  31. self.text = list(dataframe["text"].values)
  32. self.labels = dataframe["label"].values
  33. tokenizer = get_tokenizer(config)
  34. self.encoded_text = tokenizer(self.text, padding=True, truncation=True, max_length=config.max_length, return_tensors='pt')
  35. if 'resnet' in config.image_model_name:
  36. self.transforms = get_transforms(config)
  37. else:
  38. self.transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name)
  39. def set_text(self, idx):
  40. item = {
  41. key: values[idx].clone().detach()
  42. for key, values in self.encoded_text.items()
  43. }
  44. item['text'] = self.text[idx]
  45. item['label'] = self.labels[idx]
  46. item['id'] = idx
  47. return item
  48. def set_image(self, image):
  49. if 'resnet' in self.config.image_model_name:
  50. image = self.transforms(image=image)['image']
  51. return {'image': torch.as_tensor(image).reshape((3, 224, 224))}
  52. else:
  53. image = self.transforms(images=image, return_tensors='pt')
  54. image = image.convert_to_tensors(tensor_type='pt')['pixel_values']
  55. return {'image': image.reshape((3, 224, 224))}