from os import makedirs, path, remove from typing import Dict, Tuple, TYPE_CHECKING from time import time import torch import numpy as np from tqdm import tqdm from ..models.model import Model from ..data.data_loader import DataLoader from ..model_evaluation.evaluator import Evaluator from ..data.dataflow import DataFlow if TYPE_CHECKING: from ..configs.base_config import BaseConfig class Trainer(): def __init__(self, the_model: Model, conf: 'BaseConfig', train_loader: DataLoader, val_loader: DataLoader): self.conf = conf self.the_model = the_model self.train_loader = train_loader self.val_loader = val_loader self.batch_size = self.conf.batch_size self.dev_name = self.conf.dev_name self.the_device = self.conf.device self.the_model_ptr: Model = the_model self.train_evaluator: Evaluator = self.conf.evaluator_cls( self.the_model_ptr, train_loader, self.conf) self.val_evaluator: Evaluator = self.conf.evaluator_cls( self.the_model_ptr, val_loader, self.conf) self.the_model.to(self.conf.device) self.best_val_metric = float('inf') self.dataflow = DataFlow(the_model, train_loader, self.conf.device) self.optimizer = conf.optimizer_creator( filter(lambda p: p.requires_grad, self.the_model.parameters()) ) self._val_ref_metric_index = -1 def train(self): """ Does the training on the received model in initializer based on training samples and validation samples received in the initializer and saves the model trained in each epoch.""" start_epoch, best_val_epoch = self.load_state() print('Training is starting from epoch %d, best val epoch till now is %d' % (start_epoch, best_val_epoch)) iters_per_epoch = self._calcuate_iters_per_epoch() self.prepare_for_train(start_epoch) for epoch in range(start_epoch, self.conf.max_epochs): s_time = time() # training the model t1 = time() self.train_model_for_one_epoch(epoch, iters_per_epoch) print('Training epoch %d was done in %.2f secs.' % (epoch, time() - t1,), flush=True) # evaluating validation data for the trained model t1 = time() new_best_val_epoch = self.evaluate_model_after_one_epoch(epoch, best_val_epoch) print('Validating for epoch %d was done in %.2f secs.' % (epoch, time() - t1,), flush=True) # Updating the log file self.save_log_file(epoch) # Saving epoch self.save_training_stats_after_epoch(epoch, new_best_val_epoch) old_best_val_epoch = best_val_epoch if new_best_val_epoch != best_val_epoch: print('@@@ Best val epoch was changed from %d to %d :)' % (best_val_epoch, new_best_val_epoch)) # Eliminating best_val_epoch, if in config if self.conf.keep_best_and_last_epochs_only: self.eliminate_epoch_specific_information( self._get_epoch_specific_save_dir(self.conf.save_dir, best_val_epoch)) best_val_epoch = new_best_val_epoch # Eliminating prev last epoch if it is not the best if epoch > 0 and \ self.conf.keep_best_and_last_epochs_only and \ epoch - 1 != best_val_epoch and epoch - 1 != old_best_val_epoch: self.eliminate_epoch_specific_information( self._get_epoch_specific_save_dir(self.conf.save_dir, epoch - 1)) self.after_epoch_ended(epoch, best_val_epoch) print('Epoch %d ended in %.2f secs.\n' % (epoch, time() - s_time)) # checking auto-stop! if self.conf.n_unsuccessful_epochs_to_stop is not None and \ (epoch - best_val_epoch) >= self.conf.n_unsuccessful_epochs_to_stop: print('Training stopped because the model\'s performance was not improved for %d epochs.' % (epoch - best_val_epoch,)) break self.after_train() # LOADING def load_state(self, epoch=None) -> Tuple[int, int]: """ :param epoch: The epoch to load from. If none, would load from the epoch set in config and if that's not set eighter, would load from the last traine epoch. :return: The number of the next epoch and the number of the epoch with the best val results """ save_dir = self.conf.save_dir makedirs(save_dir, exist_ok=True) start_epoch = 0 best_val_epoch = 0 if path.exists(save_dir + '/GeneralInfo'): checkpoint = torch.load(save_dir + '/GeneralInfo') if epoch is not None: last_epoch = epoch best_val_epoch = epoch elif self.conf.epoch is not None: last_epoch = int(self.conf.epoch) best_val_epoch = last_epoch else: last_epoch = checkpoint['epoch'] best_val_epoch = checkpoint['best_val_epoch'] start_epoch = last_epoch + 1 self.load_required_parameters(checkpoint, last_epoch) return start_epoch, best_val_epoch # PARAMETERS SETTING def _calcuate_iters_per_epoch(self): """ Sets the number of iterations per epoch, if defined in the configurations to the defined value, otherwise to some value to go over the whole samples at least once. """ if self.conf.iters_per_epoch is None: iters_per_epoch = \ int(np.ceil( 1.0 * self.train_loader.get_classes_num() * self.train_loader.get_max_class_samples_num() / (self.conf.batch_size * self.conf.big_batch_size))) print('Iterations per epoch is set to %d' % iters_per_epoch, flush=True) return iters_per_epoch return self.conf.iters_per_epoch # MIDDLE TRAINING SAVING AND LOADING def load_required_parameters(self, checkpoint, epoch): """ Checkpoint is a dictionary keeping values of important parameters of training. It has been loaded from the specified location and this function is to load the values related to the optimizer and all other subclass dependent required things from that dictionary. """ # loading the optimizer self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # for backcward compatibility if 'best_val_loss' in checkpoint: self.best_val_metric = checkpoint['best_val_loss'] else: self.best_val_metric = checkpoint['best_val_metric'] # loading the model checkpoint = torch.load('%s/%d' % (self.conf.save_dir, epoch,)) # for backward compatibility! if 'model_state_dict' in checkpoint: checkpoint = checkpoint['model_state_dict'] self.the_model.load_state_dict(checkpoint) def add_required_state_to_dictionary(self, save_dict): """ This function adds the important values, e.g. state dictionary of all the used optimizers and all other subclass dependent required things to the dictionary passed to it so they will be saved in the process of saving the training parameters.""" save_dict['optimizer_state_dict'] = self.optimizer.state_dict() save_dict['best_val_metric'] = self.best_val_metric # TRAIN PHASES def prepare_for_train(self, start_epoch): if start_epoch == 0: self.the_model.init_weights_from_other_model(self.conf.pretrained_model_file) # freezeing the parameters of the model! self.the_model.freeze_parameters(self.conf.freezing_regexes) def after_epoch_ended(self, epoch: int, best_val_epoch: int): """ Arbitrary function, is called after training and validation of each epoch is ended""" pass def after_train(self): """ Does the final stuff, frees the memory allocated. """ del self.optimizer def train_model_for_one_epoch(self, epoch: int, iters_per_epoch: int) -> None: self.train_evaluator.reset() # changing the model's status to training mode self.the_model.train() iters = int(round(np.ceil(iters_per_epoch), 0)) with self.dataflow, tqdm(total=iters) as pbar: dataflow_iteration = self.dataflow.iterate() for i in range(iters): self.optimizer.zero_grad() nan_losses_encountered = 0 for _ in range(self.conf.big_batch_size): model_output: Dict[str, torch.Tensor] = next(dataflow_iteration) batch_loss = model_output['loss'] if torch.isnan(batch_loss): nan_losses_encountered += 1 if self.conf.skip_nan_loss: print('Nan loss at epoch %d iteration %d' % (epoch, i)) for k in model_output: model_output[k] = model_output[k].detach() batch_loss.detach() batch_loss = 0 self.optimizer.zero_grad() continue else: raise Exception('Nan loss at epoch %d iteration %d' % (epoch, i)) batch_loss = (1.0 / self.conf.big_batch_size) * batch_loss if len(batch_loss.shape) == 0: batch_loss.backward() else: batch_loss.backward(torch.ones_like(batch_loss)) # Detaching outputs for k in model_output: model_output[k] = model_output[k].detach() # updating model's prediction self.train_evaluator.update_summaries_based_on_model_output(model_output) # Clearing the optimizer if nan_losses_encountered < self.conf.big_batch_size: self.optimizer.step() self.optimizer.zero_grad() pbar.update(1) # printing iteration evaluations pbar.set_description(self.train_evaluator.get_evaluation_metrics()) # printing the stats self.train_evaluator.print_evaluation_metrics('*** Epoch %d' % epoch) def evaluate_model_after_one_epoch(self, epoch: int, best_val_epoch: int) -> int: self.val_evaluator.reset() self.val_evaluator.evaluate(self.conf.val_iters_per_epoch) self.val_evaluator.print_evaluation_metrics('*** Epoch %d' % epoch) if self._val_ref_metric_index == -1: if self.conf.title_of_reference_metric_to_choose_best_epoch not in self.val_evaluator.get_titles_of_evaluation_metrics(): print(f'WARNING!!! Your evaluator has no metric {self.conf.title_of_reference_metric_to_choose_best_epoch} to choose best epoch by that! Considering the last epoch instead!') else: self._val_ref_metric_index = self.val_evaluator.get_titles_of_evaluation_metrics().index( self.conf.title_of_reference_metric_to_choose_best_epoch) self._val_ref_metric_index = self.val_evaluator.get_titles_of_evaluation_metrics().index( self.conf.title_of_reference_metric_to_choose_best_epoch) if self._val_ref_metric_index != -1: val_reference_metric = float(self.val_evaluator.get_values_of_evaluation_metrics()[self._val_ref_metric_index]) if epoch == 0 or self._check_if_reference_metric_for_val_is_improved(self.best_val_metric, val_reference_metric): self.best_val_metric = val_reference_metric return epoch else: return best_val_epoch else: return epoch def save_log_file(self, epoch): """ Appends the evaluation metrics related to the given epoch to the end of the file. WARNING: VAL AND TRAIN EVALUATORS SHOULD NOT BE RESET BEFORE CALLING THIS FUNCTION!!!""" log_dir = self.conf.save_dir + '/log.csv' # Writing the headers if the file does not exist if not path.exists(log_dir): f = open(log_dir, 'w') f.write(','.join( ['epoch'] + ['train_%s' % x for x in self.train_evaluator.get_titles_of_evaluation_metrics()] + ['val_%s' % x for x in self.val_evaluator.get_titles_of_evaluation_metrics()]) + '\n') else: f = open(log_dir, 'a') # appending the information of the current epoch # Changing pandas to some cave age code for HPC! f.write(','.join( [str(epoch)] + self.train_evaluator.get_values_of_evaluation_metrics() + self.val_evaluator.get_values_of_evaluation_metrics() ) + '\n') f.close() def save_training_stats_after_epoch(self, epoch: int, best_val_epoch: int) -> None: """ Saves the training stats related to the last epoch of training and the trained model at the end pf the epoch, updates best val epoch. Best val epoch used to be used for keeping that epoch only and eliminating the model saved in the other epochs but the feature is commented now as we don't use exact evaluating for validation data. """ save_dir = self.conf.save_dir save_dict = {'epoch': epoch, 'best_val_epoch': best_val_epoch} self.add_required_state_to_dictionary(save_dict) torch.save(save_dict, save_dir + '/GeneralInfo') epoch_save_dir = self._get_epoch_specific_save_dir(save_dir, epoch) self.save_epoch_specific_information(epoch_save_dir) def save_epoch_specific_information(self, epoch_save_dir: str) -> None: # saving the model torch.save(self.the_model.state_dict(), epoch_save_dir) def eliminate_epoch_specific_information(self, epoch_save_dir: str) -> None: # to prevent space occupation, if the file goes to trash instead of being removed! if path.exists(epoch_save_dir): f = open(epoch_save_dir, 'w') f.close() remove(epoch_save_dir) else: print(f'Warning: file {epoch_save_dir} was not found to be removed!') @staticmethod def _get_epoch_specific_save_dir(save_dir: str, epoch: int) -> str: return '%s/%d' % (save_dir, epoch) def _check_if_reference_metric_for_val_is_improved(self, old_val, new_val): if self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '<=': return new_val <= old_val elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '<': return new_val < old_val elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '>=': return new_val >= old_val elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '>': return new_val > old_val else: raise Exception( f'You can only have <, <=, >, >= for ' f'operator_to_decide_on_improvement_of_val_reference_metric not' f' {self.conf.operator_to_decide_on_improvement_of_val_reference_metric}')