Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loaders.py 2.3KB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import numpy as np
  2. import pandas as pd
  3. from sklearn.utils import compute_class_weight
  4. from torch.utils.data import DataLoader
  5. def make_dfs(config):
  6. train_dataframe = pd.read_csv(config.train_text_path)
  7. train_dataframe.dropna(subset=['text'], inplace=True)
  8. train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True)
  9. train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x))
  10. config.class_weights = get_class_weights(train_dataframe.label.values)
  11. if config.test_text_path is None:
  12. offset = int(train_dataframe.shape[0] * 0.80)
  13. test_dataframe = train_dataframe[offset:]
  14. train_dataframe = train_dataframe[:offset]
  15. else:
  16. test_dataframe = pd.read_csv(config.test_text_path)
  17. test_dataframe.dropna(subset=['text'], inplace=True)
  18. test_dataframe = test_dataframe.sample(frac=1).reset_index(drop=True)
  19. test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x))
  20. if config.validation_text_path is None:
  21. offset = int(train_dataframe.shape[0] * 0.90)
  22. validation_dataframe = train_dataframe[offset:]
  23. train_dataframe = train_dataframe[:offset]
  24. else:
  25. validation_dataframe = pd.read_csv(config.validation_text_path)
  26. validation_dataframe.dropna(subset=['text'], inplace=True)
  27. validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True)
  28. validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x))
  29. return train_dataframe, test_dataframe, validation_dataframe
  30. def build_loaders(config, dataframe, mode):
  31. dataset = config.DatasetLoader(config, dataframe=dataframe, mode=mode)
  32. if mode != 'train':
  33. dataloader = DataLoader(
  34. dataset,
  35. batch_size=config.batch_size // 2 if mode == 'lime' else 1,
  36. num_workers=config.num_workers // 2,
  37. pin_memory=False,
  38. shuffle=False,
  39. )
  40. else:
  41. dataloader = DataLoader(
  42. dataset,
  43. batch_size=config.batch_size,
  44. num_workers=config.num_workers,
  45. pin_memory=True,
  46. shuffle=True,
  47. )
  48. return dataloader
  49. def get_class_weights(y):
  50. class_weights = compute_class_weight('balanced', np.unique(y), y)
  51. return class_weights