make other meta-learning algorithms implemented in l2l.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

embeddings.py 2.9KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class item(torch.nn.Module):
  5. def __init__(self, config):
  6. super(item, self).__init__()
  7. self.num_rate = config['num_rate']
  8. self.num_genre = config['num_genre']
  9. self.num_director = config['num_director']
  10. self.num_actor = config['num_actor']
  11. self.embedding_dim = config['embedding_dim']
  12. self.embedding_rate = torch.nn.Embedding(
  13. num_embeddings=self.num_rate,
  14. embedding_dim=self.embedding_dim
  15. )
  16. self.embedding_genre = torch.nn.Linear(
  17. in_features=self.num_genre,
  18. out_features=self.embedding_dim,
  19. bias=False
  20. )
  21. self.embedding_director = torch.nn.Linear(
  22. in_features=self.num_director,
  23. out_features=self.embedding_dim,
  24. bias=False
  25. )
  26. self.embedding_actor = torch.nn.Linear(
  27. in_features=self.num_actor,
  28. out_features=self.embedding_dim,
  29. bias=False
  30. )
  31. def forward(self, rate_idx, genre_idx, director_idx, actors_idx, vars=None):
  32. rate_emb = self.embedding_rate(rate_idx)
  33. genre_emb = self.embedding_genre(genre_idx.float()) / torch.sum(genre_idx.float(), 1).view(-1, 1)
  34. director_emb = self.embedding_director(director_idx.float()) / torch.sum(director_idx.float(), 1).view(-1, 1)
  35. actors_emb = self.embedding_actor(actors_idx.float()) / torch.sum(actors_idx.float(), 1).view(-1, 1)
  36. return torch.cat((rate_emb, genre_emb, director_emb, actors_emb), 1)
  37. class user(torch.nn.Module):
  38. def __init__(self, config):
  39. super(user, self).__init__()
  40. self.num_gender = config['num_gender']
  41. self.num_age = config['num_age']
  42. self.num_occupation = config['num_occupation']
  43. self.num_zipcode = config['num_zipcode']
  44. self.embedding_dim = config['embedding_dim']
  45. self.embedding_gender = torch.nn.Embedding(
  46. num_embeddings=self.num_gender,
  47. embedding_dim=self.embedding_dim
  48. )
  49. self.embedding_age = torch.nn.Embedding(
  50. num_embeddings=self.num_age,
  51. embedding_dim=self.embedding_dim
  52. )
  53. self.embedding_occupation = torch.nn.Embedding(
  54. num_embeddings=self.num_occupation,
  55. embedding_dim=self.embedding_dim
  56. )
  57. self.embedding_area = torch.nn.Embedding(
  58. num_embeddings=self.num_zipcode,
  59. embedding_dim=self.embedding_dim
  60. )
  61. def forward(self, gender_idx, age_idx, occupation_idx, area_idx):
  62. gender_emb = self.embedding_gender(gender_idx)
  63. age_emb = self.embedding_age(age_idx)
  64. occupation_emb = self.embedding_occupation(occupation_idx)
  65. area_emb = self.embedding_area(area_idx)
  66. return torch.cat((gender_emb, age_emb, occupation_emb, area_emb), 1)