Sequential Recommendation for cold-start users with meta transitional learning
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from collections import OrderedDict
  2. import torch
  3. import torch.nn as nn
  4. from torch.nn import functional as F
  5. class Embedding(nn.Module):
  6. def __init__(self, num_ent, parameter):
  7. super(Embedding, self).__init__()
  8. self.device = parameter['device']
  9. self.es = parameter['embed_dim']
  10. self.embedding = nn.Embedding(num_ent + 1, self.es)
  11. nn.init.xavier_uniform_(self.embedding.weight)
  12. def forward(self, triples):
  13. idx = [[[t[0], t[2]] for t in batch] for batch in triples]
  14. idx = torch.LongTensor(idx).to(self.device)
  15. return self.embedding(idx)
  16. class MetaLearner(nn.Module):
  17. def __init__(self, K, embed_size=100, num_hidden1=500, num_hidden2=200, out_size=100, dropout_p=0.5):
  18. super(MetaLearner, self).__init__()
  19. self.embed_size = embed_size
  20. self.K = K
  21. self.out_size = out_size
  22. self.rel_fc1 = nn.Sequential(OrderedDict([
  23. ('fc', nn.Linear(2*embed_size, num_hidden1)),
  24. ('bn', nn.BatchNorm1d(K)),
  25. ('relu', nn.LeakyReLU()),
  26. ('drop', nn.Dropout(p=dropout_p)),
  27. ]))
  28. self.rel_fc2 = nn.Sequential(OrderedDict([
  29. ('fc', nn.Linear(num_hidden1, num_hidden2)),
  30. ('bn', nn.BatchNorm1d(K)),
  31. ('relu', nn.LeakyReLU()),
  32. ('drop', nn.Dropout(p=dropout_p)),
  33. ]))
  34. self.rel_fc3 = nn.Sequential(OrderedDict([
  35. ('fc', nn.Linear(num_hidden2, out_size)),
  36. ('bn', nn.BatchNorm1d(K)),
  37. ]))
  38. nn.init.xavier_normal_(self.rel_fc1.fc.weight)
  39. nn.init.xavier_normal_(self.rel_fc2.fc.weight)
  40. nn.init.xavier_normal_(self.rel_fc3.fc.weight)
  41. def forward(self, inputs):
  42. size = inputs.shape
  43. x = inputs.contiguous().view(size[0], size[1], -1)
  44. x = self.rel_fc1(x)
  45. x = self.rel_fc2(x)
  46. x = self.rel_fc3(x)
  47. x = torch.mean(x, 1)
  48. return x.view(size[0], 1, 1, self.out_size)
  49. class EmbeddingLearner(nn.Module):
  50. def __init__(self):
  51. super(EmbeddingLearner, self).__init__()
  52. def forward(self, h, t, r, pos_num):
  53. score = -torch.norm(h + r - t, 2, -1).squeeze(2)
  54. p_score = score[:, :pos_num]
  55. n_score = score[:, pos_num:]
  56. return p_score, n_score
  57. class MetaTL(nn.Module):
  58. def __init__(self, itemnum, parameter):
  59. super(MetaTL, self).__init__()
  60. self.device = parameter['device']
  61. self.beta = parameter['beta']
  62. self.dropout_p = parameter['dropout_p']
  63. self.embed_dim = parameter['embed_dim']
  64. self.margin = parameter['margin']
  65. self.embedding = Embedding(itemnum, parameter)
  66. self.relation_learner = MetaLearner(parameter['K'] - 1, embed_size=100, num_hidden1=500,
  67. num_hidden2=200, out_size=100, dropout_p=self.dropout_p)
  68. self.embedding_learner = EmbeddingLearner()
  69. self.loss_func = nn.MarginRankingLoss(self.margin)
  70. self.rel_q_sharing = dict()
  71. def split_concat(self, positive, negative):
  72. pos_neg_e1 = torch.cat([positive[:, :, 0, :],
  73. negative[:, :, 0, :]], 1).unsqueeze(2)
  74. pos_neg_e2 = torch.cat([positive[:, :, 1, :],
  75. negative[:, :, 1, :]], 1).unsqueeze(2)
  76. return pos_neg_e1, pos_neg_e2
  77. def forward(self, task, iseval=False, curr_rel=''):
  78. # transfer task string into embedding
  79. support, support_negative, query, negative = [self.embedding(t) for t in task]
  80. K = support.shape[1] # num of K
  81. num_sn = support_negative.shape[1] # num of support negative
  82. num_q = query.shape[1] # num of query
  83. num_n = negative.shape[1] # num of query negative
  84. rel = self.relation_learner(support)
  85. rel.retain_grad()
  86. rel_s = rel.expand(-1, K+num_sn, -1, -1)
  87. if iseval and curr_rel != '' and curr_rel in self.rel_q_sharing.keys():
  88. rel_q = self.rel_q_sharing[curr_rel]
  89. else:
  90. sup_neg_e1, sup_neg_e2 = self.split_concat(support, support_negative)
  91. p_score, n_score = self.embedding_learner(sup_neg_e1, sup_neg_e2, rel_s, K)
  92. y = torch.Tensor([1]).to(self.device)
  93. self.zero_grad()
  94. loss = self.loss_func(p_score, n_score, y)
  95. loss.backward(retain_graph=True)
  96. grad_meta = rel.grad
  97. rel_q = rel - self.beta*grad_meta
  98. self.rel_q_sharing[curr_rel] = rel_q
  99. rel_q = rel_q.expand(-1, num_q + num_n, -1, -1)
  100. que_neg_e1, que_neg_e2 = self.split_concat(query, negative)
  101. p_score, n_score = self.embedding_learner(que_neg_e1, que_neg_e2, rel_q, num_q)
  102. return p_score, n_score