Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718666876
niketkumar authored and Orbax Authors committed Jan 23, 2025
1 parent 0500787 commit d4a2579
Showing 9 changed files with 548 additions and 78 deletions.
29 changes: 29 additions & 0 deletions checkpoint/orbax/checkpoint/_src/futures/future.py
Original file line number Diff line number Diff line change
@@ -224,3 +224,32 @@ def __init__(
def result(self, timeout: Optional[float] = None) -> Any:
"""Waits for the commit to complete."""
return self._t.join(timeout=timeout)


class CoroutineRunningFuture(Future):
"""Runs a coroutine with `asyncio.run` in a new thread.
This future resembles a concurrent.futures.Future because its result()
supports a timeout.
Extends the Orbax Future protocol.
Not thread-safe.
"""

def __init__(self, coro, name: Optional[str] = None):
"""Creates an Orbax Future for running given coroutine in a new thread.
Args:
coro: The coroutine to run.
name: The name of the underlying thread.
"""
super().__init__()
self._t = ThreadRaisingException(
name=name,
target=lambda: asyncio_utils.run_sync(coro),
)
self._t.start()

def result(self, timeout: Optional[int] = None) -> Any:
return self._t.join(timeout=timeout)
3 changes: 2 additions & 1 deletion checkpoint/orbax/checkpoint/_src/handlers/BUILD
Original file line number Diff line number Diff line change
@@ -71,6 +71,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/metadata:array_metadata_store",
],
)

@@ -81,7 +82,6 @@ py_library(
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint:checkpoint_args",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint:options",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
@@ -95,6 +95,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/futures:future",
"//orbax/checkpoint/_src/metadata:array_metadata_store",
],
)
Original file line number Diff line number Diff line change
@@ -36,10 +36,10 @@
import humanize
import jax
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import empty_values
@@ -284,6 +284,9 @@ def __init__(
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
tree_metadata.PYTREE_METADATA_OPTIONS
),
array_metadata_validator: array_metadata_store_lib.Validator = (
array_metadata_store_lib.Validator()
),
):
"""Creates BasePyTreeCheckpointHandler.
@@ -303,6 +306,7 @@ def __init__(
enable_post_merge_validation: If True, enables validation of the
parameters after the finalize step.
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
array_metadata_validator: Validator for ArrayMetadata.
"""
self._save_concurrent_bytes = save_concurrent_bytes
self._restore_concurrent_bytes = restore_concurrent_bytes
@@ -319,6 +323,7 @@ def __init__(
type_handler_registry
)
)
self._array_metadata_validator = array_metadata_validator


