Melu project implemented by l2l and using MetaSGD instead of MAML
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

embedding_module.py 1.2KB

12345678910111213141516171819202122232425262728293031323334
  1. import torch
  2. from copy import deepcopy
  3. from torch.autograd import Variable
  4. from torch.nn import functional as F
  5. from collections import OrderedDict
  6. from embeddings import item, user
  7. class EmbeddingModule(torch.nn.Module):
  8. def __init__(self, config):
  9. super(EmbeddingModule, self).__init__()
  10. self.embedding_dim = config['embedding_dim']
  11. self.use_cuda = config['use_cuda']
  12. self.item_emb = item(config)
  13. self.user_emb = user(config)
  14. def forward(self, x, training = True):
  15. rate_idx = Variable(x[:, 0], requires_grad=False)
  16. genre_idx = Variable(x[:, 1:26], requires_grad=False)
  17. director_idx = Variable(x[:, 26:2212], requires_grad=False)
  18. actor_idx = Variable(x[:, 2212:10242], requires_grad=False)
  19. gender_idx = Variable(x[:, 10242], requires_grad=False)
  20. age_idx = Variable(x[:, 10243], requires_grad=False)
  21. occupation_idx = Variable(x[:, 10244], requires_grad=False)
  22. area_idx = Variable(x[:, 10245], requires_grad=False)
  23. item_emb = self.item_emb(rate_idx, genre_idx, director_idx, actor_idx)
  24. user_emb = self.user_emb(gender_idx, age_idx, occupation_idx, area_idx)
  25. x = torch.cat((item_emb, user_emb), 1)
  26. return x