hand
This commit is contained in:
@@ -0,0 +1,502 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
import math
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import fft as jnp_fft
|
||||
from jax._src.numpy.util import (
|
||||
promote_dtypes_complex, promote_dtypes_inexact, ensure_arraylike)
|
||||
from jax._src.util import canonicalize_axis, canonicalize_axis_tuple
|
||||
from jax._src.typing import Array
|
||||
|
||||
def _W4(N: int, k: Array) -> Array:
|
||||
N_arr, k = promote_dtypes_complex(N, k)
|
||||
return jnp.exp(-.5j * np.pi * k / N_arr)
|
||||
|
||||
def _dct_interleave(x: Array, axis: int) -> Array:
|
||||
v0 = lax.slice_in_dim(x, None, None, 2, axis)
|
||||
v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis,))
|
||||
return lax.concatenate([v0, v1], axis)
|
||||
|
||||
def _dct_ortho_norm(out: Array, axis: int) -> Array:
|
||||
factor = lax.concatenate([lax.full((1,), 4, out.dtype), lax.full((out.shape[axis] - 1,), 2, out.dtype)], 0)
|
||||
factor = lax.expand_dims(factor, [a for a in range(out.ndim) if a != axis])
|
||||
return out / lax.sqrt(factor * out.shape[axis])
|
||||
|
||||
# Implementation based on
|
||||
# John Makhoul: A Fast Cosine Transform in One and Two Dimensions (1980)
|
||||
|
||||
|
||||
def dct(x: Array, type: int = 2, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
"""Computes the discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.dct`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
n: integer, default = x.shape[axis]. The length of the transform.
|
||||
If larger than ``x.shape[axis]``, the input will be zero-padded, if
|
||||
smaller, the input will be truncated.
|
||||
axis: integer, default=-1. The axis along which the dct will be performed.
|
||||
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
|
||||
The default is ``None``, which is equivalent to ``"backward"``.
|
||||
|
||||
Returns:
|
||||
array containing the discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: inverse DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
|
||||
Examples:
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x))
|
||||
[[ 6.43 3.56 -2.86]
|
||||
[-1.75 1.55 -1.4 ]
|
||||
[ 1.33 -2.01 -0.82]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=2))
|
||||
[[ 7.3 -0.57]
|
||||
[ 0.19 -0.36]
|
||||
[-0. -1.4 ]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=2, axis=0))
|
||||
[[ 3.09 4.4 -2.81]
|
||||
[ 2.41 2.62 0.76]]
|
||||
|
||||
When ``n`` larger than ``x.shape[axis]`` and ``axis=1``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dct(x, n=4, axis=1))
|
||||
[[ 6.43 4.88 0.04 -3.3 ]
|
||||
[-1.75 0.73 1.01 -2.18]
|
||||
[ 1.33 -1.05 -2.34 -0.07]]
|
||||
"""
|
||||
x = ensure_arraylike("dct", x)
|
||||
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
if norm is not None and norm not in ['backward', 'ortho']:
|
||||
raise ValueError(f"jax.scipy.fft.dct: {norm=!r} is not implemented")
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return lax.complex(
|
||||
dct(x.real, type=type, n=n, norm=norm, axis=axis),
|
||||
dct(x.imag, type=type, n=n, norm=norm, axis=axis),
|
||||
)
|
||||
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
if n is not None:
|
||||
x = lax.pad(x, jnp.array(0, x.dtype),
|
||||
[(0, n - x.shape[axis] if a == axis else 0, 0)
|
||||
for a in range(x.ndim)])
|
||||
|
||||
N = x.shape[axis]
|
||||
v = _dct_interleave(x, axis)
|
||||
V = jnp_fft.fft(v, axis=axis)
|
||||
k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis])
|
||||
out = V * _W4(N, k)
|
||||
out = 2 * out.real
|
||||
if norm == 'ortho':
|
||||
out = _dct_ortho_norm(out, axis)
|
||||
return out
|
||||
|
||||
|
||||
def _dct2(x: Array, axes: Sequence[int], norm: str | None) -> Array:
|
||||
axis1, axis2 = map(partial(canonicalize_axis, num_dims=x.ndim), axes)
|
||||
N1, N2 = x.shape[axis1], x.shape[axis2]
|
||||
v = _dct_interleave(_dct_interleave(x, axis1), axis2)
|
||||
V = jnp_fft.fftn(v, axes=axes)
|
||||
k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype),
|
||||
[a for a in range(x.ndim) if a != axis1])
|
||||
k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype),
|
||||
[a for a in range(x.ndim) if a != axis2])
|
||||
out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2))
|
||||
out = 2 * out.real
|
||||
if norm == 'ortho':
|
||||
return _dct_ortho_norm(_dct_ortho_norm(out, axis1), axis2)
|
||||
return out
|
||||
|
||||
|
||||
def dctn(x: Array, type: int = 2,
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
"""Computes the multidimensional discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.dctn`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
s: integer or sequence of integers. Specifies the shape of the result. If
|
||||
not specified, it will default to the shape of ``x`` along the specified
|
||||
``axes``.
|
||||
axes: integer or sequence of integers. Specifies the axes along which the
|
||||
transform will be computed. If not given, the last ``len(s)`` axes are
|
||||
used, or all axes if ``s`` is also not specified.
|
||||
norm: string. The normalization mode: one of
|
||||
``[None, "backward", "ortho"]``. The default is ``None``, which is
|
||||
equivalent to ``"backward"``.
|
||||
|
||||
Returns:
|
||||
array containing the discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
|
||||
Examples:
|
||||
|
||||
``jax.scipy.fft.dctn`` computes the transform along both the axes by default
|
||||
when ``axes`` argument is ``None`` and ``s`` is also ``None``.
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dctn(x))
|
||||
[[ 12.01 6.2 -10.17]
|
||||
[ 8.84 9.65 -3.54]
|
||||
[ 11.25 -1.54 -0.88]]
|
||||
|
||||
When ``s=[2]``, the transform will be computed only along the last axis,
|
||||
with its dimension padded or truncated to size ``2``:
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dctn(x, s=[2]))
|
||||
[[ 7.3 -0.57]
|
||||
[ 0.19 -0.36]
|
||||
[-0. -1.4 ]]
|
||||
|
||||
When ``s=[2]`` and ``axes=[0]``, the transform will be computed only along
|
||||
the specified axis, with its dimension padded or truncated to size ``2``:
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dctn(x, s=[2], axes=[0]))
|
||||
[[ 3.09 4.4 -2.81]
|
||||
[ 2.41 2.62 0.76]]
|
||||
|
||||
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``.
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.dctn(x, s=[2, 4]))
|
||||
[[ 9.36 11.23 2.12 -10.97]
|
||||
[ 11.57 5.86 -1.37 -1.58]]
|
||||
"""
|
||||
x = ensure_arraylike("dctn", x)
|
||||
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
if norm is not None and norm not in ['backward', 'ortho']:
|
||||
raise ValueError(f"jax.scipy.fft.dctn: {norm=!r} is not implemented")
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return lax.complex(
|
||||
dctn(x.real, type=type, s=s, norm=norm, axes=axes),
|
||||
dctn(x.imag, type=type, s=s, norm=norm, axes=axes),
|
||||
)
|
||||
|
||||
if s is not None:
|
||||
try:
|
||||
s = list(s)
|
||||
except TypeError:
|
||||
assert not isinstance(s, Sequence)
|
||||
s = [operator.index(s)]
|
||||
if len(s) > x.ndim:
|
||||
raise ValueError(
|
||||
f"s must have at most x.ndim ({x.ndim}) elements, got {len(s)}"
|
||||
)
|
||||
|
||||
if axes is None:
|
||||
if s is not None:
|
||||
axes = tuple(range(x.ndim - len(s), x.ndim))
|
||||
else:
|
||||
axes = tuple(range(x.ndim))
|
||||
else:
|
||||
axes = canonicalize_axis_tuple(axes, x.ndim)
|
||||
|
||||
if len(axes) == 1:
|
||||
return dct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
|
||||
|
||||
if s is not None:
|
||||
ns = dict(zip(axes, s))
|
||||
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
|
||||
x = lax.pad(x, jnp.array(0, x.dtype), pads)
|
||||
|
||||
if len(axes) == 2:
|
||||
return _dct2(x, axes=axes, norm=norm)
|
||||
|
||||
# compose high-D DCTs from 2D and 1D DCTs:
|
||||
for axes_block in [axes[i:i+2] for i in range(0, len(axes), 2)]:
|
||||
x = dctn(x, axes=axes_block, norm=norm)
|
||||
return x
|
||||
|
||||
|
||||
def idct(x: Array, type: int = 2, n: int | None = None,
|
||||
axis: int = -1, norm: str | None = None) -> Array:
|
||||
"""Computes the inverse discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.idct`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
n: integer, default = x.shape[axis]. The length of the transform.
|
||||
If larger than ``x.shape[axis]``, the input will be zero-padded, if
|
||||
smaller, the input will be truncated.
|
||||
axis: integer, default=-1. The axis along which the dct will be performed.
|
||||
norm: string. The normalization mode: one of ``[None, "backward", "ortho"]``.
|
||||
The default is ``None``, which is equivalent to ``"backward"``.
|
||||
|
||||
Returns:
|
||||
array containing the inverse discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: DCT
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idctn`: multidimensional inverse DCT
|
||||
|
||||
Examples:
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x))
|
||||
[[ 0.78 0.41 -0.39]
|
||||
[-0.12 0.31 -0.23]
|
||||
[ 0.17 -0.3 -0.11]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=2))
|
||||
[[ 1.12 -0.31]
|
||||
[ 0.04 -0.08]
|
||||
[ 0.05 -0.3 ]]
|
||||
|
||||
When ``n`` smaller than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=2, axis=0))
|
||||
[[ 0.38 0.57 -0.45]
|
||||
[ 0.43 0.44 0.24]]
|
||||
|
||||
When ``n`` larger than ``x.shape[axis]`` and ``axis=0``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idct(x, n=4, axis=0))
|
||||
[[ 0.1 0.38 -0.16]
|
||||
[ 0.28 0.18 -0.26]
|
||||
[ 0.3 0.15 -0.08]
|
||||
[ 0.13 0.3 0.29]]
|
||||
|
||||
``jax.scipy.fft.idct`` can be used to reconstruct ``x`` from the result
|
||||
of ``jax.scipy.fft.dct``
|
||||
|
||||
>>> x_dct = jax.scipy.fft.dct(x)
|
||||
>>> jnp.allclose(x, jax.scipy.fft.idct(x_dct))
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
x = ensure_arraylike("idct", x)
|
||||
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
if norm is not None and norm not in ['backward', 'ortho']:
|
||||
raise ValueError(f"jax.scipy.fft.idct: {norm=!r} is not implemented")
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return lax.complex(
|
||||
idct(x.real, type=type, n=n, norm=norm, axis=axis),
|
||||
idct(x.imag, type=type, n=n, norm=norm, axis=axis)
|
||||
)
|
||||
|
||||
axis = canonicalize_axis(axis, x.ndim)
|
||||
if n is not None:
|
||||
x = lax.pad(x, jnp.array(0, x.dtype),
|
||||
[(0, n - x.shape[axis] if a == axis else 0, 0)
|
||||
for a in range(x.ndim)])
|
||||
N = x.shape[axis]
|
||||
x, = promote_dtypes_inexact(x)
|
||||
if norm is None or norm == 'backward':
|
||||
x = _dct_ortho_norm(x, axis)
|
||||
x = _dct_ortho_norm(x, axis)
|
||||
|
||||
k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis])
|
||||
# everything is complex from here...
|
||||
w4 = _W4(N,k)
|
||||
x = x.astype(w4.dtype)
|
||||
x = x / (_W4(N, k))
|
||||
x = x * 2 * N
|
||||
|
||||
x = jnp_fft.ifft(x, axis=axis)
|
||||
# convert back to reals..
|
||||
out = _dct_deinterleave(x.real, axis)
|
||||
return out
|
||||
|
||||
|
||||
def idctn(x: Array, type: int = 2,
|
||||
s: Sequence[int] | None=None,
|
||||
axes: Sequence[int] | None = None,
|
||||
norm: str | None = None) -> Array:
|
||||
"""Computes the multidimensional inverse discrete cosine transform of the input
|
||||
|
||||
JAX implementation of :func:`scipy.fft.idctn`.
|
||||
|
||||
Args:
|
||||
x: array
|
||||
type: integer, default = 2. Currently only type 2 is supported.
|
||||
s: integer or sequence of integers. Specifies the shape of the result. If
|
||||
not specified, it will default to the shape of ``x`` along the specified
|
||||
``axes``.
|
||||
axes: integer or sequence of integers. Specifies the axes along which the
|
||||
transform will be computed. If not given, the last ``len(s)`` axes are
|
||||
used, or all axes if ``s`` is also not specified.
|
||||
norm: string. The normalization mode: one of
|
||||
``[None, "backward", "ortho"]``. The default is ``None``, which is
|
||||
equivalent to ``"backward"``.
|
||||
|
||||
Returns:
|
||||
array containing the inverse discrete cosine transform of x
|
||||
|
||||
See Also:
|
||||
- :func:`jax.scipy.fft.dct`: one-dimensional DCT
|
||||
- :func:`jax.scipy.fft.dctn`: multidimensional DCT
|
||||
- :func:`jax.scipy.fft.idct`: one-dimensional inverse DCT
|
||||
|
||||
Examples:
|
||||
|
||||
``jax.scipy.fft.idctn`` computes the transform along both the axes by
|
||||
default when ``axes`` argument is ``None`` and ``s`` is also ``None``.
|
||||
|
||||
>>> x = jax.random.normal(jax.random.key(0), (3, 3))
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x))
|
||||
[[ 0.12 0.11 -0.15]
|
||||
[ 0.07 0.17 -0.03]
|
||||
[ 0.19 -0.07 -0.02]]
|
||||
|
||||
When ``s=[2]``, the transform will be computed only along the last axis,
|
||||
with its dimension padded or truncated to size ``2``:
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2]))
|
||||
[[ 1.12 -0.31]
|
||||
[ 0.04 -0.08]
|
||||
[ 0.05 -0.3 ]]
|
||||
|
||||
When ``s=[2]`` and ``axes=[0]``, the transform will be computed only along
|
||||
the specified axis, with its dimension padded or truncated to size ``2``:
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2], axes=[0]))
|
||||
[[ 0.38 0.57 -0.45]
|
||||
[ 0.43 0.44 0.24]]
|
||||
|
||||
When ``s=[2, 4]``, shape of the transform will be ``(2, 4)``
|
||||
|
||||
>>> with jnp.printoptions(precision=2, suppress=True):
|
||||
... print(jax.scipy.fft.idctn(x, s=[2, 4]))
|
||||
[[ 0.1 0.18 0.07 -0.16]
|
||||
[ 0.2 0.06 -0.03 -0.01]]
|
||||
|
||||
``jax.scipy.fft.idctn`` can be used to reconstruct ``x`` from the result
|
||||
of ``jax.scipy.fft.dctn``
|
||||
|
||||
>>> x_dctn = jax.scipy.fft.dctn(x)
|
||||
>>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
x = ensure_arraylike("idctn", x)
|
||||
|
||||
if type != 2:
|
||||
raise NotImplementedError('Only DCT type 2 is implemented.')
|
||||
if norm is not None and norm not in ['backward', 'ortho']:
|
||||
raise ValueError(f"jax.scipy.fft.idctn: {norm=!r} is not implemented")
|
||||
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return lax.complex(
|
||||
idctn(x.real, type=type, s=s, norm=norm, axes=axes),
|
||||
idctn(x.imag, type=type, s=s, norm=norm, axes=axes)
|
||||
)
|
||||
|
||||
if s is not None:
|
||||
try:
|
||||
s = list(s)
|
||||
except TypeError:
|
||||
assert not isinstance(s, Sequence)
|
||||
s = [operator.index(s)]
|
||||
if len(s) > x.ndim:
|
||||
raise ValueError(
|
||||
f"s must have at most x.ndim ({x.ndim}) elements, got {len(s)}"
|
||||
)
|
||||
|
||||
if axes is None:
|
||||
if s is not None:
|
||||
axes = tuple(range(x.ndim - len(s), x.ndim))
|
||||
else:
|
||||
axes = tuple(range(x.ndim))
|
||||
else:
|
||||
axes = canonicalize_axis_tuple(axes, x.ndim)
|
||||
|
||||
if len(axes) == 1:
|
||||
return idct(x, n=s[0] if s is not None else None, axis=axes[0], norm=norm)
|
||||
|
||||
if s is not None:
|
||||
ns = dict(zip(axes, s))
|
||||
pads = [(0, ns[a] - x.shape[a] if a in ns else 0, 0) for a in range(x.ndim)]
|
||||
x = lax.pad(x, jnp.array(0, x.dtype), pads)
|
||||
|
||||
# compose high-D DCTs from 1D DCTs:
|
||||
for axis in axes:
|
||||
x = idct(x, axis=axis, norm=norm)
|
||||
return x
|
||||
|
||||
|
||||
def _dct_deinterleave(x: Array, axis: int) -> Array:
|
||||
empty_slice = slice(None, None, None)
|
||||
ix0 = tuple(
|
||||
slice(None, math.ceil(x.shape[axis]/2), 1) if i == axis else empty_slice
|
||||
for i in range(len(x.shape)))
|
||||
ix1 = tuple(
|
||||
slice(math.ceil(x.shape[axis]/2), None, 1) if i == axis else empty_slice
|
||||
for i in range(len(x.shape)))
|
||||
v0 = x[ix0]
|
||||
v1 = lax.rev(x[ix1], (axis,))
|
||||
out = jnp.zeros(x.shape, dtype=x.dtype)
|
||||
evens = tuple(
|
||||
slice(None, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
|
||||
odds = tuple(
|
||||
slice(1, None, 2) if i == axis else empty_slice for i in range(len(x.shape)))
|
||||
out = out.at[evens].set(v0)
|
||||
out = out.at[odds].set(v1)
|
||||
return out
|
||||
Reference in New Issue
Block a user