12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- from __future__ import annotations
-
- from abc import ABC, abstractmethod
- from typing import TYPE_CHECKING
-
- import torch
- from torch import nn
-
- if TYPE_CHECKING:
- from torch.optim.lr_scheduler import LRScheduler
-
-
- class ModelTrainerBase(ABC):
- def __init__(
- self,
- model: nn.Module,
- optimizer: torch.optim.Optimizer,
- scheduler: LRScheduler | None = None,
- device: torch.device = "cpu",
- ) -> None:
- self.model = model.to(device)
- self.optimizer = optimizer
- self.scheduler = scheduler
- self.device = device
-
- def save(self, save_dpath: str, step: int) -> None:
- """
- Save the model and optimizer state to the specified directory.
-
- Args:
- save_dpath (str): The directory path where the model and optimizer state will be saved.
- """
-
- data = {
- "model": self.model.state_dict(),
- "optimizer": self.optimizer.state_dict(),
- "step": step,
- }
-
- if self.scheduler is not None:
- data["scheduler"] = self.scheduler.state_dict()
-
- torch.save(
- data,
- f"{save_dpath}/model-{step}.pt",
- )
-
- def load(self, load_fpath: str) -> None:
- """
- Load the model and optimizer state from the specified directory.
-
- Args:
- load_dpath (str): The directory path from which the model and
- optimizer state will be loaded.
- """
- data = torch.load(load_fpath, map_location="cpu")
- model_state_dict = data["model"]
- self.model.load_state_dict(model_state_dict) # move everything to the correct device
- self.model.to(self.device)
- optimizer_state_dict = data["optimizer"]
- self.optimizer.load_state_dict(optimizer_state_dict)
- if self.scheduler is not None:
- scheduler_data = data.get("scheduler", {})
- self.scheduler.load_state_dict(scheduler_data)
-
- return data["step"]
-
- @abstractmethod
- def _loss(self, batch: torch.Tensor) -> torch.Tensor:
- """
- Compute the loss for a batch of data.
-
- Args:
- batch (torch.Tensor): A tensor containing a batch of data.
-
- Returns:
- torch.Tensor: The computed loss for the batch.
- """
- ...
-
- @abstractmethod
- def train_step(self, batch: torch.Tensor) -> float: ...
-
- @abstractmethod
- def eval_step(self, batch: torch.Tensor) -> float: ...
|