1234567891011121314151617181920212223242526272829303132 |
- import torch
- from copy import deepcopy
- from torch.autograd import Variable
- from torch.nn import functional as F
- from collections import OrderedDict
- from embeddings import item, user
-
-
- class EmbeddingModule(torch.nn.Module):
- def __init__(self, config):
- super(EmbeddingModule, self).__init__()
- self.embedding_dim = config['embedding_dim']
- self.use_cuda = config['use_cuda']
-
- self.item_emb = item(config)
- self.user_emb = user(config)
-
-
- def forward(self, x, training = True):
- rate_idx = Variable(x[:, 0], requires_grad=False)
- genre_idx = Variable(x[:, 1:26], requires_grad=False)
- director_idx = Variable(x[:, 26:2212], requires_grad=False)
- actor_idx = Variable(x[:, 2212:10242], requires_grad=False)
- gender_idx = Variable(x[:, 10242], requires_grad=False)
- age_idx = Variable(x[:, 10243], requires_grad=False)
- occupation_idx = Variable(x[:, 10244], requires_grad=False)
- area_idx = Variable(x[:, 10245], requires_grad=False)
-
- item_emb = self.item_emb(rate_idx, genre_idx, director_idx, actor_idx)
- user_emb = self.user_emb(gender_idx, age_idx, occupation_idx, area_idx)
- x = torch.cat((item_emb, user_emb), 1)
- return x
|