This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,63 @@
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.typing import Array, ArrayLike
from jax._src.numpy.util import promote_args_inexact
def algdiv(a: ArrayLike, b: ArrayLike) -> Array:
"""
Compute ``log(gamma(a))/log(gamma(a + b))`` when ``b >= 8``.
Derived from scipy's implementation of `algdiv`_.
This differs from the scipy implementation in that it assumes a <= b
because recomputing ``a, b = jnp.minimum(a, b), jnp.maximum(a, b)`` might
be expensive and this is only called by ``betaln``.
.. _algdiv:
https://github.com/scipy/scipy/blob/c89dfc2b90d993f2a8174e57e0cbc8fbe6f3ee19/scipy/special/cdflib/algdiv.f
"""
c0 = 0.833333333333333e-01
c1 = -0.277777777760991e-02
c2 = 0.793650666825390e-03
c3 = -0.595202931351870e-03
c4 = 0.837308034031215e-03
c5 = -0.165322962780713e-02
h = a / b
c = h / (1 + h)
x = h / (1 + h)
d = b + (a - 0.5)
# Set sN = (1 - x**n)/(1 - x)
x2 = x * x
s3 = 1.0 + (x + x2)
s5 = 1.0 + (x + x2 * s3)
s7 = 1.0 + (x + x2 * s5)
s9 = 1.0 + (x + x2 * s7)
s11 = 1.0 + (x + x2 * s9)
# Set w = del(b) - del(a + b)
# where del(x) is defined by ln(gamma(x)) = (x - 0.5)*ln(x) - x + 0.5*ln(2*pi) + del(x)
t = (1.0 / b) ** 2
w = ((((c5 * s11 * t + c4 * s9) * t + c3 * s7) * t + c2 * s5) * t + c1 * s3) * t + c0
w = w * (c / b)
# Combine the results
u = d * lax.log1p(a / b)
v = a * (lax.log(b) - 1.0)
return jnp.where(u <= v, (w - v) - u, (w - u) - v)
def betaln(a: ArrayLike, b: ArrayLike) -> Array:
"""Compute the log of the beta function.
Derived from scipy's implementation of `betaln`_.
This implementation does not handle all branches of the scipy implementation, but is still much more accurate
than just doing lgamma(a) + lgamma(b) - lgamma(a + b) when inputs are large (> 1M or so).
.. _betaln:
https://github.com/scipy/scipy/blob/ef2dee592ba8fb900ff2308b9d1c79e4d6a0ad8b/scipy/special/cdflib/betaln.f
"""
a, b = promote_args_inexact("betaln", a, b)
a, b = jnp.minimum(a, b), jnp.maximum(a, b)
small_b = lax.lgamma(a) + (lax.lgamma(b) - lax.lgamma(a + b))
large_b = lax.lgamma(a) + algdiv(a, b)
return jnp.where(b < 8, small_b, large_b)
@@ -0,0 +1,173 @@
from itertools import product
import numpy as np
from jax._src import dtypes
from jax._src.numpy import (asarray, broadcast_arrays,
empty, searchsorted, where, zeros)
from jax._src.tree_util import register_pytree_node
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
def _ndim_coords_from_arrays(points, ndim=None):
"""Convert a tuple of coordinate arrays to a (..., ndim)-shaped array."""
if isinstance(points, tuple) and len(points) == 1:
# handle argument tuple
points = points[0]
if isinstance(points, tuple):
p = broadcast_arrays(*points)
for p_other in p[1:]:
if p_other.shape != p[0].shape:
raise ValueError("coordinate arrays do not have the same shape")
points = empty(p[0].shape + (len(points),), dtype=float)
for j, item in enumerate(p):
points = points.at[..., j].set(item)
else:
check_arraylike("_ndim_coords_from_arrays", points)
points = asarray(points) # SciPy: asanyarray(points)
if points.ndim == 1:
if ndim is None:
points = points.reshape(-1, 1)
else:
points = points.reshape(-1, ndim)
return points
class RegularGridInterpolator:
"""Interpolate points on a regular rectangular grid.
JAX implementation of :func:`scipy.interpolate.RegularGridInterpolator`.
Args:
points: length-N sequence of arrays specifying the grid coordinates.
values: N-dimensional array specifying the grid values.
method: interpolation method, either ``"linear"`` or ``"nearest"``.
bounds_error: not implemented by JAX
fill_value: value returned for points outside the grid, defaults to NaN.
Returns:
interpolator: callable interpolation object.
Examples:
>>> points = (jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
>>> values = jnp.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
>>> interpolate = RegularGridInterpolator(points, values, method='linear')
>>> query_points = jnp.array([[1.5, 4.5], [2.2, 5.8]])
>>> interpolate(query_points)
Array([30., 64.], dtype=float32)
"""
# Based on SciPy's implementation which in turn is originally based on an
# implementation by Johannes Buchner
def __init__(self,
points,
values,
method="linear",
bounds_error=False,
fill_value=np.nan):
if method not in ("linear", "nearest"):
raise ValueError(f"method {method!r} is not defined")
self.method = method
self.bounds_error = bounds_error
if self.bounds_error:
raise NotImplementedError("`bounds_error` takes no effect under JIT")
check_arraylike("RegularGridInterpolator", values)
if len(points) > values.ndim:
ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions"
raise ValueError(ve)
values, = promote_dtypes_inexact(values)
if fill_value is not None:
check_arraylike("RegularGridInterpolator", fill_value)
fill_value = asarray(fill_value)
if not dtypes.can_cast(fill_value.dtype, values.dtype, casting='same_kind'):
ve = "fill_value must be either 'None' or of a type compatible with values"
raise ValueError(ve)
self.fill_value = fill_value
# TODO: assert sanity of `points` similar to SciPy but in a JIT-able way
check_arraylike("RegularGridInterpolator", *points)
self.grid = tuple(asarray(p) for p in points)
self.values = values
def __call__(self, xi, method=None):
method = self.method if method is None else method
if method not in ("linear", "nearest"):
raise ValueError(f"method {method!r} is not defined")
ndim = len(self.grid)
xi = _ndim_coords_from_arrays(xi, ndim=ndim)
if xi.shape[-1] != len(self.grid):
raise ValueError("the requested sample points xi have dimension"
f" {xi.shape[1]}, but this RegularGridInterpolator has"
f" dimension {ndim}")
xi_shape = xi.shape
xi = xi.reshape(-1, xi_shape[-1])
indices, norm_distances, out_of_bounds = self._find_indices(xi.T)
if method == "linear":
result = self._evaluate_linear(indices, norm_distances)
elif method == "nearest":
result = self._evaluate_nearest(indices, norm_distances)
else:
raise AssertionError("method must be bound")
if not self.bounds_error and self.fill_value is not None:
bc_shp = result.shape[:1] + (1,) * (result.ndim - 1)
result = where(out_of_bounds.reshape(bc_shp), self.fill_value, result)
return result.reshape(xi_shape[:-1] + self.values.shape[ndim:])
def _evaluate_linear(self, indices, norm_distances):
# slice for broadcasting over trailing dimensions in self.values
vslice = (slice(None),) + (None,) * (self.values.ndim - len(indices))
# find relevant values
# each i and i+1 represents a edge
edges = product(*[[i, i + 1] for i in indices])
values = asarray(0.)
for edge_indices in edges:
weight = asarray(1.)
for ei, i, yi in zip(edge_indices, indices, norm_distances):
weight *= where(ei == i, 1 - yi, yi)
values += self.values[edge_indices] * weight[vslice]
return values
def _evaluate_nearest(self, indices, norm_distances):
idx_res = [
where(yi <= .5, i, i + 1) for i, yi in zip(indices, norm_distances)
]
return self.values[tuple(idx_res)]
def _find_indices(self, xi):
# find relevant edges between which xi are situated
indices = []
# compute distance to lower edge in unity units
norm_distances = []
# check for out of bounds xi
out_of_bounds = zeros((xi.shape[1],), dtype=bool)
# iterate through dimensions
for x, g in zip(xi, self.grid):
i = searchsorted(g, x) - 1
i = where(i < 0, 0, i)
i = where(i > g.size - 2, g.size - 2, i)
indices.append(i)
norm_distances.append((x - g[i]) / (g[i + 1] - g[i]))
if not self.bounds_error:
out_of_bounds += x < g[0]
out_of_bounds += x > g[-1]
return indices, norm_distances, out_of_bounds
register_pytree_node(
RegularGridInterpolator,
lambda obj: ((obj.grid, obj.values, obj.fill_value),
(obj.method, obj.bounds_error)),
lambda aux, children: RegularGridInterpolator(
*children[:2],
*aux,
*children[2:]),
)
@@ -0,0 +1,116 @@
from __future__ import annotations
from collections.abc import Callable
import numpy as np
from jax._src import api
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.numpy.linalg import norm
from jax._src.scipy.linalg import rsf2csf, schur
from jax._src.typing import ArrayLike, Array
@api.jit
def _algorithm_11_1_1(F: Array, T: Array) -> tuple[Array, Array]:
# Algorithm 11.1.1 from Golub and Van Loan "Matrix Computations"
N = T.shape[0]
minden = jnp.abs(T[0, 0])
def _outer_loop(p, F_minden):
_, F, minden = lax.fori_loop(1, N-p+1, _inner_loop, (p, *F_minden))
return F, minden
def _inner_loop(i, p_F_minden):
p, F, minden = p_F_minden
j = i+p
s = T[i-1, j-1] * (F[j-1, j-1] - F[i-1, i-1])
T_row, T_col = T[i-1], T[:, j-1]
F_row, F_col = F[i-1], F[:, j-1]
ind = (jnp.arange(N) >= i) & (jnp.arange(N) < j-1)
val = (jnp.where(ind, T_row, 0) @ jnp.where(ind, F_col, 0) -
jnp.where(ind, F_row, 0) @ jnp.where(ind, T_col, 0))
s = s + val
den = T[j-1, j-1] - T[i-1, i-1]
s = jnp.where(den != 0, s / den, s)
F = F.at[i-1, j-1].set(s)
minden = jnp.minimum(minden, jnp.abs(den))
return p, F, minden
return lax.fori_loop(1, N, _outer_loop, (F, minden))
def funm(A: ArrayLike, func: Callable[[Array], Array],
disp: bool = True) -> Array | tuple[Array, Array]:
"""Evaluate a matrix-valued function
JAX implementation of :func:`scipy.linalg.funm`.
Args:
A: array of shape ``(N, N)`` for which the function is to be computed.
func: Callable object that takes a scalar argument and returns a scalar result.
Represents the function to be evaluated over the eigenvalues of A.
disp: If true (default), error information is not returned. Unlike scipy's version JAX
does not attempt to display information at runtime.
compute_expm: (N, N) array_like or None, optional.
If provided, the matrix exponential of A. This is used for improving efficiency when `func`
is the exponential function. If not provided, it is computed internally.
Defaults to None.
Returns:
Array of same shape as ``A``, containing the result of ``func`` evaluated on the
eigenvalues of ``A``.
Notes:
The returned dtype of JAX's implementation may differ from that of scipy;
specifically, in cases where all imaginary parts of the array values are
close to zero, the SciPy function may return a real-valued array, whereas
the JAX implementation will return a complex-valued array.
Examples:
Applying an arbitrary matrix function:
>>> A = jnp.array([[1., 2.], [3., 4.]])
>>> def func(x):
... return jnp.sin(x) + 2 * jnp.cos(x)
>>> jax.scipy.linalg.funm(A, func) # doctest: +SKIP
Array([[ 1.2452652 +0.j, -0.3701772 +0.j],
[-0.55526584+0.j, 0.6899995 +0.j]], dtype=complex64)
Comparing two ways of computing the matrix exponent:
>>> expA_1 = jax.scipy.linalg.funm(A, jnp.exp)
>>> expA_2 = jax.scipy.linalg.expm(A)
>>> jnp.allclose(expA_1, expA_2, rtol=1E-4)
Array(True, dtype=bool)
"""
A_arr = jnp.asarray(A)
if A_arr.ndim != 2 or A_arr.shape[0] != A_arr.shape[1]:
raise ValueError('expected square array_like input')
T, Z = schur(A_arr)
T, Z = rsf2csf(T, Z)
F = jnp.diag(func(jnp.diag(T)))
F = F.astype(T.dtype.char)
F, minden = _algorithm_11_1_1(F, T)
F = Z @ F @ Z.conj().T
if disp:
return F
if F.dtype.char.lower() == 'e':
tol = dtypes.finfo(np.float16).eps
if F.dtype.char.lower() == 'f':
tol = dtypes.finfo(np.float32).eps
else:
tol = dtypes.finfo(np.float64).eps
minden = jnp.where(minden == 0.0, tol, minden)
err = jnp.where(jnp.any(jnp.isinf(F)), np.inf, jnp.minimum(1, jnp.maximum(
tol, (tol / minden) * norm(jnp.triu(T, 1), 1))))
return F, err
@@ -0,0 +1,93 @@
"""Utility functions adopted from scipy.signal."""
from __future__ import annotations
from typing import Any
import warnings
import numpy as np
from jax._src import numpy as jnp
from jax._src.typing import Array, ArrayLike, DTypeLike
def _triage_segments(window: ArrayLike | str | tuple[Any, ...], nperseg: int | None,
input_length: int, dtype: DTypeLike) -> tuple[Array, int]:
"""
Parses window and nperseg arguments for spectrogram and _spectral_helper.
This is a helper function, not meant to be called externally.
Args:
window : string, tuple, or ndarray
If window is specified by a string or tuple and nperseg is not
specified, nperseg is set to the default of 256 and returns a window of
that length.
If instead the window is array_like and nperseg is not specified, then
nperseg is set to the length of the window. A ValueError is raised if
the user supplies both an array_like window and a value for nperseg but
nperseg does not equal the length of the window.
nperseg : int
Length of each segment
input_length: int
Length of input signal, i.e. x.shape[-1]. Used to test for errors.
dtype: dtype for window if specified as a string or tuple. Not referenced
if window is an array.
Returns:
win : ndarray
window. If function was called with string or tuple than this will hold
the actual array used as a window.
nperseg : int
Length of each segment. If window is str or tuple, nperseg is set to
256. If window is array_like, nperseg is set to the length of the window.
"""
if isinstance(window, (str, tuple)):
nperseg_int = input_length if nperseg is None else int(nperseg)
if nperseg_int > input_length:
warnings.warn(f'nperseg={nperseg_int} is greater than {input_length=},'
f' using nperseg={input_length}')
nperseg_int = input_length
if window == 'hann':
# Implement the default case without scipy
win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, np.pi, nperseg_int, endpoint=False)) ** 2
else:
# TODO(jakevdp): implement get_window() in JAX to remove optional scipy dependency
try:
from scipy.signal import get_window # pyrefly: ignore[missing-import]
except ImportError as err:
raise ImportError(f"scipy must be available to use {window=}") from err
win = get_window(window, nperseg_int)
win = jnp.array(win, dtype=dtype)
else:
win = jnp.asarray(window, dtype=dtype)
nperseg_int = win.size if nperseg is None else int(nperseg)
if win.ndim != 1:
raise ValueError('window must be 1-D')
if input_length < win.size:
raise ValueError('window is longer than input signal')
if nperseg_int != win.size:
raise ValueError("value specified for nperseg is different from length of window")
return win, nperseg_int
def _median_bias(n: int) -> Array:
"""
Returns the bias of the median of a set of periodograms relative to
the mean. See Appendix B from [1]_ for details.
Args:
n : int
Numbers of periodograms being averaged.
Returns:
bias : float
Calculated bias.
References:
.. [1] B. Allen, W.G. Anderson, P.R. Brady, D.A. Brown, J.D.E. Creighton.
"FINDCHIRP: an algorithm for detection of gravitational waves from
inspiraling compact binaries", Physical Review D 85, 2012,
:arxiv:`gr-qc/0509116`
"""
ii_2 = jnp.arange(2., n, 2)
return 1 + jnp.sum(1. / (ii_2 + 1) - 1. / ii_2)
@@ -0,0 +1,327 @@
from __future__ import annotations
import numpy as np
from jax._src import api
from jax._src import numpy as jnp
from jax._src import custom_derivatives, dtypes
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike
@api.jit
def sincospisquaredhalf(
x: Array,
) -> tuple[Array, Array]:
"""
Accurate evaluation of sin(pi * x**2 / 2) and cos(pi * x**2 / 2).
As based on the sinpi and cospi functions from SciPy, see:
- https://github.com/scipy/scipy/blob/v1.14.0/scipy/special/special/cephes/trig.h
"""
x = jnp.abs(x)
# define s = x % 2, y = x - s, then
# r = (x * x / 2) % 2
# = [(y + s)*(y + s)/2] % 2
# = [y*y/2 + s*y + s*s/2] % 2
# = [(y*y/2)%2 + (s*y + s*s/2)%2]%2
# = [0 + (s*(y+s/2))%2]%2
# = [s*(x-s/2)]%2
s = jnp.fmod(x, 2.0)
r = jnp.fmod(s * (x - s / 2), 2.0)
sinpi = jnp.where(
r < 0.5,
jnp.sin(np.pi * r),
jnp.where(
r > 1.5,
jnp.sin(np.pi * (r - 2.0)),
-jnp.sin(np.pi * (r - 1.0)),
),
)
cospi = jnp.where(
r == 0.5,
0.0,
jnp.where(r < 1.0, -jnp.sin(np.pi * (r - 0.5)), jnp.sin(np.pi * (r - 1.5))),
)
return sinpi, cospi
@custom_derivatives.custom_jvp
def fresnel(x: ArrayLike) -> tuple[Array, Array]:
r"""The Fresnel integrals
JAX implementation of :obj:`scipy.special.fresnel`.
The Fresnel integrals are defined as
.. math::
S(x) &= \int_0^x \sin(\pi t^2 /2) dt \\
C(x) &= \int_0^x \cos(\pi t^2 /2) dt.
Args:
x: arraylike, real-valued.
Returns:
Arrays containing the values of the Fresnel integrals.
Notes:
The JAX version only supports real-valued inputs, and
is based on the SciPy C++ implementation, see
`here
<https://github.com/scipy/scipy/blob/v1.14.0/scipy/special/special/cephes/fresnl.h>`_.
For ``float32`` dtypes, the implementation is directly based
on the Cephes implementation ``fresnlf``.
As for the original Cephes implementation, the accuracy
is only guaranteed in the domain [-10, 10]. Outside of
that domain, one could observe divergence between the
theoretical derivatives and the custom JVP implementation,
especially for large input values.
Finally, for half-precision data types, ``float16``
and ``bfloat16``, the array elements are upcasted to
``float32`` as the Cephes coefficients used in
series expansions would otherwise lead to poor results.
Other data types, like ``float8``, are not supported.
"""
xxa, = promote_args_inexact("fresnel", x)
original_dtype = xxa.dtype
# This part is mostly a direct translation of SciPy's C++ code,
# and the original Cephes implementation for single precision.
if dtypes.issubdtype(xxa.dtype, np.complexfloating):
raise NotImplementedError(
'Support for complex-valued inputs is not implemented yet.')
elif xxa.dtype in (np.float32, np.float16, dtypes.bfloat16):
# Single-precision Cephes coefficients
# For half-precision, series expansions have either
# produce overflow or poor accuracy.
# Upcasting to single-precision is hence needed.
xxa = xxa.astype(np.float32) # No-op for float32
fresnl_sn = jnp.array([
+1.647629463788700e-9,
-1.522754752581096e-7,
+8.424748808502400e-6,
-3.120693124703272e-4,
+7.244727626597022e-3,
-9.228055941124598e-2,
+5.235987735681432e-1,
], dtype=np.float32)
fresnl_cn = jnp.array([
+1.416802502367354e-8,
-1.157231412229871e-6,
+5.387223446683264e-5,
-1.604381798862293e-3,
+2.818489036795073e-2,
-2.467398198317899e-1,
+9.999999760004487e-1,
], dtype=np.float32)
fresnl_fn = jnp.array([
-1.903009855649792e12,
+1.355942388050252e11,
-4.158143148511033e9,
+7.343848463587323e7,
-8.732356681548485e5,
+8.560515466275470e3,
-1.032877601091159e2,
+2.999401847870011e0,
], dtype=np.float32)
fresnl_gn = jnp.array([
-1.860843997624650e11,
+1.278350673393208e10,
-3.779387713202229e8,
+6.492611570598858e6,
-7.787789623358162e4,
+8.602931494734327e2,
-1.493439396592284e1,
+9.999841934744914e-1,
], dtype=np.float32)
fresnl_cd = jnp.empty(0)
fresnl_sd = jnp.empty(0)
fresnl_fd = jnp.empty(0)
fresnl_gd = jnp.empty(0)
elif xxa.dtype == np.float64:
# Double-precision Cephes coefficients
fresnl_sn = jnp.array([
-2.99181919401019853726e3,
+7.08840045257738576863e5,
-6.29741486205862506537e7,
+2.54890880573376359104e9,
-4.42979518059697779103e10,
+3.18016297876567817986e11,
], dtype=np.float64)
fresnl_sd = jnp.array([
+1.00000000000000000000e0,
+2.81376268889994315696e2,
+4.55847810806532581675e4,
+5.17343888770096400730e6,
+4.19320245898111231129e8,
+2.24411795645340920940e10,
+6.07366389490084639049e11,
], dtype=np.float64)
fresnl_cn = jnp.array([
-4.98843114573573548651e-8,
+9.50428062829859605134e-6,
-6.45191435683965050962e-4,
+1.88843319396703850064e-2,
-2.05525900955013891793e-1,
+9.99999999999999998822e-1,
], dtype=np.float64)
fresnl_cd = jnp.array([
+3.99982968972495980367e-12,
+9.15439215774657478799e-10,
+1.25001862479598821474e-7,
+1.22262789024179030997e-5,
+8.68029542941784300606e-4,
+4.12142090722199792936e-2,
+1.00000000000000000118e0,
], dtype=np.float64)
fresnl_fn = jnp.array([
+4.21543555043677546506e-1,
+1.43407919780758885261e-1,
+1.15220955073585758835e-2,
+3.45017939782574027900e-4,
+4.63613749287867322088e-6,
+3.05568983790257605827e-8,
+1.02304514164907233465e-10,
+1.72010743268161828879e-13,
+1.34283276233062758925e-16,
+3.76329711269987889006e-20,
], dtype=np.float64)
fresnl_fd = jnp.array([
+1.00000000000000000000e0,
+7.51586398353378947175e-1,
+1.16888925859191382142e-1,
+6.44051526508858611005e-3,
+1.55934409164153020873e-4,
+1.84627567348930545870e-6,
+1.12699224763999035261e-8,
+3.60140029589371370404e-11,
+5.88754533621578410010e-14,
+4.52001434074129701496e-17,
+1.25443237090011264384e-20,
], dtype=np.float64)
fresnl_gn = jnp.array([
+5.04442073643383265887e-1,
+1.97102833525523411709e-1,
+1.87648584092575249293e-2,
+6.84079380915393090172e-4,
+1.15138826111884280931e-5,
+9.82852443688422223854e-8,
+4.45344415861750144738e-10,
+1.08268041139020870318e-12,
+1.37555460633261799868e-15,
+8.36354435630677421531e-19,
+1.86958710162783235106e-22,
], dtype=np.float64)
fresnl_gd = jnp.array([
+1.00000000000000000000e0,
+1.47495759925128324529e0,
+3.37748989120019970451e-1,
+2.53603741420338795122e-2,
+8.14679107184306179049e-4,
+1.27545075667729118702e-5,
+1.04314589657571990585e-7,
+4.60680728146520428211e-10,
+1.10273215066240270757e-12,
+1.38796531259578871258e-15,
+8.39158816283118707363e-19,
+1.86958710162783236342e-22,
], dtype=np.float64)
else:
raise NotImplementedError(
f'Support for {xxa.dtype} dtype is not implemented yet.')
assert xxa.dtype in (np.float32, np.float64)
single_precision = (xxa.dtype == np.float32)
x = jnp.abs(xxa)
x2 = x * x
# Infinite x values
s_inf = c_inf = 0.5
# Small x values
t = x2 * x2
if single_precision:
s_small = x * x2 * jnp.polyval(fresnl_sn, t)
c_small = x * jnp.polyval(fresnl_cn, t)
else:
s_small = x * x2 * jnp.polyval(fresnl_sn[:6], t) / jnp.polyval(fresnl_sd[:7], t)
c_small = x * jnp.polyval(fresnl_cn[:6], t) / jnp.polyval(fresnl_cd[:7], t)
# Large x values
sinpi, cospi = sincospisquaredhalf(x)
if single_precision:
c_large = c_inf
s_large = s_inf
else:
c_large = 0.5 + 1 / (np.pi * x) * sinpi
s_large = 0.5 - 1 / (np.pi * x) * cospi
# Other x values
t = np.pi * x2
u = 1.0 / (t * t)
t = 1.0 / t
if single_precision:
f = 1.0 - u * jnp.polyval(fresnl_fn, u)
g = t * jnp.polyval(fresnl_gn, u)
else:
f = 1.0 - u * jnp.polyval(fresnl_fn, u) / jnp.polyval(fresnl_fd, u)
g = t * jnp.polyval(fresnl_gn, u) / jnp.polyval(fresnl_gd, u)
t = np.pi * x
c_other = 0.5 + (f * sinpi - g * cospi) / t
s_other = 0.5 - (f * cospi + g * sinpi) / t
isinf = jnp.isinf(xxa)
small = x2 < 2.5625
large = x > 36974.0
s = jnp.where(
isinf, s_inf, jnp.where(small, s_small, jnp.where(large, s_large, s_other))
)
c = jnp.where(
isinf, c_inf, jnp.where(small, c_small, jnp.where(large, c_large, c_other))
)
neg = xxa < 0.0
s = jnp.where(neg, -s, s)
c = jnp.where(neg, -c, c)
if original_dtype != xxa.dtype:
s = s.astype(original_dtype)
c = c.astype(original_dtype)
return s, c
def _fresnel_jvp(primals, tangents):
x, = primals
x_dot, = tangents
result = fresnel(x)
sinpi, cospi = sincospisquaredhalf(x)
dSdx = sinpi * x_dot
dCdx = cospi * x_dot
return result, (dSdx, dCdx)
fresnel.defjvp(_fresnel_jvp)