hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
+514
@@ -0,0 +1,514 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Serializations routines for pytrees including array and non-array serialization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from os import PathLike
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from uuid import uuid4, UUID
|
||||
import json
|
||||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import shutil
|
||||
import logging
|
||||
|
||||
import jax
|
||||
from jax._src import distributed
|
||||
from jax._src.api_util import flatten_axes
|
||||
from jax._src.layout import Format
|
||||
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental.array_serialization import tensorstore_impl as ts_impl
|
||||
import jax.experimental.array_serialization.pytree_serialization_utils as utils
|
||||
from jax._src import path as pathlib
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_THREADING_SAVE_LOCK = threading.Lock()
|
||||
|
||||
_REMOTE_URL_PREFIXES = ['gs://', 's3://']
|
||||
_PYTREEDEF_FILE = "pytreedef.json"
|
||||
_ARCHIVE_NAME = "archive.zip"
|
||||
_USE_OCDBT = True # a lot of the code relies on this being True
|
||||
_MAX_PATH_LENGTH = 4096
|
||||
_ARRAY_STORE_DIRNAME = "array_store"
|
||||
_ARRAY_TYPE_FORMAT = "Array({dtype}[{shape}])"
|
||||
_ARRAY_TYPE_REGEX = r"Array\(([a-zA-Z0-9_]+)\[([0-9, ]*)\]\)"
|
||||
_MAX_CONCURRENCY = 32
|
||||
_TIMEOUT_SEC = 30
|
||||
|
||||
PyTreeT = Any
|
||||
|
||||
__all__ = ["save", "load", "load_pytreedef",
|
||||
"nonblocking_load", "nonblocking_save"]
|
||||
|
||||
|
||||
def _get_unique_sync_key() -> str | None:
|
||||
"""Generate a thread-local key for ensuring all host finish (de)serializing"""
|
||||
if jax.process_count() == 1:
|
||||
return None
|
||||
# broadcast a thread-local unique barrier name
|
||||
sync_key_unique = multihost_utils.broadcast_one_to_all(
|
||||
np.frombuffer(uuid4().bytes, dtype=np.int32))
|
||||
sync_key_id = UUID(bytes=np.array(sync_key_unique).tobytes())
|
||||
return f"jax_sync_key_{str(sync_key_id)}"
|
||||
|
||||
|
||||
def _is_str_same_on_all_hosts(path: str | PathLike[str]) -> bool:
|
||||
"""All-gather the location of the checkpoint and check if it's the same."""
|
||||
if jax.process_count() <= 1:
|
||||
return False
|
||||
path_b = str(path).encode("utf-8")
|
||||
if len(path_b) > _MAX_PATH_LENGTH:
|
||||
raise ValueError(f"Path exceeds maximum length of {_MAX_PATH_LENGTH} in"
|
||||
" multiprocess case.")
|
||||
path_array = np.concatenate([
|
||||
np.frombuffer(path_b, dtype=np.uint8), np.zeros(
|
||||
_MAX_PATH_LENGTH - len(path_b), dtype=np.uint8)])
|
||||
path_array = multihost_utils.process_allgather(path_array)
|
||||
return bool(np.all(path_array[0] == path_array[1:]))
|
||||
|
||||
|
||||
def _sync_on_key(key: str | None, extra_tag: str = "") -> None:
|
||||
if key is None:
|
||||
return
|
||||
full_key = f"{key}-{extra_tag}" if extra_tag else key
|
||||
if (client := distributed.global_state.client) is not None:
|
||||
client.wait_at_barrier(full_key, timeout_in_ms=_TIMEOUT_SEC * 1000)
|
||||
|
||||
|
||||
def _is_array_like(x):
|
||||
return isinstance(x, (jax.Array, np.ndarray))
|
||||
|
||||
|
||||
def _leaf_to_desc(leaf) -> str:
|
||||
if leaf is None:
|
||||
return "null"
|
||||
elif _is_array_like(leaf):
|
||||
return _ARRAY_TYPE_FORMAT.format(
|
||||
dtype=leaf.dtype.name, shape=", ".join(map(str, leaf.shape)))
|
||||
else:
|
||||
return type(leaf).__name__
|
||||
|
||||
|
||||
def _desc_to_leaf(leaf_desc: str | None) -> str | None | jax.ShapeDtypeStruct:
|
||||
if leaf_desc is None:
|
||||
return None
|
||||
if not re.match(_ARRAY_TYPE_REGEX, leaf_desc):
|
||||
return leaf_desc
|
||||
shape_dtype_match = re.match(_ARRAY_TYPE_REGEX, leaf_desc)
|
||||
assert shape_dtype_match is not None
|
||||
dtype_str, shape_str = shape_dtype_match.groups()
|
||||
shape = [int(x.strip()) for x in shape_str.strip("]").strip().split(",")
|
||||
if len(x.strip()) > 0]
|
||||
return jax.ShapeDtypeStruct(shape, jax.numpy.dtype(dtype_str))
|
||||
|
||||
|
||||
def _is_remote_path(path: str | PathLike[str]):
|
||||
"""Check whether a path is remote by examining the prefix."""
|
||||
# we need to truncate e.g., gs:// to gs:/ because pathlib.Path collapses //
|
||||
return any(str(path).startswith(prefix[:-1])
|
||||
for prefix in _REMOTE_URL_PREFIXES)
|
||||
|
||||
|
||||
def _norm_path(path: str | PathLike[str]) -> Any:
|
||||
if _is_remote_path(path):
|
||||
return pathlib.Path(path)
|
||||
return pathlib.Path(path).expanduser().resolve()
|
||||
|
||||
|
||||
def _rm_dir(root: Any) -> None:
|
||||
if _is_remote_path(root):
|
||||
root.rmtree()
|
||||
else:
|
||||
shutil.rmtree(root)
|
||||
|
||||
|
||||
def _set_up_destination(root: str | PathLike[str], overwrite: bool,
|
||||
pytree_repr: dict[str, Any], distinct_locations: bool,
|
||||
sync_key: str | None) -> dict[str, Any]:
|
||||
"""Inspect the destination, set it up for writing, potentially read existing data."""
|
||||
root = _norm_path(root)
|
||||
if overwrite:
|
||||
if root.exists() and len(list(root.iterdir())) > 0:
|
||||
# check that we're only deleting things that come from JAX
|
||||
# refuse to rm directories containing additional entries
|
||||
extra_member_paths = [
|
||||
path for path in list(root.iterdir()) if path.name not in
|
||||
(_PYTREEDEF_FILE, _ARCHIVE_NAME, _ARRAY_STORE_DIRNAME)]
|
||||
|
||||
if len(extra_member_paths) != 0:
|
||||
raise RuntimeError(
|
||||
"Refusing to work on a directory that is not a previous checkpoint."
|
||||
f" Unrecognized paths: {extra_member_paths}. Remove them manually"
|
||||
f" if you're sure you want to use {root} as the checkpoint"
|
||||
" directory.")
|
||||
|
||||
if (jax.process_index() == 0 or distinct_locations) and root.exists():
|
||||
_rm_dir(root)
|
||||
_sync_on_key(sync_key, "overwrite")
|
||||
return pytree_repr
|
||||
else:
|
||||
if (root.exists() and len(list(root.iterdir())) > 0): # not empty
|
||||
raise ValueError(f"Files already exist at path: `{root}`, but you"
|
||||
f" specified `{overwrite=}`")
|
||||
return pytree_repr
|
||||
|
||||
|
||||
def _prepare_directory(root: str | PathLike[str], overwrite: bool,
|
||||
pytreedef_repr: dict[str, Any], distinct_locations: bool,
|
||||
sync_key: str | None):
|
||||
"""Prepare the directory: check destination, potentially read existing data
|
||||
and overwrite.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the destination directory cannot be created.
|
||||
"""
|
||||
root = _norm_path(root)
|
||||
# prepare the destination directory, overwrite destination directory or error
|
||||
pytreedef_repr = _set_up_destination(
|
||||
root, overwrite, pytreedef_repr, distinct_locations, sync_key)
|
||||
|
||||
if not _is_remote_path(root) and (distinct_locations
|
||||
or jax.process_index() == 0):
|
||||
root.mkdir(exist_ok=True) # do not make parents, that's too much
|
||||
if not root.exists() or not root.is_dir():
|
||||
raise RuntimeError(f"Could not create destination directory at {root}")
|
||||
_sync_on_key(sync_key, "mkdir")
|
||||
return pytreedef_repr
|
||||
|
||||
|
||||
def _write_arrays(array_store_path: Any, arrs: list[Any],
|
||||
arr_leaf_ids: list[int], ts_specs: list[Any | None],
|
||||
distinct_locations: bool):
|
||||
paths = [array_store_path / str(leaf_id) for leaf_id in arr_leaf_ids]
|
||||
process_idx = None
|
||||
if not distinct_locations and jax.process_count() > 1:
|
||||
process_idx = jax.process_index()
|
||||
default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT,
|
||||
process_idx=process_idx,
|
||||
arr=arr)
|
||||
for (path, arr) in zip(paths, arrs)]
|
||||
processed_ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec)
|
||||
for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)]
|
||||
|
||||
# sanity check the ts specs
|
||||
if len(ts_specs) > 0: # verify the base path is shared for all arrays
|
||||
expected_path = processed_ts_specs[0]["kvstore"]["base"]["path"] # shared base path
|
||||
for ts_spec, arr in zip(processed_ts_specs, arrs):
|
||||
assert ts_spec is not None
|
||||
ts_impl.verify_tensorstore_spec(ts_spec, arr, expected_path,
|
||||
ocdbt=_USE_OCDBT, check_metadata=True)
|
||||
|
||||
async def _serialize_arrays():
|
||||
await asyncio.gather(*[
|
||||
ts_impl.async_serialize(arr, ts_spec, primary_host=None)
|
||||
for (arr, ts_spec) in zip(arrs, processed_ts_specs)])
|
||||
|
||||
asyncio.run(_serialize_arrays())
|
||||
|
||||
|
||||
def _finalize_array_store(kvstore_path, distinct_locations: bool):
|
||||
"""When multiple processes are writing, they must write to a per-process
|
||||
location followed by combining them via no-copy links to the final location.
|
||||
"""
|
||||
# only in multiprocess case and only process 0
|
||||
if distinct_locations or jax.process_count() == 1 or jax.process_index() != 0:
|
||||
return
|
||||
dummy_key_path = os.path.join(kvstore_path, "dummy_key")
|
||||
combined_kvstore = ts_impl.get_tensorstore_spec(
|
||||
dummy_key_path, ocdbt=True, process_idx=None)["kvstore"]
|
||||
children_kvstores = [ts_impl.get_tensorstore_spec(
|
||||
dummy_key_path, ocdbt=True, process_idx=i)["kvstore"]
|
||||
for i in range(jax.process_count())]
|
||||
_ = combined_kvstore.pop("path")
|
||||
_ = [kvstore.pop("path") for kvstore in children_kvstores]
|
||||
asyncio.run(ts_impl.combine_kvstores(combined_kvstore, children_kvstores))
|
||||
|
||||
|
||||
def _write_pytreedef(directory: Any, pytree_repr: dict[str, Any],
|
||||
distinct_locations: bool):
|
||||
"""Write the pytreedef to the destination directory and aux data to the archive."""
|
||||
if not (jax.process_index() == 0 or distinct_locations):
|
||||
return
|
||||
root = _norm_path(directory)
|
||||
(root / _PYTREEDEF_FILE).write_text(json.dumps(pytree_repr, indent=2))
|
||||
|
||||
|
||||
def _tree_broadcast(a, b, is_leaf=lambda x: x is None):
|
||||
"""Broadcast the prefix tree `a` to the full tree `b`
|
||||
|
||||
Uses `flatten_axes` for better error messages on mismatched arity but allowing
|
||||
for custom is_leaf in the `a` and `b` trees.
|
||||
"""
|
||||
a_leaves, a_struct = jax.tree.flatten(a, is_leaf=is_leaf)
|
||||
a_idx2leaf_map = dict(enumerate(a_leaves))
|
||||
a_idx = jax.tree.unflatten(a_struct, a_idx2leaf_map.keys())
|
||||
a_idx_broadcast = flatten_axes("tree_broadcast",
|
||||
jax.tree.structure(b, is_leaf=is_leaf), a_idx)
|
||||
return jax.tree.map(lambda i: a_idx2leaf_map[i], a_idx_broadcast)
|
||||
|
||||
|
||||
_serialization_executor = ThreadPoolExecutor(max_workers=_MAX_CONCURRENCY)
|
||||
|
||||
|
||||
def save(data: PyTreeT, directory: str | PathLike[str], *,
|
||||
overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None:
|
||||
"""Saves the given data structure to the provided directory path.
|
||||
|
||||
This function provides functionality to serialize and save a data structure
|
||||
comprising JAX arrays, along with its structure to a given directory. It
|
||||
leverages `PyTree` for flattening and reconstructing the data structure.
|
||||
|
||||
This is a simple experimental array serialization API, for anything more
|
||||
complex and for all checkpointing prefer: https://github.com/google/orbax
|
||||
|
||||
Args:
|
||||
data: The data structure to be saved. Arbitrary composition of JAX arrays,
|
||||
including nested structures.
|
||||
directory: The directory path where the data will be saved. A local path or
|
||||
a remote URL (e.g., gs://, s3://). For remote URLs, `etils` is required.
|
||||
overwrite: If True, any existing directory with the same name will be
|
||||
overwritten.
|
||||
ts_specs: Optional tensorstore specs to use for serialization. If None,
|
||||
defaults to using the default tensorstore specs.
|
||||
|
||||
Example:
|
||||
>>> data = {"a": jnp.array([1, 2]), "b": None}
|
||||
>>> save(data, directory)
|
||||
"""
|
||||
with _THREADING_SAVE_LOCK:
|
||||
return _save(data, directory, overwrite=overwrite, ts_specs=ts_specs)
|
||||
|
||||
|
||||
def _save(data: PyTreeT, directory: str | PathLike[str], *,
|
||||
overwrite: bool = True, ts_specs: PyTreeT | None = None) -> None:
|
||||
sync_key = _get_unique_sync_key() # get a synchronization key for multi-host
|
||||
|
||||
if _is_remote_path(directory) and not pathlib.epath_installed:
|
||||
raise RuntimeError("For saving to remote URLs (e.g., gs, s3) you need the"
|
||||
" `etils` module installed. You can install it using"
|
||||
" `pip install etils`.")
|
||||
ts_specs = _tree_broadcast(ts_specs, data,
|
||||
is_leaf=ts_impl.is_tensorstore_spec_leaf)
|
||||
data_flat, pytreedef = jax.tree.flatten(data, is_leaf=lambda x: x is None)
|
||||
if not all(x is None or _is_array_like(x) for x in data_flat):
|
||||
raise ValueError("For serialization, all leaves must be either None or"
|
||||
" jax.Array-like objects.")
|
||||
distinct_locations = not _is_str_same_on_all_hosts(directory)
|
||||
if jax.process_count() > 1 and distinct_locations:
|
||||
raise ValueError(
|
||||
"Saving to different locations on different hosts is not supported,"
|
||||
" because it is extremely fragile. Consider using a single location.")
|
||||
root = _norm_path(directory)
|
||||
|
||||
# 1. serialize the pytree #################################
|
||||
pytreedef_repr = utils.serialize_pytreedef(pytreedef)
|
||||
pytreedef_repr[utils._LEAF_IDS_KEY] = jax.tree.map(_leaf_to_desc, data_flat)
|
||||
|
||||
pytreedef_repr = _prepare_directory(
|
||||
root, overwrite, pytreedef_repr, distinct_locations, sync_key)
|
||||
futures = []
|
||||
futures.append(_serialization_executor.submit(
|
||||
_write_pytreedef, root, pytreedef_repr, distinct_locations))
|
||||
|
||||
# 2. serialize arrays #####################################
|
||||
array_store_path = root / _ARRAY_STORE_DIRNAME
|
||||
arrs = [data for data in data_flat if _is_array_like(data)]
|
||||
arr_leaf_ids = [i for i, data in enumerate(data_flat) if _is_array_like(data)]
|
||||
ts_specs_flat = jax.tree.leaves(ts_specs,
|
||||
is_leaf=ts_impl.is_tensorstore_spec_leaf)
|
||||
ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids]
|
||||
futures.append(_serialization_executor.submit(
|
||||
_write_arrays, array_store_path, arrs, arr_leaf_ids, ts_specs_flat,
|
||||
distinct_locations))
|
||||
|
||||
# 3. wait for all futures to complete #####################
|
||||
_ = [fut.result() for fut in futures]
|
||||
_sync_on_key(sync_key, "array_serialization")
|
||||
|
||||
# 4. finalize the array writing ###########################
|
||||
if len(arr_leaf_ids) > 0 and _USE_OCDBT:
|
||||
_serialization_executor.submit( # call from a thread to not nest asyncio
|
||||
_finalize_array_store, array_store_path, distinct_locations).result()
|
||||
# we are done with all async ops here, we can block ####
|
||||
_sync_on_key(sync_key, "end")
|
||||
|
||||
|
||||
def _read_arrays(array_store_path: str | PathLike[str], arr_leaf_ids: list[int],
|
||||
ts_specs: list[Any], shardings: list[Any]):
|
||||
# array_store_path = root / _LEAF_DATA_DIR / _ARRAY_STORE_DIRNAME
|
||||
arr_store_path = _norm_path(array_store_path)
|
||||
arr_paths = [arr_store_path / str(leaf_id) for leaf_id in arr_leaf_ids]
|
||||
|
||||
# byte limiter to limit number of parallel reads, resizes to largest read
|
||||
byte_limiter = ts_impl._LimitInFlightBytes(10 * 1024 ** 3) # 10 GB
|
||||
|
||||
default_ts_specs = [ts_impl.get_tensorstore_spec(path, ocdbt=_USE_OCDBT,
|
||||
process_idx=None)
|
||||
for path in arr_paths]
|
||||
ts_specs = [ts_impl.merge_nested_ts_specs(default_ts_spec, ts_spec)
|
||||
for (default_ts_spec, ts_spec) in zip(default_ts_specs, ts_specs)]
|
||||
|
||||
if len(ts_specs) > 0: # verify the base path is shared for all arrays
|
||||
expected_path = ts_specs[0]["kvstore"]["base"]["path"] # shared base path
|
||||
for ts_spec in ts_specs:
|
||||
ts_impl.verify_tensorstore_spec(ts_spec, arr=None, path=expected_path,
|
||||
ocdbt=_USE_OCDBT, check_metadata=False)
|
||||
|
||||
async def _deserialize_arrays():
|
||||
return await asyncio.gather(*[
|
||||
ts_impl.async_deserialize(sharding, ts_spec, byte_limiter=byte_limiter)
|
||||
for (sharding, ts_spec) in zip(shardings, ts_specs)])
|
||||
|
||||
return dict(zip(arr_leaf_ids, asyncio.run(_deserialize_arrays())))
|
||||
|
||||
|
||||
def load_pytreedef(directory: str | PathLike[str]) -> PyTreeT:
|
||||
"""Loads a pytree from the given directory.
|
||||
|
||||
This is a simple experimental array serialization API, for anything more
|
||||
complex and for all checkpointing prefer: https://github.com/google/orbax
|
||||
|
||||
Args:
|
||||
directory: Directory path to load from.
|
||||
Returns:
|
||||
The loaded pytree with arrays represented as jax.ShapeDtypeStruct's.
|
||||
"""
|
||||
assert not _is_remote_path(directory) or pathlib.epath_installed, (
|
||||
"For checkpointing using remote URLs (e.g., gs, s3) you need `etils`"
|
||||
" module installed. You can install it using `pip install etils`.")
|
||||
json_content = (_norm_path(directory) / _PYTREEDEF_FILE).read_text()
|
||||
raw_tree = json.loads(json_content)
|
||||
leaves = map(_desc_to_leaf, raw_tree[utils._LEAF_IDS_KEY])
|
||||
return jax.tree.unflatten(utils.deserialize_pytreedef(raw_tree), leaves)
|
||||
|
||||
|
||||
def load(directory: str | PathLike[str], shardings: PyTreeT, *,
|
||||
mask: PyTreeT | None = None, ts_specs: PyTreeT | None = None
|
||||
) -> PyTreeT:
|
||||
"""Loads and reconstructs a data structure from a directory.
|
||||
|
||||
This is a simple experimental array serialization API, for anything more
|
||||
complex and for all checkpointing prefer: https://github.com/google/orbax
|
||||
|
||||
Args:
|
||||
directory: Directory path where the data is stored.
|
||||
shardings: Sharding strategy for array objects, either a Sharding or a
|
||||
ShapeDtypeStruct with a Sharding/Format.
|
||||
mask: boolean prefix tree for partial loading, will return None for False
|
||||
leaves.
|
||||
ts_specs: Optional tensorstore specs to use for deserialization. If None,
|
||||
defaults to using the default tensorstore specs.
|
||||
|
||||
Returns:
|
||||
Reconstructed data.
|
||||
|
||||
Example:
|
||||
>>> save(data, directory)
|
||||
>>> restored_data = load(directory, SingleDeviceSharding(jax.devices()[0]))
|
||||
"""
|
||||
assert not _is_remote_path(directory) or pathlib.epath_installed, (
|
||||
"For checkpointing using remote URLs (e.g., gs, s3) you need `etils`"
|
||||
" module installed. You can install it using `pip install etils`.")
|
||||
|
||||
root = _norm_path(directory)
|
||||
assert root.is_dir(), f"Checkpoint directory {root} does not exist"
|
||||
is_leaf = lambda x: x is None
|
||||
|
||||
# deserialize PyTreeDef
|
||||
pytree = load_pytreedef(directory)
|
||||
# broadcast the (prefix) shardings and tensorstore specs to the full pytree
|
||||
shardings = _tree_broadcast(shardings, pytree)
|
||||
ts_specs = _tree_broadcast(ts_specs, pytree,
|
||||
is_leaf=ts_impl.is_tensorstore_spec_leaf)
|
||||
if mask is not None:
|
||||
_prefix_mask = lambda m, x: jax.tree.map(lambda _: None, x) if not m else x
|
||||
pytree = jax.tree.map(_prefix_mask, mask, pytree)
|
||||
pytreedef = jax.tree.structure(pytree, is_leaf=is_leaf)
|
||||
leaf_ids_flat = jax.tree.leaves(pytree, is_leaf=is_leaf)
|
||||
shardings_flat = jax.tree.leaves(shardings, is_leaf=is_leaf)
|
||||
if any(isinstance(shardings, Format) for shardings in shardings_flat):
|
||||
raise NotImplementedError(
|
||||
"Deserialization with `Format` instead of `Sharding` is not currently"
|
||||
" supported. Pass ShapeDtypeStruct(shape, dtype, sharding=format)"
|
||||
" instead.")
|
||||
ts_specs_flat = jax.tree.leaves(ts_specs,
|
||||
is_leaf=ts_impl.is_tensorstore_spec_leaf)
|
||||
|
||||
# deserialize array objects
|
||||
arr_leaf_ids = [i for i, leaf_id in enumerate(leaf_ids_flat)
|
||||
if leaf_id is not None]
|
||||
shardings_flat = [shardings_flat[i] for i in arr_leaf_ids]
|
||||
ts_specs_flat = [ts_specs_flat[i] for i in arr_leaf_ids]
|
||||
|
||||
arrs_fut = _serialization_executor.submit(
|
||||
_read_arrays, root / _ARRAY_STORE_DIRNAME, arr_leaf_ids, ts_specs_flat,
|
||||
shardings_flat)
|
||||
|
||||
arrs = arrs_fut.result()
|
||||
filled_values = [arrs.get(i, None) for i, _ in enumerate(leaf_ids_flat)]
|
||||
return jax.tree.unflatten(pytreedef, filled_values)
|
||||
|
||||
|
||||
def nonblocking_save(data: PyTreeT, directory: str | PathLike[str], *,
|
||||
overwrite: bool = True, ts_specs: PyTreeT | None = None
|
||||
) -> utils.PyTreeFuture:
|
||||
"""Nonblocking alias of save, return an awaitable future with a pytree stub.
|
||||
|
||||
This is a simple experimental array serialization API, for anything more
|
||||
complex and for all checkpointing prefer: https://github.com/google/orbax
|
||||
|
||||
Examples:
|
||||
>>> fut = nonblocking_save(data, directory)
|
||||
>>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct's
|
||||
>>> print(fut.result()) # None, blocking until the serialization is done
|
||||
"""
|
||||
# start serialization immediately
|
||||
fut = utils.PyTreeFuture(_serialization_executor.submit(
|
||||
save, data, directory, overwrite=overwrite, ts_specs=ts_specs))
|
||||
# construct a nice looking pytree representing the nodes being read
|
||||
fut.pytree = jax.tree.map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
if _is_array_like(x) else x, data)
|
||||
return fut
|
||||
|
||||
|
||||
def nonblocking_load(directory: str | PathLike[str], shardings: PyTreeT, *,
|
||||
mask: PyTreeT | None = None,
|
||||
ts_specs: PyTreeT | None = None) -> utils.PyTreeFuture:
|
||||
"""Nonblocking alias of load, return an awaitable future with a pytree stub.
|
||||
|
||||
This is a simple experimental array serialization API, for anything more
|
||||
complex and for all checkpointing prefer: https://github.com/google/orbax
|
||||
|
||||
Examples:
|
||||
>>> fut = nonblocking_load(directory)
|
||||
>>> print(fut.pytree) # a pytree of jax.ShapeDtypeStruct
|
||||
>>> print(fut.result()) # the fully populated pytree
|
||||
"""
|
||||
# TODO(rdyro): the awaitable future output is a workaround
|
||||
# it should return the fully populated pytree instead of just
|
||||
# jax.ShapeDtypeStruct for arrays by constructing them asynchronously
|
||||
fut = utils.PyTreeFuture(_serialization_executor.submit(
|
||||
load, directory, shardings, mask=mask, ts_specs=ts_specs))
|
||||
fut.pytree = load_pytreedef(directory)
|
||||
return fut
|
||||
+84
@@ -0,0 +1,84 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
|
||||
# # Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utilities for representing pytreedefs in a serializable format.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from concurrent.futures import Future
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import jax
|
||||
from jax._src.export.serialization import (flatbuffers, _serialize_pytreedef,
|
||||
_deserialize_pytreedef_to_pytree,
|
||||
ser_flatbuf)
|
||||
from jax.export import register_pytree_node_serialization
|
||||
|
||||
T = TypeVar("T")
|
||||
PickleModule = ModuleType
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_READABLE_PYTREE_SERIALIZATION = True
|
||||
_TREE_REPR_KEY = "__jax_pytreedef_repr"
|
||||
_LEAF_IDS_KEY = "__jax_leaf_ids"
|
||||
|
||||
_NOT_REGISTERED_MESSAGE = (
|
||||
" * If you want to register a custom leaf, register it via"
|
||||
" `register_pytree_leaf_serialization` first.\n"
|
||||
" * If you want to register a custom node, register is via"
|
||||
" `register_pytree_node_serialization`")
|
||||
|
||||
__all__ = ["serialize_pytreedef", "deserialize_pytreedef",
|
||||
"register_pytree_node_serialization"]
|
||||
|
||||
class PyTreeFuture(Future[Any]):
|
||||
"""A wrapper around a Future that makes it look like an async function."""
|
||||
def __init__(self, future: Future[Any]):
|
||||
self._future, self.pytree = future, None
|
||||
|
||||
def done(self):
|
||||
return self._future.done()
|
||||
|
||||
def result(self, *args, **kw):
|
||||
return self._future.result(*args, **kw)
|
||||
|
||||
def __await__(self):
|
||||
while not self.done():
|
||||
yield
|
||||
return self.result()
|
||||
|
||||
def __repr__(self):
|
||||
return f"PyTreeFuture(done={self.done()}, pytree={self.pytree})"
|
||||
|
||||
|
||||
def serialize_pytreedef(node) -> dict[str, Any]:
|
||||
builder = flatbuffers.Builder(65536)
|
||||
exported = _serialize_pytreedef(builder, node)
|
||||
builder.Finish(exported)
|
||||
root_repr = base64.b64encode(builder.Output()).decode("utf-8")
|
||||
leaf_count = node.num_leaves
|
||||
pytree_repr = {_TREE_REPR_KEY: root_repr,
|
||||
_LEAF_IDS_KEY: list(range(leaf_count))}
|
||||
return pytree_repr
|
||||
|
||||
|
||||
def deserialize_pytreedef(pytreedef_repr: dict[str, Any]):
|
||||
buf = base64.b64decode(pytreedef_repr[_TREE_REPR_KEY])
|
||||
exp = ser_flatbuf.PyTreeDef.GetRootAs(buf)
|
||||
treestruct = jax.tree.structure(_deserialize_pytreedef_to_pytree(exp))
|
||||
return treestruct
|
||||
+374
@@ -0,0 +1,374 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Array serialization and deserialization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import array
|
||||
from jax._src import distributed
|
||||
from jax._src import sharding
|
||||
from jax._src import typing
|
||||
from jax._src import util
|
||||
from jax._src.layout import Format
|
||||
from jax._src.lib import _jax
|
||||
from jax.experimental.array_serialization import tensorstore_impl as ts_impl
|
||||
# ruff: noqa: F401
|
||||
# import tensorstore-backed methods for backward compatibility.
|
||||
from jax.experimental.array_serialization.tensorstore_impl import (
|
||||
_run_deserialization as run_deserialization,
|
||||
_run_serialization as run_serialization,
|
||||
async_serialize, async_deserialize, _TS_CONTEXT as TS_CONTEXT,
|
||||
_DEFAULT_BASE_DRIVER as _DEFAULT_DRIVER, _LimitInFlightBytes)
|
||||
import tensorstore as ts
|
||||
|
||||
# for compatibility with older zarr format
|
||||
_get_metadata = functools.partial(ts_impl._get_tensorstore_metadata,
|
||||
driver='zarr')
|
||||
get_tensorstore_spec = functools.partial(ts_impl.get_tensorstore_spec,
|
||||
driver='zarr', ocdbt=False)
|
||||
|
||||
|
||||
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
||||
_module_unique_count = itertools.count()
|
||||
_DISTRIBUTED_SYSTEM_MSG = (
|
||||
'Please initialize the distributed system via '
|
||||
'`jax.distributed.initialize()` at the start of your program.')
|
||||
_REMOTE_URL_PREFIXES = ['gs://', 's3://']
|
||||
_REMOTE_DRIVER_VALIDATIONS = [
|
||||
{'driver': 'gcs', 'path_regex': None},
|
||||
{'driver': 's3', 'path_regex': None},
|
||||
]
|
||||
|
||||
class BarrierTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
_BARRIER_TIMED_OUT_MSG = (
|
||||
"Suggestions for possible fixes:\n"
|
||||
"* Check the logs to see if one or more processes failed.\n"
|
||||
"* Make sure the training and checkpointing endpoints are close geographically.\n"
|
||||
"* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_remote_storage(tspec: dict[str, Any] | str) -> bool:
|
||||
"""Detect if user is using cloud storages.
|
||||
|
||||
This can detect common defines and unable to detect some corner cases such as
|
||||
using gcsfuse.
|
||||
"""
|
||||
if isinstance(tspec, str):
|
||||
# KvStoreUrl
|
||||
if re.match(rf'^({"|".join(_REMOTE_URL_PREFIXES)})', tspec):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
for key in ('base', 'kvstore'):
|
||||
if key in tspec:
|
||||
return is_remote_storage(tspec[key])
|
||||
|
||||
if 'driver' in tspec:
|
||||
for rule in _REMOTE_DRIVER_VALIDATIONS:
|
||||
if tspec['driver'] == rule['driver']:
|
||||
if rule['path_regex'] is None:
|
||||
return True
|
||||
|
||||
# check if path matches the regex.
|
||||
if re.match(rule['path_regex'], tspec['path']):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_key(key: int):
|
||||
return f'tensorstore_checkpoint_{key}'
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManagerBase(util.StrictABC):
|
||||
"""Interface for checkpointing GDAs asynchronously.
|
||||
|
||||
This class manages the state of an ongoing asynchronous checkpoint.
|
||||
|
||||
For example, say a checkpoint happens on every step. If you checkpoint on
|
||||
step 1 and after some computation the model is on checkpoint 2. But step 1's
|
||||
checkpoint hasn't finished committing to the storage layer yet. So until that
|
||||
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
|
||||
class allows to maintain that state.
|
||||
|
||||
Examples:
|
||||
|
||||
Below is a simplified training loop:
|
||||
|
||||
```
|
||||
# Call this at the start of your program.
|
||||
jax.distributed.initialize()
|
||||
|
||||
manager = GlobalAsyncCheckpointManager()
|
||||
|
||||
# Restore checkpoint if available or initialize the train_state from
|
||||
# init_fn().
|
||||
train_state = manager.deserialize(...)
|
||||
|
||||
while ...:
|
||||
if step % num_steps_between_checkpoints == 0:
|
||||
manager.serialize(train_state, temp_checkpoint_dir=...,
|
||||
final_checkpoint_dir=...)
|
||||
train_state = train_step(train_state, input)
|
||||
# This is a non-blocking call.
|
||||
manager.check_for_errors()
|
||||
|
||||
manager.serialize(train_state, temp_checkpoint_dir=...,
|
||||
final_checkpoint_dir=...)
|
||||
# Wait before the end of the program for the checkpoint to finish. This is a
|
||||
# blocking call.
|
||||
manager.wait_until_finished()
|
||||
```
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_for_errors(self):
|
||||
"""Checks if any errors have been raised in the child thread.
|
||||
|
||||
This is a non-blocking call that can be called in the main thread.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def wait_until_finished(self):
|
||||
"""Blocks until serialization has finished."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def serialize(self, arrays, tensorstore_specs, *,
|
||||
on_commit_callback: Callable[[], None]):
|
||||
"""Serializes GDAs to TensorStore."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[dict[str, Any]],
|
||||
global_shapes: Sequence[array.Shape] | None = None,
|
||||
dtypes: Sequence[typing.DTypeLike] | None = None):
|
||||
"""Deserializes GDAs from TensorStore."""
|
||||
|
||||
|
||||
class AsyncManager:
|
||||
|
||||
def __init__(self, timeout_secs=300):
|
||||
self._timeout_secs = timeout_secs
|
||||
self._timeout_in_ms = self._timeout_secs * 1000
|
||||
|
||||
self._commit_futures = None
|
||||
self._thread = None
|
||||
self._exception = None
|
||||
|
||||
if jax.process_count() > 1 and distributed.global_state.client is None:
|
||||
raise ValueError(_DISTRIBUTED_SYSTEM_MSG)
|
||||
self._client = distributed.global_state.client
|
||||
self._count: int | None = None
|
||||
|
||||
def __del__(self):
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logger.warning('Please add `.wait_until_finished()` in the main thread '
|
||||
'before your program finishes because there is a '
|
||||
'possibility of losing errors raised if the '
|
||||
'this class is deleted before writing is completed.')
|
||||
|
||||
def _thread_func(self):
|
||||
try:
|
||||
current_process = jax.process_index()
|
||||
process_count = jax.process_count()
|
||||
logger.info('Starting commit to storage layer by process: %s',
|
||||
current_process)
|
||||
thread_start_time = time.time()
|
||||
assert self._commit_futures is not None
|
||||
for future in self._commit_futures:
|
||||
future.result()
|
||||
logger.info('Finished committing to storage layer by process: %s',
|
||||
current_process)
|
||||
|
||||
key_for_barrier: str | None = None
|
||||
if process_count > 1:
|
||||
assert self._client is not None
|
||||
assert self._count is not None
|
||||
# All processes will wait at the barrier. When all processes are at the
|
||||
# barrier, the barrier will be satisfied. If not, then it will timeout.
|
||||
key_for_barrier = _get_key(self._count)
|
||||
logger.info('Key used for barrier is %s for process %s',
|
||||
key_for_barrier, current_process)
|
||||
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
|
||||
logger.info('Finished waiting at barrier for process %s',
|
||||
current_process)
|
||||
|
||||
if current_process == 0:
|
||||
if self._on_commit_callback is not None:
|
||||
self._on_commit_callback()
|
||||
logger.info('on_commit_callback successfully ran!')
|
||||
if process_count > 1:
|
||||
assert self._client is not None
|
||||
assert key_for_barrier is not None
|
||||
self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS)
|
||||
logger.info('Process 0 successfully set key %s in the kv store',
|
||||
key_for_barrier)
|
||||
|
||||
jax.monitoring.record_event_duration_secs(
|
||||
'/jax/checkpoint/write/async/thread_duration_sec',
|
||||
time.time() - thread_start_time)
|
||||
|
||||
except Exception as e:
|
||||
self._exception = e
|
||||
|
||||
def _start_async_commit(self, on_commit_callback):
|
||||
self._count = next(_module_unique_count)
|
||||
|
||||
self._on_commit_callback = on_commit_callback
|
||||
self._thread = threading.Thread(target=self._thread_func)
|
||||
self._thread.start()
|
||||
|
||||
def check_for_errors(self):
|
||||
if self._exception is not None:
|
||||
# Clears self._exception so it is only raised once.
|
||||
exception = self._exception
|
||||
self._exception = None
|
||||
if (isinstance(exception, _jax.JaxRuntimeError) and
|
||||
'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)):
|
||||
raise BarrierTimeoutError(
|
||||
'\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG]))
|
||||
raise exception
|
||||
|
||||
def wait_until_finished(self):
|
||||
if self._thread is not None:
|
||||
self._thread.join()
|
||||
self._thread = None
|
||||
logger.info('Thread joined successfully')
|
||||
|
||||
self.check_for_errors()
|
||||
logger.info('Error check finished successfully')
|
||||
|
||||
if jax.process_count() > 1 and self._count is not None:
|
||||
assert self._client is not None
|
||||
# Block until process 0 writes success value to the key value store.
|
||||
# If it fails to write it, then `blocking_key_value_get` will time out.
|
||||
get_key = _get_key(self._count)
|
||||
self._client.blocking_key_value_get(get_key, self._timeout_in_ms)
|
||||
logger.info('blocking_key_value_get on key %s was successfully '
|
||||
'completed.', get_key)
|
||||
|
||||
def _add_futures(self, futures: Sequence[ts.Future]):
|
||||
self._commit_futures = futures
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
|
||||
"""Responsible for serializing GDAs via TensorStore."""
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
arrays,
|
||||
tensorstore_specs,
|
||||
*,
|
||||
on_commit_callback: Callable[[], None] | None = None,
|
||||
transaction: ts_impl.Transaction | None = None,
|
||||
):
|
||||
"""Serializes Arrays or Arrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
* Reading/copying from the source after which the source can be modified.
|
||||
* Returns a copy future.
|
||||
* Writing/committing to the storage layer.
|
||||
* Returns a commit future.
|
||||
|
||||
In asynchronous mode, the serialization waits for the commit future to
|
||||
finish in a separate thread allowing other computation to proceed.
|
||||
|
||||
Args:
|
||||
arrays: Arrays or Arrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
|
||||
Arrays.
|
||||
on_commit_callback: This callback will be executed after all processes
|
||||
have finished writing their checkpoints to disk. Filesystems where
|
||||
atomic rename operations are supported, you can rename from the
|
||||
temporary directory to the final directory. On GCS, you write to the
|
||||
final directory directly and in `on_commit_callback` you write a success
|
||||
file indicating that the serialization was successful because GCS does
|
||||
not support atomic rename operations.
|
||||
transaction: Optional TensorStore transaction to use.
|
||||
"""
|
||||
logger.info('Waiting for previous serialization to finish.')
|
||||
self.wait_until_finished()
|
||||
|
||||
commit_futures: list[ts_impl.Future] = []
|
||||
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(
|
||||
lambda arr_inp, tensorstore_spec: ts_impl.async_serialize(
|
||||
arr_inp,
|
||||
tensorstore_spec,
|
||||
commit_future=commit_futures,
|
||||
transaction=transaction,
|
||||
),
|
||||
arrays,
|
||||
tensorstore_specs,
|
||||
)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
self._add_futures(commit_futures)
|
||||
|
||||
# Used in wait_until_finished to check on process != 0, if the checkpoint
|
||||
# has finished writing.
|
||||
self._start_async_commit(on_commit_callback)
|
||||
|
||||
def serialize_with_paths(
|
||||
self,
|
||||
arrays: Sequence[jax.Array],
|
||||
paths: Sequence[str],
|
||||
*,
|
||||
on_commit_callback: Callable[[], None] | None = None,
|
||||
transaction: ts_impl.Transaction | None = None,
|
||||
):
|
||||
tspecs = jax.tree.map(get_tensorstore_spec, paths)
|
||||
return self.serialize(
|
||||
arrays,
|
||||
tspecs,
|
||||
on_commit_callback=on_commit_callback,
|
||||
transaction=transaction,
|
||||
)
|
||||
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding | Format],
|
||||
tensorstore_specs: Sequence[dict[str, Any]],
|
||||
global_shapes: Sequence[array.Shape] | None = None,
|
||||
dtypes: Sequence[typing.DTypeLike] | None = None,
|
||||
concurrent_gb: int = 32):
|
||||
self.wait_until_finished()
|
||||
return ts_impl._run_deserialization(
|
||||
shardings, tensorstore_specs, global_shapes, dtypes, concurrent_gb)
|
||||
|
||||
def deserialize_with_paths(
|
||||
self, shardings: Sequence[sharding.Sharding],
|
||||
paths: Sequence[str],
|
||||
global_shapes: Sequence[array.Shape] | None = None,
|
||||
dtypes: Sequence[typing.DTypeLike] | None = None,
|
||||
concurrent_gb: int = 32):
|
||||
tspecs = jax.tree.map(get_tensorstore_spec, paths)
|
||||
return self.deserialize(shardings, tspecs, global_shapes, dtypes,
|
||||
concurrent_gb)
|
||||
+604
@@ -0,0 +1,604 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
from functools import partial
|
||||
import functools
|
||||
import os
|
||||
from os import PathLike
|
||||
import re
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
import math
|
||||
import logging
|
||||
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
from jax._src import array
|
||||
from jax._src.layout import Format
|
||||
from jax._src import typing
|
||||
from jax._src.sharding_impls import make_single_device_sharding
|
||||
import numpy as np
|
||||
import tensorstore as ts
|
||||
|
||||
_TS_ARRAY_DRIVER = "zarr3"
|
||||
|
||||
_TS_CONTEXT = ts.Context({
|
||||
'file_io_concurrency': {'limit': 128},
|
||||
'cache_pool': {'total_bytes_limit': 10_000_000_000}, # 10 GB RAM limit
|
||||
'cache_pool#remote': {'total_bytes_limit': 10_000_000_000},
|
||||
'data_copy_concurrency': {'limit': 128}
|
||||
})
|
||||
_TS_CHUNK_LAYOUT = ts.ChunkLayout(
|
||||
chunk=ts.ChunkLayout.Grid(elements=100_000_000), # 100M (800MB for float64) file size
|
||||
)
|
||||
|
||||
_DEFAULT_BASE_DRIVER = 'file'
|
||||
_PROCESS_DIR_FORMAT = "process_{}"
|
||||
_FILE_SIZE_TARGET = 2 * 1024 ** 3 # 2 GB
|
||||
|
||||
Future, Transaction = ts.Future, ts.Transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lifted from T5X.
|
||||
class _LimitInFlightBytes:
|
||||
"""Limits host scratch memory usage when reading/writing checkpoints per process."""
|
||||
|
||||
def __init__(self, host_memory_bytes_limit: int):
|
||||
self._max_bytes = host_memory_bytes_limit
|
||||
self._available_bytes = host_memory_bytes_limit
|
||||
self._cv = asyncio.Condition(lock=asyncio.Lock())
|
||||
|
||||
async def wait_for_bytes(self, requested_bytes):
|
||||
if requested_bytes > self._max_bytes:
|
||||
logger.debug("A single array item requests more bytes than we reserved"
|
||||
" space for in the parallel pool: %d > %d. Increasing the"
|
||||
" limit to %d.", requested_bytes, self._max_bytes,
|
||||
requested_bytes)
|
||||
bytes_currently_used = self._max_bytes - self._available_bytes
|
||||
self._max_bytes = requested_bytes
|
||||
self._available_bytes = self._max_bytes - bytes_currently_used
|
||||
async with self._cv:
|
||||
await self._cv.wait_for(lambda: self._available_bytes >= requested_bytes)
|
||||
self._available_bytes -= requested_bytes
|
||||
assert self._available_bytes >= 0
|
||||
|
||||
async def release_bytes(self, requested_bytes):
|
||||
async with self._cv:
|
||||
self._available_bytes += requested_bytes
|
||||
assert self._available_bytes <= self._max_bytes
|
||||
self._cv.notify_all()
|
||||
|
||||
def is_tensorstore_spec_leaf(leaf: Any):
|
||||
# TODO(rdyro): think of a better way to detect which leaf is a ts config
|
||||
return leaf is None or (isinstance(leaf, dict)
|
||||
and ("driver" in leaf or "kvstore" in leaf))
|
||||
|
||||
def _prime_factors(x: int) -> list[int]:
|
||||
# find prime factors of axis sizes to help efficiently find divisor chunks
|
||||
factors = []
|
||||
while x % 2 == 0:
|
||||
factors.append(2)
|
||||
x //= 2
|
||||
for i in range(3, int(math.sqrt(x)) + 1, 2):
|
||||
while x % i == 0:
|
||||
factors.append(i)
|
||||
x //= i
|
||||
if x > 1:
|
||||
factors.append(x)
|
||||
return sorted(factors)
|
||||
|
||||
@functools.lru_cache(maxsize=1024)
|
||||
def _compute_chunk_shape(
|
||||
local_shape: Sequence[int], dtype: str | jnp.dtype,
|
||||
file_size_target: int = _FILE_SIZE_TARGET) -> list[int]:
|
||||
"""Compute a chunk such that it divides the local shape and is less than
|
||||
target file size. This helps the tensorstore kvstore driver limit the largest
|
||||
file size on disk to below the ``file_size_target``. We compute a chunk with a
|
||||
byte size at most 110% of the ``file_size_target``.
|
||||
"""
|
||||
local_shape = list(local_shape)
|
||||
if len(local_shape) == 0 or math.prod(local_shape) == 0:
|
||||
# a zero size array needs a non-zero chunk passed to tensorstore for compat.
|
||||
return [max(z, 1) for z in local_shape]
|
||||
total_size = math.prod(local_shape) * jnp.dtype(dtype).itemsize
|
||||
axis_prime_factors = [_prime_factors(z) for z in local_shape]
|
||||
chunk_shape, chunk_size = list(local_shape), total_size
|
||||
# while chunk_size exceeds target size, reduce chunk_shape
|
||||
while chunk_size > 1.1 * file_size_target: # 10% buffer
|
||||
# 1. find the smallest axis divisor across all axes
|
||||
chosen_axis_idx: int | None = None
|
||||
chosen_divisor = 1
|
||||
for axis_idx in range(len(chunk_shape)):
|
||||
if len(axis_prime_factors[axis_idx]) == 1: # ignore axes sizes == 1
|
||||
continue
|
||||
if (chosen_axis_idx is None
|
||||
or chosen_divisor > axis_prime_factors[axis_idx][0]):
|
||||
chosen_axis_idx = axis_idx
|
||||
chosen_divisor = axis_prime_factors[axis_idx][0]
|
||||
# 2. if no divisor found, give up, return current chunk shape
|
||||
if chosen_axis_idx is None:
|
||||
return chunk_shape
|
||||
# 3. remove the applied divisor from prime factors
|
||||
prime_factors = axis_prime_factors[chosen_axis_idx]
|
||||
prime_factors.pop(0)
|
||||
# 4. apply the found divisor to reduce the chunk size
|
||||
chunk_shape[chosen_axis_idx] //= chosen_divisor
|
||||
chunk_size //= chosen_divisor
|
||||
return chunk_shape
|
||||
|
||||
def _get_tensorstore_metadata(arr, is_remote: bool = False,
|
||||
file_size_target: int = _FILE_SIZE_TARGET,
|
||||
driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]:
|
||||
global_shape, dtype = arr.shape, arr.dtype
|
||||
if isinstance(arr, jax.Array):
|
||||
local_shape = arr.sharding.shard_shape(global_shape)
|
||||
else: # np.ndarray
|
||||
local_shape = global_shape
|
||||
return _get_tensorstore_metadata_cached(global_shape, dtype, local_shape,
|
||||
is_remote, file_size_target, driver)
|
||||
|
||||
@functools.lru_cache(maxsize=1024)
|
||||
def _get_tensorstore_metadata_cached(
|
||||
global_shape: Sequence[int], dtype: jnp.dtype, local_shape: Sequence[int],
|
||||
is_remote: bool = False, file_size_target: int = _FILE_SIZE_TARGET,
|
||||
driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]:
|
||||
if driver == "zarr3":
|
||||
codecs = ([{"name": "zstd"}] if is_remote else [])
|
||||
return {
|
||||
'codecs': codecs,
|
||||
'shape': global_shape,
|
||||
'data_type': jnp.dtype(dtype).name,
|
||||
'chunk_grid': {
|
||||
'name': 'regular',
|
||||
'configuration': {'chunk_shape': _compute_chunk_shape(
|
||||
local_shape, dtype, file_size_target=file_size_target)}
|
||||
}
|
||||
}
|
||||
elif driver == "zarr": # in zarr dtype goes in the base spec
|
||||
return {'compressor': {'id': 'zstd'}, 'shape': global_shape,
|
||||
'chunks': np.array(np.maximum(1, local_shape)).tolist()}
|
||||
else:
|
||||
raise ValueError(f"Unsupported driver: {driver}")
|
||||
|
||||
_divides = lambda x, y: np.all((np.array(x) % np.array(y)) == 0)
|
||||
|
||||
def merge_nested_ts_specs(dict1: dict[Any, Any], dict2: dict[Any, Any] | None):
|
||||
"""Merge two ts specs, dict2 takes precedence."""
|
||||
if dict2 is None: # nothing to do
|
||||
return dict1
|
||||
# TODO(rdyro): this is an opinionated merge, we should get user feedback
|
||||
# merge kvstore explicitly
|
||||
kvstore = dict1.get("kvstore", {}) | dict2.get("kvstore", {})
|
||||
return dict1 | dict(dict2, kvstore=kvstore) # merge with dict2 preferred
|
||||
|
||||
def verify_tensorstore_spec(spec: dict[str, Any], arr: jax.Array | None,
|
||||
path: str | os.PathLike[str], ocdbt: bool,
|
||||
check_metadata: bool = True) -> None:
|
||||
"""Verify the minimum requirements for a tensorstore spec."""
|
||||
if ocdbt:
|
||||
if spec.get("kvstore", {}).get("driver", "") != "ocdbt":
|
||||
raise ValueError(f"Expected ocdbt driver, got {spec=}")
|
||||
if check_metadata:
|
||||
if arr is None:
|
||||
raise ValueError("Array is required for metadata verification.")
|
||||
metadata = spec['metadata']
|
||||
if spec.get("driver", "") == "zarr3":
|
||||
if metadata['data_type'] != jnp.dtype(arr.dtype).name:
|
||||
raise ValueError(f"Provided dtype ({metadata['data_type']=}) doesn't"
|
||||
f" match ({arr.dtype=})")
|
||||
if 'shape' in metadata:
|
||||
if metadata['shape'] != arr.shape:
|
||||
raise ValueError(f"Provided shape ({metadata['shape']=}) doesn't match"
|
||||
f" ({arr.shape=})")
|
||||
if isinstance(arr, jax.Array):
|
||||
local_shape = arr.sharding.shard_shape(arr.shape)
|
||||
else: # np.ndarray
|
||||
local_shape = arr.shape
|
||||
if spec.get("driver", "") == "zarr3":
|
||||
chunk_shape = metadata['chunk_grid']['configuration']['chunk_shape']
|
||||
if not _divides(local_shape, chunk_shape):
|
||||
raise ValueError(f"Provided chunk shape {chunk_shape} does not divide"
|
||||
f" the local shape of the array {local_shape}")
|
||||
# check path is still the same one we expect
|
||||
if ocdbt:
|
||||
found_path = spec["kvstore"]['base']['path']
|
||||
else:
|
||||
found_path = spec["kvstore"]['path']
|
||||
if str(found_path) != str(path):
|
||||
raise ValueError(f"Provided {path=} does not match the spec path:"
|
||||
f" {spec['kvstore']}")
|
||||
|
||||
def _spec_has_metadata(tree):
|
||||
if not isinstance(tree, dict):
|
||||
return False
|
||||
return 'metadata' in tree or any(
|
||||
_spec_has_metadata(subtree) for _, subtree in tree.items())
|
||||
|
||||
def _get_kvstore_for_gcs(ckpt_path: str):
|
||||
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path)
|
||||
if m is None:
|
||||
raise ValueError('The ckpt_path should contain the bucket name and the '
|
||||
f'file path inside the bucket. Got: {ckpt_path}')
|
||||
bucket = m.group(1)
|
||||
path_without_bucket = m.group(2)
|
||||
return {'driver': 'gcs', 'bucket': bucket, 'path': path_without_bucket}
|
||||
|
||||
def _get_kvstore_for_s3(ckpt_path: str):
|
||||
m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path, re.DOTALL)
|
||||
if m is None:
|
||||
raise ValueError('The ckpt_path should contain the bucket name and the '
|
||||
f'file path inside the bucket. Got: {ckpt_path}')
|
||||
bucket = m.group(1)
|
||||
path_without_bucket = m.group(2)
|
||||
return {'driver': 's3', 'bucket': bucket, 'path': path_without_bucket}
|
||||
|
||||
def get_tensorstore_spec(
|
||||
ckpt_path: str | PathLike[str], ocdbt: bool = True,
|
||||
process_idx: int | None = None, arr: jax.Array | None = None,
|
||||
driver: str = _TS_ARRAY_DRIVER) -> dict[str, Any]:
|
||||
|
||||
# Normalize path to exclude trailing '/'. In GCS path case, normpath will
|
||||
# replace a the double '//' with a single '/' and we need to restore the
|
||||
# filesystem type:// prefix for GCS (gs://) and S3 paths (s3://)
|
||||
ckpt_path = os.path.normpath(str(ckpt_path))
|
||||
ckpt_path = re.sub(r"^([a-z]+):/", r"\1://", ckpt_path)
|
||||
|
||||
# in cases of multi-process writes, we need to write to a different location
|
||||
# for each process and finally created a combined symlink to the final
|
||||
# location, tensorstore can do this via ts.KvStore.experimental_copy_range_to
|
||||
if process_idx is not None:
|
||||
_parent, _name = os.path.split(ckpt_path)
|
||||
ckpt_path = os.path.join(_parent, _PROCESS_DIR_FORMAT.format(process_idx),
|
||||
_name)
|
||||
|
||||
is_gcs_path = ckpt_path.startswith('gs://')
|
||||
is_s3_path = ckpt_path.startswith('s3://')
|
||||
spec = {'driver': driver, 'kvstore': {}}
|
||||
|
||||
# use a combined OCDBT store, the actual path is the parent path
|
||||
# the name (filename/last part of the path) is the key in the ocdbt kvstore
|
||||
entry_key = None
|
||||
if ocdbt:
|
||||
(ckpt_path, entry_key), org_ckpt_path = os.path.split(ckpt_path), ckpt_path
|
||||
if is_gcs_path:
|
||||
m = re.fullmatch('^gs://([^/]*)/(.*)$', ckpt_path)
|
||||
elif is_s3_path:
|
||||
m = re.fullmatch('^s3://([^/]*)/(.*)$', ckpt_path)
|
||||
else:
|
||||
m = re.match("a", "a") # make it True
|
||||
if m is None:
|
||||
raise ValueError('Using OCDBT requires the bucket name, the directory'
|
||||
' name and the array name, your path is: '
|
||||
f'{org_ckpt_path}')
|
||||
|
||||
if is_gcs_path:
|
||||
base_kvstore = _get_kvstore_for_gcs(ckpt_path)
|
||||
elif is_s3_path:
|
||||
base_kvstore = _get_kvstore_for_s3(ckpt_path)
|
||||
else:
|
||||
base_kvstore = {'driver': _DEFAULT_BASE_DRIVER, 'path': ckpt_path}
|
||||
|
||||
if ocdbt:
|
||||
if not is_gcs_path and not is_s3_path and not os.path.isabs(ckpt_path):
|
||||
raise ValueError(f'Checkpoint path should be absolute. Got {ckpt_path}')
|
||||
spec['kvstore'] = {'driver': 'ocdbt', 'base': base_kvstore,
|
||||
'path': entry_key}
|
||||
else:
|
||||
spec['kvstore'] = base_kvstore # pyrefly: ignore[bad-typed-dict-key]
|
||||
# done writing tensorstore spec based on destination path
|
||||
# optionally, if array is provided, we can add metadata to the spec
|
||||
if arr is not None:
|
||||
spec["metadata"] = _get_tensorstore_metadata(
|
||||
arr, driver=str(spec["driver"]))
|
||||
return spec
|
||||
|
||||
async def _create_async_array_from_callback(
|
||||
global_shape: array.Shape,
|
||||
dtype: str | jnp.dtype | None,
|
||||
inp_sharding: jax.sharding.Sharding,
|
||||
data_callback: Callable[[array.Index, jax.Device], Awaitable[jax.Array]],
|
||||
):
|
||||
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
|
||||
addressable_da = inp_sharding._addressable_device_assignment
|
||||
future_arrays = [data_callback(device_to_index_map[d], d)
|
||||
for d in addressable_da]
|
||||
dbs = await asyncio.gather(*future_arrays)
|
||||
return array.make_array_from_single_device_arrays(
|
||||
global_shape, inp_sharding, dbs, dtype=dtype)
|
||||
|
||||
async def _transfer_shard_to_host(shard: array.Shard) -> np.ndarray:
|
||||
data = shard.data
|
||||
has_pinned_host = any(
|
||||
m.kind == "pinned_host" for m in shard.device.addressable_memories())
|
||||
if has_pinned_host:
|
||||
# If available, transfer to pinned host memory
|
||||
sharding = make_single_device_sharding(shard.device,
|
||||
memory_kind="pinned_host")
|
||||
data = jax.device_put(data, sharding)
|
||||
else:
|
||||
data.copy_to_host_async()
|
||||
# Allow other transfers to be scheduled simultaneously
|
||||
await asyncio.sleep(0)
|
||||
# Ensure that jax.Array's internal numpy array can be zero-copied. Tensorstore
|
||||
# implicitly converts the written data to a numpy array, and would otherwise
|
||||
# silently copy host-to-host.
|
||||
return np.array(data, copy=False)
|
||||
|
||||
async def combine_kvstores(combined_kvstore: dict[str, Any],
|
||||
kvstores: list[dict[str, Any]],
|
||||
context: ts.Context | None = _TS_CONTEXT
|
||||
) -> None:
|
||||
"""Merge a list of kvstores into a single kvstore. NOT multi-process safe."""
|
||||
combined_fut = ts.KvStore.open(combined_kvstore, context=context)
|
||||
kvstores_futs = [ts.KvStore.open(kvstore, context=context)
|
||||
for kvstore in kvstores]
|
||||
combined, opened_kvstores = await asyncio.gather(
|
||||
combined_fut, asyncio.gather(*kvstores_futs)
|
||||
)
|
||||
tx = ts.Transaction()
|
||||
await asyncio.gather(*(
|
||||
kvstore.experimental_copy_range_to(combined.with_transaction(tx))
|
||||
for kvstore in opened_kvstores
|
||||
))
|
||||
await tx.commit_async()
|
||||
|
||||
async def async_serialize(
|
||||
arr_inp,
|
||||
tensorstore_spec,
|
||||
commit_future=None,
|
||||
context=_TS_CONTEXT,
|
||||
chunk_layout=_TS_CHUNK_LAYOUT,
|
||||
primary_host: int | None = None,
|
||||
replica_id: int = 0,
|
||||
transaction: ts.Transaction | None = None,
|
||||
):
|
||||
"""Serialize an array using TensorStore.
|
||||
|
||||
Args:
|
||||
arr_inp: The array to serialize.
|
||||
tensorstore_spec: The tensorstore spec to use.
|
||||
commit_future: A list of futures that will be appended to. The futures can
|
||||
be awaited asynchronously. If None, the futures will be awaited
|
||||
synchronously by this method.
|
||||
context: ts.Context instance.
|
||||
primary_host: Primary host, which indicates the host that will be treated as
|
||||
the "leader". If None, all hosts are treated as the primary. DO NOT USE
|
||||
unless you are sure you know what you are doing.
|
||||
replica_id: Allows overriding the shard replica id that will be saved. DO
|
||||
NOT USE unless you are sure you know what you are doing.
|
||||
transaction: TensorStore transaction to use for opening and writing the
|
||||
array. If not specified, a non-transactional write will be used.
|
||||
"""
|
||||
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
|
||||
arr_inp.is_fully_addressable):
|
||||
raise ValueError(
|
||||
f'Passing fully addressable arrays to a multiprocess '
|
||||
f'serialization is not allowed, as this may lead to a race condition '
|
||||
f'between processes. Serialization have failed for the array with '
|
||||
f'the path from kvstore: "{tensorstore_spec["kvstore"]}".')
|
||||
|
||||
# 'metadata' may not be present at the top level (for example, if we are using
|
||||
# a 'cast' driver).
|
||||
if not _spec_has_metadata(tensorstore_spec):
|
||||
tensorstore_spec['metadata'] = _get_tensorstore_metadata(
|
||||
arr_inp, driver=tensorstore_spec['driver'])
|
||||
## zarr driver requires specifying the dtype in the spec base
|
||||
if tensorstore_spec['driver'] == 'zarr' and 'dtype' not in tensorstore_spec:
|
||||
tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name
|
||||
|
||||
# If primary_host is None, all hosts will checkpoint. This is used
|
||||
# for checkpointing to local filesystem.
|
||||
if primary_host is None or jax.process_index() == primary_host:
|
||||
open_future = ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
create=True,
|
||||
open=True,
|
||||
context=context,
|
||||
chunk_layout=chunk_layout,
|
||||
transaction=transaction,
|
||||
)
|
||||
# Asynchronous case.
|
||||
if commit_future is not None:
|
||||
assert isinstance(commit_future, list)
|
||||
commit_future.append(open_future)
|
||||
else:
|
||||
await open_future
|
||||
|
||||
# `ts.open` runs twice for process `primary_host` because for the first time,
|
||||
# we just get the future to be awaited upon in the background thread. The
|
||||
# second one runs with `assume_metadata=True` which does no I/O operation and
|
||||
# returns the tensorstore object.
|
||||
# For every process other than `primary_host`, we open with
|
||||
# `assume_metadata=True`.
|
||||
t = await ts.open(
|
||||
ts.Spec(tensorstore_spec),
|
||||
open=True,
|
||||
assume_metadata=True,
|
||||
context=context,
|
||||
chunk_layout=chunk_layout,
|
||||
transaction=transaction,
|
||||
)
|
||||
|
||||
async def _write_array(shard):
|
||||
if shard.replica_id == replica_id:
|
||||
data = await _transfer_shard_to_host(shard)
|
||||
write_future = t[shard.index].write(
|
||||
data,
|
||||
# Avoid additional copy of input array into the TensorStore chunk
|
||||
# cache. If `arr_inp` is a jax.Array, the result of converting
|
||||
# it to a NumPy array, as is done internally by TensorStore, is
|
||||
# guaranteed to be immutable and therefore it is safe to retain a
|
||||
# reference indefinitely.
|
||||
can_reference_source_data_indefinitely=isinstance(
|
||||
arr_inp, array.ArrayImpl
|
||||
),
|
||||
)
|
||||
if commit_future is not None:
|
||||
assert isinstance(commit_future, list)
|
||||
commit_future.append(write_future.commit)
|
||||
await write_future.copy
|
||||
else:
|
||||
await write_future.commit
|
||||
|
||||
local_shards = arr_inp.addressable_shards
|
||||
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
|
||||
return await asyncio.gather(*future_write_state)
|
||||
|
||||
|
||||
# TODO(rdyro): Remove this function.
|
||||
def _run_serialization(arrays, tensorstore_specs):
|
||||
"""Legacy serialization of a list of arrays."""
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
|
||||
def estimate_read_memory_footprint(t: ts.TensorStore,
|
||||
domain: ts.IndexDomain) -> int:
|
||||
rank = t.rank
|
||||
num_bytes = t.dtype.numpy_dtype.itemsize
|
||||
chunk_template = t.chunk_layout.read_chunk_template
|
||||
if domain is None:
|
||||
domain = t.domain
|
||||
origin = domain.origin
|
||||
shape = domain.shape
|
||||
chunk_origin = chunk_template.origin
|
||||
chunk_shape = chunk_template.shape
|
||||
|
||||
# Some TensorStore drivers are not chunked, e.g. the inline 'array' driver.
|
||||
# For those, instead of returning a near-infinite memory footprint, estimate
|
||||
# the footprint as the entire shape.
|
||||
for i in range(rank):
|
||||
if not chunk_template[i].finite:
|
||||
return domain.size * num_bytes
|
||||
|
||||
# Otherwise, if we have a chunked driver, estimate based on chunk size.
|
||||
for i in range(rank):
|
||||
origin_value = origin[i]
|
||||
chunk_origin_value = chunk_origin[i]
|
||||
chunk_size = chunk_shape[i]
|
||||
lower = origin_value - chunk_origin_value
|
||||
upper = origin_value + shape[i] - chunk_origin_value
|
||||
lower_aligned = lower // chunk_size * chunk_size
|
||||
upper_aligned = -(-upper // chunk_size) * chunk_size
|
||||
num_bytes *= (upper_aligned - lower_aligned)
|
||||
|
||||
return num_bytes
|
||||
|
||||
|
||||
async def async_deserialize(
|
||||
in_type: jax.sharding.Sharding | Format | jax.ShapeDtypeStruct,
|
||||
tensorstore_spec: ts.Spec | dict[str, Any],
|
||||
global_shape: Sequence[int] | None = None,
|
||||
dtype=None,
|
||||
byte_limiter: _LimitInFlightBytes | None = None,
|
||||
context=_TS_CONTEXT,
|
||||
chunk_layout=_TS_CHUNK_LAYOUT,
|
||||
assume_metadata: bool = False,
|
||||
):
|
||||
"""Main performant deserialization routine for arrays using tensorstore."""
|
||||
if isinstance(in_type, Format):
|
||||
in_sharding, layout = in_type.sharding, in_type.layout
|
||||
elif isinstance(in_type, jax.ShapeDtypeStruct):
|
||||
dtype = in_type.dtype if dtype is None else dtype
|
||||
in_sharding = in_type.sharding
|
||||
layout = in_type.format.layout
|
||||
else:
|
||||
if not isinstance(in_type, jax.sharding.Sharding):
|
||||
raise TypeError(
|
||||
'sharding passed to deserialization should be specified, concrete and'
|
||||
f' an instance of `jax.sharding.Sharding`. Got {in_type}')
|
||||
in_sharding = in_type
|
||||
layout = None
|
||||
assert isinstance(in_sharding, jax.sharding.Sharding)
|
||||
t = await ts.open(
|
||||
tensorstore_spec,
|
||||
open=True,
|
||||
assume_metadata=assume_metadata,
|
||||
context=context,
|
||||
chunk_layout=chunk_layout,
|
||||
)
|
||||
shape = t.shape if global_shape is None else global_shape
|
||||
dtype = dtype if dtype is not None else t.dtype.numpy_dtype
|
||||
new_shard_shape = in_sharding.shard_shape(tuple(shape))
|
||||
|
||||
async def cb(index: array.Index, device: jax.Device):
|
||||
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
|
||||
restricted_domain = t.domain.intersect(requested_domain)
|
||||
requested_bytes = estimate_read_memory_footprint(t, restricted_domain)
|
||||
# Limit the bytes read for every shard.
|
||||
if byte_limiter is not None:
|
||||
await byte_limiter.wait_for_bytes(requested_bytes)
|
||||
# This maybe needed because the shape the array was saved with is smaller
|
||||
# than the requested shape of the array in which it will be reloaded. So
|
||||
# the extra values will be filled with 0s.
|
||||
out = np.zeros(new_shard_shape, dtype=t.dtype.numpy_dtype)
|
||||
await ts.array(out)[ts.d[:].translate_to[requested_domain.origin]][
|
||||
restricted_domain].write(t[restricted_domain])
|
||||
if dtype is not None:
|
||||
# Cast while reloading on process to avoid 2 copies on device if the
|
||||
# casting is done on device.
|
||||
out = out.astype(dtype)
|
||||
# Convert to jnp array so that layouts are initialized properly for
|
||||
# sub-byte dtypes.
|
||||
# TODO(yashkatariya): This is a band-aid fix. Figure out a better way to
|
||||
# make this work.
|
||||
if out.dtype == jnp.int4:
|
||||
out = jnp.asarray(out)
|
||||
result = jax.device_put(
|
||||
out, Format(layout, make_single_device_sharding(device)))
|
||||
if byte_limiter is not None:
|
||||
# NB: `out` actually might not be ready for garbage collection by the
|
||||
# time we call release_bytes . Thus peak memory usage still might grow
|
||||
# beyond what byte_limiter limit suggests it should. The simplest option
|
||||
# would be to call `result.block_until_ready()`` here. However it
|
||||
# also comes with ~15-20% perf penalty as we would be waiting for CPU->GPU
|
||||
# transfer instead of loading data. In the future, if memory pressure
|
||||
# becomes a problem, we can instead instrument bytelimiter to
|
||||
# keep track of all in-flight tensors and only block_until_ready, if byte
|
||||
# limiter hits the limit to get reduced memory usage, without losing
|
||||
# performance in common use cases.
|
||||
await byte_limiter.release_bytes(requested_bytes)
|
||||
return result
|
||||
|
||||
# for deserialization canonicalize dtype to a dtype representable in jax
|
||||
return await _create_async_array_from_callback(
|
||||
tuple(shape), jax.dtypes.canonicalize_dtype(dtype), in_sharding, cb)
|
||||
|
||||
|
||||
# TODO(rdyro): Remove this function.
|
||||
def _run_deserialization(shardings: Sequence[jax.sharding.Sharding | Format],
|
||||
tensorstore_specs: Sequence[dict[str, Any] | ts.Spec],
|
||||
global_shapes: Sequence[array.Shape] | None = None,
|
||||
dtypes: Sequence[typing.DTypeLike] | None = None,
|
||||
concurrent_gb: int = 32):
|
||||
"""Legacy deserialization of a list of arrays. Optionally pass global_shapes
|
||||
and dtypes for type-checking.
|
||||
"""
|
||||
concurrent_bytes = concurrent_gb * 10**9
|
||||
|
||||
async def _run_deserializer():
|
||||
# Object should be created once per process.
|
||||
byte_limiter = _LimitInFlightBytes(concurrent_bytes)
|
||||
|
||||
future_arrays = jax.tree_util.tree_map(
|
||||
partial(async_deserialize, byte_limiter=byte_limiter),
|
||||
list(shardings), list(tensorstore_specs),
|
||||
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
|
||||
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
|
||||
return await asyncio.gather(*future_arrays)
|
||||
return asyncio.run(_run_deserializer())
|
||||
Reference in New Issue
Block a user