from __future__ import annotations from pathlib import Path from typing import TYPE_CHECKING, Any import torch from pykeen.evaluation import RankBasedEvaluator from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tools import ( CheckpointManager, DeterministicRandomSampler, TensorBoardHandler, get_pretty_logger, set_seed, ) if TYPE_CHECKING: from data.kg_dataset import KGDataset from tools import CommonParams, TrainingParams from .model_trainers.model_trainer_base import ModelTrainerBase logger = get_pretty_logger(__name__) cpu_device = torch.device("cpu") class Trainer: def __init__( self, train_dataset: KGDataset, val_dataset: KGDataset | None, model_trainer: ModelTrainerBase, common_params: CommonParams, training_params: TrainingParams, device: torch.device = cpu_device, ) -> None: set_seed(training_params.seed) self.train_dataset = train_dataset self.val_dataset = val_dataset if val_dataset is not None else train_dataset self.model_trainer = model_trainer self.common_params = common_params self.training_params = training_params self.device = device self.train_loader = self.model_trainer.create_data_loader( triples_factory=self.train_dataset.triples_factory, batch_size=training_params.batch_size, shuffle=True, ) self.train_iterator = self._data_iterator(self.train_loader) seed = 1234 val_sampler = DeterministicRandomSampler( len(self.val_dataset), sample_size=self.training_params.validation_sample_size, seed=seed, ) self.valid_loader = DataLoader( self.val_dataset, batch_size=training_params.validation_batch_size, shuffle=False, num_workers=training_params.num_workers, sampler=val_sampler, ) self.valid_iterator = self._data_iterator(self.valid_loader) self.step = 0 self.log_every = common_params.log_every self.log_console_every = common_params.log_console_every self.save_dpath = common_params.save_dpath self.save_every = common_params.save_every self.checkpoint_manager = CheckpointManager( root_directory=self.save_dpath, run_name=common_params.run_name, load_only=common_params.evaluate_only, ) self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory if not common_params.evaluate_only: logger.info(f"Saving checkpoints to {self.ckpt_dpath}") if self.common_params.load_path: self.load() self.writer = None if not common_params.evaluate_only: run_name = self.checkpoint_manager.run_name self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}") logger.addHandler(TensorBoardHandler(self.writer)) def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None: """ Log a message to the console and TensorBoard. Args: message (str): The message to log. """ if console_msg is not None: logger.info(f"{console_msg}") if self.writer is not None: self.writer.add_scalar(tag, log_value, self.step) def save(self) -> None: save_dpath = self.ckpt_dpath save_dpath = Path(save_dpath) save_dpath.mkdir(parents=True, exist_ok=True) self.model_trainer.save(save_dpath, self.step) def load(self) -> None: load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path) load_fpath = Path(load_path) if load_fpath.exists(): self.step = self.model_trainer.load(load_fpath) logger.info(f"Model loaded from {load_fpath}.") else: msg = f"Load path {load_fpath} does not exist." raise FileNotFoundError(msg) def _data_iterator(self, loader: DataLoader) -> iter: """Yield batches of positive and negative triples.""" while True: yield from loader def train(self) -> None: """Run `n_steps` optimisation steps; return mean loss.""" data_batch = next(self.train_iterator) loss = self.model_trainer.train_step(data_batch, self.step) if self.step % self.log_console_every == 0: logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}") if self.step % self.log_every == 0: self.log("train/loss", loss) self.step += 1 def evaluate(self) -> dict[str, torch.Tensor]: ks = [1, 3, 10] metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks] evaluator = RankBasedEvaluator( metrics=metrics, filtered=True, batch_size=self.training_params.validation_batch_size, ) result = evaluator.evaluate( model=self.model_trainer.model, mapped_triples=self.val_dataset.triples, additional_filter_triples=[ self.train_dataset.triples, self.val_dataset.triples, ], batch_size=self.training_params.validation_batch_size, ) mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank") hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks} metrics = { "mrr": mrr, **hits_at, } for k, v in metrics.items(): self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}") # ───────────────────────── # # master schedule # ───────────────────────── # def run(self) -> None: """ Alternate training and evaluation until `total_steps` optimisation steps have been executed. """ total_steps = self.training_params.num_train_steps eval_every = self.training_params.eval_every while self.step < total_steps: self.train() if self.step % eval_every == 0: logger.info(f"Evaluating at step {self.step}...") self.evaluate() if self.step % self.save_every == 0: logger.info(f"Saving model at step {self.step}...") self.save() logger.info("Training complete.") def reset(self) -> None: """ Reset the trainer state, including the model trainer and data iterators. """ self.model_trainer.reset() self.train_iterator = self._data_iterator(self.train_loader) self.valid_iterator = self._data_iterator(self.valid_loader) self.step = 0 logger.info("Trainer state has been reset.")