You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 15KB


  1. from os import makedirs, path, remove
  2. from typing import Dict, Tuple, TYPE_CHECKING
  3. from time import time
  4. import torch
  5. import numpy as np
  6. from tqdm import tqdm
  7. from ..models.model import Model
  8. from ..data.data_loader import DataLoader
  9. from ..model_evaluation.evaluator import Evaluator
  10. from ..data.dataflow import DataFlow
  11. if TYPE_CHECKING:
  12. from ..configs.base_config import BaseConfig
  13. class Trainer():
  14. def __init__(self, the_model: Model, conf: 'BaseConfig',
  15. train_loader: DataLoader, val_loader: DataLoader):
  16. self.conf = conf
  17. self.the_model = the_model
  18. self.train_loader = train_loader
  19. self.val_loader = val_loader
  20. self.batch_size = self.conf.batch_size
  21. self.dev_name = self.conf.dev_name
  22. self.the_device = self.conf.device
  23. self.the_model_ptr: Model = the_model
  24. self.train_evaluator: Evaluator = self.conf.evaluator_cls(
  25. self.the_model_ptr, train_loader, self.conf)
  26. self.val_evaluator: Evaluator = self.conf.evaluator_cls(
  27. self.the_model_ptr, val_loader, self.conf)
  28. self.the_model.to(self.conf.device)
  29. self.best_val_metric = float('inf')
  30. self.dataflow = DataFlow(the_model, train_loader,
  31. self.conf.device)
  32. self.optimizer = conf.optimizer_creator(
  33. filter(lambda p: p.requires_grad, self.the_model.parameters())
  34. )
  35. self._val_ref_metric_index = -1
  36. def train(self):
  37. """ Does the training on the received model in initializer based on training samples
  38. and validation samples received in the initializer and saves the model trained in
  39. each epoch."""
  40. start_epoch, best_val_epoch = self.load_state()
  41. print('Training is starting from epoch %d, best val epoch till now is %d' %
  42. (start_epoch, best_val_epoch))
  43. iters_per_epoch = self._calcuate_iters_per_epoch()
  44. self.prepare_for_train(start_epoch)
  45. for epoch in range(start_epoch, self.conf.max_epochs):
  46. s_time = time()
  47. # training the model
  48. t1 = time()
  49. self.train_model_for_one_epoch(epoch, iters_per_epoch)
  50. print('Training epoch %d was done in %.2f secs.' % (epoch, time() - t1,), flush=True)
  51. # evaluating validation data for the trained model
  52. t1 = time()
  53. new_best_val_epoch = self.evaluate_model_after_one_epoch(epoch, best_val_epoch)
  54. print('Validating for epoch %d was done in %.2f secs.' % (epoch, time() - t1,), flush=True)
  55. # Updating the log file
  56. self.save_log_file(epoch)
  57. # Saving epoch
  58. self.save_training_stats_after_epoch(epoch, new_best_val_epoch)
  59. old_best_val_epoch = best_val_epoch
  60. if new_best_val_epoch != best_val_epoch:
  61. print('@@@ Best val epoch was changed from %d to %d :)' % (best_val_epoch, new_best_val_epoch))
  62. # Eliminating best_val_epoch, if in config
  63. if self.conf.keep_best_and_last_epochs_only:
  64. self.eliminate_epoch_specific_information(
  65. self._get_epoch_specific_save_dir(self.conf.save_dir, best_val_epoch))
  66. best_val_epoch = new_best_val_epoch
  67. # Eliminating prev last epoch if it is not the best
  68. if epoch > 0 and \
  69. self.conf.keep_best_and_last_epochs_only and \
  70. epoch - 1 != best_val_epoch and epoch - 1 != old_best_val_epoch:
  71. self.eliminate_epoch_specific_information(
  72. self._get_epoch_specific_save_dir(self.conf.save_dir, epoch - 1))
  73. self.after_epoch_ended(epoch, best_val_epoch)
  74. print('Epoch %d ended in %.2f secs.\n' % (epoch, time() - s_time))
  75. # checking auto-stop!
  76. if self.conf.n_unsuccessful_epochs_to_stop is not None and \
  77. (epoch - best_val_epoch) >= self.conf.n_unsuccessful_epochs_to_stop:
  78. print('Training stopped because the model\'s performance was not improved for %d epochs.' %
  79. (epoch - best_val_epoch,))
  80. break
  81. self.after_train()
  82. # LOADING
  83. def load_state(self, epoch=None) -> Tuple[int, int]:
  84. """
  85. :param epoch: The epoch to load from. If none, would load from the epoch set in config
  86. and if that's not set eighter, would load from the last traine epoch.
  87. :return: The number of the next epoch and the number of the epoch with the best val results
  88. """
  89. save_dir = self.conf.save_dir
  90. makedirs(save_dir, exist_ok=True)
  91. start_epoch = 0
  92. best_val_epoch = 0
  93. if path.exists(save_dir + '/GeneralInfo'):
  94. checkpoint = torch.load(save_dir + '/GeneralInfo')
  95. if epoch is not None:
  96. last_epoch = epoch
  97. best_val_epoch = epoch
  98. elif self.conf.epoch is not None:
  99. last_epoch = int(self.conf.epoch)
  100. best_val_epoch = last_epoch
  101. else:
  102. last_epoch = checkpoint['epoch']
  103. best_val_epoch = checkpoint['best_val_epoch']
  104. start_epoch = last_epoch + 1
  105. self.load_required_parameters(checkpoint, last_epoch)
  106. return start_epoch, best_val_epoch
  107. # PARAMETERS SETTING
  108. def _calcuate_iters_per_epoch(self):
  109. """ Sets the number of iterations per epoch, if defined in the configurations to the defined value,
  110. otherwise to some value to go over the whole samples at least once. """
  111. if self.conf.iters_per_epoch is None:
  112. iters_per_epoch = \
  113. int(np.ceil(
  114. 1.0 * self.train_loader.get_classes_num() *
  115. self.train_loader.get_max_class_samples_num() /
  116. (self.conf.batch_size * self.conf.big_batch_size)))
  117. print('Iterations per epoch is set to %d' % iters_per_epoch, flush=True)
  118. return iters_per_epoch
  119. return self.conf.iters_per_epoch
  120. # MIDDLE TRAINING SAVING AND LOADING
  121. def load_required_parameters(self, checkpoint, epoch):
  122. """ Checkpoint is a dictionary keeping values of important parameters of training.
  123. It has been loaded from the specified location and this function is to load the
  124. values related to the optimizer and all other subclass dependent required things
  125. from that dictionary. """
  126. # loading the optimizer
  127. self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  128. # for backcward compatibility
  129. if 'best_val_loss' in checkpoint:
  130. self.best_val_metric = checkpoint['best_val_loss']
  131. else:
  132. self.best_val_metric = checkpoint['best_val_metric']
  133. # loading the model
  134. checkpoint = torch.load('%s/%d' % (self.conf.save_dir, epoch,))
  135. # for backward compatibility!
  136. if 'model_state_dict' in checkpoint:
  137. checkpoint = checkpoint['model_state_dict']
  138. self.the_model.load_state_dict(checkpoint)
  139. def add_required_state_to_dictionary(self, save_dict):
  140. """ This function adds the important values, e.g. state dictionary of all the used
  141. optimizers and all other subclass dependent required things
  142. to the dictionary passed to it so they will be saved in the process of
  143. saving the training parameters."""
  144. save_dict['optimizer_state_dict'] = self.optimizer.state_dict()
  145. save_dict['best_val_metric'] = self.best_val_metric
  146. # TRAIN PHASES
  147. def prepare_for_train(self, start_epoch):
  148. if start_epoch == 0:
  149. self.the_model.init_weights_from_other_model(self.conf.pretrained_model_file)
  150. # freezeing the parameters of the model!
  151. self.the_model.freeze_parameters(self.conf.freezing_regexes)
  152. def after_epoch_ended(self, epoch: int, best_val_epoch: int):
  153. """ Arbitrary function, is called after training and validation of each epoch is ended"""
  154. pass
  155. def after_train(self):
  156. """ Does the final stuff, frees the memory allocated. """
  157. del self.optimizer
  158. def train_model_for_one_epoch(self, epoch: int, iters_per_epoch: int) -> None:
  159. self.train_evaluator.reset()
  160. # changing the model's status to training mode
  161. self.the_model.train()
  162. iters = int(round(np.ceil(iters_per_epoch), 0))
  163. with self.dataflow, tqdm(total=iters) as pbar:
  164. dataflow_iteration = self.dataflow.iterate()
  165. for i in range(iters):
  166. self.optimizer.zero_grad()
  167. nan_losses_encountered = 0
  168. for _ in range(self.conf.big_batch_size):
  169. model_output: Dict[str, torch.Tensor] = next(dataflow_iteration)
  170. batch_loss = model_output['loss']
  171. if torch.isnan(batch_loss):
  172. nan_losses_encountered += 1
  173. if self.conf.skip_nan_loss:
  174. print('Nan loss at epoch %d iteration %d' % (epoch, i))
  175. for k in model_output:
  176. model_output[k] = model_output[k].detach()
  177. batch_loss.detach()
  178. batch_loss = 0
  179. self.optimizer.zero_grad()
  180. continue
  181. else:
  182. raise Exception('Nan loss at epoch %d iteration %d' % (epoch, i))
  183. batch_loss = (1.0 / self.conf.big_batch_size) * batch_loss
  184. if len(batch_loss.shape) == 0:
  185. batch_loss.backward()
  186. else:
  187. batch_loss.backward(torch.ones_like(batch_loss))
  188. # Detaching outputs
  189. for k in model_output:
  190. model_output[k] = model_output[k].detach()
  191. # updating model's prediction
  192. self.train_evaluator.update_summaries_based_on_model_output(model_output)
  193. # Clearing the optimizer
  194. if nan_losses_encountered < self.conf.big_batch_size:
  195. self.optimizer.step()
  196. self.optimizer.zero_grad()
  197. pbar.update(1)
  198. # printing iteration evaluations
  199. pbar.set_description(self.train_evaluator.get_evaluation_metrics())
  200. # printing the stats
  201. self.train_evaluator.print_evaluation_metrics('*** Epoch %d' % epoch)
  202. def evaluate_model_after_one_epoch(self, epoch: int, best_val_epoch: int) -> int:
  203. self.val_evaluator.reset()
  204. self.val_evaluator.evaluate(self.conf.val_iters_per_epoch)
  205. self.val_evaluator.print_evaluation_metrics('*** Epoch %d' % epoch)
  206. if self._val_ref_metric_index == -1:
  207. if self.conf.title_of_reference_metric_to_choose_best_epoch not in self.val_evaluator.get_titles_of_evaluation_metrics():
  208. print(f'WARNING!!! Your evaluator has no metric {self.conf.title_of_reference_metric_to_choose_best_epoch} to choose best epoch by that! Considering the last epoch instead!')
  209. else:
  210. self._val_ref_metric_index = self.val_evaluator.get_titles_of_evaluation_metrics().index(
  211. self.conf.title_of_reference_metric_to_choose_best_epoch)
  212. self._val_ref_metric_index = self.val_evaluator.get_titles_of_evaluation_metrics().index(
  213. self.conf.title_of_reference_metric_to_choose_best_epoch)
  214. if self._val_ref_metric_index != -1:
  215. val_reference_metric = float(self.val_evaluator.get_values_of_evaluation_metrics()[self._val_ref_metric_index])
  216. if epoch == 0 or self._check_if_reference_metric_for_val_is_improved(self.best_val_metric, val_reference_metric):
  217. self.best_val_metric = val_reference_metric
  218. return epoch
  219. else:
  220. return best_val_epoch
  221. else:
  222. return epoch
  223. def save_log_file(self, epoch):
  224. """ Appends the evaluation metrics related to the given epoch to the
  225. end of the file.
  226. WARNING: VAL AND TRAIN EVALUATORS SHOULD NOT BE RESET
  227. BEFORE CALLING THIS FUNCTION!!!"""
  228. log_dir = self.conf.save_dir + '/log.csv'
  229. # Writing the headers if the file does not exist
  230. if not path.exists(log_dir):
  231. f = open(log_dir, 'w')
  232. f.write(','.join(
  233. ['epoch'] +
  234. ['train_%s' % x for x in self.train_evaluator.get_titles_of_evaluation_metrics()] +
  235. ['val_%s' % x for x in self.val_evaluator.get_titles_of_evaluation_metrics()]) + '\n')
  236. else:
  237. f = open(log_dir, 'a')
  238. # appending the information of the current epoch
  239. # Changing pandas to some cave age code for HPC!
  240. f.write(','.join(
  241. [str(epoch)] + self.train_evaluator.get_values_of_evaluation_metrics() +
  242. self.val_evaluator.get_values_of_evaluation_metrics()
  243. ) + '\n')
  244. f.close()
  245. def save_training_stats_after_epoch(self, epoch: int, best_val_epoch: int) -> None:
  246. """ Saves the training stats related to the last epoch of training and the
  247. trained model at the end pf the epoch, updates best val epoch.
  248. Best val epoch used to be used for keeping that epoch only and
  249. eliminating the model saved in the other epochs but the feature is commented now
  250. as we don't use exact evaluating for validation data. """
  251. save_dir = self.conf.save_dir
  252. save_dict = {'epoch': epoch, 'best_val_epoch': best_val_epoch}
  253. self.add_required_state_to_dictionary(save_dict)
  254. torch.save(save_dict, save_dir + '/GeneralInfo')
  255. epoch_save_dir = self._get_epoch_specific_save_dir(save_dir, epoch)
  256. self.save_epoch_specific_information(epoch_save_dir)
  257. def save_epoch_specific_information(self, epoch_save_dir: str) -> None:
  258. # saving the model
  259. torch.save(self.the_model.state_dict(), epoch_save_dir)
  260. def eliminate_epoch_specific_information(self, epoch_save_dir: str) -> None:
  261. # to prevent space occupation, if the file goes to trash instead of being removed!
  262. if path.exists(epoch_save_dir):
  263. f = open(epoch_save_dir, 'w')
  264. f.close()
  265. remove(epoch_save_dir)
  266. else:
  267. print(f'Warning: file {epoch_save_dir} was not found to be removed!')
  268. @staticmethod
  269. def _get_epoch_specific_save_dir(save_dir: str, epoch: int) -> str:
  270. return '%s/%d' % (save_dir, epoch)
  271. def _check_if_reference_metric_for_val_is_improved(self, old_val, new_val):
  272. if self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '<=':
  273. return new_val <= old_val
  274. elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '<':
  275. return new_val < old_val
  276. elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '>=':
  277. return new_val >= old_val
  278. elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '>':
  279. return new_val > old_val
  280. else:
  281. raise Exception(
  282. f'You can only have <, <=, >, >= for '
  283. f'operator_to_decide_on_improvement_of_val_reference_metric not'
  284. f' {self.conf.operator_to_decide_on_improvement_of_val_reference_metric}')