2653 lines
108 KiB
Python
2653 lines
108 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""JAX user-facing transformations and utilities.
|
|
|
|
The transformations here mostly wrap internal transformations, providing
|
|
convenience flags to control behavior and handling Python containers of
|
|
arguments and outputs. The Python containers handled are pytrees (see
|
|
tree_util.py), which include nested tuples/lists/dicts, where the leaves are
|
|
arrays.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
import collections
|
|
from collections.abc import Callable, Hashable, Iterable, Sequence
|
|
import dataclasses
|
|
from functools import partial
|
|
import inspect
|
|
from typing import (Any, Literal, Optional, TypeVar, overload,
|
|
cast, TYPE_CHECKING)
|
|
import weakref
|
|
|
|
import numpy as np
|
|
from contextlib import contextmanager
|
|
|
|
from jax._src import api_util
|
|
from jax._src import linear_util as lu
|
|
from jax._src.tree_util import (
|
|
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
|
|
tree_leaves, Partial, PyTreeDef, keystr, generate_key_paths,
|
|
tree_flatten_with_path, equality_errors_pytreedef, register_pytree_node,
|
|
register_dataclass)
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dispatch
|
|
from jax._src import array
|
|
from jax._src import basearray
|
|
from jax._src import distributed
|
|
from jax._src import dtypes
|
|
from jax._src.dtypes import canonicalize_value
|
|
from jax._src import sharding_impls
|
|
from jax._src import source_info_util
|
|
from jax._src import traceback_util
|
|
from jax._src import pjit
|
|
from jax._src import xla_bridge as xb
|
|
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray, typeof
|
|
from jax._src.api_util import (
|
|
flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
|
|
flatten_axes, _ensure_index, apply_flat_fun_nokwargs,
|
|
check_callable, debug_info)
|
|
from jax._src.lib import jax_jit
|
|
from jax._src.lib import xla_client as xc
|
|
from jax._src.sharding import Sharding
|
|
from jax._src.mesh import get_concrete_mesh, get_abstract_mesh, Mesh
|
|
from jax._src.sharding_impls import PartitionSpec as P, NamedSharding
|
|
from jax._src.layout import Format
|
|
from jax._src.traceback_util import api_boundary
|
|
from jax._src import tree_util
|
|
from jax._src.util import unzip2, safe_map, safe_zip, wraps
|
|
from jax._src import util
|
|
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters import pxla
|
|
|
|
config_ext = xc._xla.config
|
|
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
_dtype = dtypes.dtype
|
|
|
|
AxisName = Hashable
|
|
|
|
Device = xc.Device
|
|
|
|
# These TypeVars are used below to express the fact that function types
|
|
# (i.e. call signatures) are invariant under the vmap transformation.
|
|
F = TypeVar("F", bound=Callable)
|
|
T = TypeVar("T")
|
|
U = TypeVar("U")
|
|
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
ShapeDtypeStruct = core.ShapeDtypeStruct
|
|
|
|
|
|
@api_boundary
|
|
def _nan_check_posthook(fun, args, kwargs, output):
|
|
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
|
|
buffers = []
|
|
for leaf in tree_leaves(output):
|
|
if hasattr(leaf, "addressable_shards"):
|
|
buffers.extend([shard.data for shard in leaf.addressable_shards])
|
|
|
|
try:
|
|
dispatch.check_special(pjit.jit_p.name, buffers)
|
|
except api_util.InternalFloatingPointError as e:
|
|
assert config.debug_nans.value or config.debug_infs.value
|
|
if hasattr(fun, '_fun'):
|
|
f = fun._fun
|
|
if getattr(f, '_apply_primitive', False):
|
|
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
|
|
# compiled_fun can only raise in this case
|
|
api_util.maybe_recursive_nan_check(e, f, args, kwargs)
|
|
raise AssertionError("Unreachable") from e
|
|
else:
|
|
# TODO(emilyaf): Shouldn't need this fallback.
|
|
raise
|
|
|
|
_post_hook_state = config_ext.Config[Optional[Callable]](
|
|
"post_hook", None, include_in_jit_key=False
|
|
)
|
|
jax_jit.set_post_hook_state(_post_hook_state)
|
|
|
|
def _update_debug_special_global(_):
|
|
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
|
|
_post_hook_state.set_global(_nan_check_posthook)
|
|
else:
|
|
_post_hook_state.set_global(None)
|
|
|
|
def _update_debug_special_thread_local(_):
|
|
if (config.debug_nans.get_local() == True or
|
|
config.debug_infs.get_local() == True):
|
|
_post_hook_state.set_local(_nan_check_posthook)
|
|
else:
|
|
_post_hook_state.set_local(None)
|
|
|
|
config.debug_nans._add_hooks(_update_debug_special_global,
|
|
_update_debug_special_thread_local)
|
|
config.debug_infs._add_hooks(_update_debug_special_global,
|
|
_update_debug_special_thread_local)
|
|
|
|
|
|
float0 = dtypes.float0
|
|
|
|
class NotSpecified:
|
|
"""Sentinel for use in jax.jit"""
|
|
def __repr__(self):
|
|
return "<not-specified>"
|
|
|
|
@overload
|
|
def jit(
|
|
fun: Callable, /, *,
|
|
in_shardings: Any = ...,
|
|
out_shardings: Any = ...,
|
|
static_argnums: int | Sequence[int] | None = ...,
|
|
static_argnames: str | Iterable[str] | None = ...,
|
|
donate_argnums: int | Sequence[int] | None = ...,
|
|
donate_argnames: str | Iterable[str] | None = ...,
|
|
keep_unused: bool = ...,
|
|
device: xc.Device | None = ...,
|
|
backend: str | None = ...,
|
|
inline: bool = ...,
|
|
compiler_options: dict[str, Any] | None = ...,
|
|
) -> pjit.JitWrapped:
|
|
...
|
|
|
|
@overload
|
|
def jit(
|
|
*,
|
|
in_shardings: Any = ...,
|
|
out_shardings: Any = ...,
|
|
static_argnums: int | Sequence[int] | None = ...,
|
|
static_argnames: str | Iterable[str] | None = ...,
|
|
donate_argnums: int | Sequence[int] | None = ...,
|
|
donate_argnames: str | Iterable[str] | None = ...,
|
|
keep_unused: bool = ...,
|
|
device: xc.Device | None = ...,
|
|
backend: str | None = ...,
|
|
inline: bool = ...,
|
|
compiler_options: dict[str, Any] | None = ...,
|
|
) -> Callable[[Callable], pjit.JitWrapped]:
|
|
...
|
|
|
|
def jit(
|
|
fun: Callable | NotSpecified = NotSpecified(), /, *,
|
|
in_shardings: Any = sharding_impls.UNSPECIFIED,
|
|
out_shardings: Any = sharding_impls.UNSPECIFIED,
|
|
static_argnums: int | Sequence[int] | None = None,
|
|
static_argnames: str | Iterable[str] | None = None,
|
|
donate_argnums: int | Sequence[int] | None = None,
|
|
donate_argnames: str | Iterable[str] | None = None,
|
|
keep_unused: bool = False,
|
|
device: xc.Device | None = None,
|
|
backend: str | None = None,
|
|
inline: bool = False,
|
|
compiler_options: dict[str, Any] | None = None,
|
|
) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]:
|
|
"""Sets up ``fun`` for just-in-time compilation with XLA.
|
|
|
|
Args:
|
|
fun: Function to be jitted. ``fun`` should be a pure function.
|
|
The arguments and return value of ``fun`` should be arrays, scalar, or
|
|
(nested) standard Python containers (tuple/list/dict) thereof. Positional
|
|
arguments indicated by ``static_argnums`` can be any hashable type. Static
|
|
arguments are included as part of a compilation cache key, which is why
|
|
hash and equality operators must be defined. JAX keeps a weak reference to
|
|
``fun`` for use as a compilation cache key, so the object ``fun`` must be
|
|
weakly-referenceable. Starting in JAX v0.8.1, when ``fun`` is omitted,
|
|
the return value will be a partially-evaluated function to allow the
|
|
decorator factory pattern (see Examples below).
|
|
in_shardings: optional, a :py:class:`Sharding` or pytree with
|
|
:py:class:`Sharding` leaves and structure that is a tree prefix of the
|
|
positional arguments tuple to ``fun``. If provided, the positional
|
|
arguments passed to ``fun`` must have shardings that are compatible with
|
|
``in_shardings`` or an error is raised, and the compiled computation has
|
|
input shardings corresponding to ``in_shardings``. If not provided, the
|
|
compiled computation's input shardings are inferred from argument
|
|
shardings.
|
|
out_shardings: optional, a :py:class:`Sharding` or pytree with
|
|
:py:class:`Sharding` leaves and structure that is a tree prefix of the
|
|
output of ``fun``. If provided, it has the same effect as applying
|
|
:py:func:`jax.lax.with_sharding_constraint` to the output of ``fun``.
|
|
static_argnums: optional, an int or collection of ints that specify which
|
|
positional arguments to treat as static (trace- and compile-time
|
|
constant).
|
|
|
|
Static arguments should be hashable, meaning both ``__hash__`` and
|
|
``__eq__`` are implemented, and immutable. Otherwise, they can be arbitrary
|
|
Python objects. Calling the jitted function with different values for
|
|
these constants will trigger recompilation. Arguments that are not
|
|
array-like or containers thereof must be marked as static.
|
|
|
|
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
|
|
arguments are treated as static. If ``static_argnums`` is not provided but
|
|
``static_argnames`` is, or vice versa, JAX uses
|
|
:code:`inspect.signature(fun)` to find any positional arguments that
|
|
correspond to ``static_argnames``
|
|
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are
|
|
provided, ``inspect.signature`` is not used, and only actual
|
|
parameters listed in either ``static_argnums`` or ``static_argnames`` will
|
|
be treated as static.
|
|
static_argnames: optional, a string or collection of strings specifying
|
|
which named arguments to treat as static (compile-time constant). See the
|
|
comment on ``static_argnums`` for details. If not
|
|
provided but ``static_argnums`` is set, the default is based on calling
|
|
``inspect.signature(fun)`` to find corresponding named arguments.
|
|
donate_argnums: optional, collection of integers to specify which positional
|
|
argument buffers can be overwritten by the computation and marked deleted
|
|
in the caller. It is safe to donate argument buffers if you no longer need
|
|
them once the computation has started. In some cases XLA can make use of
|
|
donated buffers to reduce the amount of memory needed to perform a
|
|
computation, for example recycling one of your input buffers to store a
|
|
result. You should not reuse buffers that you donate to a computation; JAX
|
|
will raise an error if you try to. By default, no argument buffers are
|
|
donated.
|
|
|
|
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
|
|
arguments are donated. If ``donate_argnums`` is not provided but
|
|
``donate_argnames`` is, or vice versa, JAX uses
|
|
:code:`inspect.signature(fun)` to find any positional arguments that
|
|
correspond to ``donate_argnames``
|
|
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
|
|
provided, ``inspect.signature`` is not used, and only actual
|
|
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
|
|
be donated.
|
|
|
|
For more details on buffer donation see the
|
|
`FAQ <https://docs.jax.dev/en/latest/faq.html#buffer-donation>`_.
|
|
donate_argnames: optional, a string or collection of strings specifying
|
|
which named arguments are donated to the computation. See the
|
|
comment on ``donate_argnums`` for details. If not
|
|
provided but ``donate_argnums`` is set, the default is based on calling
|
|
``inspect.signature(fun)`` to find corresponding named arguments.
|
|
keep_unused: optional boolean. If `False` (the default), arguments that JAX
|
|
determines to be unused by `fun` *may* be dropped from resulting compiled
|
|
XLA executables. Such arguments will not be transferred to the device nor
|
|
provided to the underlying executable. If `True`, unused arguments will
|
|
not be pruned.
|
|
device: This is an experimental feature and the API is likely to change.
|
|
Optional, the Device the jitted function will run on. (Available devices
|
|
can be retrieved via :py:func:`jax.devices`.) The default is inherited
|
|
from XLA's DeviceAssignment logic and is usually to use
|
|
``jax.devices()[0]``.
|
|
backend: This is an experimental feature and the API is likely to change.
|
|
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
|
|
``'tpu'``.
|
|
inline: Optional boolean. Specify whether this function should be inlined
|
|
into enclosing jaxprs. Default False.
|
|
|
|
Returns:
|
|
A wrapped version of ``fun``, set up for just-in-time compilation.
|
|
|
|
Examples:
|
|
In the following example, ``selu`` can be compiled into a single fused kernel
|
|
by XLA:
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> @jax.jit
|
|
... def selu(x, alpha=1.67, lmbda=1.05):
|
|
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
|
|
>>>
|
|
>>> key = jax.random.key(0)
|
|
>>> x = jax.random.normal(key, (10,))
|
|
>>> print(selu(x)) # doctest: +SKIP
|
|
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
|
|
-0.85743 -0.78232 0.76827 0.59566 ]
|
|
|
|
Starting in JAX v0.8.1, :func:`jit` supports the decorator factory pattern
|
|
for specifying optional keywords:
|
|
|
|
>>> @jax.jit(static_argnames=['n'])
|
|
... def g(x, n):
|
|
... for i in range(n):
|
|
... x = x ** 2
|
|
... return x
|
|
>>>
|
|
>>> g(jnp.arange(4), 3)
|
|
Array([ 0, 1, 256, 6561], dtype=int32)
|
|
|
|
For compatiblity with older JAX versions, a common pattern is to use
|
|
:func:`functools.partial`:
|
|
|
|
>>> from functools import partial
|
|
>>>
|
|
>>> @partial(jax.jit, static_argnames=['n'])
|
|
... def g(x, n):
|
|
... for i in range(n):
|
|
... x = x ** 2
|
|
... return x
|
|
>>>
|
|
>>> g(jnp.arange(4), 3)
|
|
Array([ 0, 1, 256, 6561], dtype=int32)
|
|
"""
|
|
kwds = dict(
|
|
in_shardings=in_shardings, out_shardings=out_shardings,
|
|
static_argnums=static_argnums, static_argnames=static_argnames,
|
|
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
|
|
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
|
|
compiler_options=compiler_options, use_resource_env=False)
|
|
if isinstance(fun, NotSpecified):
|
|
return lambda fun: pjit.make_jit(fun, **kwds)
|
|
else:
|
|
return pjit.make_jit(fun, **kwds)
|
|
|
|
if not TYPE_CHECKING:
|
|
# TODO(slebedev): This ought to be a decorator, but it seems it makes
|
|
# pytype ignore the overloads
|
|
jit = api_boundary(jit, repro_api_name="jax.jit")
|
|
|
|
|
|
@contextmanager
|
|
def disable_jit(disable: bool = True):
|
|
"""Context manager that disables :py:func:`jit` behavior under its dynamic context.
|
|
|
|
For debugging, it is useful to have a mechanism that disables :py:func:`jit`
|
|
everywhere in a dynamic context. Note that this not only disables explicit
|
|
uses of :func:`jit` by the user, but will also remove any implicit JIT compilation
|
|
used by the JAX library: this includes implicit JIT computation of `body` and
|
|
`cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and
|
|
:func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
|
|
and any other case where :func:`jit` is used within an API's implementation.
|
|
Note however that even under `disable_jit`, individual primitive operations
|
|
will still be compiled by XLA as in normal eager op-by-op execution.
|
|
|
|
Values that have a data dependence on the arguments to a jitted function are
|
|
traced and abstracted. For example, an abstract value may be a
|
|
:py:class:`ShapedArray` instance, representing the set of all possible arrays
|
|
with a given shape and dtype, but not representing one concrete array with
|
|
specific values. You might notice those if you use a benign side-effecting
|
|
operation in a jitted function, like a print:
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> @jax.jit
|
|
... def f(x):
|
|
... y = x * 2
|
|
... print("Value of y is", y)
|
|
... return y + 3
|
|
...
|
|
>>> print(f(jax.numpy.array([1, 2, 3])))
|
|
Value of y is JitTracer(int32[3])
|
|
[5 7 9]
|
|
|
|
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,
|
|
which represents an array with a fixed shape and type but an arbitrary value.
|
|
The value of ``y`` is also traced. If we want to see a concrete value while
|
|
debugging, and avoid the tracer too, we can use the :py:func:`disable_jit`
|
|
context manager:
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> with jax.disable_jit():
|
|
... print(f(jax.numpy.array([1, 2, 3])))
|
|
...
|
|
Value of y is [2 4 6]
|
|
[5 7 9]
|
|
"""
|
|
with config.disable_jit(disable):
|
|
yield
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.grad")
|
|
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
|
|
has_aux: bool = False, holomorphic: bool = False,
|
|
allow_int: bool = False,
|
|
reduce_axes: Sequence[AxisName] = ()) -> Callable:
|
|
"""Creates a function that evaluates the gradient of ``fun``.
|
|
|
|
Args:
|
|
fun: Function to be differentiated. Its arguments at positions specified by
|
|
``argnums`` should be arrays, scalars, or standard Python containers.
|
|
Argument arrays in the positions specified by ``argnums`` must be of
|
|
inexact (i.e., floating-point or complex) type. It
|
|
should return a scalar (which includes arrays with shape ``()`` but not
|
|
arrays with shape ``(1,)`` etc.)
|
|
argnums: Optional, integer or sequence of integers. Specifies which
|
|
positional argument(s) to differentiate with respect to (default 0).
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
|
first element is considered the output of the mathematical function to be
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
|
holomorphic. If True, inputs and outputs must be complex. Default False.
|
|
allow_int: Optional, bool. Whether to allow differentiating with
|
|
respect to integer valued inputs. The gradient of an integer input will
|
|
have a trivial vector-space dtype (float0). Default False.
|
|
|
|
Returns:
|
|
A function with the same arguments as ``fun``, that evaluates the gradient
|
|
of ``fun``. If ``argnums`` is an integer then the gradient has the same
|
|
shape and type as the positional argument indicated by that integer. If
|
|
argnums is a tuple of integers, the gradient is a tuple of values with the
|
|
same shapes and types as the corresponding arguments. If ``has_aux`` is True
|
|
then a pair of (gradient, auxiliary_data) is returned.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> grad_tanh = jax.grad(jax.numpy.tanh)
|
|
>>> print(grad_tanh(0.2))
|
|
0.961043
|
|
"""
|
|
if reduce_axes:
|
|
raise NotImplementedError("reduce_axes argument to grad is deprecated")
|
|
del reduce_axes
|
|
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
|
|
holomorphic=holomorphic,
|
|
allow_int=allow_int)
|
|
|
|
docstr = ("Gradient of {fun} with respect to positional argument(s) "
|
|
"{argnums}. Takes the same arguments as {fun} but returns the "
|
|
"gradient, which has the same shape as the arguments at "
|
|
"positions {argnums}.")
|
|
|
|
@wraps(fun, docstr=docstr, argnums=argnums)
|
|
@api_boundary
|
|
def grad_f(*args, **kwargs):
|
|
_, g = value_and_grad_f(*args, **kwargs)
|
|
return g
|
|
|
|
@wraps(fun, docstr=docstr, argnums=argnums)
|
|
@api_boundary
|
|
def grad_f_aux(*args, **kwargs):
|
|
(_, aux), g = value_and_grad_f(*args, **kwargs)
|
|
return g, aux
|
|
|
|
return grad_f_aux if has_aux else grad_f
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.value_and_grad")
|
|
def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
|
|
has_aux: bool = False, holomorphic: bool = False,
|
|
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
|
|
) -> Callable[..., tuple[Any, Any]]:
|
|
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
|
|
|
|
Args:
|
|
fun: Function to be differentiated. Its arguments at positions specified by
|
|
``argnums`` should be arrays, scalars, or standard Python containers. It
|
|
should return a scalar (which includes arrays with shape ``()`` but not
|
|
arrays with shape ``(1,)`` etc.)
|
|
argnums: Optional, integer or sequence of integers. Specifies which
|
|
positional argument(s) to differentiate with respect to (default 0).
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
|
first element is considered the output of the mathematical function to be
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
|
holomorphic. If True, inputs and outputs must be complex. Default False.
|
|
allow_int: Optional, bool. Whether to allow differentiating with
|
|
respect to integer valued inputs. The gradient of an integer input will
|
|
have a trivial vector-space dtype (float0). Default False.
|
|
|
|
Returns:
|
|
A function with the same arguments as ``fun`` that evaluates both ``fun``
|
|
and the gradient of ``fun`` and returns them as a pair (a two-element
|
|
tuple). If ``argnums`` is an integer then the gradient has the same shape
|
|
and type as the positional argument indicated by that integer. If argnums is
|
|
a sequence of integers, the gradient is a tuple of values with the same
|
|
shapes and types as the corresponding arguments. If ``has_aux`` is True
|
|
then a tuple of ((value, auxiliary_data), gradient) is returned.
|
|
"""
|
|
from jax._src.lax import lax as lax_internal # pytype: disable=import-error
|
|
|
|
if reduce_axes:
|
|
raise NotImplementedError("reduce_axes argument to grad is deprecated")
|
|
del reduce_axes
|
|
|
|
docstr = ("Value and gradient of {fun} with respect to positional "
|
|
"argument(s) {argnums}. Takes the same arguments as {fun} but "
|
|
"returns a two-element tuple where the first element is the value "
|
|
"of {fun} and the second element is the gradient, which has the "
|
|
"same shape as the arguments at positions {argnums}.")
|
|
|
|
check_callable(fun)
|
|
argnums = core.concrete_or_error(_ensure_index, argnums)
|
|
|
|
@wraps(fun, docstr=docstr, argnums=argnums)
|
|
@api_boundary
|
|
def value_and_grad_f(*args, **kwargs):
|
|
max_argnum = argnums if isinstance(argnums, int) else max(argnums)
|
|
if max_argnum >= len(args):
|
|
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
|
|
f"{max_argnum + 1} positional arguments to be passed by the caller, "
|
|
f"but got only {len(args)} positional arguments.")
|
|
dbg = debug_info('value_and_grad', fun, args, kwargs)
|
|
|
|
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
|
|
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
|
require_static_args_hashable=False)
|
|
for leaf in tree_leaves(dyn_args):
|
|
_check_input_dtype_grad(holomorphic, allow_int, leaf)
|
|
if has_aux:
|
|
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
|
|
else:
|
|
ans, vjp_py = _vjp(f_partial, *dyn_args)
|
|
aux = None
|
|
_check_scalar(ans)
|
|
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
|
|
g = vjp_py(lax_internal._one_vjp(ans))
|
|
g = g[0] if isinstance(argnums, int) else g
|
|
if not has_aux:
|
|
return ans, g
|
|
else:
|
|
return (ans, aux), g
|
|
|
|
return value_and_grad_f
|
|
|
|
def _check_scalar(x):
|
|
msg = "Gradient only defined for scalar-output functions. Output {}.".format
|
|
try:
|
|
aval = core.typeof(x)
|
|
except TypeError as e:
|
|
raise TypeError(msg(f"was {x}")) from e
|
|
else:
|
|
if isinstance(aval, ShapedArray):
|
|
if aval.shape != ():
|
|
raise TypeError(msg(f"had shape: {aval.shape}"))
|
|
else:
|
|
raise TypeError(msg(f"had abstract value {aval}"))
|
|
|
|
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
|
|
dispatch.check_arg(x)
|
|
aval = core.typeof(x)
|
|
if holomorphic:
|
|
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
|
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
|
|
f"but got {aval.dtype.name}.")
|
|
if isinstance(aval, ShapedArray):
|
|
if (dtypes.issubdtype(aval.dtype, dtypes.extended) or
|
|
dtypes.issubdtype(aval.dtype, np.integer) or
|
|
dtypes.issubdtype(aval.dtype, np.bool_)):
|
|
if not allow_int:
|
|
raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
|
|
f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. "
|
|
"If you want to use Boolean- or integer-valued inputs, use vjp "
|
|
"or set allow_int to True.")
|
|
elif not dtypes.issubdtype(aval.dtype, np.inexact):
|
|
raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
|
|
f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.")
|
|
_check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
|
|
|
|
def _check_output_dtype_revderiv(name, holomorphic, x):
|
|
aval = core.typeof(x)
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
|
raise TypeError(
|
|
f"{name} with output element type {aval.dtype.name}")
|
|
if holomorphic:
|
|
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
|
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
|
|
f"but got {aval.dtype.name}.")
|
|
elif dtypes.issubdtype(aval.dtype, np.complexfloating):
|
|
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
|
|
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
|
"For holomorphic differentiation, pass holomorphic=True. "
|
|
"For differentiation of non-holomorphic functions involving complex "
|
|
"outputs, use jax.vjp directly.")
|
|
elif not dtypes.issubdtype(aval.dtype, np.floating):
|
|
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
|
|
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
|
"For differentiation of functions with integer outputs, use "
|
|
"jax.vjp directly.")
|
|
_check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad")
|
|
|
|
@partial(api_boundary, repro_api_name="jax.fwd_and_bwd")
|
|
def fwd_and_bwd(
|
|
fun: Callable, argnums: int | Sequence[int], has_aux: bool = False,
|
|
jitted: bool = True,
|
|
) -> tuple[Callable, Callable]:
|
|
"""Creates functions ``fwd`` and ``bwd`` corresponding to the forward and
|
|
backward pass of a given function ``fun``. The forward function ``fwd(*args)``
|
|
functionally behaves much like ``y, fun_vjp = jax.vjp(fun, *args)``, but allows
|
|
reuse of the backward function ``bwd`` across multiple iterations, which is
|
|
useful to avoid recompilation when the forward and backward do not end up in a
|
|
single jitted function:
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> x = W = cot_out = jax.numpy.ones((4,4))
|
|
>>>
|
|
>>> def f(x, W):
|
|
... return x @ W
|
|
...
|
|
>>> f_jitted = jax.jit(f)
|
|
>>> for i in range(3):
|
|
... y, f_vjp = jax.vjp(f_jitted, x, W)
|
|
... cot_x, cot_W = f_vjp(cot_out) # not jitted
|
|
... cot_x, cot_W = jax.jit(f_vjp)(cot_out) # recompiles on every iteration
|
|
...
|
|
>>> fwd, bwd = jax.fwd_and_bwd(f, argnums=(0,1))
|
|
>>> for i in range(3):
|
|
... y, residuals = fwd(x, W)
|
|
... cot_x, cot_W = bwd(residuals, cot_out) # jitted, compiles once
|
|
...
|
|
|
|
Args:
|
|
fun: Function to produce a forward and backward of.
|
|
argnums: Integer or sequence of integers. Specifies which positional argument(s)
|
|
to differentiate with respect to.
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
|
first element is considered the output of the mathematical function to be
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
jitted: Optional, bool. Indicates whether to return the ``jax.jit`` of
|
|
forward and backward. Note that jit-ing only the backward but not the
|
|
forward will result in the backward recompiling on every invocation, so we
|
|
default to jit-ing both.
|
|
|
|
Returns:
|
|
The two functions, ``fwd`` and ``bwd``.
|
|
|
|
If ``has_aux`` is ``False``, ``fwd(*primals)`` returns a tuple
|
|
``(primals_out, residuals)``, where ``primals_out`` is ``fun(*primals)``.
|
|
If ``has_aux`` is ``True``, returns a ``(primals_out, residuals, aux)`` tuple
|
|
where ``aux`` is the auxiliary data returned by ``fun``.
|
|
|
|
``bwd`` is a function from ``residuals`` and a cotangent vector with the same
|
|
shape as ``primals_out`` to a tuple of cotangent vectors with the same number
|
|
and shapes as the ``primals`` designated by ``argnums``, representing the
|
|
vector-Jacobian product of ``fun`` evaluated at ``primals``.
|
|
"""
|
|
check_callable(fun)
|
|
argnums = _ensure_index(argnums)
|
|
|
|
def fwd(*args, **kwargs):
|
|
dbg = debug_info('fwd_and_bwd', fun, args, kwargs)
|
|
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
|
|
f_partial, dyn_args = argnums_partial(
|
|
f, argnums, args, require_static_args_hashable=False)
|
|
return _vjp(f_partial, *dyn_args, has_aux=has_aux)
|
|
def bwd(f_vjp, outgrad):
|
|
g = f_vjp(outgrad)
|
|
g = g[0] if isinstance(argnums, int) else g
|
|
return g
|
|
if jitted:
|
|
fwd = jit(fwd)
|
|
bwd = jit(bwd)
|
|
return fwd, bwd
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.jacfwd")
|
|
def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0,
|
|
has_aux: bool = False, holomorphic: bool = False) -> Callable:
|
|
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
|
|
|
|
Args:
|
|
fun: Function whose Jacobian is to be computed.
|
|
argnums: Optional, integer or sequence of integers. Specifies which
|
|
positional argument(s) to differentiate with respect to (default ``0``).
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
|
first element is considered the output of the mathematical function to be
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
|
holomorphic. Default False.
|
|
|
|
Returns:
|
|
A function with the same arguments as ``fun``, that evaluates the Jacobian of
|
|
``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True
|
|
then a pair of (jacobian, auxiliary_data) is returned.
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>>
|
|
>>> def f(x):
|
|
... return jnp.asarray(
|
|
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
|
|
...
|
|
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
|
|
[[ 1. 0. 0. ]
|
|
[ 0. 0. 5. ]
|
|
[ 0. 16. -2. ]
|
|
[ 1.6209 0. 0.84147]]
|
|
"""
|
|
check_callable(fun)
|
|
argnums = _ensure_index(argnums)
|
|
|
|
docstr = ("Jacobian of {fun} with respect to positional argument(s) "
|
|
"{argnums}. Takes the same arguments as {fun} but returns the "
|
|
"jacobian of the output with respect to the arguments at "
|
|
"positions {argnums}.")
|
|
|
|
@wraps(fun, docstr=docstr, argnums=argnums)
|
|
def jacfun(*args, **kwargs):
|
|
f = lu.wrap_init(
|
|
fun, kwargs,
|
|
debug_info=debug_info(
|
|
"jacfwd", fun, args, kwargs,
|
|
static_argnums=(argnums,) if isinstance(argnums, int) else argnums))
|
|
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
|
require_static_args_hashable=False)
|
|
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
|
|
pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=has_aux)
|
|
if has_aux:
|
|
y, jac, aux = vmap(pushfwd, out_axes=(None, -1, None))(_std_basis(dyn_args))
|
|
else:
|
|
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
|
|
aux = None
|
|
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
|
|
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
|
jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac)
|
|
if not has_aux:
|
|
return jac_tree
|
|
else:
|
|
return jac_tree, aux
|
|
|
|
return jacfun
|
|
|
|
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
|
|
dispatch.check_arg(x)
|
|
aval = core.typeof(x)
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
|
raise TypeError(
|
|
f"jacfwd with input element type {aval.dtype.name}")
|
|
if holomorphic:
|
|
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
|
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
|
|
f"dtype, but got {aval.dtype.name}.")
|
|
elif not dtypes.issubdtype(aval.dtype, np.floating):
|
|
raise TypeError("jacfwd requires real-valued inputs (input dtype that is "
|
|
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
|
|
"For holomorphic differentiation, pass holomorphic=True. "
|
|
"For differentiation of non-holomorphic functions involving "
|
|
"complex inputs or integer inputs, use jax.jvp directly.")
|
|
|
|
def _check_output_dtype_jacfwd(holomorphic, x):
|
|
aval = core.typeof(x)
|
|
if holomorphic:
|
|
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
|
|
raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, "
|
|
f"but got {aval.dtype.name}.")
|
|
|
|
@partial(api_boundary, repro_api_name="jax.jacrev")
|
|
def jacrev(fun: Callable, argnums: int | Sequence[int] = 0,
|
|
has_aux: bool = False, holomorphic: bool = False,
|
|
allow_int: bool = False) -> Callable:
|
|
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
|
|
|
|
Args:
|
|
fun: Function whose Jacobian is to be computed.
|
|
argnums: Optional, integer or sequence of integers. Specifies which
|
|
positional argument(s) to differentiate with respect to (default ``0``).
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
|
first element is considered the output of the mathematical function to be
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
|
holomorphic. Default False.
|
|
allow_int: Optional, bool. Whether to allow differentiating with
|
|
respect to integer valued inputs. The gradient of an integer input will
|
|
have a trivial vector-space dtype (float0). Default False.
|
|
|
|
Returns:
|
|
A function with the same arguments as ``fun``, that evaluates the Jacobian of
|
|
``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
|
|
then a pair of (jacobian, auxiliary_data) is returned.
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>>
|
|
>>> def f(x):
|
|
... return jnp.asarray(
|
|
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
|
|
...
|
|
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
|
|
[[ 1. 0. 0. ]
|
|
[ 0. 0. 5. ]
|
|
[ 0. 16. -2. ]
|
|
[ 1.6209 0. 0.84147]]
|
|
"""
|
|
check_callable(fun)
|
|
|
|
docstr = ("Jacobian of {fun} with respect to positional argument(s) "
|
|
"{argnums}. Takes the same arguments as {fun} but returns the "
|
|
"jacobian of the output with respect to the arguments at "
|
|
"positions {argnums}.")
|
|
|
|
@wraps(fun, docstr=docstr, argnums=argnums)
|
|
def jacfun(*args, **kwargs):
|
|
f = lu.wrap_init(
|
|
fun, kwargs,
|
|
debug_info=debug_info(
|
|
"jacrev", fun, args, kwargs,
|
|
static_argnums=(argnums,) if isinstance(argnums, int) else argnums))
|
|
f_partial, dyn_args = argnums_partial(f, argnums, args,
|
|
require_static_args_hashable=False)
|
|
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
|
|
if has_aux:
|
|
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
|
|
else:
|
|
y, pullback = _vjp(f_partial, *dyn_args)
|
|
aux = None
|
|
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
|
|
jac = vmap(pullback)(_std_basis(y))
|
|
jac = jac[0] if isinstance(argnums, int) else jac
|
|
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
|
|
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
|
|
jac_tree = tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree)
|
|
if not has_aux:
|
|
return jac_tree
|
|
else:
|
|
return jac_tree, aux
|
|
|
|
return jacfun
|
|
|
|
|
|
def jacobian(fun: Callable, argnums: int | Sequence[int] = 0,
|
|
has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable:
|
|
"""Alias of :func:`jax.jacrev`."""
|
|
return jacrev(fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int)
|
|
|
|
|
|
_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev")
|
|
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.hessian")
|
|
def hessian(fun: Callable, argnums: int | Sequence[int] = 0,
|
|
has_aux: bool = False, holomorphic: bool = False) -> Callable:
|
|
"""Hessian of ``fun`` as a dense array.
|
|
|
|
Args:
|
|
fun: Function whose Hessian is to be computed. Its arguments at positions
|
|
specified by ``argnums`` should be arrays, scalars, or standard Python
|
|
containers thereof. It should return arrays, scalars, or standard Python
|
|
containers thereof.
|
|
argnums: Optional, integer or sequence of integers. Specifies which
|
|
positional argument(s) to differentiate with respect to (default ``0``).
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
|
first element is considered the output of the mathematical function to be
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
|
holomorphic. Default False.
|
|
|
|
Returns:
|
|
A function with the same arguments as ``fun``, that evaluates the Hessian of
|
|
``fun``.
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
|
|
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
|
|
[[ 6. -2.]
|
|
[ -2. -480.]]
|
|
|
|
:py:func:`hessian` is a generalization of the usual definition of the Hessian
|
|
that supports nested Python containers (i.e. pytrees) as inputs and outputs.
|
|
The tree structure of ``jax.hessian(fun)(x)`` is given by forming a tree
|
|
product of the structure of ``fun(x)`` with a tree product of two copies of
|
|
the structure of ``x``. A tree product of two tree structures is formed by
|
|
replacing each leaf of the first tree with a copy of the second. For example:
|
|
|
|
>>> import jax.numpy as jnp
|
|
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
|
|
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
|
|
{'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]],
|
|
[[ 0., 0.], [ 0., 12.]]], dtype=float32),
|
|
'b': Array([[[ 1. , 0. ], [ 0. , 0. ]],
|
|
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)},
|
|
'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]],
|
|
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32),
|
|
'b': Array([[[0. , 0. ], [0. , 0. ]],
|
|
[[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
|
|
|
|
Thus each leaf in the tree structure of ``jax.hessian(fun)(x)`` corresponds to
|
|
a leaf of ``fun(x)`` and a pair of leaves of ``x``. For each leaf in
|
|
``jax.hessian(fun)(x)``, if the corresponding array leaf of ``fun(x)`` has
|
|
shape ``(out_1, out_2, ...)`` and the corresponding array leaves of ``x`` have
|
|
shape ``(in_1_1, in_1_2, ...)`` and ``(in_2_1, in_2_2, ...)`` respectively,
|
|
then the Hessian leaf has shape ``(out_1, out_2, ..., in_1_1, in_1_2, ...,
|
|
in_2_1, in_2_2, ...)``. In other words, the Python tree structure represents
|
|
the block structure of the Hessian, with blocks determined by the input and
|
|
output pytrees.
|
|
|
|
In particular, an array is produced (with no pytrees involved) when the
|
|
function input ``x`` and output ``fun(x)`` are each a single array, as in the
|
|
``g`` example above. If ``fun(x)`` has shape ``(out1, out2, ...)`` and ``x``
|
|
has shape ``(in1, in2, ...)`` then ``jax.hessian(fun)(x)`` has shape
|
|
``(out1, out2, ..., in1, in2, ..., in1, in2, ...)``. To flatten pytrees into
|
|
1D vectors, consider using :py:func:`jax.flatten_util.flatten_pytree`.
|
|
"""
|
|
return jacfwd(jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic),
|
|
argnums, has_aux=has_aux, holomorphic=holomorphic)
|
|
|
|
def _insert_pvary(basis, leaf):
|
|
if not config._check_vma.value:
|
|
return basis
|
|
return core.pvary(basis, tuple(core.typeof(leaf).mat.varying))
|
|
|
|
def _std_basis(pytree):
|
|
import jax.numpy as jnp # pytype: disable=import-error
|
|
leaves, _ = tree_flatten(pytree)
|
|
ndim = sum(map(np.size, leaves))
|
|
dtype = dtypes.result_type(*leaves)
|
|
flat_basis = jnp.eye(ndim, dtype=dtype)
|
|
axis = 1
|
|
arr_s = [None] * flat_basis.ndim
|
|
specs = tree_map(lambda l: P(arr_s[:axis], *core.typeof(l).sharding.spec,
|
|
arr_s[axis+1:]), pytree)
|
|
out_pytree = _unravel_array_into_pytree(pytree, axis, None, flat_basis, specs)
|
|
out_pytree = tree_map(_insert_pvary, out_pytree, pytree)
|
|
return out_pytree
|
|
|
|
def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr):
|
|
axis = -1 % arr.ndim
|
|
arr_s = core.typeof(arr).sharding.spec
|
|
specs = tree_map(
|
|
lambda l: P(*arr_s[:axis], *[None] * len(np.shape(l)), *arr_s[axis+1:]),
|
|
input_pytree)
|
|
return _unravel_array_into_pytree(
|
|
input_pytree, axis, output_pytree_leaf, arr, specs)
|
|
|
|
def _jacrev_unravel(output_pytree, input_pytree_leaf, arr):
|
|
specs = tree_map(
|
|
lambda l: P(*[None] * len(np.shape(l)), *core.typeof(arr).sharding.spec[1:]),
|
|
output_pytree)
|
|
return _unravel_array_into_pytree(
|
|
output_pytree, 0, input_pytree_leaf, arr, specs)
|
|
|
|
def _possible_downcast(x, example, spec):
|
|
from jax._src.lax import lax as lax_internal # pytype: disable=import-error
|
|
if (dtypes.issubdtype(x.dtype, np.complexfloating) and
|
|
not dtypes.issubdtype(_dtype(example), np.complexfloating)):
|
|
x = x.real
|
|
dtype = _dtype(example)
|
|
weak_type = dtypes.is_weakly_typed(example)
|
|
sharding = NamedSharding(core.typeof(example).sharding.mesh, spec)
|
|
return lax_internal._convert_element_type(
|
|
x, dtype, weak_type, sharding=sharding)
|
|
|
|
def _unravel_array_into_pytree(pytree, axis, example, arr, specs):
|
|
"""Unravel an array into a PyTree with a given structure.
|
|
Args:
|
|
pytree: The pytree that provides the structure.
|
|
axis: The parameter axis is either -1, 0, or 1. It controls the
|
|
resulting shapes.
|
|
example: If specified, cast the components to the matching dtype/weak_type,
|
|
or else use the pytree leaf type if example is None.
|
|
arr: The array to be unraveled.
|
|
"""
|
|
leaves, treedef = tree_flatten(pytree)
|
|
specs, _ = tree_flatten(specs)
|
|
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves]
|
|
parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis)
|
|
reshaped_parts = [
|
|
_possible_downcast(np.reshape(x, shape),
|
|
leaf if example is None else example,
|
|
spec=spec)
|
|
for x, shape, leaf, spec in zip(parts, shapes, leaves, specs)]
|
|
return tree_unflatten(treedef, reshaped_parts)
|
|
|
|
def _split(x, indices, axis):
|
|
if isinstance(x, np.ndarray):
|
|
return np.split(x, indices, axis)
|
|
else:
|
|
return x._split(indices, axis)
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.vmap")
|
|
def vmap(fun: F,
|
|
in_axes: int | None | Sequence[Any] = 0,
|
|
out_axes: Any = 0,
|
|
axis_name: AxisName | None = None,
|
|
axis_size: int | None = None,
|
|
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
|
|
sum_match: bool = False
|
|
) -> F:
|
|
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
|
|
|
|
Args:
|
|
fun: Function to be mapped over additional axes.
|
|
in_axes: An integer, None, or sequence of values specifying which input
|
|
array axes to map over.
|
|
|
|
If each positional argument to ``fun`` is an array, then ``in_axes`` can
|
|
be an integer, a None, or a tuple of integers and Nones with length equal
|
|
to the number of positional arguments to ``fun``. An integer or ``None``
|
|
indicates which array axis to map over for all arguments (with ``None``
|
|
indicating not to map any axis), and a tuple indicates which axis to map
|
|
for each corresponding positional argument. Axis integers must be in the
|
|
range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
|
|
dimensions (axes) of the corresponding input array.
|
|
|
|
If the positional arguments to ``fun`` are container (pytree) types, ``in_axes``
|
|
must be a sequence with length equal to the number of positional arguments to
|
|
``fun``, and for each argument the corresponding element of ``in_axes`` can
|
|
be a container with a matching pytree structure specifying the mapping of its
|
|
container elements. In other words, ``in_axes`` must be a container tree prefix
|
|
of the positional argument tuple passed to ``fun``. See this link for more detail:
|
|
https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
|
|
|
|
Either ``axis_size`` must be provided explicitly, or at least one
|
|
positional argument must have ``in_axes`` not None. The sizes of the
|
|
mapped input axes for all mapped positional arguments must all be equal.
|
|
|
|
Arguments passed as keywords are always mapped over their leading axis
|
|
(i.e. axis index 0).
|
|
|
|
See below for examples.
|
|
|
|
out_axes: An integer, None, or (nested) standard Python container
|
|
(tuple/list/dict) thereof indicating where the mapped axis should appear
|
|
in the output. All outputs with a mapped axis must have a non-None
|
|
``out_axes`` specification. Axis integers must be in the range ``[-ndim,
|
|
ndim)`` for each output array, where ``ndim`` is the number of dimensions
|
|
(axes) of the array returned by the :func:`vmap`-ed function, which is one
|
|
more than the number of dimensions (axes) of the corresponding array
|
|
returned by ``fun``.
|
|
axis_name: Optional, a hashable Python object used to identify the mapped
|
|
axis so that parallel collectives can be applied.
|
|
axis_size: Optional, an integer indicating the size of the axis to be
|
|
mapped. If not provided, the mapped axis size is inferred from arguments.
|
|
|
|
Returns:
|
|
Batched/vectorized version of ``fun`` with arguments that correspond to
|
|
those of ``fun``, but with extra array axes at positions indicated by
|
|
``in_axes``, and a return value that corresponds to that of ``fun``, but
|
|
with extra array axes at positions indicated by ``out_axes``.
|
|
|
|
For example, we can implement a matrix-matrix product using a vector dot
|
|
product:
|
|
|
|
>>> import jax.numpy as jnp
|
|
>>>
|
|
>>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> []
|
|
>>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
|
|
>>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
|
|
|
|
Here we use ``[a,b]`` to indicate an array with shape (a,b). Here are some
|
|
variants:
|
|
|
|
>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis)
|
|
>>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis)
|
|
>>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
|
|
|
|
Here's an example of using container types in ``in_axes`` to specify which
|
|
axes of the container elements to map over:
|
|
|
|
>>> A, B, C, D = 2, 3, 4, 5
|
|
>>> x = jnp.ones((A, B))
|
|
>>> y = jnp.ones((B, C))
|
|
>>> z = jnp.ones((C, D))
|
|
>>> def foo(tree_arg):
|
|
... x, (y, z) = tree_arg
|
|
... return jnp.dot(x, jnp.dot(y, z))
|
|
>>> tree = (x, (y, z))
|
|
>>> print(foo(tree))
|
|
[[12. 12. 12. 12. 12.]
|
|
[12. 12. 12. 12. 12.]]
|
|
>>> from jax import vmap
|
|
>>> K = 6 # batch size
|
|
>>> x = jnp.ones((K, A, B)) # batch axis in different locations
|
|
>>> y = jnp.ones((B, K, C))
|
|
>>> z = jnp.ones((C, D, K))
|
|
>>> tree = (x, (y, z))
|
|
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
|
|
>>> print(vfoo(tree).shape)
|
|
(6, 2, 5)
|
|
|
|
Here's another example using container types in ``in_axes``, this time a
|
|
dictionary, to specify the elements of the container to map over:
|
|
|
|
>>> dct = {'a': 0., 'b': jnp.arange(5.)}
|
|
>>> x = 1.
|
|
>>> def foo(dct, x):
|
|
... return dct['a'] + dct['b'] + x
|
|
>>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
|
|
>>> print(out)
|
|
[1. 2. 3. 4. 5.]
|
|
|
|
The results of a vectorized function can be mapped or unmapped. For example,
|
|
the function below returns a pair with the first element mapped and the second
|
|
unmapped. Only for unmapped results we can specify ``out_axes`` to be ``None``
|
|
(to keep it unmapped).
|
|
|
|
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
|
|
(Array([4., 5.], dtype=float32), 8.0)
|
|
|
|
If the ``out_axes`` is specified for an unmapped result, the result is
|
|
broadcast across the mapped axis:
|
|
|
|
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
|
|
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
|
|
|
|
If the ``out_axes`` is specified for a mapped result, the result is transposed
|
|
accordingly.
|
|
|
|
Finally, here's an example using ``axis_name`` together with collectives:
|
|
|
|
>>> xs = jnp.arange(3. * 4.).reshape(3, 4)
|
|
>>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
|
|
[[12. 15. 18. 21.]
|
|
[12. 15. 18. 21.]
|
|
[12. 15. 18. 21.]]
|
|
|
|
See the :py:func:`jax.pmap` docstring for more examples involving collectives.
|
|
"""
|
|
check_callable(fun)
|
|
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
|
|
"but with additional array axes over which {fun} is mapped.")
|
|
if fun.__doc__:
|
|
docstr += "\n\nOriginal documentation:\n\n"
|
|
docstr += fun.__doc__
|
|
|
|
axis_name = core.no_axis_name if axis_name is None else axis_name
|
|
if spmd_axis_name is not None and not isinstance(spmd_axis_name, tuple):
|
|
spmd_axis_name = (spmd_axis_name,)
|
|
|
|
if isinstance(in_axes, list):
|
|
# To be a tree prefix of the positional args tuple, in_axes can never be a
|
|
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
|
|
# in cases like these users expect tuples and lists to be treated
|
|
# essentially interchangeably, so we canonicalize lists to tuples here
|
|
# rather than raising an error. https://github.com/jax-ml/jax/issues/2367
|
|
in_axes = tuple(in_axes)
|
|
|
|
from jax._src import hijax # pytype: disable=import-error
|
|
if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}
|
|
or isinstance(in_axes, hijax.MappingSpec)):
|
|
raise TypeError("vmap in_axes must be an int, None, or a tuple of entries corresponding "
|
|
f"to the positional arguments passed to the function, but got {in_axes}.")
|
|
if not all(type(l) in {int, *batching.spec_types} or isinstance(l, hijax.MappingSpec)
|
|
for l in tree_leaves(in_axes)):
|
|
raise TypeError("vmap in_axes must be an int, None, or (nested) container "
|
|
f"with those types as leaves, but got {in_axes}.")
|
|
if not all(type(l) in {int, *batching.spec_types} or isinstance(l, hijax.MappingSpec)
|
|
for l in tree_leaves(out_axes)):
|
|
raise TypeError("vmap out_axes must be an int, None, or (nested) container "
|
|
f"with those types as leaves, but got {out_axes}.")
|
|
|
|
@wraps(fun, docstr=docstr)
|
|
@api_boundary
|
|
def vmap_f(*args, **kwargs):
|
|
nonlocal spmd_axis_name
|
|
if isinstance(in_axes, tuple) and len(in_axes) != len(args):
|
|
raise ValueError("vmap in_axes must be an int, None, or a tuple of entries corresponding "
|
|
"to the positional arguments passed to the function, "
|
|
f"but got {len(in_axes)=}, {len(args)=}")
|
|
|
|
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
|
|
dbg = debug_info("vmap", fun, args, kwargs)
|
|
api_util.check_no_transformed_refs_args(lambda: dbg, args_flat)
|
|
f = lu.wrap_init(fun, debug_info=dbg)
|
|
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
|
|
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
|
|
|
|
if config.mutable_array_checks.value:
|
|
avals = [None if d is None or batching.is_vmappable(x) else core.typeof(x)
|
|
for x, d in zip(args_flat, in_axes_flat)]
|
|
api_util.check_no_aliased_ref_args(lambda: dbg, avals, args_flat)
|
|
|
|
axis_size_ = _mapped_axis_size(
|
|
fun, in_tree, args_flat, in_axes_flat, "vmap", axis_size=axis_size)
|
|
explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
|
|
_check_ema_unmapped_args(explicit_mesh_axis, args_flat, in_axes_flat)
|
|
if spmd_axis_name is not None and explicit_mesh_axis is not None:
|
|
if config.remove_size_one_mesh_axis_from_type.value:
|
|
mesh = get_abstract_mesh()
|
|
spmd_axis_name = tuple(i for i in spmd_axis_name if mesh.shape[i] != 1)
|
|
if spmd_axis_name == explicit_mesh_axis:
|
|
spmd_axis_name = None
|
|
else:
|
|
raise ValueError(
|
|
"Only one of spmd_axis_name or arrays sharded on `Explicit` mesh"
|
|
f" axis type is allowed. Got {spmd_axis_name=} and"
|
|
f" arrays sharded on {explicit_mesh_axis=}")
|
|
assert spmd_axis_name is None
|
|
try:
|
|
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
|
|
explicit_mesh_axis)
|
|
out_flat, inferred_out_axes = batching.batch(
|
|
flat_fun, axis_data, in_axes_flat,
|
|
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
|
|
sum_match=sum_match
|
|
).call_wrapped(*args_flat)
|
|
except batching.SpecMatchError as e:
|
|
out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)
|
|
out_axes_full = tree_unflatten(out_tree(), out_axes_flat)
|
|
pairs, _ = tree_flatten_with_path(out_axes_full, is_leaf=lambda x: x is None)
|
|
|
|
path, _ = pairs[e.leaf_idx]
|
|
raise ValueError(f'at vmap out_axes{keystr(path)}, got axis spec {e.dst} '
|
|
f'but output was batched on axis {e.src}') from None
|
|
if any(d is batching.infer for d in tree_leaves(out_axes)):
|
|
return (tree_unflatten(out_tree(), out_flat),
|
|
tree_unflatten(out_tree(), inferred_out_axes))
|
|
else:
|
|
return tree_unflatten(out_tree(), out_flat)
|
|
|
|
return cast(F, vmap_f)
|
|
|
|
def _mapped_axis_spec(args_flat, in_axes_flat):
|
|
def _get_spec(arg, i):
|
|
try:
|
|
# Duck type arrays like BCOO arrays can be passed to vmap.
|
|
return shaped_abstractify(arg).sharding.spec[i]
|
|
except (IndexError, TypeError, AttributeError):
|
|
return None
|
|
|
|
out_spec = None
|
|
non_none_count = 0
|
|
for arg, i in zip(args_flat, in_axes_flat):
|
|
if i is not None:
|
|
spec = _get_spec(arg, i)
|
|
if non_none_count != 0 and out_spec != spec:
|
|
raise ValueError(
|
|
"Mapped away dimension of inputs passed to vmap should be sharded"
|
|
f" the same. Got inconsistent axis specs: {out_spec} vs {spec}")
|
|
out_spec = spec
|
|
non_none_count += 1
|
|
if out_spec is not None and not isinstance(out_spec, tuple):
|
|
out_spec = (out_spec,)
|
|
return out_spec
|
|
|
|
def _check_ema_unmapped_args(ema, args_flat, in_axes_flat):
|
|
if ema is None:
|
|
return
|
|
for a, i in zip(args_flat, in_axes_flat):
|
|
if i is None:
|
|
aval = core.typeof(a)
|
|
spec = set(sharding_impls.flatten_spec(aval.sharding.spec))
|
|
if any(e in spec for e in ema):
|
|
raise ValueError(
|
|
"Unmapped values passed to vmap cannot be sharded along the mesh"
|
|
f" axis you are vmapping over. Got type: {aval.str_short(True)},"
|
|
f" in_axes: {i} and vmapped mesh axis: {ema}")
|
|
|
|
def _mapped_axis_size(fn, tree, vals, dims, name, axis_size=None):
|
|
if not vals:
|
|
if axis_size is not None:
|
|
return axis_size
|
|
args, kwargs = tree_unflatten(tree, vals)
|
|
raise ValueError(
|
|
f"{name} wrapped function must be passed at least one argument "
|
|
"containing an array or axis_size must be specified, got empty "
|
|
f"*args={args} and **kwargs={kwargs}"
|
|
)
|
|
|
|
def _get_axis_size(name: str, x, shape: tuple[core.AxisSize, ...], axis: int
|
|
) -> core.AxisSize | None:
|
|
try:
|
|
return shape[axis]
|
|
except (IndexError, TypeError) as e:
|
|
if not core.valid_jaxtype(x) or not isinstance(axis, int):
|
|
return None # Suppress the check for custom vmappable types.
|
|
min_rank = axis + 1 if axis >= 0 else -axis
|
|
# TODO(mattjj): better error message here
|
|
raise ValueError(
|
|
f"{name} was requested to map its argument along axis {axis}, "
|
|
f"which implies that its rank should be at least {min_rank}, "
|
|
f"but is only {len(shape)} (its shape is {shape})") from e
|
|
|
|
all_mapped_sizes = [
|
|
None if d is None else _get_axis_size(name, x, np.shape(x), d)
|
|
for x, d in zip(vals, dims)
|
|
]
|
|
all_sizes = [s for s in all_mapped_sizes if s is not None]
|
|
if axis_size is not None:
|
|
all_sizes.append(axis_size)
|
|
sizes = core.dedup_referents(all_sizes)
|
|
if len(sizes) == 1:
|
|
sz, = sizes
|
|
return sz
|
|
if not sizes:
|
|
raise ValueError(f"{name} must have at least one non-None value in in_axes "
|
|
"or axis_size must be specified")
|
|
|
|
def _get_argument_type(x):
|
|
try:
|
|
return shaped_abstractify(x).str_short()
|
|
except TypeError: # Catch all for user specified objects that can't be interpreted as a data type
|
|
return "unknown"
|
|
msg = [f"{name} got inconsistent sizes for array axes to be mapped:\n"]
|
|
args, kwargs = tree_unflatten(tree, vals)
|
|
try:
|
|
ba = inspect.signature(fn).bind(*args, **kwargs)
|
|
signature_parameters: list[str] | None = list(ba.signature.parameters.keys())
|
|
except (TypeError, ValueError):
|
|
signature_parameters = None
|
|
|
|
def arg_name(key_path):
|
|
if signature_parameters is None:
|
|
return f"args{keystr(key_path)}"
|
|
# args is a tuple, so key_path[0].idx is the index into args.
|
|
i = key_path[0].idx
|
|
# This can happen with star arguments (*args)
|
|
if i >= len(signature_parameters):
|
|
return f"args{keystr(key_path)}"
|
|
res = f"argument {signature_parameters[i]}"
|
|
if len(key_path) > 1:
|
|
res += keystr(key_path[1:])
|
|
return res
|
|
|
|
args_paths = [
|
|
f"{arg_name(p)} of type {_get_argument_type(x)}"
|
|
for (p, x) in generate_key_paths(args)
|
|
]
|
|
kwargs_paths = [
|
|
f"kwargs{keystr(p)} of type {_get_argument_type(x)}"
|
|
for p, x in generate_key_paths(kwargs)
|
|
]
|
|
key_paths = [*args_paths, *kwargs_paths]
|
|
size_counts = collections.Counter(s for s in all_mapped_sizes if s is not None)
|
|
(sz, ct), *other_counts = counts = size_counts.most_common()
|
|
|
|
def _all_sizes_index(sz):
|
|
for i, isz in enumerate(all_mapped_sizes):
|
|
if core.definitely_equal(isz, sz): return i
|
|
assert False, (sz, all_mapped_sizes)
|
|
|
|
ex, *examples = (key_paths[_all_sizes_index(sz)] for sz, _ in counts)
|
|
ax, *axs = (dims[_all_sizes_index(sz)] for sz, _ in counts)
|
|
|
|
if axis_size is not None:
|
|
msg.append(f" * the `axis_size` argument was {axis_size};\n")
|
|
if ct == 1:
|
|
msg.append(f" * one axis had size {sz}: axis {ax} of {ex};\n")
|
|
else:
|
|
msg.append(f" * most axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
|
|
for ex, ax, (sz, ct) in zip(examples, axs, other_counts):
|
|
if ct == 1:
|
|
msg.append(f" * one axis had size {sz}: axis {ax} of {ex};\n")
|
|
else:
|
|
msg.append(f" * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
|
|
raise ValueError(''.join(msg)[:-2]) # remove last semicolon and newline
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.jvp")
|
|
def jvp(
|
|
fun: Callable, primals, tangents, has_aux: bool = False
|
|
) -> tuple[Any, ...]:
|
|
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
|
|
|
|
Args:
|
|
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
|
or standard Python containers of arrays or scalars. It should return an
|
|
array, scalar, or standard Python container of arrays or scalars.
|
|
primals: The primal values at which the Jacobian of ``fun`` should be
|
|
evaluated. Should be either a tuple or a list of arguments,
|
|
and its length should be equal to the number of positional parameters of
|
|
``fun``.
|
|
tangents: The tangent vector for which the Jacobian-vector product should be
|
|
evaluated. Should be either a tuple or a list of tangents, with the same
|
|
tree structure and array shapes as ``primals``.
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
|
first element is considered the output of the mathematical function to be
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
|
|
Returns:
|
|
If ``has_aux`` is ``False``, returns a ``(primals_out, tangents_out)`` pair,
|
|
where ``primals_out`` is ``fun(*primals)``,
|
|
and ``tangents_out`` is the Jacobian-vector product of
|
|
``function`` evaluated at ``primals`` with ``tangents``. The
|
|
``tangents_out`` value has the same Python tree structure and shapes as
|
|
``primals_out``. If ``has_aux`` is ``True``, returns a
|
|
``(primals_out, tangents_out, aux)`` tuple where ``aux``
|
|
is the auxiliary data returned by ``fun``.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
|
|
>>> print(primals)
|
|
0.09983342
|
|
>>> print(tangents)
|
|
0.19900084
|
|
"""
|
|
check_callable(fun)
|
|
if (not isinstance(primals, (tuple, list)) or
|
|
not isinstance(tangents, (tuple, list))):
|
|
raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; "
|
|
f"found {type(primals).__name__} and {type(tangents).__name__}.")
|
|
return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})),
|
|
primals, tangents, has_aux=has_aux)
|
|
|
|
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
|
|
"""Variant of jvp() that takes an lu.WrappedFun."""
|
|
ps_flat, tree_def = tree_flatten(primals)
|
|
ts_flat, tree_def_2 = tree_flatten(tangents)
|
|
if tree_def != tree_def_2:
|
|
raise TypeError("primal and tangent arguments to jax.jvp must have the same tree "
|
|
f"structure; primals have tree structure {tree_def} whereas tangents have "
|
|
f"tree structure {tree_def_2}.")
|
|
for p, t in zip(ps_flat, ts_flat):
|
|
if not isinstance(core.typeof(p), ShapedArray): continue
|
|
if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t):
|
|
raise TypeError("primal and tangent arguments to jax.jvp do not match; "
|
|
"dtypes must be equal, or in case of int/bool primal dtype "
|
|
"the tangent dtype must be float0."
|
|
f"Got primal dtype {_dtype(p)} and so expected tangent dtype "
|
|
f"{core.primal_dtype_to_tangent_dtype(_dtype(p))}, but got "
|
|
f"tangent dtype {_dtype(t)} instead.")
|
|
if np.shape(p) != np.shape(t):
|
|
raise ValueError("jvp called with different primal and tangent shapes;"
|
|
f"Got primal shape {np.shape(p)} and tangent shape as {np.shape(t)}")
|
|
|
|
if not has_aux:
|
|
flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def)
|
|
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
|
|
out_tree = out_tree()
|
|
return (tree_unflatten(out_tree, out_primals),
|
|
tree_unflatten(out_tree, out_tangents))
|
|
else:
|
|
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, tree_def)
|
|
jvp_fun, aux = ad.jvp(flat_fun, has_aux=True)
|
|
out_primals, out_tangents = jvp_fun.call_wrapped(ps_flat, ts_flat)
|
|
out_tree, aux_tree = out_aux_trees()
|
|
return (tree_unflatten(out_tree, out_primals),
|
|
tree_unflatten(out_tree, out_tangents),
|
|
tree_unflatten(aux_tree, aux()))
|
|
|
|
@overload
|
|
def linearize(fun: Callable, *primals, has_aux: Literal[False] = False
|
|
) -> tuple[Any, Callable]:
|
|
...
|
|
|
|
@overload
|
|
def linearize(fun: Callable, *primals, has_aux: Literal[True]
|
|
) -> tuple[Any, Callable, Any]:
|
|
...
|
|
|
|
@partial(api_boundary, repro_api_name="jax.linearize")
|
|
def linearize(fun: Callable, *primals, has_aux: bool = False
|
|
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
|
|
"""Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
|
|
|
|
Args:
|
|
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
|
or standard Python containers of arrays or scalars. It should return an
|
|
array, scalar, or standard python container of arrays or scalars.
|
|
primals: The primal values at which the Jacobian of ``fun`` should be
|
|
evaluated. Should be a tuple of arrays, scalar, or standard Python
|
|
container thereof. The length of the tuple is equal to the number of
|
|
positional parameters of ``fun``.
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first
|
|
element is considered the output of the mathematical function to be linearized,
|
|
and the second is auxiliary data. Default False.
|
|
|
|
Returns:
|
|
If ``has_aux`` is ``False``, returns a pair where the first element is the value of
|
|
``f(*primals)`` and the second element is a function that evaluates the
|
|
(forward-mode) Jacobian-vector product of ``fun`` evaluated at ``primals`` without
|
|
re-doing the linearization work. If ``has_aux`` is ``True``, returns a
|
|
``(primals_out, lin_fn, aux)`` tuple where ``aux`` is the auxiliary data returned by
|
|
``fun``.
|
|
|
|
In terms of values computed, :py:func:`linearize` behaves much like a curried
|
|
:py:func:`jvp`, where these two code blocks compute the same values::
|
|
|
|
y, out_tangent = jax.jvp(f, (x,), (in_tangent,))
|
|
|
|
y, f_jvp = jax.linearize(f, x)
|
|
out_tangent = f_jvp(in_tangent)
|
|
|
|
However, the difference is that :py:func:`linearize` uses partial evaluation
|
|
so that the function ``f`` is not re-linearized on calls to ``f_jvp``. In
|
|
general that means the memory usage scales with the size of the computation,
|
|
much like in reverse-mode. (Indeed, :py:func:`linearize` has a similar
|
|
signature to :py:func:`vjp`!)
|
|
|
|
This function is mainly useful if you want to apply ``f_jvp`` multiple times,
|
|
i.e. to evaluate a pushforward for many different input tangent vectors at the
|
|
same linearization point. Moreover if all the input tangent vectors are known
|
|
at once, it can be more efficient to vectorize using :py:func:`vmap`, as in::
|
|
|
|
pushfwd = partial(jvp, f, (x,))
|
|
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
|
|
|
|
By using :py:func:`vmap` and :py:func:`jvp` together like this we avoid the stored-linearization
|
|
memory cost that scales with the depth of the computation, which is incurred
|
|
by both :py:func:`linearize` and :py:func:`vjp`.
|
|
|
|
Here's a more complete example of using :py:func:`linearize`:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>>
|
|
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
|
|
...
|
|
>>> jax.jvp(f, (2.,), (3.,))
|
|
(Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True))
|
|
>>> y, f_jvp = jax.linearize(f, 2.)
|
|
>>> print(y)
|
|
3.2681944
|
|
>>> print(f_jvp(3.))
|
|
-5.007528
|
|
>>> print(f_jvp(4.))
|
|
-6.676704
|
|
"""
|
|
check_callable(fun)
|
|
f = lu.wrap_init(fun, debug_info=debug_info("linearize", fun, primals, {}))
|
|
primals_flat, in_tree = tree_flatten(primals)
|
|
if has_aux:
|
|
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
|
|
else:
|
|
jaxtree_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
|
|
out_primals, out_pvals, jaxpr, consts, *maybe_aux = ad.linearize(
|
|
jaxtree_fun, *primals_flat, has_aux=has_aux)
|
|
if has_aux:
|
|
out_tree, aux_tree = out_tree()
|
|
else:
|
|
out_tree, aux_tree = out_tree(), None
|
|
out_primal_py = tree_unflatten(out_tree, out_primals)
|
|
primal_avals = list(map(core.typeof, primals_flat))
|
|
# Ensure that lifted_jvp is a PyTree
|
|
lifted_jvp = Partial(partial(_lift_linearized, jaxpr, primal_avals,
|
|
(in_tree, out_tree), out_pvals), consts)
|
|
if has_aux:
|
|
[aux] = maybe_aux
|
|
assert aux_tree is not None
|
|
return out_primal_py, lifted_jvp, tree_unflatten(aux_tree, aux)
|
|
else:
|
|
[] = maybe_aux
|
|
return out_primal_py, lifted_jvp
|
|
|
|
def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
|
|
def fun(*tangents):
|
|
tangent_avals = list(map(core.typeof, tangents))
|
|
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
|
|
expected_tangent_aval = primal_aval.to_tangent_aval()
|
|
if not core.typecompat(expected_tangent_aval, tangent_aval):
|
|
extra_msg = ''
|
|
if (isinstance(primal_aval, core.ShapedArray) and
|
|
isinstance(tangent_aval, core.ShapedArray) and
|
|
primal_aval.mat != tangent_aval.mat):
|
|
# TODO(yashkatariya): Tweak error.
|
|
pvary_applications = []
|
|
if left := tangent_aval.mat.varying - primal_aval.mat.varying:
|
|
pvary_applications.append(
|
|
f"applying `jax.lax.pcast(..., {tuple(left)}, to='varying')` to"
|
|
" the primal value passed to `jax.linearize`")
|
|
if left := primal_aval.mat.varying - tangent_aval.mat.varying:
|
|
pvary_applications.append(
|
|
f"applying `jax.lax.pcast(..., {tuple(left)}, to='varying')` to"
|
|
" the tangent value passed to the callable `f_jvp` returned by"
|
|
" `jax.linearize`")
|
|
extra_msg = " \nThis might be fixed by:\n" + "\n".join(
|
|
f" * {d};" for d in pvary_applications)
|
|
raise ValueError(
|
|
"linearized function called on tangent values inconsistent with "
|
|
"the original primal values:\n"
|
|
f"Got tangent aval {tangent_aval} for primal aval {primal_aval} "
|
|
f"but expected {expected_tangent_aval}.{extra_msg}")
|
|
tangents_out = eval_jaxpr(jaxpr, consts, *tangents)
|
|
tangents_out_ = iter(tangents_out)
|
|
full_out = [pval.get_known() if pval.is_known() else next(tangents_out_)
|
|
for pval in out_pvals]
|
|
assert next(tangents_out_, None) is None
|
|
return full_out
|
|
|
|
return apply_flat_fun_nokwargs(fun, io_tree, py_args)
|
|
|
|
# TODO(mattjj): see similar function in custom_derivatives.py
|
|
def _temporary_dtype_exception(a, a_) -> bool:
|
|
if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray):
|
|
return a.shape == a_.shape and a_.dtype == float0
|
|
return False
|
|
|
|
@overload
|
|
def vjp(fun: Callable[..., T],
|
|
*primals: Any,
|
|
has_aux: Literal[False] = False,
|
|
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable]:
|
|
...
|
|
|
|
@overload
|
|
def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
|
|
has_aux: Literal[True],
|
|
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
|
|
...
|
|
|
|
@partial(api_boundary, repro_api_name="jax.vjp")
|
|
def vjp(
|
|
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
|
|
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
|
|
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
|
|
|
|
:py:func:`grad` is implemented as a special case of :py:func:`vjp`.
|
|
|
|
Args:
|
|
fun: Function to be differentiated. Its arguments should be arrays, scalars,
|
|
or standard Python containers of arrays or scalars. It should return an
|
|
array, scalar, or standard Python container of arrays or scalars.
|
|
primals: A sequence of primal values at which the Jacobian of ``fun``
|
|
should be evaluated. The number of ``primals`` should be equal to the
|
|
number of positional parameters of ``fun``. Each primal value should be
|
|
an array, a scalar, or a pytree (standard Python containers) thereof.
|
|
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
|
|
first element is considered the output of the mathematical function to be
|
|
differentiated and the second element is auxiliary data. Default False.
|
|
|
|
Returns:
|
|
If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where
|
|
``primals_out`` is ``fun(*primals)``. If ``has_aux`` is ``True``, returns a
|
|
``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data
|
|
returned by ``fun``.
|
|
|
|
``vjpfun`` is a function from a cotangent vector with the same shape as
|
|
``primals_out`` to a tuple of cotangent vectors with the same number and
|
|
shapes as ``primals``, representing the vector-Jacobian product of ``fun``
|
|
evaluated at ``primals``.
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> def f(x, y):
|
|
... return jax.numpy.sin(x), jax.numpy.cos(y)
|
|
...
|
|
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
|
|
>>> xbar, ybar = f_vjp((-0.7, 0.3))
|
|
>>> print(xbar)
|
|
-0.61430776
|
|
>>> print(ybar)
|
|
-0.2524413
|
|
"""
|
|
if reduce_axes:
|
|
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
|
|
del reduce_axes
|
|
check_callable(fun)
|
|
wrapped_fun = lu.wrap_init(
|
|
fun, debug_info=debug_info("vjp", fun, primals, {}))
|
|
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
|
|
|
|
def _vjp(fun, *primals, has_aux=False):
|
|
canon = lambda x: x if isinstance(x, core.Tracer) else canonicalize_value(x)
|
|
primals = tree_map(canon, primals)
|
|
primals_flat, in_tree = tree_flatten(primals)
|
|
for arg in primals_flat:
|
|
dispatch.check_arg(arg)
|
|
if not has_aux:
|
|
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
|
|
out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(
|
|
flat_fun, *primals_flat, is_vjp=True)
|
|
out_tree = out_tree()
|
|
aux = aux_tree = None
|
|
else:
|
|
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
|
|
out_primals_flat, out_pvals, jaxpr, residuals, aux = ad.linearize(
|
|
flat_fun, *primals_flat, has_aux=True, is_vjp=True)
|
|
out_tree, aux_tree = out_aux_trees()
|
|
del out_aux_trees
|
|
out_known = [pval.is_known() for pval in out_pvals]
|
|
id_map = {id(x): i for i, x in enumerate(primals_flat)}
|
|
used, opaque_residuals = set(), []
|
|
spec = [used.add(id(r)) or RSpec(id_map[id(r)], True) if id(r) in id_map else
|
|
RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False)
|
|
for r in residuals]
|
|
args_res = tuptree_map(lambda x: x if id(x) in used else NotNeeded(),
|
|
in_tree, primals_flat)
|
|
out_primal_avals = [typeof(x) for x in out_primals_flat]
|
|
f_vjp = VJP(partial(_vjp3_callable, spec, out_known, jaxpr, out_primal_avals),
|
|
in_tree, out_tree, list(args_res), opaque_residuals)
|
|
out_primals = tree_unflatten(out_tree, out_primals_flat)
|
|
if not has_aux:
|
|
return out_primals, f_vjp
|
|
else:
|
|
assert aux is not None
|
|
assert aux_tree is not None
|
|
return out_primals, f_vjp, tree_unflatten(aux_tree, aux)
|
|
|
|
def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
|
|
args_res, opaque_res, *maybe_ct_refs):
|
|
if not maybe_ct_refs:
|
|
maybe_ct_refs_flat = [GradValue()] * in_tree.num_leaves
|
|
else:
|
|
maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs)
|
|
if in_tree != in_tree_:
|
|
raise Exception # TODO accept isomorph tuple tree
|
|
args_res_ = tree_leaves(args_res, is_leaf=lambda x: isinstance(x, NotNeeded))
|
|
residuals = [args_res_[i.idx] if i.primal else opaque_res[i.idx] for i in spec]
|
|
maybe_accums = [x if isinstance(x, ad.GradAccum) else
|
|
ad.RefAccum(v.aval, x) if _is_ref(x) else ad.NullAccum(v.aval)
|
|
if isinstance(x, DontWant) else ad.ValAccum(v.aval)
|
|
for v, x in zip(jaxpr.invars, maybe_ct_refs_flat)]
|
|
return Partial(partial(_vjp3_bwd, in_tree, out_tree, out_known, jaxpr,
|
|
out_primal_avals), residuals, maybe_accums)
|
|
|
|
def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals,
|
|
maybe_accums, out_ct):
|
|
cts_flat, out_tree_ = tree_flatten(out_ct)
|
|
if out_tree != out_tree_:
|
|
_vjp_ct_tree_error(jaxpr, out_tree, out_tree_)
|
|
_vjp_check_ct_avals(cts_flat, out_primal_avals)
|
|
cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k]
|
|
ad.backward_pass3(jaxpr, True, residuals, maybe_accums, cts_flat)
|
|
arg_cts = [x.freeze() if isinstance(x, ad.ValAccum) else
|
|
DidntWant() if isinstance(x, ad.NullAccum) else GradRef()
|
|
for x in maybe_accums]
|
|
arg_cts = map(ad.instantiate_zeros, arg_cts)
|
|
return tree_unflatten(in_tree, arg_cts)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class RSpec:
|
|
idx: int
|
|
primal: bool
|
|
|
|
def tuptree_map(f, treedef, x):
|
|
return treedef.walk(lambda xs, _: tuple(xs), f, x)
|
|
|
|
def _is_ref(x):
|
|
from jax._src.state.types import AbstractRef
|
|
try:
|
|
return isinstance(typeof(x), AbstractRef)
|
|
except:
|
|
return False
|
|
|
|
|
|
_vjp_too_many_args = """
|
|
The function returned by `jax.vjp` applied to {} was called with {} arguments,
|
|
but functions returned by `jax.vjp` must be called with a single argument
|
|
corresponding to the single value returned by the function being differentiated
|
|
(even if that returned value is a tuple or other container).
|
|
|
|
For example, if we have:
|
|
|
|
def f(x):
|
|
return (x, x)
|
|
_, f_vjp = jax.vjp(f, 1.0)
|
|
|
|
the function `f` returns a single tuple as output, and so we call `f_vjp` with a
|
|
single tuple as its argument:
|
|
|
|
x_bar, = f_vjp((2.0, 2.0))
|
|
|
|
If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted out' as
|
|
arguments rather than in a tuple, this error can arise.
|
|
""".format
|
|
|
|
|
|
def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree):
|
|
msg = f"""unexpected tree structure.
|
|
|
|
The argument to a VJP function returned by `jax.vjp` must match the pytree
|
|
structure of the differentiated function {jaxpr.debug_info.func_src_info}.
|
|
|
|
But the tree structures differ:
|
|
"""
|
|
msg += '\n'.join(f" * out{keystr(path)} was a {thing1} in the original "
|
|
f" output, but a {thing2} here, so {explanation}."
|
|
for path, thing1, thing2, explanation
|
|
in equality_errors_pytreedef(out_tree, ct_tree))
|
|
raise ValueError(msg)
|
|
|
|
|
|
def _vjp_check_ct_avals(cts, primal_avals):
|
|
# TODO(mattjj): improve this error by flattening with keys in the first place
|
|
for ct, aval in zip(cts, primal_avals):
|
|
ct_aval = typeof(ct)
|
|
ct_aval_expected = aval.to_ct_aval()
|
|
if (not core.typecompat(ct_aval, ct_aval_expected) and
|
|
not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
|
|
raise ValueError(
|
|
"unexpected JAX type (e.g. shape/dtype) for argument to VJP function: "
|
|
f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} "
|
|
"because the corresponding output of the differentiated function had JAX type "
|
|
f"{aval.str_short()}")
|
|
|
|
|
|
@register_dataclass
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NotNeeded:
|
|
pass
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GradValue:
|
|
pass
|
|
|
|
@register_dataclass
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GradRef:
|
|
pass
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class DontWant:
|
|
pass
|
|
|
|
@register_dataclass
|
|
@dataclasses.dataclass(frozen=True)
|
|
class DidntWant:
|
|
pass
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class VJP:
|
|
fun: Callable
|
|
in_tree: PyTreeDef
|
|
out_tree: PyTreeDef
|
|
args_res: list[Any]
|
|
opaque_residuals: list[Any]
|
|
jaxpr = property(lambda self: self.fun.args[2])
|
|
|
|
def __call__(self, out_ct, *extra_args):
|
|
if extra_args:
|
|
name, *_ = self.jaxpr.debug_info.func_src_info.split(' ')
|
|
raise TypeError(_vjp_too_many_args(name, len(extra_args) + 1))
|
|
return self.fun(self.in_tree, self.out_tree, self.args_res,
|
|
self.opaque_residuals)(out_ct)
|
|
|
|
def with_refs(self, *maybe_ct_refs):
|
|
return self.fun(self.in_tree, self.out_tree, self.args_res,
|
|
self.opaque_residuals, *maybe_ct_refs)
|
|
|
|
# Only safe to put these in cache keys if residuals aren't mutated. Beware!
|
|
__hash__ = object.__hash__
|
|
__eq__ = object.__eq__
|
|
|
|
register_pytree_node(
|
|
VJP,
|
|
lambda vjp: ((vjp.args_res, vjp.opaque_residuals),
|
|
(vjp.fun, vjp.in_tree, vjp.out_tree)),
|
|
lambda meta, args_res: VJP(*meta, *args_res))
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.linear_transpose")
|
|
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
|
|
"""Transpose a function that is promised to be linear.
|
|
|
|
For linear functions, this transformation is equivalent to :py:func:`vjp`, but
|
|
avoids the overhead of computing the forward pass.
|
|
|
|
The outputs of the transposed function will always have the exact same dtypes
|
|
as ``primals``, even if some values are truncated (e.g., from complex to
|
|
float, or from float64 to float32). To avoid truncation, use dtypes in
|
|
``primals`` that match the full range of desired outputs from the transposed
|
|
function. Integer dtypes are not supported.
|
|
|
|
Args:
|
|
fun: the linear function to be transposed.
|
|
*primals: a positional argument tuple of arrays, scalars, or (nested)
|
|
standard Python containers (tuples, lists, dicts, namedtuples, i.e.,
|
|
pytrees) of those types used for evaluating the shape/dtype of
|
|
``fun(*primals)``. These arguments may be real scalars/ndarrays, but that
|
|
is not required: only the ``shape`` and ``dtype`` attributes are accessed.
|
|
See below for an example. (Note that the duck-typed objects cannot be
|
|
namedtuples because those are treated as standard Python containers.)
|
|
|
|
Returns:
|
|
A callable that calculates the transpose of ``fun``. Valid input into this
|
|
function must have the same shape/dtypes/structure as the result of
|
|
``fun(*primals)``. Output will be a tuple, with the same
|
|
shape/dtypes/structure as ``primals``.
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> f = lambda x, y: 0.5 * x - 0.5 * y
|
|
>>> scalar = jax.ShapeDtypeStruct(shape=(), dtype=np.dtype(np.float32))
|
|
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
|
|
>>> f_transpose(1.0)
|
|
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
|
|
"""
|
|
if reduce_axes:
|
|
raise NotImplementedError("reduce_axes argument to transpose is deprecated")
|
|
del reduce_axes
|
|
primals_flat, in_tree = tree_flatten(primals)
|
|
flat_fun, out_tree = flatten_fun_nokwargs(
|
|
lu.wrap_init(fun,
|
|
debug_info=debug_info("linear_transpose", fun, primals, {})),
|
|
in_tree)
|
|
in_avals = map(shaped_abstractify, primals_flat)
|
|
in_dtypes = map(lambda a: a.dtype, in_avals)
|
|
|
|
in_pvals = map(pe.PartialVal.unknown, in_avals)
|
|
jaxpr, out_pvals, const = pe.trace_to_jaxpr_nounits(flat_fun, in_pvals,
|
|
instantiate=True)
|
|
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), True)
|
|
out_avals, _ = unzip2(out_pvals)
|
|
out_dtypes = map(lambda a: a.dtype, out_avals)
|
|
if not (all(dtypes.issubdtype(d, np.inexact) for d in in_dtypes + out_dtypes)
|
|
or all(dtypes.issubdtype(d, np.integer)
|
|
for d in in_dtypes + out_dtypes)):
|
|
raise TypeError("linear_transpose only supports [float or complex] -> "
|
|
"[float or complex], and integer -> integer functions, "
|
|
f"but got {in_dtypes} -> {out_dtypes}.")
|
|
|
|
@api_boundary
|
|
def transposed_fun(const, out_cotangent):
|
|
out_cts, out_tree2 = tree_flatten(out_cotangent)
|
|
if out_tree() != out_tree2:
|
|
raise TypeError("cotangent tree does not match function output, "
|
|
f"expected {out_tree()} but got {out_tree2}")
|
|
if not all(map(core.typecheck, out_avals, out_cts)):
|
|
raise TypeError("cotangent type does not match function output, "
|
|
f"expected {out_avals} but got {out_cts}")
|
|
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
|
|
in_cts = ad.backward_pass(jaxpr, True, const, dummies, out_cts)
|
|
in_cts = map(ad.instantiate_zeros, in_cts)
|
|
return tree_unflatten(in_tree, in_cts)
|
|
|
|
# Ensure that transposed_fun is a PyTree
|
|
return Partial(transposed_fun, const)
|
|
|
|
|
|
@overload
|
|
def make_jaxpr(
|
|
fun: Callable,
|
|
static_argnums: int | Sequence[int] = (),
|
|
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
|
return_shape: Literal[False] = ...,
|
|
) -> Callable[..., core.ClosedJaxpr]:
|
|
...
|
|
|
|
@overload
|
|
def make_jaxpr(
|
|
fun: Callable,
|
|
static_argnums: int | Sequence[int] = (),
|
|
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
|
return_shape: Literal[True] = ...,
|
|
) -> Callable[..., tuple[core.ClosedJaxpr, Any]]:
|
|
...
|
|
|
|
@partial(api_boundary, repro_api_name="jax.make_japr")
|
|
def make_jaxpr(
|
|
fun: Callable,
|
|
static_argnums: int | Sequence[int] = (),
|
|
axis_env: Sequence[tuple[AxisName, int]] | None = None,
|
|
return_shape: bool = False,
|
|
) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]:
|
|
"""Create a function that returns the jaxpr of ``fun`` given example args.
|
|
|
|
Args:
|
|
fun: The function whose ``jaxpr`` is to be computed. Its positional
|
|
arguments and return value should be arrays, scalars, or standard Python
|
|
containers (tuple/list/dict) thereof.
|
|
static_argnums: See the :py:func:`jax.jit` docstring.
|
|
axis_env: Optional, a sequence of pairs where the first element is an axis
|
|
name and the second element is a positive integer representing the size of
|
|
the mapped axis with that name. This parameter is useful when lowering
|
|
functions that involve parallel communication collectives, and it
|
|
specifies the axis name/size environment that would be set up by
|
|
applications of :py:func:`jax.pmap`.
|
|
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
|
|
wrapped function returns a pair where the first element is the
|
|
``ClosedJaxpr`` representation of ``fun`` and the second element is a
|
|
pytree with the same structure as the output of ``fun`` and where the
|
|
leaves are objects with ``shape`` and ``dtype`` attributes representing
|
|
the corresponding types of the output leaves.
|
|
|
|
Returns:
|
|
A wrapped version of ``fun`` that when applied to example arguments returns
|
|
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
|
|
argument ``return_shape`` is ``True``, then the returned function instead
|
|
returns a pair where the first element is the ``ClosedJaxpr``
|
|
representation of ``fun`` and the second element is a pytree representing
|
|
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
|
|
|
A ``jaxpr`` is JAX's intermediate representation for program traces. The
|
|
``jaxpr`` language is based on the simply-typed first-order lambda calculus
|
|
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
|
|
``jaxpr``, which we can inspect to understand what JAX is doing internally.
|
|
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
|
|
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
|
|
|
|
We do not describe the semantics of the ``jaxpr`` language in detail here, but
|
|
instead give a few examples.
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
|
|
>>> print(f(3.0))
|
|
-0.83602
|
|
>>> jax.make_jaxpr(f)(3.0)
|
|
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
|
|
>>> jax.make_jaxpr(jax.grad(f))(3.0)
|
|
{ lambda ; a:f32[]. let
|
|
b:f32[] = cos a
|
|
c:f32[] = sin a
|
|
_:f32[] = sin b
|
|
d:f32[] = cos b
|
|
e:f32[] = mul 1.0:f32[] d
|
|
f:f32[] = neg e
|
|
g:f32[] = mul f c
|
|
in (g,) }
|
|
"""
|
|
try:
|
|
hash(fun)
|
|
weakref.ref(fun)
|
|
except TypeError:
|
|
fun = partial(fun)
|
|
|
|
@wraps(fun)
|
|
@api_boundary
|
|
def make_jaxpr_f(*args, **kwargs):
|
|
with core.extend_axis_env_nd(axis_env or []):
|
|
traced = jit(fun, static_argnums=static_argnums).trace(*args, **kwargs)
|
|
# `jit` converts tracers in consts to args but `make_jaxpr` callers expect
|
|
# consts not to be converted.
|
|
num_consts = traced._num_consts
|
|
if num_consts:
|
|
jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr, num_consts)
|
|
jaxpr = core.ClosedJaxpr(jaxpr_, traced._consts)
|
|
else:
|
|
jaxpr = traced.jaxpr
|
|
if return_shape:
|
|
return jaxpr, traced.out_info
|
|
return jaxpr
|
|
|
|
make_jaxpr_f.__module__ = "jax"
|
|
if hasattr(fun, "__qualname__"):
|
|
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
|
|
if hasattr(fun, "__name__"):
|
|
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
|
|
return make_jaxpr_f
|
|
|
|
def _infer_src_sharding(src, x, x_aval) -> Sharding | None:
|
|
if src is not None:
|
|
return src
|
|
if isinstance(x, array.ArrayImpl):
|
|
return x.sharding
|
|
if isinstance(x, core.Tracer):
|
|
val = x.to_concrete_value()
|
|
if val is not None and isinstance(val, array.ArrayImpl):
|
|
return val.sharding
|
|
if x_aval is not core.abstract_token and x_aval.sharding.mesh.are_all_axes_explicit:
|
|
return x_aval.sharding.update(
|
|
memory_kind=core.mem_space_to_kind(x_aval.memory_space))
|
|
return None
|
|
|
|
|
|
@util.cache(max_size=2048, trace_context_in_key=False)
|
|
def _check_string_compatible_sharding(s):
|
|
"""Checks if target devices are compatible with string arrays."""
|
|
if isinstance(s, xc.Device) and s.device_kind == "cpu":
|
|
return
|
|
if (isinstance(s, Sharding)
|
|
and s._internal_device_list[0].device_kind == "cpu"):
|
|
return
|
|
raise TypeError(
|
|
"String arrays can only be sharded to CPU devices. Received"
|
|
f" unsupported device or sharding: {s}")
|
|
|
|
|
|
@util.cache(max_size=2048, trace_context_in_key=False)
|
|
def _check_sharding(aval, s):
|
|
if (s is not None and
|
|
not isinstance(s, (xc.Device, Sharding, Format, core.MemorySpace))):
|
|
raise ValueError(
|
|
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
|
|
" `jax.Device`, `Format`, `jax.memory.Space` or a pytree of these"
|
|
f" values. Received invalid value: {s}")
|
|
if isinstance(aval, core.ShapedArray) and aval.dtype == dtypes.string_dtype:
|
|
_check_string_compatible_sharding(s)
|
|
|
|
if isinstance(s, Sharding):
|
|
if isinstance(aval, core.AbstractToken):
|
|
aval = core.get_token_aval()
|
|
pjit.pjit_check_aval_sharding(
|
|
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False
|
|
)
|
|
s.shard_shape(aval.shape) # should raise an Error if incompatible
|
|
|
|
def pspec_to_sharding(name, val):
|
|
if isinstance(val, P):
|
|
mesh = get_concrete_mesh()
|
|
if mesh.empty:
|
|
raise ValueError(
|
|
"Please set a mesh via `jax.set_mesh` if a PartitionSpec is"
|
|
f" passed to {name}")
|
|
return NamedSharding(mesh, val)
|
|
return val
|
|
|
|
|
|
def device_put(
|
|
x,
|
|
device: None | xc.Device | Sharding | P | Format | Any = None,
|
|
*, src: None | xc.Device | Sharding | P | Format | Any = None,
|
|
donate: bool | Any = False, may_alias: bool | None | Any = None):
|
|
"""Transfers ``x`` to ``device``.
|
|
|
|
Args:
|
|
x: An array, scalar, or (nested) standard Python container thereof.
|
|
device: The (optional) :py:class:`Device`, :py:class:`Sharding`, or a
|
|
(nested) :py:class:`Sharding` in standard Python container (must be a tree
|
|
prefix of ``x``), representing the device(s) to which ``x`` should be
|
|
transferred. If given, then the result is committed to the device(s).
|
|
src: The (optional) :py:class:`Device`, :py:class:`Sharding`, or a (nested)
|
|
:py:class:`Sharding` in standard Python container (must be a tree prefix
|
|
of ``x``), representing the device(s) on which ``x`` belongs.
|
|
donate: bool or a (nested) bool in standard Python container (must be a tree
|
|
prefix of ``x``). If True, ``x`` can be overwritten and marked deleted in
|
|
the caller. This is best effort. JAX will donate if possible, otherwise it
|
|
won't. The input buffer (in the future) will always be deleted if donated.
|
|
may_alias: bool or None or a (nested) bool in standard Python container
|
|
(must be a tree prefix of ``x``). If False, `x` will be copied. If true,
|
|
`x` may be aliased depending on the runtime's implementation.
|
|
|
|
Returns:
|
|
A copy of ``x`` that resides on ``device``.
|
|
|
|
If the ``device`` parameter is ``None``, then this operation behaves like the
|
|
identity function if the operand is on any device already, otherwise it
|
|
transfers the data to the default device, uncommitted.
|
|
|
|
This function is always asynchronous, i.e. returns immediately without
|
|
blocking the calling Python thread until any transfers are completed.
|
|
"""
|
|
with config.explicit_device_put_scope():
|
|
x_flat, treedef = tree_flatten(x)
|
|
x_avals = [shaped_abstractify(x) for x in x_flat]
|
|
if (device is None or
|
|
isinstance(device, (xc.Device, Sharding, core.MemorySpace))):
|
|
device_flat = [device] * len(x_flat)
|
|
else:
|
|
device_flat = flatten_axes("device_put device", treedef, device)
|
|
|
|
if (src is None or
|
|
isinstance(src, (xc.Device, Sharding, core.MemorySpace))):
|
|
src_flat = list(map(partial(_infer_src_sharding, src), x_flat, x_avals))
|
|
else:
|
|
src_flat = flatten_axes("device_put source", treedef, src)
|
|
src_flat = list(map(_infer_src_sharding, src_flat, x_flat, x_avals))
|
|
|
|
device_flat = map(partial(pspec_to_sharding, 'device_put'), device_flat)
|
|
src_flat = map(partial(pspec_to_sharding, 'device_put'), src_flat)
|
|
|
|
if isinstance(donate, bool):
|
|
donate_flat = [donate] * len(x_flat)
|
|
else:
|
|
donate_flat = flatten_axes("device_put donate", treedef, donate)
|
|
|
|
if isinstance(may_alias, bool):
|
|
may_alias_flat = [may_alias] * len(x_flat)
|
|
else:
|
|
may_alias_flat = flatten_axes("device_put may_alias", treedef, may_alias)
|
|
|
|
copy_semantics = []
|
|
for m, d in zip(may_alias_flat, donate_flat):
|
|
if m and d:
|
|
raise ValueError('may_alias and donate cannot be True at the same time.')
|
|
if m is None:
|
|
m = not d
|
|
if m and not d:
|
|
copy_semantics.append(dispatch.ArrayCopySemantics.REUSE_INPUT)
|
|
elif not m and d:
|
|
copy_semantics.append(dispatch.ArrayCopySemantics.DONATE_INPUT)
|
|
else:
|
|
assert not m and not d
|
|
copy_semantics.append(dispatch.ArrayCopySemantics.ALWAYS_COPY)
|
|
|
|
dst_avals = []
|
|
for x_aval, d in zip(x_avals, device_flat):
|
|
aval = dispatch.update_dp_aval(x_aval, d)
|
|
dst_avals.append(aval)
|
|
_check_sharding(aval, d)
|
|
if core.trace_state_clean():
|
|
out_flat = dispatch._batched_device_put_impl(
|
|
*x_flat, devices=device_flat, srcs=src_flat,
|
|
copy_semantics=copy_semantics, dst_avals=dst_avals)
|
|
else:
|
|
out_flat = dispatch.device_put_p.bind(
|
|
*x_flat, devices=tuple(device_flat), srcs=tuple(src_flat),
|
|
copy_semantics=tuple(copy_semantics))
|
|
return tree_unflatten(treedef, out_flat)
|
|
|
|
|
|
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # noqa: F811
|
|
"""Transfer array shards to specified devices and form Array(s).
|
|
|
|
Args:
|
|
shards: A sequence of arrays, scalars, or (nested) standard Python
|
|
containers thereof representing the shards to be stacked together to form
|
|
the output. The length of ``shards`` must equal the length of ``devices``.
|
|
devices: A sequence of :py:class:`Device` instances representing the devices
|
|
to which corresponding shards in ``shards`` will be transferred.
|
|
|
|
This function is always asynchronous, i.e. returns immediately.
|
|
|
|
Returns:
|
|
A Array or (nested) Python container thereof representing the
|
|
elements of ``shards`` stacked together, with each shard backed by physical
|
|
device memory specified by the corresponding entry in ``devices``.
|
|
|
|
Examples:
|
|
Passing a list of arrays for ``shards`` results in a sharded array
|
|
containing a stacked version of the inputs:
|
|
|
|
>>> import jax
|
|
>>> devices = jax.local_devices()
|
|
>>> x = [jax.numpy.ones(5) for device in devices]
|
|
>>> y = jax.device_put_sharded(x, devices) # doctest: +SKIP
|
|
>>> np.allclose(y, jax.numpy.stack(x)) # doctest: +SKIP
|
|
True
|
|
|
|
Passing a list of nested container objects with arrays at the leaves for
|
|
``shards`` corresponds to stacking the shards at each leaf. This requires
|
|
all entries in the list to have the same tree structure:
|
|
|
|
>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))]
|
|
>>> y = jax.device_put_sharded(x, devices) # doctest: +SKIP
|
|
>>> type(y) # doctest: +SKIP
|
|
<class 'tuple'>
|
|
>>> y0 = jax.device_put_sharded([a for a, b in x], devices) # doctest: +SKIP
|
|
>>> y1 = jax.device_put_sharded([b for a, b in x], devices) # doctest: +SKIP
|
|
>>> np.allclose(y[0], y0) # doctest: +SKIP
|
|
True
|
|
>>> np.allclose(y[1], y1) # doctest: +SKIP
|
|
True
|
|
|
|
See Also:
|
|
- device_put
|
|
- device_put_replicated
|
|
"""
|
|
# TODO(jakevdp): provide a default for devices that considers both local
|
|
# devices and pods
|
|
if not isinstance(shards, Sequence):
|
|
raise TypeError("device_put_sharded `shards` input must be a sequence; "
|
|
f"got {type(shards)}")
|
|
if len(shards) != len(devices):
|
|
raise ValueError(f"len(shards) = {len(shards)} must equal "
|
|
f"len(devices) = {len(devices)}.")
|
|
|
|
def _device_put_sharded(*xs):
|
|
avals = [core.typeof(x) for x in xs]
|
|
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
|
|
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
|
|
if a1 != a2)
|
|
raise ValueError("the shards passed to device_put_sharded must have "
|
|
f"consistent shape and dtype, but got {a1} and {a2}.")
|
|
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
|
|
mesh = Mesh(np.array(devices), ("_device_put_sharded",))
|
|
sharding = NamedSharding(mesh, P("_device_put_sharded"))
|
|
if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended):
|
|
return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices)
|
|
ys = []
|
|
for x in xs:
|
|
if not isinstance(x, (np.ndarray, basearray.Array)):
|
|
x = np.asarray(x)
|
|
ys.append(x[None])
|
|
return pxla.batched_device_put(stacked_aval, sharding, ys, list(devices))
|
|
|
|
|
|
with config.explicit_device_put_scope():
|
|
return tree_map(_device_put_sharded, *shards)
|
|
|
|
|
|
def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
|
"""Transfer array(s) to each specified device and form Array(s).
|
|
|
|
Args:
|
|
x: an array, scalar, or (nested) standard Python container thereof
|
|
representing the array to be replicated to form the output.
|
|
devices: A sequence of :py:class:`Device` instances representing the devices
|
|
to which ``x`` will be transferred.
|
|
|
|
This function is always asynchronous, i.e. returns immediately.
|
|
|
|
Returns:
|
|
An Array or (nested) Python container thereof representing the
|
|
value of ``x`` broadcasted along a new leading axis of size
|
|
``len(devices)``, with each slice along that new leading axis backed by
|
|
memory on the device specified by the corresponding entry in ``devices``.
|
|
|
|
Examples:
|
|
Passing an array:
|
|
|
|
>>> import jax
|
|
>>> devices = jax.local_devices()
|
|
>>> x = jax.numpy.array([1., 2., 3.])
|
|
>>> y = jax.device_put_replicated(x, devices) # doctest: +SKIP
|
|
>>> np.allclose(y, jax.numpy.stack([x for _ in devices])) # doctest: +SKIP
|
|
True
|
|
|
|
See Also:
|
|
- device_put
|
|
- device_put_sharded
|
|
"""
|
|
if not isinstance(devices, Sequence) or not devices:
|
|
raise ValueError("`devices` argument to `device_put_replicated must be "
|
|
"a non-empty sequence.")
|
|
def _device_put_replicated(x):
|
|
aval = core.unmapped_aval(len(devices), 0, core.typeof(x))
|
|
assert isinstance(aval, ShapedArray)
|
|
if isinstance(x, (np.ndarray, basearray.Array)):
|
|
buf = device_put(x[None], devices[0])
|
|
else:
|
|
buf = device_put(x, devices[0])[None]
|
|
mesh = Mesh(np.array(devices), ("_device_put_replicated",))
|
|
sharding = NamedSharding(mesh, P("_device_put_replicated"))
|
|
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
|
return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices)
|
|
return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices)
|
|
|
|
with config.explicit_device_put_scope():
|
|
return tree_map(_device_put_replicated, x)
|
|
|
|
|
|
# TODO(mattjj): consider revising
|
|
def _device_get(x):
|
|
if isinstance(x, core.Tracer):
|
|
return x
|
|
|
|
# Extended dtypes dispatch via their device_get rule.
|
|
if isinstance(x, basearray.Array) and dtypes.issubdtype(x.dtype, dtypes.extended):
|
|
bufs, tree = tree_util.dispatch_registry.flatten(x)
|
|
return tree.unflatten(device_get(bufs))
|
|
|
|
# Other types dispatch via their __array__ method.
|
|
try:
|
|
toarray = x.__array__
|
|
except AttributeError:
|
|
return x
|
|
else:
|
|
return toarray()
|
|
|
|
def device_get(x: Any):
|
|
"""Transfer ``x`` to host.
|
|
|
|
If ``x`` is a pytree, then the individual buffers are copied in parallel.
|
|
|
|
Args:
|
|
x: An array, scalar, Array or (nested) standard Python container thereof
|
|
representing the array to be transferred to host.
|
|
|
|
Returns:
|
|
An array or (nested) Python container thereof representing the
|
|
value of ``x``.
|
|
|
|
Examples:
|
|
Passing a Array:
|
|
|
|
>>> import jax
|
|
>>> x = jax.numpy.array([1., 2., 3.])
|
|
>>> jax.device_get(x)
|
|
array([1., 2., 3.], dtype=float32)
|
|
|
|
Passing a scalar (has no effect):
|
|
|
|
>>> jax.device_get(1)
|
|
1
|
|
|
|
See Also:
|
|
- device_put
|
|
- device_put_sharded
|
|
- device_put_replicated
|
|
"""
|
|
with config.explicit_device_get_scope():
|
|
for y in tree_leaves(x):
|
|
try:
|
|
y.copy_to_host_async()
|
|
except AttributeError:
|
|
pass
|
|
return tree_map(_device_get, x)
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.eval_shape")
|
|
def eval_shape(fun: Callable, *args, **kwargs):
|
|
"""Compute the shape/dtype of ``fun`` without any FLOPs.
|
|
|
|
This utility function is useful for performing shape inference. Its
|
|
input/output behavior is defined by::
|
|
|
|
def eval_shape(fun, *args, **kwargs):
|
|
out = fun(*args, **kwargs)
|
|
shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype)
|
|
return jax.tree_util.tree_map(shape_dtype_struct, out)
|
|
|
|
But instead of applying ``fun`` directly, which might be expensive, it uses
|
|
JAX's abstract interpretation machinery to evaluate the shapes without doing
|
|
any FLOPs.
|
|
|
|
Using :py:func:`eval_shape` can also catch shape errors, and will raise same
|
|
shape errors as evaluating ``fun(*args, **kwargs)``.
|
|
|
|
Args:
|
|
fun: The function whose output shape should be evaluated.
|
|
*args: a positional argument tuple of arrays, scalars, or (nested) standard
|
|
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
|
|
those types. Since only the ``shape`` and ``dtype`` attributes are
|
|
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
|
|
that duck-types as ndarrays (note however that duck-typed objects cannot
|
|
be namedtuples because those are treated as standard Python containers).
|
|
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
|
|
Python containers (pytrees) of those types. As in ``args``, array values
|
|
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
|
|
|
|
Returns:
|
|
out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
|
|
|
|
For example:
|
|
|
|
>>> import jax
|
|
>>> import jax.numpy as jnp
|
|
>>>
|
|
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
|
|
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
|
|
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
|
|
>>> out = jax.eval_shape(f, A, x) # no FLOPs performed
|
|
>>> print(out.shape)
|
|
(2000, 1000)
|
|
>>> print(out.dtype)
|
|
float32
|
|
|
|
All arguments passed via :func:`eval_shape` will be treated as dynamic;
|
|
static arguments can be included via closure, for example using :func:`functools.partial`:
|
|
|
|
>>> import jax
|
|
>>> from jax import lax
|
|
>>> from functools import partial
|
|
>>> import jax.numpy as jnp
|
|
>>>
|
|
>>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32)
|
|
>>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32)
|
|
>>>
|
|
>>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME")
|
|
>>> out = jax.eval_shape(conv_same, x, kernel)
|
|
>>> print(out.shape)
|
|
(1, 32, 28, 28)
|
|
>>> print(out.dtype)
|
|
float32
|
|
"""
|
|
if type(fun) is xc._xla.PjitFunction:
|
|
return fun.trace(*args, **kwargs).out_info # pyrefly: ignore[missing-attribute]
|
|
try: hash(fun)
|
|
except TypeError: fun = partial(fun)
|
|
return jit(fun).trace(*args, **kwargs).out_info
|
|
|
|
|
|
@partial(api_boundary, repro_api_name="jax.named_call")
|
|
def named_call(
|
|
fun: F,
|
|
*,
|
|
name: str | None = None,
|
|
) -> F:
|
|
"""Adds a user specified name to a function when staging out JAX computations.
|
|
|
|
When staging out computations for just-in-time compilation to XLA (or other
|
|
backends such as TensorFlow) JAX runs your Python program but by default does
|
|
not preserve any of the function names or other metadata associated with it.
|
|
This can make debugging the staged out (and/or compiled) representation of
|
|
your program complicated because there is limited context information for each
|
|
operation being executed.
|
|
|
|
`named_call` tells JAX to stage the given function out as a subcomputation
|
|
with a specific name. When the staged out program is compiled with XLA these
|
|
named subcomputations are preserved and show up in debugging utilities like
|
|
the TensorFlow Profiler in TensorBoard. Names are also preserved when staging
|
|
out JAX programs to TensorFlow using :func:`experimental.jax2tf.convert`.
|
|
|
|
Args:
|
|
fun: Function to be wrapped. This can be any Callable.
|
|
name: Optional. The prefix to use to name all sub computations created
|
|
within the name scope. Use the fun.__name__ if not specified.
|
|
|
|
Returns:
|
|
A version of ``fun`` that is wrapped in a ``named_scope``.
|
|
"""
|
|
if name is None:
|
|
name = fun.__name__
|
|
|
|
return source_info_util.extend_name_stack(name)(fun)
|
|
|
|
|
|
def named_scope(
|
|
name: str,
|
|
) -> source_info_util.ExtendNameStackContextManager:
|
|
"""A context manager that adds a user specified name to the JAX name stack.
|
|
|
|
When staging out computations for just-in-time compilation to XLA (or other
|
|
backends such as TensorFlow) JAX does not, by default, preserve the names
|
|
(or other source metadata) of Python functions it encounters.
|
|
This can make debugging the staged out (and/or compiled) representation of
|
|
your program complicated because there is limited context information for each
|
|
operation being executed.
|
|
|
|
``named_scope`` tells JAX to stage the given function with additional
|
|
annotations on the underlying operations. JAX internally keeps track of these
|
|
annotations in a name stack. When the staged out program is compiled with XLA
|
|
these annotations are preserved and show up in debugging utilities like the
|
|
TensorFlow Profiler in TensorBoard. Names are also preserved when staging out
|
|
JAX programs to TensorFlow using :func:`experimental.jax2tf.convert`.
|
|
|
|
|
|
Args:
|
|
name: The prefix to use to name all operations created within the name
|
|
scope.
|
|
Yields:
|
|
Yields ``None``, but enters a context in which `name` will be appended to
|
|
the active name stack.
|
|
|
|
Examples:
|
|
``named_scope`` can be used as a context manager inside compiled functions:
|
|
|
|
>>> import jax
|
|
>>>
|
|
>>> @jax.jit
|
|
... def layer(w, x):
|
|
... with jax.named_scope("dot_product"):
|
|
... logits = w.dot(x)
|
|
... with jax.named_scope("activation"):
|
|
... return jax.nn.relu(logits)
|
|
|
|
It can also be used as a decorator:
|
|
|
|
>>> @jax.jit
|
|
... @jax.named_scope("layer")
|
|
... def layer(w, x):
|
|
... logits = w.dot(x)
|
|
... return jax.nn.relu(logits)
|
|
"""
|
|
if not isinstance(name, str):
|
|
raise TypeError("named_scope name argument must be a string.")
|
|
return source_info_util.extend_name_stack(name)
|
|
|
|
def effects_barrier():
|
|
"""Waits until existing functions have completed any side-effects."""
|
|
dispatch.runtime_tokens.block_until_ready()
|
|
|
|
def block_until_ready(x):
|
|
"""
|
|
Tries to call a ``block_until_ready`` method on pytree leaves.
|
|
|
|
Args:
|
|
x: a pytree, usually with at least some JAX array instances at its leaves.
|
|
|
|
Returns:
|
|
A pytree with the same structure and values of the input, where the values
|
|
of all JAX array leaves are ready.
|
|
"""
|
|
def try_to_block(x):
|
|
try:
|
|
return x.block_until_ready()
|
|
except AttributeError:
|
|
return x
|
|
|
|
arrays = []
|
|
for leaf in tree_leaves(x):
|
|
if isinstance(leaf, array.ArrayImpl):
|
|
arrays.append(leaf)
|
|
else:
|
|
try_to_block(leaf)
|
|
|
|
if not arrays:
|
|
# `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
|
|
# jax.Array.
|
|
pass
|
|
elif len(arrays) == 1:
|
|
# Fast path for single array.
|
|
try_to_block(arrays[0])
|
|
else:
|
|
# Optimized for multiple arrays.
|
|
xc.batched_block_until_ready(arrays)
|
|
|
|
return x
|
|
|
|
def copy_to_host_async(x):
|
|
"""
|
|
Tries to call a ``copy_to_host_async`` method on pytree leaves.
|
|
|
|
For each leaf this method will try to call the ``copy_to_host_async`` method
|
|
on the leaf. If the leaf is not a JAX array, or if the leaf does not have a
|
|
``copy_to_host_async`` method, then this method will do nothing to the leaf.
|
|
|
|
Args:
|
|
x: a pytree, usually with at least some JAX array instances at its leaves.
|
|
|
|
Returns:
|
|
A pytree with the same structure and values of the input, where the host
|
|
copy of the values of all JAX array leaves are started.
|
|
"""
|
|
for leaf in tree_leaves(x):
|
|
try:
|
|
copy_fn = leaf.copy_to_host_async
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
copy_fn()
|
|
|
|
return x
|
|
|
|
|
|
def clear_backends():
|
|
"""
|
|
Clear all backend clients so that new backend clients can be created later.
|
|
"""
|
|
xb._clear_backends()
|
|
util.clear_all_caches()
|
|
pjit._cpp_pjit_cache_fun_only.clear()
|
|
pjit._cpp_pjit_cache_explicit_attributes.clear()
|
|
xc._xla.PjitFunctionCache.clear_all()
|
|
|
|
@atexit.register
|
|
def clean_up():
|
|
if xb._default_backend is not None:
|
|
clear_backends()
|
|
clear_caches()
|
|
|
|
# Shut down distributed system if it exists. Otherwise, this is a no-op.
|
|
distributed.shutdown()
|
|
|
|
def live_arrays(platform=None):
|
|
"""Return all live arrays in the backend for `platform`.
|
|
|
|
If platform is None, it is the default backend.
|
|
"""
|
|
return xb.get_backend(platform).live_arrays()
|
|
|
|
def clear_caches():
|
|
"""Clear all compilation and staging caches.
|
|
|
|
This doesn't clear the persistent cache; to disable it (e.g. for benchmarks),
|
|
set the jax_enable_compilation_cache config option to False.
|
|
"""
|
|
# Clear all lu.cache, util.cache and util.weakref_lru_cache instances
|
|
# (used for staging and Python-dispatch compiled executable caches).
|
|
util.clear_all_caches()
|
|
# Clear all C++ compiled executable caches for pjit
|
|
pjit._cpp_pjit_cache_fun_only.clear()
|
|
pjit._cpp_pjit_cache_explicit_attributes.clear()
|
|
xc._xla.PjitFunctionCache.clear_all()
|