You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

checkpoint_manager.py 7.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from __future__ import annotations
  2. import re
  3. from pathlib import Path
  4. class CheckpointManager:
  5. """
  6. A checkpoint manager that handles checkpoint directory creation and path resolution.
  7. Features:
  8. - Creates sequentially numbered checkpoint directories
  9. - Supports two loading modes: by components or by full path
  10. """
  11. def __init__(
  12. self,
  13. root_directory: str | Path,
  14. run_name: str,
  15. load_only: bool = False,
  16. ) -> None:
  17. """
  18. Initialize the checkpoint manager.
  19. Args:
  20. root_directory: Root directory for checkpoints
  21. run_name: Name of the run (used as suffix for checkpoint directories)
  22. """
  23. self.root_directory = Path(root_directory)
  24. self.run_name = run_name
  25. self.checkpoint_directory = None
  26. if not load_only:
  27. self.root_directory.mkdir(parents=True, exist_ok=True)
  28. self.checkpoint_directory = self._create_checkpoint_directory()
  29. else:
  30. self.checkpoint_directory = ""
  31. def _find_existing_directories(self) -> list[int]:
  32. """
  33. Find all existing directories with pattern xxx_<run_name> where xxx is a 3-digit number.
  34. Returns:
  35. List of existing sequence numbers
  36. """
  37. pattern = re.compile(rf"^(\d{{3}})_{re.escape(self.run_name)}$")
  38. existing_numbers = []
  39. if self.root_directory.exists():
  40. for item in self.root_directory.iterdir():
  41. if item.is_dir():
  42. match = pattern.match(item.name)
  43. if match:
  44. existing_numbers.append(int(match.group(1)))
  45. return sorted(existing_numbers)
  46. def _create_checkpoint_directory(self) -> Path:
  47. """
  48. Create a new checkpoint directory with the next sequential number.
  49. Returns:
  50. Path to the created checkpoint directory
  51. """
  52. existing_numbers = self._find_existing_directories()
  53. # Determine the next number
  54. next_number = max(existing_numbers) + 1 if existing_numbers else 1
  55. # Create directory name with 3-digit zero-padded number
  56. dir_name = f"{next_number:03d}_{self.run_name}"
  57. checkpoint_dir = self.root_directory / dir_name
  58. self.run_name = dir_name
  59. # Create the directory
  60. checkpoint_dir.mkdir(parents=True, exist_ok=True)
  61. return checkpoint_dir
  62. def get_checkpoint_directory(self) -> Path:
  63. """
  64. Get the current checkpoint directory.
  65. Returns:
  66. Path to the checkpoint directory
  67. """
  68. return self.checkpoint_directory
  69. def get_model_fpath(self, model_path: str) -> Path:
  70. """
  71. Get the full path to a model checkpoint file.
  72. Args:
  73. model_path: Either a tuple (model_id, iteration) or a string path
  74. Returns:
  75. Full path to the model checkpoint file
  76. """
  77. try:
  78. model_path = eval(model_path)
  79. if isinstance(model_path, tuple):
  80. model_id, iteration = model_path
  81. checkpoint_directory = f"{self.root_directory}/{model_id:03d}_{self.run_name}"
  82. filename = f"model-{iteration}.pt"
  83. return Path(f"{checkpoint_directory}/{filename}")
  84. except (SyntaxError, NameError):
  85. pass
  86. if isinstance(model_path, str):
  87. return Path(model_path)
  88. msg = "model_path must be a tuple (model_id, iteration) or a string path"
  89. raise ValueError(msg)
  90. def get_model_path_by_args(
  91. self,
  92. model_id: str | None = None,
  93. iteration: int | str | None = None,
  94. full_path: str | Path | None = None,
  95. ) -> Path:
  96. """
  97. Get the path for loading a model checkpoint.
  98. Two modes of operation:
  99. 1. Component mode: Provide model_id and iteration to construct path
  100. 2. Full path mode: Provide full_path directly
  101. Args:
  102. model_id: Model identifier (used in component mode)
  103. iteration: Training iteration (used in component mode)
  104. full_path: Full path to checkpoint (used in full path mode)
  105. Returns:
  106. Path to the checkpoint file
  107. Raises:
  108. ValueError: If neither component parameters nor full_path are provided,
  109. or if both modes are attempted simultaneously
  110. """
  111. # Check which mode we're operating in
  112. component_mode = model_id is not None or iteration is not None
  113. full_path_mode = full_path is not None
  114. if component_mode and full_path_mode:
  115. msg = (
  116. "Cannot use both component mode (model_id/iteration) "
  117. "and full path mode simultaneously"
  118. )
  119. raise ValueError(msg)
  120. if not component_mode and not full_path_mode:
  121. msg = "Must provide either (model_id and iteration) or full_path"
  122. raise ValueError(msg)
  123. if full_path_mode:
  124. # Full path mode: return the path as-is
  125. return Path(full_path)
  126. # Component mode: construct path from checkpoint directory, model_id, and iteration
  127. if model_id is None or iteration is None:
  128. msg = "Both model_id and iteration must be provided in component mode"
  129. raise ValueError(msg)
  130. filename = f"{model_id}_iter_{iteration}.pt"
  131. return self.checkpoint_directory / filename
  132. def save_checkpoint_info(self, info: dict) -> None:
  133. """
  134. Save checkpoint information to a JSON file in the checkpoint directory.
  135. Args:
  136. info: Dictionary containing checkpoint metadata
  137. """
  138. import json
  139. info_file = self.checkpoint_directory / "checkpoint_info.json"
  140. with open(info_file, "w") as f:
  141. json.dump(info, f, indent=2)
  142. def load_checkpoint_info(self) -> dict:
  143. """
  144. Load checkpoint information from the checkpoint directory.
  145. Returns:
  146. Dictionary containing checkpoint metadata
  147. """
  148. import json
  149. info_file = self.checkpoint_directory / "checkpoint_info.json"
  150. if info_file.exists():
  151. with open(info_file) as f:
  152. return json.load(f)
  153. return {}
  154. def __str__(self) -> str:
  155. return (
  156. f"CheckpointManager(root='{self.root_directory}', "
  157. f"run='{self.run_name}', ckpt_dir='{self.checkpoint_directory}')"
  158. )
  159. def __repr__(self) -> str:
  160. return self.__str__()
  161. # Example usage:
  162. if __name__ == "__main__":
  163. import tempfile
  164. # Create checkpoint manager with proper temp directory
  165. with tempfile.TemporaryDirectory() as temp_dir:
  166. ckpt_manager = CheckpointManager(temp_dir, "my_experiment")
  167. print(f"Checkpoint directory: {ckpt_manager.get_checkpoint_directory()}")
  168. # Component mode - construct path from components
  169. model_path = ckpt_manager.get_model_path_by_args(model_id="transe", iteration=1000)
  170. print(f"Model path (component mode): {model_path}")
  171. # Full path mode - use existing full path
  172. full_path = "/some/other/path/model.pt"
  173. model_path = ckpt_manager.get_model_path_by_args(full_path=full_path)
  174. print(f"Model path (full path mode): {model_path}")