119 lines
3.3 KiB
Python
119 lines
3.3 KiB
Python
# Copyright 2026 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
|
|
from jax._src import core
|
|
from jax._src import tree_util
|
|
from jax._src.typing import Array
|
|
|
|
|
|
@tree_util.register_pytree_node_class
|
|
@dataclasses.dataclass
|
|
class Slice:
|
|
"""A slice with a start index and a size.
|
|
|
|
Both start index and size can either be static, i.e. known at tracing
|
|
and compilation time, or dynamic.
|
|
"""
|
|
|
|
start: int | Array
|
|
size: int | Array
|
|
stride: int = 1
|
|
|
|
def __post_init__(self):
|
|
if self.stride < 0:
|
|
raise ValueError("`stride` must be >= 0.")
|
|
|
|
@property
|
|
def is_dynamic_start(self):
|
|
return not core.is_dim(self.start)
|
|
|
|
@property
|
|
def is_dynamic_size(self):
|
|
return not core.is_dim(self.size)
|
|
|
|
def tree_flatten(self):
|
|
# If `start` is statically known, we treat it as static information
|
|
xs = ()
|
|
data = ()
|
|
xs += (self.start,) if self.is_dynamic_start else (None,)
|
|
data += (None,) if self.is_dynamic_start else (self.start,)
|
|
xs += (self.size,) if self.is_dynamic_size else (None,)
|
|
data += (None,) if self.is_dynamic_size else (self.size,)
|
|
data += (self.stride,)
|
|
return xs, data
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, aux_data, children) -> Slice:
|
|
start, size = (
|
|
a if a is not None else b for a, b in zip(children, aux_data[:2])
|
|
)
|
|
return cls(start, size, aux_data[2])
|
|
|
|
@classmethod
|
|
def from_slice(cls, slc: slice, size: int) -> Slice:
|
|
start, step, size = core.canonicalize_slice(slc, size)
|
|
if step < 1:
|
|
raise ValueError(f"slice must have a step >= 1 (found: {step})")
|
|
return cls(start, size, step)
|
|
|
|
|
|
def dslice(
|
|
start: int | Array | None,
|
|
size: int | Array | None = None,
|
|
stride: int | None = None,
|
|
) -> slice | Slice:
|
|
"""Constructs a ``Slice`` from a start index and a size.
|
|
|
|
The semantics of ``dslice`` mirror those of the builtin ``slice`` type:
|
|
|
|
* ``dslice(None)`` is ``:``
|
|
* ``dslice(j)`` is ``:j``
|
|
* ``dslice(i, j)`` is ``i:i+j``
|
|
* ``dslice(i, j, stride)`` is ``i:i+j:stride``
|
|
|
|
Examples:
|
|
|
|
>>> x = jax.numpy.arange(10)
|
|
>>> i = 4
|
|
>>> x[i: i + 2] # standard indexing requires i to be static
|
|
Array([4, 5], dtype=int32)
|
|
>>> x[jax.ds(i, 2)] # equivalent which allows i to be dynamic
|
|
Array([4, 5], dtype=int32)
|
|
|
|
Here is an explicit example of slicing with a dynamic start index:
|
|
|
|
>>> @jax.jit(static_argnames='size')
|
|
... def f(x, i, size): # example of when `
|
|
... return x[jax.ds(i, size)]
|
|
...
|
|
>>> f(x, i, 2)
|
|
Array([4, 5], dtype=int32)
|
|
"""
|
|
if start is None:
|
|
return slice(None)
|
|
if stride is None:
|
|
stride = 1
|
|
if not isinstance(stride, int):
|
|
raise ValueError("Non-static stride in `dslice`")
|
|
if size is None:
|
|
start, size = 0, start
|
|
return Slice(start, size, stride)
|
|
|
|
|
|
ds = dslice # Handy alias.
|