# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Callable, Hashable, Sequence, Set import enum from functools import partial import inspect import itertools as it from math import prod import operator as op from typing import Any, TypeVar, Union, cast, overload import numpy as np from jax._src import api from jax._src import api_util from jax._src import config from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util from jax._src import util from jax._src.core import order_wrt_mesh from jax._src.core import pvary, Tracer, typeof, shard_aval, unshard_aval from jax._src.mesh import (AbstractMesh, Mesh, BaseMesh, AxisType, use_abstract_mesh, get_abstract_mesh, get_concrete_mesh, empty_abstract_mesh) from jax._src.lax import lax, parallel as lax_parallel from jax._src.lib import _jax from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo, sdy from jax._src.sharding_impls import (NamedSharding, PartitionSpec, canonicalize_sharding) from jax._src.util import (HashablePartial, unzip2, partition_list, merge_lists, split_list, subs_list2, fun_name as util_fun_name) from jax._src.state import discharge from jax._src.state.types import AbstractRef from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.interpreters import ad from jax._src.tree_util import ( broadcast_prefix, keystr, prefix_errors, generate_key_paths, tree_flatten, tree_leaves, tree_map, tree_structure, tree_unflatten, KeyPath, PyTreeDef, FlatTree) P = PartitionSpec map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip traceback_util.register_exclusion(__file__) # API Specs = Any # PyTree[PartitionSpec] AxisName = Hashable class InferFromArgs: def __repr__(self): return "jax.sharding.Infer" def __reduce__(self): return (_get_default_infer, ()) Infer = InferFromArgs() def _get_default_infer(): return Infer F = TypeVar("F", bound=Callable) G = TypeVar("G", bound=Callable) @overload def shard_map(f: F, /, *, out_specs: Specs, in_specs: Specs | None | InferFromArgs = ..., mesh: Mesh | AbstractMesh | None = ..., axis_names: Set[AxisName] = ..., check_vma: bool = ...) -> F: ... @overload def shard_map(f: None = None, /, *, out_specs: Specs, in_specs: Specs | None | InferFromArgs = ..., mesh: Mesh | AbstractMesh | None = ..., axis_names: Set[AxisName] = ..., check_vma: bool = ... ) -> Callable[[G], G]: ... # See https://github.com/jax-ml/jax/pull/30753 to understand why `in_specs` # defaults to `Infer`. def shard_map(f: F | None = None, /, *, out_specs: Specs, in_specs: Specs | None | InferFromArgs = Infer, mesh: Mesh | AbstractMesh | None = None, axis_names: Set[AxisName] = frozenset(), check_vma: bool = True ) -> F | Callable[[G], G]: """Map a function over shards of data using a mesh of devices. See the docs at https://docs.jax.dev/en/latest/notebooks/shard_map.html. Args: f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, takes as input a shard of the mapped-over arguments and produces a shard of the output. mesh: (optional, default None) a ``jax.sharding.Mesh`` representing the array of devices over which to shard the data and on which to execute instances of ``f``. The names of the ``Mesh`` can be used in collective communication operations in ``f``. If mesh is None, it will be inferred from the context which can be set via `jax.set_mesh` context manager. in_specs: (optional, default `Infer`) a pytree with ``jax.sharding.PartitionSpec`` instances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the corresponding argument (or subtree of arguments) should be sharded along the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. If ``Infer``, all mesh axes must be of type `Explicit`, in which case the in_specs are inferred from the argument types. If ``None``, inputs will be treated as static. out_specs: a pytree with ``PartitionSpec`` instances as leaves, with a tree structure that is a tree prefix of the output of ``f``. Each ``PartitionSpec`` represents how the corresponding output shards should be concatenated. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses concatenation of that mesh axis's shards along the corresponding positional axis; not mentioning a ``mesh`` axis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced. axis_names: (optional, default set()) set of axis names from ``mesh`` over which the function ``f`` is manual. If empty, ``f``, is manual over all mesh axes. check_vma: (optional) boolean (default True) representing whether to enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in ``out_specs`` are consistent with how the outputs of ``f`` are replicated. Returns: A callable representing a mapped version of ``f``, which accepts positional arguments corresponding to those of ``f`` and produces output corresponding to that of ``f``. """ kwargs = dict(mesh=mesh, in_specs=in_specs, out_specs=out_specs, axis_names=axis_names, check_vma=check_vma) if f is None: return lambda g: _shard_map(g, **kwargs) return _shard_map(f, **kwargs) @overload def smap(f: F, /, *, in_axes: int | None | InferFromArgs | tuple[Any, ...] = ..., out_axes: Any, axis_name: AxisName) -> F: ... @overload def smap(f: None = None, /, *, in_axes: int | None | InferFromArgs | tuple[Any, ...] = ..., out_axes: Any, axis_name: AxisName ) -> Callable[[G], G]: ... def smap(f: F | None = None, /, *, in_axes: int | None | InferFromArgs | tuple[Any, ...] = Infer, out_axes: Any, axis_name: AxisName ) -> F | Callable[[G], G]: """Single axis shard_map that maps a function `f` one axis at a time. Args: f: Callable to be mapped. Each application of ``f``, or "instance" of ``f``, takes as input a shard of the mapped-over arguments and produces a shard of the output. in_axes: (optional) An integer, None, or sequence of values specifying which input array axes to map over. If not specified, `smap` will try to infer the axes from the arguments only under `Explicit` mode. An integer or ``None`` indicates which array axis to map over for all arguments (with ``None`` indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of dimensions (axes) of the corresponding input array. out_axes: An integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. axis_name: ``mesh`` axis name over which the function ``f`` is manual. Returns: A callable representing a mapped version of ``f``, which accepts positional arguments corresponding to those of ``f`` and produces output corresponding to that of ``f``. """ kwargs = dict(in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) if f is None: return lambda g: _smap(g, **kwargs) return _smap(f, **kwargs) def _smap(f: F, *, in_axes: int | None | InferFromArgs | tuple[Any, ...], out_axes: Any, axis_name: AxisName) -> F: if isinstance(axis_name, (list, tuple)): raise TypeError( f"smap axis_name should be a `str` or a `Hashable`, but got {axis_name}") if (in_axes is not None and in_axes is not Infer and not isinstance(in_axes, (int, tuple))): raise TypeError( "smap in_axes must be an int, None, jax.sharding.Infer, or a tuple of" " entries corresponding to the positional arguments passed to the" f" function, but got {in_axes}.") if (in_axes is not Infer and not all(isinstance(l, int) for l in tree_leaves(in_axes))): raise TypeError( "smap in_axes must be an int, None, jax.sharding.Infer, or (nested)" f" container with those types as leaves, but got {in_axes}.") if not all(isinstance(l, int) for l in tree_leaves(out_axes)): raise TypeError("smap out_axes must be an int, None, or (nested) container " f"with those types as leaves, but got {out_axes}.") in_specs = (Infer if in_axes is Infer else tree_map(partial(_axes_to_pspec, axis_name), in_axes, is_leaf=lambda x: x is None)) out_specs = tree_map(partial(_axes_to_pspec, axis_name), out_axes, is_leaf=lambda x: x is None) return _shard_map(f, mesh=None, in_specs=in_specs, out_specs=out_specs, axis_names={axis_name}, check_vma=True, _smap=True) @partial(traceback_util.api_boundary, repro_api_name="jax.shard_map") def _shard_map(f: F, *, mesh: Mesh | AbstractMesh | None, in_specs: Specs, out_specs: Specs, axis_names: Set[AxisName], check_vma: bool, _smap: bool = False) -> F: if not callable(f): raise TypeError("shard_map requires a callable for its first argument, " f"but got {f} of type {type(f)}.") @util.wraps(f) @traceback_util.api_boundary def wrapped(*args): nonlocal mesh, axis_names mesh, axis_names = _shmap_checks( mesh, axis_names, in_specs, out_specs, _smap) dbg = api_util.debug_info("shard_map", f, args, {}) args_flat = FlatTree.flatten(args) api_util.check_no_transformed_refs_args(lambda: dbg, args_flat) try: in_specs_flat = broadcast_prefix( in_specs, args, is_leaf=lambda x: x is None) except ValueError: e, *_ = prefix_errors(in_specs, args) raise e('shard_map in_specs') from None if (in_specs is Infer and all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): arg_s = [typeof(a).sharding for a in args_flat] assert all(i is Infer for i in in_specs_flat), in_specs_flat in_specs_flat = [_manual_spec(axis_names, s.spec, mesh) for s in arg_s] in_tree = args_flat.tree which_dyn = [s is not None for s in in_specs_flat] static_args = [x for x, dyn in zip(args_flat, which_dyn) if not dyn] dyn_args = [x for x, dyn in zip(args_flat, which_dyn) if dyn] in_specs_flat = tuple(s for s, dyn in zip(in_specs_flat, which_dyn) if dyn) dyn_argnums = [i for i, dyn in enumerate( which_dyn) if dyn] _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, dyn_args) # TODO(yashkatariya): Add support for partial manual mesh_axis_names_wo_vmap = ( frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names) if (mesh_axis_names_wo_vmap == axis_names and all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): for a, s in zip(dyn_args, in_specs_flat): if not isinstance(s, P): continue arg_aval = typeof(a) s = s._normalized_spec_for_aval(arg_aval.ndim) if config.remove_size_one_mesh_axis_from_type.value: s = core.remove_size_one_mesh_axis(s, mesh) if arg_aval.sharding.spec != s: raise ValueError( f"in_specs passed to shard_map: {s} does not match the specs of" f" the input: {arg_aval.sharding.spec} for arg: {typeof(a)}." " `in_specs` is an optional argument so you can omit specifying" " it and shard_map will infer the in_specs from the arguments." " If you want to reshard your inputs, you can use `jax.reshard`" " on the arguments and then pass those args to shard_map.") if (dbg.arg_names is not None and len(dyn_args) != len(dbg.arg_names)): dbg = dbg.with_unknown_names() def f_wrapped(*dyn_args): dyn_args_iter = iter(dyn_args) static_args_iter = iter(static_args) all_args = [next(dyn_args_iter) if dyn else next(static_args_iter) for dyn in which_dyn] args = tree_unflatten(in_tree, all_args) ans = f(*args) ans_ft = FlatTree.flatten(ans) try: out_specs_flat = tuple(broadcast_prefix(out_specs, ans)) except ValueError: e, *_ = prefix_errors(out_specs, ans) raise e('shard_map out_specs') from None def add_implicit_pvary_and_unreduced(val, spec): if not isinstance(spec, P): return val aval = typeof(val) val = pvary(val, tuple(_spec_to_vma(spec) - aval.mat.varying)) return (lax_parallel.vary_unreduced_cast(val, tuple(unreduced)) if (unreduced := spec.unreduced - aval.mat.unreduced) else val) if check_vma: ans_ft = ans_ft.map2(add_implicit_pvary_and_unreduced, out_specs_flat) return ans_ft.with_aux(out_specs_flat) try: out_ft = shard_map_p.bind( *dyn_args, subfuns=(f_wrapped,), mesh=mesh, in_specs=in_specs_flat, check_vma=check_vma, manual_axes=axis_names, debug_info=dbg) except _SpecError as e: fails, out_tree = e.args msg = _spec_rank_error(SpecErrorType.out, f, out_tree, out_specs, fails) if any(fail is not no_fail and not fail.shape for fail in fails): msg += (" In particular, for rank 0 outputs which are not constant " "over the mesh, add at least one (singleton) axis to them so " "that they can be concatenated using out_specs.") raise ValueError(msg) from None except _RepError as e: fails, out_tree, = e.args msg = _inout_vma_error(f, mesh, out_tree, out_specs, fails) raise ValueError(msg) from None return out_ft.unflatten() return cast(F, wrapped) def _axes_to_pspec(axis_name, axis): if axis is None: return P() return P(*[None] * axis + [axis_name]) def _shmap_checks(mesh, axis_names, in_specs, out_specs, _smap): if mesh is None: mesh = get_abstract_mesh() if mesh.empty: raise ValueError( "The context mesh cannot be empty. Use" " `jax.set_mesh(mesh)` to enter into a mesh context") else: ctx_mesh = get_abstract_mesh() if not ctx_mesh.empty and mesh.abstract_mesh != ctx_mesh: raise ValueError( f"The context mesh {ctx_mesh} should match the mesh passed to" f" shard_map {mesh}") if not isinstance(mesh, (Mesh, AbstractMesh)): raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " "`jax.sharding.AbstractMesh` instance for its " f"second argument, but got {mesh} of type {type(mesh)}.") if mesh.empty: raise ValueError(f"shard_map requires a non-empty mesh. Got {mesh}") mesh_axis_names_wo_vmap = ( frozenset(mesh.axis_names) - core.get_axis_env().explicit_mesh_axis_names ) if not isinstance(axis_names, (frozenset, set)): raise TypeError( "`axis_names` argument of shard_map should be of type `frozenset` or" f" `set`. Got type: {type(axis_names)}") if isinstance(axis_names, set): axis_names = frozenset(axis_names) if not axis_names: axis_names = mesh_axis_names_wo_vmap if not axis_names.issubset(mesh_axis_names_wo_vmap): raise ValueError( f"jax.shard_map requires axis_names={axis_names} to be a subset of " f"mesh.axis_names={mesh_axis_names_wo_vmap}") if (in_specs is Infer and not all(mesh._name_to_type[a] == AxisType.Explicit for a in axis_names)): axis_types = ', '.join(str(mesh._name_to_type[a]) for a in axis_names) if _smap: msg = (f"in_axes was not specified when axis_name={axis_names} was of" f" type {axis_types}") else: msg = ("shard_map in_specs argument must be a pytree of" " `jax.sharding.PartitionSpec` instances, but it was `None` when" f" {axis_names=} are of type {axis_types}") raise TypeError(msg) if in_specs is not Infer and in_specs is not None: _check_specs(SpecErrorType.input, in_specs, axis_names) _check_unreduced(SpecErrorType.input, mesh, axis_names, in_specs) _check_specs(SpecErrorType.out, out_specs, axis_names) _check_unreduced(SpecErrorType.out, mesh, axis_names, out_specs) return mesh, axis_names def _manual_spec(manual_axes, spec: P, mesh) -> P: out: list[str | tuple[str | None, ...] | None] = [] s: str | None | tuple[str, ...] for s in spec: if s is None: out.append(s) elif isinstance(s, tuple): temp = [p if p in manual_axes else None for p in s] while temp and temp[-1] is None: temp.pop() if None in temp: raise ValueError(f"Invalid spec: {spec}") out.append(None if len(temp) == 0 else tuple(temp)) else: out.append(s if s in manual_axes else None) _check_unreduced(SpecErrorType.input, mesh, manual_axes, spec) return P(*out, unreduced=spec.unreduced, reduced=spec.reduced) # Error checking and messages SpecErrorType = enum.Enum('SpecErrorType', ['input', 'out']) def _check_unreduced(error_type, mesh, manual_axes, specs): from jax._src.hijax import HiPspec prefix = 'in' if error_type == SpecErrorType.input else 'out' full_manual = frozenset(mesh.axis_names) == manual_axes specs_flat, _ = tree_flatten(specs) for s in specs_flat: if isinstance(s, HiPspec): continue # TODO(mattjj,yashkatariya): add user validation method if not s.unreduced and not s.reduced: continue if not full_manual: raise NotImplementedError( f"unreduced/reduced can only be passed to {prefix}_specs when" " shard_map is in full manual mode. Got mesh axis names" f" {mesh.axis_names}, manual_axes: {manual_axes}, specs: {s}. Please" " file a bug at https://github.com/jax-ml/jax/issues.") if not all(mesh._name_to_type[u] == AxisType.Explicit for u in s.unreduced): raise ValueError( f"unreduced in {prefix}_specs {s} can only be used when the mesh" " passed to shard_map contains axis names all of type `Explicit`." f" Got mesh {mesh}") if not all(mesh._name_to_type[u] == AxisType.Explicit for u in s.reduced): raise ValueError( f"reduced in {prefix}_specs {s} can only be used when the mesh" " passed to shard_map contains axis names all of type `Explicit`." f" Got mesh {mesh}") def _check_specs(error_type: SpecErrorType, specs: Any, manual_axes) -> None: from jax._src.hijax import HiPspec if error_type == SpecErrorType.input and specs is None: raise TypeError( "shard_map in_specs argument must be a pytree of " "`jax.sharding.PartitionSpec` instances, but it was None.\n" "Instead of `in_specs=None`, did you mean `in_specs=P()`, " "where `P = jax.sharding.PartitionSpec`?") def check_spec(p): if isinstance(p, HiPspec): return True # TODO(mattjj,yashkatariya): add user validation method if not isinstance(p, PartitionSpec): return False for names in p: names = (names,) if not isinstance(names, tuple) else names for name in names: if name is not None and name not in manual_axes: return False return True if all(check_spec(p) for p in tree_leaves(specs)): return prefix = 'in' if error_type == SpecErrorType.input else 'out' msgs = [f" {prefix}_specs{keystr(key)} is {x} of type {type(x).__name__}, " for key, x in generate_key_paths(specs) if not isinstance(x, P)] if not msgs: for key, p in generate_key_paths(specs): for names in p: names = (names,) if not isinstance(names, tuple) else names for name in names: if name is not None and name not in manual_axes: msgs.append(f" {prefix}_specs{keystr(key)} refers to {repr(name)}") raise ValueError( f"shard_map {prefix}_specs argument must refer to an axis " f"marked as manual ({manual_axes}), but:\n\n" + '\n\n'.join(msgs) + '\n\n' f"Check the {prefix}_specs values passed to shard_map.") raise TypeError( f"shard_map {prefix}_specs argument must be a pytree of " f"`jax.sharding.PartitionSpec` instances, but:\n\n" + '\n\n'.join(msgs) + '\n\n' f"Check the {prefix}_specs values passed to shard_map.") class NoFail: def __repr__(self): return "NoFail()" no_fail = NoFail() def _check_specs_vs_args( f: Callable, mesh: Mesh | AbstractMesh, in_tree: PyTreeDef, in_specs: Specs, dyn_argnums: Sequence[int], in_specs_flat: Sequence[P], xs: Sequence) -> None: in_avals = map(core.shaped_abstractify, xs) fail = [a if isinstance(p, P) and len(p) > a.ndim else no_fail for p, a in zip(in_specs_flat, in_avals)] if any(f is not no_fail for f in fail): fail = _expand_fail(in_tree, dyn_argnums, fail) msg = _spec_rank_error(SpecErrorType.input, f, in_tree, in_specs, fail) raise ValueError(msg) bad = lambda a, d, ns: a.shape[d] % prod(mesh.shape[n] for n in ns) fail = [a if (isinstance(s, P) and any(bad(a, d, ns) for d, ns in _spec_to_names(s).items())) else no_fail for a, s in zip(in_avals, in_specs_flat)] if any(f is not no_fail for f in fail): fail = _expand_fail(in_tree, dyn_argnums, fail) msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) raise ValueError(msg) def _expand_fail(in_tree: PyTreeDef, dyn_argnums: Sequence[int], fail: Sequence[core.ShapedArray | NoFail] ) -> list[core.ShapedArray | NoFail]: fail_: list[core.ShapedArray | NoFail] = [no_fail] * in_tree.num_leaves for i, f in zip(dyn_argnums, fail): fail_[i] = f return fail_ def _spec_rank_error( error_type: SpecErrorType, f: Callable, tree: PyTreeDef, specs: Specs, fails: list[core.ShapedArray | NoFail]) -> str: fun_name = util_fun_name(f) if error_type == SpecErrorType.input: prefix, base = 'in', 'the passed args' ba = _try_infer_args(f, tree) else: prefix, base = 'out', f'{fun_name}(*args)' ba = None msgs = [] for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): extra = "" if error_type == SpecErrorType.input and ba is not None: arg_key, *_ = fail_key param_names, params = unzip2( (name, param) for name, param in ba.signature.parameters.items() if param.kind not in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD)) if (arg_key.idx >= len(params) or params[arg_key.idx].kind == inspect.Parameter.VAR_POSITIONAL): extra = (f", where args{arg_key} is the index " f"{arg_key.idx - len(params) + 1} component " f"of {fun_name}'s varargs parameter '{param_names[-1]}',") else: param_name = params[arg_key.idx] extra = (f", where args{arg_key} is bound to {fun_name}'s " f"parameter '{param_name}',") msgs.append( f"* {prefix}_specs{keystr(spec_key)} is {spec} which has length " f"{len(spec)}, but " f"{base}{keystr(fail_key)}{extra} has shape {aval.str_short()}, " f"which has rank {aval.ndim} (and {aval.ndim} < {len(spec)})") assert msgs if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point msg = (f"shard_map applied to the function '{fun_name}' was given an " f"{prefix}_specs entry which is too long to be compatible with the " f"corresponding {prefix}put value from the function:\n\n" + '\n\n'.join(msgs) + '\n\n' + f"Entries in {prefix}_specs must be of length no greater than the " f"number of axes in the corresponding {prefix}put value.\n\n" f"Either revise the spec to be shorter, or modify '{fun_name}' so " f"that its {prefix}puts have sufficient rank.") if any(not aval.ndim for _, (_, aval) in _iter_paths(tree, specs, fails)): msg += (f"\n\nFor scalar values (rank 0), consider using an {prefix}_specs " "entry of `P()`, where `P = jax.sharding.PartitionSpec`.") return msg def _spec_divisibility_error( f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs, fails: list[core.ShapedArray | NoFail]) -> str: ba = _try_infer_args(f, tree) fun_name = getattr(f, '__name__', str(f)) msgs = [] for (spec_key, spec), (fail_key, aval) in _iter_paths(tree, specs, fails): extra = "" if ba is not None: arg_key, *_ = fail_key param_names, params = unzip2( (name, param) for name, param in ba.signature.parameters.items() if param.kind not in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD)) if (arg_key.idx >= len(params) or params[arg_key.idx].kind == inspect.Parameter.VAR_POSITIONAL): extra = (f", where args{arg_key} is the index " f"{arg_key.idx - len(params) + 1} component " f"of {fun_name}'s varargs parameter '{param_names[-1]}',") else: param_name = params[arg_key.idx] extra = (f", where args{arg_key} is bound to {fun_name}'s " f"parameter '{param_name}',") names = _spec_to_names(spec) for d, ns in names.items(): if aval.shape[d] % prod(mesh.shape[n] for n in ns): axis = f"axes {ns}" if len(ns) > 1 else f"axis '{ns[0]}'" total = 'total ' if len(ns) > 1 else '' sz = prod(mesh.shape[n] for n in ns) msgs.append( f"* the passed args{keystr(fail_key)} of shape {aval.str_short()}{extra} " f"corresponds to in_specs{keystr(spec_key)} of value {spec}, " f"which maps array axis {d} (of size {aval.shape[d]}) to mesh " f"{axis} (of {total}size {sz}), but {sz} does not evenly divide " f"{aval.shape[d]}") assert msgs if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point msg = (f"shard_map applied to the function '{fun_name}' was given argument " f"arrays with axis sizes that are not evenly divisible by the " f"corresponding mesh axis sizes:\n\n" f"The mesh given has shape {tuple(mesh.shape.values())} with " f"corresponding axis names {mesh.axis_names}.\n\n" + '\n\n'.join(msgs) + '\n\n' + f"Array arguments' axis sizes must be evenly divisible by the mesh " f"axis or axes indicated by the corresponding elements of the " f"argument's in_specs entry. Consider checking that in_specs are " f"correct, and if so consider changing the mesh axis sizes or else " f"padding the input and adapting '{fun_name}' appropriately.") return msg def _inout_vma_error(f: Callable, mesh: Mesh | AbstractMesh, tree: PyTreeDef, specs: Specs, fails: list[core.ManualAxisType | NoFail] ) -> str: fun_name = getattr(f, '__name__', str(f)) msgs = [] for (spec_key, spec), (fail_key, mat) in _iter_paths(tree, specs, fails): unmentioned = _unmentioned(mesh, spec) if len(unmentioned) > 1: need_vma = ','.join(map(str, order_wrt_mesh(mesh, _spec_to_vma(spec)))) got_vma = ','.join(map(str, order_wrt_mesh(mesh, mat.varying))) diff = ','.join(map(str, order_wrt_mesh( mesh, [n for n in unmentioned if n in mat.varying]))) msgs.append( f"* out_specs{keystr(spec_key)} is {spec} which implies that the " f"corresponding output value is only varying across mesh axes " f"{{{need_vma}}} and not {{{diff}}}, but it was inferred to be " f"possibly varying over {{{got_vma}}}") else: need_rep_, = unmentioned msgs.append( f"* out_specs{keystr(spec_key)} is {spec} which implies that the " f"corresponding output value is replicated across mesh axis " f"'{need_rep_}', but could not infer replication over any axes") assert msgs if len(msgs) == 1: msgs = [msgs[0][2:]] # remove the bullet point msg = (f"shard_map applied to the function '{fun_name}' was given " f"out_specs which require replication which can't be statically " f"inferred given the mesh:\n\n" f"The mesh given has shape {tuple(mesh.shape.values())} with " f"corresponding axis names {mesh.axis_names}.\n\n" + '\n\n'.join(msgs) + '\n\n' + "Check if these output values are meant to be replicated over those " "mesh axes. If not, consider revising the corresponding out_specs " "entries. If so, consider disabling the check by passing the " "check_vma=False argument to `jax.shard_map`.") return msg def _unmentioned(mesh: Mesh | AbstractMesh, spec) -> list[AxisName]: vur = _spec_to_mat(spec).vur return [n for n in mesh.axis_names if n not in vur] def _try_infer_args(f, tree): dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) try: return inspect.signature(f).bind(*dummy_args) except (TypeError, ValueError): return None T = TypeVar('T') def _iter_paths(tree: PyTreeDef, specs: Specs, fails: list[T | NoFail] ) -> list[tuple[tuple[KeyPath, P], tuple[KeyPath, T]]]: failures = tree_unflatten(tree, fails) failures_aug = generate_key_paths(failures) specs_ = tree_unflatten(tree_structure(specs), map(Tup, generate_key_paths(specs))) specs_aug = broadcast_prefix(specs_, failures, is_leaf=lambda x: x is None) return [(s, (fail_key, fail_data)) for s, (fail_key, fail_data) in zip(specs_aug, failures_aug) if s is not None and fail_data is not no_fail] class Tup: def __init__(self, vals): self.vals = vals def __iter__(self): return iter(self.vals) # Primitive JaxType = Any MaybeTracer = Union[JaxType, Tracer] class ShardMapPrimitive(core.Primitive): multiple_results = True skip_canonicalization = True def bind_with_trace(self, trace, args, avals, params, /): fun, = params.pop('subfuns') # fun returns a FlatTree containing a tuple of the user-level data and a flat, # broadcasted out_specs wrapped in `Static`. # The result of `bind_with_trace` is a `FlatTree` of tracer-like things and # doesn't include the `Static` out_specs. return trace.process_shard_map(shard_map_p, fun, args, **params) def get_bind_params(self, params): new_params = dict(params) jaxpr = new_params.pop('jaxpr') assert isinstance(jaxpr, core.Jaxpr) axes = new_params.pop('out_specs') def eval_jaxpr(*args): result = core.eval_jaxpr(jaxpr, (), *args) return FlatTree.flatten(result).with_aux(axes) new_params['subfuns'] = (eval_jaxpr,) new_params['debug_info'] = jaxpr.debug_info return new_params shard_map_p = ShardMapPrimitive('shard_map') # Staging @util.cache(max_size=256, trace_context_in_key=False) def _as_manual_mesh(mesh, manual_axes: frozenset) -> AbstractMesh: return mesh.abstract_mesh.update_axis_types( {n: AxisType.Manual for n in manual_axes}) def _extend_axis_env(mesh, manual_axes): return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() if k in manual_axes]) def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: Callable, args: Sequence[Any], *, mesh: Mesh, in_specs, check_vma: bool, manual_axes: frozenset, debug_info ) -> FlatTree: source_info = source_info_util.current() inner_mesh = _as_manual_mesh(mesh, manual_axes) in_avals = [typeof(arg) for arg in args] in_avals_ = map(partial(shard_aval, mesh, manual_axes, check_vma), in_specs, in_avals) with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): in_avals_flat_tree = FlatTree.flatten((in_avals_, {})) jaxpr, out_data = pe.trace_to_jaxpr( f, in_avals_flat_tree, debug_info, fun_returns_flat_tree=True, requires_low=trace.requires_low) out_avals_ft, out_specs = out_data.unpack_aux() _check_names(out_specs, out_avals_ft) if check_vma: _check_mats(mesh, out_specs, out_avals_ft) out_avals = [unshard_aval(mesh, check_vma, spec, aval) for spec, aval in zip(out_specs, out_avals_ft)] with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): jaxpr, consts = pe.separate_consts(jaxpr) in_specs_staged = (*(_repspec(typeof(c)) for c in consts), *in_specs) if trace.requires_low: in_specs_staged = tuple(lo_spec for hi_spec in in_specs_staged for lo_spec in hi_spec.to_lo()) out_specs = tuple(lo_spec for hi_spec in out_specs for lo_spec in hi_spec.to_lo()) params = dict(mesh=mesh, in_specs=in_specs_staged, out_specs=out_specs, jaxpr=jaxpr.jaxpr, check_vma=check_vma, manual_axes=manual_axes) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) to_jaxpr_tracer = partial(trace.to_jaxpr_tracer, source_info=source_info) const_tracers = map(to_jaxpr_tracer, consts) trace.frame.is_high |= jaxpr.is_high if trace.requires_low: in_tracers = [to_jaxpr_tracer(loval) for arg in args for loval in typeof(arg).lower_val(arg)] out_avals_lo = [lo_aval for aval in out_avals for lo_aval in aval.lo_ty()] else: in_tracers = map(to_jaxpr_tracer, args) out_avals_lo = out_avals out = trace.emit_eqn([*const_tracers, *in_tracers], out_avals_lo, prim, params, effs, source_info) if trace.requires_low: out = pe.raise_lo_outs(out_avals, out) return out_avals_ft.update(out) pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging # TODO add underscore version, for direct-linearize to consume def _spec_to_names(spec: PartitionSpec): return {i: names if isinstance(names, tuple) else (names,) for i, names in enumerate(spec) if names is not None} def _shard_shaped_array(mesh: Mesh, manual_axes: frozenset, check_vma, spec, aval: core.ShapedArray) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) if spec.unreduced != aval.sharding.spec.unreduced: raise ValueError( f"in_specs containing unreduced {spec} passed to shard_map should be" " equal to the unreduced present on the in_aval" f" {aval.str_short(True)}") if spec.reduced != aval.sharding.spec.reduced: raise ValueError( f"in_specs containing reduced {spec} passed to shard_map should be" f" equal to the reduced present on the in_aval {aval.str_short(True)}") names = _spec_to_names(spec) new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) manual_mesh = _as_manual_mesh(mesh, manual_axes) new_sharding = aval.sharding.update( mesh=manual_mesh, spec=core.modify_spec_for_auto_manual(aval.sharding.spec, manual_mesh)) vma = (_spec_to_vma(spec) if check_vma else frozenset()) | aval.mat.varying unreduced = aval.sharding.spec.unreduced if check_vma else frozenset() reduced = aval.sharding.spec.reduced if check_vma else frozenset() mat = core.ManualAxisType(varying=vma, unreduced=unreduced, reduced=reduced) return aval.update(shape=new_shape, sharding=new_sharding, manual_axis_type=mat) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array def _unshard_shaped_array(mesh: Mesh, check_vma, spec, aval: core.ShapedArray ) -> core.ShapedArray: assert isinstance(aval, core.ShapedArray) if check_vma and spec.unreduced != aval.mat.unreduced: raise ValueError( "out_specs passed to shard_map should be equal to the unreduced" f" present on the out_aval. Got out_specs={spec} and" f" out_aval={aval.str_short(True)}") if check_vma and spec.reduced != aval.mat.reduced: raise ValueError( "out_specs passed to shard_map should be equal to the reduced present" f" on the out_aval. Got out_specs={spec} and" f" out_aval={aval.str_short(True)}") names = _spec_to_names(spec) new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) names_spec = spec._normalized_spec_for_aval(aval.ndim) if aval.ndim == 0: out_spec = P(unreduced=spec.unreduced, reduced=spec.reduced) else: out_spec = [] for name_s, aval_s in zip(names_spec, aval.sharding.spec): if name_s and not aval_s: out_spec.append(name_s) elif aval_s and not name_s: out_spec.append(aval_s) elif not name_s and not aval_s: out_spec.append(None) else: assert name_s and aval_s name_s = name_s if isinstance(name_s, tuple) else (name_s,) aval_s = aval_s if isinstance(aval_s, tuple) else (aval_s,) out_spec.append(name_s + aval_s) out_spec = PartitionSpec(*out_spec, unreduced=spec.unreduced, reduced=spec.reduced) new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else get_abstract_mesh()) new_sharding = NamedSharding(new_mesh, out_spec) manual_axes = set(new_mesh.manual_axes) vma = frozenset(v for v in aval.mat.varying if v in manual_axes) # TODO(yashkatariya): Handle partial manual unreduced/reduced. out_mat = core.ManualAxisType(varying=vma) return aval.update(shape=new_shape, sharding=new_sharding, manual_axis_type=out_mat) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_specs, out_specs, check_vma, manual_axes): # TODO(mattjj,parkers): check auto for v, x, in_spec in zip(jaxpr.invars, in_atoms, in_specs): sharded_aval = shard_aval(mesh, manual_axes, check_vma, in_spec, x.aval) if not core.typecompat(v.aval, sharded_aval): raise core.JaxprTypeError("shard_map argument avals not compatible with " "jaxpr binder avals and in_specs") with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): core.check_jaxpr(jaxpr) if check_vma: for v, os in zip(jaxpr.outvars, out_specs): if isinstance(os, P) and not _valid_repeats(mesh, v.aval.mat, os): raise core.JaxprTypeError( "shard_map can't prove output is sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] out_avals = map(partial(unshard_aval, mesh, check_vma), out_specs, out_avals_sharded) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) return out_avals, effs core.custom_typechecks[shard_map_p] = _shard_map_typecheck def _valid_repeats(mesh: Mesh, mat: core.ManualAxisType, spec) -> bool: um = set(_unmentioned(mesh, spec)) - set(mesh.manual_axes) vur = mat.vur if any(u in vur for u in um): return False return True # Lowering def _shardy_shard_map_sharding( ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in ) -> sharding_impls.SdyArray: ns = _make_scoped_manual_sharding(ctx, mesh, spec) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) sdy_sharding = ns._to_sdy_sharding(aval_in.ndim) if len(manual_axes) < len(mesh.axis_names): for dim_sharding in sdy_sharding.dim_shardings: dim_sharding.is_open = True return sdy_sharding def _get_token_sharding( ctx: mlir.LoweringRuleContext, mesh ) -> sharding_impls.SdyArray: ns = _make_scoped_manual_sharding(ctx, mesh, P()) return ns._to_sdy_sharding(0) def _get_spmdaxis_ctx_mesh(mesh): if isinstance(mesh, AbstractMesh): concrete_mesh = get_concrete_mesh() return concrete_mesh if not concrete_mesh.empty else mesh return mesh def _shard_map_lowering_shardy( ctx: mlir.LoweringRuleContext, in_nodes, jaxpr: core.Jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma): axis_ctx = ctx.module_context.axis_context in_avals_ = [v.aval for v in jaxpr.invars] if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): # Nested `ManualComputationOp`s must only refer to the new manual axes, not # all existing ones. Grab the newly-added manual axes. shardy_manual_axes = manual_axes - axis_ctx.manual_axes else: shardy_manual_axes = manual_axes new_axis_context = sharding_impls.SPMDAxisContext( _get_spmdaxis_ctx_mesh(mesh), manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) tokens = [ctx.tokens_in.get(eff) for eff in ctx.tokens_in.effects()] num_tokens = len(tokens) manual_axes = order_wrt_mesh(mesh, shardy_manual_axes) if prod([mesh.shape[a] for a in manual_axes]) == 1: # No need for a `ManualComputationOp` if all manual axes are size 1. with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): out_nodes, tokens_out = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens)), (), *in_nodes, dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering, outer_traceback=_jax.Traceback()) ctx.set_tokens_out(tokens_out) return out_nodes in_shardings = list( map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), in_specs, ctx.avals_in)) const_args_and_avals = core.jaxpr_const_args(jaxpr) const_args, const_avals = util.unzip2(const_args_and_avals) num_const_args = len(const_args) const_arg_values = mlir.flatten_ir_values( mlir.ir_constants(c, const_lowering=ctx.const_lowering, aval=aval) for c, aval in const_args_and_avals ) # TODO(necula,yashkatariya): how to construct consts shardy shardings from # consts that can be ndarray or jax.Array? const_args_shardings = [ _shardy_shard_map_sharding(ctx, mesh, manual_axes, P(), core.typeof(c)) for c in const_args] num_dim_vars = len(ctx.dim_var_values) in_shardings = ( [_get_token_sharding(ctx, mesh)] * (num_tokens + num_dim_vars) + const_args_shardings + in_shardings) in_shardings = sharding_impls.SdyArrayList(in_shardings).build() out_shardings = list( map(partial(_shardy_shard_map_sharding, ctx, mesh, manual_axes), out_specs, ctx.avals_out)) out_shardings = [ _get_token_sharding(ctx, mesh)] * num_tokens + out_shardings out_shardings = sharding_impls.SdyArrayList(out_shardings).build() output_types = ([hlo.TokenType.get()] * num_tokens + mlir.flatten_ir_types(map(mlir.aval_to_ir_types, ctx.avals_out))) args = (*ctx.dim_var_values, *tokens, *const_arg_values, *in_nodes) manual_computation_op = sdy.ManualComputationOp( output_types, mlir.flatten_ir_values(args), in_shardings, out_shardings, sdy.ManualAxesAttr.get([ir.StringAttr.get(i) for i in manual_axes])) dim_var_types = [ mlir.aval_to_ir_type(core.ShapedArray((), dtypes.default_int_dtype())) ] * num_dim_vars token_types = [hlo.TokenType.get()] * num_tokens const_arg_types = mlir.flatten_ir_types(map(mlir.aval_to_ir_types, const_avals)) in_types = mlir.flatten_ir_types(map(mlir.aval_to_ir_types, in_avals_)) block = ir.Block.create_at_start( manual_computation_op.body, (*dim_var_types, *token_types, *const_arg_types, *in_types)) with (ir.InsertionPoint(block), _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma)): dim_var_values, token_arg_values, const_arg_values, in_args = util.split_list( block.arguments, [num_dim_vars, num_tokens, num_const_args]) out_nodes_, tokens_out = mlir.jaxpr_subcomp( sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(zip(ctx.tokens_in.effects(), token_arg_values)), (), *in_args, dim_var_values=dim_var_values, const_lowering={ (id(c), aval): ca for c, aval, ca in zip(const_args, const_avals, const_arg_values) }, outer_traceback=_jax.Traceback()) sdy.ReturnOp( mlir.flatten_ir_values( it.chain((v for _, v in tokens_out.items()), out_nodes_) ) ) num_tokens = len(tokens_out.effects()) tokens_out = tokens_out.update_tokens(mlir.TokenSet(zip( ctx.tokens_in.effects(), manual_computation_op.results[:num_tokens]))) ctx.set_tokens_out(tokens_out) return manual_computation_op.results[num_tokens:] def _shard_map_lowering(ctx: mlir.LoweringRuleContext, *in_nodes, jaxpr: core.Jaxpr, mesh, in_specs, out_specs, check_vma, manual_axes): if config.use_shardy_partitioner.value: return _shard_map_lowering_shardy( ctx, in_nodes, jaxpr, mesh, in_specs, out_specs, manual_axes, check_vma) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] in_nodes_ = map(partial(_xla_shard, ctx, mesh, manual_axes), in_specs, ctx.avals_in, in_avals_, in_nodes) new_axis_context = sharding_impls.SPMDAxisContext( _get_spmdaxis_ctx_mesh(mesh), manual_axes) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) with _extend_axis_env(mesh, manual_axes), config._check_vma(check_vma): out_nodes_, tokens_out = mlir.call_lowering( "shmap_body", pe.close_jaxpr(jaxpr), None, sub_ctx, in_avals_, out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values, const_lowering=ctx.const_lowering, arg_names=map(_pspec_mhlo_attrs, in_specs, in_avals_), result_names=map(_pspec_mhlo_attrs, out_specs, out_avals_)) ctx.set_tokens_out(tokens_out) return map(partial(_xla_unshard, ctx, mesh, manual_axes), out_specs, out_avals_, ctx.avals_out, out_nodes_) mlir.register_lowering(shard_map_p, _shard_map_lowering) def _make_scoped_manual_sharding(ctx, mesh, spec): axis_ctx = ctx.module_context.axis_context mesh = mesh.abstract_mesh if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): mesh = mesh.update_axis_types( {a: AxisType.Manual for a in axis_ctx.manual_axes}) return NamedSharding(mesh, spec) def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in, aval_out, x): if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: return x ns = _make_scoped_manual_sharding(ctx, mesh, spec) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_in, ns) aval_in = core.physical_aval(aval_in) shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() unspecified = (set(range(aval_in.ndim)) if len(manual_axes) < len(mesh.axis_names) else set()) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, unspecified_dims=unspecified) manual_proto = pxla.manual_proto( aval_in, manual_axes | set(mesh.manual_axes), mesh) return mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified) def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, manual_axes, spec, aval_in, aval_out, x): if prod([size for n, size in mesh.shape.items() if n in manual_axes]) == 1: return x ns = _make_scoped_manual_sharding(ctx, mesh, spec) if dtypes.issubdtype(aval_out.dtype, dtypes.extended): ns = sharding_impls.physical_sharding(aval_out, ns) aval_out = core.physical_aval(aval_out) unspecified = (set(range(aval_in.ndim)) if len(manual_axes) < len(mesh.axis_names) else set()) if dtypes.issubdtype(aval_in.dtype, dtypes.extended): aval_in = core.physical_aval(aval_in) manual_proto = pxla.manual_proto( aval_in, manual_axes | set(mesh.manual_axes), mesh) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=unspecified) shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, unspecified) def _pspec_mhlo_attrs(spec, aval: core.AbstractValue) -> str: if isinstance(aval, core.ShapedArray): names = _spec_to_names(spec) return str(map(names.get, range(aval.ndim))) return '' # Eager evaluation def get_mesh_from_args(args_flat, mesh): for a in args_flat: if (hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding) and not a.sharding.mesh.is_scalar): # pyrefly: ignore[missing-attribute] if a.sharding.mesh.shape_tuple != mesh.shape_tuple: aval = core.shaped_abstractify(a) raise ValueError( f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" " match the mesh shape passed to shard_map " f" {mesh.shape_tuple} for shape {aval.str_short()}") mesh = a.sharding.mesh if isinstance(mesh, AbstractMesh): raise ValueError( "Please pass `jax.Array`s with a `NamedSharding` as input to" " `shard_map` when passing `AbstractMesh` to the mesh argument.") assert isinstance(mesh, Mesh) return mesh def _spec_to_vma(spec): return frozenset(p for s in spec if s is not None for p in (s if isinstance(s, tuple) else (s,))) def _mat_to_spec(mesh, mat): return P(order_wrt_mesh(mesh, mat.varying), unreduced=mat.unreduced, reduced=mat.reduced) def _spec_to_mat(spec) -> core.ManualAxisType: return core.ManualAxisType(varying=_spec_to_vma(spec), unreduced=spec.unreduced, reduced=spec.reduced) def _shard_map_impl(trace, prim, fun, args, *, mesh, in_specs, check_vma, manual_axes, debug_info): del prim if isinstance(mesh, AbstractMesh): concrete_mesh = get_concrete_mesh() mesh = concrete_mesh if not concrete_mesh.empty else mesh mesh = get_mesh_from_args(args, mesh) cur_mesh = get_abstract_mesh() args_ = map(partial(_unmatch_spec, mesh, check_vma, cur_mesh, manual_axes), in_specs, args) in_mat = map(_spec_to_mat, in_specs) outs, out_specs, out_mat = _run_shmap(fun, mesh, manual_axes, args_, in_mat, check_vma) out_avals = outs.map(lambda x: core.mapped_aval(x.shape[0], 0, core.typeof(x))) _check_names(out_specs, out_avals) if check_vma: _check_mats(mesh, out_specs, out_avals) src_pspecs = tuple(_mat_to_spec(mesh, m) for m in out_mat) else: src_pspecs = tuple(P(order_wrt_mesh(mesh, manual_axes)) for _ in range(len(out_mat))) dst_pspecs = out_specs return outs.map3(partial(_match_spec, mesh, check_vma, manual_axes), src_pspecs, dst_pspecs) core.EvalTrace.process_shard_map = _shard_map_impl def _run_shmap_lu(f, mesh, manual_axes, args, mats, check_vma): assert not mesh.manual_axes trace = ShardMapTrace(mesh, manual_axes, check_vma) in_tracers = map(partial(ShardMapTracer, trace), mats, args) inner_mesh = _as_manual_mesh(mesh, manual_axes) with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): ans = f.call_wrapped(*in_tracers) outs, out_mat = unzip2(map(trace.to_val_mat_pair, ans)) return outs, out_mat def _run_shmap(f, mesh, manual_axes, args, mats, check_vma): assert not mesh.manual_axes trace = ShardMapTrace(mesh, manual_axes, check_vma) in_tracers = map(partial(ShardMapTracer, trace), mats, args) inner_mesh = _as_manual_mesh(mesh, manual_axes) with (core.set_current_trace(trace), _extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): ans, out_specs = f(*in_tracers).unpack_aux() outs, outs_mat = ans.map(trace.to_val_mat_pair).unzip2() return outs, out_specs, list(outs_mat) def _unmatch_spec2(mesh, prev_manual, spec, x) -> JaxType: with (core.eval_context(), api.disable_jit(False), use_abstract_mesh(mesh.abstract_mesh)): return api.jit(HashablePartial(_unmatch2, mesh, prev_manual, spec))(x) def _unmatch2(mesh, prev_manual, spec, x): src = P(order_wrt_mesh(mesh, prev_manual), *spec) newly_manual = _spec_to_vma(spec) dst = P(order_wrt_mesh(mesh, prev_manual | newly_manual)) return shard_map(lambda x: x, in_specs=src, out_specs=dst, axis_names=prev_manual | newly_manual)(x) def _match_spec2(mesh, prev_manual, spec, x) -> JaxType: with (core.eval_context(), api.disable_jit(False), use_abstract_mesh(mesh.abstract_mesh)): return api.jit(HashablePartial(_match2, mesh, prev_manual, spec))(x) def _match2(mesh, prev_manual, spec, x): newly_manual = _spec_to_vma(spec) src = P(order_wrt_mesh(mesh, prev_manual | newly_manual)) dst = P(order_wrt_mesh(mesh, prev_manual), *spec) return shard_map(lambda x: x, in_specs=src, out_specs=dst, axis_names=prev_manual | newly_manual)(x) def _unmatch_spec(mesh: Mesh, check_vma, context_mesh, manual_axes, in_spec, x: JaxType) -> JaxType: with (core.eval_context(), api.disable_jit(False), use_abstract_mesh(context_mesh)): return api.jit(HashablePartial(_unmatch, mesh, check_vma, in_spec, manual_axes))(x) def _unmatch(mesh, check_vma, in_spec, manual_axes, x): if check_vma: used_axes = _spec_to_vma(in_spec) dst = P(order_wrt_mesh(mesh, used_axes), unreduced=in_spec.unreduced, reduced=in_spec.reduced) else: dst = P(mesh.axis_names) check_vma = False return shard_map(_add_singleton, mesh=mesh, in_specs=(in_spec,), out_specs=dst, check_vma=check_vma, axis_names=manual_axes)(x) def _check_names(specs, avals: FlatTree) -> None: fail = [a if isinstance(sp, P) and sp and len(sp) > a.ndim else no_fail for sp, a in zip(specs, avals)] if any(f is not no_fail for f in fail): raise _SpecError(fail, avals.tree) class _SpecError(Exception): pass def _check_mats(mesh, specs, avals): fail = [a.mat if isinstance(sp, P) and not _valid_repeats(mesh, a.mat, sp) else no_fail for sp, a in zip(specs, avals)] if any(f is not no_fail for f in fail): raise _RepError(fail, avals.tree) class _RepError(Exception): pass def _match_spec(mesh: Mesh, check_vma, manual_axes, x: JaxType, src_pspec: PartitionSpec, dst_pspec: PartitionSpec) -> JaxType: fn = HashablePartial(_match, mesh, check_vma, manual_axes, src_pspec, dst_pspec) with core.eval_context(), api.disable_jit(False): if set(mesh.axis_names) == manual_axes: return api.jit(fn, out_shardings=NamedSharding(mesh, dst_pspec))(x) return api.jit(fn)(x) def _match(mesh, check_vma, manual_axes, src_pspec, dst_pspec, x): return shard_map(_rem_singleton, mesh=mesh, in_specs=src_pspec, out_specs=dst_pspec, check_vma=check_vma, axis_names=manual_axes)(x) def _rem_singleton(x): return lax.squeeze(x, [0]) def _add_singleton(x): return lax.expand_dims(x, [0]) def _maybe_check_special(outs): if not config.debug_nans.value and not config.debug_infs.value: return bufs = [s.data for leaf in tree_leaves(outs) for s in getattr(leaf, 'addressable_shards', [])] try: dispatch.check_special('shard_map', bufs) except api_util.InternalFloatingPointError as e: raise FloatingPointError(f'Invalid value ({e.ty}) encountered in sharded computation.') from None class ShardMapTrace(core.Trace): __slots__ = ("mesh", "manual_axes", "check", "amesh") mesh: Mesh # outer concrete or abstract mesh manual_axes: frozenset[AxisName] check: bool def __init__(self, mesh, manual_axes, check): super().__init__() self.mesh = mesh self.manual_axes = manual_axes self.check = check self.amesh = mesh.abstract_mesh def to_val_mat_pair(self, val): if isinstance(val, ShardMapTracer): return val.val, val.mat elif isinstance(val, Tracer): raise Exception(f"Shouldn't have any non-shard_map tracers: {val}") else: val_ = _unmatch_spec(self.mesh, self.check, self.amesh, self.manual_axes, P(), val) return val_, core.empty_mat def process_primitive(self, prim, tracers, params, /): in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers)) if self.check: out_avals, _ = prim.abstract_eval(*(typeof(t) for t in tracers), **params) out_avals = tuple(out_avals) if type(out_avals) is list else out_avals out_mat = tree_map(lambda a: a.mat, out_avals) in_specs = tuple(map(partial(_mat_to_spec, self.mesh), in_mat)) out_specs = tree_map(partial(_mat_to_spec, self.mesh), out_mat) else: out_mat = core.empty_mat in_specs = out_specs = P(order_wrt_mesh(self.mesh, self.manual_axes)) eager_rule = eager_rules.get(prim) if eager_rule: out_vals = eager_rule(self.mesh, *in_vals, **params) else: f = HashablePartial( _prim_applier, prim, self.check, tuple(params.items()), self.mesh, self.manual_axes, in_specs, out_specs) with (core.eval_context(), api.disable_jit(False), config.debug_nans(False), config.debug_infs(False), use_abstract_mesh(self.amesh), sharding_impls._internal_use_concrete_mesh(self.mesh)): out_vals = api.jit(f)(*in_vals) _maybe_check_special(out_vals) if prim.multiple_results: out_mat = (out_mat if isinstance(out_mat, (list, tuple)) else [out_mat] * len(out_vals)) return map(partial(ShardMapTracer, self), out_mat, out_vals) return ShardMapTracer(self, out_mat, out_vals) def process_shard_map(self, prim, fun, args, mesh, in_specs, check_vma, manual_axes, debug_info): # Check consistency between outer and inner shmaps on explicitly passed # mesh and check_vma. if isinstance(mesh, Mesh): if mesh != self.mesh: raise Exception del mesh if check_vma != self.check: # TODO(mattjj): add check in jit path raise Exception del check_vma in_vals, in_mats = unzip2(map(self.to_val_mat_pair, args)) if any(m.unreduced or m.reduced for m in in_mats): raise NotImplementedError( "Eager shard_map + unreduced/reduced + partial manual is not" " implemented. Please wrap your shard_map in `jax.jit`.") trace = ShardMapTrace(self.mesh, manual_axes | self.manual_axes, self.check) in_vals_ = [_unmatch_spec2(self.mesh, self.manual_axes, spec, x) for x, spec in zip(in_vals, in_specs)] # TODO(yashkatariya): Handle unreduced/reduced correctly. in_mats_ = [core.ManualAxisType(varying=mat.varying | _spec_to_vma(s)) for mat, s in zip(in_mats, in_specs)] in_tracers = map(partial(ShardMapTracer, trace), in_mats_, in_vals_) inner_mesh = _as_manual_mesh(self.mesh, manual_axes | self.manual_axes) with (core.set_current_trace(trace), _extend_axis_env(self.mesh, manual_axes), use_abstract_mesh(inner_mesh)): ans_aux = fun(*in_tracers) ans, out_specs = ans_aux.unpack_aux() out_vals_, out_mats_ = ans.map(trace.to_val_mat_pair).unzip2() out_vals = out_vals_.map2( lambda x, spec: _match_spec2(self.mesh, self.manual_axes, spec, x), out_specs) # TODO(yashkatariya): Handle unreduced/reduced correctly. out_mats = [core.ManualAxisType(varying=mat.varying - _spec_to_vma(spec)) for mat, spec in zip(out_mats_, out_specs)] return out_vals.map2(lambda val, vma: ShardMapTracer(self, vma, val), out_mats) def process_call(self, call_primitive, fun, tracers, params, /): raise NotImplementedError( f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " "yet supported. Put a `jax.jit` around the `shard_map`-decorated " "function, and open a feature request at " "https://github.com/jax-ml/jax/issues !") def process_custom_jvp_call(self, prim, fun, jvp, tracers, /, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. del prim, jvp, symbolic_zeros in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers)) out_vals, out_mat = _run_shmap_lu(fun, self.mesh, self.manual_axes, in_vals, in_mat, self.check) return map(partial(ShardMapTracer, self), out_mat, out_vals) def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): if symbolic_zeros: msg = ("custom_vjp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros in_vals, in_mat = unzip2(map(self.to_val_mat_pair, tracers)) out_vals, out_mat = _run_shmap_lu(fun, self.mesh, self.manual_axes, in_vals, in_mat, self.check) return map(partial(ShardMapTracer, self), out_mat, out_vals) class ShardMapTracer(core.Tracer[ShardMapTrace]): mat: core.ManualAxisType val: JaxType def __init__(self, trace, mat, val): assert isinstance(mat, core.ManualAxisType) aval = core.typeof(val) mat = (mat if trace.check else core.ManualAxisType(varying=trace.manual_axes)) size = prod(trace.mesh.shape[n] for n in mat.varying) out = core.mapped_aval(size, 0, aval) manual_mesh = _as_manual_mesh(trace.amesh, trace.manual_axes) spec = core.modify_spec_for_auto_manual(out.sharding.spec, manual_mesh) # pyrefly: ignore[missing-attribute] new_sharding = NamedSharding(manual_mesh, spec) mat_out = mat if trace.check else core.empty_mat computed_aval = out.update(sharding=new_sharding, manual_axis_type=mat_out) super().__init__(trace, computed_aval) self.mat = mat self.val = val def to_concrete_value(self): if self._trace.check and self.mat.vur == frozenset(): with (core.eval_context(), use_abstract_mesh(self._trace.amesh), sharding_impls._internal_use_concrete_mesh(self._trace.mesh)): return core.to_concrete_value(self.val[0]) else: return None def __str__(self) -> str: pb_names = set(self._trace.mesh.axis_names) - self.mat.vur self = pvary(self, tuple(pb_names)) with (core.eval_context(), use_abstract_mesh(self._trace.amesh), sharding_impls._internal_use_concrete_mesh(self._trace.mesh)): blocks = list(self.val) mesh = self._trace.mesh axis_names = f"({', '.join(map(str, mesh.axis_names))},)" return '\n'.join( f"On {device} at mesh coordinates {axis_names} = {idx}:\n{block}\n" for (idx, device), block in zip(np.ndenumerate(mesh.devices), blocks)) __repr__ = __str__ # for debuggers, like `p x` def _prim_applier(prim, check_vma, params_tup, concrete_mesh, manual_axes, in_specs, out_specs, *args): def apply(*args): outs = prim.bind(*map(_rem_singleton, args), **dict(params_tup)) return tree_map(_add_singleton, outs) out_specs = list(out_specs) if type(out_specs) is tuple else out_specs return shard_map(apply, mesh=concrete_mesh, in_specs=in_specs, out_specs=out_specs, check_vma=check_vma, axis_names=manual_axes)(*args) eager_rules: dict[core.Primitive, Callable] = {} def _device_put_eager_rule(mesh, *xs, srcs, devices, copy_semantics): del mesh, srcs, copy_semantics for device in devices: if device is not None: raise ValueError("device_put with explicit device not allowed within " f"shard_map-decorated functions, but got device {device}") return xs eager_rules[dispatch.device_put_p] = _device_put_eager_rule def _ref_raise_valueerror(*args, **kwargs): raise ValueError( "Eager shard_map cannot return a `jax.Ref`. Please wrap" " your shard_map in `jax.jit`.") eager_rules[core.ref_p] = _ref_raise_valueerror eager_rules[core.empty_ref_p] = _ref_raise_valueerror # Batching def used_axis_names(spec): return _spec_to_mat(spec).vur def _shard_map_batch( trace: batching.BatchTrace, prim: core.Primitive, fun: Callable, in_tracers: Sequence[batching.BatchTracer], mesh: Mesh, in_specs, check_vma: bool, manual_axes: frozenset, debug_info) -> Sequence[batching.BatchTracer]: in_vals, in_dims = unzip2(map(trace.to_batch_info, in_tracers)) spmd_axis_name = trace.axis_data.spmd_name explicit_mesh_axis = trace.axis_data.explicit_mesh_axis if spmd_axis_name is not None: used = {n for spec in in_specs for n in used_axis_names(spec)} if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: raise ValueError("vmap spmd_axis_name cannot appear in shard_map in_specs") new_in_specs = [ sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_axis_name) for sp, d in zip(in_specs, in_dims)] new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name) new_axis_data = batching.AxisData( trace.axis_data.name, new_size, trace.axis_data.spmd_name, trace.axis_data.explicit_mesh_axis) elif explicit_mesh_axis is not None: used = {n for spec in in_specs for n in used_axis_names(spec)} if set(explicit_mesh_axis) & used: raise ValueError("vmapped away explicit mesh axis cannot appear in " "shard_map in_specs") new_in_specs = [ sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) for sp, d in zip(in_specs, in_dims)] new_axis_data = trace.axis_data else: new_in_specs = [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) for sp, d in zip(in_specs, in_dims)] new_axis_data = trace.axis_data def fun_batched(*args): ans_aux, out_dims = batching.batch_subtrace_2( fun, trace.tag, new_axis_data, tuple(in_dims), args) ans, out_specs = ans_aux.unpack_aux() new_out_specs = _batch_out_specs(spmd_axis_name, explicit_mesh_axis, out_dims, out_specs) return ans.with_aux(out_dims).with_aux(tuple(new_out_specs)) new_params = dict(mesh=mesh, in_specs=new_in_specs, check_vma=check_vma, manual_axes=manual_axes, debug_info=debug_info) # TODO(yashkatariya): Remove remove_explicit_mesh_axis_names when vmap # mesh ctx is correctly set. with (core.set_current_trace(trace.parent_trace), core.remove_explicit_mesh_axis_names(trace.axis_data.explicit_mesh_axis)): out_vals = prim.bind(*in_vals, subfuns=(fun_batched,), **new_params) make_tracer = partial(batching.BatchTracer, trace, source_info=source_info_util.current()) out_vals, out_dims = out_vals.unpack_aux() return out_vals.map2(make_tracer, out_dims) batching.BatchTrace.process_shard_map = _shard_map_batch def _batch_out_specs(spmd_name, explicit_mesh_axis, dims, out_specs): if spmd_name is not None: used = {n for spec in out_specs for n in used_axis_names(spec)} if not config.disable_vmap_shmap_error.value and set(spmd_name) & used: raise ValueError("vmap spmd_axis_name cannot appear in shard_map out_specs") return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, spmd_name) for sp, d in zip(out_specs, dims)] elif explicit_mesh_axis is not None: used = {n for spec in out_specs for n in used_axis_names(spec)} if set(explicit_mesh_axis) & used: raise ValueError("vmapped away explicit mesh axis cannot appear in " "shard_map out_specs") return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) for sp, d in zip(out_specs, dims)] else: return [sp if d is batching.not_mapped else pxla.batch_spec(sp, d, None) for sp, d in zip(out_specs, dims)] # Autodiff def _shard_map_jvp(trace, shard_map_p, f, tracers, mesh, in_specs, check_vma, manual_axes, debug_info): debug_info = debug_info.with_unknown_names() primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not ad.Zero for t in tangents] tangents = [t if type(t) is not ad.Zero else None for t in tangents] args, in_zeros_tree = tree_flatten((primals, tangents)) tangent_in_specs = [sp.to_tangent_spec() for sp, nz in zip(in_specs, which_nz) if nz] def f_jvp(*primals_and_nz_tangents_flat): primals, tangents = tree_unflatten(in_zeros_tree, primals_and_nz_tangents_flat) tangents = [ad.p2tz(p) if t is None else t for p, t in zip(primals, tangents)] primals_out_aux, tangents_out = ad.jvp_subtrace_2(f, trace.tag, primals, tangents) primals_out_ft, out_ax = primals_out_aux.unpack_aux() which_nz_out = [type(r) is not ad.Zero for r in tangents_out] tangent_out_specs = [s.to_tangent_spec() for s, nz in zip(out_ax, which_nz_out) if nz] new_out_specs = (*out_ax, *tangent_out_specs) tangents_out = [None if not nz else t for t, nz in zip(tangents_out, which_nz_out)] tangents_out_ft = FlatTree.flatten(list(tangents_out)) out_primals_tangents = FlatTree.pack((primals_out_ft, tangents_out_ft)) return out_primals_tangents.with_aux(which_nz_out).with_aux(new_out_specs) params = dict(mesh=mesh, in_specs=(*in_specs, *tangent_in_specs), check_vma=check_vma, manual_axes=manual_axes, debug_info=debug_info.with_unknown_names()) avals = [typeof(x) for x in args] result = shard_map_p.bind_with_trace( trace.parent_trace, tuple(args), avals, dict(params, subfuns=(f_jvp,))) pt_out, which_nz_out = result.unpack_aux() primal_out, nz_tangents_out = pt_out.unpack() tangents_stack = list(nz_tangents_out)[::-1] make_tracer = lambda p, nz: ad.JVPTracer(trace, p, tangents_stack.pop()) if nz else p tracers_out = primal_out.map2(make_tracer, which_nz_out) assert not tangents_stack return tracers_out ad.JVPTrace.process_shard_map = _shard_map_jvp def _shard_map_partial_eval(trace: pe.JaxprTrace, shard_map_p, f: Callable, tracers, mesh, in_specs, check_vma, manual_axes, debug_info): tracers = map(trace.to_jaxpr_tracer, tracers) in_pvals = [t.pval for t in tracers] in_knowns, in_avals, in_consts = pe.partition_pvals(in_pvals) unk_in_specs, known_in_specs = pe.partition_list(in_knowns, in_specs) in_avals_sharded = map(partial(shard_aval, mesh, manual_axes, check_vma), unk_in_specs, in_avals) all_names = _all_newly_manual_mesh_names(mesh, manual_axes) def f_pe(*in_consts): in_avals_, in_consts_ = iter(in_avals_sharded), iter(in_consts) in_pvals = [pe.PartialVal.known(next(in_consts_)) if known else pe.PartialVal.unknown(next(in_avals_)) for known in in_knowns] sentinel = object() assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel jaxpr, fwd_data = pe.trace_to_subjaxpr_nounits_fwd2( f, trace.tag, debug_info.with_unknown_names(), False, in_pvals) (in_fwds, out_fwds, out_pvals, res, env) = fwd_data which = [f1 is None and f2 is None and not v.aval.shape for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) res = [lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x for x in res] out_pvals, out_specs = out_pvals.unpack_aux() out_knowns, _, out_consts = pe.partition_pvals(out_pvals) res_avals = [typeof(r) for r in res] _, out_known_specs = pe.partition_list(out_knowns, out_specs) res_specs = [a.nospec(mesh, check_vma, all_names) for a in res_avals] new_out_specs = (*out_known_specs, *res_specs) ft = out_pvals.map(lambda _: None) ans_ft = FlatTree.flatten((out_consts, res)) aux = (in_fwds, out_fwds, out_knowns, res_avals, jaxpr, env, out_specs, new_out_specs, ft) return ans_ft.with_aux(aux).with_aux(new_out_specs) known_params = dict(mesh=mesh, in_specs=(*known_in_specs,), check_vma=check_vma, manual_axes=manual_axes, debug_info=debug_info.with_unknown_names()) avals = [typeof(x) for x in in_consts] out = shard_map_p.bind_with_trace(trace.parent_trace, tuple(in_consts), avals, dict(known_params, subfuns=(f_pe,))) outs, (in_fwd, out_fwd, out_knowns, res_avals, jaxpr, env, out_specs, new_out_specs, ft) = out.unpack_aux() out_consts, non_fwd_res = outs.unflatten() assert not jaxpr.constvars unk_out_specs, _ = pe.partition_list(out_knowns, out_specs) res = subs_list2(in_fwd, out_fwd, in_consts, out_consts, non_fwd_res) # TODO make res_avals be the full set, not just the non-fwd ones res_avals_iter = iter(res_avals) res_specs = [] for f1, f2 in zip(in_fwd, out_fwd): if f1 is not None: res_specs.append(known_in_specs[f1]) elif f2 is not None: res_specs.append(new_out_specs[f2]) else: raval = next(res_avals_iter) res_specs.append(raval.nospec(mesh, check_vma, all_names)) env_specs = [_repspec(typeof(e)) for e in env] unk_in_specs = (*res_specs, *env_specs, *unk_in_specs) const_tracers = map(trace.new_instantiated_const, res) env_tracers = map(trace.to_jaxpr_tracer, env) unk_arg_tracers = [t for t in tracers if not t.is_known()] out_avals_sharded = [v.aval for v in jaxpr.outvars] unk_params = dict(mesh=mesh, in_specs=unk_in_specs, out_specs=tuple(unk_out_specs), jaxpr=jaxpr.replace(debug_info=jaxpr.debug_info.with_unknown_names()), check_vma=check_vma, manual_axes=manual_axes) out_avals = map(partial(unshard_aval, mesh, check_vma), unk_out_specs, out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_eqn_recipe(trace, (*const_tracers, *env_tracers, *unk_arg_tracers), out_tracers, shard_map_p, unk_params, effs, source_info_util.current()) for t in out_tracers: t.recipe = eqn results = merge_lists(out_knowns, out_tracers, out_consts) return ft.update(results) pe.JaxprTrace.process_shard_map = _shard_map_partial_eval def _shard_map_linearize(trace, shard_map_p, f: Callable, tracers, mesh, in_specs, check_vma, manual_axes, debug_info): debug_info = debug_info.with_unknown_names() primals, tangents = unzip2(map(trace.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not ad.Zero for t in tangents) all_names = _all_newly_manual_mesh_names(mesh, manual_axes) def f_lin(*primals): res, ans_aux, lin_data = ad.linearize_subtrace_2( f, trace.is_vjp, trace.tag, nzs_in, debug_info, primals) primals_out, out_specs = ans_aux.unpack_aux() res_avals, _, _, _, in_fwd, out_fwd = lin_data res_avals = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) if f1 is None and f2 is None] res_specs = [a.nospec(mesh, check_vma, all_names) for a in res_avals] new_out_specs = (*res_specs, *out_specs) res = [lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x for x in res] res_and_primal = FlatTree.pack((FlatTree.flatten(res), primals_out)) return res_and_primal.with_aux((lin_data, out_specs)).with_aux(new_out_specs) fwd_params = dict( mesh=mesh, in_specs=in_specs, check_vma=check_vma, manual_axes=manual_axes, debug_info=debug_info) avals = [typeof(x) for x in primals] all_results_aux = shard_map_p.bind_with_trace( trace.parent_trace, tuple(primals), avals, dict(fwd_params, subfuns=(f_lin,))) all_results, (lin_data, out_specs) = all_results_aux.unpack_aux() res_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = lin_data non_fwd_res, primals_out = all_results.unpack() residuals = subs_list2(in_fwd, out_fwd, primals, (*primals_out,), non_fwd_res) args_to_promote = [getattr(aval, 'shape', ()) == () and f1 is None and f2 is None for aval, f1, f2 in zip(res_avals, in_fwd, out_fwd)] with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): lin_jaxpr = _promote_scalar_residuals_jaxpr(lin_jaxpr, args_to_promote) res_avals2 = [r for r, f1, f2 in zip(res_avals, in_fwd, out_fwd) if f1 is None and f2 is None] res_avals_iter = iter(res_avals2) res_specs = [in_specs[f1] if f1 is not None else out_specs[f2] if f2 is not None else next(res_avals_iter).nospec(mesh, check_vma, all_names) for f1, f2 in zip(in_fwd, out_fwd)] assert next(res_avals_iter, None) is None env_specs = [_repspec(typeof(e)) for e in env] new_in_specs = (*res_specs, *env_specs, *(s.to_tangent_spec() for s, nz in zip(in_specs, nzs_in) if nz)) tangent_out_specs = tuple(s.to_tangent_spec() for s, nz in zip(out_specs, nzs_out) if nz) tangent_params = dict( mesh=mesh, in_specs=new_in_specs, check_vma=check_vma, manual_axes=manual_axes, debug_info=lin_jaxpr.debug_info) # TODO(mattjj): avoid round-tripping the jaxpr through eval_jaxpr here def f_tangent(*args): ans = core.eval_jaxpr(lin_jaxpr, (), *args) return FlatTree.flatten(ans).with_aux(tangent_out_specs) nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] args = (*residuals, *env, *nz_tangents_in) avals = [typeof(x) for x in args] nz_tangents_out = shard_map_p.bind_with_trace( trace.tangent_trace, args, avals, dict(tangent_params, subfuns=(f_tangent,))) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else ad.p2tz(primal) for nz, primal in zip(nzs_out, primals_out)] return primals_out.map3(partial(ad.maybe_linearize_tracer, trace), nzs_out, tangents_out) ad.LinearizeTrace.process_shard_map = _shard_map_linearize def _promote_scalar_residuals_jaxpr(jaxpr: core.Jaxpr, which: Sequence[bool]): def fun(*res_and_args): res, args = split_list(res_and_args, [len(jaxpr.constvars)]) res = [_rem_singleton(x) if w else x for x, w in zip(res, which)] return core.eval_jaxpr(jaxpr, res, *args) res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval for v, w in zip(jaxpr.constvars, which)] in_avals = FlatTree.flatten(((*res_avals, *[v.aval for v in jaxpr.invars]), {})) closed_jaxpr, _ = pe.trace_to_jaxpr(fun, in_avals, debug_info=jaxpr.debug_info) closed_jaxpr, _ = pe.separate_consts(closed_jaxpr) return closed_jaxpr.jaxpr def _unmentioned2(mesh: Mesh, spec, manual_axes: frozenset[AxisName] ) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over # more chips than required in the transpose-no-check-vma case. name_set = _spec_to_vma(spec) | spec.unreduced return [n for n in _all_mesh_names_except_spmd(mesh, manual_axes) if n not in name_set] def _shard_map_transpose(out_cts, *args, jaxpr: core.Jaxpr, mesh, in_specs, out_specs, check_vma, manual_axes): mb_div = lambda x, y: x / y if y != 1 else x out_cts = [ ad.Zero(shard_aval(mesh, manual_axes, check_vma, sp, x.aval)) if type(x) is ad.Zero else x if check_vma or dtypes.dtype(x) == dtypes.float0 else mb_div(x, prod(map(mesh.shape.get, _unmentioned2(mesh, sp, manual_axes)))) for sp, x in zip(out_specs, out_cts) ] args = [x if type(x) is not ad.UndefinedPrimal else ad.UndefinedPrimal(shard_aval(mesh, manual_axes, check_vma, sp, x.aval)) for sp, x in zip(in_specs, args)] all_args, in_tree = tree_flatten((out_cts, tuple(args))) def fun_trans_callable(*right_flat): right_cts, primals_or_undefs = tree_unflatten(in_tree, right_flat) left_cts = ad.backward_pass(jaxpr, False, (), primals_or_undefs, right_cts) left_cts = [x if type(x) is ad.Zero or check_vma else lax_parallel.psum(x, tuple(_unmentioned2(mesh, sp, manual_axes))) for sp, x in zip(in_specs, left_cts)] left_specs_nz = tuple( s.to_ct_spec() for ct, s in zip(left_cts, in_specs) if ct is not None and type(ct) is not ad.Zero) return FlatTree.flatten(left_cts).with_aux(left_specs_nz) dbg = jaxpr.debug_info.with_unknown_names() new_in_specs = ( [s.to_ct_spec() for s, x in zip(out_specs, out_cts) if type(x) is not ad.Zero] + [s for s, x in zip(in_specs, args) if type(x) is not ad.UndefinedPrimal]) left_ct = shard_map_p.bind( *all_args, subfuns=(fun_trans_callable,), mesh=mesh, in_specs=tuple(new_in_specs), check_vma=check_vma, manual_axes=manual_axes, debug_info=dbg) left_cts = left_ct.unflatten() return [ad.Zero(unshard_aval(mesh, check_vma, sp.to_ct_spec(), x.aval)) if type(x) is ad.Zero else x for sp, x in zip(in_specs, left_cts)] ad.primitive_transposes[shard_map_p] = _shard_map_transpose # Remat def _partial_eval_jaxpr_custom_rule( saveable: Callable[..., pe.RematCases_], unks_in: Sequence[bool], inst_in: Sequence[bool], eqn: core.JaxprEqn ) -> tuple[core.JaxprEqn, core.JaxprEqn, Sequence[bool], Sequence[bool], list[core.Var]]: jaxpr, mesh = eqn.params['jaxpr'], eqn.params['mesh'] check_vma, manual_axes = eqn.params['check_vma'], eqn.params['manual_axes'] with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) num_out_primals = len(jaxpr_known.outvars) - num_res in_fwd = pe._jaxpr_forwarding(jaxpr_known)[num_out_primals:] out_binders_known, _ = partition_list(unks_out, eqn.outvars) out_vars, res_vars = split_list(jaxpr_known.outvars, [num_out_primals]) idx_map = {id(v): i for i, (v, b) in enumerate(zip(out_vars, out_binders_known)) if not isinstance(b, core.DropVar)} out_fwd = [idx_map.get(id(v)) for v in res_vars] which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)] mesh = eqn.params['mesh'] with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(_as_manual_mesh(mesh, manual_axes)), config._check_vma(check_vma)): jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names) ins_known, _ = partition_list(unks_in, eqn.invars) _, ins_staged = partition_list(inst_in, eqn.invars) _, out_binders_staged = partition_list(inst_out, eqn.outvars) nv = core.gensym() all_names = _all_newly_manual_mesh_names(mesh, manual_axes) lns = lambda a: a.nospec(mesh, check_vma, all_names) residuals, staged_in_res_specs = unzip2( [(nv(unshard_aval(mesh, check_vma, (rn := lns(var.aval)), var.aval)), rn) for var, w in zip(jaxpr_staged.invars[:num_res], which) if w]) out_res_specs_known = [var.aval.nospec(mesh, check_vma, all_names) # pyrefly: ignore[missing-attribute] for var, w in zip(res_vars, which) if w] params_known, params_staged = _pe_custom_params( unks_in, inst_in, map(op.not_, unks_out), inst_out, in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs, dict(eqn.params, jaxpr=jaxpr_known), dict(eqn.params, jaxpr=jaxpr_staged)) eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, eqn.source_info, eqn.ctx) full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals) eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged, eqn.primitive, params_staged, jaxpr_staged.effects, eqn.source_info, eqn.ctx) assert len(eqn_staged.invars) == len(jaxpr_staged.invars) new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is core.Var and not inst] new_inst += [out_binders_known[f] for f in {i for i in out_fwd if i is not None}] return eqn_known, eqn_staged, unks_out, inst_out, new_inst + list(residuals) pe.partial_eval_jaxpr_custom_rules[shard_map_p] = \ _partial_eval_jaxpr_custom_rule def _add_reshapes(which: Sequence[bool], jaxpr_known: core.Jaxpr, jaxpr_staged: core.Jaxpr) -> tuple[core.Jaxpr, core.Jaxpr]: # add singleton axes to residuals which are from jaxpr_known and are scalars which_ = [w and not v.aval.shape # pyrefly: ignore[missing-attribute] for w, v in zip(which, jaxpr_staged.invars[:len(which)])] if not any(which_): return jaxpr_known, jaxpr_staged assert not jaxpr_known.constvars and not jaxpr_staged.constvars def known(*args): out = core.eval_jaxpr(jaxpr_known, (), *args) out_known, res = split_list(out, [len(out) - sum(which)]) res = [_add_singleton(x) if not x.shape else x for x in res] return [*out_known, *res] avals_in = tuple(v.aval for v in jaxpr_known.invars) avals_in = FlatTree.flatten((avals_in, {})) jaxpr_known_closed, _ = pe.trace_to_jaxpr( known, avals_in, debug_info=jaxpr_known.debug_info) def staged(*args): res_, ins = split_list(args, [len(which)]) res = [_rem_singleton(x) if w else x for x, w in zip(res_, which_)] return core.eval_jaxpr(jaxpr_staged, (), *res, *ins) res_avals = [core.unmapped_aval(1, 0, v.aval) if w else v.aval for w, v in zip(which_, jaxpr_staged.invars[:len(which)])] avals_in = (*res_avals, *[v.aval for v in jaxpr_staged.invars[len(which):]]) avals_in = FlatTree.flatten((avals_in, {})) jaxpr_staged_closed, _ = pe.trace_to_jaxpr( staged, avals_in, debug_info=jaxpr_staged.debug_info) return jaxpr_known_closed.jaxpr, jaxpr_staged_closed.jaxpr def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, in_fwd, out_fwd, out_res_specs_known, staged_in_res_specs, params_known, params_staged): # prune inputs to jaxpr_known according to unks_in in_specs_known, _ = partition_list(unks_in, params_known['in_specs']) _, out_specs_known = partition_list(kept_outs_known, params_known['out_specs']) out_specs_known = out_specs_known + out_res_specs_known assert len(out_specs_known) == len(params_known['jaxpr'].outvars) new_params_known = dict(params_known, in_specs=tuple(in_specs_known), out_specs=tuple(out_specs_known)) # added num_res new inputs to jaxpr_staged, pruning according to inst_in _, in_specs_staged = partition_list(inst_in, params_staged['in_specs']) iter_staged = iter(staged_in_res_specs) res_specs = [in_specs_known[f1] if f1 is not None else out_specs_known[f2] if f2 is not None else next(iter_staged) for f1, f2 in zip(in_fwd, out_fwd)] in_specs_staged = res_specs + in_specs_staged _, out_specs_staged = partition_list(kept_outs_staged, params_staged['out_specs']) new_params_staged = dict(params_staged, in_specs=tuple(in_specs_staged), out_specs=tuple(out_specs_staged)) return new_params_known, new_params_staged # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd( mesh: Mesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: axis_env = core.get_axis_env() spmd_names = axis_env.spmd_axis_names return tuple(name for name in mesh.axis_names if name not in spmd_names and name in manual_axes) def _all_newly_manual_mesh_names( mesh: BaseMesh, manual_axes: frozenset[AxisName]) -> tuple[AxisName, ...]: axis_env = core.get_axis_env() vmap_spmd_names = set(axis_env.spmd_axis_names) if not (ctx_mesh := get_abstract_mesh()).empty: mesh = ctx_mesh already_manual_names = set(ctx_mesh.manual_axes) else: # TODO(mattjj): remove this mechanism when we revise mesh scopes already_manual_names = set(axis_env.axis_sizes) # may include vmap axis_names return tuple(name for name in mesh.axis_names if (name not in vmap_spmd_names | already_manual_names and name in manual_axes)) # DCE # TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: if not any(used_outputs) and not pe.has_effects(eqn): return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] manual_axes = eqn.params["manual_axes"] check_vma = eqn.params["check_vma"] with (_extend_axis_env(mesh, manual_axes), config._check_vma(check_vma), use_abstract_mesh(_as_manual_mesh(mesh, manual_axes))): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: return used_inputs, None else: _, in_specs = partition_list(used_inputs, eqn.params['in_specs']) _, out_specs = partition_list(used_outputs, eqn.params['out_specs']) new_params = dict(eqn.params, jaxpr=jaxpr, in_specs=tuple(in_specs), out_specs=tuple(out_specs)) effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [x for x, used in zip(eqn.outvars, used_outputs) if used], eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx) return used_inputs, new_eqn pe.dce_rules[shard_map_p] = _shard_map_dce # Mutable arrays / refs @discharge.register_discharge_rule(shard_map_p) def _shard_map_discharge( in_avals, out_avals, *args, jaxpr, mesh, in_specs, out_specs, check_vma, manual_axes): inner_mesh = _as_manual_mesh(mesh, manual_axes) with (_extend_axis_env(mesh, manual_axes), use_abstract_mesh(inner_mesh), config._check_vma(check_vma)): discharged_jaxpr, discharged_consts = discharge.discharge_state(jaxpr, ()) if discharged_consts: raise NotImplementedError del discharged_consts ref_specs = [spec for spec, invar in zip(in_specs, jaxpr.invars) if isinstance(invar.aval, AbstractRef)] params = dict(jaxpr=discharged_jaxpr, out_specs=(*out_specs, *ref_specs)) params_ = shard_map_p.get_bind_params(params) f, = params_.pop('subfuns') debug_info = params_['debug_info'] out_and_ref_vals = shard_map_p.bind( *args, subfuns=(f,), mesh=mesh, in_specs=in_specs, manual_axes=manual_axes, debug_info=debug_info, check_vma=check_vma) out_vals, ref_vals = split_list(out_and_ref_vals, [len(jaxpr.outvars)]) ref_vals_ = iter(ref_vals) new_invals = [next(ref_vals_) if isinstance(a, AbstractRef) else None for a in in_avals] assert next(ref_vals_, None) is None return new_invals, out_vals def _repspec(aval): return aval.nospec(empty_abstract_mesh, False, ()) # ----------------------- top level collectives -------------------------------- def _top_level_ag(x, aval, out_sh_, multi_dim): assert aval.sharding.mesh.are_all_axes_explicit, aval.sharding.mesh out_sh = canonicalize_sharding(out_sh_, "top_level_all_gather") if out_sh is None: raise ValueError( f'out_sharding passed to top_level_all_gather cannot be {out_sh_}. It' ' should be a PartitionSpec or NamedSharding.') if aval.sharding.mesh != out_sh.mesh: raise ValueError( f'Input sharding mesh {aval.sharding.mesh} should be equal to' f' out_sharding mesh {out_sh.mesh}') in_spec = aval.sharding.spec out_spec = out_sh.spec._normalized_spec_for_aval(len(in_spec)) if config.remove_size_one_mesh_axis_from_type.value: out_spec = core.remove_size_one_mesh_axis(out_spec, out_sh.mesh) def f_shmap(x): # Maybe this can just be 1 AG where we gather in a new dim and then do # AG(new_dim) -> reshape -> transpose -> reshape but it might be expensive. count = 0 for axis, (i, o) in enumerate(zip(in_spec, out_spec)): if i == o: continue if not multi_dim and count > 0: raise ValueError( "multiple dimensions cannot be all_gathered since multi_dim=False" f" passed to `top_level_all_gather`. Got {in_spec=} and {out_spec=}") count += 1 if i is None: raise ValueError( f"top_level_all_gather doesn't allow input {aval} to be unsharded" f" on dimension {axis} when {out_spec=}.") i = i if isinstance(i, tuple) else (i,) o = o if o is None or isinstance(o, tuple) else (o,) if o is not None and i[:len(o)] != o: raise ValueError( 'top_level_all_gather maintains `top_level_all_gather(x, ...) == x`' f" property. The {in_spec=} and {out_spec=} don't satisfy this" f' property. Please change your out_spec of array dimension {axis} so' " that it's a prefix of in_spec") axis_name = i if o is None else i[-len(o):] x = lax_parallel.all_gather(x, axis_name=axis_name, axis=axis, tiled=True, to='reduced') return x return api.jit(shard_map(f_shmap, out_specs=out_spec))(x) def top_level_all_gather(xs, out_sharding, *, multi_dim: bool = False): if not get_abstract_mesh().are_all_axes_explicit: raise ValueError( 'top_level_all_gather works when all mesh axes of context mesh are' f' explicit. Got {get_abstract_mesh()}') x_flat, treedef = tree_flatten(xs) out_sharding_flat = api_util.flatten_axis_resources( "top_level_all_gather out_sharding", treedef, out_sharding, tupled_args=True) x_avals_flat = [core.typeof(x) for x in x_flat] out_flat = [_top_level_ag(x, aval, sh, multi_dim) for x, aval, sh in zip(x_flat, x_avals_flat, out_sharding_flat)] return tree_unflatten(treedef, out_flat)