123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- from __future__ import annotations
-
- import re
- from pathlib import Path
-
-
- class CheckpointManager:
- """
- A checkpoint manager that handles checkpoint directory creation and path resolution.
-
- Features:
- - Creates sequentially numbered checkpoint directories
- - Supports two loading modes: by components or by full path
- """
-
- def __init__(
- self,
- root_directory: str | Path,
- run_name: str,
- load_only: bool = False,
- ) -> None:
- """
- Initialize the checkpoint manager.
-
- Args:
- root_directory: Root directory for checkpoints
- run_name: Name of the run (used as suffix for checkpoint directories)
- """
- self.root_directory = Path(root_directory)
- self.run_name = run_name
- self.checkpoint_directory = None
-
- if not load_only:
- self.root_directory.mkdir(parents=True, exist_ok=True)
- self.checkpoint_directory = self._create_checkpoint_directory()
- else:
- self.checkpoint_directory = ""
-
- def _find_existing_directories(self) -> list[int]:
- """
- Find all existing directories with pattern xxx_<run_name> where xxx is a 3-digit number.
-
- Returns:
- List of existing sequence numbers
- """
- pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$")
- existing_numbers = []
-
- if self.root_directory.exists():
- for item in self.root_directory.iterdir():
- if item.is_dir():
- match = pattern.match(item.name)
- if match:
- existing_numbers.append(int(match.group(1)))
-
- return sorted(existing_numbers)
-
- def _create_checkpoint_directory(self) -> Path:
- """
- Create a new checkpoint directory with the next sequential number.
-
- Returns:
- Path to the created checkpoint directory
- """
- existing_numbers = self._find_existing_directories()
-
- # Determine the next number
- next_number = max(existing_numbers) + 1 if existing_numbers else 1
-
- # Create directory name with 3-digit zero-padded number
- dir_name = f"{next_number:03d}_{self.run_name}"
- checkpoint_dir = self.root_directory / dir_name
- self.run_name = dir_name
-
- # Create the directory
- checkpoint_dir.mkdir(parents=True, exist_ok=True)
-
- return checkpoint_dir
-
- def get_checkpoint_directory(self) -> Path:
- """
- Get the current checkpoint directory.
-
- Returns:
- Path to the checkpoint directory
- """
- return self.checkpoint_directory
-
- def get_model_fpath(self, model_path: str) -> Path:
- """
- Get the full path to a model checkpoint file.
-
- Args:
- model_path: Either a tuple (model_id, iteration) or a string path
-
- Returns:
- Full path to the model checkpoint file
- """
-
- try:
- model_path = eval(model_path)
- if isinstance(model_path, tuple):
- model_id, iteration = model_path
- checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}"
- filename = f"model-{iteration}.pt"
- return Path(f"{checkpoint_directory}/{filename}")
- except (SyntaxError, NameError):
- pass
-
- if isinstance(model_path, str):
- return Path(model_path)
- msg = "model_path must be a tuple (model_id, iteration) or a string path"
- raise ValueError(msg)
-
- def get_model_path_by_args(
- self,
- model_id: str | None = None,
- iteration: int | str | None = None,
- full_path: str | Path | None = None,
- ) -> Path:
- """
- Get the path for loading a model checkpoint.
-
- Two modes of operation:
- 1. Component mode: Provide model_id and iteration to construct path
- 2. Full path mode: Provide full_path directly
-
- Args:
- model_id: Model identifier (used in component mode)
- iteration: Training iteration (used in component mode)
- full_path: Full path to checkpoint (used in full path mode)
-
- Returns:
- Path to the checkpoint file
-
- Raises:
- ValueError: If neither component parameters nor full_path are provided,
- or if both modes are attempted simultaneously
- """
- # Check which mode we're operating in
- component_mode = model_id is not None or iteration is not None
- full_path_mode = full_path is not None
-
- if component_mode and full_path_mode:
- msg = (
- "Cannot use both component mode (model_id/iteration) "
- "and full path mode simultaneously"
- )
- raise ValueError(msg)
-
- if not component_mode and not full_path_mode:
- msg = "Must provide either (model_id and iteration) or full_path"
- raise ValueError(msg)
-
- if full_path_mode:
- # Full path mode: return the path as-is
- return Path(full_path)
-
- # Component mode: construct path from checkpoint directory, model_id, and iteration
- if model_id is None or iteration is None:
- msg = "Both model_id and iteration must be provided in component mode"
- raise ValueError(msg)
-
- filename = f"{model_id}_iter_{iteration}.pt"
- return self.checkpoint_directory / filename
-
- def save_checkpoint_info(self, info: dict) -> None:
- """
- Save checkpoint information to a JSON file in the checkpoint directory.
-
- Args:
- info: Dictionary containing checkpoint metadata
- """
- import json
-
- info_file = self.checkpoint_directory / "checkpoint_info.json"
- with open(info_file, "w") as f:
- json.dump(info, f, indent=2)
-
- def load_checkpoint_info(self) -> dict:
- """
- Load checkpoint information from the checkpoint directory.
-
- Returns:
- Dictionary containing checkpoint metadata
- """
- import json
-
- info_file = self.checkpoint_directory / "checkpoint_info.json"
- if info_file.exists():
- with open(info_file) as f:
- return json.load(f)
- return {}
-
- def __str__(self) -> str:
- return (
- f"CheckpointManager(root='{self.root_directory}', "
- f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')"
- )
-
- def __repr__(self) -> str:
- return self.__str__()
-
-
- # Example usage:
- if __name__ == "__main__":
- import tempfile
-
- # Create checkpoint manager with proper temp directory
- with tempfile.TemporaryDirectory() as temp_dir:
- ckpt_manager = CheckpointManager(temp_dir, "my_experiment")
- print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}")
-
- # Component mode - construct path from components
- model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000)
- print(f"Model path (component mode): {model_path}")
-
- # Full path mode - use existing full path
- full_path = "/some/other/path/model.pt"
- model_path = ckpt_manager.get_model_path_by_args(full_path=full_path)
- print(f"Model path (full path mode): {model_path}")
|