Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loaders.py 2.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import numpy as np
  2. import pandas as pd
  3. from sklearn.utils import compute_class_weight
  4. from torch.utils.data import DataLoader
  5. def make_dfs(config):
  6. train_dataframe = pd.read_csv(config.train_text_path)
  7. train_dataframe.dropna(subset=['text'], inplace=True)
  8. train_dataframe = train_dataframe.sample(frac=1).reset_index(drop=True)
  9. train_dataframe.label = train_dataframe.label.apply(lambda x: config.classes.index(x))
  10. config.class_weights = get_class_weights(train_dataframe.label.values)
  11. if config.test_text_path is None:
  12. offset = int(train_dataframe.shape[0] * 0.80)
  13. test_dataframe = train_dataframe[offset:]
  14. train_dataframe = train_dataframe[:offset]
  15. else:
  16. test_dataframe = pd.read_csv(config.test_text_path)
  17. test_dataframe.dropna(subset=['text'], inplace=True)
  18. test_dataframe.label = test_dataframe.label.apply(lambda x: config.classes.index(x))
  19. if config.validation_text_path is None:
  20. offset = int(train_dataframe.shape[0] * 0.90)
  21. validation_dataframe = train_dataframe[offset:]
  22. train_dataframe = train_dataframe[:offset]
  23. else:
  24. validation_dataframe = pd.read_csv(config.validation_text_path)
  25. validation_dataframe.dropna(subset=['text'], inplace=True)
  26. validation_dataframe = validation_dataframe.sample(frac=1).reset_index(drop=True)
  27. validation_dataframe.label = validation_dataframe.label.apply(lambda x: config.classes.index(x))
  28. return train_dataframe, test_dataframe, validation_dataframe
  29. def build_loaders(config, dataframe, mode):
  30. dataset = config.DatasetLoader(config, dataframe=dataframe, mode=mode)
  31. if mode != 'train':
  32. dataloader = DataLoader(
  33. dataset,
  34. batch_size=config.batch_size // 2,
  35. num_workers=config.num_workers // 2,
  36. pin_memory=False,
  37. shuffle=False,
  38. )
  39. else:
  40. dataloader = DataLoader(
  41. dataset,
  42. batch_size=config.batch_size,
  43. num_workers=config.num_workers,
  44. pin_memory=True,
  45. shuffle=True,
  46. )
  47. return dataloader
  48. def get_class_weights(y):
  49. class_weights = compute_class_weight('balanced', np.unique(y), y)
  50. return class_weights