|
123456789101112131415161718192021222324252627282930313233343536373839 |
- import torch
- import pickle
- import os
-
- def f1_score(y_pred, y_true):
- tp = torch.sum((y_pred == 0) & (y_true == 0), dtype=torch.float64)
- fp = torch.sum((y_pred == 0) & (y_true != 0), dtype=torch.float64)
- fn = torch.sum((y_pred != 0) & (y_true == 0), dtype=torch.float64)
- precision = tp / (tp + fp)
- recall = tp / (tp + fn)
- f1 = 2 * (precision * recall) / (precision + recall)
-
- return f1, precision, recall
-
- if __name__ == "__main__":
-
- max_f1 = 0
- fname = None
- epoch = None
- each_split = []
-
- for bin_fname in os.listdir('results_edge_classification_bitcoin_otc/'):
- bin_file = open("results_edge_classification_bitcoin_otc/" + bin_fname, "rb")
- data = pickle.load(bin_file)
- assert data.shape[1] == 12
-
- for i in range(len(data)):
- if data[i,2] == 1 or data[i,4] == 1 or data[i,6] == 1:
- continue
- mean_f1 = (data[i, 2] + data[i, 4] + data[i, 6] ) / 3
- if mean_f1 > max_f1:
- max_f1 = mean_f1
- fname = bin_fname
- epoch = i
- each_split = [data[i, 2], data[i, 4], data[i, 6]]
-
- print(fname, epoch, each_split, str(max_f1))
-
-
|