Files
opencv/venv/lib/python3.12/site-packages/jax/_src/api.py
T
2026-05-06 19:47:31 +07:00

2653 lines
108 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX user-facing transformations and utilities.
The transformations here mostly wrap internal transformations, providing
convenience flags to control behavior and handling Python containers of
arguments and outputs. The Python containers handled are pytrees (see
tree_util.py), which include nested tuples/lists/dicts, where the leaves are
arrays.
"""
from __future__ import annotations
import atexit
import collections
from collections.abc import Callable, Hashable, Iterable, Sequence
import dataclasses
from functools import partial
import inspect
from typing import (Any, Literal, Optional, TypeVar, overload,
cast, TYPE_CHECKING)
import weakref
import numpy as np
from contextlib import contextmanager
from jax._src import api_util
from jax._src import linear_util as lu
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, tree_structure, tree_transpose,
tree_leaves, Partial, PyTreeDef, keystr, generate_key_paths,
tree_flatten_with_path, equality_errors_pytreedef, register_pytree_node,
register_dataclass)
from jax._src import config
from jax._src import core
from jax._src import dispatch
from jax._src import array
from jax._src import basearray
from jax._src import distributed
from jax._src import dtypes
from jax._src.dtypes import canonicalize_value
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import pjit
from jax._src import xla_bridge as xb
from jax._src.core import eval_jaxpr, shaped_abstractify, ShapedArray, typeof
from jax._src.api_util import (
flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial,
flatten_axes, _ensure_index, apply_flat_fun_nokwargs,
check_callable, debug_info)
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding
from jax._src.mesh import get_concrete_mesh, get_abstract_mesh, Mesh
from jax._src.sharding_impls import PartitionSpec as P, NamedSharding
from jax._src.layout import Format
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src.util import unzip2, safe_map, safe_zip, wraps
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
config_ext = xc._xla.config
traceback_util.register_exclusion(__file__)
_dtype = dtypes.dtype
AxisName = Hashable
Device = xc.Device
# These TypeVars are used below to express the fact that function types
# (i.e. call signatures) are invariant under the vmap transformation.
F = TypeVar("F", bound=Callable)
T = TypeVar("T")
U = TypeVar("U")
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
ShapeDtypeStruct = core.ShapeDtypeStruct
@api_boundary
def _nan_check_posthook(fun, args, kwargs, output):
"""Hook function called by the C++ jit/pmap to perform NaN checking."""
buffers = []
for leaf in tree_leaves(output):
if hasattr(leaf, "addressable_shards"):
buffers.extend([shard.data for shard in leaf.addressable_shards])
try:
dispatch.check_special(pjit.jit_p.name, buffers)
except api_util.InternalFloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value
if hasattr(fun, '_fun'):
f = fun._fun
if getattr(f, '_apply_primitive', False):
raise FloatingPointError(f"invalid value ({e.ty}) encountered in {f.__qualname__}") from None
# compiled_fun can only raise in this case
api_util.maybe_recursive_nan_check(e, f, args, kwargs)
raise AssertionError("Unreachable") from e
else:
# TODO(emilyaf): Shouldn't need this fallback.
raise
_post_hook_state = config_ext.Config[Optional[Callable]](
"post_hook", None, include_in_jit_key=False
)
jax_jit.set_post_hook_state(_post_hook_state)
def _update_debug_special_global(_):
if config._read("jax_debug_nans") or config._read("jax_debug_infs"):
_post_hook_state.set_global(_nan_check_posthook)
else:
_post_hook_state.set_global(None)
def _update_debug_special_thread_local(_):
if (config.debug_nans.get_local() == True or
config.debug_infs.get_local() == True):
_post_hook_state.set_local(_nan_check_posthook)
else:
_post_hook_state.set_local(None)
config.debug_nans._add_hooks(_update_debug_special_global,
_update_debug_special_thread_local)
config.debug_infs._add_hooks(_update_debug_special_global,
_update_debug_special_thread_local)
float0 = dtypes.float0
class NotSpecified:
"""Sentinel for use in jax.jit"""
def __repr__(self):
return "<not-specified>"
@overload
def jit(
fun: Callable, /, *,
in_shardings: Any = ...,
out_shardings: Any = ...,
static_argnums: int | Sequence[int] | None = ...,
static_argnames: str | Iterable[str] | None = ...,
donate_argnums: int | Sequence[int] | None = ...,
donate_argnames: str | Iterable[str] | None = ...,
keep_unused: bool = ...,
device: xc.Device | None = ...,
backend: str | None = ...,
inline: bool = ...,
compiler_options: dict[str, Any] | None = ...,
) -> pjit.JitWrapped:
...
@overload
def jit(
*,
in_shardings: Any = ...,
out_shardings: Any = ...,
static_argnums: int | Sequence[int] | None = ...,
static_argnames: str | Iterable[str] | None = ...,
donate_argnums: int | Sequence[int] | None = ...,
donate_argnames: str | Iterable[str] | None = ...,
keep_unused: bool = ...,
device: xc.Device | None = ...,
backend: str | None = ...,
inline: bool = ...,
compiler_options: dict[str, Any] | None = ...,
) -> Callable[[Callable], pjit.JitWrapped]:
...
def jit(
fun: Callable | NotSpecified = NotSpecified(), /, *,
in_shardings: Any = sharding_impls.UNSPECIFIED,
out_shardings: Any = sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
static_argnames: str | Iterable[str] | None = None,
donate_argnums: int | Sequence[int] | None = None,
donate_argnames: str | Iterable[str] | None = None,
keep_unused: bool = False,
device: xc.Device | None = None,
backend: str | None = None,
inline: bool = False,
compiler_options: dict[str, Any] | None = None,
) -> pjit.JitWrapped | Callable[[Callable], pjit.JitWrapped]:
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
fun: Function to be jitted. ``fun`` should be a pure function.
The arguments and return value of ``fun`` should be arrays, scalar, or
(nested) standard Python containers (tuple/list/dict) thereof. Positional
arguments indicated by ``static_argnums`` can be any hashable type. Static
arguments are included as part of a compilation cache key, which is why
hash and equality operators must be defined. JAX keeps a weak reference to
``fun`` for use as a compilation cache key, so the object ``fun`` must be
weakly-referenceable. Starting in JAX v0.8.1, when ``fun`` is omitted,
the return value will be a partially-evaluated function to allow the
decorator factory pattern (see Examples below).
in_shardings: optional, a :py:class:`Sharding` or pytree with
:py:class:`Sharding` leaves and structure that is a tree prefix of the
positional arguments tuple to ``fun``. If provided, the positional
arguments passed to ``fun`` must have shardings that are compatible with
``in_shardings`` or an error is raised, and the compiled computation has
input shardings corresponding to ``in_shardings``. If not provided, the
compiled computation's input shardings are inferred from argument
shardings.
out_shardings: optional, a :py:class:`Sharding` or pytree with
:py:class:`Sharding` leaves and structure that is a tree prefix of the
output of ``fun``. If provided, it has the same effect as applying
:py:func:`jax.lax.with_sharding_constraint` to the output of ``fun``.
static_argnums: optional, an int or collection of ints that specify which
positional arguments to treat as static (trace- and compile-time
constant).
Static arguments should be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and immutable. Otherwise, they can be arbitrary
Python objects. Calling the jitted function with different values for
these constants will trigger recompilation. Arguments that are not
array-like or containers thereof must be marked as static.
If neither ``static_argnums`` nor ``static_argnames`` is provided, no
arguments are treated as static. If ``static_argnums`` is not provided but
``static_argnames`` is, or vice versa, JAX uses
:code:`inspect.signature(fun)` to find any positional arguments that
correspond to ``static_argnames``
(or vice versa). If both ``static_argnums`` and ``static_argnames`` are
provided, ``inspect.signature`` is not used, and only actual
parameters listed in either ``static_argnums`` or ``static_argnames`` will
be treated as static.
static_argnames: optional, a string or collection of strings specifying
which named arguments to treat as static (compile-time constant). See the
comment on ``static_argnums`` for details. If not
provided but ``static_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
donate_argnums: optional, collection of integers to specify which positional
argument buffers can be overwritten by the computation and marked deleted
in the caller. It is safe to donate argument buffers if you no longer need
them once the computation has started. In some cases XLA can make use of
donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation; JAX
will raise an error if you try to. By default, no argument buffers are
donated.
If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no
arguments are donated. If ``donate_argnums`` is not provided but
``donate_argnames`` is, or vice versa, JAX uses
:code:`inspect.signature(fun)` to find any positional arguments that
correspond to ``donate_argnames``
(or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are
provided, ``inspect.signature`` is not used, and only actual
parameters listed in either ``donate_argnums`` or ``donate_argnames`` will
be donated.
For more details on buffer donation see the
`FAQ <https://docs.jax.dev/en/latest/faq.html#buffer-donation>`_.
donate_argnames: optional, a string or collection of strings specifying
which named arguments are donated to the computation. See the
comment on ``donate_argnums`` for details. If not
provided but ``donate_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
keep_unused: optional boolean. If `False` (the default), arguments that JAX
determines to be unused by `fun` *may* be dropped from resulting compiled
XLA executables. Such arguments will not be transferred to the device nor
provided to the underlying executable. If `True`, unused arguments will
not be pruned.
device: This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited
from XLA's DeviceAssignment logic and is usually to use
``jax.devices()[0]``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
inline: Optional boolean. Specify whether this function should be inlined
into enclosing jaxprs. Default False.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation.
Examples:
In the following example, ``selu`` can be compiled into a single fused kernel
by XLA:
>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.key(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x)) # doctest: +SKIP
[-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232 0.76827 0.59566 ]
Starting in JAX v0.8.1, :func:`jit` supports the decorator factory pattern
for specifying optional keywords:
>>> @jax.jit(static_argnames=['n'])
... def g(x, n):
... for i in range(n):
... x = x ** 2
... return x
>>>
>>> g(jnp.arange(4), 3)
Array([ 0, 1, 256, 6561], dtype=int32)
For compatiblity with older JAX versions, a common pattern is to use
:func:`functools.partial`:
>>> from functools import partial
>>>
>>> @partial(jax.jit, static_argnames=['n'])
... def g(x, n):
... for i in range(n):
... x = x ** 2
... return x
>>>
>>> g(jnp.arange(4), 3)
Array([ 0, 1, 256, 6561], dtype=int32)
"""
kwds = dict(
in_shardings=in_shardings, out_shardings=out_shardings,
static_argnums=static_argnums, static_argnames=static_argnames,
donate_argnums=donate_argnums, donate_argnames=donate_argnames,
keep_unused=keep_unused, device=device, backend=backend, inline=inline,
compiler_options=compiler_options, use_resource_env=False)
if isinstance(fun, NotSpecified):
return lambda fun: pjit.make_jit(fun, **kwds)
else:
return pjit.make_jit(fun, **kwds)
if not TYPE_CHECKING:
# TODO(slebedev): This ought to be a decorator, but it seems it makes
# pytype ignore the overloads
jit = api_boundary(jit, repro_api_name="jax.jit")
@contextmanager
def disable_jit(disable: bool = True):
"""Context manager that disables :py:func:`jit` behavior under its dynamic context.
For debugging, it is useful to have a mechanism that disables :py:func:`jit`
everywhere in a dynamic context. Note that this not only disables explicit
uses of :func:`jit` by the user, but will also remove any implicit JIT compilation
used by the JAX library: this includes implicit JIT computation of `body` and
`cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and
:func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions,
and any other case where :func:`jit` is used within an API's implementation.
Note however that even under `disable_jit`, individual primitive operations
will still be compiled by XLA as in normal eager op-by-op execution.
Values that have a data dependence on the arguments to a jitted function are
traced and abstracted. For example, an abstract value may be a
:py:class:`ShapedArray` instance, representing the set of all possible arrays
with a given shape and dtype, but not representing one concrete array with
specific values. You might notice those if you use a benign side-effecting
operation in a jitted function, like a print:
>>> import jax
>>>
>>> @jax.jit
... def f(x):
... y = x * 2
... print("Value of y is", y)
... return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is JitTracer(int32[3])
[5 7 9]
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,
which represents an array with a fixed shape and type but an arbitrary value.
The value of ``y`` is also traced. If we want to see a concrete value while
debugging, and avoid the tracer too, we can use the :py:func:`disable_jit`
context manager:
>>> import jax
>>>
>>> with jax.disable_jit():
... print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]
"""
with config.disable_jit(disable):
yield
@partial(api_boundary, repro_api_name="jax.grad")
def grad(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = ()) -> Callable:
"""Creates a function that evaluates the gradient of ``fun``.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers.
Argument arrays in the positions specified by ``argnums`` must be of
inexact (i.e., floating-point or complex) type. It
should return a scalar (which includes arrays with shape ``()`` but not
arrays with shape ``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the gradient
of ``fun``. If ``argnums`` is an integer then the gradient has the same
shape and type as the positional argument indicated by that integer. If
argnums is a tuple of integers, the gradient is a tuple of values with the
same shapes and types as the corresponding arguments. If ``has_aux`` is True
then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> import jax
>>>
>>> grad_tanh = jax.grad(jax.numpy.tanh)
>>> print(grad_tanh(0.2))
0.961043
"""
if reduce_axes:
raise NotImplementedError("reduce_axes argument to grad is deprecated")
del reduce_axes
value_and_grad_f = value_and_grad(fun, argnums, has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int)
docstr = ("Gradient of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"gradient, which has the same shape as the arguments at "
"positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f(*args, **kwargs):
_, g = value_and_grad_f(*args, **kwargs)
return g
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def grad_f_aux(*args, **kwargs):
(_, aux), g = value_and_grad_f(*args, **kwargs)
return g, aux
return grad_f_aux if has_aux else grad_f
@partial(api_boundary, repro_api_name="jax.value_and_grad")
def value_and_grad(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
) -> Callable[..., tuple[Any, Any]]:
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, or standard Python containers. It
should return a scalar (which includes arrays with shape ``()`` but not
arrays with shape ``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
Returns:
A function with the same arguments as ``fun`` that evaluates both ``fun``
and the gradient of ``fun`` and returns them as a pair (a two-element
tuple). If ``argnums`` is an integer then the gradient has the same shape
and type as the positional argument indicated by that integer. If argnums is
a sequence of integers, the gradient is a tuple of values with the same
shapes and types as the corresponding arguments. If ``has_aux`` is True
then a tuple of ((value, auxiliary_data), gradient) is returned.
"""
from jax._src.lax import lax as lax_internal # pytype: disable=import-error
if reduce_axes:
raise NotImplementedError("reduce_axes argument to grad is deprecated")
del reduce_axes
docstr = ("Value and gradient of {fun} with respect to positional "
"argument(s) {argnums}. Takes the same arguments as {fun} but "
"returns a two-element tuple where the first element is the value "
"of {fun} and the second element is the gradient, which has the "
"same shape as the arguments at positions {argnums}.")
check_callable(fun)
argnums = core.concrete_or_error(_ensure_index, argnums)
@wraps(fun, docstr=docstr, argnums=argnums)
@api_boundary
def value_and_grad_f(*args, **kwargs):
max_argnum = argnums if isinstance(argnums, int) else max(argnums)
if max_argnum >= len(args):
raise TypeError(f"differentiating with respect to {argnums=} requires at least "
f"{max_argnum + 1} positional arguments to be passed by the caller, "
f"but got only {len(args)} positional arguments.")
dbg = debug_info('value_and_grad', fun, args, kwargs)
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
for leaf in tree_leaves(dyn_args):
_check_input_dtype_grad(holomorphic, allow_int, leaf)
if has_aux:
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
else:
ans, vjp_py = _vjp(f_partial, *dyn_args)
aux = None
_check_scalar(ans)
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
g = vjp_py(lax_internal._one_vjp(ans))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
return ans, g
else:
return (ans, aux), g
return value_and_grad_f
def _check_scalar(x):
msg = "Gradient only defined for scalar-output functions. Output {}.".format
try:
aval = core.typeof(x)
except TypeError as e:
raise TypeError(msg(f"was {x}")) from e
else:
if isinstance(aval, ShapedArray):
if aval.shape != ():
raise TypeError(msg(f"had shape: {aval.shape}"))
else:
raise TypeError(msg(f"had abstract value {aval}"))
def _check_input_dtype_revderiv(name, holomorphic, allow_int, x):
dispatch.check_arg(x)
aval = core.typeof(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires inputs with complex dtype, "
f"but got {aval.dtype.name}.")
if isinstance(aval, ShapedArray):
if (dtypes.issubdtype(aval.dtype, dtypes.extended) or
dtypes.issubdtype(aval.dtype, np.integer) or
dtypes.issubdtype(aval.dtype, np.bool_)):
if not allow_int:
raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. "
"If you want to use Boolean- or integer-valued inputs, use vjp "
"or set allow_int to True.")
elif not dtypes.issubdtype(aval.dtype, np.inexact):
raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.")
_check_input_dtype_grad = partial(_check_input_dtype_revderiv, "grad")
def _check_output_dtype_revderiv(name, holomorphic, x):
aval = core.typeof(x)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
raise TypeError(
f"{name} with output element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
elif dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving complex "
"outputs, use jax.vjp directly.")
elif not dtypes.issubdtype(aval.dtype, np.floating):
raise TypeError(f"{name} requires real-valued outputs (output dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For differentiation of functions with integer outputs, use "
"jax.vjp directly.")
_check_output_dtype_grad = partial(_check_output_dtype_revderiv, "grad")
@partial(api_boundary, repro_api_name="jax.fwd_and_bwd")
def fwd_and_bwd(
fun: Callable, argnums: int | Sequence[int], has_aux: bool = False,
jitted: bool = True,
) -> tuple[Callable, Callable]:
"""Creates functions ``fwd`` and ``bwd`` corresponding to the forward and
backward pass of a given function ``fun``. The forward function ``fwd(*args)``
functionally behaves much like ``y, fun_vjp = jax.vjp(fun, *args)``, but allows
reuse of the backward function ``bwd`` across multiple iterations, which is
useful to avoid recompilation when the forward and backward do not end up in a
single jitted function:
>>> import jax
>>>
>>> x = W = cot_out = jax.numpy.ones((4,4))
>>>
>>> def f(x, W):
... return x @ W
...
>>> f_jitted = jax.jit(f)
>>> for i in range(3):
... y, f_vjp = jax.vjp(f_jitted, x, W)
... cot_x, cot_W = f_vjp(cot_out) # not jitted
... cot_x, cot_W = jax.jit(f_vjp)(cot_out) # recompiles on every iteration
...
>>> fwd, bwd = jax.fwd_and_bwd(f, argnums=(0,1))
>>> for i in range(3):
... y, residuals = fwd(x, W)
... cot_x, cot_W = bwd(residuals, cot_out) # jitted, compiles once
...
Args:
fun: Function to produce a forward and backward of.
argnums: Integer or sequence of integers. Specifies which positional argument(s)
to differentiate with respect to.
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
jitted: Optional, bool. Indicates whether to return the ``jax.jit`` of
forward and backward. Note that jit-ing only the backward but not the
forward will result in the backward recompiling on every invocation, so we
default to jit-ing both.
Returns:
The two functions, ``fwd`` and ``bwd``.
If ``has_aux`` is ``False``, ``fwd(*primals)`` returns a tuple
``(primals_out, residuals)``, where ``primals_out`` is ``fun(*primals)``.
If ``has_aux`` is ``True``, returns a ``(primals_out, residuals, aux)`` tuple
where ``aux`` is the auxiliary data returned by ``fun``.
``bwd`` is a function from ``residuals`` and a cotangent vector with the same
shape as ``primals_out`` to a tuple of cotangent vectors with the same number
and shapes as the ``primals`` designated by ``argnums``, representing the
vector-Jacobian product of ``fun`` evaluated at ``primals``.
"""
check_callable(fun)
argnums = _ensure_index(argnums)
def fwd(*args, **kwargs):
dbg = debug_info('fwd_and_bwd', fun, args, kwargs)
f = lu.wrap_init(fun, params=kwargs, debug_info=dbg)
f_partial, dyn_args = argnums_partial(
f, argnums, args, require_static_args_hashable=False)
return _vjp(f_partial, *dyn_args, has_aux=has_aux)
def bwd(f_vjp, outgrad):
g = f_vjp(outgrad)
g = g[0] if isinstance(argnums, int) else g
return g
if jitted:
fwd = jit(fwd)
bwd = jit(bwd)
return fwd, bwd
@partial(api_boundary, repro_api_name="jax.jacfwd")
def jacfwd(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False) -> Callable:
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using forward-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
... return jnp.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1. 0. 0. ]
[ 0. 0. 5. ]
[ 0. 16. -2. ]
[ 1.6209 0. 0.84147]]
"""
check_callable(fun)
argnums = _ensure_index(argnums)
docstr = ("Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(
fun, kwargs,
debug_info=debug_info(
"jacfwd", fun, args, kwargs,
static_argnums=(argnums,) if isinstance(argnums, int) else argnums))
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacfwd, holomorphic), dyn_args)
pushfwd: Callable = partial(_jvp, f_partial, dyn_args, has_aux=has_aux)
if has_aux:
y, jac, aux = vmap(pushfwd, out_axes=(None, -1, None))(_std_basis(dyn_args))
else:
y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
aux = None
tree_map(partial(_check_output_dtype_jacfwd, holomorphic), y)
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacfwd_unravel, example_args), y, jac)
if not has_aux:
return jac_tree
else:
return jac_tree, aux
return jacfun
def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
dispatch.check_arg(x)
aval = core.typeof(x)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
raise TypeError(
f"jacfwd with input element type {aval.dtype.name}")
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires inputs with complex "
f"dtype, but got {aval.dtype.name}.")
elif not dtypes.issubdtype(aval.dtype, np.floating):
raise TypeError("jacfwd requires real-valued inputs (input dtype that is "
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving "
"complex inputs or integer inputs, use jax.jvp directly.")
def _check_output_dtype_jacfwd(holomorphic, x):
aval = core.typeof(x)
if holomorphic:
if not dtypes.issubdtype(aval.dtype, np.complexfloating):
raise TypeError("jacfwd with holomorphic=True requires outputs with complex dtype, "
f"but got {aval.dtype.name}.")
@partial(api_boundary, repro_api_name="jax.jacrev")
def jacrev(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False) -> Callable:
"""Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.
Args:
fun: Function whose Jacobian is to be computed.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the Jacobian of
``fun`` using reverse-mode automatic differentiation. If ``has_aux`` is True
then a pair of (jacobian, auxiliary_data) is returned.
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
... return jnp.asarray(
... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
[[ 1. 0. 0. ]
[ 0. 0. 5. ]
[ 0. 16. -2. ]
[ 1.6209 0. 0.84147]]
"""
check_callable(fun)
docstr = ("Jacobian of {fun} with respect to positional argument(s) "
"{argnums}. Takes the same arguments as {fun} but returns the "
"jacobian of the output with respect to the arguments at "
"positions {argnums}.")
@wraps(fun, docstr=docstr, argnums=argnums)
def jacfun(*args, **kwargs):
f = lu.wrap_init(
fun, kwargs,
debug_info=debug_info(
"jacrev", fun, args, kwargs,
static_argnums=(argnums,) if isinstance(argnums, int) else argnums))
f_partial, dyn_args = argnums_partial(f, argnums, args,
require_static_args_hashable=False)
tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
if has_aux:
y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)
else:
y, pullback = _vjp(f_partial, *dyn_args)
aux = None
tree_map(partial(_check_output_dtype_jacrev, holomorphic), y)
jac = vmap(pullback)(_std_basis(y))
jac = jac[0] if isinstance(argnums, int) else jac
example_args = dyn_args[0] if isinstance(argnums, int) else dyn_args
jac_tree = tree_map(partial(_jacrev_unravel, y), example_args, jac)
jac_tree = tree_transpose(tree_structure(example_args), tree_structure(y), jac_tree)
if not has_aux:
return jac_tree
else:
return jac_tree, aux
return jacfun
def jacobian(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable:
"""Alias of :func:`jax.jacrev`."""
return jacrev(fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int)
_check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev")
_check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev")
@partial(api_boundary, repro_api_name="jax.hessian")
def hessian(fun: Callable, argnums: int | Sequence[int] = 0,
has_aux: bool = False, holomorphic: bool = False) -> Callable:
"""Hessian of ``fun`` as a dense array.
Args:
fun: Function whose Hessian is to be computed. Its arguments at positions
specified by ``argnums`` should be arrays, scalars, or standard Python
containers thereof. It should return arrays, scalars, or standard Python
containers thereof.
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default ``0``).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. Default False.
Returns:
A function with the same arguments as ``fun``, that evaluates the Hessian of
``fun``.
>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[ 6. -2.]
[ -2. -480.]]
:py:func:`hessian` is a generalization of the usual definition of the Hessian
that supports nested Python containers (i.e. pytrees) as inputs and outputs.
The tree structure of ``jax.hessian(fun)(x)`` is given by forming a tree
product of the structure of ``fun(x)`` with a tree product of two copies of
the structure of ``x``. A tree product of two tree structures is formed by
replacing each leaf of the first tree with a copy of the second. For example:
>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': Array([[[ 2., 0.], [ 0., 0.]],
[[ 0., 0.], [ 0., 12.]]], dtype=float32),
'b': Array([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)},
'b': {'a': Array([[[ 1. , 0. ], [ 0. , 0. ]],
[[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32),
'b': Array([[[0. , 0. ], [0. , 0. ]],
[[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of ``jax.hessian(fun)(x)`` corresponds to
a leaf of ``fun(x)`` and a pair of leaves of ``x``. For each leaf in
``jax.hessian(fun)(x)``, if the corresponding array leaf of ``fun(x)`` has
shape ``(out_1, out_2, ...)`` and the corresponding array leaves of ``x`` have
shape ``(in_1_1, in_1_2, ...)`` and ``(in_2_1, in_2_2, ...)`` respectively,
then the Hessian leaf has shape ``(out_1, out_2, ..., in_1_1, in_1_2, ...,
in_2_1, in_2_2, ...)``. In other words, the Python tree structure represents
the block structure of the Hessian, with blocks determined by the input and
output pytrees.
In particular, an array is produced (with no pytrees involved) when the
function input ``x`` and output ``fun(x)`` are each a single array, as in the
``g`` example above. If ``fun(x)`` has shape ``(out1, out2, ...)`` and ``x``
has shape ``(in1, in2, ...)`` then ``jax.hessian(fun)(x)`` has shape
``(out1, out2, ..., in1, in2, ..., in1, in2, ...)``. To flatten pytrees into
1D vectors, consider using :py:func:`jax.flatten_util.flatten_pytree`.
"""
return jacfwd(jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic),
argnums, has_aux=has_aux, holomorphic=holomorphic)
def _insert_pvary(basis, leaf):
if not config._check_vma.value:
return basis
return core.pvary(basis, tuple(core.typeof(leaf).mat.varying))
def _std_basis(pytree):
import jax.numpy as jnp # pytype: disable=import-error
leaves, _ = tree_flatten(pytree)
ndim = sum(map(np.size, leaves))
dtype = dtypes.result_type(*leaves)
flat_basis = jnp.eye(ndim, dtype=dtype)
axis = 1
arr_s = [None] * flat_basis.ndim
specs = tree_map(lambda l: P(arr_s[:axis], *core.typeof(l).sharding.spec,
arr_s[axis+1:]), pytree)
out_pytree = _unravel_array_into_pytree(pytree, axis, None, flat_basis, specs)
out_pytree = tree_map(_insert_pvary, out_pytree, pytree)
return out_pytree
def _jacfwd_unravel(input_pytree, output_pytree_leaf, arr):
axis = -1 % arr.ndim
arr_s = core.typeof(arr).sharding.spec
specs = tree_map(
lambda l: P(*arr_s[:axis], *[None] * len(np.shape(l)), *arr_s[axis+1:]),
input_pytree)
return _unravel_array_into_pytree(
input_pytree, axis, output_pytree_leaf, arr, specs)
def _jacrev_unravel(output_pytree, input_pytree_leaf, arr):
specs = tree_map(
lambda l: P(*[None] * len(np.shape(l)), *core.typeof(arr).sharding.spec[1:]),
output_pytree)
return _unravel_array_into_pytree(
output_pytree, 0, input_pytree_leaf, arr, specs)
def _possible_downcast(x, example, spec):
from jax._src.lax import lax as lax_internal # pytype: disable=import-error
if (dtypes.issubdtype(x.dtype, np.complexfloating) and
not dtypes.issubdtype(_dtype(example), np.complexfloating)):
x = x.real
dtype = _dtype(example)
weak_type = dtypes.is_weakly_typed(example)
sharding = NamedSharding(core.typeof(example).sharding.mesh, spec)
return lax_internal._convert_element_type(
x, dtype, weak_type, sharding=sharding)
def _unravel_array_into_pytree(pytree, axis, example, arr, specs):
"""Unravel an array into a PyTree with a given structure.
Args:
pytree: The pytree that provides the structure.
axis: The parameter axis is either -1, 0, or 1. It controls the
resulting shapes.
example: If specified, cast the components to the matching dtype/weak_type,
or else use the pytree leaf type if example is None.
arr: The array to be unraveled.
"""
leaves, treedef = tree_flatten(pytree)
specs, _ = tree_flatten(specs)
shapes = [arr.shape[:axis] + np.shape(l) + arr.shape[axis+1:] for l in leaves]
parts = _split(arr, np.cumsum(map(np.size, leaves[:-1])), axis)
reshaped_parts = [
_possible_downcast(np.reshape(x, shape),
leaf if example is None else example,
spec=spec)
for x, shape, leaf, spec in zip(parts, shapes, leaves, specs)]
return tree_unflatten(treedef, reshaped_parts)
def _split(x, indices, axis):
if isinstance(x, np.ndarray):
return np.split(x, indices, axis)
else:
return x._split(indices, axis)
@partial(api_boundary, repro_api_name="jax.vmap")
def vmap(fun: F,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
sum_match: bool = False
) -> F:
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
Args:
fun: Function to be mapped over additional axes.
in_axes: An integer, None, or sequence of values specifying which input
array axes to map over.
If each positional argument to ``fun`` is an array, then ``in_axes`` can
be an integer, a None, or a tuple of integers and Nones with length equal
to the number of positional arguments to ``fun``. An integer or ``None``
indicates which array axis to map over for all arguments (with ``None``
indicating not to map any axis), and a tuple indicates which axis to map
for each corresponding positional argument. Axis integers must be in the
range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of
dimensions (axes) of the corresponding input array.
If the positional arguments to ``fun`` are container (pytree) types, ``in_axes``
must be a sequence with length equal to the number of positional arguments to
``fun``, and for each argument the corresponding element of ``in_axes`` can
be a container with a matching pytree structure specifying the mapping of its
container elements. In other words, ``in_axes`` must be a container tree prefix
of the positional argument tuple passed to ``fun``. See this link for more detail:
https://docs.jax.dev/en/latest/pytrees.html#applying-optional-parameters-to-pytrees
Either ``axis_size`` must be provided explicitly, or at least one
positional argument must have ``in_axes`` not None. The sizes of the
mapped input axes for all mapped positional arguments must all be equal.
Arguments passed as keywords are always mapped over their leading axis
(i.e. axis index 0).
See below for examples.
out_axes: An integer, None, or (nested) standard Python container
(tuple/list/dict) thereof indicating where the mapped axis should appear
in the output. All outputs with a mapped axis must have a non-None
``out_axes`` specification. Axis integers must be in the range ``[-ndim,
ndim)`` for each output array, where ``ndim`` is the number of dimensions
(axes) of the array returned by the :func:`vmap`-ed function, which is one
more than the number of dimensions (axes) of the corresponding array
returned by ``fun``.
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
axis_size: Optional, an integer indicating the size of the axis to be
mapped. If not provided, the mapped axis size is inferred from arguments.
Returns:
Batched/vectorized version of ``fun`` with arguments that correspond to
those of ``fun``, but with extra array axes at positions indicated by
``in_axes``, and a return value that corresponds to that of ``fun``, but
with extra array axes at positions indicated by ``out_axes``.
For example, we can implement a matrix-matrix product using a vector dot
product:
>>> import jax.numpy as jnp
>>>
>>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
Here we use ``[a,b]`` to indicate an array with shape (a,b). Here are some
variants:
>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
Here's an example of using container types in ``in_axes`` to specify which
axes of the container elements to map over:
>>> A, B, C, D = 2, 3, 4, 5
>>> x = jnp.ones((A, B))
>>> y = jnp.ones((B, C))
>>> z = jnp.ones((C, D))
>>> def foo(tree_arg):
... x, (y, z) = tree_arg
... return jnp.dot(x, jnp.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6 # batch size
>>> x = jnp.ones((K, A, B)) # batch axis in different locations
>>> y = jnp.ones((B, K, C))
>>> z = jnp.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree).shape)
(6, 2, 5)
Here's another example using container types in ``in_axes``, this time a
dictionary, to specify the elements of the container to map over:
>>> dct = {'a': 0., 'b': jnp.arange(5.)}
>>> x = 1.
>>> def foo(dct, x):
... return dct['a'] + dct['b'] + x
>>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
>>> print(out)
[1. 2. 3. 4. 5.]
The results of a vectorized function can be mapped or unmapped. For example,
the function below returns a pair with the first element mapped and the second
unmapped. Only for unmapped results we can specify ``out_axes`` to be ``None``
(to keep it unmapped).
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), 8.0)
If the ``out_axes`` is specified for an unmapped result, the result is
broadcast across the mapped axis:
>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(Array([4., 5.], dtype=float32), Array([8., 8.], dtype=float32, weak_type=True))
If the ``out_axes`` is specified for a mapped result, the result is transposed
accordingly.
Finally, here's an example using ``axis_name`` together with collectives:
>>> xs = jnp.arange(3. * 4.).reshape(3, 4)
>>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))
[[12. 15. 18. 21.]
[12. 15. 18. 21.]
[12. 15. 18. 21.]]
See the :py:func:`jax.pmap` docstring for more examples involving collectives.
"""
check_callable(fun)
docstr = ("Vectorized version of {fun}. Takes similar arguments as {fun} "
"but with additional array axes over which {fun} is mapped.")
if fun.__doc__:
docstr += "\n\nOriginal documentation:\n\n"
docstr += fun.__doc__
axis_name = core.no_axis_name if axis_name is None else axis_name
if spmd_axis_name is not None and not isinstance(spmd_axis_name, tuple):
spmd_axis_name = (spmd_axis_name,)
if isinstance(in_axes, list):
# To be a tree prefix of the positional args tuple, in_axes can never be a
# list: if in_axes is not a leaf, it must be a tuple of trees. However,
# in cases like these users expect tuples and lists to be treated
# essentially interchangeably, so we canonicalize lists to tuples here
# rather than raising an error. https://github.com/jax-ml/jax/issues/2367
in_axes = tuple(in_axes)
from jax._src import hijax # pytype: disable=import-error
if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}
or isinstance(in_axes, hijax.MappingSpec)):
raise TypeError("vmap in_axes must be an int, None, or a tuple of entries corresponding "
f"to the positional arguments passed to the function, but got {in_axes}.")
if not all(type(l) in {int, *batching.spec_types} or isinstance(l, hijax.MappingSpec)
for l in tree_leaves(in_axes)):
raise TypeError("vmap in_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {in_axes}.")
if not all(type(l) in {int, *batching.spec_types} or isinstance(l, hijax.MappingSpec)
for l in tree_leaves(out_axes)):
raise TypeError("vmap out_axes must be an int, None, or (nested) container "
f"with those types as leaves, but got {out_axes}.")
@wraps(fun, docstr=docstr)
@api_boundary
def vmap_f(*args, **kwargs):
nonlocal spmd_axis_name
if isinstance(in_axes, tuple) and len(in_axes) != len(args):
raise ValueError("vmap in_axes must be an int, None, or a tuple of entries corresponding "
"to the positional arguments passed to the function, "
f"but got {len(in_axes)=}, {len(args)=}")
args_flat, in_tree = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
dbg = debug_info("vmap", fun, args, kwargs)
api_util.check_no_transformed_refs_args(lambda: dbg, args_flat)
f = lu.wrap_init(fun, debug_info=dbg)
flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
if config.mutable_array_checks.value:
avals = [None if d is None or batching.is_vmappable(x) else core.typeof(x)
for x, d in zip(args_flat, in_axes_flat)]
api_util.check_no_aliased_ref_args(lambda: dbg, avals, args_flat)
axis_size_ = _mapped_axis_size(
fun, in_tree, args_flat, in_axes_flat, "vmap", axis_size=axis_size)
explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
_check_ema_unmapped_args(explicit_mesh_axis, args_flat, in_axes_flat)
if spmd_axis_name is not None and explicit_mesh_axis is not None:
if config.remove_size_one_mesh_axis_from_type.value:
mesh = get_abstract_mesh()
spmd_axis_name = tuple(i for i in spmd_axis_name if mesh.shape[i] != 1)
if spmd_axis_name == explicit_mesh_axis:
spmd_axis_name = None
else:
raise ValueError(
"Only one of spmd_axis_name or arrays sharded on `Explicit` mesh"
f" axis type is allowed. Got {spmd_axis_name=} and"
f" arrays sharded on {explicit_mesh_axis=}")
assert spmd_axis_name is None
try:
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
explicit_mesh_axis)
out_flat, inferred_out_axes = batching.batch(
flat_fun, axis_data, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
sum_match=sum_match
).call_wrapped(*args_flat)
except batching.SpecMatchError as e:
out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)
out_axes_full = tree_unflatten(out_tree(), out_axes_flat)
pairs, _ = tree_flatten_with_path(out_axes_full, is_leaf=lambda x: x is None)
path, _ = pairs[e.leaf_idx]
raise ValueError(f'at vmap out_axes{keystr(path)}, got axis spec {e.dst} '
f'but output was batched on axis {e.src}') from None
if any(d is batching.infer for d in tree_leaves(out_axes)):
return (tree_unflatten(out_tree(), out_flat),
tree_unflatten(out_tree(), inferred_out_axes))
else:
return tree_unflatten(out_tree(), out_flat)
return cast(F, vmap_f)
def _mapped_axis_spec(args_flat, in_axes_flat):
def _get_spec(arg, i):
try:
# Duck type arrays like BCOO arrays can be passed to vmap.
return shaped_abstractify(arg).sharding.spec[i]
except (IndexError, TypeError, AttributeError):
return None
out_spec = None
non_none_count = 0
for arg, i in zip(args_flat, in_axes_flat):
if i is not None:
spec = _get_spec(arg, i)
if non_none_count != 0 and out_spec != spec:
raise ValueError(
"Mapped away dimension of inputs passed to vmap should be sharded"
f" the same. Got inconsistent axis specs: {out_spec} vs {spec}")
out_spec = spec
non_none_count += 1
if out_spec is not None and not isinstance(out_spec, tuple):
out_spec = (out_spec,)
return out_spec
def _check_ema_unmapped_args(ema, args_flat, in_axes_flat):
if ema is None:
return
for a, i in zip(args_flat, in_axes_flat):
if i is None:
aval = core.typeof(a)
spec = set(sharding_impls.flatten_spec(aval.sharding.spec))
if any(e in spec for e in ema):
raise ValueError(
"Unmapped values passed to vmap cannot be sharded along the mesh"
f" axis you are vmapping over. Got type: {aval.str_short(True)},"
f" in_axes: {i} and vmapped mesh axis: {ema}")
def _mapped_axis_size(fn, tree, vals, dims, name, axis_size=None):
if not vals:
if axis_size is not None:
return axis_size
args, kwargs = tree_unflatten(tree, vals)
raise ValueError(
f"{name} wrapped function must be passed at least one argument "
"containing an array or axis_size must be specified, got empty "
f"*args={args} and **kwargs={kwargs}"
)
def _get_axis_size(name: str, x, shape: tuple[core.AxisSize, ...], axis: int
) -> core.AxisSize | None:
try:
return shape[axis]
except (IndexError, TypeError) as e:
if not core.valid_jaxtype(x) or not isinstance(axis, int):
return None # Suppress the check for custom vmappable types.
min_rank = axis + 1 if axis >= 0 else -axis
# TODO(mattjj): better error message here
raise ValueError(
f"{name} was requested to map its argument along axis {axis}, "
f"which implies that its rank should be at least {min_rank}, "
f"but is only {len(shape)} (its shape is {shape})") from e
all_mapped_sizes = [
None if d is None else _get_axis_size(name, x, np.shape(x), d)
for x, d in zip(vals, dims)
]
all_sizes = [s for s in all_mapped_sizes if s is not None]
if axis_size is not None:
all_sizes.append(axis_size)
sizes = core.dedup_referents(all_sizes)
if len(sizes) == 1:
sz, = sizes
return sz
if not sizes:
raise ValueError(f"{name} must have at least one non-None value in in_axes "
"or axis_size must be specified")
def _get_argument_type(x):
try:
return shaped_abstractify(x).str_short()
except TypeError: # Catch all for user specified objects that can't be interpreted as a data type
return "unknown"
msg = [f"{name} got inconsistent sizes for array axes to be mapped:\n"]
args, kwargs = tree_unflatten(tree, vals)
try:
ba = inspect.signature(fn).bind(*args, **kwargs)
signature_parameters: list[str] | None = list(ba.signature.parameters.keys())
except (TypeError, ValueError):
signature_parameters = None
def arg_name(key_path):
if signature_parameters is None:
return f"args{keystr(key_path)}"
# args is a tuple, so key_path[0].idx is the index into args.
i = key_path[0].idx
# This can happen with star arguments (*args)
if i >= len(signature_parameters):
return f"args{keystr(key_path)}"
res = f"argument {signature_parameters[i]}"
if len(key_path) > 1:
res += keystr(key_path[1:])
return res
args_paths = [
f"{arg_name(p)} of type {_get_argument_type(x)}"
for (p, x) in generate_key_paths(args)
]
kwargs_paths = [
f"kwargs{keystr(p)} of type {_get_argument_type(x)}"
for p, x in generate_key_paths(kwargs)
]
key_paths = [*args_paths, *kwargs_paths]
size_counts = collections.Counter(s for s in all_mapped_sizes if s is not None)
(sz, ct), *other_counts = counts = size_counts.most_common()
def _all_sizes_index(sz):
for i, isz in enumerate(all_mapped_sizes):
if core.definitely_equal(isz, sz): return i
assert False, (sz, all_mapped_sizes)
ex, *examples = (key_paths[_all_sizes_index(sz)] for sz, _ in counts)
ax, *axs = (dims[_all_sizes_index(sz)] for sz, _ in counts)
if axis_size is not None:
msg.append(f" * the `axis_size` argument was {axis_size};\n")
if ct == 1:
msg.append(f" * one axis had size {sz}: axis {ax} of {ex};\n")
else:
msg.append(f" * most axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
for ex, ax, (sz, ct) in zip(examples, axs, other_counts):
if ct == 1:
msg.append(f" * one axis had size {sz}: axis {ax} of {ex};\n")
else:
msg.append(f" * some axes ({ct} of them) had size {sz}, e.g. axis {ax} of {ex};\n")
raise ValueError(''.join(msg)[:-2]) # remove last semicolon and newline
@partial(api_boundary, repro_api_name="jax.jvp")
def jvp(
fun: Callable, primals, tangents, has_aux: bool = False
) -> tuple[Any, ...]:
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Jacobian of ``fun`` should be
evaluated. Should be either a tuple or a list of arguments,
and its length should be equal to the number of positional parameters of
``fun``.
tangents: The tangent vector for which the Jacobian-vector product should be
evaluated. Should be either a tuple or a list of tangents, with the same
tree structure and array shapes as ``primals``.
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
Returns:
If ``has_aux`` is ``False``, returns a ``(primals_out, tangents_out)`` pair,
where ``primals_out`` is ``fun(*primals)``,
and ``tangents_out`` is the Jacobian-vector product of
``function`` evaluated at ``primals`` with ``tangents``. The
``tangents_out`` value has the same Python tree structure and shapes as
``primals_out``. If ``has_aux`` is ``True``, returns a
``(primals_out, tangents_out, aux)`` tuple where ``aux``
is the auxiliary data returned by ``fun``.
For example:
>>> import jax
>>>
>>> primals, tangents = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
>>> print(primals)
0.09983342
>>> print(tangents)
0.19900084
"""
check_callable(fun)
if (not isinstance(primals, (tuple, list)) or
not isinstance(tangents, (tuple, list))):
raise TypeError("primal and tangent arguments to jax.jvp must be tuples or lists; "
f"found {type(primals).__name__} and {type(tangents).__name__}.")
return _jvp(lu.wrap_init(fun, debug_info=debug_info("jvp", fun, primals, {})),
primals, tangents, has_aux=has_aux)
def _jvp(fun: lu.WrappedFun, primals, tangents, has_aux=False):
"""Variant of jvp() that takes an lu.WrappedFun."""
ps_flat, tree_def = tree_flatten(primals)
ts_flat, tree_def_2 = tree_flatten(tangents)
if tree_def != tree_def_2:
raise TypeError("primal and tangent arguments to jax.jvp must have the same tree "
f"structure; primals have tree structure {tree_def} whereas tangents have "
f"tree structure {tree_def_2}.")
for p, t in zip(ps_flat, ts_flat):
if not isinstance(core.typeof(p), ShapedArray): continue
if core.primal_dtype_to_tangent_dtype(_dtype(p)) != _dtype(t):
raise TypeError("primal and tangent arguments to jax.jvp do not match; "
"dtypes must be equal, or in case of int/bool primal dtype "
"the tangent dtype must be float0."
f"Got primal dtype {_dtype(p)} and so expected tangent dtype "
f"{core.primal_dtype_to_tangent_dtype(_dtype(p))}, but got "
f"tangent dtype {_dtype(t)} instead.")
if np.shape(p) != np.shape(t):
raise ValueError("jvp called with different primal and tangent shapes;"
f"Got primal shape {np.shape(p)} and tangent shape as {np.shape(t)}")
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, tree_def)
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
out_tree = out_tree()
return (tree_unflatten(out_tree, out_primals),
tree_unflatten(out_tree, out_tangents))
else:
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, tree_def)
jvp_fun, aux = ad.jvp(flat_fun, has_aux=True)
out_primals, out_tangents = jvp_fun.call_wrapped(ps_flat, ts_flat)
out_tree, aux_tree = out_aux_trees()
return (tree_unflatten(out_tree, out_primals),
tree_unflatten(out_tree, out_tangents),
tree_unflatten(aux_tree, aux()))
@overload
def linearize(fun: Callable, *primals, has_aux: Literal[False] = False
) -> tuple[Any, Callable]:
...
@overload
def linearize(fun: Callable, *primals, has_aux: Literal[True]
) -> tuple[Any, Callable, Any]:
...
@partial(api_boundary, repro_api_name="jax.linearize")
def linearize(fun: Callable, *primals, has_aux: bool = False
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
"""Produces a linear approximation to ``fun`` using :py:func:`jvp` and partial eval.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard python container of arrays or scalars.
primals: The primal values at which the Jacobian of ``fun`` should be
evaluated. Should be a tuple of arrays, scalar, or standard Python
container thereof. The length of the tuple is equal to the number of
positional parameters of ``fun``.
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first
element is considered the output of the mathematical function to be linearized,
and the second is auxiliary data. Default False.
Returns:
If ``has_aux`` is ``False``, returns a pair where the first element is the value of
``f(*primals)`` and the second element is a function that evaluates the
(forward-mode) Jacobian-vector product of ``fun`` evaluated at ``primals`` without
re-doing the linearization work. If ``has_aux`` is ``True``, returns a
``(primals_out, lin_fn, aux)`` tuple where ``aux`` is the auxiliary data returned by
``fun``.
In terms of values computed, :py:func:`linearize` behaves much like a curried
:py:func:`jvp`, where these two code blocks compute the same values::
y, out_tangent = jax.jvp(f, (x,), (in_tangent,))
y, f_jvp = jax.linearize(f, x)
out_tangent = f_jvp(in_tangent)
However, the difference is that :py:func:`linearize` uses partial evaluation
so that the function ``f`` is not re-linearized on calls to ``f_jvp``. In
general that means the memory usage scales with the size of the computation,
much like in reverse-mode. (Indeed, :py:func:`linearize` has a similar
signature to :py:func:`vjp`!)
This function is mainly useful if you want to apply ``f_jvp`` multiple times,
i.e. to evaluate a pushforward for many different input tangent vectors at the
same linearization point. Moreover if all the input tangent vectors are known
at once, it can be more efficient to vectorize using :py:func:`vmap`, as in::
pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using :py:func:`vmap` and :py:func:`jvp` together like this we avoid the stored-linearization
memory cost that scales with the depth of the computation, which is incurred
by both :py:func:`linearize` and :py:func:`vjp`.
Here's a more complete example of using :py:func:`linearize`:
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(Array(3.2681944, dtype=float32, weak_type=True), Array(-5.007528, dtype=float32, weak_type=True))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
>>> print(f_jvp(3.))
-5.007528
>>> print(f_jvp(4.))
-6.676704
"""
check_callable(fun)
f = lu.wrap_init(fun, debug_info=debug_info("linearize", fun, primals, {}))
primals_flat, in_tree = tree_flatten(primals)
if has_aux:
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
else:
jaxtree_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
out_primals, out_pvals, jaxpr, consts, *maybe_aux = ad.linearize(
jaxtree_fun, *primals_flat, has_aux=has_aux)
if has_aux:
out_tree, aux_tree = out_tree()
else:
out_tree, aux_tree = out_tree(), None
out_primal_py = tree_unflatten(out_tree, out_primals)
primal_avals = list(map(core.typeof, primals_flat))
# Ensure that lifted_jvp is a PyTree
lifted_jvp = Partial(partial(_lift_linearized, jaxpr, primal_avals,
(in_tree, out_tree), out_pvals), consts)
if has_aux:
[aux] = maybe_aux
assert aux_tree is not None
return out_primal_py, lifted_jvp, tree_unflatten(aux_tree, aux)
else:
[] = maybe_aux
return out_primal_py, lifted_jvp
def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args):
def fun(*tangents):
tangent_avals = list(map(core.typeof, tangents))
for primal_aval, tangent_aval in zip(primal_avals, tangent_avals):
expected_tangent_aval = primal_aval.to_tangent_aval()
if not core.typecompat(expected_tangent_aval, tangent_aval):
extra_msg = ''
if (isinstance(primal_aval, core.ShapedArray) and
isinstance(tangent_aval, core.ShapedArray) and
primal_aval.mat != tangent_aval.mat):
# TODO(yashkatariya): Tweak error.
pvary_applications = []
if left := tangent_aval.mat.varying - primal_aval.mat.varying:
pvary_applications.append(
f"applying `jax.lax.pcast(..., {tuple(left)}, to='varying')` to"
" the primal value passed to `jax.linearize`")
if left := primal_aval.mat.varying - tangent_aval.mat.varying:
pvary_applications.append(
f"applying `jax.lax.pcast(..., {tuple(left)}, to='varying')` to"
" the tangent value passed to the callable `f_jvp` returned by"
" `jax.linearize`")
extra_msg = " \nThis might be fixed by:\n" + "\n".join(
f" * {d};" for d in pvary_applications)
raise ValueError(
"linearized function called on tangent values inconsistent with "
"the original primal values:\n"
f"Got tangent aval {tangent_aval} for primal aval {primal_aval} "
f"but expected {expected_tangent_aval}.{extra_msg}")
tangents_out = eval_jaxpr(jaxpr, consts, *tangents)
tangents_out_ = iter(tangents_out)
full_out = [pval.get_known() if pval.is_known() else next(tangents_out_)
for pval in out_pvals]
assert next(tangents_out_, None) is None
return full_out
return apply_flat_fun_nokwargs(fun, io_tree, py_args)
# TODO(mattjj): see similar function in custom_derivatives.py
def _temporary_dtype_exception(a, a_) -> bool:
if isinstance(a, core.ShapedArray) and isinstance(a_, core.ShapedArray):
return a.shape == a_.shape and a_.dtype == float0
return False
@overload
def vjp(fun: Callable[..., T],
*primals: Any,
has_aux: Literal[False] = False,
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable]:
...
@overload
def vjp(fun: Callable[..., tuple[T, U]], *primals: Any,
has_aux: Literal[True],
reduce_axes: Sequence[AxisName] = ()) -> tuple[T, Callable, U]:
...
@partial(api_boundary, repro_api_name="jax.vjp")
def vjp(
fun: Callable, *primals, has_aux: bool = False, reduce_axes=()
) -> tuple[Any, Callable] | tuple[Any, Callable, Any]:
"""Compute a (reverse-mode) vector-Jacobian product of ``fun``.
:py:func:`grad` is implemented as a special case of :py:func:`vjp`.
Args:
fun: Function to be differentiated. Its arguments should be arrays, scalars,
or standard Python containers of arrays or scalars. It should return an
array, scalar, or standard Python container of arrays or scalars.
primals: A sequence of primal values at which the Jacobian of ``fun``
should be evaluated. The number of ``primals`` should be equal to the
number of positional parameters of ``fun``. Each primal value should be
an array, a scalar, or a pytree (standard Python containers) thereof.
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
Returns:
If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where
``primals_out`` is ``fun(*primals)``. If ``has_aux`` is ``True``, returns a
``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data
returned by ``fun``.
``vjpfun`` is a function from a cotangent vector with the same shape as
``primals_out`` to a tuple of cotangent vectors with the same number and
shapes as ``primals``, representing the vector-Jacobian product of ``fun``
evaluated at ``primals``.
>>> import jax
>>>
>>> def f(x, y):
... return jax.numpy.sin(x), jax.numpy.cos(y)
...
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
>>> xbar, ybar = f_vjp((-0.7, 0.3))
>>> print(xbar)
-0.61430776
>>> print(ybar)
-0.2524413
"""
if reduce_axes:
raise NotImplementedError("reduce_axes argument to vjp is deprecated")
del reduce_axes
check_callable(fun)
wrapped_fun = lu.wrap_init(
fun, debug_info=debug_info("vjp", fun, primals, {}))
return _vjp(wrapped_fun, *primals, has_aux=has_aux)
def _vjp(fun, *primals, has_aux=False):
canon = lambda x: x if isinstance(x, core.Tracer) else canonicalize_value(x)
primals = tree_map(canon, primals)
primals_flat, in_tree = tree_flatten(primals)
for arg in primals_flat:
dispatch.check_arg(arg)
if not has_aux:
flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
out_primals_flat, out_pvals, jaxpr, residuals = ad.linearize(
flat_fun, *primals_flat, is_vjp=True)
out_tree = out_tree()
aux = aux_tree = None
else:
flat_fun, out_aux_trees = flatten_fun_nokwargs2(fun, in_tree)
out_primals_flat, out_pvals, jaxpr, residuals, aux = ad.linearize(
flat_fun, *primals_flat, has_aux=True, is_vjp=True)
out_tree, aux_tree = out_aux_trees()
del out_aux_trees
out_known = [pval.is_known() for pval in out_pvals]
id_map = {id(x): i for i, x in enumerate(primals_flat)}
used, opaque_residuals = set(), []
spec = [used.add(id(r)) or RSpec(id_map[id(r)], True) if id(r) in id_map else
RSpec(opaque_residuals.append(r) or (len(opaque_residuals) - 1), False)
for r in residuals]
args_res = tuptree_map(lambda x: x if id(x) in used else NotNeeded(),
in_tree, primals_flat)
out_primal_avals = [typeof(x) for x in out_primals_flat]
f_vjp = VJP(partial(_vjp3_callable, spec, out_known, jaxpr, out_primal_avals),
in_tree, out_tree, list(args_res), opaque_residuals)
out_primals = tree_unflatten(out_tree, out_primals_flat)
if not has_aux:
return out_primals, f_vjp
else:
assert aux is not None
assert aux_tree is not None
return out_primals, f_vjp, tree_unflatten(aux_tree, aux)
def _vjp3_callable(spec, out_known, jaxpr, out_primal_avals, in_tree, out_tree,
args_res, opaque_res, *maybe_ct_refs):
if not maybe_ct_refs:
maybe_ct_refs_flat = [GradValue()] * in_tree.num_leaves
else:
maybe_ct_refs_flat, in_tree_ = tree_flatten(maybe_ct_refs)
if in_tree != in_tree_:
raise Exception # TODO accept isomorph tuple tree
args_res_ = tree_leaves(args_res, is_leaf=lambda x: isinstance(x, NotNeeded))
residuals = [args_res_[i.idx] if i.primal else opaque_res[i.idx] for i in spec]
maybe_accums = [x if isinstance(x, ad.GradAccum) else
ad.RefAccum(v.aval, x) if _is_ref(x) else ad.NullAccum(v.aval)
if isinstance(x, DontWant) else ad.ValAccum(v.aval)
for v, x in zip(jaxpr.invars, maybe_ct_refs_flat)]
return Partial(partial(_vjp3_bwd, in_tree, out_tree, out_known, jaxpr,
out_primal_avals), residuals, maybe_accums)
def _vjp3_bwd(in_tree, out_tree, out_known, jaxpr, out_primal_avals, residuals,
maybe_accums, out_ct):
cts_flat, out_tree_ = tree_flatten(out_ct)
if out_tree != out_tree_:
_vjp_ct_tree_error(jaxpr, out_tree, out_tree_)
_vjp_check_ct_avals(cts_flat, out_primal_avals)
cts_flat = [ct for ct, k in zip(cts_flat, out_known) if not k]
ad.backward_pass3(jaxpr, True, residuals, maybe_accums, cts_flat)
arg_cts = [x.freeze() if isinstance(x, ad.ValAccum) else
DidntWant() if isinstance(x, ad.NullAccum) else GradRef()
for x in maybe_accums]
arg_cts = map(ad.instantiate_zeros, arg_cts)
return tree_unflatten(in_tree, arg_cts)
@dataclasses.dataclass(frozen=True)
class RSpec:
idx: int
primal: bool
def tuptree_map(f, treedef, x):
return treedef.walk(lambda xs, _: tuple(xs), f, x)
def _is_ref(x):
from jax._src.state.types import AbstractRef
try:
return isinstance(typeof(x), AbstractRef)
except:
return False
_vjp_too_many_args = """
The function returned by `jax.vjp` applied to {} was called with {} arguments,
but functions returned by `jax.vjp` must be called with a single argument
corresponding to the single value returned by the function being differentiated
(even if that returned value is a tuple or other container).
For example, if we have:
def f(x):
return (x, x)
_, f_vjp = jax.vjp(f, 1.0)
the function `f` returns a single tuple as output, and so we call `f_vjp` with a
single tuple as its argument:
x_bar, = f_vjp((2.0, 2.0))
If we instead call `f_vjp(2.0, 2.0)`, with the values 'splatted out' as
arguments rather than in a tuple, this error can arise.
""".format
def _vjp_ct_tree_error(jaxpr, out_tree, ct_tree):
msg = f"""unexpected tree structure.
The argument to a VJP function returned by `jax.vjp` must match the pytree
structure of the differentiated function {jaxpr.debug_info.func_src_info}.
But the tree structures differ:
"""
msg += '\n'.join(f" * out{keystr(path)} was a {thing1} in the original "
f" output, but a {thing2} here, so {explanation}."
for path, thing1, thing2, explanation
in equality_errors_pytreedef(out_tree, ct_tree))
raise ValueError(msg)
def _vjp_check_ct_avals(cts, primal_avals):
# TODO(mattjj): improve this error by flattening with keys in the first place
for ct, aval in zip(cts, primal_avals):
ct_aval = typeof(ct)
ct_aval_expected = aval.to_ct_aval()
if (not core.typecompat(ct_aval, ct_aval_expected) and
not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
raise ValueError(
"unexpected JAX type (e.g. shape/dtype) for argument to VJP function: "
f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} "
"because the corresponding output of the differentiated function had JAX type "
f"{aval.str_short()}")
@register_dataclass
@dataclasses.dataclass(frozen=True)
class NotNeeded:
pass
@dataclasses.dataclass(frozen=True)
class GradValue:
pass
@register_dataclass
@dataclasses.dataclass(frozen=True)
class GradRef:
pass
@dataclasses.dataclass(frozen=True)
class DontWant:
pass
@register_dataclass
@dataclasses.dataclass(frozen=True)
class DidntWant:
pass
@dataclasses.dataclass
class VJP:
fun: Callable
in_tree: PyTreeDef
out_tree: PyTreeDef
args_res: list[Any]
opaque_residuals: list[Any]
jaxpr = property(lambda self: self.fun.args[2])
def __call__(self, out_ct, *extra_args):
if extra_args:
name, *_ = self.jaxpr.debug_info.func_src_info.split(' ')
raise TypeError(_vjp_too_many_args(name, len(extra_args) + 1))
return self.fun(self.in_tree, self.out_tree, self.args_res,
self.opaque_residuals)(out_ct)
def with_refs(self, *maybe_ct_refs):
return self.fun(self.in_tree, self.out_tree, self.args_res,
self.opaque_residuals, *maybe_ct_refs)
# Only safe to put these in cache keys if residuals aren't mutated. Beware!
__hash__ = object.__hash__
__eq__ = object.__eq__
register_pytree_node(
VJP,
lambda vjp: ((vjp.args_res, vjp.opaque_residuals),
(vjp.fun, vjp.in_tree, vjp.out_tree)),
lambda meta, args_res: VJP(*meta, *args_res))
@partial(api_boundary, repro_api_name="jax.linear_transpose")
def linear_transpose(fun: Callable, *primals, reduce_axes=()) -> Callable:
"""Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to :py:func:`vjp`, but
avoids the overhead of computing the forward pass.
The outputs of the transposed function will always have the exact same dtypes
as ``primals``, even if some values are truncated (e.g., from complex to
float, or from float64 to float32). To avoid truncation, use dtypes in
``primals`` that match the full range of desired outputs from the transposed
function. Integer dtypes are not supported.
Args:
fun: the linear function to be transposed.
*primals: a positional argument tuple of arrays, scalars, or (nested)
standard Python containers (tuples, lists, dicts, namedtuples, i.e.,
pytrees) of those types used for evaluating the shape/dtype of
``fun(*primals)``. These arguments may be real scalars/ndarrays, but that
is not required: only the ``shape`` and ``dtype`` attributes are accessed.
See below for an example. (Note that the duck-typed objects cannot be
namedtuples because those are treated as standard Python containers.)
Returns:
A callable that calculates the transpose of ``fun``. Valid input into this
function must have the same shape/dtypes/structure as the result of
``fun(*primals)``. Output will be a tuple, with the same
shape/dtypes/structure as ``primals``.
>>> import jax
>>>
>>> f = lambda x, y: 0.5 * x - 0.5 * y
>>> scalar = jax.ShapeDtypeStruct(shape=(), dtype=np.dtype(np.float32))
>>> f_transpose = jax.linear_transpose(f, scalar, scalar)
>>> f_transpose(1.0)
(Array(0.5, dtype=float32), Array(-0.5, dtype=float32))
"""
if reduce_axes:
raise NotImplementedError("reduce_axes argument to transpose is deprecated")
del reduce_axes
primals_flat, in_tree = tree_flatten(primals)
flat_fun, out_tree = flatten_fun_nokwargs(
lu.wrap_init(fun,
debug_info=debug_info("linear_transpose", fun, primals, {})),
in_tree)
in_avals = map(shaped_abstractify, primals_flat)
in_dtypes = map(lambda a: a.dtype, in_avals)
in_pvals = map(pe.PartialVal.unknown, in_avals)
jaxpr, out_pvals, const = pe.trace_to_jaxpr_nounits(flat_fun, in_pvals,
instantiate=True)
jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars), True)
out_avals, _ = unzip2(out_pvals)
out_dtypes = map(lambda a: a.dtype, out_avals)
if not (all(dtypes.issubdtype(d, np.inexact) for d in in_dtypes + out_dtypes)
or all(dtypes.issubdtype(d, np.integer)
for d in in_dtypes + out_dtypes)):
raise TypeError("linear_transpose only supports [float or complex] -> "
"[float or complex], and integer -> integer functions, "
f"but got {in_dtypes} -> {out_dtypes}.")
@api_boundary
def transposed_fun(const, out_cotangent):
out_cts, out_tree2 = tree_flatten(out_cotangent)
if out_tree() != out_tree2:
raise TypeError("cotangent tree does not match function output, "
f"expected {out_tree()} but got {out_tree2}")
if not all(map(core.typecheck, out_avals, out_cts)):
raise TypeError("cotangent type does not match function output, "
f"expected {out_avals} but got {out_cts}")
dummies = [ad.UndefinedPrimal(a) for a in in_avals]
in_cts = ad.backward_pass(jaxpr, True, const, dummies, out_cts)
in_cts = map(ad.instantiate_zeros, in_cts)
return tree_unflatten(in_tree, in_cts)
# Ensure that transposed_fun is a PyTree
return Partial(transposed_fun, const)
@overload
def make_jaxpr(
fun: Callable,
static_argnums: int | Sequence[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
return_shape: Literal[False] = ...,
) -> Callable[..., core.ClosedJaxpr]:
...
@overload
def make_jaxpr(
fun: Callable,
static_argnums: int | Sequence[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
return_shape: Literal[True] = ...,
) -> Callable[..., tuple[core.ClosedJaxpr, Any]]:
...
@partial(api_boundary, repro_api_name="jax.make_japr")
def make_jaxpr(
fun: Callable,
static_argnums: int | Sequence[int] = (),
axis_env: Sequence[tuple[AxisName, int]] | None = None,
return_shape: bool = False,
) -> Callable[..., core.ClosedJaxpr | tuple[core.ClosedJaxpr, Any]]:
"""Create a function that returns the jaxpr of ``fun`` given example args.
Args:
fun: The function whose ``jaxpr`` is to be computed. Its positional
arguments and return value should be arrays, scalars, or standard Python
containers (tuple/list/dict) thereof.
static_argnums: See the :py:func:`jax.jit` docstring.
axis_env: Optional, a sequence of pairs where the first element is an axis
name and the second element is a positive integer representing the size of
the mapped axis with that name. This parameter is useful when lowering
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of :py:func:`jax.pmap`.
return_shape: Optional boolean, defaults to ``False``. If ``True``, the
wrapped function returns a pair where the first element is the
``ClosedJaxpr`` representation of ``fun`` and the second element is a
pytree with the same structure as the output of ``fun`` and where the
leaves are objects with ``shape`` and ``dtype`` attributes representing
the corresponding types of the output leaves.
Returns:
A wrapped version of ``fun`` that when applied to example arguments returns
a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
argument ``return_shape`` is ``True``, then the returned function instead
returns a pair where the first element is the ``ClosedJaxpr``
representation of ``fun`` and the second element is a pytree representing
the structure, shape, dtypes, and named shapes of the output of ``fun``.
A ``jaxpr`` is JAX's intermediate representation for program traces. The
``jaxpr`` language is based on the simply-typed first-order lambda calculus
with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
``jaxpr``, which we can inspect to understand what JAX is doing internally.
The ``jaxpr`` returned is a trace of ``fun`` abstracted to
:py:class:`ShapedArray` level. Other levels of abstraction exist internally.
We do not describe the semantics of the ``jaxpr`` language in detail here, but
instead give a few examples.
>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a:f32[]. let
b:f32[] = cos a
c:f32[] = sin a
_:f32[] = sin b
d:f32[] = cos b
e:f32[] = mul 1.0:f32[] d
f:f32[] = neg e
g:f32[] = mul f c
in (g,) }
"""
try:
hash(fun)
weakref.ref(fun)
except TypeError:
fun = partial(fun)
@wraps(fun)
@api_boundary
def make_jaxpr_f(*args, **kwargs):
with core.extend_axis_env_nd(axis_env or []):
traced = jit(fun, static_argnums=static_argnums).trace(*args, **kwargs)
# `jit` converts tracers in consts to args but `make_jaxpr` callers expect
# consts not to be converted.
num_consts = traced._num_consts
if num_consts:
jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr, num_consts)
jaxpr = core.ClosedJaxpr(jaxpr_, traced._consts)
else:
jaxpr = traced.jaxpr
if return_shape:
return jaxpr, traced.out_info
return jaxpr
make_jaxpr_f.__module__ = "jax"
if hasattr(fun, "__qualname__"):
make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
if hasattr(fun, "__name__"):
make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
return make_jaxpr_f
def _infer_src_sharding(src, x, x_aval) -> Sharding | None:
if src is not None:
return src
if isinstance(x, array.ArrayImpl):
return x.sharding
if isinstance(x, core.Tracer):
val = x.to_concrete_value()
if val is not None and isinstance(val, array.ArrayImpl):
return val.sharding
if x_aval is not core.abstract_token and x_aval.sharding.mesh.are_all_axes_explicit:
return x_aval.sharding.update(
memory_kind=core.mem_space_to_kind(x_aval.memory_space))
return None
@util.cache(max_size=2048, trace_context_in_key=False)
def _check_string_compatible_sharding(s):
"""Checks if target devices are compatible with string arrays."""
if isinstance(s, xc.Device) and s.device_kind == "cpu":
return
if (isinstance(s, Sharding)
and s._internal_device_list[0].device_kind == "cpu"):
return
raise TypeError(
"String arrays can only be sharded to CPU devices. Received"
f" unsupported device or sharding: {s}")
@util.cache(max_size=2048, trace_context_in_key=False)
def _check_sharding(aval, s):
if (s is not None and
not isinstance(s, (xc.Device, Sharding, Format, core.MemorySpace))):
raise ValueError(
"`jax.device_put` only accepts `None`, `jax.sharding.Sharding`,"
" `jax.Device`, `Format`, `jax.memory.Space` or a pytree of these"
f" values. Received invalid value: {s}")
if isinstance(aval, core.ShapedArray) and aval.dtype == dtypes.string_dtype:
_check_string_compatible_sharding(s)
if isinstance(s, Sharding):
if isinstance(aval, core.AbstractToken):
aval = core.get_token_aval()
pjit.pjit_check_aval_sharding(
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False
)
s.shard_shape(aval.shape) # should raise an Error if incompatible
def pspec_to_sharding(name, val):
if isinstance(val, P):
mesh = get_concrete_mesh()
if mesh.empty:
raise ValueError(
"Please set a mesh via `jax.set_mesh` if a PartitionSpec is"
f" passed to {name}")
return NamedSharding(mesh, val)
return val
def device_put(
x,
device: None | xc.Device | Sharding | P | Format | Any = None,
*, src: None | xc.Device | Sharding | P | Format | Any = None,
donate: bool | Any = False, may_alias: bool | None | Any = None):
"""Transfers ``x`` to ``device``.
Args:
x: An array, scalar, or (nested) standard Python container thereof.
device: The (optional) :py:class:`Device`, :py:class:`Sharding`, or a
(nested) :py:class:`Sharding` in standard Python container (must be a tree
prefix of ``x``), representing the device(s) to which ``x`` should be
transferred. If given, then the result is committed to the device(s).
src: The (optional) :py:class:`Device`, :py:class:`Sharding`, or a (nested)
:py:class:`Sharding` in standard Python container (must be a tree prefix
of ``x``), representing the device(s) on which ``x`` belongs.
donate: bool or a (nested) bool in standard Python container (must be a tree
prefix of ``x``). If True, ``x`` can be overwritten and marked deleted in
the caller. This is best effort. JAX will donate if possible, otherwise it
won't. The input buffer (in the future) will always be deleted if donated.
may_alias: bool or None or a (nested) bool in standard Python container
(must be a tree prefix of ``x``). If False, `x` will be copied. If true,
`x` may be aliased depending on the runtime's implementation.
Returns:
A copy of ``x`` that resides on ``device``.
If the ``device`` parameter is ``None``, then this operation behaves like the
identity function if the operand is on any device already, otherwise it
transfers the data to the default device, uncommitted.
This function is always asynchronous, i.e. returns immediately without
blocking the calling Python thread until any transfers are completed.
"""
with config.explicit_device_put_scope():
x_flat, treedef = tree_flatten(x)
x_avals = [shaped_abstractify(x) for x in x_flat]
if (device is None or
isinstance(device, (xc.Device, Sharding, core.MemorySpace))):
device_flat = [device] * len(x_flat)
else:
device_flat = flatten_axes("device_put device", treedef, device)
if (src is None or
isinstance(src, (xc.Device, Sharding, core.MemorySpace))):
src_flat = list(map(partial(_infer_src_sharding, src), x_flat, x_avals))
else:
src_flat = flatten_axes("device_put source", treedef, src)
src_flat = list(map(_infer_src_sharding, src_flat, x_flat, x_avals))
device_flat = map(partial(pspec_to_sharding, 'device_put'), device_flat)
src_flat = map(partial(pspec_to_sharding, 'device_put'), src_flat)
if isinstance(donate, bool):
donate_flat = [donate] * len(x_flat)
else:
donate_flat = flatten_axes("device_put donate", treedef, donate)
if isinstance(may_alias, bool):
may_alias_flat = [may_alias] * len(x_flat)
else:
may_alias_flat = flatten_axes("device_put may_alias", treedef, may_alias)
copy_semantics = []
for m, d in zip(may_alias_flat, donate_flat):
if m and d:
raise ValueError('may_alias and donate cannot be True at the same time.')
if m is None:
m = not d
if m and not d:
copy_semantics.append(dispatch.ArrayCopySemantics.REUSE_INPUT)
elif not m and d:
copy_semantics.append(dispatch.ArrayCopySemantics.DONATE_INPUT)
else:
assert not m and not d
copy_semantics.append(dispatch.ArrayCopySemantics.ALWAYS_COPY)
dst_avals = []
for x_aval, d in zip(x_avals, device_flat):
aval = dispatch.update_dp_aval(x_aval, d)
dst_avals.append(aval)
_check_sharding(aval, d)
if core.trace_state_clean():
out_flat = dispatch._batched_device_put_impl(
*x_flat, devices=device_flat, srcs=src_flat,
copy_semantics=copy_semantics, dst_avals=dst_avals)
else:
out_flat = dispatch.device_put_p.bind(
*x_flat, devices=tuple(device_flat), srcs=tuple(src_flat),
copy_semantics=tuple(copy_semantics))
return tree_unflatten(treedef, out_flat)
def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): # noqa: F811
"""Transfer array shards to specified devices and form Array(s).
Args:
shards: A sequence of arrays, scalars, or (nested) standard Python
containers thereof representing the shards to be stacked together to form
the output. The length of ``shards`` must equal the length of ``devices``.
devices: A sequence of :py:class:`Device` instances representing the devices
to which corresponding shards in ``shards`` will be transferred.
This function is always asynchronous, i.e. returns immediately.
Returns:
A Array or (nested) Python container thereof representing the
elements of ``shards`` stacked together, with each shard backed by physical
device memory specified by the corresponding entry in ``devices``.
Examples:
Passing a list of arrays for ``shards`` results in a sharded array
containing a stacked version of the inputs:
>>> import jax
>>> devices = jax.local_devices()
>>> x = [jax.numpy.ones(5) for device in devices]
>>> y = jax.device_put_sharded(x, devices) # doctest: +SKIP
>>> np.allclose(y, jax.numpy.stack(x)) # doctest: +SKIP
True
Passing a list of nested container objects with arrays at the leaves for
``shards`` corresponds to stacking the shards at each leaf. This requires
all entries in the list to have the same tree structure:
>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))]
>>> y = jax.device_put_sharded(x, devices) # doctest: +SKIP
>>> type(y) # doctest: +SKIP
<class 'tuple'>
>>> y0 = jax.device_put_sharded([a for a, b in x], devices) # doctest: +SKIP
>>> y1 = jax.device_put_sharded([b for a, b in x], devices) # doctest: +SKIP
>>> np.allclose(y[0], y0) # doctest: +SKIP
True
>>> np.allclose(y[1], y1) # doctest: +SKIP
True
See Also:
- device_put
- device_put_replicated
"""
# TODO(jakevdp): provide a default for devices that considers both local
# devices and pods
if not isinstance(shards, Sequence):
raise TypeError("device_put_sharded `shards` input must be a sequence; "
f"got {type(shards)}")
if len(shards) != len(devices):
raise ValueError(f"len(shards) = {len(shards)} must equal "
f"len(devices) = {len(devices)}.")
def _device_put_sharded(*xs):
avals = [core.typeof(x) for x in xs]
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
if a1 != a2)
raise ValueError("the shards passed to device_put_sharded must have "
f"consistent shape and dtype, but got {a1} and {a2}.")
stacked_aval = avals[0].update(shape=(len(devices),) + avals[0].shape)
mesh = Mesh(np.array(devices), ("_device_put_sharded",))
sharding = NamedSharding(mesh, P("_device_put_sharded"))
if dtypes.issubdtype(stacked_aval.dtype, dtypes.extended):
return stacked_aval.dtype._rules.device_put_sharded(xs, stacked_aval, sharding, devices)
ys = []
for x in xs:
if not isinstance(x, (np.ndarray, basearray.Array)):
x = np.asarray(x)
ys.append(x[None])
return pxla.batched_device_put(stacked_aval, sharding, ys, list(devices))
with config.explicit_device_put_scope():
return tree_map(_device_put_sharded, *shards)
def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
"""Transfer array(s) to each specified device and form Array(s).
Args:
x: an array, scalar, or (nested) standard Python container thereof
representing the array to be replicated to form the output.
devices: A sequence of :py:class:`Device` instances representing the devices
to which ``x`` will be transferred.
This function is always asynchronous, i.e. returns immediately.
Returns:
An Array or (nested) Python container thereof representing the
value of ``x`` broadcasted along a new leading axis of size
``len(devices)``, with each slice along that new leading axis backed by
memory on the device specified by the corresponding entry in ``devices``.
Examples:
Passing an array:
>>> import jax
>>> devices = jax.local_devices()
>>> x = jax.numpy.array([1., 2., 3.])
>>> y = jax.device_put_replicated(x, devices) # doctest: +SKIP
>>> np.allclose(y, jax.numpy.stack([x for _ in devices])) # doctest: +SKIP
True
See Also:
- device_put
- device_put_sharded
"""
if not isinstance(devices, Sequence) or not devices:
raise ValueError("`devices` argument to `device_put_replicated must be "
"a non-empty sequence.")
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), 0, core.typeof(x))
assert isinstance(aval, ShapedArray)
if isinstance(x, (np.ndarray, basearray.Array)):
buf = device_put(x[None], devices[0])
else:
buf = device_put(x, devices[0])[None]
mesh = Mesh(np.array(devices), ("_device_put_replicated",))
sharding = NamedSharding(mesh, P("_device_put_replicated"))
if dtypes.issubdtype(aval.dtype, dtypes.extended):
return aval.dtype._rules.device_put_replicated(buf, aval, sharding, devices)
return pxla.batched_device_put(aval, sharding, [buf] * len(devices), devices)
with config.explicit_device_put_scope():
return tree_map(_device_put_replicated, x)
# TODO(mattjj): consider revising
def _device_get(x):
if isinstance(x, core.Tracer):
return x
# Extended dtypes dispatch via their device_get rule.
if isinstance(x, basearray.Array) and dtypes.issubdtype(x.dtype, dtypes.extended):
bufs, tree = tree_util.dispatch_registry.flatten(x)
return tree.unflatten(device_get(bufs))
# Other types dispatch via their __array__ method.
try:
toarray = x.__array__
except AttributeError:
return x
else:
return toarray()
def device_get(x: Any):
"""Transfer ``x`` to host.
If ``x`` is a pytree, then the individual buffers are copied in parallel.
Args:
x: An array, scalar, Array or (nested) standard Python container thereof
representing the array to be transferred to host.
Returns:
An array or (nested) Python container thereof representing the
value of ``x``.
Examples:
Passing a Array:
>>> import jax
>>> x = jax.numpy.array([1., 2., 3.])
>>> jax.device_get(x)
array([1., 2., 3.], dtype=float32)
Passing a scalar (has no effect):
>>> jax.device_get(1)
1
See Also:
- device_put
- device_put_sharded
- device_put_replicated
"""
with config.explicit_device_get_scope():
for y in tree_leaves(x):
try:
y.copy_to_host_async()
except AttributeError:
pass
return tree_map(_device_get, x)
@partial(api_boundary, repro_api_name="jax.eval_shape")
def eval_shape(fun: Callable, *args, **kwargs):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
This utility function is useful for performing shape inference. Its
input/output behavior is defined by::
def eval_shape(fun, *args, **kwargs):
out = fun(*args, **kwargs)
shape_dtype_struct = lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.tree_util.tree_map(shape_dtype_struct, out)
But instead of applying ``fun`` directly, which might be expensive, it uses
JAX's abstract interpretation machinery to evaluate the shapes without doing
any FLOPs.
Using :py:func:`eval_shape` can also catch shape errors, and will raise same
shape errors as evaluating ``fun(*args, **kwargs)``.
Args:
fun: The function whose output shape should be evaluated.
*args: a positional argument tuple of arrays, scalars, or (nested) standard
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
those types. Since only the ``shape`` and ``dtype`` attributes are
accessed, one can use :class:`jax.ShapeDtypeStruct` or another container
that duck-types as ndarrays (note however that duck-typed objects cannot
be namedtuples because those are treated as standard Python containers).
**kwargs: a keyword argument dict of arrays, scalars, or (nested) standard
Python containers (pytrees) of those types. As in ``args``, array values
need only be duck-typed to have ``shape`` and ``dtype`` attributes.
Returns:
out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>>
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
>>> A = jax.ShapeDtypeStruct((2000, 3000), jnp.float32)
>>> x = jax.ShapeDtypeStruct((3000, 1000), jnp.float32)
>>> out = jax.eval_shape(f, A, x) # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
>>> print(out.dtype)
float32
All arguments passed via :func:`eval_shape` will be treated as dynamic;
static arguments can be included via closure, for example using :func:`functools.partial`:
>>> import jax
>>> from jax import lax
>>> from functools import partial
>>> import jax.numpy as jnp
>>>
>>> x = jax.ShapeDtypeStruct((1, 1, 28, 28), jnp.float32)
>>> kernel = jax.ShapeDtypeStruct((32, 1, 3, 3), jnp.float32)
>>>
>>> conv_same = partial(lax.conv_general_dilated, window_strides=(1, 1), padding="SAME")
>>> out = jax.eval_shape(conv_same, x, kernel)
>>> print(out.shape)
(1, 32, 28, 28)
>>> print(out.dtype)
float32
"""
if type(fun) is xc._xla.PjitFunction:
return fun.trace(*args, **kwargs).out_info # pyrefly: ignore[missing-attribute]
try: hash(fun)
except TypeError: fun = partial(fun)
return jit(fun).trace(*args, **kwargs).out_info
@partial(api_boundary, repro_api_name="jax.named_call")
def named_call(
fun: F,
*,
name: str | None = None,
) -> F:
"""Adds a user specified name to a function when staging out JAX computations.
When staging out computations for just-in-time compilation to XLA (or other
backends such as TensorFlow) JAX runs your Python program but by default does
not preserve any of the function names or other metadata associated with it.
This can make debugging the staged out (and/or compiled) representation of
your program complicated because there is limited context information for each
operation being executed.
`named_call` tells JAX to stage the given function out as a subcomputation
with a specific name. When the staged out program is compiled with XLA these
named subcomputations are preserved and show up in debugging utilities like
the TensorFlow Profiler in TensorBoard. Names are also preserved when staging
out JAX programs to TensorFlow using :func:`experimental.jax2tf.convert`.
Args:
fun: Function to be wrapped. This can be any Callable.
name: Optional. The prefix to use to name all sub computations created
within the name scope. Use the fun.__name__ if not specified.
Returns:
A version of ``fun`` that is wrapped in a ``named_scope``.
"""
if name is None:
name = fun.__name__
return source_info_util.extend_name_stack(name)(fun)
def named_scope(
name: str,
) -> source_info_util.ExtendNameStackContextManager:
"""A context manager that adds a user specified name to the JAX name stack.
When staging out computations for just-in-time compilation to XLA (or other
backends such as TensorFlow) JAX does not, by default, preserve the names
(or other source metadata) of Python functions it encounters.
This can make debugging the staged out (and/or compiled) representation of
your program complicated because there is limited context information for each
operation being executed.
``named_scope`` tells JAX to stage the given function with additional
annotations on the underlying operations. JAX internally keeps track of these
annotations in a name stack. When the staged out program is compiled with XLA
these annotations are preserved and show up in debugging utilities like the
TensorFlow Profiler in TensorBoard. Names are also preserved when staging out
JAX programs to TensorFlow using :func:`experimental.jax2tf.convert`.
Args:
name: The prefix to use to name all operations created within the name
scope.
Yields:
Yields ``None``, but enters a context in which `name` will be appended to
the active name stack.
Examples:
``named_scope`` can be used as a context manager inside compiled functions:
>>> import jax
>>>
>>> @jax.jit
... def layer(w, x):
... with jax.named_scope("dot_product"):
... logits = w.dot(x)
... with jax.named_scope("activation"):
... return jax.nn.relu(logits)
It can also be used as a decorator:
>>> @jax.jit
... @jax.named_scope("layer")
... def layer(w, x):
... logits = w.dot(x)
... return jax.nn.relu(logits)
"""
if not isinstance(name, str):
raise TypeError("named_scope name argument must be a string.")
return source_info_util.extend_name_stack(name)
def effects_barrier():
"""Waits until existing functions have completed any side-effects."""
dispatch.runtime_tokens.block_until_ready()
def block_until_ready(x):
"""
Tries to call a ``block_until_ready`` method on pytree leaves.
Args:
x: a pytree, usually with at least some JAX array instances at its leaves.
Returns:
A pytree with the same structure and values of the input, where the values
of all JAX array leaves are ready.
"""
def try_to_block(x):
try:
return x.block_until_ready()
except AttributeError:
return x
arrays = []
for leaf in tree_leaves(x):
if isinstance(leaf, array.ArrayImpl):
arrays.append(leaf)
else:
try_to_block(leaf)
if not arrays:
# `arrays` will be empty if tree_leaves(x) is empty or all leaves are not
# jax.Array.
pass
elif len(arrays) == 1:
# Fast path for single array.
try_to_block(arrays[0])
else:
# Optimized for multiple arrays.
xc.batched_block_until_ready(arrays)
return x
def copy_to_host_async(x):
"""
Tries to call a ``copy_to_host_async`` method on pytree leaves.
For each leaf this method will try to call the ``copy_to_host_async`` method
on the leaf. If the leaf is not a JAX array, or if the leaf does not have a
``copy_to_host_async`` method, then this method will do nothing to the leaf.
Args:
x: a pytree, usually with at least some JAX array instances at its leaves.
Returns:
A pytree with the same structure and values of the input, where the host
copy of the values of all JAX array leaves are started.
"""
for leaf in tree_leaves(x):
try:
copy_fn = leaf.copy_to_host_async
except AttributeError:
pass
else:
copy_fn()
return x
def clear_backends():
"""
Clear all backend clients so that new backend clients can be created later.
"""
xb._clear_backends()
util.clear_all_caches()
pjit._cpp_pjit_cache_fun_only.clear()
pjit._cpp_pjit_cache_explicit_attributes.clear()
xc._xla.PjitFunctionCache.clear_all()
@atexit.register
def clean_up():
if xb._default_backend is not None:
clear_backends()
clear_caches()
# Shut down distributed system if it exists. Otherwise, this is a no-op.
distributed.shutdown()
def live_arrays(platform=None):
"""Return all live arrays in the backend for `platform`.
If platform is None, it is the default backend.
"""
return xb.get_backend(platform).live_arrays()
def clear_caches():
"""Clear all compilation and staging caches.
This doesn't clear the persistent cache; to disable it (e.g. for benchmarks),
set the jax_enable_compilation_cache config option to False.
"""
# Clear all lu.cache, util.cache and util.weakref_lru_cache instances
# (used for staging and Python-dispatch compiled executable caches).
util.clear_all_caches()
# Clear all C++ compiled executable caches for pjit
pjit._cpp_pjit_cache_fun_only.clear()
pjit._cpp_pjit_cache_explicit_attributes.clear()
xc._xla.PjitFunctionCache.clear_all()