You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. from __future__ import division
  2. from __future__ import print_function
  3. import time
  4. import argparse
  5. import torch
  6. import torch.nn.functional as F
  7. import torch.optim as optim
  8. from sklearn.metrics import f1_score
  9. from utils import load_data, class_accuracy, class_f1, layer_accuracy, layer_f1, dict_to_writer, trains_vals_tests_split
  10. from models import CCN
  11. from tensorboardX import SummaryWriter
  12. # f1 name should be f1_micro or f1_macro
  13. # class_metrics = {'loss': F.nll_loss, 'acc': class_accuracy, 'f1_micro': class_f1, 'f1_macro': class_f1,
  14. # 'all_f1_micro': f1_score, 'all_f1_macro': f1_score}
  15. class_metrics = {'loss': F.nll_loss, 'all_f1_micro': f1_score, 'all_f1_macro': f1_score}
  16. layer_metrics = {'loss': torch.nn.BCEWithLogitsLoss()}
  17. # Preparing the output of classifier for every layer separately (class numbers should be different over layers)
  18. def prepare_output_labels(outputs, labels, idx):
  19. tmp_labels = []
  20. tmp_outputs = []
  21. max_label = 0
  22. for i in range(len(labels)):
  23. tmp_labels.extend(labels[i][idx[i]] + max_label)
  24. tmp_outputs.extend((outputs[i].max(1)[1].type_as(labels[i]) + max_label)[idx[i]])
  25. max_label = labels[i].max()
  26. return tmp_outputs, tmp_labels
  27. # Calculate classification metrics for every layer
  28. def model_stat_in_layer_class(in_classes_pred, in_labels, idx, metrics):
  29. stats = {}
  30. for i in range(len(in_classes_pred)):
  31. if not in_classes_pred[i] is None:
  32. for metr_name,metr_func in metrics.items():
  33. if not metr_name in stats:
  34. stats[metr_name] = []
  35. if metr_name[:2] == 'f1':
  36. try:
  37. stats[metr_name].append(metr_func(in_classes_pred[i][idx[i]], in_labels[i][idx[i]], metr_name.split('f1_')[-1]))
  38. except:
  39. print("Exception in F1 for class" + str(i) + " for classification")
  40. stats[metr_name].append(0)
  41. elif metr_name[:3] == 'all':
  42. continue
  43. else:
  44. stats[metr_name].append(metr_func(in_classes_pred[i][idx[i]], in_labels[i][idx[i]]))
  45. else:
  46. for metr_name, metr_func in metrics.items():
  47. if not metr_name in stats:
  48. stats[metr_name] = []
  49. stats[metr_name].append(0)
  50. tmp_in_classes_pred, tmp_labels = prepare_output_labels(in_classes_pred, in_labels, idx)
  51. for metr_name, metr_func in metrics.items():
  52. if metr_name[:3] == 'all':
  53. stats[metr_name] = metr_func(tmp_in_classes_pred, tmp_labels, average=metr_name.split('f1_')[-1])
  54. return stats
  55. # Calculate reconstruction metrics for within and between layer edges (based on the input)
  56. def model_stat_struc(preds, reals, metrics, pos_weights, norms, idx=None):
  57. stats = {}
  58. for i in range(len(preds)):
  59. if reals[i] is None:
  60. continue
  61. if 'sparse' in reals[i].type():
  62. curr_real = reals[i].to_dense()
  63. else:
  64. curr_real = reals[i]
  65. if idx is None:
  66. final_preds = preds[i]
  67. final_real = curr_real
  68. else:
  69. final_preds = preds[i][idx[i][0],idx[i][1]]
  70. final_real = curr_real[idx[i][0],idx[i][1]]
  71. for metr_name,metr_func in metrics.items():
  72. if not metr_name in stats:
  73. stats[metr_name] = []
  74. if metr_name == 'f1':
  75. try:
  76. stats[metr_name].append(metr_func(final_preds, final_real, type = metr_name.split('f1_')[-1]))
  77. except:
  78. print("Exception in F1 for layer" + str(i) + " for structure")
  79. stats[metr_name].append(0)
  80. elif metr_name == 'loss':
  81. if args.cuda:
  82. pw = pos_weights[i] * torch.ones(final_real.size()).cuda()
  83. else:
  84. pw = pos_weights[i] * torch.ones(final_real.size())
  85. stats[metr_name].append(norms[i]*torch.nn.BCEWithLogitsLoss(pos_weight=pw)(final_preds, final_real))
  86. else:
  87. try:
  88. stats[metr_name].append(metr_func(final_preds, final_real))
  89. except:
  90. print("Odd problem")
  91. return stats
  92. def train(epoch):
  93. t = time.time()
  94. model.train()
  95. optimizer.zero_grad()
  96. # Run model
  97. t1 = time.time()
  98. classes_pred, bet_layers_pred, in_layers_pred = model(features, adjs)
  99. # Final statistics
  100. stats = dict()
  101. stats['in_class'] = model_stat_in_layer_class(classes_pred, labels, idx_trains, class_metrics)
  102. # Reconstructing all the within and between layer edges without train and test split (for using part of the edges
  103. # for reconstruction, idx is needed)
  104. stats['in_struc'] = model_stat_struc(in_layers_pred, adjs_orig, layer_metrics, adjs_pos_weights, adjs_norms)
  105. stats['bet_struc'] = model_stat_struc(bet_layers_pred, bet_adjs_orig, layer_metrics, bet_pos_weights, bet_norms)
  106. # Write to writer
  107. for name, stat in stats.items():
  108. dict_to_writer(stat, writer, epoch, name, 'Train')
  109. # Calculate loss
  110. loss_train = 0
  111. for name, stat in stats.items():
  112. # Weighted loss
  113. if name == 'in_struc':
  114. loss_train += wlambda * sum([a * b for a, b in zip(stat['loss'], adj_weight)])
  115. elif name== 'bet_struc':
  116. loss_train += wlambda * sum([a * b for a, b in zip(stat['loss'], adj_weight)])
  117. else:
  118. loss_train += sum([a * b for a, b in zip(stat['loss'], adj_weight)])
  119. # Gradients
  120. loss_train.backward()
  121. optimizer.step()
  122. # Validation if needed
  123. if not args.fastmode:
  124. model.eval()
  125. classes_pred, bet_layers_pred, in_layers_pred = model(features, adjs)
  126. stats = dict()
  127. stats['in_class'] = model_stat_in_layer_class(classes_pred, labels, idx_vals, class_metrics)
  128. stats['in_struc'] = model_stat_struc(in_layers_pred, adjs_orig, layer_metrics, adjs_pos_weights, adjs_norms)
  129. stats['bet_struc'] = model_stat_struc(bet_layers_pred, bet_adjs_orig, layer_metrics, bet_pos_weights, bet_norms)
  130. for name, stat in stats.items():
  131. dict_to_writer(stat, writer, epoch, name, 'Validation')
  132. loss_val = sum([sum(stat['loss']) for stat in stats.values()])
  133. print('Epoch: {:04d}'.format(epoch+1),
  134. 'loss_train: {:.4f}'.format(loss_train.item()),
  135. 'time: {:.4f}s'.format(time.time() - t))
  136. def test():
  137. # Test
  138. model.eval()
  139. classes_pred, bet_layers_pred, in_layers_pred = model(features, adjs)
  140. stats = dict()
  141. stats['in_class'] = model_stat_in_layer_class(classes_pred, labels, idx_tests, class_metrics)
  142. for name, stat in stats.items():
  143. dict_to_writer(stat, writer, epoch, name, 'Test')
  144. loss_test = sum([sum(stat['loss']) for stat in stats.values()])
  145. print("Test set results:",
  146. "loss= {:.4f}".format(loss_test.item()))
  147. if __name__=='__main__':
  148. # Training settings
  149. parser = argparse.ArgumentParser()
  150. parser.add_argument('--no-cuda', action='store_true', default=True,
  151. help='Disables CUDA training.')
  152. parser.add_argument('--fastmode', action='store_true', default=False,
  153. help='Validate during training pass.')
  154. parser.add_argument('--seed', type=int, default=42, help='Random seed.')
  155. parser.add_argument('--epochs', type=int, default=20,
  156. help='Number of epochs to train.')
  157. parser.add_argument('--weight_decay', type=float, default=5e-4,
  158. help='Weight decay (L2 loss on parameters).')
  159. parser.add_argument('--dropout', type=float, default=0.5,
  160. help='Dropout rate (1 - keep probability).')
  161. args = parser.parse_args()
  162. args.cuda = not args.no_cuda and torch.cuda.is_available()
  163. dataset_str = "infra"
  164. # parameter
  165. # All combination of parameters will be tested
  166. # adj_weights contains a list of different configs, for example [[1,1,1]] contains one config with equal weights
  167. # for a three-layered input
  168. adj_weights = []
  169. # parameter, weight of link reconstruction loss function
  170. wlambdas = [10]
  171. # parameter, hidden dimension for every layer should be defined
  172. hidden_structures = [[[32],[32],[32]]]
  173. # hidden_structures = [[[16], [16], [16]],[[32], [32], [32]],[[64], [64], [64]],[[128], [128], [128]],[[256], [256], [256]]]
  174. # Learning rate
  175. lrs = [0.01]
  176. # test size
  177. test_sizes = [0.2]
  178. # Load data
  179. adjs, adjs_orig, adjs_sizes, adjs_pos_weights, adjs_norms, bet_pos_weights, bet_norms, bet_adjs, bet_adjs_orig, bet_adjs_sizes, \
  180. features, features_sizes, labels, labels_nclass = load_data(path="../data/" + dataset_str + "/", dataset=dataset_str)
  181. # Number of layers
  182. n_inputs = len(adjs)
  183. # Weights of layers
  184. adj_weights.append(n_inputs * [1])
  185. if args.cuda:
  186. for i in range(n_inputs):
  187. labels[i] = labels[i].cuda()
  188. adjs_pos_weights[i] = adjs_pos_weights[i].cuda()
  189. adjs[i] = adjs[i].cuda()
  190. adjs_orig[i] = adjs_orig[i].cuda()
  191. features[i] = features[i].cuda()
  192. for i in range(len(bet_adjs)):
  193. if not bet_adjs[i] is None:
  194. bet_adjs[i] = bet_adjs[i].cuda()
  195. bet_adjs_orig[i] = bet_adjs_orig[i].cuda()
  196. bet_pos_weights[i] = bet_pos_weights[i].cuda()
  197. # Number of runs
  198. for run in range(10):
  199. for ts in test_sizes:
  200. idx_trains, idx_vals, idx_tests = trains_vals_tests_split(n_inputs, [s[0] for s in adjs_sizes], val_size=0.1,
  201. test_size=ts, random_state=int(run + ts*100))
  202. if args.cuda:
  203. for i in range(n_inputs):
  204. idx_trains[i] = idx_trains[i].cuda()
  205. idx_vals[i] = idx_vals[i].cuda()
  206. idx_tests[i] = idx_tests[i].cuda()
  207. # Train model
  208. t_total = time.time()
  209. for wlambda in wlambdas:
  210. for adj_weight in adj_weights:
  211. for lr in lrs:
  212. for hidden_structure in hidden_structures:
  213. temp_weight = ['{:.2f}'.format(x) for x in adj_weight]
  214. w_str = '-'.join(temp_weight)
  215. h_str = '-'.join([','.join(map(str, temp)) for temp in hidden_structure])
  216. writer = SummaryWriter('log/' + dataset_str + '_run-'+ str(run)+ '/' + '_lambda-' +
  217. str(wlambda) + "_adj_w-" + w_str + "_LR-" + str(lr) +
  218. '_hStruc-' + h_str + '_test_size-' + str(ts))
  219. model = CCN(n_inputs=n_inputs,
  220. inputs_nfeat=features_sizes,
  221. inputs_nhid=hidden_structure,
  222. inputs_nclass=labels_nclass,
  223. dropout=args.dropout)
  224. optimizer = optim.Adam(model.parameters(),
  225. lr=lr, weight_decay=args.weight_decay)
  226. if args.cuda:
  227. model.cuda()
  228. for epoch in range(args.epochs):
  229. train(epoch)
  230. # Testing
  231. test()
  232. print("Optimization Finished!")
  233. print("Total time elapsed: {:.4f}s".format(time.time() - t_total))