This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,13 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,24 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from jax._src.tpu.linalg import (
eigh as eigh,
qdwh as qdwh,
svd as svd,
)
from jax._src import traceback_util
traceback_util.register_exclusion(os.path.dirname(__file__))
@@ -0,0 +1,699 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Symmetric (Hermitian) eigendecomposition using QDWH
References:
Nakatsukasa, Yuji, and Nicholas J. Higham.
"Stable and efficient spectral divide and conquer algorithms for the symmetric
eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing 35,
no. 3 (2013): A1325-A1349.
https://epubs.siam.org/doi/abs/10.1137/120876605
This implementation is primarily used on TPU, but it can in principle work on
CPU and GPU also.
"""
from __future__ import annotations
from functools import partial
from typing import NamedTuple
import numpy as np
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.interpreters import mlir
from jax._src.lax import control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
from jax._src.lax.linalg import is_constant_shape
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import linalg as jnp_linalg
from jax._src.numpy import tensor_contractions
from jax._src.numpy import reductions
from jax._src.numpy import ufuncs
from jax._src.tpu.linalg import qdwh
from jax._src.tpu.linalg.stack import Stack
from jax._src.typing import Array
# QDWH-eigh is a recursive algorithm where the structure of the recursion
# is determined by the eigenspectrum. Neither JAX nor XLA can handle this kind
# of recursion, so we instead express the recursion as iteration using an
# explicit stack.
# TODO(phawkins): consider extracting _mask/_slice/_update_slice into a
# separate module.
def _round_up(i, n):
return ((i+n-1) // n) * n
def _mask(x, dims, alternative=0):
"""Masks `x` up to the dynamic shape `dims`.
Replaces values outside those dimensions with `alternative`. `alternative` is
broadcast with `x`.
"""
assert np.ndim(x) == len(dims)
mask: Array | None = None
for i, d in enumerate(dims):
if d is not None:
mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d
mask = mask_dim_i if mask is None else (mask & mask_dim_i)
return x if mask is None else jnp.where(mask, x, alternative)
def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes,
fill_value=0):
"""Similar to lax.dynamic_slice, but handles arrays with dynamic sizes.
Returns fill_value instead of clamping start_indices for those elements that
would overflow the side of the array.
Args:
operand: the array to slice
start_indices: the offset of the start of the slice
dynamic_slice_sizes: the true (unpadded) size of the slice
static_slice_sizes: the padded size of the slice, which must be known at
compile time. The static size must be larger than the dynamic size.
fill_value: value with which to replace masked-out elements.
Returns:
An array with static shape `static_slice_sizes`, padded from its true
(dynamic) size `dynamic_slice_sizes`.
"""
# We must pad the input array so the dynamic_slice is guaranteed to fall
# entirely in bounds.
padded = lax.pad(operand,
jnp.array(0, operand.dtype),
[(0, d, 0) for d in static_slice_sizes])
out = lax.dynamic_slice(padded, tuple(jnp.array(i, np.int32) for i in start_indices),
static_slice_sizes)
return _mask(out, dynamic_slice_sizes, fill_value)
def _update_slice(operand, update, start_indices, update_dims):
"""
Similar to lax.dynamic_update_slice, but handles padded updates where padding
values should not overwrite existing values in the array.
Args:
operand: the array to update
update: the padded array to write
start_indices: the offset at which to write `update`.
update_dims: the true dimensions of the padded update `update`. Only values
inside the rectangle given by `update_dims` will be overwritten."""
operand_shape = operand.shape
operand = lax.pad(operand,
jnp.array(0, operand.dtype),
[(0, d, 0) for d in update.shape])
start_indices = tuple(jnp.array(i, np.int32) for i in start_indices)
t = lax.dynamic_slice(operand, start_indices, update.shape)
t = _mask(update, update_dims, t)
operand = lax.dynamic_update_slice(operand, t, start_indices)
return lax.slice(operand, [0] * operand.ndim, operand_shape)
def _projector_subspace(P, H, n, rank, maxiter=2, swap=False):
"""Decomposes the `n x n` rank `rank` Hermitian projector `P` into
an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T`
and an `n x (n - rank)` isometry `V_minus` such that
-(I - P) = V_plus @ V_plus.conj().T`.
The subspaces are computed using the naiive QR eigendecomposition
algorithm, which converges very quickly due to the sharp separation
between the relevant eigenvalues of the projector.
Args:
P: A rank-`rank` Hermitian projector into the space of `H`'s first `rank`
eigenpairs. `P` is padded to NxN.
H: The aforementioned Hermitian matrix, which is used to track convergence.
n: the true (dynamic) shape of `P`.
rank: Rank of `P`.
maxiter: Maximum number of iterations.
swap: If true, the two outputs spaces are swapped.
Returns:
V_minus, V_plus: Isometries into the eigenspaces described in the docstring.
"""
# Choose an initial guess: the `rank` largest-norm columns of P.
N, _ = P.shape
negative_column_norms = -jnp_linalg.norm(P, axis=1)
# `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN.
negative_column_norms = _mask(negative_column_norms, (n,), np.nan)
sort_idxs = jnp.argsort(negative_column_norms)
X = P[:, sort_idxs]
# X = X[:, :rank]
X = _mask(X, (n, rank))
H_norm = jnp_linalg.norm(H)
thresh = 10.0 * float(dtypes.finfo(X.dtype).eps) * H_norm
# First iteration skips the matmul.
def body_f_after_matmul(X):
Q, _ = jnp_linalg.qr(X, mode="complete")
# V1 = Q[:, :rank]
# V2 = Q[:, rank:]
V1 = _mask(Q, (n, rank))
V2 = _slice(Q, (0, rank), (n, n - rank), (N, N))
# TODO: might be able to get away with lower precision here
error_matrix = tensor_contractions.dot(V2.conj().T, H)
error_matrix = tensor_contractions.dot(error_matrix, V1)
error = jnp_linalg.norm(error_matrix)
return V1, V2, error
def cond_f(args):
_, _, j, error = args
still_counting = j < maxiter
unconverged = error > thresh
return ufuncs.logical_and(still_counting, unconverged)[0]
def body_f(args):
V1, _, j, _ = args
X = tensor_contractions.dot(P, V1)
V1, V2, error = body_f_after_matmul(X)
return V1, V2, j + 1, error
V1, V2, error = body_f_after_matmul(X)
one = jnp.ones(1, dtype=np.int32)
V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
if swap:
return V2, V1
else:
return V1, V2
def split_spectrum(H, n, split_point, V0=None):
""" The Hermitian matrix `H` is split into two matrices `H_minus`
`H_plus`, respectively sharing its eigenspaces beneath and above
its `split_point`th eigenvalue.
Returns, in addition, `V_minus` and `V_plus`, isometries such that
`Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
returned instead; this allows the overall isometries mapping from
an initial input matrix to progressively smaller blocks to be formed.
Args:
H: The Hermitian matrix to split.
split_point: The eigenvalue to split along.
V0: Matrix of isometries to be updated.
Returns:
H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath
`split_point`.
V_minus: An isometry from the input space of `V0` to `H_minus`.
H_plus: A Hermitian matrix sharing the eigenvalues of `H` above
`split_point`.
V_plus: An isometry from the input space of `V0` to `H_plus`.
rank: The dynamic size of the m subblock.
"""
N, _ = H.shape
H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(H.dtype)
U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
I = _mask(jnp.eye(N, dtype=H.dtype), (n, n))
P_minus = -0.5 * (U - I)
rank_minus = jnp.round(jnp.trace(ufuncs.real(P_minus))).astype(np.int32)
P_plus = 0.5 * (U + I)
rank_plus = n - rank_minus
# Run subspace iteration on whichever projector P_minus or P_plus that has the
# smallest rank. This can save a significant amount of work when H has
# rank << n or if our estimate of the median eigenvalue is poor, because
# the subspace iteration involves computing the QR decomposition of a
# matrix of size n x rank.
swap = rank_plus < rank_minus
V_minus, V_plus = lax.cond(
swap,
lambda: _projector_subspace(P_plus, H, n, rank_plus, swap=True),
lambda: _projector_subspace(P_minus, H, n, rank_minus, swap=False),
)
H_minus = (V_minus.conj().T @ H) @ V_minus
H_plus = (V_plus.conj().T @ H) @ V_plus
if V0 is not None:
V_minus = tensor_contractions.dot(V0, V_minus)
V_plus = tensor_contractions.dot(V0, V_plus)
return H_minus, V_minus, H_plus, V_plus, rank_minus
# To help understand the iterative version of the algorithm, the original
# recursive formulation follows.
#
# def _eigh_work(H, V=None, termination_size=128):
# """ The main work loop performing the symmetric eigendecomposition of H.
# Each step recursively computes a projector into the space of eigenvalues
# above jnp.mean(jnp.diag(H)). The result of the projections into and out of
# that space, along with the isometries accomplishing these, are then computed.
# This is performed recursively until the projections have size 1, and thus
# store an eigenvalue of the original input; the corresponding isometry is
# the related eigenvector. The results are then composed.
#
# Args:
# H: The Hermitian input.
# V: Stores the isometries projecting H into its subspaces.
# precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
#
# Returns:
# H, V: The result of the projection.
# """
# if H.shape[0] <= termination_size:
# evals, evecs = jnp_linalg.eigh(H)
# if V is not None:
# evecs = jnp.dot(V, evecs)
# return evals, evecs
#
# split_point = jnp.median(jnp.diag(H)) # TODO: Improve this?
# H_minus, V_minus, H_plus, V_plus = split_spectrum(H, split_point, V0=V)
# H_minus, V_minus = _eigh_work(H_minus, V=V_minus, termination_size=termination_size)
# H_plus, V_plus = _eigh_work(H_plus, V=V_plus, termination_size=termination_size)
#
# evals = jnp.hstack((H_minus, H_plus))
# evecs = jnp.hstack((V_minus, V_plus))
# return evals, evecs
class _Subproblem(NamedTuple):
"""Describes a subproblem of _eigh_work.
Each subproblem is a `size` x `size` Hermitian matrix, starting at `offset`
in the workspace.
"""
# The row offset of the block in the matrix of blocks.
offset: Array
# The size of the block.
size: Array
@api.jit(static_argnames=('termination_size', 'subset_by_index'))
def _eigh_work(H, n, termination_size, subset_by_index):
""" The main work loop performing the symmetric eigendecomposition of H.
Each step recursively computes a projector into the space of eigenvalues
above jnp.mean(jnp.diag(H)). The result of the projections into and out of
that space, along with the isometries accomplishing these, are then computed.
This is performed recursively until the projections have size 1, and thus
store an eigenvalue of the original input; the corresponding isometry is
the related eigenvector. The results are then composed.
This function cannot be Jitted because the internal split_spectrum cannot
be.
Args:
H: The Hermitian input.
n: The true (dynamic) shape of H.
Returns:
H, V: The result of the projection.
"""
# We turn what was originally a recursive algorithm into an iterative
# algorithm with an explicit stack.
N, _ = H.shape
n = jnp.asarray(n, np.int32)
agenda = Stack.create(
N + 1, _Subproblem(jnp.array(0, np.int32), jnp.array(0, np.int32)))
agenda = agenda.push(_Subproblem(offset=jnp.array(0, np.int32), size=n))
# eigenvectors is the array in which we build the output eigenvectors.
# We initialize it with the identity matrix so the initial matrix
# multiplications in_split_spectrum_jittable are the identity.
eigenvectors = jnp.eye(N, dtype=H.dtype)
# Keep a copy of the initial matrix Frobenius norm, so we know when to stop
# recursing. When the sub-matrix norm is less than eps*H0_norm, the contents are
# pure numerical noise, and we should just stop.
H0_norm = jnp_linalg.norm(_mask(H, (n, n)))
# blocks is an array representing a stack of Hermitian matrix blocks that we
# need to recursively decompose. Subproblems are different sizes, so the stack
# of blocks is ragged. Subproblems are left-aligned (i.e. starting at the 0th
# column). Here is an ASCII art picture of three blocks A, B, C, embedded
# in the larger `blocks` workspace (represented with trailing dots).
#
# A A A . . .
# A A A . . .
# A A A . . .
# B B . . . .
# B B . . . .
# C C C C . .
# C C C C . .
# C C C C . .
# C C C C . .
#
# Each step of the algorithm subdivides a block into two subblocks whose
# sizes sum to the original block size. We overwrite the original block with
# those two subblocks so we don't need any additional scratch space.
#
# At termination, "blocks" will contain 1x1 blocks (i.e., the eigenvalues) in
# its first column.
blocks = H
def base_case(B, offset, b, agenda, blocks, eigenvectors):
# Base case: for blocks under a minimum size, we cutoff the recursion
# and call the TPU Jacobi eigendecomposition implementation. The Jacobi
# algorithm works well for small matrices but scales poorly, so the two
# complement each other well.
H = _slice(blocks, (offset, 0), (b, b), (B, B))
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
# We replace the masked-out part of the matrix with the identity matrix.
# We know that the TPU Jacobi eigh implementation will not alter the order
# of the eigenvalues, so we know the eigendecomposition of the original
# matrix is in the top-left corner of the eigendecomposition of the padded
# matrix.
# It is very important that the underlying eigh implementation does not sort
# the eigenvalues for this reason! This is currently not true of JAX's CPU
# and GPU eigendecompositions, and for those platforms this algorithm will
# only do the right thing if termination_size == 1.
H = _mask(H, (b, b))
eig_vecs, eig_vals = lax_linalg.eigh(H, sort_eigenvalues=False)
eig_vecs = _mask(eig_vecs, (b, b))
eig_vals = _mask(eig_vals, (b,))
eig_vecs = tensor_contractions.dot(V, eig_vecs)
eig_vals = eig_vals.astype(eig_vecs.dtype)
blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, 1))
eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b))
return agenda, blocks, eigenvectors
def recursive_case(B, offset, b, agenda, blocks, eigenvectors):
# The recursive case of the algorithm, specialized to a static block size
# of B.
H = _slice(blocks, (offset, 0), (b, b), (B, B))
def nearly_diagonal_case(agenda, blocks, eigenvectors):
blocks = _update_slice(blocks, jnp.diag(H)[:, None], (offset, 0), (b, 1))
return agenda, blocks, eigenvectors
def should_update_range(start, end, subset_by_index):
return (
True
if subset_by_index is None
else ((start < subset_by_index[1]) & (end > subset_by_index[0]))
)
def default_case(agenda, blocks, eigenvectors):
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
# TODO: Improve this?
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), np.nan))
H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(
H, b, split_point, V0=V)
# Update state for *_minus.
updated_minus_state = (
_update_slice(blocks, H_minus, (offset, 0), (rank, rank)),
_update_slice(eigenvectors, V_minus, (0, offset), (n, rank)),
agenda.push(_Subproblem(offset, rank)),
)
should_update_minus = should_update_range(
offset, offset + rank, subset_by_index
)
blocks, eigenvectors, agenda = lax.cond(
should_update_minus,
lambda: updated_minus_state,
lambda: (blocks, eigenvectors, agenda),
)
# Update state for *_plus.
updated_plus_state = (
_update_slice(
blocks, H_plus, (offset + rank, 0), (b - rank, b - rank)
),
_update_slice(
eigenvectors, V_plus, (0, offset + rank), (n, b - rank)
),
agenda.push(_Subproblem(offset + rank, (b - rank))),
)
should_update_plus = should_update_range(
offset + rank, offset + b, subset_by_index
)
blocks, eigenvectors, agenda = lax.cond(
should_update_plus,
lambda: updated_plus_state,
lambda: (blocks, eigenvectors, agenda),
)
return agenda, blocks, eigenvectors
# If the matrix is nearly diagonal or has a tiny Frobenius norm compared to
# the original input matrix,, terminate the execution. This is necessary to
# handle matrices with clusters of eigenvalues, including rank deficient
# matrices. See Nakatsukasa and Higham section 5.2.
norm = jnp_linalg.norm(H)
eps = jnp.asarray(dtypes.finfo(H.dtype).eps, dtype=norm.dtype)
off_diag_norm = jnp_linalg.norm(
H - jnp.diag(jnp.diag(ufuncs.real(H)).astype(H.dtype)))
nearly_diagonal = off_diag_norm <= 5 * eps * norm
tiny = norm < eps * H0_norm
return lax.cond(
nearly_diagonal | tiny,
nearly_diagonal_case,
default_case,
agenda,
blocks,
eigenvectors,
)
def loop_cond(state):
agenda, _, _ = state
return ~agenda.empty()
# It would be wasteful to perform all computation padded up to the original
# matrix size. Instead, we form buckets of padded sizes e.g.,
# [N_0, N_1, ... N_k], aiming for a balance between compilation time
# and runtime.
cutoff = min(N, termination_size)
buckets = [cutoff]
branches = [partial(base_case, cutoff)]
if N > termination_size:
# If N > termination_size We use the following schedule:
# 1. N_0 = N,
# 2. N_i = _round_up(int(N_{i-1} / 1.98), 32), 0 < i < k
# 3. N_k = termination_size
# the rule for N_i is to avoid falling into the original large bucket
# when not splitting exactly at the half-way point during the recursion.
buckets.append(N)
branches.append(partial(recursive_case, N))
multiplier = 1.98
granularity = 32
i = int(N / multiplier)
while i > cutoff:
bucket_size = _round_up(i, granularity)
buckets.append(bucket_size)
branches.append(partial(recursive_case, bucket_size))
i = i // 2
buckets_arr = jnp.array(buckets, dtype=np.int32)
def loop_body(state):
agenda, blocks, eigenvectors = state
(offset, b), agenda = agenda.pop()
which = jnp.where(buckets_arr < b, dtypes.iinfo(np.int32).max, buckets_arr)
choice = jnp.argmin(which)
return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors)
_, blocks, eigenvectors = lax.while_loop(
loop_cond, loop_body, (agenda, blocks, eigenvectors))
return blocks[:, 0], eigenvectors
def eigh(
H,
*,
precision='float32',
termination_size=256,
n=None,
sort_eigenvalues=True,
subset_by_index=None,
):
"""Computes the eigendecomposition of the symmetric/Hermitian matrix H.
Args:
H: The `n x n` Hermitian input, padded to `N x N`.
precision: :class:`~jax.lax.Precision` object specifying the matmul
precision.
termination_size: Recursion ends once the blocks reach this linear size.
n: the true (dynamic) size of the matrix.
sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
highest.
subset_by_index: Optional 2-tuple [start, end] indicating the range of
indices of eigenvalues to compute. For example, is ``range_select`` =
[n-2,n], then ``eigh`` computes the two largest eigenvalues and their
eigenvectors.
Returns:
vals: The `n` eigenvalues of `H`.
vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
to numerical error.
"""
M, N = H.shape
if M != N:
raise TypeError(f"Input H of shape {H.shape} must be square.")
if n is not None and n > N:
raise ValueError('Static size must be greater or equal to dynamic size.')
compute_slice = False
if subset_by_index is not None:
compute_slice = subset_by_index != (0, n)
if len(subset_by_index) != 2:
raise ValueError('subset_by_index must be a tuple of size 2.')
if subset_by_index[0] >= subset_by_index[1]:
raise ValueError('Got empty index range in subset_by_index.')
if subset_by_index[0] < 0:
raise ValueError('Indices in subset_by_index must be non-negative.')
range_max = N if n is None else n
if subset_by_index[1] > range_max:
raise ValueError('Index in subset_by_index[1] exceeds matrix size.')
if N <= termination_size:
if n is not None:
H = _mask(H, (n, n))
eig_vecs, eig_vals = lax_linalg.eigh(
H, lower=True, sort_eigenvalues=(sort_eigenvalues or compute_slice),
subset_by_index=None, symmetrize_input=False,
implementation=lax_linalg.EighImplementation.JACOBI,
)
if subset_by_index is not None and compute_slice:
eig_vals = eig_vals[subset_by_index[0] : subset_by_index[1]]
eig_vecs = eig_vecs[:, subset_by_index[0] : subset_by_index[1]]
return eig_vals, eig_vecs
n = N if n is None else n
with config.default_matmul_precision(precision):
eig_vals, eig_vecs = _eigh_work(
H, n, termination_size=termination_size, subset_by_index=subset_by_index
)
eig_vals = _mask(ufuncs.real(eig_vals), (n,), np.nan)
if sort_eigenvalues or compute_slice:
sort_idxs = jnp.argsort(eig_vals)
if compute_slice:
sort_idxs = sort_idxs[subset_by_index[0] : subset_by_index[1]]
eig_vals = eig_vals[sort_idxs]
eig_vecs = eig_vecs[:, sort_idxs]
return eig_vals, eig_vecs
def _T(x: Array) -> Array:
return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
def _eigh_qdwh_impl(x, *, lower, sort_eigenvalues, subset_by_index):
"""QDWH-based eigendecomposition for TPU."""
*_, m, n = x.shape
assert m == n, (m, n)
termination_size = 256
if not core.is_constant_dim(m):
# TODO: maybe we can relax the check below for shape polymorphism?
raise NotImplementedError(
"Shape polymorphism for native lowering for eigh is implemented "
f"only for the batch dimensions: {x.shape}")
if m <= termination_size and (
subset_by_index is None or subset_by_index == (0, n)
):
return lax_linalg.eigh(
x, lower=lower, sort_eigenvalues=sort_eigenvalues,
symmetrize_input=False,
implementation=lax_linalg.EighImplementation.JACOBI
)
def eigh_qdwh(x):
if len(x.shape) > 2:
return control_flow.map(eigh_qdwh, x)
# We should only look at elements from the lower/upper triangle. Reflects
# that triangle into the other triangle to form a Hermitian matrix.
if lower:
mask = lax_internal._tri(bool, (n, n), 0)
else:
mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1))
if dtypes.issubdtype(x.dtype, np.complexfloating):
re = lax.select(mask, lax.real(x), _T(lax.real(x)))
if lower:
im_mask = lax_internal._tri(bool, (n, n), -1)
else:
im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0))
im = lax.imag(x)
im = lax.select(im_mask, im, lax.full_like(im, 0))
im = lax.select(mask, im, -_T(im))
x = lax.complex(re, im)
else:
x = lax.select(mask, x, _T(x))
return eigh(
x,
sort_eigenvalues=sort_eigenvalues,
termination_size=termination_size,
subset_by_index=subset_by_index,
)
eig_vals, eig_vecs = eigh_qdwh(x)
return eig_vecs, eig_vals
def _eigh_tpu_lowering(
ctx, operand, *, lower, sort_eigenvalues, subset_by_index, algorithm
):
if algorithm is None:
algorithm = lax_linalg.EighImplementation.QDWH
if algorithm == lax_linalg.EighImplementation.QR:
raise NotImplementedError("QR algorithm is not supported on TPU")
elif algorithm == lax_linalg.EighImplementation.JACOBI:
operand_aval, = ctx.avals_in
if operand_aval.shape[-1] == 0:
reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1])
return [
operand,
hlo.real(mlir.reshape(ctx, operand, reshape_aval)),
]
v_aval, w_aval = ctx.avals_out
eigvecs_type = mlir.aval_to_ir_type(v_aval)
eigvals_type = mlir.aval_to_ir_type(w_aval)
result_types = [eigvecs_type, eigvals_type]
backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6"
if any(not is_constant_shape(aval_out.shape)
for aval_out in ctx.avals_out):
result_shapes = [
mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
for aval_out in ctx.avals_out
]
else:
result_shapes = None
op = mlir.custom_call(
"Eigh",
result_types=result_types,
operands=[operand],
backend_config=backend_config,
api_version=1,
result_shapes=result_shapes,
)
return op.results
elif algorithm == lax_linalg.EighImplementation.QDWH:
return mlir.lower_fun(_eigh_qdwh_impl, multiple_results=True)(
ctx, operand, lower=lower, sort_eigenvalues=sort_eigenvalues,
subset_by_index=subset_by_index)
else:
raise ValueError(f"Unknown algorithm: {algorithm}")
mlir.register_lowering(lax_linalg.eigh_p, _eigh_tpu_lowering, platform='tpu')
@@ -0,0 +1,291 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""A JIT-compatible library for QDWH-based polar decomposition.
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
iteration implemented through QR decmopositions does not require matrix
inversion. This is desirable for multicore and heterogeneous computing systems.
Reference: Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
"Optimizing Halley's iteration for computing the matrix polar decomposition."
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
https://epubs.siam.org/doi/abs/10.1137/090774999
"""
from __future__ import annotations
import functools
import numpy as np
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.lax import linalg as lax_linalg
from jax._src.numpy import linalg as jnp_linalg
from jax._src.typing import Array
# Helpers for working with padded shapes
def _mask(x, dims, alternative=0):
"""Masks `x` up to the dynamic shape `dims`.
Replaces values outside those dimensions with `alternative`. `alternative` is
broadcast with `x`.
"""
assert np.ndim(x) == len(dims)
mask: Array | None = None
for i, d in enumerate(dims):
if d is not None:
mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d
mask = mask_dim_i if mask is None else (mask & mask_dim_i)
return x if mask is None else jnp.where(mask, x, alternative)
def _pad_in_dim(x, low=0, high=0, interior=0, fill_value=0, axis=0):
pads = [(0, 0, 0)] * x.ndim
pads[axis] = (low, high, interior)
return lax.pad(x, jnp.array(fill_value, x.dtype), pads)
def _dynamic_concat(a, b, m, axis=0):
"Concatenates padded arrays `a` and `b` where the true size of `a` is `m`."
if m is None:
return jnp.concatenate([a, b], axis=axis)
return lax.dynamic_update_slice_in_dim(
_pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis)
def _use_qr(u, m, n, params):
"""QDWH iteration using QR decomposition.
Args:
u: a matrix, with static (padded) shape M x N.
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
params: the QDWH parameters.
"""
a_minus_e_by_sqrt_c, sqrt_c, e = params
M, N = u.shape
y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=dtypes.dtype(u)), m)
q, _ = lax_linalg.qr(y, full_matrices=False)
# q1 = q[:m, :]
q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
# q2 = (q[m:, :]).T.conj()
q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
q2 = _mask(q2, (n, n)).T.conj()
return e * u + a_minus_e_by_sqrt_c * (q1 @ q2)
def _use_cholesky(u, m, n, params):
"""QDWH iteration using Cholesky decomposition.
Args:
u: a matrix, with static (padded) shape M x N
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
params: the QDWH parameters.
"""
a_minus_e, c, e = params
_, N = u.shape
x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=dtypes.dtype(u))
# Pads the lower-right corner with the identity matrix to prevent the Cholesky
# decomposition from failing due to the matrix not being PSD if padded with
# zeros.
x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype))
# `y` is lower triangular.
y = lax_linalg.cholesky(x, symmetrize_input=False)
z = lax_linalg.triangular_solve(
y, u.T, left_side=True, lower=True, conjugate_a=True).conj()
z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True,
transpose_a=True, conjugate_a=True).T.conj()
return e * u + a_minus_e * z
def _qdwh(x, m, n, max_iterations, eps):
"""QR-based dynamically weighted Halley iteration for polar decomposition."""
# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
# norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
# the smallest singular value of x.
if eps is None:
eps = float(dtypes.finfo(x.dtype).eps)
one_norm = jnp_linalg.norm(x, ord=1)
inf_norm = jnp_linalg.norm(x, ord=np.inf)
alpha_inverse = lax.rsqrt(one_norm) * lax.rsqrt(inf_norm)
alpha_inverse = jnp.where(one_norm == 0, 1, alpha_inverse)
u = x * alpha_inverse.astype(x.dtype)
l = eps
# Iteration tolerances.
tol_l = 10.0 * eps / 2.0
tol_norm = jnp.cbrt(tol_l)
def get_qr_params(a, b, c):
e = b / c
a_minus_e = a - e
sqrt_c = c ** (1 / 2)
return (a_minus_e / sqrt_c, sqrt_c, e)
def get_chol_params(a, b, c):
e = b / c
a_minus_e = a - e
return (a_minus_e, c, e)
CHOLESKY_CUTOFF = 100
qr_coefs = []
chol_coefs = []
k = 0
while l + tol_l < 1 and k < max_iterations:
k += 1
l2 = l * l
dd = (4 * (1 / l2 - 1) / l2) ** (1 / 3)
sqd = (1.0 + dd) ** (1 / 2)
a = sqd + (2 - dd + 2 * (2 - l2) / (l2 * sqd)) ** (1 / 2)
b = (a - 1) ** 2 / 4
c = a + b - 1
l = l * (a + b * l2) / (1 + c * l2)
if c > CHOLESKY_CUTOFF:
qr_coefs.append(get_qr_params(a, b, c))
else:
chol_coefs.append(get_chol_params(a, b, c))
def iteration(k, state, update_fn, coefs, test_convergence):
u, _ = state
if coefs is None:
# As l → 1, the coefficients a, b, c → 3, 1, 3, which is Halley's method.
params = get_chol_params(3, 1, 3)
else:
params = lax.dynamic_index_in_dim(coefs, k, keepdims=False)
u_prev = u
u = update_fn(u, m, n, params)
is_not_converged = True
if test_convergence:
is_not_converged = jnp_linalg.norm(u - u_prev) > tol_norm
return u, is_not_converged
def iterate(u, coefs, **kwargs):
if not coefs:
return u, True
coefs = jnp.array(coefs).astype(x.dtype)
body = functools.partial(iteration, coefs=coefs, **kwargs)
return lax.fori_loop(0, len(coefs), body, (u, True))
u, _ = iterate(
u, coefs=qr_coefs, update_fn=_use_qr, test_convergence=False
)
u, is_not_converged = iterate(
u, coefs=chol_coefs, update_fn=_use_cholesky, test_convergence=True
)
# If l has converged but u still has not, continue with Halley's method
# (coef = None) until convergence.
def cond_fun(state):
k, _, is_not_converged = state
return jnp.logical_and(is_not_converged, k < max_iterations)
def body_fun(state):
k, u, is_not_converged = state
u, is_not_converged = iteration(
k,
(u, is_not_converged),
coefs=None,
update_fn=_use_cholesky,
test_convergence=True,
)
return k + 1, u, is_not_converged
k = len(qr_coefs) + len(chol_coefs)
num_iters, u, is_not_converged = lax.while_loop(
cond_fun, body_fun, (k, u, is_not_converged)
)
# Applies Newton-Schulz refinement for better accuracy.
u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u)
h = u.T.conj() @ x
h = (h + h.T.conj()) / 2
# Converged within the maximum number of iterations.
is_converged = jnp.logical_not(is_not_converged)
return u, h, num_iters, is_converged
# TODO: Add pivoting.
@functools.partial(
api.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps')
)
def qdwh(
x,
*,
is_hermitian: bool = False,
max_iterations: int | None = None,
eps: float | None = None,
dynamic_shape: tuple[int, int] | None = None,
):
"""QR-based dynamically weighted Halley iteration for polar decomposition.
Args:
x: A full-rank matrix, with shape `M x N`. The matrix may be padded up to
that size from a smaller true shape (``dynamic_shape``).
is_hermitian: True if `x` is Hermitian. Default to `False`. This parameter
is currently unused, but exists for backward compatibility.
eps: The final result will satisfy ``|x_k - x_k-1| < |x_k| *
(4*eps)**(1/3)`` where `x_k` is the iterate.
max_iterations: Iterations will terminate after this many steps even if the
above is unsatisfied.
dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional.
Returns:
A four-tuple of (u, h, num_iters, is_converged) containing the
polar decomposition of `x = u * h`, the number of iterations to compute `u`,
and `is_converged`, whose value is `True` when the convergence is achieved
within the maximum number of iterations.
"""
# TODO: Possibly take advantage of Hermitian inputs to speed up the QDWH step.
is_hermitian = core.concrete_or_error(
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
if max_iterations is None:
max_iterations = 10
else:
max_iterations = core.concrete_or_error(
int, max_iterations, 'The `max_iterations` argument must be statically '
'specified to use `qdwh` within JAX transformations.')
M, N = x.shape
if M < N:
raise ValueError('The input matrix of shape M x N must have M >= N.')
if dynamic_shape is not None:
m, n = dynamic_shape
x = _mask(x, (m, n))
else:
m, n = M, N
with config.default_matmul_precision('float32'):
u, h, num_iters, is_converged = _qdwh(x, m, n, max_iterations, eps)
return u, h, num_iters, is_converged
@@ -0,0 +1,79 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""A bounded functional stack implementation.
Used as a helper for expressing recursive algorithms such as QDWH-eig for
Eigendecomposition on TPU.
"""
from __future__ import annotations
from typing import Any
from jax._src import lax
from jax._src import numpy as jnp
from jax._src import tree_util
class Stack:
"""A bounded functional stack implementation. Elements may be pytrees."""
def __init__(self, size, data):
"""Private constructor."""
self._size = size
self._data = data
def __repr__(self):
return f"Stack({self._size}, {self._data})"
@staticmethod
def create(capacity: int, prototype: Any) -> Stack:
"""Creates a stack with size `capacity` with elements like `prototype`.
`prototype` can be any JAX pytree. This function looks only at its
structure; the specific values are ignored.
"""
return Stack(
jnp.array(0, 'int32'),
tree_util.tree_map(
lambda x: jnp.zeros((capacity,) + tuple(x.shape), x.dtype), prototype))
def empty(self) -> Any:
"""Returns true if the stack is empty."""
return self._size == 0
def push(self, elem: Any) -> Stack:
"""Pushes `elem` onto the stack, returning the updated stack."""
return Stack(
self._size + 1,
tree_util.tree_map(
lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0),
self._data, elem))
def pop(self) -> tuple[Any, Stack]:
"""Pops from the stack, returning an (elem, updated stack) pair."""
elem = tree_util.tree_map(
lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False),
self._data)
return elem, Stack(self._size - 1, self._data)
def flatten(self):
leaves, treedef = tree_util.tree_flatten(self._data)
return ([self._size] + leaves), treedef
@staticmethod
def unflatten(treedef, leaves):
return Stack(leaves[0], tree_util.tree_unflatten(treedef, leaves[1:]))
tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten)
@@ -0,0 +1,312 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""A JIT-compatible library for QDWH-based singular value decomposition.
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
iteration implemented through QR decompositions is numerically stable and does
not require solving a linear system involving the iteration matrix or
computing its inversion. This is desirable for multicore and heterogeneous
computing systems.
References:
Nakatsukasa, Yuji, and Nicholas J. Higham.
"Stable and efficient spectral divide and conquer algorithms for the symmetric
eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing 35,
no. 3 (2013): A1325-A1349.
https://epubs.siam.org/doi/abs/10.1137/120876605
Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
"Optimizing Halley's iteration for computing the matrix polar decomposition."
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
https://epubs.siam.org/doi/abs/10.1137/090774999
"""
from __future__ import annotations
from collections.abc import Sequence
import functools
import operator
from typing import Any
import numpy as np
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import dtypes
from jax._src import lax
from jax._src import numpy as jnp
from jax._src.interpreters import mlir
from jax._src.lax import linalg as lax_linalg
from jax._src.tpu.linalg import qdwh as tpu_qdwh
from jax._src.typing import Array
@functools.partial(api.jit, static_argnums=(1, 2, 3, 4))
def _svd_tall_and_square_input(
a: Any,
hermitian: bool,
compute_uv: bool,
max_iterations: int,
subset_by_index: tuple[int, int] | None = None,
) -> Any | Sequence[Any]:
"""Singular value decomposition for m x n matrix and m >= n.
Args:
a: A matrix of shape `m x n` with `m >= n`.
hermitian: True if `a` is Hermitian.
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
max_iterations: The predefined maximum number of iterations of QDWH.
Returns:
A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`,
`s` is vector of length `n` containing the singular values in the descending
order, `v` is a unitary matrix of shape `n x n`, and
`a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned.
"""
u_p, h, _, _ = tpu_qdwh.qdwh(
a, is_hermitian=hermitian, max_iterations=max_iterations
)
# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
v, s = lax_linalg.eigh(
h, subset_by_index=subset_by_index, sort_eigenvalues=False
)
# Singular values are non-negative by definition. But eigh could return small
# negative values, so we clamp them to zero.
s = jnp.maximum(s, 0.0)
# Sort or reorder singular values to be in descending order.
sort_idx = jnp.argsort(s, descending=True)
s_out = s[sort_idx]
if not compute_uv:
return s_out
# Reorders eigenvectors.
v_out = v[:, sort_idx]
u_out = u_p @ v_out
# Makes correction if computed `u` from qdwh is not unitary.
# Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and
# efficient spectral divide and conquer algorithms for the symmetric
# eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing
# 35, no. 3 (2013): A1325-A1349.
def correct_rank_deficiency(u_out):
u_out, r = lax_linalg.qr(u_out, full_matrices=False)
u_out = u_out @ jnp.diag(jnp.where(jnp.diag(r) >= 0, 1, -1))
return u_out
eps = float(dtypes.finfo(a.dtype).eps)
do_correction = s_out[-1] <= a.shape[1] * eps * s_out[0]
cond_f = lambda args: args[1]
body_f = lambda args: (correct_rank_deficiency(args[0]), False)
u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction))
return (u_out, s_out, v_out)
@functools.partial(api.jit, static_argnums=(1, 2, 3, 4, 5))
def svd(
a: Any,
full_matrices: bool,
compute_uv: bool = True,
hermitian: bool = False,
max_iterations: int = 10,
subset_by_index: tuple[int, int] | None = None,
) -> Any | Sequence[Any]:
"""Singular value decomposition.
Args:
a: A matrix of shape `m x n`.
full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
respectively. If False, the shapes are `m x k` and `k x n`, respectively,
where `k = min(m, n)`.
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
hermitian: True if `a` is Hermitian.
max_iterations: The predefined maximum number of iterations of QDWH.
subset_by_index: Optional 2-tuple [start, end] indicating the range of
indices of singular components to compute. For example, if
``subset_by_index`` = [0,2], then ``svd`` computes the two largest
singular values (and their singular vectors if `compute_uv` is true.
Returns:
A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices,
`s` is vector of length `k` containing the singular values in the
non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh`
depend on the value of `full_matrices`. For `compute_uv=False`,
only `s` is returned.
"""
full_matrices = core.concrete_or_error(
bool, full_matrices, 'The `full_matrices` argument must be statically '
'specified to use `svd` within JAX transformations.')
compute_uv = core.concrete_or_error(
bool, compute_uv, 'The `compute_uv` argument must be statically '
'specified to use `svd` within JAX transformations.')
hermitian = core.concrete_or_error(
bool,
hermitian,
'The `hermitian` argument must be statically '
'specified to use `svd` within JAX transformations.',
)
max_iterations = core.concrete_or_error(
int,
max_iterations,
'The `max_iterations` argument must be statically '
'specified to use `svd` within JAX transformations.',
)
if subset_by_index is not None:
if len(subset_by_index) != 2:
raise ValueError('subset_by_index must be a tuple of size 2.')
# Make sure subset_by_index is a concrete tuple.
subset_by_index = (
operator.index(subset_by_index[0]),
operator.index(subset_by_index[1]),
)
if subset_by_index[0] >= subset_by_index[1]:
raise ValueError('Got empty index range in subset_by_index.')
if subset_by_index[0] < 0:
raise ValueError('Indices in subset_by_index must be non-negative.')
m, n = a.shape
rank = n if n < m else m
if subset_by_index[1] > rank:
raise ValueError('Index in subset_by_index[1] exceeds matrix size.')
if full_matrices and subset_by_index != (0, rank):
raise ValueError(
'full_matrices and subset_by_index cannot be both be set.'
)
# By convention, eigenvalues are numbered in non-decreasing order, while
# singular values are numbered non-increasing order, so change
# subset_by_index accordingly.
subset_by_index = (rank - subset_by_index[1], rank - subset_by_index[0])
m, n = a.shape
is_flip = False
if m < n:
a = a.T.conj()
m, n = a.shape
is_flip = True
u_out_null: Array | None
q: Array | None
if full_matrices and m > n:
q_full, a_full = lax_linalg.qr(a, pivoting=False, full_matrices=True)
q = q_full[:, :n]
u_out_null = q_full[:, n:]
a = a_full[:n, :]
elif m > 1.15 * n:
# The constant `1.15` comes from Yuji Nakatsukasa's implementation
# https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav
q, a = lax_linalg.qr(a, pivoting=False, full_matrices=False)
u_out_null = None
else:
q = None
u_out_null = None
if not compute_uv:
with config.default_matmul_precision('float32'):
return _svd_tall_and_square_input(
a, hermitian, compute_uv, max_iterations, subset_by_index
)
with config.default_matmul_precision('float32'):
u_out, s_out, v_out = _svd_tall_and_square_input(
a, hermitian, compute_uv, max_iterations, subset_by_index
)
if q is not None: # (full_matrices and m > n) or (m > 1.15 * n)
u_out = q @ u_out
if u_out_null is not None: # full_matrices and m > n
u_out = jnp.hstack((u_out, u_out_null))
is_finite = jnp.all(jnp.isfinite(a))
cond_f = lambda args: jnp.logical_not(args[0])
body_f = lambda args: (
jnp.array(True),
jnp.full_like(u_out, np.nan),
jnp.full_like(s_out, np.nan),
jnp.full_like(v_out, np.nan),
)
_, u_out, s_out, v_out = lax.while_loop(
cond_f, body_f, (is_finite, u_out, s_out, v_out)
)
if is_flip:
return (v_out, s_out, u_out.T.conj())
return (u_out, s_out, v_out.T.conj())
def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None):
if algorithm is not None and algorithm != lax_linalg.SvdAlgorithm.DEFAULT:
raise NotImplementedError(
"The SVD algorithm parameter is not implemented on TPU.")
batch_dims = a.shape[:-2]
fn = functools.partial(
svd,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
for _ in range(len(batch_dims)):
fn = api.vmap(fn)
if compute_uv:
u, s, vh = fn(a)
return [s, u, vh]
else:
s = fn(a)
return [s]
def _svd_tpu_lowering_rule(
ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None
):
operand_aval, = ctx.avals_in
m, n = operand_aval.shape[-2:]
if algorithm is not None and algorithm not in [
lax_linalg.SvdAlgorithm.DEFAULT,
lax_linalg.SvdAlgorithm.POLAR,
]:
raise NotImplementedError(
'Only the POLAR (which is also DEFAULT on TPU) SVD algorithm is'
' supported on TPU.'
)
if m == 0 or n == 0:
return mlir.lower_fun(lax_linalg._empty_svd, multiple_results=True)(
ctx,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
)
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
ctx,
operand,
full_matrices=full_matrices,
compute_uv=compute_uv,
subset_by_index=subset_by_index,
)
mlir.register_lowering(lax_linalg.svd_p, _svd_tpu_lowering_rule)