hand
This commit is contained in:
+374
@@ -0,0 +1,374 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Array serialization and deserialization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import array
|
||||
from jax._src import distributed
|
||||
from jax._src import sharding
|
||||
from jax._src import typing
|
||||
from jax._src import util
|
||||
from jax._src.layout import Format
|
||||
from jax._src.lib import _jax
|
||||
from jax.experimental.array_serialization import tensorstore_impl as ts_impl
|
||||
# ruff: noqa: F401
|
||||
# import tensorstore-backed methods for backward compatibility.
|
||||
from jax.experimental.array_serialization.tensorstore_impl import (
|
||||
_run_deserialization as run_deserialization,
|
||||
_run_serialization as run_serialization,
|
||||
async_serialize, async_deserialize, _TS_CONTEXT as TS_CONTEXT,
|
||||
_DEFAULT_BASE_DRIVER as _DEFAULT_DRIVER, _LimitInFlightBytes)
|
||||
import tensorstore as ts
|
||||
|
||||
# for compatibility with older zarr format
|
||||
_get_metadata = functools.partial(ts_impl._get_tensorstore_metadata,
|
||||
driver='zarr')
|
||||
get_tensorstore_spec = functools.partial(ts_impl.get_tensorstore_spec,
|
||||
driver='zarr', ocdbt=False)
|
||||
|
||||
|
||||
_CHECKPOINT_SUCCESS = 'checkpoint_write_success'
|
||||
_module_unique_count = itertools.count()
|
||||
_DISTRIBUTED_SYSTEM_MSG = (
|
||||
'Please initialize the distributed system via '
|
||||
'`jax.distributed.initialize()` at the start of your program.')
|
||||
_REMOTE_URL_PREFIXES = ['gs://', 's3://']
|
||||
_REMOTE_DRIVER_VALIDATIONS = [
|
||||
{'driver': 'gcs', 'path_regex': None},
|
||||
{'driver': 's3', 'path_regex': None},
|
||||
]
|
||||
|
||||
class BarrierTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
_BARRIER_TIMED_OUT_MSG = (
|
||||
"Suggestions for possible fixes:\n"
|
||||
"* Check the logs to see if one or more processes failed.\n"
|
||||
"* Make sure the training and checkpointing endpoints are close geographically.\n"
|
||||
"* Try increasing the timeout you pass to GlobalAsyncCheckpointManager.")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_remote_storage(tspec: dict[str, Any] | str) -> bool:
|
||||
"""Detect if user is using cloud storages.
|
||||
|
||||
This can detect common defines and unable to detect some corner cases such as
|
||||
using gcsfuse.
|
||||
"""
|
||||
if isinstance(tspec, str):
|
||||
# KvStoreUrl
|
||||
if re.match(rf'^({"|".join(_REMOTE_URL_PREFIXES)})', tspec):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
for key in ('base', 'kvstore'):
|
||||
if key in tspec:
|
||||
return is_remote_storage(tspec[key])
|
||||
|
||||
if 'driver' in tspec:
|
||||
for rule in _REMOTE_DRIVER_VALIDATIONS:
|
||||
if tspec['driver'] == rule['driver']:
|
||||
if rule['path_regex'] is None:
|
||||
return True
|
||||
|
||||
# check if path matches the regex.
|
||||
if re.match(rule['path_regex'], tspec['path']):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _get_key(key: int):
|
||||
return f'tensorstore_checkpoint_{key}'
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManagerBase(util.StrictABC):
|
||||
"""Interface for checkpointing GDAs asynchronously.
|
||||
|
||||
This class manages the state of an ongoing asynchronous checkpoint.
|
||||
|
||||
For example, say a checkpoint happens on every step. If you checkpoint on
|
||||
step 1 and after some computation the model is on checkpoint 2. But step 1's
|
||||
checkpoint hasn't finished committing to the storage layer yet. So until that
|
||||
is finished, checkpoint for step 2 will need to be blocked. Maintaining a
|
||||
class allows to maintain that state.
|
||||
|
||||
Examples:
|
||||
|
||||
Below is a simplified training loop:
|
||||
|
||||
```
|
||||
# Call this at the start of your program.
|
||||
jax.distributed.initialize()
|
||||
|
||||
manager = GlobalAsyncCheckpointManager()
|
||||
|
||||
# Restore checkpoint if available or initialize the train_state from
|
||||
# init_fn().
|
||||
train_state = manager.deserialize(...)
|
||||
|
||||
while ...:
|
||||
if step % num_steps_between_checkpoints == 0:
|
||||
manager.serialize(train_state, temp_checkpoint_dir=...,
|
||||
final_checkpoint_dir=...)
|
||||
train_state = train_step(train_state, input)
|
||||
# This is a non-blocking call.
|
||||
manager.check_for_errors()
|
||||
|
||||
manager.serialize(train_state, temp_checkpoint_dir=...,
|
||||
final_checkpoint_dir=...)
|
||||
# Wait before the end of the program for the checkpoint to finish. This is a
|
||||
# blocking call.
|
||||
manager.wait_until_finished()
|
||||
```
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_for_errors(self):
|
||||
"""Checks if any errors have been raised in the child thread.
|
||||
|
||||
This is a non-blocking call that can be called in the main thread.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def wait_until_finished(self):
|
||||
"""Blocks until serialization has finished."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def serialize(self, arrays, tensorstore_specs, *,
|
||||
on_commit_callback: Callable[[], None]):
|
||||
"""Serializes GDAs to TensorStore."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding],
|
||||
tensorstore_specs: Sequence[dict[str, Any]],
|
||||
global_shapes: Sequence[array.Shape] | None = None,
|
||||
dtypes: Sequence[typing.DTypeLike] | None = None):
|
||||
"""Deserializes GDAs from TensorStore."""
|
||||
|
||||
|
||||
class AsyncManager:
|
||||
|
||||
def __init__(self, timeout_secs=300):
|
||||
self._timeout_secs = timeout_secs
|
||||
self._timeout_in_ms = self._timeout_secs * 1000
|
||||
|
||||
self._commit_futures = None
|
||||
self._thread = None
|
||||
self._exception = None
|
||||
|
||||
if jax.process_count() > 1 and distributed.global_state.client is None:
|
||||
raise ValueError(_DISTRIBUTED_SYSTEM_MSG)
|
||||
self._client = distributed.global_state.client
|
||||
self._count: int | None = None
|
||||
|
||||
def __del__(self):
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
logger.warning('Please add `.wait_until_finished()` in the main thread '
|
||||
'before your program finishes because there is a '
|
||||
'possibility of losing errors raised if the '
|
||||
'this class is deleted before writing is completed.')
|
||||
|
||||
def _thread_func(self):
|
||||
try:
|
||||
current_process = jax.process_index()
|
||||
process_count = jax.process_count()
|
||||
logger.info('Starting commit to storage layer by process: %s',
|
||||
current_process)
|
||||
thread_start_time = time.time()
|
||||
assert self._commit_futures is not None
|
||||
for future in self._commit_futures:
|
||||
future.result()
|
||||
logger.info('Finished committing to storage layer by process: %s',
|
||||
current_process)
|
||||
|
||||
key_for_barrier: str | None = None
|
||||
if process_count > 1:
|
||||
assert self._client is not None
|
||||
assert self._count is not None
|
||||
# All processes will wait at the barrier. When all processes are at the
|
||||
# barrier, the barrier will be satisfied. If not, then it will timeout.
|
||||
key_for_barrier = _get_key(self._count)
|
||||
logger.info('Key used for barrier is %s for process %s',
|
||||
key_for_barrier, current_process)
|
||||
self._client.wait_at_barrier(key_for_barrier, self._timeout_in_ms)
|
||||
logger.info('Finished waiting at barrier for process %s',
|
||||
current_process)
|
||||
|
||||
if current_process == 0:
|
||||
if self._on_commit_callback is not None:
|
||||
self._on_commit_callback()
|
||||
logger.info('on_commit_callback successfully ran!')
|
||||
if process_count > 1:
|
||||
assert self._client is not None
|
||||
assert key_for_barrier is not None
|
||||
self._client.key_value_set(key_for_barrier, _CHECKPOINT_SUCCESS)
|
||||
logger.info('Process 0 successfully set key %s in the kv store',
|
||||
key_for_barrier)
|
||||
|
||||
jax.monitoring.record_event_duration_secs(
|
||||
'/jax/checkpoint/write/async/thread_duration_sec',
|
||||
time.time() - thread_start_time)
|
||||
|
||||
except Exception as e:
|
||||
self._exception = e
|
||||
|
||||
def _start_async_commit(self, on_commit_callback):
|
||||
self._count = next(_module_unique_count)
|
||||
|
||||
self._on_commit_callback = on_commit_callback
|
||||
self._thread = threading.Thread(target=self._thread_func)
|
||||
self._thread.start()
|
||||
|
||||
def check_for_errors(self):
|
||||
if self._exception is not None:
|
||||
# Clears self._exception so it is only raised once.
|
||||
exception = self._exception
|
||||
self._exception = None
|
||||
if (isinstance(exception, _jax.JaxRuntimeError) and
|
||||
'DEADLINE_EXCEEDED: Barrier timed out' in str(exception)):
|
||||
raise BarrierTimeoutError(
|
||||
'\n'.join([str(exception), _BARRIER_TIMED_OUT_MSG]))
|
||||
raise exception
|
||||
|
||||
def wait_until_finished(self):
|
||||
if self._thread is not None:
|
||||
self._thread.join()
|
||||
self._thread = None
|
||||
logger.info('Thread joined successfully')
|
||||
|
||||
self.check_for_errors()
|
||||
logger.info('Error check finished successfully')
|
||||
|
||||
if jax.process_count() > 1 and self._count is not None:
|
||||
assert self._client is not None
|
||||
# Block until process 0 writes success value to the key value store.
|
||||
# If it fails to write it, then `blocking_key_value_get` will time out.
|
||||
get_key = _get_key(self._count)
|
||||
self._client.blocking_key_value_get(get_key, self._timeout_in_ms)
|
||||
logger.info('blocking_key_value_get on key %s was successfully '
|
||||
'completed.', get_key)
|
||||
|
||||
def _add_futures(self, futures: Sequence[ts.Future]):
|
||||
self._commit_futures = futures
|
||||
|
||||
|
||||
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
|
||||
"""Responsible for serializing GDAs via TensorStore."""
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
arrays,
|
||||
tensorstore_specs,
|
||||
*,
|
||||
on_commit_callback: Callable[[], None] | None = None,
|
||||
transaction: ts_impl.Transaction | None = None,
|
||||
):
|
||||
"""Serializes Arrays or Arrays via TensorStore asynchronously.
|
||||
|
||||
TensorStore writes to a storage layer in 2 steps:
|
||||
* Reading/copying from the source after which the source can be modified.
|
||||
* Returns a copy future.
|
||||
* Writing/committing to the storage layer.
|
||||
* Returns a commit future.
|
||||
|
||||
In asynchronous mode, the serialization waits for the commit future to
|
||||
finish in a separate thread allowing other computation to proceed.
|
||||
|
||||
Args:
|
||||
arrays: Arrays or Arrays that should be serialized.
|
||||
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
|
||||
Arrays.
|
||||
on_commit_callback: This callback will be executed after all processes
|
||||
have finished writing their checkpoints to disk. Filesystems where
|
||||
atomic rename operations are supported, you can rename from the
|
||||
temporary directory to the final directory. On GCS, you write to the
|
||||
final directory directly and in `on_commit_callback` you write a success
|
||||
file indicating that the serialization was successful because GCS does
|
||||
not support atomic rename operations.
|
||||
transaction: Optional TensorStore transaction to use.
|
||||
"""
|
||||
logger.info('Waiting for previous serialization to finish.')
|
||||
self.wait_until_finished()
|
||||
|
||||
commit_futures: list[ts_impl.Future] = []
|
||||
|
||||
async def _run_serializer():
|
||||
future_writer = jax.tree_util.tree_map(
|
||||
lambda arr_inp, tensorstore_spec: ts_impl.async_serialize(
|
||||
arr_inp,
|
||||
tensorstore_spec,
|
||||
commit_future=commit_futures,
|
||||
transaction=transaction,
|
||||
),
|
||||
arrays,
|
||||
tensorstore_specs,
|
||||
)
|
||||
return await asyncio.gather(*future_writer)
|
||||
asyncio.run(_run_serializer())
|
||||
|
||||
self._add_futures(commit_futures)
|
||||
|
||||
# Used in wait_until_finished to check on process != 0, if the checkpoint
|
||||
# has finished writing.
|
||||
self._start_async_commit(on_commit_callback)
|
||||
|
||||
def serialize_with_paths(
|
||||
self,
|
||||
arrays: Sequence[jax.Array],
|
||||
paths: Sequence[str],
|
||||
*,
|
||||
on_commit_callback: Callable[[], None] | None = None,
|
||||
transaction: ts_impl.Transaction | None = None,
|
||||
):
|
||||
tspecs = jax.tree.map(get_tensorstore_spec, paths)
|
||||
return self.serialize(
|
||||
arrays,
|
||||
tspecs,
|
||||
on_commit_callback=on_commit_callback,
|
||||
transaction=transaction,
|
||||
)
|
||||
|
||||
def deserialize(self, shardings: Sequence[sharding.Sharding | Format],
|
||||
tensorstore_specs: Sequence[dict[str, Any]],
|
||||
global_shapes: Sequence[array.Shape] | None = None,
|
||||
dtypes: Sequence[typing.DTypeLike] | None = None,
|
||||
concurrent_gb: int = 32):
|
||||
self.wait_until_finished()
|
||||
return ts_impl._run_deserialization(
|
||||
shardings, tensorstore_specs, global_shapes, dtypes, concurrent_gb)
|
||||
|
||||
def deserialize_with_paths(
|
||||
self, shardings: Sequence[sharding.Sharding],
|
||||
paths: Sequence[str],
|
||||
global_shapes: Sequence[array.Shape] | None = None,
|
||||
dtypes: Sequence[typing.DTypeLike] | None = None,
|
||||
concurrent_gb: int = 32):
|
||||
tspecs = jax.tree.map(get_tensorstore_spec, paths)
|
||||
return self.deserialize(shardings, tspecs, global_shapes, dtypes,
|
||||
concurrent_gb)
|
||||
Reference in New Issue
Block a user