|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- class item(torch.nn.Module):
- def __init__(self, config):
- super(item, self).__init__()
- self.num_rate = config['num_rate']
- self.num_genre = config['num_genre']
- self.num_director = config['num_director']
- self.num_actor = config['num_actor']
- self.embedding_dim = config['embedding_dim']
-
- self.embedding_rate = torch.nn.Embedding(
- num_embeddings=self.num_rate,
- embedding_dim=self.embedding_dim
- )
-
- self.embedding_genre = torch.nn.Linear(
- in_features=self.num_genre,
- out_features=self.embedding_dim,
- bias=False
- )
-
- self.embedding_director = torch.nn.Linear(
- in_features=self.num_director,
- out_features=self.embedding_dim,
- bias=False
- )
-
- self.embedding_actor = torch.nn.Linear(
- in_features=self.num_actor,
- out_features=self.embedding_dim,
- bias=False
- )
-
- def forward(self, rate_idx, genre_idx, director_idx, actors_idx, vars=None):
- rate_emb = self.embedding_rate(rate_idx)
- genre_emb = self.embedding_genre(genre_idx.float()) / torch.sum(genre_idx.float(), 1).view(-1, 1)
- director_emb = self.embedding_director(director_idx.float()) / torch.sum(director_idx.float(), 1).view(-1, 1)
- actors_emb = self.embedding_actor(actors_idx.float()) / torch.sum(actors_idx.float(), 1).view(-1, 1)
- return torch.cat((rate_emb, genre_emb, director_emb, actors_emb), 1)
-
-
- class user(torch.nn.Module):
- def __init__(self, config):
- super(user, self).__init__()
- self.num_gender = config['num_gender']
- self.num_age = config['num_age']
- self.num_occupation = config['num_occupation']
- self.num_zipcode = config['num_zipcode']
- self.embedding_dim = config['embedding_dim']
-
- self.embedding_gender = torch.nn.Embedding(
- num_embeddings=self.num_gender,
- embedding_dim=self.embedding_dim
- )
-
- self.embedding_age = torch.nn.Embedding(
- num_embeddings=self.num_age,
- embedding_dim=self.embedding_dim
- )
-
- self.embedding_occupation = torch.nn.Embedding(
- num_embeddings=self.num_occupation,
- embedding_dim=self.embedding_dim
- )
-
- self.embedding_area = torch.nn.Embedding(
- num_embeddings=self.num_zipcode,
- embedding_dim=self.embedding_dim
- )
-
- def forward(self, gender_idx, age_idx, occupation_idx, area_idx):
- gender_emb = self.embedding_gender(gender_idx)
- age_emb = self.embedding_age(age_idx)
- occupation_emb = self.embedding_occupation(occupation_idx)
- area_emb = self.embedding_area(area_idx)
- return torch.cat((gender_emb, age_emb, occupation_emb, area_emb), 1)
|