123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- import torch
- import torch.nn as nn
- from torch.nn import init
- import torch.nn.functional as F
-
- import numpy as np
-
-
- # GCN basic operation
- class GraphConv(nn.Module):
- def __init__(self, input_dim, output_dim, add_self=False, normalize_embedding=False,
- dropout=0.0, bias=True):
- super(GraphConv, self).__init__()
- self.add_self = add_self
- self.dropout = dropout
- if dropout > 0.001:
- self.dropout_layer = nn.Dropout(p=dropout)
- self.normalize_embedding = normalize_embedding
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
- if bias:
- self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda())
- else:
- self.bias = None
-
- def forward(self, x, adj):
- if self.dropout > 0.001:
- x = self.dropout_layer(x)
- y = torch.matmul(adj, x)
- if self.add_self:
- y += x
- y = torch.matmul(y, self.weight)
- if self.bias is not None:
- y = y + self.bias
- if self.normalize_embedding:
- y = F.normalize(y, p=2, dim=2)
- # print(y[0][0])
- return y
-
-
- class GcnEncoderGraph(nn.Module):
- def __init__(self, input_dim, hidden_dim, embedding_dim, label_dim, num_layers,
- pred_hidden_dims=[], concat=True, bn=True, dropout=0.0, args=None):
- super(GcnEncoderGraph, self).__init__()
- self.concat = concat
- add_self = not concat
- self.bn = bn
- self.num_layers = num_layers
- self.num_aggs = 1
-
- self.bias = True
- if args is not None:
- self.bias = args.bias
-
- self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers(
- input_dim, hidden_dim, embedding_dim, num_layers,
- add_self, normalize=True, dropout=dropout)
- self.act = nn.ReLU()
- self.label_dim = label_dim
-
- if concat:
- self.pred_input_dim = hidden_dim * (num_layers - 1) + embedding_dim
- else:
- self.pred_input_dim = embedding_dim
- self.pred_model = self.build_pred_layers(self.pred_input_dim, pred_hidden_dims,
- label_dim, num_aggs=self.num_aggs)
-
- for m in self.modules():
- if isinstance(m, GraphConv):
- m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
- if m.bias is not None:
- m.bias.data = init.constant(m.bias.data, 0.0)
-
- def build_conv_layers(self, input_dim, hidden_dim, embedding_dim, num_layers, add_self,
- normalize=False, dropout=0.0):
- conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim, add_self=add_self,
- normalize_embedding=normalize, bias=self.bias)
- conv_block = nn.ModuleList(
- [GraphConv(input_dim=hidden_dim, output_dim=hidden_dim, add_self=add_self,
- normalize_embedding=normalize, dropout=dropout, bias=self.bias)
- for i in range(num_layers - 2)])
- conv_last = GraphConv(input_dim=hidden_dim, output_dim=embedding_dim, add_self=add_self,
- normalize_embedding=normalize, bias=self.bias)
- return conv_first, conv_block, conv_last
-
- def build_pred_layers(self, pred_input_dim, pred_hidden_dims, label_dim, num_aggs=1):
- pred_input_dim = pred_input_dim * num_aggs
- if len(pred_hidden_dims) == 0:
- pred_model = nn.Linear(pred_input_dim, label_dim)
- else:
- pred_layers = []
- for pred_dim in pred_hidden_dims:
- pred_layers.append(nn.Linear(pred_input_dim, pred_dim))
- pred_layers.append(self.act)
- pred_input_dim = pred_dim
- pred_layers.append(nn.Linear(pred_dim, label_dim))
- pred_model = nn.Sequential(*pred_layers)
- return pred_model
-
- def construct_mask(self, max_nodes, batch_num_nodes):
- ''' For each num_nodes in batch_num_nodes, the first num_nodes entries of the
- corresponding column are 1's, and the rest are 0's (to be masked out).
- Dimension of mask: [batch_size x max_nodes x 1]
- '''
- # masks
- packed_masks = [torch.ones(int(num)) for num in batch_num_nodes]
- batch_size = len(batch_num_nodes)
- out_tensor = torch.zeros(batch_size, max_nodes)
- for i, mask in enumerate(packed_masks):
- out_tensor[i, :batch_num_nodes[i]] = mask
- return out_tensor.unsqueeze(2).cuda()
-
- def apply_bn(self, x):
- ''' Batch normalization of 3D tensor x
- '''
- bn_module = nn.BatchNorm1d(x.size()[1]).cuda()
- return bn_module(x)
-
- def gcn_forward(self, x, adj, conv_first, conv_block, conv_last, embedding_mask=None):
-
- ''' Perform forward prop with graph convolution.
- Returns:
- Embedding matrix with dimension [batch_size x num_nodes x embedding]
- '''
-
- x = conv_first(x, adj)
- x = self.act(x)
- if self.bn:
- x = self.apply_bn(x)
- x_all = [x]
- # out_all = []
- # out, _ = torch.max(x, dim=1)
- # out_all.append(out)
- for i in range(len(conv_block)):
- x = conv_block[i](x, adj)
- x = self.act(x)
- if self.bn:
- x = self.apply_bn(x)
- x_all.append(x)
- x = conv_last(x, adj)
- x_all.append(x)
- # x_tensor: [batch_size x num_nodes x embedding]
- x_tensor = torch.cat(x_all, dim=2)
- if embedding_mask is not None:
- x_tensor = x_tensor * embedding_mask
- return x_tensor
|