Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loaders.py 3.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import albumentations as A
  2. import numpy as np
  3. import pandas as pd
  4. from sklearn.utils import compute_class_weight
  5. from torch.utils.data import DataLoader
  6. from transformers import ViTFeatureExtractor, BertTokenizer, BigBirdTokenizer, XLNetTokenizer
  7. def get_transforms(config, mode="train"):
  8. if mode == "train":
  9. return A.Compose(
  10. [
  11. A.Resize(config.size, config.size, always_apply=True),
  12. A.Normalize(max_pixel_value=255.0, always_apply=True),
  13. ]
  14. )
  15. else:
  16. return A.Compose(
  17. [
  18. A.Resize(config.size, config.size, always_apply=True),
  19. A.Normalize(max_pixel_value=255.0, always_apply=True),
  20. ]
  21. )
  22. def make_dfs(config):
  23. train_dataframe = pd.read_csv(config.train_text_path)
  24. train_dataframe.dropna(subset=['text'], inplace=True)
  25. train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True)
  26. train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x))
  27. config.class_weights = get_class_weights(train_dataframe.label.values)
  28. if config.test_text_path is None:
  29. offset = int(train_dataframe.shape[0] * 0.80)
  30. test_dataframe = train_dataframe[offset:]
  31. train_dataframe = train_dataframe[:offset]
  32. else:
  33. test_dataframe = pd.read_csv(config.test_text_path)
  34. test_dataframe.dropna(subset=['text'], inplace=True)
  35. test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True)
  36. test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x))
  37. if config.validation_text_path is None:
  38. offset = int(train_dataframe.shape[0] * 0.90)
  39. validation_dataframe = train_dataframe[offset:]
  40. train_dataframe = train_dataframe[:offset]
  41. else:
  42. validation_dataframe = pd.read_csv(config.validation_text_path)
  43. validation_dataframe.dropna(subset=['text'], inplace=True)
  44. validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True)
  45. validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x))
  46. unlabeled_dataframe = None
  47. if config.has_unlabeled_data:
  48. unlabeled_dataframe = pd.read_csv(config.unlabeled_text_path)
  49. unlabeled_dataframe = unlabeled_dataframe.sample(frac=1).reset_index(drop=True)
  50. return train_dataframe, test_dataframe, validation_dataframe, unlabeled_dataframe
  51. def get_tokenizer(config):
  52. if 'bigbird' in config.text_encoder_model:
  53. tokenizer = BigBirdTokenizer.from_pretrained(config.text_tokenizer)
  54. elif 'xlnet' in config.text_encoder_model:
  55. tokenizer = XLNetTokenizer.from_pretrained(config.text_tokenizer)
  56. else:
  57. tokenizer = BertTokenizer.from_pretrained(config.text_tokenizer)
  58. return tokenizer
  59. def build_loaders(config, dataframe, mode):
  60. if 'resnet' in config.image_model_name:
  61. transforms = get_transforms(config, mode=mode)
  62. else:
  63. transforms = ViTFeatureExtractor.from_pretrained(config.image_model_name)
  64. dataset = config.DatasetLoader(
  65. config,
  66. dataframe["image"].values,
  67. dataframe["text"].values,
  68. dataframe["label"].values,
  69. tokenizer=get_tokenizer(config),
  70. transforms=transforms,
  71. mode=mode,
  72. )
  73. if mode != 'train':
  74. dataloader = DataLoader(
  75. dataset,
  76. batch_size=config.batch_size,
  77. num_workers=config.num_workers,
  78. pin_memory=False,
  79. shuffle=False,
  80. )
  81. else:
  82. dataloader = DataLoader(
  83. dataset,
  84. batch_size=config.batch_size,
  85. num_workers=config.num_workers,
  86. pin_memory=True,
  87. shuffle=True,
  88. )
  89. return dataloader
  90. def get_class_weights(y):
  91. class_weights = compute_class_weight('balanced', np.unique(y), y)
  92. return class_weights