meta-learning approach for solving cold start problem
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

MeLU.py 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import torch
  2. import numpy as np
  3. from copy import deepcopy
  4. from torch.autograd import Variable
  5. from torch.nn import functional as F
  6. from collections import OrderedDict
  7. from embeddings import item, user
  8. class user_preference_estimator(torch.nn.Module):
  9. def __init__(self, config):
  10. super(user_preference_estimator, self).__init__()
  11. self.embedding_dim = config['embedding_dim']
  12. self.fc1_in_dim = config['embedding_dim'] * 8
  13. self.fc2_in_dim = config['first_fc_hidden_dim']
  14. self.fc2_out_dim = config['second_fc_hidden_dim']
  15. self.use_cuda = config['use_cuda']
  16. self.item_emb = item(config)
  17. self.user_emb = user(config)
  18. self.fc1 = torch.nn.Linear(self.fc1_in_dim, self.fc2_in_dim)
  19. self.fc2 = torch.nn.Linear(self.fc2_in_dim, self.fc2_out_dim)
  20. self.linear_out = torch.nn.Linear(self.fc2_out_dim, 1)
  21. def forward(self, x, training = True):
  22. rate_idx = Variable(x[:, 0], requires_grad=False)
  23. genre_idx = Variable(x[:, 1:26], requires_grad=False)
  24. director_idx = Variable(x[:, 26:2212], requires_grad=False)
  25. actor_idx = Variable(x[:, 2212:10242], requires_grad=False)
  26. gender_idx = Variable(x[:, 10242], requires_grad=False)
  27. age_idx = Variable(x[:, 10243], requires_grad=False)
  28. occupation_idx = Variable(x[:, 10244], requires_grad=False)
  29. area_idx = Variable(x[:, 10245], requires_grad=False)
  30. item_emb = self.item_emb(rate_idx, genre_idx, director_idx, actor_idx)
  31. user_emb = self.user_emb(gender_idx, age_idx, occupation_idx, area_idx)
  32. x = torch.cat((item_emb, user_emb), 1)
  33. x = self.fc1(x)
  34. x = F.relu(x)
  35. x = self.fc2(x)
  36. x = F.relu(x)
  37. return self.linear_out(x)
  38. class MeLU(torch.nn.Module):
  39. def __init__(self, config):
  40. super(MeLU, self).__init__()
  41. self.use_cuda = config['use_cuda']
  42. self.model = user_preference_estimator(config)
  43. self.local_lr = config['local_lr']
  44. self.store_parameters()
  45. self.meta_optim = torch.optim.Adam(self.model.parameters(), lr=config['lr'])
  46. self.local_update_target_weight_name = ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'linear_out.weight', 'linear_out.bias']
  47. def store_parameters(self):
  48. self.keep_weight = deepcopy(self.model.state_dict())
  49. self.weight_name = list(self.keep_weight.keys())
  50. self.weight_len = len(self.keep_weight)
  51. self.fast_weights = OrderedDict()
  52. def forward(self, support_set_x, support_set_y, query_set_x, num_local_update):
  53. # this line added my maheri
  54. self.keep_weight = deepcopy(self.model.state_dict())
  55. for idx in range(num_local_update):
  56. if idx > 0:
  57. self.model.load_state_dict(self.fast_weights)
  58. # weight_for_local_update = list(self.model.state_dict().values())
  59. weight_for_local_update = list(self.model.state_dict().values())
  60. support_set_y_pred = self.model(support_set_x)
  61. loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1))
  62. self.model.zero_grad()
  63. grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
  64. # local update
  65. for i in range(self.weight_len):
  66. if self.weight_name[i] in self.local_update_target_weight_name:
  67. self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
  68. else:
  69. self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
  70. self.model.load_state_dict(self.fast_weights)
  71. # self.fast_weights = OrderedDict()
  72. query_set_y_pred = self.model(query_set_x)
  73. self.model.load_state_dict(self.keep_weight)
  74. del weight_for_local_update,loss,grad,support_set_y_pred
  75. return query_set_y_pred
  76. def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update):
  77. batch_sz = len(support_set_xs)
  78. losses_q = []
  79. if self.use_cuda:
  80. for i in range(batch_sz):
  81. support_set_xs[i] = support_set_xs[i].cuda()
  82. support_set_ys[i] = support_set_ys[i].cuda()
  83. query_set_xs[i] = query_set_xs[i].cuda()
  84. query_set_ys[i] = query_set_ys[i].cuda()
  85. for i in range(batch_sz):
  86. query_set_y_pred = self.forward(support_set_xs[i], support_set_ys[i], query_set_xs[i], num_local_update)
  87. loss_q = F.mse_loss(query_set_y_pred, query_set_ys[i].view(-1, 1))
  88. losses_q.append(loss_q)
  89. losses_q = torch.stack(losses_q).mean(0)
  90. self.meta_optim.zero_grad()
  91. losses_q.backward()
  92. self.meta_optim.step()
  93. self.store_parameters()
  94. return
  95. def get_weight_avg_norm(self, support_set_x, support_set_y, num_local_update):
  96. tmp = 0.
  97. if self.cuda():
  98. support_set_x = support_set_x.cuda()
  99. support_set_y = support_set_y.cuda()
  100. for idx in range(num_local_update):
  101. if idx > 0:
  102. self.model.load_state_dict(self.fast_weights)
  103. weight_for_local_update = list(self.model.state_dict().values())
  104. support_set_y_pred = self.model(support_set_x)
  105. loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1))
  106. # unit loss
  107. loss /= torch.norm(loss).tolist()
  108. self.model.zero_grad()
  109. grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
  110. for i in range(self.weight_len):
  111. # For averaging Forbenius norm.
  112. tmp += torch.norm(grad[i])
  113. if self.weight_name[i] in self.local_update_target_weight_name:
  114. self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
  115. else:
  116. self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
  117. return tmp / num_local_update