You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation.py 5.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix, f1_score, precision_score, \
  2. recall_score, accuracy_score
  3. from utils import *
  4. class Evaluation:
  5. @staticmethod
  6. def plot_train_val_accuracy(train_accuracies, val_accuracies, num_epochs):
  7. plt.xlabel('epoch')
  8. plt.ylabel('accuracy')
  9. plt.title('h')
  10. plt.plot(range(1, num_epochs + 1), train_accuracies)
  11. plt.plot(range(1, num_epochs + 1), val_accuracies)
  12. plt.show()
  13. @staticmethod
  14. def plot_train_val_loss(train_loss, val_loss, num_epochs):
  15. plt.xlabel('epoch')
  16. plt.ylabel('loss')
  17. plt.title('h')
  18. plt.plot(range(1, num_epochs + 1), train_loss)
  19. plt.plot(range(1, num_epochs + 1), val_loss)
  20. plt.show()
  21. @staticmethod
  22. def evaluate(all_targets, mlp_output, show_plot=True):
  23. predicted_labels = np.where(mlp_output > 0.5, 1, 0)
  24. # Collect predictions and targets for later evaluation
  25. predicted_labels = predicted_labels.reshape(-1)
  26. # Convert predictions and targets to numpy arrays
  27. all_predictions = predicted_labels
  28. # Calculate and print AUC
  29. fpr, tpr, thresholds = metrics.roc_curve(all_targets, mlp_output)
  30. auc = np.round(metrics.auc(fpr, tpr), 2)
  31. # Calculate and print AUPRC
  32. print(all_targets)
  33. precision, recall, thresholds = metrics.precision_recall_curve(all_targets, mlp_output)
  34. auprc = np.round(metrics.auc(recall, precision), 2)
  35. # auprc = average_precision_score(all_targets, mlp_output)
  36. print('Accuracy: {:.2f}'.format(np.round(accuracy_score(all_targets, all_predictions), 2)))
  37. print('AUC: {:.2f}'.format(auc))
  38. print('AUPRC: {:.2f}'.format(auprc))
  39. # Calculate and print confusion matrix
  40. cm = confusion_matrix(all_targets, all_predictions)
  41. accuracy = cm.trace() / np.sum(cm)
  42. precision = cm[0, 0] / (cm[0, 0] + cm[0, 1])
  43. recall = cm[0, 0] / (cm[0, 0] + cm[1, 0])
  44. f1_score = 2 * precision * recall / (precision + recall)
  45. print('Confusion matrix:\n', cm, sep='')
  46. print(f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1 score: {f1_score:.3f}')
  47. if show_plot:
  48. plt.xlabel('False Positive Rate')
  49. plt.ylabel('True Positive Rate')
  50. plt.title(f'ROC Curve: AUC={auc}')
  51. plt.plot(fpr, tpr)
  52. plt.show()
  53. # print(f'AUC: {auc}')
  54. plt.xlabel('Recall')
  55. plt.ylabel('Precision')
  56. plt.title(f'PR Curve: AUPRC={auprc}')
  57. plt.plot(recall, precision)
  58. plt.show()
  59. prediction_targets = pd.DataFrame({}, columns=['Prediction', 'Target'])
  60. res = pd.concat(
  61. [pd.DataFrame(mlp_output.numpy(), ), pd.DataFrame(all_targets.numpy())], axis=1,
  62. ignore_index=True)
  63. res.columns = prediction_targets.columns
  64. prediction_targets = pd.concat([prediction_targets, res])
  65. class_one = prediction_targets.loc[prediction_targets['Target'] == 0, 'Prediction'].astype(
  66. np.float32).tolist()
  67. class_minus_one = prediction_targets.loc[prediction_targets['Target'] == 1, 'Prediction'].astype(
  68. np.float32).tolist()
  69. fig, ax = plt.subplots()
  70. ax.set_ylabel("DeepDRA score")
  71. xticklabels = ['Responder', 'Non Responder']
  72. ax.set_xticks([1, 2])
  73. ax.set_xticklabels(xticklabels)
  74. data_to_plot = [class_minus_one, class_one]
  75. plt.ylim(0, 1)
  76. p_value = np.format_float_scientific(ttest_ind(class_one, class_minus_one)[1])
  77. cancer = 'all'
  78. plt.title(
  79. f'Responder/Non responder scores for {cancer} cancer with \np-value ~= {p_value[0]}e{p_value[-3:]} ')
  80. bp = ax.violinplot(data_to_plot, showextrema=True, showmeans=True, showmedians=True)
  81. bp['cmeans'].set_color('r')
  82. bp['cmedians'].set_color('g')
  83. plt.show()
  84. return {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1 score': f1_score, 'AUC': auc,
  85. 'AUPRC': auprc}
  86. @staticmethod
  87. def add_results(result_list, current_result):
  88. result_list['AUC'].append(current_result['AUC'])
  89. result_list['AUPRC'].append(current_result['AUPRC'])
  90. result_list['Accuracy'].append(current_result['Accuracy'])
  91. result_list['Precision'].append(current_result['Precision'])
  92. result_list['Recall'].append(current_result['Recall'])
  93. result_list['F1 score'].append(current_result['F1 score'])
  94. return result_list
  95. @staticmethod
  96. def show_final_results(result_list):
  97. print("Final Results:")
  98. for i in range(len(result_list["AUC"])):
  99. accuracy = result_list['Accuracy'][i]
  100. precision = result_list['Precision'][i]
  101. recall = result_list['Recall'][i]
  102. f1_score = result_list['F1 score'][i]
  103. auc = result_list['AUC'][i]
  104. auprc = result_list['AUPRC'][i]
  105. print(f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1 score: {f1_score:.3f}, AUC: {auc:.3f}, ,AUPRC: {auprc:.3f}')
  106. avg_auc = np.mean(result_list['AUC'])
  107. avg_auprc = np.mean(result_list['AUPRC'])
  108. std_auprc = np.std(result_list['AUPRC'])
  109. print(" Average AUC: {:.3f} \t Average AUPRC: {:.3f} \t Std AUPRC: {:.3f}".format(avg_auc, avg_auprc, std_auprc))