# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from dataclasses import dataclass from functools import partial, update_wrapper import inspect import itertools as it from typing import Any, Hashable, Callable from jax._src import api from jax._src import config from jax._src import core from jax._src import dtypes from jax._src import effects from jax._src.api_util import resolve_kwargs, infer_argnums_and_argnames from jax._src import traceback_util from jax._src.core import typeof from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import remat from jax._src.custom_derivatives import ( CustomVJPPrimal, _temporary_dtype_exception, _check_for_returned_refs) from jax._src.errors import UnexpectedTracerError from jax._src.state.types import AbstractRef from jax._src import ad_util from jax._src.util import safe_zip, safe_map, split_list, unzip2 from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, tree_leaves, tree_leaves_checked, broadcast_prefix, register_static, tree_map_with_path, keystr, tracing_registry, FlatTree) map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip PyTreeOfAvals = Any PyTreeDef = Any LoVal = Any HiVal = Any traceback_util.register_exclusion(__file__) # Hijax extension API Ty = core.AbstractValue LoType = core.AbstractValue QDD = core.QuasiDynamicData ShapedArray = core.ShapedArray class HiPrimitive(core.Primitive): def __init__(self, name): self.name = name ad.primitive_jvps[self] = self.jvp ad.primitive_transposes[self] = self.transpose def is_high(self, *avals, **params) -> bool: return True def is_effectful(self, params) -> bool: # pyrefly: ignore[bad-override] return False # default immutable # type checking and forward type propagation def abstract_eval(self, *arg_avals, **params): assert False, "must override" # lowering implements the primitive in terms of lojax inputs/outputs/ops def to_lojax(self, *lotypes_wrapped_in_hitypes, **params): # pyrefly: ignore[bad-override] assert False, f"must override for {self}" # autodiff interface def jvp(self, primals, tangents, **params): assert False, "must override" # transposition is only required if the primitive is linear in some inputs def transpose(self, *args, **params): assert False, "must override" AxisName = Any class HiType(core.AbstractValue): is_high = True has_qdd = False # immutable # type equality def __hash__(self): assert False, "must override" def __eq__(self, other): assert False, "must override" # lowering from hijax type to lojax types def lo_ty(self) -> list[core.AbstractValue]: assert False, "must override" # define lowering from hijax value to lojax values and back (like pytrees) def lower_val(self, hi_val: HiVal) -> list[LoVal]: # TODO(mattjj): not lovals assert False, "must override" def raise_val(self, *lo_vals: LoVal) -> HiVal: assert False, "must override" # autodiff interface def to_tangent_aval(self) -> HiType: assert False, "must override" def to_ct_aval(self) -> HiType: return self.to_tangent_aval() # the next two are required if this type is itself a tangent type def vspace_zero(self) -> HiVal: assert False, "must override" def vspace_add(self, x: HiVal, y: HiVal) -> HiVal: assert False, "must override" # vmap interface (also needed for scan) def dec_rank(self, size: int | None, spec: MappingSpec) -> HiType: assert False, "must override" def inc_rank(self, size: int | None, spec: MappingSpec) -> HiType: assert False, "must override" # scan interface def leading_axis_spec(self) -> MappingSpec: assert False, "must override" # shard_map interface def shard(self, mesh, manual_axes: frozenset, check_vma: bool, spec: HiPspec ) -> HiType: assert False, "must override" def unshard(self, mesh, check_vma: bool, spec: HiPspec) -> HiType: assert False, "must override" def nospec(self, mesh, check_vma: bool, all_names: tuple[AxisName, ...] ) -> HiPspec: assert False, "must override" class MutableHiType(core.AbstractValue): is_high = True has_qdd = True # mutable and potentially type-changing is_writer = False type_state = core.aval_method(core.cur_qdd) # type equality def __hash__(self): assert False, "must override" def __eq__(self, other): assert False, "must override" # define lowering from (mutable) hijax type to (immutable) lojax types def lo_ty_qdd(self, state: QDD, /) -> list[core.AbstractValue]: # pyrefly: ignore[bad-override] assert False, "must override" def lo_ty(self): assert False, "mutable hitypes should use lo_ty_qdd instead" # define lowering from hijax value to lojax values and back, depending on qdd def new_from_loval(self, state: QDD, /, *vals: LoVal) -> HiVal: assert False, "must override" def read_loval(self, state: QDD, val: HiVal, /) -> list[LoVal]: assert False, "must override" # default implementations of newer apis def read_loval_in(self, state, val, /): return self.read_loval(state, val) def read_loval_out(self, qdd, hi, /): return FlatTree.flatten(self.read_loval(qdd, hi)) # define how to mutate/set the mutable hijax value given immutable lojax vals def update_from_loval(self, state: QDD, val: HiVal, /, *lo_vals: LoVal) -> None: assert False, "must override" # default implementation of newer api def update_from_loval2(self, state, val, lo_vals_ft, /) -> None: self.update_from_loval(state, val, *lo_vals_ft.unflatten()) # autodiff interface def to_tangent_aval(self) -> HiType: assert False, "must override" # Subclasses should override if the cotangent type is a function of primal # type. For example, CT unreduced = reduced and vice-versa. def to_ct_aval(self) -> HiType: return self.to_tangent_aval() def register_hitype(val_cls, typeof_fn) -> None: core.pytype_aval_mappings[val_cls] = typeof_fn dtypes.canonicalize_value_handlers[val_cls] = lambda x: x def hijax_method(f): return core.aval_method(f) # Boxes ## Box API def new_box(): (), treedef = tree_flatten(None) return new_box_p.bind(treedef=treedef) def box_get(box): tys = core.cur_qdd(box) leaf_vals = box_get_p.bind(box, avals=tuple(tys.leaf_avals)) return tree_unflatten(tys.treedef, leaf_vals) def box_set(box, val): leaves, treedef = tree_flatten(val) box_set_p.bind(box, *leaves, treedef=treedef) ## Box implementation @dataclass(frozen=True) class BoxTypeState(QDD): leaf_avals: tuple[core.AbstractValue, ...] treedef: PyTreeDef def to_tangent_qdd(self): leaf_avals = tuple(a.to_tangent_aval() for a in self.leaf_avals) return BoxTypeState(leaf_avals, self.treedef) def normalize(self): leaf_types = tuple(a.normalize() for a in self.leaf_avals) return BoxTypeState(leaf_types, self.treedef) class BoxTy(MutableHiType): has_qdd = True # forwarded to value get = core.aval_method(box_get) set = core.aval_method(box_set) # aval interface: hashability and str_short def __hash__(self): return hash(BoxTy) def __eq__(self, other): return isinstance(other, BoxTy) def str_short(self, short_dtypes=False, **_) -> str: # pyrefly: ignore[bad-override] return 'BoxTy' # mutable interface def lo_ty_qdd(self, box_state): return [lo_ty for t in box_state.leaf_avals for lo_ty in t.lo_ty()] def new_from_loval(self, box_state: BoxTypeState, *lo_vals) -> Box: # pyrefly: ignore[bad-override] lo_vals_ = iter(lo_vals) hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # pyrefly: ignore[missing-attribute] for hi_ty in box_state.leaf_avals] assert next(lo_vals_, None) is None return Box._new(tree_unflatten(box_state.treedef, hi_vals)) # will be mutated def read_loval(self, box_state: BoxTypeState, box) -> list: # pyrefly: ignore[bad-override] leaf_vals, treedef = tree_flatten(box_get(box)) assert treedef == box_state.treedef return [lo_val for hi_ty, hi_val in zip(box_state.leaf_avals, leaf_vals) for lo_val in hi_ty.lower_val(hi_val)] # pyrefly: ignore[missing-attribute] def update_from_loval(self, box_state: BoxTypeState, box, *lo_vals) -> None: # pyrefly: ignore[bad-override] lo_vals_ = iter(lo_vals) hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # pyrefly: ignore[missing-attribute] for hi_ty in box_state.leaf_avals] assert next(lo_vals_, None) is None box_set(box, tree_unflatten(box_state.treedef, hi_vals)) def to_tangent_aval(self): return BoxTy() # Override isinstance checks under tracing class _BoxMeta(type): def __instancecheck__(self, instance): return (super().__instancecheck__(instance) or isinstance(instance, core.Tracer) and isinstance(core.typeof(instance), BoxTy)) class Box(metaclass=_BoxMeta): # noqa: F811 _val = None # always clobbered by __new__, but pytype likes this # We want `Box(x)` to bind a primitive, so we override __new__ and provide a # raw `_new` method below. def __new__(cls, init_val=None): (), treedef = tree_flatten(None) box = new_box_p.bind(treedef=treedef) box.set(init_val) return box @classmethod def _new(cls, init_val): new = super().__new__(cls) new._val = init_val return new def get(self): return box_get(self) def set(self, val): box_set(self, val) def cur_qdd(self): return self.type_state() @property def ty(self): return BoxTy() def type_state(self): leaves, treedef = tree_flatten(self._val) leaf_avals = tuple(map(core.typeof, leaves)) return BoxTypeState(leaf_avals, treedef) register_hitype(Box, lambda b: b.ty) class BoxEffect(effects.Effect): ... box_effect = BoxEffect() effects.control_flow_allowed_effects.add_type(BoxEffect) effects.custom_derivatives_allowed_effects.add_type(BoxEffect) class NewBox(HiPrimitive): def is_high(self, *, treedef) -> bool: return True def abstract_eval(self, *, treedef): leaves, treedef = tree_flatten(None) qdd = BoxTypeState(tuple(leaves), treedef) return core.AvalQDD(BoxTy(), qdd), {box_effect} def to_lojax(_, *, treedef): return Box._new(None) def jvp(_, primals, tangents, *, treedef): # pyrefly: ignore[bad-override] assert False # TODO def transpose(_, *args, treedef): assert False # TODO new_box_p = NewBox('new_box') class BoxSet(HiPrimitive): multiple_results = True def is_high(self, *leaf_avals, treedef) -> bool: return True def abstract_eval(self, box_ty, *leaf_avals, treedef): box_ty.mutable_qdd.update(BoxTypeState(leaf_avals, treedef)) return [], {box_effect} # TODO better typechecking... def to_lojax(_, box, *leaves, treedef): box._val = tree_unflatten(treedef, leaves) return [] def jvp(_, primals, tangents, *, treedef): # pyrefly: ignore[bad-override] box, *vals = primals box_dot, *val_dots = tangents if type(box_dot) is ad_util.Zero: raise Exception("can't differentiate Box.set operation, " "did you forget jax.lax.stop_gradient?") box_set_p.bind(box, *vals, treedef=treedef) box_set_p.bind(box_dot, *val_dots, treedef=treedef) return [], [] def transpose(_, *args, treedef): assert False # TODO box_set_p = BoxSet('box_set') class BoxGet(HiPrimitive): multiple_results = True def abstract_eval(self, box_ty, *, avals): return avals, {box_effect} def to_lojax(_, box, *, avals): return tree_leaves(box._val) def jvp(_, primals, tangents, *, avals): # pyrefly: ignore[bad-override] (box,), (box_dot,) = primals, tangents return ( box_get_p.bind(box, avals=avals), box_get_p.bind(box_dot, avals=tuple(a.to_tangent_aval() for a in avals)) ) def transpose(_, *args): assert False # TODO box_get_p = BoxGet('box_get') # === new-style hijax primitive implementation === class VJPHiPrimitive: in_avals: tuple[PyTreeOfAvals, ...] out_aval: PyTreeOfAvals params: dict[str, Hashable] effects: frozenset[effects.Effect] = frozenset() def __init__(self): if not hasattr(self, 'in_avals'): raise AttributeError("subclass __init__ should set `self.in_avals`") if not hasattr(self, 'out_aval'): raise AttributeError("subclass __init__ should set `self.out_aval`") if not hasattr(self, 'params'): raise AttributeError("subclass __init__ should set `self.params`") if (type(self).vjp_bwd is not VJPHiPrimitive.vjp_bwd and type(self).vjp_bwd_retval is not VJPHiPrimitive.vjp_bwd_retval): raise AttributeError(f"subclass {type(self)} should not override both " "`vjp_bwd` and `vjp_bwd_retval`") self.in_avals_flat, self.in_tree = tracing_registry.flatten(self.in_avals) self.out_avals_flat, self.out_tree = tracing_registry.flatten(self.out_aval) self.__dict__.update(self.params) self.check(*self.in_avals) # Operation implementation in terms of lojax primitives def expand(self, *args): raise NotImplementedError(f"subclass {type(self)} must implement `expand`") # reverse-mode AD interface def vjp_fwd(self, nzs_in, /, *args): raise NotImplementedError(f"for grad support, subclass {type(self)} must " "implement `vjp_fwd`") def vjp_bwd(self, res, outgrad, /, *arg_accums): args_grad = self.vjp_bwd_retval(res, outgrad) maybe_accum = lambda acc, v: isinstance(acc, ad.GradAccum) and acc.accum(v) tree_map(maybe_accum, arg_accums, args_grad) def vjp_bwd_retval(self, res, outgrad, /): # Classic API: returns values instead of using accumulators raise NotImplementedError(f"for grad support, subclass {type(self)} must " "implement `vjp_bwd` or `vjp_bwd_retval`") # optional forward-mode AD interfaces def jvp(self, primals, tangents): raise NotImplementedError(f"for jvp support, subclass {type(self)} must " "implement `jvp`") def lin(self, nzs_in, *primals): raise NotImplementedError(f"for linearize support, subclass {type(self)} " "must implement `lin` and `linearized`") def linearized(self, residuals, *tangents): raise NotImplementedError(f"for linearize support, subclass {type(self)} " "must implement `lin` and `linearized`") # optional transpose rule, for primitives that are linear in some inputs def transpose(self, out_ct, *maybe_accums): raise NotImplementedError(f"for transpose support, subclass {type(self)} " "must implement `transpose`") # vmap interface def batch(self, axis_data, args, dims): out_dim = self.batch_dim_rule(axis_data, dims) return VmapOf(self, axis_data, dims, out_dim)(*args), out_dim def batch_dim_rule(self, axis_data, dims, /): raise NotImplementedError(f"for vmap support, subclass {type(self)} must " "implement `batch` or `batch_dim_rule`") # optional dce control def dce(self, used_outs): used_outs_flat = tree_leaves_checked(self.out_tree, used_outs) if not any(used_outs_flat): return False, False, None else: return True, True, self # optional remat control def remat(self, _policy, *args): return self(*args), self # full remat by default def __call__(self, *args): args_flat = tree_leaves_checked(self.in_tree, args) ans_flat = call_hi_primitive_p.bind(*args_flat, _prim=self) return tree_unflatten(self.out_tree, ans_flat) def check(self, *arg_tys): return # subclass can optionally override this to add checking logic def staging(self, trace, source_info, *args): args_flat = tree_leaves_checked(self.in_tree, args) ans_flat = trace.default_process_primitive( call_hi_primitive_p, args_flat, dict(_prim=self), source_info) return tree_unflatten(self.out_tree, ans_flat) def __repr__(self): return f"{self.__class__.__name__}[{self.params}]" def __hash__(self): return hash((self.__class__.__name__, tuple(self.params.items()), self.effects)) def __eq__(self, other): return (type(self) is type(other) and self.params == other.params and self.effects == other.effects) class VmapOf(VJPHiPrimitive): prim: core.Primitive axis_data: batching.AxisData in_dims: Any out_dim: Any def __init__(self, prim, axis_data, in_dims, out_dim): unmap = lambda a, d: core.unmapped_aval(axis_data.size, d, a, axis_data.explicit_mesh_axis) self.in_avals = tree_map(unmap, prim.in_avals, in_dims) self.out_aval = tree_map(unmap, prim.out_aval, out_dim) self.params = dict(prim=prim, axis_data=axis_data, in_dims=in_dims, out_dim=out_dim) super().__init__() @property def _vmap_params(self): return dict(axis_size=self.axis_data.size, axis_name=self.axis_data.name, spmd_axis_name=self.axis_data.spmd_name or self.axis_data.explicit_mesh_axis) def expand(self, *args): return api.vmap(self.prim.expand, in_axes=self.in_dims, out_axes=self.out_dim, # pyrefly: ignore[missing-attribute] **self._vmap_params)(*args) def jvp(self, primals, tangents): # TODO probably gonna get non-pytree-prefix errors because of sym zeros... return api.vmap(self.prim.jvp, in_axes=(self.in_dims, self.in_dims), # pyrefly: ignore[missing-attribute] out_axes=(self.out_dim, self.out_dim), **self._vmap_params)(primals, tangents) def vjp_fwd(self, in_nzs, *args): store = lambda: None def fwd(*args): primal_out, res, *maybe_out_nzs = self.prim.vjp_fwd(in_nzs, *args) # pyrefly: ignore[missing-attribute] store.out_nzs = maybe_out_nzs # pyrefly: ignore[missing-attribute] return primal_out, res (primal_out, res), (_, res_axes) = api.vmap( fwd, in_axes=self.in_dims, out_axes=(self.out_dim, batching.infer), **self._vmap_params)(*args) return primal_out, (res, Static(res_axes)), *store.out_nzs # pyrefly: ignore[missing-attribute] def vjp_bwd_retval(self, res_, g): # TODO probably gonna get non-pytree-prefix errors because of sym zeros... res, res_axes = res_[0], res_[1].val in_dims = tree_map(lambda x: batching.sum_axis if x is None else x, self.in_dims, is_leaf=lambda x: x is None) g = tree_map(partial(map_zero, self.axis_data), self.out_dim, g, is_leaf=lambda x: x is None) out = api.vmap(self.prim.vjp_bwd_retval, in_axes=(res_axes, self.out_dim), # pyrefly: ignore[missing-attribute] out_axes=in_dims, **self._vmap_params, sum_match=True)(res, g) return tree_map(partial(unmap_zero, self.axis_data), self.in_dims, out, is_leaf=lambda x: x is None) def batch_dim_rule(self, axis_data, in_dims): fix = lambda d, d_: d if (d is None or d_ is None) else d - (d_ < d) in_dims_ = tree_map(fix, in_dims, self.in_dims, is_leaf=lambda x: x is None) out_dim = self.prim.batch_dim_rule(axis_data, in_dims_) # pyrefly: ignore[missing-attribute] return tree_map(lambda d, d_: d + (d_ < d), out_dim, self.out_dim) def map_zero(axis_data, d, ct): if isinstance(ct, ad_util.Zero): return ad_util.Zero(core.mapped_aval(axis_data.size, d, ct.aval)) return ct def unmap_zero(axis_data, d, ct): if isinstance(ct, ad_util.Zero): return ad_util.Zero(core.unmapped_aval(axis_data.size, d, ct.aval, axis_data.explicit_mesh_axis)) return ct call_hi_primitive_p = core.Primitive("call_hi_primitive") call_hi_primitive_p.multiple_results = True call_hi_primitive_p.is_high = lambda *args, _prim: True call_hi_primitive_p.is_effectful = lambda params: bool(params['_prim'].effects) @call_hi_primitive_p.def_effectful_abstract_eval def _call_hi_primitive_abstract_eval(*_args, _prim): return _prim.out_avals_flat, _prim.effects def _call_hi_primitive_typecheck(_ctx_factory, *in_atoms_flat, _prim): in_avals = [x.aval for x in in_atoms_flat] if not all(map(core.typematch, in_avals, _prim.in_avals_flat)): raise TypeError(f"input type mismatch for {_prim}") _prim.check() return _prim.out_avals_flat, _prim.effects core.custom_typechecks[call_hi_primitive_p] = _call_hi_primitive_typecheck def _call_hi_primitive_staging(trace, source_info, *args_flat, _prim): trace.frame.is_high = True args = tree_unflatten(_prim.in_tree, args_flat) ans = _prim.staging(trace, source_info, *args) return tree_leaves_checked(_prim.out_tree, ans) pe.custom_staging_rules[call_hi_primitive_p] = _call_hi_primitive_staging def _call_hi_primitive_to_lojax(*args_flat, _prim): args = tree_unflatten(_prim.in_tree, args_flat) ans = _prim.expand(*args) return tree_leaves_checked(_prim.out_tree, ans) call_hi_primitive_p.to_lojax = _call_hi_primitive_to_lojax def _call_hi_primitive_batcher(axis_data, args_flat, dims_flat, _prim): args = tree_unflatten(_prim.in_tree, args_flat) dims = tree_unflatten(_prim.in_tree, dims_flat) ans, dims = _prim.batch(axis_data, args, dims) ans_flat = tree_leaves_checked(_prim.out_tree, ans) dims_flat = _prim.out_tree.flatten_up_to(dims) return ans_flat, dims_flat batching.fancy_primitive_batchers[call_hi_primitive_p] = _call_hi_primitive_batcher def _call_hi_primitive_linearize(is_vjp, nz_in_flat, *args_flat, _prim): args = tree_unflatten(_prim.in_tree, args_flat) nzs_in = tree_unflatten(_prim.in_tree, nz_in_flat) if is_vjp: ans, residuals, *maybe_nzs_out = _prim.vjp_fwd(nzs_in, *args) linearized = partial(fake_linear_op, _prim, nz_in_flat) else: ans, residuals, *maybe_nzs_out = _prim.lin(nzs_in, *args) linearized = partial(flatten_user_linearized, _prim) ans_flat = tree_leaves_checked(_prim.out_tree, ans) nzs_out = maybe_nzs_out[0] if maybe_nzs_out else True nzs_out_flat = broadcast_prefix(nzs_out, ans) return ans_flat, nzs_out_flat, residuals, linearized ad.primitive_linearizations[call_hi_primitive_p] = _call_hi_primitive_linearize def fake_linear_op(prim, nz_in_flat, rs, *tangents): residuals_flat, residuals_tree = tree_flatten(rs) assert nz_in_flat == [not isinstance(t, ad_util.Zero) for t in tangents] nz_tangents = tree_leaves(tangents) return call_hi_primitive_linearized_p.bind( *residuals_flat, *nz_tangents, residuals_tree=residuals_tree, _prim=prim, nz_in_flat=tuple(nz_in_flat)) def flatten_user_linearized(prim, residuals, *tangents_flat): tangents = tree_unflatten(prim.in_tree, tangents_flat) tangents_out = prim.linearized(residuals, *tangents) tangents_out_flat = tree_leaves_checked(prim.out_tree, tangents_out) return tangents_out_flat call_hi_primitive_linearized_p = core.Primitive("call_hi_primitive_linearized") call_hi_primitive_linearized_p.multiple_results = True call_hi_primitive_linearized_p.is_high = lambda *args, _prim, **_: True @call_hi_primitive_linearized_p.def_abstract_eval def _call_hi_primitive_linearized_abstract_eval(*_args, _prim, residuals_tree, nz_in_flat): return [t.to_tangent_aval() for t in _prim.out_avals_flat] # TODO(dougalm): handle nonzeros def _call_hi_primitive_linearized_transpose(cts_flat, *args, _prim, residuals_tree, nz_in_flat): residuals_flat, accums_flat = split_list(args, [residuals_tree.num_leaves]) residuals = tree_unflatten(residuals_tree, residuals_flat) accums_flat_ = iter(accums_flat) accums_flat = [next(accums_flat_) if nz else ad.NullAccum(aval.to_ct_aval()) for aval, nz in zip(_prim.in_avals_flat, nz_in_flat)] assert next(accums_flat_, None) is None accums = tree_unflatten(_prim.in_tree, accums_flat) cts = tree_unflatten(_prim.out_tree, cts_flat) none = _prim.vjp_bwd(residuals, cts, *accums) assert none is None ad.fancy_transposes[call_hi_primitive_linearized_p] = _call_hi_primitive_linearized_transpose def _call_hi_primitive_jvp(primals, tangents, *, _prim): primals = tree_unflatten(_prim.in_tree, primals) tangents = tree_unflatten(_prim.in_tree, tangents) out_primals, out_tangents = _prim.jvp(primals, tangents) out_primals_flat = tree_leaves_checked(_prim.out_tree, out_primals) out_tangents_flat = _prim.out_tree.flatten_up_to(out_tangents) return out_primals_flat, out_tangents_flat ad.primitive_jvps[call_hi_primitive_p] = _call_hi_primitive_jvp def _call_hi_primitive_transpose(cts_flat, *primals_flat, _prim): cts = tree_unflatten(_prim.out_tree, cts_flat) primals = tree_unflatten(_prim.in_tree, primals_flat) none = _prim.transpose(cts, *primals) assert none is None ad.fancy_transposes[call_hi_primitive_p] = _call_hi_primitive_transpose def _call_hi_primitive_dce(used_outs_flat, eqn): _prim = eqn.params['_prim'] used_out = tree_unflatten(_prim.out_tree, used_outs_flat) used_ins, produced_outs, new_prim = _prim.dce(used_out) if new_prim is None: return [False] * len(eqn.invars), None used_ins_flat = broadcast_prefix(used_ins, _prim.in_avals) produced_outs_flat = broadcast_prefix(produced_outs, _prim.out_aval) new_invars = [x for x, u in zip(eqn.invars, used_ins_flat) if u] new_outvars = [v for v, u in zip(eqn.outvars, produced_outs_flat) if u] new_eqn = eqn.replace(invars=new_invars, outvars=new_outvars, params=dict(_prim=new_prim)) return used_ins_flat, new_eqn pe.dce_rules[call_hi_primitive_p] = _call_hi_primitive_dce call_hi_primitive_linearized_p.to_lojax = ad.raise_custom_vjp_error_on_jvp batching.fancy_primitive_batchers[call_hi_primitive_linearized_p] = ad.raise_custom_vjp_error_on_jvp def _call_hi_primitive_remat(policy, *args_flat, _prim): args = tree_unflatten(_prim.in_tree, args_flat) out, rem_ = _prim.remat(policy, *args) def rem(*args_flat): args = tree_unflatten(_prim.in_tree, args_flat) out = rem_(*args) return tree_leaves_checked(_prim.out_tree, out) return tree_leaves_checked(_prim.out_tree, out), rem remat.rules[call_hi_primitive_p] = _call_hi_primitive_remat class CustomVJPTraced(VJPHiPrimitive): traced: Any fwd: Any bwd: Any symbolic_zeros: Any static_argnums: Any opt_remat: bool def __init__(self, traced, fwd, bwd, in_avals, sym_zeros, static_argnums, opt_remat): self.in_avals = in_avals self.out_aval = traced.out_avals self.params = dict(traced=traced, fwd=fwd, bwd=bwd, symbolic_zeros=sym_zeros, static_argnums=static_argnums, opt_remat=opt_remat) super().__init__() def expand(self, *args): args = [x for x in args if not isinstance(x, Static)] return self.traced(*args) def vjp_fwd(self, in_nzs, *args): in_nzs = tuple(x.val if isinstance(x, Static) else x for x in in_nzs) args_ = tuple(x.val if isinstance(x, Static) else x for x in args) if self.symbolic_zeros: args_ = tree_map(CustomVJPPrimal, args_, in_nzs) out, res = self.fwd(*args_) if config.mutable_array_checks.value: _check_for_returned_refs(self.fwd, (out, res), "fwd", tree_leaves(args), self.out_tree.num_leaves) if ((tree := tracing_registry.flatten(out)[1]) != self.out_tree): raise TypeError(_vjp_primal_fwd_tree_mismatch_err(self, tree)) tree_map_with_path(_vjp_fwd_aval_mismatch_err, self.out_aval, out) if self.symbolic_zeros: out_pairs_flat = tree_leaves_checked(self.out_tree, out) out_flat, out_nzs_flat = unzip2( (x.value, x.perturbed) if isinstance(x, CustomVJPPrimal) else (x, True) for x in out_pairs_flat) out_nzs = tree_unflatten(self.out_tree, out_nzs_flat) out = tree_unflatten(self.out_tree, out_flat) return out, res, out_nzs else: return out, res def vjp_bwd_retval(self, res, out_ct): static_args = tuple(x.val for x in self.in_avals if isinstance(x, Static)) in_avals_ = tuple(x for x in self.in_avals if not isinstance(x, Static)) leaf = lambda x: isinstance(x, ad_util.Zero) if self.symbolic_zeros: out_ct = tree_map(ad_util.replace_internal_symbolic_zeros, out_ct, is_leaf=leaf) else: out_ct = tree_map(ad_util.instantiate, out_ct, is_leaf=leaf) in_cts = self.bwd(*static_args, res, out_ct) if isinstance(in_cts, list): in_cts = tuple(in_cts) if not isinstance(in_cts, tuple): raise TypeError(f"Custom VJP bwd rule {self.bwd} must produce a tuple " f"but got {type(in_cts)}.") if len(in_cts) != len(self.in_tree.children()) - len(self.static_argnums): raise ValueError(f"Custom VJP bwd rule {self.bwd} must produce a tuple " "of length equal to the primal args tuple, but got " f"length {len(in_cts)}") in_cts = broadcast_prefix(in_cts, in_avals_, is_leaf=lambda x: x is None) in_cts = tree_unflatten(self.in_tree, map(_replace_none, self.in_avals_flat, in_cts)) tree_map_with_path(_vjp_bwd_aval_mismatch_err, self.in_avals, in_cts) if self.symbolic_zeros: in_cts = tree_map(ad_util.replace_rule_output_symbolic_zeros, in_cts) return in_cts def jvp(self, primals, tangents): if self.symbolic_zeros: raise NotImplementedError zero = lambda x: isinstance(x, ad_util.Zero) tangents = tree_map(ad_util.instantiate, tangents, is_leaf=zero) if self.opt_remat: fwd_traced = api.jit(partial(self.vjp_fwd, (True,) * len(primals))).trace(*primals) primals_out, residuals = OptRemat(self, fwd_traced)(*primals) else: primals_out, residuals, *_ = self.vjp_fwd((True,) * len(primals), *primals) tangents_out_flat = fake_linear_op(self, [True] * len(tangents), residuals, *tangents) tangents_out = tree_unflatten(self.out_tree, tangents_out_flat) return primals_out, tangents_out def batch_dim_rule(self, axis_data, in_dims): in_dims_flat = self.in_tree.flatten_up_to(in_dims) _, out_dims = batching.batch_jaxpr2(self.traced.jaxpr, axis_data, tuple(in_dims_flat)) return tree_unflatten(self.out_tree, out_dims) def check(self, *_): effs = self.traced.jaxpr.effects disallowed = effects.custom_derivatives_allowed_effects.filter_not_in(effs) if disallowed: raise NotImplementedError(f'Effects not supported in `custom_jvp`: {disallowed}') def _vjp_primal_fwd_tree_mismatch_err(self, tree): return (f"Custom VJP fwd rule {self.fwd.__name__} for function {self.traced.fun_name} " "must produce a pair (list or tuple of length two) where the first " "element represents the primal output " "(equal to the output of the custom_vjp-decorated function " f"{self.traced.fun_name}) and the " "second element represents residuals (i.e. values stored from the " "forward pass for use on the backward pass), but " f"instead the fwd rule output's first element had container/pytree " "structure:\n" f""" {str(tree ).replace("'", "")}\n""" f"while the custom_vjp-decorated function {self.traced.fun_name} had output " "container/pytree structure:\n" f""" {str(self.out_tree).replace("'", "")}.""") def _vjp_fwd_aval_mismatch_err(path, primal_aval, fwd_val): if not core.typematch(ty := typeof(fwd_val), primal_aval): raise TypeError(f"at {keystr(path)}, got fwd output type {ty.str_short()} " f"which doesn't match primal output type {primal_aval.str_short()}") def _vjp_bwd_aval_mismatch_err(path, primal_aval, ct): if config.disable_bwd_checks.value: return if isinstance(ct, ad_util.Zero): return if isinstance(primal_aval, AbstractRef): primal_aval = primal_aval.inner_aval expected = primal_aval.to_ct_aval() ct_aval = ct.aval if isinstance(ct, ad_util.SymbolicZero) else typeof(ct) if (not core.typematch(expected, ct_aval) and not _temporary_dtype_exception(expected, ct_aval) and getattr(expected, 'dtype', None) is not dtypes.float0): result = f"at output{keystr(path)} " if path else "" raise ValueError( f"{result}the bwd rule produced an output of type {ct_aval.str_short()}" f" which doesn't match expected type {expected.str_short()}") def _replace_none(primal_in_aval, maybe_ct): if maybe_ct is None: return ad_util.Zero(primal_in_aval.to_ct_aval()) else: return maybe_ct class custom_vjp3: fwd: Callable | None = None bwd: Callable | None = None symz: bool = False opt_remat: bool = False def __init__(self, f, nondiff_argnums=(), nondiff_argnames=()): self.f = f self.static_argnums = _set_up_nondiff(f, nondiff_argnums, nondiff_argnames) update_wrapper(self, f) def defvjp(self, fwd, bwd, *, symbolic_zeros=False, optimize_remat=False): self.fwd = fwd self.bwd = bwd self.symz = symbolic_zeros self.opt_remat = optimize_remat return self def __call__(self, *args, **kwargs): if not self.fwd or not self.bwd: msg = f"No VJP defined for custom_vjp function {self.f.__name__} using defvjp." raise AttributeError(msg) args = resolve_kwargs(self.f, args, kwargs) if any(isinstance(args[i], core.Tracer) for i in self.static_argnums): raise UnexpectedTracerError("custom_vjp inputs marked with nondiff_argnums " "must be static, not Tracers") traced = api.jit(self.f, static_argnums=(*self.static_argnums,)).trace(*args) if any(isinstance(x, core.Tracer) for x in traced._consts): t = next(x for x in traced._consts if isinstance(x, core.Tracer)) raise UnexpectedTracerError( f"custom_vjp-decorated function {self.f} closed over a {type(t).__name__} " f"of type {t.aval.str_short()}, but custom_vjp functions can't close " f"over Tracers. Rewrite {self.f} to take it as an explicit input.") raise Exception # TODO(mattjj):error tracer type, value type, primal name args = tuple(Static(x) if i in self.static_argnums else x for i, x in enumerate(args)) in_avals = tree_map(typeof, args) prim = CustomVJPTraced(traced, self.fwd, self.bwd, in_avals, self.symz, self.static_argnums, self.opt_remat) return prim(*args) def def_vmap(self, rule, /): return self.f.def_vmap(rule) def def_transpose(self, rule, /): return self.f.def_transpose(rule) class OptRemat(VJPHiPrimitive): orig: CustomVJPTraced traced_fwd: Any def __init__(self, orig, traced_fwd): self.in_avals = orig.in_avals self.out_aval = traced_fwd.out_avals self.params = dict(orig=orig, traced_fwd=traced_fwd) super().__init__() def expand(self, *primals): return self.traced_fwd(*primals) def dce(self, used_outs): used_primals, used_res = used_outs if any(tree_leaves(used_res)): return True, (True, True), self # if any res used, no dce at all elif any(tree_leaves(used_primals)): return True, (True, False), self.orig # if only primals used, undo AD else: return False, (False, False), None # TODO(mattjj): jvp and transpose? does anyone rely on them? def _set_up_nondiff(f, argnums_, argnames) -> frozenset[int]: argnums = set(argnums_) if argnames: sig = inspect.signature(f) # needed for static_argnames argnums |= set(infer_argnums_and_argnames(sig, None, argnames)[0]) return frozenset(argnums) @register_static @dataclass(frozen=True) class Static: val: Any class MappingSpec: pass class HiPspec: def to_lo(self) -> HiPspec: assert False, "must override" def to_tangent_spec(self) -> HiPspec: assert False, "must override" def to_ct_spec(self) -> HiPspec: assert False, "must override" # Logs log_effect = box_effect def log_extend(log, dct): leaves, treedef = tree_flatten(dct) log_extend_p.bind(log, *leaves, treedef=treedef) def log_append(log, key, val): log_extend(log, {key: [val]}) def log_read(log): return log_read_p.bind(log) class _LogMeta(type): def __instancecheck__(self, instance): return (super().__instancecheck__(instance) or isinstance(instance, core.Tracer) and isinstance(core.typeof(instance), LogTy)) class Log(metaclass=_LogMeta): # noqa: F811 _dct: dict # dict[str, list[PyTree[Array]]] def __new__(cls): return new_log_p.bind() @classmethod def _new(cls): new = super().__new__(cls) new._dct = {} return new def cur_qdd(self): return () append = log_append extend = log_extend read = log_read class LogTy(MutableHiType): has_qdd = True is_writer = True append = core.aval_method(log_append) extend = core.aval_method(log_extend) read = core.aval_method(log_read) def __hash__(self): return hash(LogTy) def __eq__(self, other): return isinstance(other, LogTy) def str_short(self, short_dtypes=False, **_) -> str: # pyrefly: ignore[bad-override] return 'Log' def to_tangent_aval(self): return LogTy() def read_loval_in(self, qdd, log): () = qdd return [] def read_loval_out(self, qdd, log): () = qdd return FlatTree.flatten(log._dct) def new_from_loval(self, qdd): # pyrefly: ignore[bad-override] () = qdd return Log._new() def update_from_loval2(self, qdd, log: Log, lo_ft) -> None: () = qdd updates = lo_ft.unflatten() for k, v in updates.items(): log._dct.setdefault(k, []).extend(v) register_hitype(Log, lambda _: LogTy()) class LogExtend(HiPrimitive): multiple_results = True # no results is_effectful = lambda *_, **__: True def abstract_eval(self, log_ty, *val_tys, treedef): return [], {log_effect} def to_lojax(_, log, *vals, treedef): updates = tree_unflatten(treedef, vals) for k, v in updates.items(): log._dct.setdefault(k, []).extend(v) return [] log_extend_p = LogExtend('log_extend') class NewLog(HiPrimitive): def is_high(self) -> bool: return True def abstract_eval(self): ty = LogTy() return core.AvalQDD(ty, ()), {log_effect} # pyrefly: ignore[bad-argument-type] def to_lojax(_): return Log._new() new_log_p = NewLog('new_log') def new_log(): return new_log_p.bind() class ReadLog(HiPrimitive): multiple_results = True def is_high(self, _) -> bool: return True def abstract_eval(self, log_qdd): raise Exception def to_lojax(_, log): return list(FlatTree.flatten(log._dct)) log_read_p = ReadLog('log_read')