hand
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from jax._src.tpu.linalg import (
|
||||
eigh as eigh,
|
||||
qdwh as qdwh,
|
||||
svd as svd,
|
||||
)
|
||||
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(os.path.dirname(__file__))
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,699 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""Symmetric (Hermitian) eigendecomposition using QDWH
|
||||
|
||||
References:
|
||||
Nakatsukasa, Yuji, and Nicholas J. Higham.
|
||||
"Stable and efficient spectral divide and conquer algorithms for the symmetric
|
||||
eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing 35,
|
||||
no. 3 (2013): A1325-A1349.
|
||||
https://epubs.siam.org/doi/abs/10.1137/120876605
|
||||
|
||||
This implementation is primarily used on TPU, but it can in principle work on
|
||||
CPU and GPU also.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import control_flow
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.lax.linalg import is_constant_shape
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import linalg as jnp_linalg
|
||||
from jax._src.numpy import tensor_contractions
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.tpu.linalg import qdwh
|
||||
from jax._src.tpu.linalg.stack import Stack
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
# QDWH-eigh is a recursive algorithm where the structure of the recursion
|
||||
# is determined by the eigenspectrum. Neither JAX nor XLA can handle this kind
|
||||
# of recursion, so we instead express the recursion as iteration using an
|
||||
# explicit stack.
|
||||
|
||||
|
||||
# TODO(phawkins): consider extracting _mask/_slice/_update_slice into a
|
||||
# separate module.
|
||||
def _round_up(i, n):
|
||||
return ((i+n-1) // n) * n
|
||||
|
||||
def _mask(x, dims, alternative=0):
|
||||
"""Masks `x` up to the dynamic shape `dims`.
|
||||
|
||||
Replaces values outside those dimensions with `alternative`. `alternative` is
|
||||
broadcast with `x`.
|
||||
"""
|
||||
assert np.ndim(x) == len(dims)
|
||||
mask: Array | None = None
|
||||
for i, d in enumerate(dims):
|
||||
if d is not None:
|
||||
mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d
|
||||
mask = mask_dim_i if mask is None else (mask & mask_dim_i)
|
||||
return x if mask is None else jnp.where(mask, x, alternative)
|
||||
|
||||
def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes,
|
||||
fill_value=0):
|
||||
"""Similar to lax.dynamic_slice, but handles arrays with dynamic sizes.
|
||||
|
||||
Returns fill_value instead of clamping start_indices for those elements that
|
||||
would overflow the side of the array.
|
||||
|
||||
Args:
|
||||
operand: the array to slice
|
||||
start_indices: the offset of the start of the slice
|
||||
dynamic_slice_sizes: the true (unpadded) size of the slice
|
||||
static_slice_sizes: the padded size of the slice, which must be known at
|
||||
compile time. The static size must be larger than the dynamic size.
|
||||
fill_value: value with which to replace masked-out elements.
|
||||
Returns:
|
||||
An array with static shape `static_slice_sizes`, padded from its true
|
||||
(dynamic) size `dynamic_slice_sizes`.
|
||||
"""
|
||||
# We must pad the input array so the dynamic_slice is guaranteed to fall
|
||||
# entirely in bounds.
|
||||
padded = lax.pad(operand,
|
||||
jnp.array(0, operand.dtype),
|
||||
[(0, d, 0) for d in static_slice_sizes])
|
||||
out = lax.dynamic_slice(padded, tuple(jnp.array(i, np.int32) for i in start_indices),
|
||||
static_slice_sizes)
|
||||
return _mask(out, dynamic_slice_sizes, fill_value)
|
||||
|
||||
def _update_slice(operand, update, start_indices, update_dims):
|
||||
"""
|
||||
Similar to lax.dynamic_update_slice, but handles padded updates where padding
|
||||
values should not overwrite existing values in the array.
|
||||
|
||||
Args:
|
||||
operand: the array to update
|
||||
update: the padded array to write
|
||||
start_indices: the offset at which to write `update`.
|
||||
update_dims: the true dimensions of the padded update `update`. Only values
|
||||
inside the rectangle given by `update_dims` will be overwritten."""
|
||||
operand_shape = operand.shape
|
||||
operand = lax.pad(operand,
|
||||
jnp.array(0, operand.dtype),
|
||||
[(0, d, 0) for d in update.shape])
|
||||
start_indices = tuple(jnp.array(i, np.int32) for i in start_indices)
|
||||
t = lax.dynamic_slice(operand, start_indices, update.shape)
|
||||
t = _mask(update, update_dims, t)
|
||||
operand = lax.dynamic_update_slice(operand, t, start_indices)
|
||||
return lax.slice(operand, [0] * operand.ndim, operand_shape)
|
||||
|
||||
|
||||
def _projector_subspace(P, H, n, rank, maxiter=2, swap=False):
|
||||
"""Decomposes the `n x n` rank `rank` Hermitian projector `P` into
|
||||
|
||||
an `n x rank` isometry `V_minus` such that `P = V_minus @ V_minus.conj().T`
|
||||
and an `n x (n - rank)` isometry `V_minus` such that
|
||||
-(I - P) = V_plus @ V_plus.conj().T`.
|
||||
|
||||
The subspaces are computed using the naiive QR eigendecomposition
|
||||
algorithm, which converges very quickly due to the sharp separation
|
||||
between the relevant eigenvalues of the projector.
|
||||
|
||||
Args:
|
||||
P: A rank-`rank` Hermitian projector into the space of `H`'s first `rank`
|
||||
eigenpairs. `P` is padded to NxN.
|
||||
H: The aforementioned Hermitian matrix, which is used to track convergence.
|
||||
n: the true (dynamic) shape of `P`.
|
||||
rank: Rank of `P`.
|
||||
maxiter: Maximum number of iterations.
|
||||
swap: If true, the two outputs spaces are swapped.
|
||||
|
||||
Returns:
|
||||
V_minus, V_plus: Isometries into the eigenspaces described in the docstring.
|
||||
"""
|
||||
# Choose an initial guess: the `rank` largest-norm columns of P.
|
||||
N, _ = P.shape
|
||||
negative_column_norms = -jnp_linalg.norm(P, axis=1)
|
||||
# `jnp.argsort` ensures NaNs sort last, so set masked-out column norms to NaN.
|
||||
negative_column_norms = _mask(negative_column_norms, (n,), np.nan)
|
||||
sort_idxs = jnp.argsort(negative_column_norms)
|
||||
X = P[:, sort_idxs]
|
||||
# X = X[:, :rank]
|
||||
X = _mask(X, (n, rank))
|
||||
|
||||
H_norm = jnp_linalg.norm(H)
|
||||
thresh = 10.0 * float(dtypes.finfo(X.dtype).eps) * H_norm
|
||||
|
||||
# First iteration skips the matmul.
|
||||
def body_f_after_matmul(X):
|
||||
Q, _ = jnp_linalg.qr(X, mode="complete")
|
||||
# V1 = Q[:, :rank]
|
||||
# V2 = Q[:, rank:]
|
||||
V1 = _mask(Q, (n, rank))
|
||||
V2 = _slice(Q, (0, rank), (n, n - rank), (N, N))
|
||||
|
||||
# TODO: might be able to get away with lower precision here
|
||||
error_matrix = tensor_contractions.dot(V2.conj().T, H)
|
||||
error_matrix = tensor_contractions.dot(error_matrix, V1)
|
||||
error = jnp_linalg.norm(error_matrix)
|
||||
return V1, V2, error
|
||||
|
||||
def cond_f(args):
|
||||
_, _, j, error = args
|
||||
still_counting = j < maxiter
|
||||
unconverged = error > thresh
|
||||
return ufuncs.logical_and(still_counting, unconverged)[0]
|
||||
|
||||
def body_f(args):
|
||||
V1, _, j, _ = args
|
||||
X = tensor_contractions.dot(P, V1)
|
||||
V1, V2, error = body_f_after_matmul(X)
|
||||
return V1, V2, j + 1, error
|
||||
|
||||
V1, V2, error = body_f_after_matmul(X)
|
||||
one = jnp.ones(1, dtype=np.int32)
|
||||
V1, V2, _, error = lax.while_loop(cond_f, body_f, (V1, V2, one, error))
|
||||
if swap:
|
||||
return V2, V1
|
||||
else:
|
||||
return V1, V2
|
||||
|
||||
|
||||
def split_spectrum(H, n, split_point, V0=None):
|
||||
""" The Hermitian matrix `H` is split into two matrices `H_minus`
|
||||
`H_plus`, respectively sharing its eigenspaces beneath and above
|
||||
its `split_point`th eigenvalue.
|
||||
|
||||
Returns, in addition, `V_minus` and `V_plus`, isometries such that
|
||||
`Hi = Vi.conj().T @ H @ Vi`. If `V0` is not None, `V0 @ Vi` are
|
||||
returned instead; this allows the overall isometries mapping from
|
||||
an initial input matrix to progressively smaller blocks to be formed.
|
||||
|
||||
Args:
|
||||
H: The Hermitian matrix to split.
|
||||
split_point: The eigenvalue to split along.
|
||||
V0: Matrix of isometries to be updated.
|
||||
Returns:
|
||||
H_minus: A Hermitian matrix sharing the eigenvalues of `H` beneath
|
||||
`split_point`.
|
||||
V_minus: An isometry from the input space of `V0` to `H_minus`.
|
||||
H_plus: A Hermitian matrix sharing the eigenvalues of `H` above
|
||||
`split_point`.
|
||||
V_plus: An isometry from the input space of `V0` to `H_plus`.
|
||||
rank: The dynamic size of the m subblock.
|
||||
"""
|
||||
N, _ = H.shape
|
||||
H_shift = H - (split_point * jnp.eye(N, dtype=split_point.dtype)).astype(H.dtype)
|
||||
U, _, _, _ = qdwh.qdwh(H_shift, is_hermitian=True, dynamic_shape=(n, n))
|
||||
I = _mask(jnp.eye(N, dtype=H.dtype), (n, n))
|
||||
P_minus = -0.5 * (U - I)
|
||||
rank_minus = jnp.round(jnp.trace(ufuncs.real(P_minus))).astype(np.int32)
|
||||
P_plus = 0.5 * (U + I)
|
||||
rank_plus = n - rank_minus
|
||||
|
||||
# Run subspace iteration on whichever projector P_minus or P_plus that has the
|
||||
# smallest rank. This can save a significant amount of work when H has
|
||||
# rank << n or if our estimate of the median eigenvalue is poor, because
|
||||
# the subspace iteration involves computing the QR decomposition of a
|
||||
# matrix of size n x rank.
|
||||
swap = rank_plus < rank_minus
|
||||
V_minus, V_plus = lax.cond(
|
||||
swap,
|
||||
lambda: _projector_subspace(P_plus, H, n, rank_plus, swap=True),
|
||||
lambda: _projector_subspace(P_minus, H, n, rank_minus, swap=False),
|
||||
)
|
||||
H_minus = (V_minus.conj().T @ H) @ V_minus
|
||||
H_plus = (V_plus.conj().T @ H) @ V_plus
|
||||
if V0 is not None:
|
||||
V_minus = tensor_contractions.dot(V0, V_minus)
|
||||
V_plus = tensor_contractions.dot(V0, V_plus)
|
||||
return H_minus, V_minus, H_plus, V_plus, rank_minus
|
||||
|
||||
|
||||
# To help understand the iterative version of the algorithm, the original
|
||||
# recursive formulation follows.
|
||||
#
|
||||
# def _eigh_work(H, V=None, termination_size=128):
|
||||
# """ The main work loop performing the symmetric eigendecomposition of H.
|
||||
# Each step recursively computes a projector into the space of eigenvalues
|
||||
# above jnp.mean(jnp.diag(H)). The result of the projections into and out of
|
||||
# that space, along with the isometries accomplishing these, are then computed.
|
||||
# This is performed recursively until the projections have size 1, and thus
|
||||
# store an eigenvalue of the original input; the corresponding isometry is
|
||||
# the related eigenvector. The results are then composed.
|
||||
#
|
||||
# Args:
|
||||
# H: The Hermitian input.
|
||||
# V: Stores the isometries projecting H into its subspaces.
|
||||
# precision: :class:`~jax.lax.Precision` object specifying the matmul precision.
|
||||
#
|
||||
# Returns:
|
||||
# H, V: The result of the projection.
|
||||
# """
|
||||
# if H.shape[0] <= termination_size:
|
||||
# evals, evecs = jnp_linalg.eigh(H)
|
||||
# if V is not None:
|
||||
# evecs = jnp.dot(V, evecs)
|
||||
# return evals, evecs
|
||||
#
|
||||
# split_point = jnp.median(jnp.diag(H)) # TODO: Improve this?
|
||||
# H_minus, V_minus, H_plus, V_plus = split_spectrum(H, split_point, V0=V)
|
||||
# H_minus, V_minus = _eigh_work(H_minus, V=V_minus, termination_size=termination_size)
|
||||
# H_plus, V_plus = _eigh_work(H_plus, V=V_plus, termination_size=termination_size)
|
||||
#
|
||||
# evals = jnp.hstack((H_minus, H_plus))
|
||||
# evecs = jnp.hstack((V_minus, V_plus))
|
||||
# return evals, evecs
|
||||
|
||||
class _Subproblem(NamedTuple):
|
||||
"""Describes a subproblem of _eigh_work.
|
||||
|
||||
Each subproblem is a `size` x `size` Hermitian matrix, starting at `offset`
|
||||
in the workspace.
|
||||
"""
|
||||
# The row offset of the block in the matrix of blocks.
|
||||
offset: Array
|
||||
|
||||
# The size of the block.
|
||||
size: Array
|
||||
|
||||
|
||||
@api.jit(static_argnames=('termination_size', 'subset_by_index'))
|
||||
def _eigh_work(H, n, termination_size, subset_by_index):
|
||||
""" The main work loop performing the symmetric eigendecomposition of H.
|
||||
Each step recursively computes a projector into the space of eigenvalues
|
||||
above jnp.mean(jnp.diag(H)). The result of the projections into and out of
|
||||
that space, along with the isometries accomplishing these, are then computed.
|
||||
This is performed recursively until the projections have size 1, and thus
|
||||
store an eigenvalue of the original input; the corresponding isometry is
|
||||
the related eigenvector. The results are then composed.
|
||||
|
||||
This function cannot be Jitted because the internal split_spectrum cannot
|
||||
be.
|
||||
|
||||
Args:
|
||||
H: The Hermitian input.
|
||||
n: The true (dynamic) shape of H.
|
||||
|
||||
Returns:
|
||||
H, V: The result of the projection.
|
||||
"""
|
||||
# We turn what was originally a recursive algorithm into an iterative
|
||||
# algorithm with an explicit stack.
|
||||
N, _ = H.shape
|
||||
n = jnp.asarray(n, np.int32)
|
||||
agenda = Stack.create(
|
||||
N + 1, _Subproblem(jnp.array(0, np.int32), jnp.array(0, np.int32)))
|
||||
agenda = agenda.push(_Subproblem(offset=jnp.array(0, np.int32), size=n))
|
||||
|
||||
# eigenvectors is the array in which we build the output eigenvectors.
|
||||
# We initialize it with the identity matrix so the initial matrix
|
||||
# multiplications in_split_spectrum_jittable are the identity.
|
||||
eigenvectors = jnp.eye(N, dtype=H.dtype)
|
||||
|
||||
# Keep a copy of the initial matrix Frobenius norm, so we know when to stop
|
||||
# recursing. When the sub-matrix norm is less than eps*H0_norm, the contents are
|
||||
# pure numerical noise, and we should just stop.
|
||||
H0_norm = jnp_linalg.norm(_mask(H, (n, n)))
|
||||
|
||||
# blocks is an array representing a stack of Hermitian matrix blocks that we
|
||||
# need to recursively decompose. Subproblems are different sizes, so the stack
|
||||
# of blocks is ragged. Subproblems are left-aligned (i.e. starting at the 0th
|
||||
# column). Here is an ASCII art picture of three blocks A, B, C, embedded
|
||||
# in the larger `blocks` workspace (represented with trailing dots).
|
||||
#
|
||||
# A A A . . .
|
||||
# A A A . . .
|
||||
# A A A . . .
|
||||
# B B . . . .
|
||||
# B B . . . .
|
||||
# C C C C . .
|
||||
# C C C C . .
|
||||
# C C C C . .
|
||||
# C C C C . .
|
||||
#
|
||||
# Each step of the algorithm subdivides a block into two subblocks whose
|
||||
# sizes sum to the original block size. We overwrite the original block with
|
||||
# those two subblocks so we don't need any additional scratch space.
|
||||
#
|
||||
# At termination, "blocks" will contain 1x1 blocks (i.e., the eigenvalues) in
|
||||
# its first column.
|
||||
blocks = H
|
||||
|
||||
def base_case(B, offset, b, agenda, blocks, eigenvectors):
|
||||
# Base case: for blocks under a minimum size, we cutoff the recursion
|
||||
# and call the TPU Jacobi eigendecomposition implementation. The Jacobi
|
||||
# algorithm works well for small matrices but scales poorly, so the two
|
||||
# complement each other well.
|
||||
H = _slice(blocks, (offset, 0), (b, b), (B, B))
|
||||
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
|
||||
|
||||
# We replace the masked-out part of the matrix with the identity matrix.
|
||||
# We know that the TPU Jacobi eigh implementation will not alter the order
|
||||
# of the eigenvalues, so we know the eigendecomposition of the original
|
||||
# matrix is in the top-left corner of the eigendecomposition of the padded
|
||||
# matrix.
|
||||
# It is very important that the underlying eigh implementation does not sort
|
||||
# the eigenvalues for this reason! This is currently not true of JAX's CPU
|
||||
# and GPU eigendecompositions, and for those platforms this algorithm will
|
||||
# only do the right thing if termination_size == 1.
|
||||
H = _mask(H, (b, b))
|
||||
eig_vecs, eig_vals = lax_linalg.eigh(H, sort_eigenvalues=False)
|
||||
eig_vecs = _mask(eig_vecs, (b, b))
|
||||
eig_vals = _mask(eig_vals, (b,))
|
||||
eig_vecs = tensor_contractions.dot(V, eig_vecs)
|
||||
|
||||
eig_vals = eig_vals.astype(eig_vecs.dtype)
|
||||
blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, 1))
|
||||
eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b))
|
||||
return agenda, blocks, eigenvectors
|
||||
|
||||
def recursive_case(B, offset, b, agenda, blocks, eigenvectors):
|
||||
# The recursive case of the algorithm, specialized to a static block size
|
||||
# of B.
|
||||
H = _slice(blocks, (offset, 0), (b, b), (B, B))
|
||||
|
||||
def nearly_diagonal_case(agenda, blocks, eigenvectors):
|
||||
blocks = _update_slice(blocks, jnp.diag(H)[:, None], (offset, 0), (b, 1))
|
||||
return agenda, blocks, eigenvectors
|
||||
|
||||
def should_update_range(start, end, subset_by_index):
|
||||
return (
|
||||
True
|
||||
if subset_by_index is None
|
||||
else ((start < subset_by_index[1]) & (end > subset_by_index[0]))
|
||||
)
|
||||
|
||||
def default_case(agenda, blocks, eigenvectors):
|
||||
V = _slice(eigenvectors, (0, offset), (n, b), (N, B))
|
||||
# TODO: Improve this?
|
||||
split_point = reductions.nanmedian(_mask(jnp.diag(ufuncs.real(H)), (b,), np.nan))
|
||||
H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(
|
||||
H, b, split_point, V0=V)
|
||||
|
||||
# Update state for *_minus.
|
||||
updated_minus_state = (
|
||||
_update_slice(blocks, H_minus, (offset, 0), (rank, rank)),
|
||||
_update_slice(eigenvectors, V_minus, (0, offset), (n, rank)),
|
||||
agenda.push(_Subproblem(offset, rank)),
|
||||
)
|
||||
should_update_minus = should_update_range(
|
||||
offset, offset + rank, subset_by_index
|
||||
)
|
||||
blocks, eigenvectors, agenda = lax.cond(
|
||||
should_update_minus,
|
||||
lambda: updated_minus_state,
|
||||
lambda: (blocks, eigenvectors, agenda),
|
||||
)
|
||||
|
||||
# Update state for *_plus.
|
||||
updated_plus_state = (
|
||||
_update_slice(
|
||||
blocks, H_plus, (offset + rank, 0), (b - rank, b - rank)
|
||||
),
|
||||
_update_slice(
|
||||
eigenvectors, V_plus, (0, offset + rank), (n, b - rank)
|
||||
),
|
||||
agenda.push(_Subproblem(offset + rank, (b - rank))),
|
||||
)
|
||||
|
||||
should_update_plus = should_update_range(
|
||||
offset + rank, offset + b, subset_by_index
|
||||
)
|
||||
blocks, eigenvectors, agenda = lax.cond(
|
||||
should_update_plus,
|
||||
lambda: updated_plus_state,
|
||||
lambda: (blocks, eigenvectors, agenda),
|
||||
)
|
||||
|
||||
return agenda, blocks, eigenvectors
|
||||
|
||||
# If the matrix is nearly diagonal or has a tiny Frobenius norm compared to
|
||||
# the original input matrix,, terminate the execution. This is necessary to
|
||||
# handle matrices with clusters of eigenvalues, including rank deficient
|
||||
# matrices. See Nakatsukasa and Higham section 5.2.
|
||||
norm = jnp_linalg.norm(H)
|
||||
eps = jnp.asarray(dtypes.finfo(H.dtype).eps, dtype=norm.dtype)
|
||||
off_diag_norm = jnp_linalg.norm(
|
||||
H - jnp.diag(jnp.diag(ufuncs.real(H)).astype(H.dtype)))
|
||||
nearly_diagonal = off_diag_norm <= 5 * eps * norm
|
||||
tiny = norm < eps * H0_norm
|
||||
return lax.cond(
|
||||
nearly_diagonal | tiny,
|
||||
nearly_diagonal_case,
|
||||
default_case,
|
||||
agenda,
|
||||
blocks,
|
||||
eigenvectors,
|
||||
)
|
||||
|
||||
def loop_cond(state):
|
||||
agenda, _, _ = state
|
||||
return ~agenda.empty()
|
||||
|
||||
# It would be wasteful to perform all computation padded up to the original
|
||||
# matrix size. Instead, we form buckets of padded sizes e.g.,
|
||||
# [N_0, N_1, ... N_k], aiming for a balance between compilation time
|
||||
# and runtime.
|
||||
cutoff = min(N, termination_size)
|
||||
buckets = [cutoff]
|
||||
branches = [partial(base_case, cutoff)]
|
||||
if N > termination_size:
|
||||
# If N > termination_size We use the following schedule:
|
||||
# 1. N_0 = N,
|
||||
# 2. N_i = _round_up(int(N_{i-1} / 1.98), 32), 0 < i < k
|
||||
# 3. N_k = termination_size
|
||||
# the rule for N_i is to avoid falling into the original large bucket
|
||||
# when not splitting exactly at the half-way point during the recursion.
|
||||
buckets.append(N)
|
||||
branches.append(partial(recursive_case, N))
|
||||
multiplier = 1.98
|
||||
granularity = 32
|
||||
i = int(N / multiplier)
|
||||
while i > cutoff:
|
||||
bucket_size = _round_up(i, granularity)
|
||||
buckets.append(bucket_size)
|
||||
branches.append(partial(recursive_case, bucket_size))
|
||||
i = i // 2
|
||||
buckets_arr = jnp.array(buckets, dtype=np.int32)
|
||||
|
||||
def loop_body(state):
|
||||
agenda, blocks, eigenvectors = state
|
||||
(offset, b), agenda = agenda.pop()
|
||||
which = jnp.where(buckets_arr < b, dtypes.iinfo(np.int32).max, buckets_arr)
|
||||
choice = jnp.argmin(which)
|
||||
return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors)
|
||||
|
||||
_, blocks, eigenvectors = lax.while_loop(
|
||||
loop_cond, loop_body, (agenda, blocks, eigenvectors))
|
||||
return blocks[:, 0], eigenvectors
|
||||
|
||||
|
||||
def eigh(
|
||||
H,
|
||||
*,
|
||||
precision='float32',
|
||||
termination_size=256,
|
||||
n=None,
|
||||
sort_eigenvalues=True,
|
||||
subset_by_index=None,
|
||||
):
|
||||
"""Computes the eigendecomposition of the symmetric/Hermitian matrix H.
|
||||
|
||||
Args:
|
||||
H: The `n x n` Hermitian input, padded to `N x N`.
|
||||
precision: :class:`~jax.lax.Precision` object specifying the matmul
|
||||
precision.
|
||||
termination_size: Recursion ends once the blocks reach this linear size.
|
||||
n: the true (dynamic) size of the matrix.
|
||||
sort_eigenvalues: If `True`, the eigenvalues will be sorted from lowest to
|
||||
highest.
|
||||
subset_by_index: Optional 2-tuple [start, end] indicating the range of
|
||||
indices of eigenvalues to compute. For example, is ``range_select`` =
|
||||
[n-2,n], then ``eigh`` computes the two largest eigenvalues and their
|
||||
eigenvectors.
|
||||
|
||||
Returns:
|
||||
vals: The `n` eigenvalues of `H`.
|
||||
vecs: A unitary matrix such that `vecs[:, i]` is a normalized eigenvector
|
||||
of `H` corresponding to `vals[i]`. We have `H @ vecs = vals * vecs` up
|
||||
to numerical error.
|
||||
"""
|
||||
M, N = H.shape
|
||||
if M != N:
|
||||
raise TypeError(f"Input H of shape {H.shape} must be square.")
|
||||
if n is not None and n > N:
|
||||
raise ValueError('Static size must be greater or equal to dynamic size.')
|
||||
|
||||
compute_slice = False
|
||||
if subset_by_index is not None:
|
||||
compute_slice = subset_by_index != (0, n)
|
||||
if len(subset_by_index) != 2:
|
||||
raise ValueError('subset_by_index must be a tuple of size 2.')
|
||||
if subset_by_index[0] >= subset_by_index[1]:
|
||||
raise ValueError('Got empty index range in subset_by_index.')
|
||||
if subset_by_index[0] < 0:
|
||||
raise ValueError('Indices in subset_by_index must be non-negative.')
|
||||
range_max = N if n is None else n
|
||||
if subset_by_index[1] > range_max:
|
||||
raise ValueError('Index in subset_by_index[1] exceeds matrix size.')
|
||||
|
||||
if N <= termination_size:
|
||||
if n is not None:
|
||||
H = _mask(H, (n, n))
|
||||
eig_vecs, eig_vals = lax_linalg.eigh(
|
||||
H, lower=True, sort_eigenvalues=(sort_eigenvalues or compute_slice),
|
||||
subset_by_index=None, symmetrize_input=False,
|
||||
implementation=lax_linalg.EighImplementation.JACOBI,
|
||||
)
|
||||
if subset_by_index is not None and compute_slice:
|
||||
eig_vals = eig_vals[subset_by_index[0] : subset_by_index[1]]
|
||||
eig_vecs = eig_vecs[:, subset_by_index[0] : subset_by_index[1]]
|
||||
return eig_vals, eig_vecs
|
||||
|
||||
n = N if n is None else n
|
||||
with config.default_matmul_precision(precision):
|
||||
eig_vals, eig_vecs = _eigh_work(
|
||||
H, n, termination_size=termination_size, subset_by_index=subset_by_index
|
||||
)
|
||||
eig_vals = _mask(ufuncs.real(eig_vals), (n,), np.nan)
|
||||
if sort_eigenvalues or compute_slice:
|
||||
sort_idxs = jnp.argsort(eig_vals)
|
||||
if compute_slice:
|
||||
sort_idxs = sort_idxs[subset_by_index[0] : subset_by_index[1]]
|
||||
eig_vals = eig_vals[sort_idxs]
|
||||
eig_vecs = eig_vecs[:, sort_idxs]
|
||||
|
||||
return eig_vals, eig_vecs
|
||||
|
||||
|
||||
def _T(x: Array) -> Array:
|
||||
return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
|
||||
|
||||
|
||||
def _eigh_qdwh_impl(x, *, lower, sort_eigenvalues, subset_by_index):
|
||||
"""QDWH-based eigendecomposition for TPU."""
|
||||
*_, m, n = x.shape
|
||||
assert m == n, (m, n)
|
||||
|
||||
termination_size = 256
|
||||
if not core.is_constant_dim(m):
|
||||
# TODO: maybe we can relax the check below for shape polymorphism?
|
||||
raise NotImplementedError(
|
||||
"Shape polymorphism for native lowering for eigh is implemented "
|
||||
f"only for the batch dimensions: {x.shape}")
|
||||
|
||||
if m <= termination_size and (
|
||||
subset_by_index is None or subset_by_index == (0, n)
|
||||
):
|
||||
return lax_linalg.eigh(
|
||||
x, lower=lower, sort_eigenvalues=sort_eigenvalues,
|
||||
symmetrize_input=False,
|
||||
implementation=lax_linalg.EighImplementation.JACOBI
|
||||
)
|
||||
|
||||
def eigh_qdwh(x):
|
||||
if len(x.shape) > 2:
|
||||
return control_flow.map(eigh_qdwh, x)
|
||||
|
||||
# We should only look at elements from the lower/upper triangle. Reflects
|
||||
# that triangle into the other triangle to form a Hermitian matrix.
|
||||
if lower:
|
||||
mask = lax_internal._tri(bool, (n, n), 0)
|
||||
else:
|
||||
mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1))
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
re = lax.select(mask, lax.real(x), _T(lax.real(x)))
|
||||
if lower:
|
||||
im_mask = lax_internal._tri(bool, (n, n), -1)
|
||||
else:
|
||||
im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0))
|
||||
im = lax.imag(x)
|
||||
im = lax.select(im_mask, im, lax.full_like(im, 0))
|
||||
im = lax.select(mask, im, -_T(im))
|
||||
x = lax.complex(re, im)
|
||||
else:
|
||||
x = lax.select(mask, x, _T(x))
|
||||
|
||||
return eigh(
|
||||
x,
|
||||
sort_eigenvalues=sort_eigenvalues,
|
||||
termination_size=termination_size,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
|
||||
eig_vals, eig_vecs = eigh_qdwh(x)
|
||||
return eig_vecs, eig_vals
|
||||
|
||||
|
||||
def _eigh_tpu_lowering(
|
||||
ctx, operand, *, lower, sort_eigenvalues, subset_by_index, algorithm
|
||||
):
|
||||
if algorithm is None:
|
||||
algorithm = lax_linalg.EighImplementation.QDWH
|
||||
|
||||
if algorithm == lax_linalg.EighImplementation.QR:
|
||||
raise NotImplementedError("QR algorithm is not supported on TPU")
|
||||
|
||||
elif algorithm == lax_linalg.EighImplementation.JACOBI:
|
||||
operand_aval, = ctx.avals_in
|
||||
if operand_aval.shape[-1] == 0:
|
||||
reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1])
|
||||
return [
|
||||
operand,
|
||||
hlo.real(mlir.reshape(ctx, operand, reshape_aval)),
|
||||
]
|
||||
|
||||
v_aval, w_aval = ctx.avals_out
|
||||
eigvecs_type = mlir.aval_to_ir_type(v_aval)
|
||||
eigvals_type = mlir.aval_to_ir_type(w_aval)
|
||||
result_types = [eigvecs_type, eigvals_type]
|
||||
|
||||
backend_config = f"{int(lower)},{int(sort_eigenvalues)},100,1e-6"
|
||||
|
||||
if any(not is_constant_shape(aval_out.shape)
|
||||
for aval_out in ctx.avals_out):
|
||||
result_shapes = [
|
||||
mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
|
||||
for aval_out in ctx.avals_out
|
||||
]
|
||||
else:
|
||||
result_shapes = None
|
||||
op = mlir.custom_call(
|
||||
"Eigh",
|
||||
result_types=result_types,
|
||||
operands=[operand],
|
||||
backend_config=backend_config,
|
||||
api_version=1,
|
||||
result_shapes=result_shapes,
|
||||
)
|
||||
return op.results
|
||||
elif algorithm == lax_linalg.EighImplementation.QDWH:
|
||||
return mlir.lower_fun(_eigh_qdwh_impl, multiple_results=True)(
|
||||
ctx, operand, lower=lower, sort_eigenvalues=sort_eigenvalues,
|
||||
subset_by_index=subset_by_index)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown algorithm: {algorithm}")
|
||||
|
||||
|
||||
mlir.register_lowering(lax_linalg.eigh_p, _eigh_tpu_lowering, platform='tpu')
|
||||
@@ -0,0 +1,291 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""A JIT-compatible library for QDWH-based polar decomposition.
|
||||
|
||||
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
||||
iteration implemented through QR decmopositions does not require matrix
|
||||
inversion. This is desirable for multicore and heterogeneous computing systems.
|
||||
|
||||
Reference: Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
|
||||
"Optimizing Halley's iteration for computing the matrix polar decomposition."
|
||||
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
|
||||
https://epubs.siam.org/doi/abs/10.1137/090774999
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.numpy import linalg as jnp_linalg
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
# Helpers for working with padded shapes
|
||||
def _mask(x, dims, alternative=0):
|
||||
"""Masks `x` up to the dynamic shape `dims`.
|
||||
|
||||
Replaces values outside those dimensions with `alternative`. `alternative` is
|
||||
broadcast with `x`.
|
||||
"""
|
||||
assert np.ndim(x) == len(dims)
|
||||
mask: Array | None = None
|
||||
for i, d in enumerate(dims):
|
||||
if d is not None:
|
||||
mask_dim_i = lax.broadcasted_iota(np.int32, x.shape, i) < d
|
||||
mask = mask_dim_i if mask is None else (mask & mask_dim_i)
|
||||
return x if mask is None else jnp.where(mask, x, alternative)
|
||||
|
||||
def _pad_in_dim(x, low=0, high=0, interior=0, fill_value=0, axis=0):
|
||||
pads = [(0, 0, 0)] * x.ndim
|
||||
pads[axis] = (low, high, interior)
|
||||
return lax.pad(x, jnp.array(fill_value, x.dtype), pads)
|
||||
|
||||
def _dynamic_concat(a, b, m, axis=0):
|
||||
"Concatenates padded arrays `a` and `b` where the true size of `a` is `m`."
|
||||
if m is None:
|
||||
return jnp.concatenate([a, b], axis=axis)
|
||||
return lax.dynamic_update_slice_in_dim(
|
||||
_pad_in_dim(a, high=b.shape[axis], axis=axis), b, m, axis)
|
||||
|
||||
|
||||
def _use_qr(u, m, n, params):
|
||||
"""QDWH iteration using QR decomposition.
|
||||
|
||||
Args:
|
||||
u: a matrix, with static (padded) shape M x N.
|
||||
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
|
||||
params: the QDWH parameters.
|
||||
"""
|
||||
a_minus_e_by_sqrt_c, sqrt_c, e = params
|
||||
M, N = u.shape
|
||||
|
||||
y = _dynamic_concat(sqrt_c * u, jnp.eye(N, dtype=dtypes.dtype(u)), m)
|
||||
q, _ = lax_linalg.qr(y, full_matrices=False)
|
||||
# q1 = q[:m, :]
|
||||
q1 = _mask(lax.slice(q, (0, 0), (M, N)), (m, n))
|
||||
# q2 = (q[m:, :]).T.conj()
|
||||
q2 = lax.dynamic_slice_in_dim(q, m, N, axis=0)
|
||||
q2 = _mask(q2, (n, n)).T.conj()
|
||||
return e * u + a_minus_e_by_sqrt_c * (q1 @ q2)
|
||||
|
||||
|
||||
def _use_cholesky(u, m, n, params):
|
||||
"""QDWH iteration using Cholesky decomposition.
|
||||
|
||||
Args:
|
||||
u: a matrix, with static (padded) shape M x N
|
||||
m, n: the dynamic shape of the matrix, where m <= M and n <= N.
|
||||
params: the QDWH parameters.
|
||||
"""
|
||||
a_minus_e, c, e = params
|
||||
_, N = u.shape
|
||||
x = c * (u.T.conj() @ u) + jnp.eye(N, dtype=dtypes.dtype(u))
|
||||
# Pads the lower-right corner with the identity matrix to prevent the Cholesky
|
||||
# decomposition from failing due to the matrix not being PSD if padded with
|
||||
# zeros.
|
||||
x = _mask(x, (n, n), jnp.eye(N, dtype=x.dtype))
|
||||
|
||||
# `y` is lower triangular.
|
||||
y = lax_linalg.cholesky(x, symmetrize_input=False)
|
||||
|
||||
z = lax_linalg.triangular_solve(
|
||||
y, u.T, left_side=True, lower=True, conjugate_a=True).conj()
|
||||
|
||||
z = lax_linalg.triangular_solve(y, z, left_side=True, lower=True,
|
||||
transpose_a=True, conjugate_a=True).T.conj()
|
||||
|
||||
return e * u + a_minus_e * z
|
||||
|
||||
|
||||
def _qdwh(x, m, n, max_iterations, eps):
|
||||
"""QR-based dynamically weighted Halley iteration for polar decomposition."""
|
||||
|
||||
# Estimates `alpha` and `beta = alpha * l`, where `alpha` is an estimate of
|
||||
# norm(x, 2) such that `alpha >= norm(x, 2)` and `beta` is a lower bound for
|
||||
# the smallest singular value of x.
|
||||
if eps is None:
|
||||
eps = float(dtypes.finfo(x.dtype).eps)
|
||||
one_norm = jnp_linalg.norm(x, ord=1)
|
||||
inf_norm = jnp_linalg.norm(x, ord=np.inf)
|
||||
alpha_inverse = lax.rsqrt(one_norm) * lax.rsqrt(inf_norm)
|
||||
alpha_inverse = jnp.where(one_norm == 0, 1, alpha_inverse)
|
||||
u = x * alpha_inverse.astype(x.dtype)
|
||||
|
||||
l = eps
|
||||
|
||||
# Iteration tolerances.
|
||||
tol_l = 10.0 * eps / 2.0
|
||||
tol_norm = jnp.cbrt(tol_l)
|
||||
|
||||
def get_qr_params(a, b, c):
|
||||
e = b / c
|
||||
a_minus_e = a - e
|
||||
sqrt_c = c ** (1 / 2)
|
||||
return (a_minus_e / sqrt_c, sqrt_c, e)
|
||||
|
||||
def get_chol_params(a, b, c):
|
||||
e = b / c
|
||||
a_minus_e = a - e
|
||||
return (a_minus_e, c, e)
|
||||
|
||||
CHOLESKY_CUTOFF = 100
|
||||
|
||||
qr_coefs = []
|
||||
chol_coefs = []
|
||||
k = 0
|
||||
while l + tol_l < 1 and k < max_iterations:
|
||||
k += 1
|
||||
l2 = l * l
|
||||
dd = (4 * (1 / l2 - 1) / l2) ** (1 / 3)
|
||||
sqd = (1.0 + dd) ** (1 / 2)
|
||||
a = sqd + (2 - dd + 2 * (2 - l2) / (l2 * sqd)) ** (1 / 2)
|
||||
b = (a - 1) ** 2 / 4
|
||||
c = a + b - 1
|
||||
l = l * (a + b * l2) / (1 + c * l2)
|
||||
if c > CHOLESKY_CUTOFF:
|
||||
qr_coefs.append(get_qr_params(a, b, c))
|
||||
else:
|
||||
chol_coefs.append(get_chol_params(a, b, c))
|
||||
|
||||
def iteration(k, state, update_fn, coefs, test_convergence):
|
||||
u, _ = state
|
||||
|
||||
if coefs is None:
|
||||
# As l → 1, the coefficients a, b, c → 3, 1, 3, which is Halley's method.
|
||||
params = get_chol_params(3, 1, 3)
|
||||
else:
|
||||
params = lax.dynamic_index_in_dim(coefs, k, keepdims=False)
|
||||
|
||||
u_prev = u
|
||||
u = update_fn(u, m, n, params)
|
||||
|
||||
is_not_converged = True
|
||||
if test_convergence:
|
||||
is_not_converged = jnp_linalg.norm(u - u_prev) > tol_norm
|
||||
return u, is_not_converged
|
||||
|
||||
def iterate(u, coefs, **kwargs):
|
||||
if not coefs:
|
||||
return u, True
|
||||
coefs = jnp.array(coefs).astype(x.dtype)
|
||||
body = functools.partial(iteration, coefs=coefs, **kwargs)
|
||||
return lax.fori_loop(0, len(coefs), body, (u, True))
|
||||
|
||||
u, _ = iterate(
|
||||
u, coefs=qr_coefs, update_fn=_use_qr, test_convergence=False
|
||||
)
|
||||
u, is_not_converged = iterate(
|
||||
u, coefs=chol_coefs, update_fn=_use_cholesky, test_convergence=True
|
||||
)
|
||||
|
||||
# If l has converged but u still has not, continue with Halley's method
|
||||
# (coef = None) until convergence.
|
||||
def cond_fun(state):
|
||||
k, _, is_not_converged = state
|
||||
return jnp.logical_and(is_not_converged, k < max_iterations)
|
||||
|
||||
def body_fun(state):
|
||||
k, u, is_not_converged = state
|
||||
u, is_not_converged = iteration(
|
||||
k,
|
||||
(u, is_not_converged),
|
||||
coefs=None,
|
||||
update_fn=_use_cholesky,
|
||||
test_convergence=True,
|
||||
)
|
||||
return k + 1, u, is_not_converged
|
||||
|
||||
k = len(qr_coefs) + len(chol_coefs)
|
||||
num_iters, u, is_not_converged = lax.while_loop(
|
||||
cond_fun, body_fun, (k, u, is_not_converged)
|
||||
)
|
||||
|
||||
# Applies Newton-Schulz refinement for better accuracy.
|
||||
u = 1.5 * u - 0.5 * u @ (u.T.conj() @ u)
|
||||
|
||||
h = u.T.conj() @ x
|
||||
h = (h + h.T.conj()) / 2
|
||||
|
||||
# Converged within the maximum number of iterations.
|
||||
is_converged = jnp.logical_not(is_not_converged)
|
||||
|
||||
return u, h, num_iters, is_converged
|
||||
|
||||
|
||||
# TODO: Add pivoting.
|
||||
@functools.partial(
|
||||
api.jit, static_argnames=('is_hermitian', 'max_iterations', 'eps')
|
||||
)
|
||||
def qdwh(
|
||||
x,
|
||||
*,
|
||||
is_hermitian: bool = False,
|
||||
max_iterations: int | None = None,
|
||||
eps: float | None = None,
|
||||
dynamic_shape: tuple[int, int] | None = None,
|
||||
):
|
||||
"""QR-based dynamically weighted Halley iteration for polar decomposition.
|
||||
|
||||
Args:
|
||||
x: A full-rank matrix, with shape `M x N`. The matrix may be padded up to
|
||||
that size from a smaller true shape (``dynamic_shape``).
|
||||
is_hermitian: True if `x` is Hermitian. Default to `False`. This parameter
|
||||
is currently unused, but exists for backward compatibility.
|
||||
eps: The final result will satisfy ``|x_k - x_k-1| < |x_k| *
|
||||
(4*eps)**(1/3)`` where `x_k` is the iterate.
|
||||
max_iterations: Iterations will terminate after this many steps even if the
|
||||
above is unsatisfied.
|
||||
dynamic_shape: the unpadded shape as an ``(m, n)`` tuple; optional.
|
||||
|
||||
Returns:
|
||||
A four-tuple of (u, h, num_iters, is_converged) containing the
|
||||
polar decomposition of `x = u * h`, the number of iterations to compute `u`,
|
||||
and `is_converged`, whose value is `True` when the convergence is achieved
|
||||
within the maximum number of iterations.
|
||||
"""
|
||||
# TODO: Possibly take advantage of Hermitian inputs to speed up the QDWH step.
|
||||
is_hermitian = core.concrete_or_error(
|
||||
bool, is_hermitian, 'The `is_hermitian` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
|
||||
if max_iterations is None:
|
||||
max_iterations = 10
|
||||
else:
|
||||
max_iterations = core.concrete_or_error(
|
||||
int, max_iterations, 'The `max_iterations` argument must be statically '
|
||||
'specified to use `qdwh` within JAX transformations.')
|
||||
|
||||
M, N = x.shape
|
||||
if M < N:
|
||||
raise ValueError('The input matrix of shape M x N must have M >= N.')
|
||||
if dynamic_shape is not None:
|
||||
m, n = dynamic_shape
|
||||
x = _mask(x, (m, n))
|
||||
else:
|
||||
m, n = M, N
|
||||
|
||||
with config.default_matmul_precision('float32'):
|
||||
u, h, num_iters, is_converged = _qdwh(x, m, n, max_iterations, eps)
|
||||
|
||||
return u, h, num_iters, is_converged
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""A bounded functional stack implementation.
|
||||
|
||||
Used as a helper for expressing recursive algorithms such as QDWH-eig for
|
||||
Eigendecomposition on TPU.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src import tree_util
|
||||
|
||||
|
||||
class Stack:
|
||||
"""A bounded functional stack implementation. Elements may be pytrees."""
|
||||
def __init__(self, size, data):
|
||||
"""Private constructor."""
|
||||
self._size = size
|
||||
self._data = data
|
||||
|
||||
def __repr__(self):
|
||||
return f"Stack({self._size}, {self._data})"
|
||||
|
||||
@staticmethod
|
||||
def create(capacity: int, prototype: Any) -> Stack:
|
||||
"""Creates a stack with size `capacity` with elements like `prototype`.
|
||||
|
||||
`prototype` can be any JAX pytree. This function looks only at its
|
||||
structure; the specific values are ignored.
|
||||
"""
|
||||
return Stack(
|
||||
jnp.array(0, 'int32'),
|
||||
tree_util.tree_map(
|
||||
lambda x: jnp.zeros((capacity,) + tuple(x.shape), x.dtype), prototype))
|
||||
|
||||
def empty(self) -> Any:
|
||||
"""Returns true if the stack is empty."""
|
||||
return self._size == 0
|
||||
|
||||
def push(self, elem: Any) -> Stack:
|
||||
"""Pushes `elem` onto the stack, returning the updated stack."""
|
||||
return Stack(
|
||||
self._size + 1,
|
||||
tree_util.tree_map(
|
||||
lambda x, y: lax.dynamic_update_index_in_dim(x, y, self._size, 0),
|
||||
self._data, elem))
|
||||
|
||||
def pop(self) -> tuple[Any, Stack]:
|
||||
"""Pops from the stack, returning an (elem, updated stack) pair."""
|
||||
elem = tree_util.tree_map(
|
||||
lambda x: lax.dynamic_index_in_dim(x, self._size - 1, 0, keepdims=False),
|
||||
self._data)
|
||||
return elem, Stack(self._size - 1, self._data)
|
||||
|
||||
def flatten(self):
|
||||
leaves, treedef = tree_util.tree_flatten(self._data)
|
||||
return ([self._size] + leaves), treedef
|
||||
|
||||
@staticmethod
|
||||
def unflatten(treedef, leaves):
|
||||
return Stack(leaves[0], tree_util.tree_unflatten(treedef, leaves[1:]))
|
||||
|
||||
tree_util.register_pytree_node(Stack, Stack.flatten, Stack.unflatten)
|
||||
@@ -0,0 +1,312 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
|
||||
"""A JIT-compatible library for QDWH-based singular value decomposition.
|
||||
|
||||
QDWH is short for QR-based dynamically weighted Halley iteration. The Halley
|
||||
iteration implemented through QR decompositions is numerically stable and does
|
||||
not require solving a linear system involving the iteration matrix or
|
||||
computing its inversion. This is desirable for multicore and heterogeneous
|
||||
computing systems.
|
||||
|
||||
References:
|
||||
Nakatsukasa, Yuji, and Nicholas J. Higham.
|
||||
"Stable and efficient spectral divide and conquer algorithms for the symmetric
|
||||
eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing 35,
|
||||
no. 3 (2013): A1325-A1349.
|
||||
https://epubs.siam.org/doi/abs/10.1137/120876605
|
||||
|
||||
Nakatsukasa, Yuji, Zhaojun Bai, and François Gygi.
|
||||
"Optimizing Halley's iteration for computing the matrix polar decomposition."
|
||||
SIAM Journal on Matrix Analysis and Applications 31, no. 5 (2010): 2700-2720.
|
||||
https://epubs.siam.org/doi/abs/10.1137/090774999
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
import operator
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import lax
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src.tpu.linalg import qdwh as tpu_qdwh
|
||||
from jax._src.typing import Array
|
||||
|
||||
|
||||
@functools.partial(api.jit, static_argnums=(1, 2, 3, 4))
|
||||
def _svd_tall_and_square_input(
|
||||
a: Any,
|
||||
hermitian: bool,
|
||||
compute_uv: bool,
|
||||
max_iterations: int,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Any | Sequence[Any]:
|
||||
"""Singular value decomposition for m x n matrix and m >= n.
|
||||
|
||||
Args:
|
||||
a: A matrix of shape `m x n` with `m >= n`.
|
||||
hermitian: True if `a` is Hermitian.
|
||||
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
|
||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||
|
||||
Returns:
|
||||
A 3-tuple (`u`, `s`, `v`), where `u` is a unitary matrix of shape `m x n`,
|
||||
`s` is vector of length `n` containing the singular values in the descending
|
||||
order, `v` is a unitary matrix of shape `n x n`, and
|
||||
`a = (u * s) @ v.T.conj()`. For `compute_uv=False`, only `s` is returned.
|
||||
"""
|
||||
|
||||
u_p, h, _, _ = tpu_qdwh.qdwh(
|
||||
a, is_hermitian=hermitian, max_iterations=max_iterations
|
||||
)
|
||||
|
||||
# TODO: Uses `eigvals_only=True` if `compute_uv=False`.
|
||||
v, s = lax_linalg.eigh(
|
||||
h, subset_by_index=subset_by_index, sort_eigenvalues=False
|
||||
)
|
||||
|
||||
# Singular values are non-negative by definition. But eigh could return small
|
||||
# negative values, so we clamp them to zero.
|
||||
s = jnp.maximum(s, 0.0)
|
||||
|
||||
# Sort or reorder singular values to be in descending order.
|
||||
sort_idx = jnp.argsort(s, descending=True)
|
||||
s_out = s[sort_idx]
|
||||
|
||||
if not compute_uv:
|
||||
return s_out
|
||||
|
||||
# Reorders eigenvectors.
|
||||
v_out = v[:, sort_idx]
|
||||
u_out = u_p @ v_out
|
||||
|
||||
# Makes correction if computed `u` from qdwh is not unitary.
|
||||
# Section 5.5 of Nakatsukasa, Yuji, and Nicholas J. Higham. "Stable and
|
||||
# efficient spectral divide and conquer algorithms for the symmetric
|
||||
# eigenvalue decomposition and the SVD." SIAM Journal on Scientific Computing
|
||||
# 35, no. 3 (2013): A1325-A1349.
|
||||
def correct_rank_deficiency(u_out):
|
||||
u_out, r = lax_linalg.qr(u_out, full_matrices=False)
|
||||
u_out = u_out @ jnp.diag(jnp.where(jnp.diag(r) >= 0, 1, -1))
|
||||
return u_out
|
||||
|
||||
eps = float(dtypes.finfo(a.dtype).eps)
|
||||
do_correction = s_out[-1] <= a.shape[1] * eps * s_out[0]
|
||||
cond_f = lambda args: args[1]
|
||||
body_f = lambda args: (correct_rank_deficiency(args[0]), False)
|
||||
u_out, _ = lax.while_loop(cond_f, body_f, (u_out, do_correction))
|
||||
return (u_out, s_out, v_out)
|
||||
|
||||
|
||||
@functools.partial(api.jit, static_argnums=(1, 2, 3, 4, 5))
|
||||
def svd(
|
||||
a: Any,
|
||||
full_matrices: bool,
|
||||
compute_uv: bool = True,
|
||||
hermitian: bool = False,
|
||||
max_iterations: int = 10,
|
||||
subset_by_index: tuple[int, int] | None = None,
|
||||
) -> Any | Sequence[Any]:
|
||||
"""Singular value decomposition.
|
||||
|
||||
Args:
|
||||
a: A matrix of shape `m x n`.
|
||||
full_matrices: If True, `u` and `vh` have the shapes `m x m` and `n x n`,
|
||||
respectively. If False, the shapes are `m x k` and `k x n`, respectively,
|
||||
where `k = min(m, n)`.
|
||||
compute_uv: Whether to also compute `u` and `v` in addition to `s`.
|
||||
hermitian: True if `a` is Hermitian.
|
||||
max_iterations: The predefined maximum number of iterations of QDWH.
|
||||
subset_by_index: Optional 2-tuple [start, end] indicating the range of
|
||||
indices of singular components to compute. For example, if
|
||||
``subset_by_index`` = [0,2], then ``svd`` computes the two largest
|
||||
singular values (and their singular vectors if `compute_uv` is true.
|
||||
|
||||
Returns:
|
||||
A 3-tuple (`u`, `s`, `vh`), where `u` and `vh` are unitary matrices,
|
||||
`s` is vector of length `k` containing the singular values in the
|
||||
non-increasing order, and `k = min(m, n)`. The shapes of `u` and `vh`
|
||||
depend on the value of `full_matrices`. For `compute_uv=False`,
|
||||
only `s` is returned.
|
||||
"""
|
||||
full_matrices = core.concrete_or_error(
|
||||
bool, full_matrices, 'The `full_matrices` argument must be statically '
|
||||
'specified to use `svd` within JAX transformations.')
|
||||
|
||||
compute_uv = core.concrete_or_error(
|
||||
bool, compute_uv, 'The `compute_uv` argument must be statically '
|
||||
'specified to use `svd` within JAX transformations.')
|
||||
|
||||
hermitian = core.concrete_or_error(
|
||||
bool,
|
||||
hermitian,
|
||||
'The `hermitian` argument must be statically '
|
||||
'specified to use `svd` within JAX transformations.',
|
||||
)
|
||||
|
||||
max_iterations = core.concrete_or_error(
|
||||
int,
|
||||
max_iterations,
|
||||
'The `max_iterations` argument must be statically '
|
||||
'specified to use `svd` within JAX transformations.',
|
||||
)
|
||||
|
||||
if subset_by_index is not None:
|
||||
if len(subset_by_index) != 2:
|
||||
raise ValueError('subset_by_index must be a tuple of size 2.')
|
||||
# Make sure subset_by_index is a concrete tuple.
|
||||
subset_by_index = (
|
||||
operator.index(subset_by_index[0]),
|
||||
operator.index(subset_by_index[1]),
|
||||
)
|
||||
if subset_by_index[0] >= subset_by_index[1]:
|
||||
raise ValueError('Got empty index range in subset_by_index.')
|
||||
if subset_by_index[0] < 0:
|
||||
raise ValueError('Indices in subset_by_index must be non-negative.')
|
||||
m, n = a.shape
|
||||
rank = n if n < m else m
|
||||
if subset_by_index[1] > rank:
|
||||
raise ValueError('Index in subset_by_index[1] exceeds matrix size.')
|
||||
if full_matrices and subset_by_index != (0, rank):
|
||||
raise ValueError(
|
||||
'full_matrices and subset_by_index cannot be both be set.'
|
||||
)
|
||||
# By convention, eigenvalues are numbered in non-decreasing order, while
|
||||
# singular values are numbered non-increasing order, so change
|
||||
# subset_by_index accordingly.
|
||||
subset_by_index = (rank - subset_by_index[1], rank - subset_by_index[0])
|
||||
|
||||
m, n = a.shape
|
||||
is_flip = False
|
||||
if m < n:
|
||||
a = a.T.conj()
|
||||
m, n = a.shape
|
||||
is_flip = True
|
||||
|
||||
u_out_null: Array | None
|
||||
q: Array | None
|
||||
|
||||
if full_matrices and m > n:
|
||||
q_full, a_full = lax_linalg.qr(a, pivoting=False, full_matrices=True)
|
||||
q = q_full[:, :n]
|
||||
u_out_null = q_full[:, n:]
|
||||
a = a_full[:n, :]
|
||||
elif m > 1.15 * n:
|
||||
# The constant `1.15` comes from Yuji Nakatsukasa's implementation
|
||||
# https://www.mathworks.com/matlabcentral/fileexchange/36830-symmetric-eigenvalue-decomposition-and-the-svd?s_tid=FX_rc3_behav
|
||||
q, a = lax_linalg.qr(a, pivoting=False, full_matrices=False)
|
||||
u_out_null = None
|
||||
else:
|
||||
q = None
|
||||
u_out_null = None
|
||||
|
||||
if not compute_uv:
|
||||
with config.default_matmul_precision('float32'):
|
||||
return _svd_tall_and_square_input(
|
||||
a, hermitian, compute_uv, max_iterations, subset_by_index
|
||||
)
|
||||
|
||||
with config.default_matmul_precision('float32'):
|
||||
u_out, s_out, v_out = _svd_tall_and_square_input(
|
||||
a, hermitian, compute_uv, max_iterations, subset_by_index
|
||||
)
|
||||
if q is not None: # (full_matrices and m > n) or (m > 1.15 * n)
|
||||
u_out = q @ u_out
|
||||
|
||||
if u_out_null is not None: # full_matrices and m > n
|
||||
u_out = jnp.hstack((u_out, u_out_null))
|
||||
|
||||
is_finite = jnp.all(jnp.isfinite(a))
|
||||
cond_f = lambda args: jnp.logical_not(args[0])
|
||||
body_f = lambda args: (
|
||||
jnp.array(True),
|
||||
jnp.full_like(u_out, np.nan),
|
||||
jnp.full_like(s_out, np.nan),
|
||||
jnp.full_like(v_out, np.nan),
|
||||
)
|
||||
_, u_out, s_out, v_out = lax.while_loop(
|
||||
cond_f, body_f, (is_finite, u_out, s_out, v_out)
|
||||
)
|
||||
|
||||
if is_flip:
|
||||
return (v_out, s_out, u_out.T.conj())
|
||||
|
||||
return (u_out, s_out, v_out.T.conj())
|
||||
|
||||
|
||||
def _svd_tpu(a, *, full_matrices, compute_uv, subset_by_index, algorithm=None):
|
||||
if algorithm is not None and algorithm != lax_linalg.SvdAlgorithm.DEFAULT:
|
||||
raise NotImplementedError(
|
||||
"The SVD algorithm parameter is not implemented on TPU.")
|
||||
|
||||
batch_dims = a.shape[:-2]
|
||||
fn = functools.partial(
|
||||
svd,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
for _ in range(len(batch_dims)):
|
||||
fn = api.vmap(fn)
|
||||
|
||||
if compute_uv:
|
||||
u, s, vh = fn(a)
|
||||
return [s, u, vh]
|
||||
else:
|
||||
s = fn(a)
|
||||
return [s]
|
||||
|
||||
|
||||
def _svd_tpu_lowering_rule(
|
||||
ctx, operand, *, full_matrices, compute_uv, subset_by_index, algorithm=None
|
||||
):
|
||||
operand_aval, = ctx.avals_in
|
||||
m, n = operand_aval.shape[-2:]
|
||||
|
||||
if algorithm is not None and algorithm not in [
|
||||
lax_linalg.SvdAlgorithm.DEFAULT,
|
||||
lax_linalg.SvdAlgorithm.POLAR,
|
||||
]:
|
||||
raise NotImplementedError(
|
||||
'Only the POLAR (which is also DEFAULT on TPU) SVD algorithm is'
|
||||
' supported on TPU.'
|
||||
)
|
||||
|
||||
if m == 0 or n == 0:
|
||||
return mlir.lower_fun(lax_linalg._empty_svd, multiple_results=True)(
|
||||
ctx,
|
||||
operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
)
|
||||
|
||||
return mlir.lower_fun(_svd_tpu, multiple_results=True)(
|
||||
ctx,
|
||||
operand,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv,
|
||||
subset_by_index=subset_by_index,
|
||||
)
|
||||
|
||||
mlir.register_lowering(lax_linalg.svd_p, _svd_tpu_lowering_rule)
|
||||
Reference in New Issue
Block a user