hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,75 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import linalg as jnp_linalg
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
|
||||
"""Assign codes from a code book to a set of observations.
|
||||
|
||||
JAX implementation of :func:`scipy.cluster.vq.vq`.
|
||||
|
||||
Assigns each observation vector in ``obs`` to a code from ``code_book``
|
||||
based on the nearest Euclidean distance.
|
||||
|
||||
Args:
|
||||
obs: array of observation vectors of shape ``(M, N)``. Each row represents
|
||||
a single observation. If ``obs`` is one-dimensional, then each entry is
|
||||
treated as a length-1 observation.
|
||||
code_book: array of codes with shape ``(K, N)``. Each row represents a single
|
||||
code vector. If ``code_book`` is one-dimensional, then each entry is treated
|
||||
as a length-1 code.
|
||||
check_finite: unused in JAX
|
||||
|
||||
Returns:
|
||||
A tuple of arrays ``(code, dist)``
|
||||
|
||||
- ``code`` is an integer array of shape ``(M,)`` containing indices ``0 <= i < K``
|
||||
of the closest entry in ``code_book`` for the given entry in ``obs``.
|
||||
- ``dist`` is a float array of shape ``(M,)`` containing the euclidean
|
||||
distance between each observation and the nearest code.
|
||||
|
||||
Examples:
|
||||
>>> obs = jnp.array([[1.1, 2.1, 3.1],
|
||||
... [5.9, 4.8, 6.2]])
|
||||
>>> code_book = jnp.array([[1., 2., 3.],
|
||||
... [2., 3., 4.],
|
||||
... [3., 4., 5.],
|
||||
... [4., 5., 6.]])
|
||||
>>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book)
|
||||
>>> print(codes)
|
||||
[0 3]
|
||||
>>> print(distances)
|
||||
[0.17320499 1.9209373 ]
|
||||
"""
|
||||
del check_finite # unused
|
||||
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
|
||||
obs_arr, cb_arr = promote_dtypes_inexact(obs, code_book)
|
||||
if obs_arr.ndim != cb_arr.ndim:
|
||||
raise ValueError("Observation and code_book should have the same rank")
|
||||
if obs_arr.ndim == 1:
|
||||
obs_arr, cb_arr = obs_arr[..., None], cb_arr[..., None]
|
||||
if obs_arr.ndim != 2:
|
||||
raise ValueError("ndim different than 1 or 2 are not supported")
|
||||
dist = api.vmap(lambda ob: jnp_linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr)
|
||||
code = jnp.argmin(dist, axis=-1)
|
||||
dist_min = api.vmap(operator.getitem)(dist, code)
|
||||
return code, dist_min
|
||||
@@ -0,0 +1,502 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import fft as jnp_fft
|
||||
from jax._src.numpy.util import (
|
||||
promote_dtypes_complex, promote_dtypes_inexact, ensure_arraylike)
|
||||
from jax._src.util import canonicalize_axis, canonicalize_axis_tuple
|
||||
from jax._src.typing import Array
|
||||
|
||||
def _W4(N: int, k: Array) -> Array:
|
||||
N_arr, k = promote_dtypes_complex(N, k)
|
||||
return jnp.exp(-.5j * np.pi * k / N_arr)
|
||||
|
||||
def _dct_interleave(x: Array, axis: int) -> Array:
|
||||
v0 = lax.slice_in_dim(x, None, None, 2, axis)
|
||||
v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis,))
|
||||
return lax.concatenate([v0, v1], axis)
|
||||
|
||||
def _dct_ortho_norm(out: Array, axis: int) -> Array:
|
||||
factor = lax.concatenate([lax.full((1,), 4, out.dtype), lax.full((out.shape[axis] - 1,), 2, out.dtype)], 0)
|
||||
factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
|
||||
return out / lax.sqrt(factor * out.shape[axis])
|
||||
|
||||
# Implementation based on
|
||||
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)
|
||||
|
||||
|
||||
def dct(x: Array, type: int = 2, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
"""Computes the discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.dct`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
n: integer, default = x.shape[axis]. The length of the transform.
|
||||
If larger than ``x.shape[axis]``, the input will be zero-padded, if
|
||||
smaller, the input will be truncated.
|
||||
axis: integer, default=-1. The axis along which the dct will be performed.
|
||||
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
|
||||
The default is ``None``, which is equivalent to ``"backward"``.
|
||||
|
||||
Returns:
|
||||
array containing the discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: inverse DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
|
||||
Examples:
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x))
|
||||
[[ 6.43 3.56 -2.86]
|
||||
[-1.75 1.55 -1.4 ]
|
||||
[ 1.33 -2.01 -0.82]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=2))
|
||||
[[ 7.3 -0.57]
|
||||
[ 0.19 -0.36]
|
||||
[-0. -1.4 ]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=2, axis=0))
|
||||
[[ 3.09 4.4 -2.81]
|
||||
[ 2.41 2.62 0.76]]
|
||||
|
||||
When ``n`` larger than ``x.shape[axis]`` and ``axis=1``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=4, axis=1))
|
||||
[[ 6.43 4.88 0.04 -3.3 ]
|
||||
[-1.75 0.73 1.01 -2.18]
|
||||
[ 1.33 -1.05 -2.34 -0.07]]
|
||||
"""
|
||||
x = ensure_arraylike("dct", x)
|
||||
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
if norm is not None and norm not in ['backward', 'ortho']:
|
||||
raise ValueError(f"jax.scipy.fft.dct: {norm=!r} is not implemented")
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return lax.complex(
|
||||
dct(x.real, type=type, n=n, norm=norm, axis=axis),
|
||||
dct(x.imag, type=type, n=n, norm=norm, axis=axis),
|
||||
)
|
||||
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
if n is not None:
|
||||
x = lax.pad(x, jnp.array(0, x.dtype),
|
||||
[(0, n - x.shape[axis] if a == axis else 0, 0)
|
||||
for a in range(x.ndim)])
|
||||
|
||||
N = x.shape[axis]
|
||||
v = _dct_interleave(x, axis)
|
||||
V = jnp_fft.fft(v, axis=axis)
|
||||
k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis])
|
||||
out = V * _W4(N, k)
|
||||
out = 2 * out.real
|
||||
if norm == 'ortho':
|
||||
out = _dct_ortho_norm(out, axis)
|
||||
return out
|
||||
|
||||
|
||||
def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
|
||||
axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
|
||||
N1, N2 = x.shape[axis1], x.shape[axis2]
|
||||
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
|
||||
V = jnp_fft.fftn(v, axes=axes)
|
||||
k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype),
|
||||
[a for a in range(x.ndim) if a != axis1])
|
||||
k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype),
|
||||
[a for a in range(x.ndim) if a != axis2])
|
||||
out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2))
|
||||
out = 2 * out.real
|
||||
if norm == 'ortho':
|
||||
return _dct_ortho_norm(_dct_ortho_norm(out, axis1), axis2)
|
||||
return out
|
||||
|
||||
|
||||
def dctn(x: Array, type: int = 2,
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
"""Computes the multidimensional discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.dctn`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
s: integer or sequence of integers. Specifies the shape of the result. If
|
||||
not specified, it will default to the shape of ``x`` along the specified
|
||||
``axes``.
|
||||
axes: integer or sequence of integers. Specifies the axes along which the
|
||||
transform will be computed. If not given, the last ``len(s)`` axes are
|
||||
used, or all axes if ``s`` is also not specified.
|
||||
norm: string. The normalization mode: one of
|
||||
``[None, "backward", "ortho"]``. The default is ``None``, which is
|
||||
equivalent to ``"backward"``.
|
||||
|
||||
Returns:
|
||||
array containing the discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
|
||||
Examples:
|
||||
|
||||
``jax.scipy.fft.dctn`` computes the transform along both the axes by default
|
||||
when ``axes`` argument is ``None`` and ``s`` is also ``None``.
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dctn(x))
|
||||
[[ 12.01 6.2 -10.17]
|
||||
[ 8.84 9.65 -3.54]
|
||||
[ 11.25 -1.54 -0.88]]
|
||||
|
||||
When ``s=[2]``, the transform will be computed only along the last axis,
|
||||
with its dimension padded or truncated to size ``2``:
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dctn(x, s=[2]))
|
||||
[[ 7.3 -0.57]
|
||||
[ 0.19 -0.36]
|
||||
[-0. -1.4 ]]
|
||||
|
||||
When ``s=[2]`` and ``axes=[0]``, the transform will be computed only along
|
||||
the specified axis, with its dimension padded or truncated to size ``2``:
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dctn(x, s=[2], axes=[0]))
|
||||
[[ 3.09 4.4 -2.81]
|
||||
[ 2.41 2.62 0.76]]
|
||||
|
||||
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dctn(x, s=[2, 4]))
|
||||
[[ 9.36 11.23 2.12 -10.97]
|
||||
[ 11.57 5.86 -1.37 -1.58]]
|
||||
"""
|
||||
x = ensure_arraylike("dctn", x)
|
||||
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
if norm is not None and norm not in ['backward', 'ortho']:
|
||||
raise ValueError(f"jax.scipy.fft.dctn: {norm=!r} is not implemented")
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return lax.complex(
|
||||
dctn(x.real, type=type, s=s, norm=norm, axes=axes),
|
||||
dctn(x.imag, type=type, s=s, norm=norm, axes=axes),
|
||||
)
|
||||
|
||||
if s is not None:
|
||||
try:
|
||||
s = list(s)
|
||||
except TypeError:
|
||||
assert not isinstance(s, Sequence)
|
||||
s = [operator.index(s)]
|
||||
if len(s) > x.ndim:
|
||||
raise ValueError(
|
||||
f"s must have at most x.ndim ({x.ndim}) elements, got {len(s)}"
|
||||
)
|
||||
|
||||
if axes is None:
|
||||
if s is not None:
|
||||
axes = tuple(range(x.ndim - len(s), x.ndim))
|
||||
else:
|
||||
axes = tuple(range(x.ndim))
|
||||
else:
|
||||
axes = canonicalize_axis_tuple(axes, x.ndim)
|
||||
|
||||
if len(axes) == 1:
|
||||
return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
|
||||
|
||||
if s is not None:
|
||||
ns = dict(zip(axes, s))
|
||||
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
|
||||
x = lax.pad(x, jnp.array(0, x.dtype), pads)
|
||||
|
||||
if len(axes) == 2:
|
||||
return _dct2(x, axes=axes, norm=norm)
|
||||
|
||||
# compose high-D DCTs from 2D and 1D DCTs:
|
||||
for axes_block in [axes[i:i+2] for i in range(0, len(axes), 2)]:
|
||||
x = dctn(x, axes=axes_block, norm=norm)
|
||||
return x
|
||||
|
||||
|
||||
def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
"""Computes the inverse discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.idct`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
n: integer, default = x.shape[axis]. The length of the transform.
|
||||
If larger than ``x.shape[axis]``, the input will be zero-padded, if
|
||||
smaller, the input will be truncated.
|
||||
axis: integer, default=-1. The axis along which the dct will be performed.
|
||||
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
|
||||
The default is ``None``, which is equivalent to ``"backward"``.
|
||||
|
||||
Returns:
|
||||
array containing the inverse discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: DCT
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
|
||||
Examples:
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x))
|
||||
[[ 0.78 0.41 -0.39]
|
||||
[-0.12 0.31 -0.23]
|
||||
[ 0.17 -0.3 -0.11]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=2))
|
||||
[[ 1.12 -0.31]
|
||||
[ 0.04 -0.08]
|
||||
[ 0.05 -0.3 ]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=2, axis=0))
|
||||
[[ 0.38 0.57 -0.45]
|
||||
[ 0.43 0.44 0.24]]
|
||||
|
||||
When ``n`` larger than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=4, axis=0))
|
||||
[[ 0.1 0.38 -0.16]
|
||||
[ 0.28 0.18 -0.26]
|
||||
[ 0.3 0.15 -0.08]
|
||||
[ 0.13 0.3 0.29]]
|
||||
|
||||
``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result
|
||||
of ``jax.scipy.fft.dct``
|
||||
|
||||
>>> x_dct = jax.scipy.fft.dct(x)
|
||||
>>> jnp.allclose(x, jax.scipy.fft.idct(x_dct))
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
x = ensure_arraylike("idct", x)
|
||||
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
if norm is not None and norm not in ['backward', 'ortho']:
|
||||
raise ValueError(f"jax.scipy.fft.idct: {norm=!r} is not implemented")
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return lax.complex(
|
||||
idct(x.real, type=type, n=n, norm=norm, axis=axis),
|
||||
idct(x.imag, type=type, n=n, norm=norm, axis=axis)
|
||||
)
|
||||
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
if n is not None:
|
||||
x = lax.pad(x, jnp.array(0, x.dtype),
|
||||
[(0, n - x.shape[axis] if a == axis else 0, 0)
|
||||
for a in range(x.ndim)])
|
||||
N = x.shape[axis]
|
||||
x, = promote_dtypes_inexact(x)
|
||||
if norm is None or norm == 'backward':
|
||||
x = _dct_ortho_norm(x, axis)
|
||||
x = _dct_ortho_norm(x, axis)
|
||||
|
||||
k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis])
|
||||
# everything is complex from here...
|
||||
w4 = _W4(N,k)
|
||||
x = x.astype(w4.dtype)
|
||||
x = x / (_W4(N, k))
|
||||
x = x * 2 * N
|
||||
|
||||
x = jnp_fft.ifft(x, axis=axis)
|
||||
# convert back to reals..
|
||||
out = _dct_deinterleave(x.real, axis)
|
||||
return out
|
||||
|
||||
|
||||
def idctn(x: Array, type: int = 2,
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
"""Computes the multidimensional inverse discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.idctn`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
s: integer or sequence of integers. Specifies the shape of the result. If
|
||||
not specified, it will default to the shape of ``x`` along the specified
|
||||
``axes``.
|
||||
axes: integer or sequence of integers. Specifies the axes along which the
|
||||
transform will be computed. If not given, the last ``len(s)`` axes are
|
||||
used, or all axes if ``s`` is also not specified.
|
||||
norm: string. The normalization mode: one of
|
||||
``[None, "backward", "ortho"]``. The default is ``None``, which is
|
||||
equivalent to ``"backward"``.
|
||||
|
||||
Returns:
|
||||
array containing the inverse discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
|
||||
|
||||
Examples:
|
||||
|
||||
``jax.scipy.fft.idctn`` computes the transform along both the axes by
|
||||
default when ``axes`` argument is ``None`` and ``s`` is also ``None``.
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x))
|
||||
[[ 0.12 0.11 -0.15]
|
||||
[ 0.07 0.17 -0.03]
|
||||
[ 0.19 -0.07 -0.02]]
|
||||
|
||||
When ``s=[2]``, the transform will be computed only along the last axis,
|
||||
with its dimension padded or truncated to size ``2``:
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2]))
|
||||
[[ 1.12 -0.31]
|
||||
[ 0.04 -0.08]
|
||||
[ 0.05 -0.3 ]]
|
||||
|
||||
When ``s=[2]`` and ``axes=[0]``, the transform will be computed only along
|
||||
the specified axis, with its dimension padded or truncated to size ``2``:
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2], axes=[0]))
|
||||
[[ 0.38 0.57 -0.45]
|
||||
[ 0.43 0.44 0.24]]
|
||||
|
||||
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2, 4]))
|
||||
[[ 0.1 0.18 0.07 -0.16]
|
||||
[ 0.2 0.06 -0.03 -0.01]]
|
||||
|
||||
``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result
|
||||
of ``jax.scipy.fft.dctn``
|
||||
|
||||
>>> x_dctn = jax.scipy.fft.dctn(x)
|
||||
>>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
x = ensure_arraylike("idctn", x)
|
||||
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
if norm is not None and norm not in ['backward', 'ortho']:
|
||||
raise ValueError(f"jax.scipy.fft.idctn: {norm=!r} is not implemented")
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return lax.complex(
|
||||
idctn(x.real, type=type, s=s, norm=norm, axes=axes),
|
||||
idctn(x.imag, type=type, s=s, norm=norm, axes=axes)
|
||||
)
|
||||
|
||||
if s is not None:
|
||||
try:
|
||||
s = list(s)
|
||||
except TypeError:
|
||||
assert not isinstance(s, Sequence)
|
||||
s = [operator.index(s)]
|
||||
if len(s) > x.ndim:
|
||||
raise ValueError(
|
||||
f"s must have at most x.ndim ({x.ndim}) elements, got {len(s)}"
|
||||
)
|
||||
|
||||
if axes is None:
|
||||
if s is not None:
|
||||
axes = tuple(range(x.ndim - len(s), x.ndim))
|
||||
else:
|
||||
axes = tuple(range(x.ndim))
|
||||
else:
|
||||
axes = canonicalize_axis_tuple(axes, x.ndim)
|
||||
|
||||
if len(axes) == 1:
|
||||
return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
|
||||
|
||||
if s is not None:
|
||||
ns = dict(zip(axes, s))
|
||||
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
|
||||
x = lax.pad(x, jnp.array(0, x.dtype), pads)
|
||||
|
||||
# compose high-D DCTs from 1D DCTs:
|
||||
for axis in axes:
|
||||
x = idct(x, axis=axis, norm=norm)
|
||||
return x
|
||||
|
||||
|
||||
def _dct_deinterleave(x: Array, axis: int) -> Array:
|
||||
empty_slice = slice(None, None, None)
|
||||
ix0 = tuple(
|
||||
slice(None, math.ceil(x.shape[axis]/2), 1) if i == axis else empty_slice
|
||||
for i in range(len(x.shape)))
|
||||
ix1 = tuple(
|
||||
slice(math.ceil(x.shape[axis]/2), None, 1) if i == axis else empty_slice
|
||||
for i in range(len(x.shape)))
|
||||
v0 = x[ix0]
|
||||
v1 = lax.rev(x[ix1], (axis,))
|
||||
out = jnp.zeros(x.shape, dtype=x.dtype)
|
||||
evens = tuple(
|
||||
slice(None, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
|
||||
odds = tuple(
|
||||
slice(1, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
|
||||
out = out.at[evens].set(v0)
|
||||
out = out.at[odds].set(v1)
|
||||
return out
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from jax._src.api import jit
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@jit(static_argnames=('axis',))
|
||||
def trapezoid(y: ArrayLike, x: ArrayLike | None = None, dx: ArrayLike = 1.0,
|
||||
axis: int = -1) -> Array:
|
||||
r"""
|
||||
Integrate along the given axis using the composite trapezoidal rule.
|
||||
|
||||
JAX implementation of :func:`scipy.integrate.trapezoid`
|
||||
|
||||
The trapezoidal rule approximates the integral under a curve by summing the
|
||||
areas of trapezoids formed between adjacent data points.
|
||||
|
||||
Args:
|
||||
y: array of data to integrate.
|
||||
x: optional array of sample points corresponding to the ``y`` values. If not
|
||||
provided, ``x`` defaults to equally spaced with spacing given by ``dx``.
|
||||
dx: The spacing between sample points when `x` is None (default: 1.0).
|
||||
axis: The axis along which to integrate (default: -1)
|
||||
|
||||
Returns:
|
||||
The definite integral approximated by the trapezoidal rule.
|
||||
|
||||
See also:
|
||||
:func:`jax.numpy.trapezoid`: NumPy-style API for trapezoidal integration
|
||||
|
||||
Examples:
|
||||
Integrate over a regular grid, with spacing 1.0:
|
||||
|
||||
>>> y = jnp.array([1, 2, 3, 2, 3, 2, 1])
|
||||
>>> jax.scipy.integrate.trapezoid(y, dx=1.0)
|
||||
Array(13., dtype=float32)
|
||||
|
||||
Integrate over an irregular grid:
|
||||
|
||||
>>> x = jnp.array([0, 2, 5, 7, 10, 15, 20])
|
||||
>>> jax.scipy.integrate.trapezoid(y, x)
|
||||
Array(43., dtype=float32)
|
||||
|
||||
Approximate :math:`\int_0^{2\pi} \sin^2(x)dx`, which equals :math:`\pi`:
|
||||
|
||||
>>> x = jnp.linspace(0, 2 * jnp.pi, 1000)
|
||||
>>> y = jnp.sin(x) ** 2
|
||||
>>> result = jax.scipy.integrate.trapezoid(y, x)
|
||||
>>> jnp.allclose(result, jnp.pi)
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
return lax_numpy.trapezoid(y, x, dx, axis)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,180 @@
|
||||
# Copyright 2019 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.util import safe_zip as zip
|
||||
|
||||
|
||||
def _nonempty_prod(arrs: Sequence[Array]) -> Array:
|
||||
return functools.reduce(operator.mul, arrs)
|
||||
|
||||
def _nonempty_sum(arrs: Sequence[Array]) -> Array:
|
||||
return sum(arrs[1:], arrs[0])
|
||||
|
||||
def _mirror_index_fixer(index: Array, size: int) -> Array:
|
||||
s = size - 1 # Half-wavelength of triangular wave
|
||||
# Scaled, integer-valued version of the triangular wave |x - round(x)|
|
||||
return jnp.abs((index + s) % (2 * s) - s)
|
||||
|
||||
def _reflect_index_fixer(index: Array, size: int) -> Array:
|
||||
return jnp.floor_divide(_mirror_index_fixer(2*index+1, 2*size+1) - 1, 2)
|
||||
|
||||
_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = {
|
||||
'constant': lambda index, size: index,
|
||||
'nearest': lambda index, size: jnp.clip(index, 0, size - 1),
|
||||
'wrap': lambda index, size: index % size,
|
||||
'mirror': _mirror_index_fixer,
|
||||
'reflect': _reflect_index_fixer,
|
||||
}
|
||||
|
||||
|
||||
def _round_half_away_from_zero(a: Array) -> Array:
|
||||
return a if dtypes.issubdtype(a.dtype, np.integer) else lax.round(a)
|
||||
|
||||
|
||||
def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
|
||||
index = _round_half_away_from_zero(coordinate).astype(np.int32)
|
||||
weight = coordinate.dtype.type(1)
|
||||
return [(index, weight)]
|
||||
|
||||
|
||||
def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
|
||||
lower = jnp.floor(coordinate)
|
||||
upper_weight = coordinate - lower
|
||||
lower_weight = 1 - upper_weight
|
||||
index = lower.astype(np.int32)
|
||||
return [(index, lower_weight), (index + 1, upper_weight)]
|
||||
|
||||
|
||||
@functools.partial(api.jit, static_argnums=(2, 3, 4))
|
||||
def _map_coordinates(input: ArrayLike, coordinates: Sequence[ArrayLike],
|
||||
order: int, mode: str, cval: ArrayLike) -> Array:
|
||||
input_arr = jnp.asarray(input)
|
||||
coordinate_arrs = [jnp.asarray(c) for c in coordinates]
|
||||
cval = jnp.asarray(cval, input_arr.dtype)
|
||||
|
||||
if len(coordinates) != input_arr.ndim:
|
||||
raise ValueError('coordinates must be a sequence of length input.ndim, but '
|
||||
'{} != {}'.format(len(coordinates), input_arr.ndim))
|
||||
|
||||
index_fixer = _INDEX_FIXERS.get(mode)
|
||||
if index_fixer is None:
|
||||
raise NotImplementedError(
|
||||
'jax.scipy.ndimage.map_coordinates does not yet support mode {}. '
|
||||
'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS)))
|
||||
|
||||
if mode == 'constant':
|
||||
is_valid = lambda index, size: (0 <= index) & (index < size)
|
||||
else:
|
||||
is_valid = lambda index, size: True
|
||||
|
||||
if order == 0:
|
||||
interp_fun = _nearest_indices_and_weights
|
||||
elif order == 1:
|
||||
interp_fun = _linear_indices_and_weights
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'jax.scipy.ndimage.map_coordinates currently requires order<=1')
|
||||
|
||||
valid_1d_interpolations = []
|
||||
for coordinate, size in zip(coordinate_arrs, input_arr.shape):
|
||||
interp_nodes = interp_fun(coordinate)
|
||||
valid_interp = []
|
||||
for index, weight in interp_nodes:
|
||||
fixed_index = index_fixer(index, size)
|
||||
valid = is_valid(index, size)
|
||||
valid_interp.append((fixed_index, valid, weight))
|
||||
valid_1d_interpolations.append(valid_interp)
|
||||
|
||||
outputs = []
|
||||
for items in itertools.product(*valid_1d_interpolations):
|
||||
indices, validities, weights = util.unzip3(items)
|
||||
if all(valid is True for valid in validities):
|
||||
# fast path
|
||||
contribution = input_arr[indices]
|
||||
else:
|
||||
all_valid = functools.reduce(operator.and_, validities)
|
||||
contribution = jnp.where(all_valid, input_arr[indices], cval)
|
||||
outputs.append(_nonempty_prod(weights) * contribution) # pyrefly: ignore[bad-argument-type]
|
||||
result = _nonempty_sum(outputs)
|
||||
if dtypes.issubdtype(input_arr.dtype, np.integer):
|
||||
result = _round_half_away_from_zero(result)
|
||||
return result.astype(input_arr.dtype)
|
||||
|
||||
|
||||
def map_coordinates(
|
||||
input: ArrayLike, coordinates: Sequence[ArrayLike], order: int,
|
||||
mode: str = 'constant', cval: ArrayLike = 0.0,
|
||||
):
|
||||
"""
|
||||
Map the input array to new coordinates using interpolation.
|
||||
|
||||
JAX implementation of :func:`scipy.ndimage.map_coordinates`
|
||||
|
||||
Given an input array and a set of coordinates, this function returns the
|
||||
interpolated values of the input array at those coordinates.
|
||||
|
||||
Args:
|
||||
input: N-dimensional input array from which values are interpolated.
|
||||
coordinates: length-N sequence of arrays specifying the coordinates
|
||||
at which to evaluate the interpolated values
|
||||
order: The order of interpolation. JAX supports the following:
|
||||
|
||||
* 0: Nearest-neighbor
|
||||
* 1: Linear
|
||||
|
||||
mode: Points outside the boundaries of the input are filled according to the given mode.
|
||||
JAX supports one of ``('constant', 'nearest', 'mirror', 'wrap', 'reflect')``. Note the
|
||||
``'wrap'`` mode in JAX behaves as ``'grid-wrap'`` mode in SciPy, and ``'constant'``
|
||||
mode in JAX behaves as ``'grid-constant'`` mode in SciPy. This discrepancy was caused
|
||||
by a former bug in those modes in SciPy (https://github.com/scipy/scipy/issues/2640),
|
||||
which was first fixed in JAX by changing the behavior of the existing modes, and later
|
||||
on fixed in SciPy, by adding modes with new names, rather than fixing the existing
|
||||
ones, for backwards compatibility reasons. Default is 'constant'.
|
||||
cval: Value used for points outside the boundaries of the input if ``mode='constant'``
|
||||
Default is 0.0.
|
||||
|
||||
Returns:
|
||||
The interpolated values at the specified coordinates.
|
||||
|
||||
Examples:
|
||||
>>> input = jnp.arange(12.0).reshape(3, 4)
|
||||
>>> input
|
||||
Array([[ 0., 1., 2., 3.],
|
||||
[ 4., 5., 6., 7.],
|
||||
[ 8., 9., 10., 11.]], dtype=float32)
|
||||
>>> coordinates = [jnp.array([0.5, 1.5]),
|
||||
... jnp.array([1.5, 2.5])]
|
||||
>>> jax.scipy.ndimage.map_coordinates(input, coordinates, order=1)
|
||||
Array([3.5, 8.5], dtype=float32)
|
||||
|
||||
Note:
|
||||
Interpolation near boundaries differs from the scipy function, because JAX
|
||||
fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097.
|
||||
This function interprets the ``mode`` argument as documented by SciPy, but
|
||||
not as implemented by SciPy.
|
||||
"""
|
||||
return _map_coordinates(input, coordinates, order, mode, cval)
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,244 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import linalg as jnp_linalg
|
||||
from jax._src.scipy.optimize.line_search import line_search
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
|
||||
class LBFGSResults(NamedTuple):
|
||||
"""Results from L-BFGS optimization
|
||||
|
||||
Parameters:
|
||||
converged: True if minimization converged
|
||||
failed: True if non-zero status and not converged
|
||||
k: integer number of iterations of the main loop (optimisation steps)
|
||||
nfev: integer total number of objective evaluations performed.
|
||||
ngev: integer total number of jacobian evaluations
|
||||
x_k: array containing the last argument value found during the search. If
|
||||
the search converged, then this value is the argmin of the objective
|
||||
function.
|
||||
f_k: array containing the value of the objective function at `x_k`. If the
|
||||
search converged, then this is the (local) minimum of the objective
|
||||
function.
|
||||
g_k: array containing the gradient of the objective function at `x_k`. If
|
||||
the search converged the l2-norm of this tensor should be below the
|
||||
tolerance.
|
||||
status: integer describing the status:
|
||||
0 = nominal , 1 = max iters reached , 2 = max fun evals reached
|
||||
3 = max grad evals reached , 4 = insufficient progress (ftol)
|
||||
5 = line search failed
|
||||
ls_status: integer describing the end status of the last line search
|
||||
"""
|
||||
converged: Array
|
||||
failed: Array
|
||||
k: int | Array
|
||||
nfev: int | Array
|
||||
ngev: int | Array
|
||||
x_k: Array
|
||||
f_k: Array
|
||||
g_k: Array
|
||||
s_history: Array
|
||||
y_history: Array
|
||||
rho_history: Array
|
||||
gamma: float | Array
|
||||
status: int | Array
|
||||
ls_status: int | Array
|
||||
|
||||
|
||||
def _minimize_lbfgs(
|
||||
fun: Callable,
|
||||
x0: Array,
|
||||
maxiter: float | None = None,
|
||||
norm=np.inf,
|
||||
maxcor: int = 10,
|
||||
ftol: float = 2.220446049250313e-09,
|
||||
gtol: float = 1e-05,
|
||||
maxfun: float | None = None,
|
||||
maxgrad: float | None = None,
|
||||
maxls: int = 20,
|
||||
):
|
||||
"""
|
||||
Minimize a function using L-BFGS
|
||||
|
||||
Implements the L-BFGS algorithm from
|
||||
Algorithm 7.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 176-185
|
||||
And generalizes to complex variables from
|
||||
Sorber, L., Barel, M.V. and Lathauwer, L.D., 2012.
|
||||
"Unconstrained optimization of real functions in complex variables"
|
||||
SIAM Journal on Optimization, 22(3), pp.879-898.
|
||||
|
||||
Args:
|
||||
fun: function of the form f(x) where x is a flat ndarray and returns a real scalar.
|
||||
The function should be composed of operations with vjp defined.
|
||||
x0: initial guess
|
||||
maxiter: maximum number of iterations
|
||||
norm: order of norm for convergence check. Default inf.
|
||||
maxcor: maximum number of metric corrections ("history size")
|
||||
ftol: terminates the minimization when `(f_k - f_{k+1}) < ftol`
|
||||
gtol: terminates the minimization when `|g_k|_norm < gtol`
|
||||
maxfun: maximum number of function evaluations
|
||||
maxgrad: maximum number of gradient evaluations
|
||||
maxls: maximum number of line search steps (per iteration)
|
||||
|
||||
Returns:
|
||||
Optimization results.
|
||||
"""
|
||||
d = len(x0)
|
||||
dtype = dtypes.dtype(x0)
|
||||
|
||||
# ensure there is at least one termination condition
|
||||
if (maxiter is None) and (maxfun is None) and (maxgrad is None):
|
||||
maxiter = d * 200
|
||||
|
||||
# set others to inf, such that >= is supported
|
||||
if maxiter is None:
|
||||
maxiter = np.inf
|
||||
if maxfun is None:
|
||||
maxfun = np.inf
|
||||
if maxgrad is None:
|
||||
maxgrad = np.inf
|
||||
|
||||
# initial evaluation
|
||||
f_0, g_0 = api.value_and_grad(fun)(x0)
|
||||
state_initial = LBFGSResults(
|
||||
converged=jnp.array(False, dtype=bool),
|
||||
failed=jnp.array(False, dtype=bool),
|
||||
k=0,
|
||||
nfev=1,
|
||||
ngev=1,
|
||||
x_k=x0,
|
||||
f_k=f_0,
|
||||
g_k=g_0,
|
||||
s_history=jnp.zeros((maxcor, d), dtype=dtype),
|
||||
y_history=jnp.zeros((maxcor, d), dtype=dtype),
|
||||
rho_history=jnp.zeros((maxcor,), dtype=dtype),
|
||||
gamma=1.,
|
||||
status=0,
|
||||
ls_status=0,
|
||||
)
|
||||
|
||||
def cond_fun(state: LBFGSResults):
|
||||
return (~state.converged) & (~state.failed)
|
||||
|
||||
def body_fun(state: LBFGSResults):
|
||||
# find search direction
|
||||
p_k = _two_loop_recursion(state)
|
||||
|
||||
# line search
|
||||
ls_results = line_search(
|
||||
f=fun,
|
||||
xk=state.x_k,
|
||||
pk=p_k,
|
||||
old_fval=state.f_k,
|
||||
gfk=state.g_k,
|
||||
maxiter=maxls,
|
||||
)
|
||||
|
||||
# evaluate at next iterate
|
||||
s_k = jnp.asarray(ls_results.a_k).astype(p_k.dtype) * p_k
|
||||
x_kp1 = state.x_k + s_k
|
||||
f_kp1 = ls_results.f_k
|
||||
g_kp1 = ls_results.g_k
|
||||
y_k = g_kp1 - state.g_k
|
||||
rho_k_inv = jnp.real(_dot(y_k, s_k))
|
||||
rho_k = jnp.reciprocal(rho_k_inv).astype(y_k.dtype)
|
||||
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))
|
||||
|
||||
# replacements for next iteration
|
||||
status = jnp.array(0)
|
||||
status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
|
||||
status = jnp.where(state.ngev >= maxgrad, 3, status)
|
||||
status = jnp.where(state.nfev >= maxfun, 2, status)
|
||||
status = jnp.where(state.k >= maxiter, 1, status)
|
||||
status = jnp.where(ls_results.failed, 5, status)
|
||||
|
||||
converged = jnp_linalg.norm(g_kp1, ord=norm) < gtol
|
||||
|
||||
# TODO(jakevdp): use a fixed-point procedure rather than type-casting?
|
||||
state = state._replace(
|
||||
converged=converged,
|
||||
failed=(status > 0) & (~converged),
|
||||
k=state.k + 1,
|
||||
nfev=state.nfev + ls_results.nfev,
|
||||
ngev=state.ngev + ls_results.ngev,
|
||||
x_k=x_kp1.astype(state.x_k.dtype),
|
||||
f_k=f_kp1.astype(state.f_k.dtype),
|
||||
g_k=g_kp1.astype(state.g_k.dtype),
|
||||
s_history=_update_history_vectors(history=state.s_history, new=s_k),
|
||||
y_history=_update_history_vectors(history=state.y_history, new=y_k),
|
||||
rho_history=_update_history_scalars(history=state.rho_history, new=rho_k),
|
||||
gamma=gamma.astype(state.g_k.dtype),
|
||||
status=jnp.where(converged, 0, status),
|
||||
ls_status=ls_results.status,
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
return lax.while_loop(cond_fun, body_fun, state_initial)
|
||||
|
||||
|
||||
def _two_loop_recursion(state: LBFGSResults):
|
||||
dtype = state.rho_history.dtype
|
||||
his_size = len(state.rho_history)
|
||||
curr_size = jnp.where(state.k < his_size, state.k, his_size)
|
||||
q = -jnp.conj(state.g_k)
|
||||
a_his = jnp.zeros_like(state.rho_history)
|
||||
|
||||
def body_fun1(j, carry):
|
||||
i = his_size - 1 - j
|
||||
_q, _a_his = carry
|
||||
a_i = state.rho_history[i] * _dot(jnp.conj(state.s_history[i]), _q).real.astype(dtype)
|
||||
_a_his = _a_his.at[i].set(a_i)
|
||||
_q = _q - a_i * jnp.conj(state.y_history[i])
|
||||
return _q, _a_his
|
||||
|
||||
q, a_his = lax.fori_loop(0, curr_size, body_fun1, (q, a_his))
|
||||
q = state.gamma * q
|
||||
|
||||
def body_fun2(j, _q):
|
||||
i = his_size - curr_size + j
|
||||
b_i = state.rho_history[i] * _dot(state.y_history[i], _q).real.astype(dtype)
|
||||
_q = _q + (a_his[i] - b_i) * state.s_history[i]
|
||||
return _q
|
||||
|
||||
q = lax.fori_loop(0, curr_size, body_fun2, q)
|
||||
return q
|
||||
|
||||
|
||||
def _update_history_vectors(history, new):
|
||||
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
|
||||
return jnp.roll(history, -1, axis=0).at[-1, :].set(new)
|
||||
|
||||
|
||||
def _update_history_scalars(history, new):
|
||||
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
|
||||
return jnp.roll(history, -1, axis=0).at[-1].set(new)
|
||||
@@ -0,0 +1,188 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import einsum as jnp_einsum
|
||||
from jax._src.numpy import linalg as jnp_linalg
|
||||
from jax._src.scipy.optimize.line_search import line_search
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
class _BFGSResults(NamedTuple):
|
||||
"""Results from BFGS optimization.
|
||||
|
||||
Parameters:
|
||||
converged: True if minimization converged.
|
||||
failed: True if line search failed.
|
||||
k: integer the number of iterations of the BFGS update.
|
||||
nfev: integer total number of objective evaluations performed.
|
||||
ngev: integer total number of jacobian evaluations
|
||||
nhev: integer total number of hessian evaluations
|
||||
x_k: array containing the last argument value found during the search. If
|
||||
the search converged, then this value is the argmin of the objective
|
||||
function.
|
||||
f_k: array containing the value of the objective function at `x_k`. If the
|
||||
search converged, then this is the (local) minimum of the objective
|
||||
function.
|
||||
g_k: array containing the gradient of the objective function at `x_k`. If
|
||||
the search converged the l2-norm of this tensor should be below the
|
||||
tolerance.
|
||||
H_k: array containing the inverse of the estimated Hessian.
|
||||
status: int describing end state.
|
||||
line_search_status: int describing line search end state (only means
|
||||
something if line search fails).
|
||||
"""
|
||||
converged: bool | Array
|
||||
failed: bool | Array
|
||||
k: int | Array
|
||||
nfev: int | Array
|
||||
ngev: int | Array
|
||||
nhev: int | Array
|
||||
x_k: Array
|
||||
f_k: Array
|
||||
g_k: Array
|
||||
H_k: Array
|
||||
old_old_fval: Array
|
||||
status: int | Array
|
||||
line_search_status: int | Array
|
||||
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
_einsum = partial(jnp_einsum.einsum, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
def minimize_bfgs(
|
||||
fun: Callable,
|
||||
x0: Array,
|
||||
maxiter: int | None = None,
|
||||
norm=np.inf,
|
||||
gtol: float = 1e-5,
|
||||
line_search_maxiter: int = 10,
|
||||
) -> _BFGSResults:
|
||||
"""Minimize a function using BFGS.
|
||||
|
||||
Implements the BFGS algorithm from
|
||||
Algorithm 6.1 from Wright and Nocedal, 'Numerical Optimization', 1999, pg.
|
||||
136-143.
|
||||
|
||||
Args:
|
||||
fun: function of the form f(x) where x is a flat ndarray and returns a real
|
||||
scalar. The function should be composed of operations with vjp defined.
|
||||
x0: initial guess.
|
||||
maxiter: maximum number of iterations.
|
||||
norm: order of norm for convergence check. Default inf.
|
||||
gtol: terminates minimization when |grad|_norm < g_tol.
|
||||
line_search_maxiter: maximum number of linesearch iterations.
|
||||
|
||||
Returns:
|
||||
Optimization result.
|
||||
"""
|
||||
|
||||
if maxiter is None:
|
||||
maxiter = np.size(x0) * 200
|
||||
|
||||
d = x0.shape[0]
|
||||
|
||||
initial_H = jnp.eye(d, dtype=x0.dtype)
|
||||
f_0, g_0 = api.value_and_grad(fun)(x0)
|
||||
state = _BFGSResults(
|
||||
converged=jnp_linalg.norm(g_0, ord=norm) < gtol,
|
||||
failed=False,
|
||||
k=0,
|
||||
nfev=1,
|
||||
ngev=1,
|
||||
nhev=0,
|
||||
x_k=x0,
|
||||
f_k=f_0,
|
||||
g_k=g_0,
|
||||
H_k=initial_H,
|
||||
old_old_fval=f_0 + jnp_linalg.norm(g_0) / 2,
|
||||
status=0,
|
||||
line_search_status=0,
|
||||
)
|
||||
|
||||
def cond_fun(state):
|
||||
return (jnp.logical_not(state.converged)
|
||||
& jnp.logical_not(state.failed)
|
||||
& (state.k < maxiter))
|
||||
|
||||
def body_fun(state):
|
||||
p_k = -_dot(state.H_k, state.g_k)
|
||||
line_search_results = line_search(
|
||||
fun,
|
||||
state.x_k,
|
||||
p_k,
|
||||
old_fval=state.f_k,
|
||||
old_old_fval=state.old_old_fval,
|
||||
gfk=state.g_k,
|
||||
maxiter=line_search_maxiter,
|
||||
)
|
||||
state = state._replace(
|
||||
nfev=state.nfev + line_search_results.nfev,
|
||||
ngev=state.ngev + line_search_results.ngev,
|
||||
failed=line_search_results.failed,
|
||||
line_search_status=line_search_results.status,
|
||||
)
|
||||
s_k = line_search_results.a_k * p_k
|
||||
x_kp1 = state.x_k + s_k
|
||||
f_kp1 = line_search_results.f_k
|
||||
g_kp1 = line_search_results.g_k
|
||||
y_k = g_kp1 - state.g_k
|
||||
rho_k = jnp.reciprocal(_dot(y_k, s_k))
|
||||
|
||||
sy_k = s_k[:, np.newaxis] * y_k[np.newaxis, :]
|
||||
w = jnp.eye(d, dtype=rho_k.dtype) - rho_k * sy_k
|
||||
H_kp1 = (_einsum('ij,jk,lk', w, state.H_k, w)
|
||||
+ rho_k * s_k[:, np.newaxis] * s_k[np.newaxis, :])
|
||||
H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k)
|
||||
converged = jnp_linalg.norm(g_kp1, ord=norm) < gtol
|
||||
|
||||
state = state._replace(
|
||||
converged=converged,
|
||||
k=state.k + 1,
|
||||
x_k=x_kp1,
|
||||
f_k=f_kp1,
|
||||
g_k=g_kp1,
|
||||
H_k=H_kp1,
|
||||
old_old_fval=state.f_k,
|
||||
)
|
||||
return state
|
||||
|
||||
state = lax.while_loop(cond_fun, body_fun, state)
|
||||
status = jnp.where(
|
||||
state.converged,
|
||||
0, # converged
|
||||
jnp.where(
|
||||
state.k == maxiter,
|
||||
1, # max iters reached
|
||||
jnp.where(
|
||||
state.failed,
|
||||
2 + state.line_search_status, # ls failed (+ reason)
|
||||
-1, # undefined
|
||||
)
|
||||
)
|
||||
)
|
||||
state = state._replace(status=status)
|
||||
return state
|
||||
@@ -0,0 +1,442 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple, Callable
|
||||
from functools import partial
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
def _cubicmin(a, fa, fpa, b, fb, c, fc):
|
||||
dtype = dtypes.result_type(a, fa, fpa, b, fb, c, fc)
|
||||
C = fpa
|
||||
db = b - a
|
||||
dc = c - a
|
||||
denom = (db * dc) ** 2 * (db - dc)
|
||||
d1 = jnp.array([[dc ** 2, -db ** 2],
|
||||
[-dc ** 3, db ** 3]], dtype=dtype)
|
||||
d2 = jnp.array([fb - fa - C * db, fc - fa - C * dc], dtype=dtype)
|
||||
A, B = _dot(d1, d2) / denom
|
||||
|
||||
radical = B * B - 3. * A * C
|
||||
xmin = a + (-B + jnp.sqrt(radical)) / (3. * A)
|
||||
|
||||
return xmin
|
||||
|
||||
|
||||
def _quadmin(a, fa, fpa, b, fb):
|
||||
D = fa
|
||||
C = fpa
|
||||
db = b - a
|
||||
B = (fb - D - C * db) / (db ** 2)
|
||||
xmin = a - C / (2. * B)
|
||||
return xmin
|
||||
|
||||
|
||||
def _binary_replace(replace_bit, original_dict, new_dict, keys=None):
|
||||
if keys is None:
|
||||
keys = new_dict.keys()
|
||||
return {key: jnp.where(replace_bit, new_dict[key], original_dict[key])
|
||||
for key in keys}
|
||||
|
||||
|
||||
class _ZoomState(NamedTuple):
|
||||
done: bool | Array
|
||||
failed: bool | Array
|
||||
j: int | Array
|
||||
a_lo: float | Array
|
||||
phi_lo: float | Array
|
||||
dphi_lo: float | Array
|
||||
a_hi: float | Array
|
||||
phi_hi: float | Array
|
||||
dphi_hi: float | Array
|
||||
a_rec: float | Array
|
||||
phi_rec: float | Array
|
||||
a_star: float | Array
|
||||
phi_star: float | Array
|
||||
dphi_star: float | Array
|
||||
g_star: float | Array
|
||||
nfev: int | Array
|
||||
ngev: int | Array
|
||||
|
||||
|
||||
ConditionFn = Callable[..., Array]
|
||||
|
||||
|
||||
def _zoom(restricted_func_and_grad, wolfe_one: ConditionFn, wolfe_two: ConditionFn, a_lo, phi_lo,
|
||||
dphi_lo, a_hi, phi_hi, dphi_hi, g_0, pass_through):
|
||||
"""
|
||||
Implementation of zoom. Algorithm 3.6 from Wright and Nocedal, 'Numerical
|
||||
Optimization', 1999, pg. 59-61. Tries cubic, quadratic, and bisection methods
|
||||
of zooming.
|
||||
"""
|
||||
state = _ZoomState(
|
||||
done=False,
|
||||
failed=False,
|
||||
j=0,
|
||||
a_lo=a_lo,
|
||||
phi_lo=phi_lo,
|
||||
dphi_lo=dphi_lo,
|
||||
a_hi=a_hi,
|
||||
phi_hi=phi_hi,
|
||||
dphi_hi=dphi_hi,
|
||||
a_rec=(a_lo + a_hi) / 2.,
|
||||
phi_rec=(phi_lo + phi_hi) / 2.,
|
||||
a_star=1.,
|
||||
phi_star=phi_lo,
|
||||
dphi_star=dphi_lo,
|
||||
g_star=g_0,
|
||||
nfev=0,
|
||||
ngev=0,
|
||||
)
|
||||
delta1 = 0.2
|
||||
delta2 = 0.1
|
||||
|
||||
def body(state):
|
||||
# Body of zoom algorithm. We use boolean arithmetic to avoid using jax.cond
|
||||
# so that it works on GPU/TPU.
|
||||
dalpha = (state.a_hi - state.a_lo)
|
||||
a = jnp.minimum(state.a_hi, state.a_lo)
|
||||
b = jnp.maximum(state.a_hi, state.a_lo)
|
||||
cchk = delta1 * dalpha
|
||||
qchk = delta2 * dalpha
|
||||
|
||||
# This will cause the line search to stop, and since the Wolfe conditions
|
||||
# are not satisfied the minimization should stop too.
|
||||
threshold = jnp.where((dtypes.finfo(dalpha.dtype).bits < 64), 1e-5, 1e-10)
|
||||
state = state._replace(failed=state.failed | (dalpha <= threshold))
|
||||
|
||||
# Cubmin is sometimes nan, though in this case the bounds check will fail.
|
||||
a_j_cubic = _cubicmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi,
|
||||
state.phi_hi, state.a_rec, state.phi_rec)
|
||||
use_cubic = (state.j > 0) & (a_j_cubic > a + cchk) & (a_j_cubic < b - cchk)
|
||||
a_j_quad = _quadmin(state.a_lo, state.phi_lo, state.dphi_lo, state.a_hi, state.phi_hi)
|
||||
use_quad = (~use_cubic) & (a_j_quad > a + qchk) & (a_j_quad < b - qchk)
|
||||
a_j_bisection = (state.a_lo + state.a_hi) / 2.
|
||||
use_bisection = (~use_cubic) & (~use_quad)
|
||||
|
||||
a_j = jnp.where(use_cubic, a_j_cubic, state.a_rec)
|
||||
a_j = jnp.where(use_quad, a_j_quad, a_j)
|
||||
a_j = jnp.where(use_bisection, a_j_bisection, a_j)
|
||||
|
||||
# TODO(jakevdp): should we use some sort of fixed-point approach here instead?
|
||||
phi_j, dphi_j, g_j = restricted_func_and_grad(a_j)
|
||||
phi_j = phi_j.astype(state.phi_lo.dtype)
|
||||
dphi_j = dphi_j.astype(state.dphi_lo.dtype)
|
||||
g_j = g_j.astype(state.g_star.dtype)
|
||||
state = state._replace(nfev=state.nfev + 1,
|
||||
ngev=state.ngev + 1)
|
||||
|
||||
hi_to_j = wolfe_one(a_j, phi_j) | (phi_j >= state.phi_lo)
|
||||
star_to_j = wolfe_two(dphi_j) & (~hi_to_j)
|
||||
hi_to_lo = (dphi_j * (state.a_hi - state.a_lo) >= 0.) & (~hi_to_j) & (~star_to_j)
|
||||
lo_to_j = (~hi_to_j) & (~star_to_j)
|
||||
|
||||
state = state._replace(
|
||||
**_binary_replace(
|
||||
hi_to_j,
|
||||
state._asdict(),
|
||||
dict(
|
||||
a_hi=a_j,
|
||||
phi_hi=phi_j,
|
||||
dphi_hi=dphi_j,
|
||||
a_rec=state.a_hi,
|
||||
phi_rec=state.phi_hi,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# for termination
|
||||
state = state._replace(
|
||||
done=star_to_j | state.done,
|
||||
**_binary_replace(
|
||||
star_to_j,
|
||||
state._asdict(),
|
||||
dict(
|
||||
a_star=a_j,
|
||||
phi_star=phi_j,
|
||||
dphi_star=dphi_j,
|
||||
g_star=g_j,
|
||||
)
|
||||
),
|
||||
)
|
||||
state = state._replace(
|
||||
**_binary_replace(
|
||||
hi_to_lo,
|
||||
state._asdict(),
|
||||
dict(
|
||||
a_hi=state.a_lo,
|
||||
phi_hi=state.phi_lo,
|
||||
dphi_hi=state.dphi_lo,
|
||||
a_rec=state.a_hi,
|
||||
phi_rec=state.phi_hi,
|
||||
),
|
||||
),
|
||||
)
|
||||
state = state._replace(
|
||||
**_binary_replace(
|
||||
lo_to_j & ~hi_to_lo,
|
||||
state._asdict(),
|
||||
dict(
|
||||
a_rec=state.a_lo,
|
||||
phi_rec=state.phi_lo,
|
||||
),
|
||||
),
|
||||
)
|
||||
state = state._replace(
|
||||
**_binary_replace(
|
||||
lo_to_j,
|
||||
state._asdict(),
|
||||
dict(
|
||||
a_lo=a_j,
|
||||
phi_lo=phi_j,
|
||||
dphi_lo=dphi_j,
|
||||
),
|
||||
),
|
||||
)
|
||||
state = state._replace(j=state.j + 1)
|
||||
# Choose higher cutoff for maxiter than Scipy as Jax takes longer to find
|
||||
# the same value - possibly floating point issues?
|
||||
state = state._replace(failed= state.failed | (state.j >= 30))
|
||||
return state
|
||||
|
||||
state = lax.while_loop(lambda state: (~state.done) & (~pass_through) & (~state.failed),
|
||||
body,
|
||||
state)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
class _LineSearchState(NamedTuple):
|
||||
done: Array
|
||||
failed: Array
|
||||
i: int | Array
|
||||
a_i1: float | Array
|
||||
phi_i1: float | Array
|
||||
dphi_i1: float | Array
|
||||
nfev: int | Array
|
||||
ngev: int | Array
|
||||
a_star: float | Array
|
||||
phi_star: Array
|
||||
dphi_star: Array
|
||||
g_star: Array
|
||||
|
||||
|
||||
class _LineSearchResults(NamedTuple):
|
||||
"""Results of line search.
|
||||
|
||||
Parameters:
|
||||
failed: True if the strong Wolfe criteria were satisfied
|
||||
nit: integer number of iterations
|
||||
nfev: integer number of functions evaluations
|
||||
ngev: integer number of gradients evaluations
|
||||
k: integer number of iterations
|
||||
a_k: integer step size
|
||||
f_k: final function value
|
||||
g_k: final gradient value
|
||||
status: integer end status
|
||||
"""
|
||||
failed: bool | Array
|
||||
nit: int | Array
|
||||
nfev: int | Array
|
||||
ngev: int | Array
|
||||
k: int | Array
|
||||
a_k: int | Array
|
||||
f_k: Array
|
||||
g_k: Array
|
||||
status: bool | Array
|
||||
|
||||
|
||||
def line_search(f, xk, pk, old_fval=None, old_old_fval=None, gfk=None, c1=1e-4,
|
||||
c2=0.9, maxiter=20):
|
||||
"""Inexact line search that satisfies strong Wolfe conditions.
|
||||
|
||||
Algorithm 3.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 59-61
|
||||
|
||||
Args:
|
||||
fun: function of the form f(x) where x is a flat ndarray and returns a real
|
||||
scalar. The function should be composed of operations with vjp defined.
|
||||
x0: initial guess.
|
||||
pk: direction to search in. Assumes the direction is a descent direction.
|
||||
old_fval, gfk: initial value of value_and_gradient as position.
|
||||
old_old_fval: unused argument, only for scipy API compliance.
|
||||
maxiter: maximum number of iterations to search
|
||||
c1, c2: Wolfe criteria constant, see ref.
|
||||
|
||||
Returns: LineSearchResults
|
||||
"""
|
||||
xk, pk = promote_dtypes_inexact(xk, pk)
|
||||
def restricted_func_and_grad(t):
|
||||
t = jnp.array(t, dtype=pk.dtype)
|
||||
phi, g = api.value_and_grad(f)(xk + t * pk)
|
||||
dphi = jnp.real(_dot(g, pk))
|
||||
return phi, dphi, g
|
||||
|
||||
if old_fval is None or gfk is None:
|
||||
phi_0, dphi_0, gfk = restricted_func_and_grad(0)
|
||||
else:
|
||||
phi_0 = old_fval
|
||||
dphi_0 = jnp.real(_dot(gfk, pk))
|
||||
if old_old_fval is not None:
|
||||
candidate_start_value = 1.01 * 2 * (phi_0 - old_old_fval) / dphi_0
|
||||
start_value = jnp.where(candidate_start_value > 1, 1.0, candidate_start_value)
|
||||
else:
|
||||
start_value = 1
|
||||
|
||||
def wolfe_one(a_i, phi_i) -> Array:
|
||||
# actually negation of W1
|
||||
return phi_i > phi_0 + c1 * a_i * dphi_0
|
||||
|
||||
def wolfe_two(dphi_i) -> Array:
|
||||
return jnp.abs(dphi_i) <= -c2 * dphi_0
|
||||
|
||||
state = _LineSearchState(
|
||||
done=jnp.array(False, dtype=bool),
|
||||
failed=jnp.array(False, dtype=bool),
|
||||
# algorithm begins at 1 as per Wright and Nocedal, however Scipy has a
|
||||
# bug and starts at 0. See https://github.com/scipy/scipy/issues/12157
|
||||
i=1,
|
||||
a_i1=0.,
|
||||
phi_i1=phi_0,
|
||||
dphi_i1=dphi_0,
|
||||
nfev=1 if (old_fval is None or gfk is None) else 0,
|
||||
ngev=1 if (old_fval is None or gfk is None) else 0,
|
||||
a_star=0.,
|
||||
phi_star=phi_0,
|
||||
dphi_star=dphi_0,
|
||||
g_star=gfk,
|
||||
)
|
||||
|
||||
def body(state) -> _LineSearchState:
|
||||
# no amax in this version, we just double as in scipy.
|
||||
# unlike original algorithm we do our next choice at the start of this loop
|
||||
a_i = jnp.where(state.i == 1, start_value, state.a_i1 * 2.)
|
||||
|
||||
phi_i, dphi_i, g_i = restricted_func_and_grad(a_i)
|
||||
state = state._replace(nfev=state.nfev + 1,
|
||||
ngev=state.ngev + 1)
|
||||
|
||||
star_to_zoom1 = wolfe_one(a_i, phi_i) | ((phi_i >= state.phi_i1) & (state.i > 1))
|
||||
star_to_i = wolfe_two(dphi_i) & (~star_to_zoom1)
|
||||
star_to_zoom2 = (dphi_i >= 0.) & (~star_to_zoom1) & (~star_to_i)
|
||||
|
||||
zoom1 = _zoom(restricted_func_and_grad,
|
||||
wolfe_one,
|
||||
wolfe_two,
|
||||
state.a_i1,
|
||||
state.phi_i1,
|
||||
state.dphi_i1,
|
||||
a_i,
|
||||
phi_i,
|
||||
dphi_i,
|
||||
gfk,
|
||||
~star_to_zoom1)
|
||||
|
||||
state = state._replace(nfev=state.nfev + zoom1.nfev,
|
||||
ngev=state.ngev + zoom1.ngev)
|
||||
|
||||
zoom2 = _zoom(restricted_func_and_grad,
|
||||
wolfe_one,
|
||||
wolfe_two,
|
||||
a_i,
|
||||
phi_i,
|
||||
dphi_i,
|
||||
state.a_i1,
|
||||
state.phi_i1,
|
||||
state.dphi_i1,
|
||||
gfk,
|
||||
~star_to_zoom2)
|
||||
|
||||
state = state._replace(nfev=state.nfev + zoom2.nfev,
|
||||
ngev=state.ngev + zoom2.ngev)
|
||||
|
||||
state = state._replace(
|
||||
done=star_to_zoom1 | state.done,
|
||||
failed=(star_to_zoom1 & zoom1.failed) | state.failed,
|
||||
**_binary_replace(
|
||||
star_to_zoom1,
|
||||
state._asdict(),
|
||||
zoom1._asdict(),
|
||||
keys=['a_star', 'phi_star', 'dphi_star', 'g_star'],
|
||||
),
|
||||
)
|
||||
state = state._replace(
|
||||
done=star_to_i | state.done,
|
||||
**_binary_replace(
|
||||
star_to_i,
|
||||
state._asdict(),
|
||||
dict(
|
||||
a_star=a_i,
|
||||
phi_star=phi_i,
|
||||
dphi_star=dphi_i,
|
||||
g_star=g_i,
|
||||
),
|
||||
),
|
||||
)
|
||||
state = state._replace(
|
||||
done=star_to_zoom2 | state.done,
|
||||
failed=(star_to_zoom2 & zoom2.failed) | state.failed,
|
||||
**_binary_replace(
|
||||
star_to_zoom2,
|
||||
state._asdict(),
|
||||
zoom2._asdict(),
|
||||
keys=['a_star', 'phi_star', 'dphi_star', 'g_star'],
|
||||
),
|
||||
)
|
||||
state = state._replace(i=state.i + 1, a_i1=a_i, phi_i1=phi_i, dphi_i1=dphi_i)
|
||||
return state
|
||||
|
||||
state: _LineSearchState = lax.while_loop(
|
||||
lambda state: (~state.done) & (state.i <= maxiter) & (~state.failed),
|
||||
body,
|
||||
state)
|
||||
|
||||
status = jnp.where(
|
||||
state.failed,
|
||||
jnp.array(1), # zoom failed
|
||||
jnp.where(
|
||||
state.i > maxiter,
|
||||
jnp.array(3), # maxiter reached
|
||||
jnp.array(0), # passed (should be)
|
||||
),
|
||||
)
|
||||
# Step sizes which are too small causes the optimizer to get stuck with a
|
||||
# direction of zero in <64 bit mode - avoid with a floor on minimum step size.
|
||||
alpha_k = jnp.asarray(state.a_star)
|
||||
alpha_k = jnp.where((dtypes.finfo(alpha_k.dtype).bits != 64)
|
||||
& (jnp.abs(alpha_k) < 1e-8),
|
||||
jnp.sign(alpha_k) * 1e-8,
|
||||
alpha_k)
|
||||
results = _LineSearchResults(
|
||||
failed=state.failed | (~state.done),
|
||||
nit=state.i - 1, # because iterations started at 1
|
||||
nfev=state.nfev,
|
||||
ngev=state.ngev,
|
||||
k=state.i,
|
||||
a_k=alpha_k,
|
||||
f_k=state.phi_star,
|
||||
g_k=state.g_star,
|
||||
status=status,
|
||||
)
|
||||
return results
|
||||
@@ -0,0 +1,133 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.scipy.optimize.bfgs import minimize_bfgs
|
||||
from jax._src.scipy.optimize._lbfgs import _minimize_lbfgs
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
class OptimizeResults(NamedTuple):
|
||||
"""Object holding optimization results.
|
||||
|
||||
Parameters:
|
||||
x: final solution.
|
||||
success: ``True`` if optimization succeeded.
|
||||
status: integer solver specific return code. 0 means converged (nominal),
|
||||
1=max BFGS iters reached, 3=zoom failed, 4=saddle point reached,
|
||||
5=max line search iters reached, -1=undefined
|
||||
fun: final function value.
|
||||
jac: final jacobian array.
|
||||
hess_inv: final inverse Hessian estimate.
|
||||
nfev: integer number of function calls used.
|
||||
njev: integer number of gradient evaluations.
|
||||
nit: integer number of iterations of the optimization algorithm.
|
||||
"""
|
||||
x: Array
|
||||
success: bool | Array
|
||||
status: int | Array
|
||||
fun: Array
|
||||
jac: Array
|
||||
hess_inv: Array | None
|
||||
nfev: int | Array
|
||||
njev: int | Array
|
||||
nit: int | Array
|
||||
|
||||
|
||||
def minimize(
|
||||
fun: Callable,
|
||||
x0: Array,
|
||||
args: tuple = (),
|
||||
*,
|
||||
method: str,
|
||||
tol: float | None = None,
|
||||
options: Mapping[str, Any] | None = None,
|
||||
) -> OptimizeResults:
|
||||
"""Minimization of scalar function of one or more variables.
|
||||
|
||||
This API for this function matches SciPy with some minor deviations:
|
||||
|
||||
- Gradients of ``fun`` are calculated automatically using JAX's autodiff
|
||||
support when required.
|
||||
- The ``method`` argument is required. You must specify a solver.
|
||||
- Various optional arguments in the SciPy interface have not yet been
|
||||
implemented.
|
||||
- Optimization results may differ from SciPy due to differences in the line
|
||||
search implementation.
|
||||
|
||||
``minimize`` supports :func:`~jax.jit` compilation. It does not yet support
|
||||
differentiation or arguments in the form of multi-dimensional arrays, but
|
||||
support for both is planned.
|
||||
|
||||
Args:
|
||||
fun: the objective function to be minimized, ``fun(x, *args) -> float``,
|
||||
where ``x`` is a 1-D array with shape ``(n,)`` and ``args`` is a tuple
|
||||
of the fixed parameters needed to completely specify the function.
|
||||
``fun`` must support differentiation.
|
||||
x0: initial guess. Array of real elements of size ``(n,)``, where ``n`` is
|
||||
the number of independent variables.
|
||||
args: extra arguments passed to the objective function.
|
||||
method: solver type. Currently only ``"BFGS"`` is supported.
|
||||
tol: tolerance for termination. For detailed control, use solver-specific
|
||||
options.
|
||||
options: a dictionary of solver options. All methods accept the following
|
||||
generic options:
|
||||
|
||||
- maxiter (int): Maximum number of iterations to perform. Depending on the
|
||||
method each iteration may use several function evaluations.
|
||||
|
||||
Returns:
|
||||
An :class:`OptimizeResults` object.
|
||||
"""
|
||||
if options is None:
|
||||
options = {}
|
||||
|
||||
if not isinstance(args, tuple):
|
||||
msg = "args argument to jax.scipy.optimize.minimize must be a tuple, got {}"
|
||||
raise TypeError(msg.format(args))
|
||||
|
||||
fun_with_args = lambda x: fun(x, *args)
|
||||
|
||||
if method.lower() == 'bfgs':
|
||||
results = minimize_bfgs(fun_with_args, x0, **options)
|
||||
success = results.converged & jnp.logical_not(results.failed)
|
||||
return OptimizeResults(x=results.x_k,
|
||||
success=success,
|
||||
status=results.status,
|
||||
fun=results.f_k,
|
||||
jac=results.g_k,
|
||||
hess_inv=results.H_k,
|
||||
nfev=results.nfev,
|
||||
njev=results.ngev,
|
||||
nit=results.k)
|
||||
|
||||
if method.lower() == 'l-bfgs-experimental-do-not-rely-on-this':
|
||||
results = _minimize_lbfgs(fun_with_args, x0, **options)
|
||||
success = results.converged & jnp.logical_not(results.failed)
|
||||
return OptimizeResults(x=results.x_k,
|
||||
success=success,
|
||||
status=results.status,
|
||||
fun=results.f_k,
|
||||
jac=results.g_k,
|
||||
hess_inv=None,
|
||||
nfev=results.nfev,
|
||||
njev=results.ngev,
|
||||
nit=results.k)
|
||||
|
||||
raise ValueError(f"Method {method} not recognized")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,761 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from functools import partial
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy import einsum as jnp_einsum
|
||||
from jax._src.scipy import linalg as jsp_linalg
|
||||
from jax._src.tree_util import (tree_leaves, tree_map, tree_structure,
|
||||
tree_reduce, Partial)
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import safe_map as map
|
||||
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
_vdot = partial(jnp.vdot, precision=lax.Precision.HIGHEST)
|
||||
_einsum = partial(jnp_einsum.einsum, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
||||
# aliases for working with pytrees
|
||||
def _vdot_real_part(x, y):
|
||||
"""Vector dot-product guaranteed to have a real valued result despite
|
||||
possibly complex input. Thus neglects the real-imaginary cross-terms.
|
||||
The result is a real float.
|
||||
"""
|
||||
# all our uses of vdot() in CG are for computing an operator of the form
|
||||
# z^H M z
|
||||
# where M is positive definite and Hermitian, so the result is
|
||||
# real valued:
|
||||
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
|
||||
result = _vdot(x.real, y.real)
|
||||
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
|
||||
result += _vdot(x.imag, y.imag)
|
||||
return result
|
||||
|
||||
|
||||
def _vdot_real_tree(x, y):
|
||||
return sum(tree_leaves(tree_map(_vdot_real_part, x, y)))
|
||||
|
||||
|
||||
def _vdot_tree(x, y) -> ArrayLike:
|
||||
return sum(tree_leaves(tree_map(partial(
|
||||
jnp.vdot, precision=lax.Precision.HIGHEST), x, y)))
|
||||
|
||||
|
||||
def _norm(x):
|
||||
xs = tree_leaves(x)
|
||||
return jnp.sqrt(sum(map(_vdot_real_part, xs, xs)))
|
||||
|
||||
|
||||
def _mul(scalar, tree):
|
||||
return tree_map(partial(operator.mul, scalar), tree)
|
||||
|
||||
|
||||
_add = partial(tree_map, operator.add)
|
||||
_sub = partial(tree_map, operator.sub)
|
||||
_dot_tree = partial(tree_map, _dot)
|
||||
|
||||
|
||||
@Partial
|
||||
def _identity(x):
|
||||
return x
|
||||
|
||||
|
||||
def _normalize_matvec(f):
|
||||
"""Normalize an argument for computing matrix-vector products."""
|
||||
if callable(f):
|
||||
return f
|
||||
elif isinstance(f, (np.ndarray, Array)):
|
||||
if f.ndim != 2 or f.shape[0] != f.shape[1]:
|
||||
raise ValueError(
|
||||
f'linear operator must be a square matrix, but has shape: {f.shape}')
|
||||
return partial(_dot, f)
|
||||
elif hasattr(f, '__matmul__'):
|
||||
if hasattr(f, 'shape') and len(f.shape) != 2 or f.shape[0] != f.shape[1]:
|
||||
raise ValueError(
|
||||
f'linear operator must be a square matrix, but has shape: {f.shape}')
|
||||
return partial(operator.matmul, f)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'linear operator must be either a function or ndarray: {f}')
|
||||
|
||||
|
||||
def _cg_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
|
||||
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.cg
|
||||
bs = _vdot_real_tree(b, b)
|
||||
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
|
||||
|
||||
# https://en.wikipedia.org/wiki/Conjugate_gradient_method#The_preconditioned_conjugate_gradient_method
|
||||
|
||||
def cond_fun(value):
|
||||
_, r, gamma, _, k = value
|
||||
rs = gamma.real if M is _identity else _vdot_real_tree(r, r)
|
||||
return (rs > atol2) & (k < maxiter)
|
||||
|
||||
def body_fun(value):
|
||||
x, r, gamma, p, k = value
|
||||
Ap = A(p)
|
||||
alpha = gamma / _vdot_real_tree(p, Ap).astype(dtype)
|
||||
x_ = _add(x, _mul(alpha, p))
|
||||
r_ = _sub(r, _mul(alpha, Ap))
|
||||
z_ = M(r_)
|
||||
gamma_ = _vdot_real_tree(r_, z_).astype(dtype)
|
||||
beta_ = gamma_ / gamma
|
||||
p_ = _add(z_, _mul(beta_, p))
|
||||
return x_, r_, gamma_, p_, k + 1
|
||||
|
||||
r0 = _sub(b, A(x0))
|
||||
p0 = z0 = M(r0)
|
||||
dtype = dtypes.result_type(*tree_leaves(p0))
|
||||
gamma0 = _vdot_real_tree(r0, z0).astype(dtype)
|
||||
initial_value = (x0, r0, gamma0, p0, 0)
|
||||
|
||||
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
|
||||
return x_final
|
||||
|
||||
|
||||
# aliases for working with pytrees
|
||||
|
||||
def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity):
|
||||
|
||||
# tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.bicgstab
|
||||
bs = _vdot_real_tree(b, b)
|
||||
atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
|
||||
|
||||
# https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB
|
||||
|
||||
def cond_fun(value):
|
||||
x, r, *_, k = value
|
||||
rs = _vdot_real_tree(r, r)
|
||||
# the last condition checks breakdown
|
||||
return (rs > atol2) & (k < maxiter) & (k >= 0)
|
||||
|
||||
def body_fun(value):
|
||||
x, r, rhat, alpha, omega, rho, p, q, k = value
|
||||
rho_ = _vdot_tree(rhat, r)
|
||||
beta = rho_ / rho * alpha / omega
|
||||
p_ = _add(r, _mul(beta, _sub(p, _mul(omega, q))))
|
||||
phat = M(p_)
|
||||
q_ = A(phat)
|
||||
alpha_ = rho_ / _vdot_tree(rhat, q_)
|
||||
s = _sub(r, _mul(alpha_, q_))
|
||||
exit_early = _vdot_real_tree(s, s) < atol2
|
||||
shat = M(s)
|
||||
t = A(shat)
|
||||
omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases?
|
||||
x_ = tree_map(partial(jnp.where, exit_early),
|
||||
_add(x, _mul(alpha_, phat)),
|
||||
_add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))
|
||||
)
|
||||
r_ = tree_map(partial(jnp.where, exit_early),
|
||||
s, _sub(s, _mul(omega_, t)))
|
||||
k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
|
||||
k_ = jnp.where((rho_ == 0), -10, k_)
|
||||
return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_
|
||||
|
||||
r0 = _sub(b, A(x0))
|
||||
rho0 = alpha0 = omega0 = lax_internal._convert_element_type(
|
||||
1, *dtypes.lattice_result_type(*tree_leaves(b)))
|
||||
initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0)
|
||||
|
||||
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
|
||||
|
||||
return x_final
|
||||
|
||||
|
||||
def _shapes(pytree):
|
||||
return map(np.shape, tree_leaves(pytree))
|
||||
|
||||
|
||||
def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
|
||||
maxiter=None, M=None, check_symmetric=False):
|
||||
if x0 is None:
|
||||
x0 = tree_map(jnp.zeros_like, b)
|
||||
|
||||
b, x0 = api.device_put((b, x0))
|
||||
|
||||
if maxiter is None:
|
||||
size = sum(bi.size for bi in tree_leaves(b))
|
||||
maxiter = 10 * size # copied from scipy
|
||||
|
||||
if M is None:
|
||||
M = _identity
|
||||
A = _normalize_matvec(A)
|
||||
M = _normalize_matvec(M)
|
||||
|
||||
if tree_structure(x0) != tree_structure(b):
|
||||
raise ValueError(
|
||||
'x0 and b must have matching tree structure: '
|
||||
f'{tree_structure(x0)} vs {tree_structure(b)}')
|
||||
|
||||
if _shapes(x0) != _shapes(b):
|
||||
raise ValueError(
|
||||
'arrays in x0 and b must have matching shapes: '
|
||||
f'{_shapes(x0)} vs {_shapes(b)}')
|
||||
|
||||
isolve_solve = partial(
|
||||
_isolve_solve, x0=x0, tol=tol, atol=atol, maxiter=maxiter, M=M)
|
||||
|
||||
# real-valued positive-definite linear operators are symmetric
|
||||
def real_valued(x):
|
||||
return not issubclass(x.dtype.type, np.complexfloating)
|
||||
symmetric = all(map(real_valued, tree_leaves(b))) \
|
||||
if check_symmetric else False
|
||||
x = lax.custom_linear_solve(
|
||||
A, b, solve=isolve_solve, transpose_solve=isolve_solve,
|
||||
symmetric=symmetric)
|
||||
info = None
|
||||
return x, info
|
||||
|
||||
|
||||
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
|
||||
|
||||
The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
|
||||
numerical precision), but note that the interface is slightly different: you
|
||||
need to supply the linear operator ``A`` as a function instead of a sparse
|
||||
matrix or ``LinearOperator``.
|
||||
|
||||
Derivatives of ``cg`` are implemented via implicit differentiation with
|
||||
another ``cg`` solve, rather than by differentiating *through* the solver.
|
||||
They will be accurate only if both solves converge.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: ndarray, function, or matmul-compatible object
|
||||
2D array or function that calculates the linear map (matrix-vector
|
||||
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A`` must represent
|
||||
a hermitian, positive definite matrix, and must return array(s) with the
|
||||
same structure and shape as its argument.
|
||||
b : array or tree of arrays
|
||||
Right hand side of the linear system representing a single vector. Can be
|
||||
stored as an array or Python container of array(s) with any shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : array or tree of arrays
|
||||
The converged solution. Has the same structure as ``b``.
|
||||
info : None
|
||||
Placeholder for convergence information. In the future, JAX will report
|
||||
the number of iterations when convergence is not achieved, like SciPy.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
x0 : array or tree of arrays
|
||||
Starting guess for the solution. Must have the same structure as ``b``.
|
||||
tol, atol : float, optional
|
||||
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
|
||||
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
|
||||
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
|
||||
maxiter : integer
|
||||
Maximum number of iterations. Iteration will stop after maxiter
|
||||
steps even if the specified tolerance has not been achieved.
|
||||
M : ndarray, function, or matmul-compatible object
|
||||
Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance.
|
||||
|
||||
See also
|
||||
--------
|
||||
scipy.sparse.linalg.cg
|
||||
jax.lax.custom_linear_solve
|
||||
"""
|
||||
return _isolve(_cg_solve,
|
||||
A=A, b=b, x0=x0, tol=tol, atol=atol,
|
||||
maxiter=maxiter, M=M, check_symmetric=True)
|
||||
|
||||
|
||||
def _safe_normalize(x, thresh=None):
|
||||
"""
|
||||
Returns the L2-normalized vector (which can be a pytree) x, and optionally
|
||||
the computed norm. If the computed norm is less than the threshold `thresh`,
|
||||
which by default is the machine precision of x's dtype, it will be
|
||||
taken to be 0, and the normalized x to be the zero vector.
|
||||
"""
|
||||
norm = _norm(x)
|
||||
dtype, weak_type = dtypes.lattice_result_type(*tree_leaves(x))
|
||||
if thresh is None:
|
||||
thresh = dtypes.finfo(norm.dtype).eps
|
||||
thresh = thresh.astype(dtype).real
|
||||
|
||||
use_norm = norm > thresh
|
||||
|
||||
norm_cast = lax_internal._convert_element_type(norm, dtype, weak_type)
|
||||
normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm_cast, 0.0), x)
|
||||
norm = jnp.where(use_norm, norm, 0.0)
|
||||
return normalized_x, norm
|
||||
|
||||
|
||||
def _project_on_columns(A, v):
|
||||
"""
|
||||
Returns A.T.conj() @ v.
|
||||
"""
|
||||
v_proj = tree_map(
|
||||
lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
|
||||
)
|
||||
return tree_reduce(operator.add, v_proj)
|
||||
|
||||
|
||||
def _iterative_classical_gram_schmidt(Q, x, xnorm, max_iterations=2):
|
||||
"""
|
||||
Orthogonalize x against the columns of Q. The process is repeated
|
||||
up to `max_iterations` times, or fewer if the condition
|
||||
||r|| < (1/sqrt(2)) ||x|| is met earlier (see below for the meaning
|
||||
of r and x).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Q : array or tree of arrays
|
||||
A matrix of orthonormal columns.
|
||||
x : array or tree of arrays
|
||||
A vector. It will be replaced with a new vector q which is orthonormal
|
||||
to the columns of Q, such that x in span(col(Q), q).
|
||||
xnorm : float
|
||||
Norm of x.
|
||||
|
||||
Returns
|
||||
-------
|
||||
q : array or tree of arrays
|
||||
A unit vector, orthonormal to each column of Q, such that
|
||||
x in span(col(Q), q).
|
||||
r : array
|
||||
Stores the overlaps of x with each vector in Q.
|
||||
"""
|
||||
# "twice is enough"
|
||||
# http://slepc.upv.es/documentation/reports/str1.pdf
|
||||
|
||||
# TODO(shoyer): consider switching to only one iteration, like SciPy?
|
||||
|
||||
# This assumes that Q's leaves all have the same dimension in the last
|
||||
# axis.
|
||||
Q0 = tree_leaves(Q)[0]
|
||||
r = jnp.zeros(Q0.shape[-1], dtype=Q0.dtype)
|
||||
q = x
|
||||
xnorm_scaled = xnorm / jnp.sqrt(2.0)
|
||||
|
||||
def body_function(carry):
|
||||
k, q, r, qnorm_scaled = carry
|
||||
h = _project_on_columns(Q, q)
|
||||
Qh = tree_map(lambda X: _dot(X, h), Q)
|
||||
q = _sub(q, Qh)
|
||||
r = _add(r, h)
|
||||
|
||||
def qnorm_cond(carry):
|
||||
k, not_done, _, _ = carry
|
||||
return jnp.logical_and(not_done, k < (max_iterations - 1))
|
||||
|
||||
def qnorm(carry):
|
||||
k, _, q, qnorm_scaled = carry
|
||||
_, qnorm = _safe_normalize(q)
|
||||
qnorm_scaled = qnorm / jnp.sqrt(2.0)
|
||||
return (k, False, q, qnorm_scaled)
|
||||
|
||||
init = (k, True, q, qnorm_scaled)
|
||||
_, _, q, qnorm_scaled = lax.while_loop(qnorm_cond, qnorm, init)
|
||||
return (k + 1, q, r, qnorm_scaled)
|
||||
|
||||
def cond_function(carry):
|
||||
k, _, r, qnorm_scaled = carry
|
||||
_, rnorm = _safe_normalize(r)
|
||||
return jnp.logical_and(k < (max_iterations - 1), rnorm < qnorm_scaled)
|
||||
|
||||
k, q, r, qnorm_scaled = body_function((0, q, r, xnorm_scaled))
|
||||
k, q, r, _ = lax.while_loop(cond_function, body_function,
|
||||
(k, q, r, qnorm_scaled))
|
||||
return q, r
|
||||
|
||||
|
||||
def _kth_arnoldi_iteration(k, A, M, V, H):
|
||||
"""
|
||||
Performs a single (the k'th) step of the Arnoldi process. Thus,
|
||||
adds a new orthonormalized Krylov vector A(M(V[:, k])) to V[:, k+1],
|
||||
and that vectors overlaps with the existing Krylov vectors to
|
||||
H[k, :]. The tolerance 'tol' sets the threshold at which an invariant
|
||||
subspace is declared to have been found, in which case in which case the new
|
||||
vector is taken to be the zero vector.
|
||||
"""
|
||||
dtype, _ = dtypes.lattice_result_type(*tree_leaves(V))
|
||||
eps = dtypes.finfo(dtype).eps
|
||||
|
||||
v = tree_map(lambda x: x[..., k], V) # Gets V[:, k]
|
||||
v = M(A(v))
|
||||
_, v_norm_0 = _safe_normalize(v)
|
||||
v, h = _iterative_classical_gram_schmidt(V, v, v_norm_0, max_iterations=2)
|
||||
|
||||
tol = eps * v_norm_0
|
||||
unit_v, v_norm_1 = _safe_normalize(v, thresh=tol)
|
||||
V = tree_map(lambda X, y: X.at[..., k + 1].set(y), V, unit_v)
|
||||
|
||||
h = h.at[k + 1].set(v_norm_1.astype(dtype))
|
||||
H = H.at[k, :].set(h)
|
||||
breakdown = v_norm_1 == 0.
|
||||
return V, H, breakdown
|
||||
|
||||
|
||||
def _rotate_vectors(H, i, cs, sn):
|
||||
x1 = H[i]
|
||||
y1 = H[i + 1]
|
||||
x2 = cs.conj() * x1 - sn.conj() * y1
|
||||
y2 = sn * x1 + cs * y1
|
||||
H = H.at[i].set(x2)
|
||||
H = H.at[i + 1].set(y2)
|
||||
return H
|
||||
|
||||
|
||||
def _givens_rotation(a, b):
|
||||
b_zero = abs(b) == 0
|
||||
a_lt_b = abs(a) < abs(b)
|
||||
t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
|
||||
r = lax.rsqrt(1 + abs(t) ** 2).astype(t.dtype)
|
||||
cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
|
||||
sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
|
||||
return cs, sn
|
||||
|
||||
|
||||
def _apply_givens_rotations(H_row, givens, k):
|
||||
"""
|
||||
Applies the Givens rotations stored in the vectors cs and sn to the vector
|
||||
H_row. Then constructs and applies a new Givens rotation that eliminates
|
||||
H_row's k'th element.
|
||||
"""
|
||||
# This call successively applies each of the
|
||||
# Givens rotations stored in givens[:, :k] to H_col.
|
||||
|
||||
def apply_ith_rotation(i, H_row):
|
||||
return _rotate_vectors(H_row, i, *givens[i, :])
|
||||
R_row = lax.fori_loop(0, k, apply_ith_rotation, H_row)
|
||||
|
||||
givens_factors = _givens_rotation(R_row[k], R_row[k + 1])
|
||||
givens = givens.at[k, :].set(givens_factors)
|
||||
R_row = _rotate_vectors(R_row, k, *givens_factors)
|
||||
return R_row, givens
|
||||
|
||||
|
||||
def _gmres_incremental(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
|
||||
"""
|
||||
Implements a single restart of GMRES. The restart-dimensional Krylov subspace
|
||||
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
|
||||
projection of the true solution into this subspace is returned.
|
||||
|
||||
This implementation builds the QR factorization during the Arnoldi process.
|
||||
"""
|
||||
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
|
||||
|
||||
V = tree_map(
|
||||
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
|
||||
unit_residual,
|
||||
)
|
||||
dtype = dtypes.result_type(*tree_leaves(b))
|
||||
# use eye() to avoid constructing a singular matrix in case of early
|
||||
# termination
|
||||
R = jnp.eye(restart, restart + 1, dtype=dtype)
|
||||
|
||||
givens = jnp.zeros((restart, 2), dtype=dtype)
|
||||
beta_vec = jnp.zeros((restart + 1), dtype=dtype)
|
||||
beta_vec = beta_vec.at[0].set(residual_norm.astype(dtype))
|
||||
|
||||
def loop_cond(carry):
|
||||
k, err, _, _, _, _ = carry
|
||||
return jnp.logical_and(k < restart, err > ptol)
|
||||
|
||||
def arnoldi_qr_step(carry):
|
||||
k, _, V, R, beta_vec, givens = carry
|
||||
V, H, _ = _kth_arnoldi_iteration(k, A, M, V, R)
|
||||
R_row, givens = _apply_givens_rotations(H[k, :], givens, k)
|
||||
R = R.at[k, :].set(R_row)
|
||||
beta_vec = _rotate_vectors(beta_vec, k, *givens[k, :])
|
||||
err = abs(beta_vec[k + 1])
|
||||
return k + 1, err, V, R, beta_vec, givens
|
||||
|
||||
carry = (0, residual_norm, V, R, beta_vec, givens)
|
||||
carry = lax.while_loop(loop_cond, arnoldi_qr_step, carry)
|
||||
k, residual_norm, V, R, beta_vec, _ = carry
|
||||
del k # Until we figure out how to pass this to the user.
|
||||
|
||||
y = jsp_linalg.solve_triangular(R[:, :-1].T, beta_vec[:-1])
|
||||
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
|
||||
x = _add(x0, dx)
|
||||
residual = M(_sub(b, A(x)))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
# TODO(shoyer): "Inner loop tolerance control" on ptol, like SciPy
|
||||
return x, unit_residual, residual_norm
|
||||
|
||||
|
||||
def _lstsq(a, b):
|
||||
# faster than jsp_linalg.lstsq
|
||||
a2 = _dot(a.T.conj(), a)
|
||||
b2 = _dot(a.T.conj(), b)
|
||||
return jsp_linalg.solve(a2, b2, assume_a='pos')
|
||||
|
||||
|
||||
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
|
||||
"""
|
||||
Implements a single restart of GMRES. The ``restart``-dimensional Krylov
|
||||
subspace
|
||||
K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
|
||||
projection of the true solution into this subspace is returned.
|
||||
|
||||
This implementation solves a dense linear problem instead of building
|
||||
a QR factorization during the Arnoldi process.
|
||||
"""
|
||||
del ptol # unused
|
||||
# https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
|
||||
V = tree_map(
|
||||
lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
|
||||
unit_residual,
|
||||
)
|
||||
dtype, weak_type = dtypes.lattice_result_type(*tree_leaves(b))
|
||||
H = lax_internal._convert_element_type(
|
||||
jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type)
|
||||
|
||||
def loop_cond(carry):
|
||||
_, _, breakdown, k = carry
|
||||
return jnp.logical_and(k < restart, jnp.logical_not(breakdown))
|
||||
|
||||
def arnoldi_process(carry):
|
||||
V, H, _, k = carry
|
||||
V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H)
|
||||
return V, H, breakdown, k + 1
|
||||
|
||||
carry = (V, H, False, 0)
|
||||
V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)
|
||||
|
||||
beta_vec = jnp.zeros_like(H, shape=(restart + 1,)).at[0].set(residual_norm.astype(dtype))
|
||||
y = _lstsq(H.T, beta_vec)
|
||||
dx = tree_map(lambda X: _dot(X[..., :-1], y), V)
|
||||
|
||||
x = _add(x0, dx)
|
||||
|
||||
residual = M(_sub(b, A(x)))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
return x, unit_residual, residual_norm
|
||||
|
||||
|
||||
def _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func):
|
||||
"""
|
||||
The main function call wrapped by custom_linear_solve. Repeatedly calls GMRES
|
||||
to find the projected solution within the order-``restart``
|
||||
Krylov space K(A, x0, restart), using the result of the previous projection
|
||||
in place of x0 each time. Parameters are the same as in ``gmres`` except:
|
||||
|
||||
atol: Tolerance for norm(A(x) - b), used between restarts.
|
||||
ptol: Tolerance for norm(M(A(x) - b)), used within a restart.
|
||||
gmres_func: A function performing a single GMRES restart.
|
||||
|
||||
Returns: The solution.
|
||||
"""
|
||||
residual = M(_sub(b, A(x0)))
|
||||
unit_residual, residual_norm = _safe_normalize(residual)
|
||||
|
||||
def cond_fun(value):
|
||||
_, k, _, residual_norm = value
|
||||
return jnp.logical_and(k < maxiter, residual_norm > atol)
|
||||
|
||||
def body_fun(value):
|
||||
x, k, unit_residual, residual_norm = value
|
||||
x, unit_residual, residual_norm = gmres_func(
|
||||
A, b, x, unit_residual, residual_norm, ptol, restart, M)
|
||||
return x, k + 1, unit_residual, residual_norm
|
||||
|
||||
initialization = (x0, 0, unit_residual, residual_norm)
|
||||
x_final, k, _, err = lax.while_loop(cond_fun, body_fun, initialization)
|
||||
_ = k # Until we can pass this out
|
||||
_ = err
|
||||
return x_final # , info
|
||||
|
||||
|
||||
def gmres(A, b, x0=None, *, tol=1e-5, atol=0.0, restart=20, maxiter=None,
|
||||
M=None, solve_method='batched'):
|
||||
"""
|
||||
GMRES solves the linear system A x = b for x, given A and b.
|
||||
|
||||
A is specified as a function performing A(vi) -> vf = A @ vi, and in principle
|
||||
need not have any particular special properties, such as symmetry. However,
|
||||
convergence is often slow for nearly symmetric operators.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: ndarray, function, or matmul-compatible object
|
||||
2D array or function that calculates the linear map (matrix-vector
|
||||
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A``
|
||||
must return array(s) with the same structure and shape as its argument.
|
||||
b : array or tree of arrays
|
||||
Right hand side of the linear system representing a single vector. Can be
|
||||
stored as an array or Python container of array(s) with any shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : array or tree of arrays
|
||||
The converged solution. Has the same structure as ``b``.
|
||||
info : None
|
||||
Placeholder for convergence information. In the future, JAX will report
|
||||
the number of iterations when convergence is not achieved, like SciPy.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
x0 : array or tree of arrays, optional
|
||||
Starting guess for the solution. Must have the same structure as ``b``.
|
||||
If this is unspecified, zeroes are used.
|
||||
tol, atol : float, optional
|
||||
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
|
||||
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
|
||||
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``gmres``.
|
||||
restart : integer, optional
|
||||
Size of the Krylov subspace ("number of iterations") built between
|
||||
restarts. GMRES works by approximating the true solution x as its
|
||||
projection into a Krylov space of this dimension - this parameter
|
||||
therefore bounds the maximum accuracy achievable from any guess
|
||||
solution. Larger values increase both number of iterations and iteration
|
||||
cost, but may be necessary for convergence. The algorithm terminates
|
||||
early if convergence is achieved before the full subspace is built.
|
||||
Default is 20.
|
||||
maxiter : integer
|
||||
Maximum number of times to rebuild the size-``restart`` Krylov space
|
||||
starting from the solution found at the last iteration. If GMRES
|
||||
halts or is very slow, decreasing this parameter may help.
|
||||
Default is infinite.
|
||||
M : ndarray, function, or matmul-compatible object
|
||||
Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance.
|
||||
solve_method : 'incremental' or 'batched'
|
||||
The 'incremental' solve method builds a QR decomposition for the Krylov
|
||||
subspace incrementally during the GMRES process using Givens rotations.
|
||||
This improves numerical stability and gives a free estimate of the
|
||||
residual norm that allows for early termination within a single "restart".
|
||||
In contrast, the 'batched' solve method solves the least squares problem
|
||||
from scratch at the end of each GMRES iteration. It does not allow for
|
||||
early termination, but has much less overhead on GPUs.
|
||||
|
||||
See also
|
||||
--------
|
||||
scipy.sparse.linalg.gmres
|
||||
jax.lax.custom_linear_solve
|
||||
"""
|
||||
|
||||
if x0 is None:
|
||||
x0 = tree_map(jnp.zeros_like, b)
|
||||
if M is None:
|
||||
M = _identity
|
||||
A = _normalize_matvec(A)
|
||||
M = _normalize_matvec(M)
|
||||
|
||||
b, x0 = api.device_put((b, x0))
|
||||
size = sum(bi.size for bi in tree_leaves(b))
|
||||
|
||||
if maxiter is None:
|
||||
maxiter = 10 * size # copied from scipy
|
||||
restart = min(restart, size)
|
||||
|
||||
if tree_structure(x0) != tree_structure(b):
|
||||
raise ValueError(
|
||||
'x0 and b must have matching tree structure: '
|
||||
f'{tree_structure(x0)} vs {tree_structure(b)}')
|
||||
|
||||
b_norm = _norm(b)
|
||||
atol = jnp.maximum(tol * b_norm, atol)
|
||||
|
||||
Mb = M(b)
|
||||
Mb_norm = _norm(Mb)
|
||||
ptol = Mb_norm * jnp.minimum(1.0, atol / b_norm)
|
||||
|
||||
if solve_method == 'incremental':
|
||||
gmres_func = _gmres_incremental
|
||||
elif solve_method == 'batched':
|
||||
gmres_func = _gmres_batched
|
||||
else:
|
||||
raise ValueError(f"invalid solve_method {solve_method}, must be either "
|
||||
"'incremental' or 'batched'")
|
||||
|
||||
def _solve(A, b):
|
||||
return _gmres_solve(A, b, x0, atol, ptol, restart, maxiter, M, gmres_func)
|
||||
x = lax.custom_linear_solve(A, b, solve=_solve, transpose_solve=_solve)
|
||||
|
||||
failed = jnp.isnan(_norm(x))
|
||||
info = jnp.where(failed, -1, 0)
|
||||
return x, info
|
||||
|
||||
|
||||
def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None):
|
||||
"""Use Bi-Conjugate Gradient Stable iteration to solve ``Ax = b``.
|
||||
|
||||
The numerics of JAX's ``bicgstab`` should exact match SciPy's
|
||||
``bicgstab`` (up to numerical precision), but note that the interface
|
||||
is slightly different: you need to supply the linear operator ``A`` as
|
||||
a function instead of a sparse matrix or ``LinearOperator``.
|
||||
|
||||
As with ``cg``, derivatives of ``bicgstab`` are implemented via implicit
|
||||
differentiation with another ``bicgstab`` solve, rather than by
|
||||
differentiating *through* the solver. They will be accurate only if
|
||||
both solves converge.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: ndarray, function, or matmul-compatible object
|
||||
2D array or function that calculates the linear map (matrix-vector
|
||||
product) ``Ax`` when called like ``A(x)`` or ``A @ x``. ``A`` can represent
|
||||
any general (nonsymmetric) linear operator, and function must return array(s)
|
||||
with the same structure and shape as its argument.
|
||||
b : array or tree of arrays
|
||||
Right hand side of the linear system representing a single vector. Can be
|
||||
stored as an array or Python container of array(s) with any shape.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : array or tree of arrays
|
||||
The converged solution. Has the same structure as ``b``.
|
||||
info : None
|
||||
Placeholder for convergence information. In the future, JAX will report
|
||||
the number of iterations when convergence is not achieved, like SciPy.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
x0 : array or tree of arrays
|
||||
Starting guess for the solution. Must have the same structure as ``b``.
|
||||
tol, atol : float, optional
|
||||
Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
|
||||
We do not implement SciPy's "legacy" behavior, so JAX's tolerance will
|
||||
differ from SciPy unless you explicitly pass ``atol`` to SciPy's ``cg``.
|
||||
maxiter : integer
|
||||
Maximum number of iterations. Iteration will stop after maxiter
|
||||
steps even if the specified tolerance has not been achieved.
|
||||
M : ndarray, function, or matmul-compatible object
|
||||
Preconditioner for A. The preconditioner should approximate the
|
||||
inverse of A. Effective preconditioning dramatically improves the
|
||||
rate of convergence, which implies that fewer iterations are needed
|
||||
to reach a given error tolerance.
|
||||
|
||||
See also
|
||||
--------
|
||||
scipy.sparse.linalg.bicgstab
|
||||
jax.lax.custom_linear_solve
|
||||
"""
|
||||
|
||||
return _isolve(_bicgstab_solve,
|
||||
A=A, b=b, x0=x0, tol=tol, atol=atol,
|
||||
maxiter=maxiter, M=M)
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,463 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import re
|
||||
import typing
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import linalg as jnp_linalg
|
||||
from jax._src.numpy import vectorize as jnp_vectorize
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
class Rotation(typing.NamedTuple):
|
||||
"""Rotation in 3 dimensions.
|
||||
|
||||
JAX implementation of :class:`scipy.spatial.transform.Rotation`.
|
||||
|
||||
Examples:
|
||||
Construct an object describing a 90 degree rotation about the z-axis:
|
||||
|
||||
>>> from jax.scipy.spatial.transform import Rotation
|
||||
>>> r = Rotation.from_euler('z', 90, degrees=True)
|
||||
|
||||
Convert to a rotation vector:
|
||||
|
||||
>>> r.as_rotvec()
|
||||
Array([0. , 0. , 1.5707964], dtype=float32)
|
||||
|
||||
Convert to rotation matrix:
|
||||
|
||||
>>> r.as_matrix()
|
||||
Array([[ 0. , -0.99999994, 0. ],
|
||||
[ 0.99999994, 0. , 0. ],
|
||||
[ 0. , 0. , 0.99999994]], dtype=float32)
|
||||
|
||||
Compose with another rotation:
|
||||
|
||||
>>> r2 = Rotation.from_euler('x', 90, degrees=True)
|
||||
>>> r3 = r * r2
|
||||
>>> r3.as_matrix()
|
||||
Array([[0., 0., 1.],
|
||||
[1., 0., 0.],
|
||||
[0., 1., 0.]], dtype=float32)
|
||||
|
||||
See the scipy :class:`~scipy.spatial.transform.Rotation` documentation for
|
||||
further examples of manipulating Rotation objects.
|
||||
"""
|
||||
quat: Array
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, rotations: typing.Sequence):
|
||||
"""Concatenate a sequence of `Rotation` objects."""
|
||||
return cls(jnp.concatenate([rotation.quat for rotation in rotations]))
|
||||
|
||||
@classmethod
|
||||
def from_euler(cls, seq: str, angles: Array, degrees: bool = False):
|
||||
"""Initialize from Euler angles."""
|
||||
num_axes = len(seq)
|
||||
if num_axes < 1 or num_axes > 3:
|
||||
raise ValueError("Expected axis specification to be a non-empty "
|
||||
"string of upto 3 characters, got {}".format(seq))
|
||||
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
|
||||
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
|
||||
if not (intrinsic or extrinsic):
|
||||
raise ValueError("Expected axes from `seq` to be from ['x', 'y', "
|
||||
"'z'] or ['X', 'Y', 'Z'], got {}".format(seq))
|
||||
if any(seq[i] == seq[i+1] for i in range(num_axes - 1)):
|
||||
raise ValueError("Expected consecutive axes to be different, "
|
||||
"got {}".format(seq))
|
||||
angles = jnp.atleast_1d(angles)
|
||||
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
|
||||
return cls(_elementary_quat_compose(angles, axes, intrinsic, degrees))
|
||||
|
||||
@classmethod
|
||||
def from_matrix(cls, matrix: Array):
|
||||
"""Initialize from rotation matrix."""
|
||||
return cls(_from_matrix(matrix))
|
||||
|
||||
@classmethod
|
||||
def from_mrp(cls, mrp: Array):
|
||||
"""Initialize from Modified Rodrigues Parameters (MRPs)."""
|
||||
return cls(_from_mrp(mrp))
|
||||
|
||||
@classmethod
|
||||
def from_quat(cls, quat: Array):
|
||||
"""Initialize from quaternions."""
|
||||
return cls(_normalize_quaternion(quat))
|
||||
|
||||
@classmethod
|
||||
def from_rotvec(cls, rotvec: Array, degrees: bool = False):
|
||||
"""Initialize from rotation vectors."""
|
||||
return cls(_from_rotvec(rotvec, degrees))
|
||||
|
||||
@classmethod
|
||||
def identity(cls, num: int | None = None, dtype=float):
|
||||
"""Get identity rotation(s)."""
|
||||
assert num is None
|
||||
quat = jnp.array([0., 0., 0., 1.], dtype=dtype)
|
||||
return cls(quat)
|
||||
|
||||
@classmethod
|
||||
def random(cls, random_key: Array, num: int | None = None):
|
||||
"""Generate uniformly distributed rotations."""
|
||||
# Need to implement scipy.stats.special_ortho_group for this to work...
|
||||
raise NotImplementedError()
|
||||
|
||||
def __getitem__(self, indexer):
|
||||
"""Extract rotation(s) at given index(es) from object."""
|
||||
if self.single:
|
||||
raise TypeError("Single rotation is not subscriptable.")
|
||||
return Rotation(self.quat[indexer])
|
||||
|
||||
def __len__(self):
|
||||
"""Number of rotations contained in this object."""
|
||||
if self.single:
|
||||
raise TypeError('Single rotation has no len().')
|
||||
else:
|
||||
return self.quat.shape[0]
|
||||
|
||||
def __mul__(self, other) -> Rotation:
|
||||
"""Compose this rotation with the other."""
|
||||
return Rotation.from_quat(_compose_quat(self.quat, other.quat))
|
||||
|
||||
def apply(self, vectors: Array, inverse: bool = False) -> Array:
|
||||
"""Apply this rotation to one or more vectors."""
|
||||
return _apply(self.as_matrix(), vectors, inverse)
|
||||
|
||||
def as_euler(self, seq: str, degrees: bool = False):
|
||||
"""Represent as Euler angles."""
|
||||
if len(seq) != 3:
|
||||
raise ValueError(f"Expected 3 axes, got {seq}.")
|
||||
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
|
||||
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
|
||||
if not (intrinsic or extrinsic):
|
||||
raise ValueError("Expected axes from `seq` to be from "
|
||||
"['x', 'y', 'z'] or ['X', 'Y', 'Z'], "
|
||||
"got {}".format(seq))
|
||||
if any(seq[i] == seq[i+1] for i in range(2)):
|
||||
raise ValueError("Expected consecutive axes to be different, "
|
||||
"got {}".format(seq))
|
||||
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
|
||||
with config.numpy_rank_promotion('allow'):
|
||||
return _compute_euler_from_quat(self.quat, axes, extrinsic, degrees)
|
||||
|
||||
def as_matrix(self) -> Array:
|
||||
"""Represent as rotation matrix."""
|
||||
return _as_matrix(self.quat)
|
||||
|
||||
def as_mrp(self) -> Array:
|
||||
"""Represent as Modified Rodrigues Parameters (MRPs)."""
|
||||
return _as_mrp(self.quat)
|
||||
|
||||
def as_rotvec(self, degrees: bool = False) -> Array:
|
||||
"""Represent as rotation vectors."""
|
||||
return _as_rotvec(self.quat, degrees)
|
||||
|
||||
def as_quat(self, canonical: bool=False, scalar_first: bool=False) -> Array:
|
||||
"""Represent as quaternions."""
|
||||
quat = _make_canonical(self.quat) if canonical else self.quat
|
||||
if scalar_first:
|
||||
return jnp.roll(quat, shift=1, axis=-1)
|
||||
return quat
|
||||
|
||||
def inv(self):
|
||||
"""Invert this rotation."""
|
||||
return Rotation(_inv(self.quat))
|
||||
|
||||
def magnitude(self) -> Array:
|
||||
"""Get the magnitude(s) of the rotation(s)."""
|
||||
return _magnitude(self.quat)
|
||||
|
||||
def mean(self, weights: Array | None = None):
|
||||
"""Get the mean of the rotations."""
|
||||
w = jnp.ones(self.quat.shape[0], dtype=self.quat.dtype) if weights is None else jnp.asarray(weights, dtype=self.quat.dtype)
|
||||
if w.ndim != 1:
|
||||
raise ValueError("Expected `weights` to be 1 dimensional, got "
|
||||
"shape {}.".format(w.shape))
|
||||
if w.shape[0] != len(self):
|
||||
raise ValueError("Expected `weights` to have number of values "
|
||||
"equal to number of rotations, got "
|
||||
"{} values and {} rotations.".format(w.shape[0], len(self)))
|
||||
K = jnp.dot(w[np.newaxis, :] * self.quat.T, self.quat)
|
||||
_, v = jnp_linalg.eigh(K)
|
||||
return Rotation(v[:, -1])
|
||||
|
||||
@property
|
||||
def single(self) -> bool:
|
||||
"""Whether this instance represents a single rotation."""
|
||||
return self.quat.ndim == 1
|
||||
|
||||
|
||||
class Slerp(typing.NamedTuple):
|
||||
"""Spherical Linear Interpolation of Rotations.
|
||||
|
||||
JAX implementation of :class:`scipy.spatial.transform.Slerp`.
|
||||
|
||||
Examples:
|
||||
Create a Slerp instance from a series of rotations:
|
||||
|
||||
>>> import math
|
||||
>>> from jax.scipy.spatial.transform import Rotation, Slerp
|
||||
>>> rots = jnp.array([[90, 0, 0],
|
||||
... [0, 45, 0],
|
||||
... [0, 0, -30]])
|
||||
>>> key_rotations = Rotation.from_euler('zxy', rots, degrees=True)
|
||||
>>> key_times = [0, 1, 2]
|
||||
>>> slerp = Slerp.init(key_times, key_rotations)
|
||||
>>> times = [0, 0.5, 1, 1.5, 2]
|
||||
>>> interp_rots = slerp(times)
|
||||
>>> interp_rots.as_euler('zxy')
|
||||
Array([[ 1.5707963e+00, 0.0000000e+00, 0.0000000e+00],
|
||||
[ 8.5309029e-01, 3.8711953e-01, 1.7768645e-01],
|
||||
[-2.3841858e-07, 7.8539824e-01, 0.0000000e+00],
|
||||
[-5.6668043e-02, 3.9213133e-01, -2.8347540e-01],
|
||||
[ 0.0000000e+00, 0.0000000e+00, -5.2359891e-01]], dtype=float32)
|
||||
"""
|
||||
|
||||
times: Array
|
||||
timedelta: Array
|
||||
rotations: Rotation
|
||||
rotvecs: Array
|
||||
|
||||
@classmethod
|
||||
def init(cls, times: Array, rotations: Rotation):
|
||||
if not isinstance(rotations, Rotation):
|
||||
raise TypeError("`rotations` must be a `Rotation` instance.")
|
||||
if rotations.single or len(rotations) == 1:
|
||||
raise ValueError("`rotations` must be a sequence of at least 2 rotations.")
|
||||
times = jnp.asarray(times, dtype=rotations.quat.dtype)
|
||||
if times.ndim != 1:
|
||||
raise ValueError("Expected times to be specified in a 1 "
|
||||
"dimensional array, got {} "
|
||||
"dimensions.".format(times.ndim))
|
||||
if times.shape[0] != len(rotations):
|
||||
raise ValueError("Expected number of rotations to be equal to "
|
||||
"number of timestamps given, got {} rotations "
|
||||
"and {} timestamps.".format(len(rotations), times.shape[0]))
|
||||
timedelta = jnp.diff(times)
|
||||
# if jnp.any(timedelta <= 0): # this causes a concretization error...
|
||||
# raise ValueError("Times must be in strictly increasing order.")
|
||||
new_rotations = Rotation(rotations.as_quat()[:-1])
|
||||
return cls(
|
||||
times=times,
|
||||
timedelta=timedelta,
|
||||
rotations=new_rotations,
|
||||
rotvecs=(new_rotations.inv() * Rotation(rotations.as_quat()[1:])).as_rotvec())
|
||||
|
||||
def __call__(self, times: Array):
|
||||
"""Interpolate rotations."""
|
||||
compute_times = jnp.asarray(times, dtype=self.times.dtype)
|
||||
if compute_times.ndim > 1:
|
||||
raise ValueError("`times` must be at most 1-dimensional.")
|
||||
single_time = compute_times.ndim == 0
|
||||
compute_times = jnp.atleast_1d(compute_times)
|
||||
ind = jnp.maximum(jnp.searchsorted(self.times, compute_times) - 1, 0)
|
||||
alpha = (compute_times - self.times[ind]) / self.timedelta[ind]
|
||||
result = (self.rotations[ind] * Rotation.from_rotvec(self.rotvecs[ind] * alpha[:, None]))
|
||||
if single_time:
|
||||
return result[0]
|
||||
return result
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(m,m),(m),()->(m)')
|
||||
def _apply(matrix: Array, vector: Array, inverse: bool) -> Array:
|
||||
return jnp.where(inverse, matrix.T, matrix) @ vector
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n,n)')
|
||||
def _as_matrix(quat: Array) -> Array:
|
||||
x = quat[0]
|
||||
y = quat[1]
|
||||
z = quat[2]
|
||||
w = quat[3]
|
||||
x2 = x * x
|
||||
y2 = y * y
|
||||
z2 = z * z
|
||||
w2 = w * w
|
||||
xy = x * y
|
||||
zw = z * w
|
||||
xz = x * z
|
||||
yw = y * w
|
||||
yz = y * z
|
||||
xw = x * w
|
||||
return jnp.array([[+ x2 - y2 - z2 + w2, 2 * (xy - zw), 2 * (xz + yw)],
|
||||
[2 * (xy + zw), - x2 + y2 - z2 + w2, 2 * (yz - xw)],
|
||||
[2 * (xz - yw), 2 * (yz + xw), - x2 - y2 + z2 + w2]])
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n)')
|
||||
def _as_mrp(quat: Array) -> Array:
|
||||
sign = jnp.where(quat[3] < 0, -1., 1.)
|
||||
denominator = 1. + sign * quat[3]
|
||||
return sign * quat[:3] / denominator
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(m),()->(n)')
|
||||
def _as_rotvec(quat: Array, degrees: bool) -> Array:
|
||||
quat = jnp.where(quat[3] < 0, -quat, quat) # w > 0 to ensure 0 <= angle <= pi
|
||||
angle = 2. * jnp.arctan2(_vector_norm(quat[:3]), quat[3])
|
||||
angle2 = angle * angle
|
||||
small_scale = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
|
||||
large_scale = angle / jnp.sin(angle / 2)
|
||||
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
|
||||
scale = jnp.where(degrees, jnp.rad2deg(scale), scale)
|
||||
return scale * jnp.array(quat[:3])
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(n),(n)->(n)')
|
||||
def _compose_quat(p: Array, q: Array) -> Array:
|
||||
cross = jnp.cross(p[:3], q[:3])
|
||||
return jnp.array([p[3]*q[0] + q[3]*p[0] + cross[0],
|
||||
p[3]*q[1] + q[3]*p[1] + cross[1],
|
||||
p[3]*q[2] + q[3]*p[2] + cross[2],
|
||||
p[3]*q[3] - p[0]*q[0] - p[1]*q[1] - p[2]*q[2]])
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(m),(l),(),()->(n)')
|
||||
def _compute_euler_from_quat(quat: Array, axes: Array, extrinsic: bool, degrees: bool) -> Array:
|
||||
angle_first = jnp.where(extrinsic, 0, 2)
|
||||
angle_third = jnp.where(extrinsic, 2, 0)
|
||||
axes = jnp.where(extrinsic, axes, axes[::-1])
|
||||
i = axes[0]
|
||||
j = axes[1]
|
||||
k = axes[2]
|
||||
symmetric = i == k
|
||||
k = jnp.where(symmetric, 3 - i - j, k)
|
||||
sign = jnp.array((i - j) * (j - k) * (k - i) // 2, dtype=quat.dtype)
|
||||
eps = 1e-7
|
||||
a = jnp.where(symmetric, quat[3], quat[3] - quat[j])
|
||||
b = jnp.where(symmetric, quat[i], quat[i] + quat[k] * sign)
|
||||
c = jnp.where(symmetric, quat[j], quat[j] + quat[3])
|
||||
d = jnp.where(symmetric, quat[k] * sign, quat[k] * sign - quat[i])
|
||||
angles = jnp.empty(3, dtype=quat.dtype)
|
||||
angles = angles.at[1].set(2 * jnp.arctan2(jnp.hypot(c, d), jnp.hypot(a, b)))
|
||||
case = jnp.where(jnp.abs(angles[1] - np.pi) <= eps, 2, 0)
|
||||
case = jnp.where(jnp.abs(angles[1]) <= eps, 1, case)
|
||||
half_sum = jnp.arctan2(b, a)
|
||||
half_diff = jnp.arctan2(d, c)
|
||||
angles = angles.at[0].set(jnp.where(case == 1, 2 * half_sum, 2 * half_diff * jnp.where(extrinsic, -1, 1))) # any degenerate case
|
||||
angles = angles.at[angle_first].set(jnp.where(case == 0, half_sum - half_diff, angles[angle_first]))
|
||||
angles = angles.at[angle_third].set(jnp.where(case == 0, half_sum + half_diff, angles[angle_third]))
|
||||
angles = angles.at[angle_third].set(jnp.where(symmetric, angles[angle_third], angles[angle_third] * sign))
|
||||
angles = angles.at[1].set(jnp.where(symmetric, angles[1], angles[1] - np.pi / 2))
|
||||
angles = (angles + np.pi) % (2 * np.pi) - np.pi
|
||||
return jnp.where(degrees, jnp.rad2deg(angles), angles)
|
||||
|
||||
|
||||
def _elementary_basis_index(axis: str) -> int:
|
||||
if axis == 'x':
|
||||
return 0
|
||||
elif axis == 'y':
|
||||
return 1
|
||||
elif axis == 'z':
|
||||
return 2
|
||||
raise ValueError(f"Expected axis to be from ['x', 'y', 'z'], got {axis}")
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature=('(m),(m),(),()->(n)'))
|
||||
def _elementary_quat_compose(angles: Array, axes: Array, intrinsic: bool, degrees: bool) -> Array:
|
||||
angles = jnp.where(degrees, jnp.deg2rad(angles), angles)
|
||||
result = _make_elementary_quat(axes[0], angles[0])
|
||||
for idx in range(1, len(axes)):
|
||||
quat = _make_elementary_quat(axes[idx], angles[idx])
|
||||
result = jnp.where(intrinsic, _compose_quat(result, quat), _compose_quat(quat, result))
|
||||
return result
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature=('(m),()->(n)'))
|
||||
def _from_rotvec(rotvec: Array, degrees: bool) -> Array:
|
||||
rotvec = jnp.where(degrees, jnp.deg2rad(rotvec), rotvec)
|
||||
angle = _vector_norm(rotvec)
|
||||
angle2 = angle * angle
|
||||
small_scale = scale = 0.5 - angle2 / 48 + angle2 * angle2 / 3840
|
||||
large_scale = jnp.sin(angle / 2) / angle
|
||||
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
|
||||
return jnp.hstack([scale * rotvec, jnp.cos(angle / 2)])
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature=('(m,m)->(n)'))
|
||||
def _from_matrix(matrix: Array) -> Array:
|
||||
matrix_trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2]
|
||||
decision = jnp.array([matrix[0, 0], matrix[1, 1], matrix[2, 2], matrix_trace], dtype=matrix.dtype)
|
||||
choice = jnp.argmax(decision)
|
||||
i = choice
|
||||
j = (i + 1) % 3
|
||||
k = (j + 1) % 3
|
||||
quat_012 = jnp.empty(4, dtype=matrix.dtype)
|
||||
quat_012 = quat_012.at[i].set(1 - decision[3] + 2 * matrix[i, i])
|
||||
quat_012 = quat_012.at[j].set(matrix[j, i] + matrix[i, j])
|
||||
quat_012 = quat_012.at[k].set(matrix[k, i] + matrix[i, k])
|
||||
quat_012 = quat_012.at[3].set(matrix[k, j] - matrix[j, k])
|
||||
quat_3 = jnp.empty(4, dtype=matrix.dtype)
|
||||
quat_3 = quat_3.at[0].set(matrix[2, 1] - matrix[1, 2])
|
||||
quat_3 = quat_3.at[1].set(matrix[0, 2] - matrix[2, 0])
|
||||
quat_3 = quat_3.at[2].set(matrix[1, 0] - matrix[0, 1])
|
||||
quat_3 = quat_3.at[3].set(1 + decision[3])
|
||||
quat = jnp.where(choice != 3, quat_012, quat_3)
|
||||
return _normalize_quaternion(quat)
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(m)->(n)')
|
||||
def _from_mrp(mrp: Array) -> Array:
|
||||
mrp_squared_plus_1 = jnp.dot(mrp, mrp) + 1
|
||||
return jnp.hstack([2 * mrp[:3], (2 - mrp_squared_plus_1)]) / mrp_squared_plus_1
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)')
|
||||
def _inv(quat: Array) -> Array:
|
||||
return quat * jnp.array([-1, -1, -1, 1], dtype=quat.dtype)
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(n)->()')
|
||||
def _magnitude(quat: Array) -> Array:
|
||||
return 2. * jnp.arctan2(_vector_norm(quat[:3]), jnp.abs(quat[3]))
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(),()->(n)')
|
||||
def _make_elementary_quat(axis: int, angle: Array) -> Array:
|
||||
quat = jnp.zeros(4, dtype=angle.dtype)
|
||||
quat = quat.at[3].set(jnp.cos(angle / 2.))
|
||||
quat = quat.at[axis].set(jnp.sin(angle / 2.))
|
||||
return quat
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)')
|
||||
def _normalize_quaternion(quat: Array) -> Array:
|
||||
return quat / _vector_norm(quat)
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(n)->()')
|
||||
def _vector_norm(vector: Array) -> Array:
|
||||
return jnp.sqrt(jnp.dot(vector, vector))
|
||||
|
||||
|
||||
@functools.partial(jnp_vectorize.vectorize, signature='(n)->(n)')
|
||||
def _make_canonical(quat: Array) -> Array:
|
||||
is_neg = quat < 0
|
||||
is_zero = quat == 0
|
||||
|
||||
neg = (
|
||||
is_neg[3]
|
||||
| (is_zero[3] & is_neg[0])
|
||||
| (is_zero[3] & is_zero[0] & is_neg[1])
|
||||
| (is_zero[3] & is_zero[0] & is_zero[1] & is_neg[2])
|
||||
)
|
||||
|
||||
return jnp.where(neg, -quat, quat)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,13 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,311 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import check_arraylike, promote_args_inexact
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
|
||||
class ModeResult(NamedTuple):
|
||||
mode: Array
|
||||
count: Array
|
||||
|
||||
|
||||
@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims'])
|
||||
def mode(a: ArrayLike, axis: int | None = 0, nan_policy: str = "propagate", keepdims: bool = False) -> ModeResult:
|
||||
"""Compute the mode (most common value) along an axis of an array.
|
||||
|
||||
JAX implementation of :func:`scipy.stats.mode`.
|
||||
|
||||
Args:
|
||||
a: arraylike
|
||||
axis: int, default=0. Axis along which to compute the mode.
|
||||
nan_policy: str. JAX only supports ``"propagate"``.
|
||||
keepdims: bool, default=False. If true, reduced axes are left in the result
|
||||
with size 1.
|
||||
|
||||
Returns:
|
||||
A tuple of arrays, ``(mode, count)``. ``mode`` is the array of modal values,
|
||||
and ``count`` is the number of times each value appears in the input array.
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
|
||||
>>> mode, count = jax.scipy.stats.mode(x)
|
||||
>>> mode, count
|
||||
(Array(4, dtype=int32), Array(3, dtype=int32))
|
||||
|
||||
For multi dimensional arrays, ``jax.scipy.stats.mode`` computes the ``mode``
|
||||
and the corresponding ``count`` along ``axis=0``:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
|
||||
... [3, 1, 3, 2, 1, 3],
|
||||
... [1, 2, 2, 3, 1, 2]])
|
||||
>>> mode, count = jax.scipy.stats.mode(x1)
|
||||
>>> mode, count
|
||||
(Array([1, 2, 1, 3, 1, 1], dtype=int32), Array([2, 2, 1, 2, 2, 1], dtype=int32))
|
||||
|
||||
If ``axis=1``, ``mode`` and ``count`` will be computed along ``axis 1``.
|
||||
|
||||
>>> mode, count = jax.scipy.stats.mode(x1, axis=1)
|
||||
>>> mode, count
|
||||
(Array([1, 3, 2], dtype=int32), Array([3, 3, 3], dtype=int32))
|
||||
|
||||
By default, ``jax.scipy.stats.mode`` reduces the dimension of the result.
|
||||
To keep the dimensions same as that of the input array, the argument
|
||||
``keepdims`` must be set to ``True``.
|
||||
|
||||
>>> mode, count = jax.scipy.stats.mode(x1, axis=1, keepdims=True)
|
||||
>>> mode, count
|
||||
(Array([[1],
|
||||
[3],
|
||||
[2]], dtype=int32), Array([[3],
|
||||
[3],
|
||||
[3]], dtype=int32))
|
||||
"""
|
||||
check_arraylike("mode", a)
|
||||
x = jnp.atleast_1d(a)
|
||||
|
||||
if nan_policy not in ["propagate", "omit", "raise"]:
|
||||
raise ValueError(
|
||||
f"Illegal nan_policy value {nan_policy!r}; expected one of "
|
||||
"{'propagate', 'omit', 'raise'}"
|
||||
)
|
||||
if nan_policy == "omit":
|
||||
# TODO: return answer without nans included.
|
||||
raise NotImplementedError(
|
||||
f"Logic for `nan_policy` of {nan_policy} is not implemented"
|
||||
)
|
||||
if nan_policy == "raise":
|
||||
raise NotImplementedError(
|
||||
"In order to best JIT compile `mode`, we cannot know whether `x` contains nans. "
|
||||
"Please check if nans exist in `x` outside of the `mode` function."
|
||||
)
|
||||
if axis is not None:
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
|
||||
input_shape = x.shape
|
||||
if keepdims:
|
||||
if axis is None:
|
||||
output_shape = tuple(1 for i in input_shape)
|
||||
else:
|
||||
output_shape = tuple(1 if i == axis else s for i, s in enumerate(input_shape))
|
||||
else:
|
||||
if axis is None:
|
||||
output_shape = ()
|
||||
else:
|
||||
output_shape = tuple(s for i, s in enumerate(input_shape) if i != axis)
|
||||
|
||||
if axis is None:
|
||||
axis = 0
|
||||
x = x.ravel()
|
||||
|
||||
def _mode_helper(x: Array) -> tuple[Array, Array]:
|
||||
"""Helper function to return mode and count of a given array."""
|
||||
if x.size == 0:
|
||||
return (jnp.array(np.nan, dtype=dtypes.default_float_dtype()),
|
||||
jnp.array(0, dtype=dtypes.default_float_dtype()))
|
||||
else:
|
||||
vals, counts = jnp.unique(x, return_counts=True, size=x.size)
|
||||
return vals[jnp.argmax(counts)], counts.max()
|
||||
|
||||
x = jnp.moveaxis(x, axis, 0)
|
||||
x = x.reshape(x.shape[0], math.prod(x.shape[1:]))
|
||||
vals, counts = api.vmap(_mode_helper, in_axes=1)(x)
|
||||
return ModeResult(vals.reshape(output_shape), counts.reshape(output_shape))
|
||||
|
||||
def invert_permutation(i: Array) -> Array:
|
||||
"""Helper function that inverts a permutation array."""
|
||||
return jnp.empty_like(i).at[i].set(jnp.arange(i.size, dtype=i.dtype))
|
||||
|
||||
|
||||
@api.jit(static_argnames=["method", "axis", "nan_policy"])
|
||||
def rankdata(
|
||||
a: ArrayLike,
|
||||
method: str = "average",
|
||||
*,
|
||||
axis: int | None = None,
|
||||
nan_policy: str = "propagate",
|
||||
) -> Array:
|
||||
"""Compute the rank of data along an array axis.
|
||||
|
||||
JAX implementation of :func:`scipy.stats.rankdata`.
|
||||
|
||||
Ranks begin at 1, and the *method* argument controls how ties are handled.
|
||||
|
||||
Args:
|
||||
a: arraylike
|
||||
method: str, default="average". Supported methods are
|
||||
``("average", "min", "max", "dense", "ordinal")``
|
||||
For details, see the :func:`scipy.stats.rankdata` documentation.
|
||||
axis: optional integer. If not specified, the input array is flattened.
|
||||
nan_policy: str, JAX's implementation only supports ``"propagate"``.
|
||||
|
||||
Returns:
|
||||
array of ranks along the specified axis.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> x = jnp.array([10, 30, 20])
|
||||
>>> rankdata(x)
|
||||
Array([1., 3., 2.], dtype=float32)
|
||||
|
||||
>>> x = jnp.array([1, 3, 2, 3])
|
||||
>>> rankdata(x)
|
||||
Array([1. , 3.5, 2. , 3.5], dtype=float32)
|
||||
"""
|
||||
check_arraylike("rankdata", a)
|
||||
|
||||
if nan_policy not in ["propagate", "omit", "raise"]:
|
||||
raise ValueError(
|
||||
f"Illegal nan_policy value {nan_policy!r}; expected one of "
|
||||
"{'propagate', 'omit', 'raise'}"
|
||||
)
|
||||
if nan_policy == "omit":
|
||||
raise NotImplementedError(
|
||||
f"Logic for `nan_policy` of {nan_policy} is not implemented"
|
||||
)
|
||||
if nan_policy == "raise":
|
||||
raise NotImplementedError(
|
||||
"In order to best JIT compile `rankdata`, we cannot know whether `x` "
|
||||
"contains nans. Please check if nans exist in `x` outside of the "
|
||||
"`rankdata` function."
|
||||
)
|
||||
|
||||
if method not in ("average", "min", "max", "dense", "ordinal"):
|
||||
raise ValueError(f"unknown method '{method}'")
|
||||
|
||||
if axis is not None:
|
||||
return jnp.apply_along_axis(rankdata, axis, a, method)
|
||||
|
||||
a = jnp.ravel(a)
|
||||
out_dtype = dtypes.default_float_dtype()
|
||||
|
||||
def _rankdata(a: Array) -> Array:
|
||||
arr, sorter = lax.sort_key_val(a, jnp.arange(a.size))
|
||||
inv = invert_permutation(sorter)
|
||||
|
||||
if method == "ordinal":
|
||||
return (inv + 1).astype(out_dtype)
|
||||
obs = jnp.concatenate([jnp.array([True]), arr[1:] != arr[:-1]])
|
||||
dense = obs.cumsum()[inv]
|
||||
if method == "dense":
|
||||
return dense.astype(out_dtype)
|
||||
count = jnp.nonzero(obs, size=arr.size + 1, fill_value=obs.size)[0].astype(out_dtype)
|
||||
if method == "max":
|
||||
return count[dense]
|
||||
if method == "min":
|
||||
return count[dense - 1] + 1
|
||||
if method == "average":
|
||||
return .5 * (count[dense] + count[dense - 1] + 1)
|
||||
raise ValueError(f"unknown method '{method}'")
|
||||
|
||||
return lax.cond(jnp.any(jnp.isnan(a)),
|
||||
lambda a: jnp.full_like(a, jnp.nan, out_dtype),
|
||||
_rankdata, a)
|
||||
|
||||
|
||||
@api.jit(static_argnames=['axis', 'nan_policy', 'keepdims'])
|
||||
def sem(a: ArrayLike, axis: int | None = 0, ddof: int = 1, nan_policy: str = "propagate", *, keepdims: bool = False) -> Array:
|
||||
"""Compute the standard error of the mean.
|
||||
|
||||
JAX implementation of :func:`scipy.stats.sem`.
|
||||
|
||||
Args:
|
||||
a: arraylike
|
||||
axis: optional integer. If not specified, the input array is flattened.
|
||||
ddof: integer, default=1. The degrees of freedom in the SEM computation.
|
||||
nan_policy: str, default="propagate". JAX supports only "propagate" and
|
||||
"omit".
|
||||
keepdims: bool, default=False. If true, reduced axes are left in the result
|
||||
with size 1.
|
||||
|
||||
Returns:
|
||||
array
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([2, 4, 1, 1, 3, 4, 4, 2, 3])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x)
|
||||
Array(0.41, dtype=float32)
|
||||
|
||||
For multi dimensional arrays, ``sem`` computes standard error of mean along
|
||||
``axis=0``:
|
||||
|
||||
>>> x1 = jnp.array([[1, 2, 1, 3, 2, 1],
|
||||
... [3, 1, 3, 2, 1, 3],
|
||||
... [1, 2, 2, 3, 1, 2]])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1)
|
||||
Array([0.67, 0.33, 0.58, 0.33, 0.33, 0.58], dtype=float32)
|
||||
|
||||
If ``axis=1``, standard error of mean will be computed along ``axis 1``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=1)
|
||||
Array([0.33, 0.4 , 0.31], dtype=float32)
|
||||
|
||||
If ``axis=None``, standard error of mean will be computed along all the axes.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=None)
|
||||
Array(0.2, dtype=float32)
|
||||
|
||||
By default, ``sem`` reduces the dimension of the result. To keep the
|
||||
dimensions same as that of the input array, the argument ``keepdims`` must
|
||||
be set to ``True``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x1, axis=1, keepdims=True)
|
||||
Array([[0.33],
|
||||
[0.4 ],
|
||||
[0.31]], dtype=float32)
|
||||
|
||||
Since, by default, ``nan_policy='propagate'``, ``sem`` propagates the ``nan``
|
||||
values in the result.
|
||||
|
||||
>>> nan = np.nan
|
||||
>>> x2 = jnp.array([[1, 2, 3, nan, 4, 2],
|
||||
... [4, 5, 4, 3, nan, 1],
|
||||
... [7, nan, 8, 7, 9, nan]])
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x2)
|
||||
Array([1.73, nan, 1.53, nan, nan, nan], dtype=float32)
|
||||
|
||||
If ``nan_policy='omit```, ``sem`` omits the ``nan`` values and computes the error
|
||||
for the remaining values along the specified axis.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... jax.scipy.stats.sem(x2, nan_policy='omit')
|
||||
Array([1.73, 1.5 , 1.53, 2. , 2.5 , 0.5 ], dtype=float32)
|
||||
"""
|
||||
b, = promote_args_inexact("sem", a)
|
||||
if nan_policy == "propagate":
|
||||
size = b.size if axis is None else b.shape[axis]
|
||||
return b.std(axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(size).astype(b.dtype)
|
||||
elif nan_policy == "omit":
|
||||
count = (~jnp.isnan(b)).sum(axis, keepdims=keepdims)
|
||||
return jnp.nanstd(b, axis, ddof=ddof, keepdims=keepdims) / jnp.sqrt(count).astype(b.dtype)
|
||||
else:
|
||||
raise ValueError(f"{nan_policy} is not supported")
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import xlogy, xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Bernoulli log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.bernoulli` ``logpmf``
|
||||
|
||||
The Bernoulli probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = \begin{cases}
|
||||
1 - p, & k = 0 \\
|
||||
p, & k = 1 \\
|
||||
0, & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset
|
||||
|
||||
Returns:
|
||||
array of logpmf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.bernoulli.cdf`
|
||||
- :func:`jax.scipy.stats.bernoulli.pmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.ppf`
|
||||
"""
|
||||
k, p, loc = promote_args_inexact("bernoulli.logpmf", k, p, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
one = _lax_const(k, 1)
|
||||
x = lax.sub(k, loc)
|
||||
log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p)
|
||||
return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
|
||||
-np.inf, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Bernoulli probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.bernoulli` ``pmf``
|
||||
|
||||
The Bernoulli probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = \begin{cases}
|
||||
1 - p, & k = 0 \\
|
||||
p, & k = 1 \\
|
||||
0, & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset
|
||||
|
||||
Returns:
|
||||
array of pmf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.bernoulli.cdf`
|
||||
- :func:`jax.scipy.stats.bernoulli.logpmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.ppf`
|
||||
"""
|
||||
return jnp.exp(logpmf(k, p, loc))
|
||||
|
||||
|
||||
def cdf(k: ArrayLike, p: ArrayLike) -> Array:
|
||||
r"""Bernoulli cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.bernoulli` ``cdf``
|
||||
|
||||
The Bernoulli cumulative distribution function is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)
|
||||
|
||||
where :math:`f_{pmf}(k, p)` is the Bernoulli probability mass function
|
||||
:func:`jax.scipy.stats.bernoulli.pmf`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the CDF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset
|
||||
|
||||
Returns:
|
||||
array of cdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.bernoulli.logpmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.pmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.ppf`
|
||||
"""
|
||||
k, p = promote_args_inexact('bernoulli.cdf', k, p)
|
||||
zero, one = _lax_const(k, 0), _lax_const(k, 1)
|
||||
conds = [
|
||||
jnp.isnan(k) | jnp.isnan(p) | (p < zero) | (p > one),
|
||||
lax.lt(k, zero),
|
||||
jnp.logical_and(lax.ge(k, zero), lax.lt(k, one)),
|
||||
lax.ge(k, one)
|
||||
]
|
||||
vals = [jnp.nan, zero, one - p, one]
|
||||
return jnp.select(conds, vals)
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, p: ArrayLike) -> Array:
|
||||
"""Bernoulli percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.bernoulli` ``ppf``
|
||||
|
||||
The percent point function is the inverse of the cumulative
|
||||
distribution function, :func:`jax.scipy.stats.bernoulli.cdf`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PPF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset
|
||||
|
||||
Returns:
|
||||
array of ppf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.bernoulli.cdf`
|
||||
- :func:`jax.scipy.stats.bernoulli.logpmf`
|
||||
- :func:`jax.scipy.stats.bernoulli.pmf`
|
||||
"""
|
||||
q, p = promote_args_inexact('bernoulli.ppf', q, p)
|
||||
zero, one = _lax_const(q, 0), _lax_const(q, 1)
|
||||
return jnp.where(
|
||||
jnp.isnan(q) | jnp.isnan(p) | (p < zero) | (p > one) | (q < zero) | (q > one),
|
||||
jnp.nan,
|
||||
jnp.where(lax.le(q, one - p), zero, one)
|
||||
)
|
||||
@@ -0,0 +1,262 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import betaln, betainc, xlogy, xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``logpdf``.
|
||||
|
||||
The pdf of the beta function is:
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1}
|
||||
|
||||
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function,
|
||||
It is defined for :math:`0\le x\le 1` and :math:`b>0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.logpdf", x, a, b, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
zero = _lax_const(a, 0)
|
||||
shape_term = lax.neg(betaln(a, b))
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
|
||||
xlog1py(lax.sub(b, one), lax.neg(y)))
|
||||
log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
|
||||
result = jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
|
||||
lax.lt(x, loc)), -np.inf, log_probs)
|
||||
result_positive_constants = jnp.where(jnp.logical_or(jnp.logical_or(lax.le(a, zero), lax.le(b, zero)),
|
||||
lax.le(scale, zero)), np.nan, result)
|
||||
return result_positive_constants
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``pdf``.
|
||||
|
||||
The pdf of the beta function is:
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1}(1-x)^{b-1}
|
||||
|
||||
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
|
||||
It is defined for :math:`0\le x\le 1` and :math:`b>0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta cumulative distribution function
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the beta distribution probability density function,
|
||||
:func:`jax.scipy.stats.beta.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.cdf", x, a, b, loc, scale)
|
||||
return betainc(
|
||||
a,
|
||||
b,
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, a, b) = \int_{-\infty}^x f_{pdf}(y, a, b)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the beta distribution probability density function,
|
||||
:func:`jax.scipy.stats.beta.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
return lax.log(cdf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta distribution survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b)
|
||||
|
||||
where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function,
|
||||
:func:`jax.scipy.stats.beta.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
- :func:`jax.scipy.stats.beta.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("beta.sf", x, a, b, loc, scale)
|
||||
return betainc(
|
||||
b,
|
||||
a,
|
||||
1 - lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, 1),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Beta distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.beta` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, a, b) = 1 - f_{cdf}(x, a, b)
|
||||
|
||||
where :math:`f_{cdf}(x, a, b)` is the beta cumulative distribution function,
|
||||
:func:`jax.scipy.stats.beta.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.beta.cdf`
|
||||
- :func:`jax.scipy.stats.beta.pdf`
|
||||
- :func:`jax.scipy.stats.beta.sf`
|
||||
- :func:`jax.scipy.stats.beta.logcdf`
|
||||
- :func:`jax.scipy.stats.beta.logpdf`
|
||||
"""
|
||||
return lax.log(sf(x, a, b, loc, scale))
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import betaln
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@api.jit
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0) -> Array:
|
||||
r"""Beta-binomial log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.betabinom` ``logpmf``
|
||||
|
||||
The beta-binomial distribution's probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)}
|
||||
|
||||
where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is
|
||||
defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.betabinom.pmf`
|
||||
"""
|
||||
k, n, a, b, loc = promote_args_inexact("betabinom.logpmf", k, n, a, b, loc)
|
||||
y = lax.sub(lax.floor(k), loc)
|
||||
one = _lax_const(y, 1)
|
||||
zero = _lax_const(y, 0)
|
||||
combiln = lax.neg(lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n,y), one), lax.add(y,one))))
|
||||
beta_lns = lax.sub(betaln(lax.add(y,a), lax.add(lax.sub(n,y),b)), betaln(a,b))
|
||||
log_probs = lax.add(combiln, beta_lns)
|
||||
log_probs = jnp.where(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)), 0., log_probs)
|
||||
y_cond = jnp.logical_or(jnp.logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, n)),
|
||||
lax.le(lax.add(y, a), zero))
|
||||
log_probs = jnp.where(y_cond, -np.inf, log_probs)
|
||||
n_a_b_cond = jnp.logical_or(jnp.logical_or(lax.lt(n, zero), lax.le(a, zero)), lax.le(b, zero))
|
||||
return jnp.where(n_a_b_cond, np.nan, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, n: ArrayLike, a: ArrayLike, b: ArrayLike,
|
||||
loc: ArrayLike = 0) -> Array:
|
||||
r"""Beta-binomial probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.betabinom` ``pmf``.
|
||||
|
||||
The beta-binomial distribution's probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k, n, a, b) = {n \choose k}\frac{B(k+a,n-k-b)}{B(a,b)}
|
||||
|
||||
where :math:`B(a, b)` is the :func:`~jax.scipy.special.beta` function. It is
|
||||
defined for :math:`n\ge 0`, :math:`a>0`, :math:`b>0`, and non-negative integers `k`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.betabinom.logpmf`
|
||||
"""
|
||||
return lax.exp(logpmf(k, n, a, b, loc))
|
||||
@@ -0,0 +1,90 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.scipy.special import gammaln, xlogy, xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Binomial log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.binom` ``logpmf``.
|
||||
|
||||
The binomial probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k, n, p) = {n \choose k}p^k(1-p)^{n-k}
|
||||
|
||||
for :math:`0\le p\le 1` and non-negative integers :math:`k`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.binom.pmf`
|
||||
"""
|
||||
k, n, p, loc = promote_args_inexact("binom.logpmf", k, n, p, loc)
|
||||
y = lax.sub(k, loc)
|
||||
zero = _lax_const(y, 0)
|
||||
comb_term = lax.sub(
|
||||
gammaln(n + 1),
|
||||
lax.add(gammaln(y + 1), gammaln(n - y + 1))
|
||||
)
|
||||
log_linear_term = lax.add(xlogy(y, p), xlog1py(lax.sub(n, y), lax.neg(p)))
|
||||
log_probs = lax.add(comb_term, log_linear_term)
|
||||
y_n_cond = jnp.logical_or(jnp.logical_and(lax.eq(y, zero), lax.eq(n, zero)),
|
||||
lax.eq(log_linear_term, zero))
|
||||
log_probs = jnp.where(y_n_cond, 0., log_probs)
|
||||
return jnp.where(lax.ge(k, loc) & lax.lt(k, loc + n + 1), log_probs, -np.inf)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Binomial probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.binom` ``pmf``.
|
||||
|
||||
The binomial probability mass function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(k, n, p) = {n \choose k}p^k(1-p)^{n-k}
|
||||
|
||||
for :math:`0\le p\le 1` and non-negative integers :math:`k`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.binom.logpmf`
|
||||
"""
|
||||
return lax.exp(logpmf(k, n, p, loc))
|
||||
@@ -0,0 +1,293 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.ufuncs import arctan
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``logpdf``.
|
||||
|
||||
The Cauchy probability distribution function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{\pi(1 + x^2)}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("cauchy.logpdf", x, loc, scale)
|
||||
pi = _lax_const(x, np.pi)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
normalize_term = lax.log(lax.mul(pi, scale))
|
||||
return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``pdf``.
|
||||
|
||||
The Cauchy probability distribution function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{\pi(1 + x^2)}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the Cauchy probability distribution function,
|
||||
:func:`jax.scipy.stats.cauchy.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("cauchy.cdf", x, loc, scale)
|
||||
pi = _lax_const(x, np.pi)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
return lax.add(_lax_const(x, 0.5), lax.mul(lax.div(_lax_const(x, 1.), pi), arctan(scaled_x)))
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``logcdf``
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the Cauchy probability distribution function,
|
||||
:func:`jax.scipy.stats.cauchy.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
return lax.log(cdf(x, loc, scale))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.cauchy.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("cauchy.sf", x, loc, scale)
|
||||
return cdf(-x, -loc, scale)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``logsf``
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.cauchy.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("cauchy.logsf", x, loc, scale)
|
||||
return logcdf(-x, -loc, scale)
|
||||
|
||||
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy distribution inverse survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``isf``.
|
||||
|
||||
Returns the inverse of the survival function,
|
||||
:func:`jax.scipy.stats.cauchy.sf`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the ISF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of isf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.ppf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("cauchy.isf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
half_pi = _lax_const(q, np.pi / 2)
|
||||
unscaled = lax.tan(lax.sub(half_pi, lax.mul(pi, q)))
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Cauchy distribution percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.cauchy` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.cauchy.cdf`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the PPF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of ppf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.cauchy.cdf`
|
||||
- :func:`jax.scipy.stats.cauchy.pdf`
|
||||
- :func:`jax.scipy.stats.cauchy.sf`
|
||||
- :func:`jax.scipy.stats.cauchy.logcdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logpdf`
|
||||
- :func:`jax.scipy.stats.cauchy.logsf`
|
||||
- :func:`jax.scipy.stats.cauchy.isf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("cauchy.ppf", q, loc, scale)
|
||||
pi = _lax_const(q, np.pi)
|
||||
half_pi = _lax_const(q, np.pi / 2)
|
||||
unscaled = lax.tan(lax.sub(lax.mul(pi, q), half_pi))
|
||||
return lax.add(lax.mul(unscaled, scale), loc)
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import gammainc, gammaincc
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``logpdf``.
|
||||
|
||||
The chi-square probability distribution function is given by:
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, k) = \begin{cases}
|
||||
\frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
for :math:`k` degrees of freedom, and where :math:`\Gamma` is the
|
||||
:func:`~jax.scipy.special.gamma` function. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
x, df, loc, scale = promote_args_inexact("chi2.logpdf", x, df, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
two = _lax_const(x, 2)
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
df_on_two = lax.div(df, two)
|
||||
|
||||
kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))
|
||||
|
||||
nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))
|
||||
|
||||
log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
|
||||
return jnp.where(lax.lt(x, loc), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``pdf``.
|
||||
|
||||
The chi-square probability distribution function is given by:
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, k) = \begin{cases}
|
||||
\frac{x^{k/2-1}e^{-x/2}}{2^{k/2}\Gamma(k/2)} & x \ge 0 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
for :math:`k` degrees of freedom, and where :math:`\Gamma` is the
|
||||
:func:`~jax.scipy.special.gamma` function. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
x, df, loc, scale = promote_args_inexact("chi2.cdf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
return gammainc(
|
||||
lax.div(df, two),
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(
|
||||
lax.sub(x, loc),
|
||||
lax.mul(scale, two),
|
||||
),
|
||||
_lax_const(x, np.inf),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.chi2.pdf`. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
return lax.log(cdf(x, df, loc, scale))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
- :func:`jax.scipy.stats.chi2.logsf`
|
||||
"""
|
||||
x, df, loc, scale = promote_args_inexact("chi2.sf", x, df, loc, scale)
|
||||
two = _lax_const(scale, 2)
|
||||
return gammaincc(
|
||||
lax.div(df, two),
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(
|
||||
lax.sub(x, loc),
|
||||
lax.mul(scale, two),
|
||||
),
|
||||
_lax_const(x, np.inf),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Chi-square log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.chi2` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.chi2.cdf`. JAX follows the scipy
|
||||
convention of using ``df`` to denote degrees of freedom.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.chi2.cdf`
|
||||
- :func:`jax.scipy.stats.chi2.pdf`
|
||||
- :func:`jax.scipy.stats.chi2.sf`
|
||||
- :func:`jax.scipy.stats.chi2.logcdf`
|
||||
- :func:`jax.scipy.stats.chi2.logpdf`
|
||||
"""
|
||||
return lax.log(sf(x, df, loc, scale))
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def _is_simplex(x: Array) -> Array:
|
||||
x_sum = jnp.sum(x, axis=0)
|
||||
return jnp.all(x > 0, axis=0) & (abs(x_sum - 1) < 1E-6)
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, alpha: ArrayLike) -> Array:
|
||||
r"""Dirichlet log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.dirichlet` ``logpdf``.
|
||||
|
||||
The Dirichlet probability density function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
|
||||
|
||||
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
|
||||
in a :math:`K`-dimensional vector space.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
alpha: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.dirichlet.pdf`
|
||||
"""
|
||||
return _logpdf(*promote_dtypes_inexact(x, alpha))
|
||||
|
||||
def _logpdf(x: Array, alpha: Array) -> Array:
|
||||
if alpha.ndim != 1:
|
||||
raise ValueError(
|
||||
f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
|
||||
)
|
||||
if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
|
||||
raise ValueError(
|
||||
"`x` must have either the same number of entries as `alpha` "
|
||||
f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
|
||||
)
|
||||
one = _lax_const(x, 1)
|
||||
if x.shape[0] != alpha.shape[0]:
|
||||
x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
|
||||
normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
|
||||
if x.ndim > 1:
|
||||
alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
|
||||
log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
|
||||
return jnp.where(_is_simplex(x), log_probs, -np.inf)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, alpha: ArrayLike) -> Array:
|
||||
r"""Dirichlet probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.dirichlet` ``pdf``.
|
||||
|
||||
The Dirichlet probability density function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(\mathbf{x}) = \frac{1}{B(\mathbf{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1}
|
||||
|
||||
where :math:`B(\mathbf{\alpha})` is the :func:`~jax.scipy.special.beta` function
|
||||
in a :math:`K`-dimensional vector space.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
alpha: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.dirichlet.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, alpha))
|
||||
@@ -0,0 +1,270 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``logpdf``.
|
||||
|
||||
The Exponential probability distribution function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \begin{cases}
|
||||
e^{-x} & x \ge 0 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
|
||||
log_scale = lax.log(scale)
|
||||
linear_term = lax.div(lax.sub(x, loc), scale)
|
||||
log_probs = lax.neg(lax.add(linear_term, log_scale))
|
||||
return jnp.where(lax.lt(x, loc), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``pdf``.
|
||||
|
||||
The Exponential probability distribution function is
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \begin{cases}
|
||||
e^{-x} & x \ge 0 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the exponential distribution probability density function,
|
||||
:func:`jax.scipy.stats.expon.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
|
||||
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
|
||||
return jnp.where(
|
||||
lax.lt(x, loc),
|
||||
jnp.zeros_like(neg_scaled_x),
|
||||
lax.neg(lax.expm1(neg_scaled_x)),
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential log cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the exponential distribution probability density function,
|
||||
:func:`jax.scipy.stats.expon.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.log1p(lax.neg(sf(x, loc, scale)))
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
|
||||
:func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
|
||||
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
|
||||
return jnp.where(lax.lt(x, loc), jnp.zeros_like(neg_scaled_x), neg_scaled_x)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
|
||||
:func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
return lax.exp(logsf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Exponential survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.expon` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.expon.cdf`
|
||||
:func:`jax.scipy.stats.expon.pdf`
|
||||
:func:`jax.scipy.stats.expon.ppf`
|
||||
:func:`jax.scipy.stats.expon.sf`
|
||||
:func:`jax.scipy.stats.expon.logcdf`
|
||||
:func:`jax.scipy.stats.expon.logpdf`
|
||||
:func:`jax.scipy.stats.expon.logsf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
|
||||
neg_scaled_q = lax.div(lax.sub(loc, q), scale)
|
||||
return jnp.where(
|
||||
jnp.isnan(q) | (q < 0) | (q > 1),
|
||||
np.nan,
|
||||
lax.neg(lax.log1p(neg_scaled_q)),
|
||||
)
|
||||
@@ -0,0 +1,237 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy, gammainc, gammaincc
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``logpdf``.
|
||||
|
||||
The Gamma probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x}
|
||||
|
||||
Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
|
||||
It is defined for :math:`x \ge 0` and :math:`a > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
x, a, loc, scale = promote_args_inexact("gamma.logpdf", x, a, loc, scale)
|
||||
ok = lax.ge(x, loc)
|
||||
one = _lax_const(x, 1)
|
||||
y = jnp.where(ok, lax.div(lax.sub(x, loc), scale), one)
|
||||
log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
|
||||
shape_terms = lax.add(gammaln(a), lax.log(scale))
|
||||
log_probs = lax.sub(log_linear_term, shape_terms)
|
||||
return jnp.where(ok, log_probs, -np.inf)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``pdf``.
|
||||
|
||||
The Gamma probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a) = \frac{1}{\Gamma(a)}x^{a-1}e^{-x}
|
||||
|
||||
Where :math:`\Gamma(a)` is the :func:`~jax.scipy.special.gamma` function.
|
||||
It is defined for :math:`x \ge 0` and :math:`a > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, a, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.gamma.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
x, a, loc, scale = promote_args_inexact("gamma.cdf", x, a, loc, scale)
|
||||
return gammainc(
|
||||
a,
|
||||
lax.clamp(
|
||||
_lax_const(x, 0),
|
||||
lax.div(lax.sub(x, loc), scale),
|
||||
_lax_const(x, np.inf),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, a) = \int_{-\infty}^x f_{pdf}(y, a)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.gamma.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
return lax.log(cdf(x, a, loc, scale))
|
||||
|
||||
|
||||
def sf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.gamma.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
- :func:`jax.scipy.stats.gamma.logsf`
|
||||
"""
|
||||
x, a, loc, scale = promote_args_inexact("gamma.sf", x, a, loc, scale)
|
||||
y = lax.div(lax.sub(x, loc), scale)
|
||||
return jnp.where(lax.lt(y, _lax_const(y, 0)), 1, gammaincc(a, y))
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Gamma log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gamma` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.gamma.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gamma.cdf`
|
||||
- :func:`jax.scipy.stats.gamma.pdf`
|
||||
- :func:`jax.scipy.stats.gamma.sf`
|
||||
- :func:`jax.scipy.stats.gamma.logcdf`
|
||||
- :func:`jax.scipy.stats.gamma.logpdf`
|
||||
"""
|
||||
return lax.log(sf(x, a, loc, scale))
|
||||
@@ -0,0 +1,103 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, beta: ArrayLike) -> Array:
|
||||
r"""Generalized normal log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gennorm` ``logpdf``.
|
||||
|
||||
The generalized normal probability distribution function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta)
|
||||
|
||||
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and
|
||||
:math:`\beta > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
beta: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gennorm.cdf`
|
||||
- :func:`jax.scipy.stats.gennorm.pdf`
|
||||
"""
|
||||
x, beta = promote_args_inexact("gennorm.logpdf", x, beta)
|
||||
return lax.log(.5 * beta) - lax.lgamma(1/beta) - lax.abs(x)**beta
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, beta: ArrayLike) -> Array:
|
||||
r"""Generalized normal cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gennorm` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.gennorm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
beta: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gennorm.pdf`
|
||||
- :func:`jax.scipy.stats.gennorm.logpdf`
|
||||
"""
|
||||
x, beta = promote_args_inexact("gennorm.cdf", x, beta)
|
||||
return .5 * (1 + lax.sign(x) * lax.igamma(1/beta, lax.abs(x)**beta))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, beta: ArrayLike) -> Array:
|
||||
r"""Generalized normal probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gennorm` ``pdf``.
|
||||
|
||||
The generalized normal probability distribution function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \beta) = \frac{\beta}{2\Gamma(1/\beta)}\exp(-|x|^\beta)
|
||||
|
||||
where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and
|
||||
:math:`\beta > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
beta: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gennorm.cdf`
|
||||
- :func:`jax.scipy.stats.gennorm.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, beta))
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import xlog1py
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Geometric log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.geom` ``logpmf``.
|
||||
|
||||
The Geometric probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = (1 - p)^{k-1}p
|
||||
|
||||
for :math:`k\ge 1` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.geom.pmf`
|
||||
"""
|
||||
k, p, loc = promote_args_inexact("geom.logpmf", k, p, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
one = _lax_const(k, 1)
|
||||
x = lax.sub(k, loc)
|
||||
log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
|
||||
return jnp.where(lax.le(x, zero), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Geometric probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.geom` ``pmf``.
|
||||
|
||||
The Geometric probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = (1 - p)^{k-1}p
|
||||
|
||||
for :math:`k\ge 1` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.geom.logpmf`
|
||||
"""
|
||||
return jnp.exp(logpmf(k, p, loc))
|
||||
@@ -0,0 +1,256 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.scipy.special import xlogy, xlog1py
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log(pdf)
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
|
||||
x, loc, scale = promote_args_inexact("gumbel_l.logpdf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
# logpdf = -log(scale) + z - exp(z)
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
neg_log_scale = xlogy(-1, scale)
|
||||
t2 = lax.sub(z, lax.exp(z))
|
||||
log_pdf = lax.add(neg_log_scale, t2)
|
||||
return jnp.where(ok, log_pdf, np.nan)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``pdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( \frac{x - \mu}{\beta} - \exp\left( \frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate pdf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of pdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) log cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logcdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log(cdf)
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logcdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_l.logcdf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
neg_exp_z = lax.neg(lax.exp(z))
|
||||
# xlog1p fails here, that's why log1p is used here
|
||||
# even log1p fails for some cases when using float64 mode
|
||||
# so we're using this formula which is stable
|
||||
log_cdf = lax.log(-lax.expm1(neg_exp_z))
|
||||
return jnp.where(ok, log_cdf, np.nan)
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``cdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x; \mu, \beta) = 1 - \exp\left( -\exp\left( \frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate cdf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of cdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
return lax.exp(logcdf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) percent point function (inverse of CDF)
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``ppf``.
|
||||
|
||||
.. math::
|
||||
|
||||
F_{ppf}}(p; \mu, \beta) = \mu + \beta \log\left( -\log(1 - p) \right)
|
||||
|
||||
Args:
|
||||
p: ArrayLike, probability value (quantile) at which to evaluate ppf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of ppf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
p, loc, scale = promote_args_inexact("gumbel_l.ppf", p, loc, scale)
|
||||
ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)),
|
||||
lax.lt(p, _lax_const(p, 1)))
|
||||
# quantile = loc + (scale)*log(-log(1 - p))
|
||||
t1 = xlog1py(-1, lax.neg(p))
|
||||
# xlogp failed here too, that's why log is used
|
||||
t = lax.mul(scale, lax.log(t1))
|
||||
quantile = lax.add(loc, t)
|
||||
return jnp.where(ok, quantile, np.nan)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``sf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate survival function
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of sf values (1 - cdf)
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logsf`
|
||||
"""
|
||||
return jnp.exp(logsf(x, loc, scale))
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Left Skewed) log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logsf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x; \mu, \beta) = 1 - f_{cdf}(x, \mu, \beta)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log survival function
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logsf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_l.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_l.sf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_l.logsf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
# logsf = -exp(z)
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
log_sf = lax.neg(lax.exp(z))
|
||||
return jnp.where(ok, log_sf, np.nan)
|
||||
@@ -0,0 +1,257 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.scipy.special import xlogy
|
||||
from jax._src.nn.functions import log1mexp
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_l` ``logpdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log(pdf)
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
|
||||
x, loc, scale = promote_args_inexact("gumbel_r.logpdf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
# logpdf = -log(beta) - (z + exp(-z))
|
||||
neg_log_scale = xlogy(-1, scale)
|
||||
t2 = lax.neg(lax.add(z, lax.exp(lax.neg(z))))
|
||||
log_pdf = lax.add(neg_log_scale, t2)
|
||||
return jnp.where(ok, log_pdf, np.nan)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``pdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{pdf}(x; \mu, \beta) = \frac{1}{\beta} \exp\left( -\frac{x - \mu}{\beta} - \exp\left( -\frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate pdf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of pdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike,
|
||||
loc: ArrayLike = 0,
|
||||
scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) log cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``logcdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log(cdf)
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logcdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_r.logcdf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
z = lax.div(lax.sub(x, loc), scale)
|
||||
# log cdf = -exp(-z)
|
||||
log_cdf = lax.neg(lax.exp(lax.neg(z)))
|
||||
return jnp.where(ok, log_cdf, np.nan)
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) cumulative density function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``cdf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x; \mu, \beta) = \exp\left( -\exp\left( -\frac{x - \mu}{\beta} \right) \right)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate cdf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of cdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.ppf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
return lax.exp(logcdf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(p: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``ppf``.
|
||||
|
||||
.. math::
|
||||
|
||||
F(p; \mu, \beta) = \mu - \beta \log\left( -\log(p) \right)
|
||||
|
||||
Args:
|
||||
p: ArrayLike, probability value (quantile) at which to evaluate ppf
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of ppf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
p, loc, scale = promote_args_inexact("gumbel_r.ppf", p, loc, scale)
|
||||
# 0 < p < 1
|
||||
ok = lax.bitwise_and(lax.gt(p, _lax_const(p, 0)),
|
||||
lax.lt(p, _lax_const(p, 1)))
|
||||
|
||||
# quantile = loc - (scale)*log(-log(p))
|
||||
t1 = xlogy(-1, p)
|
||||
t = lax.mul(scale, lax.log(t1))
|
||||
quantile = lax.sub(loc, t)
|
||||
return jnp.where(ok, quantile, np.nan)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``sf``.
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x; \mu, \beta) = 1 - F_{cdf}(x; \mu, \beta)
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate survival function
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of sf values (1 - cdf)
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logsf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_r.sf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
# sf = 1 - exp(-exp(-z))
|
||||
neg_z = lax.div(lax.sub(loc, x), scale)
|
||||
t1 = lax.exp(lax.neg(lax.exp(neg_z)))
|
||||
_sf = lax.sub(_lax_const(x, 1), t1)
|
||||
return jnp.where(ok, _sf, np.nan)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""
|
||||
Gumbel Distribution (Right Skewed) log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.gumbel_r` ``logsf``.
|
||||
|
||||
Args:
|
||||
x: ArrayLike, value at which to evaluate log survival function
|
||||
loc: ArrayLike, distribution offset (:math:`\mu`) (defaulted to 0)
|
||||
scale: ArrayLike, distribution scaling (:math:`\beta`) (defaulted to 1)
|
||||
|
||||
Returns:
|
||||
array of logsf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.gumbel_r.logpdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.pdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.logcdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.cdf`
|
||||
- :func:`jax.scipy.stats.gumbel_r.sf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("gumbel_r.logsf", x, loc, scale)
|
||||
ok = lax.gt(scale, _lax_const(scale, 0))
|
||||
# logsf = log(1 - exp(-exp(-z)))
|
||||
neg_z = lax.div(lax.sub(loc, x), scale)
|
||||
log_sf = log1mexp(lax.exp(neg_z))
|
||||
return jnp.where(ok, log_sf, np.nan)
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src import random
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
|
||||
from jax._src.scipy import linalg, special
|
||||
from jax._src.tree_util import register_pytree_node_class
|
||||
from jax._src.typing import Array
|
||||
|
||||
BwMethod = None | str | Array | Callable[[Any], Array]
|
||||
|
||||
@register_pytree_node_class
|
||||
@dataclass(frozen=True, init=False)
|
||||
class gaussian_kde:
|
||||
"""Gaussian Kernel Density Estimator
|
||||
|
||||
JAX implementation of :class:`scipy.stats.gaussian_kde`.
|
||||
|
||||
Parameters:
|
||||
dataset: arraylike, real-valued. Data from which to estimate the distribution.
|
||||
If 1D, shape is (n_data,). If 2D, shape is (n_dimensions, n_data).
|
||||
bw_method: string, scalar, or callable. Either "scott", "silverman", a scalar
|
||||
value, or a callable function which takes ``self`` as a parameter.
|
||||
weights: arraylike, optional. Weights of the same shape as the dataset.
|
||||
"""
|
||||
neff: Any
|
||||
dataset: Any
|
||||
weights: Any
|
||||
covariance: Any
|
||||
inv_cov: Any
|
||||
|
||||
def __init__(self, dataset, bw_method: BwMethod = None, weights=None):
|
||||
check_arraylike("gaussian_kde", dataset)
|
||||
dataset = jnp.atleast_2d(dataset)
|
||||
if dtypes.issubdtype(lax.dtype(dataset), np.complexfloating):
|
||||
raise NotImplementedError("gaussian_kde does not support complex data")
|
||||
if not dataset.size > 1:
|
||||
raise ValueError("`dataset` input should have multiple elements.")
|
||||
|
||||
d, n = dataset.shape
|
||||
if weights is not None:
|
||||
check_arraylike("gaussian_kde", weights)
|
||||
dataset, weights = promote_dtypes_inexact(dataset, weights)
|
||||
weights = jnp.atleast_1d(weights)
|
||||
weights /= jnp.sum(weights)
|
||||
if weights.ndim != 1:
|
||||
raise ValueError("`weights` input should be one-dimensional.")
|
||||
if len(weights) != n:
|
||||
raise ValueError("`weights` input should be of length n")
|
||||
else:
|
||||
dataset, = promote_dtypes_inexact(dataset)
|
||||
weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype)
|
||||
|
||||
self._setattr("dataset", dataset)
|
||||
self._setattr("weights", weights)
|
||||
neff = self._setattr("neff", 1 / jnp.sum(weights**2))
|
||||
|
||||
bw_method = "scott" if bw_method is None else bw_method
|
||||
if bw_method == "scott":
|
||||
factor = jnp.power(neff, -1. / (d + 4))
|
||||
elif bw_method == "silverman":
|
||||
factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4))
|
||||
elif jnp.isscalar(bw_method) and not isinstance(bw_method, str):
|
||||
factor = cast(Array, bw_method)
|
||||
elif callable(bw_method):
|
||||
factor = bw_method(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`bw_method` should be 'scott', 'silverman', a scalar, or a callable."
|
||||
)
|
||||
|
||||
data_covariance = jnp.atleast_2d(
|
||||
jnp.cov(dataset, rowvar=1, bias=False, aweights=weights))
|
||||
data_inv_cov = jnp.linalg.inv(data_covariance)
|
||||
covariance = data_covariance * factor**2
|
||||
inv_cov = data_inv_cov / factor**2
|
||||
self._setattr("covariance", covariance)
|
||||
self._setattr("inv_cov", inv_cov)
|
||||
|
||||
def _setattr(self, name, value):
|
||||
# Frozen dataclasses don't support setting attributes so we have to
|
||||
# overload that operation here as they do in the dataclass implementation
|
||||
object.__setattr__(self, name, value)
|
||||
return value
|
||||
|
||||
def tree_flatten(self):
|
||||
return ((self.neff, self.dataset, self.weights, self.covariance,
|
||||
self.inv_cov), None)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux_data, children):
|
||||
del aux_data
|
||||
kde = cls.__new__(cls)
|
||||
kde._setattr("neff", children[0])
|
||||
kde._setattr("dataset", children[1])
|
||||
kde._setattr("weights", children[2])
|
||||
kde._setattr("covariance", children[3])
|
||||
kde._setattr("inv_cov", children[4])
|
||||
return kde
|
||||
|
||||
@property
|
||||
def d(self):
|
||||
return self.dataset.shape[0]
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return self.dataset.shape[1]
|
||||
|
||||
def evaluate(self, points):
|
||||
"""Evaluate the Gaussian KDE on the given points."""
|
||||
check_arraylike("evaluate", points)
|
||||
points = self._reshape_points(points)
|
||||
result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None],
|
||||
points.T, self.inv_cov)
|
||||
return result[:, 0]
|
||||
|
||||
def __call__(self, points):
|
||||
return self.evaluate(points)
|
||||
|
||||
def integrate_gaussian(self, mean, cov):
|
||||
"""Integrate the distribution weighted by a Gaussian."""
|
||||
mean = jnp.atleast_1d(jnp.squeeze(mean))
|
||||
cov = jnp.atleast_2d(cov)
|
||||
|
||||
if mean.shape != (self.d,):
|
||||
raise ValueError(f"mean does not have dimension {self.d}")
|
||||
if cov.shape != (self.d, self.d):
|
||||
raise ValueError(f"covariance does not have dimension {self.d}")
|
||||
|
||||
chol = linalg.cho_factor(self.covariance + cov)
|
||||
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
|
||||
norm = 1.0 / norm
|
||||
return _gaussian_kernel_convolve(chol, norm, self.dataset, self.weights,
|
||||
mean)
|
||||
|
||||
@api.jit
|
||||
def integrate_box_1d(self, low, high):
|
||||
"""Integrate the distribution over the given limits."""
|
||||
if self.d != 1:
|
||||
raise ValueError("integrate_box_1d() only handles 1D pdfs")
|
||||
if np.ndim(low) != 0 or np.ndim(high) != 0:
|
||||
raise ValueError(
|
||||
"the limits of integration in integrate_box_1d must be scalars")
|
||||
sigma = jnp.squeeze(jnp.sqrt(self.covariance))
|
||||
low = jnp.squeeze((low - self.dataset) / sigma)
|
||||
high = jnp.squeeze((high - self.dataset) / sigma)
|
||||
return jnp.sum(self.weights * (special.ndtr(high) - special.ndtr(low)))
|
||||
|
||||
def integrate_kde(self, other):
|
||||
"""Integrate the product of two Gaussian KDE distributions."""
|
||||
if other.d != self.d:
|
||||
raise ValueError("KDEs are not the same dimensionality")
|
||||
|
||||
chol = linalg.cho_factor(self.covariance + other.covariance)
|
||||
norm = jnp.sqrt(2 * np.pi)**self.d * jnp.prod(jnp.diag(chol[0]))
|
||||
norm = 1.0 / norm
|
||||
|
||||
sm, lg = (self, other) if self.n < other.n else (other, self)
|
||||
result = api.vmap(partial(_gaussian_kernel_convolve, chol, norm, lg.dataset,
|
||||
lg.weights),
|
||||
in_axes=1)(sm.dataset)
|
||||
return jnp.sum(result * sm.weights)
|
||||
|
||||
@partial(api.jit, static_argnames=("shape",))
|
||||
def resample(self, key, shape=()):
|
||||
r"""Randomly sample a dataset from the estimated pdf
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
shape: optional, a tuple of nonnegative integers specifying the result
|
||||
batch shape; that is, the prefix of the result shape excluding the last
|
||||
axis.
|
||||
|
||||
Returns:
|
||||
The resampled dataset as an array with shape `(d,) + shape`.
|
||||
"""
|
||||
ind_key, eps_key = random.split(key)
|
||||
ind = random.choice(ind_key, self.n, shape=shape, p=self.weights)
|
||||
eps = random.multivariate_normal(eps_key,
|
||||
jnp.zeros(self.d, self.covariance.dtype),
|
||||
self.covariance,
|
||||
shape=shape,
|
||||
dtype=self.dataset.dtype).T
|
||||
return self.dataset[:, ind] + eps
|
||||
|
||||
def pdf(self, x):
|
||||
"""Probability density function"""
|
||||
return self.evaluate(x)
|
||||
|
||||
def logpdf(self, x):
|
||||
"""Log probability density function"""
|
||||
check_arraylike("logpdf", x)
|
||||
x = self._reshape_points(x)
|
||||
result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None],
|
||||
x.T, self.inv_cov)
|
||||
return result[:, 0]
|
||||
|
||||
def integrate_box(self, low_bounds, high_bounds, maxpts=None):
|
||||
"""This method is not implemented in the JAX interface."""
|
||||
del low_bounds, high_bounds, maxpts
|
||||
raise NotImplementedError(
|
||||
"only 1D box integrations are supported; use `integrate_box_1d`")
|
||||
|
||||
def set_bandwidth(self, bw_method=None):
|
||||
"""This method is not implemented in the JAX interface."""
|
||||
del bw_method
|
||||
raise NotImplementedError(
|
||||
"dynamically changing the bandwidth method is not supported")
|
||||
|
||||
def _reshape_points(self, points):
|
||||
if dtypes.issubdtype(lax.dtype(points), np.complexfloating):
|
||||
raise NotImplementedError(
|
||||
"gaussian_kde does not support complex coordinates")
|
||||
points = jnp.atleast_2d(points)
|
||||
d, m = points.shape
|
||||
if d != self.d:
|
||||
if d == 1 and m == self.d:
|
||||
points = jnp.reshape(points, (self.d, 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
"points have dimension {}, dataset has dimension {}".format(
|
||||
d, self.d))
|
||||
return points
|
||||
|
||||
|
||||
def _gaussian_kernel_convolve(chol, norm, target, weights, mean):
|
||||
diff = target - mean[:, None]
|
||||
alpha = linalg.cho_solve(chol, diff)
|
||||
arg = 0.5 * jnp.sum(diff * alpha, axis=0)
|
||||
return norm * jnp.sum(jnp.exp(-arg) * weights)
|
||||
|
||||
|
||||
@api.jit(static_argnums=0)
|
||||
def _gaussian_kernel_eval(in_log, points, values, xi, precision):
|
||||
points, values, xi, precision = promote_dtypes_inexact(
|
||||
points, values, xi, precision)
|
||||
d = points.shape[1]
|
||||
|
||||
if xi.shape[1] != d:
|
||||
raise ValueError("points and xi must have same trailing dim")
|
||||
if precision.shape != (d, d):
|
||||
raise ValueError("precision matrix must match data dims")
|
||||
|
||||
whitening = linalg.cholesky(precision, lower=True)
|
||||
points = jnp.dot(points, whitening)
|
||||
xi = jnp.dot(xi, whitening)
|
||||
log_norm = jnp.sum(jnp.log(
|
||||
jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi)
|
||||
|
||||
def kernel(x_test, x_train, y_train):
|
||||
arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test))
|
||||
if in_log:
|
||||
return jnp.log(y_train) + arg
|
||||
else:
|
||||
return y_train * jnp.exp(arg)
|
||||
|
||||
reduce = special.logsumexp if in_log else jnp.sum
|
||||
reduced_kernel = lambda x: reduce(api.vmap(kernel, in_axes=(None, 0, 0))
|
||||
(x, points, values),
|
||||
axis=0)
|
||||
mapped_kernel = api.vmap(reduced_kernel)
|
||||
|
||||
return mapped_kernel(xi)
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Laplace log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.laplace` ``logpdf``.
|
||||
|
||||
The Laplace probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{2} e^{-|x|}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.laplace.cdf`
|
||||
- :func:`jax.scipy.stats.laplace.pdf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("laplace.logpdf", x, loc, scale)
|
||||
two = _lax_const(x, 2)
|
||||
linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
|
||||
return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Laplace probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.laplace` ``pdf``.
|
||||
|
||||
The Laplace probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{2} e^{-|x|}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.laplace.cdf`
|
||||
- :func:`jax.scipy.stats.laplace.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Laplace cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.laplace` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.laplace.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.laplace.pdf`
|
||||
- :func:`jax.scipy.stats.laplace.logpdf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("laplace.cdf", x, loc, scale)
|
||||
half = _lax_const(x, 0.5)
|
||||
one = _lax_const(x, 1)
|
||||
zero = _lax_const(x, 0)
|
||||
diff = lax.div(lax.sub(x, loc), scale)
|
||||
return lax.select(lax.le(diff, zero),
|
||||
lax.mul(half, lax.exp(diff)),
|
||||
lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
|
||||
@@ -0,0 +1,204 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import expit, logit
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Logistic log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``logpdf``.
|
||||
|
||||
The logistic probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.logpdf", x, loc, scale)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
two = _lax_const(x, 2)
|
||||
half_x = lax.div(x, two)
|
||||
return lax.sub(lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x))), lax.log(scale))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Logistic probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``pdf``.
|
||||
|
||||
The logistic probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{e^{-x}}{(1 + e^{-x})^2}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def ppf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Logistic distribution percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.logistic.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PPF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of ppf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.ppf", x, loc, scale)
|
||||
return lax.add(lax.mul(logit(x), scale), loc)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Logistic distribution survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``sf``
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x, k) = 1 - f_{cdf}(x, k)
|
||||
|
||||
where :math:`f_{cdf}(x, k)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.logistic.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.sf", x, loc, scale)
|
||||
return expit(lax.neg(lax.div(lax.sub(x, loc), scale)))
|
||||
|
||||
|
||||
def isf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Logistic distribution inverse survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``isf``.
|
||||
|
||||
Returns the inverse of the survival function,
|
||||
:func:`jax.scipy.stats.logistic.sf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the ISF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of isf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.cdf`
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.isf", x, loc, scale)
|
||||
return lax.add(lax.mul(lax.neg(logit(x)), scale), loc)
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Logistic cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.logistic` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.logistic.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.logistic.pdf`
|
||||
- :func:`jax.scipy.stats.logistic.sf`
|
||||
- :func:`jax.scipy.stats.logistic.isf`
|
||||
- :func:`jax.scipy.stats.logistic.logpdf`
|
||||
- :func:`jax.scipy.stats.logistic.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("logistic.cdf", x, loc, scale)
|
||||
return expit(lax.div(lax.sub(x, loc), scale))
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_args_inexact, promote_args_numeric
|
||||
from jax._src.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
|
||||
r"""Multinomial log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.multinomial` ``logpdf``.
|
||||
|
||||
The multinomial probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}
|
||||
|
||||
with :math:`n = \sum_i x_i`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.multinomial.pmf`
|
||||
"""
|
||||
p, = promote_args_inexact("multinomial.logpmf", p)
|
||||
x, n = promote_args_numeric("multinomial.logpmf", x, n)
|
||||
if not dtypes.issubdtype(x.dtype, np.integer):
|
||||
raise ValueError(f"x and n must be of integer type; got x.dtype={x.dtype}, n.dtype={n.dtype}")
|
||||
x = x.astype(p.dtype)
|
||||
n = n.astype(p.dtype)
|
||||
logprobs = gammaln(n + 1) + jnp.sum(xlogy(x, p) - gammaln(x + 1), axis=-1)
|
||||
return jnp.where(jnp.equal(jnp.sum(x), n), logprobs, -np.inf)
|
||||
|
||||
|
||||
def pmf(x: ArrayLike, n: ArrayLike, p: ArrayLike) -> Array:
|
||||
r"""Multinomial probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.multinomial` ``pmf``.
|
||||
|
||||
The multinomial probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, n, p) = n! \prod_{i=1}^k \frac{p_i^{x_i}}{x_i!}
|
||||
|
||||
with :math:`n = \sum_i x_i`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.multinomial.logpmf`
|
||||
"""
|
||||
return lax.exp(logpmf(x, n, p))
|
||||
@@ -0,0 +1,103 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import einsum as jnp_einsum
|
||||
from jax._src.numpy import vectorize as jnp_vectorize
|
||||
from jax._src.numpy.util import promote_dtypes_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike, allow_singular: None = None) -> ArrayLike:
|
||||
r"""Multivariate normal log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.multivariate_normal` ``logpdf``.
|
||||
|
||||
The multivariate normal PDF is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)
|
||||
|
||||
where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and
|
||||
:math:`k` is the rank of :math:`\Sigma`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
mean: arraylike, centroid of distribution
|
||||
cov: arraylike, covariance matrix of distribution
|
||||
allow_singular: not supported
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.multivariate_normal.pdf`
|
||||
"""
|
||||
if allow_singular is not None:
|
||||
raise NotImplementedError("allow_singular argument of multivariate_normal.logpdf")
|
||||
x, mean, cov = promote_dtypes_inexact(x, mean, cov)
|
||||
if not mean.shape:
|
||||
return (-1/2 * jnp.square(x - mean) / cov
|
||||
- 1/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
|
||||
else:
|
||||
n = mean.shape[-1]
|
||||
if not np.shape(cov):
|
||||
y = x - mean
|
||||
return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) / cov
|
||||
- n/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
|
||||
else:
|
||||
if cov.ndim < 2 or cov.shape[-2:] != (n, n):
|
||||
raise ValueError("multivariate_normal.logpdf got incompatible shapes")
|
||||
L = lax.linalg.cholesky(cov)
|
||||
y = jnp_vectorize.vectorize(
|
||||
partial(lax.linalg.triangular_solve, lower=True, transpose_a=True),
|
||||
signature="(n,n),(n)->(n)"
|
||||
)(L, x - mean)
|
||||
return (-1/2 * jnp_einsum.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi)
|
||||
- jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, mean: ArrayLike, cov: ArrayLike) -> Array:
|
||||
r"""Multivariate normal probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.multivariate_normal` ``pdf``.
|
||||
|
||||
The multivariate normal PDF is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{(2\pi)^k\det\Sigma}\exp\left(-\frac{(x-\mu)^T\Sigma^{-1}(x-\mu)}{2} \right)
|
||||
|
||||
where :math:`\mu` is the ``mean``, :math:`\Sigma` is the covariance matrix (``cov``), and
|
||||
:math:`k` is the rank of :math:`\Sigma`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
mean: arraylike, centroid of distribution
|
||||
cov: arraylike, covariance matrix of distribution
|
||||
allow_singular: not supported
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.multivariate_normal.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, mean, cov))
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.special import gammaln, xlogy
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Negative-binomial log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.nbinom` ``logpmf``.
|
||||
|
||||
The negative-binomial probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k
|
||||
|
||||
for :math:`k \ge 0` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.nbinom.pmf`
|
||||
"""
|
||||
k, n, p, loc = promote_args_inexact("nbinom.logpmf", k, n, p, loc)
|
||||
one = _lax_const(k, 1)
|
||||
y = lax.sub(k, loc)
|
||||
comb_term = lax.sub(
|
||||
lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one))
|
||||
)
|
||||
log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p)))
|
||||
log_probs = lax.add(comb_term, log_linear_term)
|
||||
return jnp.where(lax.lt(k, loc), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, n: ArrayLike, p: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Negative-binomial probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.nbinom` ``pmf``.
|
||||
|
||||
The negative-binomial probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = {{k+n-1} \choose {n-1}}p^n(1-p)^k
|
||||
|
||||
for :math:`k \ge 0` and :math:`0 \le p \le 1`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
n: arraylike, distribution shape parameter
|
||||
p: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.nbinom.logpmf`
|
||||
"""
|
||||
return lax.exp(logpmf(k, n, p, loc))
|
||||
@@ -0,0 +1,284 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy import special
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Normal log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``logpdf``.
|
||||
|
||||
The normal distribution pdf is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.logpdf", x, loc, scale)
|
||||
scale_sqrd = lax.square(scale)
|
||||
log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd))
|
||||
quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd)
|
||||
return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Normal probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``pdf``.
|
||||
|
||||
The normal distribution pdf is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \frac{1}{\sqrt{2\pi}}e^{-x^2/2}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Normal cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.norm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.cdf", x, loc, scale)
|
||||
return special.ndtr(lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Normal log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(x, k) = \int_{-\infty}^x f_{pdf}(y, k)\mathrm{d}y
|
||||
|
||||
where :math:`f_{pdf}` is the probability density function,
|
||||
:func:`jax.scipy.stats.norm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.logcdf", x, loc, scale)
|
||||
return special.log_ndtr(lax.div(lax.sub(x, loc), scale))
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Normal distribution percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.norm.cdf`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the PPF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of ppf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
"""
|
||||
return jnp.asarray(special.ndtri(q) * scale + loc, float)
|
||||
|
||||
|
||||
def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Normal distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``logsf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.norm.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.logsf", x, loc, scale)
|
||||
return logcdf(-x, -loc, scale)
|
||||
|
||||
|
||||
def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Normal distribution survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``sf``.
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.norm.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.isf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("norm.sf", x, loc, scale)
|
||||
return cdf(-x, -loc, scale)
|
||||
|
||||
|
||||
def isf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Normal distribution inverse survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.norm` ``isf``.
|
||||
|
||||
Returns the inverse of the survival function,
|
||||
:func:`jax.scipy.stats.norm.sf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the ISF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of isf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.norm.cdf`
|
||||
- :func:`jax.scipy.stats.norm.pdf`
|
||||
- :func:`jax.scipy.stats.norm.sf`
|
||||
- :func:`jax.scipy.stats.norm.logcdf`
|
||||
- :func:`jax.scipy.stats.norm.logpdf`
|
||||
- :func:`jax.scipy.stats.norm.logsf`
|
||||
- :func:`jax.scipy.stats.norm.ppf`
|
||||
"""
|
||||
return ppf(lax.sub(_lax_const(q, 1), q), loc, scale)
|
||||
@@ -0,0 +1,313 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``logpdf``.
|
||||
|
||||
The Pareto probability density function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, b) = \begin{cases}
|
||||
bx^{-(b+1)} & x \ge 1\\
|
||||
0 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
x, b, loc, scale = promote_args_inexact("pareto.logpdf", x, b, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
normalize_term = lax.log(lax.div(scale, b))
|
||||
log_probs = lax.neg(
|
||||
lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))
|
||||
)
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pdf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``pdf``.
|
||||
|
||||
The Pareto probability density function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, b) = \begin{cases}
|
||||
bx^{-(b+1)} & x \ge 1\\
|
||||
0 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, b, loc, scale))
|
||||
|
||||
|
||||
def cdf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``cdf``.
|
||||
|
||||
The Pareto cumulative distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
F(x, b) = \begin{cases}
|
||||
1 - x^{-b} & x \ge 1\\
|
||||
0 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of CDF values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
x, b, loc, scale = promote_args_inexact("pareto.cdf", x, b, loc, scale)
|
||||
one = _lax_const(x, 1)
|
||||
zero = _lax_const(x, 0)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
cdf = lax.sub(one, lax.pow(scaled_x, lax.neg(b)))
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, cdf)
|
||||
|
||||
|
||||
def logcdf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``logcdf``.
|
||||
|
||||
The Pareto cumulative distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
F(x, b) = \begin{cases}
|
||||
1 - x^{-b} & x \ge 1\\
|
||||
0 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logCDF values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
x, b, loc, scale = promote_args_inexact("pareto.logcdf", x, b, loc, scale)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
logcdf_val = lax.log1p(lax.neg(lax.pow(scaled_x, lax.neg(b))))
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), -np.inf, logcdf_val)
|
||||
|
||||
|
||||
def logsf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``logsf``.
|
||||
|
||||
The Pareto survival function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
S(x, b) = \begin{cases}
|
||||
x^{-b} & x \ge 1\\
|
||||
1 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the survival function
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of log survival function values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
x, b, loc, scale = promote_args_inexact("pareto.logsf", x, b, loc, scale)
|
||||
zero = _lax_const(x, 0)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
logsf_val = lax.neg(lax.mul(b, lax.log(scaled_x)))
|
||||
return jnp.where(lax.lt(x, lax.add(loc, scale)), zero, logsf_val)
|
||||
|
||||
|
||||
def sf(
|
||||
x: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``sf``.
|
||||
|
||||
The Pareto survival function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
S(x, b) = \begin{cases}
|
||||
x^{-b} & x \ge 1\\
|
||||
1 & x < 1
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the survival function
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of survival function values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.ppf`
|
||||
"""
|
||||
return lax.exp(logsf(x, b, loc, scale))
|
||||
|
||||
|
||||
def ppf(
|
||||
q: ArrayLike, b: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1
|
||||
) -> Array:
|
||||
r"""Pareto percent point function (inverse CDF).
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.pareto` ``ppf``.
|
||||
|
||||
The Pareto percent point function is the inverse of the Pareto CDF, and is
|
||||
given by
|
||||
|
||||
.. math::
|
||||
|
||||
F^{-1}(q, b) = \begin{cases}
|
||||
(1 - q)^{-1/b} & 0 \le q < 1\\
|
||||
\text{NaN} & \text{otherwise}
|
||||
\end{cases}
|
||||
|
||||
and is defined for :math:`b > 0`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the inverse CDF
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of percent point function values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.pareto.logcdf`
|
||||
- :func:`jax.scipy.stats.pareto.logpdf`
|
||||
- :func:`jax.scipy.stats.pareto.logsf`
|
||||
- :func:`jax.scipy.stats.pareto.cdf`
|
||||
- :func:`jax.scipy.stats.pareto.pdf`
|
||||
- :func:`jax.scipy.stats.pareto.sf`
|
||||
"""
|
||||
q, b, loc, scale = promote_args_inexact("pareto.ppf", q, b, loc, scale)
|
||||
one = _lax_const(q, 1)
|
||||
ppf_val = lax.add(
|
||||
loc, lax.mul(scale, lax.pow(lax.sub(one, q), lax.neg(lax.div(one, b))))
|
||||
)
|
||||
return jnp.where(jnp.isnan(q) | (q < 0) | (q > 1), np.nan, ppf_val)
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact, promote_dtypes_inexact, ensure_arraylike
|
||||
from jax._src.scipy.special import xlogy, entr, gammaln, gammaincc
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Poisson log probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.poisson` ``logpmf``.
|
||||
|
||||
The Poisson probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = e^{-\mu}\frac{\mu^k}{k!}
|
||||
|
||||
and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
mu: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of logpmf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.poisson.cdf`
|
||||
- :func:`jax.scipy.stats.poisson.pmf`
|
||||
"""
|
||||
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
x = lax.sub(k, loc)
|
||||
log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
|
||||
return jnp.where(jnp.logical_or(lax.lt(x, zero),
|
||||
lax.ne(jnp.round(k), k)), -np.inf, log_probs)
|
||||
|
||||
|
||||
def pmf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Poisson probability mass function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.poisson` ``pmf``.
|
||||
|
||||
The Poisson probability mass function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(k) = e^{-\mu}\frac{\mu^k}{k!}
|
||||
|
||||
and is defined for :math:`k \ge 0` and :math:`\mu \ge 0`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the PMF
|
||||
mu: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of pmf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.poisson.cdf`
|
||||
- :func:`jax.scipy.stats.poisson.logpmf`
|
||||
"""
|
||||
return jnp.exp(logpmf(k, mu, loc))
|
||||
|
||||
|
||||
def cdf(k: ArrayLike, mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Poisson cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.poisson` ``cdf``.
|
||||
|
||||
The cumulative distribution function is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf}(k, p) = \sum_{i=0}^k f_{pmf}(k, p)
|
||||
|
||||
where :math:`f_{pmf}(k, p)` is the probability mass function
|
||||
:func:`jax.scipy.stats.poisson.pmf`.
|
||||
|
||||
Args:
|
||||
k: arraylike, value at which to evaluate the CDF
|
||||
mu: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.poisson.pmf`
|
||||
- :func:`jax.scipy.stats.poisson.logpmf`
|
||||
"""
|
||||
k, mu, loc = promote_args_inexact("poisson.logpmf", k, mu, loc)
|
||||
zero = _lax_const(k, 0)
|
||||
x = lax.sub(k, loc)
|
||||
p = gammaincc(jnp.floor(1 + x), mu)
|
||||
return jnp.where(lax.lt(x, zero), zero, p)
|
||||
|
||||
@api.jit
|
||||
def entropy(mu: ArrayLike, loc: ArrayLike = 0) -> Array:
|
||||
r"""Shannon entropy of the Poisson distribution.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.poisson` ``entropy``.
|
||||
|
||||
The entropy :math:`H(X)` of a Poisson random variable
|
||||
:math:`X \sim \text{Poisson}(\mu)` is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
H(X) = -\sum_{k=0}^\infty p(k) \log p(k)
|
||||
|
||||
where :math:`p(k) = e^{-\mu} \mu^k / k!` for
|
||||
:math:`k \geq \max(0, \lfloor \text{loc} \rfloor)`.
|
||||
|
||||
This implementation uses **regime switching** for numerical stability
|
||||
and performance:
|
||||
|
||||
- **Small** :math:`\mu < 10`: Direct summation over PMF with adaptive
|
||||
upper bound :math:`k \leq \mu + 20`
|
||||
- **Medium** :math:`10 \leq \mu < 100`: Summation with bound
|
||||
:math:`k \leq \mu + 10\sqrt{\mu} + 20`
|
||||
- **Large** :math:`\mu \geq 100`: Asymptotic Stirling approximation:
|
||||
:math:`H(\mu) \approx \frac{1}{2} \log(2\pi e \mu) - \frac{1}{12\mu}`
|
||||
|
||||
Matches SciPy to relative error :math:`< 10^{-5}` across all regimes.
|
||||
|
||||
Args:
|
||||
mu: arraylike, mean parameter of the Poisson distribution.
|
||||
Must be ``> 0``.
|
||||
loc: arraylike, optional location parameter (default: 0).
|
||||
Accepted for API compatibility with scipy but does not
|
||||
affect the entropy
|
||||
|
||||
Returns:
|
||||
Array of entropy values with shape broadcast from ``mu`` and ``loc``.
|
||||
Returns ``NaN`` for ``mu <= 0``.
|
||||
|
||||
Examples:
|
||||
>>> from jax.scipy.stats import poisson
|
||||
>>> poisson.entropy(5.0)
|
||||
Array(2.204394, dtype=float32)
|
||||
>>> poisson.entropy(jax.numpy.array([1, 10, 100]))
|
||||
Array([1.3048419, 2.561407 , 3.7206903], dtype=float32)
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.poisson.pmf`
|
||||
- :func:`jax.scipy.stats.poisson.logpmf`
|
||||
- :obj:`scipy.stats.poisson`
|
||||
"""
|
||||
mu, loc = ensure_arraylike("poisson.entropy", mu, loc)
|
||||
promoted_mu, promoted_loc = promote_dtypes_inexact(mu, loc)
|
||||
|
||||
#Note: loc does not affect the entropy - translation invariant
|
||||
#it has only been taken to maintain compatibility with scipy api
|
||||
result_shape = jnp.broadcast_shapes(
|
||||
promoted_mu.shape,
|
||||
promoted_loc.shape
|
||||
)
|
||||
|
||||
mu_flat = jnp.ravel(promoted_mu)
|
||||
zero_result = jnp.zeros_like(mu_flat)
|
||||
|
||||
|
||||
# Choose the computation regime based on mu value
|
||||
result = jnp.where(
|
||||
mu_flat == 0,
|
||||
zero_result,
|
||||
jnp.where(
|
||||
mu_flat < 10,
|
||||
_entropy_small_mu(mu_flat),
|
||||
jnp.where(
|
||||
mu_flat < 100,
|
||||
_entropy_medium_mu(mu_flat),
|
||||
_entropy_large_mu(mu_flat)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
result_mu_shape = jnp.reshape(result, promoted_mu.shape)
|
||||
|
||||
# Restore original shape
|
||||
return jnp.broadcast_to(result_mu_shape, result_shape)
|
||||
|
||||
def _entropy_small_mu(mu: Array) -> Array:
|
||||
"""Entropy via direct PMF summation for small μ (< 10).
|
||||
Uses adaptive upper bound k ≤ μ + 20 to capture >99.999% of mass.
|
||||
"""
|
||||
max_k = 35
|
||||
|
||||
k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
|
||||
probs = pmf(k, mu, 0)
|
||||
|
||||
# Mask: only compute up to mu + 20 for each value
|
||||
upper_bounds = jnp.ceil(mu + 20).astype(k.dtype)
|
||||
mask = k < upper_bounds[None, :]
|
||||
probs_masked = jnp.where(mask, probs, 0.0)
|
||||
|
||||
return jnp.sum(entr(probs_masked), axis=0)
|
||||
|
||||
def _entropy_medium_mu(mu: Array) -> Array:
|
||||
"""Entropy for medium mu (10-100): Adaptive bounds based on std dev.
|
||||
|
||||
Bounds: k ≤ μ + 10√μ + 20. Caps at k=250 for JIT compatibility.
|
||||
"""
|
||||
max_k = 250 # Static bound for JIT. For mu<100, upper bound < 220
|
||||
|
||||
k = jnp.arange(max_k, dtype=mu.dtype)[:, None]
|
||||
probs = pmf(k, mu, 0)
|
||||
|
||||
upper_bounds = jnp.ceil(mu + 10 * jnp.sqrt(mu) + 20).astype(k.dtype)
|
||||
mask = k < upper_bounds[None, :]
|
||||
probs_masked = jnp.where(mask, probs, 0.0)
|
||||
|
||||
return jnp.sum(entr(probs_masked), axis=0)
|
||||
|
||||
def _entropy_large_mu(mu: Array) -> Array:
|
||||
"""Entropy for large mu (>= 100): Asymptotic approximation.
|
||||
|
||||
Formula: H(λ) ≈ 0.5*log(2πeλ) - 1/(12λ) + O(λ^-2)
|
||||
"""
|
||||
return 0.5 * jnp.log(2 * np.pi * np.e * mu) - 1.0 / (12 * mu)
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Student's T log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.t` ``logpdf``.
|
||||
|
||||
The Student's T probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2}
|
||||
|
||||
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0`
|
||||
is the degrees of freedom (JAX follows the scipy convention of naming this ``df``).
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.t.pdf`
|
||||
"""
|
||||
x, df, loc, scale = promote_args_inexact("t.logpdf", x, df, loc, scale)
|
||||
two = _lax_const(x, 2)
|
||||
scaled_x = lax.div(lax.sub(x, loc), scale)
|
||||
df_over_two = lax.div(df, two)
|
||||
df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
|
||||
normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
|
||||
normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
|
||||
normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
|
||||
lax.lgamma(df_plus_one_over_two))
|
||||
quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
|
||||
return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, df: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Student's T probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.t` ``pdf``.
|
||||
|
||||
The Student's T probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \nu) = \frac{\Gamma((\nu + 1)/2)}{\sqrt{\pi\nu}\Gamma(\nu/2)}(1 + x^2/\nu)^{(\nu+1)/2}
|
||||
|
||||
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function, and :math:`\nu > 0`
|
||||
is the degrees of freedom (JAX follows the scipy convention of naming this ``df``).
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
df: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.t.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, df, loc, scale))
|
||||
@@ -0,0 +1,296 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.scipy.stats import norm
|
||||
from jax._src.scipy.special import logsumexp, log_ndtr, ndtr
|
||||
|
||||
|
||||
def _log_diff(x, y):
|
||||
return logsumexp(
|
||||
jnp.array([x, y]),
|
||||
b=jnp.array([jnp.ones_like(x), -jnp.ones_like(y)]),
|
||||
axis=0
|
||||
)
|
||||
|
||||
|
||||
def _log_gauss_mass(a, b):
|
||||
"""Log of Gaussian probability mass within an interval"""
|
||||
a, b = jnp.array(a), jnp.array(b)
|
||||
a, b = jnp.broadcast_arrays(a, b)
|
||||
|
||||
# Note: Docstring carried over from scipy
|
||||
# Calculations in right tail are inaccurate, so we'll exploit the
|
||||
# symmetry and work only in the left tail
|
||||
case_left = b <= 0
|
||||
case_right = a > 0
|
||||
case_central = ~(case_left | case_right)
|
||||
|
||||
# By conditionally swapping arguments if we're in the right tail,
|
||||
# we only need to compile the mass_case_left graph once instead of twice.
|
||||
a_tail = jnp.where(case_right, -b, a)
|
||||
b_tail = jnp.where(case_right, -a, b)
|
||||
|
||||
mass_tail = _log_diff(log_ndtr(b_tail), log_ndtr(a_tail))
|
||||
|
||||
# Catastrophic cancellation occurs as np.exp(log_mass) approaches 1.
|
||||
# Correct for this with an alternative formulation.
|
||||
# We're not concerned with underflow here: if only one term
|
||||
# underflows, it was insignificant; if both terms underflow,
|
||||
# the result can't accurately be represented in logspace anyway
|
||||
# because sc.log1p(x) ~ x for small x.
|
||||
mass_central = jnp.log1p(-ndtr(a) - ndtr(-b))
|
||||
|
||||
out = jnp.where(case_central, mass_central, mass_tail)
|
||||
return out
|
||||
|
||||
|
||||
@api.jit
|
||||
def logpdf(x, a, b, loc=0, scale=1):
|
||||
r"""Truncated normal log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``logpdf``.
|
||||
|
||||
The truncated normal probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a, b) = \begin{cases}
|
||||
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
where :math:`a` and :math:`b` are effectively specified in number of
|
||||
standard deviations from zero. JAX uses the scipy nomenclature
|
||||
of ``loc`` for the centroid and ``scale`` for the standard deviation.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logpdf", x, a, b, loc, scale)
|
||||
val = lax.sub(norm.logpdf(x, loc, scale), _log_gauss_mass(a, b))
|
||||
|
||||
x_scaled = lax.div(lax.sub(x, loc), scale)
|
||||
val = jnp.where((x_scaled < a) | (x_scaled > b), -np.inf, val)
|
||||
val = jnp.where(a >= b, np.nan, val)
|
||||
return val
|
||||
|
||||
|
||||
def pdf(x, a, b, loc=0, scale=1):
|
||||
r"""Truncated normal probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``pdf``.
|
||||
|
||||
The truncated normal probability distribution is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, a, b) = \begin{cases}
|
||||
\frac{1}{\sqrt{2\pi}}e^{-x^2/2} & a \le x \le b \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
where :math:`a` and :math:`b` are effectively specified in number of
|
||||
standard deviations from the centroid. JAX uses the scipy nomenclature
|
||||
of ``loc`` for the centroid and ``scale`` for the standard deviation.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logsf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, a, b, loc, scale))
|
||||
|
||||
@api.jit
|
||||
def logsf(x, a, b, loc=0, scale=1):
|
||||
"""Truncated normal distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.truncnorm.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logsf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logsf", x, a, b, loc, scale)
|
||||
return logcdf(-x, -b, -a, -loc, scale)
|
||||
|
||||
|
||||
@api.jit
|
||||
def sf(x, a, b, loc=0, scale=1):
|
||||
"""Truncated normal distribution log survival function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``logsf``
|
||||
|
||||
The survival function is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{sf}(x) = 1 - f_{cdf}(x)
|
||||
|
||||
where :math:`f_{cdf}(x)` is the cumulative distribution function,
|
||||
:func:`jax.scipy.stats.truncnorm.cdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the SF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of sf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
"""
|
||||
return lax.exp(logsf(x, a, b, loc, scale))
|
||||
|
||||
|
||||
@api.jit
|
||||
def logcdf(x, a, b, loc=0, scale=1):
|
||||
r"""Truncated normal log cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``logcdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the probability distribution function,
|
||||
:func:`jax.scipy.stats.truncnorm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logcdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.cdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logsf`
|
||||
"""
|
||||
x, a, b, loc, scale = promote_args_inexact("truncnorm.logcdf", x, a, b, loc, scale)
|
||||
x, a, b = jnp.broadcast_arrays(x, a, b)
|
||||
x = lax.div(lax.sub(x, loc), scale)
|
||||
lgm_ab = _log_gauss_mass(a, b)
|
||||
logcdf = _log_gauss_mass(a, x) - lgm_ab
|
||||
logsf = _log_gauss_mass(x, b) - lgm_ab
|
||||
|
||||
logcdf = jnp.select(
|
||||
# third condition: avoid catastrophic cancellation (from scipy)
|
||||
[x >= b, x <= a, logcdf > -0.1, x > a],
|
||||
[0, -np.inf, jnp.log1p(-jnp.exp(logsf)), logcdf]
|
||||
)
|
||||
logcdf = jnp.where(a >= b, np.nan, logcdf)
|
||||
return logcdf
|
||||
|
||||
@api.jit
|
||||
def cdf(x, a, b, loc=0, scale=1):
|
||||
r"""Truncated normal cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.truncnorm` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the probability distribution function,
|
||||
:func:`jax.scipy.stats.truncnorm.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
a: arraylike, distribution shape parameter
|
||||
b: arraylike, distribution shape parameter
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.truncnorm.pdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.sf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logcdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logpdf`
|
||||
- :func:`jax.scipy.stats.truncnorm.logsf`
|
||||
"""
|
||||
return lax.exp(logcdf(x, a, b, loc, scale))
|
||||
@@ -0,0 +1,148 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Uniform log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.uniform` ``logpdf``.
|
||||
|
||||
The uniform distribution pdf is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \begin{cases}
|
||||
1 & 0 \le x \le 1 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.uniform.cdf`
|
||||
- :func:`jax.scipy.stats.uniform.pdf`
|
||||
- :func:`jax.scipy.stats.uniform.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("uniform.logpdf", x, loc, scale)
|
||||
log_probs = lax.neg(lax.log(scale))
|
||||
return jnp.where(jnp.logical_or(lax.gt(x, lax.add(loc, scale)),
|
||||
lax.lt(x, loc)),
|
||||
-np.inf, log_probs)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Uniform probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.uniform` ``pdf``.
|
||||
|
||||
The uniform distribution pdf is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x) = \begin{cases}
|
||||
1 & 0 \le x \le 1 \\
|
||||
0 & \mathrm{otherwise}
|
||||
\end{cases}
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.uniform.cdf`
|
||||
- :func:`jax.scipy.stats.uniform.logpdf`
|
||||
- :func:`jax.scipy.stats.uniform.ppf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, loc, scale))
|
||||
|
||||
|
||||
def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
r"""Uniform cumulative distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.uniform` ``cdf``.
|
||||
|
||||
The cdf is defined as
|
||||
|
||||
.. math::
|
||||
|
||||
f_{cdf} = \int_{-\infty}^x f_{pdf}(y) \mathrm{d}y
|
||||
|
||||
where here :math:`f_{pdf}` is the probability distribution function,
|
||||
:func:`jax.scipy.stats.uniform.pdf`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the CDF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of cdf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.uniform.pdf`
|
||||
- :func:`jax.scipy.stats.uniform.logpdf`
|
||||
- :func:`jax.scipy.stats.uniform.ppf`
|
||||
"""
|
||||
x, loc, scale = promote_args_inexact("uniform.cdf", x, loc, scale)
|
||||
zero, one = jnp.array(0, x.dtype), jnp.array(1, x.dtype)
|
||||
conds = [lax.lt(x, loc), lax.gt(x, lax.add(loc, scale)), lax.ge(x, loc) & lax.le(x, lax.add(loc, scale))]
|
||||
vals = [zero, one, lax.div(lax.sub(x, loc), scale)]
|
||||
|
||||
return jnp.select(conds, vals)
|
||||
|
||||
|
||||
def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
|
||||
"""Uniform distribution percent point function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.uniform` ``ppf``.
|
||||
|
||||
The percent point function is defined as the inverse of the
|
||||
cumulative distribution function, :func:`jax.scipy.stats.uniform.cdf`.
|
||||
|
||||
Args:
|
||||
q: arraylike, value at which to evaluate the PPF
|
||||
loc: arraylike, distribution offset parameter
|
||||
scale: arraylike, distribution scale parameter
|
||||
|
||||
Returns:
|
||||
array of ppf values.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.stats.uniform.cdf`
|
||||
- :func:`jax.scipy.stats.uniform.pdf`
|
||||
- :func:`jax.scipy.stats.uniform.logpdf`
|
||||
"""
|
||||
q, loc, scale = promote_args_inexact("uniform.ppf", q, loc, scale)
|
||||
return jnp.where(
|
||||
jnp.isnan(q) | (q < 0) | (q > 1),
|
||||
np.nan,
|
||||
lax.add(loc, lax.mul(scale, q))
|
||||
)
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, kappa: ArrayLike) -> Array:
|
||||
r"""von Mises log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.vonmises` ``logpdf``.
|
||||
|
||||
The von Mises probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x}
|
||||
|
||||
Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0`
|
||||
and :math:`\kappa\ge 0`, and the distribution is normalized in the interval
|
||||
:math:`-\pi \le x \le \pi`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
kappa: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.vonmises.pdf`
|
||||
"""
|
||||
x, kappa = promote_args_inexact('vonmises.logpdf', x, kappa)
|
||||
zero = _lax_const(kappa, 0)
|
||||
return jnp.where(lax.gt(kappa, zero), kappa * (jnp.cos(x) - 1) - jnp.log(2 * np.pi * lax.bessel_i0e(kappa)), np.nan)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, kappa: ArrayLike) -> Array:
|
||||
r"""von Mises probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.vonmises` ``pdf``.
|
||||
|
||||
The von Mises probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, \kappa) = \frac{1}{2\pi I_0(\kappa)}e^{\kappa\cos x}
|
||||
|
||||
Where :math:`I_0` is the modified Bessel function :func:`~jax.scipy.special.i0`
|
||||
and :math:`\kappa\ge 0`, and the distribution is normalized in the interval
|
||||
:math:`-\pi \le x \le \pi`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
kappa: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.vonmises.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, kappa))
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax.lax import _const as _lax_const
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def logpdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
r"""Wrapped Cauchy log probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.wrapcauchy` ``logpdf``.
|
||||
|
||||
The wrapped Cauchy probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)}
|
||||
|
||||
for :math:`0<c<1`, and where normalization is on the domain :math:`0\le x\le 2\pi`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
c: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of logpdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.wrapcauchy.pdf`
|
||||
"""
|
||||
x, c = promote_args_inexact('wrapcauchy.logpdf', x, c)
|
||||
return jnp.where(
|
||||
lax.gt(c, _lax_const(c, 0)) & lax.lt(c, _lax_const(c, 1)),
|
||||
jnp.where(
|
||||
lax.ge(x, _lax_const(x, 0)) & lax.le(x, _lax_const(x, np.pi * 2)),
|
||||
jnp.log(1 - c * c) - jnp.log(2 * np.pi) - jnp.log(1 + c * c - 2 * c * jnp.cos(x)),
|
||||
-np.inf,
|
||||
),
|
||||
np.nan,
|
||||
)
|
||||
|
||||
|
||||
def pdf(x: ArrayLike, c: ArrayLike) -> Array:
|
||||
r"""Wrapped Cauchy probability distribution function.
|
||||
|
||||
JAX implementation of :obj:`scipy.stats.wrapcauchy` ``pdf``.
|
||||
|
||||
The wrapped Cauchy probability distribution function is given by
|
||||
|
||||
.. math::
|
||||
|
||||
f(x, c) = \frac{1-c^2}{2\pi(1+c^2-2c\cos x)}
|
||||
|
||||
for :math:`0<c<1`, and where normalization is on the domain :math:`0\le x\le 2\pi`.
|
||||
|
||||
Args:
|
||||
x: arraylike, value at which to evaluate the PDF
|
||||
c: arraylike, distribution shape parameter
|
||||
|
||||
Returns:
|
||||
array of pdf values.
|
||||
|
||||
See Also:
|
||||
:func:`jax.scipy.stats.wrapcauchy.logpdf`
|
||||
"""
|
||||
return lax.exp(logpdf(x, c))
|
||||
Reference in New Issue
Block a user