123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- from code.layer import GraphConvolution
-
-
- def my_sigmoid(mx, dim):
- mx = torch.sigmoid(mx)
- return F.normalize(mx, p=1, dim=dim)
-
-
- class Discriminator(nn.Module):
- def __init__(self, nfeat, nhid, nclass, dropout, order):
- super(Discriminator, self).__init__()
-
- layers = []
- if len(nhid) == 0:
- layers.append(GraphConvolution(nfeat, nclass, order=order))
- else:
- layers.append(GraphConvolution(nfeat, nhid[0], order=order))
- for i in range(len(nhid) - 1):
- layers.append(GraphConvolution(nhid[i], nhid[i + 1], order=order))
- if nclass > 1:
- layers.append(GraphConvolution(nhid[-1], nclass, order=order))
- self.gc = nn.ModuleList(layers)
-
- self.dropout = dropout
- self.nclass = nclass
-
- def forward(self, x, adj, samples=-1, func=F.relu):
- end_layer = len(self.gc) - 1 if self.nclass > 1 else len(self.gc)
- for i in range(end_layer):
- x = F.dropout(x, self.dropout, training=self.training)
- x = self.gc[i](x, adj)
- x = func(x)
-
- if self.nclass > 1:
- classifier = self.gc[-1](x, adj)
- classifier = F.log_softmax(classifier, dim=1)
- return classifier[samples,:], x
- else:
- return None, x
-
-
- class Weighing(nn.Module):
- def __init__(self, nfeat, nhid, weighing_output_dim, dropout, order):
- super(Weighing, self).__init__()
-
- layers = []
- layers.append(GraphConvolution(nfeat, nhid[0], order=order))
- for i in range(len(nhid) - 1):
- layers.append(GraphConvolution(nhid[i], nhid[i + 1], order=order))
-
- layers.append(GraphConvolution(nhid[-1], weighing_output_dim, order=order))
- self.gc = nn.ModuleList(layers)
-
- self.dropout = dropout
-
- def forward(self, x, adj, func, samples=-1):
- end_layer = len(self.gc) - 1
- for i in range(end_layer):
- x = self.gc[i](x, adj)
- x = F.leaky_relu(x, inplace=False)
- x = F.dropout(x, self.dropout, training=self.training)
-
- if type(samples) is int:
- weights = func(self.gc[-1](x, adj), dim=0)
- else:
- weights = func(self.gc[-1](x, adj)[samples, :], dim=0)
-
- if len(weights.shape) == 1:
- weights = weights[None, :]
-
- return weights, x
-
-
- class RAGCN:
- def __init__(self, features, adj , nclass, struc_D, struc_Ws, n_ws, weighing_output_dim, act, gamma=1):
- nfeat = features.shape[1]
- if type(adj['D']) is list:
- order_D = len(adj['D'])
- else:
- order_D = 1
- if type(adj['W'][0]) is list:
- order_W = len(adj['W'][0])
- else:
- order_W = 1
- self.n_ws = n_ws
- self.wod = weighing_output_dim
- self.net_D = Discriminator(nfeat, struc_D['nhid'], nclass, struc_D['dropout'], order=order_D)
- self.gamma = gamma
- self.opt_D = optim.Adam(self.net_D.parameters(), lr=struc_D['lr'], weight_decay=struc_D['wd'])
- self.net_Ws = []
- self.opt_Ws = []
- for i in range(n_ws):
- self.net_Ws.append(Weighing(nfeat=nfeat, nhid=struc_Ws[i]['nhid'], weighing_output_dim=weighing_output_dim
- , dropout=struc_Ws[i]['dropout'], order=order_W))
- self.opt_Ws.append(
- optim.Adam(self.net_Ws[-1].parameters(), lr=struc_Ws[i]['lr'], weight_decay=struc_Ws[i]['wd']))
- self.adj_D = adj['D']
- self.adj_W = adj['W']
- self.features = features
- self.act = act
-
- def run_D(self, samples):
- class_prob, embed = self.net_D(self.features, self.adj_D, samples)
- return class_prob, embed
-
- def run_W(self, samples, labels, args_cuda=False, equal_weights=False):
- batch_size = samples.shape[0]
- embed = None
- if equal_weights:
- max_label = int(labels.max().item() + 1)
- weight = torch.empty(batch_size)
- for i in range(max_label):
- labels_indices = (labels == i).nonzero().squeeze()
- if len(labels_indices.shape) == 0:
- batch_size = 1
- else:
- batch_size = len(labels_indices)
- if labels_indices is not None:
- weight[labels_indices] = 1 / batch_size * torch.ones(batch_size)
- weight = weight / max_label
- else:
- max_label = int(labels.max().item() + 1)
- weight = torch.empty(batch_size)
- if args_cuda:
- weight = weight.cuda()
- for i in range(max_label):
- labels_indices = (labels == i).nonzero().squeeze()
- if labels_indices is not None:
- sub_samples = samples[labels_indices]
- weight_, embed = self.net_Ws[i](x=self.features[sub_samples,:], adj=self.adj_W[i], samples=-1, func=self.act)
- weight[labels_indices] = weight_.squeeze() if self.wod == 1 else weight_[:,i]
- weight = weight / max_label
-
- if args_cuda:
- weight = weight.cuda()
-
- return weight, embed
-
- def loss_function_D(self, output, labels, weights):
- return torch.sum(- weights * (labels.float() * output).sum(1), -1)
-
- def loss_function_G(self, output, labels, weights):
- return torch.sum(- weights * (labels.float() * output).sum(1), -1) - self.gamma*torch.sum(weights*torch.log(weights+1e-20))
-
- def zero_grad_both(self):
- self.opt_D.zero_grad()
- for opt in self.opt_Ws:
- opt.zero_grad()
-
- def run_both(self, epoch_for_D, epoch_for_W, labels_one_hot, samples=-1,
- args_cuda=False, equal_weights=False):
- labels_not_onehot = labels_one_hot.max(1)[1].type_as(labels_one_hot)
- for e_D in range(epoch_for_D):
- self.zero_grad_both()
-
- class_prob_1, embed = self.run_D(samples)
- weights_1, _ = self.run_W(samples=samples, labels=labels_not_onehot, args_cuda=args_cuda, equal_weights=equal_weights)
-
- loss_D = self.loss_function_D(output=class_prob_1, labels=labels_one_hot, weights=weights_1)
- loss_D.backward()
- self.opt_D.step()
-
- for e_W in range(epoch_for_W):
- self.zero_grad_both()
-
- class_prob_2, embed = self.run_D(samples)
- weights_2, embeds = self.run_W(samples=samples, labels=labels_not_onehot, args_cuda=args_cuda, equal_weights=equal_weights)
-
- loss_G = -self.loss_function_G(output=class_prob_2, labels=labels_one_hot, weights=weights_2)
- loss_G.backward()
- for opt in self.opt_Ws:
- opt.step()
-
- def train(self):
- self.net_D.train()
- for i in range(self.n_ws):
- self.net_Ws[i].train()
-
- def eval(self):
- self.net_D.eval()
- for i in range(self.n_ws):
- self.net_Ws[i].eval()
-
- def cuda(self):
- self.features = self.features.to(device='cuda')
- for i in range(len(self.adj_D)):
- self.adj_D[i] = self.adj_D[i].to(device='cuda')
- for i in range(len(self.adj_W)):
- self.adj_W[i] = self.adj_W[i].to(device='cuda')
- self.net_D.cuda()
- for i in range(self.n_ws):
- self.net_Ws[i].cuda()
|