Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 5.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import sys
  2. import copy
  3. import torch
  4. import random
  5. import numpy as np
  6. from collections import defaultdict, Counter
  7. from multiprocessing import Process, Queue
  8. # sampler for batch generation
  9. def random_neq(l, r, s):
  10. t = np.random.randint(l, r)
  11. while t in s:
  12. t = np.random.randint(l, r)
  13. return t
  14. # user = random.choice(list(user_train.keys()))
  15. # item = random.choice(user_train[user])
  16. #
  17. # while item in s:
  18. # user = random.choice(list(user_train.keys()))
  19. # item = random.choice(user_train[user])
  20. # return item
  21. def trans_to_cuda(variable):
  22. if torch.cuda.is_available():
  23. return variable.cuda()
  24. else:
  25. return variable
  26. def trans_to_cpu(variable):
  27. if torch.cuda.is_available():
  28. return variable.cpu()
  29. else:
  30. return variable
  31. # train/val/test data generation
  32. def data_load(fname, num_sample):
  33. usernum = 0
  34. itemnum = 0
  35. user_train = defaultdict(list)
  36. # assume user/item index starting from 1
  37. f = open('/home/maheri/metaTL/data/%s/%s_train.csv' % (fname, fname), 'r')
  38. for line in f:
  39. u, i, t = line.rstrip().split('\t')
  40. u = int(u)
  41. i = int(i)
  42. usernum = max(u, usernum)
  43. itemnum = max(i, itemnum)
  44. user_train[u].append(i)
  45. f.close()
  46. # read in new users for testing
  47. user_input_test = {}
  48. user_input_valid = {}
  49. user_valid = {}
  50. user_test = {}
  51. User_test_new = defaultdict(list)
  52. f = open('/home/maheri/metaTL/data/%s/%s_test_new_user.csv' % (fname, fname), 'r')
  53. for line in f:
  54. u, i, t = line.rstrip().split('\t')
  55. u = int(u)
  56. i = int(i)
  57. User_test_new[u].append(i)
  58. f.close()
  59. for user in User_test_new:
  60. if len(User_test_new[user]) > num_sample:
  61. if random.random()<0.3:
  62. user_input_valid[user] = User_test_new[user][:num_sample]
  63. user_valid[user] = []
  64. user_valid[user].append(User_test_new[user][num_sample])
  65. else:
  66. user_input_test[user] = User_test_new[user][:num_sample]
  67. user_test[user] = []
  68. user_test[user].append(User_test_new[user][num_sample])
  69. return [user_train, usernum, itemnum, user_input_test, user_test, user_input_valid, user_valid]
  70. class DataLoader(object):
  71. def __init__(self, user_train, user_test, itemnum, parameter):
  72. self.curr_rel_idx = 0
  73. self.bs = parameter['batch_size']
  74. self.maxlen = parameter['K']
  75. self.valid_user = []
  76. for u in user_train:
  77. if len(user_train[u]) < self.maxlen or len(user_test[u]) < 1: continue
  78. self.valid_user.append(u)
  79. self.num_tris = len(self.valid_user)
  80. self.train = user_train
  81. self.test = user_test
  82. self.itemnum = itemnum
  83. # if parameter['number_of_neg']:
  84. # self.number_of_neg = parameter['number_of_neg']
  85. # else:
  86. # self.number_of_neg = 5
  87. def next_one_on_eval(self):
  88. if self.curr_tri_idx == self.num_tris:
  89. return "EOT", "EOT"
  90. u = self.valid_user[self.curr_tri_idx]
  91. self.curr_tri_idx += 1
  92. seq = np.zeros([self.maxlen], dtype=np.int32)
  93. pos = np.zeros([self.maxlen - 1], dtype=np.int32)
  94. # neg = np.zeros([self.maxlen * self.number_of_neg], dtype=np.int32)
  95. neg = np.zeros([self.maxlen - 1], dtype=np.int32)
  96. idx = self.maxlen - 1
  97. ts = set(self.train[u])
  98. for i in reversed(self.train[u]):
  99. seq[idx] = i
  100. if idx > 0:
  101. pos[idx - 1] = i
  102. if i != 0: neg[idx - 1] = random_neq(1, self.itemnum + 1, ts)
  103. idx -= 1
  104. if idx == -1: break
  105. # for i in range(len(neg)):
  106. # neg[i] = random_neq(1, self.itemnum + 1, ts,self.train)
  107. curr_rel = u
  108. support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
  109. for idx in range(self.maxlen-1):
  110. support_triples.append([seq[idx],curr_rel,pos[idx]])
  111. support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
  112. # support_negative_triples.append([seq[-1],curr_rel,neg[idx]])
  113. # for idx in range(len(neg)):
  114. # support_negative_triples.append([seq[-1],curr_rel,neg[idx]])
  115. # print("injam",self.maxlen,list(range(self.maxlen-1)))
  116. # print("====")
  117. # for j in range(self.number_of_neg):
  118. # for idx in range(self.maxlen-1):
  119. # # print(j * self.maxlen + idx)
  120. # support_negative_triples.append([seq[idx], curr_rel, neg[j * (self.maxlen-1) + idx]])
  121. # print("=end=\n\n")
  122. rated = ts
  123. rated.add(0)
  124. query_triples.append([seq[-1],curr_rel,self.test[u][0]])
  125. for _ in range(100):
  126. t = np.random.randint(1, self.itemnum + 1)
  127. while t in rated: t = np.random.randint(1, self.itemnum + 1)
  128. negative_triples.append([seq[-1],curr_rel,t])
  129. support_triples = [support_triples]
  130. support_negative_triples = [support_negative_triples]
  131. query_triples = [query_triples]
  132. negative_triples = [negative_triples]
  133. return [support_triples, support_negative_triples, query_triples, negative_triples], curr_rel