123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- from __future__ import annotations
-
- from pathlib import Path
- from typing import TYPE_CHECKING, Any
-
- import torch
- from pykeen.evaluation import RankBasedEvaluator
- from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank
- from torch.utils.data import DataLoader
- from torch.utils.tensorboard import SummaryWriter
-
- from tools import (
- CheckpointManager,
- DeterministicRandomSampler,
- TensorBoardHandler,
- get_pretty_logger,
- set_seed,
- )
-
- if TYPE_CHECKING:
- from data.kg_dataset import KGDataset
- from tools import CommonParams, TrainingParams
-
- from .model_trainers.model_trainer_base import ModelTrainerBase
-
-
- logger = get_pretty_logger(__name__)
-
-
- cpu_device = torch.device("cpu")
-
-
- class Trainer:
- def __init__(
- self,
- train_dataset: KGDataset,
- val_dataset: KGDataset | None,
- model_trainer: ModelTrainerBase,
- common_params: CommonParams,
- training_params: TrainingParams,
- device: torch.device = cpu_device,
- ) -> None:
- set_seed(training_params.seed)
-
- self.train_dataset = train_dataset
- self.val_dataset = val_dataset if val_dataset is not None else train_dataset
- self.model_trainer = model_trainer
-
- self.common_params = common_params
- self.training_params = training_params
-
- self.device = device
-
- self.train_loader = self.model_trainer.create_data_loader(
- triples_factory=self.train_dataset.triples_factory,
- batch_size=training_params.batch_size,
- shuffle=True,
- )
-
- self.train_iterator = self._data_iterator(self.train_loader)
-
- seed = 1234
- val_sampler = DeterministicRandomSampler(
- len(self.val_dataset),
- sample_size=self.training_params.validation_sample_size,
- seed=seed,
- )
- self.valid_loader = DataLoader(
- self.val_dataset,
- batch_size=training_params.validation_batch_size,
- shuffle=False,
- num_workers=training_params.num_workers,
- sampler=val_sampler,
- )
- self.valid_iterator = self._data_iterator(self.valid_loader)
-
- self.step = 0
- self.log_every = common_params.log_every
- self.log_console_every = common_params.log_console_every
-
- self.save_dpath = common_params.save_dpath
- self.save_every = common_params.save_every
-
- self.checkpoint_manager = CheckpointManager(
- root_directory=self.save_dpath,
- run_name=common_params.run_name,
- load_only=common_params.evaluate_only,
- )
- self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory
- if not common_params.evaluate_only:
- logger.info(f"Saving checkpoints to {self.ckpt_dpath}")
-
- if self.common_params.load_path:
- self.load()
-
- self.writer = None
- if not common_params.evaluate_only:
- run_name = self.checkpoint_manager.run_name
- self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}")
- logger.addHandler(TensorBoardHandler(self.writer))
-
- def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None:
- """
- Log a message to the console and TensorBoard.
-
- Args:
- message (str): The message to log.
- """
- if console_msg is not None:
- logger.info(f"{console_msg}")
- if self.writer is not None:
- self.writer.add_scalar(tag, log_value, self.step)
-
- def save(self) -> None:
- save_dpath = self.ckpt_dpath
-
- save_dpath = Path(save_dpath)
- save_dpath.mkdir(parents=True, exist_ok=True)
- self.model_trainer.save(save_dpath, self.step)
-
- def load(self) -> None:
- load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path)
- load_fpath = Path(load_path)
-
- if load_fpath.exists():
- self.step = self.model_trainer.load(load_fpath)
- logger.info(f"Model loaded from {load_fpath}.")
- else:
- msg = f"Load path {load_fpath} does not exist."
- raise FileNotFoundError(msg)
-
- def _data_iterator(self, loader: DataLoader) -> iter:
- """Yield batches of positive and negative triples."""
- while True:
- yield from loader
-
- def train(self) -> None:
- """Run `n_steps` optimisation steps; return mean loss."""
- data_batch = next(self.train_iterator)
-
- loss = self.model_trainer.train_step(data_batch, self.step)
-
- if self.step % self.log_console_every == 0:
- logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}")
- if self.step % self.log_every == 0:
- self.log("train/loss", loss)
-
- self.step += 1
-
- def evaluate(self) -> dict[str, torch.Tensor]:
- ks = [1, 3, 10]
-
- metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks]
-
- evaluator = RankBasedEvaluator(
- metrics=metrics,
- filtered=True,
- batch_size=self.training_params.validation_batch_size,
- )
-
- result = evaluator.evaluate(
- model=self.model_trainer.model,
- mapped_triples=self.val_dataset.triples,
- additional_filter_triples=[
- self.train_dataset.triples,
- self.val_dataset.triples,
- ],
- batch_size=self.training_params.validation_batch_size,
- )
-
- mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank")
- hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks}
-
- metrics = {
- "mrr": mrr,
- **hits_at,
- }
-
- for k, v in metrics.items():
- self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}")
-
- # ───────────────────────── #
- # master schedule
- # ───────────────────────── #
- def run(self) -> None:
- """
- Alternate training and evaluation until `total_steps`
- optimisation steps have been executed.
- """
-
- total_steps = self.training_params.num_train_steps
- eval_every = self.training_params.eval_every
-
- while self.step < total_steps:
- self.train()
- if self.step % eval_every == 0:
- logger.info(f"Evaluating at step {self.step}...")
- self.evaluate()
-
- if self.step % self.save_every == 0:
- logger.info(f"Saving model at step {self.step}...")
- self.save()
-
- logger.info("Training complete.")
-
- def reset(self) -> None:
- """
- Reset the trainer state, including the model trainer and data iterators.
- """
- self.model_trainer.reset()
- self.train_iterator = self._data_iterator(self.train_loader)
- self.valid_iterator = self._data_iterator(self.valid_loader)
- self.step = 0
- logger.info("Trainer state has been reset.")
|