Browse Source

Final Commit?

master
Mahsa 2 years ago
commit
8e9a3ed855
6 changed files with 490 additions and 0 deletions
  1. BIN
      data/synthetic/sample-2000-f_128-g_4-gt_0.5-sp_0.5-100-200-500.pkl
  2. 129
    0
      main.py
  3. 37
    0
      models.py
  4. 76
    0
      readme.md
  5. 193
    0
      train_gkd.py
  6. 55
    0
      utils.py

BIN
data/synthetic/sample-2000-f_128-g_4-gt_0.5-sp_0.5-100-200-500.pkl View File


+ 129
- 0
main.py View File

@@ -0,0 +1,129 @@
import torch
import numpy as np
import os
import random

import argparse
import pickle
from train_gkd import run_gkd
from utils import calculate_imbalance_weight


def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


def last_argmax(l):
reverse = l[::-1]
return len(reverse) - np.argmax(reverse) - 1


def find_isolated(adj, idx_train):
adj = adj.to_dense()
deg = adj.sum(1)
idx_connected = torch.where(deg > 0)[0]
idx_connected = [x for x in idx_connected if x not in idx_train]
return idx_connected


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--seed', default=42, type=int, help='Seed')
parser.add_argument('--use_cuda', action='store_true', default=True,
help='Use CUDA if available.')
parser.add_argument('--epochs_teacher', default=300, type=int, help='Number of epochs to train teacher network')
parser.add_argument('--epochs_student', default=200, type=int, help='Number of epochs to train student network')
parser.add_argument('--epochs_lpa', default=10, type=int, help='Number of epochs to train student network')

parser.add_argument('--lr_teacher', type=float, default=0.005, help='Learning rate for teacher network.')
parser.add_argument('--lr_student', type=float, default=0.005, help='Learning rate for student network.')

parser.add_argument('--wd_teacher', type=float, default=5e-4, help='Weight decay for teacher network.')
parser.add_argument('--wd_student', type=float, default=5e-4, help='Weight decay for student network.')

parser.add_argument('--dropout_teacher', type=float, default=0.3, help='Dropout for teacher network.')
parser.add_argument('--dropout_student', type=float, default=0.3, help='Dropout for student network.')

parser.add_argument('--burn_out_teacher', default=100, type=int, help='Number of epochs to drop for selecting best \
parameters based on validation set for teacher network')
parser.add_argument('--burn_out_student', default=100, type=int, help='Number of epochs to drop for selecting best \
parameters based on validation set for student network')

parser.add_argument('--alpha', default=0.1, type=float, help='alpha')

args = parser.parse_args()
seed = args.seed
use_cuda = args.use_cuda and torch.cuda.is_available()
os.environ['PYTHONHASHSEED'] = str(seed)
# Torch RNG
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Python RNG
np.random.seed(seed)
random.seed(seed)
### Following lists shows the number of hidden neurons in each hidden layer of teacher network and student network respectively
hidden_teacher = [8]
hidden_student = [4]
### select the best output of teacher for lpa and best output of student for reporting based on the following metrics
best_metric_teacher = 'f1macro_val' ### concat a member of these lists: [loss, acc, f1macro][train, val, test]
best_metric_student = 'f1macro_val'

### show the changes of statistics in training, validation and test set with figures
show_stats = False
### This is needed is you want to make sure that access to test set is not available during the training
isolated_test = True

### Data should be loaded here as follow:
### adj is a sparse tensor showing the adjacency matrix between nodes with size N by N
### Features is a tensor with size N by F
### labels is a list of node labels
### idx train is a list contains the index of training samples. idx_val and idx_test follow the same pattern
with open('./data/synthetic/sample-2000-f_128-g_4-gt_0.5-sp_0.5-100-200-500.pkl', 'rb') as f: # Python 3: open(..., 'rb')
adj, features, labels, idx_train, idx_val, idx_test = pickle.load(f)
### you can load weights for samples in the training set or calculate them here.
### the weights of nodes will be calculates in the train_gkd after the lpa step
sample_weight = calculate_imbalance_weight(idx_train, labels)

