You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataLoader_Feat2.py 8.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. ''' Data Loader class for training iteration '''
  2. import random
  3. import numpy as np
  4. import torch
  5. from torch.autograd import Variable
  6. import transformer.Constants as Constants
  7. import logging
  8. import pickle
  9. import json
  10. class Options(object):
  11. def __init__(self):
  12. print("HERE in dataloader feat 2")
  13. # data options.
  14. # train file path.
  15. self.train_data = 'data/weibo2/cascade.txt'
  16. # test file path.
  17. self.test_data = 'data/weibo2/cascadetest.txt'
  18. self.u2idx_dict = 'data/weibo2/u2idx.pickle'
  19. self.idx2u_dict = 'data/weibo2/idx2u.pickle'
  20. self.idx2vec_dict = 'data/weibo2/idx2vec.pickle'
  21. self.user_data = 'data/weibo2/users_limited.txt'
  22. # save path.
  23. self.save_path = ''
  24. self.batch_size = 32
  25. class DataLoader(object):
  26. ''' For data iteration '''
  27. def __init__(
  28. self, use_valid=False, load_dict=True, cuda=True, batch_size=32, shuffle=True, test=False):
  29. self.options = Options()
  30. self.options.batch_size = batch_size
  31. self._u2idx = {}
  32. self._idx2u = []
  33. self._idx2vec = {}
  34. self.use_valid = use_valid
  35. if not load_dict:
  36. self._buildIndex()
  37. with open(self.options.u2idx_dict, 'wb') as handle:
  38. pickle.dump(self._u2idx, handle, protocol=pickle.HIGHEST_PROTOCOL)
  39. with open(self.options.idx2u_dict, 'wb') as handle:
  40. pickle.dump(self._idx2u, handle, protocol=pickle.HIGHEST_PROTOCOL)
  41. with open(self.options.idx2vec_dict, 'wb') as handle:
  42. pickle.dump(self._idx2vec, handle, protocol=pickle.HIGHEST_PROTOCOL)
  43. else:
  44. with open(self.options.u2idx_dict, 'rb') as handle:
  45. self._u2idx = pickle.load(handle)
  46. with open(self.options.idx2u_dict, 'rb') as handle:
  47. self._idx2u = pickle.load(handle)
  48. self.user_size = len(self._u2idx)
  49. self._train_cascades = self._readFromFile(self.options.train_data)
  50. self._test_cascades = self._readFromFile(self.options.test_data)
  51. self.train_size = len(self._train_cascades)
  52. self.test_size = len(self._test_cascades)
  53. print("user size:%d" % (self.user_size - 2)) # minus pad and eos
  54. print("training set size:%d testing set size:%d" % (self.train_size, self.test_size))
  55. self.cuda = cuda
  56. self.test = test
  57. if not self.use_valid:
  58. self._n_batch = int(np.ceil(len(self._train_cascades) / batch_size))
  59. else:
  60. self._n_batch = int(np.ceil(len(self._test_cascades) / batch_size))
  61. self._batch_size = self.options.batch_size
  62. self._iter_count = 0
  63. self._need_shuffle = shuffle
  64. if self._need_shuffle:
  65. random.shuffle(self._train_cascades)
  66. def _buildIndex(self):
  67. # compute an index of the users that appear at least once in the training and testing cascades.
  68. opts = self.options
  69. train_user_set = set()
  70. test_user_set = set()
  71. lineid = 0
  72. for line in open(opts.train_data):
  73. lineid += 1
  74. if len(line.strip()) == 0:
  75. continue
  76. chunks = line.strip().split()
  77. for chunk in chunks:
  78. try:
  79. user, timestamp = chunk.split(',')
  80. except:
  81. print(line)
  82. print(chunk)
  83. print(lineid)
  84. train_user_set.add(user)
  85. for line in open(opts.test_data):
  86. if len(line.strip()) == 0:
  87. continue
  88. chunks = line.strip().split()
  89. for chunk in chunks:
  90. user, timestamp = chunk.split(',')
  91. test_user_set.add(user)
  92. user_set = train_user_set | test_user_set
  93. pos = 0
  94. self._u2idx['<blank>'] = pos
  95. self._idx2u.append('<blank>')
  96. self._idx2vec[pos] = [pos] * 8
  97. pos += 1
  98. self._u2idx['</s>'] = pos
  99. self._idx2u.append('</s>')
  100. self._idx2vec[pos] = [pos] * 8
  101. pos += 1
  102. user_data = [json.loads(d) for d in open(opts.user_data, "rt").readlines()]
  103. user_dic = {}
  104. for user_vector, user_id in user_data:
  105. user_dic[user_id] = user_vector
  106. for user in user_set:
  107. self._u2idx[user] = pos
  108. self._idx2vec[pos] = user_dic[int(user)]
  109. self._idx2u.append(user)
  110. pos += 1
  111. opts.user_size = len(user_set) + 2
  112. self.user_size = len(user_set) + 2
  113. print("user_size : %d" % (opts.user_size))
  114. def _readFromFile(self, filename):
  115. """read all cascade from training or testing files. """
  116. t_cascades = []
  117. for line in open(filename):
  118. if len(line.strip()) == 0:
  119. continue
  120. userlist = []
  121. chunks = line.strip().split()
  122. for chunk in chunks:
  123. try:
  124. user, timestamp = chunk.split(',')
  125. except:
  126. print(chunk)
  127. if user in self._u2idx:
  128. userlist.append(self._u2idx[user])
  129. # if len(userlist) > 500:
  130. # break
  131. # uncomment these lines if your GPU memory is not enough
  132. if len(userlist) > 1:
  133. # userlist.append(Constants.EOS)
  134. t_cascades.append(userlist)
  135. return t_cascades
  136. def __iter__(self):
  137. return self
  138. def __next__(self):
  139. return self.next()
  140. def __len__(self):
  141. return self._n_batch
  142. def next(self):
  143. ''' Get the next batch '''
  144. def pad_to_longest(insts):
  145. ''' Pad the instance to the max seq length in batch '''
  146. max_len = max(len(inst) for inst in insts)
  147. lens = [len(inst) for inst in insts]
  148. src = []
  149. trg_data = []
  150. for i in range(len(insts)):
  151. seq_len = lens[i]
  152. src.append(insts[i][:int(seq_len / 2)])
  153. trg_data.append([Constants.EOS] + insts[i][int(seq_len / 2):seq_len])
  154. max_len_src = max(len(inst) for inst in src)
  155. src = np.array([inst + [Constants.PAD] * (max_len_src - len(inst))
  156. for inst in src])
  157. src = Variable(torch.LongTensor(src), volatile=self.test)
  158. if self.cuda:
  159. src = src.cuda(0)
  160. max_len_trg = max(len(inst) for inst in trg_data)
  161. trg_data = np.array([inst + [Constants.PAD] * (max_len_trg - len(inst))
  162. for inst in trg_data])
  163. trg_data = Variable(torch.LongTensor(trg_data), volatile=self.test)
  164. if self.cuda:
  165. trg_data = trg_data.cuda(0)
  166. trg = trg_data[:, :-1]
  167. trg_y = trg_data[:, 1:]
  168. src_mask = (src != Constants.PAD).unsqueeze(-2)
  169. trg_mask = (trg_y != Constants.PAD).unsqueeze(-2)
  170. src_lengths = [int(x / 2) for x in lens]
  171. trg_lengths = [x / 2 + 1 if x % 2 == 0 else int(x / 2) + 1 for x in lens]
  172. src_lengths = np.array(src_lengths)
  173. trg_lengths = np.array(trg_lengths)
  174. reverse_idx_src = np.argsort(-src_lengths)
  175. reverse_idx_trg = np.argsort(-trg_lengths)
  176. src = src[reverse_idx_src]
  177. src_mask = src_mask[reverse_idx_src]
  178. src_lengths = src_lengths[reverse_idx_src]
  179. trg = trg[reverse_idx_trg]
  180. trg_mask = trg_mask[reverse_idx_trg]
  181. trg_lengths = trg_lengths[reverse_idx_trg]
  182. try_y = trg_y[reverse_idx_trg]
  183. # print(insts[10],src[10],src_mask[10],src_lengths[10],trg[10],trg_mask[10],trg_lengths[10])
  184. return src, src_mask, src_lengths, trg, trg_y, trg_mask, trg_lengths
  185. if self._iter_count < self._n_batch:
  186. batch_idx = self._iter_count
  187. self._iter_count += 1
  188. start_idx = batch_idx * self._batch_size
  189. end_idx = (batch_idx + 1) * self._batch_size
  190. if not self.use_valid:
  191. seq_insts = self._train_cascades[start_idx:end_idx]
  192. else:
  193. seq_insts = self._test_cascades[start_idx:end_idx]
  194. src, src_mask, src_lengths, trg, trg_y, trg_mask, trg_lengths = pad_to_longest(seq_insts)
  195. # print('???')
  196. # print(seq_data.data)
  197. # print(seq_data.size())
  198. return src, src_mask, src_lengths, trg, trg_y, trg_mask, trg_lengths
  199. else:
  200. if self._need_shuffle:
  201. random.shuffle(self._train_cascades)
  202. # random.shuffle(self._test_cascades)
  203. self._iter_count = 0
  204. raise StopIteration()