hand
This commit is contained in:
@@ -0,0 +1,580 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Union
|
||||
|
||||
from jax._src.util import use_cpp_class, cache, use_cpp_method, unzip3
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.mesh import AxisType
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src import sharding as jsharding
|
||||
import numpy as np
|
||||
|
||||
Shape = tuple[int, ...]
|
||||
Device = xc.Device
|
||||
Index = tuple[slice, ...]
|
||||
XLADeviceAssignment = Sequence[Device]
|
||||
|
||||
|
||||
class AUTO:
|
||||
|
||||
def __init__(self, mesh: mesh_lib.Mesh):
|
||||
self.mesh = mesh
|
||||
|
||||
def _to_sdy_sharding(self, ndim: int) -> SdyArray:
|
||||
dim_shardings = [SdyDim(axes=[], is_open=True)
|
||||
for _ in range(ndim)]
|
||||
return SdyArray(mesh_shape=self.mesh.shape_tuple,
|
||||
dim_shardings=dim_shardings)
|
||||
|
||||
@property
|
||||
def _device_assignment(self):
|
||||
return self.mesh._flat_devices_tuple
|
||||
|
||||
|
||||
class UnspecifiedValue:
|
||||
def __repr__(self):
|
||||
return "UnspecifiedValue"
|
||||
UNSPECIFIED = UnspecifiedValue()
|
||||
|
||||
|
||||
MeshAxisName = Any
|
||||
|
||||
"""
|
||||
ArrayMapping specifies how an ndarray should map to mesh axes.
|
||||
|
||||
Note that the ordering is crucial for the cases when this mapping is non-injective
|
||||
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
||||
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
||||
according to which chunks of the value along the repeated dimension will be assigned.
|
||||
|
||||
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
||||
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
||||
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
||||
that would mean that a flat list of chunks would get assigned to a flattened list of
|
||||
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
||||
mesh devices ndarray would have to be transposed before flattening and assignment.
|
||||
"""
|
||||
ArrayMapping = collections.OrderedDict[MeshAxisName, int]
|
||||
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTO, UnspecifiedValue]
|
||||
|
||||
|
||||
def _unpickle_named_sharding(mesh, spec, memory_kind, logical_device_ids):
|
||||
return NamedSharding(mesh, spec, memory_kind=memory_kind,
|
||||
_logical_device_ids=logical_device_ids)
|
||||
|
||||
|
||||
@use_cpp_class(xc.NamedSharding)
|
||||
class NamedSharding(jsharding.Sharding):
|
||||
r"""A :class:`NamedSharding` expresses sharding using named axes.
|
||||
|
||||
A :class:`NamedSharding` is a pair of a :class:`Mesh` of devices and
|
||||
:class:`PartitionSpec` which describes how to shard an array across that
|
||||
mesh.
|
||||
|
||||
A :class:`Mesh` is a multidimensional NumPy array of JAX devices,
|
||||
where each axis of the mesh has a name, e.g. ``'x'`` or ``'y'``.
|
||||
|
||||
A :class:`PartitionSpec` is a tuple, whose elements can be a ``None``,
|
||||
a mesh axis, or a tuple of mesh axes. Each element describes how an input
|
||||
dimension is partitioned across zero or more mesh dimensions. For example,
|
||||
``PartitionSpec('x', 'y')`` says that the first dimension of data
|
||||
is sharded across ``x`` axis of the mesh, and the second dimension is sharded
|
||||
across ``y`` axis of the mesh.
|
||||
|
||||
The `Distributed arrays and automatic parallelization`_
|
||||
and `Explicit Sharding`_ tutorials have more details and diagrams that
|
||||
explain how :class:`Mesh` and :class:`PartitionSpec` are used.
|
||||
|
||||
Args:
|
||||
mesh: A :class:`jax.sharding.Mesh` object.
|
||||
spec: A :class:`jax.sharding.PartitionSpec` object.
|
||||
memory_kind: A string indicating the memory kind of the sharding.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from jax.sharding import Mesh
|
||||
>>> from jax.sharding import PartitionSpec as P
|
||||
>>> mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
|
||||
>>> spec = P('x', 'y')
|
||||
>>> named_sharding = jax.sharding.NamedSharding(mesh, spec)
|
||||
|
||||
.. _Distributed arrays and automatic parallelization: https://docs.jax.dev/en/latest/parallel.html
|
||||
.. _Explicit Sharding: https://docs.jax.dev/en/latest/parallel.html
|
||||
"""
|
||||
|
||||
mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh
|
||||
spec: PartitionSpec
|
||||
_memory_kind: str | None
|
||||
_logical_device_ids: tuple[int, ...] | None
|
||||
|
||||
@use_cpp_method()
|
||||
def __init__(
|
||||
self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *,
|
||||
memory_kind: str | None = None, _logical_device_ids=None):
|
||||
self.mesh = mesh
|
||||
self.spec = spec
|
||||
self._memory_kind = memory_kind
|
||||
self._logical_device_ids = _logical_device_ids
|
||||
check_pspec(self.mesh, self.spec)
|
||||
|
||||
def __repr__(self):
|
||||
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
|
||||
ldi = ('' if self._logical_device_ids is None else
|
||||
f', logical_device_ids={self._logical_device_ids}')
|
||||
mesh_repr = f"{str(self.mesh)}"
|
||||
return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})'
|
||||
|
||||
def __reduce__(self):
|
||||
return (_unpickle_named_sharding,
|
||||
(self.mesh, self.spec, self.memory_kind, self._logical_device_ids))
|
||||
|
||||
@property
|
||||
def memory_kind(self) -> str | None:
|
||||
return self._memory_kind
|
||||
|
||||
@use_cpp_method()
|
||||
def __hash__(self):
|
||||
if not hasattr(self, '_hash'):
|
||||
self._hash = hash(
|
||||
(self.mesh, self.memory_kind, self.spec, self._logical_device_ids))
|
||||
return self._hash
|
||||
|
||||
@use_cpp_method()
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, NamedSharding):
|
||||
return False
|
||||
if self is other:
|
||||
return True
|
||||
if (self.spec != other.spec
|
||||
or self.memory_kind != other.memory_kind
|
||||
or self._logical_device_ids != other._logical_device_ids):
|
||||
return False
|
||||
return self.mesh is other.mesh or self.mesh == other.mesh
|
||||
|
||||
def check_compatible_aval(self, aval_shape: Shape) -> None:
|
||||
if len(aval_shape) < len(self.spec):
|
||||
extra_msg = (' For scalars the PartitionSpec should be P()'
|
||||
if len(aval_shape) == 0 else '')
|
||||
raise ValueError(
|
||||
f"Sharding {self} is only valid for values of rank at least "
|
||||
f"{len(self.spec)}, but was applied to a value of rank "
|
||||
f"{len(aval_shape)}.{extra_msg}")
|
||||
|
||||
@property
|
||||
def num_devices(self) -> int:
|
||||
return self.mesh.size
|
||||
|
||||
@property
|
||||
def device_set(self) -> set[Device]:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError(
|
||||
'device_set is not implemented for `jax.sharding.AbstractMesh`.')
|
||||
return self.mesh._flat_devices_set
|
||||
|
||||
@property
|
||||
def _device_assignment(self) -> XLADeviceAssignment:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError('_device_assignment is not implemented for'
|
||||
' `jax.sharding.AbstractMesh`.')
|
||||
return self.mesh._flat_devices_tuple
|
||||
|
||||
@property
|
||||
def is_fully_addressable(self) -> bool:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError('is_fully_addressable is not implemented for '
|
||||
'`jax.sharding.AbstractMesh`.')
|
||||
# return False if addressable_device_list is empty.
|
||||
return self._internal_device_list.is_fully_addressable
|
||||
|
||||
@property
|
||||
def _is_concrete(self) -> bool:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def addressable_devices(self) -> set[Device]:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
raise ValueError('addressable_devices is not implemented for '
|
||||
'`jax.sharding.AbstractMesh`.')
|
||||
# Override addressable devices because there is a high chance that the mesh
|
||||
# across multiple NamedSharding objects will be the same.
|
||||
return self.mesh._local_devices_set
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
if self.mesh.size == 1:
|
||||
return True
|
||||
array_mapping = get_array_mapping(self.spec)
|
||||
mesh_shape = self.mesh.shape
|
||||
num_partitions = 1
|
||||
for name in array_mapping:
|
||||
num_partitions *= mesh_shape[name]
|
||||
return num_partitions == 1
|
||||
|
||||
@functools.cached_property
|
||||
def replicated_axes(self) -> frozenset[MeshAxisName]:
|
||||
flat_spec = frozenset(
|
||||
s for s in flatten_spec(self.spec)
|
||||
if s is not None and s is not PartitionSpec.UNCONSTRAINED)
|
||||
return frozenset(self.mesh.axis_names) - (
|
||||
flat_spec | self.spec.unreduced | self.spec.reduced)
|
||||
|
||||
def with_memory_kind(self, kind: str) -> NamedSharding:
|
||||
return self.update(memory_kind=kind)
|
||||
|
||||
def update(self, **kwargs) -> NamedSharding:
|
||||
spec = kwargs.pop("spec", self.spec)
|
||||
if not isinstance(spec, PartitionSpec):
|
||||
spec = PartitionSpec(*spec)
|
||||
return NamedSharding(
|
||||
mesh=kwargs.pop("mesh", self.mesh),
|
||||
spec=spec,
|
||||
memory_kind=kwargs.pop("memory_kind", self.memory_kind),
|
||||
_logical_device_ids=kwargs.pop("_logical_device_ids",
|
||||
self._logical_device_ids))
|
||||
|
||||
def is_equivalent_to(self, other, ndim: int) -> bool:
|
||||
if (isinstance(self.mesh, mesh_lib.AbstractMesh) and
|
||||
isinstance(other.mesh, mesh_lib.AbstractMesh)):
|
||||
return jsharding.common_is_equivalent_to(self, other, ndim,
|
||||
check_devices=False)
|
||||
return jsharding.common_is_equivalent_to(self, other, ndim)
|
||||
|
||||
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
|
||||
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
|
||||
|
||||
def _to_sdy_sharding(self, num_dimensions: int) -> SdyArray:
|
||||
dim_shardings = [SdyDim(axes=[], is_open=False)
|
||||
for _ in range(num_dimensions)]
|
||||
for i, dim_spec in enumerate(self.spec):
|
||||
if dim_spec is PartitionSpec.UNCONSTRAINED:
|
||||
dim_shardings[i].is_open = True
|
||||
elif dim_spec is None:
|
||||
# Already empty and closed sharding.
|
||||
pass
|
||||
else:
|
||||
dim_spec = dim_spec if isinstance(dim_spec, tuple) else (dim_spec,)
|
||||
dim_shardings[i].axes = dim_spec
|
||||
return SdyArray(mesh_shape=self.mesh.shape_tuple,
|
||||
dim_shardings=dim_shardings,
|
||||
logical_device_ids=self._logical_device_ids,
|
||||
unreduced_axes=self.spec.unreduced)
|
||||
|
||||
NamedSharding.__module__ = 'jax.sharding'
|
||||
|
||||
def flatten_spec(spec):
|
||||
out = []
|
||||
for s in spec:
|
||||
if isinstance(s, tuple):
|
||||
out.extend(s)
|
||||
else:
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
def get_array_mapping(axis_resources):
|
||||
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
|
||||
return axis_resources
|
||||
d = collections.OrderedDict()
|
||||
for i, axes in enumerate(axis_resources):
|
||||
if axes is None or axes is PartitionSpec.UNCONSTRAINED:
|
||||
continue
|
||||
axes = axes if isinstance(axes, tuple) else (axes,)
|
||||
for axis in axes:
|
||||
d[axis] = i
|
||||
return d
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SdyDim:
|
||||
axes: Sequence[str]
|
||||
is_open: bool
|
||||
|
||||
def build(self) -> sdy.DimensionShardingAttr:
|
||||
return sdy.DimensionShardingAttr.get(
|
||||
[sdy.AxisRefAttr.get(axis) for axis in self.axes],
|
||||
is_closed=not self.is_open)
|
||||
|
||||
def __repr__(self):
|
||||
return f'SdyDim({self._custom_repr()})'
|
||||
|
||||
def _custom_repr(self):
|
||||
axes_repr = ', '.join(f"'{a}'" for a in self.axes)
|
||||
open_repr = ''
|
||||
if self.is_open:
|
||||
open_repr = ', ?' if self.axes else '?'
|
||||
return f'{{{axes_repr}{open_repr}}}'
|
||||
|
||||
def _get_axes(axes, mesh_shape):
|
||||
if not axes:
|
||||
return ()
|
||||
assert mesh_shape is not None
|
||||
# Sort wrt mesh axis names so order is deterministic and doesn't hang in
|
||||
# McJAX.
|
||||
return tuple(n for n, _ in mesh_shape if n in axes)
|
||||
|
||||
@dataclasses.dataclass(kw_only=True)
|
||||
class SdyArray:
|
||||
mesh_shape: tuple[tuple[str, int], ...] | None
|
||||
dim_shardings: Sequence[SdyDim]
|
||||
logical_device_ids: tuple[int, ...] | None = None
|
||||
replicated_axes: tuple[str, ...] = ()
|
||||
unreduced_axes: frozenset[str] = frozenset()
|
||||
|
||||
def build(self) -> sdy.TensorShardingAttr:
|
||||
if self.mesh_shape is None:
|
||||
mesh_attr = sdy.MeshAttr.get([])
|
||||
else:
|
||||
ldi = ([] if self.logical_device_ids is None else
|
||||
list(self.logical_device_ids))
|
||||
mesh_attr = sdy.MeshAttr.get(
|
||||
[sdy.MeshAxisAttr.get(name, size) for name, size in self.mesh_shape],
|
||||
ldi)
|
||||
|
||||
replicated_axes = _get_axes(self.replicated_axes, self.mesh_shape)
|
||||
unreduced_axes = _get_axes(self.unreduced_axes, self.mesh_shape)
|
||||
return sdy.TensorShardingAttr.get(
|
||||
mesh_attr,
|
||||
[dim_sharding.build() for dim_sharding in self.dim_shardings],
|
||||
replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in replicated_axes],
|
||||
unreduced_axes=[sdy.AxisRefAttr.get(axis) for axis in unreduced_axes])
|
||||
|
||||
def __repr__(self):
|
||||
dim_sharding_repr = ', '.join(
|
||||
d._custom_repr() for d in self.dim_shardings)
|
||||
device_id_repr = (f', device_ids={self.logical_device_ids}'
|
||||
if self.logical_device_ids is not None else '')
|
||||
rar = (f', replicated_axes={self.replicated_axes}'
|
||||
if self.replicated_axes else '')
|
||||
return f"SdyArray([{dim_sharding_repr}]{device_id_repr}{rar})"
|
||||
|
||||
|
||||
# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra
|
||||
# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)`
|
||||
def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArray, mesh):
|
||||
if mesh._any_axis_auto:
|
||||
dim_shardings, used_axes = [], []
|
||||
for d in sdy_sharding.dim_shardings:
|
||||
dim_shardings.append(SdyDim(axes=d.axes, is_open=True))
|
||||
used_axes.extend(d.axes)
|
||||
remaining_axes = set(mesh.axis_names) - set(used_axes)
|
||||
replicated_axes = tuple(r for r in remaining_axes
|
||||
if mesh._name_to_type[r] == mesh_lib.AxisType.Explicit)
|
||||
return SdyArray(mesh_shape=sdy_sharding.mesh_shape,
|
||||
dim_shardings=dim_shardings,
|
||||
logical_device_ids=sdy_sharding.logical_device_ids,
|
||||
replicated_axes=replicated_axes)
|
||||
return sdy_sharding
|
||||
|
||||
def non_one_sized_spec(spec, mesh) -> PartitionSpec:
|
||||
new_spec: list[Any] = []
|
||||
for s in spec:
|
||||
if s is None or s is PartitionSpec.UNCONSTRAINED:
|
||||
new_spec.append(s)
|
||||
elif isinstance(s, tuple):
|
||||
new_spec.append(tuple(i for i in s if mesh.shape[i] != 1))
|
||||
else:
|
||||
new_spec.append(None if mesh.shape[s] == 1 else s)
|
||||
unreduced = frozenset(u for u in spec.unreduced if mesh.shape[u] != 1)
|
||||
reduced = frozenset(r for r in spec.reduced if mesh.shape[r] != 1)
|
||||
return PartitionSpec(*new_spec, unreduced=unreduced, reduced=reduced)
|
||||
|
||||
def get_non_one_sized_mesh_spec(mesh, spec):
|
||||
spec = non_one_sized_spec(spec, mesh)
|
||||
axis_sizes, axis_names, axis_types = unzip3(
|
||||
[(s, n, t) for s, n, t in zip(mesh.axis_sizes, mesh.axis_names, mesh.axis_types)
|
||||
if s != 1])
|
||||
mesh = mesh_lib.AbstractMesh(axis_sizes, axis_names, axis_types)
|
||||
return mesh, spec
|
||||
|
||||
@cache(max_size=4096, trace_context_in_key=False)
|
||||
def named_sharding_to_xla_hlo_sharding(
|
||||
self, num_dimensions: int) -> xc.HloSharding:
|
||||
mesh, spec = get_non_one_sized_mesh_spec(self.mesh, self.spec)
|
||||
mesh_shape = mesh.shape
|
||||
array_mapping = get_array_mapping(spec)
|
||||
mesh_axis_pos = {name: i for i, name in enumerate(mesh.axis_names)}
|
||||
|
||||
special_axes = {}
|
||||
if (manual_axes := frozenset(mesh.manual_axes)):
|
||||
axis_names = mesh.axis_names
|
||||
for manual_axis in manual_axes:
|
||||
special_axes[axis_names.index(manual_axis)] = xc.OpSharding.Type.MANUAL
|
||||
|
||||
if (unreduced_axes := spec.unreduced):
|
||||
axis_names = mesh.axis_names
|
||||
for u in unreduced_axes:
|
||||
special_axes[axis_names.index(u)] = xc.OpSharding.Type.UNREDUCED
|
||||
|
||||
replicated_mesh_axes = []
|
||||
for i, (axis_name, axis_val) in enumerate(mesh_shape.items()):
|
||||
if axis_name not in array_mapping:
|
||||
replicated_mesh_axes.append((i, axis_val))
|
||||
|
||||
if len(replicated_mesh_axes) == len(mesh_shape) and not special_axes:
|
||||
return xc.HloSharding.replicate()
|
||||
|
||||
mesh_permutation = []
|
||||
new_mesh_shape = [1] * num_dimensions
|
||||
for name, pos in sorted(array_mapping.items(), key=lambda x: x[1]):
|
||||
new_mesh_shape[pos] *= mesh_shape[name]
|
||||
mesh_permutation.append(mesh_axis_pos[name])
|
||||
|
||||
last_tile_dims = []
|
||||
if replicated_mesh_axes:
|
||||
axes_by_type: dict[Any, list[int]] = collections.defaultdict(list)
|
||||
size_by_type = collections.defaultdict(lambda: 1)
|
||||
assert {x[0] for x in replicated_mesh_axes}.issuperset(set(special_axes.keys()))
|
||||
for i, size in replicated_mesh_axes:
|
||||
ty = special_axes.get(i, xc.OpSharding.Type.REPLICATED)
|
||||
axes_by_type[ty].append(i)
|
||||
size_by_type[ty] *= size
|
||||
for ty, axes in sorted(axes_by_type.items(), key=lambda x: x[0].value):
|
||||
last_tile_dims.append(ty)
|
||||
new_mesh_shape.append(size_by_type[ty])
|
||||
mesh_permutation.extend(axes)
|
||||
|
||||
# Explanation of the parameters of `HloSharding.iota_tile`.
|
||||
# This is the HloShardingV2 format:
|
||||
# * dims: How many ways each dimension is sharded.
|
||||
# Replicated/Manual dims are added added at the end
|
||||
# * reshape_dims: This is the just the shape of the mesh.
|
||||
# * transpose_perm: This is the order in which mesh axes in PartitionSpec
|
||||
# appear relative to mesh.axis_names order.
|
||||
# * subgroup_types: List of type of OpSharding. Type can be REPLICATED and MANUAL.
|
||||
# Let's see an example:
|
||||
# Consider input_shape=(8, 4, 2, 2), mesh={'a': 2, 'b': 2, 'c': 2, 'd': 2}
|
||||
# and partition_spec=P(None, ('d', 'b'), 'c').
|
||||
# Arguments to iota_tile will be:
|
||||
# dims = [1, 4, 2, 1, 2] # 'a' is replicated hence `2` is at the end.
|
||||
# reshape_dims = [2, 2, 2, 2]
|
||||
# transpose_perm = [3, 1, 2, 0] # 'a' is replicated hence 0 is at the end
|
||||
# subgroup_types = [xc.OpSharding.Type.REPLICATED]
|
||||
dims = new_mesh_shape
|
||||
reshape_dims = mesh.axis_sizes
|
||||
if self._logical_device_ids is None:
|
||||
return xc.HloSharding.iota_tile(
|
||||
dims=dims, reshape_dims=reshape_dims, transpose_perm=mesh_permutation,
|
||||
subgroup_types=last_tile_dims)
|
||||
else:
|
||||
return xc.HloSharding.subgroup_with_device_ordering(
|
||||
np.asarray(self._logical_device_ids)
|
||||
.reshape(dims).reshape(reshape_dims).transpose(mesh_permutation)
|
||||
.reshape(dims), subgroup_types=last_tile_dims)
|
||||
|
||||
|
||||
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
if not array_mapping:
|
||||
return PartitionSpec()
|
||||
max_index = -1
|
||||
reverse_map = collections.defaultdict(list)
|
||||
for axis, index in array_mapping.items():
|
||||
reverse_map[index].append(axis)
|
||||
if index > max_index:
|
||||
max_index = index
|
||||
partitions: list[MeshAxisName | None] = []
|
||||
for i in range(max_index + 1):
|
||||
axis = reverse_map[i]
|
||||
if axis:
|
||||
partitions.append(axis[0] if len(axis) == 1 else tuple(axis))
|
||||
else:
|
||||
partitions.append(None)
|
||||
return PartitionSpec(*partitions)
|
||||
|
||||
|
||||
@cache(max_size=128, trace_context_in_key=False)
|
||||
def check_pspec(mesh, spec, _manual_axes=frozenset()):
|
||||
_check_unique_resources(spec, "NamedSharding spec", mesh)
|
||||
_check_mesh_resource_axis(mesh, spec)
|
||||
_check_mesh_unreduced(mesh, spec)
|
||||
|
||||
class DuplicateSpecError(Exception):
|
||||
def __init__(self, message, mesh, pspec):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.mesh = mesh
|
||||
self.pspec = pspec
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.message}"
|
||||
|
||||
def _check_unique_resources(pspec: PartitionSpec, arg_name: str, mesh=None
|
||||
) -> None:
|
||||
resource_counts: dict[MeshAxisName, int] = {}
|
||||
duplicate = False
|
||||
for d in pspec:
|
||||
if d is PartitionSpec.UNCONSTRAINED or d is None:
|
||||
continue
|
||||
d = d if isinstance(d, tuple) else (d,)
|
||||
for resource in d:
|
||||
count = resource_counts.get(resource, 0)
|
||||
if count > 0:
|
||||
duplicate = True
|
||||
resource_counts[resource] = count + 1
|
||||
if duplicate:
|
||||
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
||||
raise DuplicateSpecError(
|
||||
message=(
|
||||
f'A single {arg_name} specification can map every mesh axis to at'
|
||||
f' most one positional dimension, but {pspec} has duplicate entries'
|
||||
f' for {mesh_lib.show_axes(multiple_uses)}'),
|
||||
mesh=mesh, pspec=pspec)
|
||||
|
||||
def _check_mesh_resource_axis(mesh, pspec):
|
||||
for p in pspec:
|
||||
if p is PartitionSpec.UNCONSTRAINED or p is None:
|
||||
continue
|
||||
p = p if isinstance(p, tuple) else (p,)
|
||||
for r in p:
|
||||
if r not in mesh.axis_names:
|
||||
raise ValueError(
|
||||
f"Resource axis: {r} of {pspec} "
|
||||
f"is not found in mesh: {tuple(mesh.shape.keys())}.")
|
||||
if (AxisType.Auto not in mesh.axis_types and
|
||||
PartitionSpec.UNCONSTRAINED in pspec):
|
||||
raise ValueError(
|
||||
f'{pspec} cannot contain'
|
||||
' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh'
|
||||
f' axis_types: {mesh.axis_types}')
|
||||
|
||||
def _check_mesh_unreduced(mesh, pspec):
|
||||
for u in pspec.unreduced:
|
||||
if u not in mesh.axis_names:
|
||||
raise ValueError(
|
||||
f'Unreduced axes {u} is not found in {mesh.axis_names=}. '
|
||||
f'Got {pspec=}')
|
||||
if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual):
|
||||
raise ValueError(
|
||||
'Unreduced axes can only refer to mesh axes that is of type'
|
||||
f' `Explicit`. Got unreduced axes: {pspec.unreduced} and'
|
||||
f' mesh: {mesh}')
|
||||
|
||||
for u in pspec.reduced:
|
||||
if u not in mesh.axis_names:
|
||||
raise ValueError(
|
||||
f'Reduced axes {u} is not found in {mesh.axis_names=}. '
|
||||
f'Got {pspec=}')
|
||||
if mesh._name_to_type[u] in (AxisType.Auto, AxisType.Manual):
|
||||
raise ValueError(
|
||||
'Reduced axes can only refer to mesh axes that is of type'
|
||||
f' `Explicit`. Got reduced axes: {pspec.reduced} and'
|
||||
f' mesh: {mesh}')
|
||||
Reference in New Issue
Block a user