You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

trainer.py 6.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from typing import TYPE_CHECKING, Any
  4. import torch
  5. from pykeen.evaluation import RankBasedEvaluator
  6. from pykeen.metrics.ranking import HitsAtK, InverseHarmonicMeanRank
  7. from torch.utils.data import DataLoader
  8. from torch.utils.tensorboard import SummaryWriter
  9. from tools import (
  10. CheckpointManager,
  11. DeterministicRandomSampler,
  12. TensorBoardHandler,
  13. get_pretty_logger,
  14. set_seed,
  15. )
  16. if TYPE_CHECKING:
  17. from data.kg_dataset import KGDataset
  18. from tools import CommonParams, TrainingParams
  19. from .model_trainers.model_trainer_base import ModelTrainerBase
  20. logger = get_pretty_logger(__name__)
  21. cpu_device = torch.device("cpu")
  22. class Trainer:
  23. def __init__(
  24. self,
  25. train_dataset: KGDataset,
  26. val_dataset: KGDataset | None,
  27. model_trainer: ModelTrainerBase,
  28. common_params: CommonParams,
  29. training_params: TrainingParams,
  30. device: torch.device = cpu_device,
  31. ) -> None:
  32. set_seed(training_params.seed)
  33. self.train_dataset = train_dataset
  34. self.val_dataset = val_dataset if val_dataset is not None else train_dataset
  35. self.model_trainer = model_trainer
  36. self.common_params = common_params
  37. self.training_params = training_params
  38. self.device = device
  39. self.train_loader = self.model_trainer.create_data_loader(
  40. triples_factory=self.train_dataset.triples_factory,
  41. batch_size=training_params.batch_size,
  42. shuffle=True,
  43. )
  44. self.train_iterator = self._data_iterator(self.train_loader)
  45. seed = 1234
  46. val_sampler = DeterministicRandomSampler(
  47. len(self.val_dataset),
  48. sample_size=self.training_params.validation_sample_size,
  49. seed=seed,
  50. )
  51. self.valid_loader = DataLoader(
  52. self.val_dataset,
  53. batch_size=training_params.validation_batch_size,
  54. shuffle=False,
  55. num_workers=training_params.num_workers,
  56. sampler=val_sampler,
  57. )
  58. self.valid_iterator = self._data_iterator(self.valid_loader)
  59. self.step = 0
  60. self.log_every = common_params.log_every
  61. self.log_console_every = common_params.log_console_every
  62. self.save_dpath = common_params.save_dpath
  63. self.save_every = common_params.save_every
  64. self.checkpoint_manager = CheckpointManager(
  65. root_directory=self.save_dpath,
  66. run_name=common_params.run_name,
  67. load_only=common_params.evaluate_only,
  68. )
  69. self.ckpt_dpath = self.checkpoint_manager.checkpoint_directory
  70. if not common_params.evaluate_only:
  71. logger.info(f"Saving checkpoints to {self.ckpt_dpath}")
  72. if self.common_params.load_path:
  73. self.load()
  74. self.writer = None
  75. if not common_params.evaluate_only:
  76. run_name = self.checkpoint_manager.run_name
  77. self.writer = SummaryWriter(log_dir=f"{common_params.log_dir}/{run_name}")
  78. logger.addHandler(TensorBoardHandler(self.writer))
  79. def log(self, tag: str, log_value: Any, console_msg: str | None = None) -> None:
  80. """
  81. Log a message to the console and TensorBoard.
  82. Args:
  83. message (str): The message to log.
  84. """
  85. if console_msg is not None:
  86. logger.info(f"{console_msg}")
  87. if self.writer is not None:
  88. self.writer.add_scalar(tag, log_value, self.step)
  89. def save(self) -> None:
  90. save_dpath = self.ckpt_dpath
  91. save_dpath = Path(save_dpath)
  92. save_dpath.mkdir(parents=True, exist_ok=True)
  93. self.model_trainer.save(save_dpath, self.step)
  94. def load(self) -> None:
  95. load_path = self.checkpoint_manager.get_model_fpath(self.common_params.load_path)
  96. load_fpath = Path(load_path)
  97. if load_fpath.exists():
  98. self.step = self.model_trainer.load(load_fpath)
  99. logger.info(f"Model loaded from {load_fpath}.")
  100. else:
  101. msg = f"Load path {load_fpath} does not exist."
  102. raise FileNotFoundError(msg)
  103. def _data_iterator(self, loader: DataLoader) -> iter:
  104. """Yield batches of positive and negative triples."""
  105. while True:
  106. yield from loader
  107. def train(self) -> None:
  108. """Run `n_steps` optimisation steps; return mean loss."""
  109. data_batch = next(self.train_iterator)
  110. loss = self.model_trainer.train_step(data_batch, self.step)
  111. if self.step % self.log_console_every == 0:
  112. logger.info(f"[train] step {self.step:6d} | loss = {loss:.4f}")
  113. if self.step % self.log_every == 0:
  114. self.log("train/loss", loss)
  115. self.step += 1
  116. def evaluate(self) -> dict[str, torch.Tensor]:
  117. ks = [1, 3, 10]
  118. metrics = [InverseHarmonicMeanRank()] + [HitsAtK(k) for k in ks]
  119. evaluator = RankBasedEvaluator(
  120. metrics=metrics,
  121. filtered=True,
  122. batch_size=self.training_params.validation_batch_size,
  123. )
  124. result = evaluator.evaluate(
  125. model=self.model_trainer.model,
  126. mapped_triples=self.val_dataset.triples,
  127. additional_filter_triples=[
  128. self.train_dataset.triples,
  129. self.val_dataset.triples,
  130. ],
  131. batch_size=self.training_params.validation_batch_size,
  132. )
  133. mrr = result.get_metric("both.realistic.inverse_harmonic_mean_rank")
  134. hits_at = {f"hits_at_{k}": result.get_metric(f"both.realistic.hits_at_{k}") for k in ks}
  135. metrics = {
  136. "mrr": mrr,
  137. **hits_at,
  138. }
  139. for k, v in metrics.items():
  140. self.log(f"valid/{k}", v, console_msg=f"[valid] {k:6s} = {v:.4f}")
  141. # ───────────────────────── #
  142. # master schedule
  143. # ───────────────────────── #
  144. def run(self) -> None:
  145. """
  146. Alternate training and evaluation until `total_steps`
  147. optimisation steps have been executed.
  148. """
  149. total_steps = self.training_params.num_train_steps
  150. eval_every = self.training_params.eval_every
  151. while self.step < total_steps:
  152. self.train()
  153. if self.step % eval_every == 0:
  154. logger.info(f"Evaluating at step {self.step}...")
  155. self.evaluate()
  156. if self.step % self.save_every == 0:
  157. logger.info(f"Saving model at step {self.step}...")
  158. self.save()
  159. logger.info("Training complete.")
  160. def reset(self) -> None:
  161. """
  162. Reset the trainer state, including the model trainer and data iterators.
  163. """
  164. self.model_trainer.reset()
  165. self.train_iterator = self._data_iterator(self.train_loader)
  166. self.valid_iterator = self._data_iterator(self.valid_loader)
  167. self.step = 0
  168. logger.info("Trainer state has been reset.")