This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,13 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,75 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import operator
from jax._src import api
from jax._src import numpy as jnp
from jax._src.numpy import linalg as jnp_linalg
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike
def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
"""Assign codes from a code book to a set of observations.
JAX implementation of :func:`scipy.cluster.vq.vq`.
Assigns each observation vector in ``obs`` to a code from ``code_book``
based on the nearest Euclidean distance.
Args:
obs: array of observation vectors of shape ``(M, N)``. Each row represents
a single observation. If ``obs`` is one-dimensional, then each entry is
treated as a length-1 observation.
code_book: array of codes with shape ``(K, N)``. Each row represents a single
code vector. If ``code_book`` is one-dimensional, then each entry is treated
as a length-1 code.
check_finite: unused in JAX
Returns:
A tuple of arrays ``(code, dist)``
- ``code`` is an integer array of shape ``(M,)`` containing indices ``0 <= i < K``
of the closest entry in ``code_book`` for the given entry in ``obs``.
- ``dist`` is a float array of shape ``(M,)`` containing the euclidean
distance between each observation and the nearest code.
Examples:
>>> obs = jnp.array([[1.1, 2.1, 3.1],
... [5.9, 4.8, 6.2]])
>>> code_book = jnp.array([[1., 2., 3.],
... [2., 3., 4.],
... [3., 4., 5.],
... [4., 5., 6.]])
>>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book)
>>> print(codes)
[0 3]
>>> print(distances)
[0.17320499 1.9209373 ]
"""
del check_finite # unused
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
obs_arr, cb_arr = promote_dtypes_inexact(obs, code_book)
if obs_arr.ndim != cb_arr.ndim:
raise ValueError("Observation and code_book should have the same rank")
if obs_arr.ndim == 1:
obs_arr, cb_arr = obs_arr[..., None], cb_arr[..., None]
if obs_arr.ndim != 2:
raise ValueError("ndim different than 1 or 2 are not supported")
dist = api.vmap(lambda ob: jnp_linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr)
code = jnp.argmin(dist, axis=-1)
dist_min = api.vmap(operator.getitem)(dist, code)
return code, dist_min