2073 lines
94 KiB
Python
2073 lines
94 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Hashable, Sequence, Set
|
|
import enum
|
|
from functools import partial
|
|
import inspect
|
|
import itertools as it
|
|
from math import prod
|
|
import operator as op
|
|
from typing import Any, TypeVar, Union, cast, overload
|
|
|
|
import numpy as np
|
|
|
|
from jax._src import api
|
|
from jax._src import api_util
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import dtypes
|
|
from jax._src import sharding_impls
|
|
from jax._src import source_info_util
|
|
from jax._src import traceback_util
|
|
from jax._src import util
|
|
from jax._src.core import order_wrt_mesh
|
|
from jax._src.core import pvary, Tracer, typeof, shard_aval, unshard_aval
|
|
from jax._src.mesh import (AbstractMesh, Mesh, BaseMesh, AxisType,
|
|
use_abstract_mesh, get_abstract_mesh,
|
|
get_concrete_mesh, empty_abstract_mesh)
|
|
from jax._src.lax import lax, parallel as lax_parallel
|
|
from jax._src.lib import _jax
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import hlo, sdy
|
|
from jax._src.sharding_impls import (NamedSharding, PartitionSpec,
|
|
canonicalize_sharding)
|
|
from jax._src.util import (HashablePartial, unzip2, partition_list, merge_lists,
|
|
split_list, subs_list2, fun_name as util_fun_name)
|
|
from jax._src.state import discharge
|
|
from jax._src.state.types import AbstractRef
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.interpreters import ad
|
|
from jax._src.tree_util import (
|
|
broadcast_prefix, keystr, prefix_errors, generate_key_paths, tree_flatten,
|
|
tree_leaves, tree_map, tree_structure, tree_unflatten, KeyPath, PyTreeDef, FlatTree)
|
|
|
|
P = PartitionSpec
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
# API
|
|
|
|
Specs = Any # PyTree[PartitionSpec]
|
|
AxisName = Hashable
|
|
|
|
class InferFromArgs:
|
|
def __repr__(self):
|
|
return "jax.sharding.Infer"
|
|
|
|
def __reduce__(self):
|
|
return (_get_default_infer, ())
|
|
|
|
Infer = InferFromArgs()
|
|
|
|
def _get_default_infer():
|
|
return Infer
|
|
|
|
|
|
F = TypeVar("F", bound=Callable)
|
|
G = TypeVar("G", bound=Callable)
|
|
|
|
|
|
@overload
|
|
def shard_map(f: F, /, *, out_specs: Specs,
|
|
in_specs: Specs | None | InferFromArgs = ...,
|
|
mesh: Mesh | AbstractMesh | None = ...,
|
|
axis_names: Set[AxisName] = ..., check_vma: bool = ...) -> F:
|
|
...
|
|
|
|
@overload
|
|
def shard_map(f: None = None, /, *, out_specs: Specs,
|
|
in_specs: Specs | None | InferFromArgs = ...,
|
|
mesh: Mesh | AbstractMesh | None = ...,
|
|
axis_names: Set[AxisName] = ..., check_vma: bool = ...
|
|
) -> Callable[[G], G]:
|
|
...
|
|
|
|
# See https://github.com/jax-ml/jax/pull/30753 to understand why `in_specs`
|
|
# defaults to `Infer`.
|
|
def shard_map(f: F | None = None, /, *, out_specs: Specs,
|
|
in_specs: Specs | None | InferFromArgs = Infer,
|
|
mesh: Mesh | AbstractMesh | None = None,
|
|
axis_names: Set[AxisName] = frozenset(), check_vma: bool = True
|
|
) -> F | Callable[[G], G]:
|
|
"""Map a function over shards of data using a mesh of devices.
|
|
|
|
See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html.
|
|
|
|
Args:
|
|
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
|
|
takes as input a shard of the mapped-over arguments and produces a shard
|
|
of the output.
|
|
mesh: (optional, default None) a ``jax.sharding.Mesh`` representing the
|
|
array of devices over which to shard the data and on which to execute
|
|
instances of ``f``. The names of the ``Mesh`` can be used in collective
|
|
communication operations in ``f``. If mesh is None, it will be inferred
|
|
from the context which can be set via `jax.set_mesh` context
|
|
manager.
|
|
in_specs: (optional, default `Infer`) a pytree with
|
|
``jax.sharding.PartitionSpec`` instances as leaves, with a tree structure
|
|
that is a tree prefix of the args tuple to be mapped over. Similar to
|
|
``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the
|
|
corresponding argument (or subtree of arguments) should be sharded along
|
|
the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a
|
|
``mesh`` axis name at a position expresses sharding the corresponding
|
|
argument array axis along that positional axis; not mentioning an axis
|
|
name expresses replication.
|
|
If ``Infer``, all mesh axes must be of type
|
|
`Explicit`, in which case the in_specs are inferred from the argument types.
|
|
If ``None``, inputs will be treated as static.
|
|
out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree
|
|
structure that is a tree prefix of the output of ``f``. Each
|
|
``PartitionSpec`` represents how the corresponding output shards should be
|
|
concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name
|
|
at a position expresses concatenation of that mesh axis's shards along the
|
|
corresponding positional axis; not mentioning a ``mesh`` axis name
|
|
expresses a promise that the output values are equal along that mesh axis,
|
|
and that rather than concatenating only a single value should be produced.
|
|
axis_names: (optional, default set()) set of axis names from ``mesh`` over
|
|
which the function ``f`` is manual. If empty, ``f``, is manual
|
|
over all mesh axes.
|
|
check_vma: (optional) boolean (default True) representing whether to enable
|
|
additional validity checks and automatic differentiation optimizations.
|
|
The validity checks concern whether any mesh axis names not mentioned in
|
|
``out_specs`` are consistent with how the outputs of ``f`` are replicated.
|
|
|
|
Returns:
|
|
A callable representing a mapped version of ``f``, which accepts positional
|
|
arguments corresponding to those of ``f`` and produces output corresponding
|
|
to that of ``f``.
|
|
"""
|
|
kwargs = dict(mesh=mesh, in_specs=in_specs, out_specs=out_specs,
|
|
axis_names=axis_names, check_vma=check_vma)
|
|
if f is None:
|
|
return lambda g: _shard_map(g, **kwargs)
|
|
return _shard_map(f, **kwargs)
|
|
|
|
|
|
@overload
|
|
def smap(f: F, /, *,
|
|
in_axes: int | None | InferFromArgs | tuple[Any, ...] = ...,
|
|
out_axes: Any, axis_name: AxisName) -> F:
|
|
...
|
|
|
|
@overload
|
|
def smap(f: None = None, /, *,
|
|
in_axes: int | None | InferFromArgs | tuple[Any, ...] = ...,
|
|
out_axes: Any, axis_name: AxisName
|
|
) -> Callable[[G], G]:
|
|
...
|
|
|
|
def smap(f: F | None = None, /, *,
|
|
in_axes: int | None | InferFromArgs | tuple[Any, ...] = Infer,
|
|
out_axes: Any, axis_name: AxisName
|
|
) -> F | Callable[[G], G]:
|
|
"""Single axis shard_map that maps a function `f` one axis at a time.
|
|
|
|
Args:
|
|
f: Callable to be mapped. Each application of ``f``, or "instance" of ``f``,
|
|
takes as input a shard of the mapped-over arguments and produces a shard
|
|
of the output.
|
|
in_axes: (optional) An integer, None, or sequence of values specifying which
|
|
input array axes to map over. If not specified, `smap` will try to infer
|
|
the axes from the arguments only under `Explicit` mode.
|
|
An integer or ``None`` indicates which array axis to map over for all
|
|
arguments (with ``None`` indicating not to map any axis), and a tuple
|
|
indicates which axis to map for each corresponding positional argument.
|
|
Axis integers must be in the range ``[-ndim, ndim)`` for each array,
|
|
where ``ndim`` is the number of dimensions (axes) of the corresponding
|
|
input array.
|
|
out_axes: An integer, None, or (nested) standard Python container
|
|
(tuple/list/dict) thereof indicating where the mapped axis should appear
|
|
in the output.
|
|
axis_name: ``mesh`` axis name over which the function ``f`` is manual.
|
|
|
|
Returns:
|
|
A callable representing a mapped version of ``f``, which accepts positional
|
|
arguments corresponding to those of ``f`` and produces output corresponding
|
|
to that of ``f``.
|
|
"""
|
|
kwargs = dict(in_axes=in_axes, out_axes=out_axes, axis_name=axis_name)
|
|
if f is None:
|
|
return lambda g: _smap(g, **kwargs)
|
|
return _smap(f, **kwargs)
|
|
|
|
def _smap(f: F, *, in_axes: int | None | InferFromArgs | tuple[Any, ...],
|
|
out_axes: Any, axis_name: AxisName) -> F:
|
|
if isinstance(axis_name, (list, tuple)):
|
|
raise TypeError(
|
|
f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}")
|
|
if (in_axes is not None and in_axes is not Infer and
|
|
not isinstance(in_axes, (int, tuple))):
|
|
raise TypeError(
|
|
"smap in_axes must be an int, None, jax.sharding.Infer, or a tuple of"
|
|
" entries corresponding to the positional arguments passed to the"
|
|
f" function, but got {in_axes}.")
|
|
if (in_axes is not Infer and
|
|
not all(isinstance(l, int) for l in tree_leaves(in_axes))):
|
|
raise TypeError(
|
|
"smap in_axes must be an int, None, jax.sharding.Infer, or (nested)"
|
|
f" container with those types as leaves, but got {in_axes}.")
|
|
if not all(isinstance(l, int) for l in tree_leaves(out_axes)):
|
|
raise TypeError("smap out_axes must be an int, None, or (nested) container "
|
|
f"with those types as leaves, but got {out_axes}.")
|
|
|
|
in_specs = (Infer if in_axes is Infer else
|
|
tree_map(partial(_axes_to_pspec, axis_name), in_axes,
|
|
is_leaf=lambda x: x is None))
|
|
out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes,
|
|
is_leaf=lambda x: x is None)
|
|
return _shard_map(f, mesh=None, in_specs=in_specs, out_specs=out_specs,
|
|
axis_names={axis_name}, check_vma=True, _smap=True)
|
|
|
|
|
|
@partial(traceback_util.api_boundary, repro_api_name="jax.shard_map")
|
|
def _shard_map(f: F, *, mesh: Mesh | AbstractMesh | None,
|
|
in_specs: Specs, out_specs: Specs, axis_names: Set[AxisName],
|
|
check_vma: bool, _smap: bool = False) -> F:
|
|
if not callable(f):
|
|
raise TypeError("shard_map requires a callable for its first argument, "
|
|
f"but got {f} of type {type(f)}.")
|
|
|
|
@util.wraps(f)
|
|
@traceback_util.api_boundary
|
|
def wrapped(*args):
|
|
nonlocal mesh, axis_names
|
|
mesh, axis_names = _shmap_checks(
|
|
mesh, axis_names, in_specs, out_specs, _smap)
|
|
dbg = api_util.debug_info("shard_map", f, args, {})
|
|
args_flat = FlatTree.flatten(args)
|
|
api_util.check_no_transformed_refs_args(lambda: dbg, args_flat)
|
|
|
|
try:
|
|
in_specs_flat = broadcast_prefix(
|
|
in_specs, args, is_leaf=lambda x: x is None)
|
|
except ValueError:
|
|
e, *_ = prefix_errors(in_specs, args)
|
|
raise e('shard_map in_specs') from None
|
|
|
|
if (in_specs is Infer and
|
|
all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
|
|
arg_s = [typeof(a).sharding for a in args_flat]
|
|
assert all(i is Infer for i in in_specs_flat), in_specs_flat
|
|
in_specs_flat = [_manual_spec(axis_names, s.spec, mesh) for s in arg_s]
|
|
|
|
in_tree = args_flat.tree
|
|
which_dyn = [s is not None for s in in_specs_flat]
|
|
static_args = [x for x, dyn in zip(args_flat, which_dyn) if not dyn]
|
|
dyn_args = [x for x, dyn in zip(args_flat, which_dyn) if dyn]
|
|
in_specs_flat = tuple(s for s, dyn in zip(in_specs_flat, which_dyn) if dyn)
|
|
dyn_argnums = [i for i, dyn in enumerate( which_dyn) if dyn]
|
|
_check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums,
|
|
in_specs_flat, dyn_args)
|
|
|
|
# TODO(yashkatariya): Add support for partial manual
|
|
mesh_axis_names_wo_vmap = (
|
|
frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names)
|
|
if (mesh_axis_names_wo_vmap == axis_names and
|
|
all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
|
|
for a, s in zip(dyn_args, in_specs_flat):
|
|
if not isinstance(s, P): continue
|
|
arg_aval = typeof(a)
|
|
s = s._normalized_spec_for_aval(arg_aval.ndim)
|
|
if config.remove_size_one_mesh_axis_from_type.value:
|
|
s = core.remove_size_one_mesh_axis(s, mesh)
|
|
if arg_aval.sharding.spec != s:
|
|
raise ValueError(
|
|
f"in_specs passed to shard_map: {s} does not match the specs of"
|
|
f" the input: {arg_aval.sharding.spec} for arg: {typeof(a)}."
|
|
" `in_specs` is an optional argument so you can omit specifying"
|
|
" it and shard_map will infer the in_specs from the arguments."
|
|
" If you want to reshard your inputs, you can use `jax.reshard`"
|
|
" on the arguments and then pass those args to shard_map.")
|
|
|
|
if (dbg.arg_names is not None and len(dyn_args) != len(dbg.arg_names)):
|
|
dbg = dbg.with_unknown_names()
|
|
|
|
def f_wrapped(*dyn_args):
|
|
dyn_args_iter = iter(dyn_args)
|
|
static_args_iter = iter(static_args)
|
|
all_args = [next(dyn_args_iter) if dyn else next(static_args_iter)
|
|
for dyn in which_dyn]
|
|
args = tree_unflatten(in_tree, all_args)
|
|
ans = f(*args)
|
|
ans_ft = FlatTree.flatten(ans)
|
|
try:
|
|
out_specs_flat = tuple(broadcast_prefix(out_specs, ans))
|
|
except ValueError:
|
|
e, *_ = prefix_errors(out_specs, ans)
|
|
raise e('shard_map out_specs') from None
|
|
def add_implicit_pvary_and_unreduced(val, spec):
|
|
if not isinstance(spec, P): return val
|
|
aval = typeof(val)
|
|
val = pvary(val, tuple(_spec_to_vma(spec) - aval.mat.varying))
|
|
return (lax_parallel.vary_unreduced_cast(val, tuple(unreduced))
|
|
if (unreduced := spec.unreduced - aval.mat.unreduced) else val)
|
|
if check_vma:
|
|
ans_ft = ans_ft.map2(add_implicit_pvary_and_unreduced, out_specs_flat)
|
|
return ans_ft.with_aux(out_specs_flat)
|
|
|
|
try:
|
|
out_ft = shard_map_p.bind(
|
|
*dyn_args, subfuns=(f_wrapped,), mesh=mesh, in_specs=in_specs_flat,
|
|
check_vma=check_vma, manual_axes=axis_names, debug_info=dbg)
|
|
except _SpecError as e:
|
|
fails, out_tree = e.args
|
|
msg = _spec_rank_error(SpecErrorType.out, f, out_tree, out_specs, fails)
|
|
if any(fail is not no_fail and not fail.shape for fail in fails):
|
|
msg += (" In particular, for rank 0 outputs which are not constant "
|
|
"over the mesh, add at least one (singleton) axis to them so "
|
|
"that they can be concatenated using out_specs.")
|
|
raise ValueError(msg) from None
|
|
except _RepError as e:
|
|
fails, out_tree, = e.args
|
|
msg = _inout_vma_error(f, mesh, out_tree, out_specs, fails)
|
|
raise ValueError(msg) from None
|
|
return out_ft.unflatten()
|
|
return cast(F, wrapped)
|
|
|
|
|
|
def _axes_to_pspec(axis_name, axis):
|
|
if axis is None:
|
|
return P()
|
|
return P(*[None] * axis + [axis_name])
|
|
|
|
|
|
def _shmap_checks(mesh, axis_names, in_specs, out_specs, _smap):
|
|
if mesh is None:
|
|
mesh = get_abstract_mesh()
|
|
if mesh.empty:
|
|
raise ValueError(
|
|
"The context mesh cannot be empty. Use"
|
|
" `jax.set_mesh(mesh)` to enter into a mesh context")
|
|
else:
|
|
ctx_mesh = get_abstract_mesh()
|
|
if not ctx_mesh.empty and mesh.abstract_mesh != ctx_mesh:
|
|
raise ValueError(
|
|
f"The context mesh {ctx_mesh} should match the mesh passed to"
|
|
f" shard_map {mesh}")
|
|
|
|
if not isinstance(mesh, (Mesh, AbstractMesh)):
|
|
raise TypeError("shard_map requires a `jax.sharding.Mesh` or a "
|
|
"`jax.sharding.AbstractMesh` instance for its "
|
|
f"second argument, but got {mesh} of type {type(mesh)}.")
|
|
if mesh.empty:
|
|
raise ValueError(f"shard_map requires a non-empty mesh. Got {mesh}")
|
|
|
|
mesh_axis_names_wo_vmap = (
|
|
frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names
|
|
)
|
|
|
|
if not isinstance(axis_names, (frozenset, set)):
|
|
raise TypeError(
|
|
"`axis_names` argument of shard_map should be of type `frozenset` or"
|
|
f" `set`. Got type: {type(axis_names)}")
|
|
if isinstance(axis_names, set):
|
|
axis_names = frozenset(axis_names)
|
|
if not axis_names:
|
|
axis_names = mesh_axis_names_wo_vmap
|
|
if not axis_names.issubset(mesh_axis_names_wo_vmap):
|
|
raise ValueError(
|
|
f"jax.shard_map requires axis_names={axis_names} to be a subset of "
|
|
f"mesh.axis_names={mesh_axis_names_wo_vmap}")
|
|
|
|
if (in_specs is Infer and
|
|
not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)):
|
|
axis_types = ', '.join(str(mesh._name_to_type[a]) for a in axis_names)
|
|
if _smap:
|
|
msg = (f"in_axes was not specified when axis_name={axis_names} was of"
|
|
f" type {axis_types}")
|
|
else:
|
|
msg = ("shard_map in_specs argument must be a pytree of"
|
|
" `jax.sharding.PartitionSpec` instances, but it was `None` when"
|
|
f" {axis_names=} are of type {axis_types}")
|
|
raise TypeError(msg)
|
|
|
|
if in_specs is not Infer and in_specs is not None:
|
|
_check_specs(SpecErrorType.input, in_specs, axis_names)
|
|
_check_unreduced(SpecErrorType.input, mesh, axis_names, in_specs)
|
|
_check_specs(SpecErrorType.out, out_specs, axis_names)
|
|
_check_unreduced(SpecErrorType.out, mesh, axis_names, out_specs)
|
|
return mesh, axis_names
|
|
|
|
|
|
def _manual_spec(manual_axes, spec: P, mesh) -> P:
|
|
out: list[str | tuple[str | None, ...] | None] = []
|
|
s: str | None | tuple[str, ...]
|
|
for s in spec:
|
|
if s is None:
|
|
out.append(s)
|
|
elif isinstance(s, tuple):
|
|
temp = [p if p in manual_axes else None for p in s]
|
|
while temp and temp[-1] is None:
|
|
temp.pop()
|
|
if None in temp:
|
|
raise ValueError(f"Invalid spec: {spec}")
|
|
out.append(None if len(temp) == 0 else tuple(temp))
|
|
else:
|
|
out.append(s if s in manual_axes else None)
|
|
_check_unreduced(SpecErrorType.input, mesh, manual_axes, spec)
|
|
return P(*out, unreduced=spec.unreduced, reduced=spec.reduced)
|
|
|
|
|
|
# Error checking and messages
|
|
|
|
SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out'])
|
|
|
|
def _check_unreduced(error_type, mesh, manual_axes, specs):
|
|
from jax._src.hijax import HiPspec
|
|
prefix = 'in' if error_type == SpecErrorType.input else 'out'
|
|
full_manual = frozenset(mesh.axis_names) == manual_axes
|
|
specs_flat, _ = tree_flatten(specs)
|
|
for s in specs_flat:
|
|
if isinstance(s, HiPspec):
|
|
continue # TODO(mattjj,yashkatariya): add user validation method
|
|
if not s.unreduced and not s.reduced:
|
|
continue
|
|
if not full_manual:
|
|
raise NotImplementedError(
|
|
f"unreduced/reduced can only be passed to {prefix}_specs when"
|
|
" shard_map is in full manual mode. Got mesh axis names"
|
|
f" {mesh.axis_names}, manual_axes: {manual_axes}, specs: {s}. Please"
|
|
" file a bug at https://github.com/jax-ml/jax/issues.")
|
|
if not all(mesh._name_to_type[u] == AxisType.Explicit for u in s.unreduced):
|
|
raise ValueError(
|
|
f"unreduced in {prefix}_specs {s} can only be used when the mesh"
|
|
" passed to shard_map contains axis names all of type `Explicit`."
|
|
f" Got mesh {mesh}")
|
|
if not all(mesh._name_to_type[u] == AxisType.Explicit for u in s.reduced):
|
|
raise ValueError(
|
|
f"reduced in {prefix}_specs {s} can only be used when the mesh"
|
|
" passed to shard_map contains axis names all of type `Explicit`."
|
|
f" Got mesh {mesh}")
|
|
|
|
|
|
def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None:
|
|
from jax._src.hijax import HiPspec
|
|
if error_type == SpecErrorType.input and specs is None:
|
|
raise TypeError(
|
|
"shard_map in_specs argument must be a pytree of "
|
|
"`jax.sharding.PartitionSpec` instances, but it was None.\n"
|
|
"Instead of `in_specs=None`, did you mean `in_specs=P()`, "
|
|
"where `P = jax.sharding.PartitionSpec`?")
|
|
|
|
def check_spec(p):
|
|
if isinstance(p, HiPspec):
|
|
return True # TODO(mattjj,yashkatariya): add user validation method
|
|
if not isinstance(p, PartitionSpec):
|
|
return False
|
|
for names in p:
|
|
names = (names,) if not isinstance(names, tuple) else names
|
|
for name in names:
|
|
if name is not None and name not in manual_axes:
|
|
return False
|
|
return True
|
|
|
|
if all(check_spec(p) for p in tree_leaves(specs)):
|
|
return
|
|
prefix = 'in' if error_type == SpecErrorType.input else 'out'
|
|
msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, "
|
|
for key, x in generate_key_paths(specs) if not isinstance(x, P)]
|
|
if not msgs:
|
|
for key, p in generate_key_paths(specs):
|
|
for names in p:
|
|
names = (names,) if not isinstance(names, tuple) else names
|
|
for name in names:
|
|
if name is not None and name not in manual_axes:
|
|
msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}")
|
|
raise ValueError(
|
|
f"shard_map {prefix}_specs argument must refer to an axis "
|
|
f"marked as manual ({manual_axes}), but:\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n'
|
|
f"Check the {prefix}_specs values passed to shard_map.")
|
|
raise TypeError(
|
|
f"shard_map {prefix}_specs argument must be a pytree of "
|
|
f"`jax.sharding.PartitionSpec` instances, but:\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n'
|
|
f"Check the {prefix}_specs values passed to shard_map.")
|
|
|
|
class NoFail:
|
|
def __repr__(self):
|
|
return "NoFail()"
|
|
|
|
no_fail = NoFail()
|
|
|
|
def _check_specs_vs_args(
|
|
f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs,
|
|
dyn_argnums: Sequence[int], in_specs_flat: Sequence[P],
|
|
xs: Sequence) -> None:
|
|
in_avals = map(core.shaped_abstractify, xs)
|
|
fail = [a if isinstance(p, P) and len(p) > a.ndim else no_fail
|
|
for p, a in zip(in_specs_flat, in_avals)]
|
|
if any(f is not no_fail for f in fail):
|
|
fail = _expand_fail(in_tree, dyn_argnums, fail)
|
|
msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail)
|
|
raise ValueError(msg)
|
|
bad = lambda a, d, ns: a.shape[d] % prod(mesh.shape[n] for n in ns)
|
|
fail = [a if (isinstance(s, P) and
|
|
any(bad(a, d, ns) for d, ns in _spec_to_names(s).items()))
|
|
else no_fail for a, s in zip(in_avals, in_specs_flat)]
|
|
if any(f is not no_fail for f in fail):
|
|
fail = _expand_fail(in_tree, dyn_argnums, fail)
|
|
msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail)
|
|
raise ValueError(msg)
|
|
|
|
def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int],
|
|
fail: Sequence[core.ShapedArray | NoFail]
|
|
) -> list[core.ShapedArray | NoFail]:
|
|
fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves
|
|
for i, f in zip(dyn_argnums, fail):
|
|
fail_[i] = f
|
|
return fail_
|
|
|
|
def _spec_rank_error(
|
|
error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs,
|
|
fails: list[core.ShapedArray | NoFail]) -> str:
|
|
fun_name = util_fun_name(f)
|
|
if error_type == SpecErrorType.input:
|
|
prefix, base = 'in', 'the passed args'
|
|
ba = _try_infer_args(f, tree)
|
|
else:
|
|
prefix, base = 'out', f'{fun_name}(*args)'
|
|
ba = None
|
|
msgs = []
|
|
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
|
extra = ""
|
|
if error_type == SpecErrorType.input and ba is not None:
|
|
arg_key, *_ = fail_key
|
|
param_names, params = unzip2(
|
|
(name, param) for name, param in ba.signature.parameters.items()
|
|
if param.kind not in (inspect.Parameter.KEYWORD_ONLY,
|
|
inspect.Parameter.VAR_KEYWORD))
|
|
if (arg_key.idx >= len(params) or
|
|
params[arg_key.idx].kind == inspect.Parameter.VAR_POSITIONAL):
|
|
extra = (f", where args{arg_key} is the index "
|
|
f"{arg_key.idx - len(params) + 1} component "
|
|
f"of {fun_name}'s varargs parameter '{param_names[-1]}',")
|
|
else:
|
|
param_name = params[arg_key.idx]
|
|
extra = (f", where args{arg_key} is bound to {fun_name}'s "
|
|
f"parameter '{param_name}',")
|
|
msgs.append(
|
|
f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length "
|
|
f"{len(spec)}, but "
|
|
f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, "
|
|
f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})")
|
|
assert msgs
|
|
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
|
|
msg = (f"shard_map applied to the function '{fun_name}' was given an "
|
|
f"{prefix}_specs entry which is too long to be compatible with the "
|
|
f"corresponding {prefix}put value from the function:\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n' +
|
|
f"Entries in {prefix}_specs must be of length no greater than the "
|
|
f"number of axes in the corresponding {prefix}put value.\n\n"
|
|
f"Either revise the spec to be shorter, or modify '{fun_name}' so "
|
|
f"that its {prefix}puts have sufficient rank.")
|
|
if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)):
|
|
msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs "
|
|
"entry of `P()`, where `P = jax.sharding.PartitionSpec`.")
|
|
return msg
|
|
|
|
def _spec_divisibility_error(
|
|
f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs,
|
|
fails: list[core.ShapedArray | NoFail]) -> str:
|
|
ba = _try_infer_args(f, tree)
|
|
fun_name = getattr(f, '__name__', str(f))
|
|
msgs = []
|
|
for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails):
|
|
extra = ""
|
|
if ba is not None:
|
|
arg_key, *_ = fail_key
|
|
param_names, params = unzip2(
|
|
(name, param) for name, param in ba.signature.parameters.items()
|
|
if param.kind not in (inspect.Parameter.KEYWORD_ONLY,
|
|
inspect.Parameter.VAR_KEYWORD))
|
|
if (arg_key.idx >= len(params) or
|
|
params[arg_key.idx].kind == inspect.Parameter.VAR_POSITIONAL):
|
|
extra = (f", where args{arg_key} is the index "
|
|
f"{arg_key.idx - len(params) + 1} component "
|
|
f"of {fun_name}'s varargs parameter '{param_names[-1]}',")
|
|
else:
|
|
param_name = params[arg_key.idx]
|
|
extra = (f", where args{arg_key} is bound to {fun_name}'s "
|
|
f"parameter '{param_name}',")
|
|
names = _spec_to_names(spec)
|
|
for d, ns in names.items():
|
|
if aval.shape[d] % prod(mesh.shape[n] for n in ns):
|
|
axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'"
|
|
total = 'total ' if len(ns) > 1 else ''
|
|
sz = prod(mesh.shape[n] for n in ns)
|
|
msgs.append(
|
|
f"* the passed args{keystr(fail_key)} of shape {aval.str_short()}{extra} "
|
|
f"corresponds to in_specs{keystr(spec_key)} of value {spec}, "
|
|
f"which maps array axis {d} (of size {aval.shape[d]}) to mesh "
|
|
f"{axis} (of {total}size {sz}), but {sz} does not evenly divide "
|
|
f"{aval.shape[d]}")
|
|
assert msgs
|
|
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
|
|
msg = (f"shard_map applied to the function '{fun_name}' was given argument "
|
|
f"arrays with axis sizes that are not evenly divisible by the "
|
|
f"corresponding mesh axis sizes:\n\n"
|
|
f"The mesh given has shape {tuple(mesh.shape.values())} with "
|
|
f"corresponding axis names {mesh.axis_names}.\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n' +
|
|
f"Array arguments' axis sizes must be evenly divisible by the mesh "
|
|
f"axis or axes indicated by the corresponding elements of the "
|
|
f"argument's in_specs entry. Consider checking that in_specs are "
|
|
f"correct, and if so consider changing the mesh axis sizes or else "
|
|
f"padding the input and adapting '{fun_name}' appropriately.")
|
|
return msg
|
|
|
|
def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef,
|
|
specs: Specs, fails: list[core.ManualAxisType | NoFail]
|
|
) -> str:
|
|
fun_name = getattr(f, '__name__', str(f))
|
|
msgs = []
|
|
for (spec_key, spec), (fail_key, mat) in _iter_paths(tree, specs, fails):
|
|
unmentioned = _unmentioned(mesh, spec)
|
|
if len(unmentioned) > 1:
|
|
need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec))))
|
|
got_vma = ','.join(map(str, order_wrt_mesh(mesh, mat.varying)))
|
|
diff = ','.join(map(str, order_wrt_mesh(
|
|
mesh, [n for n in unmentioned if n in mat.varying])))
|
|
msgs.append(
|
|
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
|
|
f"corresponding output value is only varying across mesh axes "
|
|
f"{{{need_vma}}} and not {{{diff}}}, but it was inferred to be "
|
|
f"possibly varying over {{{got_vma}}}")
|
|
else:
|
|
need_rep_, = unmentioned
|
|
msgs.append(
|
|
f"* out_specs{keystr(spec_key)} is {spec} which implies that the "
|
|
f"corresponding output value is replicated across mesh axis "
|
|
f"'{need_rep_}', but could not infer replication over any axes")
|
|
assert msgs
|
|
if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point
|
|
msg = (f"shard_map applied to the function '{fun_name}' was given "
|
|
f"out_specs which require replication which can't be statically "
|
|
f"inferred given the mesh:\n\n"
|
|
f"The mesh given has shape {tuple(mesh.shape.values())} with "
|
|
f"corresponding axis names {mesh.axis_names}.\n\n"
|
|
+ '\n\n'.join(msgs) + '\n\n' +
|
|
"Check if these output values are meant to be replicated over those "
|
|
"mesh axes. If not, consider revising the corresponding out_specs "
|
|
"entries. If so, consider disabling the check by passing the "
|
|
"check_vma=False argument to `jax.shard_map`.")
|
|
return msg
|
|
|
|
def _unmentioned(mesh: Mesh | AbstractMesh, spec) -> list[AxisName]:
|
|
vur = _spec_to_mat(spec).vur
|
|
return [n for n in mesh.axis_names if n not in vur]
|
|
|
|
|
|
def _try_infer_args(f, tree):
|
|
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
|
|
try:
|
|
return inspect.signature(f).bind(*dummy_args)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
T = TypeVar('T')
|
|
def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail]
|
|
) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]:
|
|
failures = tree_unflatten(tree, fails)
|
|
failures_aug = generate_key_paths(failures)
|
|
specs_ = tree_unflatten(tree_structure(specs), map(Tup, generate_key_paths(specs)))
|
|
specs_aug = broadcast_prefix(specs_, failures, is_leaf=lambda x: x is None)
|
|
return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data)
|
|
in zip(specs_aug, failures_aug)
|
|
if s is not None and fail_data is not no_fail]
|
|
|
|
class Tup:
|
|
def __init__(self, vals): self.vals = vals
|
|
def __iter__(self): return iter(self.vals)
|
|
|
|
# Primitive
|
|
|
|
JaxType = Any
|
|
MaybeTracer = Union[JaxType, Tracer]
|
|
|
|
class ShardMapPrimitive(core.Primitive):
|
|
multiple_results = True
|
|
skip_canonicalization = True
|
|
|
|
def bind_with_trace(self, trace, args, avals, params, /):
|
|
fun, = params.pop('subfuns')
|
|
# fun returns a FlatTree containing a tuple of the user-level data and a flat,
|
|
# broadcasted out_specs wrapped in `Static`.
|
|
# The result of `bind_with_trace` is a `FlatTree` of tracer-like things and
|
|
# doesn't include the `Static` out_specs.
|
|
return trace.process_shard_map(shard_map_p, fun, args, **params)
|
|
|
|
def get_bind_params(self, params):
|
|
new_params = dict(params)
|
|
jaxpr = new_params.pop('jaxpr')
|
|
assert isinstance(jaxpr, core.Jaxpr)
|
|
axes = new_params.pop('out_specs')
|
|
def eval_jaxpr(*args):
|
|
result = core.eval_jaxpr(jaxpr, (), *args)
|
|
return FlatTree.flatten(result).with_aux(axes)
|
|
new_params['subfuns'] = (eval_jaxpr,)
|
|
new_params['debug_info'] = jaxpr.debug_info
|
|
return new_params
|
|
|
|
shard_map_p = ShardMapPrimitive('shard_map')
|
|
|
|
# Staging
|
|
|
|
@util.cache(max_size=256, trace_context_in_key=False)
|
|
def _as_manual_mesh(mesh, manual_axes: frozenset) -> AbstractMesh:
|
|
return mesh.abstract_mesh.update_axis_types(
|
|
{n: AxisType.Manual for n in manual_axes})
|
|
|
|
def _extend_axis_env(mesh, manual_axes):
|
|
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
|
|
if k in manual_axes])
|
|
|
|
def _shard_map_staging(
|
|
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: Callable,
|
|
args: Sequence[Any], *, mesh: Mesh,
|
|
in_specs, check_vma: bool, manual_axes: frozenset, debug_info
|
|
) -> FlatTree:
|
|
source_info = source_info_util.current()
|
|
|
|
inner_mesh = _as_manual_mesh(mesh, manual_axes)
|
|
in_avals = [typeof(arg) for arg in args]
|
|
in_avals_ = map(partial(shard_aval, mesh, manual_axes, check_vma),
|
|
in_specs, in_avals)
|
|
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
|
|
config._check_vma(check_vma)):
|
|
in_avals_flat_tree = FlatTree.flatten((in_avals_, {}))
|
|
jaxpr, out_data = pe.trace_to_jaxpr(
|
|
f, in_avals_flat_tree, debug_info, fun_returns_flat_tree=True,
|
|
requires_low=trace.requires_low)
|
|
out_avals_ft, out_specs = out_data.unpack_aux()
|
|
_check_names(out_specs, out_avals_ft)
|
|
if check_vma:
|
|
_check_mats(mesh, out_specs, out_avals_ft)
|
|
out_avals = [unshard_aval(mesh, check_vma, spec, aval)
|
|
for spec, aval in zip(out_specs, out_avals_ft)]
|
|
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
|
|
config._check_vma(check_vma)):
|
|
jaxpr, consts = pe.separate_consts(jaxpr)
|
|
in_specs_staged = (*(_repspec(typeof(c)) for c in consts), *in_specs)
|
|
if trace.requires_low:
|
|
in_specs_staged = tuple(lo_spec for hi_spec in in_specs_staged for lo_spec in hi_spec.to_lo())
|
|
out_specs = tuple(lo_spec for hi_spec in out_specs for lo_spec in hi_spec.to_lo())
|
|
params = dict(mesh=mesh, in_specs=in_specs_staged,
|
|
out_specs=out_specs, jaxpr=jaxpr.jaxpr,
|
|
check_vma=check_vma, manual_axes=manual_axes)
|
|
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
|
to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info)
|
|
const_tracers = map(to_jaxpr_tracer, consts)
|
|
trace.frame.is_high |= jaxpr.is_high
|
|
if trace.requires_low:
|
|
in_tracers = [to_jaxpr_tracer(loval) for arg in args
|
|
for loval in typeof(arg).lower_val(arg)]
|
|
out_avals_lo = [lo_aval for aval in out_avals for lo_aval in aval.lo_ty()]
|
|
else:
|
|
in_tracers = map(to_jaxpr_tracer, args)
|
|
out_avals_lo = out_avals
|
|
out = trace.emit_eqn([*const_tracers, *in_tracers], out_avals_lo, prim, params,
|
|
effs, source_info)
|
|
if trace.requires_low:
|
|
out = pe.raise_lo_outs(out_avals, out)
|
|
return out_avals_ft.update(out)
|
|
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
|
|
|
# TODO add underscore version, for direct-linearize to consume
|
|
|
|
def _spec_to_names(spec: PartitionSpec):
|
|
return {i: names if isinstance(names, tuple) else (names,)
|
|
for i, names in enumerate(spec) if names is not None}
|
|
|
|
def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma,
|
|
spec, aval: core.ShapedArray) -> core.ShapedArray:
|
|
assert isinstance(aval, core.ShapedArray)
|
|
if spec.unreduced != aval.sharding.spec.unreduced:
|
|
raise ValueError(
|
|
f"in_specs containing unreduced {spec} passed to shard_map should be"
|
|
" equal to the unreduced present on the in_aval"
|
|
f" {aval.str_short(True)}")
|
|
if spec.reduced != aval.sharding.spec.reduced:
|
|
raise ValueError(
|
|
f"in_specs containing reduced {spec} passed to shard_map should be"
|
|
f" equal to the reduced present on the in_aval {aval.str_short(True)}")
|
|
names = _spec_to_names(spec)
|
|
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
|
|
for i, sz in enumerate(aval.shape))
|
|
manual_mesh = _as_manual_mesh(mesh, manual_axes)
|
|
new_sharding = aval.sharding.update(
|
|
mesh=manual_mesh,
|
|
spec=core.modify_spec_for_auto_manual(aval.sharding.spec, manual_mesh))
|
|
vma = (_spec_to_vma(spec) if check_vma else frozenset()) | aval.mat.varying
|
|
unreduced = aval.sharding.spec.unreduced if check_vma else frozenset()
|
|
reduced = aval.sharding.spec.reduced if check_vma else frozenset()
|
|
mat = core.ManualAxisType(varying=vma, unreduced=unreduced, reduced=reduced)
|
|
return aval.update(shape=new_shape, sharding=new_sharding,
|
|
manual_axis_type=mat)
|
|
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
|
|
|
|
def _unshard_shaped_array(mesh: Mesh, check_vma, spec, aval: core.ShapedArray
|
|
) -> core.ShapedArray:
|
|
assert isinstance(aval, core.ShapedArray)
|
|
if check_vma and spec.unreduced != aval.mat.unreduced:
|
|
raise ValueError(
|
|
"out_specs passed to shard_map should be equal to the unreduced"
|
|
f" present on the out_aval. Got out_specs={spec} and"
|
|
f" out_aval={aval.str_short(True)}")
|
|
if check_vma and spec.reduced != aval.mat.reduced:
|
|
raise ValueError(
|
|
"out_specs passed to shard_map should be equal to the reduced present"
|
|
f" on the out_aval. Got out_specs={spec} and"
|
|
f" out_aval={aval.str_short(True)}")
|
|
names = _spec_to_names(spec)
|
|
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
|
for i, sz in enumerate(aval.shape))
|
|
names_spec = spec._normalized_spec_for_aval(aval.ndim)
|
|
if aval.ndim == 0:
|
|
out_spec = P(unreduced=spec.unreduced, reduced=spec.reduced)
|
|
else:
|
|
out_spec = []
|
|
for name_s, aval_s in zip(names_spec, aval.sharding.spec):
|
|
if name_s and not aval_s:
|
|
out_spec.append(name_s)
|
|
elif aval_s and not name_s:
|
|
out_spec.append(aval_s)
|
|
elif not name_s and not aval_s:
|
|
out_spec.append(None)
|
|
else:
|
|
assert name_s and aval_s
|
|
name_s = name_s if isinstance(name_s, tuple) else (name_s,)
|
|
aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,)
|
|
out_spec.append(name_s + aval_s)
|
|
out_spec = PartitionSpec(*out_spec, unreduced=spec.unreduced,
|
|
reduced=spec.reduced)
|
|
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
|
|
get_abstract_mesh())
|
|
new_sharding = NamedSharding(new_mesh, out_spec)
|
|
manual_axes = set(new_mesh.manual_axes)
|
|
vma = frozenset(v for v in aval.mat.varying if v in manual_axes)
|
|
# TODO(yashkatariya): Handle partial manual unreduced/reduced.
|
|
out_mat = core.ManualAxisType(varying=vma)
|
|
return aval.update(shape=new_shape, sharding=new_sharding,
|
|
manual_axis_type=out_mat)
|
|
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
|
|
|
|
# Type-checking
|
|
|
|
def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_specs, out_specs,
|
|
check_vma, manual_axes):
|
|
# TODO(mattjj,parkers): check auto
|
|
for v, x, in_spec in zip(jaxpr.invars, in_atoms, in_specs):
|
|
sharded_aval = shard_aval(mesh, manual_axes, check_vma, in_spec, x.aval)
|
|
if not core.typecompat(v.aval, sharded_aval):
|
|
raise core.JaxprTypeError("shard_map argument avals not compatible with "
|
|
"jaxpr binder avals and in_specs")
|
|
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
|
|
core.check_jaxpr(jaxpr)
|
|
if check_vma:
|
|
for v, os in zip(jaxpr.outvars, out_specs):
|
|
if isinstance(os, P) and not _valid_repeats(mesh, v.aval.mat, os):
|
|
raise core.JaxprTypeError(
|
|
"shard_map can't prove output is sufficiently replicated")
|
|
out_avals_sharded = [x.aval for x in jaxpr.outvars]
|
|
out_avals = map(partial(unshard_aval, mesh, check_vma), out_specs,
|
|
out_avals_sharded)
|
|
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
|
return out_avals, effs
|
|
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
|
|
|
|
|
|
def _valid_repeats(mesh: Mesh, mat: core.ManualAxisType, spec) -> bool:
|
|
um = set(_unmentioned(mesh, spec)) - set(mesh.manual_axes)
|
|
vur = mat.vur
|
|
if any(u in vur for u in um):
|
|
return False
|
|
return True
|
|
|
|
# Lowering
|
|
|
|
def _shardy_shard_map_sharding(
|
|
ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in
|
|
) -> sharding_impls.SdyArray:
|
|
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
|
|
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
|
ns = sharding_impls.physical_sharding(aval_in, ns)
|
|
aval_in = core.physical_aval(aval_in)
|
|
sdy_sharding = ns._to_sdy_sharding(aval_in.ndim)
|
|
if len(manual_axes) < len(mesh.axis_names):
|
|
for dim_sharding in sdy_sharding.dim_shardings:
|
|
dim_sharding.is_open = True
|
|
return sdy_sharding
|
|
|
|
|
|
def _get_token_sharding(
|
|
ctx: mlir.LoweringRuleContext, mesh
|
|
) -> sharding_impls.SdyArray:
|
|
ns = _make_scoped_manual_sharding(ctx, mesh, P())
|
|
return ns._to_sdy_sharding(0)
|
|
|
|
|
|
def _get_spmdaxis_ctx_mesh(mesh):
|
|
if isinstance(mesh, AbstractMesh):
|
|
concrete_mesh = get_concrete_mesh()
|
|
return concrete_mesh if not concrete_mesh.empty else mesh
|
|
return mesh
|
|
|
|
|
|
def _shard_map_lowering_shardy(
|
|
ctx: mlir.LoweringRuleContext, in_nodes,
|
|
jaxpr: core.Jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma):
|
|
axis_ctx = ctx.module_context.axis_context
|
|
in_avals_ = [v.aval for v in jaxpr.invars]
|
|
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
|
# Nested `ManualComputationOp`s must only refer to the new manual axes, not
|
|
# all existing ones. Grab the newly-added manual axes.
|
|
shardy_manual_axes = manual_axes - axis_ctx.manual_axes
|
|
else:
|
|
shardy_manual_axes = manual_axes
|
|
new_axis_context = sharding_impls.SPMDAxisContext(
|
|
_get_spmdaxis_ctx_mesh(mesh), manual_axes)
|
|
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
|
|
|
tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()]
|
|
num_tokens = len(tokens)
|
|
manual_axes = order_wrt_mesh(mesh, shardy_manual_axes)
|
|
if prod([mesh.shape[a] for a in manual_axes]) == 1:
|
|
# No need for a `ManualComputationOp` if all manual axes are size 1.
|
|
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
|
|
out_nodes, tokens_out = mlir.jaxpr_subcomp(
|
|
sub_ctx, jaxpr, ctx.name_stack,
|
|
mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)),
|
|
(), *in_nodes,
|
|
dim_var_values=ctx.dim_var_values,
|
|
const_lowering=ctx.const_lowering,
|
|
outer_traceback=_jax.Traceback())
|
|
ctx.set_tokens_out(tokens_out)
|
|
return out_nodes
|
|
|
|
in_shardings = list(
|
|
map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes),
|
|
in_specs, ctx.avals_in))
|
|
const_args_and_avals = core.jaxpr_const_args(jaxpr)
|
|
const_args, const_avals = util.unzip2(const_args_and_avals)
|
|
num_const_args = len(const_args)
|
|
const_arg_values = mlir.flatten_ir_values(
|
|
mlir.ir_constants(c, const_lowering=ctx.const_lowering, aval=aval)
|
|
for c, aval in const_args_and_avals
|
|
)
|
|
# TODO(necula,yashkatariya): how to construct consts shardy shardings from
|
|
# consts that can be ndarray or jax.Array?
|
|
const_args_shardings = [
|
|
_shardy_shard_map_sharding(ctx, mesh, manual_axes, P(), core.typeof(c))
|
|
for c in const_args]
|
|
|
|
num_dim_vars = len(ctx.dim_var_values)
|
|
in_shardings = (
|
|
[_get_token_sharding(ctx, mesh)] * (num_tokens + num_dim_vars) +
|
|
const_args_shardings + in_shardings)
|
|
in_shardings = sharding_impls.SdyArrayList(in_shardings).build()
|
|
|
|
out_shardings = list(
|
|
map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes),
|
|
out_specs, ctx.avals_out))
|
|
out_shardings = [
|
|
_get_token_sharding(ctx, mesh)] * num_tokens + out_shardings
|
|
out_shardings = sharding_impls.SdyArrayList(out_shardings).build()
|
|
|
|
output_types = ([hlo.TokenType.get()] * num_tokens +
|
|
mlir.flatten_ir_types(map(mlir.aval_to_ir_types, ctx.avals_out)))
|
|
|
|
args = (*ctx.dim_var_values, *tokens, *const_arg_values, *in_nodes)
|
|
manual_computation_op = sdy.ManualComputationOp(
|
|
output_types, mlir.flatten_ir_values(args), in_shardings, out_shardings,
|
|
sdy.ManualAxesAttr.get([ir.StringAttr.get(i) for i in manual_axes]))
|
|
|
|
dim_var_types = [
|
|
mlir.aval_to_ir_type(core.ShapedArray((), dtypes.default_int_dtype()))
|
|
] * num_dim_vars
|
|
token_types = [hlo.TokenType.get()] * num_tokens
|
|
const_arg_types = mlir.flatten_ir_types(map(mlir.aval_to_ir_types, const_avals))
|
|
in_types = mlir.flatten_ir_types(map(mlir.aval_to_ir_types, in_avals_))
|
|
block = ir.Block.create_at_start(
|
|
manual_computation_op.body,
|
|
(*dim_var_types, *token_types, *const_arg_types, *in_types))
|
|
|
|
with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes),
|
|
config._check_vma(check_vma)):
|
|
dim_var_values, token_arg_values, const_arg_values, in_args = util.split_list(
|
|
block.arguments, [num_dim_vars, num_tokens, num_const_args])
|
|
out_nodes_, tokens_out = mlir.jaxpr_subcomp(
|
|
sub_ctx, jaxpr, ctx.name_stack,
|
|
mlir.TokenSet(zip(ctx.tokens_in.effects(), token_arg_values)),
|
|
(), *in_args,
|
|
dim_var_values=dim_var_values,
|
|
const_lowering={
|
|
(id(c), aval): ca
|
|
for c, aval, ca in zip(const_args, const_avals, const_arg_values)
|
|
},
|
|
outer_traceback=_jax.Traceback())
|
|
sdy.ReturnOp(
|
|
mlir.flatten_ir_values(
|
|
it.chain((v for _, v in tokens_out.items()), out_nodes_)
|
|
)
|
|
)
|
|
num_tokens = len(tokens_out.effects())
|
|
tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip(
|
|
ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens])))
|
|
ctx.set_tokens_out(tokens_out)
|
|
|
|
return manual_computation_op.results[num_tokens:]
|
|
|
|
|
|
def _shard_map_lowering(ctx: mlir.LoweringRuleContext, *in_nodes,
|
|
jaxpr: core.Jaxpr, mesh, in_specs, out_specs,
|
|
check_vma, manual_axes):
|
|
if config.use_shardy_partitioner.value:
|
|
return _shard_map_lowering_shardy(
|
|
ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma)
|
|
|
|
in_avals_ = [v.aval for v in jaxpr.invars]
|
|
out_avals_ = [x.aval for x in jaxpr.outvars]
|
|
in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs,
|
|
ctx.avals_in, in_avals_, in_nodes)
|
|
new_axis_context = sharding_impls.SPMDAxisContext(
|
|
_get_spmdaxis_ctx_mesh(mesh), manual_axes)
|
|
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
|
with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma):
|
|
out_nodes_, tokens_out = mlir.call_lowering(
|
|
"shmap_body", pe.close_jaxpr(jaxpr), None, sub_ctx, in_avals_,
|
|
out_avals_, ctx.tokens_in, *in_nodes_,
|
|
dim_var_values=ctx.dim_var_values,
|
|
const_lowering=ctx.const_lowering,
|
|
arg_names=map(_pspec_mhlo_attrs, in_specs, in_avals_),
|
|
result_names=map(_pspec_mhlo_attrs, out_specs, out_avals_))
|
|
ctx.set_tokens_out(tokens_out)
|
|
return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_specs,
|
|
out_avals_, ctx.avals_out, out_nodes_)
|
|
mlir.register_lowering(shard_map_p, _shard_map_lowering)
|
|
|
|
def _make_scoped_manual_sharding(ctx, mesh, spec):
|
|
axis_ctx = ctx.module_context.axis_context
|
|
mesh = mesh.abstract_mesh
|
|
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
|
mesh = mesh.update_axis_types(
|
|
{a: AxisType.Manual for a in axis_ctx.manual_axes})
|
|
return NamedSharding(mesh, spec)
|
|
|
|
def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec,
|
|
aval_in, aval_out, x):
|
|
if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1:
|
|
return x
|
|
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
|
|
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
|
ns = sharding_impls.physical_sharding(aval_in, ns)
|
|
aval_in = core.physical_aval(aval_in)
|
|
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
|
|
unspecified = (set(range(aval_in.ndim))
|
|
if len(manual_axes) < len(mesh.axis_names) else set())
|
|
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto,
|
|
unspecified_dims=unspecified)
|
|
manual_proto = pxla.manual_proto(
|
|
aval_in, manual_axes | set(mesh.manual_axes), mesh)
|
|
return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto,
|
|
unspecified)
|
|
|
|
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec,
|
|
aval_in, aval_out, x):
|
|
if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1:
|
|
return x
|
|
ns = _make_scoped_manual_sharding(ctx, mesh, spec)
|
|
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
|
ns = sharding_impls.physical_sharding(aval_out, ns)
|
|
aval_out = core.physical_aval(aval_out)
|
|
unspecified = (set(range(aval_in.ndim))
|
|
if len(manual_axes) < len(mesh.axis_names) else set())
|
|
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
|
|
aval_in = core.physical_aval(aval_in)
|
|
manual_proto = pxla.manual_proto(
|
|
aval_in, manual_axes | set(mesh.manual_axes), mesh)
|
|
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto,
|
|
unspecified_dims=unspecified)
|
|
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
|
|
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
|
|
unspecified)
|
|
|
|
def _pspec_mhlo_attrs(spec, aval: core.AbstractValue) -> str:
|
|
if isinstance(aval, core.ShapedArray):
|
|
names = _spec_to_names(spec)
|
|
return str(map(names.get, range(aval.ndim)))
|
|
return ''
|
|
|
|
# Eager evaluation
|
|
|
|
def get_mesh_from_args(args_flat, mesh):
|
|
for a in args_flat:
|
|
if (hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding)
|
|
and not a.sharding.mesh.is_scalar): # pyrefly: ignore[missing-attribute]
|
|
|
|
if a.sharding.mesh.shape_tuple != mesh.shape_tuple:
|
|
aval = core.shaped_abstractify(a)
|
|
raise ValueError(
|
|
f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not"
|
|
" match the mesh shape passed to shard_map "
|
|
f" {mesh.shape_tuple} for shape {aval.str_short()}")
|
|
mesh = a.sharding.mesh
|
|
if isinstance(mesh, AbstractMesh):
|
|
raise ValueError(
|
|
"Please pass `jax.Array`s with a `NamedSharding` as input to"
|
|
" `shard_map` when passing `AbstractMesh` to the mesh argument.")
|
|
assert isinstance(mesh, Mesh)
|
|
return mesh
|
|
|
|
def _spec_to_vma(spec):
|
|
return frozenset(p for s in spec if s is not None
|
|
for p in (s if isinstance(s, tuple) else (s,)))
|
|
|
|
def _mat_to_spec(mesh, mat):
|
|
return P(order_wrt_mesh(mesh, mat.varying), unreduced=mat.unreduced,
|
|
reduced=mat.reduced)
|
|
|
|
def _spec_to_mat(spec) -> core.ManualAxisType:
|
|
return core.ManualAxisType(varying=_spec_to_vma(spec),
|
|
unreduced=spec.unreduced, reduced=spec.reduced)
|
|
|
|
def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs,
|
|
check_vma, manual_axes, debug_info):
|
|
del prim
|
|
if isinstance(mesh, AbstractMesh):
|
|
concrete_mesh = get_concrete_mesh()
|
|
mesh = concrete_mesh if not concrete_mesh.empty else mesh
|
|
mesh = get_mesh_from_args(args, mesh)
|
|
cur_mesh = get_abstract_mesh()
|
|
args_ = map(partial(_unmatch_spec, mesh, check_vma, cur_mesh, manual_axes),
|
|
in_specs, args)
|
|
in_mat = map(_spec_to_mat, in_specs)
|
|
outs, out_specs, out_mat = _run_shmap(fun, mesh, manual_axes, args_, in_mat,
|
|
check_vma)
|
|
out_avals = outs.map(lambda x: core.mapped_aval(x.shape[0], 0, core.typeof(x)))
|
|
_check_names(out_specs, out_avals)
|
|
if check_vma:
|
|
_check_mats(mesh, out_specs, out_avals)
|
|
src_pspecs = tuple(_mat_to_spec(mesh, m) for m in out_mat)
|
|
else:
|
|
src_pspecs = tuple(P(order_wrt_mesh(mesh, manual_axes))
|
|
for _ in range(len(out_mat)))
|
|
dst_pspecs = out_specs
|
|
return outs.map3(partial(_match_spec, mesh, check_vma, manual_axes),
|
|
src_pspecs, dst_pspecs)
|
|
core.EvalTrace.process_shard_map = _shard_map_impl
|
|
|
|
def _run_shmap_lu(f, mesh, manual_axes, args, mats, check_vma):
|
|
assert not mesh.manual_axes
|
|
trace = ShardMapTrace(mesh, manual_axes, check_vma)
|
|
in_tracers = map(partial(ShardMapTracer, trace), mats, args)
|
|
inner_mesh = _as_manual_mesh(mesh, manual_axes)
|
|
with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes),
|
|
use_abstract_mesh(inner_mesh), config._check_vma(check_vma)):
|
|
ans = f.call_wrapped(*in_tracers)
|
|
outs, out_mat = unzip2(map(trace.to_val_mat_pair, ans))
|
|
return outs, out_mat
|
|
|
|
def _run_shmap(f, mesh, manual_axes, args, mats, check_vma):
|
|
assert not mesh.manual_axes
|
|
trace = ShardMapTrace(mesh, manual_axes, check_vma)
|
|
in_tracers = map(partial(ShardMapTracer, trace), mats, args)
|
|
inner_mesh = _as_manual_mesh(mesh, manual_axes)
|
|
with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes),
|
|
use_abstract_mesh(inner_mesh), config._check_vma(check_vma)):
|
|
ans, out_specs = f(*in_tracers).unpack_aux()
|
|
outs, outs_mat = ans.map(trace.to_val_mat_pair).unzip2()
|
|
return outs, out_specs, list(outs_mat)
|
|
|
|
def _unmatch_spec2(mesh, prev_manual, spec, x) -> JaxType:
|
|
with (core.eval_context(), api.disable_jit(False),
|
|
use_abstract_mesh(mesh.abstract_mesh)):
|
|
return api.jit(HashablePartial(_unmatch2, mesh, prev_manual, spec))(x)
|
|
|
|
def _unmatch2(mesh, prev_manual, spec, x):
|
|
src = P(order_wrt_mesh(mesh, prev_manual), *spec)
|
|
newly_manual = _spec_to_vma(spec)
|
|
dst = P(order_wrt_mesh(mesh, prev_manual | newly_manual))
|
|
return shard_map(lambda x: x, in_specs=src, out_specs=dst,
|
|
axis_names=prev_manual | newly_manual)(x)
|
|
|
|
def _match_spec2(mesh, prev_manual, spec, x) -> JaxType:
|
|
with (core.eval_context(), api.disable_jit(False),
|
|
use_abstract_mesh(mesh.abstract_mesh)):
|
|
return api.jit(HashablePartial(_match2, mesh, prev_manual, spec))(x)
|
|
|
|
def _match2(mesh, prev_manual, spec, x):
|
|
newly_manual = _spec_to_vma(spec)
|
|
src = P(order_wrt_mesh(mesh, prev_manual | newly_manual))
|
|
dst = P(order_wrt_mesh(mesh, prev_manual), *spec)
|
|
return shard_map(lambda x: x, in_specs=src, out_specs=dst,
|
|
axis_names=prev_manual | newly_manual)(x)
|
|
|
|
|
|
def _unmatch_spec(mesh: Mesh, check_vma, context_mesh, manual_axes, in_spec,
|
|
x: JaxType) -> JaxType:
|
|
with (core.eval_context(), api.disable_jit(False),
|
|
use_abstract_mesh(context_mesh)):
|
|
return api.jit(HashablePartial(_unmatch, mesh, check_vma, in_spec,
|
|
manual_axes))(x)
|
|
|
|
def _unmatch(mesh, check_vma, in_spec, manual_axes, x):
|
|
if check_vma:
|
|
used_axes = _spec_to_vma(in_spec)
|
|
dst = P(order_wrt_mesh(mesh, used_axes), unreduced=in_spec.unreduced,
|
|
reduced=in_spec.reduced)
|
|
else:
|
|
dst = P(mesh.axis_names)
|
|
check_vma = False
|
|
return shard_map(_add_singleton, mesh=mesh, in_specs=(in_spec,),
|
|
out_specs=dst, check_vma=check_vma, axis_names=manual_axes)(x)
|
|
|
|
def _check_names(specs, avals: FlatTree) -> None:
|
|
fail = [a if isinstance(sp, P) and sp and len(sp) > a.ndim else no_fail
|
|
for sp, a in zip(specs, avals)]
|
|
if any(f is not no_fail for f in fail):
|
|
raise _SpecError(fail, avals.tree)
|
|
|
|
class _SpecError(Exception):
|
|
pass
|
|
|
|
def _check_mats(mesh, specs, avals):
|
|
fail = [a.mat if isinstance(sp, P) and not _valid_repeats(mesh, a.mat, sp)
|
|
else no_fail for sp, a in zip(specs, avals)]
|
|
if any(f is not no_fail for f in fail):
|
|
raise _RepError(fail, avals.tree)
|
|
|
|
class _RepError(Exception):
|
|
pass
|
|
|
|
def _match_spec(mesh: Mesh, check_vma, manual_axes, x: JaxType,
|
|
src_pspec: PartitionSpec, dst_pspec: PartitionSpec) -> JaxType:
|
|
fn = HashablePartial(_match, mesh, check_vma, manual_axes, src_pspec,
|
|
dst_pspec)
|
|
with core.eval_context(), api.disable_jit(False):
|
|
if set(mesh.axis_names) == manual_axes:
|
|
return api.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x)
|
|
return api.jit(fn)(x)
|
|
|
|
def _match(mesh, check_vma, manual_axes, src_pspec, dst_pspec, x):
|
|
return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec,
|
|
out_specs=dst_pspec, check_vma=check_vma,
|
|
axis_names=manual_axes)(x)
|
|
|
|
def _rem_singleton(x): return lax.squeeze(x, [0])
|
|
def _add_singleton(x): return lax.expand_dims(x, [0])
|
|
|
|
def _maybe_check_special(outs):
|
|
if not config.debug_nans.value and not config.debug_infs.value: return
|
|
bufs = [s.data for leaf in tree_leaves(outs)
|
|
for s in getattr(leaf, 'addressable_shards', [])]
|
|
try:
|
|
dispatch.check_special('shard_map', bufs)
|
|
except api_util.InternalFloatingPointError as e:
|
|
raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None
|
|
|
|
class ShardMapTrace(core.Trace):
|
|
__slots__ = ("mesh", "manual_axes", "check", "amesh")
|
|
|
|
mesh: Mesh # outer concrete or abstract mesh
|
|
manual_axes: frozenset[AxisName]
|
|
check: bool
|
|
|
|
def __init__(self, mesh, manual_axes, check):
|
|
super().__init__()
|
|
self.mesh = mesh
|
|
self.manual_axes = manual_axes
|
|
self.check = check
|
|
self.amesh = mesh.abstract_mesh
|
|
|
|
def to_val_mat_pair(self, val):
|
|
if isinstance(val, ShardMapTracer):
|
|
return val.val, val.mat
|
|
elif isinstance(val, Tracer):
|
|
raise Exception(f"Shouldn't have any non-shard_map tracers: {val}")
|
|
else:
|
|
val_ = _unmatch_spec(self.mesh, self.check, self.amesh, self.manual_axes,
|
|
P(), val)
|
|
return val_, core.empty_mat
|
|
|
|
def process_primitive(self, prim, tracers, params, /):
|
|
in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers))
|
|
if self.check:
|
|
out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params)
|
|
out_avals = tuple(out_avals) if type(out_avals) is list else out_avals
|
|
out_mat = tree_map(lambda a: a.mat, out_avals)
|
|
in_specs = tuple(map(partial(_mat_to_spec, self.mesh), in_mat))
|
|
out_specs = tree_map(partial(_mat_to_spec, self.mesh), out_mat)
|
|
else:
|
|
out_mat = core.empty_mat
|
|
in_specs = out_specs = P(order_wrt_mesh(self.mesh, self.manual_axes))
|
|
|
|
eager_rule = eager_rules.get(prim)
|
|
if eager_rule:
|
|
out_vals = eager_rule(self.mesh, *in_vals, **params)
|
|
else:
|
|
f = HashablePartial(
|
|
_prim_applier, prim, self.check, tuple(params.items()), self.mesh,
|
|
self.manual_axes, in_specs, out_specs)
|
|
with (core.eval_context(), api.disable_jit(False), config.debug_nans(False),
|
|
config.debug_infs(False), use_abstract_mesh(self.amesh),
|
|
sharding_impls._internal_use_concrete_mesh(self.mesh)):
|
|
out_vals = api.jit(f)(*in_vals)
|
|
_maybe_check_special(out_vals)
|
|
if prim.multiple_results:
|
|
out_mat = (out_mat if isinstance(out_mat, (list, tuple))
|
|
else [out_mat] * len(out_vals))
|
|
return map(partial(ShardMapTracer, self), out_mat, out_vals)
|
|
return ShardMapTracer(self, out_mat, out_vals)
|
|
|
|
def process_shard_map(self, prim, fun, args, mesh, in_specs,
|
|
check_vma, manual_axes, debug_info):
|
|
# Check consistency between outer and inner shmaps on explicitly passed
|
|
# mesh and check_vma.
|
|
if isinstance(mesh, Mesh):
|
|
if mesh != self.mesh: raise Exception
|
|
del mesh
|
|
if check_vma != self.check: # TODO(mattjj): add check in jit path
|
|
raise Exception
|
|
del check_vma
|
|
|
|
in_vals, in_mats = unzip2(map(self.to_val_mat_pair, args))
|
|
if any(m.unreduced or m.reduced for m in in_mats):
|
|
raise NotImplementedError(
|
|
"Eager shard_map + unreduced/reduced + partial manual is not"
|
|
" implemented. Please wrap your shard_map in `jax.jit`.")
|
|
trace = ShardMapTrace(self.mesh, manual_axes | self.manual_axes, self.check)
|
|
in_vals_ = [_unmatch_spec2(self.mesh, self.manual_axes, spec, x)
|
|
for x, spec in zip(in_vals, in_specs)]
|
|
# TODO(yashkatariya): Handle unreduced/reduced correctly.
|
|
in_mats_ = [core.ManualAxisType(varying=mat.varying | _spec_to_vma(s))
|
|
for mat, s in zip(in_mats, in_specs)]
|
|
in_tracers = map(partial(ShardMapTracer, trace), in_mats_, in_vals_)
|
|
inner_mesh = _as_manual_mesh(self.mesh, manual_axes | self.manual_axes)
|
|
with (core.set_current_trace(trace), _extend_axis_env(self.mesh, manual_axes),
|
|
use_abstract_mesh(inner_mesh)):
|
|
ans_aux = fun(*in_tracers)
|
|
ans, out_specs = ans_aux.unpack_aux()
|
|
out_vals_, out_mats_ = ans.map(trace.to_val_mat_pair).unzip2()
|
|
out_vals = out_vals_.map2(
|
|
lambda x, spec: _match_spec2(self.mesh, self.manual_axes, spec, x), out_specs)
|
|
# TODO(yashkatariya): Handle unreduced/reduced correctly.
|
|
out_mats = [core.ManualAxisType(varying=mat.varying - _spec_to_vma(spec))
|
|
for mat, spec in zip(out_mats_, out_specs)]
|
|
return out_vals.map2(lambda val, vma: ShardMapTracer(self, vma, val),
|
|
out_mats)
|
|
|
|
def process_call(self, call_primitive, fun, tracers, params, /):
|
|
raise NotImplementedError(
|
|
f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't "
|
|
"yet supported. Put a `jax.jit` around the `shard_map`-decorated "
|
|
"function, and open a feature request at "
|
|
"https://github.com/jax-ml/jax/issues !")
|
|
|
|
def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros):
|
|
# Since ShardMapTrace is only used as a base main, we can drop the jvp.
|
|
del prim, jvp, symbolic_zeros
|
|
in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers))
|
|
out_vals, out_mat = _run_shmap_lu(fun, self.mesh, self.manual_axes, in_vals,
|
|
in_mat, self.check)
|
|
return map(partial(ShardMapTracer, self), out_mat, out_vals)
|
|
|
|
def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, /, *, out_trees,
|
|
symbolic_zeros):
|
|
if symbolic_zeros:
|
|
msg = ("custom_vjp symbolic_zeros support with shard_map is not "
|
|
"implemented; please open an issue at "
|
|
"https://github.com/jax-ml/jax/issues")
|
|
raise NotImplementedError(msg)
|
|
del prim, fwd, bwd, out_trees, symbolic_zeros
|
|
in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers))
|
|
out_vals, out_mat = _run_shmap_lu(fun, self.mesh, self.manual_axes, in_vals,
|
|
in_mat, self.check)
|
|
return map(partial(ShardMapTracer, self), out_mat, out_vals)
|
|
|
|
|
|
class ShardMapTracer(core.Tracer[ShardMapTrace]):
|
|
mat: core.ManualAxisType
|
|
val: JaxType
|
|
|
|
def __init__(self, trace, mat, val):
|
|
assert isinstance(mat, core.ManualAxisType)
|
|
aval = core.typeof(val)
|
|
mat = (mat if trace.check else
|
|
core.ManualAxisType(varying=trace.manual_axes))
|
|
size = prod(trace.mesh.shape[n] for n in mat.varying)
|
|
out = core.mapped_aval(size, 0, aval)
|
|
manual_mesh = _as_manual_mesh(trace.amesh, trace.manual_axes)
|
|
spec = core.modify_spec_for_auto_manual(out.sharding.spec, manual_mesh) # pyrefly: ignore[missing-attribute]
|
|
new_sharding = NamedSharding(manual_mesh, spec)
|
|
mat_out = mat if trace.check else core.empty_mat
|
|
computed_aval = out.update(sharding=new_sharding, manual_axis_type=mat_out)
|
|
super().__init__(trace, computed_aval)
|
|
self.mat = mat
|
|
self.val = val
|
|
|
|
|
|
def to_concrete_value(self):
|
|
if self._trace.check and self.mat.vur == frozenset():
|
|
with (core.eval_context(), use_abstract_mesh(self._trace.amesh),
|
|
sharding_impls._internal_use_concrete_mesh(self._trace.mesh)):
|
|
return core.to_concrete_value(self.val[0])
|
|
else:
|
|
return None
|
|
|
|
def __str__(self) -> str:
|
|
pb_names = set(self._trace.mesh.axis_names) - self.mat.vur
|
|
self = pvary(self, tuple(pb_names))
|
|
with (core.eval_context(), use_abstract_mesh(self._trace.amesh),
|
|
sharding_impls._internal_use_concrete_mesh(self._trace.mesh)):
|
|
blocks = list(self.val)
|
|
mesh = self._trace.mesh
|
|
axis_names = f"({', '.join(map(str, mesh.axis_names))},)"
|
|
return '\n'.join(
|
|
f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n"
|
|
for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks))
|
|
|
|
__repr__ = __str__ # for debuggers, like `p x`
|
|
|
|
def _prim_applier(prim, check_vma, params_tup, concrete_mesh, manual_axes,
|
|
in_specs, out_specs, *args):
|
|
def apply(*args):
|
|
outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup))
|
|
return tree_map(_add_singleton, outs)
|
|
out_specs = list(out_specs) if type(out_specs) is tuple else out_specs
|
|
return shard_map(apply, mesh=concrete_mesh, in_specs=in_specs,
|
|
out_specs=out_specs, check_vma=check_vma,
|
|
axis_names=manual_axes)(*args)
|
|
|
|
eager_rules: dict[core.Primitive, Callable] = {}
|
|
|
|
def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics):
|
|
del mesh, srcs, copy_semantics
|
|
for device in devices:
|
|
if device is not None:
|
|
raise ValueError("device_put with explicit device not allowed within "
|
|
f"shard_map-decorated functions, but got device {device}")
|
|
return xs
|
|
eager_rules[dispatch.device_put_p] = _device_put_eager_rule
|
|
|
|
def _ref_raise_valueerror(*args, **kwargs):
|
|
raise ValueError(
|
|
"Eager shard_map cannot return a `jax.Ref`. Please wrap"
|
|
" your shard_map in `jax.jit`.")
|
|
|
|
eager_rules[core.ref_p] = _ref_raise_valueerror
|
|
eager_rules[core.empty_ref_p] = _ref_raise_valueerror
|
|
|
|
# Batching
|
|
|
|
def used_axis_names(spec):
|
|
return _spec_to_mat(spec).vur
|
|
|
|
def _shard_map_batch(
|
|
trace: batching.BatchTrace, prim: core.Primitive, fun: Callable,
|
|
in_tracers: Sequence[batching.BatchTracer], mesh: Mesh,
|
|
in_specs, check_vma: bool, manual_axes: frozenset,
|
|
debug_info) -> Sequence[batching.BatchTracer]:
|
|
in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers))
|
|
spmd_axis_name = trace.axis_data.spmd_name
|
|
explicit_mesh_axis = trace.axis_data.explicit_mesh_axis
|
|
if spmd_axis_name is not None:
|
|
used = {n for spec in in_specs for n in used_axis_names(spec)}
|
|
if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used:
|
|
raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs")
|
|
new_in_specs = [
|
|
sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_axis_name)
|
|
for sp, d in zip(in_specs, in_dims)]
|
|
new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name)
|
|
new_axis_data = batching.AxisData(
|
|
trace.axis_data.name, new_size, trace.axis_data.spmd_name,
|
|
trace.axis_data.explicit_mesh_axis)
|
|
elif explicit_mesh_axis is not None:
|
|
used = {n for spec in in_specs for n in used_axis_names(spec)}
|
|
if set(explicit_mesh_axis) & used:
|
|
raise ValueError("vmapped away explicit mesh axis cannot appear in "
|
|
"shard_map in_specs")
|
|
new_in_specs = [
|
|
sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
|
|
for sp, d in zip(in_specs, in_dims)]
|
|
new_axis_data = trace.axis_data
|
|
else:
|
|
new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
|
|
for sp, d in zip(in_specs, in_dims)]
|
|
new_axis_data = trace.axis_data
|
|
|
|
def fun_batched(*args):
|
|
ans_aux, out_dims = batching.batch_subtrace_2(
|
|
fun, trace.tag, new_axis_data, tuple(in_dims), args)
|
|
ans, out_specs = ans_aux.unpack_aux()
|
|
new_out_specs = _batch_out_specs(spmd_axis_name, explicit_mesh_axis,
|
|
out_dims, out_specs)
|
|
return ans.with_aux(out_dims).with_aux(tuple(new_out_specs))
|
|
|
|
new_params = dict(mesh=mesh, in_specs=new_in_specs, check_vma=check_vma,
|
|
manual_axes=manual_axes, debug_info=debug_info)
|
|
# TODO(yashkatariya): Remove remove_explicit_mesh_axis_names when vmap
|
|
# mesh ctx is correctly set.
|
|
with (core.set_current_trace(trace.parent_trace),
|
|
core.remove_explicit_mesh_axis_names(trace.axis_data.explicit_mesh_axis)):
|
|
out_vals = prim.bind(*in_vals, subfuns=(fun_batched,), **new_params)
|
|
make_tracer = partial(batching.BatchTracer, trace,
|
|
source_info=source_info_util.current())
|
|
out_vals, out_dims = out_vals.unpack_aux()
|
|
return out_vals.map2(make_tracer, out_dims)
|
|
batching.BatchTrace.process_shard_map = _shard_map_batch
|
|
|
|
def _batch_out_specs(spmd_name, explicit_mesh_axis, dims, out_specs):
|
|
if spmd_name is not None:
|
|
used = {n for spec in out_specs for n in used_axis_names(spec)}
|
|
if not config.disable_vmap_shmap_error.value and set(spmd_name) & used:
|
|
raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs")
|
|
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name)
|
|
for sp, d in zip(out_specs, dims)]
|
|
elif explicit_mesh_axis is not None:
|
|
used = {n for spec in out_specs for n in used_axis_names(spec)}
|
|
if set(explicit_mesh_axis) & used:
|
|
raise ValueError("vmapped away explicit mesh axis cannot appear in "
|
|
"shard_map out_specs")
|
|
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
|
|
for sp, d in zip(out_specs, dims)]
|
|
else:
|
|
return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None)
|
|
for sp, d in zip(out_specs, dims)]
|
|
|
|
|
|
# Autodiff
|
|
|
|
def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_specs,
|
|
check_vma, manual_axes, debug_info):
|
|
debug_info = debug_info.with_unknown_names()
|
|
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
|
|
which_nz = [ type(t) is not ad.Zero for t in tangents]
|
|
tangents = [t if type(t) is not ad.Zero else None for t in tangents]
|
|
args, in_zeros_tree = tree_flatten((primals, tangents))
|
|
tangent_in_specs = [sp.to_tangent_spec() for sp, nz in zip(in_specs, which_nz) if nz]
|
|
|
|
def f_jvp(*primals_and_nz_tangents_flat):
|
|
primals, tangents = tree_unflatten(in_zeros_tree, primals_and_nz_tangents_flat)
|
|
tangents = [ad.p2tz(p) if t is None else t for p, t in zip(primals, tangents)]
|
|
primals_out_aux, tangents_out = ad.jvp_subtrace_2(f, trace.tag, primals, tangents)
|
|
primals_out_ft, out_ax = primals_out_aux.unpack_aux()
|
|
which_nz_out = [type(r) is not ad.Zero for r in tangents_out]
|
|
tangent_out_specs = [s.to_tangent_spec() for s, nz in zip(out_ax, which_nz_out) if nz]
|
|
new_out_specs = (*out_ax, *tangent_out_specs)
|
|
tangents_out = [None if not nz else t for t, nz in zip(tangents_out, which_nz_out)]
|
|
tangents_out_ft = FlatTree.flatten(list(tangents_out))
|
|
out_primals_tangents = FlatTree.pack((primals_out_ft, tangents_out_ft))
|
|
return out_primals_tangents.with_aux(which_nz_out).with_aux(new_out_specs)
|
|
|
|
params = dict(mesh=mesh, in_specs=(*in_specs, *tangent_in_specs),
|
|
check_vma=check_vma, manual_axes=manual_axes,
|
|
debug_info=debug_info.with_unknown_names())
|
|
avals = [typeof(x) for x in args]
|
|
result = shard_map_p.bind_with_trace(
|
|
trace.parent_trace, tuple(args), avals, dict(params, subfuns=(f_jvp,)))
|
|
pt_out, which_nz_out = result.unpack_aux()
|
|
primal_out, nz_tangents_out = pt_out.unpack()
|
|
tangents_stack = list(nz_tangents_out)[::-1]
|
|
make_tracer = lambda p, nz: ad.JVPTracer(trace, p, tangents_stack.pop()) if nz else p
|
|
tracers_out = primal_out.map2(make_tracer, which_nz_out)
|
|
assert not tangents_stack
|
|
return tracers_out
|
|
|
|
ad.JVPTrace.process_shard_map = _shard_map_jvp
|
|
|
|
def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p,
|
|
f: Callable, tracers, mesh, in_specs,
|
|
check_vma, manual_axes, debug_info):
|
|
tracers = map(trace.to_jaxpr_tracer, tracers)
|
|
in_pvals = [t.pval for t in tracers]
|
|
in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals)
|
|
unk_in_specs, known_in_specs = pe.partition_list(in_knowns, in_specs)
|
|
in_avals_sharded = map(partial(shard_aval, mesh, manual_axes, check_vma),
|
|
unk_in_specs, in_avals)
|
|
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
|
|
def f_pe(*in_consts):
|
|
in_avals_, in_consts_ = iter(in_avals_sharded), iter(in_consts)
|
|
in_pvals = [pe.PartialVal.known(next(in_consts_)) if known else
|
|
pe.PartialVal.unknown(next(in_avals_)) for known in in_knowns]
|
|
sentinel = object()
|
|
assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel
|
|
jaxpr, fwd_data = pe.trace_to_subjaxpr_nounits_fwd2(
|
|
f, trace.tag, debug_info.with_unknown_names(), False, in_pvals)
|
|
(in_fwds, out_fwds, out_pvals, res, env) = fwd_data
|
|
which = [f1 is None and f2 is None and not v.aval.shape
|
|
for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)]
|
|
jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which)
|
|
res = [lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
|
for x in res]
|
|
out_pvals, out_specs = out_pvals.unpack_aux()
|
|
out_knowns, _, out_consts = pe.partition_pvals(out_pvals)
|
|
res_avals = [typeof(r) for r in res]
|
|
_, out_known_specs = pe.partition_list(out_knowns, out_specs)
|
|
res_specs = [a.nospec(mesh, check_vma, all_names) for a in res_avals]
|
|
new_out_specs = (*out_known_specs, *res_specs)
|
|
ft = out_pvals.map(lambda _: None)
|
|
ans_ft = FlatTree.flatten((out_consts, res))
|
|
aux = (in_fwds, out_fwds, out_knowns, res_avals, jaxpr, env, out_specs, new_out_specs, ft)
|
|
return ans_ft.with_aux(aux).with_aux(new_out_specs)
|
|
|
|
known_params = dict(mesh=mesh, in_specs=(*known_in_specs,),
|
|
check_vma=check_vma, manual_axes=manual_axes,
|
|
debug_info=debug_info.with_unknown_names())
|
|
avals = [typeof(x) for x in in_consts]
|
|
out = shard_map_p.bind_with_trace(trace.parent_trace, tuple(in_consts), avals,
|
|
dict(known_params, subfuns=(f_pe,)))
|
|
outs, (in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env, out_specs, new_out_specs, ft) = out.unpack_aux()
|
|
out_consts, non_fwd_res = outs.unflatten()
|
|
|
|
assert not jaxpr.constvars
|
|
unk_out_specs, _ = pe.partition_list(out_knowns, out_specs)
|
|
res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res)
|
|
# TODO make res_avals be the full set, not just the non-fwd ones
|
|
res_avals_iter = iter(res_avals)
|
|
res_specs = []
|
|
for f1, f2 in zip(in_fwd, out_fwd):
|
|
if f1 is not None:
|
|
res_specs.append(known_in_specs[f1])
|
|
elif f2 is not None:
|
|
res_specs.append(new_out_specs[f2])
|
|
else:
|
|
raval = next(res_avals_iter)
|
|
res_specs.append(raval.nospec(mesh, check_vma, all_names))
|
|
env_specs = [_repspec(typeof(e)) for e in env]
|
|
unk_in_specs = (*res_specs, *env_specs, *unk_in_specs)
|
|
const_tracers = map(trace.new_instantiated_const, res)
|
|
env_tracers = map(trace.to_jaxpr_tracer, env)
|
|
unk_arg_tracers = [t for t in tracers if not t.is_known()]
|
|
out_avals_sharded = [v.aval for v in jaxpr.outvars]
|
|
unk_params = dict(mesh=mesh, in_specs=unk_in_specs,
|
|
out_specs=tuple(unk_out_specs),
|
|
jaxpr=jaxpr.replace(debug_info=jaxpr.debug_info.with_unknown_names()),
|
|
check_vma=check_vma, manual_axes=manual_axes)
|
|
out_avals = map(partial(unshard_aval, mesh, check_vma), unk_out_specs,
|
|
out_avals_sharded)
|
|
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
|
for a in out_avals]
|
|
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
|
eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers),
|
|
out_tracers, shard_map_p, unk_params,
|
|
effs, source_info_util.current())
|
|
for t in out_tracers: t.recipe = eqn
|
|
results = merge_lists(out_knowns, out_tracers, out_consts)
|
|
return ft.update(results)
|
|
|
|
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
|
|
|
|
def _shard_map_linearize(trace, shard_map_p, f: Callable,
|
|
tracers, mesh, in_specs, check_vma,
|
|
manual_axes, debug_info):
|
|
debug_info = debug_info.with_unknown_names()
|
|
primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers))
|
|
nzs_in = tuple(type(t) is not ad.Zero for t in tangents)
|
|
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
|
|
|
|
def f_lin(*primals):
|
|
res, ans_aux, lin_data = ad.linearize_subtrace_2(
|
|
f, trace.is_vjp, trace.tag, nzs_in, debug_info, primals)
|
|
primals_out, out_specs = ans_aux.unpack_aux()
|
|
|
|
res_avals, _, _, _, in_fwd, out_fwd = lin_data
|
|
res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd)
|
|
if f1 is None and f2 is None]
|
|
res_specs = [a.nospec(mesh, check_vma, all_names) for a in res_avals]
|
|
new_out_specs = (*res_specs, *out_specs)
|
|
res = [lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x
|
|
for x in res]
|
|
res_and_primal = FlatTree.pack((FlatTree.flatten(res), primals_out))
|
|
return res_and_primal.with_aux((lin_data, out_specs)).with_aux(new_out_specs)
|
|
|
|
fwd_params = dict(
|
|
mesh=mesh, in_specs=in_specs,
|
|
check_vma=check_vma, manual_axes=manual_axes, debug_info=debug_info)
|
|
avals = [typeof(x) for x in primals]
|
|
all_results_aux = shard_map_p.bind_with_trace(
|
|
trace.parent_trace, tuple(primals), avals, dict(fwd_params, subfuns=(f_lin,)))
|
|
all_results, (lin_data, out_specs) = all_results_aux.unpack_aux()
|
|
res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = lin_data
|
|
non_fwd_res, primals_out = all_results.unpack()
|
|
residuals = subs_list2(in_fwd, out_fwd, primals, (*primals_out,), non_fwd_res)
|
|
args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None
|
|
for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)]
|
|
with (_extend_axis_env(mesh, manual_axes),
|
|
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)),
|
|
config._check_vma(check_vma)):
|
|
lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote)
|
|
res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd)
|
|
if f1 is None and f2 is None]
|
|
res_avals_iter = iter(res_avals2)
|
|
res_specs = [in_specs[f1] if f1 is not None else out_specs[f2] if f2 is not None
|
|
else next(res_avals_iter).nospec(mesh, check_vma, all_names)
|
|
for f1, f2 in zip(in_fwd, out_fwd)]
|
|
assert next(res_avals_iter, None) is None
|
|
env_specs = [_repspec(typeof(e)) for e in env]
|
|
new_in_specs = (*res_specs, *env_specs,
|
|
*(s.to_tangent_spec() for s, nz in zip(in_specs, nzs_in) if nz))
|
|
tangent_out_specs = tuple(s.to_tangent_spec() for s, nz in zip(out_specs, nzs_out) if nz)
|
|
tangent_params = dict(
|
|
mesh=mesh, in_specs=new_in_specs,
|
|
check_vma=check_vma, manual_axes=manual_axes, debug_info=lin_jaxpr.debug_info)
|
|
|
|
# TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here
|
|
def f_tangent(*args):
|
|
ans = core.eval_jaxpr(lin_jaxpr, (), *args)
|
|
return FlatTree.flatten(ans).with_aux(tangent_out_specs)
|
|
|
|
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
|
args = (*residuals, *env, *nz_tangents_in)
|
|
avals = [typeof(x) for x in args]
|
|
nz_tangents_out = shard_map_p.bind_with_trace(
|
|
trace.tangent_trace, args, avals,
|
|
dict(tangent_params, subfuns=(f_tangent,)))
|
|
nz_tangents_out_iter = iter(nz_tangents_out)
|
|
tangents_out = [next(nz_tangents_out_iter) if nz else ad.p2tz(primal)
|
|
for nz, primal in zip(nzs_out, primals_out)]
|
|
return primals_out.map3(partial(ad.maybe_linearize_tracer, trace), nzs_out, tangents_out)
|
|
ad.LinearizeTrace.process_shard_map = _shard_map_linearize
|
|
|
|
|
|
def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]):
|
|
def fun(*res_and_args):
|
|
res, args = split_list(res_and_args, [len(jaxpr.constvars)])
|
|
res = [_rem_singleton(x) if w else x for x, w in zip(res, which)]
|
|
return core.eval_jaxpr(jaxpr, res, *args)
|
|
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
|
|
for v, w in zip(jaxpr.constvars, which)]
|
|
in_avals = FlatTree.flatten(((*res_avals, *[v.aval for v in jaxpr.invars]), {}))
|
|
closed_jaxpr, _ = pe.trace_to_jaxpr(fun, in_avals, debug_info=jaxpr.debug_info)
|
|
closed_jaxpr, _ = pe.separate_consts(closed_jaxpr)
|
|
return closed_jaxpr.jaxpr
|
|
|
|
|
|
def _unmentioned2(mesh: Mesh, spec, manual_axes: frozenset[AxisName]
|
|
) -> list[AxisName]:
|
|
# We use a filtered-down version of unmentioned to avoid defensive-psum over
|
|
# more chips than required in the transpose-no-check-vma case.
|
|
name_set = _spec_to_vma(spec) | spec.unreduced
|
|
return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes)
|
|
if n not in name_set]
|
|
|
|
|
|
def _shard_map_transpose(out_cts, *args,
|
|
jaxpr: core.Jaxpr, mesh, in_specs, out_specs,
|
|
check_vma, manual_axes):
|
|
mb_div = lambda x, y: x / y if y != 1 else x
|
|
out_cts = [
|
|
ad.Zero(shard_aval(mesh, manual_axes, check_vma, sp, x.aval))
|
|
if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0
|
|
else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, sp, manual_axes))))
|
|
for sp, x in zip(out_specs, out_cts)
|
|
]
|
|
args = [x if type(x) is not ad.UndefinedPrimal else
|
|
ad.UndefinedPrimal(shard_aval(mesh, manual_axes, check_vma, sp, x.aval))
|
|
for sp, x in zip(in_specs, args)]
|
|
all_args, in_tree = tree_flatten((out_cts, tuple(args)))
|
|
|
|
def fun_trans_callable(*right_flat):
|
|
right_cts, primals_or_undefs = tree_unflatten(in_tree, right_flat)
|
|
left_cts = ad.backward_pass(jaxpr, False, (), primals_or_undefs, right_cts)
|
|
left_cts = [x if type(x) is ad.Zero or check_vma
|
|
else lax_parallel.psum(x, tuple(_unmentioned2(mesh, sp, manual_axes)))
|
|
for sp, x in zip(in_specs, left_cts)]
|
|
left_specs_nz = tuple(
|
|
s.to_ct_spec() for ct, s in zip(left_cts, in_specs) if ct is not None
|
|
and type(ct) is not ad.Zero)
|
|
return FlatTree.flatten(left_cts).with_aux(left_specs_nz)
|
|
|
|
dbg = jaxpr.debug_info.with_unknown_names()
|
|
new_in_specs = (
|
|
[s.to_ct_spec() for s, x in zip(out_specs, out_cts) if type(x) is not ad.Zero] +
|
|
[s for s, x in zip(in_specs, args) if type(x) is not ad.UndefinedPrimal])
|
|
|
|
left_ct = shard_map_p.bind(
|
|
*all_args, subfuns=(fun_trans_callable,), mesh=mesh, in_specs=tuple(new_in_specs),
|
|
check_vma=check_vma, manual_axes=manual_axes, debug_info=dbg)
|
|
left_cts = left_ct.unflatten()
|
|
return [ad.Zero(unshard_aval(mesh, check_vma, sp.to_ct_spec(), x.aval))
|
|
if type(x) is ad.Zero else x for sp, x in zip(in_specs, left_cts)]
|
|
ad.primitive_transposes[shard_map_p] = _shard_map_transpose
|
|
|
|
# Remat
|
|
|
|
def _partial_eval_jaxpr_custom_rule(
|
|
saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool],
|
|
inst_in: Sequence[bool], eqn: core.JaxprEqn
|
|
) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool],
|
|
list[core.Var]]:
|
|
jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh']
|
|
check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes']
|
|
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
|
|
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))):
|
|
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
|
|
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
|
|
num_out_primals = len(jaxpr_known.outvars) - num_res
|
|
in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:]
|
|
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
|
out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals])
|
|
idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_binders_known))
|
|
if not isinstance(b, core.DropVar)}
|
|
out_fwd = [idx_map.get(id(v)) for v in res_vars]
|
|
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
|
|
mesh = eqn.params['mesh']
|
|
with (_extend_axis_env(mesh, manual_axes),
|
|
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)),
|
|
config._check_vma(check_vma)):
|
|
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
|
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
|
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
|
|
jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names)
|
|
ins_known, _ = partition_list(unks_in, eqn.invars)
|
|
_, ins_staged = partition_list(inst_in, eqn.invars)
|
|
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
|
|
nv = core.gensym()
|
|
all_names = _all_newly_manual_mesh_names(mesh, manual_axes)
|
|
lns = lambda a: a.nospec(mesh, check_vma, all_names)
|
|
residuals, staged_in_res_specs = unzip2(
|
|
[(nv(unshard_aval(mesh, check_vma, (rn := lns(var.aval)), var.aval)), rn)
|
|
for var, w in zip(jaxpr_staged.invars[:num_res], which) if w])
|
|
out_res_specs_known = [var.aval.nospec(mesh, check_vma, all_names) # pyrefly: ignore[missing-attribute]
|
|
for var, w in zip(res_vars, which) if w]
|
|
params_known, params_staged = _pe_custom_params(
|
|
unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd,
|
|
out_res_specs_known, staged_in_res_specs,
|
|
dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged))
|
|
eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals],
|
|
eqn.primitive, params_known, jaxpr_known.effects,
|
|
eqn.source_info, eqn.ctx)
|
|
full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals)
|
|
eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged,
|
|
eqn.primitive, params_staged,
|
|
jaxpr_staged.effects, eqn.source_info, eqn.ctx)
|
|
assert len(eqn_staged.invars) == len(jaxpr_staged.invars)
|
|
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
|
|
if type(x) is core.Var and not inst]
|
|
new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}]
|
|
return eqn_known, eqn_staged, unks_out, inst_out, new_inst + list(residuals)
|
|
pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \
|
|
_partial_eval_jaxpr_custom_rule
|
|
|
|
def _add_reshapes(which: Sequence[bool],
|
|
jaxpr_known: core.Jaxpr,
|
|
jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]:
|
|
# add singleton axes to residuals which are from jaxpr_known and are scalars
|
|
which_ = [w and not v.aval.shape # pyrefly: ignore[missing-attribute]
|
|
for w, v in zip(which, jaxpr_staged.invars[:len(which)])]
|
|
if not any(which_): return jaxpr_known, jaxpr_staged
|
|
assert not jaxpr_known.constvars and not jaxpr_staged.constvars
|
|
|
|
def known(*args):
|
|
out = core.eval_jaxpr(jaxpr_known, (), *args)
|
|
out_known, res = split_list(out, [len(out) - sum(which)])
|
|
res = [_add_singleton(x) if not x.shape else x for x in res]
|
|
return [*out_known, *res]
|
|
avals_in = tuple(v.aval for v in jaxpr_known.invars)
|
|
avals_in = FlatTree.flatten((avals_in, {}))
|
|
jaxpr_known_closed, _ = pe.trace_to_jaxpr(
|
|
known, avals_in, debug_info=jaxpr_known.debug_info)
|
|
|
|
def staged(*args):
|
|
res_, ins = split_list(args, [len(which)])
|
|
res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)]
|
|
return core.eval_jaxpr(jaxpr_staged, (), *res, *ins)
|
|
res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval
|
|
for w, v in zip(which_, jaxpr_staged.invars[:len(which)])]
|
|
avals_in = (*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]])
|
|
avals_in = FlatTree.flatten((avals_in, {}))
|
|
jaxpr_staged_closed, _ = pe.trace_to_jaxpr(
|
|
staged, avals_in, debug_info=jaxpr_staged.debug_info)
|
|
|
|
return jaxpr_known_closed.jaxpr, jaxpr_staged_closed.jaxpr
|
|
|
|
def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
|
in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs,
|
|
params_known, params_staged):
|
|
# prune inputs to jaxpr_known according to unks_in
|
|
in_specs_known, _ = partition_list(unks_in, params_known['in_specs'])
|
|
_, out_specs_known = partition_list(kept_outs_known, params_known['out_specs'])
|
|
out_specs_known = out_specs_known + out_res_specs_known
|
|
assert len(out_specs_known) == len(params_known['jaxpr'].outvars)
|
|
new_params_known = dict(params_known, in_specs=tuple(in_specs_known),
|
|
out_specs=tuple(out_specs_known))
|
|
|
|
# added num_res new inputs to jaxpr_staged, pruning according to inst_in
|
|
_, in_specs_staged = partition_list(inst_in, params_staged['in_specs'])
|
|
iter_staged = iter(staged_in_res_specs)
|
|
res_specs = [in_specs_known[f1] if f1 is not None else
|
|
out_specs_known[f2] if f2 is not None else
|
|
next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)]
|
|
|
|
in_specs_staged = res_specs + in_specs_staged
|
|
_, out_specs_staged = partition_list(kept_outs_staged, params_staged['out_specs'])
|
|
new_params_staged = dict(params_staged, in_specs=tuple(in_specs_staged),
|
|
out_specs=tuple(out_specs_staged))
|
|
return new_params_known, new_params_staged
|
|
|
|
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
|
def _all_mesh_names_except_spmd(
|
|
mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]:
|
|
axis_env = core.get_axis_env()
|
|
spmd_names = axis_env.spmd_axis_names
|
|
return tuple(name for name in mesh.axis_names
|
|
if name not in spmd_names and name in manual_axes)
|
|
|
|
def _all_newly_manual_mesh_names(
|
|
mesh: BaseMesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]:
|
|
axis_env = core.get_axis_env()
|
|
vmap_spmd_names = set(axis_env.spmd_axis_names)
|
|
if not (ctx_mesh := get_abstract_mesh()).empty:
|
|
mesh = ctx_mesh
|
|
already_manual_names = set(ctx_mesh.manual_axes)
|
|
else:
|
|
# TODO(mattjj): remove this mechanism when we revise mesh scopes
|
|
already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names
|
|
return tuple(name for name in mesh.axis_names
|
|
if (name not in vmap_spmd_names | already_manual_names and
|
|
name in manual_axes))
|
|
|
|
|
|
# DCE
|
|
|
|
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
|
|
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
|
) -> tuple[list[bool], core.JaxprEqn | None]:
|
|
if not any(used_outputs) and not pe.has_effects(eqn):
|
|
return [False] * len(eqn.invars), None
|
|
mesh = eqn.params["mesh"]
|
|
manual_axes = eqn.params["manual_axes"]
|
|
check_vma = eqn.params["check_vma"]
|
|
with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma),
|
|
use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))):
|
|
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
|
|
if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects:
|
|
return used_inputs, None
|
|
else:
|
|
_, in_specs = partition_list(used_inputs, eqn.params['in_specs'])
|
|
_, out_specs = partition_list(used_outputs, eqn.params['out_specs'])
|
|
new_params = dict(eqn.params, jaxpr=jaxpr, in_specs=tuple(in_specs),
|
|
out_specs=tuple(out_specs))
|
|
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
|
new_eqn = pe.new_jaxpr_eqn(
|
|
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
|
[x for x, used in zip(eqn.outvars, used_outputs) if used],
|
|
eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx)
|
|
return used_inputs, new_eqn
|
|
pe.dce_rules[shard_map_p] = _shard_map_dce
|
|
|
|
# Mutable arrays / refs
|
|
|
|
@discharge.register_discharge_rule(shard_map_p)
|
|
def _shard_map_discharge(
|
|
in_avals, out_avals, *args, jaxpr, mesh, in_specs, out_specs, check_vma,
|
|
manual_axes):
|
|
inner_mesh = _as_manual_mesh(mesh, manual_axes)
|
|
with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh),
|
|
config._check_vma(check_vma)):
|
|
discharged_jaxpr, discharged_consts = discharge.discharge_state(jaxpr, ())
|
|
if discharged_consts: raise NotImplementedError
|
|
del discharged_consts
|
|
|
|
ref_specs = [spec for spec, invar in zip(in_specs, jaxpr.invars)
|
|
if isinstance(invar.aval, AbstractRef)]
|
|
params = dict(jaxpr=discharged_jaxpr, out_specs=(*out_specs, *ref_specs))
|
|
params_ = shard_map_p.get_bind_params(params)
|
|
f, = params_.pop('subfuns')
|
|
debug_info = params_['debug_info']
|
|
out_and_ref_vals = shard_map_p.bind(
|
|
*args, subfuns=(f,), mesh=mesh, in_specs=in_specs, manual_axes=manual_axes,
|
|
debug_info=debug_info, check_vma=check_vma)
|
|
out_vals, ref_vals = split_list(out_and_ref_vals, [len(jaxpr.outvars)])
|
|
ref_vals_ = iter(ref_vals)
|
|
new_invals = [next(ref_vals_) if isinstance(a, AbstractRef) else None
|
|
for a in in_avals]
|
|
assert next(ref_vals_, None) is None
|
|
return new_invals, out_vals
|
|
|
|
def _repspec(aval):
|
|
return aval.nospec(empty_abstract_mesh, False, ())
|
|
|
|
# ----------------------- top level collectives --------------------------------
|
|
|
|
def _top_level_ag(x, aval, out_sh_, multi_dim):
|
|
assert aval.sharding.mesh.are_all_axes_explicit, aval.sharding.mesh
|
|
out_sh = canonicalize_sharding(out_sh_, "top_level_all_gather")
|
|
if out_sh is None:
|
|
raise ValueError(
|
|
f'out_sharding passed to top_level_all_gather cannot be {out_sh_}. It'
|
|
' should be a PartitionSpec or NamedSharding.')
|
|
if aval.sharding.mesh != out_sh.mesh:
|
|
raise ValueError(
|
|
f'Input sharding mesh {aval.sharding.mesh} should be equal to'
|
|
f' out_sharding mesh {out_sh.mesh}')
|
|
|
|
in_spec = aval.sharding.spec
|
|
out_spec = out_sh.spec._normalized_spec_for_aval(len(in_spec))
|
|
if config.remove_size_one_mesh_axis_from_type.value:
|
|
out_spec = core.remove_size_one_mesh_axis(out_spec, out_sh.mesh)
|
|
|
|
def f_shmap(x):
|
|
# Maybe this can just be 1 AG where we gather in a new dim and then do
|
|
# AG(new_dim) -> reshape -> transpose -> reshape but it might be expensive.
|
|
count = 0
|
|
for axis, (i, o) in enumerate(zip(in_spec, out_spec)):
|
|
if i == o:
|
|
continue
|
|
if not multi_dim and count > 0:
|
|
raise ValueError(
|
|
"multiple dimensions cannot be all_gathered since multi_dim=False"
|
|
f" passed to `top_level_all_gather`. Got {in_spec=} and {out_spec=}")
|
|
count += 1
|
|
if i is None:
|
|
raise ValueError(
|
|
f"top_level_all_gather doesn't allow input {aval} to be unsharded"
|
|
f" on dimension {axis} when {out_spec=}.")
|
|
i = i if isinstance(i, tuple) else (i,)
|
|
o = o if o is None or isinstance(o, tuple) else (o,)
|
|
if o is not None and i[:len(o)] != o:
|
|
raise ValueError(
|
|
'top_level_all_gather maintains `top_level_all_gather(x, ...) == x`'
|
|
f" property. The {in_spec=} and {out_spec=} don't satisfy this"
|
|
f' property. Please change your out_spec of array dimension {axis} so'
|
|
" that it's a prefix of in_spec")
|
|
axis_name = i if o is None else i[-len(o):]
|
|
x = lax_parallel.all_gather(x, axis_name=axis_name, axis=axis,
|
|
tiled=True, to='reduced')
|
|
return x
|
|
return api.jit(shard_map(f_shmap, out_specs=out_spec))(x)
|
|
|
|
def top_level_all_gather(xs, out_sharding, *, multi_dim: bool = False):
|
|
if not get_abstract_mesh().are_all_axes_explicit:
|
|
raise ValueError(
|
|
'top_level_all_gather works when all mesh axes of context mesh are'
|
|
f' explicit. Got {get_abstract_mesh()}')
|
|
x_flat, treedef = tree_flatten(xs)
|
|
out_sharding_flat = api_util.flatten_axis_resources(
|
|
"top_level_all_gather out_sharding", treedef, out_sharding,
|
|
tupled_args=True)
|
|
x_avals_flat = [core.typeof(x) for x in x_flat]
|
|
out_flat = [_top_level_ag(x, aval, sh, multi_dim)
|
|
for x, aval, sh in zip(x_flat, x_avals_flat, out_sharding_flat)]
|
|
return tree_unflatten(treedef, out_flat)
|