# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Callable, Iterable, Sequence import inspect import operator from functools import partial, lru_cache import re from typing import Any, NoReturn from jax._src import core from jax._src import config from jax._src import dtypes from jax._src.state.types import AbstractRef from jax._src.tree_util import ( PyTreeDef, tree_flatten, tree_unflatten, treedef_children, broadcast_prefix, prefix_errors, none_leaf_registry, broadcast_flattened_prefix_with_treedef, treedef_is_leaf, tree_structure, tracing_registry) from jax._src import linear_util as lu from jax._src.util import safe_map, HashableFunction, Unhashable, safe_zip from jax._src import traceback_util traceback_util.register_exclusion(__file__) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip def _ensure_index(x: Any) -> int | tuple[int, ...]: """Ensure x is either an index or a tuple of indices.""" x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.") try: return operator.index(x) except TypeError: return tuple(map(operator.index, x)) def _ensure_index_tuple(x: Any) -> tuple[int, ...]: """Convert x to a tuple of indices.""" x = core.concrete_or_error(None, x, "expected a static index or sequence of indices.") try: return (operator.index(x),) except TypeError: return tuple(map(operator.index, x)) def _ensure_str(x: str) -> str: if not isinstance(x, str): raise TypeError(f"argument is not a string: {x}") return x def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]: """Convert x to a tuple of strings.""" if isinstance(x, str): return (x,) else: return tuple(map(_ensure_str, x)) @lu.transformation_with_aux2 def flatten_fun(f: Callable, store: lu.Store, in_tree: PyTreeDef, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) ans = f(*py_args, **py_kwargs) ans, out_tree = tree_flatten(ans) store.store(out_tree) return ans def apply_flat_fun(fun, io_tree, *py_args): in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten((py_args, {})) if in_tree != in_tree_expected: raise TypeError(f"Expected {in_tree_expected}, got {in_tree}") ans = fun(*args) return tree_unflatten(out_tree, ans) @lu.transformation_with_aux2 def flatten_fun_nokwargs(f: Callable, store: lu.Store, in_tree: PyTreeDef, *args_flat): py_args = tree_unflatten(in_tree, args_flat) ans = f(*py_args) ans, out_tree = tree_flatten(ans) store.store(out_tree) return ans def apply_flat_fun_nokwargs(fun, io_tree, py_args): in_tree_expected, out_tree = io_tree args, in_tree = tree_flatten(py_args) if in_tree != in_tree_expected: raise TypeError(f"Expected {in_tree_expected}, got {in_tree}") ans = fun(*args) return tree_unflatten(out_tree, ans) @lu.transformation_with_aux2 def flatten_fun_nokwargs2(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) pair = f(*py_args) if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise TypeError("expected function with aux output to return a two-element " f"tuple, but got type {type(pair)} with value {pair!r}") ans, aux = pair ans_flat, ans_tree = tree_flatten(ans) aux_flat, aux_tree = tree_flatten(aux) store.store((ans_tree, aux_tree)) return ans_flat, aux_flat class _HashableWithStrictTypeEquality: """Box object used when comparing static arguments as a jit key. Requires exact type equality using `is` and value equality.""" __slots__ = ["val"] def __init__(self, val): self.val = val def __hash__(self): return hash(self.val) def __eq__(self, other): return type(self.val) is type(other.val) and self.val == other.val _POSITIONAL_ARGUMENTS = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD ) def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None: """ Validate that the argnums are sensible for a given function. For functions that accept a variable number of positions arguments (`f(..., *args)`) all positive argnums are considered valid. """ n_pos_args = 0 for param in sig.parameters.values(): if param.kind in _POSITIONAL_ARGUMENTS: n_pos_args += 1 elif param.kind is inspect.Parameter.VAR_POSITIONAL: # We can have any number of positional arguments return if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args): raise ValueError(f"Jitted function has {argnums_name}={argnums}, " f"but only accepts {n_pos_args} positional arguments.") _INVALID_KEYWORD_ARGUMENTS = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.VAR_POSITIONAL ) _KEYWORD_ARGUMENTS = ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ) def _validate_argnames( sig: inspect.Signature, argnames: tuple[str, ...], argnames_name: str ) -> None: """ Validate that the argnames are sensible for a given function. For functions that accept a variable keyword arguments (`f(..., **kwargs)`) all argnames are considered valid except those marked as position-only (`f(pos_only, /, ...)`). """ var_kwargs = False valid_kwargs: set[str] = set() invalid_kwargs: set[str] = set() for param_name, param in sig.parameters.items(): if param.kind in _KEYWORD_ARGUMENTS: valid_kwargs.add(param_name) elif param.kind is inspect.Parameter.VAR_KEYWORD: var_kwargs = True elif param.kind in _INVALID_KEYWORD_ARGUMENTS: invalid_kwargs.add(param_name) # Check whether any kwargs are invalid due to position only if invalid_argnames := (invalid_kwargs & set(argnames)): raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} " f"in {argnames_name}. These are positional-only") # Takes any kwargs if var_kwargs: return # Check that all argnames exist on function if invalid_argnames := (set(argnames) - valid_kwargs): raise ValueError(f"Jitted function has invalid argnames {invalid_argnames} " f"in {argnames_name}. Function does not take these args.") def argnums_partial(f: lu.WrappedFun, dyn_argnums: int | Sequence[int], args: Sequence, require_static_args_hashable=True): dyn_argnums = _ensure_index_tuple(dyn_argnums) dyn_argnums = _ensure_inbounds(False, len(args), dyn_argnums) fixed_args: list if require_static_args_hashable: fixed_args = [] for i, arg in enumerate(args): if i in dyn_argnums: continue if not is_hashable(arg): raise ValueError( "Non-hashable static arguments are not supported, as this can lead " f"to unexpected cache-misses. Static argument (index {i}) of type " f"{type(arg)} for function {f.__name__} is non-hashable.") fixed_args.append(_HashableWithStrictTypeEquality(arg)) else: fixed_args = [Unhashable(arg) for i, arg in enumerate(args) if i not in dyn_argnums] dyn_args = tuple(args[i] for i in dyn_argnums) return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args def prepend_static_args(f, static_args): return _prepend_static_args(f, tuple(Unhashable(arg) for arg in static_args)) @lu.transformation2 def _prepend_static_args(f, static_args, *args, **kwargs): static_args = tuple(arg.val for arg in static_args) all_args = static_args + args return f(*all_args, **kwargs) def _ensure_inbounds(allow_invalid: bool, num_args: int, argnums: Sequence[int] ) -> tuple[int, ...]: """Ensure argnum is within bounds. Also resolves negative argnums.""" result = [] for i in argnums: if i >= num_args and allow_invalid: continue if not -num_args <= i < num_args: raise ValueError( "Positional argument indices, e.g. for `static_argnums`, must have " "value greater than or equal to -len(args) and less than len(args), " f"but got value {i} for len(args) == {num_args}.") result.append(i % num_args) # Resolve negative return tuple(result) @lu.transformation2 def _argnums_partial(_fun: Callable, _dyn_argnums: Sequence[int], _fixed_args: Sequence, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(_fixed_args) + len(dyn_args)) for i, arg in zip(_dyn_argnums, dyn_args): args[i] = arg fixed_args_ = iter(_fixed_args) args = [next(fixed_args_).val if x is sentinel else x for x in args] assert next(fixed_args_, sentinel) is sentinel return _fun(*args, **kwargs) @lru_cache(maxsize=4096) def donation_vector(donate_argnums, donate_argnames, in_tree, kws: bool = True) -> tuple[bool, ...]: """Returns a tuple with a boolean value for each leaf in args and kwargs. What if a user specifies donate_argnums but calls the function with kwargs or vice-versa? In that case, in `resolve_argnums` using the signature of the function, the counterpart (donate_argnames or donate_argnums respectively) is calculated so when this function is called both donate_argnums and donate_argnames are available. This allows JAX to donate kwargs when only donate_argnums is specified and vice-versa. When both donate_argnums and donate_argnames are specified, only the args and kwargs specified are donated. """ res: list[bool] = [] if kws: args_tree, kwargs_tree = treedef_children(in_tree) else: args_tree, kwargs_tree = in_tree, None for i, arg in enumerate(args_tree.children()): donate = bool(i in donate_argnums) res.extend((donate,) * arg.num_leaves) if kwargs_tree is not None: for key, val in zip(kwargs_tree.node_data()[1], kwargs_tree.children()): # pyrefly: ignore[unsupported-operation] donate = key in donate_argnames res.extend((donate,) * val.num_leaves) return tuple(res) def rebase_donate_argnums(donate_argnums, static_argnums) -> tuple[int, ...]: """Shifts donate to account for static. >>> rebase_donate_argnums((3, 4), (0, 1)) (1, 2) Args: donate_argnums: An iterable of ints. static_argnums: An iterable of ints. Returns: A tuple of unique, sorted integer values based on donate_argnums with each element offset to account for static_argnums. """ if not (static_argnums or donate_argnums): return tuple(sorted(donate_argnums)) static_argnums = sorted(set(static_argnums)) donate_argnums = sorted(set(donate_argnums)) i = j = o = 0 out = [] while j < len(donate_argnums): if i < len(static_argnums) and static_argnums[i] == donate_argnums[j]: raise ValueError(f"`static_argnums` {static_argnums} and " f"`donate_argnums` {donate_argnums} cannot intersect.") if i < len(static_argnums) and static_argnums[i] < donate_argnums[j]: o += 1 i += 1 else: out.append(donate_argnums[j] - o) j += 1 return tuple(out) def is_hashable(arg): try: hash(arg) return True except TypeError: return False SENTINEL = object() def flatten_axes(name, treedef, axis_tree, *, kws=False, tupled_args=False): # given an axis spec tree axis_tree (a pytree with integers and Nones at the # leaves, i.e. the Nones are to be considered leaves) that is a tree prefix of # the given treedef, build a complete axis spec tree with the same structure # and return the flattened result axis_tree_leaves, axis_treedef = none_leaf_registry.flatten(axis_tree) try: axes = broadcast_flattened_prefix_with_treedef( axis_tree_leaves, axis_treedef, treedef) except ValueError: if kws: # if keyword arguments are included in the tree, we make adapt the error # message only to be about the positional arguments treedef, _ = treedef_children(treedef) axis_tree, _ = axis_tree hint = "" if tupled_args: hint += (f" Note that {name} that are non-trivial pytrees should always be " f"wrapped in a tuple representing the argument list.") if len(treedef.children()) == 1: try: flatten_axes(name, treedef, (axis_tree,)) except ValueError: pass # That's not the issue. else: hint += (f" In particular, you're passing in a single argument which " f"means that {name} might need to be wrapped in " f"a singleton tuple.") dummy_tree = tree_unflatten(treedef, [PytreeLeaf()] * treedef.num_leaves) errors = prefix_errors(axis_tree, dummy_tree) if errors: details = "\n ".join(e(name).args[0] for e in errors) prefix_err_msg = ( f"\n Mismatch details ({len(errors)} found):\n {details}") raise ValueError( f"{name} specification must be a tree prefix of the " f"corresponding value; {hint}{prefix_err_msg}") from None # At this point we've failed to find a tree prefix error. assert False, "unreachable code" assert len(axes) == treedef.num_leaves return axes class PytreeLeaf: def __repr__(self): return "pytree leaf" def flatten_axis_resources(what, tree, shardings, tupled_args): try: return tuple(flatten_axes(what, tree, shardings, tupled_args=tupled_args)) except ValueError: pass # Raise a tree prefix error below # Tree leaves are always valid prefixes, so if there was a prefix error as # assumed here, axis_resources must not be a leaf. assert not treedef_is_leaf(tree_structure(shardings)) # Check the type directly rather than using isinstance because of namedtuples. if tupled_args and (type(shardings) is not tuple or len(shardings) != len(tree.children())): # We know axis_resources is meant to be a tuple corresponding to the args # tuple, but while it is a non-leaf pytree, either it wasn't a tuple or it # wasn't the right length. msg = (f"{what} specification must be a tree prefix of the positional " f"arguments tuple. In particular, {what} must either be a Sharding, " "a PartitionSpec, or a tuple of length equal to the number of " "positional arguments.") # If `tree` represents an args tuple, then `axis_resources` must be a tuple. # TODO(mattjj,apaszke): disable implicit list casts, remove 'or list' below if type(shardings) is not tuple: msg += f" But {what} is not a tuple: got {type(shardings)} instead." elif len(shardings) != len(tree.children()): msg += (f" But {what} is the wrong length: got a tuple or list of length " f"{len(shardings)} for an args tuple of length " f"{len(tree.children())}.") # As an extra hint, let's check if the user just forgot to wrap # shardings in a singleton tuple. if len(tree.children()) == 1: try: flatten_axes(what, tree, (shardings,)) except ValueError: pass # That's not the issue. else: msg += (f" Given the corresponding argument being " f"passed, it looks like {what} might need to be wrapped in " f"a singleton tuple.") raise ValueError(msg) axis_tree = shardings # Because we only have the `tree` treedef and not the full pytree here, # we construct a dummy tree to compare against. Revise this in callers? dummy_tree = tree_unflatten(tree, [PytreeLeaf()] * tree.num_leaves) errors = prefix_errors(axis_tree, dummy_tree) if errors: details = "\n".join(e(what).args[0] for e in errors) raise ValueError( f"Mismatch details ({len(errors)} found):\n{details}" ) # At this point we've failed to find a tree prefix error. assert False, "Please open a bug report!" # This should be unreachable. def flat_out_axes( f: lu.WrappedFun, out_spec: Any ) -> tuple[lu.WrappedFun, Callable]: leaves, treedef = tree_flatten(out_spec) f, out_axes = _flat_out_axes(f, tuple(leaves), treedef) return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef)) @lu.transformation_with_aux2 def _flat_out_axes(_fun, _store, _leaves, _treedef, *args, **kwargs): ans = _fun(*args, **kwargs) spec = tree_unflatten(_treedef, _leaves) try: spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None)) except ValueError: e, *_ = prefix_errors(spec, ans) # TODO(mattjj): currently hardcoded for pmap; generalize to vmap in followup msg, = e('pmap out_axes').args msg += ("\n\nThe full pytree is the output of the pmapped function. Ensure " "that the `out_axes` argument to `pmap` is a pytree prefix of the " "pmapped function's output.") raise ValueError(msg) from None _store.store(spec_flat) return ans def check_callable(fun): # In Python 3.10+, the only thing stopping us from supporting staticmethods # is that we can't take weak references to them, which the C++ JIT requires. if isinstance(fun, staticmethod): raise TypeError(f"staticmethod arguments are not supported, got {fun}") if not callable(fun): raise TypeError(f"Expected a callable value, got {fun}") if inspect.isgeneratorfunction(fun): raise TypeError(f"Expected a function, got a generator function: {fun}") _POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD def infer_argnums_and_argnames( sig: inspect.Signature, argnums: int | Iterable[int] | None, argnames: str | Iterable[str] | None, ) -> tuple[tuple[int, ...], tuple[str, ...]]: """Infer missing argnums and argnames for a function with inspect.""" if argnums is None and argnames is None: return (), () if argnums is not None and argnames is not None: argnums = _ensure_index_tuple(argnums) argnames = _ensure_str_tuple(argnames) return argnums, argnames parameters = sig.parameters if argnums is None: assert argnames is not None argnames = _ensure_str_tuple(argnames) argnums = tuple( i for i, (k, param) in enumerate(parameters.items()) if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames ) else: argnums = _ensure_index_tuple(argnums) argnames = tuple( k for i, (k, param) in enumerate(parameters.items()) if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums ) return argnums, argnames def resolve_argnums( fun: Callable, signature: inspect.Signature | None, donate_argnums: int | Sequence[int] | None, donate_argnames: str | Iterable[str] | None, static_argnums: int | Sequence[int] | None, static_argnames: str | Iterable[str] | None, ) -> tuple[tuple[int, ...], tuple[str, ...], tuple[int, ...], tuple[str, ...]]: """Validates and completes the argnum/argname specification for a jit. * fills in any missing pieces (e.g., names given numbers, or vice versa), * validates the argument names/numbers against the function signature, * validates that donated and static arguments don't intersect. """ if signature is None: # Some built-in functions don't support signature. # See: https://github.com/python/cpython/issues/73485 # In this case no validation is done static_argnums = () if static_argnums is None else _ensure_index_tuple( static_argnums) static_argnames = () if static_argnames is None else _ensure_str_tuple( static_argnames) donate_argnums = () if donate_argnums is None else _ensure_index_tuple( donate_argnums) if donate_argnames is not None: raise ValueError(f"Getting the signature of function {fun} failed. " "Pass donate_argnums instead of donate_argnames.") assert donate_argnames is None donate_argnames = () else: # Infer argnums and argnames according to docstring # If nums is None and names is not None, then nums are inferred from the # names and vice-versa. static_argnums, static_argnames = infer_argnums_and_argnames( signature, static_argnums, static_argnames) donate_argnums, donate_argnames = infer_argnums_and_argnames( signature, donate_argnums, donate_argnames) # Validation _validate_argnums(signature, static_argnums, "static_argnums") _validate_argnames(signature, static_argnames, "static_argnames") _validate_argnums(signature, donate_argnums, "donate_argnums") _validate_argnames(signature, donate_argnames, "donate_argnames") # Compensate for static argnums absorbing args _assert_no_intersection(static_argnames, donate_argnames) return donate_argnums, donate_argnames, static_argnums, static_argnames def _assert_no_intersection(static_argnames, donate_argnames): out = set(static_argnames).intersection(set(donate_argnames)) if out: raise ValueError( "static_argnames and donate_argnames cannot intersect. Argument names " f"{out} appear in both static_argnames and donate_argnames") def resolve_kwargs(fun: Callable, args, kwargs) -> tuple[Any, ...]: """Resolve input arguments to positional following a function's signature. This will raise a TypeError if any keyword-only arguments were passed by the caller. """ if isinstance(fun, partial): # functools.partial should have an opaque signature. fun = lambda *args, **kwargs: None ba = inspect.signature(fun).bind(*args, **kwargs) ba.apply_defaults() if ba.kwargs: passed_kwargs = [k for k in ba.kwargs if k in kwargs] if passed_kwargs: raise TypeError( "The following keyword arguments could not be resolved to positions: " f"{', '.join(passed_kwargs)}" ) return ba.args def _dtype(x): try: return dtypes.result_type(x) except ValueError: return dtypes.result_type(getattr(x, 'dtype')) # This decorator exists to make it easier to monkey-patch APIs in JAX. # By default it does nothing, but it can be monkey-patched to do other things. def api_hook(fun, tag: str): return fun def debug_info( traced_for: str, fun: Callable, args: Sequence[Any], kwargs: dict[str, Any], *, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), result_paths_thunk: Callable[[], tuple[str, ...]] | core.InitialResultPaths = core.initial_result_paths, # TODO(necula): check if we really need this, e.g., to speed up tracing? sourceinfo: str | None = None, signature: inspect.Signature | None = None, ) -> core.DebugInfo: """Construct core.DebugInfo for a function given example args and kwargs. `args` and `kwargs` are example positional and keyword arguments, used with `inspect.Signature` to get the names of arguments. The arguments that are considered static for tracing purposes should be included, and designated using `static_argnums` and `static_argnames`. See docstring for linear_util.DebugInfo. """ res = getattr(fun, "__fun_debug_info__", None) if res is not None: return res if sourceinfo is None: sourceinfo = fun_sourceinfo(fun) if signature is None: signature = fun_signature(fun) arg_names = _non_static_arg_names(signature, args, kwargs, static_argnums, static_argnames) return core.DebugInfo(traced_for, sourceinfo, arg_names, result_paths_thunk) def fun_signature(fun: Callable) -> inspect.Signature | None: try: return inspect.signature(fun) except (ValueError, TypeError): return None def save_wrapped_fun_debug_info(wrapper: Callable, dbg: core.DebugInfo) -> None: setattr(wrapper, "__fun_debug_info__", dbg) _fun_name_re = re.compile(r"(?:)") # TODO(mattjj): make this function internal to this module def fun_sourceinfo(fun: Callable) -> str: # See DebugInfo.fun_src_info while isinstance(fun, partial): fun = fun.func fun = inspect.unwrap(fun) try: filename = fun.__code__.co_filename lineno = fun.__code__.co_firstlineno return f"{fun.__name__} at {filename}:{lineno}" except AttributeError: try: fun_str = str(fun) except: return "" # By contract, the function name has no spaces; also, we want to avoid # fun_sourceinfo of the form "", because it makes # lowering non-deterministic. if m := _fun_name_re.match(fun_str): return m.group(1) return "" def _non_static_arg_names(fn_signature: inspect.Signature | None, args: Sequence[Any], kwargs: dict[str, Any], static_argnums: Sequence[int], static_argnames: Sequence[str], ) -> tuple[str, ...]: """Returns the names of the non-static arguments. If the `fn_signature` is given then we get from it the names of the top-level arguments. In other cases, including when the `args` and `kwargs` do not match the signature, we use names like `args[0]`, `args[1]`, etc. """ # Use the same argument parsing as jit: positional followed by kwargs # sorted by keys. static = object() static_argnums_ = _ensure_inbounds(True, len(args), static_argnums) static_argnames_ = set(static_argnames) args_ = [static if i in static_argnums_ else x for i, x in enumerate(args)] kwargs_ = {k: static if k in static_argnames_ else x for k, x in kwargs.items()} ordered_args: Sequence[tuple[str, Any]] | None = None if fn_signature is not None: try: ba = fn_signature.bind(*args_, **kwargs_) except (ValueError, TypeError): pass else: # Do we have a **kwargs kwargs_name = next((name for name, p in fn_signature.parameters.items() if p.kind == inspect.Parameter.VAR_KEYWORD), None) # Positional argument are those not passed by keyword and not passed # by **kwargs. positional = [(name, x) for name, x in ba.arguments.items() if name not in kwargs and name != kwargs_name] # Keyword arguments are passed sorted by actual kwarg keyword sorted_kwargs = sorted(((name, x) for name, x in kwargs_.items()), key=lambda name_x: name_x[0]) sorted_kwargs = [(name if name in ba.arguments else f"{kwargs_name}['{name}']", x) for name, x in sorted_kwargs] ordered_args = positional + sorted_kwargs if ordered_args is None: positional = [("args", args_)] keyword = sorted([(f"kwargs['{name}']", x) for name, x in kwargs_.items() if x is not static], key=lambda name_x: name_x[0]) ordered_args = positional + keyword return tuple(f'{name}{lu._clean_keystr_arg_names(path)}' for name, x in ordered_args for path, l in tracing_registry.flatten_with_path(x)[0] if l is not static) # TODO(mattjj): make this function faster def check_no_aliased_ref_args(dbg_fn: Callable[[], core.DebugInfo], maybe_avals, args) -> None: assert config.mutable_array_checks.value refs: dict[int, int] = {} for i, (a, x) in enumerate(zip(maybe_avals, args)): if (isinstance(a, AbstractRef) and (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): dbg = dbg_fn() raise ValueError( "only one reference to a mutable array may be passed as an argument " f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} " f"the mutable array reference of type {a.str_short()} appeared at both " f"{dbg.arg_names[dup_idx] if dbg.arg_names is not None else 'unknown'} " f"and {dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'}." if dbg else f"at both flat index {dup_idx} and flat index {i}") from None def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> None: assert config.mutable_array_checks.value refs: set[int] = {id(core.get_referent(c)) for c in consts if isinstance(core.typeof(c), AbstractRef)} for i, x in enumerate(args): if id(core.get_referent(x)) in refs: a = core.shaped_abstractify(x) raise ValueError( f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable " f"array reference of type {a.str_short()} was both closed over and " f"passed as the argument " f"{dbg.safe_arg_names(len(args))[i]}" if dbg else "at flat index {i}") def check_no_transformed_refs_args(dbg_fn: Callable[[], core.DebugInfo], args_flat) -> None: from jax._src.state.types import TransformedRef for i, arg in enumerate(args_flat): if isinstance(arg, TransformedRef): dbg = dbg_fn() raise TypeError( f"When tracing {dbg.func_src_info} for {dbg.traced_for}, " "TransformedRefs are not allowed, but got a TransformedRef with name " f"{dbg.arg_names[i] if dbg.arg_names is not None else 'unknown'}." if dbg else "TransformedRefs are not allowed in this context.") from None class InternalFloatingPointError(Exception): name: str ty: str def __init__(self, name: str, ty: str): self.name = name self.ty = ty def maybe_recursive_nan_check( e: Exception, fun: Callable, args, kwargs ) -> NoReturn: print("Invalid nan value encountered in the output of a jax.jit " "function. Calling the de-optimized version.") try: _ = fun(*args, **kwargs) except (FloatingPointError, ZeroDivisionError) as e2: raise e2 from None else: _raise_no_nan_in_deoptimized(e) def _raise_no_nan_in_deoptimized(e) -> NoReturn: msg = (f"{str(e)}. Because " "jax_config.debug_nans.value and/or config.jax_debug_infs is set, the " "de-optimized function (i.e., the function as if the `jit` " "decorator were removed) was called in an attempt to get a more " "precise error message. However, the de-optimized function did not " "produce invalid values during its execution. This behavior can " "result from `jit` optimizations causing the invalid value to be " "produced. It may also arise from having nan/inf literals as " "inputs or outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " "\n\n" "It may be possible to avoid the invalid value by removing the " "`jit` decorator, at the cost of losing optimizations. " "\n\n" "If you see this error, consider opening a bug report at " "https://github.com/jax-ml/jax.") raise FloatingPointError(msg) from None