# Copyright 2023 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains shared logic and abstractions for Pallas indexing ops.""" from __future__ import annotations import dataclasses import math import operator from typing import cast, Any, ClassVar, Union from jax._src import core from jax._src import pretty_printer as pp from jax._src import tree_util from jax._src.indexing import Slice, dslice, ds # noqa: F401 from jax._src.state import types as state_types # pytype: disable=import-error from jax._src.typing import Array from jax._src.util import merge_lists from jax._src.util import partition_list import numpy as np def _pp_slice(context: core.JaxprPpContext, dim, slc: Slice) -> str: start, size = slc.start, slc.size if isinstance(start, core.Var): start_str = core.pp_var(start, context) size_str = ( core.pp_var(size, context) if isinstance(size, core.Var) else str(size) ) return f"{start_str}:{start_str}+{size_str}" else: start_str = str(start) if start == 0: start_str = "" if isinstance(size, core.Var): size_str = core.pp_var(size, context) if start_str: return f"{start_str}:{start_str}+{size_str}" else: return f":{size_str}" else: _val = lambda x: x.val if isinstance(x, core.Literal) else x end = _val(start) + _val(size) end_str = "" if end == dim else str(end) return f"{start_str}:{end_str}" IntIndexer = Union[int, Array] DimIndexer = Union[IntIndexer, Slice] def unpack_ndindexer(indexer: NDIndexer) -> tuple[tuple[bool, ...], tuple[Slice, ...], tuple[IntIndexer, ...]]: # TODO(slebedev): Flip this to be ``is_slice_indexing`` and update callers. is_int_indexing = [not isinstance(i, Slice) for i in indexer.indices] slice_indexers, int_indexers = partition_list( is_int_indexing, indexer.indices) return tuple(is_int_indexing), tuple(slice_indexers), tuple(int_indexers) # pyrefly: ignore[bad-argument-type] def _maybe_concretize(x: Any): # This is roughly the same logic as core.concrete_or_error, but we avoid # calling that because constructing the ConcretizationTypeError can be # expensive as the size of the tracing context (i.e. the jaxpr) grows. return core.to_concrete_value(x) # This registry is used to allow hitypes that are being indexed to register # type transformation rules. indexer_transform_type_registry: set[type] = set() @tree_util.register_pytree_node_class @dataclasses.dataclass class NDIndexer(state_types.Transform): indices: tuple[DimIndexer, ...] shape: tuple[int, ...] int_indexer_shape: tuple[int | Array, ...] # Off by default to avoid doing validation during pytree operations. validate: bool = False def __post_init__(self): if len(self.indices) != len(self.shape): raise ValueError( f"`indices` must be the same length as `Ref` shape.: {self}." ) if not self.validate: return # We validate integer indexing shapes here for idx, s in zip(self.indices, self.shape): if isinstance(idx, Slice): start = idx.start if value := _maybe_concretize(start): if value >= s: raise ValueError(f"Out of bound slice: start={value}, dim={s}.") if size := _maybe_concretize(idx.size): if value + (size - 1) * idx.stride >= s: raise ValueError( f"Out of bound slice: start={value}, size={size}," f" stride={idx.stride}, dim={s}." ) continue # The shape of indexer integers should be broadcastable up to the # int_indexer_shape of the whole NDIndexer idx_shape = ( idx.shape if isinstance(idx, state_types.TransformedRef) else core.typeof(idx).shape ) if not idx_shape: if (value := _maybe_concretize(idx)) and value >= s: raise ValueError(f"Out of bound indexer: idx={value}, dim={s}.") # For ()-shaped indexers, we can broadcast no problm. continue # If we don't have a ()-shaped indexer, the rank must match # int_indexer_shape if len(idx_shape) != len(self.int_indexer_shape): raise ValueError( f"Indexer must have rank {len(idx_shape)}: {idx=} vs." f" {self.int_indexer_shape=}" ) # Here we check that the shapes broadcast. try: np.broadcast_shapes(idx_shape, self.int_indexer_shape) except ValueError as e: raise ValueError( f"Could not broadcast integer indexer: {idx=} vs." f" {self.int_indexer_shape=}" ) from e @property def is_dynamic_size(self): return any(isinstance(i, Slice) and i.is_dynamic_size for i in self.indices) def tree_flatten(self): flat_idx, idx_tree = tree_util.tree_flatten(self.indices) if not all(isinstance(i, int) for i in self.int_indexer_shape): return (*flat_idx, self.int_indexer_shape), (idx_tree, self.shape) else: return flat_idx, (idx_tree, self.shape, self.int_indexer_shape) @classmethod def tree_unflatten(cls, data, flat_idx): if len(data) == 3: idx_tree, shape, int_indexer_shape = data else: # The ``int_indexer_shape`` is dynamic. idx_tree, shape = data *flat_idx, int_indexer_shape = flat_idx indices = tree_util.tree_unflatten(idx_tree, flat_idx) return cls(tuple(indices), shape, int_indexer_shape) @classmethod def from_indices_shape(cls, indices, shape) -> NDIndexer: if not isinstance(indices, tuple): # TODO(slebedev): Consider requiring `indices` to be a Sequence. indices = (indices,) if num_ellipsis := sum(idx is ... for idx in indices): if num_ellipsis > 1: raise ValueError("Only one ellipsis is supported.") # Expand ... so that `indices` has the same length as `shape`. ip = next(i for i, idx in enumerate(indices) if idx is ...) indices = list(indices) indices[ip:ip+1] = [slice(None)] * (len(shape) - len(indices) + 1) indices = tuple(indices) if len(indices) > len(shape): raise ValueError("`indices` must not be longer than `shape`: " f"{indices=}, {shape=}") elif len(indices) < len(shape): # Pad `indices` to have the same length as `shape`. indices = (*indices, *[slice(None)] * (len(shape) - len(indices))) # Promote all builtin `slice`s to `Slice`. indices = tuple( Slice.from_slice(i, s) if isinstance(i, slice) else i for i, s in zip(indices, shape)) is_slice_indexing = [isinstance(i, Slice) for i in indices] if all(is_slice_indexing): return cls(indices, shape, (), validate=True) other_indexers, slice_indexers = partition_list(is_slice_indexing, indices) validate = True # We treat refs differently from scalars and arrays, because refs can have # a dynamic shape, making it impossible to statically determine the # broadcasted shape in the presence of other non-slice indexers. from jax._src.state import types as state_types # pytype: disable=import-error if ref_indexers := [ i for i in other_indexers if not isinstance(i, Slice) if isinstance(i, state_types.TransformedRef) or isinstance(core.typeof(i), state_types.AbstractRef) ]: # TODO(slebedev): Consider pushing these checks to lowering time. if len(ref_indexers) > 1: raise NotImplementedError("Multiple Ref indexers are not supported") if len(ref_indexers) != len(other_indexers): raise NotImplementedError( "Ref cannot be mixed with other non-slice indexers" ) [ref_indexer] = ref_indexers indexer_shape = ref_indexer.shape try: core.canonicalize_shape(indexer_shape) except TypeError: validate = False # The shape is dynamic. else: indexer_shapes = [core.typeof(i).shape for i in other_indexers] try: indexer_shape = np.broadcast_shapes(*indexer_shapes) except ValueError as e: # Raise a nicer error than the NumPy one. raise ValueError( "Cannot broadcast shapes for indexing: {indexer_shapes}" ) from e # Here we use the `broadcast_to` primitive instead of composing lax # primitives together because it is easier to lower in targets like # Triton/Mosaic. # # The local import avoids a circular dependency between primitives # and this module. from jax._src.state import primitives as sp # pytype: disable=import-error other_indexers = [ sp.broadcast_to(i, indexer_shape) for i in other_indexers # pyrefly: ignore[bad-argument-type] ] indices = tuple( merge_lists(is_slice_indexing, other_indexers, slice_indexers) ) return cls(indices, shape, indexer_shape, validate) @classmethod def make_trivial_indexer(cls, shape: tuple[int, ...]) -> NDIndexer: return NDIndexer.from_indices_shape( tuple(slice(0, e) for e in shape), shape, ) def get_indexer_shape(self) -> tuple[int | Array, ...]: is_int_indexing, slice_indexers, _ = unpack_ndindexer(self) slice_shape = tuple(s.size for s in slice_indexers) int_indexers_contiguous = bool( np.all(np.diff(np.where(is_int_indexing)[0]) == 1) ) if not int_indexers_contiguous: return self.int_indexer_shape + slice_shape has_int_indexers = any(is_int_indexing) if has_int_indexers: pos = is_int_indexing.index(True) return slice_shape[:pos] + self.int_indexer_shape + slice_shape[pos:] return slice_shape def get_indexer_shape_static(self) -> tuple[int, ...]: indexer_shape = self.get_indexer_shape() if any(not isinstance(d, int) for d in indexer_shape): raise ValueError("Indexer shape is not static") return cast(tuple[int, ...], indexer_shape) def transform_type(self, x: core.AbstractValue): match x: case state_types.AbstractRef(): return x.update(inner_aval=self.transform_type(x.inner_aval)) case core.ShapedArray(): self._validate_sharding(x.sharding) if self.is_dynamic_size: return DShapedArray(self.get_indexer_shape(), x.dtype, weak_type=x.weak_type) return x.update(shape=self.get_indexer_shape()) case _: if type(x) in indexer_transform_type_registry: assert hasattr(x, "transform_ndindexer") return x.transform_ndindexer(self) raise TypeError(f"Cannot transform type: {x}") def undo(self, x: core.AbstractValue): raise NotImplementedError def _validate_sharding(self, sharding): if all(p is None for p in sharding.spec): return # If there are explicit axes, we don't support changing the shape, so # we don't support int indexers and instead require all slices. if self.int_indexer_shape or not all( isinstance(idx, Slice) for idx in self.indices ): raise TypeError( "sharded ref (array reference) can only be indexed by " "slices, not integers" ) # Moreover, only allow trivial slice(None) slices on explicitly sharded # axes. Then the sharding stays the same. _, slice_indexers, _ = unpack_ndindexer(self) for i, (d, sl, s) in enumerate( zip(self.shape, slice_indexers, sharding.spec) ): if s is None: continue if not ( type(sl.start) is int and sl.start == 0 and type(sl.size) is int and sl.size == d and type(sl.stride) is int and sl.stride == 1 ): raise ValueError( "sharded ref (array reference) can only be sliced " f"along unsharded axes, but ref of shape {self.shape} " f"was sliced on axis {i}, which is sharded like {s}" ) def pretty_print(self, context: core.JaxprPpContext) -> pp.Doc: indices = [] for idx, dim in zip(self.indices, self.shape): if isinstance(idx, Slice): indices.append(_pp_slice(context, dim, idx)) else: indices.append(core.pp_var(idx, context, print_literal_dtype=False)) # pyrefly: ignore[bad-argument-type] return pp.concat([pp.text("["), pp.text(",".join(indices)), pp.text("]")]) class DShapedArray: def __init__(self, shape, dtype, weak_type=False): self.shape = shape self.dtype = core._dtype_object(dtype) self.weak_type = weak_type def lower_val(self, val): return [val] def raise_val(self, val): return val def lo_ty(self): return [self] def update(self, shape=None, dtype=None, weak_type=None): if shape is None: shape = self.shape if dtype is None: dtype = self.dtype if weak_type is None: weak_type = self.weak_type return DShapedArray(shape, dtype, weak_type) ndim = property(lambda self: len(self.shape)) size = property(lambda self: 0 if any(type(d) is int and d == 0 for d in self.shape) else math.prod(self.shape)) broadcast: ClassVar[core.aval_method | None] = None transpose: ClassVar[core.aval_method | None] = None reshape: ClassVar[core.aval_method | None] = None _iter: ClassVar[staticmethod | None] = None def __eq__(self, other): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type) def __hash__(self): return hash((self.shape, self.dtype, self.weak_type)) def __ne__(self, other): return not self == other def __repr__(self): wt_str = ", weak_type=True" if self.weak_type else "" return f'DShapedArray({self.str_short()}{wt_str})' def __str__(self): wt_str = "~" if self.weak_type else "" return f'{wt_str}{self.str_short()}' def str_short(self): return (f"DShapedArray(shape={self.shape}, dtype={self.dtype}," f" weak_type={self.weak_type})") def _len(self, _): try: return self.shape[0] except IndexError as err: raise TypeError("len() of unsized object") from err # same as numpy error def update_weak_type(self, weak_type): return self.update(weak_type=weak_type) _bool = core.concretization_function_error(bool) _int = core.concretization_function_error(int, True) _float = core.concretization_function_error(float, True) _complex = core.concretization_function_error(complex, True) _hex = core.concretization_function_error(hex) _oct = core.concretization_function_error(oct) _index = core.concretization_function_error(operator.index)