Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import sys
  2. import copy
  3. import random
  4. import numpy as np
  5. from collections import defaultdict, Counter
  6. from multiprocessing import Process, Queue
  7. def random_neq(l, r, s):
  8. t = np.random.randint(l, r)
  9. while t in s:
  10. t = np.random.randint(l, r)
  11. return t
  12. def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
  13. def sample():
  14. if random.random() < 0.5:
  15. user = np.random.randint(1, usernum + 1)
  16. while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
  17. seq = np.zeros([maxlen], dtype=np.int32)
  18. pos = np.zeros([maxlen], dtype=np.int32)
  19. neg = np.zeros([maxlen], dtype=np.int32)
  20. if len(user_train[user]) < maxlen:
  21. nxt_idx = len(user_train[user]) - 1
  22. else:
  23. nxt_idx = np.random.randint(maxlen, len(user_train[user]))
  24. nxt = user_train[user][nxt_idx]
  25. idx = maxlen - 1
  26. ts = set(user_train[user])
  27. for i in reversed(user_train[user][min(0, nxt_idx - 1 - maxlen): nxt_idx - 1]):
  28. seq[idx] = i
  29. pos[idx] = nxt
  30. if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
  31. nxt = i
  32. idx -= 1
  33. if idx == -1: break
  34. curr_rel = user
  35. support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
  36. for idx in range(maxlen - 1):
  37. support_triples.append([seq[idx], curr_rel, pos[idx]])
  38. support_negative_triples.append([seq[idx], curr_rel, neg[idx]])
  39. query_triples.append([seq[-1], curr_rel, pos[-1]])
  40. negative_triples.append([seq[-1], curr_rel, neg[-1]])
  41. return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel
  42. else:
  43. user = np.random.randint(1, usernum + 1)
  44. while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
  45. seq = np.zeros([maxlen], dtype=np.int32)
  46. pos = np.zeros([maxlen], dtype=np.int32)
  47. neg = np.zeros([maxlen], dtype=np.int32)
  48. list_idx = random.sample([i for i in range(len(user_train[user]))], maxlen + 1)
  49. list_item = [user_train[user][i] for i in sorted(list_idx)]
  50. nxt = list_item[-1]
  51. idx = maxlen - 1
  52. ts = set(user_train[user])
  53. for i in reversed(list_item[:-1]):
  54. seq[idx] = i
  55. pos[idx] = nxt
  56. if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
  57. nxt = i
  58. idx -= 1
  59. if idx == -1: break
  60. curr_rel = user
  61. support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
  62. for idx in range(maxlen - 1):
  63. support_triples.append([seq[idx], curr_rel, pos[idx]])
  64. support_negative_triples.append([seq[idx], curr_rel, neg[idx]])
  65. query_triples.append([seq[-1], curr_rel, pos[-1]])
  66. negative_triples.append([seq[-1], curr_rel, neg[-1]])
  67. return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel
  68. np.random.seed(SEED)
  69. while True:
  70. one_batch = []
  71. for i in range(batch_size):
  72. one_batch.append(sample())
  73. support, support_negative, query, negative, curr_rel = zip(*one_batch)
  74. result_queue.put(([support, support_negative, query, negative], curr_rel))
  75. class WarpSampler(object):
  76. def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):
  77. self.result_queue = Queue(maxsize=n_workers * 10)
  78. self.processors = []
  79. for i in range(n_workers):
  80. self.processors.append(
  81. Process(target=sample_function_mixed, args=(User,
  82. usernum,
  83. itemnum,
  84. batch_size,
  85. maxlen,
  86. self.result_queue,
  87. np.random.randint(2e9)
  88. )))
  89. self.processors[-1].daemon = True
  90. self.processors[-1].start()
  91. def next_batch(self):
  92. return self.result_queue.get()
  93. def close(self):
  94. for p in self.processors:
  95. p.terminate()
  96. p.join()