Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 7.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import sys
  2. import copy
  3. import torch
  4. import random
  5. import numpy as np
  6. from collections import defaultdict, Counter
  7. from multiprocessing import Process, Queue
  8. def random_neq(l, r, s, user_train,usernum):
  9. # t = np.random.randint(l, r)
  10. # while t in s:
  11. # t = np.random.randint(l, r)
  12. # return t
  13. # user = np.random.choice(1, usernum + 1)
  14. # user = random.randint(1,usernum+1)
  15. # while len(user_train[user])<3:
  16. # user = random.randint(1, usernum + 1)
  17. # candid_item = user_train[user][random.randint(0, len(user_train[user])-1)]
  18. #
  19. # while candid_item in s:
  20. # while len(user_train[user]) < 3:
  21. # user = random.randint(1, usernum + 1)
  22. # candid_item = user_train[user][random.randint(0, len(user_train[user])-1)]
  23. # return candid_item
  24. user = random.choice(list(user_train.keys()))
  25. item = random.choice(user_train[user])
  26. while item in s:
  27. user = random.choice(list(user_train.keys()))
  28. item = random.choice(user_train[user])
  29. return item
  30. def random_negetive_batch(l, r, s, user_train,usernum, batch_users):
  31. user = np.random.choice(batch_users)
  32. candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]
  33. while candid_item in s:
  34. user = np.random.choice(batch_users)
  35. candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]
  36. return candid_item
  37. def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED,number_of_neg):
  38. def sample(user,batch_users):
  39. if random.random()<=0.5:
  40. # user = np.random.randint(1, usernum + 1)
  41. # while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
  42. seq = np.zeros([maxlen], dtype=np.int32)
  43. pos = np.zeros([maxlen], dtype=np.int32)
  44. neg = np.zeros([(maxlen-1)*number_of_neg + 1], dtype=np.int32)
  45. if len(user_train[user]) < maxlen:
  46. nxt_idx = len(user_train[user]) - 1
  47. else:
  48. nxt_idx = np.random.randint(maxlen,len(user_train[user]))
  49. nxt = user_train[user][nxt_idx]
  50. idx = maxlen - 1
  51. ts = set(user_train[user])
  52. for i in reversed(user_train[user][(nxt_idx - maxlen) : nxt_idx ]):
  53. seq[idx] = i
  54. pos[idx] = nxt
  55. # if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts, user_train,usernum)
  56. nxt = i
  57. idx -= 1
  58. if idx == -1: break
  59. for i in range(len(neg)):
  60. # neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
  61. neg[i] = random_negetive_batch(1, itemnum + 1, ts, user_train, usernum, batch_users = batch_users)
  62. curr_rel = user
  63. support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
  64. for idx in range(maxlen-1):
  65. support_triples.append([seq[idx],curr_rel,pos[idx]])
  66. # support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
  67. # support_negative_triples.append([seq[-1], curr_rel, neg[idx]])
  68. # for idx in range(maxlen*30 - 1):
  69. # support_negative_triples.append([seq[-1], curr_rel, neg[idx]])
  70. for j in range(number_of_neg):
  71. for idx in range(maxlen-1):
  72. support_negative_triples.append([seq[idx], curr_rel, neg[j*(maxlen-1) + idx]])
  73. query_triples.append([seq[-1],curr_rel,pos[-1]])
  74. negative_triples.append([seq[-1],curr_rel,neg[-1]])
  75. return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel
  76. else:
  77. # print("bug happened in sample_function_mixed")
  78. # user = np.random.randint(1, usernum + 1)
  79. # while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
  80. seq = np.zeros([maxlen], dtype=np.int32)
  81. pos = np.zeros([maxlen], dtype=np.int32)
  82. neg = np.zeros([maxlen*number_of_neg], dtype=np.int32)
  83. list_idx = random.sample([i for i in range(len(user_train[user]))], maxlen + 1)
  84. list_item = [user_train[user][i] for i in sorted(list_idx)]
  85. nxt = list_item[-1]
  86. idx = maxlen - 1
  87. ts = set(user_train[user])
  88. for i in reversed(list_item[:-1]):
  89. seq[idx] = i
  90. pos[idx] = nxt
  91. # if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
  92. nxt = i
  93. idx -= 1
  94. if idx == -1: break
  95. curr_rel = user
  96. support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
  97. for i in range(len(neg)):
  98. # neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
  99. neg[i] = random_negetive_batch(1, itemnum + 1, ts, user_train, usernum, batch_users = batch_users)
  100. for j in range(number_of_neg):
  101. for idx in range(maxlen-1):
  102. support_negative_triples.append([seq[idx], curr_rel, neg[j*maxlen + idx]])
  103. for idx in range(maxlen-1):
  104. support_triples.append([seq[idx],curr_rel,pos[idx]])
  105. # support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
  106. query_triples.append([seq[-1],curr_rel,pos[-1]])
  107. negative_triples.append([seq[-1],curr_rel,neg[-1]])
  108. return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel
  109. np.random.seed(SEED)
  110. while True:
  111. one_batch = []
  112. users = []
  113. for i in range(batch_size):
  114. user = np.random.randint(1, usernum + 1)
  115. while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
  116. users.append(user)
  117. for i in range(batch_size):
  118. one_batch.append(sample(user = users[i], batch_users = users))
  119. support, support_negative, query, negative, curr_rel = zip(*one_batch)
  120. result_queue.put(([support, support_negative, query, negative], curr_rel))
  121. class WarpSampler(object):
  122. def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1,params = None):
  123. self.result_queue = Queue(maxsize=n_workers * 10)
  124. self.processors = []
  125. for i in range(n_workers):
  126. self.processors.append(
  127. Process(target=sample_function_mixed, args=(User,
  128. usernum,
  129. itemnum,
  130. batch_size,
  131. maxlen,
  132. self.result_queue,
  133. np.random.randint(2e9),
  134. params['number_of_neg']
  135. )))
  136. self.processors[-1].daemon = True
  137. self.processors[-1].start()
  138. def next_batch(self):
  139. return self.result_queue.get()
  140. def close(self):
  141. for p in self.processors:
  142. p.terminate()
  143. p.join()