You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

diffpoolClassesAndFunctions.py 5.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn import init
  4. import torch.nn.functional as F
  5. import numpy as np
  6. # GCN basic operation
  7. class GraphConv(nn.Module):
  8. def __init__(self, input_dim, output_dim, add_self=False, normalize_embedding=False,
  9. dropout=0.0, bias=True):
  10. super(GraphConv, self).__init__()
  11. self.add_self = add_self
  12. self.dropout = dropout
  13. if dropout > 0.001:
  14. self.dropout_layer = nn.Dropout(p=dropout)
  15. self.normalize_embedding = normalize_embedding
  16. self.input_dim = input_dim
  17. self.output_dim = output_dim
  18. self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim).cuda())
  19. if bias:
  20. self.bias = nn.Parameter(torch.FloatTensor(output_dim).cuda())
  21. else:
  22. self.bias = None
  23. def forward(self, x, adj):
  24. if self.dropout > 0.001:
  25. x = self.dropout_layer(x)
  26. y = torch.matmul(adj, x)
  27. if self.add_self:
  28. y += x
  29. y = torch.matmul(y, self.weight)
  30. if self.bias is not None:
  31. y = y + self.bias
  32. if self.normalize_embedding:
  33. y = F.normalize(y, p=2, dim=2)
  34. # print(y[0][0])
  35. return y
  36. class GcnEncoderGraph(nn.Module):
  37. def __init__(self, input_dim, hidden_dim, embedding_dim, label_dim, num_layers,
  38. pred_hidden_dims=[], concat=True, bn=True, dropout=0.0, args=None):
  39. super(GcnEncoderGraph, self).__init__()
  40. self.concat = concat
  41. add_self = not concat
  42. self.bn = bn
  43. self.num_layers = num_layers
  44. self.num_aggs = 1
  45. self.bias = True
  46. if args is not None:
  47. self.bias = args.bias
  48. self.conv_first, self.conv_block, self.conv_last = self.build_conv_layers(
  49. input_dim, hidden_dim, embedding_dim, num_layers,
  50. add_self, normalize=True, dropout=dropout)
  51. self.act = nn.ReLU()
  52. self.label_dim = label_dim
  53. if concat:
  54. self.pred_input_dim = hidden_dim * (num_layers - 1) + embedding_dim
  55. else:
  56. self.pred_input_dim = embedding_dim
  57. self.pred_model = self.build_pred_layers(self.pred_input_dim, pred_hidden_dims,
  58. label_dim, num_aggs=self.num_aggs)
  59. for m in self.modules():
  60. if isinstance(m, GraphConv):
  61. m.weight.data = init.xavier_uniform(m.weight.data, gain=nn.init.calculate_gain('relu'))
  62. if m.bias is not None:
  63. m.bias.data = init.constant(m.bias.data, 0.0)
  64. def build_conv_layers(self, input_dim, hidden_dim, embedding_dim, num_layers, add_self,
  65. normalize=False, dropout=0.0):
  66. conv_first = GraphConv(input_dim=input_dim, output_dim=hidden_dim, add_self=add_self,
  67. normalize_embedding=normalize, bias=self.bias)
  68. conv_block = nn.ModuleList(
  69. [GraphConv(input_dim=hidden_dim, output_dim=hidden_dim, add_self=add_self,
  70. normalize_embedding=normalize, dropout=dropout, bias=self.bias)
  71. for i in range(num_layers - 2)])
  72. conv_last = GraphConv(input_dim=hidden_dim, output_dim=embedding_dim, add_self=add_self,
  73. normalize_embedding=normalize, bias=self.bias)
  74. return conv_first, conv_block, conv_last
  75. def build_pred_layers(self, pred_input_dim, pred_hidden_dims, label_dim, num_aggs=1):
  76. pred_input_dim = pred_input_dim * num_aggs
  77. if len(pred_hidden_dims) == 0:
  78. pred_model = nn.Linear(pred_input_dim, label_dim)
  79. else:
  80. pred_layers = []
  81. for pred_dim in pred_hidden_dims:
  82. pred_layers.append(nn.Linear(pred_input_dim, pred_dim))
  83. pred_layers.append(self.act)
  84. pred_input_dim = pred_dim
  85. pred_layers.append(nn.Linear(pred_dim, label_dim))
  86. pred_model = nn.Sequential(*pred_layers)
  87. return pred_model
  88. def construct_mask(self, max_nodes, batch_num_nodes):
  89. ''' For each num_nodes in batch_num_nodes, the first num_nodes entries of the
  90. corresponding column are 1's, and the rest are 0's (to be masked out).
  91. Dimension of mask: [batch_size x max_nodes x 1]
  92. '''
  93. # masks
  94. packed_masks = [torch.ones(int(num)) for num in batch_num_nodes]
  95. batch_size = len(batch_num_nodes)
  96. out_tensor = torch.zeros(batch_size, max_nodes)
  97. for i, mask in enumerate(packed_masks):
  98. out_tensor[i, :batch_num_nodes[i]] = mask
  99. return out_tensor.unsqueeze(2).cuda()
  100. def apply_bn(self, x):
  101. ''' Batch normalization of 3D tensor x
  102. '''
  103. bn_module = nn.BatchNorm1d(x.size()[1]).cuda()
  104. return bn_module(x)
  105. def gcn_forward(self, x, adj, conv_first, conv_block, conv_last, embedding_mask=None):
  106. ''' Perform forward prop with graph convolution.
  107. Returns:
  108. Embedding matrix with dimension [batch_size x num_nodes x embedding]
  109. '''
  110. x = conv_first(x, adj)
  111. x = self.act(x)
  112. if self.bn:
  113. x = self.apply_bn(x)
  114. x_all = [x]
  115. # out_all = []
  116. # out, _ = torch.max(x, dim=1)
  117. # out_all.append(out)
  118. for i in range(len(conv_block)):
  119. x = conv_block[i](x, adj)
  120. x = self.act(x)
  121. if self.bn:
  122. x = self.apply_bn(x)
  123. x_all.append(x)
  124. x = conv_last(x, adj)
  125. x_all.append(x)
  126. # x_tensor: [batch_size x num_nodes x embedding]
  127. x_tensor = torch.cat(x_all, dim=2)
  128. if embedding_mask is not None:
  129. x_tensor = x_tensor * embedding_mask
  130. return x_tensor