You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_gkd.py 8.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import matplotlib.pyplot as plt
  2. import torch
  3. import torch.optim as optim
  4. from utils import accuracy, encode_onehot_torch, class_f1, roc_auc, half_normalize, loss, calculate_imbalance_weight
  5. from models import Fully
  6. def run_gkd(adj, features, labels, idx_train, idx_val, idx_test, idx_connected,
  7. params_teacher, params_student, params_lpa, isolated_test, sample_weight, use_cuda=True, show_stats=False):
  8. def train_teacher(epoch):
  9. model_teacher.train()
  10. optimizer_teacher.zero_grad()
  11. output = model_teacher(features)
  12. loss_train = loss(output[idx_train], labels_one_hot[idx_train], sample_weight[idx_train])
  13. acc_train = accuracy(output[idx_train], labels[idx_train])
  14. stats['f1macro_train'].append(class_f1(output[idx_train], labels[idx_train], type='macro'))
  15. if 'auc_train' in stats:
  16. stats['auc_train'].append(roc_auc(output[idx_train, 1], labels[idx_train]))
  17. loss_train.backward()
  18. optimizer_teacher.step()
  19. model_teacher.eval()
  20. output = model_teacher(features)
  21. loss_val = loss(output[idx_val], labels_one_hot[idx_val], sample_weight[idx_val])
  22. acc_val = accuracy(output[idx_val], labels[idx_val])
  23. if 'auc_val' in stats:
  24. stats['auc_val'].append(roc_auc(output[idx_val, 1], labels[idx_val]))
  25. stats['f1macro_val'].append(class_f1(output[idx_val], labels[idx_val], type='macro'))
  26. stats['loss_train'].append(loss_train.item())
  27. stats['acc_train'].append(acc_train.item())
  28. stats['loss_val'].append(loss_val.item())
  29. stats['acc_val'].append(acc_val.item())
  30. print('Epoch: {:04d}'.format(epoch + 1),
  31. 'loss_train: {:.4f}'.format(loss_train.item()),
  32. 'acc_train: {:.4f}'.format(acc_train.item()),
  33. 'loss_val: {:.4f}'.format(loss_val.item()),
  34. 'acc_val: {:.4f}'.format(acc_val.item()))
  35. return output.detach()
  36. def test_teacher():
  37. model_teacher.eval()
  38. output = model_teacher(features)
  39. loss_test = loss(output[idx_test], labels_one_hot[idx_test], sample_weight[idx_test])
  40. acc_test = accuracy(output[idx_test], labels[idx_test])
  41. stats['f1macro_test'].append(class_f1(output[idx_test], labels[idx_test], type='macro'))
  42. stats['loss_test'].append(loss_test.item())
  43. stats['acc_test'].append(acc_test.item())
  44. if 'auc_test' in stats:
  45. stats['auc_test'].append(roc_auc(output[idx_test, 1], labels[idx_test]))
  46. print("Test set results:",
  47. "loss= {:.4f}".format(loss_test.item()),
  48. "accuracy= {:.4f}".format(acc_test.item()))
  49. def train_student(epoch):
  50. model_student.train()
  51. optimizer_student.zero_grad()
  52. output = model_student(features)
  53. loss_train = loss(output[idx_lpa], labels_lpa[idx_lpa], sample_weight[idx_lpa])
  54. acc_train = accuracy(output[idx_lpa], labels[idx_lpa])
  55. stats['f1macro_train'].append(class_f1(output[idx_train], labels[idx_train], type='macro'))
  56. if 'auc_train' in stats:
  57. stats['auc_train'].append(roc_auc(output[idx_train, 1], labels[idx_train]))
  58. if 'f1binary_train' in stats:
  59. stats['f1binary_train'].append(class_f1(output[idx_train], labels[idx_train], type='binary'))
  60. loss_train.backward()
  61. optimizer_student.step()
  62. model_student.eval()
  63. output = model_student(features)
  64. loss_val = loss(output[idx_val], labels_one_hot[idx_val], sample_weight[idx_val])
  65. acc_val = accuracy(output[idx_val], labels[idx_val])
  66. if 'auc_val' in stats:
  67. stats['auc_val'].append(roc_auc(output[idx_val, 1], labels[idx_val]))
  68. stats['f1macro_val'].append(class_f1(output[idx_val], labels[idx_val], type='macro'))
  69. if 'f1binary_val' in stats:
  70. stats['f1binary_val'].append(class_f1(output[idx_val], labels[idx_val], type='binary'))
  71. stats['loss_train'].append(loss_train.item())
  72. stats['acc_train'].append(acc_train.item())
  73. stats['loss_val'].append(loss_val.item())
  74. stats['acc_val'].append(acc_val.item())
  75. print('Epoch: {:04d}'.format(epoch + 1),
  76. 'loss_train: {:.4f}'.format(loss_train.item()),
  77. 'acc_train: {:.4f}'.format(acc_train.item()),
  78. 'loss_val: {:.4f}'.format(loss_val.item()),
  79. 'acc_val: {:.4f}'.format(acc_val.item()))
  80. def test_student():
  81. model_student.eval()
  82. output = model_student(features)
  83. loss_test = loss(output[idx_test], labels_one_hot[idx_test], sample_weight[idx_test])
  84. acc_test = accuracy(output[idx_test], labels[idx_test])
  85. stats['f1macro_test'].append(class_f1(output[idx_test], labels[idx_test], type='macro'))
  86. if 'f1binary_test' in stats:
  87. stats['f1binary_test'].append(class_f1(output[idx_test], labels[idx_test], type='binary'))
  88. stats['loss_test'].append(loss_test.item())
  89. stats['acc_test'].append(acc_test.item())
  90. if 'auc_test' in stats:
  91. stats['auc_test'].append(roc_auc(output[idx_test, 1], labels[idx_test]))
  92. print("Test set results:",
  93. "loss= {:.4f}".format(loss_test.item()),
  94. "accuracy= {:.4f}".format(acc_test.item()))
  95. return
  96. stats = dict()
  97. states = ['train', 'val', 'test']
  98. metrics = ['loss', 'acc', 'f1macro','auc']
  99. for s in states:
  100. for m in metrics:
  101. stats[m + '_' + s] = []
  102. labels_one_hot = encode_onehot_torch(labels)
  103. idx_lpa = list(idx_connected) + list(idx_train)
  104. if isolated_test:
  105. idx_lpa = [x for x in idx_lpa if (x not in idx_test) and (x not in idx_val)]
  106. idx_lpa = torch.LongTensor(idx_lpa)
  107. model_teacher = Fully(nfeat=features.shape[1],
  108. nhid=params_teacher['hidden'],
  109. nclass=int(labels_one_hot.shape[1]),
  110. dropout=params_teacher['dropout'])
  111. model_student = Fully(nfeat=features.shape[1],
  112. nhid=params_student['hidden'],
  113. nclass=int(labels_one_hot.shape[1]),
  114. dropout=params_student['dropout'])
  115. optimizer_teacher = optim.Adam(model_teacher.parameters(),
  116. lr=params_teacher['lr'],
  117. weight_decay=params_teacher['weight_decay'])
  118. optimizer_student = optim.Adam(model_student.parameters(),
  119. lr=params_student['lr'],
  120. weight_decay=params_student['weight_decay'])
  121. if use_cuda:
  122. model_student.cuda()
  123. model_teacher.cuda()
  124. features = features.cuda()
  125. sample_weight = sample_weight.cuda()
  126. adj = adj.cuda()
  127. labels = labels.cuda()
  128. labels_one_hot = labels_one_hot.cuda()
  129. idx_train = idx_train.cuda()
  130. idx_val = idx_val.cuda()
  131. idx_test = idx_test.cuda()
  132. idx_lpa = idx_lpa.cuda()
  133. labels_lpa = labels_one_hot
  134. best_metric_val = 0
  135. for epoch in range(params_teacher['epochs']):
  136. output_fc = train_teacher(epoch)
  137. test_teacher()
  138. if stats[params_teacher['best_metric']][-1] >= best_metric_val and epoch > params_teacher['burn_out']:
  139. best_metric_val = stats[params_teacher['best_metric']][-1]
  140. best_output = output_fc
  141. alpha = params_lpa['alpha']
  142. labels_start = torch.exp(best_output)
  143. labels_lpa = labels_start
  144. labels_lpa[idx_train, :] = labels_one_hot[idx_train, :].float()
  145. for i in range(params_lpa['epochs']):
  146. labels_lpa = (1 - alpha) * adj.mm(labels_lpa) + alpha * labels_start
  147. labels_lpa[idx_train, :] = labels_one_hot[idx_train, :].float()
  148. labels_lpa = half_normalize(labels_lpa)
  149. sample_weight = calculate_imbalance_weight(idx_lpa, labels_lpa.argmax(1))
  150. if use_cuda:
  151. sample_weight = sample_weight.cuda()
  152. ### empty stats
  153. for k, v in stats.items():
  154. stats[k] = []
  155. for epoch in range(params_student['epochs']):
  156. train_student(epoch)
  157. test_student()
  158. if show_stats:
  159. plt.figure()
  160. fig, axs = plt.subplots(len(states), len(metrics), figsize=(5 * len(metrics), 3 * len(states)))
  161. for i, s in enumerate(states):
  162. for j, m in enumerate(metrics):
  163. axs[i, j].plot(stats[m + '_' + s])
  164. axs[i, j].set_title(m + ' ' + s)
  165. plt.show()
  166. return stats