You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluation.py 7.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix, f1_score, precision_score, \
  2. recall_score, accuracy_score
  3. from matplotlib import pyplot as plt
  4. from sklearn import metrics
  5. import numpy as np
  6. import pandas as pd
  7. from statsmodels.stats.weightstats import ttest_ind
  8. class Evaluation:
  9. @staticmethod
  10. def plot_train_val_accuracy(train_accuracies, val_accuracies, num_epochs):
  11. """
  12. Plot training and validation accuracies over epochs.
  13. Parameters:
  14. - train_accuracies (list): List of training accuracies.
  15. - val_accuracies (list): List of validation accuracies.
  16. - num_epochs (int): Number of training epochs.
  17. Returns:
  18. - None
  19. """
  20. plt.xlabel('Epoch')
  21. plt.ylabel('Accuracy')
  22. plt.title('Training and Validation Accuracies')
  23. plt.plot(range(1, num_epochs + 1), train_accuracies, label='Train Accuracy')
  24. plt.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy')
  25. plt.legend()
  26. plt.show()
  27. @staticmethod
  28. def plot_train_val_loss(train_loss, val_loss, num_epochs):
  29. """
  30. Plot training and validation losses over epochs.
  31. Parameters:
  32. - train_loss (list): List of training losses.
  33. - val_loss (list): List of validation losses.
  34. - num_epochs (int): Number of training epochs.
  35. Returns:
  36. - None
  37. """
  38. plt.xlabel('Epoch')
  39. plt.ylabel('Loss')
  40. plt.title('Training and Validation Losses')
  41. plt.plot(range(1, num_epochs + 1), train_loss, label='Train Loss')
  42. plt.plot(range(1, num_epochs + 1), val_loss, label='Validation Loss')
  43. plt.legend()
  44. plt.show()
  45. @staticmethod
  46. def evaluate(all_targets, mlp_output, show_plot=False):
  47. """
  48. Evaluate model performance based on predictions and targets.
  49. Parameters:
  50. - all_targets (numpy.ndarray): True target labels.
  51. - mlp_output (numpy.ndarray): Predicted probabilities.
  52. - show_plot (bool): Whether to display ROC and PR curves.
  53. Returns:
  54. - results (dict): Dictionary containing evaluation metrics.
  55. """
  56. # Step 1: Convert predicted probabilities to binary labels
  57. mlp_output = mlp_output.cpu()
  58. all_targets = all_targets.cpu()
  59. predicted_labels = np.where(mlp_output.cpu() > 0.5, 1, 0)
  60. predicted_labels = predicted_labels.reshape(-1)
  61. all_predictions = predicted_labels
  62. # Step 2: Calculate and print AUC
  63. fpr, tpr, thresholds = metrics.roc_curve(all_targets, mlp_output)
  64. auc = np.round(metrics.auc(fpr, tpr), 3)
  65. # Step 3: Calculate and print AUPRC
  66. precision, recall, thresholds = metrics.precision_recall_curve(all_targets, mlp_output)
  67. auprc = np.round(metrics.auc(recall, precision), 3)
  68. # Step 4: Print accuracy, AUC, AUPRC, and confusion matrix
  69. accuracy = accuracy_score(all_targets, all_predictions)
  70. cm = confusion_matrix(all_targets, all_predictions)
  71. precision = cm[0, 0] / (cm[0, 0] + cm[0, 1])
  72. recall = cm[0, 0] / (cm[0, 0] + cm[1, 0])
  73. f1_score = 2 * precision * recall / (precision + recall)
  74. print(f'Accuracy: {accuracy:.2f}')
  75. print(f'AUC: {auc:.2f}')
  76. print(f'AUPRC: {auprc:.2f}')
  77. print(f'Confusion matrix:\n{cm}')
  78. print(f'Precision: {precision:.3f}, Recall: {recall:.3f}, F1 score: {f1_score:.3f}')
  79. # Step 5: Display ROC and PR curves if requested
  80. if show_plot:
  81. plt.xlabel('False Positive Rate')
  82. plt.ylabel('True Positive Rate')
  83. plt.title(f'ROC Curve: AUC={auc}')
  84. plt.plot(fpr, tpr)
  85. plt.show()
  86. plt.xlabel('Recall')
  87. plt.ylabel('Precision')
  88. plt.title(f'PR Curve: AUPRC={auprc}')
  89. plt.plot(recall, precision)
  90. plt.show()
  91. # Violin plot for DeepDRA scores
  92. prediction_targets = pd.DataFrame({}, columns=['Prediction', 'Target'])
  93. res = pd.concat(
  94. [pd.DataFrame(mlp_output.numpy(), ), pd.DataFrame(all_targets.numpy())], axis=1,
  95. ignore_index=True)
  96. res.columns = prediction_targets.columns
  97. prediction_targets = pd.concat([prediction_targets, res])
  98. class_one = prediction_targets.loc[prediction_targets['Target'] == 0, 'Prediction']
  99. class_minus_one = prediction_targets.loc[prediction_targets['Target'] == 1, 'Prediction']
  100. fig, ax = plt.subplots()
  101. ax.set_ylabel("DeepDRA score")
  102. xticklabels = ['Responder', 'Non Responder']
  103. ax.set_xticks([1, 2])
  104. ax.set_xticklabels(xticklabels)
  105. data_to_plot = [class_minus_one, class_one]
  106. plt.ylim(0, 1)
  107. p_value = np.format_float_scientific(ttest_ind(class_one, class_minus_one)[1])
  108. cancer = 'all'
  109. plt.title(
  110. f'Responder/Non-responder scores for {cancer} cancer with \np-value ~= {p_value[0]}e{p_value[-3:]} ')
  111. bp = ax.violinplot(data_to_plot, showextrema=True, showmeans=True, showmedians=True)
  112. bp['cmeans'].set_color('r')
  113. bp['cmedians'].set_color('g')
  114. plt.show()
  115. # Step 6: Return evaluation metrics in a dictionary
  116. return {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1 score': f1_score, 'AUC': auc,
  117. 'AUPRC': auprc}
  118. @staticmethod
  119. def add_results(result_list, current_result):
  120. result_list['AUC'].append(current_result['AUC'])
  121. result_list['AUPRC'].append(current_result['AUPRC'])
  122. result_list['Accuracy'].append(current_result['Accuracy'])
  123. result_list['Precision'].append(current_result['Precision'])
  124. result_list['Recall'].append(current_result['Recall'])
  125. result_list['F1 score'].append(current_result['F1 score'])
  126. return result_list
  127. @staticmethod
  128. def show_final_results(result_list):
  129. print("Final Results:")
  130. for i in range(len(result_list["AUC"])):
  131. accuracy = result_list['Accuracy'][i]
  132. precision = result_list['Precision'][i]
  133. recall = result_list['Recall'][i]
  134. f1_score = result_list['F1 score'][i]
  135. auc = result_list['AUC'][i]
  136. auprc = result_list['AUPRC'][i]
  137. print(f'Accuracy: {accuracy:.3f}, Precision: {precision:.3f}, Recall: {recall:.3f}, F1 score: {f1_score:.3f}, AUC: {auc:.3f}, ,AUPRC: {auprc:.3f}')
  138. avg_auc = np.mean(result_list['AUC'])
  139. avg_auprc = np.mean(result_list['AUPRC'])
  140. std_auprc = np.std(result_list['AUPRC'])
  141. avg_accuracy = np.mean(result_list['Accuracy'])
  142. avg_precision = np.mean(result_list['Precision'])
  143. avg_recal = np.mean(result_list['Recall'])
  144. avg_f1score = np.mean(result_list['F1 score'])
  145. print(
  146. f'AVG: Accuracy: {avg_accuracy:.3f}, Precision: {avg_precision:.3f}, Recall: {avg_recal:.3f}, F1 score: {avg_f1score:.3f}, AUC: {avg_auc:.3f}, ,AUPRC: {avg_auprc:.3f}')
  147. print(" Average AUC: {:.3f} \t Average AUPRC: {:.3f} \t Std AUPRC: {:.3f}".format(avg_auc, avg_auprc, std_auprc))
  148. return {'Accuracy': avg_accuracy, 'Precision': avg_precision, 'Recall': avg_recal, 'F1 score': avg_f1score, 'AUC': avg_auc,
  149. 'AUPRC': avg_auprc}