Files
2026-05-06 19:47:31 +07:00

1695 lines
66 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Indexing code for jax.numpy."""
from __future__ import annotations
import dataclasses
import enum
from functools import partial
import operator
import string
from typing import Any, NamedTuple
from collections.abc import Sequence
import numpy as np
from jax._src import api
from jax._src import array
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import errors
from jax._src import indexing
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lax import utils as lax_utils
from jax._src.numpy import array_constructors
from jax._src.numpy import einsum
from jax._src.numpy import error as jnp_error
from jax._src.numpy import lax_numpy
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.partition_spec import PartitionSpec
from jax._src.pjit import auto_axes
from jax._src.sharding_impls import canonicalize_sharding, NamedSharding
from jax._src.tree_util import tree_flatten, tree_unflatten, register_pytree_node_class
from jax._src.typing import Array, ArrayLike, Index, StaticScalar
from jax._src.util import canonicalize_axis, safe_zip, set_module, tuple_update, unzip3
export = set_module('jax.numpy')
# Internal utilities for parsing and validating NumPy-style indices.
class IndexType(enum.Enum):
"""Enum for tracking the type of an index."""
NONE = "none"
SLICE = "slice"
ELLIPSIS = "ellipsis"
INTEGER = "integer"
BOOLEAN = "boolean"
ARRAY = "array"
DYNAMIC_SLICE = "dynamic_slice"
@classmethod
def from_index(cls, idx: Index) -> IndexType:
"""Create an IndexType enum from a supported JAX array index."""
if idx is None:
return cls.NONE
elif idx is Ellipsis:
return cls.ELLIPSIS
elif isinstance(idx, slice):
return cls.SLICE
elif isinstance(idx, indexing.Slice):
return cls.DYNAMIC_SLICE
elif _is_integer_index(idx):
return cls.INTEGER
elif _is_boolean_index(idx):
return cls.BOOLEAN
elif isinstance(idx, (Array, np.ndarray)):
if dtypes.issubdtype(idx.dtype, np.integer):
return cls.ARRAY
else:
raise TypeError(
f"Indexer must have integer or boolean type, got indexer with type {idx.dtype}")
elif isinstance(idx, str):
# TODO(jakevdp): this TypeError is for backward compatibility.
# We should switch to IndexError for consistency.
raise TypeError(f"JAX does not support string indexing; got {idx=}")
elif isinstance(idx, Sequence):
if not idx: # empty indices default to float, so special-case this.
return cls.ARRAY
idx_aval = api.eval_shape(array_constructors.asarray, idx)
if idx_aval.dtype == bool:
return cls.BOOLEAN
elif dtypes.issubdtype(idx_aval.dtype, np.integer):
return cls.ARRAY
else:
raise TypeError(
f"Indexer must have integer or boolean type, got indexer with type {idx_aval.dtype}")
elif isinstance(idx, (float, complex, np.generic)):
raise TypeError(
f"Indexer must have integer or boolean type, got indexer with type {np.dtype(type(idx))}")
else:
raise IndexError("only integers, slices (`:`), ellipsis (`...`), newaxis (`None`)"
f" and integer or boolean arrays are valid indices. Got {idx}")
class ParsedIndex(NamedTuple):
"""Structure for tracking an indexer parsed within the context of an array shape."""
index: Index
typ: IndexType
consumed_axes: tuple[int, ...]
def _parse_indices(
indices: tuple[Index, ...],
shape: tuple[int, ...],
) -> list[ParsedIndex]:
"""Parse indices in the context of an array shape.
Args:
indices: a tuple of user-supplied indices to be parsed.
shape: the shape of the array being indexed.
Returns:
The list of parsed indices stored in :class:`ParsedIndex` objects.
This list will have the same length as ``indices``.
Raises:
IndexError: if any unrecognized index types are present or if there
are too many indices, or too many ellipses.
"""
# 1. go through indices to count the number of consumed dimensions.
# This is required to determine the effect of any ellipses.
dimensions_consumed: list[int] = []
ellipses_indices: list[int] = []
index_types: list[IndexType] = []
for i, idx in enumerate(indices):
typ = IndexType.from_index(idx)
index_types.append(typ)
if typ == IndexType.NONE:
dimensions_consumed.append(0)
elif typ == IndexType.ELLIPSIS:
# We don't yet know how many dimensions are consumed, so set to zero
# for now and update later.
dimensions_consumed.append(0)
ellipses_indices.append(i)
elif typ == IndexType.BOOLEAN:
dimensions_consumed.append(np.ndim(idx)) # pyrefly: ignore[bad-argument-type]
elif typ in [IndexType.INTEGER, IndexType.ARRAY, IndexType.SLICE, IndexType.DYNAMIC_SLICE]:
dimensions_consumed.append(1)
else:
raise IndexError(f"Unrecognized index type: {typ}")
# 2. Validate the consumed dimensions and ellipses.
if len(ellipses_indices) > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
total_consumed = sum(dimensions_consumed)
if total_consumed > len(shape):
raise IndexError(f"Too many indices: array is {len(shape)}-dimensional,"
f" but {total_consumed} were indexed")
if ellipses_indices:
dimensions_consumed[ellipses_indices[0]] = len(shape) - total_consumed
# 3. Generate the final sequence of parsed indices.
result: list[ParsedIndex] = []
current_dim = 0
for index, typ, n_consumed in safe_zip(indices, index_types, dimensions_consumed):
consumed_axes = tuple(range(current_dim, current_dim + n_consumed))
current_dim += len(consumed_axes)
result.append(ParsedIndex(index=index, typ=typ, consumed_axes=consumed_axes))
return result
@register_pytree_node_class
@dataclasses.dataclass(frozen=True, kw_only=True)
class NDIndexer:
"""Object that implements NumPy-style indexing operations on top of JAX.
Generally this will be constructed via the :meth:`NDIndexer.from_raw_indices`
method.
Attributes:
shape: the shape of the array being indexed.
indices: a list of :class:`ParsedIndex` objects.
"""
shape: tuple[int, ...]
indices: list[ParsedIndex]
@classmethod
def from_raw_indices(cls, indices: Index | tuple[Index, ...], shape: tuple[int, ...]) -> NDIndexer:
"""Create an NDIndexer object from raw user-supplied indices."""
indices = eliminate_deprecated_list_indexing(indices)
parsed = _parse_indices(indices, shape)
return cls(shape=shape, indices=parsed)
def validate_static_indices(self, normalize_indices: bool = True) -> None:
"""Check that all static integer indices are in-bounds.
Raises an IndexError in case of out-of-bound indices
"""
for idx in self.indices:
if idx.typ == IndexType.INTEGER:
assert isinstance(idx.index, (int, np.integer))
i = operator.index(idx.index)
axis, = idx.consumed_axes
size = self.shape[axis]
normed_idx = i + size if normalize_indices and i < 0 else i
if not 0 <= normed_idx < size:
raise IndexError(f"index {i} out of bounds for axis {axis} with size {size}"
f" ({normalize_indices=})")
def validate_slices(self) -> None:
"""Check that all slices have static start/stop/step values.
Raises an IndexError in case of non-static entries.
"""
for position, idx in enumerate(self.indices):
if idx.typ == IndexType.SLICE:
assert isinstance(idx.index, slice)
if not all(_is_slice_element_none_or_constant_or_symbolic(val)
for val in [idx.index.start, idx.index.stop, idx.index.step]):
raise IndexError("Slice entries must be static integers."
f" Got {idx.index} at position {position}")
@staticmethod
def is_sharded(arr) -> bool:
"""Check whether the array is sharded."""
return isinstance(arr, array.ArrayImpl) and not arr.sharding.num_devices == 1
def has_partial_slices(self) -> bool:
"""Check whether the indexer contains partial slices.
For sharded arrays, partial slices cannot automatically propagate
sharding.
"""
for idx in self.indices:
if idx.typ in [IndexType.INTEGER, IndexType.DYNAMIC_SLICE]:
return True
if idx.typ == IndexType.SLICE:
slc = idx.index
assert isinstance(slc, slice)
axis, = idx.consumed_axes
size = self.shape[axis]
start, stop, step = slc.indices(self.shape[axis])
if abs(step) != 1 or abs(stop - start) != size:
return True
return False
def expand_bool_indices(self) -> NDIndexer:
"""Returns a new NDIndexer with boolean indices replaced by array indices.
The only exception are scalar boolean indices, which are left in-place.
"""
expanded_indices: list[ParsedIndex] = []
for position, idx in enumerate(self.indices):
if idx.typ != IndexType.BOOLEAN:
expanded_indices.append(idx)
continue
if not core.is_concrete(idx.index):
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(core.typeof(idx.index))
assert isinstance(idx.index, (bool, np.ndarray, Array, list))
if np.ndim(idx.index) == 0: # pyrefly: ignore[bad-argument-type]
# Scalar booleans
assert idx.consumed_axes == ()
expanded_indices.append(ParsedIndex(index=bool(idx.index), typ=idx.typ, consumed_axes=()))
continue
idx_shape = np.shape(idx.index) # pyrefly: ignore[no-matching-overload]
expected_shape = [self.shape[i] for i in idx.consumed_axes]
if not all(s1 in (0, s2) for s1, s2 in zip(idx_shape, expected_shape)):
raise IndexError("boolean index did not match shape of indexed array in index"
f" {position}: got {idx_shape}, expected {expected_shape}")
expanded_indices_raw = np.where(np.asarray(idx.index))
expanded_indices.extend(ParsedIndex(index=i, typ=IndexType.ARRAY, consumed_axes=(axis,))
for i, axis in safe_zip(expanded_indices_raw, idx.consumed_axes))
return NDIndexer(shape=self.shape, indices=expanded_indices)
def expand_scalar_bool_indices(self, sharding_spec: Any = None) -> tuple[NDIndexer, Any]:
new_shape = list(self.shape)
new_sharding_spec = list((None for _ in self.shape) if sharding_spec is None else sharding_spec)
new_indices = list(self.indices)
current_dim = 0
for i, idx in enumerate(self.indices):
if idx.typ == IndexType.BOOLEAN and np.ndim(idx.index) == 0: # pyrefly: ignore[bad-argument-type]
new_shape.insert(i, 1)
new_sharding_spec.insert(i, None)
new_indices[i] = ParsedIndex(
np.arange(int(idx.index)), typ=IndexType.ARRAY, consumed_axes=(current_dim,)) # pyrefly: ignore[bad-argument-type]
current_dim += 1
else:
n_consumed = len(idx.consumed_axes)
new_indices[i] = ParsedIndex(
index=idx.index,
typ=idx.typ,
consumed_axes = tuple(range(current_dim, current_dim + n_consumed))
)
current_dim += n_consumed
new_sharding_spec = None if sharding_spec is None else tuple(new_sharding_spec)
return NDIndexer(indices=new_indices, shape=tuple(new_shape)), new_sharding_spec
def convert_sequences_to_arrays(self) -> NDIndexer:
new_indices = [ParsedIndex(lax_numpy.asarray(idx.index), typ=idx.typ, consumed_axes=idx.consumed_axes)
if isinstance(idx.index, Sequence) else idx for idx in self.indices]
return NDIndexer(indices=new_indices, shape=self.shape)
def expand_ellipses(self) -> NDIndexer:
"""
Returns a new indexer with ellipsis and implicit trailing slices
replaced by explicit empty slices.
"""
expanded: list[ParsedIndex] = []
consumed = 0
for idx in self.indices:
consumed += len(idx.consumed_axes)
if idx.typ == IndexType.ELLIPSIS:
for axis in idx.consumed_axes:
expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,)))
else:
expanded.append(idx)
for axis in range(consumed, len(self.shape)):
expanded.append(ParsedIndex(index=slice(None), typ=IndexType.SLICE, consumed_axes=(axis,)))
return NDIndexer(shape=self.shape, indices=expanded)
def normalize_indices(self) -> NDIndexer:
new_indices: list[ParsedIndex] = []
for idx in self.indices:
if idx.typ == IndexType.INTEGER:
axis, = idx.consumed_axes
size: ArrayLike = self.shape[axis]
if isinstance(idx.index, np.unsignedinteger):
normed_index: Index = idx.index
else:
normed_index = idx.index + size if idx.index < 0 else idx.index # pyrefly: ignore[bad-assignment, unsupported-operation]
new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes))
elif idx.typ in [IndexType.ARRAY, IndexType.INTEGER]:
assert isinstance(idx.index, (Array, np.ndarray))
axis, = idx.consumed_axes
if dtypes.issubdtype(idx.index.dtype, np.unsignedinteger):
normed_index = idx.index
else:
size = self.shape[axis]
if core.is_constant_dim(size):
size = lax._const(idx.index, size)
else:
size = lax.convert_element_type(core.dimension_as_value(size),
idx.index.dtype)
normed_index = lax.select(idx.index < 0, lax.add(idx.index, size), idx.index)
new_indices.append(ParsedIndex(normed_index, typ=idx.typ, consumed_axes=idx.consumed_axes))
else:
new_indices.append(idx)
return NDIndexer(indices=new_indices, shape=self.shape)
def to_static_slice(
self, *,
arr_is_sharded: bool = False,
normalize_indices: bool = True,
mode: str | slicing.GatherScatterMode | None) -> _StaticSliceIndexer:
"""Convert to StaticSliceIndexer data structure.
If this is not possible, raise a ValueError, TypeError, or IndexError.
"""
if mode is None:
parsed_mode = slicing.GatherScatterMode.PROMISE_IN_BOUNDS
else:
parsed_mode = slicing.GatherScatterMode.from_any(mode)
if any(core.is_symbolic_dim(s) for s in self.shape):
raise ValueError("mode='slice' is not valid for polymorphic shapes.")
if parsed_mode not in [
slicing.GatherScatterMode.PROMISE_IN_BOUNDS, slicing.GatherScatterMode.CLIP]:
raise ValueError("static_slice requires mode='promise_in_bounds' or mode='clip'")
# Validation of the unmodified user indices.
if parsed_mode == slicing.GatherScatterMode.PROMISE_IN_BOUNDS:
self.validate_static_indices(normalize_indices=normalize_indices)
self.validate_slices()
# For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as
# opposed to x[:]) lead to incorrect sharding semantics when computed via slice.
# TODO(yashkatariya): fix slice with sharding
if arr_is_sharded and self.has_partial_slices():
raise ValueError("static_slice with partial slices does not support nontrivial array sharding.")
for position, pidx in enumerate(self.indices):
if pidx.typ in [IndexType.INTEGER, IndexType.ELLIPSIS, IndexType.SLICE, IndexType.NONE]:
pass
elif pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.DYNAMIC_SLICE]:
raise TypeError("static_slice: indices must be static scalars or slices."
f" Got index of type {type(pidx.index)} at position {position}")
else:
raise TypeError(f"static_slice: unrecognized index {pidx.index} at position {position}.")
# Now re-iterate to generate static slices.
start_indices: list[int] = []
limit_indices: list[int] = []
strides: list[int] = []
rev_axes: list[int] = []
squeeze_axes: list[int] = []
newaxis_dims: list[int] = []
expanded = self.expand_ellipses()
for pidx in expanded.indices:
if pidx.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.ELLIPSIS]:
raise RuntimeError(f"Internal: unexpected index encountered: {pidx}")
elif pidx.typ == IndexType.NONE:
# Expanded axes indices are based on the rank of the array after slicing
# (tracked by start_indices) and squeezing (tracked by squeeze_axes), and
# expand_dims inserts dimensions in order, so we must also account for
# previous expanded dimensions.
newaxis_dims.append(len(start_indices) - len(squeeze_axes) + len(newaxis_dims) )
elif pidx.typ == IndexType.INTEGER:
assert isinstance(pidx.index, (int, np.integer))
axis, = pidx.consumed_axes
if core.definitely_equal(self.shape[axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {axis} with size 0")
start_index = int(pidx.index)
if normalize_indices and start_index < 0:
start_index += self.shape[axis]
# Normalization & validation have already been handled, so clip start_index
# to valid range
start_index = min(max(start_index, 0), self.shape[axis] - 1)
start_indices.append(start_index)
limit_indices.append(start_index + 1)
strides.append(1)
squeeze_axes.append(axis)
elif pidx.typ == IndexType.SLICE:
assert isinstance(pidx.index, slice)
axis, = pidx.consumed_axes
size = self.shape[axis]
start, stop, stride = pidx.index.indices(size)
if stride < 0:
new_start = min(size, stop + 1 + abs(start - stop - 1) % abs(stride))
start_indices.append(new_start)
limit_indices.append(max(new_start, start + 1))
strides.append(abs(stride))
rev_axes.append(axis)
else:
start_indices.append(start)
limit_indices.append(max(start, stop))
strides.append(stride)
else:
raise TypeError(f"static_slice: unrecognized index {pidx.index}")
return _StaticSliceIndexer(
start_indices=start_indices,
limit_indices=limit_indices,
strides=None if all(s == 1 for s in strides) else strides,
rev_axes=rev_axes,
squeeze_axes=squeeze_axes,
newaxis_dims=newaxis_dims,
)
def to_dynamic_slice(
self, *,
arr_is_sharded: bool = False,
normalize_indices: bool = True,
mode: str | slicing.GatherScatterMode | None) -> _DynamicSliceIndexer:
"""Convert to DynamicSliceIndexer data structure.
If this is not possible, raise a ValueError, TypeError, or IndexError.
"""
if mode is not None:
parsed_mode = slicing.GatherScatterMode.from_any(mode)
if parsed_mode not in [
slicing.GatherScatterMode.PROMISE_IN_BOUNDS, slicing.GatherScatterMode.CLIP]:
raise ValueError("dynamic_slice requires mode='promise_in_bounds' or mode='clip'")
# For sharded inputs, indexing (like x[0]) and partial slices (like x[:2] as
# opposed to x[:]) lead to incorrect sharding semantics when computed via slice.
# TODO(yashkatariya): fix slice with sharding
if arr_is_sharded and self.has_partial_slices():
raise ValueError("dynamic_slice with partial slices does not support nontrivial array sharding.")
for position, pidx in enumerate(self.indices):
if pidx.typ in [IndexType.INTEGER, IndexType.ELLIPSIS, IndexType.NONE]:
pass
elif pidx.typ == IndexType.DYNAMIC_SLICE:
assert isinstance(pidx.index, indexing.Slice)
if pidx.index.stride != 1:
raise TypeError("dynamic_slice: only unit steps supported in slice."
f" Got {pidx.index} at position {position}")
elif pidx.typ == IndexType.SLICE:
assert isinstance(pidx.index, slice)
if pidx.index.step is not None and pidx.index.step not in [-1, 1]:
raise TypeError("dynamic_slice: only unit steps supported in slice."
f" Got {pidx.index} at position {position}")
elif pidx.typ == IndexType.ARRAY:
if isinstance(pidx.index, Sequence) or np.shape(pidx.index) != (): # pyrefly: ignore[no-matching-overload]
raise TypeError("dynamic_slice: only scalar indices allowed."
f" Got index of type {type(pidx.index)} at position {position}")
elif pidx.typ == IndexType.BOOLEAN:
raise TypeError("dynamic_slice: indices must be scalars or slices."
f" Got index of type {type(pidx.index)} at position {position}")
else:
raise TypeError(f"dynamic_slice: unrecognized index {pidx.index} at position {position}.")
start_indices: list[ArrayLike] = []
slice_sizes: list[int] = []
rev_axes: list[int] = []
squeeze_axes: list[int] = []
newaxis_dims: list[int] = []
expanded = self.expand_ellipses()
trivial_slicing = True
for pidx in expanded.indices:
if pidx.typ in [IndexType.BOOLEAN, IndexType.ELLIPSIS]:
raise RuntimeError(f"Internal: unexpected index encountered: {pidx}")
elif pidx.typ == IndexType.NONE:
# Expanded axes indices are based on the rank of the array after slicing
# (tracked by start_indices) and squeezing (tracked by squeeze_axes), and
# expand_dims inserts dimensions in order, so we must also account for
# previous expanded dimensions.
newaxis_dims.append(len(start_indices) - len(squeeze_axes) + len(newaxis_dims))
elif pidx.typ in [IndexType.INTEGER, IndexType.ARRAY]:
trivial_slicing = False
index = lax_numpy.asarray(pidx.index)
assert index.shape == () # Validated above.
axis, = pidx.consumed_axes
if core.definitely_equal(self.shape[axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {axis} with size 0")
start_indices.append(index)
slice_sizes.append(1)
squeeze_axes.append(axis)
elif pidx.typ == IndexType.SLICE:
assert isinstance(pidx.index, slice)
if pidx.index != slice(None):
trivial_slicing = False
axis, = pidx.consumed_axes
size = self.shape[axis]
start, stop, stride = pidx.index.indices(size)
assert stride in [-1, 1] # validated above
if stride < 0:
new_start = stop + 1 + abs(start - stop - 1) % abs(stride)
start_indices.append(new_start)
slice_sizes.append(max(0, start + 1 - new_start))
rev_axes.append(axis)
else:
start_indices.append(start)
slice_sizes.append(max(0, stop - start))
elif pidx.typ == IndexType.DYNAMIC_SLICE:
assert isinstance(pidx.index, indexing.Slice)
start_indices.append(pidx.index.start)
slice_sizes.append(pidx.index.size)
trivial_slicing = False
else:
raise TypeError(f"dynamic_slice: unrecognized index {pidx.index}")
if len(start_indices) > 1:
# We must be careful with dtypes because dynamic_slice requires all
# start indices to have matching types.
dt = lax_utils.int_dtype_for_shape(self.shape, signed=True)
start_indices = [lax.convert_element_type(i, dt) for i in start_indices]
return _DynamicSliceIndexer(
start_indices=start_indices,
slice_sizes=slice_sizes,
rev_axes=rev_axes,
squeeze_axes=squeeze_axes,
newaxis_dims=newaxis_dims,
normalize_indices=normalize_indices,
trivial_slicing=trivial_slicing,
)
def is_advanced_int_indexer(self):
"""Returns True if idx should trigger int array indexing, False otherwise."""
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
return any(idx.typ in [IndexType.ARRAY, IndexType.BOOLEAN] and np.ndim(idx.index) > 0 # pyrefly: ignore[bad-argument-type]
for idx in self.indices)
def to_gather(self, x_sharding: NamedSharding | Any,
normalize_indices: bool = True) -> _GatherIndexer:
return _index_to_gather(self, x_sharding=x_sharding, normalize_indices=normalize_indices)
def tree_flatten(self):
# split dynamic and static indices
def is_dynamic(i: ParsedIndex):
# These index types are non-hashable and therefore must be dynamic.
return i.typ in [IndexType.ARRAY, IndexType.BOOLEAN, IndexType.DYNAMIC_SLICE]
raw_dynamic_indices = [i.index if is_dynamic(i) else None for i in self.indices]
static_metadata = [
ParsedIndex(index=None, typ=i.typ, consumed_axes=i.consumed_axes) if is_dynamic(i) else i
for i in self.indices]
return raw_dynamic_indices, (self.shape, static_metadata)
@classmethod
def tree_unflatten(cls, aux_data, children):
shape, static_metadata = aux_data
indices = [idx if dyn_index is None else ParsedIndex(dyn_index, idx.typ, idx.consumed_axes)
for dyn_index, idx in safe_zip(children, static_metadata)]
return cls(indices=indices, shape=shape)
@export
def take(
a: ArrayLike,
indices: ArrayLike,
axis: int | None = None,
out: None = None,
mode: str | None = None,
unique_indices: bool = False,
indices_are_sorted: bool = False,
fill_value: StaticScalar | None = None,
) -> Array:
"""Take elements from an array.
JAX implementation of :func:`numpy.take`, implemented in terms of
:func:`jax.lax.gather`. JAX's behavior differs from NumPy in the case
of out-of-bound indices; see the ``mode`` parameter below.
Args:
a: array from which to take values.
indices: N-dimensional array of integer indices of values to take from the array.
axis: the axis along which to take values. If not specified, the array will
be flattened before indexing is applied.
mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default
``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices;
the ``fill_value`` argument gives control over this value. For more discussion
of ``mode`` options, see :attr:`jax.numpy.ndarray.at`.
fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'.
Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for
signed types, the largest positive value for unsigned types, and True for booleans.
unique_indices: If True, the implementation will assume that the indices are unique
after normalization of negative indices, which lets the compiler emit more efficient
code during the backward pass. If set to True and normalized indices are not unique,
the result is implementation-defined and may be non-deterministic.
indices_are_sorted : If True, the implementation will assume that the indices are
sorted in ascending order after normalization of negative indices, which can lead
to more efficient execution on some backends. If set to True and normalized indices
are not sorted, the output is implementation-defined.
Returns:
Array of values extracted from ``a``.
See also:
- :attr:`jax.numpy.ndarray.at`: take values via indexing syntax.
- :func:`jax.numpy.take_along_axis`: take values along an axis
Examples:
>>> x = jnp.array([[1., 2., 3.],
... [4., 5., 6.]])
>>> indices = jnp.array([2, 0])
Passing no axis results in indexing into the flattened array:
>>> jnp.take(x, indices)
Array([3., 1.], dtype=float32)
>>> x.ravel()[indices] # equivalent indexing syntax
Array([3., 1.], dtype=float32)
Passing an axis results ind applying the index to every subarray along the axis:
>>> jnp.take(x, indices, axis=1)
Array([[3., 1.],
[6., 4.]], dtype=float32)
>>> x[:, indices] # equivalent indexing syntax
Array([[3., 1.],
[6., 4.]], dtype=float32)
Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`:
>>> jnp.take(x, indices, axis=0)
Array([[nan, nan, nan],
[ 1., 2., 3.]], dtype=float32)
>>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax
Array([[nan, nan, nan],
[ 1., 2., 3.]], dtype=float32)
This default out-of-bound behavior can be adjusted using the ``mode`` parameter, for
example, we can instead clip to the last valid value:
>>> jnp.take(x, indices, axis=0, mode='clip')
Array([[4., 5., 6.],
[1., 2., 3.]], dtype=float32)
>>> x.at[indices].get(mode='clip') # equivalent indexing syntax
Array([[4., 5., 6.],
[1., 2., 3.]], dtype=float32)
"""
return _take(a, indices, None if axis is None else operator.index(axis), out,
mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
fill_value=fill_value)
@api.jit(static_argnames=('axis', 'mode', 'unique_indices', 'indices_are_sorted', 'fill_value'))
def _take(a, indices, axis: int | None = None, out=None, mode=None,
unique_indices=False, indices_are_sorted=False, fill_value=None):
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
a, indices = util.ensure_arraylike("take", a, indices)
if axis is None:
a = a.ravel()
axis_idx = 0
else:
axis_idx = canonicalize_axis(axis, np.ndim(a))
if mode is None or mode == "fill":
gather_mode = slicing.GatherScatterMode.FILL_OR_DROP
# lax.gather() does not support negative indices, so we wrap them here
indices = util._where(indices < 0, indices + a.shape[axis_idx], indices)
elif mode == "raise":
# TODO(phawkins): we have no way to report out of bounds errors yet.
raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
elif mode == "wrap":
indices = ufuncs.mod(indices, lax._const(indices, a.shape[axis_idx]))
gather_mode = slicing.GatherScatterMode.PROMISE_IN_BOUNDS
elif mode == "clip":
gather_mode = slicing.GatherScatterMode.CLIP
else:
raise ValueError(f"Invalid mode '{mode}' for np.take")
index_dims = len(np.shape(indices))
slice_sizes = list(np.shape(a))
if slice_sizes[axis_idx] == 0:
if indices.size != 0:
raise IndexError("Cannot do a non-empty jnp.take() from an empty axis.")
return a
if indices.size == 0:
out_shape = (slice_sizes[:axis_idx] + list(indices.shape) +
slice_sizes[axis_idx + 1:])
return lax.full_like(a, 0, shape=out_shape)
slice_sizes[axis_idx] = 1
dnums = slicing.GatherDimensionNumbers(
offset_dims=tuple(
list(range(axis_idx)) +
list(range(axis_idx + index_dims, len(a.shape) + index_dims - 1))),
collapsed_slice_dims=(axis_idx,),
start_index_map=(axis_idx,))
return slicing.gather(a, indices[..., None], dimension_numbers=dnums,
slice_sizes=tuple(slice_sizes),
mode=gather_mode, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, fill_value=fill_value)
def _normalize_index(index, axis_size):
"""Normalizes an index value in the range [-N, N) to the range [0, N)."""
if dtypes.issubdtype(dtypes.dtype(index), np.unsignedinteger):
return index
if core.is_constant_dim(axis_size):
axis_size_val = lax._const(index, axis_size)
else:
axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size),
dtypes.dtype(index))
if isinstance(index, (int, np.integer)):
return lax.add(index, axis_size_val) if index < 0 else index
else:
return lax.select(index < 0, lax.add(index, axis_size_val), index)
@export
@api.jit(static_argnames=('axis', 'mode', 'fill_value'))
def take_along_axis(
arr: ArrayLike,
indices: ArrayLike,
axis: int | None = -1,
mode: str | slicing.GatherScatterMode | None = None,
fill_value: StaticScalar | None = None,
) -> Array:
"""Take elements from an array.
JAX implementation of :func:`numpy.take_along_axis`, implemented in
terms of :func:`jax.lax.gather`. JAX's behavior differs from NumPy
in the case of out-of-bound indices; see the ``mode`` parameter below.
Args:
a: array from which to take values.
indices: array of integer indices. If ``axis`` is ``None``, must be one-dimensional.
If ``axis`` is not None, must have ``a.ndim == indices.ndim``, and ``a`` must be
broadcast-compatible with ``indices`` along dimensions other than ``axis``.
axis: the axis along which to take values. If not specified, the array will
be flattened before indexing is applied.
mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default
``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices.
For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`.
Returns:
Array of values extracted from ``a``.
See also:
- :attr:`jax.numpy.ndarray.at`: take values via indexing syntax.
- :func:`jax.numpy.take`: take the same indices along every axis slice.
Examples:
>>> x = jnp.array([[1., 2., 3.],
... [4., 5., 6.]])
>>> indices = jnp.array([[0, 2],
... [1, 0]])
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1., 3.],
[5., 4.]], dtype=float32)
>>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax
Array([[1., 3.],
[5., 4.]], dtype=float32)
Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`:
>>> indices = jnp.array([[1, 0, 2]])
>>> jnp.take_along_axis(x, indices, axis=0)
Array([[ 4., 2., nan]], dtype=float32)
>>> x.at[indices, jnp.arange(3)].get(
... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax
Array([[ 4., 2., nan]], dtype=float32)
``take_along_axis`` is helpful for extracting values from multi-dimensional
argsorts and arg reductions. For, here we compute :func:`~jax.numpy.argsort`
indices along an axis, and use ``take_along_axis`` to construct the sorted
array:
>>> x = jnp.array([[5, 3, 4],
... [2, 7, 6]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 2, 0],
[0, 2, 1]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[3, 4, 5],
[2, 6, 7]], dtype=int32)
Similarly, we can use :func:`~jax.numpy.argmin` with ``keepdims=True`` and
use ``take_along_axis`` to extract the minimum value:
>>> idx = jnp.argmin(x, axis=1, keepdims=True)
>>> idx
Array([[1],
[0]], dtype=int32)
>>> jnp.take_along_axis(x, idx, axis=1)
Array([[3],
[2]], dtype=int32)
"""
a, indices = util.ensure_arraylike("take_along_axis", arr, indices)
index_dtype = indices.dtype
idx_shape = np.shape(indices)
if not dtypes.issubdtype(index_dtype, np.integer):
raise TypeError("take_along_axis indices must be of integer type, got "
f"{index_dtype}")
if axis is None:
if np.ndim(indices) != 1:
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
raise ValueError(msg.format(idx_shape))
a = a.ravel()
axis = 0
rank = a.ndim
if rank != np.ndim(indices):
msg = "indices and arr must have the same number of dimensions; {} vs. {}"
raise ValueError(msg.format(np.ndim(indices), a.ndim))
axis_int = canonicalize_axis(axis, rank)
def replace(tup, val):
lst = list(tup)
lst[axis_int] = val
return tuple(lst)
axis_size = a.shape[axis_int]
arr_shape = replace(a.shape, 1)
out_shape = lax.broadcast_shapes(idx_shape, arr_shape)
if axis_size == 0:
return lax.full(out_shape, 0, a.dtype)
index_dtype = lax_utils.int_dtype_for_dim(a.shape, signed=True)
indices = lax.convert_element_type(indices, index_dtype)
if mode != "promise_in_bounds":
indices = _normalize_index(indices, axis_size)
if mode == "one_hot":
from jax import nn # pytype: disable=import-error
hot = nn.one_hot(indices, axis_size, dtype=np.bool_)
if a.ndim == 1:
return einsum.einsum("...b,b->...", hot, a, preferred_element_type=a.dtype)
if axis_int > len(string.ascii_letters) - 2:
raise ValueError(
"One Hot indexing is only supported for up to 50 leading dimensions."
)
labels = "".join([string.ascii_letters[i] for i in range(axis_int)])
eq = labels + "y...z," + labels + "z...->" + labels + "y..."
return einsum.einsum(
eq,
hot,
a,
precision=lax.Precision.HIGHEST,
preferred_element_type=a.dtype,
)
index_dims = [i for i, idx in enumerate(idx_shape) if i == axis_int or not core.definitely_equal(idx, 1)]
gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,)
gather_indices = lax.reshape(indices, gather_index_shape)
slice_sizes = []
offset_dims = []
start_index_map = []
collapsed_slice_dims = []
operand_batching_dims = []
start_indices_batching_dims = []
# We will squeeze the array. i is the index of the unsqueezed shape, while
# new_i is the index of the squeezed shape. j is the index of the gather
# indices.
dims_to_squeeze = []
new_i = 0
j = 0
for i in range(rank):
if i == axis_int:
slice_sizes.append(1)
start_index_map.append(new_i)
collapsed_slice_dims.append(new_i)
new_i += 1
j += 1
elif core.definitely_equal(idx_shape[i], 1):
# If idx_shape[i] == 1, we can just take the entirety of the arr's axis
# and avoid forming an iota index.
offset_dims.append(i)
slice_sizes.append(arr_shape[i])
new_i += 1
elif core.definitely_equal(arr_shape[i], 1):
# If the array dimension is 1 but the index dimension is not, we will
# squeeze this dimension.
dims_to_squeeze.append(i)
j += 1
else:
# Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both
# array and index as batching so corresponding elements are gathered.
if core.definitely_equal(arr_shape[i], 0):
slice_sizes.append(0)
else:
slice_sizes.append(1)
operand_batching_dims.append(new_i)
start_indices_batching_dims.append(j)
new_i += 1
j += 1
# Squeeze a to remove singleton dimensions.
a = lax.squeeze(a, dims_to_squeeze)
dnums = slicing.GatherDimensionNumbers(
offset_dims=tuple(offset_dims),
collapsed_slice_dims=tuple(collapsed_slice_dims),
start_index_map=tuple(start_index_map),
operand_batching_dims=tuple(operand_batching_dims),
start_indices_batching_dims=tuple(start_indices_batching_dims))
return slicing.gather(a, gather_indices, dnums, tuple(slice_sizes),
mode="fill" if mode is None else mode, fill_value=fill_value)
def _make_along_axis_idx(shape, indices, axis):
if axis < 0:
axis += len(shape)
return tuple_update(lax_numpy.indices(shape, sparse=True), axis, indices)
@export
@api.jit(static_argnames=('axis', 'inplace', 'mode'))
def put_along_axis(
arr: ArrayLike,
indices: ArrayLike,
values: ArrayLike,
axis: int | None,
inplace: bool = True,
*,
mode: str | None = None,
) -> Array:
"""Put values into the destination array by matching 1d index and data slices.
JAX implementation of :func:`numpy.put_along_axis`.
The semantics of :func:`numpy.put_along_axis` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
arr: array into which values will be put.
indices: array of indices at which to put values.
values: array of values to put into the array.
axis: the axis along which to put values. If not specified, the array will
be flattened before indexing is applied.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
mode: Out-of-bounds indexing mode. For more discussion of ``mode`` options,
see :attr:`jax.numpy.ndarray.at`.
Returns:
A copy of ``a`` with specified entries updated.
See Also:
- :func:`jax.numpy.put`: put elements into an array at given indices.
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
- :func:`jax.numpy.take`: extract values from an array at given indices.
- :func:`jax.numpy.take_along_axis`: extract values from an array along an axis.
Examples:
>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
[0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
[99 40 50]]
"""
if inplace:
raise ValueError(
"jax.numpy.put_along_axis cannot modify arrays in-place, because JAX arrays"
"are immutable. Pass inplace=False to instead return an updated array.")
arr, indices, values = util.ensure_arraylike("put_along_axis", arr, indices, values)
original_axis = axis
original_arr_shape = arr.shape
if axis is None:
arr = arr.ravel()
axis = 0
if not arr.ndim == indices.ndim:
raise ValueError(
"put_along_axis arguments 'arr' and 'indices' must have same ndim. Got "
f"{arr.ndim=} and {indices.ndim=}."
)
try:
values = util._broadcast_to(values, indices.shape)
except ValueError:
raise ValueError(
"put_along_axis argument 'values' must be broadcastable to 'indices'. Got "
f"{values.shape=} and {indices.shape=}."
)
idx = _make_along_axis_idx(arr.shape, indices, axis)
result = arr.at[idx].set(values, mode=mode)
if original_axis is None:
result = result.reshape(original_arr_shape)
return result
### Indexing
def _is_integer_index(idx: Any) -> bool:
return isinstance(idx, (int, np.integer)) and not isinstance(idx, (bool, np.bool_))
class IndexingStrategy(enum.Enum):
AUTO = 'auto'
GATHER = 'gather'
SCATTER = 'scatter'
STATIC_SLICE = 'static_slice'
DYNAMIC_SLICE = 'dynamic_slice'
def rewriting_take(
arr: Array,
idx: Index | tuple[Index, ...], *,
indices_are_sorted: bool = False,
unique_indices: bool = False,
mode: str | slicing.GatherScatterMode | None = None,
fill_value: ArrayLike | None = None,
normalize_indices: bool = True,
out_sharding: NamedSharding | PartitionSpec | None = None,
strategy: IndexingStrategy = IndexingStrategy.AUTO,
) -> Array:
# Computes arr[idx].
# All supported cases of indexing can be implemented as an XLA gather,
# followed by an optional reverse and broadcast_in_dim.
indexer = NDIndexer.from_raw_indices(idx, arr.shape)
if not isinstance(strategy, IndexingStrategy):
raise TypeError(f"Expected strategy to be IndexingStrategy; got {strategy}")
if config.check_static_indices.value and (mode is None or slicing.GatherScatterMode.from_any(mode) == slicing.GatherScatterMode.PROMISE_IN_BOUNDS):
indexer.validate_static_indices(normalize_indices=normalize_indices)
if strategy == IndexingStrategy.STATIC_SLICE:
static_slice_indexer = indexer.to_static_slice(
arr_is_sharded=indexer.is_sharded(arr),
normalize_indices=normalize_indices,
mode=mode)
return _static_slice(arr, static_slice_indexer)
if strategy == IndexingStrategy.DYNAMIC_SLICE:
dynamic_slice_indexer = indexer.to_dynamic_slice(
arr_is_sharded=indexer.is_sharded(arr),
normalize_indices=normalize_indices,
mode=mode)
return _dynamic_slice(arr, dynamic_slice_indexer)
if strategy == IndexingStrategy.AUTO:
# Attempt static slice first
try:
static_slice_indexer = indexer.to_static_slice(
arr_is_sharded=indexer.is_sharded(arr),
normalize_indices=normalize_indices,
mode=mode)
except (TypeError, ValueError, IndexError):
pass
else:
return _static_slice(arr, static_slice_indexer)
# Attempt dynamic slice next
try:
dynamic_slice_indexer = indexer.to_dynamic_slice(
arr_is_sharded=indexer.is_sharded(arr),
normalize_indices=normalize_indices,
mode=mode)
except (TypeError, ValueError, IndexError):
pass
else:
return _dynamic_slice(arr, dynamic_slice_indexer)
# In remaining cases, compute via gather.
indexer = indexer.expand_bool_indices()
dynamic_idx, treedef = tree_flatten(indexer)
internal_gather = partial(
_gather, treedef=treedef,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode, fill_value=fill_value, normalize_indices=normalize_indices)
if out_sharding is not None:
out_sharding = canonicalize_sharding(out_sharding, 'take')
return auto_axes(internal_gather, out_sharding=out_sharding,
axes=out_sharding.mesh.explicit_axes,
)(arr, dynamic_idx)
return internal_gather(arr, dynamic_idx)
def _static_slice(arr: Array, indexer: _StaticSliceIndexer) -> Array:
"""Equivalent of arr[idx] implemented in terms of static :func:`lax.slice` operations.
This supports only INTEGER, ELLIPSIS, NONE, and SLICE indices, and will raise a
TypeError if other indices are present.
"""
if indexer.is_trivial_slice(arr.shape):
result = arr
else:
result = slicing.slice(arr, indexer.start_indices,
indexer.limit_indices, indexer.strides)
if indexer.rev_axes:
result = lax.rev(result, indexer.rev_axes)
if indexer.squeeze_axes:
result = lax.squeeze(result, indexer.squeeze_axes)
if indexer.newaxis_dims:
result = lax.expand_dims(result, indexer.newaxis_dims)
return result
def _dynamic_slice(arr: Array, indexer: _DynamicSliceIndexer) -> Array:
"""Equivalent of arr[idx] implemented in terms of static :func:`lax.dynamic_slice`.
This supports only INTEGER, ELLIPSIS, NONE, SLICE, and scalar ARRAY indices,
and will raise a TypeError if other indices are present.
"""
if indexer.trivial_slicing:
result = arr
else:
result = slicing.dynamic_slice(
arr,
start_indices=indexer.start_indices,
slice_sizes=indexer.slice_sizes,
allow_negative_indices=indexer.normalize_indices)
if indexer.rev_axes:
result = lax.rev(result, indexer.rev_axes)
if indexer.squeeze_axes:
result = lax.squeeze(result, indexer.squeeze_axes)
if indexer.newaxis_dims:
result = lax.expand_dims(result, indexer.newaxis_dims)
return result
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
# @api.jit(static_argnums=(1, 2))
def _gather(arr, dynamic_idx, *, treedef, indices_are_sorted,
unique_indices, mode, fill_value, normalize_indices):
parsed_idx = tree_unflatten(treedef, dynamic_idx)
indexer = parsed_idx.to_gather(core.typeof(arr).sharding,
normalize_indices=normalize_indices)
jnp_error._check_precondition_oob_gather(arr.shape, indexer.gather_indices)
y = arr
if fill_value is not None:
core.concrete_or_error(None, fill_value,
"fill_value argument to indexed get()")
if np.ndim(fill_value) != 0:
raise ValueError("fill_value argument to indexed get() must be a scalar")
if isinstance(fill_value, np.ndarray):
fill_value = fill_value.item()
if indexer.scalar_bool_dims:
y = lax.expand_dims(y, indexer.scalar_bool_dims)
# Avoid calling gather if the slice shape is empty, both as a fast path and to
# handle cases like zeros(0)[array([], int32)].
if core.is_empty_shape(indexer.slice_shape):
return lax.full_like(y, 0, shape=indexer.slice_shape,
sharding=indexer.slice_sharding)
# We avoid generating a gather when indexer.gather_indices.size is empty.
if not core.is_empty_shape(indexer.gather_indices.shape):
y = slicing.gather(
y, indexer.gather_indices, indexer.dnums, indexer.gather_slice_shape,
unique_indices=unique_indices or indexer.unique_indices,
indices_are_sorted=indices_are_sorted or indexer.indices_are_sorted,
mode=mode, fill_value=fill_value)
# Reverses axes with negative strides.
if indexer.reversed_y_dims:
y = lax.rev(y, indexer.reversed_y_dims)
# This adds np.newaxis/None dimensions.
return lax.expand_dims(y, indexer.newaxis_dims)
class _StaticSliceIndexer(NamedTuple):
start_indices: Sequence[int]
limit_indices: Sequence[int]
strides: Sequence[int] | None
rev_axes: Sequence[int]
squeeze_axes: Sequence[int]
newaxis_dims: Sequence[int]
def is_trivial_slice(self, arr_shape: Sequence[int]):
if self.strides is not None or len(arr_shape) != len(self.start_indices):
return False
return all(
(start, stop) == (0, size)
for start, stop, size in zip(self.start_indices, self.limit_indices, arr_shape)
)
class _DynamicSliceIndexer(NamedTuple):
start_indices: Sequence[ArrayLike]
slice_sizes: Sequence[int]
rev_axes: Sequence[int]
squeeze_axes: Sequence[int]
newaxis_dims: Sequence[int]
trivial_slicing: bool
normalize_indices: bool
class _GatherIndexer(NamedTuple):
# The expected shape of the slice output.
slice_shape: Sequence[int]
# The slice shape to pass to lax.gather().
gather_slice_shape: Sequence[int]
# The gather indices to use.
gather_indices: ArrayLike
# A GatherDimensionNumbers object describing the gather to perform.
dnums: slicing.GatherDimensionNumbers
# Are the gather_indices known to be non-overlapping and/or sorted?
# (In practice, these translate to "there no advanced indices", because
# only advanced indices could lead to index repetition.)
unique_indices: bool
indices_are_sorted: bool
# Slice dimensions that have negative strides, and so must be reversed after
# the gather.
reversed_y_dims: Sequence[int]
# Keep track of any axes created by `newaxis`. These must be inserted for
# gathers and eliminated for scatters.
newaxis_dims: Sequence[int]
# Keep track of dimensions with scalar bool indices. These must be inserted
# for gathers before performing other index operations.
scalar_bool_dims: Sequence[int]
# The expected sharding of the slice output.
slice_sharding: NamedSharding | None = None
def _index_to_gather(indexer: NDIndexer, *, x_sharding: NamedSharding | Any,
normalize_indices: bool = True) -> _GatherIndexer:
indexer.validate_slices()
indexer = indexer.convert_sequences_to_arrays()
is_advanced = np.nonzero(
np.array([idx.typ in {IndexType.ARRAY, IndexType.INTEGER} for idx in indexer.indices]))
advanced_axes_are_contiguous = np.all(np.diff(is_advanced) == 1)
indexer = indexer.expand_ellipses()
scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(indexer.indices) if i.typ == IndexType.BOOLEAN]
indexer, x_spec = indexer.expand_scalar_bool_indices(x_sharding.spec)
if normalize_indices:
indexer = indexer.normalize_indices()
# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
# The advanced indices.
advanced_indexes: Sequence[Array] = []
# The positions of the advanced indexing axes in `idx`.
idx_advanced_axes: Sequence[int] = []
# The positions of the advanced indexes in x's shape.
# collapsed, after None axes have been removed. See below.
x_advanced_axes: Sequence[int] = []
if indexer.is_advanced_int_indexer():
idx_without_none = [(i, d) for i, d in enumerate(indexer.indices) if d.typ != IndexType.NONE]
advanced_pairs = (
(lax_numpy.asarray(e.index), i, j)
for j, (i, e) in enumerate(idx_without_none)
if e.typ in [IndexType.ARRAY, IndexType.INTEGER]
)
advanced_indexes, idx_advanced_axes, x_advanced_axes = unzip3(advanced_pairs)
x_axis = 0 # Current axis in x.
y_axis = 0 # Current axis in y, before collapsing. See below.
collapsed_y_axis = 0 # Current axis in y, after collapsing.
# Scatter dimension numbers.
offset_dims: list[int] = []
collapsed_slice_dims: list[int] = []
start_index_map: list[int] = []
index_dtype = lax_utils.int_dtype_for_shape(indexer.shape, signed=True)
# Gather indices.
# Pairs of (array, start_dim) values. These will be broadcast into
# gather_indices_shape, with the array dimensions aligned to start_dim, and
# then concatenated.
gather_indices: list[tuple[Array, int]] = []
gather_indices_shape: list[int] = []
# We perform three transformations to y before the scatter op, in order:
# First, y is broadcast to slice_shape. In general `y` only need broadcast to
# the right shape.
slice_shape: list[int] = []
# Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
# indices, which the scatter cannot remove itself.
newaxis_dims: list[int] = []
# Finally, we reverse reversed_y_dims to handle slices with negative strides.
reversed_y_dims: list[int] = []
gather_slice_shape: list[int] = []
slice_spec = []
for idx_pos, index in enumerate(indexer.indices):
# Handle the advanced indices here if:
# * the advanced indices were not contiguous and we are the start.
# * we are at the position of the first advanced index.
if (advanced_indexes and
(advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
not advanced_axes_are_contiguous and idx_pos == 0)):
advanced_index_arrs = util._broadcast_arrays(*advanced_indexes)
shape = advanced_index_arrs[0].shape
aia_spec = core.typeof(advanced_index_arrs[0]).sharding.spec
ndim = len(shape)
start_dim = len(gather_indices_shape)
gather_indices.extend(
(lax.convert_element_type(a, index_dtype), start_dim)
for a in advanced_index_arrs
)
gather_indices_shape += shape
assert x_advanced_axes is not None
start_index_map.extend(x_advanced_axes)
collapsed_slice_dims.extend(x_advanced_axes)
slice_shape.extend(shape)
slice_spec.extend(aia_spec)
y_axis += ndim
collapsed_y_axis += ndim
# Per-index bookkeeping for advanced indexes.
if idx_pos in idx_advanced_axes:
x_axis += 1
gather_slice_shape.append(1)
continue
if index.typ in [IndexType.INTEGER, IndexType.ARRAY] and np.ndim(index.index) == 0: # pyrefly: ignore[bad-argument-type]
# Basic scalar int indices
if core.definitely_equal(indexer.shape[x_axis], 0):
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
i_converted = lax.convert_element_type(index.index, index_dtype) # pyrefly: ignore[bad-argument-type]
gather_indices.append((i_converted, len(gather_indices_shape)))
collapsed_slice_dims.append(x_axis)
gather_slice_shape.append(1)
start_index_map.append(x_axis)
x_axis += 1
elif index.typ == IndexType.NONE:
# None indexing: add a dimension.
slice_shape.append(1)
slice_spec.append(None)
newaxis_dims.append(y_axis)
y_axis += 1
elif index.typ in [IndexType.SLICE, IndexType.DYNAMIC_SLICE]:
# Handle static slice index.
if isinstance(index.index, indexing.Slice):
start, step, slice_size = index.index.start, index.index.stride, index.index.size
elif isinstance(index.index, slice):
start, step, slice_size = core.canonicalize_slice(index.index, indexer.shape[x_axis])
else:
raise RuntimeError(f"Internal: expected slice or Slice, got {type(index.index)}")
slice_shape.append(slice_size)
slice_spec.append(x_spec[x_axis])
if core.definitely_equal(step, 1):
# Optimization: avoid generating trivial gather.
if not core.definitely_equal(slice_size, indexer.shape[x_axis]):
gather_indices.append((lax.convert_element_type(start, index_dtype),
len(gather_indices_shape)))
start_index_map.append(x_axis)
gather_slice_shape.append(slice_size)
offset_dims.append(collapsed_y_axis)
else:
indices = (lax_numpy.array(start, dtype=index_dtype) +
lax_numpy.array(step, dtype=index_dtype) * lax.iota(index_dtype, slice_size))
if step < 0:
reversed_y_dims.append(collapsed_y_axis)
indices = lax.rev(indices, dimensions=(0,))
gather_slice_shape.append(1)
gather_indices.append((indices, len(gather_indices_shape)))
start_index_map.append(x_axis)
gather_indices_shape.append(slice_size)
collapsed_slice_dims.append(x_axis)
collapsed_y_axis += 1
y_axis += 1
x_axis += 1
else:
raise IndexError(f"Got unsupported indexer at position {idx_pos}: {index!r}")
if len(gather_indices) == 0:
gather_indices_array: ArrayLike = np.zeros((0,), dtype=index_dtype)
elif len(gather_indices) == 1:
g, _ = gather_indices[0]
gather_indices_array = lax.expand_dims(g, (g.ndim,))
else:
last_dim = len(gather_indices_shape)
gather_indices_shape.append(1)
gather_indices_array = lax.concatenate([
lax.broadcast_in_dim(g, gather_indices_shape, tuple(range(i, i + g.ndim)))
for g, i in gather_indices],
last_dim)
dnums = slicing.GatherDimensionNumbers(
offset_dims = tuple(offset_dims),
collapsed_slice_dims = tuple(sorted(collapsed_slice_dims)),
start_index_map = tuple(start_index_map)
)
slice_sharding = canonicalize_sharding(x_sharding.update(spec=slice_spec),
"index_to_gather")
return _GatherIndexer(
slice_shape=slice_shape,
newaxis_dims=tuple(newaxis_dims),
gather_slice_shape=gather_slice_shape,
reversed_y_dims=reversed_y_dims,
dnums=dnums,
gather_indices=gather_indices_array,
unique_indices=not advanced_indexes,
indices_are_sorted=not advanced_indexes,
scalar_bool_dims=scalar_bool_dims,
slice_sharding=slice_sharding)
def _should_unpack_list_index(x):
"""Helper for eliminate_deprecated_list_indexing."""
return (isinstance(x, (np.ndarray, Array))
and np.ndim(x) != 0
or isinstance(x, (Sequence, slice))
or x is Ellipsis or x is None)
def eliminate_deprecated_list_indexing(idx):
# "Basic slicing is initiated if the selection object is a non-array,
# non-tuple sequence containing slice objects, [Ellipses, or newaxis
# objects]". Detects this and raises a TypeError.
if not isinstance(idx, tuple):
if isinstance(idx, Sequence) and not isinstance(
idx, (Array, np.ndarray, str)
):
# As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
# others are converted to arrays, based on a set of somewhat convoluted heuristics
# (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343)
# In JAX, we raise an informative TypeError for *all* non-tuple sequences.
if any(_should_unpack_list_index(i) for i in idx):
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[tuple(seq)]` instead of `arr[seq]`. "
"See https://github.com/jax-ml/jax/issues/4564 for more information.")
else:
msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
"use `arr[array(seq)]` instead of `arr[seq]`. "
"See https://github.com/jax-ml/jax/issues/4564 for more information.")
raise TypeError(msg)
else:
idx = (idx,)
return idx
def _is_boolean_index(i):
try:
abstract_i = core.typeof(i)
except TypeError:
abstract_i = None
return (isinstance(abstract_i, core.ShapedArray) and dtypes.issubdtype(abstract_i.dtype, np.bool_)
or isinstance(i, list) and i and all(_is_scalar(e)
and dtypes.issubdtype(dtypes.dtype(e), np.bool_) for e in i))
def _is_slice_element_none_or_constant_or_symbolic(elt):
"""Return True if elt is a constant or None."""
if elt is None: return True
if core.is_symbolic_dim(elt): return True
try:
return core.is_concrete(elt)
except TypeError:
return False
def _is_scalar(x):
"""Checks if a Python or NumPy scalar."""
return np.isscalar(x) or (
isinstance(x, (np.ndarray, Array))
and np.ndim(x) == 0
)
@export
def place(arr: ArrayLike, mask: ArrayLike, vals: ArrayLike, *,
inplace: bool = True) -> Array:
"""Update array elements based on a mask.
JAX implementation of :func:`numpy.place`.
The semantics of :func:`numpy.place` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
arr: array into which values will be placed.
mask: boolean mask with the same size as ``arr``.
vals: values to be inserted into ``arr`` at the locations indicated
by mask. If too many values are supplied, they will be truncated.
If not enough values are supplied, they will be repeated.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
Returns:
A copy of ``arr`` with masked values set to entries from `vals`.
See Also:
- :func:`jax.numpy.put`: put elements into an array at numerical indices.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing
Examples:
>>> x = jnp.zeros((3, 5), dtype=int)
>>> mask = (jnp.arange(x.size) % 3 == 0).reshape(x.shape)
>>> mask
Array([[ True, False, False, True, False],
[False, True, False, False, True],
[False, False, True, False, False]], dtype=bool)
Placing a scalar value:
>>> jnp.place(x, mask, 1, inplace=False)
Array([[1, 0, 0, 1, 0],
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0]], dtype=int32)
In this case, ``jnp.place`` is similar to the masked array update syntax:
>>> x.at[mask].set(1)
Array([[1, 0, 0, 1, 0],
[0, 1, 0, 0, 1],
[0, 0, 1, 0, 0]], dtype=int32)
``place`` differs when placing values from an array. The array is repeated
to fill the masked entries:
>>> vals = jnp.array([1, 3, 5])
>>> jnp.place(x, mask, vals, inplace=False)
Array([[1, 0, 0, 3, 0],
[0, 5, 0, 0, 1],
[0, 0, 3, 0, 0]], dtype=int32)
"""
data, mask_arr, vals_arr = util.ensure_arraylike("place", arr, mask, vals)
vals_arr = vals_arr.ravel()
if inplace:
raise ValueError(
"jax.numpy.place cannot modify arrays in-place, because JAX arrays are immutable. "
"Pass inplace=False to instead return an updated array.")
if data.size != mask_arr.size:
raise ValueError("place: arr and mask must be the same size")
if not vals_arr.size:
raise ValueError("Cannot place values from an empty array")
if not data.size:
return data
indices = lax_numpy.where(mask_arr.ravel(), size=mask_arr.size, fill_value=mask_arr.size)[0]
vals_arr = lax_numpy._tile_to_size(vals_arr, len(indices))
return data.ravel().at[indices].set(vals_arr, mode='drop').reshape(data.shape)
@export
def put(a: ArrayLike, ind: ArrayLike, v: ArrayLike,
mode: str | None = None, *, inplace: bool = True) -> Array:
"""Put elements into an array at given indices.
JAX implementation of :func:`numpy.put`.
The semantics of :func:`numpy.put` are to modify arrays in-place, which
is not possible for JAX's immutable arrays. The JAX version returns a modified
copy of the input, and adds the ``inplace`` parameter which must be set to
`False`` by the user as a reminder of this API difference.
Args:
a: array into which values will be placed.
ind: array of indices over the flattened array at which to put values.
v: array of values to put into the array.
mode: string specifying how to handle out-of-bound indices. Supported values:
- ``"clip"`` (default): clip out-of-bound indices to the final index.
- ``"wrap"``: wrap out-of-bound indices to the beginning of the array.
inplace: must be set to False to indicate that the input is not modified
in-place, but rather a modified copy is returned.
Returns:
A copy of ``a`` with specified entries updated.
See Also:
- :func:`jax.numpy.place`: place elements into an array via boolean mask.
- :func:`jax.numpy.ndarray.at`: array updates using NumPy-style indexing.
- :func:`jax.numpy.take`: extract values from an array at given indices.
Examples:
>>> x = jnp.zeros(5, dtype=int)
>>> indices = jnp.array([0, 2, 4])
>>> values = jnp.array([10, 20, 30])
>>> jnp.put(x, indices, values, inplace=False)
Array([10, 0, 20, 0, 30], dtype=int32)
This is equivalent to the following :attr:`jax.numpy.ndarray.at` indexing syntax:
>>> x.at[indices].set(values)
Array([10, 0, 20, 0, 30], dtype=int32)
There are two modes for handling out-of-bound indices. By default they are
clipped:
>>> indices = jnp.array([0, 2, 6])
>>> jnp.put(x, indices, values, inplace=False, mode='clip')
Array([10, 0, 20, 0, 30], dtype=int32)
Alternatively, they can be wrapped to the beginning of the array:
>>> jnp.put(x, indices, values, inplace=False, mode='wrap')
Array([10, 30, 20, 0, 0], dtype=int32)
For N-dimensional inputs, the indices refer to the flattened array:
>>> x = jnp.zeros((3, 5), dtype=int)
>>> indices = jnp.array([0, 7, 14])
>>> jnp.put(x, indices, values, inplace=False)
Array([[10, 0, 0, 0, 0],
[ 0, 0, 20, 0, 0],
[ 0, 0, 0, 0, 30]], dtype=int32)
"""
if inplace:
raise ValueError(
"jax.numpy.put cannot modify arrays in-place, because JAX arrays are immutable. "
"Pass inplace=False to instead return an updated array.")
arr, ind_arr, _ = util.ensure_arraylike("put", a, ind, v)
ind_arr = ind_arr.ravel()
v_arr = lax_numpy.ravel(v)
if not arr.size or not ind_arr.size or not v_arr.size:
return arr
v_arr = lax_numpy._tile_to_size(v_arr, len(ind_arr))
if mode is None:
scatter_mode = "drop"
elif mode == "clip":
ind_arr = lax_numpy.clip(ind_arr, 0, arr.size - 1)
scatter_mode = "promise_in_bounds"
elif mode == "wrap":
ind_arr = ind_arr % arr.size
scatter_mode = "promise_in_bounds"
elif mode == "raise":
raise NotImplementedError("The 'raise' mode to jnp.put is not supported.")
else:
raise ValueError(f"mode should be one of 'wrap' or 'clip'; got {mode=}")
return arr.at[lax_numpy.unravel_index(ind_arr, arr.shape)].set(v_arr, mode=scatter_mode)