Melu project implemented by l2l and using MetaSGD instead of MAML
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_generation.py 6.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import re
  2. import os
  3. import json
  4. import torch
  5. import numpy as np
  6. import random
  7. import pickle
  8. from tqdm import tqdm
  9. from options import states
  10. from dataset import movielens_1m
  11. def item_converting(row, rate_list, genre_list, director_list, actor_list):
  12. rate_idx = torch.tensor([[rate_list.index(str(row['rate']))]]).long()
  13. genre_idx = torch.zeros(1, 25).long()
  14. for genre in str(row['genre']).split(", "):
  15. idx = genre_list.index(genre)
  16. genre_idx[0, idx] = 1
  17. director_idx = torch.zeros(1, 2186).long()
  18. for director in str(row['director']).split(", "):
  19. idx = director_list.index(re.sub(r'\([^()]*\)', '', director))
  20. director_idx[0, idx] = 1
  21. actor_idx = torch.zeros(1, 8030).long()
  22. for actor in str(row['actors']).split(", "):
  23. idx = actor_list.index(actor)
  24. actor_idx[0, idx] = 1
  25. return torch.cat((rate_idx, genre_idx, director_idx, actor_idx), 1)
  26. def user_converting(row, gender_list, age_list, occupation_list, zipcode_list):
  27. gender_idx = torch.tensor([[gender_list.index(str(row['gender']))]]).long()
  28. age_idx = torch.tensor([[age_list.index(str(row['age']))]]).long()
  29. occupation_idx = torch.tensor([[occupation_list.index(str(row['occupation_code']))]]).long()
  30. zip_idx = torch.tensor([[zipcode_list.index(str(row['zip'])[:5])]]).long()
  31. return torch.cat((gender_idx, age_idx, occupation_idx, zip_idx), 1)
  32. def load_list(fname):
  33. list_ = []
  34. with open(fname, encoding="utf-8") as f:
  35. for line in f.readlines():
  36. list_.append(line.strip())
  37. return list_
  38. def generate(master_path):
  39. dataset_path = "movielens/ml-1m"
  40. rate_list = load_list("{}/m_rate.txt".format(dataset_path))
  41. genre_list = load_list("{}/m_genre.txt".format(dataset_path))
  42. actor_list = load_list("{}/m_actor.txt".format(dataset_path))
  43. director_list = load_list("{}/m_director.txt".format(dataset_path))
  44. gender_list = load_list("{}/m_gender.txt".format(dataset_path))
  45. age_list = load_list("{}/m_age.txt".format(dataset_path))
  46. occupation_list = load_list("{}/m_occupation.txt".format(dataset_path))
  47. zipcode_list = load_list("{}/m_zipcode.txt".format(dataset_path))
  48. if not os.path.exists("{}/warm_state/".format(master_path)):
  49. for state in states:
  50. os.mkdir("{}/{}/".format(master_path, state))
  51. if not os.path.exists("{}/log/".format(master_path)):
  52. os.mkdir("{}/log/".format(master_path))
  53. dataset = movielens_1m()
  54. # hashmap for item information
  55. if not os.path.exists("{}/m_movie_dict.pkl".format(master_path)):
  56. movie_dict = {}
  57. for idx, row in dataset.item_data.iterrows():
  58. m_info = item_converting(row, rate_list, genre_list, director_list, actor_list)
  59. movie_dict[row['movie_id']] = m_info
  60. pickle.dump(movie_dict, open("{}/m_movie_dict.pkl".format(master_path), "wb"))
  61. else:
  62. movie_dict = pickle.load(open("{}/m_movie_dict.pkl".format(master_path), "rb"))
  63. # hashmap for user profile
  64. if not os.path.exists("{}/m_user_dict.pkl".format(master_path)):
  65. user_dict = {}
  66. for idx, row in dataset.user_data.iterrows():
  67. u_info = user_converting(row, gender_list, age_list, occupation_list, zipcode_list)
  68. user_dict[row['user_id']] = u_info
  69. pickle.dump(user_dict, open("{}/m_user_dict.pkl".format(master_path), "wb"))
  70. else:
  71. user_dict = pickle.load(open("{}/m_user_dict.pkl".format(master_path), "rb"))
  72. for state in states:
  73. idx = 0
  74. if not os.path.exists("{}/{}/{}".format(master_path, "log", state)):
  75. os.mkdir("{}/{}/{}".format(master_path, "log", state))
  76. with open("{}/{}.json".format(dataset_path, state), encoding="utf-8") as f:
  77. dataset = json.loads(f.read())
  78. with open("{}/{}_y.json".format(dataset_path, state), encoding="utf-8") as f:
  79. dataset_y = json.loads(f.read())
  80. for _, user_id in tqdm(enumerate(dataset.keys())):
  81. u_id = int(user_id)
  82. seen_movie_len = len(dataset[str(u_id)])
  83. indices = list(range(seen_movie_len))
  84. if seen_movie_len < 13 or seen_movie_len > 100:
  85. continue
  86. random.shuffle(indices)
  87. tmp_x = np.array(dataset[str(u_id)])
  88. tmp_y = np.array(dataset_y[str(u_id)])
  89. support_x_app = None
  90. for m_id in tmp_x[indices[:-10]]:
  91. m_id = int(m_id)
  92. tmp_x_converted = torch.cat((movie_dict[m_id], user_dict[u_id]), 1)
  93. try:
  94. support_x_app = torch.cat((support_x_app, tmp_x_converted), 0)
  95. except:
  96. support_x_app = tmp_x_converted
  97. query_x_app = None
  98. for m_id in tmp_x[indices[-10:]]:
  99. m_id = int(m_id)
  100. u_id = int(user_id)
  101. tmp_x_converted = torch.cat((movie_dict[m_id], user_dict[u_id]), 1)
  102. try:
  103. query_x_app = torch.cat((query_x_app, tmp_x_converted), 0)
  104. except:
  105. query_x_app = tmp_x_converted
  106. support_y_app = torch.FloatTensor(tmp_y[indices[:-10]])
  107. query_y_app = torch.FloatTensor(tmp_y[indices[-10:]])
  108. pickle.dump(support_x_app, open("{}/{}/supp_x_{}.pkl".format(master_path, state, idx), "wb"))
  109. pickle.dump(support_y_app, open("{}/{}/supp_y_{}.pkl".format(master_path, state, idx), "wb"))
  110. pickle.dump(query_x_app, open("{}/{}/query_x_{}.pkl".format(master_path, state, idx), "wb"))
  111. pickle.dump(query_y_app, open("{}/{}/query_y_{}.pkl".format(master_path, state, idx), "wb"))
  112. with open("{}/log/{}/supp_x_{}_u_m_ids.txt".format(master_path, state, idx), "w") as f:
  113. for m_id in tmp_x[indices[:-10]]:
  114. f.write("{}\t{}\n".format(u_id, m_id))
  115. with open("{}/log/{}/query_x_{}_u_m_ids.txt".format(master_path, state, idx), "w") as f:
  116. for m_id in tmp_x[indices[-10:]]:
  117. f.write("{}\t{}\n".format(u_id, m_id))
  118. idx += 1