GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference
Here is the code for node classification in graphs when the graph is not available at test time.
Ghorbani et.al. “GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference” [1]
Usage
The main file is “main.py”. Run with python train.py
For running the code, you need to load data in the main.py. adjacency matrices, features, labels, training, validation, and test indices should be returned in this function. More description about each variable is as follows:
- adj: is a sparse tensor showing the normalized adjacency matrix between all nodes (train, validation and test). It should be noted that validation and test nodes only has self-loop without any edge to other nodes.
- Features: is a tensor that includes the features of all nodes (N by F).
- labels: is a list of labels for all nodes (with length N)
- idx_train, idx_val, idx_test: are lists of indexes for training, validation, and test samples respectively.
Parameters
Here is a list of parameters that should be passed to the main function or set in the code:
- seed: seed number
- use-cuda: using CUDA for training if it is available
- epochs_teacher: number of epochs for training the teacher network (default: 300)
- epochs_student: number of epochs for training the student network (default: 200)
- epochs_lpa: number of epochs for running label-propagation algorithm (default: 10)
- lr_teacher: learning rate for the teacher network (default: 0.005)
- lr_student: learning rate for the student network (default: 0.005)
- wd_teacher: weight decay for the teacher network (default: 5e-4)
- wd_student: weight decay for the student network (default: 5e-4)
- dropout_teacher: dropout for the teacher network (default: 0.3)
- dropout_student: dropout for the student network (default: 0.3)
- burn_out_teacher: Number of epochs to drop for selecting best parameters based on validation set for teacher network (default: 100)
- burn_out_student: Number of epochs to drop for selecting best parameters based on validation set for student network (default: 100)
- alpha: a float number between 0 and 1 that shows the coefficient of remembrance term (default: 0.1)
- hidden_teacher: a list of hidden neurons in each layer of the teacher network. This variable should be set in the code (default: [8] which is a network with one hidden layer with eight neurons in it)
- hidden_student: a list of hidden neurons in each layer of the student network. This variable should be set in the code (default: [4])
- best_metric_teacher: to select the best output of teacher network, we use the performance of the network on the validation set based on this score (should be a combination between [loss, acc, f1macro] and [train, val, test]).
- best_metric_student: to select the best output of student network, we use the performance of the network on the validation set based on this score.
Metrics
Accuracy, macro F1 are calculated in the code. ROAUC can be calculated for binary classification tasks.
Note
Thanks to Thomas Kipf. The code is written based on the “Graph Convolutional Networks in PyTorch” [2].
Bug Report
If you find a bug, please send email to [email protected] including if necessary the input file and the parameters that caused the bug.
You can also send me any comment or suggestion about the program.
References
[1] Ghorbani, Mahsa, et al. “GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference.” International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2021.
[2] Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016
Cite
Please cite our paper if you use this code in your own work:
@inproceedings{ghorbani2021gkd,
title={GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference},
author={Ghorbani, Mahsa and Bahrami, Mojtaba and Kazi, Anees and Soleymani Baghshah, Mahdieh and Rabiee, Hamid R and Navab, Nassir},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={709--718},
year={2021},
organization={Springer}
}