You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataLoader_f.py 6.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import random
  2. import numpy as np
  3. import torch
  4. from torch.autograd import Variable
  5. import transformer.Constants as Constants
  6. import logging
  7. import pickle
  8. import json
  9. class Options(object):
  10. def __init__(self):
  11. # data options.
  12. # train file path.
  13. self.train_data = 'data/weibo3/cascade.txt'
  14. # test file path.
  15. self.test_data = 'data/weibo3/cascadetest.txt'
  16. self.u2vec_dict = 'data/weibo3/u2vec.pickle'
  17. self.idx2u_dict = 'data/weibo3/idx2u.pickle'
  18. self.user_data = 'data/weibo3/users_limited.txt'
  19. # save path.
  20. self.save_path = ''
  21. self.batch_size = 32
  22. class DataLoader(object):
  23. ''' For data iteration '''
  24. def __init__(
  25. self, use_valid=False, load_dict=True, cuda=True, batch_size=32, shuffle=True, test=False):
  26. self.options = Options()
  27. self.options.batch_size = batch_size
  28. self._u2vec = {}
  29. self._idx2u = []
  30. self.use_valid = use_valid
  31. if not load_dict:
  32. self._buildIndex()
  33. with open(self.options.u2vec_dict, 'wb') as handle:
  34. pickle.dump(self._u2vec, handle, protocol=pickle.HIGHEST_PROTOCOL)
  35. with open(self.options.idx2u_dict, 'wb') as handle:
  36. pickle.dump(self._idx2u, handle, protocol=pickle.HIGHEST_PROTOCOL)
  37. else:
  38. with open(self.options.u2vec_dict, 'rb') as handle:
  39. self._u2vec = pickle.load(handle)
  40. with open(self.options.idx2u_dict, 'rb') as handle:
  41. self._idx2u = pickle.load(handle)
  42. self.user_size = len(self._u2vec)
  43. self._train_cascades = self._readFromFile(self.options.train_data)
  44. self._test_cascades = self._readFromFile(self.options.test_data)
  45. self.train_size = len(self._train_cascades)
  46. self.test_size = len(self._test_cascades)
  47. print("user size:%d" % (self.user_size - 2)) # minus pad and eos
  48. print("training set size:%d testing set size:%d" % (self.train_size, self.test_size))
  49. self.cuda = cuda
  50. self.test = test
  51. if not self.use_valid:
  52. self._n_batch = int(np.ceil(len(self._train_cascades) / batch_size))
  53. else:
  54. self._n_batch = int(np.ceil(len(self._test_cascades) / batch_size))
  55. self._batch_size = self.options.batch_size
  56. self._iter_count = 0
  57. self._need_shuffle = shuffle
  58. if self._need_shuffle:
  59. random.shuffle(self._train_cascades)
  60. def _buildIndex(self):
  61. # compute an index of the users that appear at least once in the training and testing cascades.
  62. opts = self.options
  63. train_user_set = set()
  64. test_user_set = set()
  65. lineid = 0
  66. for line in open(opts.train_data):
  67. lineid += 1
  68. if len(line.strip()) == 0:
  69. continue
  70. chunks = line.strip().split()
  71. for chunk in chunks:
  72. try:
  73. user, timestamp = chunk.split(',')
  74. except:
  75. print(line)
  76. print(chunk)
  77. print(lineid)
  78. train_user_set.add(user)
  79. for line in open(opts.test_data):
  80. if len(line.strip()) == 0:
  81. continue
  82. chunks = line.strip().split()
  83. for chunk in chunks:
  84. user, timestamp = chunk.split(',')
  85. test_user_set.add(user)
  86. user_set = train_user_set | test_user_set
  87. pos = 0
  88. self._u2vec['<blank>'] = pos
  89. self._idx2u.append('<blank>')
  90. self.idx2vec[pos] = [pos]*8
  91. pos += 1
  92. self._u2vec['</s>'] = pos
  93. self._idx2u.append('</s>')
  94. self.idx2vec[pos] = [pos]*8
  95. pos += 1
  96. user_data = [json.loads(d) for d in open(opts.user_data, "rt").readlines()]
  97. user_dic = {}
  98. for user_vector, user_id in user_data:
  99. user_dic[user_id] = user_vector
  100. for user in user_set:
  101. self._u2vec[user] = user_dic[user]
  102. self._idx2u.append(user)
  103. pos += 1
  104. opts.user_size = len(user_set) + 2
  105. self.user_size = len(user_set) + 2
  106. print("user_size : %d" % (opts.user_size))
  107. def _readFromFile(self, filename):
  108. """read all cascade from training or testing files. """
  109. t_cascades = []
  110. for line in open(filename):
  111. if len(line.strip()) == 0:
  112. continue
  113. userlist = []
  114. chunks = line.strip().split()
  115. for chunk in chunks:
  116. try:
  117. user, timestamp = chunk.split(',')
  118. except:
  119. print(chunk)
  120. userlist.append(user)
  121. # if len(userlist) > 500:
  122. # break
  123. # uncomment these lines if your GPU memory is not enough
  124. if len(userlist) > 1:
  125. userlist.append(Constants.EOS)
  126. t_cascades.append(userlist)
  127. return t_cascades
  128. def __iter__(self):
  129. return self
  130. def __next__(self):
  131. return self.next()
  132. def __len__(self):
  133. return self._n_batch
  134. def next(self):
  135. ''' Get the next batch '''
  136. def pad_to_longest(insts):
  137. ''' Pad the instance to the max seq length in batch '''
  138. max_len = max(len(inst) for inst in insts)
  139. inst_data = np.array([
  140. inst + [Constants.PAD] * (max_len - len(inst))
  141. for inst in insts])
  142. print(inst_data)
  143. inst_data_tensor = Variable(
  144. torch.LongTensor(inst_data), volatile=self.test)
  145. if self.cuda:
  146. inst_data_tensor = inst_data_tensor.cuda()
  147. return inst_data_tensor
  148. if self._iter_count < self._n_batch:
  149. batch_idx = self._iter_count
  150. self._iter_count += 1
  151. start_idx = batch_idx * self._batch_size
  152. end_idx = (batch_idx + 1) * self._batch_size
  153. if not self.use_valid:
  154. seq_insts = self._train_cascades[start_idx:end_idx]
  155. else:
  156. seq_insts = self._test_cascades[start_idx:end_idx]
  157. seq_data = pad_to_longest(seq_insts)
  158. # print('???')
  159. # print(seq_data.data)
  160. # print(seq_data.size())
  161. return seq_data
  162. else:
  163. if self._need_shuffle:
  164. random.shuffle(self._train_cascades)
  165. # random.shuffle(self._test_cascades)
  166. self._iter_count = 0
  167. raise StopIteration()