You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_trainer_base.py 2.4KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import TYPE_CHECKING
  4. import torch
  5. from torch import nn
  6. if TYPE_CHECKING:
  7. from torch.optim.lr_scheduler import LRScheduler
  8. class ModelTrainerBase(ABC):
  9. def __init__(
  10. self,
  11. model: nn.Module,
  12. optimizer: torch.optim.Optimizer,
  13. scheduler: LRScheduler | None = None,
  14. device: torch.device = "cpu",
  15. ) -> None:
  16. self.model = model.to(device)
  17. self.optimizer = optimizer
  18. self.scheduler = scheduler
  19. self.device = device
  20. def save(self, save_dpath: str, step: int) -> None:
  21. """
  22. Save the model and optimizer state to the specified directory.
  23. Args:
  24. save_dpath (str): The directory path where the model and optimizer state will be saved.
  25. """
  26. data = {
  27. "model": self.model.state_dict(),
  28. "optimizer": self.optimizer.state_dict(),
  29. "step": step,
  30. }
  31. if self.scheduler is not None:
  32. data["scheduler"] = self.scheduler.state_dict()
  33. torch.save(
  34. data,
  35. f"{save_dpath}/model-{step}.pt",
  36. )
  37. def load(self, load_fpath: str) -> None:
  38. """
  39. Load the model and optimizer state from the specified directory.
  40. Args:
  41. load_dpath (str): The directory path from which the model and
  42. optimizer state will be loaded.
  43. """
  44. data = torch.load(load_fpath, map_location="cpu")
  45. model_state_dict = data["model"]
  46. self.model.load_state_dict(model_state_dict) # move everything to the correct device
  47. self.model.to(self.device)
  48. optimizer_state_dict = data["optimizer"]
  49. self.optimizer.load_state_dict(optimizer_state_dict)
  50. if self.scheduler is not None:
  51. scheduler_data = data.get("scheduler", {})
  52. self.scheduler.load_state_dict(scheduler_data)
  53. return data["step"]
  54. @abstractmethod
  55. def _loss(self, batch: torch.Tensor) -> torch.Tensor:
  56. """
  57. Compute the loss for a batch of data.
  58. Args:
  59. batch (torch.Tensor): A tensor containing a batch of data.
  60. Returns:
  61. torch.Tensor: The computed loss for the batch.
  62. """
  63. ...
  64. @abstractmethod
  65. def train_step(self, batch: torch.Tensor) -> float: ...
  66. @abstractmethod
  67. def eval_step(self, batch: torch.Tensor) -> float: ...