|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import re
- import os
- import json
- import torch
- import numpy as np
- import random
- import pickle
- from tqdm import tqdm
-
- from options import states
- from dataset import movielens_1m
-
-
- def item_converting(row, rate_list, genre_list, director_list, actor_list):
- rate_idx = torch.tensor([[rate_list.index(str(row['rate']))]]).long()
- genre_idx = torch.zeros(1, 25).long()
- for genre in str(row['genre']).split(", "):
- idx = genre_list.index(genre)
- genre_idx[0, idx] = 1
- director_idx = torch.zeros(1, 2186).long()
- for director in str(row['director']).split(", "):
- idx = director_list.index(re.sub(r'\([^()]*\)', '', director))
- director_idx[0, idx] = 1
- actor_idx = torch.zeros(1, 8030).long()
- for actor in str(row['actors']).split(", "):
- idx = actor_list.index(actor)
- actor_idx[0, idx] = 1
- return torch.cat((rate_idx, genre_idx, director_idx, actor_idx), 1)
-
-
- def user_converting(row, gender_list, age_list, occupation_list, zipcode_list):
- gender_idx = torch.tensor([[gender_list.index(str(row['gender']))]]).long()
- age_idx = torch.tensor([[age_list.index(str(row['age']))]]).long()
- occupation_idx = torch.tensor([[occupation_list.index(str(row['occupation_code']))]]).long()
- zip_idx = torch.tensor([[zipcode_list.index(str(row['zip'])[:5])]]).long()
- return torch.cat((gender_idx, age_idx, occupation_idx, zip_idx), 1)
-
-
- def load_list(fname):
- list_ = []
- with open(fname, encoding="utf-8") as f:
- for line in f.readlines():
- list_.append(line.strip())
- return list_
-
-
- def generate(master_path):
- dataset_path = "movielens/ml-1m"
- rate_list = load_list("{}/m_rate.txt".format(dataset_path))
- genre_list = load_list("{}/m_genre.txt".format(dataset_path))
- actor_list = load_list("{}/m_actor.txt".format(dataset_path))
- director_list = load_list("{}/m_director.txt".format(dataset_path))
- gender_list = load_list("{}/m_gender.txt".format(dataset_path))
- age_list = load_list("{}/m_age.txt".format(dataset_path))
- occupation_list = load_list("{}/m_occupation.txt".format(dataset_path))
- zipcode_list = load_list("{}/m_zipcode.txt".format(dataset_path))
-
- if not os.path.exists("{}/warm_state/".format(master_path)):
- for state in states:
- os.mkdir("{}/{}/".format(master_path, state))
- if not os.path.exists("{}/log/".format(master_path)):
- os.mkdir("{}/log/".format(master_path))
-
- dataset = movielens_1m()
-
- # hashmap for item information
- if not os.path.exists("{}/m_movie_dict.pkl".format(master_path)):
- movie_dict = {}
- for idx, row in dataset.item_data.iterrows():
- m_info = item_converting(row, rate_list, genre_list, director_list, actor_list)
- movie_dict[row['movie_id']] = m_info
- pickle.dump(movie_dict, open("{}/m_movie_dict.pkl".format(master_path), "wb"))
- else:
- movie_dict = pickle.load(open("{}/m_movie_dict.pkl".format(master_path), "rb"))
- # hashmap for user profile
- if not os.path.exists("{}/m_user_dict.pkl".format(master_path)):
- user_dict = {}
- for idx, row in dataset.user_data.iterrows():
- u_info = user_converting(row, gender_list, age_list, occupation_list, zipcode_list)
- user_dict[row['user_id']] = u_info
- pickle.dump(user_dict, open("{}/m_user_dict.pkl".format(master_path), "wb"))
- else:
- user_dict = pickle.load(open("{}/m_user_dict.pkl".format(master_path), "rb"))
-
- for state in states:
- idx = 0
- if not os.path.exists("{}/{}/{}".format(master_path, "log", state)):
- os.mkdir("{}/{}/{}".format(master_path, "log", state))
- with open("{}/{}.json".format(dataset_path, state), encoding="utf-8") as f:
- dataset = json.loads(f.read())
- with open("{}/{}_y.json".format(dataset_path, state), encoding="utf-8") as f:
- dataset_y = json.loads(f.read())
- for _, user_id in tqdm(enumerate(dataset.keys())):
- u_id = int(user_id)
- seen_movie_len = len(dataset[str(u_id)])
- indices = list(range(seen_movie_len))
-
- if seen_movie_len < 13 or seen_movie_len > 100:
- continue
-
- random.shuffle(indices)
- tmp_x = np.array(dataset[str(u_id)])
- tmp_y = np.array(dataset_y[str(u_id)])
-
- support_x_app = None
- for m_id in tmp_x[indices[:-10]]:
- m_id = int(m_id)
- tmp_x_converted = torch.cat((movie_dict[m_id], user_dict[u_id]), 1)
- try:
- support_x_app = torch.cat((support_x_app, tmp_x_converted), 0)
- except:
- support_x_app = tmp_x_converted
-
- query_x_app = None
- for m_id in tmp_x[indices[-10:]]:
- m_id = int(m_id)
- u_id = int(user_id)
- tmp_x_converted = torch.cat((movie_dict[m_id], user_dict[u_id]), 1)
- try:
- query_x_app = torch.cat((query_x_app, tmp_x_converted), 0)
- except:
- query_x_app = tmp_x_converted
- support_y_app = torch.FloatTensor(tmp_y[indices[:-10]])
- query_y_app = torch.FloatTensor(tmp_y[indices[-10:]])
-
- pickle.dump(support_x_app, open("{}/{}/supp_x_{}.pkl".format(master_path, state, idx), "wb"))
- pickle.dump(support_y_app, open("{}/{}/supp_y_{}.pkl".format(master_path, state, idx), "wb"))
- pickle.dump(query_x_app, open("{}/{}/query_x_{}.pkl".format(master_path, state, idx), "wb"))
- pickle.dump(query_y_app, open("{}/{}/query_y_{}.pkl".format(master_path, state, idx), "wb"))
- with open("{}/log/{}/supp_x_{}_u_m_ids.txt".format(master_path, state, idx), "w") as f:
- for m_id in tmp_x[indices[:-10]]:
- f.write("{}\t{}\n".format(u_id, m_id))
- with open("{}/log/{}/query_x_{}_u_m_ids.txt".format(master_path, state, idx), "w") as f:
- for m_id in tmp_x[indices[-10:]]:
- f.write("{}\t{}\n".format(u_id, m_id))
- idx += 1
|