@@ -0,0 +1,129 @@ | |||
import torch | |||
import numpy as np | |||
import os | |||
import random | |||
import argparse | |||
import pickle | |||
from train_gkd import run_gkd | |||
from utils import calculate_imbalance_weight | |||
def str2bool(v): | |||
if isinstance(v, bool): | |||
return v | |||
if v.lower() in ('yes', 'true', 't', 'y', '1'): | |||
return True | |||
elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |||
return False | |||
else: | |||
raise argparse.ArgumentTypeError('Boolean value expected.') | |||
def last_argmax(l): | |||
reverse = l[::-1] | |||
return len(reverse) - np.argmax(reverse) - 1 | |||
def find_isolated(adj, idx_train): | |||
adj = adj.to_dense() | |||
deg = adj.sum(1) | |||
idx_connected = torch.where(deg > 0)[0] | |||
idx_connected = [x for x in idx_connected if x not in idx_train] | |||
return idx_connected | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--seed', default=42, type=int, help='Seed') | |||
parser.add_argument('--use_cuda', action='store_true', default=True, | |||
help='Use CUDA if available.') | |||
parser.add_argument('--epochs_teacher', default=300, type=int, help='Number of epochs to train teacher network') | |||
parser.add_argument('--epochs_student', default=200, type=int, help='Number of epochs to train student network') | |||
parser.add_argument('--epochs_lpa', default=10, type=int, help='Number of epochs to train student network') | |||
parser.add_argument('--lr_teacher', type=float, default=0.005, help='Learning rate for teacher network.') | |||
parser.add_argument('--lr_student', type=float, default=0.005, help='Learning rate for student network.') | |||
parser.add_argument('--wd_teacher', type=float, default=5e-4, help='Weight decay for teacher network.') | |||
parser.add_argument('--wd_student', type=float, default=5e-4, help='Weight decay for student network.') | |||
parser.add_argument('--dropout_teacher', type=float, default=0.3, help='Dropout for teacher network.') | |||
parser.add_argument('--dropout_student', type=float, default=0.3, help='Dropout for student network.') | |||
parser.add_argument('--burn_out_teacher', default=100, type=int, help='Number of epochs to drop for selecting best \ | |||
parameters based on validation set for teacher network') | |||
parser.add_argument('--burn_out_student', default=100, type=int, help='Number of epochs to drop for selecting best \ | |||
parameters based on validation set for student network') | |||
parser.add_argument('--alpha', default=0.1, type=float, help='alpha') | |||
args = parser.parse_args() | |||
seed = args.seed | |||
use_cuda = args.use_cuda and torch.cuda.is_available() | |||
os.environ['PYTHONHASHSEED'] = str(seed) | |||
# Torch RNG | |||
torch.manual_seed(seed) | |||
torch.cuda.manual_seed(seed) | |||
torch.cuda.manual_seed_all(seed) | |||
# Python RNG | |||
np.random.seed(seed) | |||
random.seed(seed) | |||
### Following lists shows the number of hidden neurons in each hidden layer of teacher network and student network respectively | |||
hidden_teacher = [8] | |||
hidden_student = [4] | |||
### select the best output of teacher for lpa and best output of student for reporting based on the following metrics | |||
best_metric_teacher = 'f1macro_val' ### concat a member of these lists: [loss, acc, f1macro][train, val, test] | |||
best_metric_student = 'f1macro_val' | |||
### show the changes of statistics in training, validation and test set with figures | |||
show_stats = False | |||
### This is needed is you want to make sure that access to test set is not available during the training | |||
isolated_test = True | |||
### Data should be loaded here as follow: | |||
### adj is a sparse tensor showing the adjacency matrix between nodes with size N by N | |||
### Features is a tensor with size N by F | |||
### labels is a list of node labels | |||
### idx train is a list contains the index of training samples. idx_val and idx_test follow the same pattern | |||
with open('./data/synthetic/sample-2000-f_128-g_4-gt_0.5-sp_0.5-100-200-500.pkl', 'rb') as f: # Python 3: open(..., 'rb') | |||
adj, features, labels, idx_train, idx_val, idx_test = pickle.load(f) | |||
### you can load weights for samples in the training set or calculate them here. | |||
### the weights of nodes will be calculates in the train_gkd after the lpa step | |||
sample_weight = calculate_imbalance_weight(idx_train, labels) | |||
idx_connected = find_isolated(adj, idx_train) | |||
params_teacher = { | |||
'lr': args.lr_teacher, | |||
'hidden': hidden_teacher, | |||
'weight_decay': args.wd_teacher, | |||
'dropout': args.dropout_teacher, | |||
'epochs': args.epochs_teacher, | |||
'best_metric': best_metric_teacher, | |||
'burn_out': args.burn_out_teacher | |||
} | |||
params_student = { | |||
'lr': args.lr_student, | |||
'hidden': hidden_student, | |||
'weight_decay': args.wd_student, | |||
'dropout': args.dropout_teacher, | |||
'epochs': args.epochs_student, | |||
'best_metric': best_metric_student, | |||
'burn_out': args.burn_out_student | |||
} | |||
params_lpa = { | |||
'epochs': args.epochs_lpa, | |||
'alpha': args.alpha | |||
} | |||
stats = run_gkd(adj, features, labels, idx_train, | |||
idx_val,idx_test, idx_connected, params_teacher, params_student, params_lpa, | |||
sample_weight=sample_weight, isolated_test=isolated_test, | |||
use_cuda=use_cuda, show_stats=show_stats) | |||
ind_val_max = params_student['burn_out'] + last_argmax(stats[params_student['best_metric']][params_student['burn_out']:]) | |||
print('Last index with maximum metric on validation set: ' + str(ind_val_max)) | |||
print('Final accuracy on val: ' + str(stats['acc_test'][ind_val_max])) | |||
print('Final Macro F1 on val: ' + str(stats['f1macro_test'][ind_val_max])) | |||
print('Final AUC on val' + str(stats['auc_test'][ind_val_max])) |
@@ -0,0 +1,37 @@ | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
class Fully(nn.Module): | |||
def __init__(self, nfeat, nhid, nclass, dropout,multi_label=False): | |||
super(Fully, self).__init__() | |||
self.multi_label = multi_label | |||
layers = [] | |||
if len(nhid) == 0: | |||
layers.append(nn.Linear(nfeat, nclass)) | |||
else: | |||
layers.append(nn.Linear(nfeat, nhid[0])) | |||
for i in range(len(nhid) - 1): | |||
layers.append(nn.Linear(nhid[i], nhid[i + 1])) | |||
if nclass > 1: | |||
layers.append(nn.Linear(nhid[-1], nclass)) | |||
self.fc = nn.ModuleList(layers) | |||
self.dropout = dropout | |||
self.nclass = nclass | |||
def forward(self, x, adj=None): | |||
end_layer = len(self.fc) - 1 if self.nclass > 1 else len(self.fc) | |||
for i in range(end_layer): | |||
x = F.dropout(x, self.dropout, training=self.training) | |||
x = self.fc[i](x) | |||
x = F.relu(x) | |||
classifier = self.fc[-1](x) | |||
if self.multi_label: | |||
classifier = classifier | |||
else: | |||
classifier = F.log_softmax(classifier, dim=1) | |||
return classifier | |||
@@ -0,0 +1,76 @@ | |||
GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference | |||
==== | |||
Here is the code for node classification in graphs when the graph is not available at test time. | |||
Ghorbani et.al. "GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference" [1] | |||
Usage | |||
------------ | |||
The main file is "main.py". Run with ```python train.py``` | |||
Input Data | |||
------------ | |||
For running the code, you need to load data in the main.py. adjacency matrices, features, labels, training, validation, and test indices should be returned in this function. More description about each variable is as follows: | |||
- adj: is a sparse tensor showing the **normalized** adjacency matrix between all nodes (train, validation and test). It should be noted that validation and test nodes only has self-loop without any edge to other nodes. | |||
- Features: is a tensor that includes the features of all nodes (N by F). | |||
- labels: is a list of labels for all nodes (with length N) | |||
- idx_train, idx_val, idx_test: are lists of indexes for training, validation, and test samples respectively. | |||
Parameters | |||
------------ | |||
Here is a list of parameters that should be passed to the main function or set in the code: | |||
- seed: seed number | |||
- use-cuda: using CUDA for training if it is available | |||
- epochs_teacher: number of epochs for training the teacher network (default: 300) | |||
- epochs_student: number of epochs for training the student network (default: 200) | |||
- epochs_lpa: number of epochs for running label-propagation algorithm (default: 10) | |||
- lr_teacher: learning rate for the teacher network (default: 0.005) | |||
- lr_student: learning rate for the student network (default: 0.005) | |||
- wd_teacher: weight decay for the teacher network (default: 5e-4) | |||
- wd_student: weight decay for the student network (default: 5e-4) | |||
- dropout_teacher: dropout for the teacher network (default: 0.3) | |||
- dropout_student: dropout for the student network (default: 0.3) | |||
- burn_out_teacher: Number of epochs to drop for selecting best parameters based on validation set for teacher network (default: 100) | |||
- burn_out_student: Number of epochs to drop for selecting best parameters based on validation set for student network (default: 100) | |||
- alpha: a float number between 0 and 1 that shows the coefficient of remembrance term (default: 0.1) | |||
- hidden_teacher: a list of hidden neurons in each layer of the teacher network. This variable should be set in the code (default: [8] which is a network with one hidden layer with eight neurons in it) | |||
- hidden_student: a list of hidden neurons in each layer of the student network. This variable should be set in the code (default: [4]) | |||
- best_metric_teacher: to select the best output of teacher network, we use the performance of the network on the validation set based on this score (should be a combination between [loss, acc, f1macro] and [train, val, test]). | |||
- best_metric_student: to select the best output of student network, we use the performance of the network on the validation set based on this score. | |||
Metrics | |||
------------ | |||
Accuracy, macro F1 are calculated in the code. ROAUC can be calculated for binary classification tasks. | |||
Note | |||
------------ | |||
Thanks to Thomas Kipf. The code is written based on the "Graph Convolutional Networks in PyTorch" [2]. | |||
Bug Report | |||
------------ | |||
If you find a bug, please send email to [email protected] including if necessary the input file and the parameters that caused the bug. | |||
You can also send me any comment or suggestion about the program. | |||
References | |||
------------ | |||
[1] [Ghorbani, Mahsa, et al. "GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference." International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2021.](https://arxiv.org/pdf/2104.03597) | |||
[2] [Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016](https://arxiv.org/abs/1609.02907) | |||
Cite | |||
------------ | |||
Please cite our paper if you use this code in your own work: | |||
``` | |||
@inproceedings{ghorbani2021gkd, | |||
title={GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference}, | |||
author={Ghorbani, Mahsa and Bahrami, Mojtaba and Kazi, Anees and Soleymani Baghshah, Mahdieh and Rabiee, Hamid R and Navab, Nassir}, | |||
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, | |||
pages={709--718}, | |||
year={2021}, | |||
organization={Springer} | |||
} | |||
``` |
@@ -0,0 +1,193 @@ | |||
import matplotlib.pyplot as plt | |||
import torch | |||
import torch.optim as optim | |||
from utils import accuracy, encode_onehot_torch, class_f1, roc_auc, half_normalize, loss, calculate_imbalance_weight | |||
from models import Fully | |||
def run_gkd(adj, features, labels, idx_train, idx_val, idx_test, idx_connected, | |||
params_teacher, params_student, params_lpa, isolated_test, sample_weight, use_cuda=True, show_stats=False): | |||
def train_teacher(epoch): | |||
model_teacher.train() | |||
optimizer_teacher.zero_grad() | |||
output = model_teacher(features) | |||
loss_train = loss(output[idx_train], labels_one_hot[idx_train], sample_weight[idx_train]) | |||
acc_train = accuracy(output[idx_train], labels[idx_train]) | |||
stats['f1macro_train'].append(class_f1(output[idx_train], labels[idx_train], type='macro')) | |||
if 'auc_train' in stats: | |||
stats['auc_train'].append(roc_auc(output[idx_train, 1], labels[idx_train])) | |||
loss_train.backward() | |||
optimizer_teacher.step() | |||
model_teacher.eval() | |||
output = model_teacher(features) | |||
loss_val = loss(output[idx_val], labels_one_hot[idx_val], sample_weight[idx_val]) | |||
acc_val = accuracy(output[idx_val], labels[idx_val]) | |||
if 'auc_val' in stats: | |||
stats['auc_val'].append(roc_auc(output[idx_val, 1], labels[idx_val])) | |||
stats['f1macro_val'].append(class_f1(output[idx_val], labels[idx_val], type='macro')) | |||
stats['loss_train'].append(loss_train.item()) | |||
stats['acc_train'].append(acc_train.item()) | |||
stats['loss_val'].append(loss_val.item()) | |||
stats['acc_val'].append(acc_val.item()) | |||
print('Epoch: {:04d}'.format(epoch + 1), | |||
'loss_train: {:.4f}'.format(loss_train.item()), | |||
'acc_train: {:.4f}'.format(acc_train.item()), | |||
'loss_val: {:.4f}'.format(loss_val.item()), | |||
'acc_val: {:.4f}'.format(acc_val.item())) | |||
return output.detach() | |||
def test_teacher(): | |||
model_teacher.eval() | |||
output = model_teacher(features) | |||
loss_test = loss(output[idx_test], labels_one_hot[idx_test], sample_weight[idx_test]) | |||
acc_test = accuracy(output[idx_test], labels[idx_test]) | |||
stats['f1macro_test'].append(class_f1(output[idx_test], labels[idx_test], type='macro')) | |||
stats['loss_test'].append(loss_test.item()) | |||
stats['acc_test'].append(acc_test.item()) | |||
if 'auc_test' in stats: | |||
stats['auc_test'].append(roc_auc(output[idx_test, 1], labels[idx_test])) | |||
print("Test set results:", | |||
"loss= {:.4f}".format(loss_test.item()), | |||
"accuracy= {:.4f}".format(acc_test.item())) | |||
def train_student(epoch): | |||
model_student.train() | |||
optimizer_student.zero_grad() | |||
output = model_student(features) | |||
loss_train = loss(output[idx_lpa], labels_lpa[idx_lpa], sample_weight[idx_lpa]) | |||
acc_train = accuracy(output[idx_lpa], labels[idx_lpa]) | |||
stats['f1macro_train'].append(class_f1(output[idx_train], labels[idx_train], type='macro')) | |||
if 'auc_train' in stats: | |||
stats['auc_train'].append(roc_auc(output[idx_train, 1], labels[idx_train])) | |||
if 'f1binary_train' in stats: | |||
stats['f1binary_train'].append(class_f1(output[idx_train], labels[idx_train], type='binary')) | |||
loss_train.backward() | |||
optimizer_student.step() | |||
model_student.eval() | |||
output = model_student(features) | |||
loss_val = loss(output[idx_val], labels_one_hot[idx_val], sample_weight[idx_val]) | |||
acc_val = accuracy(output[idx_val], labels[idx_val]) | |||
if 'auc_val' in stats: | |||
stats['auc_val'].append(roc_auc(output[idx_val, 1], labels[idx_val])) | |||
stats['f1macro_val'].append(class_f1(output[idx_val], labels[idx_val], type='macro')) | |||
if 'f1binary_val' in stats: | |||
stats['f1binary_val'].append(class_f1(output[idx_val], labels[idx_val], type='binary')) | |||
stats['loss_train'].append(loss_train.item()) | |||
stats['acc_train'].append(acc_train.item()) | |||
stats['loss_val'].append(loss_val.item()) | |||
stats['acc_val'].append(acc_val.item()) | |||
print('Epoch: {:04d}'.format(epoch + 1), | |||
'loss_train: {:.4f}'.format(loss_train.item()), | |||
'acc_train: {:.4f}'.format(acc_train.item()), | |||
'loss_val: {:.4f}'.format(loss_val.item()), | |||
'acc_val: {:.4f}'.format(acc_val.item())) | |||
def test_student(): | |||
model_student.eval() | |||
output = model_student(features) | |||
loss_test = loss(output[idx_test], labels_one_hot[idx_test], sample_weight[idx_test]) | |||
acc_test = accuracy(output[idx_test], labels[idx_test]) | |||
stats['f1macro_test'].append(class_f1(output[idx_test], labels[idx_test], type='macro')) | |||
if 'f1binary_test' in stats: | |||
stats['f1binary_test'].append(class_f1(output[idx_test], labels[idx_test], type='binary')) | |||
stats['loss_test'].append(loss_test.item()) | |||
stats['acc_test'].append(acc_test.item()) | |||
if 'auc_test' in stats: | |||
stats['auc_test'].append(roc_auc(output[idx_test, 1], labels[idx_test])) | |||
print("Test set results:", | |||
"loss= {:.4f}".format(loss_test.item()), | |||
"accuracy= {:.4f}".format(acc_test.item())) | |||
return | |||
stats = dict() | |||
states = ['train', 'val', 'test'] | |||
metrics = ['loss', 'acc', 'f1macro','auc'] | |||
for s in states: | |||
for m in metrics: | |||
stats[m + '_' + s] = [] | |||
labels_one_hot = encode_onehot_torch(labels) | |||
idx_lpa = list(idx_connected) + list(idx_train) | |||
if isolated_test: | |||
idx_lpa = [x for x in idx_lpa if (x not in idx_test) and (x not in idx_val)] | |||
idx_lpa = torch.LongTensor(idx_lpa) | |||
model_teacher = Fully(nfeat=features.shape[1], | |||
nhid=params_teacher['hidden'], | |||
nclass=int(labels_one_hot.shape[1]), | |||
dropout=params_teacher['dropout']) | |||
model_student = Fully(nfeat=features.shape[1], | |||
nhid=params_student['hidden'], | |||
nclass=int(labels_one_hot.shape[1]), | |||
dropout=params_student['dropout']) | |||
optimizer_teacher = optim.Adam(model_teacher.parameters(), | |||
lr=params_teacher['lr'], | |||
weight_decay=params_teacher['weight_decay']) | |||
optimizer_student = optim.Adam(model_student.parameters(), | |||
lr=params_student['lr'], | |||
weight_decay=params_student['weight_decay']) | |||
if use_cuda: | |||
model_student.cuda() | |||
model_teacher.cuda() | |||
features = features.cuda() | |||
sample_weight = sample_weight.cuda() | |||
adj = adj.cuda() | |||
labels = labels.cuda() | |||
labels_one_hot = labels_one_hot.cuda() | |||
idx_train = idx_train.cuda() | |||
idx_val = idx_val.cuda() | |||
idx_test = idx_test.cuda() | |||
idx_lpa = idx_lpa.cuda() | |||
labels_lpa = labels_one_hot | |||
best_metric_val = 0 | |||
for epoch in range(params_teacher['epochs']): | |||
output_fc = train_teacher(epoch) | |||
test_teacher() | |||
if stats[params_teacher['best_metric']][-1] >= best_metric_val and epoch > params_teacher['burn_out']: | |||
best_metric_val = stats[params_teacher['best_metric']][-1] | |||
best_output = output_fc | |||
alpha = params_lpa['alpha'] | |||
labels_start = torch.exp(best_output) | |||
labels_lpa = labels_start | |||
labels_lpa[idx_train, :] = labels_one_hot[idx_train, :].float() | |||
for i in range(params_lpa['epochs']): | |||
labels_lpa = (1 - alpha) * adj.mm(labels_lpa) + alpha * labels_start | |||
labels_lpa[idx_train, :] = labels_one_hot[idx_train, :].float() | |||
labels_lpa = half_normalize(labels_lpa) | |||
sample_weight = calculate_imbalance_weight(idx_lpa, labels_lpa.argmax(1)) | |||
if use_cuda: | |||
sample_weight = sample_weight.cuda() | |||
### empty stats | |||
for k, v in stats.items(): | |||
stats[k] = [] | |||
for epoch in range(params_student['epochs']): | |||
train_student(epoch) | |||
test_student() | |||
if show_stats: | |||
plt.figure() | |||
fig, axs = plt.subplots(len(states), len(metrics), figsize=(5 * len(metrics), 3 * len(states))) | |||
for i, s in enumerate(states): | |||
for j, m in enumerate(metrics): | |||
axs[i, j].plot(stats[m + '_' + s]) | |||
axs[i, j].set_title(m + ' ' + s) | |||
plt.show() | |||
return stats |
@@ -0,0 +1,55 @@ | |||
import torch | |||
import numpy as np | |||
from sklearn.metrics import f1_score, roc_auc_score | |||
def accuracy(output, labels): | |||
preds = output.max(1)[1].type_as(labels) | |||
correct = preds.eq(labels).double() | |||
correct = correct.sum() | |||
return correct / len(labels) | |||
def class_f1(output, labels, type='micro', pos_label=1): | |||
preds = output.max(1)[1].type_as(labels) | |||
return f1_score(labels.detach().cpu().numpy(), preds.cpu(), average=type, pos_label=pos_label) | |||
def roc_auc(output, labels): | |||
return roc_auc_score(labels.cpu().numpy(), output.detach().cpu().numpy()) | |||
def loss(output,labels, weights=None): | |||
if weights is None: | |||
weights = torch.ones(labels.shape[0]) | |||
return torch.sum(- weights * (labels.float() * output).sum(1), -1) | |||
def half_normalize(mx): | |||
rowsum = mx.sum(1).float() | |||
r_inv = rowsum.pow(-1).flatten() | |||
r_inv[torch.isinf(r_inv)] = 0. | |||
r_mat_inv = torch.diag(r_inv) | |||
mx = r_mat_inv.mm(mx) | |||
return mx | |||
def encode_onehot_torch(labels,num_classes=None): | |||
if num_classes is None: | |||
num_classes = int(labels.max() + 1) | |||
y = torch.eye(num_classes) | |||
return y[labels] | |||
def calculate_imbalance_weight(idx,labels): | |||
weights = torch.ones(len(labels)) | |||
for i in range(labels.max()+1): | |||
sub_node = torch.where(labels == i)[0] | |||
sub_idx = [x.item() for x in sub_node if x in idx] | |||
weights[sub_idx] = 1 - len(sub_idx)/ len(idx) | |||
return weights | |||