Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 1.2KB

123456789101112131415161718192021222324252627282930313233
  1. import torch
  2. from torch.utils.data import Dataset
  3. class DatasetLoader(Dataset):
  4. def __init__(self, config, image_filenames, texts, labels, tokenizer, transforms, mode):
  5. self.config = config
  6. self.image_filenames = image_filenames
  7. self.text = list(texts)
  8. self.encoded_text = tokenizer(
  9. list(texts), padding=True, truncation=True, max_length=config.max_length, return_tensors='pt'
  10. )
  11. self.labels = labels
  12. self.transforms = transforms
  13. self.mode = mode
  14. def set_text(self, idx):
  15. item = {
  16. key: values[idx].clone().detach()
  17. for key, values in self.encoded_text.items()
  18. }
  19. item['text'] = self.text[idx]
  20. item['label'] = self.labels[idx]
  21. item['id'] = idx
  22. return item
  23. def set_image(self, image, item):
  24. if 'resnet' in self.config.image_model_name:
  25. image = self.transforms(image=image)['image']
  26. item['image'] = torch.as_tensor(image).reshape((3, 224, 224))
  27. else:
  28. image = self.transforms(images=image, return_tensors='pt')
  29. image = image.convert_to_tensors(tensor_type='pt')['pixel_values']
  30. item['image'] = image.reshape((3, 224, 224))