Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

sampler.py 4.6KB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import sys
  2. import copy
  3. import torch
  4. import random
  5. import numpy as np
  6. from collections import defaultdict, Counter
  7. from multiprocessing import Process, Queue
  8. def random_neq(l, r, s):
  9. t = np.random.randint(l, r)
  10. while t in s:
  11. t = np.random.randint(l, r)
  12. return t
  13. def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
  14. def sample():
  15. if random.random()<0.5:
  16. user = np.random.randint(1, usernum + 1)
  17. while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
  18. seq = np.zeros([maxlen], dtype=np.int32)
  19. pos = np.zeros([maxlen], dtype=np.int32)
  20. neg = np.zeros([maxlen], dtype=np.int32)
  21. if len(user_train[user]) < maxlen:
  22. nxt_idx = len(user_train[user]) - 1
  23. else:
  24. nxt_idx = np.random.randint(maxlen,len(user_train[user]))
  25. nxt = user_train[user][nxt_idx]
  26. idx = maxlen - 1
  27. ts = set(user_train[user])
  28. for i in reversed(user_train[user][min(0, nxt_idx - 1 - maxlen) : nxt_idx - 1]):
  29. seq[idx] = i
  30. pos[idx] = nxt
  31. if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
  32. nxt = i
  33. idx -= 1
  34. if idx == -1: break
  35. curr_rel = user
  36. support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
  37. for idx in range(maxlen-1):
  38. support_triples.append([seq[idx],curr_rel,pos[idx]])
  39. support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
  40. query_triples.append([seq[-1],curr_rel,pos[-1]])
  41. negative_triples.append([seq[-1],curr_rel,neg[-1]])
  42. return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel
  43. else:
  44. user = np.random.randint(1, usernum + 1)
  45. while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
  46. seq = np.zeros([maxlen], dtype=np.int32)
  47. pos = np.zeros([maxlen], dtype=np.int32)
  48. neg = np.zeros([maxlen], dtype=np.int32)
  49. list_idx = random.sample([i for i in range(len(user_train[user]))], maxlen + 1)
  50. list_item = [user_train[user][i] for i in sorted(list_idx)]
  51. nxt = list_item[-1]
  52. idx = maxlen - 1
  53. ts = set(user_train[user])
  54. for i in reversed(list_item[:-1]):
  55. seq[idx] = i
  56. pos[idx] = nxt
  57. if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
  58. nxt = i
  59. idx -= 1
  60. if idx == -1: break
  61. curr_rel = user
  62. support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
  63. for idx in range(maxlen-1):
  64. support_triples.append([seq[idx],curr_rel,pos[idx]])
  65. support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
  66. query_triples.append([seq[-1],curr_rel,pos[-1]])
  67. negative_triples.append([seq[-1],curr_rel,neg[-1]])
  68. return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel
  69. np.random.seed(SEED)
  70. while True:
  71. one_batch = []
  72. for i in range(batch_size):
  73. one_batch.append(sample())
  74. support, support_negative, query, negative, curr_rel = zip(*one_batch)
  75. result_queue.put(([support, support_negative, query, negative], curr_rel))
  76. class WarpSampler(object):
  77. def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):
  78. self.result_queue = Queue(maxsize=n_workers * 10)
  79. self.processors = []
  80. for i in range(n_workers):
  81. self.processors.append(
  82. Process(target=sample_function_mixed, args=(User,
  83. usernum,
  84. itemnum,
  85. batch_size,
  86. maxlen,
  87. self.result_queue,
  88. np.random.randint(2e9)
  89. )))
  90. self.processors[-1].daemon = True
  91. self.processors[-1].start()
  92. def next_batch(self):
  93. return self.result_queue.get()
  94. def close(self):
  95. for p in self.processors:
  96. p.terminate()
  97. p.join()