You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataLoader_new.py 6.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. ''' Data Loader class for training iteration '''
  2. import random
  3. import numpy as np
  4. import torch
  5. from torch.autograd import Variable
  6. import transformer.Constants as Constants
  7. import logging
  8. import pickle
  9. class Options(object):
  10. def __init__(self):
  11. #data options.
  12. #train file path.
  13. self.train_data = 'data/weibo/cascade.txt'
  14. #test file path.
  15. self.test_data = 'data/weibo/cascadetest.txt'
  16. self.u2idx_dict = 'data/weibo/u2vec.pickle'
  17. self.idx2u_dict = 'data/weibo/vec2u.pickle'
  18. #save path.
  19. self.user_data = 'data/weibo/users_limited.txt'
  20. self.save_path = ''
  21. self.batch_size = 32
  22. class DataLoader(object):
  23. ''' For data iteration '''
  24. def __init__(
  25. self, use_valid=False, load_dict=True, cuda=True, batch_size=32, shuffle=True, test=False):
  26. self.options = Options()
  27. self.options.batch_size = batch_size
  28. self._u2idx = {}
  29. self._idx2u = []
  30. self.use_valid = use_valid
  31. if not load_dict:
  32. self._buildIndex()
  33. with open(self.options.u2idx_dict, 'wb') as handle:
  34. pickle.dump(self._u2idx, handle, protocol=pickle.HIGHEST_PROTOCOL)
  35. with open(self.options.idx2u_dict, 'wb') as handle:
  36. pickle.dump(self._idx2u, handle, protocol=pickle.HIGHEST_PROTOCOL)
  37. else:
  38. with open(self.options.u2idx_dict, 'rb') as handle:
  39. self._u2idx = pickle.load(handle)
  40. with open(self.options.idx2u_dict, 'rb') as handle:
  41. self._idx2u = pickle.load(handle)
  42. self.user_size = len(self._u2idx)
  43. self._train_cascades = self._readFromFile(self.options.train_data)
  44. self._test_cascades = self._readFromFile(self.options.test_data)
  45. self.train_size = len(self._train_cascades)
  46. self.test_size = len(self._test_cascades)
  47. print("user size:%d" % (self.user_size-2)) # minus pad and eos
  48. print("training set size:%d testing set size:%d" % (self.train_size, self.test_size))
  49. self.cuda = cuda
  50. self.test = test
  51. if not self.use_valid:
  52. self._n_batch = int(np.ceil(len(self._train_cascades) / batch_size))
  53. else:
  54. self._n_batch = int(np.ceil(len(self._test_cascades) / batch_size))
  55. self._batch_size = self.options.batch_size
  56. self._iter_count = 0
  57. self._need_shuffle = shuffle
  58. if self._need_shuffle:
  59. random.shuffle(self._train_cascades)
  60. def _buildIndex(self):
  61. #compute an index of the users that appear at least once in the training and testing cascades.
  62. opts = self.options
  63. train_user_set = set()
  64. test_user_set = set()
  65. lineid=0
  66. for line in open(opts.train_data):
  67. lineid+=1
  68. if len(line.strip()) == 0:
  69. continue
  70. chunks = line.strip().split()
  71. for chunk in chunks:
  72. try:
  73. user, timestamp = chunk.split(',')
  74. except:
  75. print(line)
  76. print(chunk)
  77. print(lineid)
  78. train_user_set.add(user)
  79. for line in open(opts.test_data):
  80. if len(line.strip()) == 0:
  81. continue
  82. chunks = line.strip().split()
  83. for chunk in chunks:
  84. user, timestamp = chunk.split(',')
  85. test_user_set.add(user)
  86. user_set = train_user_set | test_user_set
  87. pos = 0
  88. self._u2idx['<blank>'] = pos
  89. self._idx2u.append('<blank>')
  90. pos += 1
  91. self._u2idx['</s>'] = pos
  92. self._idx2u.append('</s>')
  93. pos += 1
  94. user_data = [json.loads(d) for d in open(opt.user_data, "rt").readlines()]
  95. user_dic={}
  96. for user_vector,user_id in user_data:
  97. user_dic[user_id]=user_vector
  98. for user in user_set:
  99. self._u2vec[user] = user_dic[user]
  100. opts.user_size = len(user_set) + 2
  101. self.user_size = len(user_set) + 2
  102. print("user_size : %d" % (opts.user_size))
  103. def _readFromFile(self, filename):
  104. """read all cascade from training or testing files. """
  105. t_cascades = []
  106. for line in open(filename):
  107. if len(line.strip()) == 0:
  108. continue
  109. userlist = []
  110. chunks = line.strip().split()
  111. for chunk in chunks:
  112. try:
  113. user, timestamp = chunk.split(',')
  114. except:
  115. print(chunk)
  116. if user in self._u2idx:
  117. userlist.append(self._u2idx[user])
  118. #if len(userlist) > 500:
  119. # break
  120. # uncomment these lines if your GPU memory is not enough
  121. if len(userlist) > 1:
  122. userlist.append(Constants.EOS)
  123. t_cascades.append(userlist)
  124. return t_cascades
  125. def __iter__(self):
  126. return self
  127. def __next__(self):
  128. return self.next()
  129. def __len__(self):
  130. return self._n_batch
  131. def next(self):
  132. ''' Get the next batch '''
  133. def pad_to_longest(insts):
  134. ''' Pad the instance to the max seq length in batch '''
  135. max_len = max(len(inst) for inst in insts)
  136. inst_data = np.array([
  137. inst + [Constants.PAD] * (max_len - len(inst))
  138. for inst in insts])
  139. inst_data_tensor = Variable(
  140. torch.LongTensor(inst_data), volatile=self.test)
  141. if self.cuda:
  142. inst_data_tensor = inst_data_tensor.cuda()
  143. return inst_data_tensor
  144. if self._iter_count < self._n_batch:
  145. batch_idx = self._iter_count
  146. self._iter_count += 1
  147. start_idx = batch_idx * self._batch_size
  148. end_idx = (batch_idx + 1) * self._batch_size
  149. if not self.use_valid:
  150. seq_insts = self._train_cascades[start_idx:end_idx]
  151. else:
  152. seq_insts = self._test_cascades[start_idx:end_idx]
  153. seq_data = pad_to_longest(seq_insts)
  154. #print('???')
  155. #print(seq_data.data)
  156. #print(seq_data.size())
  157. return seq_data
  158. else:
  159. if self._need_shuffle:
  160. random.shuffle(self._train_cascades)
  161. #random.shuffle(self._test_cascades)
  162. self._iter_count = 0
  163. raise StopIteration()