|
35 | 35 | from orbax.checkpoint import options as options_lib
|
36 | 36 | from orbax.checkpoint import utils
|
37 | 37 | from orbax.checkpoint._src import threading as threading_lib
|
| 38 | +from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib |
38 | 39 | from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib
|
39 | 40 | from orbax.checkpoint._src.checkpointers import abstract_checkpointer
|
40 | 41 | from orbax.checkpoint._src.checkpointers import async_checkpointer
|
@@ -206,6 +207,60 @@ def _get_default_save_decision_policy(
|
206 | 207 | return save_decision_policy_lib.AnySavePolicy(save_interval_policies)
|
207 | 208 |
|
208 | 209 |
|
| 210 | +@dataclasses.dataclass |
| 211 | +class _ShouldKeepFnPolicy(preservation_policy_lib.PreservationPolicy): |
| 212 | + """Return true based on a provided function of the step.""" |
| 213 | + should_keep_fn: Callable[[int], bool] |
| 214 | + |
| 215 | + def should_preserve( |
| 216 | + self, |
| 217 | + checkpoints: Sequence[checkpoint_info.CheckpointInfo], |
| 218 | + *, |
| 219 | + context: preservation_policy_lib.PreservationContext, |
| 220 | + ) -> list[bool]: |
| 221 | + return [self.should_keep_fn(ckpt.step) for ckpt in checkpoints] |
| 222 | + |
| 223 | + |
| 224 | +def _get_default_preservation_policy( |
| 225 | + options: CheckpointManagerOptions, |
| 226 | +) -> preservation_policy_lib.PreservationPolicy: |
| 227 | + """Returns a default preservation policy.""" |
| 228 | + # Must have set max_to_keep in order to remove any checkpoints. |
| 229 | + preservation_policies = [] |
| 230 | + if options.keep_period is not None: |
| 231 | + preservation_policies.append( |
| 232 | + preservation_policy_lib.EveryNSteps(options.keep_period) |
| 233 | + ) |
| 234 | + if options.should_keep_fn is not None: |
| 235 | + preservation_policies.append( |
| 236 | + _ShouldKeepFnPolicy( |
| 237 | + should_keep_fn=options.should_keep_fn |
| 238 | + ) |
| 239 | + ) |
| 240 | + if options.keep_time_interval is not None: |
| 241 | + total_seconds = int(options.keep_time_interval.total_seconds()) |
| 242 | + preservation_policies.append( |
| 243 | + preservation_policy_lib.EveryNSeconds( |
| 244 | + interval_secs=total_seconds |
| 245 | + ) |
| 246 | + ) |
| 247 | + if options.best_fn is not None: |
| 248 | + preservation_policies.append( |
| 249 | + preservation_policy_lib.BestN( |
| 250 | + best_fn=options.best_fn, |
| 251 | + reverse=(options.best_mode == 'min'), |
| 252 | + n=options.max_to_keep, |
| 253 | + ) |
| 254 | + ) |
| 255 | + else: |
| 256 | + preservation_policies.append( |
| 257 | + preservation_policy_lib.LatestN(n=options.max_to_keep) |
| 258 | + ) |
| 259 | + return preservation_policy_lib.AnyPreservationPolicy( |
| 260 | + preservation_policies |
| 261 | + ) |
| 262 | + |
| 263 | + |
209 | 264 | # TODO(b/268051457) Clean up when no longer depended upon by internal users.
|
210 | 265 | def is_async_checkpointer(checkpointer: AbstractCheckpointer):
|
211 | 266 | return isinstance(
|
@@ -319,6 +374,12 @@ class CheckpointManagerOptions:
|
319 | 374 | is the sole means of determining when a checkpoint should be saved. If not
|
320 | 375 | provided, these other options are used instead. Prefer to use this option
|
321 | 376 | over others.
|
| 377 | + preservation_policy: An object used to determine which checkpoints to |
| 378 | + preserve. If provided, overrides any other options dealing with this |
| 379 | + subject, including `max_to_keep`, `keep_time_interval`, `keep_period`, and |
| 380 | + `should_keep_fn`, `best_fn`, and is the sole means of determining which |
| 381 | + checkpoints to preserve. If not provided, these other options are used |
| 382 | + instead. Prefer to use this option over others. |
322 | 383 | """
|
323 | 384 |
|
324 | 385 | save_interval_steps: int = 1
|
@@ -351,6 +412,9 @@ class CheckpointManagerOptions:
|
351 | 412 | save_decision_policy: Optional[
|
352 | 413 | save_decision_policy_lib.SaveDecisionPolicy
|
353 | 414 | ] = None
|
| 415 | + preservation_policy: Optional[ |
| 416 | + preservation_policy_lib.PreservationPolicy |
| 417 | + ] = None |
354 | 418 |
|
355 | 419 | def __post_init__(self):
|
356 | 420 | step_name_format_single_host_load_and_broadcast = (
|
@@ -632,6 +696,10 @@ def __init__(
|
632 | 696 | self._options.save_decision_policy
|
633 | 697 | or _get_default_save_decision_policy(self._options)
|
634 | 698 | )
|
| 699 | + self._preservation_policy = ( |
| 700 | + self._options.preservation_policy |
| 701 | + or _get_default_preservation_policy(self._options) |
| 702 | + ) |
635 | 703 |
|
636 | 704 | if self._options.best_mode not in ['min', 'max']:
|
637 | 705 | raise ValueError('`best_mode` must be one of: "min", "max"')
|
@@ -1151,17 +1219,12 @@ def delete(self, step: int):
|
1151 | 1219 |
|
1152 | 1220 | Args:
|
1153 | 1221 | step: The step to delete.
|
1154 |
| -
|
1155 |
| - Raises: |
1156 |
| - FileNotFoundError: If the step does not exist. |
1157 | 1222 | """
|
1158 | 1223 | if self._options.read_only:
|
1159 | 1224 | logging.warning('%s is read only, delete will be skipped', self.directory)
|
1160 | 1225 | return
|
1161 | 1226 | if step not in self.all_steps():
|
1162 |
| - raise FileNotFoundError( |
1163 |
| - f'Requested deleting a non-existent step: {step}.' |
1164 |
| - ) |
| 1227 | + raise ValueError(f'Requested deleting a non-existent step: {step}.') |
1165 | 1228 | self._checkpoint_deleter.delete(step)
|
1166 | 1229 | multihost.sync_global_processes(
|
1167 | 1230 | multihost.unique_barrier_key(
|
@@ -1704,22 +1767,6 @@ def build_checkpoint_info(step_metadata):
|
1704 | 1767 | )
|
1705 | 1768 | return checkpoint_infos
|
1706 | 1769 |
|
1707 |
| - def _get_interval_preserved_checkpoints( |
1708 |
| - self, checkpoints: checkpoint_info.CheckpointInfos |
1709 |
| - ) -> List[CheckpointInfo]: |
1710 |
| - """Gets which checkpoints should be kept based on keep_time_interval.""" |
1711 |
| - if checkpoints.empty(): |
1712 |
| - return [] |
1713 |
| - interval_preserved_checkpoints = [checkpoints[0]] |
1714 |
| - if self._options.keep_time_interval is not None: |
1715 |
| - for info in checkpoints[1:]: |
1716 |
| - if info.time >= ( |
1717 |
| - interval_preserved_checkpoints[-1].time |
1718 |
| - + self._options.keep_time_interval |
1719 |
| - ): |
1720 |
| - interval_preserved_checkpoints.append(info) |
1721 |
| - return interval_preserved_checkpoints |
1722 |
| - |
1723 | 1770 | def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]):
|
1724 | 1771 | self._checkpoints.append(
|
1725 | 1772 | CheckpointInfo(
|
@@ -1866,102 +1913,21 @@ def _cleanup_tmp_directories(self):
|
1866 | 1913 |
|
1867 | 1914 | def _get_old_steps_to_remove(self) -> List[int]:
|
1868 | 1915 | """Returns checkpoints that should be deleted."""
|
1869 |
| - # Must have set max_to_keep in order to remove any checkpoints. |
1870 |
| - if self._options.max_to_keep is None: |
1871 |
| - return [] |
1872 |
| - # Not enough checkpoints accumulated to consider deletion. |
1873 |
| - if self._checkpoints.size() <= self._options.max_to_keep: |
1874 |
| - return [] |
1875 |
| - |
1876 | 1916 | # This isn't a duration but there isn't a general counter that we can use so
|
1877 | 1917 | # we abuse a duration metric to count the number of steps examined.
|
1878 | 1918 | jax.monitoring.record_event_duration_secs(
|
1879 | 1919 | '/jax/checkpoint/write/old_steps_examined_count',
|
1880 | 1920 | self._checkpoints.size(),
|
1881 | 1921 | )
|
1882 |
| - |
1883 |
| - if self._track_best: |
1884 |
| - # Best steps (to keep) are at the end, after sorting. |
1885 |
| - ( |
1886 |
| - checkpoints_without_metrics, |
1887 |
| - sorted_checkpoints, |
1888 |
| - ) = self._sort_checkpoints_by_metrics(self._checkpoints) |
1889 |
| - else: |
1890 |
| - # checkpoints already sorted by ascending step |
1891 |
| - checkpoints_without_metrics = [] |
1892 |
| - sorted_checkpoints = [info for info in self._checkpoints] |
1893 |
| - |
1894 |
| - keep = int(self._options.max_to_keep) |
1895 |
| - if self._options.keep_checkpoints_without_metrics: |
1896 |
| - maybe_delete = ( |
1897 |
| - sorted_checkpoints[:-keep] if keep > 0 else sorted_checkpoints |
1898 |
| - ) |
1899 |
| - active_checkpoints = set( |
1900 |
| - checkpoints_without_metrics |
1901 |
| - + (sorted_checkpoints[-keep:] if keep > 0 else []) |
1902 |
| - ) |
1903 |
| - else: |
1904 |
| - all_checkpoints = checkpoints_without_metrics + sorted_checkpoints |
1905 |
| - maybe_delete = all_checkpoints[:-keep] if keep > 0 else sorted_checkpoints |
1906 |
| - active_checkpoints = set(all_checkpoints[-keep:] if keep > 0 else []) |
1907 |
| - |
1908 |
| - interval_preserved_checkpoints = self._get_interval_preserved_checkpoints( |
1909 |
| - self._checkpoints |
| 1922 | + preservation_result = self._preservation_policy.should_preserve( |
| 1923 | + [info for info in self._checkpoints], |
| 1924 | + context=preservation_policy_lib.PreservationContext(), |
1910 | 1925 | )
|
1911 |
| - kept_checkpoints = set() |
1912 |
| - for info in maybe_delete: |
1913 |
| - if ( |
1914 |
| - self._options.keep_time_interval is not None |
1915 |
| - and interval_preserved_checkpoints |
1916 |
| - ): |
1917 |
| - if info in interval_preserved_checkpoints: |
1918 |
| - logging.info( |
1919 |
| - 'Preserving %s: (Reason: older falling on keep_time_interval).', |
1920 |
| - info, |
1921 |
| - ) |
1922 |
| - kept_checkpoints.add(info) |
1923 |
| - continue |
1924 |
| - elif info.time >= ( |
1925 |
| - interval_preserved_checkpoints[-1].time |
1926 |
| - + self._options.keep_time_interval |
1927 |
| - ): |
1928 |
| - interval_preserved_checkpoints.append(info) |
1929 |
| - logging.info( |
1930 |
| - 'Preserving %s: (Reason: latest falling on keep_time_interval).', |
1931 |
| - info, |
1932 |
| - ) |
1933 |
| - kept_checkpoints.add(info) |
1934 |
| - continue |
1935 |
| - if ( |
1936 |
| - self._options.should_keep_fn is not None |
1937 |
| - and self._options.should_keep_fn(info.step) |
1938 |
| - ): |
1939 |
| - logging.info( |
1940 |
| - 'Preserving %s: (Reason: on should_keep_fn callback).', info |
1941 |
| - ) |
1942 |
| - kept_checkpoints.add(info) |
1943 |
| - continue |
1944 |
| - if ( |
1945 |
| - self._options.keep_period is not None |
1946 |
| - and info.step % self._options.keep_period == 0 |
1947 |
| - ): |
1948 |
| - logging.info( |
1949 |
| - 'Preserving %s: (Reason: on keep_period=%s).', |
1950 |
| - info, |
1951 |
| - self._options.keep_period, |
1952 |
| - ) |
1953 |
| - kept_checkpoints.add(info) |
1954 |
| - continue |
1955 |
| - |
1956 |
| - kept_checkpoints.update(active_checkpoints) |
1957 |
| - |
1958 |
| - steps_to_remove = [] |
1959 |
| - for info in self._checkpoints: |
1960 |
| - if info not in kept_checkpoints: |
1961 |
| - reason = 'worse metric' if self._track_best else 'old checkpoint' |
1962 |
| - logging.info('Deleting %s: (Reason: %s).', info, reason) |
1963 |
| - steps_to_remove.append(info.step) |
1964 |
| - return steps_to_remove |
| 1926 | + result = [] |
| 1927 | + for i in range(len(self._checkpoints)): |
| 1928 | + if not preservation_result[i]: |
| 1929 | + result.append(self._checkpoints[i].step) |
| 1930 | + return result |
1965 | 1931 |
|
1966 | 1932 | def _wait_for_checkpointers(self):
|
1967 | 1933 | if is_async_checkpointer(self._checkpointer):
|
|
0 commit comments