import torch import pickle import os def f1_score(y_pred, y_true): tp = torch.sum((y_pred == 0) & (y_true == 0), dtype=torch.float64) fp = torch.sum((y_pred == 0) & (y_true != 0), dtype=torch.float64) fn = torch.sum((y_pred != 0) & (y_true == 0), dtype=torch.float64) precision = tp / (tp + fp) recall = tp / (tp + fn) f1 = 2 * (precision * recall) / (precision + recall) return f1, precision, recall if __name__ == "__main__": max_f1 = 0 fname = None epoch = None each_split = [] for bin_fname in os.listdir('results_edge_classification_bitcoin_otc/'): bin_file = open("results_edge_classification_bitcoin_otc/" + bin_fname, "rb") data = pickle.load(bin_file) assert data.shape[1] == 12 for i in range(len(data)): if data[i,2] == 1 or data[i,4] == 1 or data[i,6] == 1: continue mean_f1 = (data[i, 2] + data[i, 4] + data[i, 6] ) / 3 if mean_f1 > max_f1: max_f1 = mean_f1 fname = bin_fname epoch = i each_split = [data[i, 2], data[i, 4], data[i, 6]] print(fname, epoch, each_split, str(max_f1))