You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import torch
  2. import numpy as np
  3. import os
  4. import random
  5. import argparse
  6. import pickle
  7. from train_gkd import run_gkd
  8. from utils import calculate_imbalance_weight
  9. def str2bool(v):
  10. if isinstance(v, bool):
  11. return v
  12. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  13. return True
  14. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  15. return False
  16. else:
  17. raise argparse.ArgumentTypeError('Boolean value expected.')
  18. def last_argmax(l):
  19. reverse = l[::-1]
  20. return len(reverse) - np.argmax(reverse) - 1
  21. def find_isolated(adj, idx_train):
  22. adj = adj.to_dense()
  23. deg = adj.sum(1)
  24. idx_connected = torch.where(deg > 0)[0]
  25. idx_connected = [x for x in idx_connected if x not in idx_train]
  26. return idx_connected
  27. if __name__ == "__main__":
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument('--seed', default=42, type=int, help='Seed')
  30. parser.add_argument('--use_cuda', action='store_true', default=True,
  31. help='Use CUDA if available.')
  32. parser.add_argument('--epochs_teacher', default=300, type=int, help='Number of epochs to train teacher network')
  33. parser.add_argument('--epochs_student', default=200, type=int, help='Number of epochs to train student network')
  34. parser.add_argument('--epochs_lpa', default=10, type=int, help='Number of epochs to train student network')
  35. parser.add_argument('--lr_teacher', type=float, default=0.005, help='Learning rate for teacher network.')
  36. parser.add_argument('--lr_student', type=float, default=0.005, help='Learning rate for student network.')
  37. parser.add_argument('--wd_teacher', type=float, default=5e-4, help='Weight decay for teacher network.')
  38. parser.add_argument('--wd_student', type=float, default=5e-4, help='Weight decay for student network.')
  39. parser.add_argument('--dropout_teacher', type=float, default=0.3, help='Dropout for teacher network.')
  40. parser.add_argument('--dropout_student', type=float, default=0.3, help='Dropout for student network.')
  41. parser.add_argument('--burn_out_teacher', default=100, type=int, help='Number of epochs to drop for selecting best \
  42. parameters based on validation set for teacher network')
  43. parser.add_argument('--burn_out_student', default=100, type=int, help='Number of epochs to drop for selecting best \
  44. parameters based on validation set for student network')
  45. parser.add_argument('--alpha', default=0.1, type=float, help='alpha')
  46. args = parser.parse_args()
  47. seed = args.seed
  48. use_cuda = args.use_cuda and torch.cuda.is_available()
  49. os.environ['PYTHONHASHSEED'] = str(seed)
  50. # Torch RNG
  51. torch.manual_seed(seed)
  52. torch.cuda.manual_seed(seed)
  53. torch.cuda.manual_seed_all(seed)
  54. # Python RNG
  55. np.random.seed(seed)
  56. random.seed(seed)
  57. ### Following lists shows the number of hidden neurons in each hidden layer of teacher network and student network respectively
  58. hidden_teacher = [8]
  59. hidden_student = [4]
  60. ### select the best output of teacher for lpa and best output of student for reporting based on the following metrics
  61. best_metric_teacher = 'f1macro_val' ### concat a member of these lists: [loss, acc, f1macro][train, val, test]
  62. best_metric_student = 'f1macro_val'
  63. ### show the changes of statistics in training, validation and test set with figures
  64. show_stats = False
  65. ### This is needed is you want to make sure that access to test set is not available during the training
  66. isolated_test = True
  67. ### Data should be loaded here as follow:
  68. ### adj is a sparse tensor showing the adjacency matrix between nodes with size N by N
  69. ### Features is a tensor with size N by F
  70. ### labels is a list of node labels
  71. ### idx train is a list contains the index of training samples. idx_val and idx_test follow the same pattern
  72. with open('./data/synthetic/sample-2000-f_128-g_4-gt_0.5-sp_0.5-100-200-500.pkl', 'rb') as f: # Python 3: open(..., 'rb')
  73. adj, features, labels, idx_train, idx_val, idx_test = pickle.load(f)
  74. ### you can load weights for samples in the training set or calculate them here.
  75. ### the weights of nodes will be calculates in the train_gkd after the lpa step
  76. sample_weight = calculate_imbalance_weight(idx_train, labels)
  77. idx_connected = find_isolated(adj, idx_train)
  78. params_teacher = {
  79. 'lr': args.lr_teacher,
  80. 'hidden': hidden_teacher,
  81. 'weight_decay': args.wd_teacher,
  82. 'dropout': args.dropout_teacher,
  83. 'epochs': args.epochs_teacher,
  84. 'best_metric': best_metric_teacher,
  85. 'burn_out': args.burn_out_teacher
  86. }
  87. params_student = {
  88. 'lr': args.lr_student,
  89. 'hidden': hidden_student,
  90. 'weight_decay': args.wd_student,
  91. 'dropout': args.dropout_teacher,
  92. 'epochs': args.epochs_student,
  93. 'best_metric': best_metric_student,
  94. 'burn_out': args.burn_out_student
  95. }
  96. params_lpa = {
  97. 'epochs': args.epochs_lpa,
  98. 'alpha': args.alpha
  99. }
  100. stats = run_gkd(adj, features, labels, idx_train,
  101. idx_val,idx_test, idx_connected, params_teacher, params_student, params_lpa,
  102. sample_weight=sample_weight, isolated_test=isolated_test,
  103. use_cuda=use_cuda, show_stats=show_stats)
  104. ind_val_max = params_student['burn_out'] + last_argmax(stats[params_student['best_metric']][params_student['burn_out']:])
  105. print('Last index with maximum metric on validation set: ' + str(ind_val_max))
  106. print('Final accuracy on val: ' + str(stats['acc_test'][ind_val_max]))
  107. print('Final Macro F1 on val: ' + str(stats['f1macro_test'][ind_val_max]))
  108. print('Final AUC on val' + str(stats['auc_test'][ind_val_max]))