This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,13 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,13 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,75 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import operator
from jax._src import api
from jax._src import numpy as jnp
from jax._src.numpy import linalg as jnp_linalg
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike
def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
"""Assign codes from a code book to a set of observations.
JAX implementation of :func:`scipy.cluster.vq.vq`.
Assigns each observation vector in ``obs`` to a code from ``code_book``
based on the nearest Euclidean distance.
Args:
obs: array of observation vectors of shape ``(M, N)``. Each row represents
a single observation. If ``obs`` is one-dimensional, then each entry is
treated as a length-1 observation.
code_book: array of codes with shape ``(K, N)``. Each row represents a single
code vector. If ``code_book`` is one-dimensional, then each entry is treated
as a length-1 code.
check_finite: unused in JAX
Returns:
A tuple of arrays ``(code, dist)``
- ``code`` is an integer array of shape ``(M,)`` containing indices ``0 <= i < K``
of the closest entry in ``code_book`` for the given entry in ``obs``.
- ``dist`` is a float array of shape ``(M,)`` containing the euclidean
distance between each observation and the nearest code.
Examples:
>>> obs = jnp.array([[1.1, 2.1, 3.1],
... [5.9, 4.8, 6.2]])
>>> code_book = jnp.array([[1., 2., 3.],
... [2., 3., 4.],
... [3., 4., 5.],
... [4., 5., 6.]])
>>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book)
>>> print(codes)
[0 3]
>>> print(distances)
[0.17320499 1.9209373 ]
"""
del check_finite # unused
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
obs_arr, cb_arr = promote_dtypes_inexact(obs, code_book)
if obs_arr.ndim != cb_arr.ndim:
raise ValueError("Observation and code_book should have the same rank")
if obs_arr.ndim == 1:
obs_arr, cb_arr = obs_arr[..., None], cb_arr[..., None]
if obs_arr.ndim != 2:
raise ValueError("ndim different than 1 or 2 are not supported")
dist = api.vmap(lambda ob: jnp_linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr)
code = jnp.argmin(dist, axis=-1)
dist_min = api.vmap(operator.getitem)(dist, code)
return code, dist_min
@@ -0,0 +1,502 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
from functools import partial
import math
import operator
import numpy as np
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy import fft as jnp_fft
from jax._src.numpy.util import (
promote_dtypes_complex, promote_dtypes_inexact, ensure_arraylike)
from jax._src.util import canonicalize_axis, canonicalize_axis_tuple
from jax._src.typing import Array
def _W4(N: int, k: Array) -> Array:
N_arr, k = promote_dtypes_complex(N, k)
return jnp.exp(-.5j * np.pi * k / N_arr)
def _dct_interleave(x: Array, axis: int) -> Array:
v0 = lax.slice_in_dim(x, None, None, 2, axis)
v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis,))
return lax.concatenate([v0, v1], axis)
def _dct_ortho_norm(out: Array, axis: int) -> Array:
factor = lax.concatenate([lax.full((1,), 4, out.dtype), lax.full((out.shape[axis] - 1,), 2, out.dtype)], 0)
factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
return out / lax.sqrt(factor * out.shape[axis])
# Implementation based on
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)
def dct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.dct`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
The default is ``None``, which is equivalent to ``"backward"``.
Returns:
array containing the discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
Examples:
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x))
[[ 6.43 3.56 -2.86]
[-1.75 1.55 -1.4 ]
[ 1.33 -2.01 -0.82]]
When ``n`` smaller than ``x.shape[axis]``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=2))
[[ 7.3 -0.57]
[ 0.19 -0.36]
[-0. -1.4 ]]
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=2, axis=0))
[[ 3.09 4.4 -2.81]
[ 2.41 2.62 0.76]]
When ``n`` larger than ``x.shape[axis]`` and ``axis=1``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dct(x, n=4, axis=1))
[[ 6.43 4.88 0.04 -3.3 ]
[-1.75 0.73 1.01 -2.18]
[ 1.33 -1.05 -2.34 -0.07]]
"""
x = ensure_arraylike("dct", x)
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['backward', 'ortho']:
raise ValueError(f"jax.scipy.fft.dct: {norm=!r} is not implemented")
if dtypes.issubdtype(x.dtype, np.complexfloating):
return lax.complex(
dct(x.real, type=type, n=n, norm=norm, axis=axis),
dct(x.imag, type=type, n=n, norm=norm, axis=axis),
)
axis = canonicalize_axis(axis, x.ndim)
if n is not None:
x = lax.pad(x, jnp.array(0, x.dtype),
[(0, n - x.shape[axis] if a == axis else 0, 0)
for a in range(x.ndim)])
N = x.shape[axis]
v = _dct_interleave(x, axis)
V = jnp_fft.fft(v, axis=axis)
k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis])
out = V * _W4(N, k)
out = 2 * out.real
if norm == 'ortho':
out = _dct_ortho_norm(out, axis)
return out
def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
N1, N2 = x.shape[axis1], x.shape[axis2]
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
V = jnp_fft.fftn(v, axes=axes)
k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype),
[a for a in range(x.ndim) if a != axis1])
k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype),
[a for a in range(x.ndim) if a != axis2])
out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2))
out = 2 * out.real
if norm == 'ortho':
return _dct_ortho_norm(_dct_ortho_norm(out, axis1), axis2)
return out
def dctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.dctn`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If
not specified, it will default to the shape of ``x`` along the specified
``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed. If not given, the last ``len(s)`` axes are
used, or all axes if ``s`` is also not specified.
norm: string. The normalization mode: one of
``[None, "backward", "ortho"]``. The default is ``None``, which is
equivalent to ``"backward"``.
Returns:
array containing the discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
Examples:
``jax.scipy.fft.dctn`` computes the transform along both the axes by default
when ``axes`` argument is ``None`` and ``s`` is also ``None``.
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dctn(x))
[[ 12.01 6.2 -10.17]
[ 8.84 9.65 -3.54]
[ 11.25 -1.54 -0.88]]
When ``s=[2]``, the transform will be computed only along the last axis,
with its dimension padded or truncated to size ``2``:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dctn(x, s=[2]))
[[ 7.3 -0.57]
[ 0.19 -0.36]
[-0. -1.4 ]]
When ``s=[2]`` and ``axes=[0]``, the transform will be computed only along
the specified axis, with its dimension padded or truncated to size ``2``:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dctn(x, s=[2], axes=[0]))
[[ 3.09 4.4 -2.81]
[ 2.41 2.62 0.76]]
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``.
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.dctn(x, s=[2, 4]))
[[ 9.36 11.23 2.12 -10.97]
[ 11.57 5.86 -1.37 -1.58]]
"""
x = ensure_arraylike("dctn", x)
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['backward', 'ortho']:
raise ValueError(f"jax.scipy.fft.dctn: {norm=!r} is not implemented")
if dtypes.issubdtype(x.dtype, np.complexfloating):
return lax.complex(
dctn(x.real, type=type, s=s, norm=norm, axes=axes),
dctn(x.imag, type=type, s=s, norm=norm, axes=axes),
)
if s is not None:
try:
s = list(s)
except TypeError:
assert not isinstance(s, Sequence)
s = [operator.index(s)]
if len(s) > x.ndim:
raise ValueError(
f"s must have at most x.ndim ({x.ndim}) elements, got {len(s)}"
)
if axes is None:
if s is not None:
axes = tuple(range(x.ndim - len(s), x.ndim))
else:
axes = tuple(range(x.ndim))
else:
axes = canonicalize_axis_tuple(axes, x.ndim)
if len(axes) == 1:
return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
if s is not None:
ns = dict(zip(axes, s))
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
x = lax.pad(x, jnp.array(0, x.dtype), pads)
if len(axes) == 2:
return _dct2(x, axes=axes, norm=norm)
# compose high-D DCTs from 2D and 1D DCTs:
for axes_block in [axes[i:i+2] for i in range(0, len(axes), 2)]:
x = dctn(x, axes=axes_block, norm=norm)
return x
def idct(x: Array, type: int = 2, n: int | None = None,
axis: int = -1, norm: str | None = None) -> Array:
"""Computes the inverse discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.idct`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
n: integer, default = x.shape[axis]. The length of the transform.
If larger than ``x.shape[axis]``, the input will be zero-padded, if
smaller, the input will be truncated.
axis: integer, default=-1. The axis along which the dct will be performed.
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
The default is ``None``, which is equivalent to ``"backward"``.
Returns:
array containing the inverse discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
Examples:
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x))
[[ 0.78 0.41 -0.39]
[-0.12 0.31 -0.23]
[ 0.17 -0.3 -0.11]]
When ``n`` smaller than ``x.shape[axis]``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=2))
[[ 1.12 -0.31]
[ 0.04 -0.08]
[ 0.05 -0.3 ]]
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=2, axis=0))
[[ 0.38 0.57 -0.45]
[ 0.43 0.44 0.24]]
When ``n`` larger than ``x.shape[axis]`` and ``axis=0``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idct(x, n=4, axis=0))
[[ 0.1 0.38 -0.16]
[ 0.28 0.18 -0.26]
[ 0.3 0.15 -0.08]
[ 0.13 0.3 0.29]]
``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result
of ``jax.scipy.fft.dct``
>>> x_dct = jax.scipy.fft.dct(x)
>>> jnp.allclose(x, jax.scipy.fft.idct(x_dct))
Array(True, dtype=bool)
"""
x = ensure_arraylike("idct", x)
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['backward', 'ortho']:
raise ValueError(f"jax.scipy.fft.idct: {norm=!r} is not implemented")
if dtypes.issubdtype(x.dtype, np.complexfloating):
return lax.complex(
idct(x.real, type=type, n=n, norm=norm, axis=axis),
idct(x.imag, type=type, n=n, norm=norm, axis=axis)
)
axis = canonicalize_axis(axis, x.ndim)
if n is not None:
x = lax.pad(x, jnp.array(0, x.dtype),
[(0, n - x.shape[axis] if a == axis else 0, 0)
for a in range(x.ndim)])
N = x.shape[axis]
x, = promote_dtypes_inexact(x)
if norm is None or norm == 'backward':
x = _dct_ortho_norm(x, axis)
x = _dct_ortho_norm(x, axis)
k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis])
# everything is complex from here...
w4 = _W4(N,k)
x = x.astype(w4.dtype)
x = x / (_W4(N, k))
x = x * 2 * N
x = jnp_fft.ifft(x, axis=axis)
# convert back to reals..
out = _dct_deinterleave(x.real, axis)
return out
def idctn(x: Array, type: int = 2,
s: Sequence[int] | None=None,
axes: Sequence[int] | None = None,
norm: str | None = None) -> Array:
"""Computes the multidimensional inverse discrete cosine transform of the input
JAX implementation of :func:`scipy.fft.idctn`.
Args:
x: array
type: integer, default = 2. Currently only type 2 is supported.
s: integer or sequence of integers. Specifies the shape of the result. If
not specified, it will default to the shape of ``x`` along the specified
``axes``.
axes: integer or sequence of integers. Specifies the axes along which the
transform will be computed. If not given, the last ``len(s)`` axes are
used, or all axes if ``s`` is also not specified.
norm: string. The normalization mode: one of
``[None, "backward", "ortho"]``. The default is ``None``, which is
equivalent to ``"backward"``.
Returns:
array containing the inverse discrete cosine transform of x
See Also:
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
Examples:
``jax.scipy.fft.idctn`` computes the transform along both the axes by
default when ``axes`` argument is ``None`` and ``s`` is also ``None``.
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x))
[[ 0.12 0.11 -0.15]
[ 0.07 0.17 -0.03]
[ 0.19 -0.07 -0.02]]
When ``s=[2]``, the transform will be computed only along the last axis,
with its dimension padded or truncated to size ``2``:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2]))
[[ 1.12 -0.31]
[ 0.04 -0.08]
[ 0.05 -0.3 ]]
When ``s=[2]`` and ``axes=[0]``, the transform will be computed only along
the specified axis, with its dimension padded or truncated to size ``2``:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2], axes=[0]))
[[ 0.38 0.57 -0.45]
[ 0.43 0.44 0.24]]
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jax.scipy.fft.idctn(x, s=[2, 4]))
[[ 0.1 0.18 0.07 -0.16]
[ 0.2 0.06 -0.03 -0.01]]
``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result
of ``jax.scipy.fft.dctn``
>>> x_dctn = jax.scipy.fft.dctn(x)
>>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
Array(True, dtype=bool)
"""
x = ensure_arraylike("idctn", x)
if type != 2:
raise NotImplementedError('Only DCT type 2 is implemented.')
if norm is not None and norm not in ['backward', 'ortho']:
raise ValueError(f"jax.scipy.fft.idctn: {norm=!r} is not implemented")
if dtypes.issubdtype(x.dtype, np.complexfloating):
return lax.complex(
idctn(x.real, type=type, s=s, norm=norm, axes=axes),
idctn(x.imag, type=type, s=s, norm=norm, axes=axes)
)
if s is not None:
try:
s = list(s)
except TypeError:
assert not isinstance(s, Sequence)
s = [operator.index(s)]
if len(s) > x.ndim:
raise ValueError(
f"s must have at most x.ndim ({x.ndim}) elements, got {len(s)}"
)
if axes is None:
if s is not None:
axes = tuple(range(x.ndim - len(s), x.ndim))
else:
axes = tuple(range(x.ndim))
else:
axes = canonicalize_axis_tuple(axes, x.ndim)
if len(axes) == 1:
return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
if s is not None:
ns = dict(zip(axes, s))
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
x = lax.pad(x, jnp.array(0, x.dtype), pads)
# compose high-D DCTs from 1D DCTs:
for axis in axes:
x = idct(x, axis=axis, norm=norm)
return x
def _dct_deinterleave(x: Array, axis: int) -> Array:
empty_slice = slice(None, None, None)
ix0 = tuple(
slice(None, math.ceil(x.shape[axis]/2), 1) if i == axis else empty_slice
for i in range(len(x.shape)))
ix1 = tuple(
slice(math.ceil(x.shape[axis]/2), None, 1) if i == axis else empty_slice
for i in range(len(x.shape)))
v0 = x[ix0]
v1 = lax.rev(x[ix1], (axis,))
out = jnp.zeros(x.shape, dtype=x.dtype)
evens = tuple(
slice(None, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
odds = tuple(
slice(1, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
out = out.at[evens].set(v0)
out = out.at[odds].set(v1)
return out
@@ -0,0 +1,67 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from jax._src.api import jit
from jax._src.numpy import lax_numpy
from jax._src.typing import Array, ArrayLike
@jit(static_argnames=('axis',))
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
axis: int = -1) -> Array:
r"""
Integrate along the given axis using the composite trapezoidal rule.
JAX implementation of :func:`scipy.integrate.trapezoid`
The trapezoidal rule approximates the integral under a curve by summing the
areas of trapezoids formed between adjacent data points.
Args:
y: array of data to integrate.
x: optional array of sample points corresponding to the ``y`` values. If not
provided, ``x`` defaults to equally spaced with spacing given by ``dx``.
dx: The spacing between sample points when `x` is None (default: 1.0).
axis: The axis along which to integrate (default: -1)
Returns:
The definite integral approximated by the trapezoidal rule.
See also:
:func:`jax.numpy.trapezoid`: NumPy-style API for trapezoidal integration
Examples:
Integrate over a regular grid, with spacing 1.0:
>>> y = jnp.array([1, 2, 3, 2, 3, 2, 1])
>>> jax.scipy.integrate.trapezoid(y, dx=1.0)
Array(13., dtype=float32)
Integrate over an irregular grid:
>>> x = jnp.array([0, 2, 5, 7, 10, 15, 20])
>>> jax.scipy.integrate.trapezoid(y, x)
Array(43., dtype=float32)
Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`:
>>> x = jnp.linspace(0, 2 * jnp.pi, 1000)
>>> y = jnp.sin(x) ** 2
>>> result = jax.scipy.integrate.trapezoid(y, x)
>>> jnp.allclose(result, jnp.pi)
Array(True, dtype=bool)
"""
return lax_numpy.trapezoid(y, x, dx, axis)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,180 @@
# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Sequence
import functools
import itertools
import operator
import numpy as np
from jax._src import api
from jax._src import dtypes
from jax._src import numpy as jnp
from jax._src import util
from jax._src.lax import lax
from jax._src.typing import ArrayLike, Array
from jax._src.util import safe_zip as zip
def _nonempty_prod(arrs: Sequence[Array]) -> Array:
return functools.reduce(operator.mul, arrs)
def _nonempty_sum(arrs: Sequence[Array]) -> Array:
return sum(arrs[1:], arrs[0])
def _mirror_index_fixer(index: Array, size: int) -> Array:
s = size - 1 # Half-wavelength of triangular wave
# Scaled, integer-valued version of the triangular wave |x - round(x)|
return jnp.abs((index + s) % (2 * s) - s)
def _reflect_index_fixer(index: Array, size: int) -> Array:
return jnp.floor_divide(_mirror_index_fixer(2*index+1, 2*size+1) - 1, 2)
_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = {
'constant': lambda index, size: index,
'nearest': lambda index, size: jnp.clip(index, 0, size - 1),
'wrap': lambda index, size: index % size,
'mirror': _mirror_index_fixer,
'reflect': _reflect_index_fixer,
}
def _round_half_away_from_zero(a: Array) -> Array:
return a if dtypes.issubdtype(a.dtype, np.integer) else lax.round(a)
def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
index = _round_half_away_from_zero(coordinate).astype(np.int32)
weight = coordinate.dtype.type(1)
return [(index, weight)]
def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
lower = jnp.floor(coordinate)
upper_weight = coordinate - lower
lower_weight = 1 - upper_weight
index = lower.astype(np.int32)
return [(index, lower_weight), (index + 1, upper_weight)]
@functools.partial(api.jit, static_argnums=(2, 3, 4))
def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
order: int, mode: str, cval: ArrayLike) -> Array:
input_arr = jnp.asarray(input)
coordinate_arrs = [jnp.asarray(c) for c in coordinates]
cval = jnp.asarray(cval, input_arr.dtype)
if len(coordinates) != input_arr.ndim:
raise ValueError('coordinates must be a sequence of length input.ndim, but '
'{} != {}'.format(len(coordinates), input_arr.ndim))
index_fixer = _INDEX_FIXERS.get(mode)
if index_fixer is None:
raise NotImplementedError(
'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS)))
if mode == 'constant':
is_valid = lambda index, size: (0 <= index) & (index < size)
else:
is_valid = lambda index, size: True
if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError(
'jax.scipy.ndimage.map_coordinates currently requires order<=1')
valid_1d_interpolations = []
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)
outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = util.unzip3(items)
if all(valid is True for valid in validities):
# fast path
contribution = input_arr[indices]
else:
all_valid = functools.reduce(operator.and_, validities)
contribution = jnp.where(all_valid, input_arr[indices], cval)
outputs.append(_nonempty_prod(weights) * contribution) # pyrefly: ignore[bad-argument-type]
result = _nonempty_sum(outputs)
if dtypes.issubdtype(input_arr.dtype, np.integer):
result = _round_half_away_from_zero(result)
return result.astype(input_arr.dtype)
def map_coordinates(
input: ArrayLike, coordinates: Sequence[ArrayLike], order: int,
mode: str = 'constant', cval: ArrayLike = 0.0,
):
"""
Map the input array to new coordinates using interpolation.
JAX implementation of :func:`scipy.ndimage.map_coordinates`
Given an input array and a set of coordinates, this function returns the
interpolated values of the input array at those coordinates.
Args:
input: N-dimensional input array from which values are interpolated.
coordinates: length-N sequence of arrays specifying the coordinates
at which to evaluate the interpolated values
order: The order of interpolation. JAX supports the following:
* 0: Nearest-neighbor
* 1: Linear
mode: Points outside the boundaries of the input are filled according to the given mode.
JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``. Note the
``'wrap'`` mode in JAX behaves as ``'grid-wrap'`` mode in SciPy, and ``'constant'``
mode in JAX behaves as ``'grid-constant'`` mode in SciPy. This discrepancy was caused
by a former bug in those modes in SciPy (https://github.com/scipy/scipy/issues/2640),
which was first fixed in JAX by changing the behavior of the existing modes, and later
on fixed in SciPy, by adding modes with new names, rather than fixing the existing
ones, for backwards compatibility reasons. Default is 'constant'.
cval: Value used for points outside the boundaries of the input if ``mode='constant'``
Default is 0.0.
Returns:
The interpolated values at the specified coordinates.
Examples:
>>> input = jnp.arange(12.0).reshape(3, 4)
>>> input
Array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]], dtype=float32)
>>> coordinates = [jnp.array([0.5, 1.5]),
... jnp.array([1.5, 2.5])]
>>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1)
Array([3.5, 8.5], dtype=float32)
Note:
Interpolation near boundaries differs from the scipy function, because JAX
fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097.
This function interprets the ``mode`` argument as documented by SciPy, but
not as implemented by SciPy.
"""
return _map_coordinates(input, coordinates, order, mode, cval)
@@ -0,0 +1,13 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,244 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import NamedTuple
import numpy as np
from jax._src import api
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy import linalg as jnp_linalg
from jax._src.scipy.optimize.line_search import line_search
from jax._src.typing import Array
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
class LBFGSResults(NamedTuple):
"""Results from L-BFGS optimization
Parameters:
converged: True if minimization converged
failed: True if non-zero status and not converged
k: integer number of iterations of the main loop (optimisation steps)
nfev: integer total number of objective evaluations performed.
ngev: integer total number of jacobian evaluations
x_k: array containing the last argument value found during the search. If
the search converged, then this value is the argmin of the objective
function.
f_k: array containing the value of the objective function at `x_k`. If the
search converged, then this is the (local) minimum of the objective
function.
g_k: array containing the gradient of the objective function at `x_k`. If
the search converged the l2-norm of this tensor should be below the
tolerance.
status: integer describing the status:
0 = nominal , 1 = max iters reached , 2 = max fun evals reached
3 = max grad evals reached , 4 = insufficient progress (ftol)
5 = line search failed
ls_status: integer describing the end status of the last line search
"""
converged: Array
failed: Array
k: int | Array
nfev: int | Array
ngev: int | Array
x_k: Array
f_k: Array
g_k: Array
s_history: Array
y_history: Array
rho_history: Array
gamma: float | Array
status: int | Array
ls_status: int | Array
def _minimize_lbfgs(
fun: Callable,
x0: Array,
maxiter: float | None = None,
norm=np.inf,
maxcor: int = 10,
ftol: float = 2.220446049250313e-09,
gtol: float = 1e-05,
maxfun: float | None = None,
maxgrad: float | None = None,
maxls: int = 20,
):
"""
Minimize a function using L-BFGS
Implements the L-BFGS algorithm from
Algorithm 7.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 176-185
And generalizes to complex variables from
Sorber, L., Barel, M.V. and Lathauwer, L.D., 2012.
"Unconstrained optimization of real functions in complex variables"
SIAM Journal on Optimization, 22(3), pp.879-898.
Args:
fun: function of the form f(x) where x is a flat ndarray and returns a real scalar.
The function should be composed of operations with vjp defined.
x0: initial guess
maxiter: maximum number of iterations
norm: order of norm for convergence check. Default inf.
maxcor: maximum number of metric corrections ("history size")
ftol: terminates the minimization when `(f_k - f_{k+1}) < ftol`
gtol: terminates the minimization when `|g_k|_norm < gtol`
maxfun: maximum number of function evaluations
maxgrad: maximum number of gradient evaluations
maxls: maximum number of line search steps (per iteration)
Returns:
Optimization results.
"""
d = len(x0)
dtype = dtypes.dtype(x0)
# ensure there is at least one termination condition
if (maxiter is None) and (maxfun is None) and (maxgrad is None):
maxiter = d * 200
# set others to inf, such that >= is supported
if maxiter is None:
maxiter = np.inf
if maxfun is None:
maxfun = np.inf
if maxgrad is None:
maxgrad = np.inf
# initial evaluation
f_0, g_0 = api.value_and_grad(fun)(x0)
state_initial = LBFGSResults(
converged=jnp.array(False, dtype=bool),
failed=jnp.array(False, dtype=bool),
k=0,
nfev=1,
ngev=1,
x_k=x0,
f_k=f_0,
g_k=g_0,
s_history=jnp.zeros((maxcor, d), dtype=dtype),
y_history=jnp.zeros((maxcor, d), dtype=dtype),
rho_history=jnp.zeros((maxcor,), dtype=dtype),
gamma=1.,
status=0,
ls_status=0,
)
def cond_fun(state: LBFGSResults):
return (~state.converged) & (~state.failed)
def body_fun(state: LBFGSResults):
# find search direction
p_k = _two_loop_recursion(state)
# line search
ls_results = line_search(
f=fun,
xk=state.x_k,
pk=p_k,
old_fval=state.f_k,
gfk=state.g_k,
maxiter=maxls,
)
# evaluate at next iterate
s_k = jnp.asarray(ls_results.a_k).astype(p_k.dtype) * p_k
x_kp1 = state.x_k + s_k
f_kp1 = ls_results.f_k
g_kp1 = ls_results.g_k
y_k = g_kp1 - state.g_k
rho_k_inv = jnp.real(_dot(y_k, s_k))
rho_k = jnp.reciprocal(rho_k_inv).astype(y_k.dtype)
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))
# replacements for next iteration
status = jnp.array(0)
status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
status = jnp.where(state.ngev >= maxgrad, 3, status)
status = jnp.where(state.nfev >= maxfun, 2, status)
status = jnp.where(state.k >= maxiter, 1, status)
status = jnp.where(ls_results.failed, 5, status)
converged = jnp_linalg.norm(g_kp1, ord=norm) < gtol
# TODO(jakevdp): use a fixed-point procedure rather than type-casting?
state = state._replace(
converged=converged,
failed=(status > 0) & (~converged),
k=state.k + 1,
nfev=state.nfev + ls_results.nfev,
ngev=state.ngev + ls_results.ngev,
x_k=x_kp1.astype(state.x_k.dtype),
f_k=f_kp1.astype(state.f_k.dtype),
g_k=g_kp1.astype(state.g_k.dtype),
s_history=_update_history_vectors(history=state.s_history, new=s_k),
y_history=_update_history_vectors(history=state.y_history, new=y_k),
rho_history=_update_history_scalars(history=state.rho_history, new=rho_k),
gamma=gamma.astype(state.g_k.dtype),
status=jnp.where(converged, 0, status),
ls_status=ls_results.status,
)
return state
return lax.while_loop(cond_fun, body_fun, state_initial)
def _two_loop_recursion(state: LBFGSResults):
dtype = state.rho_history.dtype
his_size = len(state.rho_history)
curr_size = jnp.where(state.k < his_size, state.k, his_size)
q = -jnp.conj(state.g_k)
a_his = jnp.zeros_like(state.rho_history)
def body_fun1(j, carry):
i = his_size - 1 - j
_q, _a_his = carry
a_i = state.rho_history[i] * _dot(jnp.conj(state.s_history[i]), _q).real.astype(dtype)
_a_his = _a_his.at[i].set(a_i)
_q = _q - a_i * jnp.conj(state.y_history[i])
return _q, _a_his
q, a_his = lax.fori_loop(0, curr_size, body_fun1, (q, a_his))
q = state.gamma * q
def body_fun2(j, _q):
i = his_size - curr_size + j
b_i = state.rho_history[i] * _dot(state.y_history[i], _q).real.astype(dtype)
_q = _q + (a_his[i] - b_i) * state.s_history[i]
return _q
q = lax.fori_loop(0, curr_size, body_fun2, q)
return q
def _update_history_vectors(history, new):
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
return jnp.roll(history, -1, axis=0).at[-1, :].set(new)
def _update_history_scalars(history, new):
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
return jnp.roll(history, -1, axis=0).at[-1].set(new)
@@ -0,0 +1,188 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import NamedTuple
import numpy as np
from jax._src import api
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy import einsum as jnp_einsum
from jax._src.numpy import linalg as jnp_linalg
from jax._src.scipy.optimize.line_search import line_search
from jax._src.typing import Array
class _BFGSResults(NamedTuple):
"""Results from BFGS optimization.
Parameters:
converged: True if minimization converged.
failed: True if line search failed.
k: integer the number of iterations of the BFGS update.
nfev: integer total number of objective evaluations performed.
ngev: integer total number of jacobian evaluations
nhev: integer total number of hessian evaluations
x_k: array containing the last argument value found during the search. If
the search converged, then this value is the argmin of the objective
function.
f_k: array containing the value of the objective function at `x_k`. If the
search converged, then this is the (local) minimum of the objective
function.
g_k: array containing the gradient of the objective function at `x_k`. If
the search converged the l2-norm of this tensor should be below the
tolerance.
H_k: array containing the inverse of the estimated Hessian.
status: int describing end state.
line_search_status: int describing line search end state (only means
something if line search fails).
"""
converged: bool | Array
failed: bool | Array
k: int | Array
nfev: int | Array
ngev: int | Array
nhev: int | Array
x_k: Array
f_k: Array
g_k: Array
H_k: Array
old_old_fval: Array
status: int | Array
line_search_status: int | Array
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
_einsum = partial(jnp_einsum.einsum, precision=lax.Precision.HIGHEST)
def minimize_bfgs(
fun: Callable,
x0: Array,
maxiter: int | None = None,
norm=np.inf,
gtol: float = 1e-5,
line_search_maxiter: int = 10,
) -> _BFGSResults:
"""Minimize a function using BFGS.
Implements the BFGS algorithm from
Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg.
136-143.
Args:
fun: function of the form f(x) where x is a flat ndarray and returns a real
scalar. The function should be composed of operations with vjp defined.
x0: initial guess.
maxiter: maximum number of iterations.
norm: order of norm for convergence check. Default inf.
gtol: terminates minimization when |grad|_norm < g_tol.
line_search_maxiter: maximum number of linesearch iterations.
Returns:
Optimization result.
"""
if maxiter is None:
maxiter = np.size(x0) * 200
d = x0.shape[0]
initial_H = jnp.eye(d, dtype=x0.dtype)
f_0, g_0 = api.value_and_grad(fun)(x0)
state = _BFGSResults(
converged=jnp_linalg.norm(g_0, ord=norm) < gtol,
failed=False,
k=0,
nfev=1,
ngev=1,
nhev=0,
x_k=x0,
f_k=f_0,
g_k=g_0,
H_k=initial_H,
old_old_fval=f_0 + jnp_linalg.norm(g_0) / 2,
status=0,
line_search_status=0,
)
def cond_fun(state):
return (jnp.logical_not(state.converged)
& jnp.logical_not(state.failed)
& (state.k < maxiter))
def body_fun(state):
p_k = -_dot(state.H_k, state.g_k)
line_search_results = line_search(
fun,
state.x_k,
p_k,
old_fval=state.f_k,
old_old_fval=state.old_old_fval,
gfk=state.g_k,
maxiter=line_search_maxiter,
)
state = state._replace(
nfev=state.nfev + line_search_results.nfev,
ngev=state.ngev + line_search_results.ngev,
failed=line_search_results.failed,
line_search_status=line_search_results.status,
)
s_k = line_search_results.a_k * p_k
x_kp1 = state.x_k + s_k
f_kp1 = line_search_results.f_k
g_kp1 = line_search_results.g_k
y_k = g_kp1 - state.g_k
rho_k = jnp.reciprocal(_dot(y_k, s_k))
sy_k = s_k[:, np.newaxis] * y_k[np.newaxis, :]
w = jnp.eye(d, dtype=rho_k.dtype) - rho_k * sy_k
H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w)
+ rho_k * s_k[:, np.newaxis] * s_k[np.newaxis, :])
H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k)
converged = jnp_linalg.norm(g_kp1, ord=norm) < gtol
state = state._replace(
converged=converged,
k=state.k + 1,
x_k=x_kp1,
f_k=f_kp1,
g_k=g_kp1,
H_k=H_kp1,
old_old_fval=state.f_k,
)
return state
state = lax.while_loop(cond_fun, body_fun, state)
status = jnp.where(
state.converged,
0, # converged
jnp.where(
state.k == maxiter,
1, # max iters reached
jnp.where(
state.failed,
2 + state.line_search_status, # ls failed (+ reason)
-1, # undefined
)
)
)
state = state._replace(status=status)
return state
@@ -0,0 +1,442 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import NamedTuple, Callable
from functools import partial
from jax._src import api
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src.typing import Array
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
def _cubicmin(a, fa, fpa, b, fb, c, fc):
dtype = dtypes.result_type(a, fa, fpa, b, fb, c, fc)
C = fpa
db = b - a
dc = c - a
denom = (db * dc) ** 2 * (db - dc)
d1 = jnp.array([[dc ** 2, -db ** 2],
[-dc ** 3, db ** 3]], dtype=dtype)
d2 = jnp.array([fb - fa - C * db, fc - fa - C * dc], dtype=dtype)
A, B = _dot(d1, d2) / denom
radical = B * B - 3. * A * C
xmin = a + (-B + jnp.sqrt(radical)) / (3. * A)
return xmin
def _quadmin(a, fa, fpa, b, fb):
D = fa
C = fpa
db = b - a
B = (fb - D - C * db) / (db ** 2)
xmin = a - C / (2. * B)
return xmin
def _binary_replace(replace_bit, original_dict, new_dict, keys=None):
if keys is None:
keys = new_dict.keys()
return {key: jnp.where(replace_bit, new_dict[key], original_dict[key])
for key in keys}
class _ZoomState(NamedTuple):
done: bool | Array
failed: bool | Array
j: int | Array
a_lo: float | Array
phi_lo: float | Array
dphi_lo: float | Array
a_hi: float | Array
phi_hi: float | Array
dphi_hi: float | Array
a_rec: float | Array
phi_rec: float | Array
a_star: float | Array
phi_star: float | Array
dphi_star: float | Array
g_star: float | Array
nfev: int | Array
ngev: int | Array
ConditionFn = Callable[..., Array]
def _zoom(restricted_func_and_grad, wolfe_one: ConditionFn, wolfe_two: ConditionFn, a_lo, phi_lo,
dphi_lo, a_hi, phi_hi, dphi_hi, g_0, pass_through):
"""
Implementation of zoom. Algorithm 3.6 from Wright and Nocedal, 'Numerical
Optimization', 1999, pg. 59-61. Tries cubic, quadratic, and bisection methods
of zooming.
"""
state = _ZoomState(
done=False,
failed=False,
j=0,
a_lo=a_lo,
phi_lo=phi_lo,
dphi_lo=dphi_lo,
a_hi=a_hi,
phi_hi=phi_hi,
dphi_hi=dphi_hi,
a_rec=(a_lo + a_hi) / 2.,
phi_rec=(phi_lo + phi_hi) / 2.,
a_star=1.,
phi_star=phi_lo,
dphi_star=dphi_lo,
g_star=g_0,
nfev=0,
ngev=0,
)
delta1 = 0.2
delta2 = 0.1
def body(state):
# Body of zoom algorithm. We use boolean arithmetic to avoid using jax.cond
# so that it works on GPU/TPU.
dalpha = (state.a_hi - state.a_lo)
a = jnp.minimum(state.a_hi, state.a_lo)
b = jnp.maximum(state.a_hi, state.a_lo)
cchk = delta1 * dalpha
qchk = delta2 * dalpha
# This will cause the line search to stop, and since the Wolfe conditions
# are not satisfied the minimization should stop too.
threshold = jnp.where((dtypes.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10)
state = state._replace(failed=state.failed | (dalpha <= threshold))
# Cubmin is sometimes nan, though in this case the bounds check will fail.
a_j_cubic = _cubicmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi,
state.phi_hi, state.a_rec, state.phi_rec)
use_cubic = (state.j > 0) & (a_j_cubic > a + cchk) & (a_j_cubic < b - cchk)
a_j_quad = _quadmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi, state.phi_hi)
use_quad = (~use_cubic) & (a_j_quad > a + qchk) & (a_j_quad < b - qchk)
a_j_bisection = (state.a_lo + state.a_hi) / 2.
use_bisection = (~use_cubic) & (~use_quad)
a_j = jnp.where(use_cubic, a_j_cubic, state.a_rec)
a_j = jnp.where(use_quad, a_j_quad, a_j)
a_j = jnp.where(use_bisection, a_j_bisection, a_j)
# TODO(jakevdp): should we use some sort of fixed-point approach here instead?
phi_j, dphi_j, g_j = restricted_func_and_grad(a_j)
phi_j = phi_j.astype(state.phi_lo.dtype)
dphi_j = dphi_j.astype(state.dphi_lo.dtype)
g_j = g_j.astype(state.g_star.dtype)
state = state._replace(nfev=state.nfev + 1,
ngev=state.ngev + 1)
hi_to_j = wolfe_one(a_j, phi_j) | (phi_j >= state.phi_lo)
star_to_j = wolfe_two(dphi_j) & (~hi_to_j)
hi_to_lo = (dphi_j * (state.a_hi - state.a_lo) >= 0.) & (~hi_to_j) & (~star_to_j)
lo_to_j = (~hi_to_j) & (~star_to_j)
state = state._replace(
**_binary_replace(
hi_to_j,
state._asdict(),
dict(
a_hi=a_j,
phi_hi=phi_j,
dphi_hi=dphi_j,
a_rec=state.a_hi,
phi_rec=state.phi_hi,
),
),
)
# for termination
state = state._replace(
done=star_to_j | state.done,
**_binary_replace(
star_to_j,
state._asdict(),
dict(
a_star=a_j,
phi_star=phi_j,
dphi_star=dphi_j,
g_star=g_j,
)
),
)
state = state._replace(
**_binary_replace(
hi_to_lo,
state._asdict(),
dict(
a_hi=state.a_lo,
phi_hi=state.phi_lo,
dphi_hi=state.dphi_lo,
a_rec=state.a_hi,
phi_rec=state.phi_hi,
),
),
)
state = state._replace(
**_binary_replace(
lo_to_j & ~hi_to_lo,
state._asdict(),
dict(
a_rec=state.a_lo,
phi_rec=state.phi_lo,
),
),
)
state = state._replace(
**_binary_replace(
lo_to_j,
state._asdict(),
dict(
a_lo=a_j,
phi_lo=phi_j,
dphi_lo=dphi_j,
),
),
)
state = state._replace(j=state.j + 1)
# Choose higher cutoff for maxiter than Scipy as Jax takes longer to find
# the same value - possibly floating point issues?
state = state._replace(failed= state.failed | (state.j >= 30))
return state
state = lax.while_loop(lambda state: (~state.done) & (~pass_through) & (~state.failed),
body,
state)
return state
class _LineSearchState(NamedTuple):
done: Array
failed: Array
i: int | Array
a_i1: float | Array
phi_i1: float | Array
dphi_i1: float | Array
nfev: int | Array
ngev: int | Array
a_star: float | Array
phi_star: Array
dphi_star: Array
g_star: Array
class _LineSearchResults(NamedTuple):
"""Results of line search.
Parameters:
failed: True if the strong Wolfe criteria were satisfied
nit: integer number of iterations
nfev: integer number of functions evaluations
ngev: integer number of gradients evaluations
k: integer number of iterations
a_k: integer step size
f_k: final function value
g_k: final gradient value
status: integer end status
"""
failed: bool | Array
nit: int | Array
nfev: int | Array
ngev: int | Array
k: int | Array
a_k: int | Array
f_k: Array
g_k: Array
status: bool | Array
def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
c2=0.9, maxiter=20):
"""Inexact line search that satisfies strong Wolfe conditions.
Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61
Args:
fun: function of the form f(x) where x is a flat ndarray and returns a real
scalar. The function should be composed of operations with vjp defined.
x0: initial guess.
pk: direction to search in. Assumes the direction is a descent direction.
old_fval, gfk: initial value of value_and_gradient as position.
old_old_fval: unused argument, only for scipy API compliance.
maxiter: maximum number of iterations to search
c1, c2: Wolfe criteria constant, see ref.
Returns: LineSearchResults
"""
xk, pk = promote_dtypes_inexact(xk, pk)
def restricted_func_and_grad(t):
t = jnp.array(t, dtype=pk.dtype)
phi, g = api.value_and_grad(f)(xk + t * pk)
dphi = jnp.real(_dot(g, pk))
return phi, dphi, g
if old_fval is None or gfk is None:
phi_0, dphi_0, gfk = restricted_func_and_grad(0)
else:
phi_0 = old_fval
dphi_0 = jnp.real(_dot(gfk, pk))
if old_old_fval is not None:
candidate_start_value = 1.01 * 2 * (phi_0 - old_old_fval) / dphi_0
start_value = jnp.where(candidate_start_value > 1, 1.0, candidate_start_value)
else:
start_value = 1
def wolfe_one(a_i, phi_i) -> Array:
# actually negation of W1
return phi_i > phi_0 + c1 * a_i * dphi_0
def wolfe_two(dphi_i) -> Array:
return jnp.abs(dphi_i) <= -c2 * dphi_0
state = _LineSearchState(
done=jnp.array(False, dtype=bool),
failed=jnp.array(False, dtype=bool),
# algorithm begins at 1 as per Wright and Nocedal, however Scipy has a
# bug and starts at 0. See https://github.com/scipy/scipy/issues/12157
i=1,
a_i1=0.,
phi_i1=phi_0,
dphi_i1=dphi_0,
nfev=1 if (old_fval is None or gfk is None) else 0,
ngev=1 if (old_fval is None or gfk is None) else 0,
a_star=0.,
phi_star=phi_0,
dphi_star=dphi_0,
g_star=gfk,
)
def body(state) -> _LineSearchState:
# no amax in this version, we just double as in scipy.
# unlike original algorithm we do our next choice at the start of this loop
a_i = jnp.where(state.i == 1, start_value, state.a_i1 * 2.)
phi_i, dphi_i, g_i = restricted_func_and_grad(a_i)
state = state._replace(nfev=state.nfev + 1,
ngev=state.ngev + 1)
star_to_zoom1 = wolfe_one(a_i, phi_i) | ((phi_i >= state.phi_i1) & (state.i > 1))
star_to_i = wolfe_two(dphi_i) & (~star_to_zoom1)
star_to_zoom2 = (dphi_i >= 0.) & (~star_to_zoom1) & (~star_to_i)
zoom1 = _zoom(restricted_func_and_grad,
wolfe_one,
wolfe_two,
state.a_i1,
state.phi_i1,
state.dphi_i1,
a_i,
phi_i,
dphi_i,
gfk,
~star_to_zoom1)
state = state._replace(nfev=state.nfev + zoom1.nfev,
ngev=state.ngev + zoom1.ngev)
zoom2 = _zoom(restricted_func_and_grad,
wolfe_one,
wolfe_two,
a_i,
phi_i,
dphi_i,
state.a_i1,
state.phi_i1,
state.dphi_i1,
gfk,
~star_to_zoom2)
state = state._replace(nfev=state.nfev + zoom2.nfev,
ngev=state.ngev + zoom2.ngev)
state = state._replace(
done=star_to_zoom1 | state.done,
failed=(star_to_zoom1 & zoom1.failed) | state.failed,
**_binary_replace(
star_to_zoom1,
state._asdict(),
zoom1._asdict(),
keys=['a_star', 'phi_star', 'dphi_star', 'g_star'],
),
)
state = state._replace(
done=star_to_i | state.done,
**_binary_replace(
star_to_i,
state._asdict(),
dict(
a_star=a_i,
phi_star=phi_i,
dphi_star=dphi_i,
g_star=g_i,
),
),
)
state = state._replace(
done=star_to_zoom2 | state.done,
failed=(star_to_zoom2 & zoom2.failed) | state.failed,
**_binary_replace(
star_to_zoom2,
state._asdict(),
zoom2._asdict(),
keys=['a_star', 'phi_star', 'dphi_star', 'g_star'],
),
)
state = state._replace(i=state.i + 1, a_i1=a_i, phi_i1=phi_i, dphi_i1=dphi_i)
return state
state: _LineSearchState = lax.while_loop(
lambda state: (~state.done) & (state.i <= maxiter) & (~state.failed),
body,
state)
status = jnp.where(
state.failed,
jnp.array(1), # zoom failed
jnp.where(
state.i > maxiter,
jnp.array(3), # maxiter reached
jnp.array(0), # passed (should be)
),
)
# Step sizes which are too small causes the optimizer to get stuck with a
# direction of zero in <64 bit mode - avoid with a floor on minimum step size.
alpha_k = jnp.asarray(state.a_star)
alpha_k = jnp.where((dtypes.finfo(alpha_k.dtype).bits != 64)
& (jnp.abs(alpha_k) < 1e-8),
jnp.sign(alpha_k) * 1e-8,
alpha_k)
results = _LineSearchResults(
failed=state.failed | (~state.done),
nit=state.i - 1, # because iterations started at 1
nfev=state.nfev,
ngev=state.ngev,
k=state.i,
a_k=alpha_k,
f_k=state.phi_star,
g_k=state.g_star,
status=status,
)
return results
@@ -0,0 +1,133 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import Any, NamedTuple
from jax._src import numpy as jnp
from jax._src.scipy.optimize.bfgs import minimize_bfgs
from jax._src.scipy.optimize._lbfgs import _minimize_lbfgs
from jax._src.typing import Array
class OptimizeResults(NamedTuple):
"""Object holding optimization results.
Parameters:
x: final solution.
success: ``True`` if optimization succeeded.
status: integer solver specific return code. 0 means converged (nominal),
1=max BFGS iters reached, 3=zoom failed, 4=saddle point reached,
5=max line search iters reached, -1=undefined
fun: final function value.
jac: final jacobian array.
hess_inv: final inverse Hessian estimate.
nfev: integer number of function calls used.
njev: integer number of gradient evaluations.
nit: integer number of iterations of the optimization algorithm.
"""
x: Array
success: bool | Array
status: int | Array
fun: Array
jac: Array
hess_inv: Array | None
nfev: int | Array
njev: int | Array
nit: int | Array
def minimize(
fun: Callable,
x0: Array,
args: tuple = (),
*,
method: str,
tol: float | None = None,
options: Mapping[str, Any] | None = None,
) -> OptimizeResults:
"""Minimization of scalar function of one or more variables.
This API for this function matches SciPy with some minor deviations:
- Gradients of ``fun`` are calculated automatically using JAX's autodiff
support when required.
- The ``method`` argument is required. You must specify a solver.
- Various optional arguments in the SciPy interface have not yet been
implemented.
- Optimization results may differ from SciPy due to differences in the line
search implementation.
``minimize`` supports :func:`~jax.jit` compilation. It does not yet support
differentiation or arguments in the form of multi-dimensional arrays, but
support for both is planned.
Args:
fun: the objective function to be minimized, ``fun(x, *args) -> float``,
where ``x`` is a 1-D array with shape ``(n,)`` and ``args`` is a tuple
of the fixed parameters needed to completely specify the function.
``fun`` must support differentiation.
x0: initial guess. Array of real elements of size ``(n,)``, where ``n`` is
the number of independent variables.
args: extra arguments passed to the objective function.
method: solver type. Currently only ``"BFGS"`` is supported.
tol: tolerance for termination. For detailed control, use solver-specific
options.
options: a dictionary of solver options. All methods accept the following
generic options:
- maxiter (int): Maximum number of iterations to perform. Depending on the
method each iteration may use several function evaluations.
Returns:
An :class:`OptimizeResults` object.
"""
if options is None:
options = {}
if not isinstance(args, tuple):
msg = "args argument to jax.scipy.optimize.minimize must be a tuple, got {}"
raise TypeError(msg.format(args))
fun_with_args = lambda x: fun(x, *args)
if method.lower() == 'bfgs':
results = minimize_bfgs(fun_with_args, x0, **options)
success = results.converged & jnp.logical_not(results.failed)
return OptimizeResults(x=results.x_k,
success=success,
status=results.status,
fun=results.f_k,
jac=results.g_k,
hess_inv=results.H_k,
nfev=results.nfev,
njev=results.ngev,
nit=results.k)
if method.lower() == 'l-bfgs-experimental-do-not-rely-on-this':
results = _minimize_lbfgs(fun_with_args, x0, **options)
success = results.converged & jnp.logical_not(results.failed)
return OptimizeResults(x=results.x_k,
success=success,
status=results.status,
fun=results.f_k,
jac=results.g_k,
hess_inv=None,
nfev=results.nfev,
njev=results.ngev,
nit=results.k)
raise ValueError(f"Method {method} not recognized")
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,13 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,761 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import operator
import numpy as np
from jax._src import api
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax import lax as lax_internal
from jax._src.numpy import einsum as jnp_einsum
from jax._src.scipy import linalg as jsp_linalg
from jax._src.tree_util import (tree_leaves, tree_map, tree_structure,
tree_reduce, Partial)
from jax._src.typing import Array, ArrayLike
from jax._src.util import safe_map as map
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
_vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
_einsum = partial(jnp_einsum.einsum, precision=lax.Precision.HIGHEST)
# aliases for working with pytrees
def _vdot_real_part(x, y):
"""Vector dot-product guaranteed to have a real valued result despite
possibly complex input. Thus neglects the real-imaginary cross-terms.
The result is a real float.
"""
# all our uses of vdot() in CG are for computing an operator of the form
# z^H M z
# where M is positive definite and Hermitian, so the result is
# real valued:
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
result = _vdot(x.real, y.real)
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
result += _vdot(x.imag, y.imag)
return result
def _vdot_real_tree(x, y):
return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))
def _vdot_tree(x, y) -> ArrayLike:
return sum(tree_leaves(tree_map(partial(
jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))
def _norm(x):
xs = tree_leaves(x)
return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
def _mul(scalar, tree):
return tree_map(partial(operator.mul, scalar), tree)
_add = partial(tree_map, operator.add)
_sub = partial(tree_map, operator.sub)
_dot_tree = partial(tree_map, _dot)
@Partial
def _identity(x):
return x
def _normalize_matvec(f):
"""Normalize an argument for computing matrix-vector products."""
if callable(f):
return f
elif isinstance(f, (np.ndarray, Array)):
if f.ndim != 2 or f.shape[0] != f.shape[1]:
raise ValueError(
f'linear operator must be a square matrix, but has shape: {f.shape}')
return partial(_dot, f)
elif hasattr(f, '__matmul__'):
if hasattr(f, 'shape') and len(f.shape) != 2 or f.shape[0] != f.shape[1]:
raise ValueError(
f'linear operator must be a square matrix, but has shape: {f.shape}')
return partial(operator.matmul, f)
else:
raise TypeError(
f'linear operator must be either a function or ndarray: {f}')
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
bs = _vdot_real_tree(b, b)
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
def cond_fun(value):
_, r, gamma, _, k = value
rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
return (rs > atol2) & (k < maxiter)
def body_fun(value):
x, r, gamma, p, k = value
Ap = A(p)
alpha = gamma / _vdot_real_tree(p, Ap).astype(dtype)
x_ = _add(x, _mul(alpha, p))
r_ = _sub(r, _mul(alpha, Ap))
z_ = M(r_)
gamma_ = _vdot_real_tree(r_, z_).astype(dtype)
beta_ = gamma_ / gamma
p_ = _add(z_, _mul(beta_, p))
return x_, r_, gamma_, p_, k + 1
r0 = _sub(b, A(x0))
p0 = z0 = M(r0)
dtype = dtypes.result_type(*tree_leaves(p0))
gamma0 = _vdot_real_tree(r0, z0).astype(dtype)
initial_value = (x0, r0, gamma0, p0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
return x_final
# aliases for working with pytrees
def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.bicgstab
bs = _vdot_real_tree(b, b)
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
# https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB
def cond_fun(value):
x, r, *_, k = value
rs = _vdot_real_tree(r, r)
# the last condition checks breakdown
return (rs > atol2) & (k < maxiter) & (k >= 0)
def body_fun(value):
x, r, rhat, alpha, omega, rho, p, q, k = value
rho_ = _vdot_tree(rhat, r)
beta = rho_ / rho * alpha / omega
p_ = _add(r, _mul(beta, _sub(p, _mul(omega, q))))
phat = M(p_)
q_ = A(phat)
alpha_ = rho_ / _vdot_tree(rhat, q_)
s = _sub(r, _mul(alpha_, q_))
exit_early = _vdot_real_tree(s, s) < atol2
shat = M(s)
t = A(shat)
omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases?
x_ = tree_map(partial(jnp.where, exit_early),
_add(x, _mul(alpha_, phat)),
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))
)
r_ = tree_map(partial(jnp.where, exit_early),
s, _sub(s, _mul(omega_, t)))
k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
k_ = jnp.where((rho_ == 0), -10, k_)
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_
r0 = _sub(b, A(x0))
rho0 = alpha0 = omega0 = lax_internal._convert_element_type(
1, *dtypes.lattice_result_type(*tree_leaves(b)))
initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
return x_final
def _shapes(pytree):
return map(np.shape, tree_leaves(pytree))
def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
maxiter=None, M=None, check_symmetric=False):
if x0 is None:
x0 = tree_map(jnp.zeros_like, b)
b, x0 = api.device_put((b, x0))
if maxiter is None:
size = sum(bi.size for bi in tree_leaves(b))
maxiter = 10 * size # copied from scipy
if M is None:
M = _identity
A = _normalize_matvec(A)
M = _normalize_matvec(M)
if tree_structure(x0) != tree_structure(b):
raise ValueError(
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')
if _shapes(x0) != _shapes(b):
raise ValueError(
'arrays in x0 and b must have matching shapes: '
f'{_shapes(x0)} vs {_shapes(b)}')
isolve_solve = partial(
_isolve_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
# real-valued positive-definite linear operators are symmetric
def real_valued(x):
return not issubclass(x.dtype.type, np.complexfloating)
symmetric = all(map(real_valued, tree_leaves(b))) \
if check_symmetric else False
x = lax.custom_linear_solve(
A, b, solve=isolve_solve, transpose_solve=isolve_solve,
symmetric=symmetric)
info = None
return x, info
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
numerical precision), but note that the interface is slightly different: you
need to supply the linear operator ``A`` as a function instead of a sparse
matrix or ``LinearOperator``.
Derivatives of ``cg`` are implemented via implicit differentiation with
another ``cg`` solve, rather than by differentiating *through* the solver.
They will be accurate only if both solves converge.
Parameters
----------
A: ndarray, function, or matmul-compatible object
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A`` must represent
a hermitian, positive definite matrix, and must return array(s) with the
same structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
Returns
-------
x : array or tree of arrays
The converged solution. Has the same structure as ``b``.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
Other Parameters
----------------
x0 : array or tree of arrays
Starting guess for the solution. Must have the same structure as ``b``.
tol, atol : float, optional
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : ndarray, function, or matmul-compatible object
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
See also
--------
scipy.sparse.linalg.cg
jax.lax.custom_linear_solve
"""
return _isolve(_cg_solve,
A=A, b=b, x0=x0, tol=tol, atol=atol,
maxiter=maxiter, M=M, check_symmetric=True)
def _safe_normalize(x, thresh=None):
"""
Returns the L2-normalized vector (which can be a pytree) x, and optionally
the computed norm. If the computed norm is less than the threshold `thresh`,
which by default is the machine precision of x's dtype, it will be
taken to be 0, and the normalized x to be the zero vector.
"""
norm = _norm(x)
dtype, weak_type = dtypes.lattice_result_type(*tree_leaves(x))
if thresh is None:
thresh = dtypes.finfo(norm.dtype).eps
thresh = thresh.astype(dtype).real
use_norm = norm > thresh
norm_cast = lax_internal._convert_element_type(norm, dtype, weak_type)
normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm_cast, 0.0), x)
norm = jnp.where(use_norm, norm, 0.0)
return normalized_x, norm
def _project_on_columns(A, v):
"""
Returns A.T.conj() @ v.
"""
v_proj = tree_map(
lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
)
return tree_reduce(operator.add, v_proj)
def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
"""
Orthogonalize x against the columns of Q. The process is repeated
up to `max_iterations` times, or fewer if the condition
||r|| < (1/sqrt(2)) ||x|| is met earlier (see below for the meaning
of r and x).
Parameters
----------
Q : array or tree of arrays
A matrix of orthonormal columns.
x : array or tree of arrays
A vector. It will be replaced with a new vector q which is orthonormal
to the columns of Q, such that x in span(col(Q), q).
xnorm : float
Norm of x.
Returns
-------
q : array or tree of arrays
A unit vector, orthonormal to each column of Q, such that
x in span(col(Q), q).
r : array
Stores the overlaps of x with each vector in Q.
"""
# "twice is enough"
# http://slepc.upv.es/documentation/reports/str1.pdf
# TODO(shoyer): consider switching to only one iteration, like SciPy?
# This assumes that Q's leaves all have the same dimension in the last
# axis.
Q0 = tree_leaves(Q)[0]
r = jnp.zeros(Q0.shape[-1], dtype=Q0.dtype)
q = x
xnorm_scaled = xnorm / jnp.sqrt(2.0)
def body_function(carry):
k, q, r, qnorm_scaled = carry
h = _project_on_columns(Q, q)
Qh = tree_map(lambda X: _dot(X, h), Q)
q = _sub(q, Qh)
r = _add(r, h)
def qnorm_cond(carry):
k, not_done, _, _ = carry
return jnp.logical_and(not_done, k < (max_iterations - 1))
def qnorm(carry):
k, _, q, qnorm_scaled = carry
_, qnorm = _safe_normalize(q)
qnorm_scaled = qnorm / jnp.sqrt(2.0)
return (k, False, q, qnorm_scaled)
init = (k, True, q, qnorm_scaled)
_, _, q, qnorm_scaled = lax.while_loop(qnorm_cond, qnorm, init)
return (k + 1, q, r, qnorm_scaled)
def cond_function(carry):
k, _, r, qnorm_scaled = carry
_, rnorm = _safe_normalize(r)
return jnp.logical_and(k < (max_iterations - 1), rnorm < qnorm_scaled)
k, q, r, qnorm_scaled = body_function((0, q, r, xnorm_scaled))
k, q, r, _ = lax.while_loop(cond_function, body_function,
(k, q, r, qnorm_scaled))
return q, r
def _kth_arnoldi_iteration(k, A, M, V, H):
"""
Performs a single (the k'th) step of the Arnoldi process. Thus,
adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
and that vectors overlaps with the existing Krylov vectors to
H[k, :]. The tolerance 'tol' sets the threshold at which an invariant
subspace is declared to have been found, in which case in which case the new
vector is taken to be the zero vector.
"""
dtype, _ = dtypes.lattice_result_type(*tree_leaves(V))
eps = dtypes.finfo(dtype).eps
v = tree_map(lambda x: x[..., k], V) # Gets V[:, k]
v = M(A(v))
_, v_norm_0 = _safe_normalize(v)
v, h = _iterative_classical_gram_schmidt(V, v, v_norm_0, max_iterations=2)
tol = eps * v_norm_0
unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
V = tree_map(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
h = h.at[k + 1].set(v_norm_1.astype(dtype))
H = H.at[k, :].set(h)
breakdown = v_norm_1 == 0.
return V, H, breakdown
def _rotate_vectors(H, i, cs, sn):
x1 = H[i]
y1 = H[i + 1]
x2 = cs.conj() * x1 - sn.conj() * y1
y2 = sn * x1 + cs * y1
H = H.at[i].set(x2)
H = H.at[i + 1].set(y2)
return H
def _givens_rotation(a, b):
b_zero = abs(b) == 0
a_lt_b = abs(a) < abs(b)
t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
r = lax.rsqrt(1 + abs(t) ** 2).astype(t.dtype)
cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
return cs, sn
def _apply_givens_rotations(H_row, givens, k):
"""
Applies the Givens rotations stored in the vectors cs and sn to the vector
H_row. Then constructs and applies a new Givens rotation that eliminates
H_row's k'th element.
"""
# This call successively applies each of the
# Givens rotations stored in givens[:, :k] to H_col.
def apply_ith_rotation(i, H_row):
return _rotate_vectors(H_row, i, *givens[i, :])
R_row = lax.fori_loop(0, k, apply_ith_rotation, H_row)
givens_factors = _givens_rotation(R_row[k], R_row[k + 1])
givens = givens.at[k, :].set(givens_factors)
R_row = _rotate_vectors(R_row, k, *givens_factors)
return R_row, givens
def _gmres_incremental(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
"""
Implements a single restart of GMRES. The restart-dimensional Krylov subspace
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
projection of the true solution into this subspace is returned.
This implementation builds the QR factorization during the Arnoldi process.
"""
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
V = tree_map(
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
unit_residual,
)
dtype = dtypes.result_type(*tree_leaves(b))
# use eye() to avoid constructing a singular matrix in case of early
# termination
R = jnp.eye(restart, restart + 1, dtype=dtype)
givens = jnp.zeros((restart, 2), dtype=dtype)
beta_vec = jnp.zeros((restart + 1), dtype=dtype)
beta_vec = beta_vec.at[0].set(residual_norm.astype(dtype))
def loop_cond(carry):
k, err, _, _, _, _ = carry
return jnp.logical_and(k < restart, err > ptol)
def arnoldi_qr_step(carry):
k, _, V, R, beta_vec, givens = carry
V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R)
R_row, givens = _apply_givens_rotations(H[k, :], givens, k)
R = R.at[k, :].set(R_row)
beta_vec = _rotate_vectors(beta_vec, k, *givens[k, :])
err = abs(beta_vec[k + 1])
return k + 1, err, V, R, beta_vec, givens
carry = (0, residual_norm, V, R, beta_vec, givens)
carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry)
k, residual_norm, V, R, beta_vec, _ = carry
del k # Until we figure out how to pass this to the user.
y = jsp_linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
x = _add(x0, dx)
residual = M(_sub(b, A(x)))
unit_residual, residual_norm = _safe_normalize(residual)
# TODO(shoyer): "Inner loop tolerance control" on ptol, like SciPy
return x, unit_residual, residual_norm
def _lstsq(a, b):
# faster than jsp_linalg.lstsq
a2 = _dot(a.T.conj(), a)
b2 = _dot(a.T.conj(), b)
return jsp_linalg.solve(a2, b2, assume_a='pos')
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
"""
Implements a single restart of GMRES. The ``restart``-dimensional Krylov
subspace
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
projection of the true solution into this subspace is returned.
This implementation solves a dense linear problem instead of building
a QR factorization during the Arnoldi process.
"""
del ptol # unused
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
V = tree_map(
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
unit_residual,
)
dtype, weak_type = dtypes.lattice_result_type(*tree_leaves(b))
H = lax_internal._convert_element_type(
jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type)
def loop_cond(carry):
_, _, breakdown, k = carry
return jnp.logical_and(k < restart, jnp.logical_not(breakdown))
def arnoldi_process(carry):
V, H, _, k = carry
V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H)
return V, H, breakdown, k + 1
carry = (V, H, False, 0)
V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)
beta_vec = jnp.zeros_like(H, shape=(restart + 1,)).at[0].set(residual_norm.astype(dtype))
y = _lstsq(H.T, beta_vec)
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
x = _add(x0, dx)
residual = M(_sub(b, A(x)))
unit_residual, residual_norm = _safe_normalize(residual)
return x, unit_residual, residual_norm
def _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func):
"""
The main function call wrapped by custom_linear_solve. Repeatedly calls GMRES
to find the projected solution within the order-``restart``
Krylov space K(A, x0, restart), using the result of the previous projection
in place of x0 each time. Parameters are the same as in ``gmres`` except:
atol: Tolerance for norm(A(x) - b), used between restarts.
ptol: Tolerance for norm(M(A(x) - b)), used within a restart.
gmres_func: A function performing a single GMRES restart.
Returns: The solution.
"""
residual = M(_sub(b, A(x0)))
unit_residual, residual_norm = _safe_normalize(residual)
def cond_fun(value):
_, k, _, residual_norm = value
return jnp.logical_and(k < maxiter, residual_norm > atol)
def body_fun(value):
x, k, unit_residual, residual_norm = value
x, unit_residual, residual_norm = gmres_func(
A, b, x, unit_residual, residual_norm, ptol, restart, M)
return x, k + 1, unit_residual, residual_norm
initialization = (x0, 0, unit_residual, residual_norm)
x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)
_ = k # Until we can pass this out
_ = err
return x_final # , info
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
M=None, solve_method='batched'):
"""
GMRES solves the linear system A x = b for x, given A and b.
A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
need not have any particular special properties, such as symmetry. However,
convergence is often slow for nearly symmetric operators.
Parameters
----------
A: ndarray, function, or matmul-compatible object
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A``
must return array(s) with the same structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
Returns
-------
x : array or tree of arrays
The converged solution. Has the same structure as ``b``.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
Other Parameters
----------------
x0 : array or tree of arrays, optional
Starting guess for the solution. Must have the same structure as ``b``.
If this is unspecified, zeroes are used.
tol, atol : float, optional
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
restart : integer, optional
Size of the Krylov subspace ("number of iterations") built between
restarts. GMRES works by approximating the true solution x as its
projection into a Krylov space of this dimension - this parameter
therefore bounds the maximum accuracy achievable from any guess
solution. Larger values increase both number of iterations and iteration
cost, but may be necessary for convergence. The algorithm terminates
early if convergence is achieved before the full subspace is built.
Default is 20.
maxiter : integer
Maximum number of times to rebuild the size-``restart`` Krylov space
starting from the solution found at the last iteration. If GMRES
halts or is very slow, decreasing this parameter may help.
Default is infinite.
M : ndarray, function, or matmul-compatible object
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
solve_method : 'incremental' or 'batched'
The 'incremental' solve method builds a QR decomposition for the Krylov
subspace incrementally during the GMRES process using Givens rotations.
This improves numerical stability and gives a free estimate of the
residual norm that allows for early termination within a single "restart".
In contrast, the 'batched' solve method solves the least squares problem
from scratch at the end of each GMRES iteration. It does not allow for
early termination, but has much less overhead on GPUs.
See also
--------
scipy.sparse.linalg.gmres
jax.lax.custom_linear_solve
"""
if x0 is None:
x0 = tree_map(jnp.zeros_like, b)
if M is None:
M = _identity
A = _normalize_matvec(A)
M = _normalize_matvec(M)
b, x0 = api.device_put((b, x0))
size = sum(bi.size for bi in tree_leaves(b))
if maxiter is None:
maxiter = 10 * size # copied from scipy
restart = min(restart, size)
if tree_structure(x0) != tree_structure(b):
raise ValueError(
'x0 and b must have matching tree structure: '
f'{tree_structure(x0)} vs {tree_structure(b)}')
b_norm = _norm(b)
atol = jnp.maximum(tol * b_norm, atol)
Mb = M(b)
Mb_norm = _norm(Mb)
ptol = Mb_norm * jnp.minimum(1.0, atol / b_norm)
if solve_method == 'incremental':
gmres_func = _gmres_incremental
elif solve_method == 'batched':
gmres_func = _gmres_batched
else:
raise ValueError(f"invalid solve_method {solve_method}, must be either "
"'incremental' or 'batched'")
def _solve(A, b):
return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func)
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
failed = jnp.isnan(_norm(x))
info = jnp.where(failed, -1, 0)
return x, info
def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
"""Use Bi-Conjugate Gradient Stable iteration to solve ``Ax = b``.
The numerics of JAX's ``bicgstab`` should exact match SciPy's
``bicgstab`` (up to numerical precision), but note that the interface
is slightly different: you need to supply the linear operator ``A`` as
a function instead of a sparse matrix or ``LinearOperator``.
As with ``cg``, derivatives of ``bicgstab`` are implemented via implicit
differentiation with another ``bicgstab`` solve, rather than by
differentiating *through* the solver. They will be accurate only if
both solves converge.
Parameters
----------
A: ndarray, function, or matmul-compatible object
2D array or function that calculates the linear map (matrix-vector
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A`` can represent
any general (nonsymmetric) linear operator, and function must return array(s)
with the same structure and shape as its argument.
b : array or tree of arrays
Right hand side of the linear system representing a single vector. Can be
stored as an array or Python container of array(s) with any shape.
Returns
-------
x : array or tree of arrays
The converged solution. Has the same structure as ``b``.
info : None
Placeholder for convergence information. In the future, JAX will report
the number of iterations when convergence is not achieved, like SciPy.
Other Parameters
----------------
x0 : array or tree of arrays
Starting guess for the solution. Must have the same structure as ``b``.
tol, atol : float, optional
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : ndarray, function, or matmul-compatible object
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
See also
--------
scipy.sparse.linalg.bicgstab
jax.lax.custom_linear_solve
"""
return _isolve(_bicgstab_solve,
A=A, b=b, x0=x0, tol=tol, atol=atol,
maxiter=maxiter, M=M)
@@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,463 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
import re
import typing
import numpy as np
from jax._src import config
from jax._src import numpy as jnp
from jax._src.numpy import linalg as jnp_linalg
from jax._src.numpy import vectorize as jnp_vectorize
from jax._src.typing import Array
class Rotation(typing.NamedTuple):
"""Rotation in 3 dimensions.
JAX implementation of :class:`scipy.spatial.transform.Rotation`.
Examples:
Construct an object describing a 90 degree rotation about the z-axis:
>>> from jax.scipy.spatial.transform import Rotation
>>> r = Rotation.from_euler('z', 90, degrees=True)
Convert to a rotation vector:
>>> r.as_rotvec()
Array([0. , 0. , 1.5707964], dtype=float32)
Convert to rotation matrix:
>>> r.as_matrix()
Array([[ 0. , -0.99999994, 0. ],
[ 0.99999994, 0. , 0. ],
[ 0. , 0. , 0.99999994]], dtype=float32)
Compose with another rotation:
>>> r2 = Rotation.from_euler('x', 90, degrees=True)
>>> r3 = r * r2
>>> r3.as_matrix()
Array([[0., 0., 1.],
[1., 0., 0.],
[0., 1., 0.]], dtype=float32)
See the scipy :class:`~scipy.spatial.transform.Rotation` documentation for
further examples of manipulating Rotation objects.
"""
quat: Array
@classmethod
def concatenate(cls, rotations: typing.Sequence):
"""Concatenate a sequence of `Rotation` objects."""
return cls(jnp.concatenate([rotation.quat for rotation in rotations]))
@classmethod
def from_euler(cls, seq: str, angles: Array, degrees: bool = False):
"""Initialize from Euler angles."""
num_axes = len(seq)
if num_axes < 1 or num_axes > 3:
raise ValueError("Expected axis specification to be a non-empty "
"string of upto 3 characters, got {}".format(seq))
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
if not (intrinsic or extrinsic):
raise ValueError("Expected axes from `seq` to be from ['x', 'y', "
"'z'] or ['X', 'Y', 'Z'], got {}".format(seq))
if any(seq[i] == seq[i+1] for i in range(num_axes - 1)):
raise ValueError("Expected consecutive axes to be different, "
"got {}".format(seq))
angles = jnp.atleast_1d(angles)
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
return cls(_elementary_quat_compose(angles, axes, intrinsic, degrees))
@classmethod
def from_matrix(cls, matrix: Array):
"""Initialize from rotation matrix."""
return cls(_from_matrix(matrix))
@classmethod
def from_mrp(cls, mrp: Array):
"""Initialize from Modified Rodrigues Parameters (MRPs)."""
return cls(_from_mrp(mrp))
@classmethod
def from_quat(cls, quat: Array):
"""Initialize from quaternions."""
return cls(_normalize_quaternion(quat))
@classmethod
def from_rotvec(cls, rotvec: Array, degrees: bool = False):
"""Initialize from rotation vectors."""
return cls(_from_rotvec(rotvec, degrees))
@classmethod
def identity(cls, num: int | None = None, dtype=float):
"""Get identity rotation(s)."""
assert num is None
quat = jnp.array([0., 0., 0., 1.], dtype=dtype)
return cls(quat)
@classmethod
def random(cls, random_key: Array, num: int | None = None):
"""Generate uniformly distributed rotations."""
# Need to implement scipy.stats.special_ortho_group for this to work...
raise NotImplementedError()
def __getitem__(self, indexer):
"""Extract rotation(s) at given index(es) from object."""
if self.single:
raise TypeError("Single rotation is not subscriptable.")
return Rotation(self.quat[indexer])
def __len__(self):
"""Number of rotations contained in this object."""
if self.single:
raise TypeError('Single rotation has no len().')
else:
return self.quat.shape[0]
def __mul__(self, other) -> Rotation:
"""Compose this rotation with the other."""
return Rotation.from_quat(_compose_quat(self.quat, other.quat))
def apply(self, vectors: Array, inverse: bool = False) -> Array:
"""Apply this rotation to one or more vectors."""
return _apply(self.as_matrix(), vectors, inverse)
def as_euler(self, seq: str, degrees: bool = False):
"""Represent as Euler angles."""
if len(seq) != 3:
raise ValueError(f"Expected 3 axes, got {seq}.")
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
if not (intrinsic or extrinsic):
raise ValueError("Expected axes from `seq` to be from "
"['x', 'y', 'z'] or ['X', 'Y', 'Z'], "
"got {}".format(seq))
if any(seq[i] == seq[i+1] for i in range(2)):
raise ValueError("Expected consecutive axes to be different, "
"got {}".format(seq))
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
with config.numpy_rank_promotion('allow'):
return _compute_euler_from_quat(self.quat, axes, extrinsic, degrees)
def as_matrix(self) -> Array:
"""Represent as rotation matrix."""
return _as_matrix(self.quat)
def as_mrp(self) -> Array:
"""Represent as Modified Rodrigues Parameters (MRPs)."""
return _as_mrp(self.quat)
def as_rotvec(self, degrees: bool = False) -> Array:
"""Represent as rotation vectors."""
return _as_rotvec(self.quat, degrees)
def as_quat(self, canonical: bool=False, scalar_first: bool=False) -> Array:
"""Represent as quaternions."""
quat = _make_canonical(self.quat) if canonical else self.quat
if scalar_first:
return jnp.roll(quat, shift=1, axis=-1)
return quat
def inv(self):
"""Invert this rotation."""
return Rotation(_inv(self.quat))
def magnitude(self) -> Array:
"""Get the magnitude(s) of the rotation(s)."""
return _magnitude(self.quat)
def mean(self, weights: Array | None = None):
"""Get the mean of the rotations."""
w = jnp.ones(self.quat.shape[0], dtype=self.quat.dtype) if weights is None else jnp.asarray(weights, dtype=self.quat.dtype)
if w.ndim != 1:
raise ValueError("Expected `weights` to be 1 dimensional, got "
"shape {}.".format(w.shape))
if w.shape[0] != len(self):
raise ValueError("Expected `weights` to have number of values "
"equal to number of rotations, got "
"{} values and {} rotations.".format(w.shape[0], len(self)))
K = jnp.dot(w[np.newaxis, :] * self.quat.T, self.quat)
_, v = jnp_linalg.eigh(K)
return Rotation(v[:, -1])
@property
def single(self) -> bool:
"""Whether this instance represents a single rotation."""
return self.quat.ndim == 1
class Slerp(typing.NamedTuple):
"""Spherical Linear Interpolation of Rotations.
JAX implementation of :class:`scipy.spatial.transform.Slerp`.
Examples:
Create a Slerp instance from a series of rotations:
>>> import math
>>> from jax.scipy.spatial.transform import Rotation, Slerp
>>> rots = jnp.array([[90, 0, 0],
... [0, 45, 0],
... [0, 0, -30]])
>>> key_rotations = Rotation.from_euler('zxy', rots, degrees=True)
>>> key_times = [0, 1, 2]
>>> slerp = Slerp.init(key_times, key_rotations)
>>> times = [0, 0.5, 1, 1.5, 2]
>>> interp_rots = slerp(times)
>>> interp_rots.as_euler('zxy')
Array([[ 1.5707963e+00, 0.0000000e+00, 0.0000000e+00],
[ 8.5309029e-01, 3.8711953e-01, 1.7768645e-01],
[-2.3841858e-07, 7.8539824e-01, 0.0000000e+00],
[-5.6668043e-02, 3.9213133e-01, -2.8347540e-01],
[ 0.0000000e+00, 0.0000000e+00, -5.2359891e-01]], dtype=float32)
"""
times: Array
timedelta: Array
rotations: Rotation
rotvecs: Array
@classmethod
def init(cls, times: Array, rotations: Rotation):
if not isinstance(rotations, Rotation):
raise TypeError("`rotations` must be a `Rotation` instance.")
if rotations.single or len(rotations) == 1:
raise ValueError("`rotations` must be a sequence of at least 2 rotations.")
times = jnp.asarray(times, dtype=rotations.quat.dtype)
if times.ndim != 1:
raise ValueError("Expected times to be specified in a 1 "
"dimensional array, got {} "
"dimensions.".format(times.ndim))
if times.shape[0] != len(rotations):
raise ValueError("Expected number of rotations to be equal to "
"number of timestamps given, got {} rotations "
"and {} timestamps.".format(len(rotations), times.shape[0]))
timedelta = jnp.diff(times)
# if jnp.any(timedelta <= 0): # this causes a concretization error...
# raise ValueError("Times must be in strictly increasing order.")
new_rotations = Rotation(rotations.as_quat()[:-1])
return cls(
times=times,
timedelta=timedelta,
rotations=new_rotations,
rotvecs=(new_rotations.inv() * Rotation(rotations.as_quat()[1:])).as_rotvec())
def __call__(self, times: Array):
"""Interpolate rotations."""
compute_times = jnp.asarray(times, dtype=self.times.dtype)
if compute_times.ndim > 1:
raise ValueError("`times` must be at most 1-dimensional.")
single_time = compute_times.ndim == 0
compute_times = jnp.atleast_1d(compute_times)
ind = jnp.maximum(jnp.searchsorted(self.times, compute_times) - 1, 0)
alpha = (compute_times - self.times[ind]) / self.timedelta[ind]
result = (self.rotations[ind] * Rotation.from_rotvec(self.rotvecs[ind] * alpha[:, None]))
if single_time:
return result[0]
return result
@functools.partial(jnp_vectorize.vectorize, signature='(m,m),(m),()->(m)')
def _apply(matrix: Array, vector: Array, inverse: bool) -> Array:
return jnp.where(inverse, matrix.T, matrix) @ vector
@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n,n)')
def _as_matrix(quat: Array) -> Array:
x = quat[0]
y = quat[1]
z = quat[2]
w = quat[3]
x2 = x * x
y2 = y * y
z2 = z * z
w2 = w * w
xy = x * y
zw = z * w
xz = x * z
yw = y * w
yz = y * z
xw = x * w
return jnp.array([[+ x2 - y2 - z2 + w2, 2 * (xy - zw), 2 * (xz + yw)],
[2 * (xy + zw), - x2 + y2 - z2 + w2, 2 * (yz - xw)],
[2 * (xz - yw), 2 * (yz + xw), - x2 - y2 + z2 + w2]])
@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n)')
def _as_mrp(quat: Array) -> Array:
sign = jnp.where(quat[3] < 0, -1., 1.)
denominator = 1. + sign * quat[3]
return sign * quat[:3] / denominator
@functools.partial(jnp_vectorize.vectorize, signature='(m),()->(n)')
def _as_rotvec(quat: Array, degrees: bool) -> Array:
quat = jnp.where(quat[3] < 0, -quat, quat) # w > 0 to ensure 0 <= angle <= pi
angle = 2. * jnp.arctan2(_vector_norm(quat[:3]), quat[3])
angle2 = angle * angle
small_scale = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
large_scale = angle / jnp.sin(angle / 2)
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
scale = jnp.where(degrees, jnp.rad2deg(scale), scale)
return scale * jnp.array(quat[:3])
@functools.partial(jnp_vectorize.vectorize, signature='(n),(n)->(n)')
def _compose_quat(p: Array, q: Array) -> Array:
cross = jnp.cross(p[:3], q[:3])
return jnp.array([p[3]*q[0] + q[3]*p[0] + cross[0],
p[3]*q[1] + q[3]*p[1] + cross[1],
p[3]*q[2] + q[3]*p[2] + cross[2],
p[3]*q[3] - p[0]*q[0] - p[1]*q[1] - p[2]*q[2]])
@functools.partial(jnp_vectorize.vectorize, signature='(m),(l),(),()->(n)')
def _compute_euler_from_quat(quat: Array, axes: Array, extrinsic: bool, degrees: bool) -> Array:
angle_first = jnp.where(extrinsic, 0, 2)
angle_third = jnp.where(extrinsic, 2, 0)
axes = jnp.where(extrinsic, axes, axes[::-1])
i = axes[0]
j = axes[1]
k = axes[2]
symmetric = i == k
k = jnp.where(symmetric, 3 - i - j, k)
sign = jnp.array((i - j) * (j - k) * (k - i) // 2, dtype=quat.dtype)
eps = 1e-7
a = jnp.where(symmetric, quat[3], quat[3] - quat[j])
b = jnp.where(symmetric, quat[i], quat[i] + quat[k] * sign)
c = jnp.where(symmetric, quat[j], quat[j] + quat[3])
d = jnp.where(symmetric, quat[k] * sign, quat[k] * sign - quat[i])
angles = jnp.empty(3, dtype=quat.dtype)
angles = angles.at[1].set(2 * jnp.arctan2(jnp.hypot(c, d), jnp.hypot(a, b)))
case = jnp.where(jnp.abs(angles[1] - np.pi) <= eps, 2, 0)
case = jnp.where(jnp.abs(angles[1]) <= eps, 1, case)
half_sum = jnp.arctan2(b, a)
half_diff = jnp.arctan2(d, c)
angles = angles.at[0].set(jnp.where(case == 1, 2 * half_sum, 2 * half_diff * jnp.where(extrinsic, -1, 1))) # any degenerate case
angles = angles.at[angle_first].set(jnp.where(case == 0, half_sum - half_diff, angles[angle_first]))
angles = angles.at[angle_third].set(jnp.where(case == 0, half_sum + half_diff, angles[angle_third]))
angles = angles.at[angle_third].set(jnp.where(symmetric, angles[angle_third], angles[angle_third] * sign))
angles = angles.at[1].set(jnp.where(symmetric, angles[1], angles[1] - np.pi / 2))
angles = (angles + np.pi) % (2 * np.pi) - np.pi
return jnp.where(degrees, jnp.rad2deg(angles), angles)
def _elementary_basis_index(axis: str) -> int:
if axis == 'x':
return 0
elif axis == 'y':
return 1
elif axis == 'z':
return 2
raise ValueError(f"Expected axis to be from ['x', 'y', 'z'], got {axis}")
@functools.partial(jnp_vectorize.vectorize, signature=('(m),(m),(),()->(n)'))
def _elementary_quat_compose(angles: Array, axes: Array, intrinsic: bool, degrees: bool) -> Array:
angles = jnp.where(degrees, jnp.deg2rad(angles), angles)
result = _make_elementary_quat(axes[0], angles[0])
for idx in range(1, len(axes)):
quat = _make_elementary_quat(axes[idx], angles[idx])
result = jnp.where(intrinsic, _compose_quat(result, quat), _compose_quat(quat, result))
return result
@functools.partial(jnp_vectorize.vectorize, signature=('(m),()->(n)'))
def _from_rotvec(rotvec: Array, degrees: bool) -> Array:
rotvec = jnp.where(degrees, jnp.deg2rad(rotvec), rotvec)
angle = _vector_norm(rotvec)
angle2 = angle * angle
small_scale = scale = 0.5 - angle2 / 48 + angle2 * angle2 / 3840
large_scale = jnp.sin(angle / 2) / angle
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
return jnp.hstack([scale * rotvec, jnp.cos(angle / 2)])
@functools.partial(jnp_vectorize.vectorize, signature=('(m,m)->(n)'))
def _from_matrix(matrix: Array) -> Array:
matrix_trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2]
decision = jnp.array([matrix[0, 0], matrix[1, 1], matrix[2, 2], matrix_trace], dtype=matrix.dtype)
choice = jnp.argmax(decision)
i = choice
j = (i + 1) % 3
k = (j + 1) % 3
quat_012 = jnp.empty(4, dtype=matrix.dtype)
quat_012 = quat_012.at[i].set(1 - decision[3] + 2 * matrix[i, i])
quat_012 = quat_012.at[j].set(matrix[j, i] + matrix[i, j])
quat_012 = quat_012.at[k].set(matrix[k, i] + matrix[i, k])
quat_012 = quat_012.at[3].set(matrix[k, j] - matrix[j, k])
quat_3 = jnp.empty(4, dtype=matrix.dtype)
quat_3 = quat_3.at[0].set(matrix[2, 1] - matrix[1, 2])
quat_3 = quat_3.at[1].set(matrix[0, 2] - matrix[2, 0])
quat_3 = quat_3.at[2].set(matrix[1, 0] - matrix[0, 1])
quat_3 = quat_3.at[3].set(1 + decision[3])
quat = jnp.where(choice != 3, quat_012, quat_3)
return _normalize_quaternion(quat)
@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n)')
def _from_mrp(mrp: Array) -> Array:
mrp_squared_plus_1 = jnp.dot(mrp, mrp) + 1
return jnp.hstack([2 * mrp[:3], (2 - mrp_squared_plus_1)]) / mrp_squared_plus_1
@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)')
def _inv(quat: Array) -> Array:
return quat * jnp.array([-1, -1, -1, 1], dtype=quat.dtype)
@functools.partial(jnp_vectorize.vectorize, signature='(n)->()')
def _magnitude(quat: Array) -> Array:
return 2. * jnp.arctan2(_vector_norm(quat[:3]), jnp.abs(quat[3]))
@functools.partial(jnp_vectorize.vectorize, signature='(),()->(n)')
def _make_elementary_quat(axis: int, angle: Array) -> Array:
quat = jnp.zeros(4, dtype=angle.dtype)
quat = quat.at[3].set(jnp.cos(angle / 2.))
quat = quat.at[axis].set(jnp.sin(angle / 2.))
return quat
@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)')
def _normalize_quaternion(quat: Array) -> Array:
return quat / _vector_norm(quat)
@functools.partial(jnp_vectorize.vectorize, signature='(n)->()')
def _vector_norm(vector: Array) -> Array:
return jnp.sqrt(jnp.dot(vector, vector))
@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)')
def _make_canonical(quat: Array) -> Array:
is_neg = quat < 0
is_zero = quat == 0
neg = (
is_neg[3]
| (is_zero[3] & is_neg[0])
| (is_zero[3] & is_zero[0] & is_neg[1])
| (is_zero[3] & is_zero[0] & is_zero[1] & is_neg[2])
)
return jnp.where(neg, -quat, quat)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,13 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,311 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
from typing import NamedTuple
import numpy as np
from jax._src import api
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import check_arraylike, promote_args_inexact
from jax._src.typing import ArrayLike, Array
from jax._src.util import canonicalize_axis
class ModeResult(NamedTuple):
mode: Array
count: Array
@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims'])
def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult:
"""Compute the mode (most common value) along an axis of an array.
JAX implementation of :func:`scipy.stats.mode`.
Args:
a: arraylike
axis: int, default=0. Axis along which to compute the mode.
nan_policy: str. JAX only supports ``"propagate"``.
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
Returns:
A tuple of arrays, ``(mode, count)``. ``mode`` is the array of modal values,
and ``count`` is the number of times each value appears in the input array.
Examples:
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> mode, count = jax.scipy.stats.mode(x)
>>> mode, count
(Array(4, dtype=int32), Array(3, dtype=int32))
For multi dimensional arrays, ``jax.scipy.stats.mode`` computes the ``mode``
and the corresponding ``count`` along ``axis=0``:
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
... [3, 1, 3, 2, 1, 3],
... [1, 2, 2, 3, 1, 2]])
>>> mode, count = jax.scipy.stats.mode(x1)
>>> mode, count
(Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))
If ``axis=1``, ``mode`` and ``count`` will be computed along ``axis 1``.
>>> mode, count = jax.scipy.stats.mode(x1, axis=1)
>>> mode, count
(Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))
By default, ``jax.scipy.stats.mode`` reduces the dimension of the result.
To keep the dimensions same as that of the input array, the argument
``keepdims`` must be set to ``True``.
>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True)
>>> mode, count
(Array([[1],
[3],
[2]], dtype=int32), Array([[3],
[3],
[3]], dtype=int32))
"""
check_arraylike("mode", a)
x = jnp.atleast_1d(a)
if nan_policy not in ["propagate", "omit", "raise"]:
raise ValueError(
f"Illegal nan_policy value {nan_policy!r}; expected one of "
"{'propagate', 'omit', 'raise'}"
)
if nan_policy == "omit":
# TODO: return answer without nans included.
raise NotImplementedError(
f"Logic for `nan_policy` of {nan_policy} is not implemented"
)
if nan_policy == "raise":
raise NotImplementedError(
"In order to best JIT compile `mode`, we cannot know whether `x` contains nans. "
"Please check if nans exist in `x` outside of the `mode` function."
)
if axis is not None:
axis = canonicalize_axis(axis, x.ndim)
input_shape = x.shape
if keepdims:
if axis is None:
output_shape = tuple(1 for i in input_shape)
else:
output_shape = tuple(1 if i == axis else s for i, s in enumerate(input_shape))
else:
if axis is None:
output_shape = ()
else:
output_shape = tuple(s for i, s in enumerate(input_shape) if i != axis)
if axis is None:
axis = 0
x = x.ravel()
def _mode_helper(x: Array) -> tuple[Array, Array]:
"""Helper function to return mode and count of a given array."""
if x.size == 0:
return (jnp.array(np.nan, dtype=dtypes.default_float_dtype()),
jnp.array(0, dtype=dtypes.default_float_dtype()))
else:
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
return vals[jnp.argmax(counts)], counts.max()
x = jnp.moveaxis(x, axis, 0)
x = x.reshape(x.shape[0], math.prod(x.shape[1:]))
vals, counts = api.vmap(_mode_helper, in_axes=1)(x)
return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape))
def invert_permutation(i: Array) -> Array:
"""Helper function that inverts a permutation array."""
return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype))
@api.jit(static_argnames=["method", "axis", "nan_policy"])
def rankdata(
a: ArrayLike,
method: str = "average",
*,
axis: int | None = None,
nan_policy: str = "propagate",
) -> Array:
"""Compute the rank of data along an array axis.
JAX implementation of :func:`scipy.stats.rankdata`.
Ranks begin at 1, and the *method* argument controls how ties are handled.
Args:
a: arraylike
method: str, default="average". Supported methods are
``("average", "min", "max", "dense", "ordinal")``
For details, see the :func:`scipy.stats.rankdata` documentation.
axis: optional integer. If not specified, the input array is flattened.
nan_policy: str, JAX's implementation only supports ``"propagate"``.
Returns:
array of ranks along the specified axis.
Examples:
>>> x = jnp.array([10, 30, 20])
>>> rankdata(x)
Array([1., 3., 2.], dtype=float32)
>>> x = jnp.array([1, 3, 2, 3])
>>> rankdata(x)
Array([1. , 3.5, 2. , 3.5], dtype=float32)
"""
check_arraylike("rankdata", a)
if nan_policy not in ["propagate", "omit", "raise"]:
raise ValueError(
f"Illegal nan_policy value {nan_policy!r}; expected one of "
"{'propagate', 'omit', 'raise'}"
)
if nan_policy == "omit":
raise NotImplementedError(
f"Logic for `nan_policy` of {nan_policy} is not implemented"
)
if nan_policy == "raise":
raise NotImplementedError(
"In order to best JIT compile `rankdata`, we cannot know whether `x` "
"contains nans. Please check if nans exist in `x` outside of the "
"`rankdata` function."
)
if method not in ("average", "min", "max", "dense", "ordinal"):
raise ValueError(f"unknown method '{method}'")
if axis is not None:
return jnp.apply_along_axis(rankdata, axis, a, method)
a = jnp.ravel(a)
out_dtype = dtypes.default_float_dtype()
def _rankdata(a: Array) -> Array:
arr, sorter = lax.sort_key_val(a, jnp.arange(a.size))
inv = invert_permutation(sorter)
if method == "ordinal":
return (inv + 1).astype(out_dtype)
obs = jnp.concatenate([jnp.array([True]), arr[1:] != arr[:-1]])
dense = obs.cumsum()[inv]
if method == "dense":
return dense.astype(out_dtype)
count = jnp.nonzero(obs, size=arr.size + 1, fill_value=obs.size)[0].astype(out_dtype)
if method == "max":
return count[dense]
if method == "min":
return count[dense - 1] + 1
if method == "average":
return .5 * (count[dense] + count[dense - 1] + 1)
raise ValueError(f"unknown method '{method}'")
return lax.cond(jnp.any(jnp.isnan(a)),
lambda a: jnp.full_like(a, jnp.nan, out_dtype),
_rankdata, a)
@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims'])
def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array:
"""Compute the standard error of the mean.
JAX implementation of :func:`scipy.stats.sem`.
Args:
a: arraylike
axis: optional integer. If not specified, the input array is flattened.
ddof: integer, default=1. The degrees of freedom in the SEM computation.
nan_policy: str, default="propagate". JAX supports only "propagate" and
"omit".
keepdims: bool, default=False. If true, reduced axes are left in the result
with size 1.
Returns:
array
Examples:
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x)
Array(0.41, dtype=float32)
For multi dimensional arrays, ``sem`` computes standard error of mean along
``axis=0``:
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
... [3, 1, 3, 2, 1, 3],
... [1, 2, 2, 3, 1, 2]])
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x1)
Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)
If ``axis=1``, standard error of mean will be computed along ``axis 1``.
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x1, axis=1)
Array([0.33, 0.4 , 0.31], dtype=float32)
If ``axis=None``, standard error of mean will be computed along all the axes.
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x1, axis=None)
Array(0.2, dtype=float32)
By default, ``sem`` reduces the dimension of the result. To keep the
dimensions same as that of the input array, the argument ``keepdims`` must
be set to ``True``.
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x1, axis=1, keepdims=True)
Array([[0.33],
[0.4 ],
[0.31]], dtype=float32)
Since, by default, ``nan_policy='propagate'``, ``sem`` propagates the ``nan``
values in the result.
>>> nan = np.nan
>>> x2 = jnp.array([[1, 2, 3, nan, 4, 2],
... [4, 5, 4, 3, nan, 1],
... [7, nan, 8, 7, 9, nan]])
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x2)
Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32)
If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error
for the remaining values along the specified axis.
>>> with jnp.printoptions(precision=2, suppress=True):
... jax.scipy.stats.sem(x2, nan_policy='omit')
Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32)
"""
b, = promote_args_inexact("sem", a)
if nan_policy == "propagate":
size = b.size if axis is None else b.shape[axis]
return b.std(axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(size).astype(b.dtype)
elif nan_policy == "omit":
count = (~jnp.isnan(b)).sum(axis, keepdims=keepdims)
return jnp.nanstd(b, axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(count).astype(b.dtype)
else:
raise ValueError(f"{nan_policy} is not supported")
@@ -0,0 +1,159 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import xlogy, xlog1py
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Bernoulli log probability mass function.
JAX implementation of :obj:`scipy.stats.bernoulli` ``logpmf``
The Bernoulli probability mass function is defined as
.. math::
f(k) = \begin{cases}
1 - p, & k = 0 \\
p, & k = 1 \\
0, & \mathrm{otherwise}
\end{cases}
Args:
k: arraylike, value at which to evaluate the PMF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset
Returns:
array of logpmf values
See Also:
- :func:`jax.scipy.stats.bernoulli.cdf`
- :func:`jax.scipy.stats.bernoulli.pmf`
- :func:`jax.scipy.stats.bernoulli.ppf`
"""
k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc)
zero = _lax_const(k, 0)
one = _lax_const(k, 1)
x = lax.sub(k, loc)
log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p)
return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
-np.inf, log_probs)
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Bernoulli probability mass function.
JAX implementation of :obj:`scipy.stats.bernoulli` ``pmf``
The Bernoulli probability mass function is defined as
.. math::
f(k) = \begin{cases}
1 - p, & k = 0 \\
p, & k = 1 \\
0, & \mathrm{otherwise}
\end{cases}
Args:
k: arraylike, value at which to evaluate the PMF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset
Returns:
array of pmf values
See Also:
- :func:`jax.scipy.stats.bernoulli.cdf`
- :func:`jax.scipy.stats.bernoulli.logpmf`
- :func:`jax.scipy.stats.bernoulli.ppf`
"""
return jnp.exp(logpmf(k, p, loc))
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
r"""Bernoulli cumulative distribution function.
JAX implementation of :obj:`scipy.stats.bernoulli` ``cdf``
The Bernoulli cumulative distribution function is defined as:
.. math::
f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)
where :math:`f_{pmf}(k, p)` is the Bernoulli probability mass function
:func:`jax.scipy.stats.bernoulli.pmf`.
Args:
k: arraylike, value at which to evaluate the CDF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset
Returns:
array of cdf values
See Also:
- :func:`jax.scipy.stats.bernoulli.logpmf`
- :func:`jax.scipy.stats.bernoulli.pmf`
- :func:`jax.scipy.stats.bernoulli.ppf`
"""
k, p = promote_args_inexact('bernoulli.cdf', k, p)
zero, one = _lax_const(k, 0), _lax_const(k, 1)
conds = [
jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one),
lax.lt(k, zero),
jnp.logical_and(lax.ge(k, zero), lax.lt(k, one)),
lax.ge(k, one)
]
vals = [jnp.nan, zero, one - p, one]
return jnp.select(conds, vals)
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
"""Bernoulli percent point function.
JAX implementation of :obj:`scipy.stats.bernoulli` ``ppf``
The percent point function is the inverse of the cumulative
distribution function, :func:`jax.scipy.stats.bernoulli.cdf`.
Args:
k: arraylike, value at which to evaluate the PPF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset
Returns:
array of ppf values
See Also:
- :func:`jax.scipy.stats.bernoulli.cdf`
- :func:`jax.scipy.stats.bernoulli.logpmf`
- :func:`jax.scipy.stats.bernoulli.pmf`
"""
q, p = promote_args_inexact('bernoulli.ppf', q, p)
zero, one = _lax_const(q, 0), _lax_const(q, 1)
return jnp.where(
jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one),
jnp.nan,
jnp.where(lax.le(q, one - p), zero, one)
)
@@ -0,0 +1,262 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import betaln, betainc, xlogy, xlog1py
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta log probability distribution function.
JAX implementation of :obj:`scipy.stats.beta` ``logpdf``.
The pdf of the beta function is:
.. math::
f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1}
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function,
It is defined for :math:`0\le x\le 1` and :math:`b>0`.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
one = _lax_const(x, 1)
zero = _lax_const(a, 0)
shape_term = lax.neg(betaln(a, b))
y = lax.div(lax.sub(x, loc), scale)
log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
xlog1py(lax.sub(b, one), lax.neg(y)))
log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
result = jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
lax.lt(x, loc)), -np.inf, log_probs)
result_positive_constants = jnp.where(jnp.logical_or(jnp.logical_or(lax.le(a, zero), lax.le(b, zero)),
lax.le(scale, zero)), np.nan, result)
return result_positive_constants
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta probability distribution function.
JAX implementation of :obj:`scipy.stats.beta` ``pdf``.
The pdf of the beta function is:
.. math::
f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1}
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
It is defined for :math:`0\le x\le 1` and :math:`b>0`.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logpdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
return lax.exp(logpdf(x, a, b, loc, scale))
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta cumulative distribution function
JAX implementation of :obj:`scipy.stats.beta` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y
where :math:`f_{pdf}` is the beta distribution probability density function,
:func:`jax.scipy.stats.beta.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values
See Also:
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logpdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
return betainc(
a,
b,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.beta` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y
where :math:`f_{pdf}` is the beta distribution probability density function,
:func:`jax.scipy.stats.beta.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logpdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
return lax.log(cdf(x, a, b, loc, scale))
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta distribution survival function.
JAX implementation of :obj:`scipy.stats.beta` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b)
where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function,
:func:`jax.scipy.stats.beta.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logpdf`
- :func:`jax.scipy.stats.beta.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
return betainc(
b,
a,
1 - lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, 1),
)
)
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Beta distribution log survival function.
JAX implementation of :obj:`scipy.stats.beta` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b)
where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function,
:func:`jax.scipy.stats.beta.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.beta.cdf`
- :func:`jax.scipy.stats.beta.pdf`
- :func:`jax.scipy.stats.beta.sf`
- :func:`jax.scipy.stats.beta.logcdf`
- :func:`jax.scipy.stats.beta.logpdf`
"""
return lax.log(sf(x, a, b, loc, scale))
@@ -0,0 +1,98 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from jax._src import api
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import betaln
from jax._src.typing import Array, ArrayLike
@api.jit
def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0) -> Array:
r"""Beta-binomial log probability mass function.
JAX implementation of :obj:`scipy.stats.betabinom` ``logpmf``
The beta-binomial distribution's probability mass function is defined as
.. math::
f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)}
where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is
defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpmf values
See Also:
:func:`jax.scipy.stats.betabinom.pmf`
"""
k, n, a, b, loc = promote_args_inexact("betabinom.logpmf", k, n, a, b, loc)
y = lax.sub(lax.floor(k), loc)
one = _lax_const(y, 1)
zero = _lax_const(y, 0)
combiln = lax.neg(lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n,y), one), lax.add(y,one))))
beta_lns = lax.sub(betaln(lax.add(y,a), lax.add(lax.sub(n,y),b)), betaln(a,b))
log_probs = lax.add(combiln, beta_lns)
log_probs = jnp.where(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)), 0., log_probs)
y_cond = jnp.logical_or(jnp.logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, n)),
lax.le(lax.add(y, a), zero))
log_probs = jnp.where(y_cond, -np.inf, log_probs)
n_a_b_cond = jnp.logical_or(jnp.logical_or(lax.lt(n, zero), lax.le(a, zero)), lax.le(b, zero))
return jnp.where(n_a_b_cond, np.nan, log_probs)
def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
loc: ArrayLike = 0) -> Array:
r"""Beta-binomial probability mass function.
JAX implementation of :obj:`scipy.stats.betabinom` ``pmf``.
The beta-binomial distribution's probability mass function is defined as
.. math::
f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)}
where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is
defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values
See Also:
:func:`jax.scipy.stats.betabinom.logpmf`
"""
return lax.exp(logpmf(k, n, a, b, loc))
@@ -0,0 +1,90 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_args_inexact
from jax._src.lax.lax import _const as _lax_const
from jax._src.scipy.special import gammaln, xlogy, xlog1py
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Binomial log probability mass function.
JAX implementation of :obj:`scipy.stats.binom` ``logpmf``.
The binomial probability mass function is defined as
.. math::
f(k, n, p) = {n \choose k}p^k(1-p)^{n-k}
for :math:`0\le p\le 1` and non-negative integers :math:`k`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpmf values.
See Also:
:func:`jax.scipy.stats.binom.pmf`
"""
k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc)
y = lax.sub(k, loc)
zero = _lax_const(y, 0)
comb_term = lax.sub(
gammaln(n + 1),
lax.add(gammaln(y + 1), gammaln(n - y + 1))
)
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
log_probs = lax.add(comb_term, log_linear_term)
y_n_cond = jnp.logical_or(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)),
lax.eq(log_linear_term, zero))
log_probs = jnp.where(y_n_cond, 0., log_probs)
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -np.inf)
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Binomial probability mass function.
JAX implementation of :obj:`scipy.stats.binom` ``pmf``.
The binomial probability mass function is defined as
.. math::
f(k, n, p) = {n \choose k}p^k(1-p)^{n-k}
for :math:`0\le p\le 1` and non-negative integers :math:`k`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values.
See Also:
:func:`jax.scipy.stats.binom.logpmf`
"""
return lax.exp(logpmf(k, n, p, loc))
@@ -0,0 +1,293 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.ufuncs import arctan
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy log probability distribution function.
JAX implementation of :obj:`scipy.stats.cauchy` ``logpdf``.
The Cauchy probability distribution function is
.. math::
f(x) = \frac{1}{\pi(1 + x^2)}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale)
pi = _lax_const(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
normalize_term = lax.log(lax.mul(pi, scale))
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy probability distribution function.
JAX implementation of :obj:`scipy.stats.cauchy` ``pdf``.
The Cauchy probability distribution function is
.. math::
f(x) = \frac{1}{\pi(1 + x^2)}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy cumulative distribution function.
JAX implementation of :obj:`scipy.stats.cauchy` ``cdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the Cauchy probability distribution function,
:func:`jax.scipy.stats.cauchy.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
pi = _lax_const(x, np.pi)
scaled_x = lax.div(lax.sub(x, loc), scale)
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.cauchy` ``logcdf``
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the Cauchy probability distribution function,
:func:`jax.scipy.stats.cauchy.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
return lax.log(cdf(x, loc, scale))
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy distribution log survival function.
JAX implementation of :obj:`scipy.stats.cauchy` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.cauchy.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
return cdf(-x, -loc, scale)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy distribution log survival function.
JAX implementation of :obj:`scipy.stats.cauchy` ``logsf``
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.cauchy.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.isf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy distribution inverse survival function.
JAX implementation of :obj:`scipy.stats.cauchy` ``isf``.
Returns the inverse of the survival function,
:func:`jax.scipy.stats.cauchy.sf`.
Args:
q: arraylike, value at which to evaluate the ISF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of isf values.
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.ppf`
"""
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
return lax.add(lax.mul(unscaled, scale), loc)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Cauchy distribution percent point function.
JAX implementation of :obj:`scipy.stats.cauchy` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.cauchy.cdf`.
Args:
q: arraylike, value at which to evaluate the PPF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of ppf values.
See Also:
- :func:`jax.scipy.stats.cauchy.cdf`
- :func:`jax.scipy.stats.cauchy.pdf`
- :func:`jax.scipy.stats.cauchy.sf`
- :func:`jax.scipy.stats.cauchy.logcdf`
- :func:`jax.scipy.stats.cauchy.logpdf`
- :func:`jax.scipy.stats.cauchy.logsf`
- :func:`jax.scipy.stats.cauchy.isf`
"""
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
pi = _lax_const(q, np.pi)
half_pi = _lax_const(q, np.pi / 2)
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
return lax.add(lax.mul(unscaled, scale), loc)
@@ -0,0 +1,267 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import gammainc, gammaincc
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square log probability distribution function.
JAX implementation of :obj:`scipy.stats.chi2` ``logpdf``.
The chi-square probability distribution function is given by:
.. math::
f(x, k) = \begin{cases}
\frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\
0 & \mathrm{otherwise}
\end{cases}
for :math:`k` degrees of freedom, and where :math:`\Gamma` is the
:func:`~jax.scipy.special.gamma` function. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the PDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
one = _lax_const(x, 1)
two = _lax_const(x, 2)
y = lax.div(lax.sub(x, loc), scale)
df_on_two = lax.div(df, two)
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
return jnp.where(lax.lt(x, loc), -np.inf, log_probs)
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square probability distribution function.
JAX implementation of :obj:`scipy.stats.chi2` ``pdf``.
The chi-square probability distribution function is given by:
.. math::
f(x, k) = \begin{cases}
\frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\
0 & \mathrm{otherwise}
\end{cases}
for :math:`k` degrees of freedom, and where :math:`\Gamma` is the
:func:`~jax.scipy.special.gamma` function. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the PDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logpdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
return lax.exp(logpdf(x, df, loc, scale))
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square cumulative distribution function.
JAX implementation of :obj:`scipy.stats.chi2` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the CDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logpdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammainc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, np.inf),
),
)
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.chi2` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the CDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logpdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
return lax.log(cdf(x, df, loc, scale))
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square survival function.
JAX implementation of :obj:`scipy.stats.chi2` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the SF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logpdf`
- :func:`jax.scipy.stats.chi2.logsf`
"""
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
two = _lax_const(scale, 2)
return gammaincc(
lax.div(df, two),
lax.clamp(
_lax_const(x, 0),
lax.div(
lax.sub(x, loc),
lax.mul(scale, two),
),
_lax_const(x, np.inf),
),
)
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Chi-square log survival function.
JAX implementation of :obj:`scipy.stats.chi2` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy
convention of using ``df`` to denote degrees of freedom.
Args:
x: arraylike, value at which to evaluate the SF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.chi2.cdf`
- :func:`jax.scipy.stats.chi2.pdf`
- :func:`jax.scipy.stats.chi2.sf`
- :func:`jax.scipy.stats.chi2.logcdf`
- :func:`jax.scipy.stats.chi2.logpdf`
"""
return lax.log(sf(x, df, loc, scale))
@@ -0,0 +1,100 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
def _is_simplex(x: Array) -> Array:
x_sum = jnp.sum(x, axis=0)
return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6)
def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array:
r"""Dirichlet log probability distribution function.
JAX implementation of :obj:`scipy.stats.dirichlet` ``logpdf``.
The Dirichlet probability density function is
.. math::
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
in a :math:`K`-dimensional vector space.
Args:
x: arraylike, value at which to evaluate the PDF
alpha: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.dirichlet.pdf`
"""
return _logpdf(*promote_dtypes_inexact(x, alpha))
def _logpdf(x: Array, alpha: Array) -> Array:
if alpha.ndim != 1:
raise ValueError(
f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
)
if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
raise ValueError(
"`x` must have either the same number of entries as `alpha` "
f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
)
one = _lax_const(x, 1)
if x.shape[0] != alpha.shape[0]:
x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
if x.ndim > 1:
alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
return jnp.where(_is_simplex(x), log_probs, -np.inf)
def pdf(x: ArrayLike, alpha: ArrayLike) -> Array:
r"""Dirichlet probability distribution function.
JAX implementation of :obj:`scipy.stats.dirichlet` ``pdf``.
The Dirichlet probability density function is
.. math::
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
in a :math:`K`-dimensional vector space.
Args:
x: arraylike, value at which to evaluate the PDF
alpha: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.dirichlet.logpdf`
"""
return lax.exp(logpdf(x, alpha))
@@ -0,0 +1,270 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log probability distribution function.
JAX implementation of :obj:`scipy.stats.expon` ``logpdf``.
The Exponential probability distribution function is
.. math::
f(x) = \begin{cases}
e^{-x} & x \ge 0 \\
0 & \mathrm{otherwise}
\end{cases}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
log_scale = lax.log(scale)
linear_term = lax.div(lax.sub(x, loc), scale)
log_probs = lax.neg(lax.add(linear_term, log_scale))
return jnp.where(lax.lt(x, loc), -np.inf, log_probs)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential probability distribution function.
JAX implementation of :obj:`scipy.stats.expon` ``pdf``.
The Exponential probability distribution function is
.. math::
f(x) = \begin{cases}
e^{-x} & x \ge 0 \\
0 & \mathrm{otherwise}
\end{cases}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential cumulative density function.
JAX implementation of :obj:`scipy.stats.expon` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(
lax.lt(x, loc),
jnp.zeros_like(neg_scaled_x),
lax.neg(lax.expm1(neg_scaled_x)),
)
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log cumulative density function.
JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.log1p(lax.neg(sf(x, loc, scale)))
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log survival function.
JAX implementation of :obj:`scipy.stats.expon` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(lax.lt(x, loc), jnp.zeros_like(neg_scaled_x), neg_scaled_x)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.
JAX implementation of :obj:`scipy.stats.expon` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logsf(x, loc, scale))
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.
JAX implementation of :obj:`scipy.stats.expon` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
neg_scaled_q = lax.div(lax.sub(loc, q), scale)
return jnp.where(
jnp.isnan(q) | (q < 0) | (q > 1),
np.nan,
lax.neg(lax.log1p(neg_scaled_q)),
)
@@ -0,0 +1,237 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy, gammainc, gammaincc
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma log probability distribution function.
JAX implementation of :obj:`scipy.stats.gamma` ``logpdf``.
The Gamma probability distribution is given by
.. math::
f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x}
Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
It is defined for :math:`x \ge 0` and :math:`a > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
ok = lax.ge(x, loc)
one = _lax_const(x, 1)
y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one)
log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
shape_terms = lax.add(gammaln(a), lax.log(scale))
log_probs = lax.sub(log_linear_term, shape_terms)
return jnp.where(ok, log_probs, -np.inf)
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma probability distribution function.
JAX implementation of :obj:`scipy.stats.gamma` ``pdf``.
The Gamma probability distribution is given by
.. math::
f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x}
Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
It is defined for :math:`x \ge 0` and :math:`a > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logpdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
return lax.exp(logpdf(x, a, loc, scale))
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma cumulative distribution function.
JAX implementation of :obj:`scipy.stats.gamma` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.gamma.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logpdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
return gammainc(
a,
lax.clamp(
_lax_const(x, 0),
lax.div(lax.sub(x, loc), scale),
_lax_const(x, np.inf),
)
)
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.gamma` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.gamma.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logpdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
return lax.log(cdf(x, a, loc, scale))
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma survival function.
JAX implementation of :obj:`scipy.stats.gamma` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.gamma.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logpdf`
- :func:`jax.scipy.stats.gamma.logsf`
"""
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
y = lax.div(lax.sub(x, loc), scale)
return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y))
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Gamma log survival function.
JAX implementation of :obj:`scipy.stats.gamma` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.gamma.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.gamma.cdf`
- :func:`jax.scipy.stats.gamma.pdf`
- :func:`jax.scipy.stats.gamma.sf`
- :func:`jax.scipy.stats.gamma.logcdf`
- :func:`jax.scipy.stats.gamma.logpdf`
"""
return lax.log(sf(x, a, loc, scale))
@@ -0,0 +1,103 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src import lax
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, beta: ArrayLike) -> Array:
r"""Generalized normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.gennorm` ``logpdf``.
The generalized normal probability distribution function is defined as
.. math::
f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta)
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and
:math:`\beta > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
beta: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.gennorm.cdf`
- :func:`jax.scipy.stats.gennorm.pdf`
"""
x, beta = promote_args_inexact("gennorm.logpdf", x, beta)
return lax.log(.5 * beta) - lax.lgamma(1/beta) - lax.abs(x)**beta
def cdf(x: ArrayLike, beta: ArrayLike) -> Array:
r"""Generalized normal cumulative distribution function.
JAX implementation of :obj:`scipy.stats.gennorm` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.gennorm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
beta: arraylike, distribution shape parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.gennorm.pdf`
- :func:`jax.scipy.stats.gennorm.logpdf`
"""
x, beta = promote_args_inexact("gennorm.cdf", x, beta)
return .5 * (1 + lax.sign(x) * lax.igamma(1/beta, lax.abs(x)**beta))
def pdf(x: ArrayLike, beta: ArrayLike) -> Array:
r"""Generalized normal probability distribution function.
JAX implementation of :obj:`scipy.stats.gennorm` ``pdf``.
The generalized normal probability distribution function is defined as
.. math::
f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta)
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and
:math:`\beta > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
beta: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.gennorm.cdf`
- :func:`jax.scipy.stats.gennorm.logpdf`
"""
return lax.exp(logpdf(x, beta))
@@ -0,0 +1,81 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import xlog1py
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Geometric log probability mass function.
JAX implementation of :obj:`scipy.stats.geom` ``logpmf``.
The Geometric probability mass function is given by
.. math::
f(k) = (1 - p)^{k-1}p
for :math:`k\ge 1` and :math:`0 \le p \le 1`.
Args:
k: arraylike, value at which to evaluate the PMF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpmf values.
See Also:
:func:`jax.scipy.stats.geom.pmf`
"""
k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc)
zero = _lax_const(k, 0)
one = _lax_const(k, 1)
x = lax.sub(k, loc)
log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
return jnp.where(lax.le(x, zero), -np.inf, log_probs)
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Geometric probability mass function.
JAX implementation of :obj:`scipy.stats.geom` ``pmf``.
The Geometric probability mass function is given by
.. math::
f(k) = (1 - p)^{k-1}p
for :math:`k\ge 1` and :math:`0 \le p \le 1`.
Args:
k: arraylike, value at which to evaluate the PMF
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values.
See Also:
:func:`jax.scipy.stats.geom.logpmf`
"""
return jnp.exp(logpmf(k, p, loc))
@@ -0,0 +1,256 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax._src.scipy.special import xlogy, xlog1py
def logpdf(x: ArrayLike,
loc: ArrayLike = 0,
scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) log probability distribution function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``.
.. math::
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate log(pdf)
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.ppf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
x, loc, scale = promote_args_inexact("gumbel_l.logpdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
# logpdf = -log(scale) + z - exp(z)
z = lax.div(lax.sub(x, loc), scale)
neg_log_scale = xlogy(-1, scale)
t2 = lax.sub(z, lax.exp(z))
log_pdf = lax.add(neg_log_scale, t2)
return jnp.where(ok, log_pdf, np.nan)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) probability distribution function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``pdf``.
.. math::
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate pdf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of pdf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.ppf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
return lax.exp(logpdf(x, loc, scale))
def logcdf(x: ArrayLike,
loc: ArrayLike = 0,
scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) log cumulative density function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logcdf``.
.. math::
f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate log(cdf)
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logcdf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.ppf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
x, loc, scale = promote_args_inexact("gumbel_l.logcdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
z = lax.div(lax.sub(x, loc), scale)
neg_exp_z = lax.neg(lax.exp(z))
# xlog1p fails here, that's why log1p is used here
# even log1p fails for some cases when using float64 mode
# so we're using this formula which is stable
log_cdf = lax.log(-lax.expm1(neg_exp_z))
return jnp.where(ok, log_cdf, np.nan)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) cumulative density function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``cdf``.
.. math::
f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate cdf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of cdf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.ppf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
return lax.exp(logcdf(x, loc, scale))
def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) percent point function (inverse of CDF)
JAX implementation of :obj:`scipy.stats.gumbel_l` ``ppf``.
.. math::
F_{ppf}}(p; \mu, \beta) = \mu + \beta \log\left( -\log(1 - p) \right)
Args:
p: ArrayLike, probability value (quantile) at which to evaluate ppf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of ppf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
p, loc, scale = promote_args_inexact("gumbel_l.ppf", p, loc, scale)
ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)),
lax.lt(p, _lax_const(p, 1)))
# quantile = loc + (scale)*log(-log(1 - p))
t1 = xlog1py(-1, lax.neg(p))
# xlogp failed here too, that's why log is used
t = lax.mul(scale, lax.log(t1))
quantile = lax.add(loc, t)
return jnp.where(ok, quantile, np.nan)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) survival function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``sf``.
.. math::
f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta)
Args:
x: ArrayLike, value at which to evaluate survival function
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of sf values (1 - cdf)
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.logsf`
"""
return jnp.exp(logsf(x, loc, scale))
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Left Skewed) log survival function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logsf``.
.. math::
f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta)
Args:
x: ArrayLike, value at which to evaluate log survival function
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logsf values
See Also:
- :func:`jax.scipy.stats.gumbel_l.logpdf`
- :func:`jax.scipy.stats.gumbel_l.pdf`
- :func:`jax.scipy.stats.gumbel_l.logcdf`
- :func:`jax.scipy.stats.gumbel_l.cdf`
- :func:`jax.scipy.stats.gumbel_l.sf`
"""
x, loc, scale = promote_args_inexact("gumbel_l.logsf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
# logsf = -exp(z)
z = lax.div(lax.sub(x, loc), scale)
log_sf = lax.neg(lax.exp(z))
return jnp.where(ok, log_sf, np.nan)
@@ -0,0 +1,257 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
from jax._src.scipy.special import xlogy
from jax._src.nn.functions import log1mexp
def logpdf(x: ArrayLike,
loc: ArrayLike = 0,
scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) log probability distribution function.
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``.
.. math::
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate log(pdf)
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.ppf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
x, loc, scale = promote_args_inexact("gumbel_r.logpdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
z = lax.div(lax.sub(x, loc), scale)
# logpdf = -log(beta) - (z + exp(-z))
neg_log_scale = xlogy(-1, scale)
t2 = lax.neg(lax.add(z, lax.exp(lax.neg(z))))
log_pdf = lax.add(neg_log_scale, t2)
return jnp.where(ok, log_pdf, np.nan)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) probability distribution function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``pdf``.
.. math::
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate pdf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of pdf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.ppf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
return lax.exp(logpdf(x, loc, scale))
def logcdf(x: ArrayLike,
loc: ArrayLike = 0,
scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) log cumulative density function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``logcdf``.
.. math::
f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate log(cdf)
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logcdf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.ppf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
x, loc, scale = promote_args_inexact("gumbel_r.logcdf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
z = lax.div(lax.sub(x, loc), scale)
# log cdf = -exp(-z)
log_cdf = lax.neg(lax.exp(lax.neg(z)))
return jnp.where(ok, log_cdf, np.nan)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) cumulative density function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``cdf``.
.. math::
f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right)
Args:
x: ArrayLike, value at which to evaluate cdf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of cdf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.ppf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
return lax.exp(logcdf(x, loc, scale))
def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) percent point function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``ppf``.
.. math::
F(p; \mu, \beta) = \mu - \beta \log\left( -\log(p) \right)
Args:
p: ArrayLike, probability value (quantile) at which to evaluate ppf
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of ppf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.sf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
p, loc, scale = promote_args_inexact("gumbel_r.ppf", p, loc, scale)
# 0 < p < 1
ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)),
lax.lt(p, _lax_const(p, 1)))
# quantile = loc - (scale)*log(-log(p))
t1 = xlogy(-1, p)
t = lax.mul(scale, lax.log(t1))
quantile = lax.sub(loc, t)
return jnp.where(ok, quantile, np.nan)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) survival function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``sf``.
.. math::
f_{sf}(x; \mu, \beta) = 1 - F_{cdf}(x; \mu, \beta)
Args:
x: ArrayLike, value at which to evaluate survival function
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of sf values (1 - cdf)
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.logsf`
"""
x, loc, scale = promote_args_inexact("gumbel_r.sf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
# sf = 1 - exp(-exp(-z))
neg_z = lax.div(lax.sub(loc, x), scale)
t1 = lax.exp(lax.neg(lax.exp(neg_z)))
_sf = lax.sub(_lax_const(x, 1), t1)
return jnp.where(ok, _sf, np.nan)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""
Gumbel Distribution (Right Skewed) log survival function.
JAX implementation of :obj:`scipy.stats.gumbel_r` ``logsf``.
Args:
x: ArrayLike, value at which to evaluate log survival function
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
Returns:
array of logsf values
See Also:
- :func:`jax.scipy.stats.gumbel_r.logpdf`
- :func:`jax.scipy.stats.gumbel_r.pdf`
- :func:`jax.scipy.stats.gumbel_r.logcdf`
- :func:`jax.scipy.stats.gumbel_r.cdf`
- :func:`jax.scipy.stats.gumbel_r.sf`
"""
x, loc, scale = promote_args_inexact("gumbel_r.logsf", x, loc, scale)
ok = lax.gt(scale, _lax_const(scale, 0))
# logsf = log(1 - exp(-exp(-z)))
neg_z = lax.div(lax.sub(loc, x), scale)
log_sf = log1mexp(lax.exp(neg_z))
return jnp.where(ok, log_sf, np.nan)
@@ -0,0 +1,284 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, cast
import numpy as np
from jax._src import api
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src import random
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
from jax._src.scipy import linalg, special
from jax._src.tree_util import register_pytree_node_class
from jax._src.typing import Array
BwMethod = None | str | Array | Callable[[Any], Array]
@register_pytree_node_class
@dataclass(frozen=True, init=False)
class gaussian_kde:
"""Gaussian Kernel Density Estimator
JAX implementation of :class:`scipy.stats.gaussian_kde`.
Parameters:
dataset: arraylike, real-valued. Data from which to estimate the distribution.
If 1D, shape is (n_data,). If 2D, shape is (n_dimensions, n_data).
bw_method: string, scalar, or callable. Either "scott", "silverman", a scalar
value, or a callable function which takes ``self`` as a parameter.
weights: arraylike, optional. Weights of the same shape as the dataset.
"""
neff: Any
dataset: Any
weights: Any
covariance: Any
inv_cov: Any
def __init__(self, dataset, bw_method: BwMethod = None, weights=None):
check_arraylike("gaussian_kde", dataset)
dataset = jnp.atleast_2d(dataset)
if dtypes.issubdtype(lax.dtype(dataset), np.complexfloating):
raise NotImplementedError("gaussian_kde does not support complex data")
if not dataset.size > 1:
raise ValueError("`dataset` input should have multiple elements.")
d, n = dataset.shape
if weights is not None:
check_arraylike("gaussian_kde", weights)
dataset, weights = promote_dtypes_inexact(dataset, weights)
weights = jnp.atleast_1d(weights)
weights /= jnp.sum(weights)
if weights.ndim != 1:
raise ValueError("`weights` input should be one-dimensional.")
if len(weights) != n:
raise ValueError("`weights` input should be of length n")
else:
dataset, = promote_dtypes_inexact(dataset)
weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)
self._setattr("dataset", dataset)
self._setattr("weights", weights)
neff = self._setattr("neff", 1 / jnp.sum(weights**2))
bw_method = "scott" if bw_method is None else bw_method
if bw_method == "scott":
factor = jnp.power(neff, -1. / (d + 4))
elif bw_method == "silverman":
factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
factor = cast(Array, bw_method)
elif callable(bw_method):
factor = bw_method(self)
else:
raise ValueError(
"`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
)
data_covariance = jnp.atleast_2d(
jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
data_inv_cov = jnp.linalg.inv(data_covariance)
covariance = data_covariance * factor**2
inv_cov = data_inv_cov / factor**2
self._setattr("covariance", covariance)
self._setattr("inv_cov", inv_cov)
def _setattr(self, name, value):
# Frozen dataclasses don't support setting attributes so we have to
# overload that operation here as they do in the dataclass implementation
object.__setattr__(self, name, value)
return value
def tree_flatten(self):
return ((self.neff, self.dataset, self.weights, self.covariance,
self.inv_cov), None)
@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
kde = cls.__new__(cls)
kde._setattr("neff", children[0])
kde._setattr("dataset", children[1])
kde._setattr("weights", children[2])
kde._setattr("covariance", children[3])
kde._setattr("inv_cov", children[4])
return kde
@property
def d(self):
return self.dataset.shape[0]
@property
def n(self):
return self.dataset.shape[1]
def evaluate(self, points):
"""Evaluate the Gaussian KDE on the given points."""
check_arraylike("evaluate", points)
points = self._reshape_points(points)
result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None],
points.T, self.inv_cov)
return result[:, 0]
def __call__(self, points):
return self.evaluate(points)
def integrate_gaussian(self, mean, cov):
"""Integrate the distribution weighted by a Gaussian."""
mean = jnp.atleast_1d(jnp.squeeze(mean))
cov = jnp.atleast_2d(cov)
if mean.shape != (self.d,):
raise ValueError(f"mean does not have dimension {self.d}")
if cov.shape != (self.d, self.d):
raise ValueError(f"covariance does not have dimension {self.d}")
chol = linalg.cho_factor(self.covariance + cov)
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
norm = 1.0 / norm
return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights,
mean)
@api.jit
def integrate_box_1d(self, low, high):
"""Integrate the distribution over the given limits."""
if self.d != 1:
raise ValueError("integrate_box_1d() only handles 1D pdfs")
if np.ndim(low) != 0 or np.ndim(high) != 0:
raise ValueError(
"the limits of integration in integrate_box_1d must be scalars")
sigma = jnp.squeeze(jnp.sqrt(self.covariance))
low = jnp.squeeze((low - self.dataset) / sigma)
high = jnp.squeeze((high - self.dataset) / sigma)
return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low)))
def integrate_kde(self, other):
"""Integrate the product of two Gaussian KDE distributions."""
if other.d != self.d:
raise ValueError("KDEs are not the same dimensionality")
chol = linalg.cho_factor(self.covariance + other.covariance)
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
norm = 1.0 / norm
sm, lg = (self, other) if self.n < other.n else (other, self)
result = api.vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset,
lg.weights),
in_axes=1)(sm.dataset)
return jnp.sum(result * sm.weights)
@partial(api.jit, static_argnames=("shape",))
def resample(self, key, shape=()):
r"""Randomly sample a dataset from the estimated pdf
Args:
key: a PRNG key used as the random key.
shape: optional, a tuple of nonnegative integers specifying the result
batch shape; that is, the prefix of the result shape excluding the last
axis.
Returns:
The resampled dataset as an array with shape `(d,) + shape`.
"""
ind_key, eps_key = random.split(key)
ind = random.choice(ind_key, self.n, shape=shape, p=self.weights)
eps = random.multivariate_normal(eps_key,
jnp.zeros(self.d, self.covariance.dtype),
self.covariance,
shape=shape,
dtype=self.dataset.dtype).T
return self.dataset[:, ind] + eps
def pdf(self, x):
"""Probability density function"""
return self.evaluate(x)
def logpdf(self, x):
"""Log probability density function"""
check_arraylike("logpdf", x)
x = self._reshape_points(x)
result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None],
x.T, self.inv_cov)
return result[:, 0]
def integrate_box(self, low_bounds, high_bounds, maxpts=None):
"""This method is not implemented in the JAX interface."""
del low_bounds, high_bounds, maxpts
raise NotImplementedError(
"only 1D box integrations are supported; use `integrate_box_1d`")
def set_bandwidth(self, bw_method=None):
"""This method is not implemented in the JAX interface."""
del bw_method
raise NotImplementedError(
"dynamically changing the bandwidth method is not supported")
def _reshape_points(self, points):
if dtypes.issubdtype(lax.dtype(points), np.complexfloating):
raise NotImplementedError(
"gaussian_kde does not support complex coordinates")
points = jnp.atleast_2d(points)
d, m = points.shape
if d != self.d:
if d == 1 and m == self.d:
points = jnp.reshape(points, (self.d, 1))
else:
raise ValueError(
"points have dimension {}, dataset has dimension {}".format(
d, self.d))
return points
def _gaussian_kernel_convolve(chol, norm, target, weights, mean):
diff = target - mean[:, None]
alpha = linalg.cho_solve(chol, diff)
arg = 0.5 * jnp.sum(diff * alpha, axis=0)
return norm * jnp.sum(jnp.exp(-arg) * weights)
@api.jit(static_argnums=0)
def _gaussian_kernel_eval(in_log, points, values, xi, precision):
points, values, xi, precision = promote_dtypes_inexact(
points, values, xi, precision)
d = points.shape[1]
if xi.shape[1] != d:
raise ValueError("points and xi must have same trailing dim")
if precision.shape != (d, d):
raise ValueError("precision matrix must match data dims")
whitening = linalg.cholesky(precision, lower=True)
points = jnp.dot(points, whitening)
xi = jnp.dot(xi, whitening)
log_norm = jnp.sum(jnp.log(
jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi)
def kernel(x_test, x_train, y_train):
arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test))
if in_log:
return jnp.log(y_train) + arg
else:
return y_train * jnp.exp(arg)
reduce = special.logsumexp if in_log else jnp.sum
reduced_kernel = lambda x: reduce(api.vmap(kernel, in_axes=(None, 0, 0))
(x, points, values),
axis=0)
mapped_kernel = api.vmap(reduced_kernel)
return mapped_kernel(xi)
@@ -0,0 +1,109 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Laplace log probability distribution function.
JAX implementation of :obj:`scipy.stats.laplace` ``logpdf``.
The Laplace probability distribution function is given by
.. math::
f(x) = \frac{1}{2} e^{-|x|}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.laplace.cdf`
- :func:`jax.scipy.stats.laplace.pdf`
"""
x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale)
two = _lax_const(x, 2)
linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Laplace probability distribution function.
JAX implementation of :obj:`scipy.stats.laplace` ``pdf``.
The Laplace probability distribution function is given by
.. math::
f(x) = \frac{1}{2} e^{-|x|}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.laplace.cdf`
- :func:`jax.scipy.stats.laplace.logpdf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Laplace cumulative distribution function.
JAX implementation of :obj:`scipy.stats.laplace` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.laplace.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.laplace.pdf`
- :func:`jax.scipy.stats.laplace.logpdf`
"""
x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale)
half = _lax_const(x, 0.5)
one = _lax_const(x, 1)
zero = _lax_const(x, 0)
diff = lax.div(lax.sub(x, loc), scale)
return lax.select(lax.le(diff, zero),
lax.mul(half, lax.exp(diff)),
lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
@@ -0,0 +1,204 @@
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import expit, logit
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Logistic log probability distribution function.
JAX implementation of :obj:`scipy.stats.logistic` ``logpdf``.
The logistic probability distribution function is given by
.. math::
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
x = lax.div(lax.sub(x, loc), scale)
two = _lax_const(x, 2)
half_x = lax.div(x, two)
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Logistic probability distribution function.
JAX implementation of :obj:`scipy.stats.logistic` ``pdf``.
The logistic probability distribution function is given by
.. math::
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.logpdf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
return lax.exp(logpdf(x, loc, scale))
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Logistic distribution percent point function.
JAX implementation of :obj:`scipy.stats.logistic` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.logistic.cdf`.
Args:
x: arraylike, value at which to evaluate the PPF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of ppf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.logpdf`
"""
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
return lax.add(lax.mul(logit(x), scale), loc)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Logistic distribution survival function.
JAX implementation of :obj:`scipy.stats.logistic` ``sf``
The survival function is defined as
.. math::
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
:func:`jax.scipy.stats.logistic.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.logpdf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Logistic distribution inverse survival function.
JAX implementation of :obj:`scipy.stats.logistic` ``isf``.
Returns the inverse of the survival function,
:func:`jax.scipy.stats.logistic.sf`.
Args:
x: arraylike, value at which to evaluate the ISF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of isf values.
See Also:
- :func:`jax.scipy.stats.logistic.cdf`
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.logpdf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Logistic cumulative distribution function.
JAX implementation of :obj:`scipy.stats.logistic` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.logistic.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.logistic.pdf`
- :func:`jax.scipy.stats.logistic.sf`
- :func:`jax.scipy.stats.logistic.isf`
- :func:`jax.scipy.stats.logistic.logpdf`
- :func:`jax.scipy.stats.logistic.ppf`
"""
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
return expit(lax.div(lax.sub(x, loc), scale))
@@ -0,0 +1,83 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_args_inexact, promote_args_numeric
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
r"""Multinomial log probability mass function.
JAX implementation of :obj:`scipy.stats.multinomial` ``logpdf``.
The multinomial probability distribution is given by
.. math::
f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}
with :math:`n = \sum_i x_i`.
Args:
x: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
Returns:
array of logpmf values.
See Also:
:func:`jax.scipy.stats.multinomial.pmf`
"""
p, = promote_args_inexact("multinomial.logpmf", p)
x, n = promote_args_numeric("multinomial.logpmf", x, n)
if not dtypes.issubdtype(x.dtype, np.integer):
raise ValueError(f"x and n must be of integer type; got x.dtype={x.dtype}, n.dtype={n.dtype}")
x = x.astype(p.dtype)
n = n.astype(p.dtype)
logprobs = gammaln(n + 1) + jnp.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)
return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -np.inf)
def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
r"""Multinomial probability mass function.
JAX implementation of :obj:`scipy.stats.multinomial` ``pmf``.
The multinomial probability distribution is given by
.. math::
f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}
with :math:`n = \sum_i x_i`.
Args:
x: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
Returns:
array of pmf values
See Also:
:func:`jax.scipy.stats.multinomial.logpmf`
"""
return lax.exp(logpmf(x, n, p))
@@ -0,0 +1,103 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy import einsum as jnp_einsum
from jax._src.numpy import vectorize as jnp_vectorize
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike:
r"""Multivariate normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.multivariate_normal` ``logpdf``.
The multivariate normal PDF is defined as
.. math::
f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)
where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and
:math:`k` is the rank of :math:`\Sigma`.
Args:
x: arraylike, value at which to evaluate the PDF
mean: arraylike, centroid of distribution
cov: arraylike, covariance matrix of distribution
allow_singular: not supported
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.multivariate_normal.pdf`
"""
if allow_singular is not None:
raise NotImplementedError("allow_singular argument of multivariate_normal.logpdf")
x, mean, cov = promote_dtypes_inexact(x, mean, cov)
if not mean.shape:
return (-1/2 * jnp.square(x - mean) / cov
- 1/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
else:
n = mean.shape[-1]
if not np.shape(cov):
y = x - mean
return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) / cov
- n/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
else:
if cov.ndim < 2 or cov.shape[-2:] != (n, n):
raise ValueError("multivariate_normal.logpdf got incompatible shapes")
L = lax.linalg.cholesky(cov)
y = jnp_vectorize.vectorize(
partial(lax.linalg.triangular_solve, lower=True, transpose_a=True),
signature="(n,n),(n)->(n)"
)(L, x - mean)
return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi)
- jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array:
r"""Multivariate normal probability distribution function.
JAX implementation of :obj:`scipy.stats.multivariate_normal` ``pdf``.
The multivariate normal PDF is defined as
.. math::
f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)
where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and
:math:`k` is the rank of :math:`\Sigma`.
Args:
x: arraylike, value at which to evaluate the PDF
mean: arraylike, centroid of distribution
cov: arraylike, covariance matrix of distribution
allow_singular: not supported
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.multivariate_normal.logpdf`
"""
return lax.exp(logpdf(x, mean, cov))
@@ -0,0 +1,86 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.special import gammaln, xlogy
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Negative-binomial log probability mass function.
JAX implementation of :obj:`scipy.stats.nbinom` ``logpmf``.
The negative-binomial probability mass function is given by
.. math::
f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k
for :math:`k \ge 0` and :math:`0 \le p \le 1`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.nbinom.pmf`
"""
k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc)
one = _lax_const(k, 1)
y = lax.sub(k, loc)
comb_term = lax.sub(
lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one))
)
log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p)))
log_probs = lax.add(comb_term, log_linear_term)
return jnp.where(lax.lt(k, loc), -np.inf, log_probs)
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Negative-binomial probability mass function.
JAX implementation of :obj:`scipy.stats.nbinom` ``pmf``.
The negative-binomial probability mass function is given by
.. math::
f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k
for :math:`k \ge 0` and :math:`0 \le p \le 1`.
Args:
k: arraylike, value at which to evaluate the PMF
n: arraylike, distribution shape parameter
p: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values.
See Also:
:func:`jax.scipy.stats.nbinom.logpmf`
"""
return lax.exp(logpmf(k, n, p, loc))
@@ -0,0 +1,284 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy import special
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.norm` ``logpdf``.
The normal distribution pdf is given by
.. math::
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
scale_sqrd = lax.square(scale)
log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd))
quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd)
return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2))
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Normal probability distribution function.
JAX implementation of :obj:`scipy.stats.norm` ``pdf``.
The normal distribution pdf is given by
.. math::
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Normal cumulative distribution function.
JAX implementation of :obj:`scipy.stats.norm` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.norm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale)
return special.ndtr(lax.div(lax.sub(x, loc), scale))
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Normal log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.norm` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
where :math:`f_{pdf}` is the probability density function,
:func:`jax.scipy.stats.norm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale)
return special.log_ndtr(lax.div(lax.sub(x, loc), scale))
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Normal distribution percent point function.
JAX implementation of :obj:`scipy.stats.norm` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.norm.cdf`.
Args:
q: arraylike, value at which to evaluate the PPF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of ppf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
"""
return jnp.asarray(special.ndtri(q) * scale + loc, float)
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Normal distribution log survival function.
JAX implementation of :obj:`scipy.stats.norm` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.norm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale)
return logcdf(-x, -loc, scale)
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Normal distribution survival function.
JAX implementation of :obj:`scipy.stats.norm` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.norm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.isf`
- :func:`jax.scipy.stats.norm.ppf`
"""
x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale)
return cdf(-x, -loc, scale)
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Normal distribution inverse survival function.
JAX implementation of :obj:`scipy.stats.norm` ``isf``.
Returns the inverse of the survival function,
:func:`jax.scipy.stats.norm.sf`.
Args:
x: arraylike, value at which to evaluate the ISF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of isf values.
See Also:
- :func:`jax.scipy.stats.norm.cdf`
- :func:`jax.scipy.stats.norm.pdf`
- :func:`jax.scipy.stats.norm.sf`
- :func:`jax.scipy.stats.norm.logcdf`
- :func:`jax.scipy.stats.norm.logpdf`
- :func:`jax.scipy.stats.norm.logsf`
- :func:`jax.scipy.stats.norm.ppf`
"""
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)
@@ -0,0 +1,313 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto log probability distribution function.
JAX implementation of :obj:`scipy.stats.pareto` ``logpdf``.
The Pareto probability density function is given by
.. math::
f(x, b) = \begin{cases}
bx^{-(b+1)} & x \ge 1\\
0 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale)
one = _lax_const(x, 1)
scaled_x = lax.div(lax.sub(x, loc), scale)
normalize_term = lax.log(lax.div(scale, b))
log_probs = lax.neg(
lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))
)
return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, log_probs)
def pdf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto probability distribution function.
JAX implementation of :obj:`scipy.stats.pareto` ``pdf``.
The Pareto probability density function is given by
.. math::
f(x, b) = \begin{cases}
bx^{-(b+1)} & x \ge 1\\
0 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the PDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
return lax.exp(logpdf(x, b, loc, scale))
def cdf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto cumulative distribution function.
JAX implementation of :obj:`scipy.stats.pareto` ``cdf``.
The Pareto cumulative distribution function is given by
.. math::
F(x, b) = \begin{cases}
1 - x^{-b} & x \ge 1\\
0 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the CDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of CDF values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
x, b, loc, scale = promote_args_inexact("pareto.cdf", x, b, loc, scale)
one = _lax_const(x, 1)
zero = _lax_const(x, 0)
scaled_x = lax.div(lax.sub(x, loc), scale)
cdf = lax.sub(one, lax.pow(scaled_x, lax.neg(b)))
return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, cdf)
def logcdf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.pareto` ``logcdf``.
The Pareto cumulative distribution function is given by
.. math::
F(x, b) = \begin{cases}
1 - x^{-b} & x \ge 1\\
0 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the CDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logCDF values.
See Also:
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
x, b, loc, scale = promote_args_inexact("pareto.logcdf", x, b, loc, scale)
scaled_x = lax.div(lax.sub(x, loc), scale)
logcdf_val = lax.log1p(lax.neg(lax.pow(scaled_x, lax.neg(b))))
return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, logcdf_val)
def logsf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto log survival function.
JAX implementation of :obj:`scipy.stats.pareto` ``logsf``.
The Pareto survival function is given by
.. math::
S(x, b) = \begin{cases}
x^{-b} & x \ge 1\\
1 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the survival function
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of log survival function values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
- :func:`jax.scipy.stats.pareto.sf`
"""
x, b, loc, scale = promote_args_inexact("pareto.logsf", x, b, loc, scale)
zero = _lax_const(x, 0)
scaled_x = lax.div(lax.sub(x, loc), scale)
logsf_val = lax.neg(lax.mul(b, lax.log(scaled_x)))
return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, logsf_val)
def sf(
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto survival function.
JAX implementation of :obj:`scipy.stats.pareto` ``sf``.
The Pareto survival function is given by
.. math::
S(x, b) = \begin{cases}
x^{-b} & x \ge 1\\
1 & x < 1
\end{cases}
and is defined for :math:`b > 0`.
Args:
x: arraylike, value at which to evaluate the survival function
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of survival function values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.ppf`
"""
return lax.exp(logsf(x, b, loc, scale))
def ppf(
q: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
) -> Array:
r"""Pareto percent point function (inverse CDF).
JAX implementation of :obj:`scipy.stats.pareto` ``ppf``.
The Pareto percent point function is the inverse of the Pareto CDF, and is
given by
.. math::
F^{-1}(q, b) = \begin{cases}
(1 - q)^{-1/b} & 0 \le q < 1\\
\text{NaN} & \text{otherwise}
\end{cases}
and is defined for :math:`b > 0`.
Args:
q: arraylike, value at which to evaluate the inverse CDF
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of percent point function values.
See Also:
- :func:`jax.scipy.stats.pareto.logcdf`
- :func:`jax.scipy.stats.pareto.logpdf`
- :func:`jax.scipy.stats.pareto.logsf`
- :func:`jax.scipy.stats.pareto.cdf`
- :func:`jax.scipy.stats.pareto.pdf`
- :func:`jax.scipy.stats.pareto.sf`
"""
q, b, loc, scale = promote_args_inexact("pareto.ppf", q, b, loc, scale)
one = _lax_const(q, 1)
ppf_val = lax.add(
loc, lax.mul(scale, lax.pow(lax.sub(one, q), lax.neg(lax.div(one, b))))
)
return jnp.where(jnp.isnan(q) | (q < 0) | (q > 1), np.nan, ppf_val)
@@ -0,0 +1,241 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import api
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact, ensure_arraylike
from jax._src.scipy.special import xlogy, entr, gammaln, gammaincc
from jax._src.typing import Array, ArrayLike
def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Poisson log probability mass function.
JAX implementation of :obj:`scipy.stats.poisson` ``logpmf``.
The Poisson probability mass function is given by
.. math::
f(k) = e^{-\mu}\frac{\mu^k}{k!}
and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`.
Args:
k: arraylike, value at which to evaluate the PMF
mu: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of logpmf values.
See Also:
- :func:`jax.scipy.stats.poisson.cdf`
- :func:`jax.scipy.stats.poisson.pmf`
"""
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
return jnp.where(jnp.logical_or(lax.lt(x, zero),
lax.ne(jnp.round(k), k)), -np.inf, log_probs)
def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Poisson probability mass function.
JAX implementation of :obj:`scipy.stats.poisson` ``pmf``.
The Poisson probability mass function is given by
.. math::
f(k) = e^{-\mu}\frac{\mu^k}{k!}
and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`.
Args:
k: arraylike, value at which to evaluate the PMF
mu: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of pmf values.
See Also:
- :func:`jax.scipy.stats.poisson.cdf`
- :func:`jax.scipy.stats.poisson.logpmf`
"""
return jnp.exp(logpmf(k, mu, loc))
def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Poisson cumulative distribution function.
JAX implementation of :obj:`scipy.stats.poisson` ``cdf``.
The cumulative distribution function is defined as:
.. math::
f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)
where :math:`f_{pmf}(k, p)` is the probability mass function
:func:`jax.scipy.stats.poisson.pmf`.
Args:
k: arraylike, value at which to evaluate the CDF
mu: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.poisson.pmf`
- :func:`jax.scipy.stats.poisson.logpmf`
"""
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
zero = _lax_const(k, 0)
x = lax.sub(k, loc)
p = gammaincc(jnp.floor(1 + x), mu)
return jnp.where(lax.lt(x, zero), zero, p)
@api.jit
def entropy(mu: ArrayLike, loc: ArrayLike = 0) -> Array:
r"""Shannon entropy of the Poisson distribution.
JAX implementation of :obj:`scipy.stats.poisson` ``entropy``.
The entropy :math:`H(X)` of a Poisson random variable
:math:`X \sim \text{Poisson}(\mu)` is defined as:
.. math::
H(X) = -\sum_{k=0}^\infty p(k) \log p(k)
where :math:`p(k) = e^{-\mu} \mu^k / k!` for
:math:`k \geq \max(0, \lfloor \text{loc} \rfloor)`.
This implementation uses **regime switching** for numerical stability
and performance:
- **Small** :math:`\mu < 10`: Direct summation over PMF with adaptive
upper bound :math:`k \leq \mu + 20`
- **Medium** :math:`10 \leq \mu < 100`: Summation with bound
:math:`k \leq \mu + 10\sqrt{\mu} + 20`
- **Large** :math:`\mu \geq 100`: Asymptotic Stirling approximation:
:math:`H(\mu) \approx \frac{1}{2} \log(2\pi e \mu) - \frac{1}{12\mu}`
Matches SciPy to relative error :math:`< 10^{-5}` across all regimes.
Args:
mu: arraylike, mean parameter of the Poisson distribution.
Must be ``> 0``.
loc: arraylike, optional location parameter (default: 0).
Accepted for API compatibility with scipy but does not
affect the entropy
Returns:
Array of entropy values with shape broadcast from ``mu`` and ``loc``.
Returns ``NaN`` for ``mu <= 0``.
Examples:
>>> from jax.scipy.stats import poisson
>>> poisson.entropy(5.0)
Array(2.204394, dtype=float32)
>>> poisson.entropy(jax.numpy.array([1, 10, 100]))
Array([1.3048419, 2.561407 , 3.7206903], dtype=float32)
See Also:
- :func:`jax.scipy.stats.poisson.pmf`
- :func:`jax.scipy.stats.poisson.logpmf`
- :obj:`scipy.stats.poisson`
"""
mu, loc = ensure_arraylike("poisson.entropy", mu, loc)
promoted_mu, promoted_loc = promote_dtypes_inexact(mu, loc)
#Note: loc does not affect the entropy - translation invariant
#it has only been taken to maintain compatibility with scipy api
result_shape = jnp.broadcast_shapes(
promoted_mu.shape,
promoted_loc.shape
)
mu_flat = jnp.ravel(promoted_mu)
zero_result = jnp.zeros_like(mu_flat)
# Choose the computation regime based on mu value
result = jnp.where(
mu_flat == 0,
zero_result,
jnp.where(
mu_flat < 10,
_entropy_small_mu(mu_flat),
jnp.where(
mu_flat < 100,
_entropy_medium_mu(mu_flat),
_entropy_large_mu(mu_flat)
)
)
)
result_mu_shape = jnp.reshape(result, promoted_mu.shape)
# Restore original shape
return jnp.broadcast_to(result_mu_shape, result_shape)
def _entropy_small_mu(mu: Array) -> Array:
"""Entropy via direct PMF summation for small μ (< 10).
Uses adaptive upper bound k μ + 20 to capture >99.999% of mass.
"""
max_k = 35
k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
probs = pmf(k, mu, 0)
# Mask: only compute up to mu + 20 for each value
upper_bounds = jnp.ceil(mu + 20).astype(k.dtype)
mask = k < upper_bounds[None, :]
probs_masked = jnp.where(mask, probs, 0.0)
return jnp.sum(entr(probs_masked), axis=0)
def _entropy_medium_mu(mu: Array) -> Array:
"""Entropy for medium mu (10-100): Adaptive bounds based on std dev.
Bounds: k μ + 10μ + 20. Caps at k=250 for JIT compatibility.
"""
max_k = 250 # Static bound for JIT. For mu<100, upper bound < 220
k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
probs = pmf(k, mu, 0)
upper_bounds = jnp.ceil(mu + 10 * jnp.sqrt(mu) + 20).astype(k.dtype)
mask = k < upper_bounds[None, :]
probs_masked = jnp.where(mask, probs, 0.0)
return jnp.sum(entr(probs_masked), axis=0)
def _entropy_large_mu(mu: Array) -> Array:
"""Entropy for large mu (>= 100): Asymptotic approximation.
Formula: H(λ) 0.5*log(2πeλ) - 1/(12λ) + O(λ^-2)
"""
return 0.5 * jnp.log(2 * np.pi * np.e * mu) - 1.0 / (12 * mu)
@@ -0,0 +1,89 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Student's T log probability distribution function.
JAX implementation of :obj:`scipy.stats.t` ``logpdf``.
The Student's T probability distribution function is given by
.. math::
f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2}
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0`
is the degrees of freedom (JAX follows the scipy convention of naming this ``df``).
Args:
x: arraylike, value at which to evaluate the PDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.t.pdf`
"""
x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale)
two = _lax_const(x, 2)
scaled_x = lax.div(lax.sub(x, loc), scale)
df_over_two = lax.div(df, two)
df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
lax.lgamma(df_plus_one_over_two))
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Student's T probability distribution function.
JAX implementation of :obj:`scipy.stats.t` ``pdf``.
The Student's T probability distribution function is given by
.. math::
f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2}
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0`
is the degrees of freedom (JAX follows the scipy convention of naming this ``df``).
Args:
x: arraylike, value at which to evaluate the PDF
df: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array
See Also:
:func:`jax.scipy.stats.t.logpdf`
"""
return lax.exp(logpdf(x, df, loc, scale))
@@ -0,0 +1,296 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import api
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.util import promote_args_inexact
from jax._src.scipy.stats import norm
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
def _log_diff(x, y):
return logsumexp(
jnp.array([x, y]),
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
axis=0
)
def _log_gauss_mass(a, b):
"""Log of Gaussian probability mass within an interval"""
a, b = jnp.array(a), jnp.array(b)
a, b = jnp.broadcast_arrays(a, b)
# Note: Docstring carried over from scipy
# Calculations in right tail are inaccurate, so we'll exploit the
# symmetry and work only in the left tail
case_left = b <= 0
case_right = a > 0
case_central = ~(case_left | case_right)
# By conditionally swapping arguments if we're in the right tail,
# we only need to compile the mass_case_left graph once instead of twice.
a_tail = jnp.where(case_right, -b, a)
b_tail = jnp.where(case_right, -a, b)
mass_tail = _log_diff(log_ndtr(b_tail), log_ndtr(a_tail))
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
# Correct for this with an alternative formulation.
# We're not concerned with underflow here: if only one term
# underflows, it was insignificant; if both terms underflow,
# the result can't accurately be represented in logspace anyway
# because sc.log1p(x) ~ x for small x.
mass_central = jnp.log1p(-ndtr(a) - ndtr(-b))
out = jnp.where(case_central, mass_central, mass_tail)
return out
@api.jit
def logpdf(x, a, b, loc=0, scale=1):
r"""Truncated normal log probability distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logpdf``.
The truncated normal probability distribution is given by
.. math::
f(x, a, b) = \begin{cases}
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
0 & \mathrm{otherwise}
\end{cases}
where :math:`a` and :math:`b` are effectively specified in number of
standard deviations from zero. JAX uses the scipy nomenclature
of ``loc`` for the centroid and ``scale`` for the standard deviation.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
x_scaled = lax.div(lax.sub(x, loc), scale)
val = jnp.where((x_scaled < a) | (x_scaled > b), -np.inf, val)
val = jnp.where(a >= b, np.nan, val)
return val
def pdf(x, a, b, loc=0, scale=1):
r"""Truncated normal probability distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``pdf``.
The truncated normal probability distribution is given by
.. math::
f(x, a, b) = \begin{cases}
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
0 & \mathrm{otherwise}
\end{cases}
where :math:`a` and :math:`b` are effectively specified in number of
standard deviations from the centroid. JAX uses the scipy nomenclature
of ``loc`` for the centroid and ``scale`` for the standard deviation.
Args:
x: arraylike, value at which to evaluate the PDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
return lax.exp(logpdf(x, a, b, loc, scale))
@api.jit
def logsf(x, a, b, loc=0, scale=1):
"""Truncated normal distribution log survival function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.truncnorm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logsf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
return logcdf(-x, -b, -a, -loc, scale)
@api.jit
def sf(x, a, b, loc=0, scale=1):
"""Truncated normal distribution log survival function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the cumulative distribution function,
:func:`jax.scipy.stats.truncnorm.cdf`.
Args:
x: arraylike, value at which to evaluate the SF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of sf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
"""
return lax.exp(logsf(x, a, b, loc, scale))
@api.jit
def logcdf(x, a, b, loc=0, scale=1):
r"""Truncated normal log cumulative distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the probability distribution function,
:func:`jax.scipy.stats.truncnorm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logcdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.cdf`
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
x, a, b = jnp.broadcast_arrays(x, a, b)
x = lax.div(lax.sub(x, loc), scale)
lgm_ab = _log_gauss_mass(a, b)
logcdf = _log_gauss_mass(a, x) - lgm_ab
logsf = _log_gauss_mass(x, b) - lgm_ab
logcdf = jnp.select(
# third condition: avoid catastrophic cancellation (from scipy)
[x >= b, x <= a, logcdf > -0.1, x > a],
[0, -np.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
)
logcdf = jnp.where(a >= b, np.nan, logcdf)
return logcdf
@api.jit
def cdf(x, a, b, loc=0, scale=1):
r"""Truncated normal cumulative distribution function.
JAX implementation of :obj:`scipy.stats.truncnorm` ``cdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the probability distribution function,
:func:`jax.scipy.stats.truncnorm.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
a: arraylike, distribution shape parameter
b: arraylike, distribution shape parameter
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.truncnorm.pdf`
- :func:`jax.scipy.stats.truncnorm.sf`
- :func:`jax.scipy.stats.truncnorm.logcdf`
- :func:`jax.scipy.stats.truncnorm.logpdf`
- :func:`jax.scipy.stats.truncnorm.logsf`
"""
return lax.exp(logcdf(x, a, b, loc, scale))
@@ -0,0 +1,148 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import promote_args_inexact
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Uniform log probability distribution function.
JAX implementation of :obj:`scipy.stats.uniform` ``logpdf``.
The uniform distribution pdf is given by
.. math::
f(x) = \begin{cases}
1 & 0 \le x \le 1 \\
0 & \mathrm{otherwise}
\end{cases}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of logpdf values
See Also:
- :func:`jax.scipy.stats.uniform.cdf`
- :func:`jax.scipy.stats.uniform.pdf`
- :func:`jax.scipy.stats.uniform.ppf`
"""
x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale)
log_probs = lax.neg(lax.log(scale))
return jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
lax.lt(x, loc)),
-np.inf, log_probs)
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Uniform probability distribution function.
JAX implementation of :obj:`scipy.stats.uniform` ``pdf``.
The uniform distribution pdf is given by
.. math::
f(x) = \begin{cases}
1 & 0 \le x \le 1 \\
0 & \mathrm{otherwise}
\end{cases}
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
- :func:`jax.scipy.stats.uniform.cdf`
- :func:`jax.scipy.stats.uniform.logpdf`
- :func:`jax.scipy.stats.uniform.ppf`
"""
return lax.exp(logpdf(x, loc, scale))
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Uniform cumulative distribution function.
JAX implementation of :obj:`scipy.stats.uniform` ``cdf``.
The cdf is defined as
.. math::
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
where here :math:`f_{pdf}` is the probability distribution function,
:func:`jax.scipy.stats.uniform.pdf`.
Args:
x: arraylike, value at which to evaluate the CDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of cdf values.
See Also:
- :func:`jax.scipy.stats.uniform.pdf`
- :func:`jax.scipy.stats.uniform.logpdf`
- :func:`jax.scipy.stats.uniform.ppf`
"""
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))]
vals = [zero, one, lax.div(lax.sub(x, loc), scale)]
return jnp.select(conds, vals)
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
"""Uniform distribution percent point function.
JAX implementation of :obj:`scipy.stats.uniform` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.uniform.cdf`.
Args:
q: arraylike, value at which to evaluate the PPF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of ppf values.
See Also:
- :func:`jax.scipy.stats.uniform.cdf`
- :func:`jax.scipy.stats.uniform.pdf`
- :func:`jax.scipy.stats.uniform.logpdf`
"""
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
return jnp.where(
jnp.isnan(q) | (q < 0) | (q > 1),
np.nan,
lax.add(loc, lax.mul(scale, q))
)
@@ -0,0 +1,79 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array:
r"""von Mises log probability distribution function.
JAX implementation of :obj:`scipy.stats.vonmises` ``logpdf``.
The von Mises probability distribution function is given by
.. math::
f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x}
Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0`
and :math:`\kappa\ge 0`, and the distribution is normalized in the interval
:math:`-\pi \le x \le \pi`.
Args:
x: arraylike, value at which to evaluate the PDF
kappa: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.vonmises.pdf`
"""
x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa)
zero = _lax_const(kappa, 0)
return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * np.pi * lax.bessel_i0e(kappa)), np.nan)
def pdf(x: ArrayLike, kappa: ArrayLike) -> Array:
r"""von Mises probability distribution function.
JAX implementation of :obj:`scipy.stats.vonmises` ``pdf``.
The von Mises probability distribution function is given by
.. math::
f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x}
Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0`
and :math:`\kappa\ge 0`, and the distribution is normalized in the interval
:math:`-\pi \le x \le \pi`.
Args:
x: arraylike, value at which to evaluate the PDF
kappa: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.vonmises.logpdf`
"""
return lax.exp(logpdf(x, kappa))
@@ -0,0 +1,82 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax.lax import _const as _lax_const
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
r"""Wrapped Cauchy log probability distribution function.
JAX implementation of :obj:`scipy.stats.wrapcauchy` ``logpdf``.
The wrapped Cauchy probability distribution function is given by
.. math::
f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)}
for :math:`0<c<1`, and where normalization is on the domain :math:`0\le x\le 2\pi`.
Args:
x: arraylike, value at which to evaluate the PDF
c: arraylike, distribution shape parameter
Returns:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.wrapcauchy.pdf`
"""
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
return jnp.where(
lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)),
jnp.where(
lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, np.pi * 2)),
jnp.log(1 - c * c) - jnp.log(2 * np.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)),
-np.inf,
),
np.nan,
)
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
r"""Wrapped Cauchy probability distribution function.
JAX implementation of :obj:`scipy.stats.wrapcauchy` ``pdf``.
The wrapped Cauchy probability distribution function is given by
.. math::
f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)}
for :math:`0<c<1`, and where normalization is on the domain :math:`0\le x\le 2\pi`.
Args:
x: arraylike, value at which to evaluate the PDF
c: arraylike, distribution shape parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.wrapcauchy.logpdf`
"""
return lax.exp(logpdf(x, c))