Adapted to Movie lens dataset
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

loader.py 10.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import json
  2. import random
  3. import torch
  4. import numpy as np
  5. import pickle
  6. import codecs
  7. import re
  8. import os
  9. import datetime
  10. import tqdm
  11. import pandas as pd
  12. #convert userids to userdict key-id(int), val:onehot_vector(tensor)
  13. #element in list is str type.
  14. def to_onehot_dict(list):
  15. dict={}
  16. length = len(list)
  17. for index, element in enumerate(list):
  18. vector = torch.zeros(1, length).long()
  19. element = int(element)
  20. vector[:, element] = 1.0
  21. dict[element] = vector
  22. return dict
  23. def load_list(fname):
  24. list_ = []
  25. with open(fname, encoding="utf-8") as f:
  26. for line in f.readlines():
  27. list_.append(line.strip())
  28. return list_
  29. # used for merge dictionaries.
  30. def merge_key(dict1, dict2):
  31. res = {**dict1, **dict2}
  32. return res
  33. def merge_value(dict1, dict2): # merge and item_cold
  34. for key, value in dict2.items():
  35. if key in dict1.keys():
  36. # if list(set(dict1[key]+value)) the final number of movies-1m is 1000205
  37. new_value = dict1[key]+value
  38. dict1[key] = new_value
  39. else:
  40. print('Unexpected key.')
  41. def count_values(dict):
  42. count_val = 0
  43. for key, value in dict.items():
  44. count_val += len(value)
  45. return count_val
  46. def construct_dictionary(user_list, total_dict):
  47. dict = {}
  48. for i in range(len(user_list)):
  49. dict[str(user_list[i])] = total_dict[str(user_list[i])]
  50. return dict
  51. class Preprocess(object):
  52. """
  53. Preprocess the training, validation and test data.
  54. Generate the episode-style data.
  55. """
  56. def __init__(self, opt):
  57. self.batch_size = opt["batch_size"]
  58. self.opt = opt
  59. # warm data ratio
  60. self.train_ratio = opt['train_ratio']
  61. self.valid_ratio = opt['valid_ratio']
  62. self.test_ratio = 1 - self.train_ratio - self.valid_ratio
  63. self.dataset_path = opt["data_dir"]
  64. self.support_size = opt['support_size']
  65. self.query_size = opt['query_size']
  66. self.max_len = opt['max_len']
  67. # save one-hot dimension length
  68. uf_dim, if_dim = self.preprocess(self.dataset_path)
  69. self.uf_dim = uf_dim
  70. self.if_dim = if_dim
  71. def preprocess(self, dataset_path):
  72. """ Preprocess the data and convert to ids. """
  73. #Create training-validation-test datasets
  74. print('Create training, validation and test data from scratch!')
  75. with open('./{}/interaction_dict_x.json'.format(dataset_path), 'r', encoding='utf-8') as f:
  76. inter_dict_x = json.loads(f.read())
  77. with open('./{}/interaction_dict_y.json'.format(dataset_path), 'r', encoding='utf-8') as f:
  78. inter_dict_y = json.loads(f.read())
  79. print('The size of total interactions is %d.' % (count_values(inter_dict_x))) # 42346
  80. assert count_values(inter_dict_x) == count_values(inter_dict_y)
  81. with open('./{}/user_list.json'.format(dataset_path), 'r', encoding='utf-8') as f:
  82. userids = json.loads(f.read())
  83. with open('./{}/item_list.json'.format(dataset_path), 'r', encoding='utf-8') as f:
  84. itemids = json.loads(f.read())
  85. #userids = list(inter_dict_x.keys())
  86. random.shuffle(userids)
  87. warm_user_size = int(len(userids) * self.train_ratio)
  88. valid_user_size = int(len(userids) * self.valid_ratio)
  89. warm_users = userids[:warm_user_size]
  90. valid_users = userids[warm_user_size:warm_user_size+valid_user_size]
  91. cold_users = userids[warm_user_size+valid_user_size:]
  92. assert len(userids) == len(warm_users)+len(valid_users)+len(cold_users)
  93. # Construct the training data dict
  94. training_dict_x = construct_dictionary(warm_users, inter_dict_x)
  95. training_dict_y = construct_dictionary(warm_users, inter_dict_y)
  96. #Avoid the new items shown in test data in the case of cold user.
  97. item_set = set()
  98. for i in training_dict_x.values():
  99. i = set(i)
  100. item_set = item_set.union(i)
  101. # Construct one-hot dictionary
  102. user_dict = to_onehot_dict(userids)
  103. # only items contained in all data are encoded.
  104. item_dict = to_onehot_dict(itemids)
  105. # This part of data is not used, so we do not process it temporally.
  106. valid_dict_x = construct_dictionary(valid_users, inter_dict_x)
  107. valid_dict_y = construct_dictionary(valid_users, inter_dict_y)
  108. assert count_values(valid_dict_x) == count_values(valid_dict_y)
  109. test_dict_x = construct_dictionary(cold_users, inter_dict_x)
  110. test_dict_y = construct_dictionary(cold_users, inter_dict_y)
  111. assert count_values(test_dict_x) == count_values(test_dict_y)
  112. print('Before delete new items in test data, test data has %d interactions.' % (count_values(test_dict_x)))
  113. #Delete the new items in test data.
  114. unseen_count = 0
  115. for key, value in test_dict_x.items():
  116. assert len(value) == len(test_dict_y[key])
  117. unseen_item_index = [index for index, i in enumerate(value) if i not in item_set]
  118. unseen_count+=len(unseen_item_index)
  119. if len(unseen_item_index) == 0:
  120. continue
  121. else:
  122. new_value_x = [element for index, element in enumerate(value) if index not in unseen_item_index]
  123. new_value_y = [test_dict_y[key][index] for index, element in enumerate(value) if index not in unseen_item_index]
  124. test_dict_x[key] = new_value_x
  125. test_dict_y[key] = new_value_y
  126. print('After delete new items in test data, test data has %d interactions.' % (count_values(test_dict_x)))
  127. assert count_values(test_dict_x) == count_values(test_dict_y)
  128. print('The number of total unseen interactions is %d.' % (unseen_count))
  129. pickle.dump(training_dict_x, open("{}/training_dict_x_{:2f}.pkl".format(dataset_path, self.train_ratio), "wb"))
  130. pickle.dump(training_dict_y, open("{}/training_dict_y_{:2f}.pkl".format(dataset_path, self.train_ratio), "wb"))
  131. pickle.dump(valid_dict_x, open("{}/valid_dict_x_{:2f}.pkl".format(dataset_path, self.valid_ratio), "wb"))
  132. pickle.dump(valid_dict_y, open("{}/valid_dict_y_{:2f}.pkl".format(dataset_path, self.valid_ratio), "wb"))
  133. pickle.dump(test_dict_x, open("{}/test_dict_x_{:2f}.pkl".format(dataset_path, self.test_ratio), "wb"))
  134. pickle.dump(test_dict_y, open("{}/test_dict_y_{:2f}.pkl".format(dataset_path, self.test_ratio), "wb"))
  135. def generate_episodes(dict_x, dict_y, category, support_size, query_size, max_len, dir="log"):
  136. idx = 0
  137. if not os.path.exists("{}/{}/{}".format(dataset_path, category, dir)):
  138. os.makedirs("{}/{}/{}".format(dataset_path, category, dir))
  139. os.makedirs("{}/{}/{}".format(dataset_path, category, "evidence"))
  140. for _, user_id in enumerate(dict_x.keys()):
  141. u_id = int(user_id)
  142. seen_music_len = len(dict_x[str(u_id)])
  143. indices = list(range(seen_music_len))
  144. # filter some users with their interactions, i.e., tasks
  145. if seen_music_len < (support_size + query_size) or seen_music_len > max_len:
  146. continue
  147. random.shuffle(indices)
  148. tmp_x = np.array(dict_x[str(u_id)])
  149. tmp_y = np.array(dict_y[str(u_id)])
  150. support_x_app = None
  151. for m_id in tmp_x[indices[:support_size]]:
  152. m_id = int(m_id)
  153. tmp_x_converted = torch.cat((item_dict[m_id], user_dict[u_id]), 1)
  154. try:
  155. support_x_app = torch.cat((support_x_app, tmp_x_converted), 0)
  156. except:
  157. support_x_app = tmp_x_converted
  158. query_x_app = None
  159. for m_id in tmp_x[indices[support_size:]]:
  160. m_id = int(m_id)
  161. u_id = int(user_id)
  162. tmp_x_converted = torch.cat((item_dict[m_id], user_dict[u_id]), 1)
  163. try:
  164. query_x_app = torch.cat((query_x_app, tmp_x_converted), 0)
  165. except:
  166. query_x_app = tmp_x_converted
  167. support_y_app = torch.FloatTensor(tmp_y[indices[:support_size]])
  168. query_y_app = torch.FloatTensor(tmp_y[indices[support_size:]])
  169. pickle.dump(support_x_app, open("{}/{}/{}/supp_x_{}.pkl".format(dataset_path, category, dir, idx), "wb"))
  170. pickle.dump(support_y_app, open("{}/{}/{}/supp_y_{}.pkl".format(dataset_path, category, dir, idx), "wb"))
  171. pickle.dump(query_x_app, open("{}/{}/{}/query_x_{}.pkl".format(dataset_path, category, dir, idx), "wb"))
  172. pickle.dump(query_y_app, open("{}/{}/{}/query_y_{}.pkl".format(dataset_path, category, dir, idx), "wb"))
  173. # used for evidence candidate selection
  174. with open("{}/{}/{}/supp_x_{}_u_m_ids.txt".format(dataset_path, category, "evidence", idx), "w") as f:
  175. for m_id in tmp_x[indices[:support_size]]:
  176. f.write("{}\t{}\n".format(u_id, m_id))
  177. with open("{}/{}/{}/query_x_{}_u_m_ids.txt".format(dataset_path, category, "evidence", idx), "w") as f:
  178. for m_id in tmp_x[indices[support_size:]]:
  179. f.write("{}\t{}\n".format(u_id, m_id))
  180. idx+=1
  181. print("Generate eposide data for training.")
  182. generate_episodes(training_dict_x, training_dict_y, "training", self.support_size, self.query_size, self.max_len)
  183. print("Generate eposide data for validation.")
  184. generate_episodes(valid_dict_x, valid_dict_y, "validation", self.support_size, self.query_size, self.max_len)
  185. print("Generate eposide data for testing.")
  186. generate_episodes(test_dict_x, test_dict_y, "testing", self.support_size, self.query_size, self.max_len)
  187. return len(userids), len(itemids)