This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,42 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Experimental Key Reuse Checking
-------------------------------
This module contains **experimental** functionality for detecting reuse of random
keys within JAX programs. It is under active development and the APIs here are
likely to change. The usage below requires JAX version 0.4.26 or newer.
Key reuse checking can be enabled using the ``jax_debug_key_reuse`` configuration.
This can be set globally using::
>>> jax.config.update('jax_debug_key_reuse', True) # doctest: +SKIP
Or it can be enabled locally with the :func:`jax.debug_key_reuse` context manager.
When enabled, using the same key twice will result in a :class:`~jax.errors.KeyReuseError`::
>>> import jax
>>> with jax.debug_key_reuse(True):
... key = jax.random.key(0)
... val1 = jax.random.normal(key)
... val2 = jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
The key reuse checker is currently experimental, but in the future we will likely
enable it by default.
"""
@@ -0,0 +1,593 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Iterator
from functools import partial, reduce, total_ordering
from typing import Any, NamedTuple
import jax
from jax import lax
from jax import tree_util
from jax.errors import KeyReuseError
from jax.interpreters import batching, mlir
from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import prng
from jax._src import random
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.ad_checkpoint import remat_p
from jax._src.debugging import debug_callback_p
from jax._src.effects import Effect
from jax._src.hashable_array import HashableArray
from jax._src.interpreters import partial_eval as pe
from jax._src.util import weakref_lru_cache
from jax._src.shard_map import shard_map_p
import numpy as np
traceback_util.register_exclusion(__file__)
_source_context_message = (
'PRNG key first used at the above location was subsequently reused'
' at the following location:')
def key_reuse_error_with_source_traceback(
message: str, traceback: source_info_util.Traceback | None) -> KeyReuseError:
err = KeyReuseError(message)
if traceback is not None:
filtered_tb = traceback_util.filter_traceback(traceback.as_python_traceback())
if filtered_tb:
context_err = KeyReuseError(_source_context_message).with_traceback(filtered_tb)
context_err.__context__ = err.__context__
context_err.__cause__ = err.__cause__
context_err.__suppress_context__ = err.__suppress_context__
err.__context__ = None
err.__cause__ = context_err
return err
# Create Source() and Sink() objects which validate inputs, have
# correct equality semantics, and are hashable & immutable.
@total_ordering
class _SourceSinkBase:
idx: int
mask: bool | np.ndarray
def __init__(self, idx: int, mask: bool | np.bool_ | np.ndarray = True):
assert isinstance(idx, int)
if isinstance(mask, np.ndarray):
assert mask.dtype == np.dtype('bool')
if np.all(mask):
mask = True
elif not np.any(mask):
mask = False
elif mask.flags.writeable:
mask = np.array(mask, copy=True)
mask.flags.writeable = False
elif isinstance(mask, np.bool_):
mask = bool(mask)
else:
assert isinstance(mask, bool)
super().__setattr__("idx", idx)
super().__setattr__("mask", mask)
def __setattr__(self, *args, **kwargs):
raise ValueError(f"{self.__class__.__name__} is immutable")
def __eq__(self, other):
return (self.__class__ == other.__class__
and self.idx == other.idx
and np.shape(self.mask) == np.shape(other.mask)
and np.all(self.mask == other.mask))
def __lt__(self, other):
if isinstance(other, Forward):
return True
elif isinstance(other, _SourceSinkBase):
return ((self.__class__.__name__, self.idx)
< (other.__class__.__name__, other.idx))
else:
return NotImplemented
def __hash__(self):
if isinstance(self.mask, bool):
return hash((self.__class__, self.idx, self.mask))
else:
mask = np.asarray(self.mask)
return hash((self.__class__, self.idx, mask.shape,
tuple(mask.flatten().tolist())))
def __repr__(self):
if self.mask is True:
return f"{self.__class__.__name__}({self.idx})"
return f"{self.__class__.__name__}({self.idx}, {self.mask})"
class Sink(_SourceSinkBase):
pass
class Source(_SourceSinkBase):
pass
class Forward(NamedTuple):
in_idx: int
out_idx: int
def __repr__(self):
return f"Forward({self.in_idx}, {self.out_idx})"
# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward
# objects, with a few convenience methods related to key reuse checking.
class KeyReuseSignature:
_args: frozenset[Source | Sink | Forward]
def __init__(self, *args):
self._args = frozenset(args)
def __repr__(self):
return f"KeyReuseSignature{tuple(sorted(self._args))}"
def __eq__(self, other):
return isinstance(other, KeyReuseSignature) and self._args == other._args
def __hash__(self):
return hash(self._args)
@property
def sinks(self) -> Iterator[Sink]:
yield from (s for s in self._args if isinstance(s, Sink))
@property
def sources(self) -> Iterator[Source]:
yield from (s for s in self._args if isinstance(s, Source))
@property
def forwards(self) -> Iterator[Forward]:
yield from (s for s in self._args if isinstance(s, Forward))
def check_signature(self, *args, funcname="function", context=None):
for sink in self.sinks:
key = args[sink.idx]
if not isinstance(key, prng.PRNGKeyArray):
continue
if np.any(key._consumed & sink.mask):
msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
if context:
msg += " {context}"
raise key_reuse_error_with_source_traceback(
msg, None if key._source_info is None else key._source_info.traceback)
def update_consumption(self, args_in, args_out):
for sink in self.sinks:
arg = args_in[sink.idx]
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = arg._consumed | sink.mask
if np.any(sink.mask):
arg._source_info = source_info_util.current()
for arg in args_out:
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = True
for source in self.sources:
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
args_out[source.idx]._consumed = ~np.asarray(source.mask)
for forward in self.forwards:
arg_in = args_in[forward.in_idx]
arg_out = args_out[forward.out_idx]
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
arg_out._consumed = arg_in._consumed
class DynamicKeyReuseSignature(NamedTuple):
signature: Callable[[core.JaxprEqn], KeyReuseSignature]
def dynamic_key_reuse_signature(f: Callable[[core.JaxprEqn], KeyReuseSignature]) -> DynamicKeyReuseSignature:
return DynamicKeyReuseSignature(f)
def key_reuse_signature_from_eqn(eqn: core.JaxprEqn) -> KeyReuseSignature:
if eqn.primitive in key_reuse_signatures:
sig = key_reuse_signatures[eqn.primitive]
if isinstance(sig, KeyReuseSignature):
return sig
elif isinstance(sig, DynamicKeyReuseSignature):
return sig.signature(eqn)
else:
raise TypeError(
f"Unrecognized key reuse signature of type {type(sig)}: {sig}")
else:
return unknown_signature(eqn)
def key_reuse_signature_from_primitive(prim, *args, **params):
if prim == pjit.jit_p:
return jaxpr_type_signature(params['jaxpr'].jaxpr)
if prim not in key_reuse_signatures:
# TODO(jakevdp) should we generate an unknown signature here?
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
sig = key_reuse_signatures[prim]
if isinstance(sig, KeyReuseSignature):
return sig
elif isinstance(sig, DynamicKeyReuseSignature):
jaxpr = jax.make_jaxpr(partial(prim.bind, **params))(*args).jaxpr
return jaxpr_type_signature(jaxpr)
else:
raise TypeError(
f"Unrecognized key reuse signature of type {type(sig)}: {sig}")
consume_effect = Effect()
consume_p = core.Primitive("consume")
consume_p.def_impl(lambda x: x)
consume_p.def_effectful_abstract_eval(lambda x: (x, {consume_effect}))
batching.defvectorized(consume_p)
mlir.register_lowering(
consume_p,
mlir.lower_fun(lambda x: x, multiple_results=False))
def consume(key):
"""Consume the key and return a consumed copy."""
return consume_p.bind(key)
assert_effect = Effect()
assert_consumed_value_p = core.Primitive("assert_consumed_value")
assert_consumed_value_p.def_impl(lambda x, *, value: x)
assert_consumed_value_p.def_effectful_abstract_eval(lambda x, *, value: (x, {assert_effect}))
batching.defvectorized(assert_consumed_value_p)
mlir.register_lowering(
assert_consumed_value_p,
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))
def assert_unconsumed(key):
"""Assert that a key is unconsumed"""
assert_consumed_value_p.bind(key, value=HashableArray(False))
def assert_consumed(key, value=True):
"""Assert that a key is consumed"""
assert_consumed_value_p.bind(key, value=HashableArray(value))
def _check_consumed_value(eqn, consumed):
"""Extra check for use with assert_consumed_value_p"""
expected = eqn.params['value'].val
if not np.all(consumed == expected):
if np.all(expected):
raise AssertionError(f"Expected key to be consumed in {eqn}")
elif not np.any(expected):
raise AssertionError(f"Expected key to not be consumed in {eqn}")
else:
raise AssertionError(f"Expected {expected}, got {consumed} in {eqn}")
# The behavior of most primitives can be described via simple signatures.
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature | DynamicKeyReuseSignature] = {}
key_reuse_signatures[consume_p] = KeyReuseSignature(Sink(0), Forward(0, 0))
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[random.random_clone_p] = KeyReuseSignature(Source(0))
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature(Sink(0))
# TODO(jakevdp): should fold_in sink its input key?
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature(Source(0))
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature(Source(0))
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature(Sink(0), Source(0))
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature(Sink(0))
# TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.copy_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.reshape_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[pjit.layout_constraint_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[pjit.sharding_constraint_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[pjit.reshard_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature(Source(0))
# TODO(jakevdp): should unwrap sink its input key?
key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature()
key_reuse_signatures[debug_callback_p] = KeyReuseSignature()
key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature(Sink(1), Forward(0, 0))
key_reuse_signatures[lax.gather_p] = KeyReuseSignature(Forward(0, 0))
key_reuse_signatures[lax.scatter_p] = KeyReuseSignature(Sink(2), Forward(0, 0))
# Equality checks don't consume
key_reuse_signatures[lax.eq_p] = KeyReuseSignature()
key_reuse_signatures[lax.ne_p] = KeyReuseSignature()
# The default signature will Sink all key inputs, and not Source any.
def unknown_signature(eqn):
def is_key(var: core.Atom):
return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key)
return KeyReuseSignature(
*(Sink(idx) for idx, var in enumerate(eqn.invars) if is_key(var))
)
@weakref_lru_cache
def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
"""Parse the jaxpr to determine key reuse signature"""
consumed: dict[core.Atom, bool | np.ndarray] = {}
forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs.
def resolve_forwards(var: core.Atom) -> core.Atom:
if not forwards:
return var
for _ in range(len(forwards) + 1):
if isinstance(var, core.Literal):
return var
if var in forwards:
var = forwards[var]
else:
return var
raise ValueError("forwarding cycle detected")
def is_key(var: core.Atom):
return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key)
def sink(var: core.Atom, mask=True):
if not is_key(var):
return
var = resolve_forwards(var)
assert not isinstance(var, core.Literal)
if np.any(np.logical_and(consumed.get(var, False), mask)):
return True
consumed[var] = np.logical_or(consumed.get(var, False), mask)
def source(var: core.Atom, mask=False):
if not is_key(var):
return
var = resolve_forwards(var)
assert not isinstance(var, core.Literal)
consumed[var] = mask
def is_consumed(var: core.Atom):
var = resolve_forwards(var)
if isinstance(var, core.Literal):
return False
return consumed.get(var, False)
for eqn in jaxpr.eqns:
traceback = eqn.source_info.traceback
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
with source_info_util.user_context(traceback, name_stack=name_stack):
signature = key_reuse_signature_from_eqn(eqn)
if eqn.primitive == assert_consumed_value_p:
# This is a special case that goes beyond normal key reuse logic.
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))
for in_idx, out_idx in signature.forwards:
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
for snk in signature.sinks:
if not 0 <= snk.idx < len(eqn.invars):
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
if sink(eqn.invars[snk.idx], snk.mask):
raise KeyReuseError(f"In {eqn.primitive}, argument {snk.idx} is already consumed.")
for var in eqn.outvars:
if not isinstance(var, core.Literal) and var not in forwards:
source(var, True) # consumed unless in a Source.
for src in signature.sources:
if not 0 <= src.idx < len(eqn.outvars):
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
source(eqn.outvars[src.idx])
all_inputs: list[core.Atom] = [*jaxpr.invars, *jaxpr.constvars]
return KeyReuseSignature(
*(Sink(i, consumed[v]) for i, v in enumerate(all_inputs)
if is_key(v) and np.any(consumed.get(v, False))),
*(Source(i) for i, v in enumerate(jaxpr.outvars)
if is_key(v) and resolve_forwards(v) not in all_inputs and not consumed.get(v, False)),
*(Forward(all_inputs.index(resolve_forwards(outvar)), idx_out)
for idx_out, outvar in enumerate(jaxpr.outvars)
if is_key(outvar) and resolve_forwards(outvar) in all_inputs)
)
def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature:
args_flat, in_tree = tree_util.tree_flatten(args)
in_avals_flat = [core.typeof(arg) for arg in args_flat]
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun,
debug_info=api_util.debug_info("key_reuse", fun, args, {})),
in_tree)
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
return jaxpr_type_signature(jaxpr)
def check_key_reuse_jaxpr(jaxpr: core.Jaxpr) -> None:
"""Check the jaxpr for key reuse."""
jaxpr_type_signature(jaxpr)
def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None:
"""Function to statically check key reuse."""
function_type_signature(fun, *args)
# ----------------------------------------------------------------------------------
# key reuse rules for particular primitives:
@dynamic_key_reuse_signature
def _slice_signature(eqn):
in_aval = eqn.invars[0].aval
assert hasattr(in_aval, "dtype")
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
return KeyReuseSignature(Forward(0, 0))
assert hasattr(in_aval, "shape")
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
return KeyReuseSignature(Forward(0, 0))
start_indices = eqn.params['start_indices']
limit_indices = eqn.params['limit_indices']
strides = eqn.params['strides'] or (1,) * len(start_indices)
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
sink = np.zeros(in_aval.shape, dtype=bool)
sink[idx] = True
return KeyReuseSignature(Sink(0, sink), Source(0))
key_reuse_signatures[lax.slice_p] = _slice_signature
@dynamic_key_reuse_signature
def _concatenate_signature(eqn):
num_vals = len(eqn.invars)
# TODO(jakevdp): should this signature be more granular?
if num_vals == 1:
return KeyReuseSignature(Forward(0, 0))
else:
return KeyReuseSignature(*(Sink(i) for i in range(num_vals)), Source(0))
key_reuse_signatures[lax.concatenate_p] = _concatenate_signature
@dynamic_key_reuse_signature
def _pjit_key_type_signature(eqn):
return jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
key_reuse_signatures[pjit.jit_p] = _pjit_key_type_signature
@dynamic_key_reuse_signature
def _shard_map_type_signature(eqn):
return jaxpr_type_signature(eqn.params['jaxpr'])
key_reuse_signatures[shard_map_p] = _shard_map_type_signature
@dynamic_key_reuse_signature
def _cond_key_type_signature(eqn):
signatures = [jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
sinks = defaultdict(list)
sources = defaultdict(list)
for sig in signatures:
for sink in sig.sinks:
sinks[sink.idx].append(sink.mask)
for source in sig.sources:
sources[source.idx].append(source.mask)
combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()]
combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()]
combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in
set.intersection(*(set(sig.forwards) for sig in signatures))]
return KeyReuseSignature(*combined_sinks, *combined_sources, *combined_forwards)
key_reuse_signatures[lax.cond_p] = _cond_key_type_signature
@dynamic_key_reuse_signature
def _scan_key_type_signature(eqn):
jaxpr = eqn.params['jaxpr'].jaxpr
num_consts = eqn.params['num_consts']
num_carry = eqn.params['num_carry']
signature = jaxpr_type_signature(jaxpr)
# scan body should not consume key in constants
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
"because key constants are repeatedly consumed:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")
# scan carry should only consume keys that are sourced on output.
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks
if 0 <= s.idx - num_consts < num_carry and np.any(s.mask)}
carry_sources = {s.idx: s.mask for s in signature.sources
if 0 <= s.idx < num_carry and np.any(s.mask)}
if not set(carry_sinks).issubset(set(carry_sources)): # TODO(jakevdp): check that masks match
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
"because consumed inputs don't match sourced outputs:\n"
f" {signature=}\n"
f" {eqn=}\n"
f" {jaxpr=}")
return signature
key_reuse_signatures[jax.lax.scan_p] = _scan_key_type_signature
@dynamic_key_reuse_signature
def _while_key_type_signature(eqn):
cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr
cond_nconsts = eqn.params['cond_nconsts']
body_jaxpr = eqn.params['body_jaxpr'].jaxpr
body_nconsts = eqn.params['body_nconsts']
cond_signature = jaxpr_type_signature(cond_jaxpr)
body_signature = jaxpr_type_signature(body_jaxpr)
# Error if there are sinks among consts.
if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts):
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
f" {cond_signature=}\n"
f" {eqn=}")
if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts):
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
f" {body_signature=}\n"
f" {eqn=}")
# carry should only consume keys that are sourced on output.
body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts}
cond_carry_sinks = {s.idx - cond_nconsts: s.mask for s in cond_signature.sinks if s.idx >= cond_nconsts}
carry_sources = {s.idx: s.mask for s in body_signature.sources}
# TODO(jakevdp): check masks at each index?
if not (cond_carry_sinks.keys() <= carry_sources.keys()):
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
f" {cond_signature=}\n"
f" {eqn=}")
if not (body_carry_sinks.keys() <= carry_sources.keys()):
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
f" {body_signature=}\n"
f" {eqn=}")
if body_carry_sinks.keys() & cond_carry_sinks.keys():
raise KeyReuseError("while_loop cond and body functions both use the same key: "
f" {cond_signature=}\n"
f" {body_signature=}\n"
f" {eqn=}")
return body_signature
key_reuse_signatures[jax.lax.while_p] = _while_key_type_signature
@dynamic_key_reuse_signature
def _remat_key_type_signature(eqn):
# The assumption here is that the non-differentiated pass contains all relevant
# key usage, and the differentiated pass
# 1) will only consume keys that are already consumed in the non-differentiated pass
# 2) will never create keys
# Therefore, the differentiated pass is a no-op.
if eqn.params['differentiated']:
return KeyReuseSignature()
return jaxpr_type_signature(eqn.params['jaxpr'])
key_reuse_signatures[remat_p] = _remat_key_type_signature
@dynamic_key_reuse_signature
def _device_put_signature(eqn):
num_vals = len(eqn.invars)
return KeyReuseSignature(*(Forward(i, i) for i in range(num_vals)))
key_reuse_signatures[lax.device_put_p] = _device_put_signature
def call_impl_with_key_reuse_checks(prim: core.Primitive, raw_impl: Callable[..., Any], *args, **kwargs) -> Any:
if prim not in key_reuse_signatures:
# TODO(jakevdp): should we use an unknown signature here?
return raw_impl(*args, **kwargs)
signature = key_reuse_signature_from_primitive(prim, *args, **kwargs)
funcname = "jit-compiled function" if prim == pjit.jit_p else str(prim)
consts = kwargs['jaxpr'].consts if prim == pjit.jit_p else []
signature.check_signature(*args, *consts, funcname=funcname)
result = raw_impl(*args, **kwargs)
signature.update_consumption([*args, *consts], result if prim.multiple_results else [result])
return result