RA-GCN: Graph Convolutional Network for Disease Prediction Problems with Imbalanced Data | |||||
==== | |||||
Here is the code for node classification in graphs with imbalanced classes written in Pytorch. | |||||
Ghorbani et.al. "RA-GCN: Graph Convolutional Network for Disease Prediction Problems with Imbalanced Data" [1] | |||||
Usage | |||||
------------ | |||||
The main file is "main_medical.py". | |||||
```PYTHONPATH=../ python train.py``` | |||||
Input Data | |||||
------------ | |||||
For running the code, you need to change the data load function named "load_data_medical". adjacency matrices, features, labels, training, validation, and test indices should be returned in this function. More description about each variable is as follows: | |||||
- adj: is a dictionary with the keys 'D' and 'W'. adj['D'] contains the normalize adjacency matrix (with self-loop) between all nodes and is used for the discriminator. adj['W'] contains a list of normalized adjacency matrices (with self-loop). k-th element is the adjacency matrix between training samples with label k. | |||||
- Features: is a tensor that includes the features of all nodes (N by F). | |||||
- labels: is a list of labels for all nodes (with length N) | |||||
- idx_train, idx_val, idx_test: are lists of indexes for training, validation, and test samples respectively. | |||||
Parameters | |||||
------------ | |||||
Here is a list of parameters that should be passed to the main function or set in the code: | |||||
- epochs: number of epochs for training the whole network (default: 1000) | |||||
- epoch_D: number of epochs for training discriminator in each iteration (default: 1) | |||||
- epoch_W: number of epochs for training weighting networks in each iteration (default: 1) | |||||
- lr_D: learning for the discriminator (default: 0.01) | |||||
- lr_W: common learning rate for all weighting networks (default: 0.01) | |||||
- dropout_D: dropout for discriminator (default: 0.5) | |||||
- dropout_W: common dropout for all weighting networks (default: 0.5) | |||||
- gamma: a float number that shows the coefficient of entropy term in the loss function (default: 1) | |||||
- no-cuda: a boolean that can be set to True if using the GPU is not necessary | |||||
- structure_D: a list of hidden neurons in each layer of the discriminator. This variable should be set in the code (default: [2] which is a network with one hidden layer with two neurons in it) | |||||
- structure_W: a list of hidden neurons in each layer of all the weighting networks. This variable should be set in the code (default: [4]) | |||||
- drop_epochs: to select the best model, we use the performance of the network on the validation set based on the macro F1 score. To choose the best performance and avoid the network when it is not stabilized yet, we drop a number of epochs at the start of the iterations (default: 500). | |||||
Metrics | |||||
------------ | |||||
Accuracy and macro F1 are calculated in the code. Binary F1 and ROAUC are calculated for binary classification tasks. | |||||
Note | |||||
------------ | |||||
Thanks to Thomas Kipf. The code is written based on the "Graph Convolutional Networks in PyTorch" [2]. | |||||
Bug Report | |||||
------------ | |||||
If you find a bug, please send email to [email protected] including if necessary the input file and the parameters that caused the bug. | |||||
You can also send me any comment or suggestion about the program. | |||||
References | |||||
------------ | |||||
[1] [Ghorbani, Mahsa, et al. "RA-GCN: Graph convolutional network for disease prediction problems with imbalanced data." Medical Image Analysis 75 (2022): 102272.](https://arxiv.org/pdf/2103.00221) | |||||
[2] [Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016](https://arxiv.org/abs/1609.02907) | |||||
Cite | |||||
------------ | |||||
Please cite our paper if you use this code in your own work: | |||||
``` | |||||
@article{ghorbani2022ra, | |||||
title={Ra-gcn: Graph convolutional network for disease prediction problems with imbalanced data}, | |||||
author={Ghorbani, Mahsa and Kazi, Anees and Baghshah, Mahdieh Soleymani and Rabiee, Hamid R and Navab, Nassir}, | |||||
journal={Medical Image Analysis}, | |||||
volume={75}, | |||||
pages={102272}, | |||||
year={2022}, | |||||
publisher={Elsevier} | |||||
} | |||||
``` |
import math | |||||
import torch | |||||
from torch.nn.parameter import Parameter | |||||
from torch.nn.modules.module import Module | |||||
class GraphConvolution(Module): | |||||
""" | |||||
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 | |||||
""" | |||||
def __init__(self, in_features, out_features, order, bias=True): | |||||
super(GraphConvolution, self).__init__() | |||||
self.in_features = in_features | |||||
self.out_features = out_features | |||||
self.order = order | |||||
self.weight = torch.nn.ParameterList([]) | |||||
for i in range(self.order): | |||||
self.weight.append(Parameter(torch.FloatTensor(in_features, out_features))) | |||||
if bias: | |||||
self.bias = Parameter(torch.FloatTensor(out_features)) | |||||
else: | |||||
self.register_parameter('bias', None) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
for i in range(self.order): | |||||
stdv = 1. / math.sqrt(self.weight[i].size(1)) | |||||
self.weight[i].data.uniform_(-stdv, stdv) | |||||
if self.bias is not None: | |||||
self.bias.data.uniform_(-stdv, stdv) | |||||
def forward(self, input, adj): | |||||
output = [] | |||||
if self.order == 1 and type(adj) != list: | |||||
adj = [adj] | |||||
for i in range(self.order): | |||||
support = torch.mm(input, self.weight[i]) | |||||
# output.append(support) | |||||
output.append(torch.mm(adj[i], support)) | |||||
output = sum(output) | |||||
if self.bias is not None: | |||||
return output + self.bias | |||||
else: | |||||
return output | |||||
def __repr__(self): | |||||
return self.__class__.__name__ + ' (' \ | |||||
+ str(self.in_features) + ' -> ' \ | |||||
+ str(self.out_features) + ')' |
import time | |||||
import argparse | |||||
import torch | |||||
seed_num = 17 | |||||
torch.manual_seed(seed_num) | |||||
import torch.nn.functional as F | |||||
import itertools as it | |||||
import matplotlib.pyplot as plt | |||||
from code.model import RAGCN | |||||
from code.utils import accuracy, load_data_medical, encode_onehot_torch, class_f1, auc_score | |||||
from code.model import my_sigmoid | |||||
def train(): | |||||
### create structure for discriminator and weighting networks | |||||
struc_D = {'dropout': args.dropout_D, 'wd': 5e-4, 'lr': args.lr_D, 'nhid': structure_D} | |||||
### For simplicity, all weighting networks have the same hyper-parameters. They can be defined as a list of / | |||||
# / dictionaries which each dictionary contains the hyper-parameters of each weighting network | |||||
struc_Ws = n_ws*[{'dropout': args.dropout_W, 'wd': 5e-4, 'lr': args.lr_W, 'nhid': structure_W}] | |||||
### stats variable keeps the statistics of the network on train, validation and test sets | |||||
stats = dict() | |||||
### act is the function for normalization of weights of samples in each class | |||||
act = my_sigmoid | |||||
### Definition of model | |||||
model = RAGCN(adj=adj, features=features, nclass=nclass, struc_D=struc_D, struc_Ws=struc_Ws,n_ws=n_ws, | |||||
weighing_output_dim=1, act=act, gamma=args.gamma) | |||||
if use_cuda: | |||||
model.cuda() | |||||
### Keeping the best stats based on the value of Macro F1 in validation set | |||||
max_val = dict() | |||||
max_val['f1Macro_val'] = 0 | |||||
for epoch in range(args.epochs): | |||||
model.train() | |||||
### Train discriminator and weighting networks | |||||
model.run_both(epoch_for_D=args.epoch_D, epoch_for_W=args.epoch_W, labels_one_hot=labels_one_hot[idx_train, :], | |||||
samples=idx_train, args_cuda=use_cuda, equal_weights=False) | |||||
model.eval() | |||||
### calculate stats for training set | |||||
class_prob, embed = model.run_D(samples=idx_train) | |||||
weights, _ = model.run_W(samples=idx_train, labels=labels[idx_train], args_cuda=use_cuda, equal_weights=False) | |||||
stats['loss_train'] = model.loss_function_D(class_prob, labels_one_hot[idx_train], weights).item() | |||||
stats['nll_train'] = F.nll_loss(class_prob, labels[idx_train]).item() | |||||
stats['acc_train'] = accuracy(class_prob, labels=labels[idx_train]).item() | |||||
stats['f1Macro_train'] = class_f1(class_prob, labels[idx_train], type='macro') | |||||
if nclass == 2: | |||||
stats['f1Binary_train'] = class_f1(class_prob, labels[idx_train], type='binary', pos_label=pos_label) | |||||
stats['AUC_train'] = auc_score(class_prob, labels[idx_train]) | |||||
### calculate stats for validation and test set | |||||
test(model, stats) | |||||
### Drop first epochs and keep the best based on the macro F1 on validation set just for reporting | |||||
if epoch > drop_epochs and max_val['f1Macro_val'] < stats['f1Macro_val']: | |||||
for key, val in stats.items(): | |||||
max_val[key] = val | |||||
### Print stats in each epoch | |||||
print('Epoch: {:04d}'.format(epoch + 1)) | |||||
print('acc_train: {:.4f}'.format(stats['acc_train'])) | |||||
print('f1_macro_train: {:.4f}'.format(stats['f1Macro_train'])) | |||||
print('loss_train: {:.4f}'.format(stats['loss_train'])) | |||||
### Reporting the best results on test set | |||||
print('========Results==========') | |||||
for key, val in max_val.items(): | |||||
if 'loss' in key or 'nll' in key or 'test' not in key: | |||||
continue | |||||
print(key.replace('_', ' ') + ' : ' + str(val)) | |||||
### Calculate metrics on validation and test sets | |||||
def test(model, stats): | |||||
model.eval() | |||||
class_prob, embed = model.run_D(samples=idx_val) | |||||
weights, _ = model.run_W(samples=idx_val, labels=labels[idx_val], args_cuda=use_cuda, equal_weights=True) | |||||
stats['loss_val'] = model.loss_function_D(class_prob, labels_one_hot[idx_val], weights).item() | |||||
stats['nll_val'] = F.nll_loss(class_prob, labels[idx_val]).item() | |||||
stats['acc_val'] = accuracy(class_prob, labels[idx_val]).item() | |||||
stats['f1Macro_val'] = class_f1(class_prob, labels[idx_val], type='macro') | |||||
if nclass == 2: | |||||
stats['f1Binary_val'] = class_f1(class_prob, labels[idx_val], type='binary', pos_label=pos_label) | |||||
stats['AUC_val'] = auc_score(class_prob, labels[idx_val]) | |||||
class_prob, embed = model.run_D(samples=idx_test) | |||||
weights, _ = model.run_W(samples=idx_test, labels=labels[idx_test], args_cuda=use_cuda, equal_weights=True) | |||||
stats['loss_test'] = model.loss_function_D(class_prob, labels_one_hot[idx_test], weights).item() | |||||
stats['nll_test'] = F.nll_loss(class_prob, labels[idx_test]).item() | |||||
stats['acc_test'] = accuracy(class_prob, labels[idx_test]).item() | |||||
stats['f1Macro_test'] = class_f1(class_prob, labels[idx_test], type='macro') | |||||
if nclass == 2: | |||||
stats['f1Binary_test'] = class_f1(class_prob, labels[idx_test], type='binary', pos_label=pos_label) | |||||
stats['AUC_test'] = auc_score(class_prob, labels[idx_test]) | |||||
if __name__ == '__main__': | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--epochs', type=int, default=1000, | |||||
help='Number of epochs to train.') | |||||
parser.add_argument('--epoch_D', type=int, default=1, | |||||
help='Number of training loop for discriminator in each epoch.') | |||||
parser.add_argument('--epoch_W', type=int, default=1, | |||||
help='Number of training loop for discriminator in each epoch.') | |||||
parser.add_argument('--lr_D', type=float, default=0.01, | |||||
help='Learning rate for discriminator.') | |||||
parser.add_argument('--lr_W', type=float, default=0.01, | |||||
help='Equal learning rate for weighting networks.') | |||||
parser.add_argument('--dropout_D', type=float, default=0.5, | |||||
help='Dropout rate for discriminator.') | |||||
parser.add_argument('--dropout_W', type=float, default=0.5, | |||||
help='Dropout rate for weighting networks.') | |||||
parser.add_argument('--gamma', type=float, default=1, | |||||
help='Coefficient of entropy term in loss function.') | |||||
parser.add_argument('--no-cuda', action='store_true', default=False, | |||||
help='Disables CUDA training.') | |||||
### This list shows the number of hidden neurons in each hidden layer of discriminator | |||||
structure_D = [2] | |||||
### This list shows the number of hidden neurons in each hidden layer of weighting networks | |||||
structure_W = [4] | |||||
### The results of first drop_epochs will be dropped for choosing the best network based on the validation set | |||||
drop_epochs = 500 | |||||
args = parser.parse_args() | |||||
use_cuda = not args.no_cuda and torch.cuda.is_available() | |||||
### Loading function should return the following variables | |||||
### adj is a dictionary including 'D' and 'W' as keys | |||||
### adj['D'] contains the main adjacency matrix between all samples for discriminator | |||||
### adj['W'] contains a list of adjacency matrices. Element i contains the adjacency matrix between samples of / | |||||
### / class i in the training samples | |||||
### Features is a tensor with size N by F | |||||
### labels is a list of node labels | |||||
### idx train is a list contains the index of training samples. idx_val and idx_test follow the same pattern | |||||
adj, features, labels, idx_train, idx_val, idx_test = load_data_medical(dataset_addr='../data/synthetic/per-90gt-0.5.pkl', | |||||
train_ratio=0.6, test_ratio=0.2) | |||||
### start of code | |||||
labels_one_hot = encode_onehot_torch(labels) | |||||
nclass = labels_one_hot.shape[1] | |||||
n_ws = nclass | |||||
pos_label = None | |||||
if nclass == 2: | |||||
pos_label = 1 | |||||
zero_class = (labels == 0).sum() | |||||
one_class = (labels == 1).sum() | |||||
if zero_class < one_class: | |||||
pos_label = 0 | |||||
if use_cuda: | |||||
for key, val in adj.items(): | |||||
if type(val) is list: | |||||
for i in range(len(adj)): | |||||
adj[key][i] = adj[key][i].cuda() | |||||
else: | |||||
adj[key] = adj[key].cuda() | |||||
features = features.cuda() | |||||
labels_one_hot = labels_one_hot.cuda() | |||||
labels = labels.cuda() | |||||
### Training the networks | |||||
train() | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import torch.optim as optim | |||||
from code.layer import GraphConvolution | |||||
def my_sigmoid(mx, dim): | |||||
mx = torch.sigmoid(mx) | |||||
return F.normalize(mx, p=1, dim=dim) | |||||
class Discriminator(nn.Module): | |||||
def __init__(self, nfeat, nhid, nclass, dropout, order): | |||||
super(Discriminator, self).__init__() | |||||
layers = [] | |||||
if len(nhid) == 0: | |||||
layers.append(GraphConvolution(nfeat, nclass, order=order)) | |||||
else: | |||||
layers.append(GraphConvolution(nfeat, nhid[0], order=order)) | |||||
for i in range(len(nhid) - 1): | |||||
layers.append(GraphConvolution(nhid[i], nhid[i + 1], order=order)) | |||||
if nclass > 1: | |||||
layers.append(GraphConvolution(nhid[-1], nclass, order=order)) | |||||
self.gc = nn.ModuleList(layers) | |||||
self.dropout = dropout | |||||
self.nclass = nclass | |||||
def forward(self, x, adj, samples=-1, func=F.relu): | |||||
end_layer = len(self.gc) - 1 if self.nclass > 1 else len(self.gc) | |||||
for i in range(end_layer): | |||||
x = F.dropout(x, self.dropout, training=self.training) | |||||
x = self.gc[i](x, adj) | |||||
x = func(x) | |||||
if self.nclass > 1: | |||||
classifier = self.gc[-1](x, adj) | |||||
classifier = F.log_softmax(classifier, dim=1) | |||||
return classifier[samples,:], x | |||||
else: | |||||
return None, x | |||||
class Weighing(nn.Module): | |||||
def __init__(self, nfeat, nhid, weighing_output_dim, dropout, order): | |||||
super(Weighing, self).__init__() | |||||
layers = [] | |||||
layers.append(GraphConvolution(nfeat, nhid[0], order=order)) | |||||
for i in range(len(nhid) - 1): | |||||
layers.append(GraphConvolution(nhid[i], nhid[i + 1], order=order)) | |||||
layers.append(GraphConvolution(nhid[-1], weighing_output_dim, order=order)) | |||||
self.gc = nn.ModuleList(layers) | |||||
self.dropout = dropout | |||||
def forward(self, x, adj, func, samples=-1): | |||||
end_layer = len(self.gc) - 1 | |||||
for i in range(end_layer): | |||||
x = self.gc[i](x, adj) | |||||
x = F.leaky_relu(x, inplace=False) | |||||
x = F.dropout(x, self.dropout, training=self.training) | |||||
if type(samples) is int: | |||||
weights = func(self.gc[-1](x, adj), dim=0) | |||||
else: | |||||
weights = func(self.gc[-1](x, adj)[samples, :], dim=0) | |||||
if len(weights.shape) == 1: | |||||
weights = weights[None, :] | |||||
return weights, x | |||||
class RAGCN: | |||||
def __init__(self, features, adj , nclass, struc_D, struc_Ws, n_ws, weighing_output_dim, act, gamma=1): | |||||
nfeat = features.shape[1] | |||||
if type(adj['D']) is list: | |||||
order_D = len(adj['D']) | |||||
else: | |||||
order_D = 1 | |||||
if type(adj['W'][0]) is list: | |||||
order_W = len(adj['W'][0]) | |||||
else: | |||||
order_W = 1 | |||||
self.n_ws = n_ws | |||||
self.wod = weighing_output_dim | |||||
self.net_D = Discriminator(nfeat, struc_D['nhid'], nclass, struc_D['dropout'], order=order_D) | |||||
self.gamma = gamma | |||||
self.opt_D = optim.Adam(self.net_D.parameters(), lr=struc_D['lr'], weight_decay=struc_D['wd']) | |||||
self.net_Ws = [] | |||||
self.opt_Ws = [] | |||||
for i in range(n_ws): | |||||
self.net_Ws.append(Weighing(nfeat=nfeat, nhid=struc_Ws[i]['nhid'], weighing_output_dim=weighing_output_dim | |||||
, dropout=struc_Ws[i]['dropout'], order=order_W)) | |||||
self.opt_Ws.append( | |||||
optim.Adam(self.net_Ws[-1].parameters(), lr=struc_Ws[i]['lr'], weight_decay=struc_Ws[i]['wd'])) | |||||
self.adj_D = adj['D'] | |||||
self.adj_W = adj['W'] | |||||
self.features = features | |||||
self.act = act | |||||
def run_D(self, samples): | |||||
class_prob, embed = self.net_D(self.features, self.adj_D, samples) | |||||
return class_prob, embed | |||||
def run_W(self, samples, labels, args_cuda=False, equal_weights=False): | |||||
batch_size = samples.shape[0] | |||||
embed = None | |||||
if equal_weights: | |||||
max_label = int(labels.max().item() + 1) | |||||
weight = torch.empty(batch_size) | |||||
for i in range(max_label): | |||||
labels_indices = (labels == i).nonzero().squeeze() | |||||
if len(labels_indices.shape) == 0: | |||||
batch_size = 1 | |||||
else: | |||||
batch_size = len(labels_indices) | |||||
if labels_indices is not None: | |||||
weight[labels_indices] = 1 / batch_size * torch.ones(batch_size) | |||||
weight = weight / max_label | |||||
else: | |||||
max_label = int(labels.max().item() + 1) | |||||
weight = torch.empty(batch_size) | |||||
if args_cuda: | |||||
weight = weight.cuda() | |||||
for i in range(max_label): | |||||
labels_indices = (labels == i).nonzero().squeeze() | |||||
if labels_indices is not None: | |||||
sub_samples = samples[labels_indices] | |||||
weight_, embed = self.net_Ws[i](x=self.features[sub_samples,:], adj=self.adj_W[i], samples=-1, func=self.act) | |||||
weight[labels_indices] = weight_.squeeze() if self.wod == 1 else weight_[:,i] | |||||
weight = weight / max_label | |||||
if args_cuda: | |||||
weight = weight.cuda() | |||||
return weight, embed | |||||
def loss_function_D(self, output, labels, weights): | |||||
return torch.sum(- weights * (labels.float() * output).sum(1), -1) | |||||
def loss_function_G(self, output, labels, weights): | |||||
return torch.sum(- weights * (labels.float() * output).sum(1), -1) - self.gamma*torch.sum(weights*torch.log(weights+1e-20)) | |||||
def zero_grad_both(self): | |||||
self.opt_D.zero_grad() | |||||
for opt in self.opt_Ws: | |||||
opt.zero_grad() | |||||
def run_both(self, epoch_for_D, epoch_for_W, labels_one_hot, samples=-1, | |||||
args_cuda=False, equal_weights=False): | |||||
labels_not_onehot = labels_one_hot.max(1)[1].type_as(labels_one_hot) | |||||
for e_D in range(epoch_for_D): | |||||
self.zero_grad_both() | |||||
class_prob_1, embed = self.run_D(samples) | |||||
weights_1, _ = self.run_W(samples=samples, labels=labels_not_onehot, args_cuda=args_cuda, equal_weights=equal_weights) | |||||
loss_D = self.loss_function_D(output=class_prob_1, labels=labels_one_hot, weights=weights_1) | |||||
loss_D.backward() | |||||
self.opt_D.step() | |||||
for e_W in range(epoch_for_W): | |||||
self.zero_grad_both() | |||||
class_prob_2, embed = self.run_D(samples) | |||||
weights_2, embeds = self.run_W(samples=samples, labels=labels_not_onehot, args_cuda=args_cuda, equal_weights=equal_weights) | |||||
loss_G = -self.loss_function_G(output=class_prob_2, labels=labels_one_hot, weights=weights_2) | |||||
loss_G.backward() | |||||
for opt in self.opt_Ws: | |||||
opt.step() | |||||
def train(self): | |||||
self.net_D.train() | |||||
for i in range(self.n_ws): | |||||
self.net_Ws[i].train() | |||||
def eval(self): | |||||
self.net_D.eval() | |||||
for i in range(self.n_ws): | |||||
self.net_Ws[i].eval() | |||||
def cuda(self): | |||||
self.features = self.features.to(device='cuda') | |||||
for i in range(len(self.adj_D)): | |||||
self.adj_D[i] = self.adj_D[i].to(device='cuda') | |||||
for i in range(len(self.adj_W)): | |||||
self.adj_W[i] = self.adj_W[i].to(device='cuda') | |||||
self.net_D.cuda() | |||||
for i in range(self.n_ws): | |||||
self.net_Ws[i].cuda() |
import numpy as np | |||||
import scipy.sparse as sp | |||||
import torch | |||||
import networkx as nx | |||||
from sklearn.model_selection import train_test_split | |||||
from sklearn.metrics import f1_score, roc_auc_score | |||||
import pickle | |||||
def encode_onehot_torch(labels): | |||||
num_classes = int(labels.max() + 1) | |||||
y = torch.eye(num_classes) | |||||
return y[labels] | |||||
def encode_onehot(labels): | |||||
classes = set(labels) | |||||
classes_dict = {c: np.identity(len(classes))[i, :] for i, c in | |||||
enumerate(classes)} | |||||
labels_onehot = np.array(list(map(classes_dict.get, labels)), | |||||
dtype=np.int32) | |||||
return labels_onehot | |||||
def fill_features(features): | |||||
empty_indices = np.argwhere(features != features) | |||||
for index in empty_indices: | |||||
features[index[0], index[1]] = np.nanmean(features[:,index[1]]) | |||||
return features | |||||
def load_data_medical(dataset_addr, train_ratio, test_ratio=0.2): | |||||
with open(dataset_addr, 'rb') as f: # Python 3: open(..., 'rb') | |||||
adj, features, labels = pickle.load(f) | |||||
n_node = adj.shape[1] | |||||
adj = nx.adjacency_matrix(nx.from_numpy_array(adj)) | |||||
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) | |||||
adj_D = normalize(adj + sp.eye(adj.shape[0])) | |||||
adj_W = normalize(adj + sp.eye(adj.shape[0])) | |||||
adj= dict() | |||||
adj['D'] = adj_D | |||||
features = features.astype(np.float) | |||||
features = fill_features(features) | |||||
idx_train, idx_test = train_test_split(range(n_node), test_size=test_ratio, random_state=42, stratify=labels) | |||||
idx_train, idx_val = train_test_split(idx_train, train_size=train_ratio/(1-test_ratio), random_state=42, stratify=labels[idx_train]) | |||||
adj['W'] = [] | |||||
for nc in range(labels.max() + 1): | |||||
nc_idx = np.where(labels[idx_train] == nc)[0] | |||||
nc_idx = np.array(idx_train)[nc_idx] | |||||
adj['W'].append(adj_W[np.ix_(nc_idx,nc_idx)]) | |||||
features = torch.FloatTensor(features) | |||||
for key,val in adj.items(): | |||||
if key == 'D': | |||||
adj[key] = torch.FloatTensor(np.array(adj[key].todense()))#sparse_mx_to_torch_sparse_tensor(adj[i]) | |||||
else: | |||||
for i in range(len(val)): | |||||
adj[key][i] = torch.FloatTensor(np.array(adj[key][i].todense())) | |||||
labels = torch.LongTensor(labels) | |||||
idx_train = torch.LongTensor(idx_train) | |||||
idx_val = torch.LongTensor(idx_val) | |||||
idx_test = torch.LongTensor(idx_test) | |||||
return adj, features, labels, idx_train, idx_val, idx_test | |||||
def normalize(mx): | |||||
"""Row-normalize sparse matrix""" | |||||
rowsum = np.array(mx.sum(1)) | |||||
r_inv = np.power(rowsum, -1).flatten() | |||||
r_inv[np.isinf(r_inv)] = 0. | |||||
r_mat_inv = sp.diags(r_inv) | |||||
mx = r_mat_inv.dot(mx) | |||||
return mx | |||||
def accuracy(output, labels): | |||||
preds = output.max(1)[1].type_as(labels) | |||||
correct = preds.eq(labels).double() | |||||
correct = correct.sum() | |||||
return correct / len(labels) | |||||
def class_f1(output, labels, type='micro', pos_label=None): | |||||
preds = output.max(1)[1].type_as(labels) | |||||
if pos_label is None: | |||||
return f1_score(labels.cpu().numpy(), preds.cpu(), average=type) | |||||
else: | |||||
return f1_score(labels.cpu().numpy(), preds.cpu(), average=type, pos_label=pos_label) | |||||
def auc_score(output, labels): | |||||
return roc_auc_score(labels.cpu().numpy(), output[:,1].detach().cpu().numpy()) | |||||
def sparse_mx_to_torch_sparse_tensor(sparse_mx): | |||||
"""Convert a scipy sparse matrix to a torch sparse tensor.""" | |||||
sparse_mx = sparse_mx.tocoo().astype(np.float32) | |||||
indices = torch.from_numpy( | |||||
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) | |||||
values = torch.from_numpy(sparse_mx.data) | |||||
shape = torch.Size(sparse_mx.shape) | |||||
return torch.sparse.FloatTensor(indices, values, shape) |