import torch import numpy as np from copy import deepcopy from torch.autograd import Variable from torch.nn import functional as F from collections import OrderedDict from embeddings import item, user class user_preference_estimator(torch.nn.Module): def __init__(self, config): super(user_preference_estimator, self).__init__() self.embedding_dim = config['embedding_dim'] self.fc1_in_dim = config['embedding_dim'] * 8 self.fc2_in_dim = config['first_fc_hidden_dim'] self.fc2_out_dim = config['second_fc_hidden_dim'] self.use_cuda = config['use_cuda'] self.item_emb = item(config) self.user_emb = user(config) self.fc1 = torch.nn.Linear(self.fc1_in_dim, self.fc2_in_dim) self.fc2 = torch.nn.Linear(self.fc2_in_dim, self.fc2_out_dim) self.linear_out = torch.nn.Linear(self.fc2_out_dim, 1) def forward(self, x, training = True): rate_idx = Variable(x[:, 0], requires_grad=False) genre_idx = Variable(x[:, 1:26], requires_grad=False) director_idx = Variable(x[:, 26:2212], requires_grad=False) actor_idx = Variable(x[:, 2212:10242], requires_grad=False) gender_idx = Variable(x[:, 10242], requires_grad=False) age_idx = Variable(x[:, 10243], requires_grad=False) occupation_idx = Variable(x[:, 10244], requires_grad=False) area_idx = Variable(x[:, 10245], requires_grad=False) item_emb = self.item_emb(rate_idx, genre_idx, director_idx, actor_idx) user_emb = self.user_emb(gender_idx, age_idx, occupation_idx, area_idx) x = torch.cat((item_emb, user_emb), 1) x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.relu(x) return self.linear_out(x) class MeLU(torch.nn.Module): def __init__(self, config): super(MeLU, self).__init__() self.use_cuda = config['use_cuda'] self.model = user_preference_estimator(config) self.local_lr = config['local_lr'] self.store_parameters() self.meta_optim = torch.optim.Adam(self.model.parameters(), lr=config['lr']) self.local_update_target_weight_name = ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'linear_out.weight', 'linear_out.bias'] def store_parameters(self): self.keep_weight = deepcopy(self.model.state_dict()) self.weight_name = list(self.keep_weight.keys()) self.weight_len = len(self.keep_weight) self.fast_weights = OrderedDict() def forward(self, support_set_x, support_set_y, query_set_x, num_local_update): # this line added my maheri self.keep_weight = deepcopy(self.model.state_dict()) for idx in range(num_local_update): if idx > 0: self.model.load_state_dict(self.fast_weights) # weight_for_local_update = list(self.model.state_dict().values()) weight_for_local_update = list(self.model.state_dict().values()) support_set_y_pred = self.model(support_set_x) loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1)) self.model.zero_grad() grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) # local update for i in range(self.weight_len): if self.weight_name[i] in self.local_update_target_weight_name: self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i] else: self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] self.model.load_state_dict(self.fast_weights) # self.fast_weights = OrderedDict() query_set_y_pred = self.model(query_set_x) self.model.load_state_dict(self.keep_weight) del weight_for_local_update,loss,grad,support_set_y_pred return query_set_y_pred def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update): batch_sz = len(support_set_xs) losses_q = [] if self.use_cuda: for i in range(batch_sz): support_set_xs[i] = support_set_xs[i].cuda() support_set_ys[i] = support_set_ys[i].cuda() query_set_xs[i] = query_set_xs[i].cuda() query_set_ys[i] = query_set_ys[i].cuda() for i in range(batch_sz): query_set_y_pred = self.forward(support_set_xs[i], support_set_ys[i], query_set_xs[i], num_local_update) loss_q = F.mse_loss(query_set_y_pred, query_set_ys[i].view(-1, 1)) losses_q.append(loss_q) losses_q = torch.stack(losses_q).mean(0) self.meta_optim.zero_grad() losses_q.backward() self.meta_optim.step() self.store_parameters() return def get_weight_avg_norm(self, support_set_x, support_set_y, num_local_update): tmp = 0. if self.cuda(): support_set_x = support_set_x.cuda() support_set_y = support_set_y.cuda() for idx in range(num_local_update): if idx > 0: self.model.load_state_dict(self.fast_weights) weight_for_local_update = list(self.model.state_dict().values()) support_set_y_pred = self.model(support_set_x) loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1)) # unit loss loss /= torch.norm(loss).tolist() self.model.zero_grad() grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) for i in range(self.weight_len): # For averaging Forbenius norm. tmp += torch.norm(grad[i]) if self.weight_name[i] in self.local_update_target_weight_name: self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i] else: self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] return tmp / num_local_update