from __future__ import annotations import re from pathlib import Path class CheckpointManager: """ A checkpoint manager that handles checkpoint directory creation and path resolution. Features: - Creates sequentially numbered checkpoint directories - Supports two loading modes: by components or by full path """ def __init__( self, root_directory: str | Path, run_name: str, load_only: bool = False, ) -> None: """ Initialize the checkpoint manager. Args: root_directory: Root directory for checkpoints run_name: Name of the run (used as suffix for checkpoint directories) """ self.root_directory = Path(root_directory) self.run_name = run_name self.checkpoint_directory = None if not load_only: self.root_directory.mkdir(parents=True, exist_ok=True) self.checkpoint_directory = self._create_checkpoint_directory() else: self.checkpoint_directory = "" def _find_existing_directories(self) -> list[int]: """ Find all existing directories with pattern xxx_ where xxx is a 3-digit number. Returns: List of existing sequence numbers """ pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$") existing_numbers = [] if self.root_directory.exists(): for item in self.root_directory.iterdir(): if item.is_dir(): match = pattern.match(item.name) if match: existing_numbers.append(int(match.group(1))) return sorted(existing_numbers) def _create_checkpoint_directory(self) -> Path: """ Create a new checkpoint directory with the next sequential number. Returns: Path to the created checkpoint directory """ existing_numbers = self._find_existing_directories() # Determine the next number next_number = max(existing_numbers) + 1 if existing_numbers else 1 # Create directory name with 3-digit zero-padded number dir_name = f"{next_number:03d}_{self.run_name}" checkpoint_dir = self.root_directory / dir_name self.run_name = dir_name # Create the directory checkpoint_dir.mkdir(parents=True, exist_ok=True) return checkpoint_dir def get_checkpoint_directory(self) -> Path: """ Get the current checkpoint directory. Returns: Path to the checkpoint directory """ return self.checkpoint_directory def get_model_fpath(self, model_path: str) -> Path: """ Get the full path to a model checkpoint file. Args: model_path: Either a tuple (model_id, iteration) or a string path Returns: Full path to the model checkpoint file """ try: model_path = eval(model_path) if isinstance(model_path, tuple): model_id, iteration = model_path checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}" filename = f"model-{iteration}.pt" return Path(f"{checkpoint_directory}/{filename}") except (SyntaxError, NameError): pass if isinstance(model_path, str): return Path(model_path) msg = "model_path must be a tuple (model_id, iteration) or a string path" raise ValueError(msg) def get_model_path_by_args( self, model_id: str | None = None, iteration: int | str | None = None, full_path: str | Path | None = None, ) -> Path: """ Get the path for loading a model checkpoint. Two modes of operation: 1. Component mode: Provide model_id and iteration to construct path 2. Full path mode: Provide full_path directly Args: model_id: Model identifier (used in component mode) iteration: Training iteration (used in component mode) full_path: Full path to checkpoint (used in full path mode) Returns: Path to the checkpoint file Raises: ValueError: If neither component parameters nor full_path are provided, or if both modes are attempted simultaneously """ # Check which mode we're operating in component_mode = model_id is not None or iteration is not None full_path_mode = full_path is not None if component_mode and full_path_mode: msg = ( "Cannot use both component mode (model_id/iteration) " "and full path mode simultaneously" ) raise ValueError(msg) if not component_mode and not full_path_mode: msg = "Must provide either (model_id and iteration) or full_path" raise ValueError(msg) if full_path_mode: # Full path mode: return the path as-is return Path(full_path) # Component mode: construct path from checkpoint directory, model_id, and iteration if model_id is None or iteration is None: msg = "Both model_id and iteration must be provided in component mode" raise ValueError(msg) filename = f"{model_id}_iter_{iteration}.pt" return self.checkpoint_directory / filename def save_checkpoint_info(self, info: dict) -> None: """ Save checkpoint information to a JSON file in the checkpoint directory. Args: info: Dictionary containing checkpoint metadata """ import json info_file = self.checkpoint_directory / "checkpoint_info.json" with open(info_file, "w") as f: json.dump(info, f, indent=2) def load_checkpoint_info(self) -> dict: """ Load checkpoint information from the checkpoint directory. Returns: Dictionary containing checkpoint metadata """ import json info_file = self.checkpoint_directory / "checkpoint_info.json" if info_file.exists(): with open(info_file) as f: return json.load(f) return {} def __str__(self) -> str: return ( f"CheckpointManager(root='{self.root_directory}', " f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')" ) def __repr__(self) -> str: return self.__str__() # Example usage: if __name__ == "__main__": import tempfile # Create checkpoint manager with proper temp directory with tempfile.TemporaryDirectory() as temp_dir: ckpt_manager = CheckpointManager(temp_dir, "my_experiment") print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}") # Component mode - construct path from components model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000) print(f"Model path (component mode): {model_path}") # Full path mode - use existing full path full_path = "/some/other/path/model.pt" model_path = ckpt_manager.get_model_path_by_args(full_path=full_path) print(f"Model path (full path mode): {model_path}")