PyTorch implementation of Dynamic Graph Convolutional Network
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.2KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import torch
  2. import pickle
  3. import os
  4. def f1_score(y_pred, y_true):
  5. tp = torch.sum((y_pred == 0) & (y_true == 0), dtype=torch.float64)
  6. fp = torch.sum((y_pred == 0) & (y_true != 0), dtype=torch.float64)
  7. fn = torch.sum((y_pred != 0) & (y_true == 0), dtype=torch.float64)
  8. precision = tp / (tp + fp)
  9. recall = tp / (tp + fn)
  10. f1 = 2 * (precision * recall) / (precision + recall)
  11. return f1, precision, recall
  12. if __name__ == "__main__":
  13. max_f1 = 0
  14. fname = None
  15. epoch = None
  16. each_split = []
  17. for bin_fname in os.listdir('results_edge_classification_bitcoin_otc/'):
  18. bin_file = open("results_edge_classification_bitcoin_otc/" + bin_fname, "rb")
  19. data = pickle.load(bin_file)
  20. assert data.shape[1] == 12
  21. for i in range(len(data)):
  22. if data[i,2] == 1 or data[i,4] == 1 or data[i,6] == 1:
  23. continue
  24. mean_f1 = (data[i, 2] + data[i, 4] + data[i, 6] ) / 3
  25. if mean_f1 > max_f1:
  26. max_f1 = mean_f1
  27. fname = bin_fname
  28. epoch = i
  29. each_split = [data[i, 2], data[i, 4], data[i, 6]]
  30. print(fname, epoch, each_split, str(max_f1))