hand
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Experimental Key Reuse Checking
|
||||
-------------------------------
|
||||
|
||||
This module contains **experimental** functionality for detecting reuse of random
|
||||
keys within JAX programs. It is under active development and the APIs here are
|
||||
likely to change. The usage below requires JAX version 0.4.26 or newer.
|
||||
|
||||
Key reuse checking can be enabled using the ``jax_debug_key_reuse`` configuration.
|
||||
This can be set globally using::
|
||||
|
||||
>>> jax.config.update('jax_debug_key_reuse', True) # doctest: +SKIP
|
||||
|
||||
Or it can be enabled locally with the :func:`jax.debug_key_reuse` context manager.
|
||||
When enabled, using the same key twice will result in a :class:`~jax.errors.KeyReuseError`::
|
||||
|
||||
>>> import jax
|
||||
>>> with jax.debug_key_reuse(True):
|
||||
... key = jax.random.key(0)
|
||||
... val1 = jax.random.normal(key)
|
||||
... val2 = jax.random.normal(key) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
|
||||
|
||||
The key reuse checker is currently experimental, but in the future we will likely
|
||||
enable it by default.
|
||||
"""
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,593 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterator
|
||||
from functools import partial, reduce, total_ordering
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax.errors import KeyReuseError
|
||||
from jax.interpreters import batching, mlir
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.ad_checkpoint import remat_p
|
||||
from jax._src.debugging import debug_callback_p
|
||||
from jax._src.effects import Effect
|
||||
from jax._src.hashable_array import HashableArray
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.util import weakref_lru_cache
|
||||
|
||||
from jax._src.shard_map import shard_map_p
|
||||
import numpy as np
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
_source_context_message = (
|
||||
'PRNG key first used at the above location was subsequently reused'
|
||||
' at the following location:')
|
||||
|
||||
def key_reuse_error_with_source_traceback(
|
||||
message: str, traceback: source_info_util.Traceback | None) -> KeyReuseError:
|
||||
err = KeyReuseError(message)
|
||||
if traceback is not None:
|
||||
filtered_tb = traceback_util.filter_traceback(traceback.as_python_traceback())
|
||||
if filtered_tb:
|
||||
context_err = KeyReuseError(_source_context_message).with_traceback(filtered_tb)
|
||||
context_err.__context__ = err.__context__
|
||||
context_err.__cause__ = err.__cause__
|
||||
context_err.__suppress_context__ = err.__suppress_context__
|
||||
err.__context__ = None
|
||||
err.__cause__ = context_err
|
||||
return err
|
||||
|
||||
|
||||
# Create Source() and Sink() objects which validate inputs, have
|
||||
# correct equality semantics, and are hashable & immutable.
|
||||
@total_ordering
|
||||
class _SourceSinkBase:
|
||||
idx: int
|
||||
mask: bool | np.ndarray
|
||||
|
||||
def __init__(self, idx: int, mask: bool | np.bool_ | np.ndarray = True):
|
||||
assert isinstance(idx, int)
|
||||
if isinstance(mask, np.ndarray):
|
||||
assert mask.dtype == np.dtype('bool')
|
||||
if np.all(mask):
|
||||
mask = True
|
||||
elif not np.any(mask):
|
||||
mask = False
|
||||
elif mask.flags.writeable:
|
||||
mask = np.array(mask, copy=True)
|
||||
mask.flags.writeable = False
|
||||
elif isinstance(mask, np.bool_):
|
||||
mask = bool(mask)
|
||||
else:
|
||||
assert isinstance(mask, bool)
|
||||
super().__setattr__("idx", idx)
|
||||
super().__setattr__("mask", mask)
|
||||
|
||||
def __setattr__(self, *args, **kwargs):
|
||||
raise ValueError(f"{self.__class__.__name__} is immutable")
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.__class__ == other.__class__
|
||||
and self.idx == other.idx
|
||||
and np.shape(self.mask) == np.shape(other.mask)
|
||||
and np.all(self.mask == other.mask))
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, Forward):
|
||||
return True
|
||||
elif isinstance(other, _SourceSinkBase):
|
||||
return ((self.__class__.__name__, self.idx)
|
||||
< (other.__class__.__name__, other.idx))
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
if isinstance(self.mask, bool):
|
||||
return hash((self.__class__, self.idx, self.mask))
|
||||
else:
|
||||
mask = np.asarray(self.mask)
|
||||
return hash((self.__class__, self.idx, mask.shape,
|
||||
tuple(mask.flatten().tolist())))
|
||||
|
||||
def __repr__(self):
|
||||
if self.mask is True:
|
||||
return f"{self.__class__.__name__}({self.idx})"
|
||||
return f"{self.__class__.__name__}({self.idx}, {self.mask})"
|
||||
|
||||
|
||||
class Sink(_SourceSinkBase):
|
||||
pass
|
||||
|
||||
|
||||
class Source(_SourceSinkBase):
|
||||
pass
|
||||
|
||||
|
||||
class Forward(NamedTuple):
|
||||
in_idx: int
|
||||
out_idx: int
|
||||
|
||||
def __repr__(self):
|
||||
return f"Forward({self.in_idx}, {self.out_idx})"
|
||||
|
||||
|
||||
# KeyReuseSignature is essentially a frozen set of Source/Sink/Forward
|
||||
# objects, with a few convenience methods related to key reuse checking.
|
||||
class KeyReuseSignature:
|
||||
_args: frozenset[Source | Sink | Forward]
|
||||
|
||||
def __init__(self, *args):
|
||||
self._args = frozenset(args)
|
||||
|
||||
def __repr__(self):
|
||||
return f"KeyReuseSignature{tuple(sorted(self._args))}"
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, KeyReuseSignature) and self._args == other._args
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._args)
|
||||
|
||||
@property
|
||||
def sinks(self) -> Iterator[Sink]:
|
||||
yield from (s for s in self._args if isinstance(s, Sink))
|
||||
|
||||
@property
|
||||
def sources(self) -> Iterator[Source]:
|
||||
yield from (s for s in self._args if isinstance(s, Source))
|
||||
|
||||
@property
|
||||
def forwards(self) -> Iterator[Forward]:
|
||||
yield from (s for s in self._args if isinstance(s, Forward))
|
||||
|
||||
def check_signature(self, *args, funcname="function", context=None):
|
||||
for sink in self.sinks:
|
||||
key = args[sink.idx]
|
||||
if not isinstance(key, prng.PRNGKeyArray):
|
||||
continue
|
||||
if np.any(key._consumed & sink.mask):
|
||||
msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
|
||||
if context:
|
||||
msg += " {context}"
|
||||
raise key_reuse_error_with_source_traceback(
|
||||
msg, None if key._source_info is None else key._source_info.traceback)
|
||||
|
||||
def update_consumption(self, args_in, args_out):
|
||||
for sink in self.sinks:
|
||||
arg = args_in[sink.idx]
|
||||
if isinstance(arg, prng.PRNGKeyArray):
|
||||
arg._consumed = arg._consumed | sink.mask
|
||||
if np.any(sink.mask):
|
||||
arg._source_info = source_info_util.current()
|
||||
for arg in args_out:
|
||||
if isinstance(arg, prng.PRNGKeyArray):
|
||||
arg._consumed = True
|
||||
for source in self.sources:
|
||||
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
|
||||
args_out[source.idx]._consumed = ~np.asarray(source.mask)
|
||||
for forward in self.forwards:
|
||||
arg_in = args_in[forward.in_idx]
|
||||
arg_out = args_out[forward.out_idx]
|
||||
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
|
||||
arg_out._consumed = arg_in._consumed
|
||||
|
||||
|
||||
class DynamicKeyReuseSignature(NamedTuple):
|
||||
signature: Callable[[core.JaxprEqn], KeyReuseSignature]
|
||||
|
||||
def dynamic_key_reuse_signature(f: Callable[[core.JaxprEqn], KeyReuseSignature]) -> DynamicKeyReuseSignature:
|
||||
return DynamicKeyReuseSignature(f)
|
||||
|
||||
def key_reuse_signature_from_eqn(eqn: core.JaxprEqn) -> KeyReuseSignature:
|
||||
if eqn.primitive in key_reuse_signatures:
|
||||
sig = key_reuse_signatures[eqn.primitive]
|
||||
if isinstance(sig, KeyReuseSignature):
|
||||
return sig
|
||||
elif isinstance(sig, DynamicKeyReuseSignature):
|
||||
return sig.signature(eqn)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unrecognized key reuse signature of type {type(sig)}: {sig}")
|
||||
else:
|
||||
return unknown_signature(eqn)
|
||||
|
||||
|
||||
def key_reuse_signature_from_primitive(prim, *args, **params):
|
||||
if prim == pjit.jit_p:
|
||||
return jaxpr_type_signature(params['jaxpr'].jaxpr)
|
||||
if prim not in key_reuse_signatures:
|
||||
# TODO(jakevdp) should we generate an unknown signature here?
|
||||
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
|
||||
sig = key_reuse_signatures[prim]
|
||||
if isinstance(sig, KeyReuseSignature):
|
||||
return sig
|
||||
elif isinstance(sig, DynamicKeyReuseSignature):
|
||||
jaxpr = jax.make_jaxpr(partial(prim.bind, **params))(*args).jaxpr
|
||||
return jaxpr_type_signature(jaxpr)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unrecognized key reuse signature of type {type(sig)}: {sig}")
|
||||
|
||||
|
||||
consume_effect = Effect()
|
||||
consume_p = core.Primitive("consume")
|
||||
consume_p.def_impl(lambda x: x)
|
||||
consume_p.def_effectful_abstract_eval(lambda x: (x, {consume_effect}))
|
||||
batching.defvectorized(consume_p)
|
||||
mlir.register_lowering(
|
||||
consume_p,
|
||||
mlir.lower_fun(lambda x: x, multiple_results=False))
|
||||
|
||||
def consume(key):
|
||||
"""Consume the key and return a consumed copy."""
|
||||
return consume_p.bind(key)
|
||||
|
||||
assert_effect = Effect()
|
||||
|
||||
assert_consumed_value_p = core.Primitive("assert_consumed_value")
|
||||
assert_consumed_value_p.def_impl(lambda x, *, value: x)
|
||||
assert_consumed_value_p.def_effectful_abstract_eval(lambda x, *, value: (x, {assert_effect}))
|
||||
batching.defvectorized(assert_consumed_value_p)
|
||||
mlir.register_lowering(
|
||||
assert_consumed_value_p,
|
||||
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))
|
||||
|
||||
def assert_unconsumed(key):
|
||||
"""Assert that a key is unconsumed"""
|
||||
assert_consumed_value_p.bind(key, value=HashableArray(False))
|
||||
|
||||
def assert_consumed(key, value=True):
|
||||
"""Assert that a key is consumed"""
|
||||
assert_consumed_value_p.bind(key, value=HashableArray(value))
|
||||
|
||||
|
||||
def _check_consumed_value(eqn, consumed):
|
||||
"""Extra check for use with assert_consumed_value_p"""
|
||||
expected = eqn.params['value'].val
|
||||
if not np.all(consumed == expected):
|
||||
if np.all(expected):
|
||||
raise AssertionError(f"Expected key to be consumed in {eqn}")
|
||||
elif not np.any(expected):
|
||||
raise AssertionError(f"Expected key to not be consumed in {eqn}")
|
||||
else:
|
||||
raise AssertionError(f"Expected {expected}, got {consumed} in {eqn}")
|
||||
|
||||
|
||||
# The behavior of most primitives can be described via simple signatures.
|
||||
key_reuse_signatures: dict[core.Primitive, KeyReuseSignature | DynamicKeyReuseSignature] = {}
|
||||
|
||||
key_reuse_signatures[consume_p] = KeyReuseSignature(Sink(0), Forward(0, 0))
|
||||
key_reuse_signatures[assert_consumed_value_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[random.random_clone_p] = KeyReuseSignature(Source(0))
|
||||
key_reuse_signatures[prng.random_bits_p] = KeyReuseSignature(Sink(0))
|
||||
# TODO(jakevdp): should fold_in sink its input key?
|
||||
key_reuse_signatures[prng.random_fold_in_p] = KeyReuseSignature(Source(0))
|
||||
key_reuse_signatures[prng.random_seed_p] = KeyReuseSignature(Source(0))
|
||||
key_reuse_signatures[prng.random_split_p] = KeyReuseSignature(Sink(0), Source(0))
|
||||
key_reuse_signatures[random.random_gamma_p] = KeyReuseSignature(Sink(0))
|
||||
# TODO(jakevdp): broadcast should probably consume the input to avoid implicit duplication
|
||||
key_reuse_signatures[lax.broadcast_in_dim_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.copy_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.convert_element_type_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.reshape_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.squeeze_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[pjit.layout_constraint_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[pjit.sharding_constraint_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[pjit.reshard_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[prng.random_wrap_p] = KeyReuseSignature(Source(0))
|
||||
# TODO(jakevdp): should unwrap sink its input key?
|
||||
key_reuse_signatures[prng.random_unwrap_p] = KeyReuseSignature()
|
||||
key_reuse_signatures[debug_callback_p] = KeyReuseSignature()
|
||||
key_reuse_signatures[lax.dynamic_slice_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.dynamic_update_slice_p] = KeyReuseSignature(Sink(1), Forward(0, 0))
|
||||
key_reuse_signatures[lax.gather_p] = KeyReuseSignature(Forward(0, 0))
|
||||
key_reuse_signatures[lax.scatter_p] = KeyReuseSignature(Sink(2), Forward(0, 0))
|
||||
# Equality checks don't consume
|
||||
key_reuse_signatures[lax.eq_p] = KeyReuseSignature()
|
||||
key_reuse_signatures[lax.ne_p] = KeyReuseSignature()
|
||||
|
||||
# The default signature will Sink all key inputs, and not Source any.
|
||||
def unknown_signature(eqn):
|
||||
def is_key(var: core.Atom):
|
||||
return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key)
|
||||
return KeyReuseSignature(
|
||||
*(Sink(idx) for idx, var in enumerate(eqn.invars) if is_key(var))
|
||||
)
|
||||
|
||||
@weakref_lru_cache
|
||||
def jaxpr_type_signature(jaxpr: core.Jaxpr) -> KeyReuseSignature:
|
||||
"""Parse the jaxpr to determine key reuse signature"""
|
||||
consumed: dict[core.Atom, bool | np.ndarray] = {}
|
||||
forwards: dict[core.Atom, core.Atom] = {} # map forwarded outputs to inputs.
|
||||
|
||||
def resolve_forwards(var: core.Atom) -> core.Atom:
|
||||
if not forwards:
|
||||
return var
|
||||
for _ in range(len(forwards) + 1):
|
||||
if isinstance(var, core.Literal):
|
||||
return var
|
||||
if var in forwards:
|
||||
var = forwards[var]
|
||||
else:
|
||||
return var
|
||||
raise ValueError("forwarding cycle detected")
|
||||
|
||||
def is_key(var: core.Atom):
|
||||
return hasattr(var.aval, "dtype") and jax.dtypes.issubdtype(var.aval.dtype, jax.dtypes.prng_key)
|
||||
|
||||
def sink(var: core.Atom, mask=True):
|
||||
if not is_key(var):
|
||||
return
|
||||
var = resolve_forwards(var)
|
||||
assert not isinstance(var, core.Literal)
|
||||
if np.any(np.logical_and(consumed.get(var, False), mask)):
|
||||
return True
|
||||
consumed[var] = np.logical_or(consumed.get(var, False), mask)
|
||||
|
||||
def source(var: core.Atom, mask=False):
|
||||
if not is_key(var):
|
||||
return
|
||||
var = resolve_forwards(var)
|
||||
assert not isinstance(var, core.Literal)
|
||||
consumed[var] = mask
|
||||
|
||||
def is_consumed(var: core.Atom):
|
||||
var = resolve_forwards(var)
|
||||
if isinstance(var, core.Literal):
|
||||
return False
|
||||
return consumed.get(var, False)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
traceback = eqn.source_info.traceback
|
||||
name_stack = source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
with source_info_util.user_context(traceback, name_stack=name_stack):
|
||||
signature = key_reuse_signature_from_eqn(eqn)
|
||||
if eqn.primitive == assert_consumed_value_p:
|
||||
# This is a special case that goes beyond normal key reuse logic.
|
||||
_check_consumed_value(eqn, is_consumed(eqn.invars[0]))
|
||||
|
||||
for in_idx, out_idx in signature.forwards:
|
||||
forwards[eqn.outvars[out_idx]] = eqn.invars[in_idx]
|
||||
|
||||
for snk in signature.sinks:
|
||||
if not 0 <= snk.idx < len(eqn.invars):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, sink {snk.idx} out of range [0, {len(eqn.invars)}]")
|
||||
if sink(eqn.invars[snk.idx], snk.mask):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, argument {snk.idx} is already consumed.")
|
||||
for var in eqn.outvars:
|
||||
if not isinstance(var, core.Literal) and var not in forwards:
|
||||
source(var, True) # consumed unless in a Source.
|
||||
for src in signature.sources:
|
||||
if not 0 <= src.idx < len(eqn.outvars):
|
||||
raise KeyReuseError(f"In {eqn.primitive}, source {src.idx} out of range [0, {len(eqn.outvars)}]")
|
||||
source(eqn.outvars[src.idx])
|
||||
|
||||
all_inputs: list[core.Atom] = [*jaxpr.invars, *jaxpr.constvars]
|
||||
return KeyReuseSignature(
|
||||
*(Sink(i, consumed[v]) for i, v in enumerate(all_inputs)
|
||||
if is_key(v) and np.any(consumed.get(v, False))),
|
||||
*(Source(i) for i, v in enumerate(jaxpr.outvars)
|
||||
if is_key(v) and resolve_forwards(v) not in all_inputs and not consumed.get(v, False)),
|
||||
*(Forward(all_inputs.index(resolve_forwards(outvar)), idx_out)
|
||||
for idx_out, outvar in enumerate(jaxpr.outvars)
|
||||
if is_key(outvar) and resolve_forwards(outvar) in all_inputs)
|
||||
)
|
||||
|
||||
|
||||
def function_type_signature(fun: Callable[..., Any], *args: Any) -> KeyReuseSignature:
|
||||
args_flat, in_tree = tree_util.tree_flatten(args)
|
||||
in_avals_flat = [core.typeof(arg) for arg in args_flat]
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun,
|
||||
debug_info=api_util.debug_info("key_reuse", fun, args, {})),
|
||||
in_tree)
|
||||
jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
|
||||
return jaxpr_type_signature(jaxpr)
|
||||
|
||||
|
||||
def check_key_reuse_jaxpr(jaxpr: core.Jaxpr) -> None:
|
||||
"""Check the jaxpr for key reuse."""
|
||||
jaxpr_type_signature(jaxpr)
|
||||
|
||||
|
||||
def check_key_reuse(fun: Callable[..., Any], /, *args: Any) -> None:
|
||||
"""Function to statically check key reuse."""
|
||||
function_type_signature(fun, *args)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------
|
||||
# key reuse rules for particular primitives:
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _slice_signature(eqn):
|
||||
in_aval = eqn.invars[0].aval
|
||||
assert hasattr(in_aval, "dtype")
|
||||
if not jax.dtypes.issubdtype(in_aval.dtype, jax.dtypes.prng_key):
|
||||
return KeyReuseSignature(Forward(0, 0))
|
||||
assert hasattr(in_aval, "shape")
|
||||
if any(core.is_symbolic_dim(s) for s in in_aval.shape):
|
||||
return KeyReuseSignature(Forward(0, 0))
|
||||
start_indices = eqn.params['start_indices']
|
||||
limit_indices = eqn.params['limit_indices']
|
||||
strides = eqn.params['strides'] or (1,) * len(start_indices)
|
||||
idx = tuple(slice(*tup) for tup in util.safe_zip(start_indices, limit_indices, strides))
|
||||
sink = np.zeros(in_aval.shape, dtype=bool)
|
||||
sink[idx] = True
|
||||
return KeyReuseSignature(Sink(0, sink), Source(0))
|
||||
|
||||
key_reuse_signatures[lax.slice_p] = _slice_signature
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _concatenate_signature(eqn):
|
||||
num_vals = len(eqn.invars)
|
||||
# TODO(jakevdp): should this signature be more granular?
|
||||
if num_vals == 1:
|
||||
return KeyReuseSignature(Forward(0, 0))
|
||||
else:
|
||||
return KeyReuseSignature(*(Sink(i) for i in range(num_vals)), Source(0))
|
||||
|
||||
key_reuse_signatures[lax.concatenate_p] = _concatenate_signature
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _pjit_key_type_signature(eqn):
|
||||
return jaxpr_type_signature(eqn.params['jaxpr'].jaxpr)
|
||||
|
||||
key_reuse_signatures[pjit.jit_p] = _pjit_key_type_signature
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _shard_map_type_signature(eqn):
|
||||
return jaxpr_type_signature(eqn.params['jaxpr'])
|
||||
|
||||
key_reuse_signatures[shard_map_p] = _shard_map_type_signature
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _cond_key_type_signature(eqn):
|
||||
signatures = [jaxpr_type_signature(branch.jaxpr) for branch in eqn.params['branches']]
|
||||
sinks = defaultdict(list)
|
||||
sources = defaultdict(list)
|
||||
for sig in signatures:
|
||||
for sink in sig.sinks:
|
||||
sinks[sink.idx].append(sink.mask)
|
||||
for source in sig.sources:
|
||||
sources[source.idx].append(source.mask)
|
||||
|
||||
combined_sinks = [Sink(i + 1, reduce(np.logical_or, m)) for i, m in sinks.items()]
|
||||
combined_sources = [Source(i, reduce(np.logical_and, m)) for i, m in sources.items()]
|
||||
combined_forwards = [Forward(f.in_idx + 1, f.out_idx) for f in
|
||||
set.intersection(*(set(sig.forwards) for sig in signatures))]
|
||||
return KeyReuseSignature(*combined_sinks, *combined_sources, *combined_forwards)
|
||||
|
||||
key_reuse_signatures[lax.cond_p] = _cond_key_type_signature
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _scan_key_type_signature(eqn):
|
||||
jaxpr = eqn.params['jaxpr'].jaxpr
|
||||
num_consts = eqn.params['num_consts']
|
||||
num_carry = eqn.params['num_carry']
|
||||
signature = jaxpr_type_signature(jaxpr)
|
||||
|
||||
# scan body should not consume key in constants
|
||||
if any(np.any(s.mask) for s in signature.sinks if s.idx < num_consts):
|
||||
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
|
||||
"because key constants are repeatedly consumed:\n"
|
||||
f" {signature=}\n"
|
||||
f" {eqn=}\n"
|
||||
f" {jaxpr=}")
|
||||
|
||||
# scan carry should only consume keys that are sourced on output.
|
||||
carry_sinks = {s.idx - num_consts: s.mask for s in signature.sinks
|
||||
if 0 <= s.idx - num_consts < num_carry and np.any(s.mask)}
|
||||
carry_sources = {s.idx: s.mask for s in signature.sources
|
||||
if 0 <= s.idx < num_carry and np.any(s.mask)}
|
||||
if not set(carry_sinks).issubset(set(carry_sources)): # TODO(jakevdp): check that masks match
|
||||
raise KeyReuseError("scan body function leads to key reuse when repeatedly executed, "
|
||||
"because consumed inputs don't match sourced outputs:\n"
|
||||
f" {signature=}\n"
|
||||
f" {eqn=}\n"
|
||||
f" {jaxpr=}")
|
||||
return signature
|
||||
|
||||
key_reuse_signatures[jax.lax.scan_p] = _scan_key_type_signature
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _while_key_type_signature(eqn):
|
||||
cond_jaxpr = eqn.params['cond_jaxpr'].jaxpr
|
||||
cond_nconsts = eqn.params['cond_nconsts']
|
||||
body_jaxpr = eqn.params['body_jaxpr'].jaxpr
|
||||
body_nconsts = eqn.params['body_nconsts']
|
||||
|
||||
cond_signature = jaxpr_type_signature(cond_jaxpr)
|
||||
body_signature = jaxpr_type_signature(body_jaxpr)
|
||||
|
||||
# Error if there are sinks among consts.
|
||||
if any(np.any(s.mask) for s in cond_signature.sinks if s.idx < cond_nconsts):
|
||||
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
|
||||
f" {cond_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if any(np.any(s.mask) for s in body_signature.sinks if s.idx < body_nconsts):
|
||||
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
|
||||
# carry should only consume keys that are sourced on output.
|
||||
body_carry_sinks = {s.idx - body_nconsts: s.mask for s in body_signature.sinks if s.idx >= body_nconsts}
|
||||
cond_carry_sinks = {s.idx - cond_nconsts: s.mask for s in cond_signature.sinks if s.idx >= cond_nconsts}
|
||||
carry_sources = {s.idx: s.mask for s in body_signature.sources}
|
||||
# TODO(jakevdp): check masks at each index?
|
||||
if not (cond_carry_sinks.keys() <= carry_sources.keys()):
|
||||
raise KeyReuseError("while_loop cond function leads to key reuse when repeatedly executed: "
|
||||
f" {cond_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if not (body_carry_sinks.keys() <= carry_sources.keys()):
|
||||
raise KeyReuseError("while_loop body function leads to key reuse when repeatedly executed: "
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
if body_carry_sinks.keys() & cond_carry_sinks.keys():
|
||||
raise KeyReuseError("while_loop cond and body functions both use the same key: "
|
||||
f" {cond_signature=}\n"
|
||||
f" {body_signature=}\n"
|
||||
f" {eqn=}")
|
||||
return body_signature
|
||||
|
||||
key_reuse_signatures[jax.lax.while_p] = _while_key_type_signature
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _remat_key_type_signature(eqn):
|
||||
# The assumption here is that the non-differentiated pass contains all relevant
|
||||
# key usage, and the differentiated pass
|
||||
# 1) will only consume keys that are already consumed in the non-differentiated pass
|
||||
# 2) will never create keys
|
||||
# Therefore, the differentiated pass is a no-op.
|
||||
if eqn.params['differentiated']:
|
||||
return KeyReuseSignature()
|
||||
return jaxpr_type_signature(eqn.params['jaxpr'])
|
||||
|
||||
key_reuse_signatures[remat_p] = _remat_key_type_signature
|
||||
|
||||
|
||||
@dynamic_key_reuse_signature
|
||||
def _device_put_signature(eqn):
|
||||
num_vals = len(eqn.invars)
|
||||
return KeyReuseSignature(*(Forward(i, i) for i in range(num_vals)))
|
||||
|
||||
key_reuse_signatures[lax.device_put_p] = _device_put_signature
|
||||
|
||||
|
||||
def call_impl_with_key_reuse_checks(prim: core.Primitive, raw_impl: Callable[..., Any], *args, **kwargs) -> Any:
|
||||
if prim not in key_reuse_signatures:
|
||||
# TODO(jakevdp): should we use an unknown signature here?
|
||||
return raw_impl(*args, **kwargs)
|
||||
signature = key_reuse_signature_from_primitive(prim, *args, **kwargs)
|
||||
funcname = "jit-compiled function" if prim == pjit.jit_p else str(prim)
|
||||
consts = kwargs['jaxpr'].consts if prim == pjit.jit_p else []
|
||||
signature.check_signature(*args, *consts, funcname=funcname)
|
||||
result = raw_impl(*args, **kwargs)
|
||||
signature.update_consumption([*args, *consts], result if prim.multiple_results else [result])
|
||||
return result
|
||||
Reference in New Issue
Block a user