jax.monitoring.record_event(
@@ -329,8 +334,10 @@ def __init__(
max_workers=3, thread_name_prefix='base_pytree_ch'
)
logging.info(
'Created BasePyTreeCheckpointHandler: pytree_metadata_options=%s',
'Created BasePyTreeCheckpointHandler: pytree_metadata_options=%s,'
' array_metadata_store=%s',
self._pytree_metadata_options,
self._array_metadata_store,
)

def get_param_names(self, item: PyTree) -> PyTree:
@@ -487,14 +494,16 @@ async def async_save(
save_futures = []
if multihost.is_primary_host(self._primary_host):
save_futures.append(
self._thread_pool.submit(
self._write_metadata_after_commits,
commit_futures=commit_futures,
checkpoint_dir=directory,
param_infos=param_infos,
save_args=save_args,
custom_metadata=custom_metadata,
use_zarr3=self._use_zarr3,
future.CoroutineRunningFuture(
self._write_metadata_after_commits(
commit_futures,
checkpoint_dir=directory,
param_infos=param_infos,
save_args=save_args,
custom_metadata=custom_metadata,
use_zarr3=self._use_zarr3,
),
name='write_metadata_after_commits',
)
)
else:
@@ -745,7 +754,7 @@ class TrainState:
)
return restored_item

def _get_param_infos_with_write_shape(
async def _get_param_infos_with_write_shape(
self,
param_infos: PyTree,
checkpoint_dir: epath.Path,
@@ -764,7 +773,7 @@ def _get_param_infos_with_write_shape(
return param_infos
# Extract write_shape from ArrayMetadata for current process_index.
process_index = multihost.process_index()
array_metadatas = array_metadata_store.read(
array_metadatas = await array_metadata_store.read(
checkpoint_dir, process_index=process_index
)
if array_metadatas is None:
@@ -834,7 +843,7 @@ def _save_fn(param_infos):

return self._thread_pool.submit(_save_fn, param_infos)

def _write_metadata_after_commits(
async def _write_metadata_after_commits(
self,
commit_futures: List[future.Future],
checkpoint_dir: epath.Path,
@@ -847,7 +856,7 @@ def _write_metadata_after_commits(
if not utils.is_primary_host(self._primary_host):
return
for commit_future in commit_futures:
commit_future.result()
await asyncio.to_thread(commit_future.result)
# `write_shape` is extracted from ArrayMetadata store saved during
# materialization of commit_futures. Then it is written to the pytree
# metadata.
@@ -857,16 +866,17 @@ def _write_metadata_after_commits(
# BasePyTreeCheckpointHandler should delegate all metadata related code to
# that class.
if self._array_metadata_store is not None:
param_infos = self._get_param_infos_with_write_shape(
param_infos = await self._get_param_infos_with_write_shape(
param_infos, checkpoint_dir, self._array_metadata_store
)
self._write_metadata_file(
metadata_write_future = self._write_metadata_file(
checkpoint_dir,
param_infos=param_infos,
save_args=save_args,
custom_metadata=custom_metadata,
use_zarr3=use_zarr3,
).result()
)
await asyncio.to_thread(metadata_write_future.result)

def _read_metadata_file(
self, directory: epath.Path
@@ -947,20 +957,45 @@ def finalize(self, directory: epath.Path) -> None:
Args:
directory: Path where the checkpoint is located.
"""
merge_start_time = time.time()
ts_context = ts_utils.get_ts_context(use_ocdbt=True)
asyncio_utils.run_sync(
type_handlers.merge_ocdbt_per_process_files(
directory,
ts_context=ts_context,
use_zarr3=self._use_zarr3,
enable_validation=self._enable_post_merge_validation,
finalize_coros = []
if self._array_metadata_store is not None:
if self._primary_host is None:
logging.log_first_n(
logging.WARNING,
'[process=%s] Skipped cross-host ArrayMetadata validation'
' because all hosts are primary (e.g. local storage).',
1, # log only once
multihost.process_index(),
)
)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/ocdbt_merge_duration_secs',
time.time() - merge_start_time,
)
elif utils.is_primary_host(self._primary_host):
finalize_coros.append(
array_metadata_store_lib.validate_all_array_metadatas(
self._array_metadata_validator,
self._array_metadata_store,
directory,
)
)

async def merge_ocdbt_per_process_files():
merge_start_time = time.time()
ts_context = ts_utils.get_ts_context(use_ocdbt=True)
await type_handlers.merge_ocdbt_per_process_files(
directory,
ts_context=ts_context,
use_zarr3=self._use_zarr3,
enable_validation=self._enable_post_merge_validation,
)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/ocdbt_merge_duration_secs',
time.time() - merge_start_time,
)

finalize_coros.append(merge_ocdbt_per_process_files())

async def _fn():
await asyncio.gather(*finalize_coros)

asyncio_utils.run_sync(_fn())

def close(self):
"""Closes the handler. Called automatically by Checkpointer."""
Original file line number Diff line number Diff line change
@@ -41,6 +41,7 @@
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.serialization import serialization
@@ -477,6 +478,9 @@ def __init__(
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
tree_metadata.PYTREE_METADATA_OPTIONS
),
array_metadata_validator: array_metadata_store_lib.Validator = (
array_metadata_store_lib.Validator()
),
):
"""Creates PyTreeCheckpointHandler.
@@ -498,6 +502,7 @@ def __init__(
specified, the global type handler registry will be used.
handler_impl: Allows overriding the internal implementation.
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
array_metadata_validator: Validator for ArrayMetadata.
"""
self._aggregate_handler = MsgpackHandler(
primary_host=multiprocessing_options.primary_host,
@@ -520,6 +525,7 @@ def __init__(
multiprocessing_options=multiprocessing_options,
type_handler_registry=type_handler_registry,
pytree_metadata_options=pytree_metadata_options,
array_metadata_validator=array_metadata_validator,
)
self._pytree_metadata_options = pytree_metadata_options

281 changes: 262 additions & 19 deletions checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store.py
Original file line number Diff line number Diff line change
@@ -14,9 +14,11 @@

"""Storage for `array_metadata.ArrayMetadata` (not value.ArrayMetadata)."""

import asyncio
import json
import threading
from typing import Any, Iterator, List, Sequence
import time
from typing import Any, Iterator, List, Sequence, Tuple
from absl import logging
from etils import epath
import jax
@@ -78,7 +80,7 @@ def get_read_file_paths(


class SerDeserializer:
"""Serializes and deserializes `tensorstore_utils.ArrayMetadata`."""
"""Serializes and deserializes `array_metadata.ArrayMetadata`."""

def _to_dict(
self, array_metadata: array_metadata_lib.ArrayMetadata
@@ -117,13 +119,13 @@ def serialize(
def deserialize(
self, serialized: str
) -> List[array_metadata_lib.SerializedArrayMetadata]:
"""Deserializes `serialized` to `tensorstore_utils.ArrayMetadata`."""
"""Deserializes `serialized` to `array_metadata.ArrayMetadata`."""
obj = json.loads(serialized, object_hook=self._from_dict)
return obj['array_metadatas']


class Store:
"""Storage for `tensorstore_utils.ArrayMetadata` (not value.ArrayMetadata)."""
"""Storage for `array_metadata.ArrayMetadata` (not value.ArrayMetadata)."""

def __init__(
self,
@@ -151,18 +153,26 @@ async def write(
file_path = self._path_resolver.get_write_file_path(
checkpoint_dir, process_index
)
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text(self._ser_deser.serialize(array_metadatas))
await asyncio.to_thread(file_path.parent.mkdir, parents=True, exist_ok=True)
await asyncio.to_thread(
file_path.write_text, self._ser_deser.serialize(array_metadatas)
)
logging.info(
'[process=%s][thread=%s] Wrote %d tensorstore_utils.ArrayMetadata'
' to %s',
'[process=%s][thread=%s] Wrote %d array_metadata.ArrayMetadata to %s',
multihost.process_index(),
threading.current_thread().name,
len(array_metadatas),
file_path,
)

def read(
async def _get_array_metadatas(
self,
array_metadatas_file_path: epath.Path,
) -> Tuple[epath.Path, list[array_metadata_lib.SerializedArrayMetadata]]:
serialized = await asyncio.to_thread(array_metadatas_file_path.read_text)
return array_metadatas_file_path, self._ser_deser.deserialize(serialized)

async def read(
self,
checkpoint_dir: epath.Path,
process_index: int | None = None,
@@ -183,42 +193,73 @@ def read(
is None. A list of metadata if `process_index` is not None. None if
metadata does not exist.
"""
if not checkpoint_dir.exists():
if not await asyncio.to_thread(checkpoint_dir.exists):
raise ValueError(
f'Checkpoint directory does not exist: {checkpoint_dir}.'
)
start_time = time.time()
file_paths = self._path_resolver.get_read_file_paths(
checkpoint_dir, process_index
)
if file_paths is None:
logging.warning(
'[process=%s][thread=%s] No metadata found for process_index=%s,'
' checkpoint_dir=%s. Please ignore if input checkpoint does not'
' contain any jax.Array.',
' checkpoint_dir=%s. If the checkpoint does not contain jax.Array'
' then it is expected. If checkpoint contains jax.Array then it'
' should lead to an error eventually; if no error is raised then it'
' is a bug.',
multihost.process_index(),
threading.current_thread().name,
process_index,
checkpoint_dir,
)
return None

if isinstance(file_paths, epath.Path):
return self._ser_deser.deserialize(file_paths.read_text())
_, result = await self._get_array_metadatas(file_paths)
logging.vlog(
1,
'[process=%s][thread=%s] Read %s metadata from metadata path=%s'
' in %s seconds.',
multihost.process_index(),
threading.current_thread().name,
len(result),
file_paths,
time.time() - start_time,
)
return result

getter_coros = []
for file_path in file_paths:
getter_coros.append(self._get_array_metadatas(file_path))
path_metadatas_pairs = await asyncio.gather(*getter_coros)
result = {
self._path_resolver.get_process_index(
file_path
): self._ser_deser.deserialize(file_path.read_text())
for file_path in file_paths
self._path_resolver.get_process_index(file_path): metadatas
for file_path, metadatas in path_metadatas_pairs
}
if not result:
logging.warning(
'[process=%s][thread=%s] No metadata found for any process_index,'
' checkpoint_dir=%s. Please ignore if input checkpoint does not'
' contain any jax.Array.',
' checkpoint_dir=%s. time elapsed=%s seconds. If the checkpoint does'
' not contain jax.Array then it is expected. If checkpoint contains'
' jax.Array then it should lead to an error eventually; if no error'
' is raised then it is a bug.',
multihost.process_index(),
threading.current_thread().name,
checkpoint_dir,
time.time() - start_time,
)
return None

logging.vlog(
1,
'[process=%s][thread=%s] Read all metadata from checkpoint_dir=%s in %s'
' seconds.',
multihost.process_index(),
threading.current_thread().name,
checkpoint_dir,
time.time() - start_time,
)
return result


@@ -252,3 +293,205 @@ def resolve_array_metadata_store(
array_handler.__class__.__qualname__,
)
return None


class Validator:
"""Validates ArrayMetadata."""

def validate_all_array_metadatas(
self,
array_metadatas: dict[
int, List[array_metadata_lib.SerializedArrayMetadata]
],
) -> None:
"""Validates that all processes have the same array metadatas.
Args:
array_metadatas: A dictionary of process index to list of metadata.
"""
start_time = time.time()
ref_process_index, ref_process_array_metadatas = next(
iter(array_metadatas.items())
)
if not ref_process_array_metadatas:
raise ValueError(
'ArrayMetadata Store contains no metadata for process_index='
f'{ref_process_index}.'
)
if len(array_metadatas) == 1:
# check if the number of processes are indeed just one.
self._validate_process_count(ref_process_index=ref_process_index)
logging.warning(
'[process=%s][thread=%s] Skipped cross-host ArrayMetadata validation'
' because only one process is found: process_index=%s.',
multihost.process_index(),
threading.current_thread().name,
ref_process_index,
)
return

ref_process_cache = {
array_metadata.param_name: array_metadata
for array_metadata in ref_process_array_metadatas
}
for process_index, process_array_metadatas in array_metadatas.items():
if process_index == ref_process_index:
continue
process_cache = {
array_metadata.param_name: array_metadata
for array_metadata in process_array_metadatas
}
# Check if the number of params is the same.
self._validate_param_count(
ref_process_index=ref_process_index,
ref_process_array_metadatas=ref_process_array_metadatas,
ref_process_cache=ref_process_cache,
process_index=process_index,
process_array_metadatas=process_array_metadatas,
process_cache=process_cache,
)
# Check if the params are the same.
self._validate_params(
ref_process_index=ref_process_index,
ref_process_cache=ref_process_cache,
process_index=process_index,
process_cache=process_cache,
)
# Check if the chunk_shape and write_shape are the same for each param.
self._validate_chunk_shape_and_write_shape(
ref_process_index=ref_process_index,
ref_process_cache=ref_process_cache,
process_index=process_index,
process_cache=process_cache,
)
logging.info(
'[process=%s][thread=%s] Validated ArrayMetadata from all %s hosts in'
' %s seconds.',
multihost.process_index(),
threading.current_thread().name,
len(array_metadatas),
time.time() - start_time,
)

def _validate_process_count(
self,
*,
ref_process_index: int,
) -> None:
"""Validates that the number of processes is just one."""
process_count = multihost.process_count()
if process_count != 1:
raise ValueError(
'ArrayMetadata Store contains metadata from just one process'
f' (process_index={ref_process_index}), but found'
f' {process_count} processes.'
)

def _validate_param_count(
self,
*,
ref_process_index: int,
ref_process_array_metadatas: List[
array_metadata_lib.SerializedArrayMetadata
],
ref_process_cache: dict[str, array_metadata_lib.SerializedArrayMetadata],
process_index: int,
process_array_metadatas: List[array_metadata_lib.SerializedArrayMetadata],
process_cache: dict[str, array_metadata_lib.SerializedArrayMetadata],
) -> None:
"""Validates that the number of params is the same."""
if len(ref_process_array_metadatas) != len(process_array_metadatas):
missing_in_process = ref_process_cache.keys() - process_cache.keys()
missing_in_ref_process = process_cache.keys() - ref_process_cache.keys()
diff_msg = 'Diff:'
if missing_in_process:
diff_msg += (
f' process_index={process_index} is missing'
f' {len(missing_in_process)} params: {missing_in_process}.'
)
if missing_in_ref_process:
diff_msg += (
f' process_index={ref_process_index} is missing'
f' {len(missing_in_ref_process)} params:'
f' {missing_in_ref_process}.'
)
raise ValueError(
'ArrayMetadata Store contains different number of params:'
f' process_index={ref_process_index} has'
f' {len(ref_process_array_metadatas)}, but'
f' process_index={process_index} has'
f' {len(process_array_metadatas)} params. {diff_msg}'
)

def _validate_params(
self,
*,
ref_process_index: int,
ref_process_cache: dict[str, array_metadata_lib.SerializedArrayMetadata],
process_index: int,
process_cache: dict[str, array_metadata_lib.SerializedArrayMetadata],
) -> None:
"""Validates that the params are the same."""
missing_in_process = ref_process_cache.keys() - process_cache.keys()
if missing_in_process:
raise ValueError(
'ArrayMetadata Store contains different params: comparing with'
f' process_index={ref_process_index},'
f' process_index={process_index} is missing'
f' {len(missing_in_process)} params: {missing_in_process}.'
)
missing_in_ref_process = process_cache.keys() - ref_process_cache.keys()
if missing_in_ref_process:
raise ValueError(
'ArrayMetadata Store contains different params: comparing with'
f' process_index={process_index},'
f' process_index={ref_process_index} is missing'
f' {len(missing_in_ref_process)} params: {missing_in_ref_process}.'
)

def _validate_chunk_shape_and_write_shape(
self,
*,
ref_process_index: int,
ref_process_cache: dict[str, array_metadata_lib.SerializedArrayMetadata],
process_index: int,
process_cache: dict[str, array_metadata_lib.SerializedArrayMetadata],
) -> None:
"""Validates that chunk_shape and write_shape are the same for a param."""
for param_name, array_metadata in process_cache.items():
ref_array_metadata = ref_process_cache[param_name]
if array_metadata.chunk_shape != ref_array_metadata.chunk_shape:
raise ValueError(
'ArrayMetadata Store contains different chunk_shape for param:'
f' {param_name}. process_index={process_index} has'
f' {array_metadata.chunk_shape}, but'
f' process_index={ref_process_index} has'
f' {ref_array_metadata.chunk_shape}.'
)
if array_metadata.write_shape != ref_array_metadata.write_shape:
raise ValueError(
'ArrayMetadata Store contains different write_shape for param:'
f' {param_name}. process_index={process_index} has'
f' {array_metadata.write_shape}, but'
f' process_index={ref_process_index} has'
f' {ref_array_metadata.write_shape}.'
)


async def validate_all_array_metadatas(
validator: Validator,
array_metadata_store: Store,
directory: epath.Path,
) -> None:
"""Validates that all processes have the same array metadatas.
Args:
validator: The `Validator` instance for validation.
array_metadata_store: The `Store` instance for reading metadata from
`directory`.
directory: The checkpoint directory with array_metadatas subdir.
"""
array_metadatas = await array_metadata_store.read(directory)
if array_metadatas is not None:
assert isinstance(array_metadatas, dict) # read all processes.
validator.validate_all_array_metadatas(array_metadatas)
182 changes: 175 additions & 7 deletions checkpoint/orbax/checkpoint/_src/metadata/array_metadata_store_test.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

"""Tests for `array_metadata_store` module."""

from typing import List
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@@ -33,19 +34,19 @@ def setUp(self):
self.checkpoint_dir = epath.Path(self.create_tempdir().full_path)
self.store = array_metadata_store_lib.Store()

def test_non_existing_checkpoint_dir(self):
async def test_non_existing_checkpoint_dir(self):
with self.assertRaisesRegex(
ValueError, 'Checkpoint directory does not exist'
):
_ = self.store.read(self.checkpoint_dir / 'unknown_dir')
_ = await self.store.read(self.checkpoint_dir / 'unknown_dir')

def test_non_existing_metadata_files(self):
self.assertIsNone(self.store.read(self.checkpoint_dir))
async def test_non_existing_metadata_files(self):
self.assertIsNone(await self.store.read(self.checkpoint_dir))

(self.checkpoint_dir / 'array_metadatas').mkdir(
parents=True, exist_ok=False
)
self.assertIsNone(self.store.read(self.checkpoint_dir))
self.assertIsNone(await self.store.read(self.checkpoint_dir))

async def test_write_and_read_single_process(self):
process_index = 0
@@ -74,7 +75,7 @@ async def test_write_and_read_single_process(self):
)

self.assertEqual(
self.store.read(self.checkpoint_dir, process_index=process_index),
await self.store.read(self.checkpoint_dir, process_index=process_index),
[
array_metadata_lib.SerializedArrayMetadata(
param_name='a',
@@ -107,7 +108,7 @@ async def test_write_and_read_multiple_process(self):
)

self.assertEqual(
self.store.read(self.checkpoint_dir, process_index=None),
await self.store.read(self.checkpoint_dir, process_index=None),
{
0: [
array_metadata_lib.SerializedArrayMetadata(
@@ -185,5 +186,172 @@ def __init__(self):
self.assertIsNone(store)


class ValidatorTest(parameterized.TestCase):

@parameterized.named_parameters([
dict(
testcase_name='empty_array_metadatas',
array_metadatas={0: []},
expected_error_regex=(
'ArrayMetadata Store contains no metadata for process_index=0'
),
),
dict(
testcase_name='different_number_of_params',
array_metadatas={
0: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
)
],
1: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(1,), chunk_shape=(1,)
),
],
},
expected_error_regex=(
'ArrayMetadata Store contains different number of params'
),
),
dict(
testcase_name='different_params',
array_metadatas={
0: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='c', write_shape=(1,), chunk_shape=(1,)
),
],
1: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(1,), chunk_shape=(1,)
),
],
},
expected_error_regex='ArrayMetadata Store contains different params',
),
dict(
testcase_name='different_chunk_shapes',
array_metadatas={
0: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(1,), chunk_shape=(2,)
),
],
1: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(1,), chunk_shape=(3,)
),
],
},
expected_error_regex=(
'ArrayMetadata Store contains different chunk_shape'
),
),
dict(
testcase_name='different_write_shapes',
array_metadatas={
0: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(2,), chunk_shape=(3,)
),
],
1: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(3,), chunk_shape=(3,)
),
],
},
expected_error_regex=(
'ArrayMetadata Store contains different write_shape'
),
),
dict(
testcase_name='single_process_array_metadatas',
array_metadatas={
0: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
)
]
},
expected_error_regex=None,
),
dict(
testcase_name='valid_array_metadatas',
array_metadatas={
0: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(2,), chunk_shape=(2,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='c', write_shape=(3,), chunk_shape=(3,)
),
],
1: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(2,), chunk_shape=(2,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='c', write_shape=(3,), chunk_shape=(3,)
),
],
2: [
array_metadata_lib.SerializedArrayMetadata(
param_name='a', write_shape=(1,), chunk_shape=(1,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='b', write_shape=(2,), chunk_shape=(2,)
),
array_metadata_lib.SerializedArrayMetadata(
param_name='c', write_shape=(3,), chunk_shape=(3,)
),
],
},
expected_error_regex=None,
),
])
def test_validate_all_array_metadatas(
self,
array_metadatas: dict[
int, List[array_metadata_lib.SerializedArrayMetadata]
],
expected_error_regex: str | None,
):
validator = array_metadata_store_lib.Validator()
if expected_error_regex is None:
validator.validate_all_array_metadatas(array_metadatas)
else:
with self.assertRaisesRegex(ValueError, expected_error_regex):
validator.validate_all_array_metadatas(array_metadatas)


