from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING import torch from torch import nn if TYPE_CHECKING: from torch.optim.lr_scheduler import LRScheduler class ModelTrainerBase(ABC): def __init__( self, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: LRScheduler | None = None, device: torch.device = "cpu", ) -> None: self.model = model.to(device) self.optimizer = optimizer self.scheduler = scheduler self.device = device def save(self, save_dpath: str, step: int) -> None: """ Save the model and optimizer state to the specified directory. Args: save_dpath (str): The directory path where the model and optimizer state will be saved. """ data = { "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "step": step, } if self.scheduler is not None: data["scheduler"] = self.scheduler.state_dict() torch.save( data, f"{save_dpath}/model-{step}.pt", ) def load(self, load_fpath: str) -> None: """ Load the model and optimizer state from the specified directory. Args: load_dpath (str): The directory path from which the model and optimizer state will be loaded. """ data = torch.load(load_fpath, map_location="cpu") model_state_dict = data["model"] self.model.load_state_dict(model_state_dict) # move everything to the correct device self.model.to(self.device) optimizer_state_dict = data["optimizer"] self.optimizer.load_state_dict(optimizer_state_dict) if self.scheduler is not None: scheduler_data = data.get("scheduler", {}) self.scheduler.load_state_dict(scheduler_data) return data["step"] @abstractmethod def _loss(self, batch: torch.Tensor) -> torch.Tensor: """ Compute the loss for a batch of data. Args: batch (torch.Tensor): A tensor containing a batch of data. Returns: torch.Tensor: The computed loss for the batch. """ ... @abstractmethod def train_step(self, batch: torch.Tensor) -> float: ... @abstractmethod def eval_step(self, batch: torch.Tensor) -> float: ...