hand
This commit is contained in:
Vendored
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
venv/lib/python3.12/site-packages/jax/_src/third_party/scipy/__pycache__/interpolate.cpython-312.pyc
Vendored
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
BIN
Binary file not shown.
Vendored
BIN
Binary file not shown.
@@ -0,0 +1,63 @@
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
|
||||
|
||||
def algdiv(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
"""
|
||||
Compute ``log(gamma(a))/log(gamma(a + b))`` when ``b >= 8``.
|
||||
|
||||
Derived from scipy's implementation of `algdiv`_.
|
||||
|
||||
This differs from the scipy implementation in that it assumes a <= b
|
||||
because recomputing ``a, b = jnp.minimum(a, b), jnp.maximum(a, b)`` might
|
||||
be expensive and this is only called by ``betaln``.
|
||||
|
||||
.. _algdiv:
|
||||
https://github.com/scipy/scipy/blob/c89dfc2b90d993f2a8174e57e0cbc8fbe6f3ee19/scipy/special/cdflib/algdiv.f
|
||||
"""
|
||||
c0 = 0.833333333333333e-01
|
||||
c1 = -0.277777777760991e-02
|
||||
c2 = 0.793650666825390e-03
|
||||
c3 = -0.595202931351870e-03
|
||||
c4 = 0.837308034031215e-03
|
||||
c5 = -0.165322962780713e-02
|
||||
h = a / b
|
||||
c = h / (1 + h)
|
||||
x = h / (1 + h)
|
||||
d = b + (a - 0.5)
|
||||
# Set sN = (1 - x**n)/(1 - x)
|
||||
x2 = x * x
|
||||
s3 = 1.0 + (x + x2)
|
||||
s5 = 1.0 + (x + x2 * s3)
|
||||
s7 = 1.0 + (x + x2 * s5)
|
||||
s9 = 1.0 + (x + x2 * s7)
|
||||
s11 = 1.0 + (x + x2 * s9)
|
||||
# Set w = del(b) - del(a + b)
|
||||
# where del(x) is defined by ln(gamma(x)) = (x - 0.5)*ln(x) - x + 0.5*ln(2*pi) + del(x)
|
||||
t = (1.0 / b) ** 2
|
||||
w = ((((c5 * s11 * t + c4 * s9) * t + c3 * s7) * t + c2 * s5) * t + c1 * s3) * t + c0
|
||||
w = w * (c / b)
|
||||
# Combine the results
|
||||
u = d * lax.log1p(a / b)
|
||||
v = a * (lax.log(b) - 1.0)
|
||||
return jnp.where(u <= v, (w - v) - u, (w - u) - v)
|
||||
|
||||
|
||||
def betaln(a: ArrayLike, b: ArrayLike) -> Array:
|
||||
"""Compute the log of the beta function.
|
||||
|
||||
Derived from scipy's implementation of `betaln`_.
|
||||
|
||||
This implementation does not handle all branches of the scipy implementation, but is still much more accurate
|
||||
than just doing lgamma(a) + lgamma(b) - lgamma(a + b) when inputs are large (> 1M or so).
|
||||
|
||||
.. _betaln:
|
||||
https://github.com/scipy/scipy/blob/ef2dee592ba8fb900ff2308b9d1c79e4d6a0ad8b/scipy/special/cdflib/betaln.f
|
||||
"""
|
||||
a, b = promote_args_inexact("betaln", a, b)
|
||||
a, b = jnp.minimum(a, b), jnp.maximum(a, b)
|
||||
small_b = lax.lgamma(a) + (lax.lgamma(b) - lax.lgamma(a + b))
|
||||
large_b = lax.lgamma(a) + algdiv(a, b)
|
||||
return jnp.where(b < 8, small_b, large_b)
|
||||
@@ -0,0 +1,173 @@
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy import (asarray, broadcast_arrays,
|
||||
empty, searchsorted, where, zeros)
|
||||
from jax._src.tree_util import register_pytree_node
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
|
||||
|
||||
|
||||
def _ndim_coords_from_arrays(points, ndim=None):
|
||||
"""Convert a tuple of coordinate arrays to a (..., ndim)-shaped array."""
|
||||
if isinstance(points, tuple) and len(points) == 1:
|
||||
# handle argument tuple
|
||||
points = points[0]
|
||||
if isinstance(points, tuple):
|
||||
p = broadcast_arrays(*points)
|
||||
for p_other in p[1:]:
|
||||
if p_other.shape != p[0].shape:
|
||||
raise ValueError("coordinate arrays do not have the same shape")
|
||||
points = empty(p[0].shape + (len(points),), dtype=float)
|
||||
for j, item in enumerate(p):
|
||||
points = points.at[..., j].set(item)
|
||||
else:
|
||||
check_arraylike("_ndim_coords_from_arrays", points)
|
||||
points = asarray(points) # SciPy: asanyarray(points)
|
||||
if points.ndim == 1:
|
||||
if ndim is None:
|
||||
points = points.reshape(-1, 1)
|
||||
else:
|
||||
points = points.reshape(-1, ndim)
|
||||
return points
|
||||
|
||||
|
||||
class RegularGridInterpolator:
|
||||
"""Interpolate points on a regular rectangular grid.
|
||||
|
||||
JAX implementation of :func:`scipy.interpolate.RegularGridInterpolator`.
|
||||
|
||||
Args:
|
||||
points: length-N sequence of arrays specifying the grid coordinates.
|
||||
values: N-dimensional array specifying the grid values.
|
||||
method: interpolation method, either ``"linear"`` or ``"nearest"``.
|
||||
bounds_error: not implemented by JAX
|
||||
fill_value: value returned for points outside the grid, defaults to NaN.
|
||||
|
||||
Returns:
|
||||
interpolator: callable interpolation object.
|
||||
|
||||
Examples:
|
||||
>>> points = (jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
|
||||
>>> values = jnp.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
|
||||
>>> interpolate = RegularGridInterpolator(points, values, method='linear')
|
||||
|
||||
>>> query_points = jnp.array([[1.5, 4.5], [2.2, 5.8]])
|
||||
>>> interpolate(query_points)
|
||||
Array([30., 64.], dtype=float32)
|
||||
"""
|
||||
# Based on SciPy's implementation which in turn is originally based on an
|
||||
# implementation by Johannes Buchner
|
||||
|
||||
def __init__(self,
|
||||
points,
|
||||
values,
|
||||
method="linear",
|
||||
bounds_error=False,
|
||||
fill_value=np.nan):
|
||||
if method not in ("linear", "nearest"):
|
||||
raise ValueError(f"method {method!r} is not defined")
|
||||
self.method = method
|
||||
self.bounds_error = bounds_error
|
||||
if self.bounds_error:
|
||||
raise NotImplementedError("`bounds_error` takes no effect under JIT")
|
||||
|
||||
check_arraylike("RegularGridInterpolator", values)
|
||||
if len(points) > values.ndim:
|
||||
ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions"
|
||||
raise ValueError(ve)
|
||||
|
||||
values, = promote_dtypes_inexact(values)
|
||||
|
||||
if fill_value is not None:
|
||||
check_arraylike("RegularGridInterpolator", fill_value)
|
||||
fill_value = asarray(fill_value)
|
||||
if not dtypes.can_cast(fill_value.dtype, values.dtype, casting='same_kind'):
|
||||
ve = "fill_value must be either 'None' or of a type compatible with values"
|
||||
raise ValueError(ve)
|
||||
self.fill_value = fill_value
|
||||
|
||||
# TODO: assert sanity of `points` similar to SciPy but in a JIT-able way
|
||||
check_arraylike("RegularGridInterpolator", *points)
|
||||
self.grid = tuple(asarray(p) for p in points)
|
||||
self.values = values
|
||||
|
||||
def __call__(self, xi, method=None):
|
||||
method = self.method if method is None else method
|
||||
if method not in ("linear", "nearest"):
|
||||
raise ValueError(f"method {method!r} is not defined")
|
||||
|
||||
ndim = len(self.grid)
|
||||
xi = _ndim_coords_from_arrays(xi, ndim=ndim)
|
||||
if xi.shape[-1] != len(self.grid):
|
||||
raise ValueError("the requested sample points xi have dimension"
|
||||
f" {xi.shape[1]}, but this RegularGridInterpolator has"
|
||||
f" dimension {ndim}")
|
||||
|
||||
xi_shape = xi.shape
|
||||
xi = xi.reshape(-1, xi_shape[-1])
|
||||
|
||||
indices, norm_distances, out_of_bounds = self._find_indices(xi.T)
|
||||
if method == "linear":
|
||||
result = self._evaluate_linear(indices, norm_distances)
|
||||
elif method == "nearest":
|
||||
result = self._evaluate_nearest(indices, norm_distances)
|
||||
else:
|
||||
raise AssertionError("method must be bound")
|
||||
if not self.bounds_error and self.fill_value is not None:
|
||||
bc_shp = result.shape[:1] + (1,) * (result.ndim - 1)
|
||||
result = where(out_of_bounds.reshape(bc_shp), self.fill_value, result)
|
||||
|
||||
return result.reshape(xi_shape[:-1] + self.values.shape[ndim:])
|
||||
|
||||
def _evaluate_linear(self, indices, norm_distances):
|
||||
# slice for broadcasting over trailing dimensions in self.values
|
||||
vslice = (slice(None),) + (None,) * (self.values.ndim - len(indices))
|
||||
|
||||
# find relevant values
|
||||
# each i and i+1 represents a edge
|
||||
edges = product(*[[i, i + 1] for i in indices])
|
||||
values = asarray(0.)
|
||||
for edge_indices in edges:
|
||||
weight = asarray(1.)
|
||||
for ei, i, yi in zip(edge_indices, indices, norm_distances):
|
||||
weight *= where(ei == i, 1 - yi, yi)
|
||||
values += self.values[edge_indices] * weight[vslice]
|
||||
return values
|
||||
|
||||
def _evaluate_nearest(self, indices, norm_distances):
|
||||
idx_res = [
|
||||
where(yi <= .5, i, i + 1) for i, yi in zip(indices, norm_distances)
|
||||
]
|
||||
return self.values[tuple(idx_res)]
|
||||
|
||||
def _find_indices(self, xi):
|
||||
# find relevant edges between which xi are situated
|
||||
indices = []
|
||||
# compute distance to lower edge in unity units
|
||||
norm_distances = []
|
||||
# check for out of bounds xi
|
||||
out_of_bounds = zeros((xi.shape[1],), dtype=bool)
|
||||
# iterate through dimensions
|
||||
for x, g in zip(xi, self.grid):
|
||||
i = searchsorted(g, x) - 1
|
||||
i = where(i < 0, 0, i)
|
||||
i = where(i > g.size - 2, g.size - 2, i)
|
||||
indices.append(i)
|
||||
norm_distances.append((x - g[i]) / (g[i + 1] - g[i]))
|
||||
if not self.bounds_error:
|
||||
out_of_bounds += x < g[0]
|
||||
out_of_bounds += x > g[-1]
|
||||
return indices, norm_distances, out_of_bounds
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
RegularGridInterpolator,
|
||||
lambda obj: ((obj.grid, obj.values, obj.fill_value),
|
||||
(obj.method, obj.bounds_error)),
|
||||
lambda aux, children: RegularGridInterpolator(
|
||||
*children[:2],
|
||||
*aux,
|
||||
*children[2:]),
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy.linalg import norm
|
||||
from jax._src.scipy.linalg import rsf2csf, schur
|
||||
from jax._src.typing import ArrayLike, Array
|
||||
|
||||
|
||||
@api.jit
|
||||
def _algorithm_11_1_1(F: Array, T: Array) -> tuple[Array, Array]:
|
||||
# Algorithm 11.1.1 from Golub and Van Loan "Matrix Computations"
|
||||
N = T.shape[0]
|
||||
minden = jnp.abs(T[0, 0])
|
||||
|
||||
def _outer_loop(p, F_minden):
|
||||
_, F, minden = lax.fori_loop(1, N-p+1, _inner_loop, (p, *F_minden))
|
||||
return F, minden
|
||||
|
||||
def _inner_loop(i, p_F_minden):
|
||||
p, F, minden = p_F_minden
|
||||
j = i+p
|
||||
s = T[i-1, j-1] * (F[j-1, j-1] - F[i-1, i-1])
|
||||
T_row, T_col = T[i-1], T[:, j-1]
|
||||
F_row, F_col = F[i-1], F[:, j-1]
|
||||
ind = (jnp.arange(N) >= i) & (jnp.arange(N) < j-1)
|
||||
val = (jnp.where(ind, T_row, 0) @ jnp.where(ind, F_col, 0) -
|
||||
jnp.where(ind, F_row, 0) @ jnp.where(ind, T_col, 0))
|
||||
s = s + val
|
||||
den = T[j-1, j-1] - T[i-1, i-1]
|
||||
s = jnp.where(den != 0, s / den, s)
|
||||
F = F.at[i-1, j-1].set(s)
|
||||
minden = jnp.minimum(minden, jnp.abs(den))
|
||||
return p, F, minden
|
||||
|
||||
return lax.fori_loop(1, N, _outer_loop, (F, minden))
|
||||
|
||||
|
||||
def funm(A: ArrayLike, func: Callable[[Array], Array],
|
||||
disp: bool = True) -> Array | tuple[Array, Array]:
|
||||
"""Evaluate a matrix-valued function
|
||||
|
||||
JAX implementation of :func:`scipy.linalg.funm`.
|
||||
|
||||
Args:
|
||||
A: array of shape ``(N, N)`` for which the function is to be computed.
|
||||
func: Callable object that takes a scalar argument and returns a scalar result.
|
||||
Represents the function to be evaluated over the eigenvalues of A.
|
||||
disp: If true (default), error information is not returned. Unlike scipy's version JAX
|
||||
does not attempt to display information at runtime.
|
||||
compute_expm: (N, N) array_like or None, optional.
|
||||
If provided, the matrix exponential of A. This is used for improving efficiency when `func`
|
||||
is the exponential function. If not provided, it is computed internally.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Array of same shape as ``A``, containing the result of ``func`` evaluated on the
|
||||
eigenvalues of ``A``.
|
||||
|
||||
Notes:
|
||||
The returned dtype of JAX's implementation may differ from that of scipy;
|
||||
specifically, in cases where all imaginary parts of the array values are
|
||||
close to zero, the SciPy function may return a real-valued array, whereas
|
||||
the JAX implementation will return a complex-valued array.
|
||||
|
||||
Examples:
|
||||
Applying an arbitrary matrix function:
|
||||
|
||||
>>> A = jnp.array([[1., 2.], [3., 4.]])
|
||||
>>> def func(x):
|
||||
... return jnp.sin(x) + 2 * jnp.cos(x)
|
||||
>>> jax.scipy.linalg.funm(A, func) # doctest: +SKIP
|
||||
Array([[ 1.2452652 +0.j, -0.3701772 +0.j],
|
||||
[-0.55526584+0.j, 0.6899995 +0.j]], dtype=complex64)
|
||||
|
||||
Comparing two ways of computing the matrix exponent:
|
||||
|
||||
>>> expA_1 = jax.scipy.linalg.funm(A, jnp.exp)
|
||||
>>> expA_2 = jax.scipy.linalg.expm(A)
|
||||
>>> jnp.allclose(expA_1, expA_2, rtol=1E-4)
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
A_arr = jnp.asarray(A)
|
||||
if A_arr.ndim != 2 or A_arr.shape[0] != A_arr.shape[1]:
|
||||
raise ValueError('expected square array_like input')
|
||||
|
||||
T, Z = schur(A_arr)
|
||||
T, Z = rsf2csf(T, Z)
|
||||
|
||||
F = jnp.diag(func(jnp.diag(T)))
|
||||
F = F.astype(T.dtype.char)
|
||||
|
||||
F, minden = _algorithm_11_1_1(F, T)
|
||||
F = Z @ F @ Z.conj().T
|
||||
|
||||
if disp:
|
||||
return F
|
||||
|
||||
if F.dtype.char.lower() == 'e':
|
||||
tol = dtypes.finfo(np.float16).eps
|
||||
if F.dtype.char.lower() == 'f':
|
||||
tol = dtypes.finfo(np.float32).eps
|
||||
else:
|
||||
tol = dtypes.finfo(np.float64).eps
|
||||
|
||||
minden = jnp.where(minden == 0.0, tol, minden)
|
||||
err = jnp.where(jnp.any(jnp.isinf(F)), np.inf, jnp.minimum(1, jnp.maximum(
|
||||
tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))
|
||||
|
||||
return F, err
|
||||
+93
@@ -0,0 +1,93 @@
|
||||
"""Utility functions adopted from scipy.signal."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
|
||||
|
||||
def _triage_segments(window: ArrayLike | str | tuple[Any, ...], nperseg: int | None,
|
||||
input_length: int, dtype: DTypeLike) -> tuple[Array, int]:
|
||||
"""
|
||||
Parses window and nperseg arguments for spectrogram and _spectral_helper.
|
||||
This is a helper function, not meant to be called externally.
|
||||
|
||||
Args:
|
||||
window : string, tuple, or ndarray
|
||||
If window is specified by a string or tuple and nperseg is not
|
||||
specified, nperseg is set to the default of 256 and returns a window of
|
||||
that length.
|
||||
If instead the window is array_like and nperseg is not specified, then
|
||||
nperseg is set to the length of the window. A ValueError is raised if
|
||||
the user supplies both an array_like window and a value for nperseg but
|
||||
nperseg does not equal the length of the window.
|
||||
nperseg : int
|
||||
Length of each segment
|
||||
input_length: int
|
||||
Length of input signal, i.e. x.shape[-1]. Used to test for errors.
|
||||
dtype: dtype for window if specified as a string or tuple. Not referenced
|
||||
if window is an array.
|
||||
|
||||
Returns:
|
||||
win : ndarray
|
||||
window. If function was called with string or tuple than this will hold
|
||||
the actual array used as a window.
|
||||
nperseg : int
|
||||
Length of each segment. If window is str or tuple, nperseg is set to
|
||||
256. If window is array_like, nperseg is set to the length of the window.
|
||||
"""
|
||||
if isinstance(window, (str, tuple)):
|
||||
nperseg_int = input_length if nperseg is None else int(nperseg)
|
||||
if nperseg_int > input_length:
|
||||
warnings.warn(f'nperseg={nperseg_int} is greater than {input_length=},'
|
||||
f' using nperseg={input_length}')
|
||||
nperseg_int = input_length
|
||||
if window == 'hann':
|
||||
# Implement the default case without scipy
|
||||
win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, np.pi, nperseg_int, endpoint=False)) ** 2
|
||||
else:
|
||||
# TODO(jakevdp): implement get_window() in JAX to remove optional scipy dependency
|
||||
try:
|
||||
from scipy.signal import get_window # pyrefly: ignore[missing-import]
|
||||
except ImportError as err:
|
||||
raise ImportError(f"scipy must be available to use {window=}") from err
|
||||
win = get_window(window, nperseg_int)
|
||||
win = jnp.array(win, dtype=dtype)
|
||||
else:
|
||||
win = jnp.asarray(window, dtype=dtype)
|
||||
nperseg_int = win.size if nperseg is None else int(nperseg)
|
||||
if win.ndim != 1:
|
||||
raise ValueError('window must be 1-D')
|
||||
if input_length < win.size:
|
||||
raise ValueError('window is longer than input signal')
|
||||
if nperseg_int != win.size:
|
||||
raise ValueError("value specified for nperseg is different from length of window")
|
||||
return win, nperseg_int
|
||||
|
||||
|
||||
def _median_bias(n: int) -> Array:
|
||||
"""
|
||||
Returns the bias of the median of a set of periodograms relative to
|
||||
the mean. See Appendix B from [1]_ for details.
|
||||
|
||||
Args:
|
||||
n : int
|
||||
Numbers of periodograms being averaged.
|
||||
|
||||
Returns:
|
||||
bias : float
|
||||
Calculated bias.
|
||||
|
||||
References:
|
||||
.. [1] B. Allen, W.G. Anderson, P.R. Brady, D.A. Brown, J.D.E. Creighton.
|
||||
"FINDCHIRP: an algorithm for detection of gravitational waves from
|
||||
inspiraling compact binaries", Physical Review D 85, 2012,
|
||||
:arxiv:`gr-qc/0509116`
|
||||
"""
|
||||
ii_2 = jnp.arange(2., n, 2)
|
||||
return 1 + jnp.sum(1. / (ii_2 + 1) - 1. / ii_2)
|
||||
@@ -0,0 +1,327 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src import custom_derivatives, dtypes
|
||||
from jax._src.numpy.util import promote_args_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
@api.jit
|
||||
def sincospisquaredhalf(
|
||||
x: Array,
|
||||
) -> tuple[Array, Array]:
|
||||
"""
|
||||
Accurate evaluation of sin(pi * x**2 / 2) and cos(pi * x**2 / 2).
|
||||
|
||||
As based on the sinpi and cospi functions from SciPy, see:
|
||||
- https://github.com/scipy/scipy/blob/v1.14.0/scipy/special/special/cephes/trig.h
|
||||
"""
|
||||
x = jnp.abs(x)
|
||||
# define s = x % 2, y = x - s, then
|
||||
# r = (x * x / 2) % 2
|
||||
# = [(y + s)*(y + s)/2] % 2
|
||||
# = [y*y/2 + s*y + s*s/2] % 2
|
||||
# = [(y*y/2)%2 + (s*y + s*s/2)%2]%2
|
||||
# = [0 + (s*(y+s/2))%2]%2
|
||||
# = [s*(x-s/2)]%2
|
||||
s = jnp.fmod(x, 2.0)
|
||||
r = jnp.fmod(s * (x - s / 2), 2.0)
|
||||
|
||||
sinpi = jnp.where(
|
||||
r < 0.5,
|
||||
jnp.sin(np.pi * r),
|
||||
jnp.where(
|
||||
r > 1.5,
|
||||
jnp.sin(np.pi * (r - 2.0)),
|
||||
-jnp.sin(np.pi * (r - 1.0)),
|
||||
),
|
||||
)
|
||||
cospi = jnp.where(
|
||||
r == 0.5,
|
||||
0.0,
|
||||
jnp.where(r < 1.0, -jnp.sin(np.pi * (r - 0.5)), jnp.sin(np.pi * (r - 1.5))),
|
||||
)
|
||||
|
||||
return sinpi, cospi
|
||||
|
||||
|
||||
@custom_derivatives.custom_jvp
|
||||
def fresnel(x: ArrayLike) -> tuple[Array, Array]:
|
||||
r"""The Fresnel integrals
|
||||
|
||||
JAX implementation of :obj:`scipy.special.fresnel`.
|
||||
|
||||
The Fresnel integrals are defined as
|
||||
.. math::
|
||||
S(x) &= \int_0^x \sin(\pi t^2 /2) dt \\
|
||||
C(x) &= \int_0^x \cos(\pi t^2 /2) dt.
|
||||
|
||||
Args:
|
||||
x: arraylike, real-valued.
|
||||
|
||||
Returns:
|
||||
Arrays containing the values of the Fresnel integrals.
|
||||
|
||||
Notes:
|
||||
The JAX version only supports real-valued inputs, and
|
||||
is based on the SciPy C++ implementation, see
|
||||
`here
|
||||
<https://github.com/scipy/scipy/blob/v1.14.0/scipy/special/special/cephes/fresnl.h>`_.
|
||||
For ``float32`` dtypes, the implementation is directly based
|
||||
on the Cephes implementation ``fresnlf``.
|
||||
|
||||
As for the original Cephes implementation, the accuracy
|
||||
is only guaranteed in the domain [-10, 10]. Outside of
|
||||
that domain, one could observe divergence between the
|
||||
theoretical derivatives and the custom JVP implementation,
|
||||
especially for large input values.
|
||||
|
||||
Finally, for half-precision data types, ``float16``
|
||||
and ``bfloat16``, the array elements are upcasted to
|
||||
``float32`` as the Cephes coefficients used in
|
||||
series expansions would otherwise lead to poor results.
|
||||
Other data types, like ``float8``, are not supported.
|
||||
"""
|
||||
|
||||
xxa, = promote_args_inexact("fresnel", x)
|
||||
original_dtype = xxa.dtype
|
||||
|
||||
# This part is mostly a direct translation of SciPy's C++ code,
|
||||
# and the original Cephes implementation for single precision.
|
||||
|
||||
if dtypes.issubdtype(xxa.dtype, np.complexfloating):
|
||||
raise NotImplementedError(
|
||||
'Support for complex-valued inputs is not implemented yet.')
|
||||
elif xxa.dtype in (np.float32, np.float16, dtypes.bfloat16):
|
||||
# Single-precision Cephes coefficients
|
||||
|
||||
# For half-precision, series expansions have either
|
||||
# produce overflow or poor accuracy.
|
||||
# Upcasting to single-precision is hence needed.
|
||||
xxa = xxa.astype(np.float32) # No-op for float32
|
||||
|
||||
fresnl_sn = jnp.array([
|
||||
+1.647629463788700e-9,
|
||||
-1.522754752581096e-7,
|
||||
+8.424748808502400e-6,
|
||||
-3.120693124703272e-4,
|
||||
+7.244727626597022e-3,
|
||||
-9.228055941124598e-2,
|
||||
+5.235987735681432e-1,
|
||||
], dtype=np.float32)
|
||||
|
||||
fresnl_cn = jnp.array([
|
||||
+1.416802502367354e-8,
|
||||
-1.157231412229871e-6,
|
||||
+5.387223446683264e-5,
|
||||
-1.604381798862293e-3,
|
||||
+2.818489036795073e-2,
|
||||
-2.467398198317899e-1,
|
||||
+9.999999760004487e-1,
|
||||
], dtype=np.float32)
|
||||
|
||||
fresnl_fn = jnp.array([
|
||||
-1.903009855649792e12,
|
||||
+1.355942388050252e11,
|
||||
-4.158143148511033e9,
|
||||
+7.343848463587323e7,
|
||||
-8.732356681548485e5,
|
||||
+8.560515466275470e3,
|
||||
-1.032877601091159e2,
|
||||
+2.999401847870011e0,
|
||||
], dtype=np.float32)
|
||||
|
||||
fresnl_gn = jnp.array([
|
||||
-1.860843997624650e11,
|
||||
+1.278350673393208e10,
|
||||
-3.779387713202229e8,
|
||||
+6.492611570598858e6,
|
||||
-7.787789623358162e4,
|
||||
+8.602931494734327e2,
|
||||
-1.493439396592284e1,
|
||||
+9.999841934744914e-1,
|
||||
], dtype=np.float32)
|
||||
|
||||
fresnl_cd = jnp.empty(0)
|
||||
fresnl_sd = jnp.empty(0)
|
||||
fresnl_fd = jnp.empty(0)
|
||||
fresnl_gd = jnp.empty(0)
|
||||
elif xxa.dtype == np.float64:
|
||||
# Double-precision Cephes coefficients
|
||||
|
||||
fresnl_sn = jnp.array([
|
||||
-2.99181919401019853726e3,
|
||||
+7.08840045257738576863e5,
|
||||
-6.29741486205862506537e7,
|
||||
+2.54890880573376359104e9,
|
||||
-4.42979518059697779103e10,
|
||||
+3.18016297876567817986e11,
|
||||
], dtype=np.float64)
|
||||
|
||||
fresnl_sd = jnp.array([
|
||||
+1.00000000000000000000e0,
|
||||
+2.81376268889994315696e2,
|
||||
+4.55847810806532581675e4,
|
||||
+5.17343888770096400730e6,
|
||||
+4.19320245898111231129e8,
|
||||
+2.24411795645340920940e10,
|
||||
+6.07366389490084639049e11,
|
||||
], dtype=np.float64)
|
||||
|
||||
fresnl_cn = jnp.array([
|
||||
-4.98843114573573548651e-8,
|
||||
+9.50428062829859605134e-6,
|
||||
-6.45191435683965050962e-4,
|
||||
+1.88843319396703850064e-2,
|
||||
-2.05525900955013891793e-1,
|
||||
+9.99999999999999998822e-1,
|
||||
], dtype=np.float64)
|
||||
|
||||
fresnl_cd = jnp.array([
|
||||
+3.99982968972495980367e-12,
|
||||
+9.15439215774657478799e-10,
|
||||
+1.25001862479598821474e-7,
|
||||
+1.22262789024179030997e-5,
|
||||
+8.68029542941784300606e-4,
|
||||
+4.12142090722199792936e-2,
|
||||
+1.00000000000000000118e0,
|
||||
], dtype=np.float64)
|
||||
|
||||
fresnl_fn = jnp.array([
|
||||
+4.21543555043677546506e-1,
|
||||
+1.43407919780758885261e-1,
|
||||
+1.15220955073585758835e-2,
|
||||
+3.45017939782574027900e-4,
|
||||
+4.63613749287867322088e-6,
|
||||
+3.05568983790257605827e-8,
|
||||
+1.02304514164907233465e-10,
|
||||
+1.72010743268161828879e-13,
|
||||
+1.34283276233062758925e-16,
|
||||
+3.76329711269987889006e-20,
|
||||
], dtype=np.float64)
|
||||
|
||||
fresnl_fd = jnp.array([
|
||||
+1.00000000000000000000e0,
|
||||
+7.51586398353378947175e-1,
|
||||
+1.16888925859191382142e-1,
|
||||
+6.44051526508858611005e-3,
|
||||
+1.55934409164153020873e-4,
|
||||
+1.84627567348930545870e-6,
|
||||
+1.12699224763999035261e-8,
|
||||
+3.60140029589371370404e-11,
|
||||
+5.88754533621578410010e-14,
|
||||
+4.52001434074129701496e-17,
|
||||
+1.25443237090011264384e-20,
|
||||
], dtype=np.float64)
|
||||
|
||||
fresnl_gn = jnp.array([
|
||||
+5.04442073643383265887e-1,
|
||||
+1.97102833525523411709e-1,
|
||||
+1.87648584092575249293e-2,
|
||||
+6.84079380915393090172e-4,
|
||||
+1.15138826111884280931e-5,
|
||||
+9.82852443688422223854e-8,
|
||||
+4.45344415861750144738e-10,
|
||||
+1.08268041139020870318e-12,
|
||||
+1.37555460633261799868e-15,
|
||||
+8.36354435630677421531e-19,
|
||||
+1.86958710162783235106e-22,
|
||||
], dtype=np.float64)
|
||||
|
||||
fresnl_gd = jnp.array([
|
||||
+1.00000000000000000000e0,
|
||||
+1.47495759925128324529e0,
|
||||
+3.37748989120019970451e-1,
|
||||
+2.53603741420338795122e-2,
|
||||
+8.14679107184306179049e-4,
|
||||
+1.27545075667729118702e-5,
|
||||
+1.04314589657571990585e-7,
|
||||
+4.60680728146520428211e-10,
|
||||
+1.10273215066240270757e-12,
|
||||
+1.38796531259578871258e-15,
|
||||
+8.39158816283118707363e-19,
|
||||
+1.86958710162783236342e-22,
|
||||
], dtype=np.float64)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Support for {xxa.dtype} dtype is not implemented yet.')
|
||||
|
||||
assert xxa.dtype in (np.float32, np.float64)
|
||||
single_precision = (xxa.dtype == np.float32)
|
||||
|
||||
x = jnp.abs(xxa)
|
||||
|
||||
x2 = x * x
|
||||
|
||||
# Infinite x values
|
||||
s_inf = c_inf = 0.5
|
||||
|
||||
# Small x values
|
||||
t = x2 * x2
|
||||
|
||||
if single_precision:
|
||||
s_small = x * x2 * jnp.polyval(fresnl_sn, t)
|
||||
c_small = x * jnp.polyval(fresnl_cn, t)
|
||||
else:
|
||||
s_small = x * x2 * jnp.polyval(fresnl_sn[:6], t) / jnp.polyval(fresnl_sd[:7], t)
|
||||
c_small = x * jnp.polyval(fresnl_cn[:6], t) / jnp.polyval(fresnl_cd[:7], t)
|
||||
|
||||
# Large x values
|
||||
|
||||
sinpi, cospi = sincospisquaredhalf(x)
|
||||
|
||||
if single_precision:
|
||||
c_large = c_inf
|
||||
s_large = s_inf
|
||||
else:
|
||||
c_large = 0.5 + 1 / (np.pi * x) * sinpi
|
||||
s_large = 0.5 - 1 / (np.pi * x) * cospi
|
||||
|
||||
# Other x values
|
||||
t = np.pi * x2
|
||||
u = 1.0 / (t * t)
|
||||
t = 1.0 / t
|
||||
|
||||
if single_precision:
|
||||
f = 1.0 - u * jnp.polyval(fresnl_fn, u)
|
||||
g = t * jnp.polyval(fresnl_gn, u)
|
||||
else:
|
||||
f = 1.0 - u * jnp.polyval(fresnl_fn, u) / jnp.polyval(fresnl_fd, u)
|
||||
g = t * jnp.polyval(fresnl_gn, u) / jnp.polyval(fresnl_gd, u)
|
||||
|
||||
t = np.pi * x
|
||||
c_other = 0.5 + (f * sinpi - g * cospi) / t
|
||||
s_other = 0.5 - (f * cospi + g * sinpi) / t
|
||||
|
||||
isinf = jnp.isinf(xxa)
|
||||
small = x2 < 2.5625
|
||||
large = x > 36974.0
|
||||
s = jnp.where(
|
||||
isinf, s_inf, jnp.where(small, s_small, jnp.where(large, s_large, s_other))
|
||||
)
|
||||
c = jnp.where(
|
||||
isinf, c_inf, jnp.where(small, c_small, jnp.where(large, c_large, c_other))
|
||||
)
|
||||
|
||||
neg = xxa < 0.0
|
||||
s = jnp.where(neg, -s, s)
|
||||
c = jnp.where(neg, -c, c)
|
||||
|
||||
if original_dtype != xxa.dtype:
|
||||
s = s.astype(original_dtype)
|
||||
c = c.astype(original_dtype)
|
||||
|
||||
return s, c
|
||||
|
||||
def _fresnel_jvp(primals, tangents):
|
||||
x, = primals
|
||||
x_dot, = tangents
|
||||
result = fresnel(x)
|
||||
sinpi, cospi = sincospisquaredhalf(x)
|
||||
dSdx = sinpi * x_dot
|
||||
dCdx = cospi * x_dot
|
||||
return result, (dSdx, dCdx)
|
||||
fresnel.defjvp(_fresnel_jvp)
|
||||
Reference in New Issue
Block a user