3075 lines
124 KiB
Python
3075 lines
124 KiB
Python
# Copyright 2019 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Parallelization primitives.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from functools import partial
|
|
from dataclasses import dataclass
|
|
import itertools
|
|
import math
|
|
from typing import Any
|
|
|
|
from jax._src import core
|
|
from jax._src import config
|
|
from jax._src import dispatch
|
|
from jax._src import dtypes
|
|
from jax._src import effects as effects_lib
|
|
from jax._src import tree_util
|
|
from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext,
|
|
NamedSharding, PartitionSpec as P)
|
|
from jax._src.core import AxisName, ShapedArray
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import mlir
|
|
from jax._src.interpreters import pxla
|
|
from jax._src.core import check_unreduced_args
|
|
from jax._src.mesh import get_abstract_mesh
|
|
from jax._src.core import abstract_token, pvary
|
|
from jax._src.lax import control_flow
|
|
from jax._src.lax import lax
|
|
from jax._src.lax import slicing
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import hlo
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.typing import Array
|
|
from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip,
|
|
unzip2)
|
|
import numpy as np
|
|
|
|
unsafe_map, map = map, safe_map
|
|
unsafe_zip, zip = zip, safe_zip
|
|
|
|
|
|
### parallel traceables
|
|
|
|
def psum(x, axis_name, *, axis_index_groups=None):
|
|
"""Compute an all-reduce sum on ``x`` over the pmapped axis ``axis_name``.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
Inputs of boolean dtype are converted to integers before the reduction.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
|
an axis of size 4, [[0, 1], [2, 3]] would perform psums over the first
|
|
two and last two replicas). Groups must cover all axis indices exactly
|
|
once.
|
|
|
|
Returns:
|
|
Array(s) with the same shape as ``x`` representing the result of an
|
|
all-reduce sum along the axis ``axis_name``.
|
|
|
|
Examples:
|
|
For example, with 4 XLA devices available:
|
|
|
|
>>> x = np.arange(4)
|
|
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
|
|
>>> print(y)
|
|
[6 6 6 6]
|
|
>>> y = jax.pmap(lambda x: x / jax.lax.psum(x, 'i'), axis_name='i')(x)
|
|
>>> print(y)
|
|
[0. 0.16666667 0.33333334 0.5 ]
|
|
|
|
Suppose we want to perform ``psum`` among two groups, one with ``device0`` and ``device1``, the other with ``device2`` and ``device3``,
|
|
|
|
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
|
|
>>> print(y)
|
|
[1 1 5 5]
|
|
|
|
An example using 2D-shaped x. Each row is data from one device.
|
|
|
|
>>> x = np.arange(16).reshape(4, 4)
|
|
>>> print(x)
|
|
[[ 0 1 2 3]
|
|
[ 4 5 6 7]
|
|
[ 8 9 10 11]
|
|
[12 13 14 15]]
|
|
|
|
Full ``psum`` across all devices:
|
|
|
|
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(x)
|
|
>>> print(y)
|
|
[[24 28 32 36]
|
|
[24 28 32 36]
|
|
[24 28 32 36]
|
|
[24 28 32 36]]
|
|
|
|
Perform ``psum`` among two groups:
|
|
|
|
>>> y = jax.pmap(lambda x: jax.lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]), axis_name='i')(x)
|
|
>>> print(y)
|
|
[[ 4 6 8 10]
|
|
[ 4 6 8 10]
|
|
[20 22 24 26]
|
|
[20 22 24 26]]
|
|
"""
|
|
return _psum_is_async(x, axis_name, axis_index_groups=axis_index_groups,
|
|
is_async=False)
|
|
|
|
def _psum_is_async(x, axis_name, *, axis_index_groups=None, is_async=False):
|
|
axes = ((axis_name,) if not isinstance(axis_name, (tuple, list)) else
|
|
tuple(axis_name))
|
|
# TODO(yashkatariya): Remove this handling and remove_size_one_mesh_axis_from_type
|
|
# generally from JAX.
|
|
axes = _maybe_skip_one_sized_axes(axes)
|
|
if not axes:
|
|
return x
|
|
def bind(leaf):
|
|
from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.psum')
|
|
if from_ == 'unreduced':
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError
|
|
return unreduced_psum(leaf, axes, is_async)
|
|
else:
|
|
return _psum(leaf, axes, axis_index_groups=axis_index_groups,
|
|
is_async=is_async)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
def _psum(x, axis_name, *, axis_index_groups, is_async):
|
|
if not isinstance(axis_name, (tuple, list)):
|
|
axis_name = (axis_name,)
|
|
if not axis_name:
|
|
return x
|
|
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
|
|
raise ValueError("axis_index_groups only supported for sums over just named axes")
|
|
_validate_reduce_axis_index_groups(axis_index_groups)
|
|
leaves, treedef = tree_util.tree_flatten(x)
|
|
leaves = [lax.convert_element_type(l, np.int32)
|
|
if dtypes.dtype(l) == np.bool_ else l for l in leaves]
|
|
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
|
# handle the constant case specially
|
|
if all(not isinstance(leaf, core.Tracer) for leaf in leaves):
|
|
named_axes, pos_axes = axes_partition = [], []
|
|
for axis in axis_name:
|
|
axes_partition[isinstance(axis, int)].append(axis)
|
|
def pos_reduce(x):
|
|
if not pos_axes:
|
|
return x
|
|
return lax.reduce_sum(x, [canonicalize_axis(axis, getattr(x, 'ndim', 0))
|
|
for axis in pos_axes])
|
|
if axis_index_groups is not None:
|
|
assert not pos_axes
|
|
size = len(axis_index_groups[0])
|
|
else:
|
|
size = math.prod([core.get_axis_env().axis_size(name) for name in named_axes])
|
|
out_flat = tuple(lax._const(leaf, size) * pos_reduce(leaf) for leaf in leaves)
|
|
else:
|
|
if config._check_vma.value:
|
|
out_flat = [bind_psum_invariant(leaf, axes=tuple(axis_name),
|
|
axis_index_groups=axis_index_groups,
|
|
is_async=is_async)
|
|
for leaf in leaves]
|
|
else:
|
|
prim = psum_start_p if is_async else psum_p
|
|
out_flat = [prim.bind(leaf, axes=tuple(axis_name),
|
|
axis_index_groups=axis_index_groups)
|
|
for leaf in leaves]
|
|
return tree_util.tree_unflatten(treedef, out_flat)
|
|
|
|
|
|
def _maybe_skip_one_sized_axes(axes):
|
|
if config.remove_size_one_mesh_axis_from_type.value:
|
|
cur_mesh = get_abstract_mesh()
|
|
return tuple(i for i in axes
|
|
if (size := cur_mesh.shape.get(i)) is None or size != 1)
|
|
return axes
|
|
|
|
|
|
def pmean(x, axis_name, *, axis_index_groups=None):
|
|
"""Compute an all-reduce mean on ``x`` over the pmapped axis ``axis_name``.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
|
an axis of size 4, [[0, 1], [2, 3]] would perform pmeans over the first
|
|
two and last two replicas). Groups must cover all axis indices exactly
|
|
once, and on TPUs all groups must be the same size.
|
|
|
|
Returns:
|
|
Array(s) with the same shape as ``x`` representing the result of an
|
|
all-reduce mean along the axis ``axis_name``.
|
|
|
|
For example, with 4 XLA devices available:
|
|
|
|
>>> x = np.arange(4)
|
|
>>> y = jax.pmap(lambda x: jax.lax.pmean(x, 'i'), axis_name='i')(x)
|
|
>>> print(y)
|
|
[1.5 1.5 1.5 1.5]
|
|
>>> y = jax.pmap(lambda x: x / jax.lax.pmean(x, 'i'), axis_name='i')(x)
|
|
>>> print(y)
|
|
[0. 0.6666667 1.3333334 2. ]
|
|
"""
|
|
x = psum(x, axis_name=axis_name, axis_index_groups=axis_index_groups)
|
|
n = _axis_size(axis_name, axis_index_groups)
|
|
return tree_util.tree_map(lambda v: v / n, x)
|
|
|
|
def pmax(x, axis_name, *, axis_index_groups=None):
|
|
"""Compute an all-reduce max on ``x`` over the pmapped axis ``axis_name``.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
|
an axis of size 4, [[0, 1], [2, 3]] would perform pmaxes over the first
|
|
two and last two replicas). Groups must cover all axis indices exactly
|
|
once, and on TPUs all groups must be the same size.
|
|
|
|
Returns:
|
|
Array(s) with the same shape as ``x`` representing the result of an
|
|
all-reduce max along the axis ``axis_name``.
|
|
"""
|
|
if not isinstance(axis_name, (tuple, list)):
|
|
axis_name = (axis_name,)
|
|
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
|
|
raise ValueError("axis_index_groups only supported for sums over just named axes")
|
|
_validate_reduce_axis_index_groups(axis_index_groups)
|
|
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
|
def bind(leaf):
|
|
leaf = insert_collective_pvary(axis_name, leaf)
|
|
return pmax_p.bind(leaf, axes=axis_name, axis_index_groups=axis_index_groups)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
|
|
def pmin(x, axis_name, *, axis_index_groups=None):
|
|
"""Compute an all-reduce min on ``x`` over the pmapped axis ``axis_name``.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
|
an axis of size 4, [[0, 1], [2, 3]] would perform pmins over the first
|
|
two and last two replicas). Groups must cover all axis indices exactly
|
|
once, and on TPUs all groups must be the same size.
|
|
|
|
Returns:
|
|
Array(s) with the same shape as ``x`` representing the result of an
|
|
all-reduce min along the axis ``axis_name``.
|
|
"""
|
|
if not isinstance(axis_name, (tuple, list)):
|
|
axis_name = (axis_name,)
|
|
if any(isinstance(axis, int) for axis in axis_name) and axis_index_groups is not None:
|
|
raise ValueError("axis_index_groups only supported for sums over just named axes")
|
|
_validate_reduce_axis_index_groups(axis_index_groups)
|
|
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
|
def bind(leaf):
|
|
leaf = insert_collective_pvary(axis_name, leaf)
|
|
return pmin_p.bind(leaf, axes=axis_name, axis_index_groups=axis_index_groups)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
# TODO(mattjj): add a pargmin_p, or add named axis support to lax.argmin_p
|
|
def pargmin(x, axis_name):
|
|
if isinstance(axis_name, (tuple, list)):
|
|
raise TypeError(f"pargmin only accepts a single axis, got {axis_name}")
|
|
return _axis_index_of_val(x, pmin(x, axis_name), axis_name)
|
|
|
|
# TODO(mattjj): add a pargmax_p, or add named axis support to lax.argmax_p
|
|
def pargmax(x, axis_name):
|
|
if isinstance(axis_name, (tuple, list)):
|
|
raise TypeError(f"pargmin only accepts a single axis, got {axis_name}")
|
|
return _axis_index_of_val(x, pmax(x, axis_name), axis_name)
|
|
|
|
def _axis_index_of_val(x, val, axis_name):
|
|
idx = axis_index(axis_name)
|
|
mask = (val == x)
|
|
validx = lax.select(mask,
|
|
lax.full(mask.shape, idx),
|
|
lax.full(mask.shape, dtypes.iinfo(idx.dtype).max, idx.dtype))
|
|
return pmin(validx, axis_name)
|
|
|
|
def _validate_reduce_axis_index_groups(axis_index_groups):
|
|
if axis_index_groups is None:
|
|
return
|
|
axis_space = range(sum(len(group) for group in axis_index_groups))
|
|
if {i for g in axis_index_groups for i in g} != set(axis_space):
|
|
raise ValueError("axis_index_groups must cover all indices exactly once")
|
|
|
|
def _canonicalize_axis_index_groups(axis_index_groups):
|
|
if axis_index_groups is None:
|
|
return
|
|
return tuple(map(tuple, axis_index_groups))
|
|
|
|
|
|
def pbroadcast(x, axis_name, source):
|
|
"""Perform a collective broadcast and replicate from ``source``.
|
|
|
|
This is equivalent to
|
|
```
|
|
def pbroadcast(x, axis_name, source):
|
|
masked = jnp.where(axis_index(axis_name) == source, x, zeros_like(x))
|
|
return psum(masked, axis_name)
|
|
```
|
|
but implemented in a hardware optimized way.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
This function is an analog of the CollectiveBroadcast HLO.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
source: int, representing which index into ``axis_name`` that should be copied.
|
|
|
|
Returns:
|
|
Array(s) with ``x`` being copied from the ``source`` index slice of ``axis_name``.
|
|
"""
|
|
return _pbroadcast_is_async(x, axis_name, source, is_async=False)
|
|
|
|
def _pbroadcast_is_async(x, axis_name, source, is_async=False):
|
|
prim = pbroadcast_start_p if is_async else pbroadcast_p
|
|
return tree_util.tree_map(
|
|
partial(prim.bind, axis_name=axis_name, source=source), x)
|
|
|
|
|
|
def ppermute(x, axis_name, perm):
|
|
"""Perform a collective permutation according to the permutation ``perm``.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
This function is an analog of the CollectivePermute HLO.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
perm: list of pairs of ints, representing
|
|
``(source_index, destination_index)``
|
|
pairs that encode how the mapped axis named ``axis_name`` should be
|
|
shuffled. The integer values are treated as indices into the mapped axis
|
|
``axis_name``. Any two pairs should not have the same source index or the
|
|
same destination index. For each index of the axis ``axis_name`` that does
|
|
not correspond to a destination index in ``perm``, the corresponding
|
|
values in the result are filled with zeros of the appropriate type.
|
|
|
|
Returns:
|
|
Array(s) with the same shape as ``x`` with slices along the axis
|
|
``axis_name`` gathered from ``x`` according to the permutation ``perm``.
|
|
"""
|
|
return _ppermute_is_async(x, axis_name, perm, is_async=False)
|
|
|
|
def _ppermute_is_async(x, axis_name, perm, is_async=False):
|
|
if not isinstance(axis_name, (list, tuple)):
|
|
axis_name = (axis_name,)
|
|
def bind(leaf):
|
|
leaf = insert_collective_pvary(axis_name, leaf)
|
|
prim = ppermute_start_p if is_async else ppermute_p
|
|
return prim.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm)))
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
|
|
def psend(x, axis_name, perm):
|
|
"""Perform a collective send according to the permutation ``perm``.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
This function is an analog of the Send HLO.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
perm: list of pairs of ints, representing ``(source_index,
|
|
destination_index)`` pairs that encode how the mapped axis named
|
|
``axis_name`` should be shuffled. The integer values are treated as
|
|
indices into the mapped axis ``axis_name``. Any two pairs should not have
|
|
the same source index or the same destination index. For each index of the
|
|
axis ``axis_name`` that does not correspond to a destination index in
|
|
``perm``, the corresponding values in the result are filled with zeros of
|
|
the appropriate type. The semantics here are platform-specific, and for
|
|
GPU they correspond to NCCL send.
|
|
|
|
Returns:
|
|
A compiler token that can be used by precv and lax.optimzation_barrier to
|
|
enforce ordering of collective ops.
|
|
"""
|
|
axis_name = tuple(axis_name) if isinstance(axis_name, (list, tuple)) else (axis_name,)
|
|
|
|
def bind(leaf):
|
|
leaf = insert_collective_pvary(axis_name, leaf)
|
|
return psend_p.bind(leaf, axis_name=axis_name, perm=tuple(map(tuple, perm)))
|
|
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
|
|
def precv(token, out_shape, axis_name, perm):
|
|
"""Perform a collective recv according to the permutation ``perm``.
|
|
|
|
This function is an analog of the Recv HLO.
|
|
|
|
Args:
|
|
token: a compiler token, either generated by a matching psend or
|
|
lax.create_token(). This is used to enforce control dependencies between
|
|
collectives.
|
|
out_shape: ShapeDtypeStruct(s) containing the dtype and shape
|
|
of the result.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
perm: list of pairs of ints, representing ``(source_index,
|
|
destination_index)`` pairs that encode how the mapped axis named
|
|
``axis_name`` should be shuffled. The integer values are treated as
|
|
indices into the mapped axis ``axis_name``. Any two pairs should not have
|
|
the same source index or the same destination index. For each index of the
|
|
axis ``axis_name`` that does not correspond to a destination index in
|
|
``perm``, the corresponding values in the result are filled with zeros of
|
|
the appropriate type. The semantics here are platform-specific, and for
|
|
GPU they correspond to NCCL recv.
|
|
|
|
Returns:
|
|
Array(s) with the same shape as ``out_shape``.
|
|
"""
|
|
axis_name = tuple(axis_name) if isinstance(axis_name, (list, tuple)) else (axis_name,)
|
|
|
|
return precv_p.bind(
|
|
token,
|
|
out_shape=core.ShapedArray(
|
|
out_shape.shape, out_shape.dtype
|
|
),
|
|
axis_name=axis_name,
|
|
perm=tuple(map(tuple, perm)),
|
|
)
|
|
|
|
|
|
def pshuffle(x, axis_name, perm):
|
|
"""Convenience wrapper of jax.lax.ppermute with alternate permutation encoding
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
perm: list of ints encoding sources for the permutation to be applied to
|
|
the axis named ``axis_name``, so that the output at axis index i
|
|
comes from the input at axis index perm[i]. Every integer in [0, N) should
|
|
be included exactly once for axis size N.
|
|
|
|
Returns:
|
|
Array(s) with the same shape as ``x`` with slices along the axis
|
|
``axis_name`` gathered from ``x`` according to the permutation ``perm``.
|
|
"""
|
|
if set(perm) != set(range(len(perm))):
|
|
raise ValueError(f"`perm` does not represent a permutation: {perm}")
|
|
return ppermute(x, axis_name, list(zip(perm, range(len(perm)))))
|
|
|
|
|
|
def pswapaxes(x, axis_name, axis, *, axis_index_groups=None):
|
|
"""Swap the pmapped axis ``axis_name`` with the unmapped axis ``axis``.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
The group size of the mapped axis size must be equal to the size of the
|
|
unmapped axis; that is, we must have
|
|
``lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]``.
|
|
By default, when ``axis_index_groups=None``, this encompasses all the devices.
|
|
|
|
This function is a special case of ``all_to_all`` where the pmapped axis of
|
|
the input is placed at the position ``axis`` in the output. That is, it is
|
|
equivalent to ``all_to_all(x, axis_name, axis, axis)``.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
axis: int indicating the unmapped axis of ``x`` to map with the name
|
|
``axis_name``.
|
|
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
|
an axis of size 4, [[0, 1], [2, 3]] would run pswapaxes over the first
|
|
two and last two replicas). Groups must cover all axis indices exactly
|
|
once, and all groups must be the same size.
|
|
|
|
Returns:
|
|
Array(s) with the same shape as ``x``.
|
|
"""
|
|
return all_to_all(x, axis_name, axis, axis, axis_index_groups=axis_index_groups)
|
|
|
|
def all_to_all(x, axis_name, split_axis, concat_axis, *, axis_index_groups=None,
|
|
tiled=False):
|
|
"""Materialize the mapped axis and map a different axis.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
In the output, the input mapped axis ``axis_name`` is materialized at the
|
|
logical axis position ``concat_axis``, and the input unmapped axis at position
|
|
``split_axis`` is mapped with the name ``axis_name``.
|
|
|
|
The group size of the mapped axis size must be equal to the size of the
|
|
unmapped axis; that is, we must have
|
|
``lax.psum(1, axis_name, axis_index_groups=axis_index_groups) == x.shape[axis]``.
|
|
By default, when ``axis_index_groups=None``, this encompasses all the devices.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
split_axis: int indicating the unmapped axis of ``x`` to map with the name
|
|
``axis_name``.
|
|
concat_axis: int indicating the position in the output to materialize the
|
|
mapped axis of the input with the name ``axis_name``.
|
|
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
|
an axis of size 4, [[0, 1], [2, 3]] would run all_to_all over the first
|
|
two and last two replicas). Groups must cover all axis indices exactly
|
|
once, and all groups must be the same size.
|
|
tiled: when True, all_to_all will divide split_axis into chunks and concatenate
|
|
them along concat_axis. In particular, no dimensions are added or removed.
|
|
False by default.
|
|
|
|
Returns:
|
|
When tiled is False, array(s) with shape given by the expression::
|
|
|
|
np.insert(np.delete(x.shape, split_axis), concat_axis, axis_size)
|
|
|
|
where ``axis_size`` is the size of the mapped axis named ``axis_name`` in
|
|
the input ``x``.
|
|
|
|
Otherwise array with shape similar to the input shape, except with split_axis
|
|
divided by axis size and concat_axis multiplied by axis size.
|
|
"""
|
|
return _all_to_all_is_async(x, axis_name, split_axis, concat_axis,
|
|
axis_index_groups=axis_index_groups, tiled=tiled,
|
|
is_async=False)
|
|
|
|
def _all_to_all_is_async(x, axis_name, split_axis, concat_axis, *,
|
|
axis_index_groups=None, tiled=False, is_async=False):
|
|
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
|
def bind(x, split_axis=split_axis, concat_axis=concat_axis):
|
|
group_size = _axis_size(axis_name, axis_index_groups)
|
|
if tiled:
|
|
if x.shape[split_axis] % group_size != 0:
|
|
raise ValueError(f"The size of all_to_all split_axis ({x.shape[split_axis]}) "
|
|
f"has to be divisible by the size of the named axis "
|
|
f"{axis_name} ({group_size})")
|
|
else:
|
|
if group_size != x.shape[split_axis]:
|
|
msg = ("all_to_all requires the size of the mapped axis axis_name to "
|
|
"equal x.shape[split_axis], but they are {} and {} respectively.")
|
|
raise ValueError(msg.format(group_size, x.shape[split_axis]))
|
|
if split_axis < concat_axis:
|
|
concat_axis += 1 # concat_axis gives a position _after_ split_axis is removed
|
|
x = lax.expand_dims(x, (concat_axis,)) # insert the new axis
|
|
elif split_axis == concat_axis:
|
|
pass
|
|
else: # concat_axis < split_axis
|
|
x = lax.expand_dims(x, (concat_axis,)) # insert the new axis
|
|
split_axis += 1 # we have a new axis before split_axis now
|
|
x = insert_collective_pvary(axis_name, x)
|
|
prim = all_to_all_start_p if is_async else all_to_all_p
|
|
result = prim.bind(x, split_axis=split_axis, concat_axis=concat_axis,
|
|
axis_name=axis_name,
|
|
axis_index_groups=axis_index_groups,
|
|
tiled=tiled)
|
|
if not tiled and split_axis != concat_axis:
|
|
result = lax.squeeze(result, (split_axis,))
|
|
return result
|
|
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
def ragged_all_to_all(
|
|
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes, *,
|
|
axis_name, axis_index_groups = None):
|
|
"""Ragged version of :func:`all_to_all` collective.
|
|
|
|
We say data are "ragged" when they can be represented as a list of arrays
|
|
whose shapes differ only in the size of the leading axis. For example, these
|
|
data are ragged, comprising four component arrays::
|
|
|
|
ragged_data = [jnp.arange(3), jnp.arange(1), jnp.arange(4), jnp.arange(1)]
|
|
|
|
We often instead want a contiguous representation, e.g. for batching. But
|
|
because the shapes of the components differ, we can't apply ``jnp.stack`` to
|
|
represent these data by a single rectangular array with the leading axis
|
|
indexing the component arrays. So instead of stacking, we concatenate along
|
|
the leading axis and keep track of offsets and sizes.
|
|
|
|
That is, we can represent ragged data contiguously using a triple of dense
|
|
arrays ``(data, offsets, sizes)``:
|
|
|
|
* ``data``: the concatenated component arrays,
|
|
* ``offsets``: 1D array of indices into the leading axis of ``data``
|
|
indicating where the data for each component array begins,
|
|
* ``sizes``: 1D array of sizes of the leading axis of each component array.
|
|
|
|
We refer to this triple as a ragged array. (Offsets can't be computed from
|
|
sizes in general to allow for internal padding.)
|
|
|
|
For example::
|
|
|
|
data: f32[8,3] = jnp.array([
|
|
[a,b,c], [d,e,f], [g,h,i], [j,k,l], [m,n,o], [p,q,r], [s,t,u], [v,w,x],
|
|
])
|
|
offsets: i32[3] = jnp.array([0, 1, 4])
|
|
sizes: i32[3] = jnp.array([1, 3, 4])
|
|
|
|
# To extract the first component array, of type f32[1,3]
|
|
data[offsets[0]:offsets[0]+sizes[0]]
|
|
|
|
# To extract the second component array, of type f32[3,3]
|
|
data[offsets[1]:offsets[1]+sizes[1]]
|
|
|
|
# To extract the third component array, of type f32[4,3]
|
|
data[offsets[2]:offsets[2]+sizes[2]]
|
|
|
|
The ``ragged_all_to_all`` collective operation communicates slices of ragged
|
|
arrays between devices. Each caller is both a sender and a receiver. The
|
|
``input_offsets`` and ``send_sizes`` arguments indicate the slices of the
|
|
caller's ``operand`` to be sent. Received results are returned in an array
|
|
that has the same value of the argument ``output`` except with received values
|
|
written at some slices. The ``output_offsets`` argument does *not* indicate
|
|
the offsets at which all the received results are written; instead,
|
|
``output_offsets`` indicates the offsets at which the *sent* slices are
|
|
written on their corresponding receivers. The sizes of received slices are
|
|
indicated by ``recv_sizes``. See below for details.
|
|
|
|
The arrays ``input_offsets``, ``send_sizes``,``output_offsets``, and
|
|
``recv_sizes`` must all be the same length, and that length must be divisible
|
|
by the size of the mapped axis ``axis_name``. Moreover, ``send_sizes`` and
|
|
``recv_sizes`` must satisfy::
|
|
|
|
jnp.all(send_sizes == jax.lax.all_to_all(recv_sizes, axis_name, 0, 0, tiled=True))
|
|
|
|
Specifically, given a call::
|
|
|
|
result = ragged_all_to_all(operand, output, input_offsets, send_sizes,
|
|
output_offsets, recv_sizes, axis_name)
|
|
|
|
the caller sends data like::
|
|
|
|
assert len(input_offsets) == len(send_sizes) == len(output_offsets) == len(recv_sizes)
|
|
N = len(input_offsets)
|
|
slices_per_device, leftover = divmod(N, lax.axis_size(axis_name))
|
|
assert not leftover
|
|
|
|
for i in range(N):
|
|
dst_idx = i // slices_per_device
|
|
SEND(data=operand[input_offsets[i]:input_offsets[i]+send_sizes[i]],
|
|
axis_name=axis_name, to_axis_index=dst_idx)
|
|
|
|
and receives data in ``result`` like::
|
|
|
|
result = output
|
|
output_offsets_ = jax.lax.all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
|
|
for i in range(N):
|
|
src_idx = i // slices_per_device
|
|
result = result.at[output_offsets_[i]:output_offsets_[i]+recv_sizes[i]
|
|
].set(RECEIVE(axis_name=axis_name, from_axis_index=src_idx))
|
|
|
|
where ``SEND`` and ``RECEIVE`` are pseudocode. Notice that a caller's local
|
|
``output_offsets`` does not indicate the offsets at which its local ``result``
|
|
is updated; instead, it indicates where the corresponding sent slices are
|
|
written on their destination instances. To compute the local offsets at which
|
|
received data are written, we apply an ``all_to_all`` on ``output_offsets``.
|
|
|
|
For example, if we apply a ``ragged_all_to_all`` along an axis of size 2, with
|
|
these arguments in each mapped function instance::
|
|
|
|
axis index 0:
|
|
operand = [1, 2, 2]
|
|
output = [0, 0, 0, 0]
|
|
input_offsets = [0, 1]
|
|
send_sizes = [1, 2]
|
|
output_offsets = [0, 0]
|
|
recv_sizes = [1, 1]
|
|
|
|
axis index 1:
|
|
operand = [3, 4, 0]
|
|
output = [0, 0, 0, 0]
|
|
input_offsets = [0, 1]
|
|
send_sizes = [1, 1]
|
|
output_offsets = [1, 2]
|
|
recv_sizes = [2, 1]
|
|
|
|
then::
|
|
|
|
axis index 0:
|
|
result = [1, 3, 0, 0]
|
|
|
|
axis index 1:
|
|
result = [2, 2, 4, 0]
|
|
|
|
Args:
|
|
operand: data array of shape (N, A, B, ...) representing concatenated
|
|
(possibly padded) ragged data to be sent.
|
|
output: data array of shape (M, A, B, ...) to update with received data.
|
|
input_offsets: 1D integer array of shape (K,) representing the offsets of
|
|
leading-axis slices into ``operand`` to be sent.
|
|
send_sizes: 1D integer array array of shape (K,) representing the sizes of
|
|
leading-axis slices into ``operand`` to be sent.
|
|
output_offsets: 1D integer array of shape (K,) representing where the
|
|
corresponding sent data is written on each corresponding receiver.
|
|
recv_sizes: 1D integer array of shape (K,) representing sizes of
|
|
leading-axis slices into ``output`` to update with received data.
|
|
axis_name: name of the mapped axis over which to perform the communication.
|
|
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
|
an axis of size 4, [[0, 1], [2, 3]] would run ragged all to all over the
|
|
first two and last two replicas). Groups must cover all axis indices
|
|
exactly once, and all groups must be the same size. Otherwise, the
|
|
behavior is undefined.
|
|
|
|
Returns:
|
|
Array of shape (M, A, B, ...) with the same value as the ``output`` except
|
|
with received data written into slices starting at
|
|
``all_to_all(output_offsets, axis_name, 0, 0, tiled=True)`` and with size
|
|
``recv_sizes``.
|
|
"""
|
|
|
|
if not isinstance(axis_name, (tuple, list)):
|
|
axis_name = (axis_name,)
|
|
|
|
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
|
return ragged_all_to_all_p.bind(operand, output, input_offsets, send_sizes,
|
|
output_offsets, recv_sizes,
|
|
axis_name=axis_name,
|
|
axis_index_groups=axis_index_groups)
|
|
|
|
|
|
def axis_index(axis_name: AxisName) -> Array:
|
|
"""Return the index along the mapped axis ``axis_name``.
|
|
|
|
Args:
|
|
axis_name: hashable Python object used to name the mapped axis.
|
|
|
|
Returns:
|
|
An integer representing the index.
|
|
|
|
For example, with 8 XLA devices available:
|
|
|
|
>>> mesh = jax.make_mesh((8,), 'i')
|
|
>>> @jax.shard_map(mesh=mesh, in_specs=(), out_specs=jax.P('i'))
|
|
... def f():
|
|
... return lax.axis_index('i')[None]
|
|
...
|
|
>>> f()
|
|
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)
|
|
|
|
>>> mesh = jax.make_mesh((4, 2), ('i', 'j'))
|
|
>>> @jax.shard_map(mesh=mesh, in_specs=(), out_specs=jax.P('i', 'j'))
|
|
... def f():
|
|
... return lax.axis_index(('i', 'j'))[None, None]
|
|
...
|
|
>>> f()
|
|
Array([[0, 1],
|
|
[2, 3],
|
|
[4, 5],
|
|
[6, 7]], dtype=int32)
|
|
"""
|
|
if not isinstance(axis_name, (tuple, list)):
|
|
return axis_index_p.bind(axis_name=axis_name)
|
|
else:
|
|
inner_size = 1
|
|
index = lax.asarray(0)
|
|
for name in reversed(axis_name):
|
|
index += axis_index(name) * inner_size
|
|
inner_size *= axis_size(name)
|
|
return index
|
|
|
|
|
|
def axis_size(axis_name: AxisName) -> int:
|
|
"""Return the size of the mapped axis ``axis_name``.
|
|
|
|
Args:
|
|
axis_name: hashable Python object used to name the mapped axis.
|
|
|
|
Returns:
|
|
An integer representing the size.
|
|
|
|
For example, with 8 XLA devices available:
|
|
|
|
>>> mesh = jax.make_mesh((8,), 'i')
|
|
>>> @jax.shard_map(mesh=mesh, in_specs=jax.P('i'), out_specs=jax.P())
|
|
... def f(_):
|
|
... return lax.axis_size('i')
|
|
...
|
|
>>> f(jax.device_put(jnp.zeros(16), jax.NamedSharding(mesh, P('i'))))
|
|
Array(8, dtype=int32, weak_type=True)
|
|
|
|
>>> mesh = jax.make_mesh((4, 2), ('i', 'j'))
|
|
>>> @jax.shard_map(mesh=mesh, in_specs=jax.P('i', 'j'), out_specs=jax.P())
|
|
... def f(_):
|
|
... return lax.axis_size(('i', 'j'))
|
|
...
|
|
>>> f(jax.device_put(jnp.zeros((16, 8)), jax.NamedSharding(mesh, P('i', 'j'))))
|
|
Array(8, dtype=int32, weak_type=True)
|
|
"""
|
|
return _axis_size(axis_name)
|
|
|
|
|
|
def _axis_size(
|
|
axis_name: AxisName,
|
|
axis_index_groups: Sequence[Sequence[int]] | None = None,
|
|
/,
|
|
) -> int:
|
|
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
|
return psum(1, axis_name, axis_index_groups=axis_index_groups)
|
|
|
|
### parallel primitives
|
|
|
|
|
|
def _constant_reduction(prim, axis_data, arg, axes, axis_index_groups):
|
|
assert axis_data.name in axes
|
|
if axis_index_groups: raise NotImplementedError
|
|
new_axes = tuple(n for n in axes if n != axis_data.name)
|
|
if new_axes:
|
|
arg = (prim.bind(arg, axes=new_axes) if prim is psum_invariant_p else
|
|
prim.bind(arg, axes=new_axes, axis_index_groups=axis_index_groups))
|
|
if prim is psum_p:
|
|
out = lax._const(arg, axis_data.size) * arg
|
|
elif prim in (pmin_p, pmax_p):
|
|
out = arg
|
|
else:
|
|
raise Exception(f"Unrecognized reducer: {prim}")
|
|
return out, None
|
|
|
|
def _reduction_with_positional_batcher(
|
|
prim, v, d, axis_index_groups, transform_unmapped, transform_mapped):
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError("axis_index_groups not supported in vmap collectives. "
|
|
"Please open a feature request!")
|
|
v = v if d is batching.not_mapped or d == 0 else _moveaxis(d, 0, v)
|
|
if d is batching.not_mapped:
|
|
unmapped_axes, unmapped_vals_in = transform_unmapped(0, v)
|
|
return (prim.bind(unmapped_vals_in, axes=unmapped_axes)
|
|
if prim is psum_invariant_p else
|
|
prim.bind(unmapped_vals_in, axes=unmapped_axes, axis_index_groups=None))
|
|
|
|
mapped_axes, mapped_vals_in = transform_mapped(0, v)
|
|
return (prim.bind(mapped_vals_in, axes=mapped_axes)
|
|
if prim is psum_invariant_p else
|
|
prim.bind(mapped_vals_in, axes=mapped_axes, axis_index_groups=None))
|
|
|
|
def _reduction_batcher(prim, v, d, *, axes, axis_index_groups):
|
|
assert not prim.multiple_results
|
|
if not any(isinstance(axis, int) for axis in axes):
|
|
out = (prim.bind(v, axes=axes) if prim is psum_invariant_p else
|
|
prim.bind(v, axes=axes, axis_index_groups=axis_index_groups))
|
|
return out, d
|
|
val_out = _reduction_with_positional_batcher(
|
|
prim, v, d, axis_index_groups,
|
|
lambda d, v: (axes, v),
|
|
lambda d, v: (tuple(axis + (axis >= d) if isinstance(axis, int) else axis
|
|
for axis in axes),
|
|
v))
|
|
# _reduction_with_positional_batcher moves all map dims to 0
|
|
return val_out, d if d is batching.not_mapped else 0
|
|
|
|
def _batched_reduction_collective(prim, if_unmapped, axis_data, vals_in,
|
|
dims_in, axes, axis_index_groups):
|
|
assert not prim.multiple_results
|
|
(v,), (d,) = vals_in, dims_in
|
|
del vals_in, dims_in
|
|
|
|
if d is None:
|
|
if axis_data.name in axes:
|
|
return _constant_reduction(prim, axis_data, v, axes, axis_index_groups)
|
|
else:
|
|
out = (prim.bind(v, axes=axes) if prim is psum_invariant_p else
|
|
prim.bind(v, axes=axes, axis_index_groups=axis_index_groups))
|
|
return out, d
|
|
|
|
if axis_data.name not in axes:
|
|
return _reduction_batcher(
|
|
prim, v, d, axes=axes, axis_index_groups=axis_index_groups)
|
|
|
|
# Note that we have a choice here. We can either unfuse the reduction into one
|
|
# that handles the batched dims and then another one that handles the rest.
|
|
# Alternatively, we can keep the dimension reduction fused with the rest, but
|
|
# we have to split the primitive into one for unmapped inputs and another
|
|
# one for mapped, because they differ in their `axes` parameter.
|
|
# We choose the second strategy here.
|
|
val_out = _reduction_with_positional_batcher(
|
|
prim, v, d, axis_index_groups,
|
|
lambda d, v: (tuple(axis for axis in axes if axis != axis_data.name),
|
|
if_unmapped(v, axis_data.size)),
|
|
lambda d, v: (tuple(axis + (axis >= d) if isinstance(axis, int) else axis
|
|
if axis != axis_data.name else d for axis in axes),
|
|
v))
|
|
return val_out, batching.not_mapped
|
|
|
|
def _replica_groups(axis_env, axis_name, axis_index_groups):
|
|
replica_groups = pxla.axis_groups(axis_env, axis_name)
|
|
if axis_index_groups is not None:
|
|
replica_groups = [[axis_group[i] for i in axis_index_group]
|
|
for axis_group in replica_groups
|
|
for axis_index_group in axis_index_groups]
|
|
return replica_groups
|
|
|
|
def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]]
|
|
) -> ir.DenseElementsAttr:
|
|
# Uneven replica groups are padded with -1.
|
|
groups = np.array(list(itertools.zip_longest(*replica_groups, fillvalue=-1)),
|
|
dtype=np.int64).T
|
|
return ir.DenseIntElementsAttr.get(np.ascontiguousarray(groups)) # pyrefly: ignore[no-matching-overload]
|
|
|
|
def _allreduce_impl(prim, pos_reducer, arg, *, axes, axis_index_groups):
|
|
assert axis_index_groups is None
|
|
if not all(isinstance(axis, int) for axis in axes):
|
|
return dispatch.apply_primitive(prim, arg, axes=axes,
|
|
axis_index_groups=axis_index_groups)
|
|
assert all(isinstance(axis, int) for axis in axes)
|
|
return pos_reducer(arg, axes)
|
|
|
|
def _allreduce_effectful_abstract_eval(aval, *, axes, axis_index_groups):
|
|
_check_axis_names(axes, 'psum')
|
|
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
|
|
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
|
|
if axis_index_groups is not None:
|
|
if len(pos_axes) != 0:
|
|
raise ValueError(f"axis_index_groups can only be used with reductions over "
|
|
f"named axes, but got: {axes}")
|
|
core.check_avals_context_mesh([aval], 'psum')
|
|
check_unreduced_args([aval], axes, 'psum')
|
|
out_aval = ShapedArray(
|
|
lax._reduce_op_shape_rule(aval, axes=pos_axes), aval.dtype,
|
|
sharding=lax._reduce_op_sharding_rule(aval, axes=pos_axes))
|
|
return out_aval, {core.NamedAxisEffect(axis) for axis in named_axes}
|
|
|
|
# TODO(yashkatariya): Replace this with _psum_invariant_abstract_eval
|
|
def _pmin_pmax_abstract_eval(name, aval, *, axes, axis_index_groups):
|
|
if not config._check_vma.value:
|
|
return _allreduce_effectful_abstract_eval(
|
|
aval, axes=axes, axis_index_groups=axis_index_groups)
|
|
return _psum_invariant_abstract_eval(name, aval, axes=axes)
|
|
|
|
def _check_axis_names(axes, api_name):
|
|
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
|
|
axis_env = core.get_axis_env()
|
|
for name in named_axes:
|
|
if not axis_env.axis_exists(name):
|
|
raise NameError(
|
|
f"Found an unbound axis name: {name}. To fix this, please call"
|
|
f" {api_name} under `jax.shard_map`.")
|
|
|
|
def _allreduce_lowering(prim, pos_fn, ctx, arg, *, axes, axis_index_groups):
|
|
aval_in, = ctx.avals_in
|
|
if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms):
|
|
len_0 = len(axis_index_groups[0])
|
|
if any(len(g) != len_0 for g in axis_index_groups):
|
|
raise ValueError("axis_index_groups must all be the same size for TPU lowering")
|
|
named_axes, positional_axes = axes_partition = [], []
|
|
for axis in axes:
|
|
axes_partition[isinstance(axis, int)].append(axis)
|
|
|
|
if positional_axes:
|
|
reducer = mlir.lower_fun(pos_fn, multiple_results=False)
|
|
def _positional_reduce(aval, arg):
|
|
aval_out = aval.update(
|
|
shape=np.delete(np.array(aval.shape, dtype=np.int64),
|
|
positional_axes))
|
|
reducer_ctx = ctx.replace(primitive=None, avals_in=[aval], avals_out=[aval_out])
|
|
out, = reducer(reducer_ctx, arg, axes=tuple(positional_axes))
|
|
return out
|
|
arg = _positional_reduce(aval_in, arg)
|
|
if not named_axes:
|
|
return [arg]
|
|
|
|
replica_groups = _replica_groups_hlo(
|
|
_replica_groups(ctx.module_context.axis_env, named_axes,
|
|
axis_index_groups))
|
|
axis_context = ctx.module_context.axis_context
|
|
is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext))
|
|
|
|
def all_reduce(aval, x):
|
|
if is_spmd:
|
|
other_args: dict[str, Any] = dict(
|
|
channel_handle=hlo.ChannelHandle.get(
|
|
mlir.COLLECTIVE_CHANNEL_ID, mlir.DEVICE_TO_DEVICE_TYPE),
|
|
use_global_device_ids=ir.BoolAttr.get(True))
|
|
else:
|
|
other_args = {}
|
|
|
|
op = hlo.AllReduceOp(
|
|
[x.type], [x], replica_groups=replica_groups, **other_args)
|
|
scalar_aval = core.ShapedArray(
|
|
(), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P()))
|
|
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
|
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
|
with ir.InsertionPoint(reducer_block):
|
|
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
|
|
reducer_ctx = ctx.replace(primitive=None,
|
|
avals_in=[scalar_aval] * 2, avals_out=[scalar_aval])
|
|
out_nodes = lower_reducer(reducer_ctx, *reducer_block.arguments)
|
|
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
|
return op.result
|
|
return [all_reduce(aval_in, arg)]
|
|
|
|
def _psum_transpose_rule(cts, arg, *, axes, axis_index_groups):
|
|
named_axes, pos_axes = axes_partition = [], []
|
|
for axis in axes:
|
|
axes_partition[isinstance(axis, int)].append(axis)
|
|
|
|
if pos_axes:
|
|
def broadcast_positional(ct, arg):
|
|
assert ad.is_undefined_primal(arg)
|
|
if type(ct) is ad.Zero: return ad.Zero(arg.aval)
|
|
return lax._reduce_sum_transpose_rule(ct, arg, axes=pos_axes,
|
|
out_sharding=None)[0]
|
|
cts = broadcast_positional(cts, arg)
|
|
|
|
# We treat psum as psum + pbroadcast, which is why the transpose reduces
|
|
# over the named axes again (unlike for positional axes).
|
|
return (psum_p.bind(cts, axes=tuple(named_axes),
|
|
axis_index_groups=axis_index_groups),)
|
|
|
|
psum_p = core.Primitive('psum')
|
|
psum_p.def_impl(partial(_allreduce_impl, psum_p, lax.reduce_sum))
|
|
psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
|
|
mlir.register_lowering(
|
|
psum_p, partial(_allreduce_lowering, lax.add_p, lax.reduce_sum))
|
|
ad.deflinear2(psum_p, _psum_transpose_rule)
|
|
batching.fancy_primitive_batchers[psum_p] = \
|
|
partial(_batched_reduction_collective, psum_p, lambda v, axis_size: axis_size * v)
|
|
|
|
|
|
pmax_p = core.Primitive('pmax')
|
|
pmax_p.def_impl(partial(_allreduce_impl, pmax_p, lax.reduce_max))
|
|
pmax_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmax'))
|
|
mlir.register_lowering(
|
|
pmax_p, partial(_allreduce_lowering, lax.max_p, lax.reduce_max))
|
|
batching.fancy_primitive_batchers[pmax_p] = \
|
|
partial(_batched_reduction_collective, pmax_p, lambda v, axis_size: v)
|
|
|
|
|
|
pmin_p = core.Primitive('pmin')
|
|
pmin_p.def_impl(partial(_allreduce_impl, pmin_p, lax.reduce_min))
|
|
pmin_p.def_effectful_abstract_eval(partial(_pmin_pmax_abstract_eval, 'pmin'))
|
|
mlir.register_lowering(
|
|
pmin_p, partial(_allreduce_lowering, lax.min_p, lax.reduce_min))
|
|
batching.fancy_primitive_batchers[pmin_p] = \
|
|
partial(_batched_reduction_collective, pmin_p, lambda v, axis_size: v)
|
|
|
|
|
|
def _pcollectives_lowering_common(ctx, *, axis_name, perm, op_name):
|
|
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None)
|
|
group_size = len(replica_groups[0])
|
|
srcs, dsts = unzip2((src % group_size, dst % group_size) for src, dst in perm)
|
|
if not (len(srcs) == len(set(srcs)) and len(dsts) == len(set(dsts))):
|
|
msg = f"{op_name} sources and destinations must be unique, got {{}}."
|
|
raise ValueError(msg.format(perm))
|
|
|
|
full_perm = np.zeros((len(replica_groups), len(perm), 2), np.int64)
|
|
for i, grp in enumerate(replica_groups):
|
|
grp = sorted(grp)
|
|
for j, (src, dst) in enumerate(perm):
|
|
full_perm[i, j, 0] = grp[src]
|
|
full_perm[i, j, 1] = grp[dst]
|
|
full_perm = full_perm.reshape((-1, 2))
|
|
|
|
axis_context = ctx.module_context.axis_context
|
|
is_manual = (
|
|
isinstance(axis_context, SPMDAxisContext)
|
|
and axis_context.manual_axes
|
|
)
|
|
if is_manual:
|
|
other_args: dict[str, Any] = dict(
|
|
channel_handle=hlo.ChannelHandle.get(
|
|
mlir.COLLECTIVE_CHANNEL_ID, mlir.DEVICE_TO_DEVICE_TYPE
|
|
)
|
|
)
|
|
else:
|
|
other_args = {}
|
|
return full_perm, other_args
|
|
|
|
|
|
def _ppermute_lowering(ctx, x, *, axis_name, perm):
|
|
full_perm, other_args = _pcollectives_lowering_common(
|
|
ctx, axis_name=axis_name, perm=perm, op_name="ppermute"
|
|
)
|
|
return hlo.CollectivePermuteOp(
|
|
x, mlir.dense_int_elements(full_perm), **other_args).results
|
|
|
|
|
|
def _ppermute_transpose_rule(t, x, perm, axis_name):
|
|
srcs, dsts = unzip2(perm)
|
|
inverse_perm = list(zip(dsts, srcs))
|
|
return [ppermute(t, axis_name=axis_name, perm=inverse_perm)]
|
|
|
|
def _ppermute_batcher(axis_data, vals_in, dims_in, axis_name, perm):
|
|
axis_size, frame_name = axis_data.size, axis_data.name
|
|
(v,), (d,) = vals_in, dims_in
|
|
if not isinstance(axis_name, (tuple, list)):
|
|
axis_name = (axis_name,)
|
|
if d is None and axis_data.name not in axis_name:
|
|
return ppermute_p.bind(v, perm=perm, axis_name=axis_name), None
|
|
if axis_data.name not in axis_name:
|
|
return ppermute_p.bind(v, perm=perm, axis_name=axis_name), d
|
|
remaining_axes = tuple(axis for axis in axis_name if axis != frame_name)
|
|
if remaining_axes:
|
|
return ppermute_p.bind(v, perm=perm, axis_name=remaining_axes), d
|
|
assert axis_name[0] == frame_name, "ppermute batcher called with a wrong axis!"
|
|
assert len(perm) == axis_size, "Permutation doesn't match the axis size!"
|
|
if d is batching.not_mapped:
|
|
return v, d
|
|
perm_indices = np.zeros(axis_size, dtype=int)
|
|
for src, dst in perm:
|
|
perm_indices[dst] = src
|
|
return v.take(perm_indices, d), d
|
|
|
|
def _raise_to_shaped_abstract_eval(x, *, axis_name, **params):
|
|
_check_axis_names(axis_name, 'ppermute')
|
|
collective_vma_rule('ppermute', axis_name, x)
|
|
check_unreduced_args([x], axis_name, 'ppermute')
|
|
return x
|
|
|
|
ppermute_p = core.Primitive('ppermute')
|
|
ppermute_p.def_abstract_eval(_raise_to_shaped_abstract_eval)
|
|
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
|
|
mlir.register_lowering(ppermute_p, _ppermute_lowering)
|
|
batching.fancy_primitive_batchers[ppermute_p] = _ppermute_batcher
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SingleSideCollectiveEffect(core.Effect):
|
|
__str__ = lambda _: "one-sided communication"
|
|
def __hash__(self):
|
|
return hash(SingleSideCollectiveEffect)
|
|
def __eq__(self, other):
|
|
return isinstance(other, SingleSideCollectiveEffect)
|
|
|
|
|
|
single_side_collective_effect = SingleSideCollectiveEffect()
|
|
core.effects.control_flow_allowed_effects.add_type(SingleSideCollectiveEffect)
|
|
|
|
def _psend_lowering_gpu(ctx, x, *, axis_name, perm):
|
|
if ("cuda" not in ctx.module_context.platforms and
|
|
"rocm" not in ctx.module_context.platforms):
|
|
raise NotImplementedError("psend is currently only implemented on GPUs")
|
|
|
|
full_perm, other_args = _pcollectives_lowering_common(
|
|
ctx, axis_name=axis_name, perm=perm, op_name="psend"
|
|
)
|
|
token = hlo.create_token()
|
|
send_op = hlo.SendOp(
|
|
[x],
|
|
token,
|
|
source_target_pairs=mlir.dense_int_elements(full_perm),
|
|
**other_args,
|
|
)
|
|
axis_ctx = ctx.module_context.axis_context
|
|
if not isinstance(axis_ctx, SPMDAxisContext):
|
|
raise NotImplementedError("psend currently only supports manual sharding")
|
|
|
|
sharding = xc.OpSharding()
|
|
sharding.type = xc.OpSharding.Type.MANUAL
|
|
mlir.set_sharding(send_op, sharding)
|
|
return send_op.results
|
|
|
|
|
|
effects_lib.lowerable_effects.add_type(SingleSideCollectiveEffect)
|
|
|
|
|
|
def _psend_abstract_eval(x, *, axis_name, **params):
|
|
_check_axis_names(axis_name, 'psend')
|
|
return abstract_token, {
|
|
*map(core.NamedAxisEffect, axis_name),
|
|
single_side_collective_effect,
|
|
}
|
|
|
|
|
|
psend_p = core.Primitive("psend")
|
|
psend_p.def_impl(partial(dispatch.apply_primitive, psend_p))
|
|
psend_p.def_effectful_abstract_eval(_psend_abstract_eval)
|
|
mlir.register_lowering(psend_p, _psend_lowering_gpu, platform="gpu")
|
|
|
|
def _psend_lowering(ctx, x, *, axis_name, perm):
|
|
raise NotImplementedError("psend is currently only implemented on GPU")
|
|
mlir.register_lowering(psend_p, _psend_lowering)
|
|
|
|
batching.fancy_primitive_batchers[psend_p] = _ppermute_batcher
|
|
|
|
|
|
def _precv_lowering_gpu(ctx, token, *, out_shape, axis_name, perm):
|
|
full_perm, other_args = _pcollectives_lowering_common(
|
|
ctx, axis_name=axis_name, perm=perm, op_name="precv"
|
|
)
|
|
out_type = mlir.aval_to_ir_type(out_shape)
|
|
recv_op = hlo.RecvOp(
|
|
[out_type, token.type],
|
|
token,
|
|
source_target_pairs=mlir.dense_int_elements(full_perm),
|
|
**other_args,
|
|
)
|
|
axis_ctx = ctx.module_context.axis_context
|
|
if not isinstance(axis_ctx, SPMDAxisContext):
|
|
raise NotImplementedError("precv currently only supports manual sharding")
|
|
|
|
sharding = xc.OpSharding()
|
|
sharding.type = xc.OpSharding.Type.MANUAL
|
|
mlir.set_sharding(recv_op, sharding)
|
|
|
|
# recv_op should return an array of [RankedTensorType, StableHlo.token]; we
|
|
# only need the tensor.
|
|
results = recv_op.results
|
|
return [results[0]]
|
|
|
|
|
|
def _precv_abstract_eval(
|
|
token, *, out_shape, axis_name, **params
|
|
):
|
|
return out_shape, {*map(core.NamedAxisEffect, axis_name),
|
|
single_side_collective_effect}
|
|
|
|
precv_p = core.Primitive("precv")
|
|
precv_p.def_effectful_abstract_eval(_precv_abstract_eval)
|
|
mlir.register_lowering(precv_p, _precv_lowering_gpu, platform='gpu')
|
|
|
|
def _precv_lowering(ctx, token, *, out_shape, axis_name, perm):
|
|
raise NotImplementedError("precv is currently only implemented on GPU")
|
|
mlir.register_lowering(precv_p, _precv_lowering)
|
|
|
|
batching.fancy_primitive_batchers[precv_p] = _ppermute_batcher
|
|
|
|
def _pbroadcast_transpose_rule(t, x, source, axis_name):
|
|
is_source = axis_index(axis_name) == source
|
|
tsum = psum(t, axis_name)
|
|
return [lax.select(is_source, lax.full_like(t, tsum), lax.full_like(t, 0))]
|
|
|
|
def _pbroadcast_batcher(axis_data, vals_in, dims_in, axis_name, source):
|
|
axis_size = axis_data.size
|
|
(v,), (d,) = vals_in, dims_in
|
|
if not isinstance(axis_name, (tuple, list)):
|
|
axis_name = (axis_name,)
|
|
if d is None and axis_data.name not in axis_name:
|
|
return pbroadcast_p.bind(v, axis_name=axis_name, source=source), None
|
|
if axis_data.name not in axis_name:
|
|
return pbroadcast_p.bind(v, axis_name=axis_name, source=source), d
|
|
remaining_axes = tuple(axis for axis in axis_name if axis != axis_data.name)
|
|
if remaining_axes:
|
|
raise NotImplementedError("pbroadcast batcher only supports a single axis")
|
|
assert axis_name[0] == axis_data.name, "pbroadcast batcher called with a wrong axis!"
|
|
assert source >= 0 and source < axis_size, "collective broadcast doesn't fit in the axis size!"
|
|
if axis_size == 1 and remaining_axes:
|
|
return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d
|
|
if d is batching.not_mapped:
|
|
return v, d
|
|
return v.take([source] * axis_size, d), d
|
|
|
|
def _pbroadcast_lowering(ctx, x, *, axis_name, source):
|
|
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None)
|
|
def source_to_front(group):
|
|
return [group[source]] + list(group[:source]) + list(group[source + 1:])
|
|
replica_groups = [source_to_front(group) for group in replica_groups]
|
|
is_spmd = isinstance(
|
|
ctx.module_context.axis_context,
|
|
(SPMDAxisContext, ShardingContext),
|
|
)
|
|
if is_spmd:
|
|
# We want to emit the collective-broadcast with global device IDs and a
|
|
# channel ID, as otherwise it interprets the devices as replicas instead
|
|
# of partitions - and XLA is configured with only a single replica.
|
|
channel_handle = hlo.ChannelHandle.get(mlir.COLLECTIVE_CHANNEL_ID,
|
|
mlir.DEVICE_TO_DEVICE_TYPE)
|
|
other_args: dict[str, Any] = dict(channel_handle=channel_handle)
|
|
else:
|
|
other_args = {}
|
|
return hlo.CollectiveBroadcastOp(
|
|
x, replica_groups=_replica_groups_hlo(replica_groups), **other_args
|
|
).results
|
|
|
|
pbroadcast_p = core.Primitive('pbroadcast')
|
|
pbroadcast_p.def_abstract_eval(_raise_to_shaped_abstract_eval)
|
|
ad.deflinear2(pbroadcast_p, _pbroadcast_transpose_rule)
|
|
mlir.register_lowering(pbroadcast_p, _pbroadcast_lowering, platform='gpu')
|
|
batching.fancy_primitive_batchers[pbroadcast_p] = _pbroadcast_batcher
|
|
|
|
|
|
def _moveaxis(src, dst, x):
|
|
perm = [i for i in range(x.ndim) if i != src]
|
|
perm.insert(dst, src)
|
|
return lax.transpose(x, perm)
|
|
|
|
def _splitaxis(axis, factor, x):
|
|
new_shape = list(x.shape)
|
|
assert new_shape[axis] % factor == 0, (new_shape[axis], factor)
|
|
new_shape[axis:axis+1] = [factor, new_shape[axis] // factor]
|
|
return x.reshape(new_shape)
|
|
|
|
def _foldaxis(axis, x):
|
|
new_shape = list(x.shape)
|
|
new_shape[axis:axis+2] = [x.shape[axis] * x.shape[axis + 1]]
|
|
return x.reshape(new_shape)
|
|
|
|
def _all_to_all_lowering(
|
|
ctx, x, *, split_axis, concat_axis, axis_name, axis_index_groups, tiled
|
|
):
|
|
del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
|
|
# Workaround for AllToAll not being implemented on CPU.
|
|
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
|
axis_index_groups)
|
|
if len(replica_groups[0]) == 1:
|
|
return [x]
|
|
split_count = len(replica_groups[0])
|
|
if not all(split_count == len(g) for g in replica_groups):
|
|
raise ValueError('Replica groups must be equally sized')
|
|
is_spmd = isinstance(
|
|
ctx.module_context.axis_context,
|
|
(SPMDAxisContext, ShardingContext),
|
|
)
|
|
if is_spmd:
|
|
# We want to emit the all-gather with global device IDs and a
|
|
# channel ID, as otherwise it interprets the devices as replicas instead
|
|
# of partitions - and XLA is configured with only a single replica.
|
|
channel_handle = hlo.ChannelHandle.get(mlir.COLLECTIVE_CHANNEL_ID,
|
|
mlir.DEVICE_TO_DEVICE_TYPE)
|
|
other_args: dict[str, Any] = dict(channel_handle=channel_handle)
|
|
else:
|
|
other_args = {}
|
|
return hlo.AllToAllOp(
|
|
[x],
|
|
split_dimension=mlir.i64_attr(split_axis),
|
|
concat_dimension=mlir.i64_attr(concat_axis),
|
|
split_count=mlir.i64_attr(split_count),
|
|
replica_groups=_replica_groups_hlo(replica_groups),
|
|
**other_args).results
|
|
|
|
def _all_to_all_transpose_rule(
|
|
cts, x, axis_name, split_axis, concat_axis, axis_index_groups, tiled
|
|
):
|
|
return (all_to_all(
|
|
cts,
|
|
axis_name=axis_name,
|
|
split_axis=concat_axis,
|
|
concat_axis=split_axis,
|
|
axis_index_groups=axis_index_groups,
|
|
tiled=tiled),)
|
|
|
|
def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, axis_index_groups,
|
|
tiled):
|
|
x, = vals_in
|
|
d, = dims_in
|
|
result = all_to_all_p.bind(
|
|
x,
|
|
axis_name=axis_name,
|
|
split_axis=split_axis + (d <= split_axis),
|
|
concat_axis=concat_axis + (d <= concat_axis),
|
|
axis_index_groups=axis_index_groups,
|
|
tiled=tiled,
|
|
)
|
|
return result, d
|
|
|
|
def _all_to_all_batched_collective(axis_data, vals_in, dims_in,
|
|
axis_name, split_axis, concat_axis,
|
|
axis_index_groups, tiled):
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError("Please open a feature request!")
|
|
x, = vals_in
|
|
d, = dims_in
|
|
axis_size, frame_name = axis_data.size, axis_data.name
|
|
axes_names = axis_name if isinstance(axis_name, (list, tuple)) else [axis_name]
|
|
if d is None and frame_name not in axes_names:
|
|
out = all_to_all_p.bind(
|
|
x, axis_name=axis_name, split_axis=split_axis, concat_axis=concat_axis,
|
|
axis_index_groups=axis_index_groups, tiled=tiled)
|
|
return out, None
|
|
if frame_name not in axes_names:
|
|
return _all_to_all_batcher(
|
|
vals_in, dims_in, axis_name=axis_name, split_axis=split_axis,
|
|
concat_axis=concat_axis, axis_index_groups=axis_index_groups, tiled=tiled)
|
|
|
|
if d is batching.not_mapped:
|
|
# TODO(sharadmv,apaszke): Remove this broadcast that comes from
|
|
# all_gather_transpose and instead avoid using all_to_all in
|
|
# all_gather_transpose.
|
|
x = lax.broadcast(x, (axis_size, *x.shape))
|
|
d = 0
|
|
if isinstance(axis_name, (list, tuple)):
|
|
pos = axis_name.index(frame_name)
|
|
major_axes, minor_axes = axis_name[:pos], axis_name[pos + 1:]
|
|
else:
|
|
major_axes, minor_axes = (), ()
|
|
# Optimized case when no splitting is necessary
|
|
if not major_axes and not minor_axes:
|
|
if split_axis == concat_axis:
|
|
axis = split_axis + (d <= split_axis)
|
|
d_pre_split = d
|
|
x = _splitaxis(axis, axis_size, x)
|
|
d += (axis <= d)
|
|
return _foldaxis(axis, moveaxis(x, (d, axis), (axis, d))), d_pre_split
|
|
else:
|
|
x_concat = _foldaxis(concat_axis, _moveaxis(d, concat_axis, x))
|
|
return _splitaxis(split_axis, axis_size, x_concat), split_axis
|
|
# Here we have to handle either the major or the minor dimensions
|
|
# We will be accumulating chunks into the three leading dims: [Major, Current, Minor, ...]
|
|
x, d = lax.expand_dims(_moveaxis(d, 0, x), (0, 2)), 1
|
|
split_axis += 3; concat_axis += 3 # Offset by extra three leading dims
|
|
|
|
if major_axes:
|
|
x = all_to_all_p.bind(x, axis_name=major_axes,
|
|
split_axis=split_axis, concat_axis=0,
|
|
axis_index_groups=axis_index_groups,
|
|
tiled=tiled)
|
|
# Split out the local part into axis new_d (NOTE: d is already in axis 1)
|
|
assert d == 1
|
|
x = _splitaxis(split_axis, axis_size, x)
|
|
new_d = split_axis
|
|
concat_axis += (split_axis <= concat_axis) # Offset the existing axes by the new batch axis
|
|
split_axis += 1
|
|
if minor_axes:
|
|
x = all_to_all_p.bind(x, axis_name=minor_axes,
|
|
split_axis=split_axis, concat_axis=2,
|
|
axis_index_groups=axis_index_groups,
|
|
tiled=tiled)
|
|
|
|
# Fold the chunk axes into a single one
|
|
x = _foldaxis(0, _foldaxis(0, x))
|
|
split_axis -= 2; concat_axis -= 2; new_d -= 2
|
|
# Fold gathered axes into concat_axis
|
|
x = _foldaxis(concat_axis - 1, _moveaxis(0, concat_axis - 1, x))
|
|
new_d -= 1 # We've removed 0th dimension, so new_d needs to be adjusted
|
|
return x, new_d
|
|
|
|
|
|
def _all_to_all_effectful_abstract_eval(
|
|
input_aval, axis_name, split_axis, concat_axis, axis_index_groups, tiled
|
|
):
|
|
del tiled # expand_dims and squeeze is done in `all_to_all` if `True`
|
|
if isinstance(axis_name, list):
|
|
axis_name = tuple(axis_name)
|
|
elif not isinstance(axis_name, tuple):
|
|
axis_name = (axis_name,)
|
|
_check_axis_names(axis_name, 'all_to_all')
|
|
check_unreduced_args([input_aval], axis_name, 'all_to_all')
|
|
shape = list(input_aval.shape)
|
|
axis_size = (
|
|
_axis_size(axis_name)
|
|
if axis_index_groups is None
|
|
else len(axis_index_groups[0])
|
|
)
|
|
assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size)
|
|
shape[split_axis] //= axis_size
|
|
shape[concat_axis] *= axis_size
|
|
vma = collective_vma_rule('all_to_all', axis_name, input_aval)
|
|
out_aval = input_aval.update(
|
|
shape=tuple(shape), weak_type=False,
|
|
manual_axis_type=input_aval.mat.update(varying=vma))
|
|
effects = {*map(core.NamedAxisEffect, axis_name)}
|
|
return out_aval, effects
|
|
|
|
def _all_to_all_impl(*args, **kwargs):
|
|
raise RuntimeError("all_to_all must be used within a mapped context"
|
|
" like vmap or shard_map.")
|
|
|
|
all_to_all_p = core.Primitive('all_to_all')
|
|
all_to_all_p.def_impl(_all_to_all_impl)
|
|
all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval)
|
|
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
|
|
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
|
|
batching.fancy_primitive_batchers[all_to_all_p] = _all_to_all_batched_collective
|
|
|
|
|
|
def _ragged_all_to_all_lowering(
|
|
ctx, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
|
|
*, axis_name, axis_index_groups
|
|
):
|
|
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
|
axis_index_groups)
|
|
|
|
# Assumes all groups are the same size
|
|
split_count = len(replica_groups[0])
|
|
if not all(split_count == len(g) for g in replica_groups):
|
|
raise ValueError('Replica groups must be equally sized')
|
|
|
|
ragged_all_to_all_attrs: dict[str, ir.Attribute] = {
|
|
"replica_groups": _replica_groups_hlo(replica_groups)
|
|
}
|
|
is_spmd = isinstance(
|
|
ctx.module_context.axis_context, (SPMDAxisContext, ShardingContext))
|
|
if is_spmd:
|
|
ragged_all_to_all_attrs['channel_id'] = ir.IntegerAttr.get(
|
|
ir.IntegerType.get_signless(64), mlir.COLLECTIVE_CHANNEL_ID
|
|
)
|
|
|
|
return hlo.CustomCallOp(
|
|
result=[output.type],
|
|
inputs=[operand, output, input_offsets, send_sizes, output_offsets,
|
|
recv_sizes],
|
|
call_target_name=ir.StringAttr.get('ragged_all_to_all'),
|
|
backend_config=ir.DictAttr.get(ragged_all_to_all_attrs),
|
|
api_version=ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 4),
|
|
).results
|
|
|
|
def _ragged_all_to_all_effectful_abstract_eval(
|
|
operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
|
|
axis_name, axis_index_groups
|
|
):
|
|
del operand, axis_index_groups
|
|
if not dtypes.issubdtype(input_offsets.dtype, np.integer):
|
|
raise ValueError("ragged_all_to_all input_offsets must be integer type.")
|
|
if not dtypes.issubdtype(send_sizes.dtype, np.integer):
|
|
raise ValueError("ragged_all_to_all send_sizes must be integer type.")
|
|
if not dtypes.issubdtype(output_offsets.dtype, np.integer):
|
|
raise ValueError("ragged_all_to_all output_offsets must be integer type.")
|
|
if not dtypes.issubdtype(recv_sizes.dtype, np.integer):
|
|
raise ValueError("ragged_all_to_all recv_sizes must be integer type.")
|
|
if len(input_offsets.shape) != 1 or input_offsets.shape[0] < 1:
|
|
raise ValueError(
|
|
"ragged_all_to_all input_offsets must be rank 1 with positive dimension"
|
|
" size, but got shape {}".format(input_offsets.shape)
|
|
)
|
|
if len(send_sizes.shape) != 1 or send_sizes.shape[0] < 1:
|
|
raise ValueError(
|
|
"ragged_all_to_all send_sizes must be rank 1 with positive dimension"
|
|
" size, but got shape {}".format(send_sizes.shape)
|
|
)
|
|
if len(output_offsets.shape) != 1 or output_offsets.shape[0] < 1:
|
|
raise ValueError(
|
|
"ragged_all_to_all output_offsets must be rank 1 with positive"
|
|
" dimension size, but got shape {}".format(output_offsets.shape)
|
|
)
|
|
if len(recv_sizes.shape) != 1 or recv_sizes.shape[0] < 1:
|
|
raise ValueError(
|
|
"ragged_all_to_all recv_sizes must be rank 1 with positive dimension"
|
|
" size, but got shape {}".format(recv_sizes.shape)
|
|
)
|
|
|
|
_check_axis_names(axis_name, 'ragged_all_to_all')
|
|
out_aval = output.update(shape=output.shape, weak_type=False)
|
|
effects = {*map(core.NamedAxisEffect, axis_name)}
|
|
return out_aval, effects
|
|
|
|
def _ragged_all_to_all_jvp(primals, tangents, **params):
|
|
operand, output, *sizes_and_offsets = primals
|
|
operand_dot, output_dot, *_ = tangents
|
|
result = ragged_all_to_all_p.bind(
|
|
operand, output, *sizes_and_offsets, **params)
|
|
if type(operand_dot) is type(output_dot) is ad.Zero:
|
|
result_dot = ad.p2tz(result)
|
|
else:
|
|
operand_dot = ad.instantiate_zeros(operand_dot)
|
|
output_dot = ad.instantiate_zeros(output_dot)
|
|
result_dot = ragged_all_to_all_p.bind(
|
|
operand_dot, output_dot, *sizes_and_offsets, **params)
|
|
return result, result_dot
|
|
|
|
def _ragged_all_to_all_transpose(
|
|
t, operand, output, input_offsets, send_sizes, output_offsets, recv_sizes,
|
|
*, axis_name, axis_index_groups):
|
|
if type(t) is ad.Zero:
|
|
operand_t = ad.Zero(operand.aval) if ad.is_undefined_primal(operand) else None
|
|
output_t = ad.Zero(output.aval) if ad.is_undefined_primal(output) else None
|
|
else:
|
|
zero = ad.zeros_like_aval(operand.aval)
|
|
output_offsets_ = all_to_all(output_offsets, axis_name, 0, 0, tiled=True)
|
|
input_offsets_ = all_to_all(input_offsets, axis_name, 0, 0, tiled=True)
|
|
operand_t = ragged_all_to_all_p.bind(
|
|
t, zero, output_offsets_, recv_sizes, input_offsets_, send_sizes,
|
|
axis_name=axis_name, axis_index_groups=axis_index_groups)
|
|
mask = control_flow.cumsum(
|
|
lax.full(t.shape[0], 0, dtype='int32').at[output_offsets_].set(1)
|
|
.at[output_offsets_ + recv_sizes].add(-1))
|
|
mask = lax.expand_dims(mask, (*range(1, t.ndim),))
|
|
mask = lax.broadcast_in_dim(mask, shape=t.shape, broadcast_dimensions=tuple(range(t.ndim)))
|
|
output_t = lax.select(mask, lax._zeros(t), t)
|
|
return [operand_t, output_t] + [None] * 4
|
|
|
|
def _ragged_all_to_all_batched_collective(axis_data, vals_in, dims_in,
|
|
axis_name, axis_index_groups):
|
|
if all(bdim is None for bdim in dims_in) and axis_data.name not in axis_name:
|
|
out = ragged_all_to_all_p.bind(*vals_in, axis_name=axis_name,
|
|
axis_index_groups=axis_index_groups)
|
|
return out, None
|
|
if axis_data.name in axis_name:
|
|
raise NotImplementedError("Please open a feature request!")
|
|
if axis_index_groups:
|
|
raise NotImplementedError("Please open a feature request!")
|
|
size = axis_data.size
|
|
|
|
def bdim_at_second(x, d):
|
|
assert x.ndim == 2
|
|
return (batching.broadcast(x, size, 1, None) if d is None else
|
|
x if d == 1 else x.T)
|
|
def merge(x): return x.reshape(-1, *x.shape[2:])
|
|
def split(x): return x.reshape(size, -1, *x.shape[1:])
|
|
|
|
operand, output = map(partial(batching.bdim_at_front, size=size), vals_in[:2], dims_in[:2])
|
|
N, M = operand.shape[1], output.shape[1]
|
|
input_offsets, send_sizes, output_offsets, recv_sizes = \
|
|
map(bdim_at_second, vals_in[2:], dims_in[2:])
|
|
input_offsets += lax.iota(input_offsets.dtype, size)[None, :] * N
|
|
output_offsets += lax.iota(output_offsets.dtype, size)[None, :] * M
|
|
vals_in = operand, output, input_offsets, send_sizes, output_offsets, recv_sizes
|
|
result = split(ragged_all_to_all(*map(merge, vals_in), axis_name=axis_name))
|
|
return result, 0
|
|
|
|
def _ragged_all_to_all_impl(*args, **kwargs):
|
|
raise RuntimeError("ragged_all_to_all must be used within a mapped context"
|
|
" like vmap or shard_map.")
|
|
|
|
ragged_all_to_all_p = core.Primitive('ragged_all_to_all')
|
|
ragged_all_to_all_p.def_impl(_ragged_all_to_all_impl)
|
|
ragged_all_to_all_p.def_effectful_abstract_eval(_ragged_all_to_all_effectful_abstract_eval)
|
|
ad.primitive_jvps[ragged_all_to_all_p] = _ragged_all_to_all_jvp
|
|
ad.primitive_transposes[ragged_all_to_all_p] = _ragged_all_to_all_transpose
|
|
mlir.register_lowering(ragged_all_to_all_p, _ragged_all_to_all_lowering)
|
|
batching.fancy_primitive_batchers[ragged_all_to_all_p] = _ragged_all_to_all_batched_collective
|
|
|
|
|
|
def insert_collective_pvary(axis_name, x):
|
|
if not config._check_vma.value:
|
|
return x
|
|
|
|
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
|
aval = core.typeof(x)
|
|
names_union = set(axis_name) | aval.mat.varying
|
|
x = pvary(x, tuple(n for n in names_union if n not in aval.mat.varying))
|
|
return x
|
|
|
|
def all_gather(x, axis_name, *, axis_index_groups=None, axis=0, tiled=False,
|
|
to: str = 'varying'):
|
|
"""Gather values of x across all replicas.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
This is equivalent to, but faster than, all_to_all(broadcast(x)).
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a pmapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
axis_index_groups: optional list of lists containing axis indices (e.g. for
|
|
an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first
|
|
two and last two replicas). Groups must cover all axis indices exactly
|
|
once, and all groups must be the same size.
|
|
axis: a positional axis into which the chunks along ``axis_name`` will be
|
|
concatenated.
|
|
tiled: when ``False``, the chunks will be stacked into a fresh positional
|
|
axis at index ``axis`` in the output. When ``True``, ``axis`` has to
|
|
refer to an existing positional dimension and the chunks will be
|
|
concatenated into that dimension.
|
|
|
|
Returns:
|
|
Array(s) representing the result of an all-gather along the axis
|
|
``axis_name``. Shapes are the same as ``x.shape``, but:
|
|
|
|
- when ``tiled`` is ``False``, there is a new dimension equal to the
|
|
size of axis ``axis_name`` in position ``axis``,
|
|
- when ``tiled`` is ``True``, the size of dimension in position ``axis``
|
|
is multiplied by the size of axis ``axis_name``.
|
|
|
|
For example, with 4 XLA devices available:
|
|
|
|
>>> x = np.arange(4)
|
|
>>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x)
|
|
>>> print(y)
|
|
[[0 1 2 3]
|
|
[0 1 2 3]
|
|
[0 1 2 3]
|
|
[0 1 2 3]]
|
|
|
|
An example of using axis_index_groups, groups split by even & odd device ids:
|
|
|
|
>>> x = np.arange(16).reshape(4, 4)
|
|
>>> print(x)
|
|
[[ 0 1 2 3]
|
|
[ 4 5 6 7]
|
|
[ 8 9 10 11]
|
|
[12 13 14 15]]
|
|
>>> def f(x):
|
|
... return jax.lax.all_gather(
|
|
... x, 'i', axis_index_groups=[[0, 2], [3, 1]])
|
|
>>> y = jax.pmap(f, axis_name='i')(x)
|
|
>>> print(y)
|
|
[[[ 0 1 2 3]
|
|
[ 8 9 10 11]]
|
|
[[12 13 14 15]
|
|
[ 4 5 6 7]]
|
|
[[ 0 1 2 3]
|
|
[ 8 9 10 11]]
|
|
[[12 13 14 15]
|
|
[ 4 5 6 7]]]
|
|
"""
|
|
return _all_gather_is_async(x, axis_name, axis_index_groups=axis_index_groups,
|
|
axis=axis, tiled=tiled, to=to, is_async=False)
|
|
|
|
def _all_gather_is_async(x, axis_name, *, axis_index_groups=None, axis=0,
|
|
tiled=False, to: str = 'varying', is_async: bool =
|
|
False):
|
|
_allowed_ag_to = {'varying', 'reduced', 'invarying'}
|
|
if to not in _allowed_ag_to:
|
|
raise ValueError(
|
|
"Got unexpected `to` value for `jax.lax.all_gather`. Allowed `to`"
|
|
f" values are: {_allowed_ag_to}")
|
|
if to == 'varying':
|
|
return _all_gather(x, axis_name, axis_index_groups=axis_index_groups,
|
|
axis=axis, tiled=tiled, is_async=is_async)
|
|
elif to == 'invarying':
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError
|
|
return all_gather_invariant(x, axis_name, axis=axis, tiled=tiled)
|
|
else:
|
|
assert to == 'reduced'
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError
|
|
return all_gather_reduced(x, axis_name, axis=axis, tiled=tiled,
|
|
is_async=is_async)
|
|
|
|
|
|
def _all_gather(x, axis_name, *, axis_index_groups, axis, tiled, is_async):
|
|
if not isinstance(axis_name, tuple):
|
|
axis_name = (axis_name,)
|
|
if not axis_name:
|
|
return x
|
|
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
|
axis_size = _axis_size(axis_name, axis_index_groups)
|
|
def bind(leaf):
|
|
leaf = insert_collective_pvary(axis_name, leaf)
|
|
prim = all_gather_start_p if is_async else all_gather_p
|
|
return prim.bind(
|
|
leaf,
|
|
all_gather_dimension=canonicalize_axis(
|
|
axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1),
|
|
axis_name=axis_name, axis_index_groups=axis_index_groups,
|
|
axis_size=axis_size, tiled=tiled)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
def _all_gather_impl(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
|
|
raise AssertionError("Unexpected call to _all_gather_impl")
|
|
|
|
def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
|
axis_index_groups, axis_size, tiled,
|
|
platform=None, is_async=False):
|
|
x_aval, = ctx.avals_in
|
|
out_aval, = ctx.avals_out
|
|
if is_async:
|
|
out_aval = out_aval.inner_aval
|
|
axis_context = ctx.module_context.axis_context
|
|
is_spmd = isinstance(axis_context, (SPMDAxisContext, ShardingContext))
|
|
if not tiled:
|
|
new_shape = list(x_aval.shape)
|
|
new_shape.insert(all_gather_dimension, 1)
|
|
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
|
|
x = hlo.broadcast_in_dim(
|
|
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
|
|
mlir.dense_int_array(broadcast_dimensions))
|
|
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
|
axis_index_groups)
|
|
if is_spmd:
|
|
# We want to emit the all-gather with global device IDs and a
|
|
# channel ID, as otherwise it interprets the devices as replicas instead
|
|
# of partitions - and XLA is configured with only a single replica.
|
|
other_args: dict[str, Any] = dict(
|
|
channel_handle=hlo.ChannelHandle.get(
|
|
mlir.COLLECTIVE_CHANNEL_ID, mlir.DEVICE_TO_DEVICE_TYPE),
|
|
use_global_device_ids=ir.BoolAttr.get(True))
|
|
else:
|
|
other_args = {}
|
|
|
|
out_type = mlir.aval_to_ir_type(out_aval)
|
|
if not is_async:
|
|
return hlo.AllGatherOp(
|
|
[out_type],
|
|
[x], all_gather_dim=mlir.i64_attr(all_gather_dimension),
|
|
replica_groups=_replica_groups_hlo(replica_groups),
|
|
**other_args).results
|
|
|
|
future_type = hlo.FutureType.get([out_type])
|
|
async_start = hlo.AsyncStartOp(future_type, [x])
|
|
block = async_start.regions[0].blocks.append(x.type)
|
|
with ir.InsertionPoint(block):
|
|
results = hlo.AllGatherOp(
|
|
[out_type],
|
|
[block.arguments[0]],
|
|
all_gather_dim=mlir.i64_attr(all_gather_dimension),
|
|
replica_groups=_replica_groups_hlo(replica_groups),
|
|
**other_args,
|
|
).results
|
|
hlo.return_(results)
|
|
return async_start.results
|
|
|
|
|
|
def collective_vma_rule(prim_name, axis_name, x_aval):
|
|
if not config._check_vma.value:
|
|
return frozenset()
|
|
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
|
if any(a not in x_aval.mat.varying for a in axis_name):
|
|
raise ValueError(
|
|
f"Collective {prim_name} must be applied to a device-varying "
|
|
f" type, but got {x_aval.mat.varying} for collective acting "
|
|
f"over axis name {axis_name}. Please open an issue at "
|
|
"https://github.com/jax-ml/jax/issues and as a temporary "
|
|
"workaround pass the check_vma=False argument to `jax.shard_map`")
|
|
return x_aval.mat.varying
|
|
|
|
def _all_gather_effectful_abstract_eval(
|
|
x_aval, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled
|
|
):
|
|
if not isinstance(axis_name, (list, tuple)):
|
|
axis_name = (axis_name,)
|
|
_check_axis_names(axis_name, 'all_gather')
|
|
check_unreduced_args([x_aval], axis_name, 'all_gather')
|
|
new_shape = list(x_aval.shape)
|
|
if tiled:
|
|
new_shape[all_gather_dimension] *= axis_size
|
|
else:
|
|
new_shape.insert(all_gather_dimension, axis_size)
|
|
out_vma = collective_vma_rule('all_gather', axis_name, x_aval)
|
|
return (x_aval.update(shape=new_shape,
|
|
manual_axis_type=x_aval.mat.update(varying=out_vma)),
|
|
{*map(core.NamedAxisEffect, axis_name)})
|
|
|
|
def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name,
|
|
axis_index_groups, axis_size, tiled):
|
|
return (psum_scatter(cts, axis_name=axis_name,
|
|
scatter_dimension=all_gather_dimension,
|
|
axis_index_groups=axis_index_groups,
|
|
tiled=tiled),)
|
|
|
|
def _all_gather_batcher(prim, vals_in, dims_in, *, all_gather_dimension, axis_name,
|
|
axis_index_groups, axis_size, tiled):
|
|
(x,), (d,) = vals_in, dims_in
|
|
if d is not batching.not_mapped:
|
|
if d <= all_gather_dimension:
|
|
all_gather_dimension += 1
|
|
elif not tiled: # Tiled all-gather doesn't modify the set of dimensions
|
|
d += 1
|
|
if prim is all_gather_p:
|
|
result = all_gather_p.bind(
|
|
x, all_gather_dimension=all_gather_dimension, axis_name=axis_name,
|
|
axis_index_groups=axis_index_groups, axis_size=axis_size,
|
|
tiled=tiled)
|
|
return result, d
|
|
else:
|
|
assert prim is all_gather_invariant_p
|
|
result = all_gather_invariant_p.bind(
|
|
x, all_gather_dimension=all_gather_dimension, axis_name=axis_name,
|
|
axis_size=axis_size, tiled=tiled)
|
|
return result, d
|
|
|
|
def _all_gather_batched_collective(prim, axis_data, vals_in, dims_in,
|
|
all_gather_dimension, axis_name,
|
|
axis_index_groups, axis_size, tiled):
|
|
frame_size, frame_name = axis_data.size, axis_data.name
|
|
if not isinstance(axis_name, tuple):
|
|
axis_name = (axis_name,)
|
|
(x,), (d,) = vals_in, dims_in
|
|
if d is None and axis_data.name not in axis_name:
|
|
kwargs = dict(all_gather_dimension=all_gather_dimension, axis_name=axis_name,
|
|
axis_size=axis_size, tiled=tiled)
|
|
out = (prim.bind(x, axis_index_groups=axis_index_groups, **kwargs)
|
|
if prim is all_gather_p else prim.bind(x, **kwargs))
|
|
return out, None
|
|
if frame_name not in axis_name:
|
|
return _all_gather_batcher(
|
|
prim, vals_in, dims_in, all_gather_dimension=all_gather_dimension,
|
|
axis_name=axis_name, axis_index_groups=axis_index_groups,
|
|
axis_size=axis_size, tiled=tiled)
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError("axis_index_groups not supported in vmap")
|
|
assert axis_size == frame_size, "axis size doesn't match"
|
|
if len(axis_name) > 1:
|
|
raise NotImplementedError("Please open a feature request!")
|
|
assert axis_name == (frame_name,), "batcher called with wrong axis name"
|
|
if d is batching.not_mapped:
|
|
out_shape = list(np.shape(x))
|
|
out_shape.insert(all_gather_dimension, axis_size)
|
|
broadcast_dims = [i for i in range(len(out_shape)) if i != all_gather_dimension]
|
|
y = lax.broadcast_in_dim(x, out_shape, broadcast_dims)
|
|
else:
|
|
y = _moveaxis(d, all_gather_dimension, x)
|
|
if tiled:
|
|
y = _foldaxis(all_gather_dimension, y)
|
|
return y, batching.not_mapped
|
|
|
|
all_gather_p = core.Primitive('all_gather')
|
|
all_gather_p.def_effectful_abstract_eval(_all_gather_effectful_abstract_eval)
|
|
all_gather_p.def_impl(_all_gather_impl)
|
|
mlir.register_lowering(all_gather_p, _all_gather_lowering)
|
|
for p in ("cuda", "rocm", "tpu"):
|
|
mlir.register_lowering(all_gather_p,
|
|
partial(_all_gather_lowering, platform=p),
|
|
platform=p)
|
|
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
|
|
batching.fancy_primitive_batchers[all_gather_p] = partial(
|
|
_all_gather_batched_collective, all_gather_p)
|
|
|
|
|
|
def all_gather_invariant(x, axis_name, *, axis: int = 0, tiled: bool = False):
|
|
"""Gather values of x across all replicas.
|
|
|
|
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
|
each leaf in the tree.
|
|
|
|
all_gather_invariant differs from all_gather in the following ways:
|
|
|
|
* all_gather_invariant is Varying -> Invariant.
|
|
For example: `out: f32[8] = all_gather_invariant(inp: f32[4]{V: x}, 'x')`
|
|
where the size of mesh axis `x` is 2.
|
|
While all_gather is Varying -> Varying.
|
|
|
|
* all_gather_invariant transposes to dynamic_slice which is
|
|
Invariant -> Varying. While all_gather transposes to reduce_scatter
|
|
which is Varying -> Varying.
|
|
"""
|
|
if not isinstance(axis_name, tuple):
|
|
axis_name = (axis_name,)
|
|
if not axis_name:
|
|
return x
|
|
axis_size = _axis_size(axis_name, None)
|
|
axes_ = frozenset(axis_name)
|
|
def bind(leaf):
|
|
in_vma = core.typeof(leaf).mat.varying
|
|
if vary_names := axes_ - in_vma:
|
|
leaf = pvary(leaf, tuple(vary_names))
|
|
return all_gather_invariant_p.bind(
|
|
leaf,
|
|
all_gather_dimension=canonicalize_axis(axis, np.ndim(leaf) if tiled else
|
|
np.ndim(leaf) + 1),
|
|
axis_name=axis_name, axis_size=axis_size, tiled=tiled)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
all_gather_invariant_p = core.Primitive('all_gather_invariant')
|
|
|
|
def _all_gather_invariant_effectful_abstract_eval(
|
|
x_aval, *, all_gather_dimension, axis_name, axis_size, tiled
|
|
):
|
|
_check_axis_names(axis_name, 'all_gather_invariant')
|
|
check_unreduced_args([x_aval], axis_name, 'all_gather_invariant')
|
|
new_shape = list(x_aval.shape)
|
|
if tiled:
|
|
new_shape[all_gather_dimension] *= axis_size
|
|
else:
|
|
new_shape.insert(all_gather_dimension, axis_size)
|
|
out_vma = frozenset(v for v in x_aval.mat.varying if v not in axis_name)
|
|
return (x_aval.update(shape=new_shape,
|
|
manual_axis_type=x_aval.mat.update(varying=out_vma)),
|
|
{*map(core.NamedAxisEffect, axis_name)})
|
|
|
|
all_gather_invariant_p.def_effectful_abstract_eval(
|
|
_all_gather_invariant_effectful_abstract_eval)
|
|
|
|
def _all_gather_invariant_impl(x, *, all_gather_dimension, axis_name, axis_size,
|
|
tiled):
|
|
raise NotImplementedError
|
|
all_gather_invariant_p.def_impl(_all_gather_invariant_impl)
|
|
|
|
|
|
def _all_gather_invariant_lowering(
|
|
ctx, x, *, all_gather_dimension, axis_name, axis_size, tiled, platform=None):
|
|
return _all_gather_lowering(
|
|
ctx, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name,
|
|
axis_index_groups=None, axis_size=axis_size, tiled=tiled,
|
|
platform=platform)
|
|
|
|
mlir.register_lowering(all_gather_invariant_p, _all_gather_invariant_lowering)
|
|
for p in ("cuda", "rocm", "tpu"):
|
|
mlir.register_lowering(all_gather_invariant_p,
|
|
partial(_all_gather_invariant_lowering, platform=p),
|
|
platform=p)
|
|
|
|
def _all_gather_invariant_transpose_rule(
|
|
cts, x, *, all_gather_dimension, axis_name, axis_size, tiled):
|
|
slice_size, rem = divmod(cts.shape[all_gather_dimension], axis_size)
|
|
assert not rem
|
|
idx = axis_index(axis_name) * slice_size
|
|
out = slicing.dynamic_slice_in_dim(
|
|
cts, idx, slice_size=slice_size, axis=all_gather_dimension)
|
|
return (out,) if tiled else (lax.squeeze(out, [all_gather_dimension]),)
|
|
ad.deflinear2(all_gather_invariant_p, _all_gather_invariant_transpose_rule)
|
|
|
|
def _all_gather_invariant_batched_collective(
|
|
axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_size,
|
|
tiled):
|
|
return _all_gather_batched_collective(
|
|
all_gather_invariant_p, axis_data, vals_in, dims_in, all_gather_dimension,
|
|
axis_name, None, axis_size, tiled)
|
|
batching.fancy_primitive_batchers[all_gather_invariant_p] = _all_gather_invariant_batched_collective
|
|
|
|
|
|
def _reduce_scatter_lowering(
|
|
prim, ctx, x,
|
|
*, scatter_dimension, axis_name,
|
|
axis_index_groups, axis_size, tiled):
|
|
x_aval, = ctx.avals_in
|
|
aval_out, = ctx.avals_out
|
|
scalar_aval = x_aval.update(shape=())
|
|
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
|
axis_index_groups)
|
|
scatter_out_shape = list(x_aval.shape)
|
|
scatter_out_shape[scatter_dimension] //= axis_size
|
|
axis_context = ctx.module_context.axis_context
|
|
is_spmd = isinstance(
|
|
axis_context,
|
|
(SPMDAxisContext, ShardingContext),
|
|
)
|
|
if is_spmd:
|
|
# We want to emit the all-gather with global device IDs and a
|
|
# channel ID, as otherwise it interprets the devices as replicas instead
|
|
# of partitions - and XLA is configured with only a single replica.
|
|
other_args: dict[str, Any] = dict(
|
|
channel_handle=hlo.ChannelHandle.get(
|
|
mlir.COLLECTIVE_CHANNEL_ID, mlir.DEVICE_TO_DEVICE_TYPE),
|
|
use_global_device_ids=ir.BoolAttr.get(True))
|
|
else:
|
|
other_args = {}
|
|
op = hlo.ReduceScatterOp(
|
|
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
|
|
x,
|
|
scatter_dimension=mlir.i64_attr(scatter_dimension),
|
|
replica_groups=_replica_groups_hlo(replica_groups),
|
|
**other_args)
|
|
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
|
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
|
with ir.InsertionPoint(reducer_block):
|
|
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
|
|
reducer_ctx = ctx.replace(primitive=None,
|
|
avals_in=[scalar_aval] * 2,
|
|
avals_out=[scalar_aval])
|
|
out_nodes = lower_reducer(reducer_ctx, *reducer_block.arguments)
|
|
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
|
|
|
if tiled:
|
|
return op.results
|
|
else:
|
|
out_type = mlir.aval_to_ir_type(aval_out)
|
|
return [hlo.reshape(out_type, op.result)]
|
|
|
|
|
|
def _reduce_scatter_effectful_abstract_eval(
|
|
x_aval, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled
|
|
):
|
|
if not isinstance(axis_name, (list, tuple)):
|
|
axis_name = (axis_name,)
|
|
_check_axis_names(axis_name, 'reduce_scatter')
|
|
check_unreduced_args([x_aval], axis_name, 'reduce_scatter')
|
|
new_shape = list(x_aval.shape)
|
|
scatter_dim_input_size = x_aval.shape[scatter_dimension]
|
|
if tiled:
|
|
if scatter_dim_input_size % axis_size != 0:
|
|
raise ValueError(f"tiled reduce_scatter operand scatter dimension size "
|
|
f"{scatter_dim_input_size} must be divisible by "
|
|
f"shard_count {axis_size}")
|
|
new_shape[scatter_dimension] = scatter_dim_input_size // axis_size
|
|
else:
|
|
if scatter_dim_input_size != axis_size:
|
|
raise ValueError(f"reduce_scatter operand scatter dimension size "
|
|
f"{scatter_dim_input_size} must match shard count "
|
|
f"{axis_size}")
|
|
del new_shape[scatter_dimension]
|
|
vma = collective_vma_rule('reduce_scatter', axis_name, x_aval)
|
|
return (x_aval.update(shape=new_shape,
|
|
manual_axis_type=x_aval.mat.update(varying=vma)),
|
|
{*map(core.NamedAxisEffect, axis_name)})
|
|
|
|
|
|
def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension,
|
|
axis_index_groups, axis_size, tiled):
|
|
return (all_gather(cts, axis_name=axis_name,
|
|
axis_index_groups=axis_index_groups,
|
|
axis=scatter_dimension, tiled=tiled),)
|
|
|
|
|
|
def _reduce_scatter_batcher(vals_in, dims_in, *, scatter_dimension, axis_name,
|
|
axis_index_groups, axis_size, tiled):
|
|
(x,), (d,) = vals_in, dims_in
|
|
if d <= scatter_dimension:
|
|
scatter_dimension += 1
|
|
elif not tiled: # Tiled all-scatter doesn't change the rank
|
|
d += 1
|
|
result = reduce_scatter_p.bind(
|
|
x,
|
|
scatter_dimension=scatter_dimension,
|
|
axis_name=axis_name,
|
|
axis_index_groups=axis_index_groups,
|
|
axis_size=axis_size,
|
|
tiled=tiled)
|
|
return result, d
|
|
|
|
def _reduce_scatter_collective(axis_data, vals_in, dims_in,
|
|
scatter_dimension, axis_name,
|
|
axis_index_groups, axis_size, tiled):
|
|
frame_size, frame_name = axis_data.size, axis_data.name
|
|
if not isinstance(axis_name, tuple):
|
|
axis_name = (axis_name,)
|
|
(x,), (d,) = vals_in, dims_in
|
|
if d is None and frame_name not in axis_name:
|
|
out = reduce_scatter_p.bind(
|
|
x, scatter_dimension=scatter_dimension, axis_name=axis_name,
|
|
axis_index_groups=axis_index_groups, axis_size=axis_size,
|
|
tiled=tiled)
|
|
return out, None
|
|
|
|
if frame_name not in axis_name:
|
|
return _reduce_scatter_batcher(
|
|
vals_in, dims_in, scatter_dimension=scatter_dimension,
|
|
axis_name=axis_name, axis_index_groups=axis_index_groups,
|
|
axis_size=axis_size, tiled=tiled)
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError("axis_index_groups not supported in vmap")
|
|
assert axis_size == frame_size, "axis size doesn't match"
|
|
if len(axis_name) > 1:
|
|
raise NotImplementedError("Please open a feature request!")
|
|
assert axis_name == (frame_name,), "batcher called with wrong axis name"
|
|
if d is batching.not_mapped:
|
|
y, dy = x * axis_size, scatter_dimension
|
|
else:
|
|
y, dy = lax.reduce(x, 0., lax.add, (d,)), scatter_dimension
|
|
if tiled:
|
|
y = _splitaxis(dy, axis_size, y)
|
|
return y, dy
|
|
|
|
|
|
reduce_scatter_p = core.Primitive("reduce_scatter")
|
|
reduce_scatter_p.def_effectful_abstract_eval(
|
|
_reduce_scatter_effectful_abstract_eval
|
|
)
|
|
ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
|
|
batching.fancy_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
|
|
|
|
mlir.register_lowering(reduce_scatter_p,
|
|
partial(_reduce_scatter_lowering, lax.add_p))
|
|
|
|
def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None,
|
|
tiled=False):
|
|
"""
|
|
Like ``psum(x, axis_name)`` but each device retains only part of the result.
|
|
|
|
For example, ``psum_scatter(x, axis_name, scatter_dimension=0, tiled=False)``
|
|
computes the same value as ``psum(x, axis_name)[axis_index(axis_name)]``, but
|
|
it is more efficient. Thus the ``psum`` result is left scattered along the
|
|
mapped axis.
|
|
|
|
One efficient algorithm for computing ``psum(x, axis_name)`` is to perform a
|
|
``psum_scatter`` followed by an ``all_gather``, essentially evaluating
|
|
``all_gather(psum_scatter(x, axis_name))``. So we can think of
|
|
``psum_scatter`` as "the first half" of a ``psum``.
|
|
|
|
Args:
|
|
x: array(s) with a mapped axis named ``axis_name``.
|
|
axis_name: hashable Python object used to name a mapped axis (see the
|
|
:func:`jax.pmap` documentation for more details).
|
|
scatter_dimension: a positional axis into which the all-reduce result along
|
|
``axis_name`` will be scattered.
|
|
axis_index_groups: optional list of lists of integers containing axis
|
|
indices. For example, for an axis of size 4,
|
|
``axis_index_groups=[[0, 1], [2, 3]]`` would run reduce-scatter over the
|
|
first two and the last two axis indices. Groups must cover all axis
|
|
indices exactly once, and all groups must be the same size.
|
|
tiled: boolean representing whether to use rank-preserving 'tiled' behavior.
|
|
When ``False`` (the default value), the size of dimension in
|
|
``scatter_dimension`` must match the size of axis ``axis_name`` (or the
|
|
group size if ``axis_index_groups`` is given). After scattering the
|
|
all-reduce result along ``scatter_dimension``, the output is squeezed by
|
|
removing ``scatter_dimension``, so the result has lower rank than the
|
|
input. When ``True``, the size of dimension in ``scatter_dimension`` must
|
|
be divisible by the size of axis ``axis_name`` (or the group size if
|
|
``axis_index_groups`` is given), and the ``scatter_dimension`` axis is
|
|
preserved (so the result has the same rank as the input).
|
|
|
|
Returns:
|
|
Array(s) with the similar shape as ``x``, except the size of dimension in
|
|
position ``scatter_dimension`` is divided by the size of axis ``axis_name``
|
|
(when ``tiled=True``), or the dimension in position ``scatter_dimension`` is
|
|
eliminated (when ``tiled=False``).
|
|
|
|
For example, with 4 XLA devices available:
|
|
|
|
>>> x = np.arange(16).reshape(4, 4)
|
|
>>> print(x)
|
|
[[ 0 1 2 3]
|
|
[ 4 5 6 7]
|
|
[ 8 9 10 11]
|
|
[12 13 14 15]]
|
|
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i'), axis_name='i')(x)
|
|
>>> print(y)
|
|
[24 28 32 36]
|
|
|
|
if using tiled:
|
|
|
|
>>> y = jax.pmap(lambda x: jax.lax.psum_scatter(x, 'i', tiled=True), axis_name='i')(x)
|
|
>>> print(y)
|
|
[[24]
|
|
[28]
|
|
[32]
|
|
[36]]
|
|
|
|
An example of using axis_index_groups:
|
|
|
|
>>> def f(x):
|
|
... return jax.lax.psum_scatter(
|
|
... x, 'i', axis_index_groups=[[0, 2], [3, 1]], tiled=True)
|
|
>>> y = jax.pmap(f, axis_name='i')(x)
|
|
>>> print(y)
|
|
[[ 8 10]
|
|
[20 22]
|
|
[12 14]
|
|
[16 18]]
|
|
"""
|
|
return _psum_scatter_is_async(x, axis_name,
|
|
scatter_dimension=scatter_dimension,
|
|
axis_index_groups=axis_index_groups,
|
|
tiled=tiled, is_async=False)
|
|
|
|
def _psum_scatter_is_async(x, axis_name, *, scatter_dimension=0,
|
|
axis_index_groups=None, tiled=False, is_async=False):
|
|
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
|
# TODO(yashkatariya): Remove this handling and remove_size_one_mesh_axis_from_type
|
|
# generally from JAX.
|
|
axes = _maybe_skip_one_sized_axes(axes)
|
|
if not axes:
|
|
return x
|
|
def bind(leaf):
|
|
from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.psum_scatter')
|
|
if from_ == 'unreduced':
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError
|
|
return unreduced_psum_scatter(
|
|
leaf, axes, scatter_dimension=scatter_dimension, tiled=tiled,
|
|
is_async=is_async)
|
|
else:
|
|
return _psum_scatter(leaf, axes, scatter_dimension=scatter_dimension,
|
|
axis_index_groups=axis_index_groups, tiled=tiled,
|
|
is_async=is_async)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
def _psum_scatter(x, axis_name, *, scatter_dimension, axis_index_groups, tiled,
|
|
is_async):
|
|
if not isinstance(axis_name, tuple):
|
|
axis_name = (axis_name,)
|
|
if not axis_name:
|
|
return x
|
|
axis_size = _axis_size(axis_name, axis_index_groups)
|
|
axis_index_groups = _canonicalize_axis_index_groups(axis_index_groups)
|
|
def bind(leaf):
|
|
leaf = insert_collective_pvary(axis_name, leaf)
|
|
prim = reduce_scatter_start_p if is_async else reduce_scatter_p
|
|
return prim.bind(
|
|
leaf, axis_name=axis_name, scatter_dimension=scatter_dimension,
|
|
axis_index_groups=axis_index_groups, axis_size=axis_size, tiled=tiled)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
|
|
def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
|
from jax._src.shard_map import shard_map # pytype: disable=import-error
|
|
|
|
if isinstance(axis_name, tuple):
|
|
assert axis_name, 'empty axis name'
|
|
if len(axis_name) > 1:
|
|
raise NotImplementedError(
|
|
'`axis_index` translation rule does not support multiple axis names.')
|
|
axis_name, = axis_name
|
|
if axis_name not in axis_env.names:
|
|
raise NameError(f"unbound axis name: {axis_name}")
|
|
axis_context = ctx.module_context.axis_context
|
|
axis_pos = list(axis_env.names).index(axis_name)
|
|
|
|
# For partial auto, enter into a fully manual shard_map.
|
|
if (isinstance(axis_context, SPMDAxisContext) and
|
|
axis_context.manual_axes and
|
|
axis_context.manual_axes != frozenset(axis_context.mesh.axis_names)):
|
|
if axis_env.sizes[axis_pos] == 1:
|
|
return hlo.constant(ir.DenseElementsAttr.get(np.asarray(0, dtype=np.int32))) # pyrefly: ignore[no-matching-overload]
|
|
def f():
|
|
return axis_index_p.bind(axis_name=axis_name)
|
|
return mlir.lower_fun(
|
|
lambda: [shard_map(f, check_vma=False, in_specs=(),
|
|
out_specs=P())()])(ctx)[0]
|
|
|
|
nreplicas = axis_env.nreps // math.prod(axis_env.sizes)
|
|
div = mlir.ir_constant(
|
|
np.array(
|
|
nreplicas * math.prod(axis_env.sizes[axis_pos + 1 :]), dtype=np.uint32
|
|
)
|
|
)
|
|
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
|
|
if isinstance(axis_context, (ShardingContext, SPMDAxisContext)):
|
|
device_id = hlo.partition_id()
|
|
else:
|
|
device_id = hlo.replica_id()
|
|
unsigned_index = hlo.remainder(hlo.divide(device_id, div), mod)
|
|
return hlo.convert(
|
|
ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)),
|
|
unsigned_index)
|
|
|
|
def _axis_index_lowering(ctx, *, axis_name):
|
|
return [_build_axis_index_lowering_hlo(ctx, axis_name,
|
|
ctx.module_context.axis_env)]
|
|
|
|
def _axis_index_effectful_abstract_eval(*, axis_name):
|
|
effect = {core.NamedAxisEffect(axis_name)}
|
|
axis_name = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
|
_check_axis_names(axis_name, 'axis_index')
|
|
mesh = get_abstract_mesh()
|
|
sharding = NamedSharding(mesh, P())
|
|
vma = ((frozenset(axis_name) if mesh._any_axis_manual else frozenset())
|
|
if config._check_vma.value else frozenset())
|
|
out_mat = core.ManualAxisType(varying=vma)
|
|
out_aval = ShapedArray((), np.int32, sharding=sharding, manual_axis_type=out_mat)
|
|
return out_aval, effect
|
|
|
|
def _axis_index_batcher(axis_data, vals_in, dims_in, *, axis_name):
|
|
axes = tuple(axis_name) if isinstance(axis_name, (tuple, list)) else (axis_name,)
|
|
if axis_data.name not in axes:
|
|
return axis_index_p.bind(axis_name=axis_name), None
|
|
return lax.iota(np.int32, axis_data.size), 0
|
|
|
|
axis_index_p = core.Primitive('axis_index')
|
|
axis_index_p.def_impl(partial(dispatch.apply_primitive, axis_index_p))
|
|
mlir.register_lowering(axis_index_p, _axis_index_lowering)
|
|
axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval)
|
|
batching.fancy_primitive_batchers[axis_index_p] = _axis_index_batcher
|
|
|
|
######################## psum_invariant_p ####################################
|
|
|
|
def bind_psum_invariant(leaf, *, axes, axis_index_groups, is_async):
|
|
if axis_index_groups is not None:
|
|
raise NotImplementedError
|
|
axes_ = frozenset(axes)
|
|
in_vma = core.typeof(leaf).mat.varying
|
|
arg = (pvary(leaf, tuple(pbroadcast_names))
|
|
if (pbroadcast_names := axes_ - in_vma) else leaf)
|
|
prim = psum_invariant_start_p if is_async else psum_invariant_p
|
|
return prim.bind(arg, axes=axes)
|
|
|
|
psum_invariant_p = core.Primitive('psum_invariant')
|
|
|
|
def _psum_invariant_impl(arg, *, axes):
|
|
return _allreduce_impl(psum_invariant_p, lax.reduce_sum, arg, axes=axes,
|
|
axis_index_groups=None)
|
|
psum_invariant_p.def_impl(_psum_invariant_impl)
|
|
|
|
def _psum_invariant_abstract_eval(name, aval, *, axes):
|
|
assert isinstance(axes, tuple)
|
|
_check_axis_names(axes, 'psum')
|
|
if not set(axes).intersection(aval.mat.varying):
|
|
raise ValueError(
|
|
"psum is a variant->invariant collective. This means that the axis"
|
|
" names mentioned in `axes` passed to `psum` must be present in"
|
|
f" `jax.typeof(inp).mat.varying`. Got axes={axes} and"
|
|
f" jax.typeof(inp).mat.varying={aval.mat.varying}")
|
|
if any(isinstance(a, int) for a in axes):
|
|
raise ValueError(f'psum_invariant does not accept integer axes. Got {axes}')
|
|
|
|
named_axes = tuple(axis for axis in axes if not isinstance(axis, int))
|
|
core.check_avals_context_mesh([aval], name)
|
|
check_unreduced_args([aval], axes, name)
|
|
vma = frozenset(a for a in aval.mat.varying if a not in named_axes)
|
|
out_aval = aval.update(manual_axis_type=aval.mat.update(varying=vma))
|
|
return out_aval, {core.NamedAxisEffect(axis) for axis in named_axes}
|
|
psum_invariant_p.def_effectful_abstract_eval(
|
|
partial(_psum_invariant_abstract_eval, psum_invariant_p.name))
|
|
|
|
def _psum_invariant_lowering_rule(ctx, arg, *, axes):
|
|
return _allreduce_lowering(lax.add_p, lax.reduce_sum, ctx, arg, axes=axes,
|
|
axis_index_groups=None)
|
|
mlir.register_lowering(psum_invariant_p, _psum_invariant_lowering_rule)
|
|
|
|
def _psum_invariant_batching_rule(axis_data, vals_in, dims_in, axes):
|
|
return _batched_reduction_collective(
|
|
psum_invariant_p, lambda v, axis_size: axis_size * v,
|
|
axis_data, vals_in, dims_in, axes, None)
|
|
batching.fancy_primitive_batchers[psum_invariant_p] = _psum_invariant_batching_rule
|
|
|
|
def _psum_invariant_transpose_rule(cts, arg, *, axes):
|
|
assert ad.is_undefined_primal(arg)
|
|
return (core.pvary(cts, axis_name=axes),)
|
|
ad.deflinear2(psum_invariant_p, _psum_invariant_transpose_rule)
|
|
|
|
########################### pvary ##################################
|
|
|
|
def _raise_valueerror(name, arg, *, axes):
|
|
raise ValueError(f'{name} should be called under jax.shard_map.')
|
|
|
|
core.pvary_p.def_impl(partial(_raise_valueerror, 'pvary'))
|
|
mlir.register_lowering(core.pvary_p, lambda ctx, x, *, axes: [x])
|
|
|
|
def _pvary_abstract_eval(aval, *, axes):
|
|
_check_axis_names(axes, 'pvary')
|
|
check_unreduced_args([aval], axes, 'pvary')
|
|
assert isinstance(axes, tuple)
|
|
if set(axes).intersection(aval.mat.varying):
|
|
raise ValueError(
|
|
"pvary is a invariant->variant collective. This means that the axis"
|
|
" names mentioned in `axes` passed to `pvary` must not be present in"
|
|
f" `jax.typeof(inp).mat.varying`. Got axes={axes} and"
|
|
f" jax.typeof(inp)={aval}")
|
|
out_vma = aval.mat.varying.union(frozenset(axes))
|
|
return aval.update(sharding=aval.sharding.update(mesh=get_abstract_mesh()),
|
|
manual_axis_type=aval.mat.update(varying=out_vma))
|
|
core.pvary_p.def_abstract_eval(_pvary_abstract_eval)
|
|
|
|
def _pvary_transpose_rule(cts, arg, *, axes):
|
|
assert ad.is_undefined_primal(arg)
|
|
return (psum_invariant_p.bind(cts, axes=axes),)
|
|
ad.deflinear2(core.pvary_p, _pvary_transpose_rule)
|
|
|
|
def _pvary_batcher(vals_in, dims_in, *, axes):
|
|
if any(type(axis) is int for axis in axes):
|
|
raise NotImplementedError
|
|
(x,), (d,) = vals_in, dims_in
|
|
y = core.pvary_p.bind(x, axes=axes)
|
|
return y, d
|
|
batching.primitive_batchers[core.pvary_p] = _pvary_batcher
|
|
|
|
####################### all_gather_reduced ###########################
|
|
|
|
# Varying -> Reduced collective
|
|
def all_gather_reduced(x, axis_name, *, axis: int = 0, tiled: bool = False, is_async: bool = False):
|
|
if not isinstance(axis_name, tuple):
|
|
axis_name = (axis_name,)
|
|
if not axis_name:
|
|
return x
|
|
axis_size = _axis_size(axis_name, None)
|
|
def bind(leaf):
|
|
prim = all_gather_reduced_start_p if is_async else all_gather_reduced_p
|
|
return prim.bind(
|
|
leaf,
|
|
all_gather_dimension=canonicalize_axis(
|
|
axis, np.ndim(leaf) if tiled else np.ndim(leaf) + 1),
|
|
axis_name=axis_name, axis_size=axis_size, tiled=tiled)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
all_gather_reduced_p = core.Primitive('all_gather_reduced')
|
|
|
|
def _all_gather_reduced_effectful_abstract_eval(
|
|
x_aval, *, all_gather_dimension, axis_name, axis_size, tiled
|
|
):
|
|
_check_axis_names(axis_name, 'all_gather_reduced')
|
|
if not x_aval.mat.varying:
|
|
raise ValueError('all_gather_reduced only accepts inputs that are'
|
|
f' varying. Got {x_aval.str_short(True)}')
|
|
# If the intersection between x.mat.varying and axis_name is empty, error
|
|
if not (x_aval.mat.varying & set(axis_name)):
|
|
raise ValueError(
|
|
'all_gather_reduced is a Varying -> Reduced collective. This means '
|
|
f'that the {axis_name=} passed to `all_gather_reduced` must be present '
|
|
f'in jax.typeof(x).mat.varying={x_aval.mat.varying}')
|
|
if x_aval.mat.reduced & set(axis_name):
|
|
raise ValueError(
|
|
"all_gather_reduced's input cannot be reduced across the axis_name"
|
|
f" provided. Got x={x_aval.str_short(True)} and {axis_name=}")
|
|
|
|
new_shape = list(x_aval.shape)
|
|
if tiled:
|
|
new_shape[all_gather_dimension] *= axis_size
|
|
else:
|
|
new_shape.insert(all_gather_dimension, axis_size)
|
|
|
|
new_reduced = x_aval.mat.reduced | frozenset(axis_name)
|
|
out_vma = frozenset(v for v in x_aval.mat.varying if v not in axis_name)
|
|
out_mat = x_aval.mat.update(varying=out_vma, reduced=new_reduced)
|
|
return (x_aval.update(shape=new_shape, manual_axis_type=out_mat),
|
|
{*map(core.NamedAxisEffect, axis_name)})
|
|
all_gather_reduced_p.def_effectful_abstract_eval(
|
|
_all_gather_reduced_effectful_abstract_eval)
|
|
|
|
|
|
def _all_gather_reduced_impl(x, *, all_gather_dimension, axis_name, axis_size,
|
|
tiled):
|
|
raise NotImplementedError
|
|
all_gather_reduced_p.def_impl(_all_gather_reduced_impl)
|
|
|
|
|
|
def _all_gather_reduced_lowering(
|
|
ctx, x, *, all_gather_dimension, axis_name, axis_size, tiled,
|
|
platform=None, is_async=False):
|
|
return _all_gather_lowering(
|
|
ctx, x, all_gather_dimension=all_gather_dimension, axis_name=axis_name,
|
|
axis_index_groups=None, axis_size=axis_size, tiled=tiled,
|
|
platform=platform, is_async=is_async)
|
|
|
|
mlir.register_lowering(all_gather_reduced_p, _all_gather_reduced_lowering)
|
|
for p in ("cuda", "rocm", "tpu"):
|
|
mlir.register_lowering(all_gather_reduced_p,
|
|
partial(_all_gather_reduced_lowering, platform=p),
|
|
platform=p)
|
|
|
|
def _all_gather_reduced_transpose_rule(
|
|
cts, x, *, all_gather_dimension, axis_name, axis_size, tiled):
|
|
return (unreduced_psum_scatter(cts, axis_name=axis_name,
|
|
scatter_dimension=all_gather_dimension,
|
|
tiled=tiled),)
|
|
ad.deflinear2(all_gather_reduced_p, _all_gather_reduced_transpose_rule)
|
|
|
|
def _all_gather_reduced_batched_collective(
|
|
axis_data, vals_in, dims_in, all_gather_dimension, axis_name, axis_size,
|
|
tiled):
|
|
raise NotImplementedError(
|
|
"Please file an issue at https://github.com/jax-ml/jax/issues")
|
|
batching.fancy_primitive_batchers[all_gather_reduced_p] = _all_gather_reduced_batched_collective
|
|
|
|
####################### unreduced_psum_scatter ###########################
|
|
|
|
# Unreduced -> Varying collective
|
|
def unreduced_psum_scatter(x, axis_name, *, scatter_dimension=0, tiled=False,
|
|
is_async=False):
|
|
if not isinstance(axis_name, tuple):
|
|
axis_name = (axis_name,)
|
|
if not axis_name:
|
|
return x
|
|
axis_size = _axis_size(axis_name, None)
|
|
def bind(leaf):
|
|
prim = (
|
|
unreduced_reduce_scatter_start_p
|
|
if is_async
|
|
else unreduced_reduce_scatter_p
|
|
)
|
|
return prim.bind(
|
|
leaf, axis_name=axis_name, scatter_dimension=scatter_dimension,
|
|
axis_size=axis_size, tiled=tiled)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
unreduced_reduce_scatter_p = core.Primitive('unreduced_reduce_scatter')
|
|
|
|
def _unreduced_reduce_scatter_effectful_abstract_eval(
|
|
x_aval, *, axis_name, scatter_dimension, axis_size, tiled
|
|
):
|
|
_check_axis_names(axis_name, 'reduce_scatter')
|
|
if not x_aval.mat.unreduced:
|
|
raise ValueError('unreduced_psum_scatter only accepts inputs that are'
|
|
f' unreduced. Got {x_aval.str_short(True)}')
|
|
# If intersection between x.unreduced & axis_name is empty, error
|
|
if not (x_aval.mat.unreduced & frozenset(axis_name)):
|
|
raise ValueError(
|
|
"unreduced_psum_scatter is a Unreduced -> Varying collective. This"
|
|
f" means that the {axis_name=} passed to `unreduced_psum_scatter` must"
|
|
" be present in"
|
|
f" jax.typeof(x).mat.unreduced={x_aval.mat.unreduced}"
|
|
)
|
|
if x_aval.mat.varying & set(axis_name):
|
|
raise ValueError(
|
|
"unreduced_psum_scatter's input cannot be varying across the axis_name"
|
|
f" provided. Got x={x_aval.str_short(True)} and {axis_name=}")
|
|
|
|
new_shape = list(x_aval.shape)
|
|
scatter_dim_input_size = x_aval.shape[scatter_dimension]
|
|
if tiled:
|
|
if scatter_dim_input_size % axis_size != 0:
|
|
raise ValueError(f"tiled reduce_scatter operand scatter dimension size "
|
|
f"{scatter_dim_input_size} must be divisible by "
|
|
f"shard_count {axis_size}")
|
|
new_shape[scatter_dimension] = scatter_dim_input_size // axis_size
|
|
else:
|
|
if scatter_dim_input_size != axis_size:
|
|
raise ValueError(f"reduce_scatter operand scatter dimension size "
|
|
f"{scatter_dim_input_size} must match shard count "
|
|
f"{axis_size}")
|
|
del new_shape[scatter_dimension]
|
|
|
|
out_unreduced = frozenset(i for i in x_aval.mat.unreduced
|
|
if i not in axis_name)
|
|
out_vma = x_aval.mat.varying | set(axis_name)
|
|
out_mat = x_aval.mat.update(varying=out_vma, unreduced=out_unreduced)
|
|
return (x_aval.update(shape=new_shape, manual_axis_type=out_mat),
|
|
{*map(core.NamedAxisEffect, axis_name)})
|
|
unreduced_reduce_scatter_p.def_effectful_abstract_eval(
|
|
_unreduced_reduce_scatter_effectful_abstract_eval)
|
|
|
|
|
|
def _unreduced_reduce_scatter_impl(
|
|
x, *, axis_name, scatter_dimension, axis_size, tiled):
|
|
raise NotImplementedError
|
|
unreduced_reduce_scatter_p.def_impl(_unreduced_reduce_scatter_impl)
|
|
|
|
def _unreduced_reduce_scatter_transpose_rule(
|
|
cts, x, *, axis_name, scatter_dimension, axis_size, tiled):
|
|
return (all_gather_reduced(cts, axis_name=axis_name, axis=scatter_dimension,
|
|
tiled=tiled),)
|
|
ad.deflinear2(unreduced_reduce_scatter_p, _unreduced_reduce_scatter_transpose_rule)
|
|
|
|
def _unreduced_reduce_scatter_batcher(
|
|
axis_data, vals_in, dims_in, axis_name, scatter_dimension, axis_size,
|
|
tiled):
|
|
raise NotImplementedError(
|
|
"Please file an issue at https://github.com/jax-ml/jax/issues")
|
|
batching.fancy_primitive_batchers[unreduced_reduce_scatter_p] = _unreduced_reduce_scatter_batcher
|
|
|
|
def _unreduced_reduce_scatter_lowering(
|
|
prim, ctx, x, *, axis_name, scatter_dimension, axis_size, tiled):
|
|
return _reduce_scatter_lowering(
|
|
prim, ctx, x, axis_name=axis_name, scatter_dimension=scatter_dimension,
|
|
axis_size=axis_size, tiled=tiled, axis_index_groups=None)
|
|
mlir.register_lowering(unreduced_reduce_scatter_p,
|
|
partial(_unreduced_reduce_scatter_lowering, lax.add_p))
|
|
|
|
############################## unreduced_psum ###########################
|
|
|
|
# Unreduced -> Invariant collective
|
|
def unreduced_psum(x, axis_name, is_async=False):
|
|
if not isinstance(axis_name, (tuple, list)):
|
|
axis_name = (axis_name,)
|
|
if not axis_name:
|
|
return x
|
|
prim = unreduced_psum_start_p if is_async else unreduced_psum_p
|
|
return tree_util.tree_map(
|
|
lambda leaf: prim.bind(leaf, axes=tuple(axis_name)), x)
|
|
|
|
unreduced_psum_p = core.Primitive('unreduced_psum')
|
|
|
|
def _unreduced_psum_abstract_eval(aval, *, axes):
|
|
_check_axis_names(axes, 'psum')
|
|
if not aval.mat.unreduced:
|
|
raise ValueError('unreduced_psum only accepts inputs that are'
|
|
f' unreduced. Got {aval.str_short(True)}')
|
|
# If intersection between x.unreduced & axis_name is empty, error
|
|
if not (aval.mat.unreduced & frozenset(axes)):
|
|
raise ValueError(
|
|
"unreduced_psum is a Unreduced -> Invariant collective. This"
|
|
f" means that the {axes=} passed to `unreduced_psum` must"
|
|
" be present in"
|
|
f" jax.typeof(x).mat.unreduced={aval.mat.unreduced}")
|
|
if aval.mat.varying & set(axes):
|
|
raise ValueError(
|
|
"unreduced_psum's input cannot be varying across the "
|
|
f" axis_name provided. Got x={aval.str_short(True)} and {axes=}")
|
|
|
|
if any(isinstance(a, int) for a in axes):
|
|
raise ValueError('unreduced_psum does not accept integer axis_name.'
|
|
f' Got axis_name={axes}')
|
|
|
|
core.check_avals_context_mesh([aval], 'unreduced_psum')
|
|
out_mat = aval.mat.update(unreduced=frozenset(u for u in aval.mat.unreduced
|
|
if u not in axes))
|
|
out_aval = aval.update(manual_axis_type=out_mat)
|
|
return out_aval, {core.NamedAxisEffect(axis) for axis in axes}
|
|
unreduced_psum_p.def_effectful_abstract_eval(_unreduced_psum_abstract_eval)
|
|
|
|
def _unreduced_psum_lowering(ctx, arg, *, axes):
|
|
return _allreduce_lowering(lax.add_p, lax.reduce_sum, ctx, arg,
|
|
axes=axes, axis_index_groups=None)
|
|
mlir.register_lowering(unreduced_psum_p, _unreduced_psum_lowering)
|
|
|
|
def _unreduced_psum_batcher(axis_data, vals_in, dims_in, axes):
|
|
raise NotImplementedError
|
|
batching.fancy_primitive_batchers[unreduced_psum_p] = _unreduced_psum_batcher
|
|
|
|
def _unreduced_psum_transpose_rule(cts, arg, *, axes):
|
|
assert ad.is_undefined_primal(arg)
|
|
return (preduced(cts, axis_name=axes),)
|
|
ad.deflinear2(unreduced_psum_p, _unreduced_psum_transpose_rule)
|
|
|
|
############################## preduced #################################
|
|
|
|
# Invariant -> Reduced no-op cast. It's the transpose of unreduced_psum.
|
|
def preduced(x, axis_name):
|
|
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
|
if not axes:
|
|
return x
|
|
cur_mesh = get_abstract_mesh()
|
|
if not config._check_vma.value and all(a in cur_mesh.manual_axes for a in axes):
|
|
return x
|
|
new_axes = axes if cur_mesh.empty else core.order_wrt_mesh(cur_mesh, axes)
|
|
assert set(new_axes) == set(axes)
|
|
del axes
|
|
return tree_util.tree_map(lambda l: preduced_p.bind(l, axes=new_axes), x)
|
|
|
|
preduced_p = core.Primitive('preduced')
|
|
preduced_p.def_impl(partial(_raise_valueerror, 'preduced'))
|
|
mlir.register_lowering(preduced_p, lambda ctx, x, *, axes: [x])
|
|
|
|
def _preduced_abstract_eval(aval, *, axes):
|
|
assert isinstance(axes, tuple)
|
|
_check_axis_names(axes, 'preduced')
|
|
if aval.mat.varying.intersection(set(axes)):
|
|
raise ValueError(
|
|
"preduced is a Invariant->Reduced collective. This means that the"
|
|
" axis names mentioned in `axes` passed to `preduced` must not be"
|
|
f" present in `jax.typeof(inp).mat.varying`. Got axes={axes} and"
|
|
f" jax.typeof(inp).mat.varying={aval.mat.varying}")
|
|
if aval.mat.reduced & set(axes):
|
|
raise ValueError(
|
|
"preduced input cannot be reduced across the axis_name"
|
|
f" provided. Got x={aval.str_short(True)} and axis_name={axes}")
|
|
return aval.update(manual_axis_type=aval.mat.update(
|
|
reduced=aval.mat.reduced | frozenset(axes)))
|
|
preduced_p.def_abstract_eval(_preduced_abstract_eval)
|
|
|
|
def _preduced_transpose_rule(cts, arg, *, axes):
|
|
assert ad.is_undefined_primal(arg)
|
|
return (unreduced_psum(cts, axis_name=axes),)
|
|
ad.deflinear2(preduced_p, _preduced_transpose_rule)
|
|
|
|
def _preduced_batcher(vals_in, dims_in, *, axes):
|
|
raise NotImplementedError
|
|
batching.primitive_batchers[preduced_p] = _preduced_batcher
|
|
|
|
######################## vary_unreduced_cast #######################
|
|
|
|
# Varying -> Unreduced no-op cast
|
|
def vary_unreduced_cast(x, axis_name):
|
|
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
|
if not axis_name:
|
|
return x
|
|
cur_mesh = get_abstract_mesh()
|
|
if not config._check_vma.value and all(a in cur_mesh.manual_axes for a in axes):
|
|
return x
|
|
new_axes = axes if cur_mesh.empty else core.order_wrt_mesh(cur_mesh, axes)
|
|
assert set(new_axes) == set(axes)
|
|
del axes
|
|
return tree_util.tree_map(
|
|
lambda leaf: vary_unreduced_cast_p.bind(leaf, axes=new_axes), x)
|
|
|
|
vary_unreduced_cast_p = core.Primitive('vary_unreduced_cast_p')
|
|
vary_unreduced_cast_p.def_impl(partial(_raise_valueerror, 'vary_unreduced_cast'))
|
|
mlir.register_lowering(vary_unreduced_cast_p, lambda ctx, x, *, axes: [x])
|
|
|
|
def _vary_unreduced_cast_abstract_eval(aval, *, axes):
|
|
assert isinstance(axes, tuple)
|
|
_check_axis_names(axes, 'vary_unreduced_cast')
|
|
check_unreduced_args([aval], axes, 'vary_unreduced_cast')
|
|
if not aval.mat.varying:
|
|
raise ValueError('vary_unreduced_cast only accepts inputs that are'
|
|
f' varying. Got {aval.str_short(True)}')
|
|
# If the intersection between aval.mat.varying and axes is empty, error
|
|
if not (aval.mat.varying & set(axes)):
|
|
raise ValueError(
|
|
"vary_unreduced_cast is a Varying->Unreduced collective. This"
|
|
" means that the axis names mentioned in `axes` passed to"
|
|
" `vary_unreduced_cast` must be present in"
|
|
f" `jax.typeof(x).mat.varying`. Got axes={axes} and"
|
|
f" jax.typeof(x).mat.varying={aval.mat.varying}")
|
|
if aval.mat.unreduced & set(axes):
|
|
raise ValueError(
|
|
"vary_unreduced_cast input cannot be unreduced across the axis_name"
|
|
f" provided. Got x={aval.str_short(True)} and axis_name={axes}")
|
|
|
|
new_unreduced = aval.mat.unreduced | frozenset(axes)
|
|
out_vma = frozenset(i for i in aval.mat.varying if i not in axes)
|
|
return aval.update(manual_axis_type=aval.mat.update(
|
|
varying=out_vma, unreduced=new_unreduced))
|
|
vary_unreduced_cast_p.def_abstract_eval(_vary_unreduced_cast_abstract_eval)
|
|
|
|
def _vary_unreduced_cast_transpose_rule(cts, x, *, axes):
|
|
assert ad.is_undefined_primal(x)
|
|
return (core.reduced_vary_cast(cts, axis_name=axes),)
|
|
ad.deflinear2(vary_unreduced_cast_p, _vary_unreduced_cast_transpose_rule)
|
|
|
|
def _vary_unreduced_cast_batcher(vals_in, dims_in, *, axes):
|
|
raise NotImplementedError
|
|
batching.primitive_batchers[vary_unreduced_cast_p] = _vary_unreduced_cast_batcher
|
|
|
|
####################### reduced_vary_cast #############################
|
|
|
|
# Reduced -> Varying no-op cast
|
|
# Traceable defined in core.py to avoid circular imports
|
|
core.reduced_vary_cast_p.def_impl(
|
|
partial(_raise_valueerror, 'reduced_vary_cast'))
|
|
mlir.register_lowering(core.reduced_vary_cast_p, lambda ctx, x, *, axes: [x])
|
|
|
|
def _reduced_vary_cast_abstract_eval(aval, *, axes):
|
|
assert isinstance(axes, tuple)
|
|
_check_axis_names(axes, 'reduced_vary_cast')
|
|
if not aval.mat.reduced:
|
|
raise ValueError('reduced_vary_cast only accepts inputs that are'
|
|
f' reduced. Got {aval.str_short(True)}')
|
|
# If the intersection between aval.mat.reduced and axes is empty, error
|
|
if not (aval.mat.reduced & set(axes)):
|
|
raise ValueError(
|
|
"reduced_vary_cast is a Reduced->Varying collective. This"
|
|
" means that the axis names mentioned in `axes` passed to"
|
|
" `reduced_vary_cast` must be present in"
|
|
f" `jax.typeof(x).mat.reduced`. Got axes={axes} and"
|
|
f" jax.typeof(x).mat.reduced={aval.mat.reduced}")
|
|
if aval.mat.varying & set(axes):
|
|
raise ValueError(
|
|
"reduced_vary_cast input cannot be varying across the axis_name"
|
|
f" provided. Got x={aval.str_short(True)} and axis_name={axes}")
|
|
|
|
new_reduced = frozenset(i for i in aval.mat.reduced if i not in axes)
|
|
out_vma = aval.mat.varying | frozenset(axes)
|
|
return aval.update(manual_axis_type=aval.mat.update(
|
|
varying=out_vma, reduced=new_reduced))
|
|
core.reduced_vary_cast_p.def_abstract_eval(_reduced_vary_cast_abstract_eval)
|
|
|
|
def _reduced_vary_cast_transpose_rule(cts, x, *, axes):
|
|
assert ad.is_undefined_primal(x)
|
|
return (vary_unreduced_cast(cts, axis_name=axes),)
|
|
ad.deflinear2(core.reduced_vary_cast_p, _reduced_vary_cast_transpose_rule)
|
|
|
|
def _reduced_vary_cast_batcher(vals_in, dims_in, *, axes):
|
|
raise NotImplementedError
|
|
batching.primitive_batchers[core.reduced_vary_cast_p] = _reduced_vary_cast_batcher
|
|
|
|
################################## pcast #############################
|
|
|
|
def _get_from(aval, axes: tuple[AxisName, ...], name) -> str:
|
|
out = set()
|
|
for a in axes:
|
|
if a in aval.mat.varying:
|
|
out.add('varying')
|
|
elif a in aval.mat.unreduced:
|
|
out.add('unreduced')
|
|
elif a in aval.mat.reduced:
|
|
out.add('reduced')
|
|
else:
|
|
out.add('invarying')
|
|
|
|
if len(out) > 1:
|
|
raise ValueError(
|
|
f"{name} can only accept axis_name which corresponds to one of"
|
|
" varying, unreduced, reduced or invarying state of the input. Got"
|
|
f" input type: {aval}, axes: {axes} and input state: {out}")
|
|
o, = out
|
|
return o
|
|
|
|
|
|
_pcast_funcs = {
|
|
('invarying', 'varying'): core.pvary,
|
|
('invarying', 'reduced'): preduced,
|
|
('varying', 'unreduced'): vary_unreduced_cast,
|
|
('reduced', 'varying'): core.reduced_vary_cast,
|
|
}
|
|
|
|
_allowed_pcast_to = {'unreduced', 'reduced', 'varying'}
|
|
|
|
def pcast(x, axis_name, *, to: str):
|
|
if isinstance(axis_name, (set, frozenset)):
|
|
raise TypeError(f"{axis_name=} must be a tuple or a str. Got {axis_name}")
|
|
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
|
|
if not axis_name:
|
|
return x
|
|
|
|
if to not in _allowed_pcast_to:
|
|
raise ValueError(
|
|
"Got unexpected `to` value. Allowed `to` values are:"
|
|
f" {_allowed_pcast_to}")
|
|
|
|
def bind(leaf):
|
|
from_ = _get_from(core.typeof(leaf), axes, 'jax.lax.pcast')
|
|
func = _pcast_funcs.get((from_, to), None)
|
|
if func is None:
|
|
raise ValueError(f"Unsupported pcast from={from_}, {to=}")
|
|
return func(leaf, axes)
|
|
return tree_util.tree_map(bind, x)
|
|
|
|
######################### async ops #########################
|
|
|
|
# Asynchronous start primitives.
|
|
all_gather_start_p = core.Primitive("all_gather_start")
|
|
all_gather_reduced_start_p = core.Primitive("all_gather_reduced_start")
|
|
psum_start_p = core.Primitive("psum_start")
|
|
psum_invariant_start_p = core.Primitive("psum_invariant_start")
|
|
unreduced_psum_start_p = core.Primitive("unreduced_psum_start")
|
|
reduce_scatter_start_p = core.Primitive("reduce_scatter_start")
|
|
unreduced_reduce_scatter_start_p = core.Primitive("unreduced_reduce_scatter_start")
|
|
all_to_all_start_p = core.Primitive("all_to_all_start")
|
|
pbroadcast_start_p = core.Primitive("pbroadcast_start")
|
|
ppermute_start_p = core.Primitive("ppermute_start")
|
|
|
|
# Asynchronous start functions.
|
|
class Todo:
|
|
|
|
def __init__(self, x, done_fun):
|
|
self.x = x
|
|
self.done_fun = done_fun
|
|
|
|
def done(self):
|
|
return self.done_fun(self.x)
|
|
|
|
|
|
def all_gather_start(*args, **kwargs):
|
|
x = _all_gather_is_async(*args, **kwargs, is_async=True)
|
|
return Todo(x, all_gather_done_p.bind)
|
|
|
|
|
|
def psum_start(*args, **kwargs):
|
|
x = _psum_is_async(*args, **kwargs, is_async=True)
|
|
return Todo(x, psum_done_p.bind)
|
|
|
|
|
|
def psum_scatter_start(*args, **kwargs):
|
|
x = _psum_scatter_is_async(*args, **kwargs, is_async=True)
|
|
return Todo(x, reduce_scatter_done_p.bind)
|
|
|
|
|
|
def all_to_all_start(*args, **kwargs):
|
|
x = _all_to_all_is_async(*args, **kwargs, is_async=True)
|
|
return Todo(x, all_to_all_done_p.bind)
|
|
|
|
|
|
def pbroadcast_start(*args, **kwargs):
|
|
x = _pbroadcast_is_async(*args, **kwargs, is_async=True)
|
|
return Todo(x, pbroadcast_done_p.bind)
|
|
|
|
|
|
def ppermute_start(*args, **kwargs):
|
|
x = _ppermute_is_async(*args, **kwargs, is_async=True)
|
|
return Todo(x, ppermute_done_p.bind)
|
|
|
|
|
|
# Asynchronous start abstract eval.
|
|
def _start_abstract_eval(q):
|
|
def f(*args, **kwargs):
|
|
aval, effs = q.abstract_eval(*args, **kwargs)
|
|
return core.AbstractTodo(aval), effs
|
|
return f
|
|
|
|
for p, q in [
|
|
(all_gather_start_p, all_gather_p),
|
|
(all_gather_reduced_start_p, all_gather_reduced_p),
|
|
(psum_start_p, psum_p),
|
|
(psum_invariant_start_p, psum_invariant_p),
|
|
(unreduced_psum_start_p, unreduced_psum_p),
|
|
(reduce_scatter_start_p, reduce_scatter_p),
|
|
(unreduced_reduce_scatter_start_p, unreduced_reduce_scatter_p),
|
|
(all_to_all_start_p, all_to_all_p),
|
|
(pbroadcast_start_p, pbroadcast_p),
|
|
(ppermute_start_p, ppermute_p),
|
|
]:
|
|
p.def_effectful_abstract_eval(_start_abstract_eval(q))
|
|
|
|
# Asynchronous start lowering.
|
|
def _start_lowering(sync_lower):
|
|
"""Returns an async start lowering function given a synchronous lowering.
|
|
|
|
An async StableHLO collective looks like this:
|
|
|
|
> %f = "stablehlo.async_start"(%x) ({
|
|
> ^bb0(%arg: tensor<2x2xf32>):
|
|
> %tmp = "stablehlo.all_gather"(%arg) : (tensor<2x2xf32>) ->
|
|
tensor<4x2xf32>
|
|
> stablehlo.return %tmp : tensor<4x2xf32>
|
|
> }) : (tensor<2x2xf32>) -> !stablehlo.future<tensor<4x2xf32>>
|
|
> %y = "stablehlo.async_done"(%f) : (!stablehlo.future<tensor<4x2xf32>>) ->
|
|
tensor<4x2xf32>
|
|
|
|
There is an async_start op with a region that performs and returns the
|
|
synchronous collective. _start_lowering takes in a lowering function for the
|
|
synchronous collective and transforms it into a lowering function for the
|
|
async collective by wrapping everything in an async_start.
|
|
"""
|
|
|
|
def f(ctx, x, **kwargs):
|
|
(x_aval,) = ctx.avals_in # e.g., f32[2, 2]
|
|
(out_aval,) = ctx.avals_out # e.g., # AbstractTodo[f32[4, 2]]
|
|
inner_aval = out_aval.inner_aval # e.g., f32[4, 2]
|
|
inner_type = mlir.aval_to_ir_type(inner_aval) # e.g., <tensor<4x2xf32>
|
|
# e.g., !stablehlo.future<tensor<4x2xf32>>
|
|
future_type = hlo.FutureType.get([inner_type])
|
|
async_start = hlo.AsyncStartOp(future_type, [x])
|
|
block = async_start.regions[0].blocks.append(x.type)
|
|
with ir.InsertionPoint(block):
|
|
inner_ctx = ctx.replace(
|
|
primitive=None, avals_in=[x_aval], avals_out=[inner_aval]
|
|
)
|
|
results = sync_lower(inner_ctx, block.arguments[0], **kwargs)
|
|
hlo.return_(results)
|
|
return async_start.results
|
|
|
|
return f
|
|
|
|
|
|
def _reduce_scatter_start_lowering(ctx, x, *, tiled, **kwargs):
|
|
if not tiled:
|
|
# TODO(mwhittaker): When the output is not tiled, a reduce_scatter is
|
|
# lowered to two operations: a reduce_scatter and a reshape. Lowering the
|
|
# async version of this is tricky because we need to reshape after the
|
|
# future is resolved.
|
|
raise NotImplementedError("lowering reduce_scatter_start with tiled=False unimplemented")
|
|
lower = partial(_reduce_scatter_lowering, lax.add_p)
|
|
return _start_lowering(lower)(ctx, x, tiled=tiled, **kwargs)
|
|
|
|
|
|
def _unreduced_reduce_scatter_start_lowering(ctx, x, *, tiled, **kwargs):
|
|
if not tiled:
|
|
msg = (
|
|
"lowering unreduced_reduce_scatter_start with tiled=False unimplemented"
|
|
)
|
|
raise NotImplementedError(msg)
|
|
lower = partial(_unreduced_reduce_scatter_lowering, lax.add_p)
|
|
return _start_lowering(lower)(ctx, x, tiled=tiled, **kwargs)
|
|
|
|
|
|
mlir.register_lowering(
|
|
all_gather_start_p, partial(_all_gather_lowering, is_async=True)
|
|
)
|
|
for p in ("cuda", "rocm", "tpu"):
|
|
mlir.register_lowering(
|
|
all_gather_start_p,
|
|
partial(_all_gather_lowering, platform=p, is_async=True),
|
|
platform=p,
|
|
)
|
|
mlir.register_lowering(
|
|
all_gather_reduced_start_p,
|
|
partial(_all_gather_reduced_lowering, is_async=True),
|
|
)
|
|
for p in ("cuda", "rocm", "tpu"):
|
|
mlir.register_lowering(
|
|
all_gather_reduced_start_p,
|
|
partial(_all_gather_reduced_lowering, platform=p, is_async=True),
|
|
platform=p,
|
|
)
|
|
mlir.register_lowering(
|
|
psum_start_p,
|
|
_start_lowering(partial(_allreduce_lowering, lax.add_p, lax.reduce_sum)),
|
|
)
|
|
mlir.register_lowering(
|
|
psum_invariant_start_p, _start_lowering(_psum_invariant_lowering_rule)
|
|
)
|
|
mlir.register_lowering(
|
|
unreduced_psum_start_p, _start_lowering(_unreduced_psum_lowering)
|
|
)
|
|
mlir.register_lowering(reduce_scatter_start_p, _reduce_scatter_start_lowering)
|
|
mlir.register_lowering(
|
|
unreduced_reduce_scatter_start_p, _unreduced_reduce_scatter_start_lowering
|
|
)
|
|
mlir.register_lowering(
|
|
all_to_all_start_p, _start_lowering(_all_to_all_lowering)
|
|
)
|
|
mlir.register_lowering(
|
|
pbroadcast_start_p, _start_lowering(_pbroadcast_lowering), platform="gpu"
|
|
)
|
|
mlir.register_lowering(ppermute_start_p, _start_lowering(_ppermute_lowering))
|
|
|
|
# Asynchronous done primitives.
|
|
all_gather_done_p = core.Primitive("all_gather_done")
|
|
psum_done_p = core.Primitive("psum_done")
|
|
reduce_scatter_done_p = core.Primitive("reduce_scatter_done")
|
|
all_to_all_done_p = core.Primitive("all_to_all_done")
|
|
pbroadcast_done_p = core.Primitive("pbroadcast_done")
|
|
ppermute_done_p = core.Primitive("ppermute_done")
|
|
|
|
_dones_p = [
|
|
all_gather_done_p,
|
|
psum_done_p,
|
|
reduce_scatter_done_p,
|
|
all_to_all_done_p,
|
|
pbroadcast_done_p,
|
|
ppermute_done_p,
|
|
]
|
|
|
|
# Asynchronous done abstract eval and lowering.
|
|
def _done_abstract_eval(aval):
|
|
if not isinstance(aval, core.AbstractTodo):
|
|
raise TypeError(f"async done op got {aval}, want core.AbstractTodo")
|
|
return aval.inner_aval
|
|
|
|
for p in _dones_p:
|
|
p.def_abstract_eval(_done_abstract_eval)
|
|
mlir.register_lowering(p, lambda ctx, x: hlo.AsyncDoneOp(x).results)
|