# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Indexing code for jax.numpy.""" from __future__ import annotations import dataclasses import enum from functools import partial import operator import string from typing import Any, NamedTuple from collections.abc import Sequence import numpy as np from jax._src import api from jax._src import array from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import errors from jax._src import indexing from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lax import utils as lax_utils from jax._src.numpy import array_constructors from jax._src.numpy import einsum from jax._src.numpy import error as jnp_error from jax._src.numpy import lax_numpy from jax._src.numpy import ufuncs from jax._src.numpy import util from jax._src.partition_spec import PartitionSpec from jax._src.pjit import auto_axes from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.tree_util import tree_flatten, tree_unflatten, register_pytree_node_class from jax._src.typing import Array, ArrayLike, Index, StaticScalar from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update, unzip3 export = set_module('jax.numpy') # Internal utilities for parsing and validating NumPy-style indices. class IndexType(enum.Enum): """Enum for tracking the type of an index.""" NONE = "none" SLICE = "slice" ELLIPSIS = "ellipsis" INTEGER = "integer" BOOLEAN = "boolean" ARRAY = "array" DYNAMIC_SLICE = "dynamic_slice" @classmethod def from_index(cls, idx: Index) -> IndexType: """Create an IndexType enum from a supported JAX array index.""" if idx is None: return cls.NONE elif idx is Ellipsis: return cls.ELLIPSIS elif isinstance(idx, slice): return cls.SLICE elif isinstance(idx, indexing.Slice): return cls.DYNAMIC_SLICE elif _is_integer_index(idx): return cls.INTEGER elif _is_boolean_index(idx): return cls.BOOLEAN elif isinstance(idx, (Array, np.ndarray)): if dtypes.issubdtype(idx.dtype, np.integer): return cls.ARRAY else: raise TypeError( f"Indexer must have integer or boolean type, got indexer with type {idx.dtype}") elif isinstance(idx, str): # TODO(jakevdp): this TypeError is for backward compatibility. # We should switch to IndexError for consistency. raise TypeError(f"JAX does not support string indexing; got {idx=}") elif isinstance(idx, Sequence): if not idx: # empty indices default to float, so special-case this. return cls.ARRAY idx_aval = api.eval_shape(array_constructors.asarray, idx) if idx_aval.dtype == bool: return cls.BOOLEAN elif dtypes.issubdtype(idx_aval.dtype, np.integer): return cls.ARRAY else: raise TypeError( f"Indexer must have integer or boolean type, got indexer with type {idx_aval.dtype}") elif isinstance(idx, (float, complex, np.generic)): raise TypeError( f"Indexer must have integer or boolean type, got indexer with type {np.dtype(type(idx))}") else: raise IndexError("only integers, slices (`:`), ellipsis (`...`), newaxis (`None`)" f" and integer or boolean arrays are valid indices. Got {idx}") class ParsedIndex(NamedTuple): """Structure for tracking an indexer parsed within the context of an array shape.""" index: Index typ: IndexType consumed_axes: tuple[int, ...] def _parse_indices( indices: tuple[Index, ...], shape: tuple[int, ...], ) -> list[ParsedIndex]: """Parse indices in the context of an array shape. Args: indices: a tuple of user-supplied indices to be parsed. shape: the shape of the array being indexed. Returns: The list of parsed indices stored in :class:`ParsedIndex` objects. This list will have the same length as ``indices``. Raises: IndexError: if any unrecognized index types are present or if there are too many indices, or too many ellipses. """ # 1. go through indices to count the number of consumed dimensions. # This is required to determine the effect of any ellipses. dimensions_consumed: list[int] = [] ellipses_indices: list[int] = [] index_types: list[IndexType] = [] for i, idx in enumerate(indices): typ = IndexType.from_index(idx) index_types.append(typ) if typ == IndexType.NONE: dimensions_consumed.append(0) elif typ == IndexType.ELLIPSIS: # We don't yet know how many dimensions are consumed, so set to zero # for now and update later. dimensions_consumed.append(0) ellipses_indices.append(i) elif typ == IndexType.BOOLEAN: dimensions_consumed.append(np.ndim(idx)) # pyrefly: ignore[bad-argument-type] elif typ in [IndexType.INTEGER, IndexType.ARRAY, IndexType.SLICE, IndexType.DYNAMIC_SLICE]: dimensions_consumed.append(1) else: raise IndexError(f"Unrecognized index type: {typ}") # 2. Validate the consumed dimensions and ellipses. if len(ellipses_indices) > 1: raise IndexError("an index can only have a single ellipsis ('...')") total_consumed = sum(dimensions_consumed) if total_consumed > len(shape): raise IndexError(f"Too many indices: array is {len(shape)}-dimensional," f" but {total_consumed} were indexed") if ellipses_indices: dimensions_consumed[ellipses_indices[0]] = len(shape) - total_consumed # 3. Generate the final sequence of parsed indices. result: list[ParsedIndex] = [] current_dim = 0 for index, typ, n_consumed in safe_zip(indices, index_types, dimensions_consumed): consumed_axes = tuple(range(current_dim, current_dim + n_consumed)) current_dim += len(consumed_axes) result.append(ParsedIndex(index=index, typ=typ, consumed_axes=consumed_axes)) return result @register_pytree_node_class @dataclasses.dataclass(frozen=True, kw_only=True) class NDIndexer: """Object that implements NumPy-style indexing operations on top of JAX. Generally this will be constructed via the :meth:`NDIndexer.from_raw_indices` method. Attributes: shape: the shape of the array being indexed. indices: a list of :class:`ParsedIndex` objects. """ shape: tuple[int, ...] indices: list[ParsedIndex] @classmethod def from_raw_indices(cls, indices: Index | tuple[Index, ...], shape: tuple[int, ...]) -> NDIndexer: """Create an NDIndexer object from raw user-supplied indices.""" indices = eliminate_deprecated_list_indexing(indices) parsed = _parse_indices(indices, shape) return cls(shape=shape, indices=parsed) def validate_static_indices(self, normalize_indices: bool = True) -> None: """Check that all static integer indices are in-bounds. Raises an IndexError in case of out-of-bound indices """ for idx in self.indices: if idx.typ == IndexType.INTEGER: assert isinstance(idx.index, (int, np.integer)) i = operator.index(idx.index) axis, = idx.consumed_axes size = self.shape[axis] normed_idx = i + size if normalize_indices and i < 0 else i if not 0 <= normed_idx < size: raise IndexError(f"index {i} out of bounds for axis {axis} with size {size}" f" ({normalize_indices=})") def validate_slices(self) -> None: """Check that all slices have static start/stop/step values. Raises an IndexError in case of non-static entries. """ for position, idx in enumerate(self.indices): if idx.typ == IndexType.SLICE: assert isinstance(idx.index, slice) if not all(_is_slice_element_none_or_constant_or_symbolic(val) for val in [idx.index.start, idx.index.stop, idx.index.step]): raise IndexError("Slice entries must be static integers." f" Got {idx.index} at position {position}") @staticmethod def is_sharded(arr) -> bool: """Check whether the array is sharded.""" return isinstance(arr, array.ArrayImpl) and not arr.sharding.num_devices == 1 def has_partial_slices(self) -> bool: """Check whether the indexer contains partial slices. For sharded arrays, partial slices cannot automatically propagate sharding. """ for idx in self.indices: if idx.typ in [IndexType.INTEGER, IndexType.DYNAMIC_SLICE]: return True if idx.typ == IndexType.SLICE: slc = idx.index assert isinstance(slc, slice) axis, = idx.consumed_axes size = self.shape[axis] start, stop, step = slc.indices(self.shape[axis]) if abs(step) != 1 or abs(stop - start) != size: return True return False def expand_bool_indices(self) -> NDIndexer: """Returns a new NDIndexer with boolean indices replaced by array indices. The only exception are scalar boolean indices, which are left in-place. """ expanded_indices: list[ParsedIndex] = [] for position, idx in enumerate(self.indices): if idx.typ != IndexType.BOOLEAN: expanded_indices.append(idx) continue if not core.is_concrete(idx.index): # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete raise errors.NonConcreteBooleanIndexError(core.typeof(idx.index)) assert isinstance(idx.index, (bool, np.ndarray, Array, list)) if np.ndim(idx.index) == 0: # pyrefly: ignore[bad-argument-type] # Scalar booleans assert idx.consumed_axes == () expanded_indices.append(ParsedIndex(index=bool(idx.index), typ=idx.typ, consumed_axes=())) continue idx_shape = np.shape(idx.index) # pyrefly: ignore[no-matching-overload] expected_shape = [self.shape[i] for i in idx.consumed_axes] if not all(s1 in (0, s2) for s1, s2 in zip(idx_shape, expected_shape)): raise IndexError("boolean index did not match shape of indexed array in index" f" {position}: got {idx_shape}, expected {expected_shape}") expanded_indices_raw = np.where(np.asarray(idx.index)) expanded_indices.extend(ParsedIndex(index=i, typ=IndexType.ARRAY, consumed_axes=(axis,)) for i, axis in safe_zip(expanded_indices_raw, idx.consumed_axes)) return NDIndexer(shape=self.shape, indices=expanded_indices) def expand_scalar_bool_indices(self, sharding_spec: Any = None) -> tuple[NDIndexer, Any]: new_shape = list(self.shape) new_sharding_spec = list((None for _ in self.shape) if sharding_spec is None else sharding_spec) new_indices = list(self.indices) current_dim = 0 for i, idx in enumerate(self.indices): if idx.typ == IndexType.BOOLEAN and np.ndim(idx.index) == 0: # pyrefly: ignore[bad-argument-type] new_shape.insert(i, 1) new_sharding_spec.insert(i, None) new_indices[i] = ParsedIndex( np.arange(int(idx.index)), typ=IndexType.ARRAY, consumed_axes=(current_dim,)) # pyrefly: ignore[bad-argument-type] current_dim += 1 else: n_consumed = len(idx.consumed_axes) new_indices[i] = ParsedIndex( index=idx.index, typ=idx.typ, consumed_axes = tuple(range(current_dim, current_dim + n_consumed)) ) current_dim += n_consumed new_sharding_spec = None if sharding_spec is None else tuple(new_sharding_spec) return NDIndexer(indices=new_indices, shape=tuple(new_shape)), new_sharding_spec def convert_sequences_to_arrays(self) -> NDIndexer: new_indices = [ParsedIndex(lax_numpy.asarray(idx.index), typ=idx.typ, consumed_axes=idx.consumed_axes) if isinstance(idx.index, Sequence) else idx for idx in self.indices] return NDIndexer(indices=new_indices, shape=self.shape) def expand_ellipses(self) -> NDIndexer: """ Returns a new indexer with ellipsis and implicit trailing slices replaced by explicit empty slices. """ expanded: list[ParsedIndex] = [] consumed = 0 for idx in self.indices: consumed += len(idx.consumed_axes) if idx.typ == IndexType.ELLIPSIS: for axis in idx.consumed_axes: expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,))) else: expanded.append(idx) for axis in range(consumed, len(self.shape)): expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,))) return NDIndexer(shape=self.shape, indices=expanded) def normalize_indices(self) -> NDIndexer: new_indices: list[ParsedIndex] = [] for idx in self.indices: if idx.typ == IndexType.INTEGER: axis, = idx.consumed_axes size: ArrayLike = self.shape[axis] if isinstance(idx.index, np.unsignedinteger): normed_index: Index = idx.index else: normed_index = idx.index + size if idx.index < 0 else idx.index # pyrefly: ignore[bad-assignment, unsupported-operation] new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes)) elif idx.typ in [IndexType.ARRAY, IndexType.INTEGER]: assert isinstance(idx.index, (Array, np.ndarray)) axis, = idx.consumed_axes if dtypes.issubdtype(idx.index.dtype, np.unsignedinteger): normed_index = idx.index else: size = self.shape[axis] if core.is_constant_dim(size): size = lax._const(idx.index, size) else: size = lax.convert_element_type(core.dimension_as_value(size), idx.index.dtype) normed_index = lax.select(idx.index < 0, lax.add(idx.index, size), idx.index) new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes)) else: new_indices.append(idx) return NDIndexer(indices=new_indices, shape=self.shape) def to_static_slice( self, *, arr_is_sharded: bool = False, normalize_indices: bool = True, mode: str | slicing.GatherScatterMode | None) -> _StaticSliceIndexer: """Convert to StaticSliceIndexer data structure. If this is not possible, raise a ValueError, TypeError, or IndexError. """ if mode is None: parsed_mode = slicing.GatherScatterMode.PROMISE_IN_BOUNDS else: parsed_mode = slicing.GatherScatterMode.from_any(mode) if any(core.is_symbolic_dim(s) for s in self.shape): raise ValueError("mode='slice' is not valid for polymorphic shapes.") if parsed_mode not in [ slicing.GatherScatterMode.PROMISE_IN_BOUNDS, slicing.GatherScatterMode.CLIP]: raise ValueError("static_slice requires mode='promise_in_bounds' or mode='clip'") # Validation of the unmodified user indices. if parsed_mode == slicing.GatherScatterMode.PROMISE_IN_BOUNDS: self.validate_static_indices(normalize_indices=normalize_indices) self.validate_slices() # For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as # opposed to x[:]) lead to incorrect sharding semantics when computed via slice. # TODO(yashkatariya): fix slice with sharding if arr_is_sharded and self.has_partial_slices(): raise ValueError("static_slice with partial slices does not support nontrivial array sharding.") for position, pidx in enumerate(self.indices): if pidx.typ in [IndexType.INTEGER, IndexType.ELLIPSIS, IndexType.SLICE, IndexType.NONE]: pass elif pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.DYNAMIC_SLICE]: raise TypeError("static_slice: indices must be static scalars or slices." f" Got index of type {type(pidx.index)} at position {position}") else: raise TypeError(f"static_slice: unrecognized index {pidx.index} at position {position}.") # Now re-iterate to generate static slices. start_indices: list[int] = [] limit_indices: list[int] = [] strides: list[int] = [] rev_axes: list[int] = [] squeeze_axes: list[int] = [] newaxis_dims: list[int] = [] expanded = self.expand_ellipses() for pidx in expanded.indices: if pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.ELLIPSIS]: raise RuntimeError(f"Internal: unexpected index encountered: {pidx}") elif pidx.typ == IndexType.NONE: # Expanded axes indices are based on the rank of the array after slicing # (tracked by start_indices) and squeezing (tracked by squeeze_axes), and # expand_dims inserts dimensions in order, so we must also account for # previous expanded dimensions. newaxis_dims.append(len(start_indices) - len(squeeze_axes) + len(newaxis_dims) ) elif pidx.typ == IndexType.INTEGER: assert isinstance(pidx.index, (int, np.integer)) axis, = pidx.consumed_axes if core.definitely_equal(self.shape[axis], 0): # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {axis} with size 0") start_index = int(pidx.index) if normalize_indices and start_index < 0: start_index += self.shape[axis] # Normalization & validation have already been handled, so clip start_index # to valid range start_index = min(max(start_index, 0), self.shape[axis] - 1) start_indices.append(start_index) limit_indices.append(start_index + 1) strides.append(1) squeeze_axes.append(axis) elif pidx.typ == IndexType.SLICE: assert isinstance(pidx.index, slice) axis, = pidx.consumed_axes size = self.shape[axis] start, stop, stride = pidx.index.indices(size) if stride < 0: new_start = min(size, stop + 1 + abs(start - stop - 1) % abs(stride)) start_indices.append(new_start) limit_indices.append(max(new_start, start + 1)) strides.append(abs(stride)) rev_axes.append(axis) else: start_indices.append(start) limit_indices.append(max(start, stop)) strides.append(stride) else: raise TypeError(f"static_slice: unrecognized index {pidx.index}") return _StaticSliceIndexer( start_indices=start_indices, limit_indices=limit_indices, strides=None if all(s == 1 for s in strides) else strides, rev_axes=rev_axes, squeeze_axes=squeeze_axes, newaxis_dims=newaxis_dims, ) def to_dynamic_slice( self, *, arr_is_sharded: bool = False, normalize_indices: bool = True, mode: str | slicing.GatherScatterMode | None) -> _DynamicSliceIndexer: """Convert to DynamicSliceIndexer data structure. If this is not possible, raise a ValueError, TypeError, or IndexError. """ if mode is not None: parsed_mode = slicing.GatherScatterMode.from_any(mode) if parsed_mode not in [ slicing.GatherScatterMode.PROMISE_IN_BOUNDS, slicing.GatherScatterMode.CLIP]: raise ValueError("dynamic_slice requires mode='promise_in_bounds' or mode='clip'") # For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as # opposed to x[:]) lead to incorrect sharding semantics when computed via slice. # TODO(yashkatariya): fix slice with sharding if arr_is_sharded and self.has_partial_slices(): raise ValueError("dynamic_slice with partial slices does not support nontrivial array sharding.") for position, pidx in enumerate(self.indices): if pidx.typ in [IndexType.INTEGER, IndexType.ELLIPSIS, IndexType.NONE]: pass elif pidx.typ == IndexType.DYNAMIC_SLICE: assert isinstance(pidx.index, indexing.Slice) if pidx.index.stride != 1: raise TypeError("dynamic_slice: only unit steps supported in slice." f" Got {pidx.index} at position {position}") elif pidx.typ == IndexType.SLICE: assert isinstance(pidx.index, slice) if pidx.index.step is not None and pidx.index.step not in [-1, 1]: raise TypeError("dynamic_slice: only unit steps supported in slice." f" Got {pidx.index} at position {position}") elif pidx.typ == IndexType.ARRAY: if isinstance(pidx.index, Sequence) or np.shape(pidx.index) != (): # pyrefly: ignore[no-matching-overload] raise TypeError("dynamic_slice: only scalar indices allowed." f" Got index of type {type(pidx.index)} at position {position}") elif pidx.typ == IndexType.BOOLEAN: raise TypeError("dynamic_slice: indices must be scalars or slices." f" Got index of type {type(pidx.index)} at position {position}") else: raise TypeError(f"dynamic_slice: unrecognized index {pidx.index} at position {position}.") start_indices: list[ArrayLike] = [] slice_sizes: list[int] = [] rev_axes: list[int] = [] squeeze_axes: list[int] = [] newaxis_dims: list[int] = [] expanded = self.expand_ellipses() trivial_slicing = True for pidx in expanded.indices: if pidx.typ in [IndexType.BOOLEAN, IndexType.ELLIPSIS]: raise RuntimeError(f"Internal: unexpected index encountered: {pidx}") elif pidx.typ == IndexType.NONE: # Expanded axes indices are based on the rank of the array after slicing # (tracked by start_indices) and squeezing (tracked by squeeze_axes), and # expand_dims inserts dimensions in order, so we must also account for # previous expanded dimensions. newaxis_dims.append(len(start_indices) - len(squeeze_axes) + len(newaxis_dims)) elif pidx.typ in [IndexType.INTEGER, IndexType.ARRAY]: trivial_slicing = False index = lax_numpy.asarray(pidx.index) assert index.shape == () # Validated above. axis, = pidx.consumed_axes if core.definitely_equal(self.shape[axis], 0): # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {axis} with size 0") start_indices.append(index) slice_sizes.append(1) squeeze_axes.append(axis) elif pidx.typ == IndexType.SLICE: assert isinstance(pidx.index, slice) if pidx.index != slice(None): trivial_slicing = False axis, = pidx.consumed_axes size = self.shape[axis] start, stop, stride = pidx.index.indices(size) assert stride in [-1, 1] # validated above if stride < 0: new_start = stop + 1 + abs(start - stop - 1) % abs(stride) start_indices.append(new_start) slice_sizes.append(max(0, start + 1 - new_start)) rev_axes.append(axis) else: start_indices.append(start) slice_sizes.append(max(0, stop - start)) elif pidx.typ == IndexType.DYNAMIC_SLICE: assert isinstance(pidx.index, indexing.Slice) start_indices.append(pidx.index.start) slice_sizes.append(pidx.index.size) trivial_slicing = False else: raise TypeError(f"dynamic_slice: unrecognized index {pidx.index}") if len(start_indices) > 1: # We must be careful with dtypes because dynamic_slice requires all # start indices to have matching types. dt = lax_utils.int_dtype_for_shape(self.shape, signed=True) start_indices = [lax.convert_element_type(i, dt) for i in start_indices] return _DynamicSliceIndexer( start_indices=start_indices, slice_sizes=slice_sizes, rev_axes=rev_axes, squeeze_axes=squeeze_axes, newaxis_dims=newaxis_dims, normalize_indices=normalize_indices, trivial_slicing=trivial_slicing, ) def is_advanced_int_indexer(self): """Returns True if idx should trigger int array indexing, False otherwise.""" # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing return any(idx.typ in [IndexType.ARRAY, IndexType.BOOLEAN] and np.ndim(idx.index) > 0 # pyrefly: ignore[bad-argument-type] for idx in self.indices) def to_gather(self, x_sharding: NamedSharding | Any, normalize_indices: bool = True) -> _GatherIndexer: return _index_to_gather(self, x_sharding=x_sharding, normalize_indices=normalize_indices) def tree_flatten(self): # split dynamic and static indices def is_dynamic(i: ParsedIndex): # These index types are non-hashable and therefore must be dynamic. return i.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.DYNAMIC_SLICE] raw_dynamic_indices = [i.index if is_dynamic(i) else None for i in self.indices] static_metadata = [ ParsedIndex(index=None, typ=i.typ, consumed_axes=i.consumed_axes) if is_dynamic(i) else i for i in self.indices] return raw_dynamic_indices, (self.shape, static_metadata) @classmethod def tree_unflatten(cls, aux_data, children): shape, static_metadata = aux_data indices = [idx if dyn_index is None else ParsedIndex(dyn_index, idx.typ, idx.consumed_axes) for dyn_index, idx in safe_zip(children, static_metadata)] return cls(indices=indices, shape=shape) @export def take( a: ArrayLike, indices: ArrayLike, axis: int | None = None, out: None = None, mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, fill_value: StaticScalar | None = None, ) -> Array: """Take elements from an array. JAX implementation of :func:`numpy.take`, implemented in terms of :func:`jax.lax.gather`. JAX's behavior differs from NumPy in the case of out-of-bound indices; see the ``mode`` parameter below. Args: a: array from which to take values. indices: N-dimensional array of integer indices of values to take from the array. axis: the axis along which to take values. If not specified, the array will be flattened before indexing is applied. mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default ``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices; the ``fill_value`` argument gives control over this value. For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`. fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans. unique_indices: If True, the implementation will assume that the indices are unique after normalization of negative indices, which lets the compiler emit more efficient code during the backward pass. If set to True and normalized indices are not unique, the result is implementation-defined and may be non-deterministic. indices_are_sorted : If True, the implementation will assume that the indices are sorted in ascending order after normalization of negative indices, which can lead to more efficient execution on some backends. If set to True and normalized indices are not sorted, the output is implementation-defined. Returns: Array of values extracted from ``a``. See also: - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax. - :func:`jax.numpy.take_along_axis`: take values along an axis Examples: >>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([2, 0]) Passing no axis results in indexing into the flattened array: >>> jnp.take(x, indices) Array([3., 1.], dtype=float32) >>> x.ravel()[indices] # equivalent indexing syntax Array([3., 1.], dtype=float32) Passing an axis results ind applying the index to every subarray along the axis: >>> jnp.take(x, indices, axis=1) Array([[3., 1.], [6., 4.]], dtype=float32) >>> x[:, indices] # equivalent indexing syntax Array([[3., 1.], [6., 4.]], dtype=float32) Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`: >>> jnp.take(x, indices, axis=0) Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32) This default out-of-bound behavior can be adjusted using the ``mode`` parameter, for example, we can instead clip to the last valid value: >>> jnp.take(x, indices, axis=0, mode='clip') Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='clip') # equivalent indexing syntax Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32) """ return _take(a, indices, None if axis is None else operator.index(axis), out, mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value) @api.jit(static_argnames=('axis', 'mode', 'unique_indices', 'indices_are_sorted', 'fill_value')) def _take(a, indices, axis: int | None = None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None): if out is not None: raise NotImplementedError("The 'out' argument to jnp.take is not supported.") a, indices = util.ensure_arraylike("take", a, indices) if axis is None: a = a.ravel() axis_idx = 0 else: axis_idx = canonicalize_axis(axis, np.ndim(a)) if mode is None or mode == "fill": gather_mode = slicing.GatherScatterMode.FILL_OR_DROP # lax.gather() does not support negative indices, so we wrap them here indices = util._where(indices < 0, indices + a.shape[axis_idx], indices) elif mode == "raise": # TODO(phawkins): we have no way to report out of bounds errors yet. raise NotImplementedError("The 'raise' mode to jnp.take is not supported.") elif mode == "wrap": indices = ufuncs.mod(indices, lax._const(indices, a.shape[axis_idx])) gather_mode = slicing.GatherScatterMode.PROMISE_IN_BOUNDS elif mode == "clip": gather_mode = slicing.GatherScatterMode.CLIP else: raise ValueError(f"Invalid mode '{mode}' for np.take") index_dims = len(np.shape(indices)) slice_sizes = list(np.shape(a)) if slice_sizes[axis_idx] == 0: if indices.size != 0: raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.") return a if indices.size == 0: out_shape = (slice_sizes[:axis_idx] + list(indices.shape) + slice_sizes[axis_idx + 1:]) return lax.full_like(a, 0, shape=out_shape) slice_sizes[axis_idx] = 1 dnums = slicing.GatherDimensionNumbers( offset_dims=tuple( list(range(axis_idx)) + list(range(axis_idx + index_dims, len(a.shape) + index_dims - 1))), collapsed_slice_dims=(axis_idx,), start_index_map=(axis_idx,)) return slicing.gather(a, indices[..., None], dimension_numbers=dnums, slice_sizes=tuple(slice_sizes), mode=gather_mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, fill_value=fill_value) def _normalize_index(index, axis_size): """Normalizes an index value in the range [-N, N) to the range [0, N).""" if dtypes.issubdtype(dtypes.dtype(index), np.unsignedinteger): return index if core.is_constant_dim(axis_size): axis_size_val = lax._const(index, axis_size) else: axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size), dtypes.dtype(index)) if isinstance(index, (int, np.integer)): return lax.add(index, axis_size_val) if index < 0 else index else: return lax.select(index < 0, lax.add(index, axis_size_val), index) @export @api.jit(static_argnames=('axis', 'mode', 'fill_value')) def take_along_axis( arr: ArrayLike, indices: ArrayLike, axis: int | None = -1, mode: str | slicing.GatherScatterMode | None = None, fill_value: StaticScalar | None = None, ) -> Array: """Take elements from an array. JAX implementation of :func:`numpy.take_along_axis`, implemented in terms of :func:`jax.lax.gather`. JAX's behavior differs from NumPy in the case of out-of-bound indices; see the ``mode`` parameter below. Args: a: array from which to take values. indices: array of integer indices. If ``axis`` is ``None``, must be one-dimensional. If ``axis`` is not None, must have ``a.ndim == indices.ndim``, and ``a`` must be broadcast-compatible with ``indices`` along dimensions other than ``axis``. axis: the axis along which to take values. If not specified, the array will be flattened before indexing is applied. mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default ``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices. For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`. Returns: Array of values extracted from ``a``. See also: - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax. - :func:`jax.numpy.take`: take the same indices along every axis slice. Examples: >>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([[0, 2], ... [1, 0]]) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1., 3.], [5., 4.]], dtype=float32) >>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax Array([[1., 3.], [5., 4.]], dtype=float32) Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`: >>> indices = jnp.array([[1, 0, 2]]) >>> jnp.take_along_axis(x, indices, axis=0) Array([[ 4., 2., nan]], dtype=float32) >>> x.at[indices, jnp.arange(3)].get( ... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax Array([[ 4., 2., nan]], dtype=float32) ``take_along_axis`` is helpful for extracting values from multi-dimensional argsorts and arg reductions. For, here we compute :func:`~jax.numpy.argsort` indices along an axis, and use ``take_along_axis`` to construct the sorted array: >>> x = jnp.array([[5, 3, 4], ... [2, 7, 6]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 2, 0], [0, 2, 1]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[3, 4, 5], [2, 6, 7]], dtype=int32) Similarly, we can use :func:`~jax.numpy.argmin` with ``keepdims=True`` and use ``take_along_axis`` to extract the minimum value: >>> idx = jnp.argmin(x, axis=1, keepdims=True) >>> idx Array([[1], [0]], dtype=int32) >>> jnp.take_along_axis(x, idx, axis=1) Array([[3], [2]], dtype=int32) """ a, indices = util.ensure_arraylike("take_along_axis", arr, indices) index_dtype = indices.dtype idx_shape = np.shape(indices) if not dtypes.issubdtype(index_dtype, np.integer): raise TypeError("take_along_axis indices must be of integer type, got " f"{index_dtype}") if axis is None: if np.ndim(indices) != 1: msg = "take_along_axis indices must be 1D if axis=None, got shape {}" raise ValueError(msg.format(idx_shape)) a = a.ravel() axis = 0 rank = a.ndim if rank != np.ndim(indices): msg = "indices and arr must have the same number of dimensions; {} vs. {}" raise ValueError(msg.format(np.ndim(indices), a.ndim)) axis_int = canonicalize_axis(axis, rank) def replace(tup, val): lst = list(tup) lst[axis_int] = val return tuple(lst) axis_size = a.shape[axis_int] arr_shape = replace(a.shape, 1) out_shape = lax.broadcast_shapes(idx_shape, arr_shape) if axis_size == 0: return lax.full(out_shape, 0, a.dtype) index_dtype = lax_utils.int_dtype_for_dim(a.shape, signed=True) indices = lax.convert_element_type(indices, index_dtype) if mode != "promise_in_bounds": indices = _normalize_index(indices, axis_size) if mode == "one_hot": from jax import nn # pytype: disable=import-error hot = nn.one_hot(indices, axis_size, dtype=np.bool_) if a.ndim == 1: return einsum.einsum("...b,b->...", hot, a, preferred_element_type=a.dtype) if axis_int > len(string.ascii_letters) - 2: raise ValueError( "One Hot indexing is only supported for up to 50 leading dimensions." ) labels = "".join([string.ascii_letters[i] for i in range(axis_int)]) eq = labels + "y...z," + labels + "z...->" + labels + "y..." return einsum.einsum( eq, hot, a, precision=lax.Precision.HIGHEST, preferred_element_type=a.dtype, ) index_dims = [i for i, idx in enumerate(idx_shape) if i == axis_int or not core.definitely_equal(idx, 1)] gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,) gather_indices = lax.reshape(indices, gather_index_shape) slice_sizes = [] offset_dims = [] start_index_map = [] collapsed_slice_dims = [] operand_batching_dims = [] start_indices_batching_dims = [] # We will squeeze the array. i is the index of the unsqueezed shape, while # new_i is the index of the squeezed shape. j is the index of the gather # indices. dims_to_squeeze = [] new_i = 0 j = 0 for i in range(rank): if i == axis_int: slice_sizes.append(1) start_index_map.append(new_i) collapsed_slice_dims.append(new_i) new_i += 1 j += 1 elif core.definitely_equal(idx_shape[i], 1): # If idx_shape[i] == 1, we can just take the entirety of the arr's axis # and avoid forming an iota index. offset_dims.append(i) slice_sizes.append(arr_shape[i]) new_i += 1 elif core.definitely_equal(arr_shape[i], 1): # If the array dimension is 1 but the index dimension is not, we will # squeeze this dimension. dims_to_squeeze.append(i) j += 1 else: # Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both # array and index as batching so corresponding elements are gathered. if core.definitely_equal(arr_shape[i], 0): slice_sizes.append(0) else: slice_sizes.append(1) operand_batching_dims.append(new_i) start_indices_batching_dims.append(j) new_i += 1 j += 1 # Squeeze a to remove singleton dimensions. a = lax.squeeze(a, dims_to_squeeze) dnums = slicing.GatherDimensionNumbers( offset_dims=tuple(offset_dims), collapsed_slice_dims=tuple(collapsed_slice_dims), start_index_map=tuple(start_index_map), operand_batching_dims=tuple(operand_batching_dims), start_indices_batching_dims=tuple(start_indices_batching_dims)) return slicing.gather(a, gather_indices, dnums, tuple(slice_sizes), mode="fill" if mode is None else mode, fill_value=fill_value) def _make_along_axis_idx(shape, indices, axis): if axis < 0: axis += len(shape) return tuple_update(lax_numpy.indices(shape, sparse=True), axis, indices) @export @api.jit(static_argnames=('axis', 'inplace', 'mode')) def put_along_axis( arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis: int | None, inplace: bool = True, *, mode: str | None = None, ) -> Array: """Put values into the destination array by matching 1d index and data slices. JAX implementation of :func:`numpy.put_along_axis`. The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which is not possible for JAX's immutable arrays. The JAX version returns a modified copy of the input, and adds the ``inplace`` parameter which must be set to `False`` by the user as a reminder of this API difference. Args: arr: array into which values will be put. indices: array of indices at which to put values. values: array of values to put into the array. axis: the axis along which to put values. If not specified, the array will be flattened before indexing is applied. inplace: must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned. mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`. Returns: A copy of ``a`` with specified entries updated. See Also: - :func:`jax.numpy.put`: put elements into an array at given indices. - :func:`jax.numpy.place`: place elements into an array via boolean mask. - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing. - :func:`jax.numpy.take`: extract values from an array at given indices. - :func:`jax.numpy.take_along_axis`: extract values from an array along an axis. Examples: >>> from jax import numpy as jnp >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) >>> i = jnp.argmax(a, axis=1, keepdims=True) >>> print(i) [[1] [0]] >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) >>> print(b) [[10 99 20] [99 40 50]] """ if inplace: raise ValueError( "jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays" "are immutable. Pass inplace=False to instead return an updated array.") arr, indices, values = util.ensure_arraylike("put_along_axis", arr, indices, values) original_axis = axis original_arr_shape = arr.shape if axis is None: arr = arr.ravel() axis = 0 if not arr.ndim == indices.ndim: raise ValueError( "put_along_axis arguments 'arr' and 'indices' must have same ndim. Got " f"{arr.ndim=} and {indices.ndim=}." ) try: values = util._broadcast_to(values, indices.shape) except ValueError: raise ValueError( "put_along_axis argument 'values' must be broadcastable to 'indices'. Got " f"{values.shape=} and {indices.shape=}." ) idx = _make_along_axis_idx(arr.shape, indices, axis) result = arr.at[idx].set(values, mode=mode) if original_axis is None: result = result.reshape(original_arr_shape) return result ### Indexing def _is_integer_index(idx: Any) -> bool: return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_)) class IndexingStrategy(enum.Enum): AUTO = 'auto' GATHER = 'gather' SCATTER = 'scatter' STATIC_SLICE = 'static_slice' DYNAMIC_SLICE = 'dynamic_slice' def rewriting_take( arr: Array, idx: Index | tuple[Index, ...], *, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | slicing.GatherScatterMode | None = None, fill_value: ArrayLike | None = None, normalize_indices: bool = True, out_sharding: NamedSharding | PartitionSpec | None = None, strategy: IndexingStrategy = IndexingStrategy.AUTO, ) -> Array: # Computes arr[idx]. # All supported cases of indexing can be implemented as an XLA gather, # followed by an optional reverse and broadcast_in_dim. indexer = NDIndexer.from_raw_indices(idx, arr.shape) if not isinstance(strategy, IndexingStrategy): raise TypeError(f"Expected strategy to be IndexingStrategy; got {strategy}") if config.check_static_indices.value and (mode is None or slicing.GatherScatterMode.from_any(mode) == slicing.GatherScatterMode.PROMISE_IN_BOUNDS): indexer.validate_static_indices(normalize_indices=normalize_indices) if strategy == IndexingStrategy.STATIC_SLICE: static_slice_indexer = indexer.to_static_slice( arr_is_sharded=indexer.is_sharded(arr), normalize_indices=normalize_indices, mode=mode) return _static_slice(arr, static_slice_indexer) if strategy == IndexingStrategy.DYNAMIC_SLICE: dynamic_slice_indexer = indexer.to_dynamic_slice( arr_is_sharded=indexer.is_sharded(arr), normalize_indices=normalize_indices, mode=mode) return _dynamic_slice(arr, dynamic_slice_indexer) if strategy == IndexingStrategy.AUTO: # Attempt static slice first try: static_slice_indexer = indexer.to_static_slice( arr_is_sharded=indexer.is_sharded(arr), normalize_indices=normalize_indices, mode=mode) except (TypeError, ValueError, IndexError): pass else: return _static_slice(arr, static_slice_indexer) # Attempt dynamic slice next try: dynamic_slice_indexer = indexer.to_dynamic_slice( arr_is_sharded=indexer.is_sharded(arr), normalize_indices=normalize_indices, mode=mode) except (TypeError, ValueError, IndexError): pass else: return _dynamic_slice(arr, dynamic_slice_indexer) # In remaining cases, compute via gather. indexer = indexer.expand_bool_indices() dynamic_idx, treedef = tree_flatten(indexer) internal_gather = partial( _gather, treedef=treedef, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, fill_value=fill_value, normalize_indices=normalize_indices) if out_sharding is not None: out_sharding = canonicalize_sharding(out_sharding, 'take') return auto_axes(internal_gather, out_sharding=out_sharding, axes=out_sharding.mesh.explicit_axes, )(arr, dynamic_idx) return internal_gather(arr, dynamic_idx) def _static_slice(arr: Array, indexer: _StaticSliceIndexer) -> Array: """Equivalent of arr[idx] implemented in terms of static :func:`lax.slice` operations. This supports only INTEGER, ELLIPSIS, NONE, and SLICE indices, and will raise a TypeError if other indices are present. """ if indexer.is_trivial_slice(arr.shape): result = arr else: result = slicing.slice(arr, indexer.start_indices, indexer.limit_indices, indexer.strides) if indexer.rev_axes: result = lax.rev(result, indexer.rev_axes) if indexer.squeeze_axes: result = lax.squeeze(result, indexer.squeeze_axes) if indexer.newaxis_dims: result = lax.expand_dims(result, indexer.newaxis_dims) return result def _dynamic_slice(arr: Array, indexer: _DynamicSliceIndexer) -> Array: """Equivalent of arr[idx] implemented in terms of static :func:`lax.dynamic_slice`. This supports only INTEGER, ELLIPSIS, NONE, SLICE, and scalar ARRAY indices, and will raise a TypeError if other indices are present. """ if indexer.trivial_slicing: result = arr else: result = slicing.dynamic_slice( arr, start_indices=indexer.start_indices, slice_sizes=indexer.slice_sizes, allow_negative_indices=indexer.normalize_indices) if indexer.rev_axes: result = lax.rev(result, indexer.rev_axes) if indexer.squeeze_axes: result = lax.squeeze(result, indexer.squeeze_axes) if indexer.newaxis_dims: result = lax.expand_dims(result, indexer.newaxis_dims) return result # TODO(phawkins): re-enable jit after fixing excessive recompilation for # slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). # @api.jit(static_argnums=(1, 2)) def _gather(arr, dynamic_idx, *, treedef, indices_are_sorted, unique_indices, mode, fill_value, normalize_indices): parsed_idx = tree_unflatten(treedef, dynamic_idx) indexer = parsed_idx.to_gather(core.typeof(arr).sharding, normalize_indices=normalize_indices) jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices) y = arr if fill_value is not None: core.concrete_or_error(None, fill_value, "fill_value argument to indexed get()") if np.ndim(fill_value) != 0: raise ValueError("fill_value argument to indexed get() must be a scalar") if isinstance(fill_value, np.ndarray): fill_value = fill_value.item() if indexer.scalar_bool_dims: y = lax.expand_dims(y, indexer.scalar_bool_dims) # Avoid calling gather if the slice shape is empty, both as a fast path and to # handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): return lax.full_like(y, 0, shape=indexer.slice_shape, sharding=indexer.slice_sharding) # We avoid generating a gather when indexer.gather_indices.size is empty. if not core.is_empty_shape(indexer.gather_indices.shape): y = slicing.gather( y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape, unique_indices=unique_indices or indexer.unique_indices, indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted, mode=mode, fill_value=fill_value) # Reverses axes with negative strides. if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # This adds np.newaxis/None dimensions. return lax.expand_dims(y, indexer.newaxis_dims) class _StaticSliceIndexer(NamedTuple): start_indices: Sequence[int] limit_indices: Sequence[int] strides: Sequence[int] | None rev_axes: Sequence[int] squeeze_axes: Sequence[int] newaxis_dims: Sequence[int] def is_trivial_slice(self, arr_shape: Sequence[int]): if self.strides is not None or len(arr_shape) != len(self.start_indices): return False return all( (start, stop) == (0, size) for start, stop, size in zip(self.start_indices, self.limit_indices, arr_shape) ) class _DynamicSliceIndexer(NamedTuple): start_indices: Sequence[ArrayLike] slice_sizes: Sequence[int] rev_axes: Sequence[int] squeeze_axes: Sequence[int] newaxis_dims: Sequence[int] trivial_slicing: bool normalize_indices: bool class _GatherIndexer(NamedTuple): # The expected shape of the slice output. slice_shape: Sequence[int] # The slice shape to pass to lax.gather(). gather_slice_shape: Sequence[int] # The gather indices to use. gather_indices: ArrayLike # A GatherDimensionNumbers object describing the gather to perform. dnums: slicing.GatherDimensionNumbers # Are the gather_indices known to be non-overlapping and/or sorted? # (In practice, these translate to "there no advanced indices", because # only advanced indices could lead to index repetition.) unique_indices: bool indices_are_sorted: bool # Slice dimensions that have negative strides, and so must be reversed after # the gather. reversed_y_dims: Sequence[int] # Keep track of any axes created by `newaxis`. These must be inserted for # gathers and eliminated for scatters. newaxis_dims: Sequence[int] # Keep track of dimensions with scalar bool indices. These must be inserted # for gathers before performing other index operations. scalar_bool_dims: Sequence[int] # The expected sharding of the slice output. slice_sharding: NamedSharding | None = None def _index_to_gather(indexer: NDIndexer, *, x_sharding: NamedSharding | Any, normalize_indices: bool = True) -> _GatherIndexer: indexer.validate_slices() indexer = indexer.convert_sequences_to_arrays() is_advanced = np.nonzero( np.array([idx.typ in {IndexType.ARRAY, IndexType.INTEGER} for idx in indexer.indices])) advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1) indexer = indexer.expand_ellipses() scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(indexer.indices) if i.typ == IndexType.BOOLEAN] indexer, x_spec = indexer.expand_scalar_bool_indices(x_sharding.spec) if normalize_indices: indexer = indexer.normalize_indices() # Check for advanced indexing: # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing # The advanced indices. advanced_indexes: Sequence[Array] = [] # The positions of the advanced indexing axes in `idx`. idx_advanced_axes: Sequence[int] = [] # The positions of the advanced indexes in x's shape. # collapsed, after None axes have been removed. See below. x_advanced_axes: Sequence[int] = [] if indexer.is_advanced_int_indexer(): idx_without_none = [(i, d) for i, d in enumerate(indexer.indices) if d.typ != IndexType.NONE] advanced_pairs = ( (lax_numpy.asarray(e.index), i, j) for j, (i, e) in enumerate(idx_without_none) if e.typ in [IndexType.ARRAY, IndexType.INTEGER] ) advanced_indexes, idx_advanced_axes, x_advanced_axes = unzip3(advanced_pairs) x_axis = 0 # Current axis in x. y_axis = 0 # Current axis in y, before collapsing. See below. collapsed_y_axis = 0 # Current axis in y, after collapsing. # Scatter dimension numbers. offset_dims: list[int] = [] collapsed_slice_dims: list[int] = [] start_index_map: list[int] = [] index_dtype = lax_utils.int_dtype_for_shape(indexer.shape, signed=True) # Gather indices. # Pairs of (array, start_dim) values. These will be broadcast into # gather_indices_shape, with the array dimensions aligned to start_dim, and # then concatenated. gather_indices: list[tuple[Array, int]] = [] gather_indices_shape: list[int] = [] # We perform three transformations to y before the scatter op, in order: # First, y is broadcast to slice_shape. In general `y` only need broadcast to # the right shape. slice_shape: list[int] = [] # Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None` # indices, which the scatter cannot remove itself. newaxis_dims: list[int] = [] # Finally, we reverse reversed_y_dims to handle slices with negative strides. reversed_y_dims: list[int] = [] gather_slice_shape: list[int] = [] slice_spec = [] for idx_pos, index in enumerate(indexer.indices): # Handle the advanced indices here if: # * the advanced indices were not contiguous and we are the start. # * we are at the position of the first advanced index. if (advanced_indexes and (advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or not advanced_axes_are_contiguous and idx_pos == 0)): advanced_index_arrs = util._broadcast_arrays(*advanced_indexes) shape = advanced_index_arrs[0].shape aia_spec = core.typeof(advanced_index_arrs[0]).sharding.spec ndim = len(shape) start_dim = len(gather_indices_shape) gather_indices.extend( (lax.convert_element_type(a, index_dtype), start_dim) for a in advanced_index_arrs ) gather_indices_shape += shape assert x_advanced_axes is not None start_index_map.extend(x_advanced_axes) collapsed_slice_dims.extend(x_advanced_axes) slice_shape.extend(shape) slice_spec.extend(aia_spec) y_axis += ndim collapsed_y_axis += ndim # Per-index bookkeeping for advanced indexes. if idx_pos in idx_advanced_axes: x_axis += 1 gather_slice_shape.append(1) continue if index.typ in [IndexType.INTEGER, IndexType.ARRAY] and np.ndim(index.index) == 0: # pyrefly: ignore[bad-argument-type] # Basic scalar int indices if core.definitely_equal(indexer.shape[x_axis], 0): # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") i_converted = lax.convert_element_type(index.index, index_dtype) # pyrefly: ignore[bad-argument-type] gather_indices.append((i_converted, len(gather_indices_shape))) collapsed_slice_dims.append(x_axis) gather_slice_shape.append(1) start_index_map.append(x_axis) x_axis += 1 elif index.typ == IndexType.NONE: # None indexing: add a dimension. slice_shape.append(1) slice_spec.append(None) newaxis_dims.append(y_axis) y_axis += 1 elif index.typ in [IndexType.SLICE, IndexType.DYNAMIC_SLICE]: # Handle static slice index. if isinstance(index.index, indexing.Slice): start, step, slice_size = index.index.start, index.index.stride, index.index.size elif isinstance(index.index, slice): start, step, slice_size = core.canonicalize_slice(index.index, indexer.shape[x_axis]) else: raise RuntimeError(f"Internal: expected slice or Slice, got {type(index.index)}") slice_shape.append(slice_size) slice_spec.append(x_spec[x_axis]) if core.definitely_equal(step, 1): # Optimization: avoid generating trivial gather. if not core.definitely_equal(slice_size, indexer.shape[x_axis]): gather_indices.append((lax.convert_element_type(start, index_dtype), len(gather_indices_shape))) start_index_map.append(x_axis) gather_slice_shape.append(slice_size) offset_dims.append(collapsed_y_axis) else: indices = (lax_numpy.array(start, dtype=index_dtype) + lax_numpy.array(step, dtype=index_dtype) * lax.iota(index_dtype, slice_size)) if step < 0: reversed_y_dims.append(collapsed_y_axis) indices = lax.rev(indices, dimensions=(0,)) gather_slice_shape.append(1) gather_indices.append((indices, len(gather_indices_shape))) start_index_map.append(x_axis) gather_indices_shape.append(slice_size) collapsed_slice_dims.append(x_axis) collapsed_y_axis += 1 y_axis += 1 x_axis += 1 else: raise IndexError(f"Got unsupported indexer at position {idx_pos}: {index!r}") if len(gather_indices) == 0: gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype) elif len(gather_indices) == 1: g, _ = gather_indices[0] gather_indices_array = lax.expand_dims(g, (g.ndim,)) else: last_dim = len(gather_indices_shape) gather_indices_shape.append(1) gather_indices_array = lax.concatenate([ lax.broadcast_in_dim(g, gather_indices_shape, tuple(range(i, i + g.ndim))) for g, i in gather_indices], last_dim) dnums = slicing.GatherDimensionNumbers( offset_dims = tuple(offset_dims), collapsed_slice_dims = tuple(sorted(collapsed_slice_dims)), start_index_map = tuple(start_index_map) ) slice_sharding = canonicalize_sharding(x_sharding.update(spec=slice_spec), "index_to_gather") return _GatherIndexer( slice_shape=slice_shape, newaxis_dims=tuple(newaxis_dims), gather_slice_shape=gather_slice_shape, reversed_y_dims=reversed_y_dims, dnums=dnums, gather_indices=gather_indices_array, unique_indices=not advanced_indexes, indices_are_sorted=not advanced_indexes, scalar_bool_dims=scalar_bool_dims, slice_sharding=slice_sharding) def _should_unpack_list_index(x): """Helper for eliminate_deprecated_list_indexing.""" return (isinstance(x, (np.ndarray, Array)) and np.ndim(x) != 0 or isinstance(x, (Sequence, slice)) or x is Ellipsis or x is None) def eliminate_deprecated_list_indexing(idx): # "Basic slicing is initiated if the selection object is a non-array, # non-tuple sequence containing slice objects, [Ellipses, or newaxis # objects]". Detects this and raises a TypeError. if not isinstance(idx, tuple): if isinstance(idx, Sequence) and not isinstance( idx, (Array, np.ndarray, str) ): # As of numpy 1.16, some non-tuple sequences of indices result in a warning, while # others are converted to arrays, based on a set of somewhat convoluted heuristics # (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343) # In JAX, we raise an informative TypeError for *all* non-tuple sequences. if any(_should_unpack_list_index(i) for i in idx): msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " "use `arr[tuple(seq)]` instead of `arr[seq]`. " "See https://github.com/jax-ml/jax/issues/4564 for more information.") else: msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " "use `arr[array(seq)]` instead of `arr[seq]`. " "See https://github.com/jax-ml/jax/issues/4564 for more information.") raise TypeError(msg) else: idx = (idx,) return idx def _is_boolean_index(i): try: abstract_i = core.typeof(i) except TypeError: abstract_i = None return (isinstance(abstract_i, core.ShapedArray) and dtypes.issubdtype(abstract_i.dtype, np.bool_) or isinstance(i, list) and i and all(_is_scalar(e) and dtypes.issubdtype(dtypes.dtype(e), np.bool_) for e in i)) def _is_slice_element_none_or_constant_or_symbolic(elt): """Return True if elt is a constant or None.""" if elt is None: return True if core.is_symbolic_dim(elt): return True try: return core.is_concrete(elt) except TypeError: return False def _is_scalar(x): """Checks if a Python or NumPy scalar.""" return np.isscalar(x) or ( isinstance(x, (np.ndarray, Array)) and np.ndim(x) == 0 ) @export def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *, inplace: bool = True) -> Array: """Update array elements based on a mask. JAX implementation of :func:`numpy.place`. The semantics of :func:`numpy.place` are to modify arrays in-place, which is not possible for JAX's immutable arrays. The JAX version returns a modified copy of the input, and adds the ``inplace`` parameter which must be set to `False`` by the user as a reminder of this API difference. Args: arr: array into which values will be placed. mask: boolean mask with the same size as ``arr``. vals: values to be inserted into ``arr`` at the locations indicated by mask. If too many values are supplied, they will be truncated. If not enough values are supplied, they will be repeated. inplace: must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned. Returns: A copy of ``arr`` with masked values set to entries from `vals`. See Also: - :func:`jax.numpy.put`: put elements into an array at numerical indices. - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing Examples: >>> x = jnp.zeros((3, 5), dtype=int) >>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape) >>> mask Array([[ True, False, False, True, False], [False, True, False, False, True], [False, False, True, False, False]], dtype=bool) Placing a scalar value: >>> jnp.place(x, mask, 1, inplace=False) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32) In this case, ``jnp.place`` is similar to the masked array update syntax: >>> x.at[mask].set(1) Array([[1, 0, 0, 1, 0], [0, 1, 0, 0, 1], [0, 0, 1, 0, 0]], dtype=int32) ``place`` differs when placing values from an array. The array is repeated to fill the masked entries: >>> vals = jnp.array([1, 3, 5]) >>> jnp.place(x, mask, vals, inplace=False) Array([[1, 0, 0, 3, 0], [0, 5, 0, 0, 1], [0, 0, 3, 0, 0]], dtype=int32) """ data, mask_arr, vals_arr = util.ensure_arraylike("place", arr, mask, vals) vals_arr = vals_arr.ravel() if inplace: raise ValueError( "jax.numpy.place cannot modify arrays in-place, because JAX arrays are immutable. " "Pass inplace=False to instead return an updated array.") if data.size != mask_arr.size: raise ValueError("place: arr and mask must be the same size") if not vals_arr.size: raise ValueError("Cannot place values from an empty array") if not data.size: return data indices = lax_numpy.where(mask_arr.ravel(), size=mask_arr.size, fill_value=mask_arr.size)[0] vals_arr = lax_numpy._tile_to_size(vals_arr, len(indices)) return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape) @export def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike, mode: str | None = None, *, inplace: bool = True) -> Array: """Put elements into an array at given indices. JAX implementation of :func:`numpy.put`. The semantics of :func:`numpy.put` are to modify arrays in-place, which is not possible for JAX's immutable arrays. The JAX version returns a modified copy of the input, and adds the ``inplace`` parameter which must be set to `False`` by the user as a reminder of this API difference. Args: a: array into which values will be placed. ind: array of indices over the flattened array at which to put values. v: array of values to put into the array. mode: string specifying how to handle out-of-bound indices. Supported values: - ``"clip"`` (default): clip out-of-bound indices to the final index. - ``"wrap"``: wrap out-of-bound indices to the beginning of the array. inplace: must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned. Returns: A copy of ``a`` with specified entries updated. See Also: - :func:`jax.numpy.place`: place elements into an array via boolean mask. - :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing. - :func:`jax.numpy.take`: extract values from an array at given indices. Examples: >>> x = jnp.zeros(5, dtype=int) >>> indices = jnp.array([0, 2, 4]) >>> values = jnp.array([10, 20, 30]) >>> jnp.put(x, indices, values, inplace=False) Array([10, 0, 20, 0, 30], dtype=int32) This is equivalent to the following :attr:`jax.numpy.ndarray.at` indexing syntax: >>> x.at[indices].set(values) Array([10, 0, 20, 0, 30], dtype=int32) There are two modes for handling out-of-bound indices. By default they are clipped: >>> indices = jnp.array([0, 2, 6]) >>> jnp.put(x, indices, values, inplace=False, mode='clip') Array([10, 0, 20, 0, 30], dtype=int32) Alternatively, they can be wrapped to the beginning of the array: >>> jnp.put(x, indices, values, inplace=False, mode='wrap') Array([10, 30, 20, 0, 0], dtype=int32) For N-dimensional inputs, the indices refer to the flattened array: >>> x = jnp.zeros((3, 5), dtype=int) >>> indices = jnp.array([0, 7, 14]) >>> jnp.put(x, indices, values, inplace=False) Array([[10, 0, 0, 0, 0], [ 0, 0, 20, 0, 0], [ 0, 0, 0, 0, 30]], dtype=int32) """ if inplace: raise ValueError( "jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. " "Pass inplace=False to instead return an updated array.") arr, ind_arr, _ = util.ensure_arraylike("put", a, ind, v) ind_arr = ind_arr.ravel() v_arr = lax_numpy.ravel(v) if not arr.size or not ind_arr.size or not v_arr.size: return arr v_arr = lax_numpy._tile_to_size(v_arr, len(ind_arr)) if mode is None: scatter_mode = "drop" elif mode == "clip": ind_arr = lax_numpy.clip(ind_arr, 0, arr.size - 1) scatter_mode = "promise_in_bounds" elif mode == "wrap": ind_arr = ind_arr % arr.size scatter_mode = "promise_in_bounds" elif mode == "raise": raise NotImplementedError("The 'raise' mode to jnp.put is not supported.") else: raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}") return arr.at[lax_numpy.unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode)