You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc
  2. from matplotlib import pyplot
  3. def AUROC(scores, labels):
  4. scores = scores.cpu().numpy()
  5. labels = labels.cpu().numpy()
  6. ns_probs = [0 for _ in range(len(labels))]
  7. lr_auc = roc_auc_score(labels, scores)
  8. write_to_out('AUROC: %.3f \n' % (lr_auc))
  9. ns_fpr, ns_tpr, _ = roc_curve(labels, ns_probs)
  10. lr_fpr, lr_tpr, _ = roc_curve(labels, scores)
  11. pyplot.plot(ns_fpr, ns_tpr, linestyle='--', label='No Skill')
  12. pyplot.plot(lr_fpr, lr_tpr, label='Logistic')
  13. pyplot.xlabel('False Positive Rate')
  14. pyplot.ylabel('True Positive Rate')
  15. pyplot.legend()
  16. pyplot.savefig('../out/AUROC', dpi=180)
  17. pyplot.show()
  18. def AUPR(scores, labels):
  19. scores = scores.cpu().numpy()
  20. labels = labels.cpu().numpy()
  21. lr_precision, lr_recall, _ = precision_recall_curve(labels, scores)
  22. lr_auc = auc(lr_recall, lr_precision)
  23. write_to_out('AUPR: %.3f \n' % (lr_auc))
  24. no_skill = len(labels[labels==1]) / len(labels)
  25. pyplot.plot([0, 1], [no_skill, no_skill], linestyle='--', label='No Skill')
  26. pyplot.plot(lr_recall, lr_precision, label='HGT')
  27. pyplot.xlabel('Recall')
  28. pyplot.ylabel('Precision')
  29. pyplot.legend()
  30. pyplot.savefig('../out/AUPR', dpi=180)
  31. pyplot.show()
  32. def plot_losses(losses, val_losses):
  33. pyplot.plot(range(len(losses)), losses, label="loss")
  34. pyplot.plot(range(len(losses)), val_losses, label="val_loss")
  35. pyplot.legend()
  36. pyplot.savefig('../out/losses', dpi=200)
  37. pyplot.clf()
  38. def write_to_out(text):
  39. print(text)
  40. out_file = open('../out/out.txt', 'a')
  41. out_file.write(text)
  42. out_file.close()