768 lines
30 KiB
Python
768 lines
30 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Primitive dispatch and jit dispatch.
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
from collections.abc import Sequence
|
|
import dataclasses
|
|
from functools import partial
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import Any
|
|
|
|
from jax._src import api
|
|
from jax._src import array
|
|
from jax._src import basearray
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import literals
|
|
from jax._src import pjit
|
|
from jax._src import traceback_util
|
|
from jax._src import util
|
|
|
|
from jax._src import xla_bridge
|
|
from jax._src.abstract_arrays import array_types
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.api_util import InternalFloatingPointError
|
|
from jax._src.layout import Layout, Format
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.mesh import AbstractMesh, Mesh
|
|
from jax._src.monitoring import record_scalar, record_event_duration_secs, record_event_time_span
|
|
from jax._src.partition_spec import PartitionSpec
|
|
from jax._src.sharding import Sharding
|
|
from jax._src.sharding_impls import (
|
|
NamedSharding, make_single_device_sharding, GSPMDSharding)
|
|
from jax._src.stages import SourceInfo
|
|
import numpy as np
|
|
from jax._src.lib import jaxlib_extension_version
|
|
|
|
|
|
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
|
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
|
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
xe = xc._xla
|
|
|
|
Backend = xe.Client
|
|
Device = xc.Device
|
|
ArrayCopySemantics = xc.ArrayCopySemantics
|
|
|
|
CompileOptions = xc.CompileOptions
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# This flag is set on exit; no logging should be attempted
|
|
_on_exit = False
|
|
|
|
### op-by-op execution
|
|
|
|
def apply_primitive(prim, *args, **params):
|
|
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
|
fun = xla_primitive_callable(prim, **params)
|
|
# TODO(yashkatariya): Investigate adding is_primitive to jit and never
|
|
# triggering the disable jit path instead of messing around with it here.
|
|
prev = config.disable_jit.swap_local(False)
|
|
try:
|
|
outs = fun(*args)
|
|
finally:
|
|
config.disable_jit.set_local(prev)
|
|
return outs
|
|
|
|
# TODO(necula): this cache will contain strong references to all
|
|
# Jaxprs in `params` (for higher-order primitives).
|
|
# This is not immediately fixable by using
|
|
# util.multi_weakref_lru_cache, because the `params` (including the Jaxpr)
|
|
# are closed over in the `prim_fun` lambda. Leaving this fix for a later PR.
|
|
@util.cache()
|
|
def xla_primitive_callable(prim: core.Primitive, **params):
|
|
util.test_event("xla_primitive_callable_cache_miss")
|
|
def prim_fun(*args):
|
|
with config.eager_constant_folding(False):
|
|
return prim.bind(*args, **params)
|
|
prim_fun.__name__ = prim.name
|
|
prim_fun.__qualname__ = prim.name
|
|
prim_fun._apply_primitive = True # pyrefly: ignore[missing-attribute]
|
|
return api.jit(prim_fun)
|
|
|
|
|
|
def simple_impl(prim):
|
|
prim.def_impl(partial(apply_primitive, prim))
|
|
|
|
RuntimeToken = Any
|
|
|
|
class RuntimeTokenSet(threading.local):
|
|
"""See docstring for effects.py module for the calling convention for tokens."""
|
|
|
|
# For each ordered effect, the token returned by the last dispatched
|
|
# computation, sharded over the devices in that computation.
|
|
current_tokens: dict[core.Effect, core.Token]
|
|
|
|
# For each device, the runtime token returned by the last dispatched
|
|
# computation on that device.
|
|
output_runtime_tokens: dict[Device, RuntimeToken]
|
|
|
|
def __init__(self):
|
|
self.current_tokens = {}
|
|
self.output_runtime_tokens = {}
|
|
|
|
def get_token_input(
|
|
self, eff: core.Effect, devices: list[Device]
|
|
) -> core.Token:
|
|
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))
|
|
|
|
if isinstance(tok, core.Token):
|
|
# The order of devices may change, so we need to reshard if necessary.
|
|
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
|
|
# scenario. Revise the logic later. A distributed shutdown barrier inside
|
|
# the XLA program may be needed.
|
|
return api.device_put(
|
|
tok, NamedSharding(Mesh(devices, 'x'), PartitionSpec('x')))
|
|
|
|
# We only use replicated sharding for the first time when the token for the
|
|
# order effect hasn't been created.
|
|
s = GSPMDSharding.get_replicated(devices)
|
|
sharded_tok = core.Token(
|
|
pxla.shard_args(
|
|
[s], [None], [xc.ArrayCopySemantics.REUSE_INPUT], [tok]
|
|
)[0]
|
|
)
|
|
self.current_tokens[eff] = sharded_tok
|
|
return sharded_tok
|
|
|
|
def set_token_result(self, eff: core.Effect, token: core.Token):
|
|
self.current_tokens[eff] = token
|
|
|
|
def set_output_runtime_token(self, device: Device, token: RuntimeToken):
|
|
# We're free to clobber the previous output token because on each
|
|
# device we have a total ordering of computations. Only the token
|
|
# from the latest computation matters.
|
|
self.output_runtime_tokens[device] = token
|
|
|
|
def clear(self):
|
|
self.current_tokens = {}
|
|
self.output_runtime_tokens = {}
|
|
|
|
def block_until_ready(self):
|
|
for token in self.current_tokens.values():
|
|
token.block_until_ready()
|
|
for token in self.output_runtime_tokens.values():
|
|
token.block_until_ready()
|
|
self.clear()
|
|
|
|
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
|
|
|
|
@atexit.register
|
|
def wait_for_tokens():
|
|
runtime_tokens.block_until_ready()
|
|
|
|
|
|
class LogElapsedTimeContextManager:
|
|
__slots__ = ['fmt', 'fun_name', 'event', 'start_time']
|
|
|
|
def __init__(self, fmt: str, fun_name: str, event: str | None = None):
|
|
self.fmt = fmt
|
|
self.fun_name = fun_name
|
|
self.event = event
|
|
|
|
def __enter__(self):
|
|
self.start_time = time.time()
|
|
if self.event is not None:
|
|
record_scalar(
|
|
self.event, self.start_time, fun_name=self.fun_name
|
|
)
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
if _on_exit:
|
|
return
|
|
|
|
end_time = time.time()
|
|
elapsed_time = end_time - self.start_time
|
|
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
|
|
if logger.isEnabledFor(log_priority):
|
|
logger.log(log_priority, self.fmt.format(
|
|
fun_name=self.fun_name, elapsed_time=elapsed_time))
|
|
if self.event is not None:
|
|
record_event_duration_secs(
|
|
self.event, elapsed_time, fun_name=self.fun_name
|
|
)
|
|
record_event_time_span(
|
|
self.event, self.start_time, end_time, fun_name=self.fun_name
|
|
)
|
|
|
|
log_elapsed_time = LogElapsedTimeContextManager
|
|
|
|
|
|
def should_tuple_args(num_args: int, platform: str) -> bool:
|
|
# CPU and GPU do not need tuples as they use host-side data structures that
|
|
# do not have small bounds.
|
|
# TPU only needs a tuple for very long lists
|
|
if platform == "tpu":
|
|
return num_args > 2000
|
|
else:
|
|
return False
|
|
|
|
def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool:
|
|
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
|
|
for eqn in jaxpr.eqns:
|
|
if prim_name in eqn.primitive.name:
|
|
return True
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
if jaxpr_has_primitive(subjaxpr, prim_name):
|
|
return True
|
|
return False
|
|
|
|
|
|
# Use this registry with caution. It will void the guarantee that lowering to
|
|
# stablehlo is oblivious of physical devices.
|
|
prim_requires_devices_during_lowering: set[core.Primitive] = set()
|
|
|
|
@util.weakref_lru_cache
|
|
def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool:
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive in prim_requires_devices_during_lowering:
|
|
return True
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
if jaxpr_has_prim_requiring_devices(subjaxpr):
|
|
return True
|
|
return False
|
|
|
|
|
|
@util.weakref_lru_cache
|
|
def get_intermediate_shardings(
|
|
jaxpr: core.Jaxpr) -> Sequence[tuple[Sharding, SourceInfo]]:
|
|
from jax._src import shard_map # pytype: disable=import-error
|
|
|
|
out = []
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive is pjit.sharding_constraint_p:
|
|
s = eqn.params['sharding']
|
|
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
|
|
continue
|
|
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
|
out.append((s, source_info))
|
|
elif eqn.primitive is pjit.jit_p:
|
|
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
|
out.extend((i, source_info) for i in eqn.params['in_shardings'])
|
|
out.extend((o, source_info) for o in eqn.params['out_shardings'])
|
|
elif eqn.primitive is shard_map.shard_map_p:
|
|
mesh = eqn.params['mesh']
|
|
if isinstance(mesh, AbstractMesh):
|
|
continue
|
|
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
|
out.extend((NamedSharding(mesh, spec), source_info)
|
|
for spec in [*eqn.params['in_specs'], *eqn.params['out_specs']])
|
|
elif eqn.primitive is device_put_p:
|
|
source_info = SourceInfo(eqn.source_info, eqn.primitive.name)
|
|
out.extend((s, source_info) for s in eqn.params['devices']
|
|
if isinstance(s, Sharding) and s.memory_kind is not None)
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
out.extend(get_intermediate_shardings(subjaxpr))
|
|
return out
|
|
|
|
|
|
def check_arg(arg: Any):
|
|
if not core.valid_jaxtype(arg):
|
|
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
|
|
"JAX type.")
|
|
|
|
|
|
def needs_check_special() -> bool:
|
|
return config.debug_infs.value or config.debug_nans.value
|
|
|
|
def check_special(name: str, bufs: Sequence[basearray.Array]) -> None:
|
|
if needs_check_special():
|
|
for buf in bufs:
|
|
_check_special(name, buf.dtype, buf)
|
|
|
|
|
|
def check_special_array(name: str, arr: array.ArrayImpl) -> array.ArrayImpl:
|
|
if needs_check_special():
|
|
if dtypes.issubdtype(arr.dtype, np.inexact):
|
|
for buf in arr._arrays:
|
|
_check_special(name, buf.dtype, buf)
|
|
return arr
|
|
|
|
|
|
def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
|
|
raise InternalFloatingPointError(name, "nan")
|
|
if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):
|
|
raise InternalFloatingPointError(name, "inf")
|
|
|
|
def _device_put_reshard(x): return x
|
|
|
|
|
|
@util.cache(max_size=2048, trace_context_in_key=False)
|
|
def _cached_logical_device_ids(
|
|
inp_device_list: xc.DeviceList,
|
|
target_device_list: xc.DeviceList
|
|
) -> tuple[int, ...]:
|
|
device_to_index = {d: i for i, d in enumerate(target_device_list)}
|
|
return tuple(device_to_index[d] for d in inp_device_list)
|
|
|
|
|
|
def _different_device_order_reshard(
|
|
x: array.ArrayImpl, target_sharding: NamedSharding, copy: ArrayCopySemantics
|
|
) -> array.ArrayImpl:
|
|
x._check_if_deleted()
|
|
inp_sharding = x.sharding
|
|
assert isinstance(inp_sharding, NamedSharding)
|
|
|
|
inp_device_list = inp_sharding._internal_device_list
|
|
target_device_list = target_sharding._internal_device_list
|
|
|
|
donate_argnums = 0 if copy == ArrayCopySemantics.DONATE_INPUT else None
|
|
if inp_device_list == target_device_list:
|
|
return api.jit(_device_put_reshard, out_shardings=target_sharding,
|
|
donate_argnums=donate_argnums)(x)
|
|
|
|
if inp_sharding.is_fully_replicated:
|
|
logical_device_ids = None
|
|
else:
|
|
logical_device_ids = _cached_logical_device_ids(
|
|
inp_device_list, target_device_list,
|
|
)
|
|
|
|
new_mesh = Mesh(
|
|
target_sharding.mesh.devices.reshape(inp_sharding.mesh.axis_sizes),
|
|
inp_sharding.mesh.axis_names)
|
|
new_s = NamedSharding(
|
|
new_mesh, inp_sharding.spec, memory_kind=target_sharding.memory_kind,
|
|
_logical_device_ids=logical_device_ids)
|
|
new_x = xc.reorder_shards(x, new_s, ArrayCopySemantics.REUSE_INPUT)
|
|
return api.jit(_device_put_reshard, out_shardings=target_sharding,
|
|
donate_argnums=donate_argnums)(new_x)
|
|
|
|
|
|
@util.cache(max_size=2048, trace_context_in_key=False)
|
|
def _is_supported_cross_host_transfer(ndim, src_sharding, dst_sharding):
|
|
"""Returns True if src->dst is a supported cross-host transfer."""
|
|
if (src_sharding._internal_device_list.device_kind !=
|
|
dst_sharding._internal_device_list.device_kind):
|
|
return False
|
|
if (src_sharding._to_xla_hlo_sharding(ndim) !=
|
|
dst_sharding._to_xla_hlo_sharding(ndim)):
|
|
return False
|
|
# This check excludes the case where the source and destination shardings
|
|
# have the same process index sets but there are shards that require
|
|
# cross-host transfers. This case is supportable but expensive to check for.
|
|
different_process_inds = (
|
|
src_sharding._internal_device_list.process_indices !=
|
|
dst_sharding._internal_device_list.process_indices)
|
|
backend = xla_bridge.get_backend()
|
|
# If a cross-host device transfer is requested but the backend does not
|
|
# support it, then the user must set the flags to enable DCN-based transfers.
|
|
if (different_process_inds and
|
|
(xla_bridge.FORCE_DCN_CROSS_HOST_TRANSFERS.value
|
|
or not getattr(backend, "supports_cross_host_transfers", False)) and
|
|
not xla_bridge.CROSS_HOST_TRANSFER_SOCKET_ADDRESS.value):
|
|
if xla_bridge.FORCE_DCN_CROSS_HOST_TRANSFERS.value:
|
|
msg = ("DCN-based cross-host transfers were requested with the "
|
|
"jax_force_dcn_cross_host_transfers flag.")
|
|
else:
|
|
msg = ("The backend ({backend.platform}, {backend.platform_version}) "
|
|
"does not support cross-host device transfers.")
|
|
raise ValueError(
|
|
f"{msg} Please set jax_cross_host_transfer_socket_address and "
|
|
"(optionally) jax_cross_host_transport_addresses flags to enable "
|
|
"DCN-based cross host device transfers.")
|
|
return different_process_inds
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class _DeferredShardArg:
|
|
"""Deferred call to `pxla.shard_args`.
|
|
|
|
Per-array impls return this object instead of a result array to indicate a
|
|
deferred `shard_args` call. `_batched_device_put_impl` then batches all
|
|
`_DeferredShardArg` objects into a single `shard_args` call.
|
|
"""
|
|
|
|
x: Any
|
|
s: Sharding
|
|
aval: core.AbstractValue
|
|
committed: bool
|
|
copy_semantics: ArrayCopySemantics
|
|
|
|
def result_handler(self, shard_arg_result):
|
|
return pxla.global_aval_to_result_handler(
|
|
self.aval, self.s, self.committed)(shard_arg_result)
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class _DeferredCrossHostTransferArg:
|
|
"""Deferred call to `xc.batched_copy_array_to_devices_with_sharding` for
|
|
cross-host data transfers.
|
|
|
|
Per-array impls return this object instead of a result array to indicate a
|
|
deferred `batched_copy_array_to_devices_with_sharding` call for a cross-host
|
|
data transfer. `_batched_device_put_impl` then batches all
|
|
`_DeferredCrossHostTransferArg` objects into a single
|
|
`_batched_device_put_impl` call.
|
|
|
|
For any _DeferredCrossHostTransferArg, _is_supported_cross_host_transfer(
|
|
x.ndim, x.sharding, dst_sharding) == True.
|
|
"""
|
|
|
|
x: array.ArrayImpl
|
|
dst_sharding: Sharding
|
|
copy_semantics: ArrayCopySemantics
|
|
|
|
|
|
def _device_put_sharding_impl(
|
|
x: Any,
|
|
aval: core.ShapedArray,
|
|
device: Device | Sharding | None,
|
|
copy: ArrayCopySemantics,
|
|
):
|
|
from jax.experimental import multihost_utils # pytype: disable=import-error
|
|
|
|
# Use a dynamic type, because the static type depends on the value of
|
|
# ``x_is_jax_array``.
|
|
x_sharding: Any
|
|
if isinstance(x, array.ArrayImpl):
|
|
x_is_jax_array = True
|
|
x_is_fully_addressable, x_sharding = x.is_fully_addressable, x.sharding
|
|
else:
|
|
x_is_jax_array = False
|
|
x_is_fully_addressable, x_sharding = None, None
|
|
|
|
if isinstance(device, Sharding):
|
|
s = device
|
|
s_is_fully_addressable = s.is_fully_addressable
|
|
if (getattr(x, 'sharding', None) == s and getattr(x, '_committed', False)
|
|
and copy == ArrayCopySemantics.REUSE_INPUT):
|
|
return x
|
|
|
|
if isinstance(s, NamedSharding) and s.spec.unreduced:
|
|
# TODO(mattjj,yashkatariya): handle donation
|
|
if jaxlib_extension_version >= 428:
|
|
return api.jit(_device_put_reshard, out_shardings=s)(x)
|
|
else:
|
|
return pjit.reshard(x, s)
|
|
|
|
if (not s_is_fully_addressable and
|
|
x_is_jax_array and not x_is_fully_addressable and
|
|
s.device_set == x_sharding.device_set):
|
|
assert isinstance(s, NamedSharding), s
|
|
return _different_device_order_reshard(x, s, copy)
|
|
|
|
if (s_is_fully_addressable and x_is_jax_array and
|
|
x_is_fully_addressable and s.num_devices > 1 and
|
|
s._internal_device_list != x_sharding._internal_device_list and
|
|
s.device_set == x_sharding.device_set):
|
|
assert isinstance(s, NamedSharding), s
|
|
return _different_device_order_reshard(x, s, copy)
|
|
|
|
if (x_is_jax_array and x._committed and xla_bridge.process_count() > 1
|
|
and _is_supported_cross_host_transfer(x.ndim, x_sharding, s)):
|
|
return _DeferredCrossHostTransferArg(x, s, copy)
|
|
|
|
if not s_is_fully_addressable:
|
|
# If both the source and target shardings are not fully addressable and
|
|
# one of the above conditions has not been met, then assume that the user
|
|
# is attempting a different device order reshard.
|
|
if (x_is_jax_array and not x_is_fully_addressable
|
|
and s.device_set != x_sharding.device_set):
|
|
inp_ids = [d.id for d in x_sharding._device_assignment]
|
|
inp_plat = x_sharding._device_assignment[0].platform.upper()
|
|
target_ids = [d.id for d in s._device_assignment]
|
|
target_plat = s._device_assignment[0].platform.upper()
|
|
raise ValueError(
|
|
"For a cross-host reshard in multi-controller JAX, input and target"
|
|
" sharding should have the same set of devices. Got input's device"
|
|
f" set ids: {inp_ids} on platform {inp_plat} and target sharding's"
|
|
f" device set ids: {target_ids} on platform {target_plat}.\n\n"
|
|
"There is experimental support for cross-host transfers with "
|
|
"different device sets, when input/output shardings have the same "
|
|
"indices and layouts, in the TFRT TPU runtime only.")
|
|
|
|
if ((x_is_jax_array and not x._committed) or
|
|
type(x) in array_types or type(x) in dtypes.python_scalar_types):
|
|
# If all hosts participate in the sharding, assert that the input is the
|
|
# same on all hosts. If some hosts have no addressable devices in the
|
|
# sharding, bypass the check, since we can't easily distinguish between
|
|
# these two cases: (1) the sharding contains the same subset of global
|
|
# devices on all hosts (and hosts with no addressable devices in the
|
|
# sharding do not transfer data) or (2) the sharding contains a
|
|
# different subset of devices on each host. For (1), the input should be
|
|
# the same on all hosts, but for (2) it need not be.
|
|
if xla_bridge.process_count() == len(s._internal_device_list.process_indices):
|
|
multihost_utils.assert_equal(
|
|
x, fail_message=(
|
|
f"{type(x)} passed to device_put is not the same on each"
|
|
" process. Make sure you are passing the same value of"
|
|
f" {type(x)} on each process."))
|
|
return _DeferredShardArg(x, s, aval, True, copy)
|
|
# TODO(yashkatariya,mattjj): Link to a doc about McJAX and jax.Array.
|
|
raise ValueError(
|
|
"device_put's second argument must be a Device or a Sharding which"
|
|
f" represents addressable devices, but got {s}. Please pass device or"
|
|
" Sharding which represents addressable devices.")
|
|
return _DeferredShardArg(x, s, aval, True, copy)
|
|
|
|
# Only `Device` exists below. `Sharding` instance is handled above.
|
|
if x_is_jax_array:
|
|
if not x_is_fully_addressable and not x_sharding.num_devices == 1:
|
|
raise ValueError(
|
|
"When the second argument to `device_put` is a Device, the first "
|
|
"argument must be a fully addressable array or a non-addressable "
|
|
"array with a single device sharding. Got value with devices "
|
|
f"{x.devices()}")
|
|
if device is None:
|
|
if copy == ArrayCopySemantics.REUSE_INPUT:
|
|
return x
|
|
else:
|
|
return _DeferredShardArg(x, x_sharding, aval, x.committed, copy)
|
|
elif x_sharding.num_devices == 1:
|
|
device = x_sharding._device_assignment[0] if device is None else device
|
|
sharding = make_single_device_sharding(device)
|
|
if not x._committed and not sharding.has_addressable_devices:
|
|
# For uncommitted arrays in McJAX, each process has a local copy of the
|
|
# array. If the destination sharding is not addressable, no data
|
|
# transfer is needed, since the data was transferred in the process
|
|
# in which the sharding is addressable.
|
|
shards, devices = [], []
|
|
else:
|
|
shards, devices = [x], [device]
|
|
if copy == ArrayCopySemantics.ALWAYS_COPY:
|
|
return xc.batched_device_put(aval, sharding, shards, devices, True,
|
|
True)
|
|
return pxla.batched_device_put(aval, sharding, shards, devices)
|
|
|
|
sh = make_single_device_sharding(pxla.get_default_device()
|
|
if device is None else device)
|
|
return _DeferredShardArg(x, sh, aval, device is not None, copy)
|
|
|
|
|
|
def _device_put_impl(
|
|
x, *, device: Device | Sharding | Format | None,
|
|
src: Device | Sharding | Format | None, copy: ArrayCopySemantics, aval):
|
|
if aval is None:
|
|
try:
|
|
if isinstance(x, core.Tracer):
|
|
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
|
|
aval = core.typeof(x)
|
|
aval = update_dp_aval(aval, device)
|
|
except TypeError as err:
|
|
raise TypeError(
|
|
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
|
|
|
if isinstance(device, core.MemorySpace):
|
|
return apply_primitive(device_put_p, x, devices=(device,), srcs=(src,),
|
|
copy_semantics=(copy,))[0]
|
|
|
|
if isinstance(device, Format):
|
|
l = device
|
|
dll = l.layout
|
|
x_dll = x.format.layout if hasattr(x, 'format') else None
|
|
if dll is None and l.sharding is None:
|
|
return _device_put_sharding_impl(x, aval, l.sharding, copy)
|
|
if (not isinstance(l.sharding, Sharding) or
|
|
not isinstance(dll, (Layout, type(None)))):
|
|
raise ValueError(
|
|
"sharding and layout in `Layout` instance should be"
|
|
f" concrete. Got layout: {l} for input {aval.str_short()}")
|
|
if (getattr(x, 'format', None) == l and getattr(x, '_committed', False) and
|
|
copy == ArrayCopySemantics.REUSE_INPUT):
|
|
return x
|
|
if x_dll is None and dll is None:
|
|
return _device_put_sharding_impl(x, aval, l.sharding, copy)
|
|
return api.jit(
|
|
_device_put_reshard,
|
|
out_shardings=l,
|
|
donate_argnums=(0 if copy == ArrayCopySemantics.DONATE_INPUT else None),
|
|
)(x)
|
|
|
|
return _device_put_sharding_impl(x, aval, device, copy)
|
|
|
|
|
|
def _batched_device_put_impl(
|
|
*xs,
|
|
devices: Sequence[Device | Sharding | Format | None],
|
|
srcs: Sequence[Device | Sharding | Format | None],
|
|
copy_semantics: Sequence[ArrayCopySemantics],
|
|
dst_avals: Sequence[core.ShapedArray | None]):
|
|
ys = []
|
|
dsa_indices, dsa_xs, dsa_shardings, dsa_copy_semantics = [], [], [], []
|
|
dca_indices, dca_xs, dca_shardings, dca_device_lists, dca_copy_semantics = \
|
|
[], [], [], [], []
|
|
|
|
for i, (x, device, src, cp, aval) in enumerate(
|
|
zip(xs, devices, srcs, copy_semantics, dst_avals)):
|
|
y = _device_put_impl(x, device=device, src=src, copy=cp, aval=aval)
|
|
if isinstance(y, _DeferredShardArg):
|
|
dsa_indices.append(i)
|
|
dsa_xs.append(y.x)
|
|
dsa_shardings.append(y.s)
|
|
dsa_copy_semantics.append(y.copy_semantics)
|
|
elif isinstance(y, _DeferredCrossHostTransferArg):
|
|
dca_indices.append(i)
|
|
dca_xs.append(y.x)
|
|
dca_shardings.append(y.dst_sharding)
|
|
dca_device_lists.append(y.dst_sharding._internal_device_list)
|
|
dca_copy_semantics.append(y.copy_semantics)
|
|
ys.append(y)
|
|
|
|
if dsa_xs:
|
|
shard_arg_results = pxla.shard_args(dsa_shardings, [None] * len(dsa_xs),
|
|
dsa_copy_semantics, dsa_xs)
|
|
for i, shard_arg_result in zip(dsa_indices, shard_arg_results):
|
|
assert isinstance(ys[i], _DeferredShardArg)
|
|
ys[i] = ys[i].result_handler(shard_arg_result)
|
|
if dca_xs:
|
|
copy_array_results = xc.batched_copy_array_to_devices_with_sharding(
|
|
dca_xs, dca_device_lists, dca_shardings, dca_copy_semantics)
|
|
for i, copy_array_result in zip(dca_indices, copy_array_results):
|
|
assert isinstance(ys[i], _DeferredCrossHostTransferArg)
|
|
ys[i] = copy_array_result
|
|
|
|
return ys
|
|
|
|
def batched_device_put_impl(
|
|
*xs,
|
|
devices: Sequence[Device | Sharding | Format | None],
|
|
srcs: Sequence[Device | Sharding | Format | None],
|
|
copy_semantics: Sequence[ArrayCopySemantics]):
|
|
return _batched_device_put_impl(
|
|
*xs, devices=devices, srcs=srcs, copy_semantics=copy_semantics,
|
|
dst_avals=[None] * len(devices))
|
|
|
|
|
|
device_put_p = core.Primitive('device_put')
|
|
device_put_p.multiple_results = True
|
|
device_put_p.def_impl(batched_device_put_impl)
|
|
|
|
|
|
def _device_put_folding_rule(consts, params, out_avals):
|
|
# We elide device_puts that do nothing; these can be generated by jnp.array,
|
|
# for example.
|
|
if (all(x is None for x in params["devices"])
|
|
and all(isinstance(x, literals.TypedNdArray) for x in consts)
|
|
and all(x == ArrayCopySemantics.REUSE_INPUT for x in params["copy_semantics"])):
|
|
return consts
|
|
return None
|
|
|
|
partial_eval.const_fold_rules[device_put_p] = _device_put_folding_rule
|
|
|
|
|
|
def update_dp_aval(aval, d):
|
|
if not isinstance(aval, core.ShapedArray):
|
|
return aval
|
|
if isinstance(d, Sharding):
|
|
aval = (aval.update(sharding=aval.sharding.update(mesh=d.mesh.abstract_mesh,
|
|
spec=d.spec))
|
|
if isinstance(d, NamedSharding) else aval.update(sharding=None))
|
|
if d.memory_kind is not None:
|
|
aval = aval.update(memory_space=core.mem_kind_to_space(d.memory_kind))
|
|
return aval
|
|
elif isinstance(d, core.MemorySpace):
|
|
return aval.update(memory_space=d)
|
|
return aval
|
|
|
|
def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics):
|
|
return [update_dp_aval(x, d) for x, d in zip(xs, devices)]
|
|
device_put_p.def_abstract_eval(_device_put_abstract_eval)
|
|
|
|
def _device_put_transpose(cts, *args, devices, srcs, copy_semantics):
|
|
results: list[Any | None] = [None] * len(cts)
|
|
dp_cts = []
|
|
for i, (ct, arg, device, src, cp) in enumerate(zip(
|
|
cts, args, devices, srcs, copy_semantics)):
|
|
if ad.is_undefined_primal(arg):
|
|
if type(ct) is ad.Zero:
|
|
results[i] = ad.Zero(arg.aval.to_ct_aval())
|
|
else:
|
|
dp_cts.append((i, ct, arg, device, src, cp))
|
|
|
|
if dp_cts:
|
|
indices, dp_ct, args, devices, srcs, copy_semantics = list(zip(*dp_cts))
|
|
# TODO(yashkatariya): Maybe remove the special carve out for Host?
|
|
srcs = tuple(a.aval.memory_space
|
|
if s is None and a.aval.memory_space == core.MemorySpace.Host
|
|
else s for s, a in zip(srcs, args))
|
|
new_copy_semantics = []
|
|
for cp in copy_semantics:
|
|
if cp == ArrayCopySemantics.DONATE_INPUT:
|
|
raise ValueError(
|
|
"donate=True is not allowed during tranposition of device_put."
|
|
" Please file an issue if you want this to be supported.")
|
|
elif cp == ArrayCopySemantics.REUSE_INPUT:
|
|
new_copy_semantics.append(ArrayCopySemantics.ALWAYS_COPY)
|
|
else:
|
|
assert cp == ArrayCopySemantics.ALWAYS_COPY
|
|
new_copy_semantics.append(ArrayCopySemantics.ALWAYS_COPY)
|
|
ys = device_put_p.bind(*dp_ct, devices=srcs, srcs=devices,
|
|
copy_semantics=tuple(new_copy_semantics))
|
|
for i, y in zip(indices, ys):
|
|
results[i] = y
|
|
return results
|
|
ad.primitive_jvps[device_put_p] = partial(ad.linear_jvp, device_put_p)
|
|
ad.primitive_transposes[device_put_p] = _device_put_transpose
|
|
|
|
def _device_put_batcher(batched_args, batch_dims, **params):
|
|
mapped_batch_dims = [bd for bd in batch_dims if bd is not batching.not_mapped]
|
|
assert not mapped_batch_dims or all(
|
|
mapped_batch_dims[0] == bd for bd in mapped_batch_dims[1:]
|
|
), batch_dims
|
|
return device_put_p.bind(*batched_args, **params), batch_dims
|
|
batching.primitive_batchers[device_put_p] = _device_put_batcher
|
|
|
|
def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
|
|
# TODO(yashkatariya): Maybe we should add the custom calls anyways if it's
|
|
# being used inside jit? Atleast for now, this preserves the old behavior.
|
|
if ctx.module_context.all_default_mem_kind:
|
|
return xs
|
|
def lower(x, device, aval, out_aval):
|
|
if ((isinstance(device, Sharding) and device.memory_kind is not None) or
|
|
isinstance(device, core.MemorySpace)):
|
|
if isinstance(device, Sharding):
|
|
if config.use_shardy_partitioner.value:
|
|
x = mlir.wrap_with_sharding_op(
|
|
ctx, x, out_aval,
|
|
device._to_sdy_sharding(aval.ndim))
|
|
else:
|
|
x = mlir.wrap_with_sharding_op(
|
|
ctx, x, out_aval,
|
|
device._to_xla_hlo_sharding(aval.ndim).to_proto())
|
|
mem_kind = (core.mem_space_to_kind(device)
|
|
if isinstance(device, core.MemorySpace) else device.memory_kind)
|
|
assert mem_kind is not None
|
|
x = mlir.wrap_with_memory_kind(x, mem_kind, out_aval)
|
|
return x
|
|
return x
|
|
return list(map(lower, xs, devices, ctx.avals_in, ctx.avals_out))
|
|
|
|
mlir.register_lowering(
|
|
device_put_p, _tpu_gpu_device_put_lowering, platform='tpu')
|
|
mlir.register_lowering(
|
|
device_put_p, _tpu_gpu_device_put_lowering, platform='gpu')
|
|
|
|
|
|
def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
|
|
return xs
|
|
mlir.register_lowering(device_put_p, _common_device_put_lowering)
|