idx_connected = find_isolated(adj, idx_train)

params_teacher = {
'lr': args.lr_teacher,
'hidden': hidden_teacher,
'weight_decay': args.wd_teacher,
'dropout': args.dropout_teacher,
'epochs': args.epochs_teacher,
'best_metric': best_metric_teacher,
'burn_out': args.burn_out_teacher
}
params_student = {
'lr': args.lr_student,
'hidden': hidden_student,
'weight_decay': args.wd_student,
'dropout': args.dropout_teacher,
'epochs': args.epochs_student,
'best_metric': best_metric_student,
'burn_out': args.burn_out_student
}
params_lpa = {
'epochs': args.epochs_lpa,
'alpha': args.alpha
}

stats = run_gkd(adj, features, labels, idx_train,
idx_val,idx_test, idx_connected, params_teacher, params_student, params_lpa,
sample_weight=sample_weight, isolated_test=isolated_test,
use_cuda=use_cuda, show_stats=show_stats)

ind_val_max = params_student['burn_out'] + last_argmax(stats[params_student['best_metric']][params_student['burn_out']:])
print('Last index with maximum metric on validation set: ' + str(ind_val_max))
print('Final accuracy on val: ' + str(stats['acc_test'][ind_val_max]))
print('Final Macro F1 on val: ' + str(stats['f1macro_test'][ind_val_max]))
print('Final AUC on val' + str(stats['auc_test'][ind_val_max]))

+ 37
- 0
models.py View File

@@ -0,0 +1,37 @@
import torch.nn as nn
import torch.nn.functional as F


class Fully(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout,multi_label=False):
super(Fully, self).__init__()
self.multi_label = multi_label
layers = []
if len(nhid) == 0:
layers.append(nn.Linear(nfeat, nclass))
else:
layers.append(nn.Linear(nfeat, nhid[0]))
for i in range(len(nhid) - 1):
layers.append(nn.Linear(nhid[i], nhid[i + 1]))
if nclass > 1:
layers.append(nn.Linear(nhid[-1], nclass))
self.fc = nn.ModuleList(layers)

self.dropout = dropout
self.nclass = nclass

def forward(self, x, adj=None):
end_layer = len(self.fc) - 1 if self.nclass > 1 else len(self.fc)
for i in range(end_layer):
x = F.dropout(x, self.dropout, training=self.training)
x = self.fc[i](x)
x = F.relu(x)

classifier = self.fc[-1](x)
if self.multi_label:
classifier = classifier
else:
classifier = F.log_softmax(classifier, dim=1)
return classifier



+ 76
- 0
readme.md View File

@@ -0,0 +1,76 @@
GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference
====

Here is the code for node classification in graphs when the graph is not available at test time.
Ghorbani et.al. "GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference" [1]



Usage
------------
The main file is "main.py". Run with ```python train.py```


Input Data
------------
For running the code, you need to load data in the main.py. adjacency matrices, features, labels, training, validation, and test indices should be returned in this function. More description about each variable is as follows:
- adj: is a sparse tensor showing the **normalized** adjacency matrix between all nodes (train, validation and test). It should be noted that validation and test nodes only has self-loop without any edge to other nodes.
- Features: is a tensor that includes the features of all nodes (N by F).
- labels: is a list of labels for all nodes (with length N)
- idx_train, idx_val, idx_test: are lists of indexes for training, validation, and test samples respectively.

