PyTorch implementation of Dynamic Graph Convolutional Network
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import os
  2. import argparse
  3. import torch
  4. import torch.nn as nn
  5. import numpy as np
  6. import pickle
  7. from wd_gc import WD_GCN
  8. from data_loader import DataLoader
  9. from utils import f1_score
  10. parser = argparse.ArgumentParser()
  11. # directories
  12. parser.add_argument('--data_root', type=str, default='data/', help='path to the root of data directory')
  13. parser.add_argument('--dataset', type=str, default='bitcoin_alpha', help='dataset name')
  14. parser.add_argument('--dataset_mat_file_path', type=str, default='Bitcoin_Alpha/saved_content_bitcoin_alpha.mat', help='dataset mat file path')
  15. parser.add_argument('--save_dir', type=str, default='results_edge_classification_bitcoin_alpha/', help='path to save directory')
  16. # model params
  17. parser.add_argument('--hidden_feat', type=list, default=[6, 2], help='hidden feature sizes')
  18. parser.add_argument('--time_partitioning', type=list, help='time partitioning for train/test/val split')
  19. # training params
  20. parser.add_argument('--epochs', type=int, default=1001, help='number of epochs')
  21. parser.add_argument('--optimizer', type=str, default='SGD', help='optimizer')
  22. parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate')
  23. parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
  24. parser.add_argument('--alpha_vec', type=list,
  25. default=[.75, .76, .77, .78, .79, .80, .81, .82, .83, .84, .85,
  26. .86, .87, .88, .89, .90, .91, .92, .93, .94, .95], help='alpha rate for weighted cross-entropy')
  27. args = parser.parse_args()
  28. print('__pyTorch VERSION:', torch.__version__)
  29. print('# of available devices ', torch.cuda.device_count())
  30. print('Is cuda GPU available? ', torch.cuda.is_available())
  31. print('========================')
  32. if torch.cuda.is_available():
  33. dev = "cuda:1"
  34. else:
  35. dev = "cpu"
  36. device = torch.device(dev)
  37. torch.cuda.set_device(device)
  38. print('Current cuda device', torch.cuda.current_device())
  39. # Make the save directory
  40. if not os.path.exists(args.save_dir):
  41. os.makedirs(args.save_dir)
  42. time_slices = args.time_partitioning
  43. S_train, S_val, S_test = time_slices[0], time_slices[1], time_slices[2]
  44. # Create data loader
  45. data_loader = DataLoader(S_train=S_train, S_val=S_val, S_test=S_test,
  46. data_dir=args.data_root,
  47. mat_file_path=args.dataset_mat_file_path)
  48. # load data
  49. data, C, targets, edges = data_loader.load_data()
  50. # train
  51. print("Training Started...")
  52. for alpha in args.alpha_vec:
  53. print(">> alpha = {}".format(alpha))
  54. class_weights = torch.tensor([alpha, 1.0 - alpha]).to(device)
  55. save_res_fname = "results_WDGCN" + "_w" + str(round(float(class_weights[0]*100))) + "_" + args.dataset
  56. # model definition
  57. gcn = WD_GCN(C['C_train'], data['X_train'].to(device), edges["edges_train"].to(device), args.hidden_feat)
  58. gcn.to(device)
  59. if args.optimizer == "SGD":
  60. optimizer = torch.optim.SGD(gcn.parameters(), lr=args.learning_rate, momentum=args.momentum)
  61. criterion = nn.CrossEntropyLoss(weight=class_weights)
  62. ep_acc_loss = np.zeros((args.epochs, 12)) # (precision_train, recall_train, f1_train, loss_train, precision_val, recall_val, f1_val, loss_val, precision_test, recall_test, f1_test, loss_test)
  63. for ep in range(args.epochs):
  64. # compute loss and take step
  65. optimizer.zero_grad()
  66. output_train = gcn(C['C_train'], data['X_train'].to(device), edges["edges_train"].to(device)) # forward passing
  67. loss_train = criterion(output_train, targets['target_train'].to(device))
  68. loss_train.backward() # backward propagation
  69. optimizer.step()
  70. with torch.no_grad():
  71. prediction_train = torch.argmax(output_train, dim=1)
  72. f1_train, precision_train, recall_train = f1_score(prediction_train, targets['target_train'].to(device))
  73. if ep % 100 == 0:
  74. # validation
  75. output_val = gcn(C['C_val'], data['X_val'].to(device), edges["edges_val"].to(device))
  76. prediction_val = torch.argmax(output_val, dim=1)
  77. f1_val, precision_val, recall_val = f1_score(prediction_val, targets['target_val'].to(device))
  78. loss_val = criterion(output_val, targets['target_val'].to(device))
  79. # test
  80. output_test = gcn(C['C_test'], data['X_test'].to(device), edges["edges_test"].to(device))
  81. prediction_test = torch.argmax(output_test, dim=1)
  82. f1_test, precision_test, recall_test = f1_score(prediction_test, targets['target_test'].to(device))
  83. loss_test = criterion(output_test, targets['target_test'].to(device))
  84. # log
  85. print("Epoch %d:" % ep)
  86. print("%d/%d: Train precision/recall/f1 %.4f/ %.4f/ %.4f. Train loss %.4f." % (ep, args.epochs, precision_train, recall_train, f1_train, loss_train))
  87. print("%d/%d: Val precision/recall/f1 %.4f/ %.4f/ %.4f. Val loss %.4f." % (ep, args.epochs, precision_val, recall_val, f1_val, loss_val))
  88. print("%d/%d: Test precision/recall/f1 %.4f/ %.4f/ %.4f. Test loss %.4f." % (ep, args.epochs, precision_test, recall_test, f1_test, loss_test))
  89. ep_acc_loss[ep] = [precision_train, recall_train, f1_train, loss_train,
  90. precision_val, recall_val, f1_val, loss_val,
  91. precision_test, recall_test, f1_test, loss_test]
  92. pickle.dump(ep_acc_loss, open(args.save_dir + save_res_fname, "wb"))