hand
This commit is contained in:
@@ -0,0 +1,676 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for synchronizing and communication across multiple hosts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial, lru_cache
|
||||
import zlib
|
||||
|
||||
import contextlib
|
||||
from typing import Any
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src import array
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src import pjit as pjit_lib
|
||||
from jax._src import prng
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jax._src import distributed
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _psum(xs: Any) -> Any:
|
||||
return jax.tree.map(lambda x: jnp.sum(x, dtype=x.dtype, axis=0), xs)
|
||||
|
||||
|
||||
def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
|
||||
"""Broadcast data from a source host (host 0 by default) to all other hosts.
|
||||
|
||||
Args:
|
||||
in_tree: pytree of arrays - each array *must* have the same shape across the
|
||||
hosts.
|
||||
is_source: optional bool denoting whether the caller is the source. Only
|
||||
'source host' will contribute the data for the broadcast. If None, then
|
||||
host 0 is used.
|
||||
|
||||
Returns:
|
||||
A pytree matching in_tree where the leaves now all contain the data from the
|
||||
first host.
|
||||
"""
|
||||
if jax.process_count() == 1:
|
||||
return jax.tree.map(np.asarray, in_tree)
|
||||
|
||||
if is_source is None:
|
||||
is_source = jax.process_index() == 0
|
||||
|
||||
devices: np.ndarray = np.array(
|
||||
jax.devices()).reshape(jax.process_count(), jax.local_device_count())
|
||||
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
|
||||
pspec = P('processes')
|
||||
|
||||
def pre_jit(x):
|
||||
if is_source:
|
||||
inp = x
|
||||
else:
|
||||
inp = np.zeros_like(x)
|
||||
inp = np.expand_dims(inp, axis=0)
|
||||
return host_local_array_to_global_array(inp, global_mesh, pspec)
|
||||
|
||||
def post_jit(x):
|
||||
return jax.device_get(x.addressable_data(0))
|
||||
|
||||
in_tree = jax.tree.map(pre_jit, in_tree)
|
||||
with jax.set_mesh(global_mesh):
|
||||
out_tree = jax.jit(_psum, out_shardings=P())(in_tree)
|
||||
|
||||
return jax.tree.map(post_jit, out_tree)
|
||||
|
||||
|
||||
# Identity function is at the top level so that `process_allgather` doesn't
|
||||
# recompile on every invocation.
|
||||
def _identity_fn(x):
|
||||
return x
|
||||
|
||||
|
||||
def _handle_array_process_allgather(inp, tiled):
|
||||
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
|
||||
if not tiled:
|
||||
raise ValueError(
|
||||
'Gathering global non-fully-addressable arrays only supports'
|
||||
' tiled=True')
|
||||
if isinstance(inp.sharding, sharding_impls.NamedSharding):
|
||||
reps = inp.sharding.update(spec=P())
|
||||
else:
|
||||
reps = sharding_impls.GSPMDSharding.get_replicated(
|
||||
inp.sharding._device_assignment, memory_kind=inp.sharding.memory_kind)
|
||||
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
|
||||
else:
|
||||
# All inputs here will be fully addressable.
|
||||
if jax.process_count() == 1:
|
||||
out = np.asarray(inp)
|
||||
return np.expand_dims(out, axis=0) if not tiled else out
|
||||
|
||||
devices = np.array(jax.devices()).reshape(jax.process_count(),
|
||||
jax.local_device_count())
|
||||
global_mesh = jax.sharding.Mesh(devices, ('processes', 'local_devices'))
|
||||
pspec = P('processes')
|
||||
s = jax.sharding.NamedSharding(global_mesh, pspec)
|
||||
|
||||
host_np_arr = np.asarray(inp)
|
||||
if host_np_arr.ndim == 0 or not tiled:
|
||||
host_np_arr = np.expand_dims(host_np_arr, axis=0)
|
||||
|
||||
aval = core.ShapedArray(host_np_arr.shape, host_np_arr.dtype)
|
||||
pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
|
||||
global_aval = pxla.mesh_local_to_global(
|
||||
global_mesh, sharding_impls.get_array_mapping(pspec), aval)
|
||||
|
||||
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
|
||||
global_arr = array.make_array_from_single_device_arrays(
|
||||
global_aval.shape, s, bufs)
|
||||
with jax.set_mesh(global_mesh):
|
||||
out = jax.jit(_identity_fn, out_shardings=P())(global_arr)
|
||||
return np.asarray(out.addressable_data(0))
|
||||
|
||||
|
||||
def process_allgather(in_tree: Any, tiled: bool = False) -> Any:
|
||||
"""Gather data from across processes.
|
||||
|
||||
Args:
|
||||
in_tree: pytree of arrays - each array _must_ have the same shape across the
|
||||
hosts.
|
||||
tiled: Whether to stack or concat the output. Defaults to False i.e. stack
|
||||
into a new positional axis at index 0.
|
||||
|
||||
Returns:
|
||||
Pytrees of numpy arrays.
|
||||
* If the input is a non-fully addressable jax.Array, then the data is
|
||||
fully replicated.
|
||||
* If the input is numpy array or fully addressable jax.Array, then the
|
||||
output shape is dependent on the `tiled` argument.
|
||||
If its False, then the output will be stacked else concatenated.
|
||||
* If the input is a scalar, then the output will be stacked.
|
||||
"""
|
||||
|
||||
def _pjit(inp):
|
||||
return _handle_array_process_allgather(inp, tiled)
|
||||
return jax.tree.map(_pjit, in_tree)
|
||||
|
||||
|
||||
def sync_global_devices(name: str):
|
||||
"""Creates a barrier across all hosts/devices."""
|
||||
h = np.uint32(zlib.crc32(name.encode()))
|
||||
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")
|
||||
|
||||
|
||||
def assert_equal(in_tree, fail_message: str = ''):
|
||||
"""Verifies that all the hosts have the same tree of values."""
|
||||
def concat_in_tree(x):
|
||||
if isinstance(x, array.ArrayImpl) and not x.is_fully_addressable:
|
||||
return np.asarray(x.addressable_data(0))
|
||||
else:
|
||||
x = np.asarray(x)
|
||||
if x.ndim == 0:
|
||||
x = np.expand_dims(x, axis=0)
|
||||
return np.concat([x] * jax.process_count())
|
||||
|
||||
out = process_allgather(in_tree, tiled=True)
|
||||
expected_in_tree = jax.tree.map(concat_in_tree, in_tree)
|
||||
if not jax.tree.all(
|
||||
jax.tree.map(lambda *x: np.all(np.equal(*x)), expected_in_tree, out)):
|
||||
raise AssertionError(
|
||||
f'{fail_message}. Expected: {out}; got: {in_tree}.')
|
||||
|
||||
|
||||
def reached_preemption_sync_point(step_id: int) -> bool:
|
||||
"""Determine whether all hosts have reached a preemption sync step.
|
||||
|
||||
When any host receives a preemption notice, the notice is propagated to all
|
||||
hosts and triggers a synchronization protocol in the background. The
|
||||
synchronization protocol calculates the maximum step ids from all hosts, and
|
||||
uses the next step id (i.e., max + 1) as the safe step to save a checkpoint.
|
||||
All hosts should continue training more steps until this method returns True,
|
||||
indicating that the `step_id` is equal to the safe step and the hosts should
|
||||
start saving a checkpoint.
|
||||
|
||||
To use this API, all hosts must start training from the same step and call it
|
||||
at every training step. Example usage:
|
||||
|
||||
```
|
||||
def should_save(step_id: int) -> bool:
|
||||
|
||||
# Should save an on-demand checkpoint for preemption
|
||||
if multihost_utils.reached_preemption_sync_point(step_id):
|
||||
return True
|
||||
|
||||
# Should save a regular checkpoint
|
||||
return step_id - last_saved_checkpoint_step >= save_interval_steps
|
||||
```
|
||||
|
||||
Preemption notice is provided by the cluster scheduler to notify the
|
||||
application in advance before it gets evicted. By default, we use SIGTERM as
|
||||
the signal for preemption notice.
|
||||
|
||||
TODO(b/230630494): Add instructions for customized preemption notice.
|
||||
|
||||
Returns:
|
||||
A boolean indicating whether all hosts have reached a synchronization step
|
||||
after some hosts are preempted.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if preemption sync manager has not been initialized.
|
||||
"""
|
||||
if distributed.global_state.client is None:
|
||||
return False
|
||||
sync_manager = distributed.global_state.preemption_sync_manager
|
||||
if sync_manager is None:
|
||||
raise RuntimeError(
|
||||
"Preemption sync manager has not been initialized. Make sure the"
|
||||
" 'jax_enable_preemption_service' config is enabled."
|
||||
)
|
||||
return sync_manager.reached_sync_point(step_id)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _flatten_pspecs(name, in_tree, pspecs_thunk):
|
||||
return pjit_lib.flatten_axis_resources(
|
||||
name, in_tree, pspecs_thunk(), tupled_args=True)
|
||||
|
||||
@lru_cache
|
||||
def _local_to_global_aval(local_aval, mesh, pspec):
|
||||
pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
|
||||
return pxla.mesh_local_to_global(
|
||||
mesh, sharding_impls.get_array_mapping(pspec), local_aval)
|
||||
|
||||
@lru_cache
|
||||
def _global_to_local_aval(global_aval, mesh, pspec):
|
||||
pspec = sharding_impls.prepare_axis_resources(pspec, "pspec to array_mapping")
|
||||
return pxla.mesh_global_to_local(
|
||||
mesh, sharding_impls.get_array_mapping(pspec), global_aval)
|
||||
|
||||
|
||||
def host_local_array_to_global_array_impl(
|
||||
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
|
||||
if pspec is None:
|
||||
raise ValueError(
|
||||
'`None` is not a valid input to the pspecs argument. Please use '
|
||||
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
|
||||
# If the Array is not fully addressable i.e. not host local, return it.
|
||||
if isinstance(arr, array.ArrayImpl) and not arr.is_fully_addressable:
|
||||
return arr
|
||||
if (isinstance(arr, array.ArrayImpl) and not hasattr(arr, 'shape')):
|
||||
arr = np.array(arr)
|
||||
if arr.dtype == dtypes.float0:
|
||||
arr = np.zeros(arr.shape, dtype=np.dtype(bool))
|
||||
dtype = arr.dtype
|
||||
if is_prng_key_array := isinstance(arr, prng.PRNGKeyArray):
|
||||
arr = arr._base_array
|
||||
|
||||
local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec)
|
||||
|
||||
# If the input is a concrete jax.Array and the input array sharding
|
||||
# matches the `local_sharding`, then there's no need to reshard and create
|
||||
# copies.
|
||||
if (isinstance(arr, array.ArrayImpl) and
|
||||
arr.sharding.is_equivalent_to(local_sharding, arr.ndim)):
|
||||
arrays = [x.data for x in arr.addressable_shards]
|
||||
else:
|
||||
arr = dtypes.canonicalize_value(arr)
|
||||
arrays = [
|
||||
arr[i] for i in local_sharding.devices_indices_map(arr.shape).values()
|
||||
]
|
||||
|
||||
global_aval = _local_to_global_aval(
|
||||
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
||||
|
||||
out = pxla.batched_device_put(
|
||||
global_aval, jax.sharding.NamedSharding(global_mesh, pspec),
|
||||
arrays, list(global_mesh.local_mesh.devices.flat))
|
||||
if is_prng_key_array:
|
||||
return prng.PRNGKeyArray(dtype._impl, out)
|
||||
return out
|
||||
|
||||
|
||||
def host_local_array_to_global_array(
|
||||
local_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
|
||||
r"""Converts a host local value to a globally sharded jax.Array.
|
||||
|
||||
This function takes host-local data (which might be different
|
||||
across hosts), and populates a global array with this data, where each
|
||||
device on each host, get the appropriate slice of the data according to
|
||||
sharding defined by the global_mesh/pspects.
|
||||
|
||||
For example:
|
||||
|
||||
>>> global_mesh = jax.sharding.Mesh(jax.devices(), 'x')
|
||||
>>> pspecs = jax.sharding.PartitionSpec('x')
|
||||
>>> host_id = jax.process_index()
|
||||
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # NB: assumes jax.local_device_count() divides 4. # doctest: +SKIP
|
||||
|
||||
The resulting array will have the shape (4 * num_processes) and will
|
||||
have distributed value of: (0, 1, 2, 3, 0, 2, 4, 6, 0, 3, 6, 9, ... ),
|
||||
where each slice np.arange(4) * host_id will be partitioned across the
|
||||
corresponding host's devices.
|
||||
|
||||
Similarly:
|
||||
|
||||
>>> mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(jax.process_count(), jax.local_device_count()), ['host', 'dev'])
|
||||
>>> pspecs = jax.sharding.PartitionSpec('host')
|
||||
>>> host_id = jax.process_index()
|
||||
>>> arr = host_local_array_to_global_array(np.arange(4) * host_id, mesh, pspecs) # doctest: +SKIP
|
||||
|
||||
will create the same distributed value (0, 1, 2, 3, 0, 2, 4, 6, ...),
|
||||
however each slice np.arange(4) * i will be *replicated* across corresponding
|
||||
host devices.
|
||||
|
||||
On the other hand, if pspecs = PartitionSpec(), which means
|
||||
replication across all axes, then this snippet:
|
||||
|
||||
>>> pspecs = jax.sharding.PartitionSpec()
|
||||
>>> arr = host_local_array_to_global_array(np.arange(4), mesh, pspecs) # doctest: +SKIP
|
||||
|
||||
will have the shape (4,) and the value (0, 1, 2, 3) will be replicated
|
||||
across all hosts and devices.
|
||||
|
||||
It is an undefined behavior to have not identical local_inputs with pspec
|
||||
indicating data replication.
|
||||
|
||||
You can use this function to transition to jax.Array. Using jax.Array with
|
||||
pjit has the same semantics of using GDA with pjit i.e. all jax.Array
|
||||
inputs to pjit should be globally shaped.
|
||||
|
||||
If you are currently passing host local values to pjit, you can use this
|
||||
function to convert your host local values to global Arrays and then pass that
|
||||
to pjit.
|
||||
|
||||
|
||||
Example usage.
|
||||
|
||||
>>> from jax.experimental import multihost_utils # doctest: +SKIP
|
||||
>>>
|
||||
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
|
||||
>>>
|
||||
>>> with mesh: # doctest: +SKIP
|
||||
>>> global_out = pjitted_fun(global_inputs) # doctest: +SKIP
|
||||
>>>
|
||||
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
|
||||
|
||||
Please note this function requires global mesh to be a continuous mesh, meaning
|
||||
that devices that belong to each host should form a subcube in this mesh.
|
||||
To move local data to global array with non-continuous mesh use
|
||||
jax.make_array_from_callback or jax.make_array_from_single_device_arrays
|
||||
instead.
|
||||
|
||||
Args:
|
||||
local_inputs: A Pytree of host local values.
|
||||
global_mesh: A jax.sharding.Mesh object. The mesh must be a contiguous mesh,
|
||||
that is all hosts' devices must form a subcube in this mesh.
|
||||
pspecs: A Pytree of jax.sharding.PartitionSpec's.
|
||||
|
||||
Returns:
|
||||
A pytree of global arrays.
|
||||
"""
|
||||
flat_inps, in_tree = tree_flatten(local_inputs)
|
||||
in_pspecs = _flatten_pspecs('input pspecs', in_tree,
|
||||
pjit_lib.hashable_pytree(pspecs))
|
||||
out_flat = [
|
||||
host_local_array_to_global_array_p.bind(inp, global_mesh=global_mesh,
|
||||
pspec=in_spec)
|
||||
for inp, in_spec in safe_zip(flat_inps, in_pspecs)
|
||||
]
|
||||
return tree_unflatten(in_tree, out_flat)
|
||||
|
||||
host_local_array_to_global_array_p = core.Primitive('host_local_array_to_global_array')
|
||||
host_local_array_to_global_array_p.def_impl(host_local_array_to_global_array_impl)
|
||||
|
||||
def ltg_abstract_eval(arr, *, global_mesh, pspec):
|
||||
return _local_to_global_aval(
|
||||
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
||||
host_local_array_to_global_array_p.def_abstract_eval(ltg_abstract_eval)
|
||||
|
||||
ad.deflinear2(host_local_array_to_global_array_p,
|
||||
lambda ct, _, **params: (
|
||||
host_local_array_to_global_array_p.bind(ct, **params),))
|
||||
|
||||
def ltg_batcher(insert_axis, axis_data, vals_in, dims_in, global_mesh, pspec):
|
||||
del insert_axis
|
||||
x, = vals_in
|
||||
d, = dims_in
|
||||
new_parts = None if axis_data.spmd_name is None else axis_data.spmd_name
|
||||
new_pspec = list(pspec)
|
||||
if d is not None:
|
||||
new_pspec.insert(d, new_parts)
|
||||
new_pspec = P(*new_pspec)
|
||||
y = host_local_array_to_global_array_p.bind(
|
||||
x, global_mesh=global_mesh, pspec=new_pspec)
|
||||
return y, d
|
||||
batching.fancy_primitive_batchers[host_local_array_to_global_array_p] = partial(
|
||||
ltg_batcher, False)
|
||||
|
||||
def _ltg_lowering(ctx, x, *, global_mesh, pspec):
|
||||
return [x]
|
||||
mlir.register_lowering(host_local_array_to_global_array_p, _ltg_lowering)
|
||||
|
||||
|
||||
def global_array_to_host_local_array_impl(
|
||||
arr: Any, *, global_mesh: jax.sharding.Mesh, pspec: Any):
|
||||
if pspec is None:
|
||||
raise ValueError(
|
||||
'`None` is not a valid input to the pspecs argument. Please use '
|
||||
'jax.sharding.PartitionSpec() if you wanted to replicate your input.')
|
||||
# If the Array is already fully addressable i.e. host local, return it.
|
||||
if isinstance(arr, array.ArrayImpl) and arr.is_fully_addressable:
|
||||
return arr
|
||||
if not hasattr(arr, 'shape'):
|
||||
arr = np.array(arr)
|
||||
if arr.dtype == dtypes.float0:
|
||||
arr = np.zeros(arr.shape, dtype=np.dtype(bool))
|
||||
dtype = arr.dtype
|
||||
if is_prng_key_array := isinstance(arr, prng.PRNGKeyArray):
|
||||
arr = arr._base_array
|
||||
|
||||
global_sharding = jax.sharding.NamedSharding(global_mesh, pspec)
|
||||
local_sharding = jax.sharding.NamedSharding(global_mesh.local_mesh, pspec)
|
||||
local_aval = _global_to_local_aval(
|
||||
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
||||
|
||||
if isinstance(arr, array.ArrayImpl):
|
||||
if arr.sharding.is_equivalent_to(global_sharding, arr.ndim):
|
||||
arrays = arr._arrays
|
||||
else:
|
||||
resharded_array = jax.device_put(arr, global_sharding)
|
||||
arrays = resharded_array._arrays
|
||||
out = array.ArrayImpl(local_aval, local_sharding, arrays, committed=True)
|
||||
if is_prng_key_array:
|
||||
return prng.PRNGKeyArray(dtype._impl, out)
|
||||
return out
|
||||
else:
|
||||
# numpy array can show up here during AD.
|
||||
arr = dtypes.canonicalize_value(arr)
|
||||
arrays = [
|
||||
arr[i] for i in local_sharding.devices_indices_map(arr.shape).values()
|
||||
]
|
||||
return pxla.batched_device_put(
|
||||
local_aval, local_sharding, arrays,
|
||||
list(global_mesh.local_mesh.devices.flat))
|
||||
|
||||
|
||||
def global_array_to_host_local_array(
|
||||
global_inputs: Any, global_mesh: jax.sharding.Mesh, pspecs: Any):
|
||||
r"""Converts a global `jax.Array` to a host local `jax.Array`.
|
||||
|
||||
You can use this function to transition to `jax.Array`. Using `jax.Array` with
|
||||
pjit has the same semantics of using GDA with pjit i.e. all `jax.Array`
|
||||
inputs to pjit should be globally shaped and the output from pjit will also
|
||||
be globally shaped jax.Array's
|
||||
|
||||
You can use this function to convert the globally shaped `jax.Array` output
|
||||
from pjit to host local values again so that the transition to jax.Array can
|
||||
be a mechanical change.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> from jax.experimental import multihost_utils # doctest: +SKIP
|
||||
>>>
|
||||
>>> global_inputs = multihost_utils.host_local_array_to_global_array(host_local_inputs, global_mesh, in_pspecs) # doctest: +SKIP
|
||||
>>>
|
||||
>>> with mesh: # doctest: +SKIP
|
||||
... global_out = pjitted_fun(global_inputs) # doctest: +SKIP
|
||||
>>>
|
||||
>>> host_local_output = multihost_utils.global_array_to_host_local_array(global_out, mesh, out_pspecs) # doctest: +SKIP
|
||||
|
||||
Args:
|
||||
global_inputs: A Pytree of global jax.Array's.
|
||||
global_mesh: A :class:`jax.sharding.Mesh` object. The mesh must be contiguous
|
||||
meaning all local devices of the host must form a subcube.
|
||||
pspecs: A Pytree of :class:`jax.sharding.PartitionSpec` objects.
|
||||
|
||||
Returns:
|
||||
A Pytree of host local arrays.
|
||||
"""
|
||||
flat_inps, out_tree = tree_flatten(global_inputs)
|
||||
out_pspecs = _flatten_pspecs('output pspecs', out_tree,
|
||||
pjit_lib.hashable_pytree(pspecs))
|
||||
out_flat = [
|
||||
global_array_to_host_local_array_p.bind(inp, global_mesh=global_mesh,
|
||||
pspec=o)
|
||||
for inp, o in safe_zip(flat_inps, out_pspecs)
|
||||
]
|
||||
return tree_unflatten(out_tree, out_flat)
|
||||
|
||||
global_array_to_host_local_array_p = core.Primitive('global_array_to_host_local_array')
|
||||
global_array_to_host_local_array_p.def_impl(global_array_to_host_local_array_impl)
|
||||
|
||||
def gtl_abstract_eval(arr, *, global_mesh, pspec):
|
||||
return _global_to_local_aval(
|
||||
core.ShapedArray(arr.shape, arr.dtype), global_mesh, pspec)
|
||||
global_array_to_host_local_array_p.def_abstract_eval(gtl_abstract_eval)
|
||||
|
||||
ad.deflinear2(global_array_to_host_local_array_p,
|
||||
lambda ct, _, **params: (
|
||||
global_array_to_host_local_array_p.bind(ct, **params),))
|
||||
batching.defvectorized(global_array_to_host_local_array_p)
|
||||
|
||||
def _gtl_lowering(ctx, x, *, global_mesh, pspec):
|
||||
return [x]
|
||||
mlir.register_lowering(global_array_to_host_local_array_p, _gtl_lowering)
|
||||
|
||||
|
||||
def _live_devices(client, devices: list[xla_client.Device]) -> dict[xla_client.Device, int]:
|
||||
"""Returns the subset of the provided devices that are live and healthy."""
|
||||
process_ids = {d.process_index for d in devices}
|
||||
if xla_bridge.process_index() not in process_ids:
|
||||
# A process can only participate in an live_devices call if it hosts some of
|
||||
# the provided devices.
|
||||
raise ValueError('Provided devices do not have any local devices.')
|
||||
|
||||
live_process_ids = client.get_live_nodes(list(process_ids))
|
||||
return {
|
||||
d: live_process_ids[d.process_index]
|
||||
for d in devices
|
||||
if d.process_index in live_process_ids
|
||||
}
|
||||
|
||||
|
||||
class _LiveDevices:
|
||||
"""A context manager for atomically running code on the set of live devices.
|
||||
|
||||
THIS API IS UNDER ACTIVE DEVELOPMENT AND IS NOT STABLE.
|
||||
|
||||
# Overview
|
||||
|
||||
`live_devices` is a low-level primitive that can be used to make
|
||||
multi-controller JAX programs fault tolerant. A multi-controller JAX program
|
||||
runs across many devices, and the machines that host these devices might fail.
|
||||
`live_devices` is a context manager that yields the current set of healthy
|
||||
devices, allowing you to run JAX code on the healthy devices while ignoring
|
||||
the failed ones.
|
||||
|
||||
Concretely, `live_devices` is a context manager. You provide it the set of
|
||||
devices you are interested in, and it yields the subset of these devices that
|
||||
are live. In the body of the `with` statement, you can execute arbitrary JAX
|
||||
code using the set of live devices.
|
||||
|
||||
# Example Usage
|
||||
|
||||
try:
|
||||
with jax.live_devices(jax.devices()) as devices:
|
||||
# Run JAX code here with devices.
|
||||
pass
|
||||
except:
|
||||
# A device died while executing the with statement above.
|
||||
pass
|
||||
else:
|
||||
# The with statement executed successfully.
|
||||
pass
|
||||
|
||||
# Barrier Semantics
|
||||
|
||||
It's important that every process agrees on which devices are live to avoid
|
||||
the processes' behavior from diverging. For example, imagine a set of
|
||||
processes trying to run an AllGather, but they all disagree on which devices
|
||||
should be participating in the AllGather. This is buggy.
|
||||
|
||||
To ensure that every process agrees on the set of live devices, the
|
||||
`live_devices` context manager has barrier-like semantics. Consider an
|
||||
invocation `with live_devices(devices)` where `devices` includes devices
|
||||
across a set of processes P. The invocation acts as a barrier, waiting for
|
||||
every process in P to call `with live_devices(devices)`. Afterwards,
|
||||
`live_devices` returns the same set of live devices `A` to all the processes
|
||||
in P. This ensures that every process agrees on the set of live devices.
|
||||
|
||||
`live_devices` does not actually act as a barrier for *every* process in P
|
||||
because some processes in P might have failed. Instead, the `live_devices`
|
||||
function waits only for the processes with a device in the returned set of
|
||||
live devices A.
|
||||
|
||||
# An Example
|
||||
|
||||
Imagine we have four processes, each with two devices:
|
||||
|
||||
Process A: Devices 1 and 2
|
||||
Process B: Devices 3 and 4
|
||||
Process C: Devices 5 and 6
|
||||
Process D: Devices 7 and 8
|
||||
|
||||
Further imagine that process D fails and that every process calls `with
|
||||
live_devices(jax.devices())`. The invocation returns devices 1, 2, 3, 4, 5,
|
||||
and 6. Because these devices are hosted by processes A, B, and C, the call to
|
||||
`live_devices` acts as a barrier across processes A, B, and C. Process D,
|
||||
which failed, is ignored.
|
||||
|
||||
# Atomicity
|
||||
|
||||
`live_devices` also provides the following transaction-like atomicity
|
||||
property. When a process exits the body of a `with jax.live_devices(...) as
|
||||
devices:` block, there are two possibilities.
|
||||
|
||||
1. All processes in `devices` successfully executed all code in the block
|
||||
without any exceptions being raised.
|
||||
2. All processes in `devices` did not successfully execute the code in the
|
||||
block, and all the processes will raise an exception.
|
||||
|
||||
Consider the following code.
|
||||
|
||||
try:
|
||||
with jax.live_devices(...) as devices:
|
||||
pass
|
||||
except:
|
||||
pass # A
|
||||
else:
|
||||
pass # B
|
||||
|
||||
The atomicity property says that either every process with devices in
|
||||
`devices` will enter the except branch (A) or every process with devices in
|
||||
`devices` will enter the else branch (B). It is impossible for some processes
|
||||
to enter A and others to enter B.
|
||||
|
||||
TODO: mwhittaker - Link to formal live devices semantics.
|
||||
|
||||
Args:
|
||||
devices: A list of devices. The provided devices must include at least one
|
||||
local device.
|
||||
|
||||
Returns:
|
||||
The subset of the provided devices that are live and healthy.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the distributed runtime was not initialized.
|
||||
ValueError: If no local devices are provided.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.devices = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def __call__(self, devices):
|
||||
client = distributed.global_state.client
|
||||
if client is None:
|
||||
raise RuntimeError('Distributed JAX not initialized.')
|
||||
|
||||
if not devices:
|
||||
# TODO(mwhittaker): Make devices optional. If it's not provided, use
|
||||
# jax.devices() as a default.
|
||||
raise ValueError('No devices provided.')
|
||||
|
||||
if self.devices is None:
|
||||
self.devices = _live_devices(client, devices)
|
||||
exception = None
|
||||
try:
|
||||
alive = list(self.devices.keys())
|
||||
alive.sort(key=lambda d: d.id)
|
||||
yield alive
|
||||
except Exception as e:
|
||||
exception = e
|
||||
finally:
|
||||
old_devices = self.devices
|
||||
new_devices = _live_devices(client, devices)
|
||||
self.devices = new_devices
|
||||
if exception:
|
||||
raise exception
|
||||
if not old_devices.items() <= new_devices.items():
|
||||
raise ValueError(f'{old_devices} is not a subset of {new_devices}')
|
||||
|
||||
live_devices = _LiveDevices()
|
||||
Reference in New Issue
Block a user