You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation.py 7.2KB

11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
11 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix, f1_score, precision_score, \
  2. recall_score, accuracy_score
  3. from matplotlib import pyplot as plt
  4. from sklearn import metrics
  5. import numpy as np
  6. import pandas as pd
  7. from statsmodels.stats.weightstats import ttest_ind
  8. class Evaluation:
  9. @staticmethod
  10. def plot_train_val_accuracy(train_accuracies, val_accuracies, num_epochs):
  11. """
  12. Plot training and validation accuracies over epochs.
  13. Parameters:
  14. - train_accuracies (list): List of training accuracies.
  15. - val_accuracies (list): List of validation accuracies.
  16. - num_epochs (int): Number of training epochs.
  17. Returns:
  18. - None
  19. """
  20. plt.xlabel('Epoch')
  21. plt.ylabel('Accuracy')
  22. plt.title('Training and Validation Accuracies')
  23. plt.plot(range(1, num_epochs + 1), train_accuracies, label='Train Accuracy')
  24. plt.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy')
  25. plt.legend()
  26. plt.show()
  27. @staticmethod
  28. def plot_train_val_loss(train_loss, val_loss, num_epochs):
  29. """
  30. Plot training and validation losses over epochs.
  31. Parameters:
  32. - train_loss (list): List of training losses.
  33. - val_loss (list): List of validation losses.
  34. - num_epochs (int): Number of training epochs.
  35. Returns:
  36. - None
  37. """
  38. plt.xlabel('Epoch')
  39. plt.ylabel('Loss')
  40. plt.title('Training and Validation Losses')
  41. plt.plot(range(1, num_epochs + 1), train_loss, label='Train Loss')
  42. plt.plot(range(1, num_epochs + 1), val_loss, label='Validation Loss')
  43. plt.legend()
  44. plt.show()
  45. @staticmethod
  46. def evaluate(all_targets, mlp_output, show_plot=False):
  47. """
  48. Evaluate model performance based on predictions and targets.
  49. Parameters:
  50. - all_targets (numpy.ndarray): True target labels.
  51. - mlp_output (numpy.ndarray): Predicted probabilities.
  52. - show_plot (bool): Whether to display ROC and PR curves.
  53. Returns:
  54. - results (dict): Dictionary containing evaluation metrics.
  55. """
  56. # Step 1: Convert predicted probabilities to binary labels
  57. mlp_output = mlp_output.cpu()
  58. all_targets = all_targets.cpu()
  59. predicted_labels = np.where(mlp_output.cpu() > 0.5, 1, 0)
  60. predicted_labels = predicted_labels.reshape(-1)
  61. all_predictions = predicted_labels
  62. # Step 2: Calculate and print AUC
  63. fpr, tpr, thresholds = metrics.roc_curve(all_targets, mlp_output)
  64. auc = np.round(metrics.auc(fpr, tpr), 4)
  65. # Step 3: Calculate and print AUPRC
  66. precision, recall, thresholds = metrics.precision_recall_curve(all_targets, mlp_output)
  67. auprc = np.round(metrics.auc(recall, precision), 4)
  68. # Step 4: Print accuracy, AUC, AUPRC, and confusion matrix
  69. accuracy = accuracy_score(all_targets, all_predictions)
  70. cm = confusion_matrix(all_targets, all_predictions)
  71. precision = cm[0, 0] / (cm[0, 0] + cm[0, 1])
  72. recall = cm[0, 0] / (cm[0, 0] + cm[1, 0])
  73. f1_score = 2 * precision * recall / (precision + recall)
  74. print(f'Confusion matrix:\n{cm}')
  75. print(
  76. f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1 score: {f1_score:.3f}, AUC: {auc:.3f}, ,AUPRC: {auprc:.3f}')
  77. # Step 5: Display ROC and PR curves if requested
  78. if show_plot:
  79. plt.xlabel('False Positive Rate')
  80. plt.ylabel('True Positive Rate')
  81. plt.title(f'ROC Curve: AUC={auc}')
  82. plt.plot(fpr, tpr)
  83. plt.show()
  84. plt.xlabel('Recall')
  85. plt.ylabel('Precision')
  86. plt.title(f'PR Curve: AUPRC={auprc}')
  87. plt.plot(recall, precision)
  88. plt.show()
  89. # Violin plot for DeepDRA scores
  90. prediction_targets = pd.DataFrame({}, columns=['Prediction', 'Target'])
  91. res = pd.concat(
  92. [pd.DataFrame(mlp_output.numpy(), ), pd.DataFrame(all_targets.numpy())], axis=1,
  93. ignore_index=True)
  94. res.columns = prediction_targets.columns
  95. prediction_targets = pd.concat([prediction_targets, res])
  96. class_one = prediction_targets.loc[prediction_targets['Target'] == 0, 'Prediction']
  97. class_minus_one = prediction_targets.loc[prediction_targets['Target'] == 1, 'Prediction']
  98. fig, ax = plt.subplots()
  99. ax.set_ylabel("DeepDRA score")
  100. xticklabels = ['Responder', 'Non Responder']
  101. ax.set_xticks([1, 2])
  102. ax.set_xticklabels(xticklabels)
  103. data_to_plot = [class_minus_one, class_one]
  104. plt.ylim(0, 1)
  105. p_value = np.format_float_scientific(ttest_ind(class_one, class_minus_one)[1])
  106. cancer = 'all'
  107. plt.title(
  108. f'Responder/Non-responder scores for {cancer} cancer with \np-value ~= {p_value[0]}e{p_value[-3:]} ')
  109. bp = ax.violinplot(data_to_plot, showextrema=True, showmeans=True, showmedians=True)
  110. bp['cmeans'].set_color('r')
  111. bp['cmedians'].set_color('g')
  112. plt.show()
  113. # Step 6: Return evaluation metrics in a dictionary
  114. return {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1 score': f1_score, 'AUC': auc,
  115. 'AUPRC': auprc}
  116. @staticmethod
  117. def add_results(result_list, current_result):
  118. result_list['AUC'].append(current_result['AUC'])
  119. result_list['AUPRC'].append(current_result['AUPRC'])
  120. result_list['Accuracy'].append(current_result['Accuracy'])
  121. result_list['Precision'].append(current_result['Precision'])
  122. result_list['Recall'].append(current_result['Recall'])
  123. result_list['F1 score'].append(current_result['F1 score'])
  124. return result_list
  125. @staticmethod
  126. def show_final_results(result_list):
  127. print("Final Results:")
  128. for i in range(len(result_list["AUC"])):
  129. accuracy = result_list['Accuracy'][i]
  130. precision = result_list['Precision'][i]
  131. recall = result_list['Recall'][i]
  132. f1_score = result_list['F1 score'][i]
  133. auc = result_list['AUC'][i]
  134. auprc = result_list['AUPRC'][i]
  135. print(f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1 score: {f1_score:.3f}, AUC: {auc:.3f}, ,AUPRC: {auprc:.3f}')
  136. avg_auc = np.mean(result_list['AUC'])
  137. avg_auprc = np.mean(result_list['AUPRC'])
  138. std_auprc = np.std(result_list['AUPRC'])
  139. avg_accuracy = np.mean(result_list['Accuracy'])
  140. avg_precision = np.mean(result_list['Precision'])
  141. avg_recal = np.mean(result_list['Recall'])
  142. avg_f1score = np.mean(result_list['F1 score'])
  143. print(
  144. f'AVG: Accuracy: {avg_accuracy:.3f}, Precision: {avg_precision:.3f}, Recall: {avg_recal:.3f}, F1 score: {avg_f1score:.3f}, AUC: {avg_auc:.3f}, ,AUPRC: {avg_auprc:.3f}')
  145. print(" Average AUC: {:.3f} \t Average AUPRC: {:.3f} \t Std AUPRC: {:.3f}".format(avg_auc, avg_auprc, std_auprc))
  146. return {'Accuracy': avg_accuracy, 'Precision': avg_precision, 'Recall': avg_recal, 'F1 score': avg_f1score, 'AUC': avg_auc,
  147. 'AUPRC': avg_auprc}