123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- import torch
- import numpy as np
- from copy import deepcopy
-
- from torch.autograd import Variable
- from torch.nn import functional as F
- from collections import OrderedDict
-
- from embeddings import item, user
-
-
- class user_preference_estimator(torch.nn.Module):
- def __init__(self, config):
- super(user_preference_estimator, self).__init__()
- self.embedding_dim = config['embedding_dim']
- self.fc1_in_dim = config['embedding_dim'] * 8
- self.fc2_in_dim = config['first_fc_hidden_dim']
- self.fc2_out_dim = config['second_fc_hidden_dim']
- self.use_cuda = config['use_cuda']
-
- self.item_emb = item(config)
- self.user_emb = user(config)
- self.fc1 = torch.nn.Linear(self.fc1_in_dim, self.fc2_in_dim)
- self.fc2 = torch.nn.Linear(self.fc2_in_dim, self.fc2_out_dim)
- self.linear_out = torch.nn.Linear(self.fc2_out_dim, 1)
-
- def forward(self, x, training = True):
- rate_idx = Variable(x[:, 0], requires_grad=False)
- genre_idx = Variable(x[:, 1:26], requires_grad=False)
- director_idx = Variable(x[:, 26:2212], requires_grad=False)
- actor_idx = Variable(x[:, 2212:10242], requires_grad=False)
- gender_idx = Variable(x[:, 10242], requires_grad=False)
- age_idx = Variable(x[:, 10243], requires_grad=False)
- occupation_idx = Variable(x[:, 10244], requires_grad=False)
- area_idx = Variable(x[:, 10245], requires_grad=False)
-
- item_emb = self.item_emb(rate_idx, genre_idx, director_idx, actor_idx)
- user_emb = self.user_emb(gender_idx, age_idx, occupation_idx, area_idx)
- x = torch.cat((item_emb, user_emb), 1)
- x = self.fc1(x)
- x = F.relu(x)
- x = self.fc2(x)
- x = F.relu(x)
- return self.linear_out(x)
-
-
- class MeLU(torch.nn.Module):
- def __init__(self, config):
- super(MeLU, self).__init__()
- self.use_cuda = config['use_cuda']
- self.model = user_preference_estimator(config)
- self.local_lr = config['local_lr']
- self.store_parameters()
- self.meta_optim = torch.optim.Adam(self.model.parameters(), lr=config['lr'])
- self.local_update_target_weight_name = ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'linear_out.weight', 'linear_out.bias']
-
- def store_parameters(self):
- self.keep_weight = deepcopy(self.model.state_dict())
- self.weight_name = list(self.keep_weight.keys())
- self.weight_len = len(self.keep_weight)
- self.fast_weights = OrderedDict()
-
- def forward(self, support_set_x, support_set_y, query_set_x, num_local_update):
-
- # this line added my maheri
- self.keep_weight = deepcopy(self.model.state_dict())
-
- for idx in range(num_local_update):
- if idx > 0:
- self.model.load_state_dict(self.fast_weights)
-
- # weight_for_local_update = list(self.model.state_dict().values())
- weight_for_local_update = list(self.model.state_dict().values())
-
- support_set_y_pred = self.model(support_set_x)
- loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1))
- self.model.zero_grad()
- grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
-
- # local update
- for i in range(self.weight_len):
- if self.weight_name[i] in self.local_update_target_weight_name:
- self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
- else:
- self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
-
- self.model.load_state_dict(self.fast_weights)
- # self.fast_weights = OrderedDict()
- query_set_y_pred = self.model(query_set_x)
- self.model.load_state_dict(self.keep_weight)
-
- del weight_for_local_update,loss,grad,support_set_y_pred
-
-
- return query_set_y_pred
-
- def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update):
- batch_sz = len(support_set_xs)
- losses_q = []
- if self.use_cuda:
- for i in range(batch_sz):
- support_set_xs[i] = support_set_xs[i].cuda()
- support_set_ys[i] = support_set_ys[i].cuda()
- query_set_xs[i] = query_set_xs[i].cuda()
- query_set_ys[i] = query_set_ys[i].cuda()
- for i in range(batch_sz):
- query_set_y_pred = self.forward(support_set_xs[i], support_set_ys[i], query_set_xs[i], num_local_update)
- loss_q = F.mse_loss(query_set_y_pred, query_set_ys[i].view(-1, 1))
- losses_q.append(loss_q)
- losses_q = torch.stack(losses_q).mean(0)
- self.meta_optim.zero_grad()
- losses_q.backward()
- self.meta_optim.step()
- self.store_parameters()
-
- return
-
- def get_weight_avg_norm(self, support_set_x, support_set_y, num_local_update):
- tmp = 0.
- if self.cuda():
- support_set_x = support_set_x.cuda()
- support_set_y = support_set_y.cuda()
- for idx in range(num_local_update):
- if idx > 0:
- self.model.load_state_dict(self.fast_weights)
- weight_for_local_update = list(self.model.state_dict().values())
- support_set_y_pred = self.model(support_set_x)
- loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1))
- # unit loss
- loss /= torch.norm(loss).tolist()
- self.model.zero_grad()
- grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
- for i in range(self.weight_len):
- # For averaging Forbenius norm.
- tmp += torch.norm(grad[i])
- if self.weight_name[i] in self.local_update_target_weight_name:
- self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
- else:
- self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
- return tmp / num_local_update
|