You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataLoader.py 6.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. ''' Data Loader class for training iteration '''
  2. import random
  3. import numpy as np
  4. import torch
  5. from torch.autograd import Variable
  6. import transformer.Constants as Constants
  7. import logging
  8. import pickle
  9. CUDA = 1
  10. class Options(object):
  11. def __init__(self):
  12. print("?")
  13. #data options.
  14. #train file path.
  15. self.train_data = 'data/gossipcop/cascade.txt'
  16. #test file path.
  17. self.test_data = 'data/gossipcop/cascadetest.txt'
  18. self.u2idx_dict = 'data/gossipcop/u2idx.pickle'
  19. self.idx2u_dict = 'data/gossipcop/idx2u.pickle'
  20. #save path.
  21. self.save_path = ''
  22. self.batch_size = 32
  23. class DataLoader(object):
  24. ''' For data iteration '''
  25. def __init__(
  26. self, use_valid=False, load_dict=True, cuda=True, batch_size=32, shuffle=True, test=False):
  27. self.options = Options()
  28. self.options.batch_size = batch_size
  29. self._u2idx = {}
  30. self._idx2u = []
  31. self.use_valid = use_valid
  32. if not load_dict:
  33. self._buildIndex()
  34. with open(self.options.u2idx_dict, 'wb') as handle:
  35. pickle.dump(self._u2idx, handle, protocol=pickle.HIGHEST_PROTOCOL)
  36. with open(self.options.idx2u_dict, 'wb') as handle:
  37. pickle.dump(self._idx2u, handle, protocol=pickle.HIGHEST_PROTOCOL)
  38. else:
  39. with open(self.options.u2idx_dict, 'rb') as handle:
  40. self._u2idx = pickle.load(handle)
  41. with open(self.options.idx2u_dict, 'rb') as handle:
  42. self._idx2u = pickle.load(handle)
  43. self.user_size = len(self._u2idx)
  44. self._train_cascades = self._readFromFile(self.options.train_data)
  45. self._test_cascades = self._readFromFile(self.options.test_data)
  46. self.train_size = len(self._train_cascades)
  47. self.test_size = len(self._test_cascades)
  48. print("user size:%d" % (self.user_size-2)) # minus pad and eos
  49. print("training set size:%d testing set size:%d" % (self.train_size, self.test_size))
  50. self.cuda = cuda
  51. self.test = test
  52. if not self.use_valid:
  53. self._n_batch = int(np.ceil(len(self._train_cascades) / batch_size))
  54. else:
  55. self._n_batch = int(np.ceil(len(self._test_cascades) / batch_size))
  56. self._batch_size = self.options.batch_size
  57. self._iter_count = 0
  58. self._need_shuffle = shuffle
  59. if self._need_shuffle:
  60. random.shuffle(self._train_cascades)
  61. def _buildIndex(self):
  62. #compute an index of the users that appear at least once in the training and testing cascades.
  63. opts = self.options
  64. train_user_set = set()
  65. test_user_set = set()
  66. lineid=0
  67. for line in open(opts.train_data):
  68. lineid+=1
  69. if len(line.strip()) == 0:
  70. continue
  71. chunks = line.strip().split()
  72. for chunk in chunks:
  73. try:
  74. user, timestamp = chunk.split(',')
  75. except:
  76. print(line)
  77. print(chunk)
  78. print(lineid)
  79. train_user_set.add(user)
  80. for line in open(opts.test_data):
  81. if len(line.strip()) == 0:
  82. continue
  83. chunks = line.strip().split()
  84. for chunk in chunks:
  85. user, timestamp = chunk.split(',')
  86. test_user_set.add(user)
  87. user_set = train_user_set | test_user_set
  88. pos = 0
  89. self._u2idx['<blank>'] = pos
  90. self._idx2u.append('<blank>')
  91. pos += 1
  92. self._u2idx['</s>'] = pos
  93. self._idx2u.append('</s>')
  94. pos += 1
  95. for user in user_set:
  96. self._u2idx[user] = pos
  97. self._idx2u.append(user)
  98. pos += 1
  99. opts.user_size = len(user_set) + 2
  100. self.user_size = len(user_set) + 2
  101. print("user_size : %d" % (opts.user_size))
  102. def _readFromFile(self, filename):
  103. """read all cascade from training or testing files. """
  104. t_cascades = []
  105. for line in open(filename):
  106. if len(line.strip()) == 0:
  107. continue
  108. userlist = []
  109. chunks = line.strip().split()
  110. for chunk in chunks:
  111. try:
  112. user, timestamp = chunk.split(',')
  113. except:
  114. print(chunk)
  115. if user in self._u2idx:
  116. userlist.append(self._u2idx[user])
  117. #if len(userlist) > 500:
  118. # break
  119. # uncomment these lines if your GPU memory is not enough
  120. if len(userlist) > 1:
  121. userlist.append(Constants.EOS)
  122. t_cascades.append(userlist)
  123. return t_cascades
  124. def __iter__(self):
  125. return self
  126. def __next__(self):
  127. return self.next()
  128. def __len__(self):
  129. return self._n_batch
  130. def next(self):
  131. ''' Get the next batch '''
  132. def pad_to_longest(insts):
  133. ''' Pad the instance to the max seq length in batch '''
  134. max_len = max(len(inst) for inst in insts)
  135. inst_data = np.array([
  136. inst + [Constants.PAD] * (max_len - len(inst))
  137. for inst in insts])
  138. inst_data_tensor = Variable(
  139. torch.LongTensor(inst_data), volatile=self.test)
  140. if self.cuda:
  141. inst_data_tensor = inst_data_tensor.cuda(CUDA)
  142. return inst_data_tensor
  143. if self._iter_count < self._n_batch:
  144. batch_idx = self._iter_count
  145. self._iter_count += 1
  146. start_idx = batch_idx * self._batch_size
  147. end_idx = (batch_idx + 1) * self._batch_size
  148. if not self.use_valid:
  149. seq_insts = self._train_cascades[start_idx:end_idx]
  150. else:
  151. seq_insts = self._test_cascades[start_idx:end_idx]
  152. seq_data = pad_to_longest(seq_insts)
  153. #print('???')
  154. #print(seq_data.data)
  155. #print(seq_data.size())
  156. return seq_data
  157. else:
  158. if self._need_shuffle:
  159. random.shuffle(self._train_cascades)
  160. #random.shuffle(self._test_cascades)
  161. self._iter_count = 0
  162. raise StopIteration()