Parameters
------------
Here is a list of parameters that should be passed to the main function or set in the code:
- seed: seed number
- use-cuda: using CUDA for training if it is available
- epochs_teacher: number of epochs for training the teacher network (default: 300)
- epochs_student: number of epochs for training the student network (default: 200)
- epochs_lpa: number of epochs for running label-propagation algorithm (default: 10)
- lr_teacher: learning rate for the teacher network (default: 0.005)
- lr_student: learning rate for the student network (default: 0.005)
- wd_teacher: weight decay for the teacher network (default: 5e-4)
- wd_student: weight decay for the student network (default: 5e-4)
- dropout_teacher: dropout for the teacher network (default: 0.3)
- dropout_student: dropout for the student network (default: 0.3)
- burn_out_teacher: Number of epochs to drop for selecting best parameters based on validation set for teacher network (default: 100)
- burn_out_student: Number of epochs to drop for selecting best parameters based on validation set for student network (default: 100)
- alpha: a float number between 0 and 1 that shows the coefficient of remembrance term (default: 0.1)
- hidden_teacher: a list of hidden neurons in each layer of the teacher network. This variable should be set in the code (default: [8] which is a network with one hidden layer with eight neurons in it)
- hidden_student: a list of hidden neurons in each layer of the student network. This variable should be set in the code (default: [4])
- best_metric_teacher: to select the best output of teacher network, we use the performance of the network on the validation set based on this score (should be a combination between [loss, acc, f1macro] and [train, val, test]).
- best_metric_student: to select the best output of student network, we use the performance of the network on the validation set based on this score.

Metrics
------------
Accuracy, macro F1 are calculated in the code. ROAUC can be calculated for binary classification tasks.

Note
------------
Thanks to Thomas Kipf. The code is written based on the "Graph Convolutional Networks in PyTorch" [2].

Bug Report
------------
If you find a bug, please send email to [email protected] including if necessary the input file and the parameters that caused the bug.
You can also send me any comment or suggestion about the program.

