A PyTorch implementation of the paper "CSI: a hybrid deep neural network for fake news detection"
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2 years ago
123456789101112131415161718192021
  1. import os
  2. import numpy as np
  3. import pickle
  4. import torch
  5. from torch.utils.data import Dataset
  6. class CSIDataset(Dataset):
  7. def __init__(self, pkl_dir):
  8. super().__init__()
  9. self.samples = []
  10. for file in os.listdir(pkl_dir):
  11. sample = pickle.load(open(pkl_dir + file, 'rb'))
  12. if sample['x_score'].shape[0] > 0:
  13. self.samples += [sample]
  14. def __getitem__(self, idx):
  15. data = self.samples[idx]
  16. return data['x_capture'], data['x_score'], data['label']
  17. def __len__(self):
  18. return len(self.samples)