hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,75 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import numpy as jnp
|
||||
from jax._src.numpy import linalg as jnp_linalg
|
||||
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
|
||||
def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple[Array, Array]:
|
||||
"""Assign codes from a code book to a set of observations.
|
||||
|
||||
JAX implementation of :func:`scipy.cluster.vq.vq`.
|
||||
|
||||
Assigns each observation vector in ``obs`` to a code from ``code_book``
|
||||
based on the nearest Euclidean distance.
|
||||
|
||||
Args:
|
||||
obs: array of observation vectors of shape ``(M, N)``. Each row represents
|
||||
a single observation. If ``obs`` is one-dimensional, then each entry is
|
||||
treated as a length-1 observation.
|
||||
code_book: array of codes with shape ``(K, N)``. Each row represents a single
|
||||
code vector. If ``code_book`` is one-dimensional, then each entry is treated
|
||||
as a length-1 code.
|
||||
check_finite: unused in JAX
|
||||
|
||||
Returns:
|
||||
A tuple of arrays ``(code, dist)``
|
||||
|
||||
- ``code`` is an integer array of shape ``(M,)`` containing indices ``0 <= i < K``
|
||||
of the closest entry in ``code_book`` for the given entry in ``obs``.
|
||||
- ``dist`` is a float array of shape ``(M,)`` containing the euclidean
|
||||
distance between each observation and the nearest code.
|
||||
|
||||
Examples:
|
||||
>>> obs = jnp.array([[1.1, 2.1, 3.1],
|
||||
... [5.9, 4.8, 6.2]])
|
||||
>>> code_book = jnp.array([[1., 2., 3.],
|
||||
... [2., 3., 4.],
|
||||
... [3., 4., 5.],
|
||||
... [4., 5., 6.]])
|
||||
>>> codes, distances = jax.scipy.cluster.vq.vq(obs, code_book)
|
||||
>>> print(codes)
|
||||
[0 3]
|
||||
>>> print(distances)
|
||||
[0.17320499 1.9209373 ]
|
||||
"""
|
||||
del check_finite # unused
|
||||
check_arraylike("scipy.cluster.vq.vq", obs, code_book)
|
||||
obs_arr, cb_arr = promote_dtypes_inexact(obs, code_book)
|
||||
if obs_arr.ndim != cb_arr.ndim:
|
||||
raise ValueError("Observation and code_book should have the same rank")
|
||||
if obs_arr.ndim == 1:
|
||||
obs_arr, cb_arr = obs_arr[..., None], cb_arr[..., None]
|
||||
if obs_arr.ndim != 2:
|
||||
raise ValueError("ndim different than 1 or 2 are not supported")
|
||||
dist = api.vmap(lambda ob: jnp_linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr)
|
||||
code = jnp.argmin(dist, axis=-1)
|
||||
dist_min = api.vmap(operator.getitem)(dist, code)
|
||||
return code, dist_min
|
||||
Reference in New Issue
Block a user