You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

readme.md 4.3KB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference
  2. ====
  3. Here is the code for node classification in graphs when the graph is not available at test time.
  4. Ghorbani et.al. "GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference" [1]
  5. Usage
  6. ------------
  7. The main file is "main.py". Run with ```python train.py```
  8. Input Data
  9. ------------
  10. For running the code, you need to load data in the main.py. adjacency matrices, features, labels, training, validation, and test indices should be returned in this function. More description about each variable is as follows:
  11. - adj: is a sparse tensor showing the **normalized** adjacency matrix between all nodes (train, validation and test). It should be noted that validation and test nodes only has self-loop without any edge to other nodes.
  12. - Features: is a tensor that includes the features of all nodes (N by F).
  13. - labels: is a list of labels for all nodes (with length N)
  14. - idx_train, idx_val, idx_test: are lists of indexes for training, validation, and test samples respectively.
  15. Parameters
  16. ------------
  17. Here is a list of parameters that should be passed to the main function or set in the code:
  18. - seed: seed number
  19. - use-cuda: using CUDA for training if it is available
  20. - epochs_teacher: number of epochs for training the teacher network (default: 300)
  21. - epochs_student: number of epochs for training the student network (default: 200)
  22. - epochs_lpa: number of epochs for running label-propagation algorithm (default: 10)
  23. - lr_teacher: learning rate for the teacher network (default: 0.005)
  24. - lr_student: learning rate for the student network (default: 0.005)
  25. - wd_teacher: weight decay for the teacher network (default: 5e-4)
  26. - wd_student: weight decay for the student network (default: 5e-4)
  27. - dropout_teacher: dropout for the teacher network (default: 0.3)
  28. - dropout_student: dropout for the student network (default: 0.3)
  29. - burn_out_teacher: Number of epochs to drop for selecting best parameters based on validation set for teacher network (default: 100)
  30. - burn_out_student: Number of epochs to drop for selecting best parameters based on validation set for student network (default: 100)
  31. - alpha: a float number between 0 and 1 that shows the coefficient of remembrance term (default: 0.1)
  32. - hidden_teacher: a list of hidden neurons in each layer of the teacher network. This variable should be set in the code (default: [8] which is a network with one hidden layer with eight neurons in it)
  33. - hidden_student: a list of hidden neurons in each layer of the student network. This variable should be set in the code (default: [4])
  34. - best_metric_teacher: to select the best output of teacher network, we use the performance of the network on the validation set based on this score (should be a combination between [loss, acc, f1macro] and [train, val, test]).
  35. - best_metric_student: to select the best output of student network, we use the performance of the network on the validation set based on this score.
  36. Metrics
  37. ------------
  38. Accuracy, macro F1 are calculated in the code. ROAUC can be calculated for binary classification tasks.
  39. Note
  40. ------------
  41. Thanks to Thomas Kipf. The code is written based on the "Graph Convolutional Networks in PyTorch" [2].
  42. Bug Report
  43. ------------
  44. If you find a bug, please send email to [email protected] including if necessary the input file and the parameters that caused the bug.
  45. You can also send me any comment or suggestion about the program.
  46. References
  47. ------------
  48. [1] [Ghorbani, Mahsa, et al. "GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference." International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2021.](https://arxiv.org/pdf/2104.03597)
  49. [2] [Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016](https://arxiv.org/abs/1609.02907)
  50. Cite
  51. ------------
  52. Please cite our paper if you use this code in your own work:
  53. ```
  54. @inproceedings{ghorbani2021gkd,
  55. title={GKD: Semi-supervised Graph Knowledge Distillation for Graph-Independent Inference},
  56. author={Ghorbani, Mahsa and Bahrami, Mojtaba and Kazi, Anees and Soleymani Baghshah, Mahdieh and Rabiee, Hamid R and Navab, Nassir},
  57. booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
  58. pages={709--718},
  59. year={2021},
  60. organization={Springer}
  61. }
  62. ```