Files
opencv/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py
T
2026-05-06 19:47:31 +07:00

2124 lines
83 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of pmap and related functionality."""
from __future__ import annotations
import collections
from collections.abc import Callable, Sequence, Iterable
import dataclasses
from functools import partial
import functools
import itertools as it
import logging
import math
from typing import Any, NamedTuple, Union
import numpy as np
from jax._src import api
from jax._src import array
from jax._src import compiler
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import jaxpr_util
from jax._src import literals
from jax._src import op_shardings
from jax._src import pjit
from jax._src import profiler
from jax._src import sharding_impls
from jax._src import stages
from jax._src import tree_util
from jax._src import typing
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.abstract_arrays import array_types
from jax._src.core import ShapedArray
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
from jax._src.layout import Layout, AutoLayout, Format
from jax._src.lib import _jax
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import (Sharding as JSharding, IndivisibleError,
common_is_equivalent_to)
from jax._src.mesh import (AbstractMesh, Mesh, get_abstract_mesh,
get_concrete_mesh)
from jax._src.sharding_impls import (
ArrayMapping, AUTO, UnspecifiedValue, SingleDeviceSharding,
make_single_device_sharding, GSPMDSharding,
NamedSharding, PartitionSpec as P)
from jax._src.util import (safe_map, safe_zip, partition_list,
unzip2, weakref_lru_cache, tuple_insert)
from jax._src.state.types import AbstractRef, RefEffect
from jax._src.typing import ArrayLike
unsafe_map, map = map, safe_map
zip, unsafe_zip = safe_zip, zip
logger = logging.getLogger(__name__)
Index = Union[int, slice, tuple[Union[int, slice], ...]]
PyTreeDef = tree_util.PyTreeDef
MeshAxisName = sharding_impls.MeshAxisName
### util
def identity(x): return x
@profiler.annotate_function
def shard_args(
shardings: Sequence[JSharding],
layouts: Sequence[Any | None],
copy_semantics: Sequence[xc.ArrayCopySemantics],
args: Sequence[Any],
canonicalize: bool = True,
) -> Sequence[xc.ArrayImpl]:
# Fast path for one argument.
if len(args) == 1:
arg = args[0]
if canonicalize:
arg = dtypes.canonicalize_value(arg)
return shard_arg_handlers[type(arg)]([arg], shardings, layouts,
copy_semantics)
# type(arg) -> (list[indices], list[args], list[shardings], list[layouts],
# list[copy_semantics])
batches = collections.defaultdict(lambda: ([], [], [], [], []))
for i, (arg, sharding, layout, cs) in enumerate(
safe_zip(args, shardings, layouts, copy_semantics)):
if canonicalize:
arg = dtypes.canonicalize_value(arg)
batch = batches[type(arg)]
batch[0].append(i)
batch[1].append(arg)
batch[2].append(sharding)
batch[3].append(layout)
batch[4].append(cs)
# Call `shard_arg_handlers` per batch and build a flat list of arrays returned
# from each call in the same order as `args`. Since `batches` is grouped by
# types, we cannot simply flatten the results and we have to use the original
# indices to put each array back to its original position.
results: list[typing.Array | None] = [None] * len(args)
for t, (indices, a, s, l, xcs) in batches.items():
outs = shard_arg_handlers[t](a, s, l, xcs)
for i, out in safe_zip(indices, outs):
results[i] = out
assert all(result is not None for result in results)
return results
shard_arg_handlers: dict[
Any,
Callable[
[Sequence[Any], Sequence[Any], Sequence[Any],
Sequence[xc.ArrayCopySemantics]],
Sequence[Any],
],
] = {}
@util.cache(max_size=2048, trace_context_in_key=False)
def is_default_layout(curr_layout, sharding, aval):
if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue):
return True
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
dtypes.issubdtype(aval.dtype, dtypes.extended)):
return True
if isinstance(curr_layout, AutoLayout):
return False
d = sharding._device_assignment[0]
shard_shape = sharding.shard_shape(aval.shape)
try:
# TODO(yashkatariya): Replace this with normal `==` check once CPU supports
# int4.
return is_user_xla_layout_equal(
curr_layout,
Layout.from_pjrt_layout(
d.client.get_default_layout(aval.dtype, shard_shape, d)))
except _jax.JaxRuntimeError as e:
msg, *_ = e.args
if isinstance(msg, str) and msg.startswith("UNIMPLEMENTED"):
return True
else:
raise
def _masked_array_error(xs, shardings, layouts, copy_semantics):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error
def _shard_np_array(xs, shardings, layouts, copy_semantics):
results = []
for x, sharding, layout in safe_zip(xs, shardings, layouts):
devices = sharding._addressable_device_assignment
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = core.shaped_abstractify(x)
if layout is not None:
results.append(api.device_put(x, Format(layout, sharding)))
else:
if sharding.is_fully_replicated:
shards = [x] * len(devices)
else:
indices = tuple(sharding.addressable_devices_indices_map(x.shape).values())
shards = [x[i] for i in indices]
results.append(batched_device_put(aval, sharding, shards, devices))
return results
for _t in array_types:
shard_arg_handlers[_t] = _shard_np_array
shard_arg_handlers[literals.TypedNdArray] = _shard_np_array
def _shard_python_scalar(xs, shardings, layouts, copy_semantics):
return shard_args(shardings, layouts, copy_semantics,
[np.array(x) for x in xs])
for _t in dtypes.python_scalar_types:
shard_arg_handlers[_t] = _shard_python_scalar
def _shard_typed_scalar(xs, shardings, layouts, copy_semantics):
return _shard_np_array(
[literals.TypedNdArray(
np.array(x, dtype=x.dtype),
aval=core.ShapedArray((), x.dtype, weak_type=True))
for x in xs],
shardings, layouts, copy_semantics
)
for _t in literals.typed_scalar_types:
shard_arg_handlers[_t] = _shard_typed_scalar
def _shard_mutable_array(xs, shardings, layouts, copy_semantics):
bufs = [x._refs._buf for x in xs]
return shard_args(shardings, layouts, copy_semantics, bufs)
shard_arg_handlers[core.Ref] = _shard_mutable_array
def batched_device_put(aval: core.ShapedArray,
sharding: JSharding, xs: Sequence[Any],
devices: Sequence[xc.Device], committed: bool = True,
enable_x64: bool | None = None):
util.test_event("batched_device_put_start")
try:
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and x.sharding.num_devices == 1
and x.devices() == {d})]
if len(bufs) == len(xs) > 0:
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
return xc.batched_device_put(aval, sharding, xs, list(devices), committed,
enable_x64=enable_x64)
finally:
util.test_event("batched_device_put_end")
def global_aval_to_result_handler(
aval: core.AbstractValue, out_sharding, committed: bool
) -> Callable[[Sequence[xc.ArrayImpl]], Any]:
"""Returns a function for handling the raw buffers of a single output aval.
Args:
aval: The global output AbstractValue.
out_axis_resources: A PartitionSpec specifying the sharding of outputs.
Used for creating GSDAs.
global_mesh: The global device mesh that generated this output. Used
for creating GSDAs.
Returns:
A function for handling the Buffers that will eventually be produced
for this output. The function will return an object suitable for returning
to the user, e.g. an Array.
"""
try:
return global_result_handlers[type(aval)](aval, out_sharding, committed)
except KeyError as err:
raise TypeError(
f"No pxla_result_handler for type: {type(aval)}") from err
PxlaResultHandler = Callable[..., xc._xla.ResultHandler]
global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
class InputsHandler:
__slots__ = ("handler", "in_shardings", "in_layouts", "local_devices",
"input_indices")
def __init__(self, in_shardings, in_layouts, local_devices=None,
input_indices=None):
self.handler = partial(
shard_args, in_shardings, in_layouts,
[xc.ArrayCopySemantics.REUSE_INPUT] * len(in_shardings))
self.in_shardings = in_shardings
self.in_layouts = in_layouts
self.local_devices = local_devices
self.input_indices = input_indices
def __call__(self, input_buffers):
return self.handler(input_buffers)
def __str__(self):
return ("InputsHandler(\n"
f"in_shardings={self.in_shardings},\n"
f"in_layouts={self.in_layouts},\n"
f"local_devices={self.local_devices},\n"
f"input_indices={self.input_indices})")
class ResultsHandler:
__slots__ = ("handlers", "out_shardings", "out_avals")
def __init__(self, handlers, out_shardings, out_avals):
self.handlers = handlers
self.out_shardings = out_shardings
self.out_avals = out_avals
def __call__(self, out_bufs):
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
def global_avals_to_results_handler(
global_out_avals: Sequence[ShapedArray],
shardings: Sequence[JSharding],
committed: bool) -> ResultsHandler:
handlers = [
global_aval_to_result_handler(global_aval, s, committed)
for global_aval, s in safe_zip(global_out_avals, shardings)
]
return ResultsHandler(handlers, shardings, global_out_avals)
class ExecuteReplicated:
"""The logic to shard inputs, execute a replicated model, returning outputs."""
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
'has_unordered_effects', 'ordered_effects', 'keepalive',
'has_host_callbacks', '_local_devices', 'kept_var_idx',
'mut', 'pgle_profiler', '__weakref__']
def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
out_handler: ResultsHandler,
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: set[int],
mut: MutationData | None,
pgle_profiler: profiler.PGLEProfiler | None = None):
self.xla_executable = xla_executable
self.name = name
self.backend = backend
self.in_handler = in_handler
self.out_handler = out_handler
self.has_unordered_effects = bool(unordered_effects)
self.ordered_effects = ordered_effects
self._local_devices = self.xla_executable.local_devices()
self.keepalive = keepalive
self.has_host_callbacks = has_host_callbacks
self.kept_var_idx = kept_var_idx
self.mut = mut
self.pgle_profiler = pgle_profiler
def _add_tokens_to_inputs(self, input_bufs):
if self.ordered_effects:
tokens = [
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)._buf
for eff in self.ordered_effects
]
input_bufs = [*tokens, *input_bufs]
return input_bufs
def _handle_token_bufs(self, token_bufs, sharded_token):
# token_bufs: Sequence[Sequence[tokenArray]], for each effect the returned
# token buffers.
# sharded_token: ShardedToken, containing the RuntimeTokens for each device
for i, device in enumerate(self._local_devices):
dispatch.runtime_tokens.set_output_runtime_token(
device, sharded_token.get_token(i))
for eff, token_buf in zip(self.ordered_effects, token_bufs):
assert len(token_buf) > 0
if len(token_buf) == 1:
dispatch.runtime_tokens.set_token_result(eff, core.Token(token_buf[0]))
else:
token_devices = []
for token in token_buf:
assert len(token.sharding.device_set) == 1
token_devices.append(token.sharding._device_assignment[0])
s = NamedSharding(Mesh(token_devices, 'x'), P('x'))
global_token_array = array.make_array_from_single_device_arrays(
(0,), s, token_buf
)
dispatch.runtime_tokens.set_token_result(
eff, core.Token(global_token_array)
)
@profiler.annotate_function
def __call__(self, *args):
if config.no_execution.value:
raise RuntimeError(
f"JAX tried to execute function {self.name}, but the no_execution config "
"option is set")
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
if self.mut:
args = [*args, *self.mut.in_mut]
input_bufs = self.in_handler(args)
with profiler.PGLEProfiler.trace(self.pgle_profiler):
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
input_bufs = self._add_tokens_to_inputs(input_bufs)
results = self.xla_executable.execute_sharded(input_bufs, with_tokens=True)
result_token_bufs = results.consume_with_handlers(
[lambda xs: xs] * len(self.ordered_effects), strict=False)
sharded_runtime_token = results.consume_token()
self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
else:
results = self.xla_executable.execute_sharded(input_bufs)
handlers = self.out_handler.handlers
if dispatch.needs_check_special():
special_check = functools.partial(
dispatch.check_special_array, self.name)
handlers = [h.pre_wrap(special_check) for h in handlers]
out = results.consume_with_handlers(handlers)
if (self.pgle_profiler is not None and self.pgle_profiler.is_running()
and len(out) > 0):
out[0].block_until_ready()
if self.mut is None:
return out
else:
out_ = []
for i, o in zip(self.mut.out_mut, out):
if i is not None:
try: args[i]._refs._buf._replace_with(o)
except AttributeError: pass # TODO(mattjj): remove float0
else:
out_.append(o)
return out_
def _axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError(f"unbound axis name: {axis_name}") from None
def axis_groups(axis_env: sharding_impls.AxisEnv, name) -> tuple[tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(_axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, math.prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)
def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
Args:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(math.prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(math.prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))
def tile_aval_nd(axis_sizes, in_axes: ArrayMapping, aval):
assert isinstance(aval, ShapedArray)
shape = list(aval.shape)
for name, axis in in_axes.items():
assert shape[axis] % axis_sizes[name] == 0
shape[axis] //= axis_sizes[name]
return aval.update(shape=tuple(shape))
def untile_aval_nd(axis_sizes, out_axes: ArrayMapping, aval):
assert isinstance(aval, ShapedArray)
shape = list(aval.shape)
for name, axis in out_axes.items():
shape[axis] *= axis_sizes[name]
return aval.update(shape=tuple(shape))
def mesh_local_to_global(mesh, axes: ArrayMapping, aval):
return untile_aval_nd(mesh.shape, axes,
tile_aval_nd(mesh.local_mesh.shape, axes, aval))
def mesh_global_to_local(mesh, axes: ArrayMapping, aval):
return untile_aval_nd(mesh.local_mesh.shape, axes,
tile_aval_nd(mesh.shape, axes, aval))
def manual_proto(
aval: core.ShapedArray,
manual_axes_set: frozenset[sharding_impls.MeshAxisName], mesh: Mesh):
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
and all others as replicated.
"""
named_mesh_shape = mesh.shape
mesh_shape = list(named_mesh_shape.values())
axis_order = {axis: i for i, axis in enumerate(mesh.axis_names)}
manual_axes = sorted(manual_axes_set, key=str)
replicated_axes = [axis for axis in mesh.axis_names
if axis not in manual_axes_set]
tad_perm = ([axis_order[a] for a in replicated_axes] +
[axis_order[a] for a in manual_axes])
tad_shape = [1] * aval.ndim
tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes]))
tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes]))
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = tad_shape
proto.iota_reshape_dims = mesh_shape
proto.iota_transpose_perm = tad_perm
proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
return proto
def check_if_any_auto(
shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool:
for s in shardings:
if isinstance(s, AUTO):
return True
return False
ShardingInfo = tuple[
Union[JSharding, UnspecifiedValue, AUTO],
stages.MismatchType,
Union[Any, None], # Any is dispatch.SourceInfo to avoid circular imports
]
def get_default_device() -> xc.Device:
if isinstance(config.default_device.value, str):
return xb.get_backend(config.default_device.value).local_devices()[0]
else:
return config.default_device.value or xb.local_devices()[0]
def _get_and_check_device_assignment(
shardings: Iterable[ShardingInfo],
ctx_mesh: Mesh | AbstractMesh,
) -> tuple[xc.Client, tuple[xc.Device, ...] | None, int]:
first_sharding_info = None
abstract_mesh = (
ctx_mesh if not ctx_mesh.empty and isinstance(ctx_mesh, AbstractMesh)
else None)
any_concrete_sharding = (
True if not ctx_mesh.empty and isinstance(ctx_mesh, Mesh) else False)
for sh, s_type, source_info in shardings:
if isinstance(sh, UnspecifiedValue):
continue
elif isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh):
if (abstract_mesh is not None and not sh.mesh.empty and
abstract_mesh.size != sh.mesh.size):
raise ValueError("AbstractMesh should be of the same size across all "
f"shardings. Got {abstract_mesh} and {sh.mesh}")
abstract_mesh = sh.mesh
else:
any_concrete_sharding = True
arr_device_assignment = sh._device_assignment
if first_sharding_info is None:
first_sharding_info = (arr_device_assignment, s_type, source_info)
if ctx_mesh.empty:
if first_sharding_info[0] != arr_device_assignment:
raise stages.DeviceAssignmentMismatchError([
stages.DeviceAssignmentMismatch(*first_sharding_info),
stages.DeviceAssignmentMismatch(
arr_device_assignment, s_type, source_info)])
elif isinstance(ctx_mesh, AbstractMesh):
if ctx_mesh.size != len(arr_device_assignment):
raise stages.DeviceAssignmentMismatchError([
stages.DeviceAssignmentMismatch(
ctx_mesh.size, stages.MismatchType.CONTEXT_DEVICES, None),
stages.DeviceAssignmentMismatch(
arr_device_assignment, s_type, source_info)])
else:
if ctx_mesh._flat_devices_tuple != arr_device_assignment:
raise stages.DeviceAssignmentMismatchError([
stages.DeviceAssignmentMismatch(
ctx_mesh._flat_devices_tuple,
stages.MismatchType.CONTEXT_DEVICES, None),
stages.DeviceAssignmentMismatch(
arr_device_assignment, s_type, source_info)])
device_assignment: tuple[xc.Device, ...]
if (first_sharding_info is None and not ctx_mesh.empty and
isinstance(ctx_mesh, Mesh)):
device_assignment = ctx_mesh._flat_devices_tuple
elif first_sharding_info is None:
device_assignment = (get_default_device(),)
else:
device_assignment = first_sharding_info[0] # pyrefly: ignore[bad-assignment]
backend = xb.get_device_backend(device_assignment[0])
if (any_concrete_sharding and abstract_mesh is not None and
len(device_assignment) != abstract_mesh.size):
raise ValueError(
f"AbstractMesh size: {abstract_mesh.size} does not match the"
f" device assignment size: {len(device_assignment)}")
if any_concrete_sharding or abstract_mesh is None:
return backend, device_assignment, len(device_assignment)
else:
return backend, None, abstract_mesh.size
MaybeSharding = Union[JSharding, UnspecifiedValue]
def prune_unused_inputs(
jaxpr: core.Jaxpr,
) -> tuple[core.Jaxpr, set[int], set[int]]:
used_outputs = [True] * len(jaxpr.outvars)
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
kept_var_idx = {i for i, b in enumerate(used_inputs) if b}
return new_jaxpr, kept_const_idx, kept_var_idx
@weakref_lru_cache
def _dce_jaxpr(closed_jaxpr, keep_unused, donated_invars, auto_spmd_lowering):
assert isinstance(closed_jaxpr, core.ClosedJaxpr)
jaxpr = closed_jaxpr.jaxpr
consts = closed_jaxpr.consts
in_avals = closed_jaxpr.in_avals
if (keep_unused or auto_spmd_lowering or
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in in_avals)):
kept_var_idx = set(range(len(in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
return closed_jaxpr, donated_invars, kept_var_idx
class MutationData(NamedTuple):
in_mut: list[core.Ref]
# out_mut[o_idx] = i_idx, when the output[o_idx] corresponds to the
# mutable array args[i_idx]. None when it does not correspond to a mutable array.
out_mut: list[int | None]
@weakref_lru_cache
def _discharge_refs(
jaxpr: core.ClosedJaxpr
) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]:
from jax._src.state.discharge import discharge_state2 # pytype: disable=import-error
jaxpr, in_mut = _move_mutable_consts(jaxpr)
new_jaxpr = discharge_state2(jaxpr)
count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end
inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals)
if isinstance(a, AbstractRef)}
outin_map = {j: i for i, j in inout_map.items()}
inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals))))
out_mut = list(map(outin_map.get, range(len(new_jaxpr.out_avals))))
return new_jaxpr, inout_aliases, MutationData(in_mut, out_mut)
@weakref_lru_cache
def _move_mutable_consts(
closed_jaxpr: core.ClosedJaxpr,
) -> tuple[core.ClosedJaxpr, list[core.Ref]]:
jaxpr = closed_jaxpr.jaxpr
hoist = [isinstance(c, core.Ref) for c in closed_jaxpr.consts]
consts, in_mut = partition_list(hoist, closed_jaxpr.consts)
constvars, mutvars = partition_list(hoist, jaxpr.constvars)
invars = (*jaxpr.invars, *mutvars)
effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns)
# TODO(mattjj): debug_info must be updated...
jaxpr = closed_jaxpr.jaxpr.replace(
constvars=constvars, invars=invars, effects=effects,
debug_info=closed_jaxpr.debug_info.with_unknown_names())
return core.ClosedJaxpr(jaxpr, consts), in_mut
@weakref_lru_cache
def _discharge_internal_refs(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
from jax._src.state.discharge import discharge_state # pytype: disable=import-error
jaxpr_, consts = discharge_state(jaxpr.jaxpr, jaxpr.consts)
jaxpr_._debug_info = jaxpr.jaxpr._debug_info
return core.ClosedJaxpr(jaxpr_, consts)
class SemanticallyEqualShardings:
def __init__(self, shardings, avals):
self.shardings = shardings
self.avals = avals
def __hash__(self):
return hash(tuple(s if isinstance(s, (UnspecifiedValue, AUTO)) else
(s._to_xla_hlo_sharding(a.ndim), s.memory_kind)
for s, a in zip(self.shardings, self.avals)))
def __eq__(self, other):
if not isinstance(other, SemanticallyEqualShardings):
return False
is_ua = lambda x: isinstance(x, (UnspecifiedValue, AUTO))
return all(common_is_equivalent_to(s, o, a.ndim, check_devices=False)
if not is_ua(s) and not is_ua(o) else s == o
for s, o, a in zip(self.shardings, other.shardings, self.avals))
@weakref_lru_cache
def _cached_lowering_to_hlo(
closed_jaxpr: core.ClosedJaxpr, module_name, backend, num_const_args: int,
in_avals, semantic_in_shardings, semantic_out_shardings,
in_layouts, out_layouts, num_devices, device_assignment, donated_invars,
all_default_mem_kind, inout_aliases: None | tuple[None | int, ...],
propagated_out_mem_kinds: tuple[None | str, ...], platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters,
abstract_mesh: AbstractMesh | None):
# in_avals, in_shardings, in_layouts include the jaxpr_const_args(jaxpr)
out_avals = closed_jaxpr.out_avals
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
if logger.isEnabledFor(log_priority):
logger.log(log_priority,
"Compiling %s with global shapes and types %s. "
"Argument mapping: %s.",
module_name, in_avals, in_shardings)
in_mlir_shardings = map(_to_logical_sharding, in_avals, in_shardings)
out_mlir_shardings = map(_to_logical_sharding, out_avals, out_shardings)
replicated_args = [False] * len(in_avals)
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment,
abstract_mesh)
if num_devices > 1:
unsupported_effects = effects.ordered_effects.filter_in(closed_jaxpr.effects)
unsupported_effects = effects.shardable_ordered_effects.filter_not_in(
unsupported_effects)
if len(unsupported_effects) > 0:
raise ValueError(
"The following ordered effects are not supported for "
f"more than 1 device: {unsupported_effects}")
ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
arg_names = ("",) * num_const_args + jaxpr._debug_info.safe_arg_names(len(in_avals) - num_const_args)
with dispatch.log_elapsed_time(
"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec",
fun_name=module_name, event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
lowering_result = mlir.lower_jaxpr_to_module(
module_name,
closed_jaxpr,
num_const_args=num_const_args,
ordered_effects=ordered_effects,
backend=backend,
platforms=platforms,
axis_context=axis_ctx,
in_avals=in_avals,
donated_args=donated_invars,
replicated_args=replicated_args,
arg_shardings=in_mlir_shardings,
result_shardings=out_mlir_shardings,
in_layouts=in_layouts,
out_layouts=out_layouts,
arg_names=arg_names,
result_names=jaxpr._debug_info.safe_result_paths(len(out_avals)),
num_partitions=num_devices,
all_default_mem_kind=all_default_mem_kind,
input_output_aliases=inout_aliases,
propagated_out_mem_kinds=propagated_out_mem_kinds,
lowering_parameters=lowering_parameters)
tuple_args = dispatch.should_tuple_args(len(in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
return (lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks, unordered_effects, ordered_effects,
tuple_args, lowering_result.shape_poly_state)
@util.cache(max_size=2048, trace_context_in_key=False)
def _create_device_list_cached(device_assignment: tuple[xc.Device, ...]
) -> xc.DeviceList:
return xc.DeviceList(device_assignment)
def _create_device_list(
device_assignment: tuple[xc.Device, ...] | xc.DeviceList | None
) -> xc.DeviceList | None:
if device_assignment is None or isinstance(device_assignment, xc.DeviceList):
return device_assignment
return _create_device_list_cached(device_assignment)
@weakref_lru_cache
def jaxpr_transfer_mem_kinds(jaxpr: core.Jaxpr):
out = []
for eqn in jaxpr.eqns:
if eqn.primitive is dispatch.device_put_p:
out.extend(d for d in eqn.params['devices']
if isinstance(d, core.MemorySpace))
elif eqn.primitive.name == 'compute_on':
out.extend(o for o in eqn.params['out_memory_spaces'])
elif eqn.primitive.name == 'call_exported':
out.extend(aval.memory_space for aval in eqn.params['exported'].out_avals)
for subjaxpr in core.subjaxprs(jaxpr):
out.extend(jaxpr_transfer_mem_kinds(subjaxpr))
return out
def are_all_shardings_default_mem_kind(shardings):
for i in shardings:
if isinstance(i, (UnspecifiedValue, AUTO)):
continue
mem_kind = (core.mem_space_to_kind(i) if isinstance(i, core.MemorySpace)
else i.memory_kind)
if mem_kind is None:
continue
if mem_kind != 'device':
return False
return True
@weakref_lru_cache
def get_out_layouts_via_propagation(closed_jaxpr: core.ClosedJaxpr
) -> tuple[None | Layout]:
env = {}
jaxpr = closed_jaxpr.jaxpr
def read(var):
if type(var) is core.Literal:
return None
return env[var]
def write(var, val):
env[var] = val
safe_map(write, jaxpr.invars, [None] * len(jaxpr.invars))
safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars))
for eqn in jaxpr.eqns:
if eqn.primitive is pjit.sharding_constraint_p:
out_eqn_layouts = [eqn.params['layout']]
elif eqn.primitive is pjit.layout_constraint_p:
out_eqn_layouts = [eqn.params['layout']]
else:
out_eqn_layouts = [None] * len(eqn.outvars)
safe_map(write, eqn.outvars, out_eqn_layouts)
return tuple(safe_map(read, jaxpr.outvars))
MaybeLayout = Sequence[Union[Layout, AutoLayout, None]]
class AllArgsInfo(NamedTuple):
"""Avals and debug_info for all arguments prior to DCE."""
in_avals: Sequence[core.ShapedArray]
debug_info: core.DebugInfo
def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts):
if (any(isinstance(e, RefEffect) for e in closed_jaxpr.effects) or
any(isinstance(a, AbstractRef) for a in closed_jaxpr.in_avals)):
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
in_shardings = (*in_shardings, *(
pjit.finalize_arg_sharding(c.sharding, c.committed) for c in mut.in_mut))
in_layouts = (*in_layouts, *(c.format.layout if hasattr(c, 'format')
else None for c in mut.in_mut))
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
out_layouts_ = iter(zip(out_shardings, out_layouts))
out_shardings, out_layouts = unzip2(
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
for i in mut.out_mut)
assert next(out_layouts_, None) is None
else:
inout_aliases = mut = None
if any(isinstance(e, core.InternalMutableArrayEffect) for e in closed_jaxpr.effects):
closed_jaxpr = _discharge_internal_refs(closed_jaxpr)
return (closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts)
def hoist_constants_as_args(
closed_jaxpr: core.ClosedJaxpr, global_in_avals, in_shardings, in_layouts,
donated_invars, kept_var_idx: set[int], inout_aliases, mut,
all_args_info: AllArgsInfo):
const_args, const_arg_avals = unzip2(
core.jaxpr_const_args(closed_jaxpr.jaxpr)
)
num_const_args = len(const_args)
if num_const_args:
global_in_avals = list(const_arg_avals) + global_in_avals
ca_shardings = pjit.const_args_shardings(const_args)
in_shardings = (*ca_shardings, *in_shardings)
ca_layouts = pjit.const_args_layouts(const_args, const_arg_avals,
ca_shardings)
in_layouts = (*ca_layouts, *in_layouts)
donated_invars = (False,) * num_const_args + donated_invars
kept_var_idx = set(range(num_const_args)).union(
{kv + num_const_args for kv in kept_var_idx})
if inout_aliases is not None:
inout_aliases = (None,) * num_const_args + inout_aliases
if mut is not None:
mut = MutationData(
in_mut=mut.in_mut,
out_mut=[None if i_idx is None else i_idx + num_const_args
for i_idx in mut.out_mut])
if all_args_info.debug_info.arg_names is None:
arg_names = None
else:
arg_names = (("",) * num_const_args + all_args_info.debug_info.arg_names)
all_args_info = AllArgsInfo(
[*const_arg_avals, *all_args_info.in_avals],
all_args_info.debug_info._replace(arg_names=arg_names))
return (const_args, global_in_avals, in_shardings, in_layouts, donated_invars,
kept_var_idx, inout_aliases, mut, all_args_info)
@util.cache(max_size=1024, trace_context_in_key=False)
def _abstract_to_concrete_mesh(abstract_mesh, device_assignment):
np_dev = np.vectorize(lambda i: device_assignment[i],
otypes=[object])(np.arange(len(device_assignment)))
return Mesh(np_dev.reshape(abstract_mesh.axis_sizes),
abstract_mesh.axis_names, axis_types=abstract_mesh.axis_types)
def _concretize_abstract_out_shardings(shardings, avals, device_assignment,
out_mem_kinds):
if device_assignment is None:
return shardings
out: list[UnspecifiedValue | JSharding] = []
for s, a, mem_kind in zip(shardings, avals, out_mem_kinds):
if isinstance(s, UnspecifiedValue) and isinstance(a, core.ShapedArray):
if a.sharding.mesh.empty:
out.append(s)
elif a.sharding.mesh._are_all_axes_auto_or_manual:
out.append(s)
else:
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
for sp in a.sharding.spec])
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
out.append(NamedSharding(
_abstract_to_concrete_mesh(a.sharding.mesh, device_assignment),
spec, memory_kind=mem_kind))
else:
out.append(s)
return tuple(out)
def _get_context_mesh(context_mesh: Mesh | AbstractMesh) -> Mesh | AbstractMesh:
if isinstance(context_mesh, AbstractMesh):
return context_mesh
if get_concrete_mesh().empty:
return context_mesh
cur_mesh = get_abstract_mesh()
if cur_mesh.empty or context_mesh.empty:
return context_mesh
if cur_mesh == context_mesh.abstract_mesh:
return context_mesh
assert context_mesh.size == cur_mesh.size
return Mesh(context_mesh.devices.reshape(cur_mesh.axis_sizes),
cur_mesh.axis_names, cur_mesh.axis_types)
@profiler.annotate_function
def lower_sharding_computation(
closed_jaxpr: core.ClosedJaxpr,
api_name: str,
fun_name: str,
in_shardings: Sequence[MaybeSharding],
out_shardings: Sequence[MaybeSharding],
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
donated_invars: Sequence[bool],
*,
keep_unused: bool,
context_mesh: Mesh | AbstractMesh,
compiler_options_kvs: tuple[tuple[str, Any], ...],
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None,
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
The caller of this code can pass in a singleton UNSPECIFIED because the
number of out_avals might not be known at that time and
lower_sharding_computation calculates the number of out_avals so it can apply
the singleton UNSPECIFIED to all out_avals."""
auto_spmd_lowering = check_if_any_auto(it.chain(in_shardings, out_shardings))
all_args_info = AllArgsInfo(closed_jaxpr.in_avals, closed_jaxpr.jaxpr._debug_info)
closed_jaxpr, donated_invars, kept_var_idx = _dce_jaxpr(
closed_jaxpr, keep_unused, donated_invars, auto_spmd_lowering)
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)
(closed_jaxpr, inout_aliases, mut, in_shardings, in_layouts,
donated_invars, out_shardings, out_layouts) = _discharge_refs_jaxpr(
closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings,
out_layouts)
jaxpr = closed_jaxpr.jaxpr
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
if lowering_parameters.hoist_constants_as_args:
(const_args, global_in_avals, in_shardings, in_layouts, donated_invars,
kept_var_idx, inout_aliases, mut, all_args_info) = hoist_constants_as_args(
closed_jaxpr, global_in_avals, in_shardings, in_layouts,
donated_invars, kept_var_idx, inout_aliases, mut, all_args_info)
else:
const_args = []
# If layout is propagated, then set the out_layout in the top module to AUTO
# so that XLA can override the entry_computation_layout. The propagated
# layout will be set via a custom call.
out_layouts_via_prop = get_out_layouts_via_propagation(closed_jaxpr)
out_layouts = tuple(Layout.AUTO if p is not None else o
for o, p in safe_zip(out_layouts, out_layouts_via_prop))
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
len(out_shardings), len(out_layouts), len(global_out_avals))
context_mesh = _get_context_mesh(context_mesh)
# Device assignment across all inputs, outputs and shardings inside jaxpr
# should be the same.
unique_intermediate_shardings = util.stable_unique(
dispatch.get_intermediate_shardings(jaxpr))
unique_const_shardings = util.stable_unique(in_shardings[:len(const_args)])
unique_in_shardings = util.stable_unique(in_shardings[len(const_args):])
unique_out_shardings = util.stable_unique(out_shardings)
backend, device_assignment, num_devices = _get_and_check_device_assignment(
it.chain(
((i, stages.MismatchType.ARG_SHARDING, None) for i in unique_in_shardings),
((c, stages.MismatchType.CONST_SHARDING, None) for c in unique_const_shardings),
((o, stages.MismatchType.OUT_SHARDING, None) for o in unique_out_shardings),
((js, stages.MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in unique_intermediate_shardings)),
context_mesh)
unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]
unique_in_shardings = unique_in_shardings | unique_const_shardings # pyrefly: ignore[unsupported-operation]
del unique_const_shardings
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
if device_assignment is None:
if lowering_platforms is None:
raise ValueError(
"Passing lowering_platforms via jax.export or"
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
" only AbstractMesh exists in a jitted computation. Got context"
f" mesh: {context_mesh}")
if prim_requires_devices:
raise ValueError(
"AbstractMesh cannot be used when jaxpr contains primitives that"
" require devices to be present during lowering.")
# For device_assignment == 1, this doesn't matter.
if device_assignment is not None and len(device_assignment) > 1:
rep_gs = GSPMDSharding.get_replicated(device_assignment)
in_shardings = tuple(
rep_gs if (isinstance(s, UnspecifiedValue) and
aval is not core.abstract_token and aval.ndim == 0)
else s for s, aval in zip(in_shardings, global_in_avals))
for a in global_out_avals:
if (a is not core.abstract_token and not a.sharding.mesh.empty and
a.sharding.mesh.are_all_axes_explicit and
device_assignment is not None and
len(device_assignment) != a.sharding.mesh.size):
raise ValueError(
f"Length of device assignment {len(device_assignment)} is not equal"
f" to the size of the mesh {a.sharding.mesh.size} of aval"
f" {a.str_short(True, True)}. Please enter your `jit` into a mesh"
" context via `jax.set_mesh`.")
# TODO(parkers): One _raw_platform has been unified with platform,
# change this back to just read platform.
platforms = lowering_platforms or (
getattr(backend, "_raw_platform", backend.platform),)
device_list = _create_device_list(device_assignment)
transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
committed = bool(
not context_mesh.empty
or num_devices > 1
or any(not isinstance(s, UnspecifiedValue) for s in it.chain(
unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings))
or transfer_mem_kind_in_jaxpr
)
all_default_mem_kind = are_all_shardings_default_mem_kind(
it.chain(unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings, transfer_mem_kind_in_jaxpr))
if all_default_mem_kind:
propagated_out_mem_kinds = (None,) * len(global_out_avals)
else:
propagated_out_mem_kinds = tuple(
core.mem_space_to_kind(o.memory_space) for o in closed_jaxpr.out_avals)
out_shardings = _concretize_abstract_out_shardings(
out_shardings, global_out_avals, device_assignment,
propagated_out_mem_kinds)
global_in_avals = [core.update_aval_with_sharding(a, sh)
if isinstance(a, core.ShapedArray) else a
for a, sh in zip(global_in_avals, in_shardings)]
global_out_avals = [core.update_aval_with_sharding(a, sh)
if isinstance(a, core.ShapedArray) else a
for a, sh in zip(global_out_avals, out_shardings)]
############################ Build up the stableHLO ######################
abstract_mesh = None
if prim_requires_devices:
assert device_list is not None
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings):
if isinstance(sharding, NamedSharding):
if (abstract_mesh is not None and
abstract_mesh != sharding.mesh.abstract_mesh):
raise ValueError(
"mesh should be the same across the entire program. Got mesh"
f" shape for one sharding {abstract_mesh} and"
f" {sharding.mesh.abstract_mesh} for another")
abstract_mesh = sharding.mesh.abstract_mesh
semantic_in_shardings = SemanticallyEqualShardings(
in_shardings, global_in_avals)
semantic_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals)
jaxpr_util.maybe_dump_jaxpr_to_file(fun_name, closed_jaxpr.jaxpr)
module_name = util.wrap_name(api_name, fun_name)
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, module_name, backend,
len(const_args), tuple(global_in_avals),
semantic_in_shardings, semantic_out_shardings,
in_layouts, out_layouts, num_devices,
tuple(device_list) if prim_requires_devices else None, # pyrefly: ignore[bad-argument-type]
donated_invars, all_default_mem_kind, inout_aliases,
propagated_out_mem_kinds, platforms,
lowering_parameters=lowering_parameters,
abstract_mesh=abstract_mesh)
# backend and device_assignment is passed through to MeshExecutable because
# if keep_unused=False and all in_shardings are pruned, then there is no way
# to get the device_assignment and backend. So pass it to MeshExecutable
# because we calculate the device_assignment and backend before in_shardings,
# etc are pruned.
return MeshComputation(
module_name,
module,
const_args,
donated_invars,
platforms,
compiler_options_kvs,
device_list,
global_in_avals=global_in_avals,
global_out_avals=global_out_avals,
in_shardings=in_shardings,
out_shardings=out_shardings,
tuple_args=tuple_args,
auto_spmd_lowering=auto_spmd_lowering,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
host_callbacks=host_callbacks,
keepalive=keepalive,
kept_var_idx=kept_var_idx,
mut=mut,
backend=backend,
num_devices=num_devices,
committed=committed,
in_layouts=in_layouts,
out_layouts=out_layouts,
shape_poly_state=shape_poly_state,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler,
intermediate_shardings=unique_intermediate_shardings,
context_mesh=context_mesh)
def _to_logical_sharding(
aval: core.AbstractValue, sharding: MaybeSharding | AUTO
) -> JSharding | AUTO | None:
if isinstance(sharding, UnspecifiedValue):
return None
if isinstance(sharding, AUTO):
return sharding
elif isinstance(aval, (ShapedArray, AbstractRef)):
assert isinstance(sharding, JSharding)
return sharding
elif isinstance(aval, core.AbstractToken):
return None
else:
raise TypeError(aval)
class MeshComputation(stages.Lowering):
_hlo: ir.Module
_executable: MeshExecutable | None
def __init__(self, name: str, hlo: ir.Module,
const_args: list[ArrayLike],
donated_invars: Sequence[bool], platforms: Sequence[str],
compiler_options_kvs: tuple[tuple[str, Any], ...],
device_assignment: xc.DeviceList | tuple[xc.Device, ...] | None,
**compile_args):
self._name = name
self._hlo = hlo
self.const_args = const_args
self._donated_invars = donated_invars
self._platforms = platforms
self._compiler_options_kvs = compiler_options_kvs
self._device_list = _create_device_list(device_assignment)
self.compile_args = compile_args
self._executable = None
# -- stages.Lowering overrides
def stablehlo(self) -> ir.Module:
return self._hlo
def compile(self, compiler_options=None, *, device_assignment=None,
) -> MeshExecutable:
t_compiler_options = (() if compiler_options is None else
tuple(compiler_options.items()))
compiler_options_kvs = self._compiler_options_kvs + t_compiler_options
device_list = _create_device_list(device_assignment)
if device_list is None:
compilation_device_list = self._device_list
else:
if (self._device_list is not None and
self._device_list != device_list):
raise ValueError(
"device_assignment passed to `.compile` must match the"
" device_assignment calculated from array shardings and"
" out_shardings. Got device ids passed to compile"
f" {[d.id for d in device_list]} on platform"
f" {device_list[0].platform.upper()} and devices ids"
" calculated from array shardings and out_shardings"
f" {[d.id for d in self._device_list]} on platform"
f" {self._device_list[0].platform.upper()}")
compilation_device_list = device_list
assert isinstance(compilation_device_list, (type(None), xc.DeviceList))
# Only cache executable into `self` if `.compile()` in AOT is called without
# specifying compiler_options and device_assignment.
use_cache = compiler_options is None and device_assignment is None
if use_cache and self._executable is not None:
return self._executable
executable = UnloadedMeshExecutable.from_hlo(
self._name, self._hlo, **self.compile_args,
compiler_options_kvs=compiler_options_kvs,
device_list=compilation_device_list)
if use_cache:
self._executable = executable
return executable
def cost_analysis(self) -> dict[str, float]:
backend = self.compile_args["backend"]
if xb.using_pjrt_c_api(backend):
raise NotImplementedError(
"Lowered.cost_analysis not implemented on platform "
f"'{backend.platform}'. Use compile().cost_analysis() for "
"post-compilation cost estimates.")
return _jax.hlo_module_cost_analysis(backend, self.hlo().as_hlo_module())
def get_op_sharding_from_executable(
executable) -> tuple[Sequence[xc.OpSharding], Sequence[xc.OpSharding]]:
in_op_shardings: list[xc.OpSharding] = []
parameter_shardings_from_xla = executable.get_parameter_shardings()
if parameter_shardings_from_xla is not None:
in_op_shardings = parameter_shardings_from_xla
out_op_shardings: list[xc.OpSharding] = []
output_shardings_from_xla = executable.get_output_shardings()
if output_shardings_from_xla is not None:
out_op_shardings = output_shardings_from_xla
return in_op_shardings, out_op_shardings
def get_pspec_from_executable(
executable, mesh: Mesh
) -> tuple[tuple[PartitionSpec, ...], tuple[PartitionSpec, ...]]:
input_op_s, output_op_s = get_op_sharding_from_executable(executable)
in_pspec: list[PartitionSpec] = []
for s in input_op_s:
in_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh))
out_pspec: list[PartitionSpec] = []
for s in output_op_s:
out_pspec.extend(sharding_impls.parse_flatten_op_sharding(s, mesh))
return tuple(in_pspec), tuple(out_pspec)
def get_out_shardings_from_executable(
xla_executable,
device_list: xc.DeviceList,
num_out_avals: int,
num_ordered_effects: int,
) -> Sequence[sharding_impls.GSPMDSharding] | None:
assert isinstance(device_list, xc.DeviceList)
try:
omk = xla_executable.get_output_memory_kinds()[0]
if num_ordered_effects > 0:
omk = omk[num_ordered_effects:]
except:
omk = [None] * num_out_avals
assert len(omk) == num_out_avals, (len(omk), num_out_avals)
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`.
if len(device_list) == 1:
return [sharding_impls.GSPMDSharding.get_replicated(device_list, memory_kind=mk)
for mk in omk]
out_op_shardings: Sequence[xc.OpSharding]
_, out_op_shardings = get_op_sharding_from_executable(xla_executable)
if not out_op_shardings:
return None
if num_ordered_effects > 0:
out_op_shardings = out_op_shardings[num_ordered_effects:]
# This means that there are no outputs for JAX but for XLA there is an empty
# tuple output which gets a replicated sharding.
if num_out_avals == 0 and len(out_op_shardings) == 1:
return None
# This condition happens when all the elements in the output tuple have the
# same sharding, so XLA decides to run the `FusionTupleDeduplicator` to
# put the sharding on ROOT instead of the tuple.
# TODO(b/245667823): Remove this when XLA fixes this.
if len(out_op_shardings) == 1 and len(out_op_shardings) < num_out_avals:
out_op_shardings = list(out_op_shardings) * num_out_avals
assert len(out_op_shardings) == num_out_avals == len(omk), (
len(out_op_shardings), num_out_avals, len(omk))
return [sharding_impls.GSPMDSharding(device_list, os, memory_kind=mk)
for os, mk in safe_zip(out_op_shardings, omk)]
def _get_in_shardings_from_xla(
xla_executable, device_list: xc.DeviceList, num_in_avals: int,
num_ordered_effects: int
) -> Sequence[GSPMDSharding] | None:
"""Returns input shardings from XLA."""
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`.
assert isinstance(device_list, xc.DeviceList)
if len(device_list) == 1:
return [GSPMDSharding.get_replicated(device_list)] * num_in_avals
in_op_shardings, _ = get_op_sharding_from_executable(xla_executable)
if not in_op_shardings:
return None
if num_ordered_effects > 0:
in_op_shardings = in_op_shardings[num_ordered_effects:]
assert len(in_op_shardings) == num_in_avals, (
len(in_op_shardings), num_in_avals)
return [GSPMDSharding(device_list, os) for os in in_op_shardings]
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
# without mesh.
def _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh: Mesh
) -> tuple[Sequence[NamedSharding], Sequence[NamedSharding]]:
in_pspec, out_pspec = get_pspec_from_executable(xla_executable, mesh)
return ([NamedSharding(mesh, i) for i in in_pspec],
[NamedSharding(mesh, o) for o in out_pspec])
_orig_out_sharding_handlers: dict[Any, Any] = {}
def _gspmd_to_named_sharding(
out_s: GSPMDSharding, out_aval, orig_in_s: NamedSharding) -> NamedSharding:
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, NamedSharding)
assert isinstance(orig_in_s.mesh, Mesh)
if (out_aval is not None and not out_aval.sharding.mesh.empty and
not out_aval.sharding.mesh._any_axis_manual):
mesh = _abstract_to_concrete_mesh(
out_aval.sharding.mesh, out_s._device_assignment)
else:
mesh = orig_in_s.mesh
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, mesh)
_orig_out_sharding_handlers[NamedSharding] = _gspmd_to_named_sharding
def _gspmd_to_single_device_sharding(
out_s: GSPMDSharding, out_aval, orig_in_s: SingleDeviceSharding
) -> SingleDeviceSharding:
assert isinstance(out_s, GSPMDSharding)
assert isinstance(orig_in_s, SingleDeviceSharding)
return SingleDeviceSharding(
out_s._device_assignment[0], memory_kind=out_s.memory_kind)
_orig_out_sharding_handlers[SingleDeviceSharding] = _gspmd_to_single_device_sharding
def _get_out_sharding_from_orig_sharding(
out_shardings, out_avals, orig_in_s, orig_aval):
out: list[JSharding] = []
orig_handler = _orig_out_sharding_handlers[type(orig_in_s)]
for o, out_aval in safe_zip(out_shardings, out_avals):
if (isinstance(o, sharding_impls.GSPMDSharding) and
out_aval is not core.abstract_token):
# TODO(yashkatariya): Remove this condition and ask users to drop into
# explicit mode.
if (orig_aval is not None and out_aval is not None
and out_aval.ndim == orig_aval.ndim
and isinstance(orig_in_s, NamedSharding)
and out_aval.sharding.mesh == orig_in_s.mesh.abstract_mesh
and o.is_equivalent_to(orig_in_s, orig_aval.ndim)):
out.append(orig_in_s)
else:
try:
out.append(orig_handler(o, out_aval, orig_in_s))
except IndivisibleError:
raise
except:
out.append(o)
else:
out.append(o)
return out
def maybe_recover_user_shardings(
old_shardings, new_shardings, old_avals, new_avals,
intermediate_shardings=None, context_mesh: Mesh | None = None):
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
return new_shardings
for oi, o_aval in safe_zip(old_shardings, old_avals):
if oi is not None and type(oi) in _orig_out_sharding_handlers:
return _get_out_sharding_from_orig_sharding(
new_shardings, new_avals, oi, o_aval)
if intermediate_shardings is not None:
for i in intermediate_shardings:
if i is not None and type(i) in _orig_out_sharding_handlers:
return _get_out_sharding_from_orig_sharding(
new_shardings, [None] * len(new_shardings), i, None)
# For nullary cases like: `jit(lambda: ..., out_shardings=(None, sharding))`
for ns in new_shardings:
if ns is not None and type(ns) in _orig_out_sharding_handlers:
return _get_out_sharding_from_orig_sharding(
new_shardings, new_avals, ns, None)
if context_mesh is not None and not context_mesh.empty:
return [sharding_impls._gspmd_to_named_sharding_via_mesh(n, context_mesh)
if isinstance(n, GSPMDSharding) else n
for n in new_shardings]
return new_shardings
def is_user_xla_layout_equal(ul: Layout | AutoLayout,
xl: Layout) -> bool:
if isinstance(ul, Layout) and not ul.tiling:
return ul.major_to_minor == xl.major_to_minor
else:
return ul == xl
def _get_layouts_from_executable(
xla_executable, in_layouts, out_layouts, num_ordered_effects
) -> tuple[Sequence[Layout | None], Sequence[Layout | None]]:
try:
in_layouts_xla = xla_executable.get_parameter_layouts()
out_layouts_xla = xla_executable.get_output_layouts()
except:
return (None,) * len(in_layouts), (None,) * len(out_layouts)
if num_ordered_effects > 0:
in_layouts_xla = in_layouts_xla[num_ordered_effects:]
out_layouts_xla = out_layouts_xla[num_ordered_effects:]
new_in_layouts = []
for x, l in safe_zip(in_layouts_xla, in_layouts):
x = Layout.from_pjrt_layout(x)
if isinstance(l, Layout) and not is_user_xla_layout_equal(l, x):
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {l} "
f"(User input layout)")
# Always append the XLA layout because it has the full information
# (tiling, etc) even if the user layout does not specify tiling.
new_in_layouts.append(x)
new_out_layouts = []
for x, l in safe_zip(out_layouts_xla, out_layouts):
x = Layout.from_pjrt_layout(x)
if isinstance(l, Layout) and not is_user_xla_layout_equal(l, x):
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {l} "
f"(User output layout)")
# Always append the XLA layout because it has the full information
# (tiling, etc) even if the user layout does not specify tiling.
new_out_layouts.append(x)
assert all(isinstance(i, Layout) for i in new_in_layouts)
assert all(isinstance(o, Layout) for o in new_out_layouts)
return new_in_layouts, new_out_layouts
def get_logical_mesh_ids(mesh_shape):
return np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
def create_compile_options(
computation, mesh, tuple_args, auto_spmd_lowering,
allow_prop_to_inputs, allow_prop_to_outputs, backend,
np_dev, compiler_options):
num_replicas, num_partitions = 1, np_dev.size
xla_device_assignment = np_dev.reshape((num_replicas, num_partitions))
fdo_profile = compiler_options.pop("fdo_profile", None)
compile_options = compiler.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=xla_device_assignment,
use_spmd_partitioning=True,
use_auto_spmd_partitioning=auto_spmd_lowering,
env_options_overrides=compiler_options,
fdo_profile=fdo_profile,
detailed_logging=compiler.use_detailed_logging(computation),
backend=backend,
)
opts = compile_options.executable_build_options
if auto_spmd_lowering:
assert mesh is not None
opts.auto_spmd_partitioning_mesh_shape = list(mesh.shape.values())
opts.auto_spmd_partitioning_mesh_ids = (
get_logical_mesh_ids(list(mesh.shape.values()))
.reshape(-1))
compile_options.parameter_is_tupled_arguments = tuple_args
opts.allow_spmd_sharding_propagation_to_parameters = list(allow_prop_to_inputs)
opts.allow_spmd_sharding_propagation_to_output = list(allow_prop_to_outputs)
return compile_options
@weakref_lru_cache
def _cached_compilation(computation, name, mesh,
tuple_args, auto_spmd_lowering, allow_prop_to_inputs,
allow_prop_to_outputs, host_callbacks, backend,
da, compiler_options_kvs, pgle_profiler):
# One would normally just write: dev = np.array(device_assignment)
# The formulation below is substantially faster if there are many devices.
dev = np.vectorize(lambda i: da[i], otypes=[object])(np.arange(len(da)))
compiler_options = dict(compiler_options_kvs)
compile_options = create_compile_options(
computation, mesh, tuple_args, auto_spmd_lowering,
allow_prop_to_inputs, allow_prop_to_outputs, backend,
dev, compiler_options)
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time:.9f} sec",
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = compiler.compile_or_get_cached(
backend, computation, dev, compile_options, host_callbacks,
da, pgle_profiler)
return xla_executable
def _maybe_get_and_check_in_shardings(
xla_executable, in_shardings, device_list, global_in_avals,
num_ordered_effects):
"""Returns in_shardings extracted from XLA or checks and returns original
shardings.
If in_shardings exist on `jit` or on `jax.Array`, then this function will
check that sharding against what XLA returns as in_shardings. If they don't
match, an error is raised.
If in_sharding is unspecified, then the sharding returned by XLA is returned.
"""
in_shardings_xla = _get_in_shardings_from_xla(
xla_executable, device_list, len(global_in_avals), num_ordered_effects)
if in_shardings_xla is None:
return in_shardings
new_in_shardings = []
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
global_in_avals):
if isinstance(orig, UnspecifiedValue):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval.shape, aval.dtype, xla_s)
new_in_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim)
# MANUAL HloSharding comes from other partitioning frameworks.
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
not xla_hlo_s.is_manual() and
(not op_shardings.are_hlo_shardings_equal(xla_hlo_s, orig_hlo_s))):
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_in_shardings.append(orig)
new_in_shardings = maybe_recover_user_shardings(
in_shardings, new_in_shardings, global_in_avals, global_in_avals)
return new_in_shardings
def _maybe_get_and_check_out_shardings(
xla_executable, out_shardings, device_list, global_out_avals,
num_ordered_effects
):
out_shardings_xla = get_out_shardings_from_executable(
xla_executable, device_list, len(global_out_avals),
num_ordered_effects)
if out_shardings_xla is None:
return out_shardings
new_out_shardings = []
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
global_out_avals):
if isinstance(orig, UnspecifiedValue):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval.shape, aval.dtype, xla_s)
new_out_shardings.append(xla_s)
elif mlir.contains_unconstrained(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval.shape, aval.dtype, xla_s)
try:
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, aval, orig)) # pyrefly: ignore[bad-argument-type]
except:
new_out_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim)
# MANUAL HloSharding comes from other partitioning frameworks.
if (not dtypes.issubdtype(aval.dtype, dtypes.extended) and
not xla_hlo_s.is_manual() and aval.size != 0 and
(not op_shardings.are_hlo_shardings_equal(xla_hlo_s, orig_hlo_s) or
xla_s.memory_kind != orig.memory_kind)):
raise AssertionError(
f"Unexpected XLA sharding override: (XLA) {xla_s} != {orig} "
"(User sharding)")
new_out_shardings.append(orig)
return new_out_shardings
def finalize_shardings(shardings, device_assignment):
if len(device_assignment) == 1:
return [make_single_device_sharding(device_assignment[0], memory_kind=o.memory_kind)
if isinstance(o, GSPMDSharding) else o for o in shardings]
return shardings
def get_prop_to_input_output(in_shardings, out_shardings,
num_ordered_effects):
allow_prop_to_inputs = (False,) * num_ordered_effects + tuple(
isinstance(i, (UnspecifiedValue, AUTO)) for i in in_shardings)
allow_prop_to_outputs = (False,) * num_ordered_effects + tuple(
isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o)
for o in out_shardings)
return allow_prop_to_inputs, allow_prop_to_outputs
def maybe_concretize_mesh(sharding, da: xc.DeviceList):
if (isinstance(sharding, NamedSharding) and
isinstance(sharding.mesh, AbstractMesh)):
if sharding.mesh.size != len(da):
raise ValueError(
f"The size of abstract mesh {sharding.mesh.size} in {sharding} must"
f" match the length of device assignment: {len(da)}")
return sharding.update(mesh=_abstract_to_concrete_mesh(sharding.mesh, da))
return sharding
@dataclasses.dataclass
class UnloadedMeshExecutable:
xla_executable: Any
device_list: xc.DeviceList
backend: xc.Client
input_avals: Sequence[ShapedArray]
input_shardings: Sequence[JSharding]
output_avals: Sequence[ShapedArray]
output_shardings: Sequence[JSharding]
committed: bool
name: str
unordered_effects: list[core.Effect]
ordered_effects: list[core.Effect]
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
kept_var_idx: set[int]
mut: MutationData | None
auto_spmd_lowering: bool
xla_in_layouts: Sequence[Layout | None]
dispatch_in_layouts: Sequence[Layout | None]
xla_out_layouts: Sequence[Layout | None]
all_args_info: AllArgsInfo | None
pgle_profiler: profiler.PGLEProfiler | None
def build_unsafe_call(self):
handle_args = InputsHandler(self.input_shardings, self.dispatch_in_layouts)
handle_outs = global_avals_to_results_handler(
self.output_avals, self.output_shardings, self.committed)
unsafe_call = ExecuteReplicated(
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx, self.mut,
self.pgle_profiler)
return unsafe_call
def load(self) -> MeshExecutable:
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
self.input_avals, self.output_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.xla_in_layouts, self.dispatch_in_layouts,
self.xla_out_layouts, self.mut, self.all_args_info,
self)
@staticmethod
def from_hlo(name: str,
hlo: ir.Module,
global_in_avals: Sequence[ShapedArray],
global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[JSharding | AUTO],
out_shardings: Sequence[(JSharding | AUTO | UnspecifiedValue)],
tuple_args: bool,
auto_spmd_lowering: bool,
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect],
host_callbacks: list[Any],
keepalive: Any,
kept_var_idx: set[int],
backend: xc.Client,
device_list: xc.DeviceList | None,
committed: bool,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
compiler_options_kvs: tuple[tuple[str, Any], ...],
num_devices: int,
mut: MutationData | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_args_info: AllArgsInfo | None = None,
pgle_profiler: profiler.PGLEProfiler | None = None,
intermediate_shardings: Sequence[JSharding] | None = None,
context_mesh: Mesh | None = None,
) -> MeshExecutable:
del num_devices # For compilation, we have an actual device_assignment
if device_list is None:
raise RuntimeError(
"device_assignment cannot be `None` during compilation. Please pass a"
" tuple of devices to `.compile(device_assignment=)`")
assert isinstance(device_list, xc.DeviceList)
in_shardings = tuple(maybe_concretize_mesh(i, device_list)
for i in in_shardings)
out_shardings = tuple(maybe_concretize_mesh(o, device_list)
for o in out_shardings)
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
allow_prop_to_inputs, allow_prop_to_outputs = get_prop_to_input_output(
in_shardings, out_shardings, len(ordered_effects))
mesh = None
if auto_spmd_lowering:
for i in it.chain(in_shardings, out_shardings):
if isinstance(i, AUTO):
mesh = i.mesh
break
util.test_event("pxla_cached_compilation")
xla_executable = _cached_compilation(
hlo, name, mesh, tuple_args, auto_spmd_lowering,
allow_prop_to_inputs, allow_prop_to_outputs, tuple(host_callbacks),
backend, device_list, compiler_options_kvs, pgle_profiler)
if auto_spmd_lowering:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if isinstance(i, AUTO) else i
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings = [x if isinstance(o, AUTO) else o
for x, o in safe_zip(out_shardings_xla, out_shardings)]
else:
assert mesh is None
in_shardings = _maybe_get_and_check_in_shardings(
xla_executable, in_shardings, device_list, global_in_avals,
len(ordered_effects))
out_shardings = _maybe_get_and_check_out_shardings(
xla_executable, out_shardings, device_list, global_out_avals,
len(ordered_effects))
# xla_in_layouts are all either None or Layout. Even default
# layout are concrete layouts and they are used in `compiled.input_formats`
# to return concrete layouts to users.
# `dispatch_in_layouts` replaces default layouts with `None` to simplify
# dispatch logic downstream.
xla_in_layouts, xla_out_layouts = _get_layouts_from_executable(
xla_executable, in_layouts, out_layouts, len(ordered_effects))
del in_layouts, out_layouts
dispatch_in_layouts = [
None if is_default_layout(l, s, a) else l
for l, s, a, in safe_zip(xla_in_layouts, in_shardings, global_in_avals)
]
out_shardings = maybe_recover_user_shardings(
in_shardings, out_shardings, global_in_avals, global_out_avals,
intermediate_shardings, context_mesh)
in_shardings = finalize_shardings(in_shardings, device_list)
out_shardings = finalize_shardings(out_shardings, device_list)
return UnloadedMeshExecutable(
xla_executable=xla_executable,
device_list=device_list,
backend=backend,
input_avals=global_in_avals,
input_shardings=in_shardings,
output_avals=global_out_avals,
output_shardings=out_shardings,
committed=committed,
name=name,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
keepalive=keepalive,
host_callbacks=host_callbacks,
kept_var_idx=kept_var_idx,
mut=mut,
auto_spmd_lowering=auto_spmd_lowering,
xla_in_layouts=xla_in_layouts,
dispatch_in_layouts=dispatch_in_layouts,
xla_out_layouts=xla_out_layouts,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler).load()
class MeshExecutableFastpathData(NamedTuple):
xla_executable: xc.LoadedExecutable
out_pytree_def: Any
in_shardings: Sequence[JSharding]
out_shardings: Sequence[JSharding]
out_avals: Sequence[ShapedArray]
out_committed: Sequence[bool]
kept_var_bitvec: Iterable[bool]
in_device_local_layouts: Sequence[Layout | None]
const_args: Sequence[ArrayLike]
def clear_in_memory_compilation_cache() -> None:
"""Clears the in-memory compilation cache.
This function clears all cached executables that were compiled by
_cached_compilation function.
"""
_cached_compilation.cache_clear()
@dataclasses.dataclass(frozen=True, kw_only=True)
class JitGlobalCppCacheKeys:
donate_argnums: tuple[int, ...] | None = None
donate_argnames: tuple[str, ...] | None = None
device: xc.Device | None = None
backend: str | None = None
in_shardings_treedef: PyTreeDef | None = None
in_shardings_leaves: tuple[Any, ...] | None = None
out_shardings_treedef: PyTreeDef | None = None
out_shardings_leaves: tuple[Any, ...] | None = None
in_layouts_treedef: PyTreeDef | None = None
in_layouts_leaves: tuple[Any, ...] | None = None
out_layouts_treedef: PyTreeDef | None = None
out_layouts_leaves: tuple[Any, ...] | None = None
compiler_options_kvs: tuple[tuple[str, Any], ...] | None = None
@functools.cached_property
def contains_explicit_attributes(self):
return (self.donate_argnums is not None or
self.donate_argnames is not None or
self.device is not None or
self.backend is not None or
any(not isinstance(i, UnspecifiedValue) for i in (self.in_shardings_leaves or [])) or
any(not isinstance(o, UnspecifiedValue) for o in (self.out_shardings_leaves or [])) or
any(i is not None for i in (self.in_layouts_leaves or [])) or
any(o is not None for o in (self.out_layouts_leaves or [])) or
self.compiler_options_kvs)
def reflatten_outputs_for_dispatch(out_tree, out_flat):
# We arrive at dispatch having flattened according to the default
# pytree registry, but we want to re-flatten according to our
# dispatch-specific registry.
out_unflat = tree_util.tree_unflatten(out_tree, out_flat)
return tree_util.dispatch_registry.flatten(out_unflat, None)
class MeshExecutable(stages.Executable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering",
"_kept_var_idx", "_xla_in_layouts", "_dispatch_in_layouts",
"_xla_out_layouts", "_mut", "_all_args_info", "_unloaded_executable",
]
def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
xla_in_layouts, dispatch_in_layouts, xla_out_layouts, mut,
all_args_info: AllArgsInfo | None = None,
unloaded_executable=None):
self.xla_executable = xla_executable
self.build_unsafe_call = build_unsafe_call
# in_avals is a list of global and local avals. Aval is global if input
# is a GDA or jax.Array else local.
self.in_avals = in_avals # has const_args, but not dead args
self.out_avals = out_avals
self._unsafe_call = None
self._in_shardings = in_shardings # has const_args, but not dead args
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
# indices include const_args and dead args
self._kept_var_idx = kept_var_idx
self._xla_in_layouts = xla_in_layouts # has const_args, but not dead args
self._dispatch_in_layouts = dispatch_in_layouts
self._xla_out_layouts = xla_out_layouts
self._mut = mut
self._all_args_info = all_args_info # has const_args and also dead args
self._unloaded_executable = unloaded_executable
@property
def unsafe_call(self) -> Callable[..., Any]:
if self._unsafe_call is None:
self._unsafe_call = self.build_unsafe_call()
return self._unsafe_call
# -- stages.Executable overrides
def xla_extension_executable(self):
return self.xla_executable
def call(self, *args):
args_after_dce = tuple(a for i, a in enumerate(args) if i in self._kept_var_idx)
if (self._all_args_info is not None and
self._all_args_info.debug_info.arg_names is not None):
arg_names_after_dce = tuple(
n for i, n in enumerate(self._all_args_info.debug_info.arg_names)
if i in self._kept_var_idx)
else:
arg_names_after_dce = ("",) * len(args_after_dce)
if self._all_args_info is not None:
# We check all args before DCE
check_arg_avals_for_call(self._all_args_info.in_avals,
map(core.shaped_abstractify, args),
self._all_args_info.debug_info)
else:
# We can only check the args after DCE
check_arg_avals_for_call(self.in_avals,
map(core.shaped_abstractify, args_after_dce),
core.DebugInfo("MeshExecutable", "<unknown>",
arg_names_after_dce, None))
if not self._mut:
check_array_xla_sharding_layout_match(
args_after_dce, self._in_shardings, self._xla_in_layouts,
arg_names_after_dce)
else:
args_after_dce = [*args_after_dce, *self._mut.in_mut]
arg_names_after_dce += (("",) * len(self._mut.in_mut))
check_array_xla_sharding_layout_match(
args_after_dce, self._in_shardings, self._xla_in_layouts,
arg_names_after_dce)
return self.unsafe_call(*args)
def create_cpp_call(self, params: stages.CompiledCallParams):
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
not self.unsafe_call.has_unordered_effects and
not self.unsafe_call.has_host_callbacks):
return None
def aot_cache_miss(*args, **kwargs):
# args do not include the const args.
# See https://docs.jax.dev/en/latest/internals/constants.html.
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
if not params.is_high:
out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch(
params.out_tree, out_flat)
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat)
and not self._mut)
else:
out_tree_dispatch = None
use_fastpath = False
if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(params.const_args) + len(args_flat))]
in_shardings = [
sharding_impls.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(self._in_shardings, self.in_avals)
]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree_dispatch, in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
self._dispatch_in_layouts, params.const_args)
else:
fastpath_data = None
return outs, fastpath_data, False # Do not remove cache entry
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [],
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
def cc_shard_arg(x, sharding, layout):
return shard_args([sharding], [layout], [xc.ArrayCopySemantics.REUSE_INPUT],
[x])[0]
def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.DebugInfo):
if len(ref_avals) != len(arg_avals):
raise TypeError(
f"Computation compiled for {len(ref_avals)} inputs "
f"but called with {len(arg_avals)}")
arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))]
errors = []
for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names):
# Don't compare shardings of avals because you can lower with
# numpy arrays + in_shardings and call compiled executable with
# sharded arrays. We also have sharding checks downstream.
if (ref_aval.shape, ref_aval.dtype) != (arg_aval.shape, arg_aval.dtype):
errors.append(
f"Argument {name} compiled with {ref_aval.str_short()} and called "
f"with {arg_aval.str_short()}")
if errors:
max_num_errors = 5
str_errors = "\n".join(errors[:max_num_errors])
if len(errors) >= max_num_errors:
num_mismatch_str = f"The first {max_num_errors} of {len(errors)}"
else:
num_mismatch_str = "The"
raise TypeError(
"Argument types differ from the types for which this computation was "
"compiled. Perhaps you are calling the compiled executable with a "
"different enable_x64 mode than when it was AOT compiled? "
f"{num_mismatch_str} mismatches are:\n{str_errors}")
def check_device_backend_on_shardings(shardings) -> bool:
for i in shardings:
if isinstance(i, (UnspecifiedValue, AUTO)):
continue
if getattr(i, '_device_backend', False):
return True
return False
def check_array_xla_sharding_layout_match(
args,
in_shardings: Sequence[JSharding],
in_layouts: Sequence[Layout],
arg_names: Sequence[str]
) -> None:
errors = []
num_errors = 5
for arg, xs, xl, name in zip(args, in_shardings, in_layouts, arg_names):
if not isinstance(arg, array.ArrayImpl):
continue
if isinstance(xs, (UnspecifiedValue, AUTO)):
continue
db_xs = check_device_backend_on_shardings([xs])
if (not db_xs and arg._committed and
not arg.sharding.is_equivalent_to(xs, arg.ndim)):
errors.append((
f"Argument {name} with shape {arg.aval.str_short()}:\n"
f" Passed sharding: {arg.sharding}\n"
f" Required sharding: {xs}",
"sharding"))
if (not db_xs and arg._committed and
arg.format.layout is not None and xl is not None and
arg.format.layout != xl):
errors.append((
f"Argument {name} with shape {arg.aval.str_short()}:\n"
f" Passed layout: {arg.format.layout}\n"
f" Required layout: {xl}",
"layout"))
if errors:
first_errors, error_kinds = unzip2(errors[:num_errors])
str_errors = '\n'.join(first_errors)
if all(k == 'sharding' for k in error_kinds):
kind_str = r'shardings'
elif all(k == 'layout' for k in error_kinds):
kind_str = 'layouts'
else:
kind_str = 'shardings and layouts'
num_mismatch_str = (
f"the {len(errors)} mismatches" if len(errors) < num_errors else
f"{num_errors} mismatches out of {len(errors)}")
raise ValueError(
f"Computation was compiled for input {kind_str} that disagree with the "
f"{kind_str} of arguments passed to it. "
f"Here are {num_mismatch_str}:\n{str_errors}")
def batch_spec(spec, dim, val):
too_short = dim - len(spec)
if too_short > 0:
spec += (None,) * too_short
new_partitions = tuple_insert(spec, dim, val)
return PartitionSpec(*new_partitions)