Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.1KB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import sys
  2. import copy
  3. import torch
  4. import random
  5. import numpy as np
  6. from collections import defaultdict, Counter
  7. from multiprocessing import Process, Queue
  8. # sampler for batch generation
  9. def random_neq(l, r, s):
  10. t = np.random.randint(l, r)
  11. while t in s:
  12. t = np.random.randint(l, r)
  13. return t
  14. def trans_to_cuda(variable):
  15. if torch.cuda.is_available():
  16. return variable.cuda()
  17. else:
  18. return variable
  19. def trans_to_cpu(variable):
  20. if torch.cuda.is_available():
  21. return variable.cpu()
  22. else:
  23. return variable
  24. # train/val/test data generation
  25. def data_load(fname, num_sample):
  26. usernum = 0
  27. itemnum = 0
  28. user_train = defaultdict(list)
  29. # assume user/item index starting from 1
  30. f = open('data/%s/%s_train.csv' % (fname, fname), 'r')
  31. for line in f:
  32. u, i, t = line.rstrip().split('\t')
  33. u = int(u)
  34. i = int(i)
  35. usernum = max(u, usernum)
  36. itemnum = max(i, itemnum)
  37. user_train[u].append(i)
  38. f.close()
  39. # read in new users for testing
  40. user_input_test = {}
  41. user_input_valid = {}
  42. user_valid = {}
  43. user_test = {}
  44. User_test_new = defaultdict(list)
  45. f = open('data/%s/%s_test_new_user.csv' % (fname, fname), 'r')
  46. for line in f:
  47. u, i, t = line.rstrip().split('\t')
  48. u = int(u)
  49. i = int(i)
  50. User_test_new[u].append(i)
  51. f.close()
  52. for user in User_test_new:
  53. if len(User_test_new[user]) > num_sample:
  54. if random.random()<0.3:
  55. user_input_valid[user] = User_test_new[user][:num_sample]
  56. user_valid[user] = []
  57. user_valid[user].append(User_test_new[user][num_sample])
  58. else:
  59. user_input_test[user] = User_test_new[user][:num_sample]
  60. user_test[user] = []
  61. user_test[user].append(User_test_new[user][num_sample])
  62. return [user_train, usernum, itemnum, user_input_test, user_test, user_input_valid, user_valid]
  63. class DataLoader(object):
  64. def __init__(self, user_train, user_test, itemnum, parameter):
  65. self.curr_rel_idx = 0
  66. self.bs = parameter['batch_size']
  67. self.maxlen = parameter['K']
  68. self.valid_user = []
  69. for u in user_train:
  70. if len(user_train[u]) < self.maxlen or len(user_test[u]) < 1: continue
  71. self.valid_user.append(u)
  72. self.num_tris = len(self.valid_user)
  73. self.train = user_train
  74. self.test = user_test
  75. self.itemnum = itemnum
  76. def next_one_on_eval(self):
  77. if self.curr_tri_idx == self.num_tris:
  78. return "EOT", "EOT"
  79. u = self.valid_user[self.curr_tri_idx]
  80. self.curr_tri_idx += 1
  81. seq = np.zeros([self.maxlen], dtype=np.int32)
  82. pos = np.zeros([self.maxlen - 1], dtype=np.int32)
  83. neg = np.zeros([self.maxlen - 1], dtype=np.int32)
  84. idx = self.maxlen - 1
  85. ts = set(self.train[u])
  86. for i in reversed(self.train[u]):
  87. seq[idx] = i
  88. if idx > 0:
  89. pos[idx - 1] = i
  90. if i != 0: neg[idx - 1] = random_neq(1, self.itemnum + 1, ts)
  91. idx -= 1
  92. if idx == -1: break
  93. curr_rel = u
  94. support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
  95. for idx in range(self.maxlen-1):
  96. support_triples.append([seq[idx],curr_rel,pos[idx]])
  97. support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
  98. rated = ts
  99. rated.add(0)
  100. query_triples.append([seq[-1],curr_rel,self.test[u][0]])
  101. for _ in range(100):
  102. t = np.random.randint(1, self.itemnum + 1)
  103. while t in rated: t = np.random.randint(1, self.itemnum + 1)
  104. negative_triples.append([seq[-1],curr_rel,t])
  105. support_triples = [support_triples]
  106. support_negative_triples = [support_negative_triples]
  107. query_triples = [query_triples]
  108. negative_triples = [negative_triples]
  109. return [support_triples, support_negative_triples, query_triples, negative_triples], curr_rel