Files
2026-05-06 19:47:31 +07:00

2073 lines
94 KiB
Python

# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable, Hashable, Sequence, Set
import enum
from functools import partial
import inspect
import itertools as it
from math import prod
import operator as op
from typing import Any, TypeVar, Union, cast, overload
import numpy as np
from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.core import order_wrt_mesh
from jax._src.core import pvary, Tracer, typeof, shard_aval, unshard_aval
from jax._src.mesh import (AbstractMesh, Mesh, BaseMesh, AxisType,
use_abstract_mesh, get_abstract_mesh,
get_concrete_mesh, empty_abstract_mesh)
from jax._src.lax import lax, parallel as lax_parallel
from jax._src.lib import _jax
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo, sdy
from jax._src.sharding_impls import (NamedSharding, PartitionSpec,
canonicalize_sharding)
from jax._src.util import (HashablePartial, unzip2, partition_list, merge_lists,
split_list, subs_list2, fun_name as util_fun_name)
from jax._src.state import discharge
from jax._src.state.types import AbstractRef
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import ad
from jax._src.tree_util import (
broadcast_prefix, keystr, prefix_errors, generate_key_paths, tree_flatten,
tree_leaves, tree_map, tree_structure, tree_unflatten, KeyPath, PyTreeDef, FlatTree)
P = PartitionSpec
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
traceback_util.register_exclusion(__file__)
# API
Specs = Any # PyTree[PartitionSpec]
AxisName = Hashable
class InferFromArgs:
def __repr__(self):
return "jax.sharding.Infer"
def __reduce__(self):
return (_get_default_infer, ())
Infer = InferFromArgs()
def _get_default_infer():
return Infer
F = TypeVar("F", bound=Callable)
G = TypeVar("G", bound=Callable)
@overload
def shard_map(f: F, /, *, out_specs: Specs,
in_specs: Specs | None | InferFromArgs = ...,
mesh: Mesh | AbstractMesh | None = ...,
axis_names: Set[AxisName] = ..., check_vma: bool = ...) -> F:
...
@overload
def shard_map(f: None = None, /, *, out_specs: Specs,
in_specs: Specs | None | InferFromArgs = ...,
mesh: Mesh | AbstractMesh | None = ...,
axis_names: Set[AxisName] = ..., check_vma: bool = ...
) -> Callable[[G], G]:
...
# See https://github.com/jax-ml/jax/pull/30753 to understand why `in_specs`
# defaults to `Infer`.
def shard_map(f: F | None = None, /, *, out_specs: Specs,
in_specs: Specs | None | InferFromArgs = Infer,
mesh: Mesh | AbstractMesh | None = None,
axis_names: Set[AxisName] = frozenset(), check_vma: bool = True
) -> F | Callable[[G], G]:
"""Map a function over shards of data using a mesh of devices.
See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html.
Args:
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
takes as input a shard of the mapped-over arguments and produces a shard
of the output.
mesh: (optional, default None) a ``jax.sharding.Mesh`` representing the
array of devices over which to shard the data and on which to execute
instances of ``f``. The names of the ``Mesh`` can be used in collective
communication operations in ``f``. If mesh is None, it will be inferred
from the context which can be set via `jax.set_mesh` context
manager.
in_specs: (optional, default `Infer`) a pytree with
``jax.sharding.PartitionSpec`` instances as leaves, with a tree structure
that is a tree prefix of the args tuple to be mapped over. Similar to
``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the
corresponding argument (or subtree of arguments) should be sharded along
the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a
``mesh`` axis name at a position expresses sharding the corresponding
argument array axis along that positional axis; not mentioning an axis
name expresses replication.
If ``Infer``, all mesh axes must be of type
`Explicit`, in which case the in_specs are inferred from the argument types.
If ``None``, inputs will be treated as static.
out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree
structure that is a tree prefix of the output of ``f``. Each
``PartitionSpec`` represents how the corresponding output shards should be
concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name
at a position expresses concatenation of that mesh axis's shards along the
corresponding positional axis; not mentioning a ``mesh`` axis name
expresses a promise that the output values are equal along that mesh axis,
and that rather than concatenating only a single value should be produced.
axis_names: (optional, default set()) set of axis names from ``mesh`` over
which the function ``f`` is manual. If empty, ``f``, is manual
over all mesh axes.
check_vma: (optional) boolean (default True) representing whether to enable
additional validity checks and automatic differentiation optimizations.
The validity checks concern whether any mesh axis names not mentioned in
``out_specs`` are consistent with how the outputs of ``f`` are replicated.
Returns:
A callable representing a mapped version of ``f``, which accepts positional
arguments corresponding to those of ``f`` and produces output corresponding
to that of ``f``.
"""
kwargs = dict(mesh=mesh, in_specs=in_specs, out_specs=out_specs,
axis_names=axis_names, check_vma=check_vma)
if f is None:
return lambda g: _shard_map(g, **kwargs)
return _shard_map(f, **kwargs)
@overload
def smap(f: F, /, *,
in_axes: int | None | InferFromArgs | tuple[Any, ...] = ...,
out_axes: Any, axis_name: AxisName) -> F:
...
@overload
def smap(f: None = None, /, *,
in_axes: int | None | InferFromArgs | tuple[Any, ...] = ...,
out_axes: Any, axis_name: AxisName
) -> Callable[[G], G]:
...
def smap(f: F | None = None, /, *,
in_axes: int | None | InferFromArgs | tuple[Any, ...] = Infer,
out_axes: Any, axis_name: AxisName
) -> F | Callable[[G], G]:
"""Single axis shard_map that maps a function `f` one axis at a time.
Args:
f: Callable to be mapped. Each application of ``f``, or "instance" of ``f``,
takes as input a shard of the mapped-over arguments and produces a shard
of the output.
in_axes: (optional) An integer, None, or sequence of values specifying which
input array axes to map over. If not specified, `smap` will try to infer
the axes from the arguments only under `Explicit` mode.
An integer or ``None`` indicates which array axis to map over for all
arguments (with ``None`` indicating not to map any axis), and a tuple
indicates which axis to map for each corresponding positional argument.
Axis integers must be in the range ``[-ndim, ndim)`` for each array,
where ``ndim`` is the number of dimensions (axes) of the corresponding
input array.
out_axes: An integer, None, or (nested) standard Python container
(tuple/list/dict) thereof indicating where the mapped axis should appear
in the output.
axis_name: ``mesh`` axis name over which the function ``f`` is manual.
Returns:
A callable representing a mapped version of ``f``, which accepts positional
arguments corresponding to those of ``f`` and produces output corresponding
to that of ``f``.
"""
kwargs = dict(in_axes=in_axes, out_axes=out_axes, axis_name=axis_name)
if f is None:
return lambda g: _smap(g, **kwargs)
return _smap(f, **kwargs)
def _smap(f: F, *, in_axes: int | None | InferFromArgs | tuple[Any, ...],
out_axes: Any, axis_name: AxisName) -> F:
if isinstance(axis_name, (list, tuple)):
raise TypeError(
f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}")
if (in_axes is not None and in_axes is not Infer and
not isinstance(in_axes, (int, tuple))):
raise TypeError(
"smap in_axes must be an int, None, jax.sharding.Infer, or a tuple of"
" entries corresponding to the positional arguments passed to the"
f" function, but got {in_axes}.")
if (in_axes is not Infer and
not all(isinstance(l, int) for l in tree_leaves(in_axes))):
raise TypeError(
"smap in_axes must be an int, None, jax.sharding.Infer, or (nested)"
f" container with those types as leaves, but got {in_axes}.")
if not all(isinstance(l, int) for l in tree_leaves(out_axes)):
raise TypeError("smap out_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {out_axes}.")
in_specs = (Infer if in_axes is Infer else
tree_map(partial(_axes_to_pspec, axis_name), in_axes,
is_leaf=lambda x: x is None))
out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes,
is_leaf=lambda x: x is None)
return _shard_map(f, mesh=None, in_specs=in_specs, out_specs=out_specs,
axis_names={axis_name}, check_vma=True, _smap=True)
@partial(traceback_util.api_boundary, repro_api_name="jax.shard_map")
def _shard_map(f: F, *, mesh: Mesh | AbstractMesh | None,
in_specs: Specs, out_specs: Specs, axis_names: Set[AxisName],
check_vma: bool, _smap: bool = False) -> F:
if not callable(f):
raise TypeError("shard_map requires a callable for its first argument, "
f"but got {f} of type {type(f)}.")
@util.wraps(f)
@traceback_util.api_boundary
def wrapped(*args):
nonlocal mesh, axis_names
mesh, axis_names = _shmap_checks(
mesh, axis_names, in_specs, out_specs, _smap)
dbg = api_util.debug_info("shard_map", f, args, {})
args_flat = FlatTree.flatten(args)
api_util.check_no_transformed_refs_args(lambda: dbg, args_flat)
try:
in_specs_flat = broadcast_prefix(
in_specs, args, is_leaf=lambda x: x is None)
except ValueError:
e, *_ = prefix_errors(in_specs, args)
raise e('shard_map in_specs') from None
if (in_specs is Infer and
all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
arg_s = [typeof(a).sharding for a in args_flat]
assert all(i is Infer for i in in_specs_flat), in_specs_flat
in_specs_flat = [_manual_spec(axis_names, s.spec, mesh) for s in arg_s]
in_tree = args_flat.tree
which_dyn = [s is not None for s in in_specs_flat]
static_args = [x for x, dyn in zip(args_flat, which_dyn) if not dyn]
dyn_args = [x for x, dyn in zip(args_flat, which_dyn) if dyn]
in_specs_flat = tuple(s for s, dyn in zip(in_specs_flat, which_dyn) if dyn)
dyn_argnums = [i for i, dyn in enumerate( which_dyn) if dyn]
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums,
in_specs_flat, dyn_args)
# TODO(yashkatariya): Add support for partial manual
mesh_axis_names_wo_vmap = (
frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names)
if (mesh_axis_names_wo_vmap == axis_names and
all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
for a, s in zip(dyn_args, in_specs_flat):
if not isinstance(s, P): continue
arg_aval = typeof(a)
s = s._normalized_spec_for_aval(arg_aval.ndim)
if config.remove_size_one_mesh_axis_from_type.value:
s = core.remove_size_one_mesh_axis(s, mesh)
if arg_aval.sharding.spec != s:
raise ValueError(
f"in_specs passed to shard_map: {s} does not match the specs of"
f" the input: {arg_aval.sharding.spec} for arg: {typeof(a)}."
" `in_specs` is an optional argument so you can omit specifying"
" it and shard_map will infer the in_specs from the arguments."
" If you want to reshard your inputs, you can use `jax.reshard`"
" on the arguments and then pass those args to shard_map.")
if (dbg.arg_names is not None and len(dyn_args) != len(dbg.arg_names)):
dbg = dbg.with_unknown_names()
def f_wrapped(*dyn_args):
dyn_args_iter = iter(dyn_args)
static_args_iter = iter(static_args)
all_args = [next(dyn_args_iter) if dyn else next(static_args_iter)
for dyn in which_dyn]
args = tree_unflatten(in_tree, all_args)
ans = f(*args)
ans_ft = FlatTree.flatten(ans)
try:
out_specs_flat = tuple(broadcast_prefix(out_specs, ans))
except ValueError:
e, *_ = prefix_errors(out_specs, ans)
raise e('shard_map out_specs') from None
def add_implicit_pvary_and_unreduced(val, spec):
if not isinstance(spec, P): return val
aval = typeof(val)
val = pvary(val, tuple(_spec_to_vma(spec) - aval.mat.varying))
return (lax_parallel.vary_unreduced_cast(val, tuple(unreduced))
if (unreduced := spec.unreduced - aval.mat.unreduced) else val)
if check_vma:
ans_ft = ans_ft.map2(add_implicit_pvary_and_unreduced, out_specs_flat)
return ans_ft.with_aux(out_specs_flat)
try:
out_ft = shard_map_p.bind(
*dyn_args, subfuns=(f_wrapped,), mesh=mesh, in_specs=in_specs_flat,
check_vma=check_vma, manual_axes=axis_names, debug_info=dbg)
except _SpecError as e:
fails, out_tree = e.args
msg = _spec_rank_error(SpecErrorType.out, f, out_tree, out_specs, fails)
if any(fail is not no_fail and not fail.shape for fail in fails):
msg += (" In particular, for rank 0 outputs which are not constant "
"over the mesh, add at least one (singleton) axis to them so "
"that they can be concatenated using out_specs.")
raise ValueError(msg) from None
except _RepError as e:
fails, out_tree, = e.args
msg = _inout_vma_error(f, mesh, out_tree, out_specs, fails)
raise ValueError(msg) from None
return out_ft.unflatten()
return cast(F, wrapped)
def _axes_to_pspec(axis_name, axis):
if axis is None:
return P()
return P(*[None] * axis + [axis_name])
def _shmap_checks(mesh, axis_names, in_specs, out_specs, _smap):
if mesh is None:
mesh = get_abstract_mesh()
if mesh.empty:
raise ValueError(
"The context mesh cannot be empty. Use"
" `jax.set_mesh(mesh)` to enter into a mesh context")
else:
ctx_mesh = get_abstract_mesh()
if not ctx_mesh.empty and mesh.abstract_mesh != ctx_mesh:
raise ValueError(
f"The context mesh {ctx_mesh} should match the mesh passed to"
f" shard_map {mesh}")
if not isinstance(mesh, (Mesh, AbstractMesh)):
raise TypeError("shard_map requires a `jax.sharding.Mesh` or a "
"`jax.sharding.AbstractMesh` instance for its "
f"second argument, but got {mesh} of type {type(mesh)}.")
if mesh.empty:
raise ValueError(f"shard_map requires a non-empty mesh. Got {mesh}")
mesh_axis_names_wo_vmap = (
frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names
)
if not isinstance(axis_names, (frozenset, set)):
raise TypeError(
"`axis_names` argument of shard_map should be of type `frozenset` or"
f" `set`. Got type: {type(axis_names)}")
if isinstance(axis_names, set):
axis_names = frozenset(axis_names)
if not axis_names:
axis_names = mesh_axis_names_wo_vmap
if not axis_names.issubset(mesh_axis_names_wo_vmap):
raise ValueError(
f"jax.shard_map requires axis_names={axis_names} to be a subset of "
f"mesh.axis_names={mesh_axis_names_wo_vmap}")
if (in_specs is Infer and
not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
axis_types = ', '.join(str(mesh._name_to_type[a]) for a in axis_names)
if _smap:
msg = (f"in_axes was not specified when axis_name={axis_names} was of"
f" type {axis_types}")
else:
msg = ("shard_map in_specs argument must be a pytree of"
" `jax.sharding.PartitionSpec` instances, but it was `None` when"
f" {axis_names=} are of type {axis_types}")
raise TypeError(msg)
if in_specs is not Infer and in_specs is not None:
_check_specs(SpecErrorType.input, in_specs, axis_names)
_check_unreduced(SpecErrorType.input, mesh, axis_names, in_specs)
_check_specs(SpecErrorType.out, out_specs, axis_names)
_check_unreduced(SpecErrorType.out, mesh, axis_names, out_specs)
return mesh, axis_names
def _manual_spec(manual_axes, spec: P, mesh) -> P:
out: list[str | tuple[str | None, ...] | None] = []
s: str | None | tuple[str, ...]
for s in spec:
if s is None:
out.append(s)
elif isinstance(s, tuple):
temp = [p if p in manual_axes else None for p in s]
while temp and temp[-1] is None:
temp.pop()
if None in temp:
raise ValueError(f"Invalid spec: {spec}")
out.append(None if len(temp) == 0 else tuple(temp))
else:
out.append(s if s in manual_axes else None)
_check_unreduced(SpecErrorType.input, mesh, manual_axes, spec)
return P(*out, unreduced=spec.unreduced, reduced=spec.reduced)
# Error checking and messages
SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
def _check_unreduced(error_type, mesh, manual_axes, specs):
from jax._src.hijax import HiPspec
prefix = 'in' if error_type == SpecErrorType.input else 'out'
full_manual = frozenset(mesh.axis_names) == manual_axes
specs_flat, _ = tree_flatten(specs)
for s in specs_flat:
if isinstance(s, HiPspec):
continue # TODO(mattjj,yashkatariya): add user validation method
if not s.unreduced and not s.reduced:
continue
if not full_manual:
raise NotImplementedError(
f"unreduced/reduced can only be passed to {prefix}_specs when"
" shard_map is in full manual mode. Got mesh axis names"
f" {mesh.axis_names}, manual_axes: {manual_axes}, specs: {s}. Please"
" file a bug at https://github.com/jax-ml/jax/issues.")
if not all(mesh._name_to_type[u] == AxisType.Explicit for u in s.unreduced):
raise ValueError(
f"unreduced in {prefix}_specs {s} can only be used when the mesh"
" passed to shard_map contains axis names all of type `Explicit`."
f" Got mesh {mesh}")
if not all(mesh._name_to_type[u] == AxisType.Explicit for u in s.reduced):
raise ValueError(
f"reduced in {prefix}_specs {s} can only be used when the mesh"
" passed to shard_map contains axis names all of type `Explicit`."
f" Got mesh {mesh}")
def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None:
from jax._src.hijax import HiPspec
if error_type == SpecErrorType.input and specs is None:
raise TypeError(
"shard_map in_specs argument must be a pytree of "
"`jax.sharding.PartitionSpec` instances, but it was None.\n"
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
"where `P = jax.sharding.PartitionSpec`?")
def check_spec(p):
if isinstance(p, HiPspec):
return True # TODO(mattjj,yashkatariya): add user validation method
if not isinstance(p, PartitionSpec):
return False
for names in p:
names = (names,) if not isinstance(names, tuple) else names
for name in names:
if name is not None and name not in manual_axes:
return False
return True
if all(check_spec(p) for p in tree_leaves(specs)):
return
prefix = 'in' if error_type == SpecErrorType.input else 'out'
msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, "
for key, x in generate_key_paths(specs) if not isinstance(x, P)]
if not msgs:
for key, p in generate_key_paths(specs):
for names in p:
names = (names,) if not isinstance(names, tuple) else names
for name in names:
if name is not None and name not in manual_axes:
msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}")
raise ValueError(
f"shard_map {prefix}_specs argument must refer to an axis "
f"marked as manual ({manual_axes}), but:\n\n"
+ '\n\n'.join(msgs) + '\n\n'
f"Check the {prefix}_specs values passed to shard_map.")
raise TypeError(
f"shard_map {prefix}_specs argument must be a pytree of "
f"`jax.sharding.PartitionSpec` instances, but:\n\n"
+ '\n\n'.join(msgs) + '\n\n'
f"Check the {prefix}_specs values passed to shard_map.")
class NoFail:
def __repr__(self):
return "NoFail()"
no_fail = NoFail()
def _check_specs_vs_args(
f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs,
dyn_argnums: Sequence[int], in_specs_flat: Sequence[P],
xs: Sequence) -> None:
in_avals = map(core.shaped_abstractify, xs)
fail = [a if isinstance(p, P) and len(p) > a.ndim else no_fail
for p, a in zip(in_specs_flat, in_avals)]
if any(f is not no_fail for f in fail):
fail = _expand_fail(in_tree, dyn_argnums, fail)
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
raise ValueError(msg)
bad = lambda a, d, ns: a.shape[d] % prod(mesh.shape[n] for n in ns)
fail = [a if (isinstance(s, P) and
any(bad(a, d, ns) for d, ns in _spec_to_names(s).items()))
else no_fail for a, s in zip(in_avals, in_specs_flat)]
if any(f is not no_fail for f in fail):
fail = _expand_fail(in_tree, dyn_argnums, fail)
msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
raise ValueError(msg)
def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int],
fail: Sequence[core.ShapedArray | NoFail]
) -> list[core.ShapedArray | NoFail]:
fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves
for i, f in zip(dyn_argnums, fail):
fail_[i] = f
return fail_
def _spec_rank_error(
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
fails: list[core.ShapedArray | NoFail]) -> str:
fun_name = util_fun_name(f)
if error_type == SpecErrorType.input:
prefix, base = 'in', 'the passed args'
ba = _try_infer_args(f, tree)
else:
prefix, base = 'out', f'{fun_name}(*args)'
ba = None
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
extra = ""
if error_type == SpecErrorType.input and ba is not None:
arg_key, *_ = fail_key
param_names, params = unzip2(
(name, param) for name, param in ba.signature.parameters.items()
if param.kind not in (inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.VAR_KEYWORD))
if (arg_key.idx >= len(params) or
params[arg_key.idx].kind == inspect.Parameter.VAR_POSITIONAL):
extra = (f", where args{arg_key} is the index "
f"{arg_key.idx - len(params) + 1} component "
f"of {fun_name}'s varargs parameter '{param_names[-1]}',")
else:
param_name = params[arg_key.idx]
extra = (f", where args{arg_key} is bound to {fun_name}'s "
f"parameter '{param_name}',")
msgs.append(
f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length "
f"{len(spec)}, but "
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
assert msgs
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
msg = (f"shard_map applied to the function '{fun_name}' was given an "
f"{prefix}_specs entry which is too long to be compatible with the "
f"corresponding {prefix}put value from the function:\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Entries in {prefix}_specs must be of length no greater than the "
f"number of axes in the corresponding {prefix}put value.\n\n"
f"Either revise the spec to be shorter, or modify '{fun_name}' so "
f"that its {prefix}puts have sufficient rank.")
if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)):
msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs "
"entry of `P()`, where `P = jax.sharding.PartitionSpec`.")
return msg
def _spec_divisibility_error(
f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs,
fails: list[core.ShapedArray | NoFail]) -> str:
ba = _try_infer_args(f, tree)
fun_name = getattr(f, '__name__', str(f))
msgs = []
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
extra = ""
if ba is not None:
arg_key, *_ = fail_key
param_names, params = unzip2(
(name, param) for name, param in ba.signature.parameters.items()
if param.kind not in (inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.VAR_KEYWORD))
if (arg_key.idx >= len(params) or
params[arg_key.idx].kind == inspect.Parameter.VAR_POSITIONAL):
extra = (f", where args{arg_key} is the index "
f"{arg_key.idx - len(params) + 1} component "
f"of {fun_name}'s varargs parameter '{param_names[-1]}',")
else:
param_name = params[arg_key.idx]
extra = (f", where args{arg_key} is bound to {fun_name}'s "
f"parameter '{param_name}',")
names = _spec_to_names(spec)
for d, ns in names.items():
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
total = 'total ' if len(ns) > 1 else ''
sz = prod(mesh.shape[n] for n in ns)
msgs.append(
f"* the passed args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
f"{aval.shape[d]}")
assert msgs
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
msg = (f"shard_map applied to the function '{fun_name}' was given argument "
f"arrays with axis sizes that are not evenly divisible by the "
f"corresponding mesh axis sizes:\n\n"
f"The mesh given has shape {tuple(mesh.shape.values())} with "
f"corresponding axis names {mesh.axis_names}.\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
f"Array arguments' axis sizes must be evenly divisible by the mesh "
f"axis or axes indicated by the corresponding elements of the "
f"argument's in_specs entry. Consider checking that in_specs are "
f"correct, and if so consider changing the mesh axis sizes or else "
f"padding the input and adapting '{fun_name}' appropriately.")
return msg
def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef,
specs: Specs, fails: list[core.ManualAxisType | NoFail]
) -> str:
fun_name = getattr(f, '__name__', str(f))
msgs = []
for (spec_key, spec), (fail_key, mat) in _iter_paths(tree, specs, fails):
unmentioned = _unmentioned(mesh, spec)
if len(unmentioned) > 1:
need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec))))
got_vma = ','.join(map(str, order_wrt_mesh(mesh, mat.varying)))
diff = ','.join(map(str, order_wrt_mesh(
mesh, [n for n in unmentioned if n in mat.varying])))
msgs.append(
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
f"corresponding output value is only varying across mesh axes "
f"{{{need_vma}}} and not {{{diff}}}, but it was inferred to be "
f"possibly varying over {{{got_vma}}}")
else:
need_rep_, = unmentioned
msgs.append(
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
f"corresponding output value is replicated across mesh axis "
f"'{need_rep_}', but could not infer replication over any axes")
assert msgs
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
msg = (f"shard_map applied to the function '{fun_name}' was given "
f"out_specs which require replication which can't be statically "
f"inferred given the mesh:\n\n"
f"The mesh given has shape {tuple(mesh.shape.values())} with "
f"corresponding axis names {mesh.axis_names}.\n\n"
+ '\n\n'.join(msgs) + '\n\n' +
"Check if these output values are meant to be replicated over those "
"mesh axes. If not, consider revising the corresponding out_specs "
"entries. If so, consider disabling the check by passing the "
"check_vma=False argument to `jax.shard_map`.")
return msg
def _unmentioned(mesh: Mesh | AbstractMesh, spec) -> list[AxisName]:
vur = _spec_to_mat(spec).vur
return [n for n in mesh.axis_names if n not in vur]
def _try_infer_args(f, tree):
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
try:
return inspect.signature(f).bind(*dummy_args)
except (TypeError, ValueError):
return None
T = TypeVar('T')
def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail]
) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]:
failures = tree_unflatten(tree, fails)
failures_aug = generate_key_paths(failures)
specs_ = tree_unflatten(tree_structure(specs), map(Tup, generate_key_paths(specs)))
specs_aug = broadcast_prefix(specs_, failures, is_leaf=lambda x: x is None)
return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data)
in zip(specs_aug, failures_aug)
if s is not None and fail_data is not no_fail]
class Tup:
def __init__(self, vals): self.vals = vals
def __iter__(self): return iter(self.vals)
# Primitive
JaxType = Any
MaybeTracer = Union[JaxType, Tracer]
class ShardMapPrimitive(core.Primitive):
multiple_results = True
skip_canonicalization = True
def bind_with_trace(self, trace, args, avals, params, /):
fun, = params.pop('subfuns')
# fun returns a FlatTree containing a tuple of the user-level data and a flat,
# broadcasted out_specs wrapped in `Static`.
# The result of `bind_with_trace` is a `FlatTree` of tracer-like things and
# doesn't include the `Static` out_specs.
return trace.process_shard_map(shard_map_p, fun, args, **params)
def get_bind_params(self, params):
new_params = dict(params)
jaxpr = new_params.pop('jaxpr')
assert isinstance(jaxpr, core.Jaxpr)
axes = new_params.pop('out_specs')
def eval_jaxpr(*args):
result = core.eval_jaxpr(jaxpr, (), *args)
return FlatTree.flatten(result).with_aux(axes)
new_params['subfuns'] = (eval_jaxpr,)
new_params['debug_info'] = jaxpr.debug_info
return new_params
shard_map_p = ShardMapPrimitive('shard_map')
# Staging
@util.cache(max_size=256, trace_context_in_key=False)
def _as_manual_mesh(mesh, manual_axes: frozenset) -> AbstractMesh:
return mesh.abstract_mesh.update_axis_types(
{n: AxisType.Manual for n in manual_axes})
def _extend_axis_env(mesh, manual_axes):
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
if k in manual_axes])
def _shard_map_staging(
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: Callable,
args: Sequence[Any], *, mesh: Mesh,
in_specs, check_vma: bool, manual_axes: frozenset, debug_info
) -> FlatTree:
source_info = source_info_util.current()
inner_mesh = _as_manual_mesh(mesh, manual_axes)
in_avals = [typeof(arg) for arg in args]
in_avals_ = map(partial(shard_aval, mesh, manual_axes, check_vma),
in_specs, in_avals)
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
config._check_vma(check_vma)):
in_avals_flat_tree = FlatTree.flatten((in_avals_, {}))
jaxpr, out_data = pe.trace_to_jaxpr(
f, in_avals_flat_tree, debug_info, fun_returns_flat_tree=True,
requires_low=trace.requires_low)
out_avals_ft, out_specs = out_data.unpack_aux()
_check_names(out_specs, out_avals_ft)
if check_vma:
_check_mats(mesh, out_specs, out_avals_ft)
out_avals = [unshard_aval(mesh, check_vma, spec, aval)
for spec, aval in zip(out_specs, out_avals_ft)]
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
config._check_vma(check_vma)):
jaxpr, consts = pe.separate_consts(jaxpr)
in_specs_staged = (*(_repspec(typeof(c)) for c in consts), *in_specs)
if trace.requires_low:
in_specs_staged = tuple(lo_spec for hi_spec in in_specs_staged for lo_spec in hi_spec.to_lo())
out_specs = tuple(lo_spec for hi_spec in out_specs for lo_spec in hi_spec.to_lo())
params = dict(mesh=mesh, in_specs=in_specs_staged,
out_specs=out_specs, jaxpr=jaxpr.jaxpr,
check_vma=check_vma, manual_axes=manual_axes)
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info)
const_tracers = map(to_jaxpr_tracer, consts)
trace.frame.is_high |= jaxpr.is_high
if trace.requires_low:
in_tracers = [to_jaxpr_tracer(loval) for arg in args
for loval in typeof(arg).lower_val(arg)]
out_avals_lo = [lo_aval for aval in out_avals for lo_aval in aval.lo_ty()]
else:
in_tracers = map(to_jaxpr_tracer, args)
out_avals_lo = out_avals
out = trace.emit_eqn([*const_tracers, *in_tracers], out_avals_lo, prim, params,
effs, source_info)
if trace.requires_low:
out = pe.raise_lo_outs(out_avals, out)
return out_avals_ft.update(out)
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
# TODO add underscore version, for direct-linearize to consume
def _spec_to_names(spec: PartitionSpec):
return {i: names if isinstance(names, tuple) else (names,)
for i, names in enumerate(spec) if names is not None}
def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma,
spec, aval: core.ShapedArray) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
if spec.unreduced != aval.sharding.spec.unreduced:
raise ValueError(
f"in_specs containing unreduced {spec} passed to shard_map should be"
" equal to the unreduced present on the in_aval"
f" {aval.str_short(True)}")
if spec.reduced != aval.sharding.spec.reduced:
raise ValueError(
f"in_specs containing reduced {spec} passed to shard_map should be"
f" equal to the reduced present on the in_aval {aval.str_short(True)}")
names = _spec_to_names(spec)
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
manual_mesh = _as_manual_mesh(mesh, manual_axes)
new_sharding = aval.sharding.update(
mesh=manual_mesh,
spec=core.modify_spec_for_auto_manual(aval.sharding.spec, manual_mesh))
vma = (_spec_to_vma(spec) if check_vma else frozenset()) | aval.mat.varying
unreduced = aval.sharding.spec.unreduced if check_vma else frozenset()
reduced = aval.sharding.spec.reduced if check_vma else frozenset()
mat = core.ManualAxisType(varying=vma, unreduced=unreduced, reduced=reduced)
return aval.update(shape=new_shape, sharding=new_sharding,
manual_axis_type=mat)
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
def _unshard_shaped_array(mesh: Mesh, check_vma, spec, aval: core.ShapedArray
) -> core.ShapedArray:
assert isinstance(aval, core.ShapedArray)
if check_vma and spec.unreduced != aval.mat.unreduced:
raise ValueError(
"out_specs passed to shard_map should be equal to the unreduced"
f" present on the out_aval. Got out_specs={spec} and"
f" out_aval={aval.str_short(True)}")
if check_vma and spec.reduced != aval.mat.reduced:
raise ValueError(
"out_specs passed to shard_map should be equal to the reduced present"
f" on the out_aval. Got out_specs={spec} and"
f" out_aval={aval.str_short(True)}")
names = _spec_to_names(spec)
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
names_spec = spec._normalized_spec_for_aval(aval.ndim)
if aval.ndim == 0:
out_spec = P(unreduced=spec.unreduced, reduced=spec.reduced)
else:
out_spec = []
for name_s, aval_s in zip(names_spec, aval.sharding.spec):
if name_s and not aval_s:
out_spec.append(name_s)
elif aval_s and not name_s:
out_spec.append(aval_s)
elif not name_s and not aval_s:
out_spec.append(None)
else:
assert name_s and aval_s
name_s = name_s if isinstance(name_s, tuple) else (name_s,)
aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,)
out_spec.append(name_s + aval_s)
out_spec = PartitionSpec(*out_spec, unreduced=spec.unreduced,
reduced=spec.reduced)
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
get_abstract_mesh())
new_sharding = NamedSharding(new_mesh, out_spec)
manual_axes = set(new_mesh.manual_axes)
vma = frozenset(v for v in aval.mat.varying if v in manual_axes)
# TODO(yashkatariya): Handle partial manual unreduced/reduced.
out_mat = core.ManualAxisType(varying=vma)
return aval.update(shape=new_shape, sharding=new_sharding,
manual_axis_type=out_mat)
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
# Type-checking
def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_specs, out_specs,
check_vma, manual_axes):
# TODO(mattjj,parkers): check auto
for v, x, in_spec in zip(jaxpr.invars, in_atoms, in_specs):
sharded_aval = shard_aval(mesh, manual_axes, check_vma, in_spec, x.aval)
if not core.typecompat(v.aval, sharded_aval):
raise core.JaxprTypeError("shard_map argument avals not compatible with "
"jaxpr binder avals and in_specs")
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
core.check_jaxpr(jaxpr)
if check_vma:
for v, os in zip(jaxpr.outvars, out_specs):
if isinstance(os, P) and not _valid_repeats(mesh, v.aval.mat, os):
raise core.JaxprTypeError(
"shard_map can't prove output is sufficiently replicated")
out_avals_sharded = [x.aval for x in jaxpr.outvars]
out_avals = map(partial(unshard_aval, mesh, check_vma), out_specs,
out_avals_sharded)
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
return out_avals, effs
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
def _valid_repeats(mesh: Mesh, mat: core.ManualAxisType, spec) -> bool:
um = set(_unmentioned(mesh, spec)) - set(mesh.manual_axes)
vur = mat.vur
if any(u in vur for u in um):
return False
return True
# Lowering
def _shardy_shard_map_sharding(
ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in
) -> sharding_impls.SdyArray:
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
ns = sharding_impls.physical_sharding(aval_in, ns)
aval_in = core.physical_aval(aval_in)
sdy_sharding = ns._to_sdy_sharding(aval_in.ndim)
if len(manual_axes) < len(mesh.axis_names):
for dim_sharding in sdy_sharding.dim_shardings:
dim_sharding.is_open = True
return sdy_sharding
def _get_token_sharding(
ctx: mlir.LoweringRuleContext, mesh
) -> sharding_impls.SdyArray:
ns = _make_scoped_manual_sharding(ctx, mesh, P())
return ns._to_sdy_sharding(0)
def _get_spmdaxis_ctx_mesh(mesh):
if isinstance(mesh, AbstractMesh):
concrete_mesh = get_concrete_mesh()
return concrete_mesh if not concrete_mesh.empty else mesh
return mesh
def _shard_map_lowering_shardy(
ctx: mlir.LoweringRuleContext, in_nodes,
jaxpr: core.Jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma):
axis_ctx = ctx.module_context.axis_context
in_avals_ = [v.aval for v in jaxpr.invars]
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
# Nested `ManualComputationOp`s must only refer to the new manual axes, not
# all existing ones. Grab the newly-added manual axes.
shardy_manual_axes = manual_axes - axis_ctx.manual_axes
else:
shardy_manual_axes = manual_axes
new_axis_context = sharding_impls.SPMDAxisContext(
_get_spmdaxis_ctx_mesh(mesh), manual_axes)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()]
num_tokens = len(tokens)
manual_axes = order_wrt_mesh(mesh, shardy_manual_axes)
if prod([mesh.shape[a] for a in manual_axes]) == 1:
# No need for a `ManualComputationOp` if all manual axes are size 1.
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
out_nodes, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr, ctx.name_stack,
mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)),
(), *in_nodes,
dim_var_values=ctx.dim_var_values,
const_lowering=ctx.const_lowering,
outer_traceback=_jax.Traceback())
ctx.set_tokens_out(tokens_out)
return out_nodes
in_shardings = list(
map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes),
in_specs, ctx.avals_in))
const_args_and_avals = core.jaxpr_const_args(jaxpr)
const_args, const_avals = util.unzip2(const_args_and_avals)
num_const_args = len(const_args)
const_arg_values = mlir.flatten_ir_values(
mlir.ir_constants(c, const_lowering=ctx.const_lowering, aval=aval)
for c, aval in const_args_and_avals
)
# TODO(necula,yashkatariya): how to construct consts shardy shardings from
# consts that can be ndarray or jax.Array?
const_args_shardings = [
_shardy_shard_map_sharding(ctx, mesh, manual_axes, P(), core.typeof(c))
for c in const_args]
num_dim_vars = len(ctx.dim_var_values)
in_shardings = (
[_get_token_sharding(ctx, mesh)] * (num_tokens + num_dim_vars) +
const_args_shardings + in_shardings)
in_shardings = sharding_impls.SdyArrayList(in_shardings).build()
out_shardings = list(
map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes),
out_specs, ctx.avals_out))
out_shardings = [
_get_token_sharding(ctx, mesh)] * num_tokens + out_shardings
out_shardings = sharding_impls.SdyArrayList(out_shardings).build()
output_types = ([hlo.TokenType.get()] * num_tokens +
mlir.flatten_ir_types(map(mlir.aval_to_ir_types, ctx.avals_out)))
args = (*ctx.dim_var_values, *tokens, *const_arg_values, *in_nodes)
manual_computation_op = sdy.ManualComputationOp(
output_types, mlir.flatten_ir_values(args), in_shardings, out_shardings,
sdy.ManualAxesAttr.get([ir.StringAttr.get(i) for i in manual_axes]))
dim_var_types = [
mlir.aval_to_ir_type(core.ShapedArray((), dtypes.default_int_dtype()))
] * num_dim_vars
token_types = [hlo.TokenType.get()] * num_tokens
const_arg_types = mlir.flatten_ir_types(map(mlir.aval_to_ir_types, const_avals))
in_types = mlir.flatten_ir_types(map(mlir.aval_to_ir_types, in_avals_))
block = ir.Block.create_at_start(
manual_computation_op.body,
(*dim_var_types, *token_types, *const_arg_types, *in_types))
with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes),
config._check_vma(check_vma)):
dim_var_values, token_arg_values, const_arg_values, in_args = util.split_list(
block.arguments, [num_dim_vars, num_tokens, num_const_args])
out_nodes_, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr, ctx.name_stack,
mlir.TokenSet(zip(ctx.tokens_in.effects(), token_arg_values)),
(), *in_args,
dim_var_values=dim_var_values,
const_lowering={
(id(c), aval): ca
for c, aval, ca in zip(const_args, const_avals, const_arg_values)
},
outer_traceback=_jax.Traceback())
sdy.ReturnOp(
mlir.flatten_ir_values(
it.chain((v for _, v in tokens_out.items()), out_nodes_)
)
)
num_tokens = len(tokens_out.effects())
tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip(
ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens])))
ctx.set_tokens_out(tokens_out)
return manual_computation_op.results[num_tokens:]
def _shard_map_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
jaxpr: core.Jaxpr, mesh, in_specs, out_specs,
check_vma, manual_axes):
if config.use_shardy_partitioner.value:
return _shard_map_lowering_shardy(
ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma)
in_avals_ = [v.aval for v in jaxpr.invars]
out_avals_ = [x.aval for x in jaxpr.outvars]
in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs,
ctx.avals_in, in_avals_, in_nodes)
new_axis_context = sharding_impls.SPMDAxisContext(
_get_spmdaxis_ctx_mesh(mesh), manual_axes)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
out_nodes_, tokens_out = mlir.call_lowering(
"shmap_body", pe.close_jaxpr(jaxpr), None, sub_ctx, in_avals_,
out_avals_, ctx.tokens_in, *in_nodes_,
dim_var_values=ctx.dim_var_values,
const_lowering=ctx.const_lowering,
arg_names=map(_pspec_mhlo_attrs, in_specs, in_avals_),
result_names=map(_pspec_mhlo_attrs, out_specs, out_avals_))
ctx.set_tokens_out(tokens_out)
return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_specs,
out_avals_, ctx.avals_out, out_nodes_)
mlir.register_lowering(shard_map_p, _shard_map_lowering)
def _make_scoped_manual_sharding(ctx, mesh, spec):
axis_ctx = ctx.module_context.axis_context
mesh = mesh.abstract_mesh
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
mesh = mesh.update_axis_types(
{a: AxisType.Manual for a in axis_ctx.manual_axes})
return NamedSharding(mesh, spec)
def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec,
aval_in, aval_out, x):
if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1:
return x
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
ns = sharding_impls.physical_sharding(aval_in, ns)
aval_in = core.physical_aval(aval_in)
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
unspecified = (set(range(aval_in.ndim))
if len(manual_axes) < len(mesh.axis_names) else set())
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
unspecified_dims=unspecified)
manual_proto = pxla.manual_proto(
aval_in, manual_axes | set(mesh.manual_axes), mesh)
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto,
unspecified)
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec,
aval_in, aval_out, x):
if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1:
return x
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
ns = sharding_impls.physical_sharding(aval_out, ns)
aval_out = core.physical_aval(aval_out)
unspecified = (set(range(aval_in.ndim))
if len(manual_axes) < len(mesh.axis_names) else set())
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
aval_in = core.physical_aval(aval_in)
manual_proto = pxla.manual_proto(
aval_in, manual_axes | set(mesh.manual_axes), mesh)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto,
unspecified_dims=unspecified)
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
unspecified)
def _pspec_mhlo_attrs(spec, aval: core.AbstractValue) -> str:
if isinstance(aval, core.ShapedArray):
names = _spec_to_names(spec)
return str(map(names.get, range(aval.ndim)))
return ''
# Eager evaluation
def get_mesh_from_args(args_flat, mesh):
for a in args_flat:
if (hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding)
and not a.sharding.mesh.is_scalar): # pyrefly: ignore[missing-attribute]
if a.sharding.mesh.shape_tuple != mesh.shape_tuple:
aval = core.shaped_abstractify(a)
raise ValueError(
f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not"
" match the mesh shape passed to shard_map "
f" {mesh.shape_tuple} for shape {aval.str_short()}")
mesh = a.sharding.mesh
if isinstance(mesh, AbstractMesh):
raise ValueError(
"Please pass `jax.Array`s with a `NamedSharding` as input to"
" `shard_map` when passing `AbstractMesh` to the mesh argument.")
assert isinstance(mesh, Mesh)
return mesh
def _spec_to_vma(spec):
return frozenset(p for s in spec if s is not None
for p in (s if isinstance(s, tuple) else (s,)))
def _mat_to_spec(mesh, mat):
return P(order_wrt_mesh(mesh, mat.varying), unreduced=mat.unreduced,
reduced=mat.reduced)
def _spec_to_mat(spec) -> core.ManualAxisType:
return core.ManualAxisType(varying=_spec_to_vma(spec),
unreduced=spec.unreduced, reduced=spec.reduced)
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs,
check_vma, manual_axes, debug_info):
del prim
if isinstance(mesh, AbstractMesh):
concrete_mesh = get_concrete_mesh()
mesh = concrete_mesh if not concrete_mesh.empty else mesh
mesh = get_mesh_from_args(args, mesh)
cur_mesh = get_abstract_mesh()
args_ = map(partial(_unmatch_spec, mesh, check_vma, cur_mesh, manual_axes),
in_specs, args)
in_mat = map(_spec_to_mat, in_specs)
outs, out_specs, out_mat = _run_shmap(fun, mesh, manual_axes, args_, in_mat,
check_vma)
out_avals = outs.map(lambda x: core.mapped_aval(x.shape[0], 0, core.typeof(x)))
_check_names(out_specs, out_avals)
if check_vma:
_check_mats(mesh, out_specs, out_avals)
src_pspecs = tuple(_mat_to_spec(mesh, m) for m in out_mat)
else:
src_pspecs = tuple(P(order_wrt_mesh(mesh, manual_axes))
for _ in range(len(out_mat)))
dst_pspecs = out_specs
return outs.map3(partial(_match_spec, mesh, check_vma, manual_axes),
src_pspecs, dst_pspecs)
core.EvalTrace.process_shard_map = _shard_map_impl
def _run_shmap_lu(f, mesh, manual_axes, args, mats, check_vma):
assert not mesh.manual_axes
trace = ShardMapTrace(mesh, manual_axes, check_vma)
in_tracers = map(partial(ShardMapTracer, trace), mats, args)
inner_mesh = _as_manual_mesh(mesh, manual_axes)
with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes),
use_abstract_mesh(inner_mesh), config._check_vma(check_vma)):
ans = f.call_wrapped(*in_tracers)
outs, out_mat = unzip2(map(trace.to_val_mat_pair, ans))
return outs, out_mat
def _run_shmap(f, mesh, manual_axes, args, mats, check_vma):
assert not mesh.manual_axes
trace = ShardMapTrace(mesh, manual_axes, check_vma)
in_tracers = map(partial(ShardMapTracer, trace), mats, args)
inner_mesh = _as_manual_mesh(mesh, manual_axes)
with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes),
use_abstract_mesh(inner_mesh), config._check_vma(check_vma)):
ans, out_specs = f(*in_tracers).unpack_aux()
outs, outs_mat = ans.map(trace.to_val_mat_pair).unzip2()
return outs, out_specs, list(outs_mat)
def _unmatch_spec2(mesh, prev_manual, spec, x) -> JaxType:
with (core.eval_context(), api.disable_jit(False),
use_abstract_mesh(mesh.abstract_mesh)):
return api.jit(HashablePartial(_unmatch2, mesh, prev_manual, spec))(x)
def _unmatch2(mesh, prev_manual, spec, x):
src = P(order_wrt_mesh(mesh, prev_manual), *spec)
newly_manual = _spec_to_vma(spec)
dst = P(order_wrt_mesh(mesh, prev_manual | newly_manual))
return shard_map(lambda x: x, in_specs=src, out_specs=dst,
axis_names=prev_manual | newly_manual)(x)
def _match_spec2(mesh, prev_manual, spec, x) -> JaxType:
with (core.eval_context(), api.disable_jit(False),
use_abstract_mesh(mesh.abstract_mesh)):
return api.jit(HashablePartial(_match2, mesh, prev_manual, spec))(x)
def _match2(mesh, prev_manual, spec, x):
newly_manual = _spec_to_vma(spec)
src = P(order_wrt_mesh(mesh, prev_manual | newly_manual))
dst = P(order_wrt_mesh(mesh, prev_manual), *spec)
return shard_map(lambda x: x, in_specs=src, out_specs=dst,
axis_names=prev_manual | newly_manual)(x)
def _unmatch_spec(mesh: Mesh, check_vma, context_mesh, manual_axes, in_spec,
x: JaxType) -> JaxType:
with (core.eval_context(), api.disable_jit(False),
use_abstract_mesh(context_mesh)):
return api.jit(HashablePartial(_unmatch, mesh, check_vma, in_spec,
manual_axes))(x)
def _unmatch(mesh, check_vma, in_spec, manual_axes, x):
if check_vma:
used_axes = _spec_to_vma(in_spec)
dst = P(order_wrt_mesh(mesh, used_axes), unreduced=in_spec.unreduced,
reduced=in_spec.reduced)
else:
dst = P(mesh.axis_names)
check_vma = False
return shard_map(_add_singleton, mesh=mesh, in_specs=(in_spec,),
out_specs=dst, check_vma=check_vma, axis_names=manual_axes)(x)
def _check_names(specs, avals: FlatTree) -> None:
fail = [a if isinstance(sp, P) and sp and len(sp) > a.ndim else no_fail
for sp, a in zip(specs, avals)]
if any(f is not no_fail for f in fail):
raise _SpecError(fail, avals.tree)
class _SpecError(Exception):
pass
def _check_mats(mesh, specs, avals):
fail = [a.mat if isinstance(sp, P) and not _valid_repeats(mesh, a.mat, sp)
else no_fail for sp, a in zip(specs, avals)]
if any(f is not no_fail for f in fail):
raise _RepError(fail, avals.tree)
class _RepError(Exception):
pass
def _match_spec(mesh: Mesh, check_vma, manual_axes, x: JaxType,
src_pspec: PartitionSpec, dst_pspec: PartitionSpec) -> JaxType:
fn = HashablePartial(_match, mesh, check_vma, manual_axes, src_pspec,
dst_pspec)
with core.eval_context(), api.disable_jit(False):
if set(mesh.axis_names) == manual_axes:
return api.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x)
return api.jit(fn)(x)
def _match(mesh, check_vma, manual_axes, src_pspec, dst_pspec, x):
return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec,
out_specs=dst_pspec, check_vma=check_vma,
axis_names=manual_axes)(x)
def _rem_singleton(x): return lax.squeeze(x, [0])
def _add_singleton(x): return lax.expand_dims(x, [0])
def _maybe_check_special(outs):
if not config.debug_nans.value and not config.debug_infs.value: return
bufs = [s.data for leaf in tree_leaves(outs)
for s in getattr(leaf, 'addressable_shards', [])]
try:
dispatch.check_special('shard_map', bufs)
except api_util.InternalFloatingPointError as e:
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None
class ShardMapTrace(core.Trace):
__slots__ = ("mesh", "manual_axes", "check", "amesh")
mesh: Mesh # outer concrete or abstract mesh
manual_axes: frozenset[AxisName]
check: bool
def __init__(self, mesh, manual_axes, check):
super().__init__()
self.mesh = mesh
self.manual_axes = manual_axes
self.check = check
self.amesh = mesh.abstract_mesh
def to_val_mat_pair(self, val):
if isinstance(val, ShardMapTracer):
return val.val, val.mat
elif isinstance(val, Tracer):
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
else:
val_ = _unmatch_spec(self.mesh, self.check, self.amesh, self.manual_axes,
P(), val)
return val_, core.empty_mat
def process_primitive(self, prim, tracers, params, /):
in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers))
if self.check:
out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params)
out_avals = tuple(out_avals) if type(out_avals) is list else out_avals
out_mat = tree_map(lambda a: a.mat, out_avals)
in_specs = tuple(map(partial(_mat_to_spec, self.mesh), in_mat))
out_specs = tree_map(partial(_mat_to_spec, self.mesh), out_mat)
else:
out_mat = core.empty_mat
in_specs = out_specs = P(order_wrt_mesh(self.mesh, self.manual_axes))
eager_rule = eager_rules.get(prim)
if eager_rule:
out_vals = eager_rule(self.mesh, *in_vals, **params)
else:
f = HashablePartial(
_prim_applier, prim, self.check, tuple(params.items()), self.mesh,
self.manual_axes, in_specs, out_specs)
with (core.eval_context(), api.disable_jit(False), config.debug_nans(False),
config.debug_infs(False), use_abstract_mesh(self.amesh),
sharding_impls._internal_use_concrete_mesh(self.mesh)):
out_vals = api.jit(f)(*in_vals)
_maybe_check_special(out_vals)
if prim.multiple_results:
out_mat = (out_mat if isinstance(out_mat, (list, tuple))
else [out_mat] * len(out_vals))
return map(partial(ShardMapTracer, self), out_mat, out_vals)
return ShardMapTracer(self, out_mat, out_vals)
def process_shard_map(self, prim, fun, args, mesh, in_specs,
check_vma, manual_axes, debug_info):
# Check consistency between outer and inner shmaps on explicitly passed
# mesh and check_vma.
if isinstance(mesh, Mesh):
if mesh != self.mesh: raise Exception
del mesh
if check_vma != self.check: # TODO(mattjj): add check in jit path
raise Exception
del check_vma
in_vals, in_mats = unzip2(map(self.to_val_mat_pair, args))
if any(m.unreduced or m.reduced for m in in_mats):
raise NotImplementedError(
"Eager shard_map + unreduced/reduced + partial manual is not"
" implemented. Please wrap your shard_map in `jax.jit`.")
trace = ShardMapTrace(self.mesh, manual_axes | self.manual_axes, self.check)
in_vals_ = [_unmatch_spec2(self.mesh, self.manual_axes, spec, x)
for x, spec in zip(in_vals, in_specs)]
# TODO(yashkatariya): Handle unreduced/reduced correctly.
in_mats_ = [core.ManualAxisType(varying=mat.varying | _spec_to_vma(s))
for mat, s in zip(in_mats, in_specs)]
in_tracers = map(partial(ShardMapTracer, trace), in_mats_, in_vals_)
inner_mesh = _as_manual_mesh(self.mesh, manual_axes | self.manual_axes)
with (core.set_current_trace(trace), _extend_axis_env(self.mesh, manual_axes),
use_abstract_mesh(inner_mesh)):
ans_aux = fun(*in_tracers)
ans, out_specs = ans_aux.unpack_aux()
out_vals_, out_mats_ = ans.map(trace.to_val_mat_pair).unzip2()
out_vals = out_vals_.map2(
lambda x, spec: _match_spec2(self.mesh, self.manual_axes, spec, x), out_specs)
# TODO(yashkatariya): Handle unreduced/reduced correctly.
out_mats = [core.ManualAxisType(varying=mat.varying - _spec_to_vma(spec))
for mat, spec in zip(out_mats_, out_specs)]
return out_vals.map2(lambda val, vma: ShardMapTracer(self, vma, val),
out_mats)
def process_call(self, call_primitive, fun, tracers, params, /):
raise NotImplementedError(
f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't "
"yet supported. Put a `jax.jit` around the `shard_map`-decorated "
"function, and open a feature request at "
"https://github.com/jax-ml/jax/issues !")
def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros):
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
del prim, jvp, symbolic_zeros
in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers))
out_vals, out_mat = _run_shmap_lu(fun, self.mesh, self.manual_axes, in_vals,
in_mat, self.check)
return map(partial(ShardMapTracer, self), out_mat, out_vals)
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, /, *, out_trees,
symbolic_zeros):
if symbolic_zeros:
msg = ("custom_vjp symbolic_zeros support with shard_map is not "
"implemented; please open an issue at "
"https://github.com/jax-ml/jax/issues")
raise NotImplementedError(msg)
del prim, fwd, bwd, out_trees, symbolic_zeros
in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers))
out_vals, out_mat = _run_shmap_lu(fun, self.mesh, self.manual_axes, in_vals,
in_mat, self.check)
return map(partial(ShardMapTracer, self), out_mat, out_vals)
class ShardMapTracer(core.Tracer[ShardMapTrace]):
mat: core.ManualAxisType
val: JaxType
def __init__(self, trace, mat, val):
assert isinstance(mat, core.ManualAxisType)
aval = core.typeof(val)
mat = (mat if trace.check else
core.ManualAxisType(varying=trace.manual_axes))
size = prod(trace.mesh.shape[n] for n in mat.varying)
out = core.mapped_aval(size, 0, aval)
manual_mesh = _as_manual_mesh(trace.amesh, trace.manual_axes)
spec = core.modify_spec_for_auto_manual(out.sharding.spec, manual_mesh) # pyrefly: ignore[missing-attribute]
new_sharding = NamedSharding(manual_mesh, spec)
mat_out = mat if trace.check else core.empty_mat
computed_aval = out.update(sharding=new_sharding, manual_axis_type=mat_out)
super().__init__(trace, computed_aval)
self.mat = mat
self.val = val
def to_concrete_value(self):
if self._trace.check and self.mat.vur == frozenset():
with (core.eval_context(), use_abstract_mesh(self._trace.amesh),
sharding_impls._internal_use_concrete_mesh(self._trace.mesh)):
return core.to_concrete_value(self.val[0])
else:
return None
def __str__(self) -> str:
pb_names = set(self._trace.mesh.axis_names) - self.mat.vur
self = pvary(self, tuple(pb_names))
with (core.eval_context(), use_abstract_mesh(self._trace.amesh),
sharding_impls._internal_use_concrete_mesh(self._trace.mesh)):
blocks = list(self.val)
mesh = self._trace.mesh
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
return '\n'.join(
f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n"
for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks))
__repr__ = __str__ # for debuggers, like `p x`
def _prim_applier(prim, check_vma, params_tup, concrete_mesh, manual_axes,
in_specs, out_specs, *args):
def apply(*args):
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
return tree_map(_add_singleton, outs)
out_specs = list(out_specs) if type(out_specs) is tuple else out_specs
return shard_map(apply, mesh=concrete_mesh, in_specs=in_specs,
out_specs=out_specs, check_vma=check_vma,
axis_names=manual_axes)(*args)
eager_rules: dict[core.Primitive, Callable] = {}
def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics):
del mesh, srcs, copy_semantics
for device in devices:
if device is not None:
raise ValueError("device_put with explicit device not allowed within "
f"shard_map-decorated functions, but got device {device}")
return xs
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
def _ref_raise_valueerror(*args, **kwargs):
raise ValueError(
"Eager shard_map cannot return a `jax.Ref`. Please wrap"
" your shard_map in `jax.jit`.")
eager_rules[core.ref_p] = _ref_raise_valueerror
eager_rules[core.empty_ref_p] = _ref_raise_valueerror
# Batching
def used_axis_names(spec):
return _spec_to_mat(spec).vur
def _shard_map_batch(
trace: batching.BatchTrace, prim: core.Primitive, fun: Callable,
in_tracers: Sequence[batching.BatchTracer], mesh: Mesh,
in_specs, check_vma: bool, manual_axes: frozenset,
debug_info) -> Sequence[batching.BatchTracer]:
in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers))
spmd_axis_name = trace.axis_data.spmd_name
explicit_mesh_axis = trace.axis_data.explicit_mesh_axis
if spmd_axis_name is not None:
used = {n for spec in in_specs for n in used_axis_names(spec)}
if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
new_in_specs = [
sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_axis_name)
for sp, d in zip(in_specs, in_dims)]
new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name)
new_axis_data = batching.AxisData(
trace.axis_data.name, new_size, trace.axis_data.spmd_name,
trace.axis_data.explicit_mesh_axis)
elif explicit_mesh_axis is not None:
used = {n for spec in in_specs for n in used_axis_names(spec)}
if set(explicit_mesh_axis) & used:
raise ValueError("vmapped away explicit mesh axis cannot appear in "
"shard_map in_specs")
new_in_specs = [
sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
for sp, d in zip(in_specs, in_dims)]
new_axis_data = trace.axis_data
else:
new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
for sp, d in zip(in_specs, in_dims)]
new_axis_data = trace.axis_data
def fun_batched(*args):
ans_aux, out_dims = batching.batch_subtrace_2(
fun, trace.tag, new_axis_data, tuple(in_dims), args)
ans, out_specs = ans_aux.unpack_aux()
new_out_specs = _batch_out_specs(spmd_axis_name, explicit_mesh_axis,
out_dims, out_specs)
return ans.with_aux(out_dims).with_aux(tuple(new_out_specs))
new_params = dict(mesh=mesh, in_specs=new_in_specs, check_vma=check_vma,
manual_axes=manual_axes, debug_info=debug_info)
# TODO(yashkatariya): Remove remove_explicit_mesh_axis_names when vmap
# mesh ctx is correctly set.
with (core.set_current_trace(trace.parent_trace),
core.remove_explicit_mesh_axis_names(trace.axis_data.explicit_mesh_axis)):
out_vals = prim.bind(*in_vals, subfuns=(fun_batched,), **new_params)
make_tracer = partial(batching.BatchTracer, trace,
source_info=source_info_util.current())
out_vals, out_dims = out_vals.unpack_aux()
return out_vals.map2(make_tracer, out_dims)
batching.BatchTrace.process_shard_map = _shard_map_batch
def _batch_out_specs(spmd_name, explicit_mesh_axis, dims, out_specs):
if spmd_name is not None:
used = {n for spec in out_specs for n in used_axis_names(spec)}
if not config.disable_vmap_shmap_error.value and set(spmd_name) & used:
raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs")
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name)
for sp, d in zip(out_specs, dims)]
elif explicit_mesh_axis is not None:
used = {n for spec in out_specs for n in used_axis_names(spec)}
if set(explicit_mesh_axis) & used:
raise ValueError("vmapped away explicit mesh axis cannot appear in "
"shard_map out_specs")
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
for sp, d in zip(out_specs, dims)]
else:
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
for sp, d in zip(out_specs, dims)]
# Autodiff
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_specs,
check_vma, manual_axes, debug_info):
debug_info = debug_info.with_unknown_names()
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
which_nz = [ type(t) is not ad.Zero for t in tangents]
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
args, in_zeros_tree = tree_flatten((primals, tangents))
tangent_in_specs = [sp.to_tangent_spec() for sp, nz in zip(in_specs, which_nz) if nz]
def f_jvp(*primals_and_nz_tangents_flat):
primals, tangents = tree_unflatten(in_zeros_tree, primals_and_nz_tangents_flat)
tangents = [ad.p2tz(p) if t is None else t for p, t in zip(primals, tangents)]
primals_out_aux, tangents_out = ad.jvp_subtrace_2(f, trace.tag, primals, tangents)
primals_out_ft, out_ax = primals_out_aux.unpack_aux()
which_nz_out = [type(r) is not ad.Zero for r in tangents_out]
tangent_out_specs = [s.to_tangent_spec() for s, nz in zip(out_ax, which_nz_out) if nz]
new_out_specs = (*out_ax, *tangent_out_specs)
tangents_out = [None if not nz else t for t, nz in zip(tangents_out, which_nz_out)]
tangents_out_ft = FlatTree.flatten(list(tangents_out))
out_primals_tangents = FlatTree.pack((primals_out_ft, tangents_out_ft))
return out_primals_tangents.with_aux(which_nz_out).with_aux(new_out_specs)
params = dict(mesh=mesh, in_specs=(*in_specs, *tangent_in_specs),
check_vma=check_vma, manual_axes=manual_axes,
debug_info=debug_info.with_unknown_names())
avals = [typeof(x) for x in args]
result = shard_map_p.bind_with_trace(
trace.parent_trace, tuple(args), avals, dict(params, subfuns=(f_jvp,)))
pt_out, which_nz_out = result.unpack_aux()
primal_out, nz_tangents_out = pt_out.unpack()
tangents_stack = list(nz_tangents_out)[::-1]
make_tracer = lambda p, nz: ad.JVPTracer(trace, p, tangents_stack.pop()) if nz else p
tracers_out = primal_out.map2(make_tracer, which_nz_out)
assert not tangents_stack
return tracers_out
ad.JVPTrace.process_shard_map = _shard_map_jvp
def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
f: Callable, tracers, mesh, in_specs,
check_vma, manual_axes, debug_info):
tracers = map(trace.to_jaxpr_tracer, tracers)
in_pvals = [t.pval for t in tracers]
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
unk_in_specs, known_in_specs = pe.partition_list(in_knowns, in_specs)
in_avals_sharded = map(partial(shard_aval, mesh, manual_axes, check_vma),
unk_in_specs, in_avals)
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
def f_pe(*in_consts):
in_avals_, in_consts_ = iter(in_avals_sharded), iter(in_consts)
in_pvals = [pe.PartialVal.known(next(in_consts_)) if known else
pe.PartialVal.unknown(next(in_avals_)) for known in in_knowns]
sentinel = object()
assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel
jaxpr, fwd_data = pe.trace_to_subjaxpr_nounits_fwd2(
f, trace.tag, debug_info.with_unknown_names(), False, in_pvals)
(in_fwds, out_fwds, out_pvals, res, env) = fwd_data
which = [f1 is None and f2 is None and not v.aval.shape
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
res = [lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in res]
out_pvals, out_specs = out_pvals.unpack_aux()
out_knowns, _, out_consts = pe.partition_pvals(out_pvals)
res_avals = [typeof(r) for r in res]
_, out_known_specs = pe.partition_list(out_knowns, out_specs)
res_specs = [a.nospec(mesh, check_vma, all_names) for a in res_avals]
new_out_specs = (*out_known_specs, *res_specs)
ft = out_pvals.map(lambda _: None)
ans_ft = FlatTree.flatten((out_consts, res))
aux = (in_fwds, out_fwds, out_knowns, res_avals, jaxpr, env, out_specs, new_out_specs, ft)
return ans_ft.with_aux(aux).with_aux(new_out_specs)
known_params = dict(mesh=mesh, in_specs=(*known_in_specs,),
check_vma=check_vma, manual_axes=manual_axes,
debug_info=debug_info.with_unknown_names())
avals = [typeof(x) for x in in_consts]
out = shard_map_p.bind_with_trace(trace.parent_trace, tuple(in_consts), avals,
dict(known_params, subfuns=(f_pe,)))
outs, (in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env, out_specs, new_out_specs, ft) = out.unpack_aux()
out_consts, non_fwd_res = outs.unflatten()
assert not jaxpr.constvars
unk_out_specs, _ = pe.partition_list(out_knowns, out_specs)
res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res)
# TODO make res_avals be the full set, not just the non-fwd ones
res_avals_iter = iter(res_avals)
res_specs = []
for f1, f2 in zip(in_fwd, out_fwd):
if f1 is not None:
res_specs.append(known_in_specs[f1])
elif f2 is not None:
res_specs.append(new_out_specs[f2])
else:
raval = next(res_avals_iter)
res_specs.append(raval.nospec(mesh, check_vma, all_names))
env_specs = [_repspec(typeof(e)) for e in env]
unk_in_specs = (*res_specs, *env_specs, *unk_in_specs)
const_tracers = map(trace.new_instantiated_const, res)
env_tracers = map(trace.to_jaxpr_tracer, env)
unk_arg_tracers = [t for t in tracers if not t.is_known()]
out_avals_sharded = [v.aval for v in jaxpr.outvars]
unk_params = dict(mesh=mesh, in_specs=unk_in_specs,
out_specs=tuple(unk_out_specs),
jaxpr=jaxpr.replace(debug_info=jaxpr.debug_info.with_unknown_names()),
check_vma=check_vma, manual_axes=manual_axes)
out_avals = map(partial(unshard_aval, mesh, check_vma), unk_out_specs,
out_avals_sharded)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers),
out_tracers, shard_map_p, unk_params,
effs, source_info_util.current())
for t in out_tracers: t.recipe = eqn
results = merge_lists(out_knowns, out_tracers, out_consts)
return ft.update(results)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
def _shard_map_linearize(trace, shard_map_p, f: Callable,
tracers, mesh, in_specs, check_vma,
manual_axes, debug_info):
debug_info = debug_info.with_unknown_names()
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
def f_lin(*primals):
res, ans_aux, lin_data = ad.linearize_subtrace_2(
f, trace.is_vjp, trace.tag, nzs_in, debug_info, primals)
primals_out, out_specs = ans_aux.unpack_aux()
res_avals, _, _, _, in_fwd, out_fwd = lin_data
res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd)
if f1 is None and f2 is None]
res_specs = [a.nospec(mesh, check_vma, all_names) for a in res_avals]
new_out_specs = (*res_specs, *out_specs)
res = [lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
for x in res]
res_and_primal = FlatTree.pack((FlatTree.flatten(res), primals_out))
return res_and_primal.with_aux((lin_data, out_specs)).with_aux(new_out_specs)
fwd_params = dict(
mesh=mesh, in_specs=in_specs,
check_vma=check_vma, manual_axes=manual_axes, debug_info=debug_info)
avals = [typeof(x) for x in primals]
all_results_aux = shard_map_p.bind_with_trace(
trace.parent_trace, tuple(primals), avals, dict(fwd_params, subfuns=(f_lin,)))
all_results, (lin_data, out_specs) = all_results_aux.unpack_aux()
res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = lin_data
non_fwd_res, primals_out = all_results.unpack()
residuals = subs_list2(in_fwd, out_fwd, primals, (*primals_out,), non_fwd_res)
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)]
with (_extend_axis_env(mesh, manual_axes),
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)),
config._check_vma(check_vma)):
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd)
if f1 is None and f2 is None]
res_avals_iter = iter(res_avals2)
res_specs = [in_specs[f1] if f1 is not None else out_specs[f2] if f2 is not None
else next(res_avals_iter).nospec(mesh, check_vma, all_names)
for f1, f2 in zip(in_fwd, out_fwd)]
assert next(res_avals_iter, None) is None
env_specs = [_repspec(typeof(e)) for e in env]
new_in_specs = (*res_specs, *env_specs,
*(s.to_tangent_spec() for s, nz in zip(in_specs, nzs_in) if nz))
tangent_out_specs = tuple(s.to_tangent_spec() for s, nz in zip(out_specs, nzs_out) if nz)
tangent_params = dict(
mesh=mesh, in_specs=new_in_specs,
check_vma=check_vma, manual_axes=manual_axes, debug_info=lin_jaxpr.debug_info)
# TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here
def f_tangent(*args):
ans = core.eval_jaxpr(lin_jaxpr, (), *args)
return FlatTree.flatten(ans).with_aux(tangent_out_specs)
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
args = (*residuals, *env, *nz_tangents_in)
avals = [typeof(x) for x in args]
nz_tangents_out = shard_map_p.bind_with_trace(
trace.tangent_trace, args, avals,
dict(tangent_params, subfuns=(f_tangent,)))
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else ad.p2tz(primal)
for nz, primal in zip(nzs_out, primals_out)]
return primals_out.map3(partial(ad.maybe_linearize_tracer, trace), nzs_out, tangents_out)
ad.LinearizeTrace.process_shard_map = _shard_map_linearize
def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]):
def fun(*res_and_args):
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
return core.eval_jaxpr(jaxpr, res, *args)
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
for v, w in zip(jaxpr.constvars, which)]
in_avals = FlatTree.flatten(((*res_avals, *[v.aval for v in jaxpr.invars]), {}))
closed_jaxpr, _ = pe.trace_to_jaxpr(fun, in_avals, debug_info=jaxpr.debug_info)
closed_jaxpr, _ = pe.separate_consts(closed_jaxpr)
return closed_jaxpr.jaxpr
def _unmentioned2(mesh: Mesh, spec, manual_axes: frozenset[AxisName]
) -> list[AxisName]:
# We use a filtered-down version of unmentioned to avoid defensive-psum over
# more chips than required in the transpose-no-check-vma case.
name_set = _spec_to_vma(spec) | spec.unreduced
return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes)
if n not in name_set]
def _shard_map_transpose(out_cts, *args,
jaxpr: core.Jaxpr, mesh, in_specs, out_specs,
check_vma, manual_axes):
mb_div = lambda x, y: x / y if y != 1 else x
out_cts = [
ad.Zero(shard_aval(mesh, manual_axes, check_vma, sp, x.aval))
if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, sp, manual_axes))))
for sp, x in zip(out_specs, out_cts)
]
args = [x if type(x) is not ad.UndefinedPrimal else
ad.UndefinedPrimal(shard_aval(mesh, manual_axes, check_vma, sp, x.aval))
for sp, x in zip(in_specs, args)]
all_args, in_tree = tree_flatten((out_cts, tuple(args)))
def fun_trans_callable(*right_flat):
right_cts, primals_or_undefs = tree_unflatten(in_tree, right_flat)
left_cts = ad.backward_pass(jaxpr, False, (), primals_or_undefs, right_cts)
left_cts = [x if type(x) is ad.Zero or check_vma
else lax_parallel.psum(x, tuple(_unmentioned2(mesh, sp, manual_axes)))
for sp, x in zip(in_specs, left_cts)]
left_specs_nz = tuple(
s.to_ct_spec() for ct, s in zip(left_cts, in_specs) if ct is not None
and type(ct) is not ad.Zero)
return FlatTree.flatten(left_cts).with_aux(left_specs_nz)
dbg = jaxpr.debug_info.with_unknown_names()
new_in_specs = (
[s.to_ct_spec() for s, x in zip(out_specs, out_cts) if type(x) is not ad.Zero] +
[s for s, x in zip(in_specs, args) if type(x) is not ad.UndefinedPrimal])
left_ct = shard_map_p.bind(
*all_args, subfuns=(fun_trans_callable,), mesh=mesh, in_specs=tuple(new_in_specs),
check_vma=check_vma, manual_axes=manual_axes, debug_info=dbg)
left_cts = left_ct.unflatten()
return [ad.Zero(unshard_aval(mesh, check_vma, sp.to_ct_spec(), x.aval))
if type(x) is ad.Zero else x for sp, x in zip(in_specs, left_cts)]
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
# Remat
def _partial_eval_jaxpr_custom_rule(
saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool],
inst_in: Sequence[bool], eqn: core.JaxprEqn
) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool],
list[core.Var]]:
jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh']
check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes']
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))):
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
num_out_primals = len(jaxpr_known.outvars) - num_res
in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:]
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals])
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_binders_known))
if not isinstance(b, core.DropVar)}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
mesh = eqn.params['mesh']
with (_extend_axis_env(mesh, manual_axes),
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)),
config._check_vma(check_vma)):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names)
ins_known, _ = partition_list(unks_in, eqn.invars)
_, ins_staged = partition_list(inst_in, eqn.invars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
nv = core.gensym()
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
lns = lambda a: a.nospec(mesh, check_vma, all_names)
residuals, staged_in_res_specs = unzip2(
[(nv(unshard_aval(mesh, check_vma, (rn := lns(var.aval)), var.aval)), rn)
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w])
out_res_specs_known = [var.aval.nospec(mesh, check_vma, all_names) # pyrefly: ignore[missing-attribute]
for var, w in zip(res_vars, which) if w]
params_known, params_staged = _pe_custom_params(
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd,
out_res_specs_known, staged_in_res_specs,
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
eqn.primitive, params_known, jaxpr_known.effects,
eqn.source_info, eqn.ctx)
full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals)
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
eqn.primitive, params_staged,
jaxpr_staged.effects, eqn.source_info, eqn.ctx)
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}]
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + list(residuals)
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
_partial_eval_jaxpr_custom_rule
def _add_reshapes(which: Sequence[bool],
jaxpr_known: core.Jaxpr,
jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]:
# add singleton axes to residuals which are from jaxpr_known and are scalars
which_ = [w and not v.aval.shape # pyrefly: ignore[missing-attribute]
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
if not any(which_): return jaxpr_known, jaxpr_staged
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
def known(*args):
out = core.eval_jaxpr(jaxpr_known, (), *args)
out_known, res = split_list(out, [len(out) - sum(which)])
res = [_add_singleton(x) if not x.shape else x for x in res]
return [*out_known, *res]
avals_in = tuple(v.aval for v in jaxpr_known.invars)
avals_in = FlatTree.flatten((avals_in, {}))
jaxpr_known_closed, _ = pe.trace_to_jaxpr(
known, avals_in, debug_info=jaxpr_known.debug_info)
def staged(*args):
res_, ins = split_list(args, [len(which)])
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
avals_in = (*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]])
avals_in = FlatTree.flatten((avals_in, {}))
jaxpr_staged_closed, _ = pe.trace_to_jaxpr(
staged, avals_in, debug_info=jaxpr_staged.debug_info)
return jaxpr_known_closed.jaxpr, jaxpr_staged_closed.jaxpr
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs,
params_known, params_staged):
# prune inputs to jaxpr_known according to unks_in
in_specs_known, _ = partition_list(unks_in, params_known['in_specs'])
_, out_specs_known = partition_list(kept_outs_known, params_known['out_specs'])
out_specs_known = out_specs_known + out_res_specs_known
assert len(out_specs_known) == len(params_known['jaxpr'].outvars)
new_params_known = dict(params_known, in_specs=tuple(in_specs_known),
out_specs=tuple(out_specs_known))
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
_, in_specs_staged = partition_list(inst_in, params_staged['in_specs'])
iter_staged = iter(staged_in_res_specs)
res_specs = [in_specs_known[f1] if f1 is not None else
out_specs_known[f2] if f2 is not None else
next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)]
in_specs_staged = res_specs + in_specs_staged
_, out_specs_staged = partition_list(kept_outs_staged, params_staged['out_specs'])
new_params_staged = dict(params_staged, in_specs=tuple(in_specs_staged),
out_specs=tuple(out_specs_staged))
return new_params_known, new_params_staged
# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names_except_spmd(
mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]:
axis_env = core.get_axis_env()
spmd_names = axis_env.spmd_axis_names
return tuple(name for name in mesh.axis_names
if name not in spmd_names and name in manual_axes)
def _all_newly_manual_mesh_names(
mesh: BaseMesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]:
axis_env = core.get_axis_env()
vmap_spmd_names = set(axis_env.spmd_axis_names)
if not (ctx_mesh := get_abstract_mesh()).empty:
mesh = ctx_mesh
already_manual_names = set(ctx_mesh.manual_axes)
else:
# TODO(mattjj): remove this mechanism when we revise mesh scopes
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names
return tuple(name for name in mesh.axis_names
if (name not in vmap_spmd_names | already_manual_names and
name in manual_axes))
# DCE
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
mesh = eqn.params["mesh"]
manual_axes = eqn.params["manual_axes"]
check_vma = eqn.params["check_vma"]
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))):
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects:
return used_inputs, None
else:
_, in_specs = partition_list(used_inputs, eqn.params['in_specs'])
_, out_specs = partition_list(used_outputs, eqn.params['out_specs'])
new_params = dict(eqn.params, jaxpr=jaxpr, in_specs=tuple(in_specs),
out_specs=tuple(out_specs))
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[x for x, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn
pe.dce_rules[shard_map_p] = _shard_map_dce
# Mutable arrays / refs
@discharge.register_discharge_rule(shard_map_p)
def _shard_map_discharge(
in_avals, out_avals, *args, jaxpr, mesh, in_specs, out_specs, check_vma,
manual_axes):
inner_mesh = _as_manual_mesh(mesh, manual_axes)
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
config._check_vma(check_vma)):
discharged_jaxpr, discharged_consts = discharge.discharge_state(jaxpr, ())
if discharged_consts: raise NotImplementedError
del discharged_consts
ref_specs = [spec for spec, invar in zip(in_specs, jaxpr.invars)
if isinstance(invar.aval, AbstractRef)]
params = dict(jaxpr=discharged_jaxpr, out_specs=(*out_specs, *ref_specs))
params_ = shard_map_p.get_bind_params(params)
f, = params_.pop('subfuns')
debug_info = params_['debug_info']
out_and_ref_vals = shard_map_p.bind(
*args, subfuns=(f,), mesh=mesh, in_specs=in_specs, manual_axes=manual_axes,
debug_info=debug_info, check_vma=check_vma)
out_vals, ref_vals = split_list(out_and_ref_vals, [len(jaxpr.outvars)])
ref_vals_ = iter(ref_vals)
new_invals = [next(ref_vals_) if isinstance(a, AbstractRef) else None
for a in in_avals]
assert next(ref_vals_, None) is None
return new_invals, out_vals
def _repspec(aval):
return aval.nospec(empty_abstract_mesh, False, ())
# ----------------------- top level collectives --------------------------------
def _top_level_ag(x, aval, out_sh_, multi_dim):
assert aval.sharding.mesh.are_all_axes_explicit, aval.sharding.mesh
out_sh = canonicalize_sharding(out_sh_, "top_level_all_gather")
if out_sh is None:
raise ValueError(
f'out_sharding passed to top_level_all_gather cannot be {out_sh_}. It'
' should be a PartitionSpec or NamedSharding.')
if aval.sharding.mesh != out_sh.mesh:
raise ValueError(
f'Input sharding mesh {aval.sharding.mesh} should be equal to'
f' out_sharding mesh {out_sh.mesh}')
in_spec = aval.sharding.spec
out_spec = out_sh.spec._normalized_spec_for_aval(len(in_spec))
if config.remove_size_one_mesh_axis_from_type.value:
out_spec = core.remove_size_one_mesh_axis(out_spec, out_sh.mesh)
def f_shmap(x):
# Maybe this can just be 1 AG where we gather in a new dim and then do
# AG(new_dim) -> reshape -> transpose -> reshape but it might be expensive.
count = 0
for axis, (i, o) in enumerate(zip(in_spec, out_spec)):
if i == o:
continue
if not multi_dim and count > 0:
raise ValueError(
"multiple dimensions cannot be all_gathered since multi_dim=False"
f" passed to `top_level_all_gather`. Got {in_spec=} and {out_spec=}")
count += 1
if i is None:
raise ValueError(
f"top_level_all_gather doesn't allow input {aval} to be unsharded"
f" on dimension {axis} when {out_spec=}.")
i = i if isinstance(i, tuple) else (i,)
o = o if o is None or isinstance(o, tuple) else (o,)
if o is not None and i[:len(o)] != o:
raise ValueError(
'top_level_all_gather maintains `top_level_all_gather(x, ...) == x`'
f" property. The {in_spec=} and {out_spec=} don't satisfy this"
f' property. Please change your out_spec of array dimension {axis} so'
" that it's a prefix of in_spec")
axis_name = i if o is None else i[-len(o):]
x = lax_parallel.all_gather(x, axis_name=axis_name, axis=axis,
tiled=True, to='reduced')
return x
return api.jit(shard_map(f_shmap, out_specs=out_spec))(x)
def top_level_all_gather(xs, out_sharding, *, multi_dim: bool = False):
if not get_abstract_mesh().are_all_axes_explicit:
raise ValueError(
'top_level_all_gather works when all mesh axes of context mesh are'
f' explicit. Got {get_abstract_mesh()}')
x_flat, treedef = tree_flatten(xs)
out_sharding_flat = api_util.flatten_axis_resources(
"top_level_all_gather out_sharding", treedef, out_sharding,
tupled_args=True)
x_avals_flat = [core.typeof(x) for x in x_flat]
out_flat = [_top_level_ag(x, aval, sh, multi_dim)
for x, aval, sh in zip(x_flat, x_avals_flat, out_sharding_flat)]
return tree_unflatten(treedef, out_flat)