References
------------
[1] [Ghorbani, Mahsa, et al. "GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference." International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2021.](https://arxiv.org/pdf/2104.03597)

[2] [Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016](https://arxiv.org/abs/1609.02907)

Cite
------------
Please cite our paper if you use this code in your own work:

```
@inproceedings{ghorbani2021gkd,
title={GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference},
author={Ghorbani, Mahsa and Bahrami, Mojtaba and Kazi, Anees and Soleymani Baghshah, Mahdieh and Rabiee, Hamid R and Navab, Nassir},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={709--718},
year={2021},
organization={Springer}
}
```

+ 193
- 0
train_gkd.py View File

@@ -0,0 +1,193 @@
import matplotlib.pyplot as plt

import torch
import torch.optim as optim

from utils import accuracy, encode_onehot_torch, class_f1, roc_auc, half_normalize, loss, calculate_imbalance_weight
from models import Fully


def run_gkd(adj, features, labels, idx_train, idx_val, idx_test, idx_connected,
params_teacher, params_student, params_lpa, isolated_test, sample_weight, use_cuda=True, show_stats=False):

def train_teacher(epoch):
model_teacher.train()
optimizer_teacher.zero_grad()
output = model_teacher(features)
loss_train = loss(output[idx_train], labels_one_hot[idx_train], sample_weight[idx_train])
acc_train = accuracy(output[idx_train], labels[idx_train])
stats['f1macro_train'].append(class_f1(output[idx_train], labels[idx_train], type='macro'))
if 'auc_train' in stats:
stats['auc_train'].append(roc_auc(output[idx_train, 1], labels[idx_train]))
loss_train.backward()
optimizer_teacher.step()

model_teacher.eval()
output = model_teacher(features)

loss_val = loss(output[idx_val], labels_one_hot[idx_val], sample_weight[idx_val])
acc_val = accuracy(output[idx_val], labels[idx_val])
if 'auc_val' in stats:
stats['auc_val'].append(roc_auc(output[idx_val, 1], labels[idx_val]))
stats['f1macro_val'].append(class_f1(output[idx_val], labels[idx_val], type='macro'))
stats['loss_train'].append(loss_train.item())
stats['acc_train'].append(acc_train.item())
stats['loss_val'].append(loss_val.item())
stats['acc_val'].append(acc_val.item())

print('Epoch: {:04d}'.format(epoch + 1),
'loss_train: {:.4f}'.format(loss_train.item()),
'acc_train: {:.4f}'.format(acc_train.item()),
'loss_val: {:.4f}'.format(loss_val.item()),
'acc_val: {:.4f}'.format(acc_val.item()))

return output.detach()

def test_teacher():
model_teacher.eval()
output = model_teacher(features)
loss_test = loss(output[idx_test], labels_one_hot[idx_test], sample_weight[idx_test])
acc_test = accuracy(output[idx_test], labels[idx_test])
stats['f1macro_test'].append(class_f1(output[idx_test], labels[idx_test], type='macro'))
stats['loss_test'].append(loss_test.item())
stats['acc_test'].append(acc_test.item())
if 'auc_test' in stats:
stats['auc_test'].append(roc_auc(output[idx_test, 1], labels[idx_test]))

print("Test set results:",
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test.item()))

def train_student(epoch):
model_student.train()
optimizer_student.zero_grad()
output = model_student(features)
loss_train = loss(output[idx_lpa], labels_lpa[idx_lpa], sample_weight[idx_lpa])
acc_train = accuracy(output[idx_lpa], labels[idx_lpa])
stats['f1macro_train'].append(class_f1(output[idx_train], labels[idx_train], type='macro'))
if 'auc_train' in stats:
stats['auc_train'].append(roc_auc(output[idx_train, 1], labels[idx_train]))
if 'f1binary_train' in stats:
stats['f1binary_train'].append(class_f1(output[idx_train], labels[idx_train], type='binary'))
loss_train.backward()
optimizer_student.step()

model_student.eval()
output = model_student(features)
loss_val = loss(output[idx_val], labels_one_hot[idx_val], sample_weight[idx_val])
acc_val = accuracy(output[idx_val], labels[idx_val])
if 'auc_val' in stats:
stats['auc_val'].append(roc_auc(output[idx_val, 1], labels[idx_val]))
stats['f1macro_val'].append(class_f1(output[idx_val], labels[idx_val], type='macro'))
if 'f1binary_val' in stats:
stats['f1binary_val'].append(class_f1(output[idx_val], labels[idx_val], type='binary'))
stats['loss_train'].append(loss_train.item())
stats['acc_train'].append(acc_train.item())
stats['loss_val'].append(loss_val.item())
stats['acc_val'].append(acc_val.item())

print('Epoch: {:04d}'.format(epoch + 1),
'loss_train: {:.4f}'.format(loss_train.item()),
'acc_train: {:.4f}'.format(acc_train.item()),
'loss_val: {:.4f}'.format(loss_val.item()),
'acc_val: {:.4f}'.format(acc_val.item()))

def test_student():
model_student.eval()
output = model_student(features)
loss_test = loss(output[idx_test], labels_one_hot[idx_test], sample_weight[idx_test])
acc_test = accuracy(output[idx_test], labels[idx_test])
stats['f1macro_test'].append(class_f1(output[idx_test], labels[idx_test], type='macro'))
if 'f1binary_test' in stats:
stats['f1binary_test'].append(class_f1(output[idx_test], labels[idx_test], type='binary'))
stats['loss_test'].append(loss_test.item())
stats['acc_test'].append(acc_test.item())
if 'auc_test' in stats:
stats['auc_test'].append(roc_auc(output[idx_test, 1], labels[idx_test]))

print("Test set results:",
"loss= {:.4f}".format(loss_test.item()),
"accuracy= {:.4f}".format(acc_test.item()))
return

stats = dict()
states = ['train', 'val', 'test']
metrics = ['loss', 'acc', 'f1macro','auc']

for s in states:
for m in metrics:
stats[m + '_' + s] = []

labels_one_hot = encode_onehot_torch(labels)
idx_lpa = list(idx_connected) + list(idx_train)
if isolated_test:
idx_lpa = [x for x in idx_lpa if (x not in idx_test) and (x not in idx_val)]
idx_lpa = torch.LongTensor(idx_lpa)
model_teacher = Fully(nfeat=features.shape[1],
nhid=params_teacher['hidden'],
nclass=int(labels_one_hot.shape[1]),
dropout=params_teacher['dropout'])
model_student = Fully(nfeat=features.shape[1],
nhid=params_student['hidden'],
nclass=int(labels_one_hot.shape[1]),
dropout=params_student['dropout'])

optimizer_teacher = optim.Adam(model_teacher.parameters(),
lr=params_teacher['lr'],
weight_decay=params_teacher['weight_decay'])
optimizer_student = optim.Adam(model_student.parameters(),
lr=params_student['lr'],
weight_decay=params_student['weight_decay'])

if use_cuda:
model_student.cuda()
model_teacher.cuda()
features = features.cuda()
sample_weight = sample_weight.cuda()
adj = adj.cuda()
labels = labels.cuda()
labels_one_hot = labels_one_hot.cuda()
idx_train = idx_train.cuda()
idx_val = idx_val.cuda()
idx_test = idx_test.cuda()
idx_lpa = idx_lpa.cuda()

labels_lpa = labels_one_hot
best_metric_val = 0
for epoch in range(params_teacher['epochs']):
output_fc = train_teacher(epoch)
test_teacher()
if stats[params_teacher['best_metric']][-1] >= best_metric_val and epoch > params_teacher['burn_out']:
best_metric_val = stats[params_teacher['best_metric']][-1]
best_output = output_fc

alpha = params_lpa['alpha']
labels_start = torch.exp(best_output)
labels_lpa = labels_start
labels_lpa[idx_train, :] = labels_one_hot[idx_train, :].float()
for i in range(params_lpa['epochs']):
labels_lpa = (1 - alpha) * adj.mm(labels_lpa) + alpha * labels_start
labels_lpa[idx_train, :] = labels_one_hot[idx_train, :].float()
labels_lpa = half_normalize(labels_lpa)

sample_weight = calculate_imbalance_weight(idx_lpa, labels_lpa.argmax(1))
if use_cuda:
sample_weight = sample_weight.cuda()
### empty stats
for k, v in stats.items():
stats[k] = []

for epoch in range(params_student['epochs']):
train_student(epoch)
test_student()

if show_stats:
plt.figure()
fig, axs = plt.subplots(len(states), len(metrics), figsize=(5 * len(metrics), 3 * len(states)))
for i, s in enumerate(states):
for j, m in enumerate(metrics):
axs[i, j].plot(stats[m + '_' + s])
axs[i, j].set_title(m + ' ' + s)
plt.show()

return stats

+ 55
- 0
utils.py View File

@@ -0,0 +1,55 @@
import torch
import numpy as np
from sklearn.metrics import f1_score, roc_auc_score


def accuracy(output, labels):
preds = output.max(1)[1].type_as(labels)
correct = preds.eq(labels).double()
correct = correct.sum()
return correct / len(labels)


def class_f1(output, labels, type='micro', pos_label=1):
preds = output.max(1)[1].type_as(labels)
return f1_score(labels.detach().cpu().numpy(), preds.cpu(), average=type, pos_label=pos_label)


def roc_auc(output, labels):
return roc_auc_score(labels.cpu().numpy(), output.detach().cpu().numpy())


def loss(output,labels, weights=None):
if weights is None:
weights = torch.ones(labels.shape[0])
return torch.sum(- weights * (labels.float() * output).sum(1), -1)


def half_normalize(mx):
rowsum = mx.sum(1).float()
r_inv = rowsum.pow(-1).flatten()
r_inv[torch.isinf(r_inv)] = 0.
r_mat_inv = torch.diag(r_inv)
mx = r_mat_inv.mm(mx)
return mx


def encode_onehot_torch(labels,num_classes=None):
if num_classes is None:
num_classes = int(labels.max() + 1)
y = torch.eye(num_classes)
return y[labels]


def calculate_imbalance_weight(idx,labels):
weights = torch.ones(len(labels))
for i in range(labels.max()+1):
sub_node = torch.where(labels == i)[0]
sub_idx = [x.item() for x in sub_node if x in idx]
weights[sub_idx] = 1 - len(sub_idx)/ len(idx)
return weights






Loading…
Cancel
Save