A PyTorch implementation of the paper "CSI: a hybrid deep neural network for fake news detection"
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data.py 596B

123456789101112131415161718192021
  1. import os
  2. import numpy as np
  3. import pickle
  4. import torch
  5. from torch.utils.data import Dataset
  6. class CSIDataset(Dataset):
  7. def __init__(self, pkl_dir):
  8. super().__init__()
  9. self.samples = []
  10. for file in os.listdir(pkl_dir):
  11. sample = pickle.load(open(pkl_dir + file, 'rb'))
  12. if sample['x_score'].shape[0] > 0:
  13. self.samples += [sample]
  14. def __getitem__(self, idx):
  15. data = self.samples[idx]
  16. return data['x_capture'], data['x_score'], data['label']
  17. def __len__(self):
  18. return len(self.samples)