|
|
- from os import makedirs, path, remove
- from typing import Dict, Tuple, TYPE_CHECKING
- from time import time
-
- import torch
- import numpy as np
- from tqdm import tqdm
-
- from ..models.model import Model
- from ..data.data_loader import DataLoader
- from ..model_evaluation.evaluator import Evaluator
- from ..data.dataflow import DataFlow
- if TYPE_CHECKING:
- from ..configs.base_config import BaseConfig
-
-
- class Trainer():
-
- def __init__(self, the_model: Model, conf: 'BaseConfig',
- train_loader: DataLoader, val_loader: DataLoader):
- self.conf = conf
- self.the_model = the_model
-
- self.train_loader = train_loader
- self.val_loader = val_loader
-
- self.batch_size = self.conf.batch_size
- self.dev_name = self.conf.dev_name
- self.the_device = self.conf.device
- self.the_model_ptr: Model = the_model
-
- self.train_evaluator: Evaluator = self.conf.evaluator_cls(
- self.the_model_ptr, train_loader, self.conf)
- self.val_evaluator: Evaluator = self.conf.evaluator_cls(
- self.the_model_ptr, val_loader, self.conf)
-
- self.the_model.to(self.conf.device)
- self.best_val_metric = float('inf')
-
- self.dataflow = DataFlow(the_model, train_loader,
- self.conf.device)
-
- self.optimizer = conf.optimizer_creator(
- filter(lambda p: p.requires_grad, self.the_model.parameters())
- )
-
- self._val_ref_metric_index = -1
-
- def train(self):
- """ Does the training on the received model in initializer based on training samples
- and validation samples received in the initializer and saves the model trained in
- each epoch."""
-
- start_epoch, best_val_epoch = self.load_state()
-
- print('Training is starting from epoch %d, best val epoch till now is %d' %
- (start_epoch, best_val_epoch))
-
- iters_per_epoch = self._calcuate_iters_per_epoch()
-
- self.prepare_for_train(start_epoch)
-
- for epoch in range(start_epoch, self.conf.max_epochs):
-
- s_time = time()
-
- # training the model
- t1 = time()
- self.train_model_for_one_epoch(epoch, iters_per_epoch)
- print('Training epoch %d was done in %.2f secs.' % (epoch, time() - t1,), flush=True)
-
- # evaluating validation data for the trained model
- t1 = time()
- new_best_val_epoch = self.evaluate_model_after_one_epoch(epoch, best_val_epoch)
-
- print('Validating for epoch %d was done in %.2f secs.' % (epoch, time() - t1,), flush=True)
-
- # Updating the log file
- self.save_log_file(epoch)
-
- # Saving epoch
- self.save_training_stats_after_epoch(epoch, new_best_val_epoch)
- old_best_val_epoch = best_val_epoch
-
- if new_best_val_epoch != best_val_epoch:
- print('@@@ Best val epoch was changed from %d to %d :)' % (best_val_epoch, new_best_val_epoch))
-
- # Eliminating best_val_epoch, if in config
- if self.conf.keep_best_and_last_epochs_only:
- self.eliminate_epoch_specific_information(
- self._get_epoch_specific_save_dir(self.conf.save_dir, best_val_epoch))
-
- best_val_epoch = new_best_val_epoch
-
- # Eliminating prev last epoch if it is not the best
- if epoch > 0 and \
- self.conf.keep_best_and_last_epochs_only and \
- epoch - 1 != best_val_epoch and epoch - 1 != old_best_val_epoch:
- self.eliminate_epoch_specific_information(
- self._get_epoch_specific_save_dir(self.conf.save_dir, epoch - 1))
-
- self.after_epoch_ended(epoch, best_val_epoch)
- print('Epoch %d ended in %.2f secs.\n' % (epoch, time() - s_time))
-
- # checking auto-stop!
- if self.conf.n_unsuccessful_epochs_to_stop is not None and \
- (epoch - best_val_epoch) >= self.conf.n_unsuccessful_epochs_to_stop:
- print('Training stopped because the model\'s performance was not improved for %d epochs.' %
- (epoch - best_val_epoch,))
- break
-
- self.after_train()
-
- # LOADING
-
- def load_state(self, epoch=None) -> Tuple[int, int]:
- """
- :param epoch: The epoch to load from. If none, would load from the epoch set in config
- and if that's not set eighter, would load from the last traine epoch.
- :return: The number of the next epoch and the number of the epoch with the best val results
- """
-
- save_dir = self.conf.save_dir
- makedirs(save_dir, exist_ok=True)
-
- start_epoch = 0
- best_val_epoch = 0
-
- if path.exists(save_dir + '/GeneralInfo'):
-
- checkpoint = torch.load(save_dir + '/GeneralInfo')
-
- if epoch is not None:
- last_epoch = epoch
- best_val_epoch = epoch
- elif self.conf.epoch is not None:
- last_epoch = int(self.conf.epoch)
- best_val_epoch = last_epoch
- else:
- last_epoch = checkpoint['epoch']
- best_val_epoch = checkpoint['best_val_epoch']
-
- start_epoch = last_epoch + 1
- self.load_required_parameters(checkpoint, last_epoch)
-
- return start_epoch, best_val_epoch
-
- # PARAMETERS SETTING
-
- def _calcuate_iters_per_epoch(self):
- """ Sets the number of iterations per epoch, if defined in the configurations to the defined value,
- otherwise to some value to go over the whole samples at least once. """
-
- if self.conf.iters_per_epoch is None:
- iters_per_epoch = \
- int(np.ceil(
- 1.0 * self.train_loader.get_classes_num() *
- self.train_loader.get_max_class_samples_num() /
- (self.conf.batch_size * self.conf.big_batch_size)))
-
- print('Iterations per epoch is set to %d' % iters_per_epoch, flush=True)
- return iters_per_epoch
-
- return self.conf.iters_per_epoch
-
- # MIDDLE TRAINING SAVING AND LOADING
-
- def load_required_parameters(self, checkpoint, epoch):
- """ Checkpoint is a dictionary keeping values of important parameters of training.
- It has been loaded from the specified location and this function is to load the
- values related to the optimizer and all other subclass dependent required things
- from that dictionary. """
-
- # loading the optimizer
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
-
- # for backcward compatibility
- if 'best_val_loss' in checkpoint:
- self.best_val_metric = checkpoint['best_val_loss']
- else:
- self.best_val_metric = checkpoint['best_val_metric']
-
- # loading the model
- checkpoint = torch.load('%s/%d' % (self.conf.save_dir, epoch,))
-
- # for backward compatibility!
- if 'model_state_dict' in checkpoint:
- checkpoint = checkpoint['model_state_dict']
-
- self.the_model.load_state_dict(checkpoint)
-
- def add_required_state_to_dictionary(self, save_dict):
- """ This function adds the important values, e.g. state dictionary of all the used
- optimizers and all other subclass dependent required things
- to the dictionary passed to it so they will be saved in the process of
- saving the training parameters."""
- save_dict['optimizer_state_dict'] = self.optimizer.state_dict()
- save_dict['best_val_metric'] = self.best_val_metric
-
- # TRAIN PHASES
-
- def prepare_for_train(self, start_epoch):
-
- if start_epoch == 0:
- self.the_model.init_weights_from_other_model(self.conf.pretrained_model_file)
-
- # freezeing the parameters of the model!
- self.the_model.freeze_parameters(self.conf.freezing_regexes)
-
- def after_epoch_ended(self, epoch: int, best_val_epoch: int):
- """ Arbitrary function, is called after training and validation of each epoch is ended"""
- pass
-
- def after_train(self):
- """ Does the final stuff, frees the memory allocated. """
- del self.optimizer
-
- def train_model_for_one_epoch(self, epoch: int, iters_per_epoch: int) -> None:
-
- self.train_evaluator.reset()
-
- # changing the model's status to training mode
- self.the_model.train()
-
- iters = int(round(np.ceil(iters_per_epoch), 0))
- with self.dataflow, tqdm(total=iters) as pbar:
- dataflow_iteration = self.dataflow.iterate()
-
- for i in range(iters):
-
- self.optimizer.zero_grad()
- nan_losses_encountered = 0
-
- for _ in range(self.conf.big_batch_size):
- model_output: Dict[str, torch.Tensor] = next(dataflow_iteration)
- batch_loss = model_output['loss']
-
- if torch.isnan(batch_loss):
- nan_losses_encountered += 1
- if self.conf.skip_nan_loss:
- print('Nan loss at epoch %d iteration %d' % (epoch, i))
- for k in model_output:
- model_output[k] = model_output[k].detach()
- batch_loss.detach()
- batch_loss = 0
- self.optimizer.zero_grad()
- continue
- else:
- raise Exception('Nan loss at epoch %d iteration %d' % (epoch, i))
-
- batch_loss = (1.0 / self.conf.big_batch_size) * batch_loss
-
- if len(batch_loss.shape) == 0:
- batch_loss.backward()
- else:
- batch_loss.backward(torch.ones_like(batch_loss))
-
- # Detaching outputs
- for k in model_output:
- model_output[k] = model_output[k].detach()
-
- # updating model's prediction
- self.train_evaluator.update_summaries_based_on_model_output(model_output)
-
- # Clearing the optimizer
- if nan_losses_encountered < self.conf.big_batch_size:
- self.optimizer.step()
- self.optimizer.zero_grad()
-
- pbar.update(1)
- # printing iteration evaluations
- pbar.set_description(self.train_evaluator.get_evaluation_metrics())
-
- # printing the stats
- self.train_evaluator.print_evaluation_metrics('*** Epoch %d' % epoch)
-
- def evaluate_model_after_one_epoch(self, epoch: int, best_val_epoch: int) -> int:
-
- self.val_evaluator.reset()
- self.val_evaluator.evaluate(self.conf.val_iters_per_epoch)
- self.val_evaluator.print_evaluation_metrics('*** Epoch %d' % epoch)
-
- if self._val_ref_metric_index == -1:
- if self.conf.title_of_reference_metric_to_choose_best_epoch not in self.val_evaluator.get_titles_of_evaluation_metrics():
- print(f'WARNING!!! Your evaluator has no metric {self.conf.title_of_reference_metric_to_choose_best_epoch} to choose best epoch by that! Considering the last epoch instead!')
- else:
- self._val_ref_metric_index = self.val_evaluator.get_titles_of_evaluation_metrics().index(
- self.conf.title_of_reference_metric_to_choose_best_epoch)
-
- self._val_ref_metric_index = self.val_evaluator.get_titles_of_evaluation_metrics().index(
- self.conf.title_of_reference_metric_to_choose_best_epoch)
-
- if self._val_ref_metric_index != -1:
- val_reference_metric = float(self.val_evaluator.get_values_of_evaluation_metrics()[self._val_ref_metric_index])
- if epoch == 0 or self._check_if_reference_metric_for_val_is_improved(self.best_val_metric, val_reference_metric):
- self.best_val_metric = val_reference_metric
- return epoch
- else:
- return best_val_epoch
- else:
- return epoch
-
- def save_log_file(self, epoch):
- """ Appends the evaluation metrics related to the given epoch to the
- end of the file.
- WARNING: VAL AND TRAIN EVALUATORS SHOULD NOT BE RESET
- BEFORE CALLING THIS FUNCTION!!!"""
-
- log_dir = self.conf.save_dir + '/log.csv'
-
- # Writing the headers if the file does not exist
- if not path.exists(log_dir):
- f = open(log_dir, 'w')
- f.write(','.join(
- ['epoch'] +
- ['train_%s' % x for x in self.train_evaluator.get_titles_of_evaluation_metrics()] +
- ['val_%s' % x for x in self.val_evaluator.get_titles_of_evaluation_metrics()]) + '\n')
- else:
- f = open(log_dir, 'a')
-
- # appending the information of the current epoch
- # Changing pandas to some cave age code for HPC!
- f.write(','.join(
- [str(epoch)] + self.train_evaluator.get_values_of_evaluation_metrics() +
- self.val_evaluator.get_values_of_evaluation_metrics()
- ) + '\n')
-
- f.close()
-
- def save_training_stats_after_epoch(self, epoch: int, best_val_epoch: int) -> None:
- """ Saves the training stats related to the last epoch of training and the
- trained model at the end pf the epoch, updates best val epoch.
- Best val epoch used to be used for keeping that epoch only and
- eliminating the model saved in the other epochs but the feature is commented now
- as we don't use exact evaluating for validation data. """
- save_dir = self.conf.save_dir
-
- save_dict = {'epoch': epoch, 'best_val_epoch': best_val_epoch}
- self.add_required_state_to_dictionary(save_dict)
- torch.save(save_dict, save_dir + '/GeneralInfo')
-
- epoch_save_dir = self._get_epoch_specific_save_dir(save_dir, epoch)
- self.save_epoch_specific_information(epoch_save_dir)
-
- def save_epoch_specific_information(self, epoch_save_dir: str) -> None:
- # saving the model
- torch.save(self.the_model.state_dict(), epoch_save_dir)
-
- def eliminate_epoch_specific_information(self, epoch_save_dir: str) -> None:
- # to prevent space occupation, if the file goes to trash instead of being removed!
- if path.exists(epoch_save_dir):
- f = open(epoch_save_dir, 'w')
- f.close()
- remove(epoch_save_dir)
- else:
- print(f'Warning: file {epoch_save_dir} was not found to be removed!')
-
- @staticmethod
- def _get_epoch_specific_save_dir(save_dir: str, epoch: int) -> str:
- return '%s/%d' % (save_dir, epoch)
-
- def _check_if_reference_metric_for_val_is_improved(self, old_val, new_val):
-
- if self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '<=':
- return new_val <= old_val
- elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '<':
- return new_val < old_val
- elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '>=':
- return new_val >= old_val
- elif self.conf.operator_to_decide_on_improvement_of_val_reference_metric == '>':
- return new_val > old_val
- else:
- raise Exception(
- f'You can only have <, <=, >, >= for '
- f'operator_to_decide_on_improvement_of_val_reference_metric not'
- f' {self.conf.operator_to_decide_on_improvement_of_val_reference_metric}')
|