Skip to content

Commit 2899e37

Browse files
abhishek002002Orbax Authors
authored andcommitted
Use the new preservation policy in CheckpointManager.
PiperOrigin-RevId: 750914697
1 parent 4968dc7 commit 2899e37

File tree

3 files changed

+91
-111
lines changed

3 files changed

+91
-111
lines changed

checkpoint/orbax/checkpoint/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ py_library(
139139
"//checkpoint/orbax/checkpoint/_src/path:step",
140140
"//checkpoint/orbax/checkpoint/_src/path:utils",
141141
"//orbax/checkpoint/_src:threading",
142+
"//orbax/checkpoint/_src/checkpoint_managers:preservation_policy",
142143
],
143144
)
144145

@@ -353,5 +354,6 @@ py_library(
353354
":abstract_checkpoint_manager",
354355
":checkpoint_manager",
355356
"//checkpoint/orbax/checkpoint/_src/checkpoint_managers:save_decision_policy",
357+
"//orbax/checkpoint/_src/checkpoint_managers:preservation_policy",
356358
],
357359
)

checkpoint/orbax/checkpoint/checkpoint_manager.py

Lines changed: 77 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from orbax.checkpoint import options as options_lib
3636
from orbax.checkpoint import utils
3737
from orbax.checkpoint._src import threading as threading_lib
38+
from orbax.checkpoint._src.checkpoint_managers import preservation_policy as preservation_policy_lib
3839
from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib
3940
from orbax.checkpoint._src.checkpointers import abstract_checkpointer
4041
from orbax.checkpoint._src.checkpointers import async_checkpointer
@@ -206,6 +207,60 @@ def _get_default_save_decision_policy(
206207
return save_decision_policy_lib.AnySavePolicy(save_interval_policies)
207208

208209

210+
@dataclasses.dataclass
211+
class _ShouldKeepFnPolicy(preservation_policy_lib.PreservationPolicy):
212+
"""Return true based on a provided function of the step."""
213+
should_keep_fn: Callable[[int], bool]
214+
215+
def should_preserve(
216+
self,
217+
checkpoints: Sequence[checkpoint_info.CheckpointInfo],
218+
*,
219+
context: preservation_policy_lib.PreservationContext,
220+
) -> list[bool]:
221+
return [self.should_keep_fn(ckpt.step) for ckpt in checkpoints]
222+
223+
224+
def _get_default_preservation_policy(
225+
options: CheckpointManagerOptions,
226+
) -> preservation_policy_lib.PreservationPolicy:
227+
"""Returns a default preservation policy."""
228+
# Must have set max_to_keep in order to remove any checkpoints.
229+
preservation_policies = []
230+
if options.keep_period is not None:
231+
preservation_policies.append(
232+
preservation_policy_lib.EveryNSteps(options.keep_period)
233+
)
234+
if options.should_keep_fn is not None:
235+
preservation_policies.append(
236+
_ShouldKeepFnPolicy(
237+
should_keep_fn=options.should_keep_fn
238+
)
239+
)
240+
if options.keep_time_interval is not None:
241+
total_seconds = int(options.keep_time_interval.total_seconds())
242+
preservation_policies.append(
243+
preservation_policy_lib.EveryNSeconds(
244+
interval_secs=total_seconds
245+
)
246+
)
247+
if options.best_fn is not None:
248+
preservation_policies.append(
249+
preservation_policy_lib.BestN(
250+
best_fn=options.best_fn,
251+
reverse=(options.best_mode == 'min'),
252+
n=options.max_to_keep,
253+
)
254+
)
255+
else:
256+
preservation_policies.append(
257+
preservation_policy_lib.LatestN(n=options.max_to_keep)
258+
)
259+
return preservation_policy_lib.AnyPreservationPolicy(
260+
preservation_policies
261+
)
262+
263+
209264
# TODO(b/268051457) Clean up when no longer depended upon by internal users.
210265
def is_async_checkpointer(checkpointer: AbstractCheckpointer):
211266
return isinstance(
@@ -319,6 +374,12 @@ class CheckpointManagerOptions:
319374
is the sole means of determining when a checkpoint should be saved. If not
320375
provided, these other options are used instead. Prefer to use this option
321376
over others.
377+
preservation_policy: An object used to determine which checkpoints to
378+
preserve. If provided, overrides any other options dealing with this
379+
subject, including `max_to_keep`, `keep_time_interval`, `keep_period`, and
380+
`should_keep_fn`, `best_fn`, and is the sole means of determining which
381+
checkpoints to preserve. If not provided, these other options are used
382+
instead. Prefer to use this option over others.
322383
"""
323384

324385
save_interval_steps: int = 1
@@ -351,6 +412,9 @@ class CheckpointManagerOptions:
351412
save_decision_policy: Optional[
352413
save_decision_policy_lib.SaveDecisionPolicy
353414
] = None
415+
preservation_policy: Optional[
416+
preservation_policy_lib.PreservationPolicy
417+
] = None
354418

355419
def __post_init__(self):
356420
step_name_format_single_host_load_and_broadcast = (
@@ -632,6 +696,10 @@ def __init__(
632696
self._options.save_decision_policy
633697
or _get_default_save_decision_policy(self._options)
634698
)
699+
self._preservation_policy = (
700+
self._options.preservation_policy
701+
or _get_default_preservation_policy(self._options)
702+
)
635703

636704
if self._options.best_mode not in ['min', 'max']:
637705
raise ValueError('`best_mode` must be one of: "min", "max"')
@@ -1151,17 +1219,12 @@ def delete(self, step: int):
11511219
11521220
Args:
11531221
step: The step to delete.
1154-
1155-
Raises:
1156-
FileNotFoundError: If the step does not exist.
11571222
"""
11581223
if self._options.read_only:
11591224
logging.warning('%s is read only, delete will be skipped', self.directory)
11601225
return
11611226
if step not in self.all_steps():
1162-
raise FileNotFoundError(
1163-
f'Requested deleting a non-existent step: {step}.'
1164-
)
1227+
raise ValueError(f'Requested deleting a non-existent step: {step}.')
11651228
self._checkpoint_deleter.delete(step)
11661229
multihost.sync_global_processes(
11671230
multihost.unique_barrier_key(
@@ -1704,22 +1767,6 @@ def build_checkpoint_info(step_metadata):
17041767
)
17051768
return checkpoint_infos
17061769

1707-
def _get_interval_preserved_checkpoints(
1708-
self, checkpoints: checkpoint_info.CheckpointInfos
1709-
) -> List[CheckpointInfo]:
1710-
"""Gets which checkpoints should be kept based on keep_time_interval."""
1711-
if checkpoints.empty():
1712-
return []
1713-
interval_preserved_checkpoints = [checkpoints[0]]
1714-
if self._options.keep_time_interval is not None:
1715-
for info in checkpoints[1:]:
1716-
if info.time >= (
1717-
interval_preserved_checkpoints[-1].time
1718-
+ self._options.keep_time_interval
1719-
):
1720-
interval_preserved_checkpoints.append(info)
1721-
return interval_preserved_checkpoints
1722-
17231770
def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]):
17241771
self._checkpoints.append(
17251772
CheckpointInfo(
@@ -1866,102 +1913,21 @@ def _cleanup_tmp_directories(self):
18661913

18671914
def _get_old_steps_to_remove(self) -> List[int]:
18681915
"""Returns checkpoints that should be deleted."""
1869-
# Must have set max_to_keep in order to remove any checkpoints.
1870-
if self._options.max_to_keep is None:
1871-
return []
1872-
# Not enough checkpoints accumulated to consider deletion.
1873-
if self._checkpoints.size() <= self._options.max_to_keep:
1874-
return []
1875-
18761916
# This isn't a duration but there isn't a general counter that we can use so
18771917
# we abuse a duration metric to count the number of steps examined.
18781918
jax.monitoring.record_event_duration_secs(
18791919
'/jax/checkpoint/write/old_steps_examined_count',
18801920
self._checkpoints.size(),
18811921
)
1882-
1883-
if self._track_best:
1884-
# Best steps (to keep) are at the end, after sorting.
1885-
(
1886-
checkpoints_without_metrics,
1887-
sorted_checkpoints,
1888-
) = self._sort_checkpoints_by_metrics(self._checkpoints)
1889-
else:
1890-
# checkpoints already sorted by ascending step
1891-
checkpoints_without_metrics = []
1892-
sorted_checkpoints = [info for info in self._checkpoints]
1893-
1894-
keep = int(self._options.max_to_keep)
1895-
if self._options.keep_checkpoints_without_metrics:
1896-
maybe_delete = (
1897-
sorted_checkpoints[:-keep] if keep > 0 else sorted_checkpoints
1898-
)
1899-
active_checkpoints = set(
1900-
checkpoints_without_metrics
1901-
+ (sorted_checkpoints[-keep:] if keep > 0 else [])
1902-
)
1903-
else:
1904-
all_checkpoints = checkpoints_without_metrics + sorted_checkpoints
1905-
maybe_delete = all_checkpoints[:-keep] if keep > 0 else sorted_checkpoints
1906-
active_checkpoints = set(all_checkpoints[-keep:] if keep > 0 else [])
1907-
1908-
interval_preserved_checkpoints = self._get_interval_preserved_checkpoints(
1909-
self._checkpoints
1922+
preservation_result = self._preservation_policy.should_preserve(
1923+
[info for info in self._checkpoints],
1924+
context=preservation_policy_lib.PreservationContext(),
19101925
)
1911-
kept_checkpoints = set()
1912-
for info in maybe_delete:
1913-
if (
1914-
self._options.keep_time_interval is not None
1915-
and interval_preserved_checkpoints
1916-
):
1917-
if info in interval_preserved_checkpoints:
1918-
logging.info(
1919-
'Preserving %s: (Reason: older falling on keep_time_interval).',
1920-
info,
1921-
)
1922-
kept_checkpoints.add(info)
1923-
continue
1924-
elif info.time >= (
1925-
interval_preserved_checkpoints[-1].time
1926-
+ self._options.keep_time_interval
1927-
):
1928-
interval_preserved_checkpoints.append(info)
1929-
logging.info(
1930-
'Preserving %s: (Reason: latest falling on keep_time_interval).',
1931-
info,
1932-
)
1933-
kept_checkpoints.add(info)
1934-
continue
1935-
if (
1936-
self._options.should_keep_fn is not None
1937-
and self._options.should_keep_fn(info.step)
1938-
):
1939-
logging.info(
1940-
'Preserving %s: (Reason: on should_keep_fn callback).', info
1941-
)
1942-
kept_checkpoints.add(info)
1943-
continue
1944-
if (
1945-
self._options.keep_period is not None
1946-
and info.step % self._options.keep_period == 0
1947-
):
1948-
logging.info(
1949-
'Preserving %s: (Reason: on keep_period=%s).',
1950-
info,
1951-
self._options.keep_period,
1952-
)
1953-
kept_checkpoints.add(info)
1954-
continue
1955-
1956-
kept_checkpoints.update(active_checkpoints)
1957-
1958-
steps_to_remove = []
1959-
for info in self._checkpoints:
1960-
if info not in kept_checkpoints:
1961-
reason = 'worse metric' if self._track_best else 'old checkpoint'
1962-
logging.info('Deleting %s: (Reason: %s).', info, reason)
1963-
steps_to_remove.append(info.step)
1964-
return steps_to_remove
1926+
result = []
1927+
for i in range(len(self._checkpoints)):
1928+
if not preservation_result[i]:
1929+
result.append(self._checkpoints[i].step)
1930+
return result
19651931

19661932
def _wait_for_checkpointers(self):
19671933
if is_async_checkpointer(self._checkpointer):

checkpoint/orbax/checkpoint/checkpoint_managers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@
2626
AnySavePolicy,
2727
)
2828

29+
from orbax.checkpoint._src.checkpoint_managers import preservation_policy
30+
from orbax.checkpoint._src.checkpoint_managers.preservation_policy import (
31+
PreservationPolicy,
32+
LatestN,
33+
EveryNSeconds,
34+
EveryNSteps,
35+
CustomSteps,
36+
AnyPreservationPolicy,
37+
BestN,
38+
)
39+
40+
2941
from orbax.checkpoint.checkpoint_manager import CheckpointManagerOptions
3042
from orbax.checkpoint.checkpoint_manager import CheckpointManager
3143

0 commit comments

Comments
 (0)