import torch from copy import deepcopy from torch.autograd import Variable from torch.nn import functional as F from collections import OrderedDict from embeddings import item, user class EmbeddingModule(torch.nn.Module): def __init__(self, config): super(EmbeddingModule, self).__init__() self.embedding_dim = config['embedding_dim'] self.use_cuda = config['use_cuda'] self.item_emb = item(config) self.user_emb = user(config) def forward(self, x, training = True): rate_idx = Variable(x[:, 0], requires_grad=False) genre_idx = Variable(x[:, 1:26], requires_grad=False) director_idx = Variable(x[:, 26:2212], requires_grad=False) actor_idx = Variable(x[:, 2212:10242], requires_grad=False) gender_idx = Variable(x[:, 10242], requires_grad=False) age_idx = Variable(x[:, 10243], requires_grad=False) occupation_idx = Variable(x[:, 10244], requires_grad=False) area_idx = Variable(x[:, 10245], requires_grad=False) item_emb = self.item_emb(rate_idx, genre_idx, director_idx, actor_idx) user_emb = self.user_emb(gender_idx, age_idx, occupation_idx, area_idx) x = torch.cat((item_emb, user_emb), 1) return x