You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main_medical.py 8.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import time
  2. import argparse
  3. import torch
  4. seed_num = 17
  5. torch.manual_seed(seed_num)
  6. import torch.nn.functional as F
  7. import itertools as it
  8. import matplotlib.pyplot as plt
  9. from code.model import RAGCN
  10. from code.utils import accuracy, load_data_medical, encode_onehot_torch, class_f1, auc_score
  11. from code.model import my_sigmoid
  12. def train():
  13. ### create structure for discriminator and weighting networks
  14. struc_D = {'dropout': args.dropout_D, 'wd': 5e-4, 'lr': args.lr_D, 'nhid': structure_D}
  15. ### For simplicity, all weighting networks have the same hyper-parameters. They can be defined as a list of /
  16. # / dictionaries which each dictionary contains the hyper-parameters of each weighting network
  17. struc_Ws = n_ws*[{'dropout': args.dropout_W, 'wd': 5e-4, 'lr': args.lr_W, 'nhid': structure_W}]
  18. ### stats variable keeps the statistics of the network on train, validation and test sets
  19. stats = dict()
  20. ### act is the function for normalization of weights of samples in each class
  21. act = my_sigmoid
  22. ### Definition of model
  23. model = RAGCN(adj=adj, features=features, nclass=nclass, struc_D=struc_D, struc_Ws=struc_Ws,n_ws=n_ws,
  24. weighing_output_dim=1, act=act, gamma=args.gamma)
  25. if use_cuda:
  26. model.cuda()
  27. ### Keeping the best stats based on the value of Macro F1 in validation set
  28. max_val = dict()
  29. max_val['f1Macro_val'] = 0
  30. for epoch in range(args.epochs):
  31. model.train()
  32. ### Train discriminator and weighting networks
  33. model.run_both(epoch_for_D=args.epoch_D, epoch_for_W=args.epoch_W, labels_one_hot=labels_one_hot[idx_train, :],
  34. samples=idx_train, args_cuda=use_cuda, equal_weights=False)
  35. model.eval()
  36. ### calculate stats for training set
  37. class_prob, embed = model.run_D(samples=idx_train)
  38. weights, _ = model.run_W(samples=idx_train, labels=labels[idx_train], args_cuda=use_cuda, equal_weights=False)
  39. stats['loss_train'] = model.loss_function_D(class_prob, labels_one_hot[idx_train], weights).item()
  40. stats['nll_train'] = F.nll_loss(class_prob, labels[idx_train]).item()
  41. stats['acc_train'] = accuracy(class_prob, labels=labels[idx_train]).item()
  42. stats['f1Macro_train'] = class_f1(class_prob, labels[idx_train], type='macro')
  43. if nclass == 2:
  44. stats['f1Binary_train'] = class_f1(class_prob, labels[idx_train], type='binary', pos_label=pos_label)
  45. stats['AUC_train'] = auc_score(class_prob, labels[idx_train])
  46. ### calculate stats for validation and test set
  47. test(model, stats)
  48. ### Drop first epochs and keep the best based on the macro F1 on validation set just for reporting
  49. if epoch > drop_epochs and max_val['f1Macro_val'] < stats['f1Macro_val']:
  50. for key, val in stats.items():
  51. max_val[key] = val
  52. ### Print stats in each epoch
  53. print('Epoch: {:04d}'.format(epoch + 1))
  54. print('acc_train: {:.4f}'.format(stats['acc_train']))
  55. print('f1_macro_train: {:.4f}'.format(stats['f1Macro_train']))
  56. print('loss_train: {:.4f}'.format(stats['loss_train']))
  57. ### Reporting the best results on test set
  58. print('========Results==========')
  59. for key, val in max_val.items():
  60. if 'loss' in key or 'nll' in key or 'test' not in key:
  61. continue
  62. print(key.replace('_', ' ') + ' : ' + str(val))
  63. ### Calculate metrics on validation and test sets
  64. def test(model, stats):
  65. model.eval()
  66. class_prob, embed = model.run_D(samples=idx_val)
  67. weights, _ = model.run_W(samples=idx_val, labels=labels[idx_val], args_cuda=use_cuda, equal_weights=True)
  68. stats['loss_val'] = model.loss_function_D(class_prob, labels_one_hot[idx_val], weights).item()
  69. stats['nll_val'] = F.nll_loss(class_prob, labels[idx_val]).item()
  70. stats['acc_val'] = accuracy(class_prob, labels[idx_val]).item()
  71. stats['f1Macro_val'] = class_f1(class_prob, labels[idx_val], type='macro')
  72. if nclass == 2:
  73. stats['f1Binary_val'] = class_f1(class_prob, labels[idx_val], type='binary', pos_label=pos_label)
  74. stats['AUC_val'] = auc_score(class_prob, labels[idx_val])
  75. class_prob, embed = model.run_D(samples=idx_test)
  76. weights, _ = model.run_W(samples=idx_test, labels=labels[idx_test], args_cuda=use_cuda, equal_weights=True)
  77. stats['loss_test'] = model.loss_function_D(class_prob, labels_one_hot[idx_test], weights).item()
  78. stats['nll_test'] = F.nll_loss(class_prob, labels[idx_test]).item()
  79. stats['acc_test'] = accuracy(class_prob, labels[idx_test]).item()
  80. stats['f1Macro_test'] = class_f1(class_prob, labels[idx_test], type='macro')
  81. if nclass == 2:
  82. stats['f1Binary_test'] = class_f1(class_prob, labels[idx_test], type='binary', pos_label=pos_label)
  83. stats['AUC_test'] = auc_score(class_prob, labels[idx_test])
  84. if __name__ == '__main__':
  85. parser = argparse.ArgumentParser()
  86. parser.add_argument('--epochs', type=int, default=1000,
  87. help='Number of epochs to train.')
  88. parser.add_argument('--epoch_D', type=int, default=1,
  89. help='Number of training loop for discriminator in each epoch.')
  90. parser.add_argument('--epoch_W', type=int, default=1,
  91. help='Number of training loop for discriminator in each epoch.')
  92. parser.add_argument('--lr_D', type=float, default=0.01,
  93. help='Learning rate for discriminator.')
  94. parser.add_argument('--lr_W', type=float, default=0.01,
  95. help='Equal learning rate for weighting networks.')
  96. parser.add_argument('--dropout_D', type=float, default=0.5,
  97. help='Dropout rate for discriminator.')
  98. parser.add_argument('--dropout_W', type=float, default=0.5,
  99. help='Dropout rate for weighting networks.')
  100. parser.add_argument('--gamma', type=float, default=1,
  101. help='Coefficient of entropy term in loss function.')
  102. parser.add_argument('--no-cuda', action='store_true', default=False,
  103. help='Disables CUDA training.')
  104. ### This list shows the number of hidden neurons in each hidden layer of discriminator
  105. structure_D = [2]
  106. ### This list shows the number of hidden neurons in each hidden layer of weighting networks
  107. structure_W = [4]
  108. ### The results of first drop_epochs will be dropped for choosing the best network based on the validation set
  109. drop_epochs = 500
  110. args = parser.parse_args()
  111. use_cuda = not args.no_cuda and torch.cuda.is_available()
  112. ### Loading function should return the following variables
  113. ### adj is a dictionary including 'D' and 'W' as keys
  114. ### adj['D'] contains the main adjacency matrix between all samples for discriminator
  115. ### adj['W'] contains a list of adjacency matrices. Element i contains the adjacency matrix between samples of /
  116. ### / class i in the training samples
  117. ### Features is a tensor with size N by F
  118. ### labels is a list of node labels
  119. ### idx train is a list contains the index of training samples. idx_val and idx_test follow the same pattern
  120. adj, features, labels, idx_train, idx_val, idx_test = load_data_medical(dataset_addr='../data/synthetic/per-90gt-0.5.pkl',
  121. train_ratio=0.6, test_ratio=0.2)
  122. ### start of code
  123. labels_one_hot = encode_onehot_torch(labels)
  124. nclass = labels_one_hot.shape[1]
  125. n_ws = nclass
  126. pos_label = None
  127. if nclass == 2:
  128. pos_label = 1
  129. zero_class = (labels == 0).sum()
  130. one_class = (labels == 1).sum()
  131. if zero_class < one_class:
  132. pos_label = 0
  133. if use_cuda:
  134. for key, val in adj.items():
  135. if type(val) is list:
  136. for i in range(len(adj)):
  137. adj[key][i] = adj[key][i].cuda()
  138. else:
  139. adj[key] = adj[key].cuda()
  140. features = features.cuda()
  141. labels_one_hot = labels_one_hot.cuda()
  142. labels = labels.cuda()
  143. ### Training the networks
  144. train()