hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,430 @@
|
||||
# Copyright 2019 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Helpers for indexed updates.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.numpy import indexing
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes
|
||||
from jax._src.pjit import auto_axes
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.typing import Array, ArrayLike, Index
|
||||
|
||||
|
||||
def _scatter_update(x: ArrayLike, idx: Index | tuple[Index, ...],
|
||||
y: ArrayLike, scatter_op: Callable[..., Array],
|
||||
indices_are_sorted: bool, unique_indices: bool,
|
||||
mode: slicing.GatherScatterMode | str | None = None, normalize_indices: bool = True,
|
||||
out_sharding: NamedSharding | None = None):
|
||||
"""Helper for indexed updates.
|
||||
|
||||
Computes the value of x that would result from computing::
|
||||
x[idx] op= y
|
||||
except in a pure functional way, with no in-place updating.
|
||||
|
||||
Args:
|
||||
x: ndarray to be updated.
|
||||
idx: None, an integer, a slice, an ellipsis, an ndarray with integer dtype,
|
||||
or a tuple of those indicating the locations of `x` into which to scatter-
|
||||
update the values in `y`.
|
||||
y: values to be scattered.
|
||||
scatter_op: callable, one of lax.scatter, lax.scatter_add, lax.scatter_min,
|
||||
or lax_scatter_max.
|
||||
indices_are_sorted: whether `idx` is known to be sorted
|
||||
unique_indices: whether `idx` is known to be free of duplicates
|
||||
|
||||
Returns:
|
||||
An ndarray representing an updated `x` after performing the scatter-update.
|
||||
"""
|
||||
x = jnp.asarray(x)
|
||||
if (isinstance(y, int) and np.issubdtype(x.dtype, np.integer) and
|
||||
np.iinfo(x.dtype).min <= y <= np.iinfo(x.dtype).max):
|
||||
y = jnp.asarray(y, dtype=x.dtype)
|
||||
else:
|
||||
y = jnp.asarray(y)
|
||||
|
||||
# XLA gathers and scatters are very similar in structure; the scatter logic
|
||||
# is more or less a transpose of the gather equivalent.
|
||||
indexer = indexing.NDIndexer.from_raw_indices(idx, x.shape).expand_bool_indices()
|
||||
dynamic_idx, treedef = tree_util.tree_flatten(indexer)
|
||||
dynamic_idx = tuple(dynamic_idx)
|
||||
internal_scatter = partial(
|
||||
_scatter_impl, scatter_op=scatter_op, treedef=treedef,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode,
|
||||
normalize_indices=normalize_indices)
|
||||
if out_sharding is not None:
|
||||
return auto_axes(internal_scatter, out_sharding=out_sharding,
|
||||
axes=out_sharding.mesh.explicit_axes
|
||||
)(x, y, dynamic_idx)
|
||||
return internal_scatter(x, y, dynamic_idx)
|
||||
|
||||
|
||||
# TODO(phawkins): re-enable jit after fixing excessive recompilation for
|
||||
# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
|
||||
def _scatter_impl(x: ArrayLike, y: ArrayLike, dynamic_idx: tuple[Any, ...], *,
|
||||
scatter_op: Callable[..., Array],
|
||||
treedef: tree_util.PyTreeDef,
|
||||
indices_are_sorted: bool, unique_indices: bool,
|
||||
mode: slicing.GatherScatterMode | str | None, normalize_indices: bool):
|
||||
dtype = lax.dtype(x)
|
||||
weak_type = dtypes.is_weakly_typed(x)
|
||||
|
||||
if not dtypes.safe_to_cast(y, x):
|
||||
# TODO(jakevdp): change this to an error after the deprecation period.
|
||||
warnings.warn(
|
||||
"scatter inputs have incompatible types: cannot safely cast value "
|
||||
f"from dtype={lax.dtype(y)} to dtype={lax.dtype(x)} with "
|
||||
f"jax_numpy_dtype_promotion={config.numpy_dtype_promotion.value}. "
|
||||
"In future JAX releases this will result in an error.",
|
||||
FutureWarning)
|
||||
|
||||
general_indexer = tree_util.tree_unflatten(treedef, dynamic_idx)
|
||||
indexer = general_indexer.to_gather(
|
||||
core.typeof(x).sharding, normalize_indices=normalize_indices)
|
||||
|
||||
# Avoid calling scatter if the slice shape is empty, both as a fast path and
|
||||
# to handle cases like zeros(0)[array([], int32)].
|
||||
if core.is_empty_shape(indexer.slice_shape):
|
||||
return x
|
||||
|
||||
x, y = promote_dtypes(x, y)
|
||||
|
||||
# Broadcast `y` to the slice output shape.
|
||||
y = jnp.broadcast_to(y, tuple(indexer.slice_shape),
|
||||
out_sharding=indexer.slice_sharding)
|
||||
# Collapse any `None`/`np.newaxis` dimensions.
|
||||
y = jnp.squeeze(y, axis=indexer.newaxis_dims)
|
||||
if indexer.reversed_y_dims:
|
||||
y = lax.rev(y, indexer.reversed_y_dims)
|
||||
|
||||
if indexer.scalar_bool_dims:
|
||||
x = lax.expand_dims(x, indexer.scalar_bool_dims)
|
||||
|
||||
# Transpose the gather dimensions into scatter dimensions (cf.
|
||||
# lax._gather_transpose_rule)
|
||||
dnums = slicing.ScatterDimensionNumbers(
|
||||
update_window_dims=indexer.dnums.offset_dims,
|
||||
inserted_window_dims=indexer.dnums.collapsed_slice_dims,
|
||||
scatter_dims_to_operand_dims=indexer.dnums.start_index_map,
|
||||
operand_batching_dims=indexer.dnums.operand_batching_dims,
|
||||
scatter_indices_batching_dims=indexer.dnums.start_indices_batching_dims,
|
||||
)
|
||||
out = scatter_op(
|
||||
x, indexer.gather_indices, y, dnums,
|
||||
indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted,
|
||||
unique_indices=indexer.unique_indices or unique_indices,
|
||||
mode=mode)
|
||||
if indexer.scalar_bool_dims:
|
||||
out = lax.squeeze(out, indexer.scalar_bool_dims)
|
||||
return lax._convert_element_type(out, dtype, weak_type)
|
||||
|
||||
|
||||
def _get_identity(op, dtype):
|
||||
"""Get an appropriate identity for a given operation in a given dtype."""
|
||||
if op is slicing.scatter_add:
|
||||
return 0
|
||||
elif op is slicing.scatter_mul:
|
||||
return 1
|
||||
elif op is slicing.scatter_min:
|
||||
if dtype == dtypes.bool_:
|
||||
return True
|
||||
elif dtypes.issubdtype(dtype, np.integer):
|
||||
return dtypes.iinfo(dtype).max
|
||||
return float('inf')
|
||||
elif op is slicing.scatter_max:
|
||||
if dtype == dtypes.bool_:
|
||||
return False
|
||||
elif dtypes.issubdtype(dtype, np.integer):
|
||||
return dtypes.iinfo(dtype).min
|
||||
return -float('inf')
|
||||
else:
|
||||
raise ValueError(f"Unrecognized op: {op}")
|
||||
|
||||
|
||||
def _segment_update(name: str,
|
||||
data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
scatter_op: Callable,
|
||||
num_segments: int | None = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: int | None = None,
|
||||
reducer: Callable | None = None,
|
||||
mode: slicing.GatherScatterMode | str | None = None) -> Array:
|
||||
check_arraylike(name, data, segment_ids)
|
||||
mode = slicing.GatherScatterMode.FILL_OR_DROP if mode is None else mode
|
||||
data = jnp.asarray(data)
|
||||
segment_ids = jnp.asarray(segment_ids)
|
||||
dtype = data.dtype
|
||||
if num_segments is None:
|
||||
num_segments = np.max(segment_ids) + 1
|
||||
num_segments = core.concrete_dim_or_error(num_segments, "segment_sum() `num_segments` argument.")
|
||||
if num_segments is not None and num_segments < 0:
|
||||
raise ValueError("num_segments must be non-negative.")
|
||||
|
||||
if bucket_size is None:
|
||||
out = jnp.full((num_segments,) + data.shape[1:],
|
||||
_get_identity(scatter_op, dtype), dtype=dtype)
|
||||
return _scatter_update(
|
||||
out, segment_ids, data, scatter_op, indices_are_sorted,
|
||||
unique_indices, normalize_indices=False, mode=mode)
|
||||
|
||||
# Bucketize indices and perform segment_update on each bucket to improve
|
||||
# numerical stability for operations like product and sum.
|
||||
assert reducer is not None
|
||||
num_buckets = util.ceil_of_ratio(segment_ids.size, bucket_size)
|
||||
out = jnp.full((num_buckets, num_segments) + data.shape[1:],
|
||||
_get_identity(scatter_op, dtype), dtype=dtype)
|
||||
out = _scatter_update(
|
||||
out, np.index_exp[jnp.arange(segment_ids.shape[0]) // bucket_size,
|
||||
segment_ids[None, :]],
|
||||
data, scatter_op, indices_are_sorted,
|
||||
unique_indices, normalize_indices=False, mode=mode)
|
||||
return reducer(out, axis=0).astype(dtype)
|
||||
|
||||
|
||||
def segment_sum(data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
num_segments: int | None = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: int | None = None,
|
||||
mode: slicing.GatherScatterMode | str | None = None) -> Array:
|
||||
"""Computes the sum within segments of an array.
|
||||
|
||||
Similar to TensorFlow's `segment_sum
|
||||
<https://www.tensorflow.org/api_docs/python/tf/math/segment_sum>`_
|
||||
|
||||
Args:
|
||||
data: an array with the values to be summed.
|
||||
segment_ids: an array with integer dtype that indicates the segments of
|
||||
`data` (along its leading axis) to be summed. Values can be repeated and
|
||||
need not be sorted.
|
||||
num_segments: optional, an int with nonnegative value indicating the number
|
||||
of segments. The default is set to be the minimum number of segments that
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
``max(segment_ids) + 1``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use ``segment_sum`` in a JIT-compiled function.
|
||||
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_sum`` is
|
||||
performed on each bucket separately to improve numerical stability of
|
||||
addition. Default ``None`` means no bucketing.
|
||||
mode: a :class:`jax.lax.GatherScatterMode` value describing how
|
||||
out-of-bounds indices should be handled. By default, values outside of the
|
||||
range [0, num_segments) are dropped and do not contribute to the sum.
|
||||
|
||||
Returns:
|
||||
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
|
||||
segment sums.
|
||||
|
||||
Examples:
|
||||
Simple 1D segment sum:
|
||||
|
||||
>>> data = jnp.arange(5)
|
||||
>>> segment_ids = jnp.array([0, 0, 1, 1, 2])
|
||||
>>> segment_sum(data, segment_ids)
|
||||
Array([1, 5, 4], dtype=int32)
|
||||
|
||||
Using JIT requires static `num_segments`:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
|
||||
Array([1, 5, 4], dtype=int32)
|
||||
"""
|
||||
return _segment_update(
|
||||
"segment_sum", data, segment_ids, slicing.scatter_add, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, reductions.sum, mode=mode)
|
||||
|
||||
|
||||
def segment_prod(data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
num_segments: int | None = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: int | None = None,
|
||||
mode: slicing.GatherScatterMode | str | None = None) -> Array:
|
||||
"""Computes the product within segments of an array.
|
||||
|
||||
Similar to TensorFlow's `segment_prod
|
||||
<https://www.tensorflow.org/api_docs/python/tf/math/segment_prod>`_
|
||||
|
||||
Args:
|
||||
data: an array with the values to be reduced.
|
||||
segment_ids: an array with integer dtype that indicates the segments of
|
||||
`data` (along its leading axis) to be reduced. Values can be repeated and
|
||||
need not be sorted.
|
||||
num_segments: optional, an int with nonnegative value indicating the number
|
||||
of segments. The default is set to be the minimum number of segments that
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
``max(segment_ids) + 1``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use ``segment_prod`` in a JIT-compiled function.
|
||||
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_prod`` is
|
||||
performed on each bucket separately to improve numerical stability.
|
||||
Default ``None`` means no bucketing.
|
||||
mode: a :class:`jax.lax.GatherScatterMode` value describing how
|
||||
out-of-bounds indices should be handled. By default, values outside of the
|
||||
range [0, num_segments) are dropped and do not contribute to the result.
|
||||
|
||||
Returns:
|
||||
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
|
||||
segment products.
|
||||
|
||||
Examples:
|
||||
Simple 1D segment product:
|
||||
|
||||
>>> data = jnp.arange(6)
|
||||
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
|
||||
>>> segment_prod(data, segment_ids)
|
||||
Array([ 0, 6, 20], dtype=int32)
|
||||
|
||||
Using JIT requires static `num_segments`:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> jit(segment_prod, static_argnums=2)(data, segment_ids, 3)
|
||||
Array([ 0, 6, 20], dtype=int32)
|
||||
"""
|
||||
return _segment_update(
|
||||
"segment_prod", data, segment_ids, slicing.scatter_mul, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, reductions.prod, mode=mode)
|
||||
|
||||
|
||||
def segment_max(data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
num_segments: int | None = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: int | None = None,
|
||||
mode: slicing.GatherScatterMode | str | None = None) -> Array:
|
||||
"""Computes the maximum within segments of an array.
|
||||
|
||||
Similar to TensorFlow's `segment_max
|
||||
<https://www.tensorflow.org/api_docs/python/tf/math/segment_max>`_
|
||||
|
||||
Args:
|
||||
data: an array with the values to be reduced.
|
||||
segment_ids: an array with integer dtype that indicates the segments of
|
||||
`data` (along its leading axis) to be reduced. Values can be repeated and
|
||||
need not be sorted.
|
||||
num_segments: optional, an int with nonnegative value indicating the number
|
||||
of segments. The default is set to be the minimum number of segments that
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
``max(segment_ids) + 1``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use ``segment_max`` in a JIT-compiled function.
|
||||
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_max`` is
|
||||
performed on each bucket separately. Default ``None`` means no bucketing.
|
||||
mode: a :class:`jax.lax.GatherScatterMode` value describing how
|
||||
out-of-bounds indices should be handled. By default, values outside of the
|
||||
range [0, num_segments) are dropped and do not contribute to the result.
|
||||
|
||||
Returns:
|
||||
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
|
||||
segment maximums.
|
||||
|
||||
Examples:
|
||||
Simple 1D segment max:
|
||||
|
||||
>>> data = jnp.arange(6)
|
||||
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
|
||||
>>> segment_max(data, segment_ids)
|
||||
Array([1, 3, 5], dtype=int32)
|
||||
|
||||
Using JIT requires static `num_segments`:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> jit(segment_max, static_argnums=2)(data, segment_ids, 3)
|
||||
Array([1, 3, 5], dtype=int32)
|
||||
"""
|
||||
return _segment_update(
|
||||
"segment_max", data, segment_ids, slicing.scatter_max, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, reductions.max, mode=mode)
|
||||
|
||||
|
||||
def segment_min(data: ArrayLike,
|
||||
segment_ids: ArrayLike,
|
||||
num_segments: int | None = None,
|
||||
indices_are_sorted: bool = False,
|
||||
unique_indices: bool = False,
|
||||
bucket_size: int | None = None,
|
||||
mode: slicing.GatherScatterMode | str | None = None) -> Array:
|
||||
"""Computes the minimum within segments of an array.
|
||||
|
||||
Similar to TensorFlow's `segment_min
|
||||
<https://www.tensorflow.org/api_docs/python/tf/math/segment_min>`_
|
||||
|
||||
Args:
|
||||
data: an array with the values to be reduced.
|
||||
segment_ids: an array with integer dtype that indicates the segments of
|
||||
`data` (along its leading axis) to be reduced. Values can be repeated and
|
||||
need not be sorted.
|
||||
num_segments: optional, an int with nonnegative value indicating the number
|
||||
of segments. The default is set to be the minimum number of segments that
|
||||
would support all indices in ``segment_ids``, calculated as
|
||||
``max(segment_ids) + 1``.
|
||||
Since `num_segments` determines the size of the output, a static value
|
||||
must be provided to use ``segment_min`` in a JIT-compiled function.
|
||||
indices_are_sorted: whether ``segment_ids`` is known to be sorted.
|
||||
unique_indices: whether `segment_ids` is known to be free of duplicates.
|
||||
bucket_size: size of bucket to group indices into. ``segment_min`` is
|
||||
performed on each bucket separately. Default ``None`` means no bucketing.
|
||||
mode: a :class:`jax.lax.GatherScatterMode` value describing how
|
||||
out-of-bounds indices should be handled. By default, values outside of the
|
||||
range [0, num_segments) are dropped and do not contribute to the result.
|
||||
|
||||
Returns:
|
||||
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
|
||||
segment minimums.
|
||||
|
||||
Examples:
|
||||
Simple 1D segment min:
|
||||
|
||||
>>> data = jnp.arange(6)
|
||||
>>> segment_ids = jnp.array([0, 0, 1, 1, 2, 2])
|
||||
>>> segment_min(data, segment_ids)
|
||||
Array([0, 2, 4], dtype=int32)
|
||||
|
||||
Using JIT requires static `num_segments`:
|
||||
|
||||
>>> from jax import jit
|
||||
>>> jit(segment_min, static_argnums=2)(data, segment_ids, 3)
|
||||
Array([0, 2, 4], dtype=int32)
|
||||
"""
|
||||
return _segment_update(
|
||||
"segment_min", data, segment_ids, slicing.scatter_min, num_segments,
|
||||
indices_are_sorted, unique_indices, bucket_size, reductions.min, mode=mode)
|
||||
@@ -0,0 +1,104 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import overload, Literal
|
||||
|
||||
from jax._src import config
|
||||
from jax._src.lax import lax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy.reductions import _reduction_dims, Axis
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
import numpy as np
|
||||
|
||||
# The definition of logsumexp is shared between jax.nn and jax.scipy, and
|
||||
# although it matches scipy's definition, we put it here to avoid having
|
||||
# unnecessary scipy dependencies.
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) -> Array: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: ...
|
||||
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]:
|
||||
r"""Log-sum-exp reduction.
|
||||
|
||||
JAX implementation of :func:`scipy.special.logsumexp`.
|
||||
|
||||
.. math::
|
||||
\operatorname{logsumexp} a = \log \sum_i b_i \exp a_i
|
||||
|
||||
where the :math:`i` indices range over one or more dimensions to be reduced.
|
||||
|
||||
Args:
|
||||
a: the input array
|
||||
axis: int or sequence of ints, default=None. Axis along which the sum to be
|
||||
computed. If None, the sum is computed along all the axes.
|
||||
b: scaling factors for the exponentials. Must be broadcastable to the shape of `a`.
|
||||
keepdims: If ``True``, the axes that are reduced are left in the output as
|
||||
dimensions of size 1.
|
||||
return_sign: If ``True``, the output will be a ``(result, sign)`` pair,
|
||||
where ``sign`` is the sign of the sums and ``result`` contains the
|
||||
logarithms of their absolute values. If ``False`` only ``result`` is
|
||||
returned and it will contain NaN values if the sums are negative.
|
||||
where: Elements to include in the reduction.
|
||||
|
||||
Returns:
|
||||
Either an array ``result`` or a pair of arrays ``(result, sign)``, depending
|
||||
on the value of the ``return_sign`` argument.
|
||||
|
||||
See also:
|
||||
:func:`jax.nn.logmeanexp`
|
||||
"""
|
||||
if where is not None:
|
||||
a = jnp.where(where, a, 0)
|
||||
if b is not None:
|
||||
a_arr, b_arr = promote_args_inexact("logsumexp", a, b)
|
||||
a_arr = jnp.where(b_arr != 0, a_arr, -np.inf)
|
||||
else:
|
||||
a_arr, = promote_args_inexact("logsumexp", a)
|
||||
b_arr = a_arr # for type checking
|
||||
pos_dims, dims = _reduction_dims(a_arr, axis)
|
||||
amax = reductions.max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf)
|
||||
amax = lax.stop_gradient(lax.select(ufuncs.isfinite(amax), amax, lax.full_like(amax, 0)))
|
||||
amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
|
||||
|
||||
exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype)))
|
||||
if b is not None:
|
||||
exp_a = lax.mul(exp_a, b_arr)
|
||||
sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where)
|
||||
sign = lax.sign(sumexp)
|
||||
if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating):
|
||||
sumexp = abs(sumexp)
|
||||
out = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype))
|
||||
|
||||
if return_sign:
|
||||
return (out, sign)
|
||||
if b is not None and not np.issubdtype(out.dtype, np.complexfloating):
|
||||
with config.debug_nans(False):
|
||||
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
|
||||
return out
|
||||
Reference in New Issue
Block a user