This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,13 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,514 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Serializations routines for pytrees including array and non-array serialization.
"""
from __future__ import annotations
from os import PathLike
import os
import re
from typing import Any
from uuid import uuid4, UUID
import json
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
import shutil
import logging
import jax
from jax._src import distributed
from jax._src.api_util import flatten_axes
from jax._src.layout import Format
from jax.experimental import multihost_utils
from jax.experimental.array_serialization import tensorstore_impl as ts_impl
import jax.experimental.array_serialization.pytree_serialization_utils as utils
from jax._src import path as pathlib
import numpy as np
logger = logging.getLogger(__name__)
_THREADING_SAVE_LOCK = threading.Lock()
_REMOTE_URL_PREFIXES = ['gs://', 's3://']
_PYTREEDEF_FILE = "pytreedef.json"
_ARCHIVE_NAME = "archive.zip"
_USE_OCDBT = True # a lot of the code relies on this being True
_MAX_PATH_LENGTH = 4096
_ARRAY_STORE_DIRNAME = "array_store"
_ARRAY_TYPE_FORMAT = "Array({dtype}[{shape}])"
_ARRAY_TYPE_REGEX = r"Array\(([a-zA-Z0-9_]+)\[([0-9, ]*)\]\)"
_MAX_CONCURRENCY = 32
_TIMEOUT_SEC = 30
PyTreeT = Any
__all__ = ["save", "load", "load_pytreedef",
"nonblocking_load", "nonblocking_save"]
def _get_unique_sync_key() -> str | None:
"""Generate a thread-local key for ensuring all host finish (de)serializing"""
if jax.process_count() == 1:
return None
# broadcast a thread-local unique barrier name
sync_key_unique = multihost_utils.broadcast_one_to_all(
np.frombuffer(uuid4().bytes, dtype=np.int32))
sync_key_id = UUID(bytes=np.array(sync_key_unique).tobytes())
return f"jax_sync_key_{str(sync_key_id)}"
def _is_str_same_on_all_hosts(path: str | PathLike[str]) -> bool:
"""All-gather the location of the checkpoint and check if it's the same."""
if jax.process_count() <= 1:
return False
path_b = str(path).encode("utf-8")
if len(path_b) > _MAX_PATH_LENGTH:
raise ValueError(f"Path exceeds maximum length of {_MAX_PATH_LENGTH} in"
" multiprocess case.")
path_array = np.concatenate([
np.frombuffer(path_b, dtype=np.uint8), np.zeros(
_MAX_PATH_LENGTH - len(path_b), dtype=np.uint8)])
path_array = multihost_utils.process_allgather(path_array)
return bool(np.all(path_array[0] == path_array[1:]))
def _sync_on_key(key: str | None, extra_tag: str = "") -> None:
if key is None:
return
full_key = f"{key}-{extra_tag}" if extra_tag else key
if (client := distributed.global_state.client) is not None:
client.wait_at_barrier(full_key, timeout_in_ms=_TIMEOUT_SEC * 1000)
def _is_array_like(x):
return isinstance(x, (jax.Array, np.ndarray))
def _leaf_to_desc(leaf) -> str:
if leaf is None:
return "null"
elif _is_array_like(leaf):
return _ARRAY_TYPE_FORMAT.format(
dtype=leaf.dtype.name, shape=", ".join(map(str, leaf.shape)))
else:
return type(leaf).__name__
def _desc_to_leaf(leaf_desc: str | None) -> str | None | jax.ShapeDtypeStruct:
if leaf_desc is None:
return None
if not re.match(_ARRAY_TYPE_REGEX, leaf_desc):
return leaf_desc
shape_dtype_match = re.match(_ARRAY_TYPE_REGEX, leaf_desc)
assert shape_dtype_match is not None
dtype_str, shape_str = shape_dtype_match.groups()
shape = [int(x.strip()) for x in shape_str.strip("]").strip().split(",")
if len(x.strip()) > 0]
return jax.ShapeDtypeStruct(shape, jax.numpy.dtype(dtype_str))
def _is_remote_path(path: str | PathLike[str]):
"""Check whether a path is remote by examining the prefix."""
# we need to truncate e.g., gs:// to gs:/ because pathlib.Path collapses //
return any(str(path).startswith(prefix[:-1])
for prefix in _REMOTE_URL_PREFIXES)
def _norm_path(path: str | PathLike[str]) -> Any:
if _is_remote_path(path):
return pathlib.Path(path)
return pathlib.Path(path).expanduser().resolve()
def _rm_dir(root: Any) -> None:
if _is_remote_path(root):
root.rmtree()
else:
shutil.rmtree(root)
def _set_up_destination(root: str | PathLike[str], overwrite: bool,
pytree_repr: dict[str, Any], distinct_locations: bool,
sync_key: str | None) -> dict[str, Any]:
"""Inspect the destination, set it up for writing, potentially read existing data."""
root = _norm_path(root)
if overwrite:
if root.exists() and len(list(root.iterdir())) > 0:
# check that we're only deleting things that come from JAX
# refuse to rm directories containing additional entries
extra_member_paths = [
path for path in list(root.iterdir()) if path.name not in
(_PYTREEDEF_FILE, _ARCHIVE_NAME, _ARRAY_STORE_DIRNAME)]
if len(extra_member_paths) != 0:
raise RuntimeError(
"Refusing to work on a directory that is not a previous checkpoint."
f" Unrecognized paths: {extra_member_paths}. Remove them manually"
f" if you're sure you want to use {root} as the checkpoint"
" directory.")
if (jax.process_index() == 0 or distinct_locations) and root.exists():
_rm_dir(root)
_sync_on_key(sync_key, "overwrite")
return pytree_repr
else:
if (root.exists() and len(list(root.iterdir())) > 0): # not empty
raise ValueError(f"Files already exist at path: `{root}`, but you"
f" specified `{overwrite=}`")
return pytree_repr
def _prepare_directory(root: str | PathLike[str], overwrite: bool,
pytreedef_repr: dict[str, Any], distinct_locations: bool,
sync_key: str | None):
"""Prepare the directory: check destination, potentially read existing data
and overwrite.
Raises:
RuntimeError: If the destination directory cannot be created.
"""
root = _norm_path(root)
# prepare the destination directory, overwrite destination directory or error
pytreedef_repr = _set_up_destination(
root, overwrite, pytreedef_repr, distinct_locations, sync_key)
if not _is_remote_path(root) and (distinct_locations
or jax.process_index() == 0):
root.mkdir(exist_ok=True) # do not make parents, that's too much
if not root.exists() or not root.is_dir():
raise RuntimeError(f"Could not create destination directory at {root}")
_sync_on_key(sync_key, "mkdir")
return pytreedef_repr
def _write_arrays(array_store_path: Any, arrs: list[Any],
arr_leaf_ids: list[int], ts_specs: list[Any | None],
distinct_locations: bool):
paths = [array_store_path / str(leaf_id) for leaf_id in arr_leaf_ids]
process_idx = None
if not distinct_locations and jax.process_count() > 1:
process_idx = jax.process_index()
default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT,
process_idx=process_idx,
arr=arr)
for (path, arr) in zip(paths, arrs)]
processed_ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec)
for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)]
# sanity check the ts specs
if len(ts_specs) > 0: # verify the base path is shared for all arrays
expected_path = processed_ts_specs[0]["kvstore"]["base"]["path"] # shared base path
for ts_spec, arr in zip(processed_ts_specs, arrs):
assert ts_spec is not None
ts_impl.verify_tensorstore_spec(ts_spec, arr, expected_path,
ocdbt=_USE_OCDBT, check_metadata=True)
async def _serialize_arrays():
await asyncio.gather(*[
ts_impl.async_serialize(arr, ts_spec, primary_host=None)
for (arr, ts_spec) in zip(arrs, processed_ts_specs)])
asyncio.run(_serialize_arrays())
def _finalize_array_store(kvstore_path, distinct_locations: bool):
"""When multiple processes are writing, they must write to a per-process
location followed by combining them via no-copy links to the final location.
"""
# only in multiprocess case and only process 0
if distinct_locations or jax.process_count() == 1 or jax.process_index() != 0:
return
dummy_key_path = os.path.join(kvstore_path, "dummy_key")
combined_kvstore = ts_impl.get_tensorstore_spec(
dummy_key_path, ocdbt=True, process_idx=None)["kvstore"]
children_kvstores = [ts_impl.get_tensorstore_spec(
dummy_key_path, ocdbt=True, process_idx=i)["kvstore"]
for i in range(jax.process_count())]
_ = combined_kvstore.pop("path")
_ = [kvstore.pop("path") for kvstore in children_kvstores]
asyncio.run(ts_impl.combine_kvstores(combined_kvstore, children_kvstores))
def _write_pytreedef(directory: Any, pytree_repr: dict[str, Any],
distinct_locations: bool):
"""Write the pytreedef to the destination directory and aux data to the archive."""
if not (jax.process_index() == 0 or distinct_locations):
return
root = _norm_path(directory)
(root / _PYTREEDEF_FILE).write_text(json.dumps(pytree_repr, indent=2))
def _tree_broadcast(a, b, is_leaf=lambda x: x is None):
"""Broadcast the prefix tree `a` to the full tree `b`
Uses `flatten_axes` for better error messages on mismatched arity but allowing
for custom is_leaf in the `a` and `b` trees.
"""
a_leaves, a_struct = jax.tree.flatten(a, is_leaf=is_leaf)
a_idx2leaf_map = dict(enumerate(a_leaves))
a_idx = jax.tree.unflatten(a_struct, a_idx2leaf_map.keys())
a_idx_broadcast = flatten_axes("tree_broadcast",
jax.tree.structure(b, is_leaf=is_leaf), a_idx)
return jax.tree.map(lambda i: a_idx2leaf_map[i], a_idx_broadcast)
_serialization_executor = ThreadPoolExecutor(max_workers=_MAX_CONCURRENCY)
def save(data: PyTreeT, directory: str | PathLike[str], *,
overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None:
"""Saves the given data structure to the provided directory path.
This function provides functionality to serialize and save a data structure
comprising JAX arrays, along with its structure to a given directory. It
leverages `PyTree` for flattening and reconstructing the data structure.
This is a simple experimental array serialization API, for anything more
complex and for all checkpointing prefer: https://github.com/google/orbax
Args:
data: The data structure to be saved. Arbitrary composition of JAX arrays,
including nested structures.
directory: The directory path where the data will be saved. A local path or
a remote URL (e.g., gs://, s3://). For remote URLs, `etils` is required.
overwrite: If True, any existing directory with the same name will be
overwritten.
ts_specs: Optional tensorstore specs to use for serialization. If None,
defaults to using the default tensorstore specs.
Example:
>>> data = {"a": jnp.array([1, 2]), "b": None}
>>> save(data, directory)
"""
with _THREADING_SAVE_LOCK:
return _save(data, directory, overwrite=overwrite, ts_specs=ts_specs)
def _save(data: PyTreeT, directory: str | PathLike[str], *,
overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None:
sync_key = _get_unique_sync_key() # get a synchronization key for multi-host
if _is_remote_path(directory) and not pathlib.epath_installed:
raise RuntimeError("For saving to remote URLs (e.g., gs, s3) you need the"
" `etils` module installed. You can install it using"
" `pip install etils`.")
ts_specs = _tree_broadcast(ts_specs, data,
is_leaf=ts_impl.is_tensorstore_spec_leaf)
data_flat, pytreedef = jax.tree.flatten(data, is_leaf=lambda x: x is None)
if not all(x is None or _is_array_like(x) for x in data_flat):
raise ValueError("For serialization, all leaves must be either None or"
" jax.Array-like objects.")
distinct_locations = not _is_str_same_on_all_hosts(directory)
if jax.process_count() > 1 and distinct_locations:
raise ValueError(
"Saving to different locations on different hosts is not supported,"
" because it is extremely fragile. Consider using a single location.")
root = _norm_path(directory)
# 1. serialize the pytree #################################
pytreedef_repr = utils.serialize_pytreedef(pytreedef)
pytreedef_repr[utils._LEAF_IDS_KEY] = jax.tree.map(_leaf_to_desc, data_flat)
pytreedef_repr = _prepare_directory(
root, overwrite, pytreedef_repr, distinct_locations, sync_key)
futures = []
futures.append(_serialization_executor.submit(
_write_pytreedef, root, pytreedef_repr, distinct_locations))
# 2. serialize arrays #####################################
array_store_path = root / _ARRAY_STORE_DIRNAME
arrs = [data for data in data_flat if _is_array_like(data)]
arr_leaf_ids = [i for i, data in enumerate(data_flat) if _is_array_like(data)]
ts_specs_flat = jax.tree.leaves(ts_specs,
is_leaf=ts_impl.is_tensorstore_spec_leaf)
ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids]
futures.append(_serialization_executor.submit(
_write_arrays, array_store_path, arrs, arr_leaf_ids, ts_specs_flat,
distinct_locations))
# 3. wait for all futures to complete #####################
_ = [fut.result() for fut in futures]
_sync_on_key(sync_key, "array_serialization")
# 4. finalize the array writing ###########################
if len(arr_leaf_ids) > 0 and _USE_OCDBT:
_serialization_executor.submit( # call from a thread to not nest asyncio
_finalize_array_store, array_store_path, distinct_locations).result()
# we are done with all async ops here, we can block ####
_sync_on_key(sync_key, "end")
def _read_arrays(array_store_path: str | PathLike[str], arr_leaf_ids: list[int],
ts_specs: list[Any], shardings: list[Any]):
# array_store_path = root / _LEAF_DATA_DIR / _ARRAY_STORE_DIRNAME
arr_store_path = _norm_path(array_store_path)
arr_paths = [arr_store_path / str(leaf_id) for leaf_id in arr_leaf_ids]
# byte limiter to limit number of parallel reads, resizes to largest read
byte_limiter = ts_impl._LimitInFlightBytes(10 * 1024 ** 3) # 10 GB
default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT,
process_idx=None)
for path in arr_paths]
ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec)
for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)]
if len(ts_specs) > 0: # verify the base path is shared for all arrays
expected_path = ts_specs[0]["kvstore"]["base"]["path"] # shared base path
for ts_spec in ts_specs:
ts_impl.verify_tensorstore_spec(ts_spec, arr=None, path=expected_path,
ocdbt=_USE_OCDBT, check_metadata=False)
async def _deserialize_arrays():
return await asyncio.gather(*[
ts_impl.async_deserialize(sharding, ts_spec, byte_limiter=byte_limiter)
for (sharding, ts_spec) in zip(shardings, ts_specs)])
return dict(zip(arr_leaf_ids, asyncio.run(_deserialize_arrays())))
def load_pytreedef(directory: str | PathLike[str]) -> PyTreeT:
"""Loads a pytree from the given directory.
This is a simple experimental array serialization API, for anything more
complex and for all checkpointing prefer: https://github.com/google/orbax
Args:
directory: Directory path to load from.
Returns:
The loaded pytree with arrays represented as jax.ShapeDtypeStruct's.
"""
assert not _is_remote_path(directory) or pathlib.epath_installed, (
"For checkpointing using remote URLs (e.g., gs, s3) you need `etils`"
" module installed. You can install it using `pip install etils`.")
json_content = (_norm_path(directory) / _PYTREEDEF_FILE).read_text()
raw_tree = json.loads(json_content)
leaves = map(_desc_to_leaf, raw_tree[utils._LEAF_IDS_KEY])
return jax.tree.unflatten(utils.deserialize_pytreedef(raw_tree), leaves)
def load(directory: str | PathLike[str], shardings: PyTreeT, *,
mask: PyTreeT | None = None, ts_specs: PyTreeT | None = None
) -> PyTreeT:
"""Loads and reconstructs a data structure from a directory.
This is a simple experimental array serialization API, for anything more
complex and for all checkpointing prefer: https://github.com/google/orbax
Args:
directory: Directory path where the data is stored.
shardings: Sharding strategy for array objects, either a Sharding or a
ShapeDtypeStruct with a Sharding/Format.
mask: boolean prefix tree for partial loading, will return None for False
leaves.
ts_specs: Optional tensorstore specs to use for deserialization. If None,
defaults to using the default tensorstore specs.
Returns:
Reconstructed data.
Example:
>>> save(data, directory)
>>> restored_data = load(directory, SingleDeviceSharding(jax.devices()[0]))
"""
assert not _is_remote_path(directory) or pathlib.epath_installed, (
"For checkpointing using remote URLs (e.g., gs, s3) you need `etils`"
" module installed. You can install it using `pip install etils`.")
root = _norm_path(directory)
assert root.is_dir(), f"Checkpoint directory {root} does not exist"
is_leaf = lambda x: x is None
# deserialize PyTreeDef
pytree = load_pytreedef(directory)
# broadcast the (prefix) shardings and tensorstore specs to the full pytree
shardings = _tree_broadcast(shardings, pytree)
ts_specs = _tree_broadcast(ts_specs, pytree,
is_leaf=ts_impl.is_tensorstore_spec_leaf)
if mask is not None:
_prefix_mask = lambda m, x: jax.tree.map(lambda _: None, x) if not m else x
pytree = jax.tree.map(_prefix_mask, mask, pytree)
pytreedef = jax.tree.structure(pytree, is_leaf=is_leaf)
leaf_ids_flat = jax.tree.leaves(pytree, is_leaf=is_leaf)
shardings_flat = jax.tree.leaves(shardings, is_leaf=is_leaf)
if any(isinstance(shardings, Format) for shardings in shardings_flat):
raise NotImplementedError(
"Deserialization with `Format` instead of `Sharding` is not currently"
" supported. Pass ShapeDtypeStruct(shape, dtype, sharding=format)"
" instead.")
ts_specs_flat = jax.tree.leaves(ts_specs,
is_leaf=ts_impl.is_tensorstore_spec_leaf)
# deserialize array objects
arr_leaf_ids = [i for i, leaf_id in enumerate(leaf_ids_flat)
if leaf_id is not None]
shardings_flat = [shardings_flat[i] for i in arr_leaf_ids]
ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids]
arrs_fut = _serialization_executor.submit(
_read_arrays, root / _ARRAY_STORE_DIRNAME, arr_leaf_ids, ts_specs_flat,
shardings_flat)
arrs = arrs_fut.result()
filled_values = [arrs.get(i, None) for i, _ in enumerate(leaf_ids_flat)]
return jax.tree.unflatten(pytreedef, filled_values)
def nonblocking_save(data: PyTreeT, directory: str | PathLike[str], *,
overwrite: bool = True, ts_specs: PyTreeT | None = None
) -> utils.PyTreeFuture:
"""Nonblocking alias of save, return an awaitable future with a pytree stub.
This is a simple experimental array serialization API, for anything more
complex and for all checkpointing prefer: https://github.com/google/orbax
Examples:
>>> fut = nonblocking_save(data, directory)
>>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct's
>>> print(fut.result()) # None, blocking until the serialization is done
"""
# start serialization immediately
fut = utils.PyTreeFuture(_serialization_executor.submit(
save, data, directory, overwrite=overwrite, ts_specs=ts_specs))
# construct a nice looking pytree representing the nodes being read
fut.pytree = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype)
if _is_array_like(x) else x, data)
return fut
def nonblocking_load(directory: str | PathLike[str], shardings: PyTreeT, *,
mask: PyTreeT | None = None,
ts_specs: PyTreeT | None = None) -> utils.PyTreeFuture:
"""Nonblocking alias of load, return an awaitable future with a pytree stub.
This is a simple experimental array serialization API, for anything more
complex and for all checkpointing prefer: https://github.com/google/orbax
Examples:
>>> fut = nonblocking_load(directory)
>>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct
>>> print(fut.result()) # the fully populated pytree
"""
# TODO(rdyro): the awaitable future output is a workaround
# it should return the fully populated pytree instead of just
# jax.ShapeDtypeStruct for arrays by constructing them asynchronously
fut = utils.PyTreeFuture(_serialization_executor.submit(
load, directory, shardings, mask=mask, ts_specs=ts_specs))
fut.pytree = load_pytreedef(directory)
return fut
@@ -0,0 +1,84 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for representing pytreedefs in a serializable format.
"""
import base64
import logging
from types import ModuleType
from concurrent.futures import Future
from typing import Any, TypeVar
import jax
from jax._src.export.serialization import (flatbuffers, _serialize_pytreedef,
_deserialize_pytreedef_to_pytree,
ser_flatbuf)
from jax.export import register_pytree_node_serialization
T = TypeVar("T")
PickleModule = ModuleType
logger = logging.getLogger(__name__)
_READABLE_PYTREE_SERIALIZATION = True
_TREE_REPR_KEY = "__jax_pytreedef_repr"
_LEAF_IDS_KEY = "__jax_leaf_ids"
_NOT_REGISTERED_MESSAGE = (
" * If you want to register a custom leaf, register it via"
" `register_pytree_leaf_serialization` first.\n"
" * If you want to register a custom node, register is via"
" `register_pytree_node_serialization`")
__all__ = ["serialize_pytreedef", "deserialize_pytreedef",
"register_pytree_node_serialization"]
class PyTreeFuture(Future[Any]):
"""A wrapper around a Future that makes it look like an async function."""
def __init__(self, future: Future[Any]):
self._future, self.pytree = future, None
def done(self):
return self._future.done()
def result(self, *args, **kw):
return self._future.result(*args, **kw)
def __await__(self):
while not self.done():
yield
return self.result()
def __repr__(self):
return f"PyTreeFuture(done={self.done()}, pytree={self.pytree})"
def serialize_pytreedef(node) -> dict[str, Any]:
builder = flatbuffers.Builder(65536)
exported = _serialize_pytreedef(builder, node)
builder.Finish(exported)
root_repr = base64.b64encode(builder.Output()).decode("utf-8")
leaf_count = node.num_leaves
pytree_repr = {_TREE_REPR_KEY: root_repr,
_LEAF_IDS_KEY: list(range(leaf_count))}
return pytree_repr
def deserialize_pytreedef(pytreedef_repr: dict[str, Any]):
buf = base64.b64decode(pytreedef_repr[_TREE_REPR_KEY])
exp = ser_flatbuf.PyTreeDef.GetRootAs(buf)
treestruct = jax.tree.structure(_deserialize_pytreedef_to_pytree(exp))
return treestruct
@@ -0,0 +1,374 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Array serialization and deserialization."""
from __future__ import annotations
import abc
import asyncio
from collections.abc import Callable, Sequence
import functools
import itertools
import logging
import re
import threading
import time
from typing import Any
import jax
from jax._src import array
from jax._src import distributed
from jax._src import sharding
from jax._src import typing
from jax._src import util
from jax._src.layout import Format
from jax._src.lib import _jax
from jax.experimental.array_serialization import tensorstore_impl as ts_impl
# ruff: noqa: F401
# import tensorstore-backed methods for backward compatibility.
from jax.experimental.array_serialization.tensorstore_impl import (
_run_deserialization as run_deserialization,
_run_serialization as run_serialization,
async_serialize, async_deserialize, _TS_CONTEXT as TS_CONTEXT,
_DEFAULT_BASE_DRIVER as _DEFAULT_DRIVER, _LimitInFlightBytes)
import tensorstore as ts
# for compatibility with older zarr format
_get_metadata = functools.partial(ts_impl._get_tensorstore_metadata,
driver='zarr')
get_tensorstore_spec = functools.partial(ts_impl.get_tensorstore_spec,
driver='zarr', ocdbt=False)
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
_module_unique_count = itertools.count()
_DISTRIBUTED_SYSTEM_MSG = (
'Please initialize the distributed system via '
'`jax.distributed.initialize()` at the start of your program.')
_REMOTE_URL_PREFIXES = ['gs://', 's3://']
_REMOTE_DRIVER_VALIDATIONS = [
{'driver': 'gcs', 'path_regex': None},
{'driver': 's3', 'path_regex': None},
]
class BarrierTimeoutError(Exception):
pass
_BARRIER_TIMED_OUT_MSG = (
"Suggestions for possible fixes:\n"
"* Check the logs to see if one or more processes failed.\n"
"* Make sure the training and checkpointing endpoints are close geographically.\n"
"* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.")
logger = logging.getLogger(__name__)
def is_remote_storage(tspec: dict[str, Any] | str) -> bool:
"""Detect if user is using cloud storages.
This can detect common defines and unable to detect some corner cases such as
using gcsfuse.
"""
if isinstance(tspec, str):
# KvStoreUrl
if re.match(rf'^({"|".join(_REMOTE_URL_PREFIXES)})', tspec):
return True
else:
return False
for key in ('base', 'kvstore'):
if key in tspec:
return is_remote_storage(tspec[key])
if 'driver' in tspec:
for rule in _REMOTE_DRIVER_VALIDATIONS:
if tspec['driver'] == rule['driver']:
if rule['path_regex'] is None:
return True
# check if path matches the regex.
if re.match(rule['path_regex'], tspec['path']):
return True
return False
def _get_key(key: int):
return f'tensorstore_checkpoint_{key}'
class GlobalAsyncCheckpointManagerBase(util.StrictABC):
"""Interface for checkpointing GDAs asynchronously.
This class manages the state of an ongoing asynchronous checkpoint.
For example, say a checkpoint happens on every step. If you checkpoint on
step 1 and after some computation the model is on checkpoint 2. But step 1's
checkpoint hasn't finished committing to the storage layer yet. So until that
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
class allows to maintain that state.
Examples:
Below is a simplified training loop:
```
# Call this at the start of your program.
jax.distributed.initialize()
manager = GlobalAsyncCheckpointManager()
# Restore checkpoint if available or initialize the train_state from
# init_fn().
train_state = manager.deserialize(...)
while ...:
if step % num_steps_between_checkpoints == 0:
manager.serialize(train_state, temp_checkpoint_dir=...,
final_checkpoint_dir=...)
train_state = train_step(train_state, input)
# This is a non-blocking call.
manager.check_for_errors()
manager.serialize(train_state, temp_checkpoint_dir=...,
final_checkpoint_dir=...)
# Wait before the end of the program for the checkpoint to finish. This is a
# blocking call.
manager.wait_until_finished()
```
"""
@abc.abstractmethod
def check_for_errors(self):
"""Checks if any errors have been raised in the child thread.
This is a non-blocking call that can be called in the main thread.
"""
@abc.abstractmethod
def wait_until_finished(self):
"""Blocks until serialization has finished."""
@abc.abstractmethod
def serialize(self, arrays, tensorstore_specs, *,
on_commit_callback: Callable[[], None]):
"""Serializes GDAs to TensorStore."""
@abc.abstractmethod
def deserialize(self, shardings: Sequence[sharding.Sharding],
tensorstore_specs: Sequence[dict[str, Any]],
global_shapes: Sequence[array.Shape] | None = None,
dtypes: Sequence[typing.DTypeLike] | None = None):
"""Deserializes GDAs from TensorStore."""
class AsyncManager:
def __init__(self, timeout_secs=300):
self._timeout_secs = timeout_secs
self._timeout_in_ms = self._timeout_secs * 1000
self._commit_futures = None
self._thread = None
self._exception = None
if jax.process_count() > 1 and distributed.global_state.client is None:
raise ValueError(_DISTRIBUTED_SYSTEM_MSG)
self._client = distributed.global_state.client
self._count: int | None = None
def __del__(self):
if self._thread is not None and self._thread.is_alive():
logger.warning('Please add `.wait_until_finished()` in the main thread '
'before your program finishes because there is a '
'possibility of losing errors raised if the '
'this class is deleted before writing is completed.')
def _thread_func(self):
try:
current_process = jax.process_index()
process_count = jax.process_count()
logger.info('Starting commit to storage layer by process: %s',
current_process)
thread_start_time = time.time()
assert self._commit_futures is not None
for future in self._commit_futures:
future.result()
logger.info('Finished committing to storage layer by process: %s',
current_process)
key_for_barrier: str | None = None
if process_count > 1:
assert self._client is not None
assert self._count is not None
# All processes will wait at the barrier. When all processes are at the
# barrier, the barrier will be satisfied. If not, then it will timeout.
key_for_barrier = _get_key(self._count)
logger.info('Key used for barrier is %s for process %s',
key_for_barrier, current_process)
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
logger.info('Finished waiting at barrier for process %s',
current_process)
if current_process == 0:
if self._on_commit_callback is not None:
self._on_commit_callback()
logger.info('on_commit_callback successfully ran!')
if process_count > 1:
assert self._client is not None
assert key_for_barrier is not None
self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS)
logger.info('Process 0 successfully set key %s in the kv store',
key_for_barrier)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/thread_duration_sec',
time.time() - thread_start_time)
except Exception as e:
self._exception = e
def _start_async_commit(self, on_commit_callback):
self._count = next(_module_unique_count)
self._on_commit_callback = on_commit_callback
self._thread = threading.Thread(target=self._thread_func)
self._thread.start()
def check_for_errors(self):
if self._exception is not None:
# Clears self._exception so it is only raised once.
exception = self._exception
self._exception = None
if (isinstance(exception, _jax.JaxRuntimeError) and
'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)):
raise BarrierTimeoutError(
'\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG]))
raise exception
def wait_until_finished(self):
if self._thread is not None:
self._thread.join()
self._thread = None
logger.info('Thread joined successfully')
self.check_for_errors()
logger.info('Error check finished successfully')
if jax.process_count() > 1 and self._count is not None:
assert self._client is not None
# Block until process 0 writes success value to the key value store.
# If it fails to write it, then `blocking_key_value_get` will time out.
get_key = _get_key(self._count)
self._client.blocking_key_value_get(get_key, self._timeout_in_ms)
logger.info('blocking_key_value_get on key %s was successfully '
'completed.', get_key)
def _add_futures(self, futures: Sequence[ts.Future]):
self._commit_futures = futures
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
"""Responsible for serializing GDAs via TensorStore."""
def serialize(
self,
arrays,
tensorstore_specs,
*,
on_commit_callback: Callable[[], None] | None = None,
transaction: ts_impl.Transaction | None = None,
):
"""Serializes Arrays or Arrays via TensorStore asynchronously.
TensorStore writes to a storage layer in 2 steps:
* Reading/copying from the source after which the source can be modified.
* Returns a copy future.
* Writing/committing to the storage layer.
* Returns a commit future.
In asynchronous mode, the serialization waits for the commit future to
finish in a separate thread allowing other computation to proceed.
Args:
arrays: Arrays or Arrays that should be serialized.
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
Arrays.
on_commit_callback: This callback will be executed after all processes
have finished writing their checkpoints to disk. Filesystems where
atomic rename operations are supported, you can rename from the
temporary directory to the final directory. On GCS, you write to the
final directory directly and in `on_commit_callback` you write a success
file indicating that the serialization was successful because GCS does
not support atomic rename operations.
transaction: Optional TensorStore transaction to use.
"""
logger.info('Waiting for previous serialization to finish.')
self.wait_until_finished()
commit_futures: list[ts_impl.Future] = []
async def _run_serializer():
future_writer = jax.tree_util.tree_map(
lambda arr_inp, tensorstore_spec: ts_impl.async_serialize(
arr_inp,
tensorstore_spec,
commit_future=commit_futures,
transaction=transaction,
),
arrays,
tensorstore_specs,
)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
self._add_futures(commit_futures)
# Used in wait_until_finished to check on process != 0, if the checkpoint
# has finished writing.
self._start_async_commit(on_commit_callback)
def serialize_with_paths(
self,
arrays: Sequence[jax.Array],
paths: Sequence[str],
*,
on_commit_callback: Callable[[], None] | None = None,
transaction: ts_impl.Transaction | None = None,
):
tspecs = jax.tree.map(get_tensorstore_spec, paths)
return self.serialize(
arrays,
tspecs,
on_commit_callback=on_commit_callback,
transaction=transaction,
)
def deserialize(self, shardings: Sequence[sharding.Sharding | Format],
tensorstore_specs: Sequence[dict[str, Any]],
global_shapes: Sequence[array.Shape] | None = None,
dtypes: Sequence[typing.DTypeLike] | None = None,
concurrent_gb: int = 32):
self.wait_until_finished()
return ts_impl._run_deserialization(
shardings, tensorstore_specs, global_shapes, dtypes, concurrent_gb)
def deserialize_with_paths(
self, shardings: Sequence[sharding.Sharding],
paths: Sequence[str],
global_shapes: Sequence[array.Shape] | None = None,
dtypes: Sequence[typing.DTypeLike] | None = None,
concurrent_gb: int = 32):
tspecs = jax.tree.map(get_tensorstore_spec, paths)
return self.deserialize(shardings, tspecs, global_shapes, dtypes,
concurrent_gb)
@@ -0,0 +1,604 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from functools import partial
import functools
import os
from os import PathLike
import re
from typing import Any
from collections.abc import Awaitable, Callable, Sequence
import math
import logging
import jax
from jax import numpy as jnp
from jax._src import array
from jax._src.layout import Format
from jax._src import typing
from jax._src.sharding_impls import make_single_device_sharding
import numpy as np
import tensorstore as ts
_TS_ARRAY_DRIVER = "zarr3"
_TS_CONTEXT = ts.Context({
'file_io_concurrency': {'limit': 128},
'cache_pool': {'total_bytes_limit': 10_000_000_000}, # 10 GB RAM limit
'cache_pool#remote': {'total_bytes_limit': 10_000_000_000},
'data_copy_concurrency': {'limit': 128}
})
_TS_CHUNK_LAYOUT = ts.ChunkLayout(
chunk=ts.ChunkLayout.Grid(elements=100_000_000), # 100M (800MB for float64) file size
)
_DEFAULT_BASE_DRIVER = 'file'
_PROCESS_DIR_FORMAT = "process_{}"
_FILE_SIZE_TARGET = 2 * 1024 ** 3 # 2 GB
Future, Transaction = ts.Future, ts.Transaction
logger = logging.getLogger(__name__)
# Lifted from T5X.
class _LimitInFlightBytes:
"""Limits host scratch memory usage when reading/writing checkpoints per process."""
def __init__(self, host_memory_bytes_limit: int):
self._max_bytes = host_memory_bytes_limit
self._available_bytes = host_memory_bytes_limit
self._cv = asyncio.Condition(lock=asyncio.Lock())
async def wait_for_bytes(self, requested_bytes):
if requested_bytes > self._max_bytes:
logger.debug("A single array item requests more bytes than we reserved"
" space for in the parallel pool: %d > %d. Increasing the"
" limit to %d.", requested_bytes, self._max_bytes,
requested_bytes)
bytes_currently_used = self._max_bytes - self._available_bytes
self._max_bytes = requested_bytes
self._available_bytes = self._max_bytes - bytes_currently_used
async with self._cv:
await self._cv.wait_for(lambda: self._available_bytes >= requested_bytes)
self._available_bytes -= requested_bytes
assert self._available_bytes >= 0
async def release_bytes(self, requested_bytes):
async with self._cv:
self._available_bytes += requested_bytes
assert self._available_bytes <= self._max_bytes
self._cv.notify_all()
def is_tensorstore_spec_leaf(leaf: Any):
# TODO(rdyro): think of a better way to detect which leaf is a ts config
return leaf is None or (isinstance(leaf, dict)
and ("driver" in leaf or "kvstore" in leaf))
def _prime_factors(x: int) -> list[int]:
# find prime factors of axis sizes to help efficiently find divisor chunks
factors = []
while x % 2 == 0:
factors.append(2)
x //= 2
for i in range(3, int(math.sqrt(x)) + 1, 2):
while x % i == 0:
factors.append(i)
x //= i
if x > 1:
factors.append(x)
return sorted(factors)
@functools.lru_cache(maxsize=1024)
def _compute_chunk_shape(
local_shape: Sequence[int], dtype: str | jnp.dtype,
file_size_target: int = _FILE_SIZE_TARGET) -> list[int]:
"""Compute a chunk such that it divides the local shape and is less than
target file size. This helps the tensorstore kvstore driver limit the largest
file size on disk to below the ``file_size_target``. We compute a chunk with a
byte size at most 110% of the ``file_size_target``.
"""
local_shape = list(local_shape)
if len(local_shape) == 0 or math.prod(local_shape) == 0:
# a zero size array needs a non-zero chunk passed to tensorstore for compat.
return [max(z, 1) for z in local_shape]
total_size = math.prod(local_shape) * jnp.dtype(dtype).itemsize
axis_prime_factors = [_prime_factors(z) for z in local_shape]
chunk_shape, chunk_size = list(local_shape), total_size
# while chunk_size exceeds target size, reduce chunk_shape
while chunk_size > 1.1 * file_size_target: # 10% buffer
# 1. find the smallest axis divisor across all axes
chosen_axis_idx: int | None = None
chosen_divisor = 1
for axis_idx in range(len(chunk_shape)):
if len(axis_prime_factors[axis_idx]) == 1: # ignore axes sizes == 1
continue
if (chosen_axis_idx is None
or chosen_divisor > axis_prime_factors[axis_idx][0]):
chosen_axis_idx = axis_idx
chosen_divisor = axis_prime_factors[axis_idx][0]
# 2. if no divisor found, give up, return current chunk shape
if chosen_axis_idx is None:
return chunk_shape
# 3. remove the applied divisor from prime factors
prime_factors = axis_prime_factors[chosen_axis_idx]
prime_factors.pop(0)
# 4. apply the found divisor to reduce the chunk size
chunk_shape[chosen_axis_idx] //= chosen_divisor
chunk_size //= chosen_divisor
return chunk_shape
def _get_tensorstore_metadata(arr, is_remote: bool = False,
file_size_target: int = _FILE_SIZE_TARGET,
driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]:
global_shape, dtype = arr.shape, arr.dtype
if isinstance(arr, jax.Array):
local_shape = arr.sharding.shard_shape(global_shape)
else: # np.ndarray
local_shape = global_shape
return _get_tensorstore_metadata_cached(global_shape, dtype, local_shape,
is_remote, file_size_target, driver)
@functools.lru_cache(maxsize=1024)
def _get_tensorstore_metadata_cached(
global_shape: Sequence[int], dtype: jnp.dtype, local_shape: Sequence[int],
is_remote: bool = False, file_size_target: int = _FILE_SIZE_TARGET,
driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]:
if driver == "zarr3":
codecs = ([{"name": "zstd"}] if is_remote else [])
return {
'codecs': codecs,
'shape': global_shape,
'data_type': jnp.dtype(dtype).name,
'chunk_grid': {
'name': 'regular',
'configuration': {'chunk_shape': _compute_chunk_shape(
local_shape, dtype, file_size_target=file_size_target)}
}
}
elif driver == "zarr": # in zarr dtype goes in the base spec
return {'compressor': {'id': 'zstd'}, 'shape': global_shape,
'chunks': np.array(np.maximum(1, local_shape)).tolist()}
else:
raise ValueError(f"Unsupported driver: {driver}")
_divides = lambda x, y: np.all((np.array(x) % np.array(y)) == 0)
def merge_nested_ts_specs(dict1: dict[Any, Any], dict2: dict[Any, Any] | None):
"""Merge two ts specs, dict2 takes precedence."""
if dict2 is None: # nothing to do
return dict1
# TODO(rdyro): this is an opinionated merge, we should get user feedback
# merge kvstore explicitly
kvstore = dict1.get("kvstore", {}) | dict2.get("kvstore", {})
return dict1 | dict(dict2, kvstore=kvstore) # merge with dict2 preferred
def verify_tensorstore_spec(spec: dict[str, Any], arr: jax.Array | None,
path: str | os.PathLike[str], ocdbt: bool,
check_metadata: bool = True) -> None:
"""Verify the minimum requirements for a tensorstore spec."""
if ocdbt:
if spec.get("kvstore", {}).get("driver", "") != "ocdbt":
raise ValueError(f"Expected ocdbt driver, got {spec=}")
if check_metadata:
if arr is None:
raise ValueError("Array is required for metadata verification.")
metadata = spec['metadata']
if spec.get("driver", "") == "zarr3":
if metadata['data_type'] != jnp.dtype(arr.dtype).name:
raise ValueError(f"Provided dtype ({metadata['data_type']=}) doesn't"
f" match ({arr.dtype=})")
if 'shape' in metadata:
if metadata['shape'] != arr.shape:
raise ValueError(f"Provided shape ({metadata['shape']=}) doesn't match"
f" ({arr.shape=})")
if isinstance(arr, jax.Array):
local_shape = arr.sharding.shard_shape(arr.shape)
else: # np.ndarray
local_shape = arr.shape
if spec.get("driver", "") == "zarr3":
chunk_shape = metadata['chunk_grid']['configuration']['chunk_shape']
if not _divides(local_shape, chunk_shape):
raise ValueError(f"Provided chunk shape {chunk_shape} does not divide"
f" the local shape of the array {local_shape}")
# check path is still the same one we expect
if ocdbt:
found_path = spec["kvstore"]['base']['path']
else:
found_path = spec["kvstore"]['path']
if str(found_path) != str(path):
raise ValueError(f"Provided {path=} does not match the spec path:"
f" {spec['kvstore']}")
def _spec_has_metadata(tree):
if not isinstance(tree, dict):
return False
return 'metadata' in tree or any(
_spec_has_metadata(subtree) for _, subtree in tree.items())
def _get_kvstore_for_gcs(ckpt_path: str):
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path)
if m is None:
raise ValueError('The ckpt_path should contain the bucket name and the '
f'file path inside the bucket. Got: {ckpt_path}')
bucket = m.group(1)
path_without_bucket = m.group(2)
return {'driver': 'gcs', 'bucket': bucket, 'path': path_without_bucket}
def _get_kvstore_for_s3(ckpt_path: str):
m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
if m is None:
raise ValueError('The ckpt_path should contain the bucket name and the '
f'file path inside the bucket. Got: {ckpt_path}')
bucket = m.group(1)
path_without_bucket = m.group(2)
return {'driver': 's3', 'bucket': bucket, 'path': path_without_bucket}
def get_tensorstore_spec(
ckpt_path: str | PathLike[str], ocdbt: bool = True,
process_idx: int | None = None, arr: jax.Array | None = None,
driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]:
# Normalize path to exclude trailing '/'. In GCS path case, normpath will
# replace a the double '//' with a single '/' and we need to restore the
# filesystem type:// prefix for GCS (gs://) and S3 paths (s3://)
ckpt_path = os.path.normpath(str(ckpt_path))
ckpt_path = re.sub(r"^([a-z]+):/", r"\1://", ckpt_path)
# in cases of multi-process writes, we need to write to a different location
# for each process and finally created a combined symlink to the final
# location, tensorstore can do this via ts.KvStore.experimental_copy_range_to
if process_idx is not None:
_parent, _name = os.path.split(ckpt_path)
ckpt_path = os.path.join(_parent, _PROCESS_DIR_FORMAT.format(process_idx),
_name)
is_gcs_path = ckpt_path.startswith('gs://')
is_s3_path = ckpt_path.startswith('s3://')
spec = {'driver': driver, 'kvstore': {}}
# use a combined OCDBT store, the actual path is the parent path
# the name (filename/last part of the path) is the key in the ocdbt kvstore
entry_key = None
if ocdbt:
(ckpt_path, entry_key), org_ckpt_path = os.path.split(ckpt_path), ckpt_path
if is_gcs_path:
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path)
elif is_s3_path:
m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path)
else:
m = re.match("a", "a") # make it True
if m is None:
raise ValueError('Using OCDBT requires the bucket name, the directory'
' name and the array name, your path is: '
f'{org_ckpt_path}')
if is_gcs_path:
base_kvstore = _get_kvstore_for_gcs(ckpt_path)
elif is_s3_path:
base_kvstore = _get_kvstore_for_s3(ckpt_path)
else:
base_kvstore = {'driver': _DEFAULT_BASE_DRIVER, 'path': ckpt_path}
if ocdbt:
if not is_gcs_path and not is_s3_path and not os.path.isabs(ckpt_path):
raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
spec['kvstore'] = {'driver': 'ocdbt', 'base': base_kvstore,
'path': entry_key}
else:
spec['kvstore'] = base_kvstore # pyrefly: ignore[bad-typed-dict-key]
# done writing tensorstore spec based on destination path
# optionally, if array is provided, we can add metadata to the spec
if arr is not None:
spec["metadata"] = _get_tensorstore_metadata(
arr, driver=str(spec["driver"]))
return spec
async def _create_async_array_from_callback(
global_shape: array.Shape,
dtype: str | jnp.dtype | None,
inp_sharding: jax.sharding.Sharding,
data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]],
):
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
addressable_da = inp_sharding._addressable_device_assignment
future_arrays = [data_callback(device_to_index_map[d], d)
for d in addressable_da]
dbs = await asyncio.gather(*future_arrays)
return array.make_array_from_single_device_arrays(
global_shape, inp_sharding, dbs, dtype=dtype)
async def _transfer_shard_to_host(shard: array.Shard) -> np.ndarray:
data = shard.data
has_pinned_host = any(
m.kind == "pinned_host" for m in shard.device.addressable_memories())
if has_pinned_host:
# If available, transfer to pinned host memory
sharding = make_single_device_sharding(shard.device,
memory_kind="pinned_host")
data = jax.device_put(data, sharding)
else:
data.copy_to_host_async()
# Allow other transfers to be scheduled simultaneously
await asyncio.sleep(0)
# Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore
# implicitly converts the written data to a numpy array, and would otherwise
# silently copy host-to-host.
return np.array(data, copy=False)
async def combine_kvstores(combined_kvstore: dict[str, Any],
kvstores: list[dict[str, Any]],
context: ts.Context | None = _TS_CONTEXT
) -> None:
"""Merge a list of kvstores into a single kvstore. NOT multi-process safe."""
combined_fut = ts.KvStore.open(combined_kvstore, context=context)
kvstores_futs = [ts.KvStore.open(kvstore, context=context)
for kvstore in kvstores]
combined, opened_kvstores = await asyncio.gather(
combined_fut, asyncio.gather(*kvstores_futs)
)
tx = ts.Transaction()
await asyncio.gather(*(
kvstore.experimental_copy_range_to(combined.with_transaction(tx))
for kvstore in opened_kvstores
))
await tx.commit_async()
async def async_serialize(
arr_inp,
tensorstore_spec,
commit_future=None,
context=_TS_CONTEXT,
chunk_layout=_TS_CHUNK_LAYOUT,
primary_host: int | None = None,
replica_id: int = 0,
transaction: ts.Transaction | None = None,
):
"""Serialize an array using TensorStore.
Args:
arr_inp: The array to serialize.
tensorstore_spec: The tensorstore spec to use.
commit_future: A list of futures that will be appended to. The futures can
be awaited asynchronously. If None, the futures will be awaited
synchronously by this method.
context: ts.Context instance.
primary_host: Primary host, which indicates the host that will be treated as
the "leader". If None, all hosts are treated as the primary. DO NOT USE
unless you are sure you know what you are doing.
replica_id: Allows overriding the shard replica id that will be saved. DO
NOT USE unless you are sure you know what you are doing.
transaction: TensorStore transaction to use for opening and writing the
array. If not specified, a non-transactional write will be used.
"""
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
arr_inp.is_fully_addressable):
raise ValueError(
f'Passing fully addressable arrays to a multiprocess '
f'serialization is not allowed, as this may lead to a race condition '
f'between processes. Serialization have failed for the array with '
f'the path from kvstore: "{tensorstore_spec["kvstore"]}".')
# 'metadata' may not be present at the top level (for example, if we are using
# a 'cast' driver).
if not _spec_has_metadata(tensorstore_spec):
tensorstore_spec['metadata'] = _get_tensorstore_metadata(
arr_inp, driver=tensorstore_spec['driver'])
## zarr driver requires specifying the dtype in the spec base
if tensorstore_spec['driver'] == 'zarr' and 'dtype' not in tensorstore_spec:
tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name
# If primary_host is None, all hosts will checkpoint. This is used
# for checkpointing to local filesystem.
if primary_host is None or jax.process_index() == primary_host:
open_future = ts.open(
ts.Spec(tensorstore_spec),
create=True,
open=True,
context=context,
chunk_layout=chunk_layout,
transaction=transaction,
)
# Asynchronous case.
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(open_future)
else:
await open_future
# `ts.open` runs twice for process `primary_host` because for the first time,
# we just get the future to be awaited upon in the background thread. The
# second one runs with `assume_metadata=True` which does no I/O operation and
# returns the tensorstore object.
# For every process other than `primary_host`, we open with
# `assume_metadata=True`.
t = await ts.open(
ts.Spec(tensorstore_spec),
open=True,
assume_metadata=True,
context=context,
chunk_layout=chunk_layout,
transaction=transaction,
)
async def _write_array(shard):
if shard.replica_id == replica_id:
data = await _transfer_shard_to_host(shard)
write_future = t[shard.index].write(
data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=isinstance(
arr_inp, array.ArrayImpl
),
)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(write_future.commit)
await write_future.copy
else:
await write_future.commit
local_shards = arr_inp.addressable_shards
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
return await asyncio.gather(*future_write_state)
# TODO(rdyro): Remove this function.
def _run_serialization(arrays, tensorstore_specs):
"""Legacy serialization of a list of arrays."""
async def _run_serializer():
future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())
def estimate_read_memory_footprint(t: ts.TensorStore,
domain: ts.IndexDomain) -> int:
rank = t.rank
num_bytes = t.dtype.numpy_dtype.itemsize
chunk_template = t.chunk_layout.read_chunk_template
if domain is None:
domain = t.domain
origin = domain.origin
shape = domain.shape
chunk_origin = chunk_template.origin
chunk_shape = chunk_template.shape
# Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
# For those, instead of returning a near-infinite memory footprint, estimate
# the footprint as the entire shape.
for i in range(rank):
if not chunk_template[i].finite:
return domain.size * num_bytes
# Otherwise, if we have a chunked driver, estimate based on chunk size.
for i in range(rank):
origin_value = origin[i]
chunk_origin_value = chunk_origin[i]
chunk_size = chunk_shape[i]
lower = origin_value - chunk_origin_value
upper = origin_value + shape[i] - chunk_origin_value
lower_aligned = lower // chunk_size * chunk_size
upper_aligned = -(-upper // chunk_size) * chunk_size
num_bytes *= (upper_aligned - lower_aligned)
return num_bytes
async def async_deserialize(
in_type: jax.sharding.Sharding | Format | jax.ShapeDtypeStruct,
tensorstore_spec: ts.Spec | dict[str, Any],
global_shape: Sequence[int] | None = None,
dtype=None,
byte_limiter: _LimitInFlightBytes | None = None,
context=_TS_CONTEXT,
chunk_layout=_TS_CHUNK_LAYOUT,
assume_metadata: bool = False,
):
"""Main performant deserialization routine for arrays using tensorstore."""
if isinstance(in_type, Format):
in_sharding, layout = in_type.sharding, in_type.layout
elif isinstance(in_type, jax.ShapeDtypeStruct):
dtype = in_type.dtype if dtype is None else dtype
in_sharding = in_type.sharding
layout = in_type.format.layout
else:
if not isinstance(in_type, jax.sharding.Sharding):
raise TypeError(
'sharding passed to deserialization should be specified, concrete and'
f' an instance of `jax.sharding.Sharding`. Got {in_type}')
in_sharding = in_type
layout = None
assert isinstance(in_sharding, jax.sharding.Sharding)
t = await ts.open(
tensorstore_spec,
open=True,
assume_metadata=assume_metadata,
context=context,
chunk_layout=chunk_layout,
)
shape = t.shape if global_shape is None else global_shape
dtype = dtype if dtype is not None else t.dtype.numpy_dtype
new_shard_shape = in_sharding.shard_shape(tuple(shape))
async def cb(index: array.Index, device: jax.Device):
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
# Limit the bytes read for every shard.
if byte_limiter is not None:
await byte_limiter.wait_for_bytes(requested_bytes)
# This maybe needed because the shape the array was saved with is smaller
# than the requested shape of the array in which it will be reloaded. So
# the extra values will be filled with 0s.
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][
restricted_domain].write(t[restricted_domain])
if dtype is not None:
# Cast while reloading on process to avoid 2 copies on device if the
# casting is done on device.
out = out.astype(dtype)
# Convert to jnp array so that layouts are initialized properly for
# sub-byte dtypes.
# TODO(yashkatariya): This is a band-aid fix. Figure out a better way to
# make this work.
if out.dtype == jnp.int4:
out = jnp.asarray(out)
result = jax.device_put(
out, Format(layout, make_single_device_sharding(device)))
if byte_limiter is not None:
# NB: `out` actually might not be ready for garbage collection by the
# time we call release_bytes . Thus peak memory usage still might grow
# beyond what byte_limiter limit suggests it should. The simplest option
# would be to call `result.block_until_ready()`` here. However it
# also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU
# transfer instead of loading data. In the future, if memory pressure
# becomes a problem, we can instead instrument bytelimiter to
# keep track of all in-flight tensors and only block_until_ready, if byte
# limiter hits the limit to get reduced memory usage, without losing
# performance in common use cases.
await byte_limiter.release_bytes(requested_bytes)
return result
# for deserialization canonicalize dtype to a dtype representable in jax
return await _create_async_array_from_callback(
tuple(shape), jax.dtypes.canonicalize_dtype(dtype), in_sharding, cb)
# TODO(rdyro): Remove this function.
def _run_deserialization(shardings: Sequence[jax.sharding.Sharding | Format],
tensorstore_specs: Sequence[dict[str, Any] | ts.Spec],
global_shapes: Sequence[array.Shape] | None = None,
dtypes: Sequence[typing.DTypeLike] | None = None,
concurrent_gb: int = 32):
"""Legacy deserialization of a list of arrays. Optionally pass global_shapes
and dtypes for type-checking.
"""
concurrent_bytes = concurrent_gb * 10**9
async def _run_deserializer():
# Object should be created once per process.
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
future_arrays = jax.tree_util.tree_map(
partial(async_deserialize, byte_limiter=byte_limiter),
list(shardings), list(tensorstore_specs),
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
return await asyncio.gather(*future_arrays)
return asyncio.run(_run_deserializer())