if __name__ == '__main__':
absltest.main()
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/async_utils.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,8 @@
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.path import step as step_lib

# TODO(b/360190539): Why not use just asyncio.to_thread?


# TODO(b/360190539): This functionality should be provided by either an external
# library or Orbax should subclass epath.Path.
@@ -35,6 +37,8 @@ def async_makedirs(
)




def async_write_bytes(path: epath.Path, data: Any):
return asyncio_utils.as_async_function(path.write_bytes)(data)

3 changes: 1 addition & 2 deletions checkpoint/orbax/checkpoint/_src/serialization/BUILD
Original file line number Diff line number Diff line change
@@ -35,8 +35,6 @@ py_library(
":serialization",
":tensorstore_utils",
":types",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/arrays:subchunking",
"//checkpoint/orbax/checkpoint/_src/arrays:types",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
@@ -46,6 +44,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/multihost:multislice",
"//checkpoint/orbax/checkpoint/_src/path:async_utils",
"//checkpoint/orbax/checkpoint/_src/path:format_utils",
"//orbax/checkpoint/_src/futures:future",
"//orbax/checkpoint/_src/metadata:array_metadata_store",
],
)
23 changes: 4 additions & 19 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
@@ -35,10 +35,9 @@
from jax.experimental import layout
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import future
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.arrays import subchunking
from orbax.checkpoint._src.arrays import types as arrays_types
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import sharding as sharding_metadata
@@ -239,20 +238,6 @@ def _build_array_write_spec(
)


class _CommitFuture(future.Future):
"""Represents the result of a background commit."""

def __init__(self, coro, name: Optional[str] = None):
self._t = future.ThreadRaisingException(
name=name,
target=lambda: asyncio_utils.run_sync(coro),
)
self._t.start()

def result(self, timeout: Optional[int] = None) -> Any:
return self._t.join(timeout=timeout)


def check_input_arguments(*args):
l = None
for arg in args:
@@ -650,7 +635,7 @@ async def serialize(
_print_ts_debug_data(self._metadata_key, infos)
copied_values = [copy.deepcopy(v) for v in values]
return [
_CommitFuture(
future.CoroutineRunningFuture(
self._background_serialize(copied_values, infos, args),
name='np_type_handler',
)
@@ -1131,7 +1116,7 @@ async def serialize(
)

return [
_CommitFuture(
future.CoroutineRunningFuture(
self._background_serialize(values_on_host, infos, args),
name='array_type_handler',
)
@@ -1576,7 +1561,7 @@ async def serialize(
del args
# Copy is not needed since strings are passed by value.
return [
_CommitFuture(
future.CoroutineRunningFuture(
self._background_serialize(values, infos),
name='string_type_handler',
)

0 comments on commit d4a2579

Please sign in to comment.