You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_training.py 17KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import os
  2. import random
  3. import time
  4. from typing import cast
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import timm
  8. import torch
  9. import torchvision
  10. from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score
  11. from torch import nn, optim
  12. from torch.utils.data import DataLoader
  13. from tqdm import tqdm
  14. from config import Config
  15. from fragment_splitter import CustomFragmentLoader
  16. from model_train_logger import set_config_for_logger
  17. from thyroid_dataset import ThyroidDataset
  18. from thyroid_ml_model import ThyroidClassificationModel
  19. from transformation import get_transformation
  20. @torch.no_grad()
  21. def validate(model, data_loader, loss_function=None, show_tqdm=False):
  22. class_set = sorted(data_loader.dataset.class_to_idx_dict.values())
  23. loss_values = []
  24. y_preds = []
  25. y_targets = []
  26. y_positive_scores = []
  27. for images, labels in (data_loader if not show_tqdm else tqdm(data_loader)):
  28. images = images.to(Config.available_device)
  29. labels = labels.to(Config.available_device)
  30. x = model(images, validate=True)
  31. if loss_function:
  32. loss_values.append(loss_function(x, labels))
  33. values, preds = torch.max(x, 1)
  34. y_positive_scores += x[:, 1].cpu()
  35. y_preds += preds.cpu()
  36. y_targets += labels.cpu()
  37. cf_matrix = confusion_matrix(y_targets, y_preds, normalize="true")
  38. class_accuracies = [cf_matrix[c][c] for c in class_set]
  39. acc = sum(class_accuracies)
  40. acc /= len(class_set)
  41. # TN|FN
  42. # FP|TP
  43. fpr, tpr, _ = roc_curve(y_targets, y_positive_scores)
  44. auc = roc_auc_score(y_targets, y_positive_scores)
  45. if loss_function:
  46. loss = sum(loss_values)
  47. loss /= len(loss_values)
  48. return acc * 100, cf_matrix, (fpr, tpr, auc), loss
  49. return acc * 100, cf_matrix, (fpr, tpr, auc)
  50. def get_save_state_dirs(config_label, epoch=None):
  51. trains_state_dir = "./train_state"
  52. if not os.path.isdir(trains_state_dir):
  53. os.mkdir(trains_state_dir)
  54. config_train_dir = os.path.join(trains_state_dir, config_label)
  55. if not os.path.isdir(config_train_dir):
  56. os.mkdir(config_train_dir)
  57. if epoch is not None:
  58. save_state_dir = os.path.join(config_train_dir, f"epoch-{epoch}")
  59. if not os.path.isdir(save_state_dir):
  60. os.mkdir(save_state_dir)
  61. else:
  62. save_state_dir = None
  63. return trains_state_dir, config_train_dir, save_state_dir
  64. def plot_and_save_model_per_epoch(epoch,
  65. model_to_save,
  66. val_acc_list,
  67. train_acc_list,
  68. val_loss_list,
  69. train_loss_list,
  70. config_label):
  71. trains_state_dir, config_train_dir, save_state_dir = get_save_state_dirs(config_label, epoch)
  72. fig_save_path = os.path.join(config_train_dir, "val_train_acc.jpeg")
  73. plt.plot(range(len(val_acc_list)), val_acc_list, label="validation")
  74. plt.plot(range(len(train_acc_list)), train_acc_list, label="train")
  75. plt.legend(loc="lower right")
  76. plt.xlabel('Epoch')
  77. plt.ylabel('Balanced Accuracy')
  78. plt.savefig(fig_save_path)
  79. plt.clf()
  80. fig_save_path = os.path.join(config_train_dir, "val_train_loss.jpeg")
  81. plt.plot(range(len(val_loss_list)), val_loss_list, label="validation")
  82. plt.plot(range(len(train_loss_list)), train_loss_list, label="train")
  83. plt.legend(loc="lower right")
  84. plt.xlabel('Epoch')
  85. plt.ylabel('Loss')
  86. plt.savefig(fig_save_path)
  87. plt.clf()
  88. if model_to_save:
  89. model_save_path = os.path.join(save_state_dir, "model.state")
  90. model_to_save.save_model(model_save_path)
  91. def save_auc_roc_chart_for_test(test_fpr, test_tpr, test_auc_score, config_label, epoch):
  92. trains_state_dir, config_train_dir, save_dir = get_save_state_dirs(config_label, epoch)
  93. fig_save_path = os.path.join(save_dir, f"test_roc_{time.time()}.jpeg")
  94. plt.plot(test_fpr, test_tpr, label="test, auc=" + str(test_auc_score))
  95. plt.legend(loc="lower right")
  96. plt.xlabel('FPR')
  97. plt.ylabel('TPR')
  98. plt.savefig(fig_save_path)
  99. plt.clf()
  100. def calculate_test(image_model, epoch, test_data_loader, logger, config_name, show_tqdm=False):
  101. image_model.eval()
  102. test_acc, test_c_acc, (test_FPR, test_TPR, test_auc_score) = validate(image_model,
  103. test_data_loader,
  104. show_tqdm=show_tqdm)
  105. test_acc = float(test_acc)
  106. save_auc_roc_chart_for_test(test_FPR, test_TPR, test_auc_score, config_name, epoch)
  107. logger.info(f'Test|Epoch:{epoch}|Accuracy:{round(test_acc, 4)}, {test_c_acc}%')
  108. def train_model(base_model, config_base_name, train_val_test_data_loaders, augmentation,
  109. adaptation_sample_dataset=None,
  110. train_model_flag=True,
  111. load_model_from_dir=None):
  112. config_name = f"{config_base_name}-{augmentation}-{','.join(Config.class_idx_dict.keys())}"
  113. logger = set_config_for_logger(config_name)
  114. logger.info(f"training config: {config_name}")
  115. try:
  116. _is_inception = type(base_model) == torchvision.models.inception.Inception3
  117. train_data_loader, val_data_loader, test_data_loader = train_val_test_data_loaders
  118. logger.info(
  119. f"train valid test splits:" +
  120. f" {len(train_data_loader.dataset.samples) if train_data_loader else None}," +
  121. f" {len(val_data_loader.dataset.samples) if val_data_loader else None}," +
  122. f" {len(test_data_loader.dataset.samples) if test_data_loader else None}")
  123. # MODEL
  124. if load_model_from_dir:
  125. # Load model from file
  126. model_path = os.path.join(load_model_from_dir, 'model.state')
  127. image_model = ThyroidClassificationModel(base_model).load_model(model_path).to(Config.available_device)
  128. else:
  129. image_model = ThyroidClassificationModel(base_model).to(Config.available_device)
  130. if train_model_flag:
  131. # TRAIN
  132. transformation = get_transformation(augmentation=augmentation, base_dataset=adaptation_sample_dataset)
  133. train_dataset = cast(ThyroidDataset, train_data_loader.dataset)
  134. train_dataset.transform = transformation
  135. cec = nn.CrossEntropyLoss(weight=torch.tensor(train_dataset.class_weights).to(Config.available_device))
  136. optimizer = optim.Adam(image_model.parameters(), lr=Config.learning_rate)
  137. my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=Config.decay_rate)
  138. val_acc_history = []
  139. train_acc_history = []
  140. train_y_preds = []
  141. train_y_targets = []
  142. best_epoch_val_acc = 0
  143. for epoch in range(Config.n_epoch):
  144. # variables to calculate train acc
  145. class_set = sorted(train_data_loader.dataset.class_to_idx_dict.values())
  146. for images, labels in tqdm(train_data_loader, colour="#0000ff"):
  147. if len(images) >= Config.batch_size // 2:
  148. image_model.train()
  149. images = images.to(Config.available_device)
  150. labels = labels.to(Config.available_device)
  151. optimizer.zero_grad()
  152. pred = image_model(images)
  153. # pred label: torch.max(pred, 1)[1], labels
  154. if _is_inception:
  155. pred, aux_pred = pred
  156. loss, aux_loss = cec(pred, labels), cec(aux_pred, labels)
  157. loss = loss + 0.4 * aux_loss
  158. else:
  159. loss = cec(pred, labels)
  160. loss.backward()
  161. optimizer.step()
  162. # train preds and labels
  163. values, preds = torch.max(pred, 1)
  164. train_y_preds.extend(preds.cpu())
  165. train_y_targets.extend(labels.cpu())
  166. # Epoch level
  167. # validation data
  168. image_model.eval()
  169. train_cf_matrix = confusion_matrix(train_y_targets, train_y_preds, normalize="true")
  170. class_accuracies = [train_cf_matrix[c][c] for c in class_set]
  171. train_acc = sum(class_accuracies)
  172. train_acc /= len(class_set)
  173. train_acc = (100 * sum(class_accuracies) / len(class_set)).item()
  174. train_acc_history.append(train_acc)
  175. logger.info(f'Train|E:{epoch}|Balanced Accuracy:{round(train_acc, 4)}%,\n{train_cf_matrix}')
  176. val_acc, val_cf_matrix, _, val_loss = validate(image_model,
  177. val_data_loader,
  178. cec)
  179. val_acc = float(val_acc)
  180. val_acc_history.append(val_acc)
  181. logger.info(f'Val|E:{epoch}|Balanced Accuracy:{round(val_acc, 4)}%,\n{val_cf_matrix}')
  182. save_model = False
  183. is_last_epoch = epoch == Config.n_epoch
  184. is_a_better_epoch = val_acc >= best_epoch_val_acc
  185. is_a_better_epoch &= abs(train_acc - val_acc) < Config.train_val_acc_max_distance_for_best_epoch
  186. if is_a_better_epoch or is_last_epoch:
  187. save_model = True
  188. calculate_test(image_model, epoch, test_data_loader, logger, config_name, show_tqdm=False)
  189. plot_and_save_model_per_epoch(epoch if save_model else None,
  190. image_model if save_model else None,
  191. val_acc_history,
  192. train_acc_history,
  193. [],
  194. [],
  195. config_label=config_name)
  196. my_lr_scheduler.step()
  197. else:
  198. # JUST EVALUATE
  199. calculate_test(image_model, 0, test_data_loader, logger, config_name,
  200. show_tqdm=True)
  201. except Exception as e:
  202. print(e)
  203. logger.error(str(e))
  204. raise e
  205. def load_datasets(datasets_folders, test_percent=Config.test_percent, val_percent=Config.val_percent, sample_percent=1,
  206. is_nci_per_slide=False):
  207. if is_nci_per_slide:
  208. l_train, l_val, l_test = CustomFragmentLoader(
  209. datasets_folders).national_cancer_image_and_labels_splitter_per_slide(
  210. test_percent=test_percent,
  211. val_percent=val_percent)
  212. else:
  213. l_train, l_val, l_test = CustomFragmentLoader(datasets_folders).load_image_path_and_labels_and_split(
  214. test_percent=test_percent,
  215. val_percent=val_percent)
  216. l_train = random.choices(l_train, k=int(sample_percent * len(l_train)))
  217. l_val = random.choices(l_val, k=int(sample_percent * len(l_val)))
  218. l_test = random.choices(l_test, k=int(sample_percent * len(l_test)))
  219. l_train_ds = ThyroidDataset(l_train, Config.class_idx_dict)
  220. l_val_ds = ThyroidDataset(l_val, Config.class_idx_dict)
  221. l_test_ds = ThyroidDataset(l_test, Config.class_idx_dict)
  222. l_train_data_loader = None
  223. if l_train:
  224. l_train_data_loader = DataLoader(l_train_ds, batch_size=Config.batch_size, shuffle=True)
  225. l_val_data_loader = None
  226. if l_val:
  227. l_val_data_loader = DataLoader(l_val_ds, batch_size=Config.eval_batch_size, shuffle=True)
  228. l_test_data_loader = None
  229. if l_test:
  230. l_test_data_loader = DataLoader(l_test_ds, batch_size=Config.eval_batch_size, shuffle=True)
  231. return (l_train, l_val, l_test), (l_train_ds, l_val_ds, l_test_ds), (
  232. l_train_data_loader, l_val_data_loader, l_test_data_loader)
  233. @torch.no_grad()
  234. def evaluate_nci_dataset_per_slide(config_base_name, augmentation, base_model, data_loader,
  235. load_model_from_dir):
  236. config_name = f"{config_base_name}-{augmentation}-tumor-percent"
  237. logger = set_config_for_logger(config_name)
  238. logger.info(f"training config: {config_name}")
  239. _is_inception = type(base_model) == torchvision.models.inception.Inception3
  240. logger.info(
  241. f"test:" +
  242. f" {len(data_loader.dataset.samples) if data_loader else None}")
  243. # MODEL
  244. # Load model from file
  245. model_path = os.path.join(load_model_from_dir, 'model.state')
  246. model = ThyroidClassificationModel(base_model).load_model(model_path).to(Config.available_device)
  247. y_positive_scores = []
  248. slides_preds = {}
  249. slide_labels = {}
  250. for images, (labels, slides) in tqdm(data_loader):
  251. images = images.to(Config.available_device)
  252. x = model(images, validate=True).cpu()
  253. preds = x[:, 1]
  254. logger.info("zero and 1000 percent")
  255. logger.info(x[:, 0])
  256. logger.info(x[:, 1])
  257. for row_index in range(len(labels)):
  258. slide_id = slides[row_index]
  259. slide_label = labels[row_index]
  260. slide_labels[slide_id] = slide_label
  261. slides_preds[slide_id] = slides_preds.get(slide_id, []) + [preds[row_index].item()]
  262. y_positive_scores += x[:, 1].cpu()
  263. y_targets = []
  264. y_preds = []
  265. for key, value in slides_preds.items():
  266. slides_preds[key] = (sum(slides_preds[key]) / len(slides_preds[key])) * 100
  267. y_preds.append(slides_preds[key])
  268. y_targets.append(int(slide_labels[key]))
  269. y_targets_rounded = [int(round(x / 100, 1) * 100) for x in y_targets]
  270. y_preds_rounded = [int(round(x / 100, 1) * 100) for x in y_preds]
  271. cf_matrix = confusion_matrix(y_targets_rounded, y_preds_rounded, labels=Config.class_names, normalize="true")
  272. class_accuracies = [cf_matrix[c][c] for c in range(len(cf_matrix))]
  273. class_weights = [sum(cf_matrix[c]) for c in range(len(cf_matrix))]
  274. acc = sum([class_accuracies[i] * class_weights[i] for i in range(len(class_accuracies))])
  275. acc /= sum(class_weights)
  276. # TN|FN
  277. # FP|TP
  278. # fpr, tpr, _ = roc_curve(y_targets, y_positive_scores)
  279. # auc = roc_auc_score(y_targets, y_positive_scores)
  280. logger.info(f"target rounded:{y_targets_rounded}")
  281. logger.info(f"pred rounded:{y_preds_rounded}")
  282. logger.info(f"Results| acc:{acc * 100}\ncf:{cf_matrix}")
  283. return acc * 100, cf_matrix
  284. ##########
  285. ## Runs ##
  286. ##########
  287. # train_phase block
  288. if __name__ == '__main__' and Config.train_phase:
  289. _, (train_ds, _, _), (train_data_loader, val_data_loader, test_data_loader) = load_datasets(
  290. ["national_cancer_institute"],
  291. sample_percent=1)
  292. # Domain adaptation dataset on small real datasets
  293. # _, (_, _, domain_sample_test_dataset), _ = load_datasets(["stanford_tissue_microarray",
  294. # "papsociaty"],
  295. # sample_percent=0.5,
  296. # test_percent=100,
  297. # val_percent=0)
  298. for c_base_name, model, augmentations in [
  299. (f"resnet101_{Config.learning_rate}_{Config.decay_rate}_nci_final",
  300. torchvision.models.resnet101(pretrained=True, progress=True), [
  301. "mixup",
  302. # "jit",
  303. # "fda",
  304. # "jit-fda-mixup",
  305. # "shear",
  306. # "std"
  307. ]),
  308. ]:
  309. for aug in augmentations:
  310. Config.reset_random_seeds()
  311. train_model(model, c_base_name, (train_data_loader, val_data_loader, test_data_loader),
  312. augmentation=aug, adaptation_sample_dataset=train_ds)
  313. # evaluate_phase block
  314. if __name__ == '__main__' and Config.evaluate_phase:
  315. # Main data
  316. Config.class_names = [i for i in range(101)]
  317. Config.class_idx_dict = {i: i for i in range(101)}
  318. _, (train_ds, _, _), (_, _, test_data_loader) = load_datasets(
  319. ["national_cancer_institute",
  320. ],
  321. sample_percent=1, test_percent=100, val_percent=0, is_nci_per_slide=True)
  322. for c_base_name, model, aug_best_epoch_list in [
  323. (f"resnet101_{Config.learning_rate}_{Config.decay_rate}_nci_eval",
  324. torchvision.models.resnet101(pretrained=True, progress=True), [
  325. ("mixup", "train_state/resnet101_0.0001_1_nci_final-mixup-BENIGN,MALIGNANT/epoch-19/"),
  326. ]),
  327. # (f"resnet101_{Config.learning_rate}_{Config.decay_rate}_test_nci_eval",
  328. # torchvision.models.resnet101(pretrained=True, progress=True), [
  329. # ("fda",
  330. # "train_state/runs_0.0001_1_nic_test_benign_mal/resnet101_0.0001_1_nci-fda-BENIGN,MALIGNANT/epoch-3/"),
  331. # ("mixup",
  332. # "train_state/runs_0.0001_1_nic_test_benign_mal/resnet101_0.0001_1_nci-mixup-BENIGN,MALIGNANT/epoch-3/"),
  333. # ("jit",
  334. # "train_state/runs_0.0001_1_nic_test_benign_mal/resnet101_0.0001_1_nci-jit-BENIGN,MALIGNANT/epoch-3/"),
  335. # ("jit-fda-mixup",
  336. # "train_state/runs_0.0001_1_nic_test_benign_mal/resnet101_0.0001_1_nci-jit-fda-mixup-BENIGN,MALIGNANT/epoch-3/"),
  337. # ]),
  338. ]:
  339. for aug, best_epoch in aug_best_epoch_list:
  340. Config.reset_random_seeds()
  341. evaluate_nci_dataset_per_slide(c_base_name, aug, model, test_data_loader,
  342. load_model_from_dir=best_epoch)