You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataLoader_Feat.py 7.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. ''' Data Loader class for training iteration '''
  2. import random
  3. import numpy as np
  4. import torch
  5. from torch.autograd import Variable
  6. import transformer.Constants as Constants
  7. import logging
  8. import pickle
  9. import json
  10. CUDA = 0
  11. class Options(object):
  12. def __init__(self):
  13. add = "/media/external_3TB/3TB/ramezani/pmoini/Trial/data_d8"
  14. # data options.
  15. # train file path.
  16. self.train_data = f'{add}/cascade.txt'
  17. print(self.train_data)
  18. # test file path.
  19. self.test_data = f'{add}/cascadetest.txt'
  20. self.u2idx_dict = f'{add}/u2idx.pickle'
  21. self.idx2u_dict = f'{add}/idx2u.pickle'
  22. self.idx2vec_dict = f'{add}/idx2vec.pickle'
  23. self.user_data = f'{add}/users_limited.txt'
  24. # save path.
  25. self.save_path = ''
  26. self.batch_size = 16
  27. class DataLoader(object):
  28. ''' For data iteration '''
  29. def __init__(
  30. self, use_valid=False, load_dict=True, cuda=True, batch_size=32, shuffle=True, test=False):
  31. self.options = Options()
  32. self.options.batch_size = batch_size
  33. self._u2idx = {}
  34. self._idx2u = []
  35. self._idx2vec = {}
  36. self.use_valid = use_valid
  37. if not load_dict:
  38. self._buildIndex()
  39. with open(self.options.u2idx_dict, 'wb') as handle:
  40. pickle.dump(self._u2idx, handle, protocol=pickle.HIGHEST_PROTOCOL)
  41. with open(self.options.idx2u_dict, 'wb') as handle:
  42. pickle.dump(self._idx2u, handle, protocol=pickle.HIGHEST_PROTOCOL)
  43. with open(self.options.idx2vec_dict, 'wb') as handle:
  44. pickle.dump(self._idx2vec, handle, protocol=pickle.HIGHEST_PROTOCOL)
  45. else:
  46. with open(self.options.u2idx_dict, 'rb') as handle:
  47. self._u2idx = pickle.load(handle)
  48. with open(self.options.idx2u_dict, 'rb') as handle:
  49. self._idx2u = pickle.load(handle)
  50. self.user_size = len(self._u2idx)
  51. self._train_cascades = self._readFromFile(self.options.train_data)
  52. self._test_cascades = self._readFromFile(self.options.test_data)
  53. self.train_size = len(self._train_cascades)
  54. self.test_size = len(self._test_cascades)
  55. print("user size:%d" % (self.user_size - 2)) # minus pad and eos
  56. print("training set size:%d testing set size:%d" % (self.train_size, self.test_size))
  57. self.cuda = cuda
  58. self.test = test
  59. if not self.use_valid:
  60. self._n_batch = int(np.ceil(len(self._train_cascades) / batch_size))
  61. else:
  62. self._n_batch = int(np.ceil(len(self._test_cascades) / batch_size))
  63. self._batch_size = self.options.batch_size
  64. self._iter_count = 0
  65. self._need_shuffle = shuffle
  66. if self._need_shuffle:
  67. random.shuffle(self._train_cascades)
  68. def _buildIndex(self):
  69. # compute an index of the users that appear at least once in the training and testing cascades.
  70. opts = self.options
  71. train_user_set = set()
  72. test_user_set = set()
  73. lineid = 0
  74. for line in open(opts.train_data):
  75. lineid += 1
  76. if len(line.strip()) == 0:
  77. continue
  78. chunks = line.strip().split()
  79. for chunk in chunks:
  80. try:
  81. user, timestamp = chunk.split(',')
  82. except:
  83. print(line)
  84. print(chunk)
  85. print(lineid)
  86. train_user_set.add(user)
  87. for line in open(opts.test_data):
  88. if len(line.strip()) == 0:
  89. continue
  90. chunks = line.strip().split()
  91. for chunk in chunks:
  92. user, timestamp = chunk.split(',')
  93. test_user_set.add(user)
  94. user_set = train_user_set | test_user_set
  95. pos = 0
  96. self._u2idx['<blank>'] = pos
  97. self._idx2u.append('<blank>')
  98. self._idx2vec[pos] = [pos] * 8
  99. pos += 1
  100. self._u2idx['</s>'] = pos
  101. self._idx2u.append('</s>')
  102. self._idx2vec[pos] = [pos] * 8
  103. pos += 1
  104. user_data = [json.loads(d) for d in open(opts.user_data, "rt").readlines()]
  105. user_dic = {}
  106. for user_vector, user_id in user_data:
  107. user_dic[user_id] = user_vector
  108. for user in user_set:
  109. self._u2idx[user] = pos
  110. self._idx2vec[pos] = user_dic[int(user)]
  111. self._idx2u.append(user)
  112. pos += 1
  113. opts.user_size = len(user_set) + 2
  114. self.user_size = len(user_set) + 2
  115. print("user_size : %d" % (opts.user_size))
  116. def _readFromFile(self, filename):
  117. """read all cascade from training or testing files. """
  118. t_cascades = []
  119. for line in open(filename):
  120. if len(line.strip()) == 0:
  121. continue
  122. userlist = []
  123. chunks = line.strip().split()
  124. for chunk in chunks:
  125. try:
  126. user, timestamp = chunk.split(',')
  127. except:
  128. print(chunk)
  129. if int(user) in self._u2idx:
  130. userlist.append(self._u2idx[int(user)])
  131. # if len(userlist) > 500:
  132. # break
  133. # uncomment these lines if your GPU memory is not enough
  134. if len(userlist) > 1:
  135. # userlist.append(Constants.EOS)
  136. t_cascades.append(userlist)
  137. return t_cascades
  138. def __iter__(self):
  139. return self
  140. def __next__(self):
  141. return self.next()
  142. def __len__(self):
  143. return self._n_batch
  144. def next(self):
  145. ''' Get the next batch '''
  146. def pad_to_longest(insts):
  147. ''' Pad the instance to the max seq length in batch '''
  148. max_len = max(len(inst) for inst in insts)
  149. inst_data = np.array([
  150. inst + [Constants.PAD] * (max_len - len(inst))
  151. for inst in insts])
  152. # print(inst_data)
  153. inst_data_tensor = Variable(
  154. torch.LongTensor(inst_data), volatile=self.test)
  155. if self.cuda:
  156. inst_data_tensor = inst_data_tensor.cuda(CUDA)
  157. return inst_data_tensor
  158. if self._iter_count < self._n_batch:
  159. batch_idx = self._iter_count
  160. self._iter_count += 1
  161. start_idx = batch_idx * self._batch_size
  162. end_idx = (batch_idx + 1) * self._batch_size
  163. if not self.use_valid:
  164. seq_insts = self._train_cascades[start_idx:end_idx]
  165. else:
  166. seq_insts = self._test_cascades[start_idx:end_idx]
  167. seq_data = pad_to_longest(seq_insts)
  168. # print('???')
  169. # print(seq_data.data)
  170. # print(seq_data.size())
  171. return seq_data
  172. else:
  173. if self._need_shuffle:
  174. random.shuffle(self._train_cascades)
  175. # random.shuffle(self._test_cascades)
  176. self._iter_count = 0
  177. raise StopIteration()