Adapted to Movie lens dataset
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

TaNP.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import torch
  2. import numpy as np
  3. from random import randint
  4. from copy import deepcopy
  5. from torch.autograd import Variable
  6. from torch.nn import functional as F
  7. from collections import OrderedDict
  8. from embeddings_TaNP import Item, User, Encoder, MuSigmaEncoder, Decoder, Gating_Decoder, TaskEncoder, MemoryUnit,Movie_item,Movie_user
  9. import torch.nn as nn
  10. class NP(nn.Module):
  11. def __init__(self, config):
  12. super(NP, self).__init__()
  13. # change for Movie lens
  14. # self.x_dim = config['second_embedding_dim'] * 2
  15. self.x_dim = 32 * 8
  16. # use one-hot or not?
  17. self.y_dim = 1
  18. self.z1_dim = config['z1_dim']
  19. self.z2_dim = config['z2_dim']
  20. # z is the dimension size of mu and sigma.
  21. self.z_dim = config['z_dim']
  22. # the dimension size of rc.
  23. self.enc_h1_dim = config['enc_h1_dim']
  24. self.enc_h2_dim = config['enc_h2_dim']
  25. self.dec_h1_dim = config['dec_h1_dim']
  26. self.dec_h2_dim = config['dec_h2_dim']
  27. self.dec_h3_dim = config['dec_h3_dim']
  28. self.taskenc_h1_dim = config['taskenc_h1_dim']
  29. self.taskenc_h2_dim = config['taskenc_h2_dim']
  30. self.taskenc_final_dim = config['taskenc_final_dim']
  31. self.clusters_k = config['clusters_k']
  32. self.temperture = config['temperature']
  33. self.dropout_rate = config['dropout_rate']
  34. # Initialize networks
  35. # change for Movie Lens
  36. # self.item_emb = Item(config)
  37. # self.user_emb = User(config)
  38. self.item_emb = Movie_item(config)
  39. self.user_emb = Movie_user(config)
  40. # This encoder is used to generated z actually, it is a latent encoder in ANP.
  41. self.xy_to_z = Encoder(self.x_dim, self.y_dim, self.enc_h1_dim, self.enc_h2_dim, self.z1_dim, self.dropout_rate)
  42. self.z_to_mu_sigma = MuSigmaEncoder(self.z1_dim, self.z2_dim, self.z_dim)
  43. # This encoder is used to generated r actually, it is a deterministic encoder in ANP.
  44. self.xy_to_task = TaskEncoder(self.x_dim, self.y_dim, self.taskenc_h1_dim, self.taskenc_h2_dim, self.taskenc_final_dim,
  45. self.dropout_rate)
  46. self.memoryunit = MemoryUnit(self.clusters_k, self.taskenc_final_dim, self.temperture)
  47. #self.xz_to_y = Gating_Decoder(self.x_dim, self.z_dim, self.taskenc_final_dim, self.dec_h1_dim, self.dec_h2_dim, self.dec_h3_dim, self.y_dim, self.dropout_rate)
  48. self.xz_to_y = Decoder(self.x_dim, self.z_dim, self.taskenc_final_dim, self.dec_h1_dim, self.dec_h2_dim, self.dec_h3_dim, self.y_dim, self.dropout_rate)
  49. def aggregate(self, z_i):
  50. return torch.mean(z_i, dim=0)
  51. def xy_to_mu_sigma(self, x, y):
  52. # Encode each point into a representation r_i
  53. z_i = self.xy_to_z(x, y)
  54. # Aggregate representations r_i into a single representation r
  55. z = self.aggregate(z_i)
  56. # Return parameters of distribution
  57. return self.z_to_mu_sigma(z)
  58. # embedding each (item, user) as the x for np
  59. def embedding(self, x):
  60. # change for Movie lens
  61. rate_idx = Variable(x[:, 0], requires_grad=False)
  62. genre_idx = Variable(x[:, 1:26], requires_grad=False)
  63. director_idx = Variable(x[:, 26:2212], requires_grad=False)
  64. actor_idx = Variable(x[:, 2212:10242], requires_grad=False)
  65. gender_idx = Variable(x[:, 10242], requires_grad=False)
  66. age_idx = Variable(x[:, 10243], requires_grad=False)
  67. occupation_idx = Variable(x[:, 10244], requires_grad=False)
  68. area_idx = Variable(x[:, 10245], requires_grad=False)
  69. item_emb = self.item_emb(rate_idx, genre_idx, director_idx, actor_idx)
  70. user_emb = self.user_emb(gender_idx, age_idx, occupation_idx, area_idx)
  71. x = torch.cat((item_emb, user_emb), 1)
  72. return x
  73. def forward(self, x_context, y_context, x_target, y_target):
  74. # change for Movie lens
  75. x_context_embed = self.embedding(x_context)
  76. x_target_embed = self.embedding(x_target)
  77. if self.training:
  78. # sigma is log_sigma actually
  79. mu_target, sigma_target, z_target = self.xy_to_mu_sigma(x_target_embed, y_target)
  80. mu_context, sigma_context, z_context = self.xy_to_mu_sigma(x_context_embed, y_context)
  81. task = self.xy_to_task(x_context_embed, y_context)
  82. mean_task = self.aggregate(task)
  83. C_distribution, new_task_embed = self.memoryunit(mean_task)
  84. p_y_pred = self.xz_to_y(x_target_embed, z_target, new_task_embed)
  85. return p_y_pred, mu_target, sigma_target, mu_context, sigma_context, C_distribution
  86. else:
  87. mu_context, sigma_context, z_context = self.xy_to_mu_sigma(x_context_embed, y_context)
  88. task = self.xy_to_task(x_context_embed, y_context)
  89. mean_task = self.aggregate(task)
  90. C_distribution, new_task_embed = self.memoryunit(mean_task)
  91. p_y_pred = self.xz_to_y(x_target_embed, z_context, new_task_embed)
  92. return p_y_pred
  93. class Trainer(torch.nn.Module):
  94. def __init__(self, config):
  95. self.opt = config
  96. super(Trainer, self).__init__()
  97. self.use_cuda = config['use_cuda']
  98. self.np = NP(self.opt)
  99. self._lambda = config['lambda']
  100. self.optimizer = torch.optim.Adam(self.np.parameters(), lr=config['lr'])
  101. # our kl divergence
  102. def kl_div(self, mu_target, logsigma_target, mu_context, logsigma_context):
  103. target_sigma = torch.exp(logsigma_target)
  104. context_sigma = torch.exp(logsigma_context)
  105. kl_div = (logsigma_context - logsigma_target) - 0.5 + (((target_sigma ** 2) + (mu_target - mu_context) ** 2) / 2 * context_sigma ** 2)
  106. #kl_div = (t.exp(posterior_var) + (posterior_mu-prior_mu) ** 2) / t.exp(prior_var) - 1. + (prior_var - posterior_var)
  107. #kl_div = 0.5 * kl_div.sum()
  108. kl_div = kl_div.sum()
  109. return kl_div
  110. # new kl divergence -- kl(st|sc)
  111. def new_kl_div(self, prior_mu, prior_var, posterior_mu, posterior_var):
  112. kl_div = (torch.exp(posterior_var) + (posterior_mu-prior_mu) ** 2) / torch.exp(prior_var) - 1. + (prior_var - posterior_var)
  113. kl_div = 0.5 * kl_div.sum()
  114. return kl_div
  115. def loss(self, p_y_pred, y_target, mu_target, sigma_target, mu_context, sigma_context):
  116. #print('p_y_pred size is ', p_y_pred.size())
  117. regression_loss = F.mse_loss(p_y_pred, y_target.view(-1, 1))
  118. #print('regession loss size is ', regression_loss.size())
  119. # kl divergence between target and context
  120. #print('regession_loss is ', regression_loss.item())
  121. kl = self.new_kl_div(mu_context, sigma_context, mu_target, sigma_target)
  122. #print('KL_loss is ', kl.item())
  123. return regression_loss+kl
  124. def context_target_split(self, support_set_x, support_set_y, query_set_x, query_set_y):
  125. total_x = torch.cat((support_set_x, query_set_x), 0)
  126. total_y = torch.cat((support_set_y, query_set_y), 0)
  127. total_size = total_x.size(0)
  128. context_min = self.opt['context_min']
  129. context_max = self.opt['context_max']
  130. extra_tar_min = self.opt['target_extra_min']
  131. #here we simply use the total_size as the maximum of target size.
  132. num_context = randint(context_min, context_max)
  133. num_target = randint(extra_tar_min, total_size - num_context)
  134. sampled = np.random.choice(total_size, num_context+num_target, replace=False)
  135. x_context = total_x[sampled[:num_context], :]
  136. y_context = total_y[sampled[:num_context]]
  137. x_target = total_x[sampled, :]
  138. y_target = total_y[sampled]
  139. return x_context, y_context, x_target, y_target
  140. def new_context_target_split(self, support_set_x, support_set_y, query_set_x, query_set_y):
  141. total_x = torch.cat((support_set_x, query_set_x), 0)
  142. total_y = torch.cat((support_set_y, query_set_y), 0)
  143. total_size = total_x.size(0)
  144. context_min = self.opt['context_min']
  145. # change for Movie lens
  146. context_min = min(context_min, total_size - 1)
  147. num_context = np.random.randint(context_min, total_size)
  148. num_target = np.random.randint(0, total_size - num_context)
  149. sampled = np.random.choice(total_size, num_context+num_target, replace=False)
  150. x_context = total_x[sampled[:num_context], :]
  151. y_context = total_y[sampled[:num_context]]
  152. x_target = total_x[sampled, :]
  153. y_target = total_y[sampled]
  154. return x_context, y_context, x_target, y_target
  155. def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys):
  156. batch_sz = len(support_set_xs)
  157. losses = []
  158. C_distribs = []
  159. if self.use_cuda:
  160. for i in range(batch_sz):
  161. support_set_xs[i] = support_set_xs[i].cuda()
  162. support_set_ys[i] = support_set_ys[i].cuda()
  163. query_set_xs[i] = query_set_xs[i].cuda()
  164. query_set_ys[i] = query_set_ys[i].cuda()
  165. for i in range(batch_sz):
  166. x_context, y_context, x_target, y_target = self.new_context_target_split(support_set_xs[i], support_set_ys[i],
  167. query_set_xs[i], query_set_ys[i])
  168. # print("inja1: x_context_size:",x_context.size())
  169. p_y_pred, mu_target, sigma_target, mu_context, sigma_context, C_distribution = self.np(x_context, y_context, x_target,
  170. y_target)
  171. C_distribs.append(C_distribution)
  172. loss = self.loss(p_y_pred, y_target, mu_target, sigma_target, mu_context, sigma_context)
  173. #print('Each task has loss: ', loss)
  174. losses.append(loss)
  175. # calculate target distribution for clustering in batch manner.
  176. # batchsize * k
  177. C_distribs = torch.stack(C_distribs)
  178. # batchsize * k
  179. C_distribs_sq = torch.pow(C_distribs, 2)
  180. # 1*k
  181. C_distribs_sum = torch.sum(C_distribs, dim=0, keepdim=True)
  182. # batchsize * k
  183. temp = C_distribs_sq / C_distribs_sum
  184. # batchsize * 1
  185. temp_sum = torch.sum(temp, dim=1, keepdim=True)
  186. target_distribs = temp / temp_sum
  187. # calculate the kl loss
  188. clustering_loss = self._lambda * F.kl_div(C_distribs.log(), target_distribs, reduction='batchmean')
  189. #print('The clustering loss is %.6f' % (clustering_loss.item()))
  190. np_losses_mean = torch.stack(losses).mean(0)
  191. total_loss = np_losses_mean + clustering_loss
  192. self.optimizer.zero_grad()
  193. total_loss.backward()
  194. self.optimizer.step()
  195. return total_loss.item(), C_distribs.cpu().detach().numpy()
  196. def query_rec(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys):
  197. batch_sz = 1
  198. # used for calculating the rmse.
  199. losses_q = []
  200. if self.use_cuda:
  201. for i in range(batch_sz):
  202. support_set_xs[i] = support_set_xs[i].cuda()
  203. support_set_ys[i] = support_set_ys[i].cuda()
  204. query_set_xs[i] = query_set_xs[i].cuda()
  205. query_set_ys[i] = query_set_ys[i].cuda()
  206. for i in range(batch_sz):
  207. #query_set_y_pred = self.forward(support_set_xs[i], support_set_ys[i], query_set_xs[i], num_local_update)
  208. query_set_y_pred = self.np(support_set_xs[i], support_set_ys[i], query_set_xs[i], query_set_ys[i])
  209. # obtain the mean of gaussian distribution
  210. #(interation_size, y_dim)
  211. #query_set_y_pred = query_set_y_pred.loc.detach()
  212. #print('test_y_pred size is ', query_set_y_pred.size())
  213. loss_q = F.mse_loss(query_set_y_pred, query_set_ys[i].view(-1, 1))
  214. losses_q.append(loss_q)
  215. losses_q = torch.stack(losses_q).mean(0)
  216. output_list, recommendation_list = query_set_y_pred.view(-1).sort(descending=True)
  217. return losses_q.item(), recommendation_list