# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from collections.abc import Callable, Sequence import contextlib import functools import itertools as it from functools import partial from typing import Any from jax._src import config from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_flatten, tree_unflatten, register_pytree_node, PyTreeDef) from jax._src import mesh as mesh_lib from jax._src import core from jax._src import source_info_util from jax._src.ad_util import ( add_jaxvals, replace_internal_symbolic_zeros, replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, SymbolicZero, add_jaxvals_p, p2tz, p2cz) # noqa: F401 from jax._src.api_util import flatten_fun, flatten_fun_nokwargs, debug_info from jax._src.core import (Trace, Tracer, typeof, call_p, Primitive, Literal) from jax._src.dtypes import dtype, float0 from jax._src.state.types import AbstractRef from jax._src.util import (unzip2, safe_map, safe_zip, split_list, weakref_lru_cache, partition_list, subs_list2, foreach) Array = Any Ref = Any zip = safe_zip map = safe_map def identity(x): return x def _update_annotation( f: lu.WrappedFun, orig_type: tuple[core.AbstractValue, ...] | None, nonzeros: list[bool] ) -> lu.WrappedFun: if orig_type is None: return f tan_types = [aval.to_tangent_aval() for nz, aval in zip(nonzeros, orig_type) if nz] return lu.annotate(f, (*orig_type, *tan_types)) def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, transform_stack=True) -> Any: if not has_aux: return jvpfun(jvp_subtrace(fun), instantiate, transform_stack) else: fun, aux = jvp_subtrace_aux(fun) return jvpfun(fun, instantiate, transform_stack), aux @lu.transformation2 def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [p2tz(t) if not isinstance(t, Zero) and isinstance(typeof(t), core.ShapedArray) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) with ctx: out_primals, out_tangents = f(tag, primals, tangents) if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] return out_primals, out_tangents # The result of `f` should be a `FlatTree` def linearize_subtrace_2(f: Callable, is_vjp: bool, tag: core.TraceTag, nzs_in: Sequence[bool], debug_info: core.DebugInfo, primals): source_info = source_info_util.current() with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True) tangent_trace.tag = tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, is_vjp) tracers = [LinearizeTracer(linearize_trace, p, tangent_trace.new_arg(typeof(p).to_tangent_aval(), source_info)) if nz else p for p, nz in zip(primals, nzs_in)] with core.set_current_trace(linearize_trace, check_leaks=True): ans = f(*tracers) out_primals, out_tangents = ans.map(linearize_trace.to_primal_tangent_pair).unzip2() del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) jaxpr, consts = tangent_trace.to_jaxpr(out_tangents, debug_info.with_unknown_names(), source_info) which_env = [(isinstance(c, pe.DynamicJaxprTracer) and getattr(c._trace, 'tag', None) is tag) for c in consts] jaxpr = pe.move_envvars(jaxpr, tuple(which_env)) res, env = partition_list(which_env, consts) residual_avals = map(typeof, res) # Which residuals are just forwarded inputs? Check object id. id_map = {id(p): i for i, p in enumerate(primals)} in_fwd: list[int | None] = [id_map.get(id(r)) for r in res] # Which residuals are already primal outputs? Check object id. id_map = {id(p): i for i, p in enumerate(out_primals)} out_fwd: list[int | None] = [id_map.get(id(r)) for r in res] # Prune residuals not to include forwarded primal inputs or outputs. res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None] aux = (residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd) return res, out_primals, aux @lu.transformation_with_aux2 def linearize_subtrace(_f: Callable, _store: lu.Store, _is_vjp: bool, _tag: core.TraceTag, nzs_in: Sequence[bool], debug_info: core.DebugInfo, *primals, **params): source_info = source_info_util.current() with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True) tangent_trace.tag = _tag linearize_trace = LinearizeTrace(parent_trace, tangent_trace, _is_vjp) tracers = [LinearizeTracer(linearize_trace, p, tangent_trace.new_arg(typeof(p).to_tangent_aval(), source_info)) if nz else p for p, nz in zip(primals, nzs_in)] with core.set_current_trace(linearize_trace, check_leaks=True): ans = _f(*tracers) out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans)) del linearize_trace, ans, tracers nzs_out = tuple(type(t) is not Zero for t in out_tangents) out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz) out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents) jaxpr, consts = tangent_trace.to_jaxpr( out_tangents, debug_info.with_unknown_names(), source_info) which_env = [(isinstance(c, pe.DynamicJaxprTracer) and getattr(c._trace, 'tag', None) is _tag) for c in consts] jaxpr = pe.move_envvars(jaxpr, tuple(which_env)) res, env = partition_list(which_env, consts) residual_avals = map(typeof, res) # Which residuals are just forwarded inputs? Check object id. id_map = {id(p): i for i, p in enumerate(primals)} in_fwd: list[int | None] = [id_map.get(id(r)) for r in res] # Which residuals are already primal outputs? Check object id. id_map = {id(p): i for i, p in enumerate(out_primals)} out_fwd: list[int | None] = [id_map.get(id(r)) for r in res] # Prune residuals not to include forwarded primal inputs or outputs. res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None] _store.store((residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd)) return *res, *out_primals # The result of `f` should be a `FlatTree` def jvp_subtrace_2(f: Callable, tag: core.TraceTag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) in_tracers = [maybe_jvp_tracer(trace, x, t) for x, t in zip(primals, tangents)] with core.set_current_trace(trace): ans = f(*in_tracers) return ans.map(trace.to_primal_tangent_pair).unzip2() @lu.transformation2 def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) in_tracers = [maybe_jvp_tracer(trace, x, t) for x, t in zip(primals, tangents)] with core.set_current_trace(trace): ans = f(*in_tracers) out = unzip2(map(trace.to_primal_tangent_pair, ans)) return out @lu.transformation_with_aux2 def jvp_subtrace_aux(f, store, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) with core.set_current_trace(trace): ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents))) out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag else x for x in aux] store.store(aux_primals) return out_primals, out_tangents def linearize_jaxpr( jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], instantiate: bool | Sequence[bool] = False, allow_fwds: bool | Sequence[bool] = True, *, is_vjp: bool, ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]: if type(allow_fwds) is bool: allow_fwds = (allow_fwds,) * (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) assert len(allow_fwds) == (len(jaxpr.consts) + len(jaxpr.jaxpr.invars)) if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.jaxpr.outvars) assert len(instantiate) == len(jaxpr.jaxpr.outvars) return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate), tuple(allow_fwds), is_vjp) @weakref_lru_cache @source_info_util.reset_name_stack() def _linearize_jaxpr( jaxpr: core.ClosedJaxpr, nonzeros: tuple[bool, ...], instantiate: tuple[bool, ...], allow_fwds: tuple[bool, ...], is_vjp: bool, ) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]: dbg = jaxpr.jaxpr.debug_info config.enable_checks.value and dbg.assert_arg_names(len(nonzeros)) primal_trace = pe.DynamicJaxprTrace(dbg) tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True) tag = core.TraceTag() tangent_trace.tag = tag lin_trace = LinearizeTrace(primal_trace, tangent_trace, is_vjp=is_vjp) def new_arg(trace, primal_aval, nz, source_info): primal = primal_trace.new_arg(primal_aval, source_info) tangent_aval = primal_aval.to_tangent_aval() tangent = tangent_trace.new_arg(tangent_aval, source_info) if nz else Zero(tangent_aval) return LinearizeTracer(trace, primal, tangent) source_info = source_info_util.current() tracers = [new_arg(lin_trace, a, nz, source_info) for (a, nz) in zip(jaxpr.in_aval_qdds, nonzeros)] in_primals = [t.primal for t in tracers] with core.set_current_trace(lin_trace, check_leaks=True): ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers) out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans)) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] del lin_trace, ans, new_arg, tracers # pe._check_no_returned_refs(debug_info, out_tangents) nzs_out = [type(t) is not Zero for t in out_tangents] out_tangents = [tangent_trace.to_jaxpr_tracer(t, source_info) for (nz, t) in zip(nzs_out, out_tangents) if nz] tangent_jaxpr, tangent_consts = tangent_trace.to_jaxpr( out_tangents, dbg.with_unknown_names(), source_info) tangent_trace.invalidate() tangent_jaxpr, tangent_consts = _dce_consts(tangent_jaxpr, tangent_consts) tangent_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(tangent_jaxpr)) fwd_inputs = (*jaxpr.consts, *in_primals) id_map = {id(x):i for i, (x,a) in enumerate(zip(fwd_inputs, allow_fwds)) if a} fwds = [id_map.get(id(c)) for c in tangent_consts] tangent_consts = [c for c, f in zip(tangent_consts, fwds) if f is None] del in_primals # pe._check_no_returned_refs(debug_info, out_primals) primals_and_residuals = *out_primals, *tangent_consts primals_and_residuals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info), primals_and_residuals) primal_jaxpr, primal_consts = primal_trace.to_jaxpr( primals_and_residuals, dbg.with_unknown_names(), source_info) primal_trace.invalidate() primal_jaxpr, primal_consts = _dce_consts(primal_jaxpr, primal_consts) primal_jaxpr = core.ClosedJaxpr(primal_jaxpr, primal_consts) num_residuals_out = len(tangent_consts) return primal_jaxpr, num_residuals_out, nzs_out, fwds, tangent_jaxpr def _dce_consts(jaxpr, consts): jaxpr, used_consts, _ = pe.dce_jaxpr_consts( jaxpr, [True] * len(jaxpr.outvars), [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) return jaxpr, [c for c, used in zip(consts, used_consts) if used] def direct_linearize(traceable, primals, *, has_aux, is_vjp): dbg = traceable.debug_info.with_unknown_names() tag = core.TraceTag() with core.take_current_trace() as parent_trace: source_info = source_info_util.current() tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True) tangents = [tangent_trace.new_arg(typeof(p).to_tangent_aval(), source_info) for p in primals] tangents = [p2tz(t) if not isinstance(t, Zero) and isinstance(typeof(t), core.ShapedArray) and dtype(t) == float0 else t for t in tangents] tangent_trace.tag = tag lin_trace = LinearizeTrace(parent_trace, tangent_trace, is_vjp) tracers = [LinearizeTracer(lin_trace, p, t) for p, t in zip(primals, tangents)] tracers = [t.full_lower() for t in tracers] with (core.set_current_trace(lin_trace), source_info_util.transform_name_stack('jvp')): if has_aux: ans, aux = traceable.call_wrapped(*tracers) aux = [x.primal if type(x) is LinearizeTracer and x._trace.tag is tag else x for x in aux] else: ans = traceable.call_wrapped(*tracers) aux = None out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans)) del lin_trace, ans, tracers out_nzs = [type(t) is not Zero for t in out_tangents] out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz] out_nz_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_nz_tangents) jaxpr, consts = tangent_trace.to_jaxpr(out_nz_tangents, dbg, source_info) tangent_trace.invalidate() config.enable_checks.value and core.check_jaxpr(jaxpr) jaxpr, used_consts, _ = pe.dce_jaxpr_consts( jaxpr, [True] * len(jaxpr.outvars), [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) consts = [c for c, used in zip(consts, used_consts) if used] out_tangents_pvals = [pe.PartialVal.unknown(core.typeof(t)) if nz else pe.PartialVal.known(zeros_like_aval(t.aval)) for t, nz in zip(out_tangents, out_nzs)] if has_aux: return out_primals, out_tangents_pvals, jaxpr, consts, aux else: return out_primals, out_tangents_pvals, jaxpr, consts def linearize(traceable: lu.WrappedFun, *primals, has_aux=False, is_vjp=False): if config.use_direct_linearize.value: return direct_linearize(traceable, primals, has_aux=has_aux, is_vjp=is_vjp) if has_aux: jvpfun, aux = jvp(traceable, has_aux=True) else: jvpfun = jvp(traceable) aux = None in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) + tuple(pe.PartialVal.unknown(typeof(p).to_tangent_aval()) for p in primals)) _, in_tree = tree_flatten(((primals, primals), {})) jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) if any(not out_primal_pval.is_known() for out_primal_pval in out_primals_pvals): raise ValueError( "Linearization failed to produce known values for all output primals. " "This is typically caused by attempting to differentiate a function " "using an operation that does not support reverse-mode autodiff.") out_primals_consts = [pval.get_known() for pval in out_primals_pvals] if not has_aux: assert aux is None return out_primals_consts, out_tangents_pvals, jaxpr, consts else: assert aux is not None return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux() class UndefinedPrimal: __slots__ = ['aval'] def __init__(self, aval): self.aval = aval def __repr__(self): return f'UndefinedPrimal({self.aval})' def is_undefined_primal(x): return type(x) is UndefinedPrimal register_pytree_node(UndefinedPrimal, lambda z: ((), z.aval), lambda aval, _: UndefinedPrimal(aval)) def get_primitive_transpose(p): try: return primitive_transposes[p] except KeyError as err: raise NotImplementedError( "Transpose rule (for reverse-mode differentiation) for '{}' " "not implemented".format(p)) from err def backward_pass3( jaxpr: core.Jaxpr, transform_stack: bool, consts: Sequence[Array], primals_in: Sequence[Array | Ref | GradAccum], cotangents_in: Sequence[Array]) -> None: if all(type(ct) is Zero for ct in cotangents_in) and not jaxpr.effects: return env: dict = dict(zip((*jaxpr.constvars, *jaxpr.invars), (*consts, *primals_in))) def read(x: core.Atom) -> Array | GradAccum: return x.val if isinstance(x, Literal) else env[x] lin_eqns = [] for eqn in jaxpr.eqns: # TODO(mattjj): shorten the lifetime of the reference accumulators, as it # is longer than necessary. if eqn.primitive.ref_primitive: v, = eqn.outvars lin_eqns.append(eqn) if eqn.primitive is core.ref_p or eqn.primitive is core.empty_ref_p: env[v] = RefAccum(v.aval.inner_aval) # pyrefly: ignore[missing-attribute] elif eqn.primitive is core.freeze_p: env[v] = ValAccum(v.aval) elif eqn.primitive is core.accum_grad_in_ref_p: env[v] = RefAccum(v.aval) else: assert False elif any(isinstance(read(x), GradAccum) for x in eqn.invars): for v in eqn.outvars: env[v] = ValAccum(v.aval) lin_eqns.append(eqn) else: params = eqn.primitive.get_bind_params(eqn.params) with eqn.ctx.manager, _name_stack_ctx(eqn.source_info): ans = eqn.primitive.bind(*map(read, eqn.invars), **params) ans = ans if eqn.primitive.multiple_results else [ans] foreach(env.setdefault, eqn.outvars, ans) ctx = (source_info_util.transform_name_stack('transpose') if transform_stack else contextlib.nullcontext()) for acc, ct in zip(map(read, jaxpr.outvars), cotangents_in): if isinstance(acc, GradAccum): acc.accum(ct) # jaxpr.outvars can have Literals, env can have inst zeros with ctx: for eqn in lin_eqns[::-1]: with eqn.ctx.manager, _name_stack_ctx(eqn.source_info): if eqn.primitive is core.empty_ref_p: env.pop(eqn.outvars[0]).freeze() elif eqn.primitive.ref_primitive: ct = env.pop(eqn.outvars[0]).freeze() acc = read(eqn.invars[0]) if isinstance(acc, GradAccum): acc.accum(ct) else: cts_in = [env.pop(v).freeze() for v in eqn.outvars] if not eqn.primitive.multiple_results: cts_in, = cts_in if eqn.primitive in fancy_transposes: rule = fancy_transposes[eqn.primitive] rule(cts_in, *map(read, eqn.invars), **eqn.params) else: rule = get_primitive_transpose(eqn.primitive) primals = map(read, eqn.invars) up = lambda x: UndefinedPrimal(x.aval) if isinstance(x, GradAccum) else x if eqn.primitive.call_primitive: # TODO(mattjj,dougalm): remove this path by revising call/map trans cts_in_avals = [v.aval for v in eqn.outvars] params = dict(eqn.params) call_jaxpr = params.pop('call_jaxpr') cts_out = rule(params, call_jaxpr, map(up, primals), cts_in, cts_in_avals) else: cts_out = rule(cts_in, *map(up, primals), **eqn.params) for x, ct in zip(primals, cts_out): if isinstance(x, GradAccum): x.accum(ct) def _name_stack_ctx(src_info): stack = source_info_util.current_name_stack() + src_info.name_stack return source_info_util.user_context(src_info.traceback, name_stack=stack) class GradAccum: aval: core.AbstractValue def accum(self, x) -> None: assert False def freeze(self) -> Array | Zero: assert False class RefAccum(GradAccum): aval: core.AbstractValue ref: Ref | None def __init__(self, aval, ref=None): self.aval = aval self.ref = ref def accum(self, x): assert x is not Zero if isinstance(x, Zero) or x is None: return if self.ref is None: self.ref = core.new_ref(x) else: ct_check(self, x) self.ref.addupdate(x) def freeze(self): if self.ref is None: return Zero(self.aval) else: return core.freeze(self.ref) def inst(self): if self.ref is None: self.ref = core.new_ref(zeros_like_aval(self.aval)) return self class ValAccum(GradAccum): aval: core.AbstractValue val: Array | Zero def __init__(self, aval, val=None): self.aval = aval self.val = Zero(aval.to_ct_aval()) if val is None else val ct_check(self, self.val) def __repr__(self): return f"ValAccum({self.aval})" def accum(self, x): if x is not None: ct_check(self, x) self.val = add_tangents(self.val, x) def freeze(self): return self.val def ct_check(primal, ct): if config.disable_bwd_checks.value: return ct_aval = ct.aval if type(ct) is Zero else typeof(ct) ct_aval_expected = primal.aval.to_ct_aval() if not core.typematch(ct_aval, ct_aval_expected, no_dtype_check=True): # TODO(yashkatariya, mattjj): Add primitive name here for better error? raise ValueError( f"Expected cotangent type {ct_aval_expected.str_short()} for primal " f"type {primal.aval.str_short()}, but got {ct_aval.str_short()}") class NullAccum(GradAccum): aval: core.AbstractValue def __init__(self, aval): self.aval = aval def __repr__(self): return f"NullAccum({self.aval})" def accum(self, x): return def freeze(self): assert False fancy_transposes: dict[core.Primitive, Callable] = {} def project_accums(args): result, specs = [], [] for x in args: if isinstance(x, ValAccum): specs.append((ValAccum, x.aval)) elif isinstance(x, RefAccum): result.append(x.inst().ref) specs.append((RefAccum, x.aval)) elif isinstance(x, NullAccum): specs.append((NullAccum, x.aval)) else: result.append(x) specs.append((None, typeof(x))) return result, tuple(specs) def unproject_accums(specs, result): args, result_ = [], iter(result) for k, aval in specs: if k is ValAccum: args.append(ValAccum(aval)) elif k is RefAccum: args.append(RefAccum(aval, next(result_))) elif k is NullAccum: args.append(NullAccum(aval)) elif k is None: args.append(next(result_)) else: assert False assert next(result_, None) is None return args def accum_typeof(x): if isinstance(x, GradAccum): return x.aval else: return typeof(x) # TODO(mattjj): this is for for backward (get it?) compatibility. Remove, maybe. def backward_pass(jaxpr, transform_stack: bool, consts, primals_in, cts_in): primals_in = [ValAccum(x.aval) if isinstance(x, UndefinedPrimal) else x for x in primals_in] backward_pass3(jaxpr, transform_stack, consts, primals_in, cts_in) return [x.freeze() if isinstance(x, ValAccum) else p2cz(x) for x in primals_in] def closed_backward_pass(jaxpr: core.ClosedJaxpr, transform_stack, primals_in, cotangents_in): return backward_pass(jaxpr.jaxpr, transform_stack, jaxpr.consts, primals_in, cotangents_in) @lu.transformation_with_aux2 def nonzero_tangent_outputs(f, store, *args, **kwargs): results = (_, tangents_out) = f(*args, **kwargs) store.store([type(r) is not Zero for r in tangents_out]) return results class JVPTrace(Trace): def __init__(self, parent_trace, tag): super().__init__() self.tag = tag self.parent_trace = parent_trace self.requires_low = False def to_primal_tangent_pair(self, val): if isinstance(val, JVPTracer) and val._trace.tag is self.tag: return (val.primal, val.tangent) else: tangent_zero = p2tz(val) return (val, tangent_zero) def process_primitive(self, primitive, tracers, params, /): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if (all(type(t) is Zero for t in tangents_in) and primitive is not core.ref_p and primitive is not core.empty_ref_p and not any(isinstance(typeof(x), AbstractRef) for x in primals_in)): avals = tuple(core.typeof(x) for x in primals_in) return primitive.bind_with_trace(self.parent_trace, primals_in, avals, params) jvp = primitive_jvps.get(primitive) if not jvp: msg = f"Differentiation rule for '{primitive}' not implemented" raise NotImplementedError(msg) with core.set_current_trace(self.parent_trace): primal_out, tangent_out = jvp(primals_in, tangents_in, **params) if primitive.multiple_results: return [maybe_jvp_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)] else: return maybe_jvp_tracer(self, primal_out, tangent_out) def cur_qdd(self, x): p, _ = self.to_primal_tangent_pair(x) with core.set_current_trace(self.parent_trace): return core.cur_qdd(p) def process_call(self, call_primitive, f, tracers, params, /): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) which_nz = [ type(t) is not Zero for t in tangents] tangents = [t if type(t) is not Zero else None for t in tangents] args, in_tree = tree_flatten((primals, tangents)) f_jvp = jvp_subtrace(f, self.tag) f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp) f_jvp, out_tree = traceable(f_jvp, in_tree) update_params = call_param_updaters.get(call_primitive) new_params = update_params(params, which_nz) if update_params else params fun_and_args = _update_annotation(f_jvp.with_unknown_names(), f.in_type, which_nz) new_params = dict(new_params, subfuns=(fun_and_args,)) avals = tuple(core.typeof(x) for x in args) result = call_primitive.bind_with_trace(self.parent_trace, args, avals, new_params) primal_out, tangent_out = tree_unflatten(out_tree(), result) tangent_out = [p2tz(p) if t is None else t for p, t in zip(primal_out, tangent_out)] return [maybe_jvp_tracer(self, p, t) for p, t in zip(primal_out, tangent_out)] def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, *, symbolic_zeros): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): avals = tuple(core.typeof(x) for x in primals_in) return primitive.bind_with_trace( self.parent_trace, tuple(primals_in), avals, dict(subfuns=(fun, jvp), symbolic_zeros=symbolic_zeros)) with core.set_current_trace(self.parent_trace): if not symbolic_zeros: tangents_in = map(instantiate_zeros, tangents_in) else: tangents_in = map(replace_internal_symbolic_zeros, tangents_in) outs = jvp.call_wrapped(*(tuple(primals_in) + tuple(tangents_in))) primals_out, tangents_out = split_list(outs, [len(outs) // 2]) tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees, symbolic_zeros): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): avals = tuple(core.typeof(x) for x in primals_in) return primitive.bind_with_trace( self.parent_trace, tuple(primals_in), avals, dict(subfuns=(fun, fwd, bwd), out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] fwd_in = [x for pair in fwd_in for x in pair] # flatten with core.set_current_trace(self.parent_trace): res_and_primals_out = fwd.call_wrapped(*fwd_in) _, res_tree, input_fwds = out_trees() num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds) res_out, primals_out = split_list(res_and_primals_out, [num_res_out]) res_out_ = iter(res_out) res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds] assert next(res_out_, None) is None avals_out = [core.typeof(x).to_tangent_aval() for x in primals_out] in_zeros = [type(t) is Zero for t in tangents_in] nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] with core.set_current_trace(self.parent_trace): tangents_out = custom_lin_p.bind( *res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros) return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out) def maybe_jvp_tracer(trace, primal, tangent): if (type(tangent) is Zero or isinstance(typeof(tangent), core.ShapedArray) and dtype(tangent) == float0): return primal else: return JVPTracer(trace, primal, tangent) class JVPTracer(Tracer[JVPTrace]): __slots__ = ['primal', 'tangent'] def __init__(self, trace, primal, tangent): if config.enable_checks.value: _primal_tangent_shapes_match(primal, tangent) super().__init__(trace, typeof(primal)) self.primal = primal self.tangent = tangent def _short_repr(self): pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x) primal, tangent = pp(self.primal), pp(self.tangent) return f'JVPTracer({primal=!s}, {tangent=!s})' def cur_qdd(self): return core.cur_qdd(self.primal) def full_lower(self): if type(self.tangent) is Zero: return core.full_lower(self.primal) else: return self def to_concrete_value(self): return core.to_concrete_value(self.primal) def get_referent(self): return core.get_referent(self.primal) def type_state(self): return self.primal.type_state() def _primal_tangent_shapes_match(primal, tangent): if type(tangent) is not Zero: primal_aval = typeof(primal).strip_weak_type() tangent_aval = typeof(tangent).strip_weak_type() if not isinstance(primal_aval, core.ShapedArray): return # TODO(mattjj,dougalm) assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), ( primal_aval.shape, tangent_aval.shape) expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) assert expected_tangent_dtype == tangent_aval.dtype, ( expected_tangent_dtype, tangent_aval.dtype) if (not primal_aval.sharding.mesh.empty and not tangent_aval.sharding.mesh.empty and (primal_aval.sharding.mesh._any_axis_explicit or tangent_aval.sharding.mesh._any_axis_explicit)): assert primal_aval.sharding == tangent_aval.sharding, ( primal_aval.sharding, tangent_aval.sharding) call_param_updaters: dict[core.Primitive, Callable] = {} call_linearize_param_updaters: dict[core.Primitive, Callable] = {} call_transpose_param_updaters: dict[core.Primitive, Callable] = {} # -------------------- Linearize trace -------------------- class LinearizeTrace(Trace): parent_trace: core.Trace | None tangent_trace: core.Trace is_vjp: bool requires_low: bool _name_stack_prefix_len: int def __init__(self, parent_trace, tangent_trace, is_vjp): super().__init__() if not hasattr(tangent_trace, "tag"): raise RuntimeError("Internal: LinearizeTrace.__init__ requires tangent_trace.tag to be defined.") self.parent_trace = parent_trace self.tangent_trace = tangent_trace self.is_vjp = is_vjp self.requires_low = False self._name_stack_prefix_len = len(source_info_util.current_name_stack()) @property def tag(self) -> core.TraceTag: assert hasattr(self.tangent_trace, "tag") return self.tangent_trace.tag def _name_stack_suffix(self): return source_info_util.current_name_stack()[self._name_stack_prefix_len:] def to_primal_tangent_pair(self, val): if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag: return (val.primal, val.tangent) else: tangent_zero = p2tz(val) return (val, tangent_zero) def process_primitive(self, primitive, tracers, params, /): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) tangent_nzs = [type(t) is not Zero for t in tangents_in] if (all(type(t) is Zero for t in tangents_in) and primitive is not core.ref_p and primitive is not core.empty_ref_p and not any(isinstance(typeof(x), AbstractRef) for x in primals_in)): avals = tuple(core.typeof(x) for x in primals_in) return primitive.bind_with_trace(self.parent_trace, primals_in, avals, params) fallback = partial(fallback_linearize_rule, primitive) lin = primitive_linearizations.get(primitive, fallback) with core.set_current_trace(self.parent_trace): primal_out, tangent_nzs_out, residuals, linearized = lin( self.is_vjp, tangent_nzs, *primals_in, **params) with (core.set_current_trace(self.tangent_trace), source_info_util.set_name_stack(self._name_stack_suffix())): tangent_out = linearized(residuals, *tangents_in) if primitive.multiple_results: return [maybe_linearize_tracer(self, x, nz, t) for x, nz, t in zip(primal_out, tangent_nzs_out, tangent_out)] else: return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out) def cur_qdd(self, x): p, _ = self.to_primal_tangent_pair(x) with core.set_current_trace(self.parent_trace): return core.cur_qdd(p) def process_custom_jvp_call(self, primitive, fun: lu.WrappedFun, jvp: lu.WrappedFun, tracers, /, *, symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): avals = [typeof(x) for x in primals_in] return primitive.bind_with_trace( self.parent_trace, tuple(primals_in), avals, dict(subfuns=(fun, jvp), symbolic_zeros=symbolic_zeros)) @partial(lu.wrap_init, debug_info=jvp.debug_info) def _f_jvp(primals, tangents): outs = jvp.call_wrapped(*primals, *tangents) primals_out, tangents_out = split_list(outs, [len(outs) // 2]) return primals_out, tangents_out with core.set_current_trace(self.parent_trace): instantiate_zeros = not symbolic_zeros nonzeros_in = [type(t) is not Zero for t in tangents_in] primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp( _f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros, primals_in, {}) with core.set_current_trace(self.tangent_trace): tangents_out = linearized(residuals, *tangents_in) tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) return [maybe_linearize_tracer(self, x, nz, t) for x, nz, t in zip(primals_out, tangent_nzs_out, tangents_out)] def process_custom_vjp_call(self, primitive, fun, fwd, bwd: lu.WrappedFun, tracers, /, *, out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]], symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): avals = [typeof(x) for x in primals_in] return primitive.bind_with_trace( self.parent_trace, tuple(primals_in), avals, dict(subfuns=(fun, fwd, bwd), out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] fwd_in_flat = [x for pair in fwd_in for x in pair] # flatten with core.set_current_trace(self.parent_trace): res_and_primals_out = fwd.call_wrapped(*fwd_in_flat) _, res_tree, input_fwds = out_trees() num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds) res_out, primals_out = split_list(res_and_primals_out, [num_res_out]) res_out_ = iter(res_out) res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds] assert next(res_out_, None) is None avals_out = [core.typeof(x).to_tangent_aval() for x in primals_out] in_zeros = [type(t) is Zero for t in tangents_in] nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z] with core.set_current_trace(self.tangent_trace): tangents_out = custom_lin_p.bind( *res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros) tangent_nzs_out = [type(t) is not Zero for t in tangents_out] return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out) def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params, /): assert call_primitive.multiple_results primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers)) nzs_in = tuple(type(t) is not Zero for t in tangents) f_primal, linearize_outs_thunk = linearize_subtrace( f, self.is_vjp, self.tag, nzs_in, f.debug_info) avals = [typeof(x) for x in primals] all_primal_results = call_primitive.bind_with_trace( self.parent_trace, primals, avals, dict(params, subfuns=(f_primal,))) residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk() num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)) non_fwd_res = all_primal_results[:num_res_out] primals_out = all_primal_results[num_res_out:] residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res) update_params = call_linearize_param_updaters.get(call_primitive) num_new_args = len(residuals) + len(env) new_params = (update_params(params, num_new_args, nzs_in) if update_params else params) num_residuals = len(residual_avals) f_tangent = _get_f_tangent(lin_jaxpr, num_residuals) nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] new_params = dict(new_params, subfuns=(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),)) args = (*residuals, *env, *nz_tangents_in) avals = [typeof(x) for x in args] nz_tangents_out = call_primitive.bind_with_trace( self.tangent_trace, args, avals, new_params) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else p2tz(primal) for nz, primal in zip(nzs_out, primals_out)] return map(partial(maybe_linearize_tracer, self), primals_out, nzs_out, tangents_out) @weakref_lru_cache def _get_f_tangent(lin_jaxpr, num_residuals): def _f(*args): consts = args[:num_residuals] nz_tangents = args[num_residuals:] return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents) return _f def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): if is_nonzero: assert not type(tangent) is Zero return LinearizeTracer(trace, primal, tangent) else: assert type(tangent) is Zero return primal def fallback_linearize_rule(_prim: core.Primitive, _is_vjp, _nonzeros: Sequence[bool], *primals, **params): jvp = primitive_jvps.get(_prim) if not jvp: msg = f"Differentiation rule for '{_prim}' not implemented" raise NotImplementedError(msg) debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params) return linearize_from_jvp(lu.wrap_init(jvp, debug_info=debug_jvp), _prim.multiple_results, _nonzeros, False, False, primals, params) def linearize_from_jvp(jvp: lu.WrappedFun, multiple_results: bool, nonzeros: Sequence[bool], user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool, primals, params): current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag()) tangent_avals = [typeof(p).to_tangent_aval() for p in primals] # map tangents with float0 dtype to symbolic zeros nonzeros = [nz and not (isinstance(a, core.ShapedArray) and a.dtype == float0) for a, nz in zip(tangent_avals, nonzeros)] def make_zero(aval): if instantiate_input_zeros: return zeros_like_aval(aval) elif user_facing_symbolic_zeros: return SymbolicZero(aval) else: return Zero(aval) if user_facing_symbolic_zeros: zero_type = SymbolicZero else: zero_type = Zero with core.set_current_trace(trace): tangent_args = [trace.new_arg(pe.PartialVal.unknown(a)) if nz else make_zero(a) for a, nz in zip(tangent_avals, nonzeros)] out_primals, out_tangents = jvp.call_wrapped( tuple(primals), tuple(tangent_args), **params) if not multiple_results: out_primals = [out_primals] out_tangents = [out_tangents] out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals] if any(p is None for p in out_primals): raise ValueError( "Linearization failed to produce known values for all output primals. " "This is typically caused by attempting to differentiate a function " "uses an operation that does not support reverse-mode autodiff.") out_nzs = [type(t) is not zero_type and not trace.to_jaxpr_tracer(t).is_known() for t in out_tangents] out_tangent_avals = [typeof(p).to_tangent_aval() for p in out_primals] out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] jaxpr, out_consts, _ = pe.tracers_to_jaxpr( in_tracers, out_nz_tracers, trace.effect_handles, jvp.debug_info.with_unknown_names()) jaxpr, used_consts, _ = pe.dce_jaxpr_consts( jaxpr, [True] * len(jaxpr.outvars), [False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars)) out_consts = [c for used, c in zip(used_consts, out_consts) if used] def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] nz_tangents_out = core.eval_jaxpr(jaxpr, residuals, *nz_tangents_in) nz_tangents_out_iter = iter(nz_tangents_out) all_out_tangents = [next(nz_tangents_out_iter) if nz else Zero(aval) for (aval, nz) in zip(out_tangent_avals, out_nzs)] if multiple_results: return all_out_tangents else: out_tangent, = all_out_tangents return out_tangent if multiple_results: return out_primals, out_nzs, out_consts, linearized else: out_primal, = out_primals out_nz, = out_nzs return out_primal, out_nz, out_consts, linearized class LinearizeTracer(Tracer[LinearizeTrace]): __slots__ = ['primal', 'tangent'] def __init__(self, trace, primal, tangent): if config.enable_checks.value: _primal_tangent_shapes_match(primal, tangent) super().__init__(trace, typeof(primal)) self.primal = primal self.tangent = tangent def _short_repr(self): pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x) primal, tangent = pp(self.primal), typeof(self.tangent).str_short(True) return f"GradTracer({primal=!s}, typeof(tangent)={tangent!s})" def full_lower(self): if type(self.tangent) is Zero: return core.full_lower(self.primal) else: return self def to_concrete_value(self): return core.to_concrete_value(self.primal) def get_referent(self): return core.get_referent(self.primal) def cur_qdd(self): return core.cur_qdd(self.primal) # -------------------- Primitives -------------------- primitive_jvps : dict[core.Primitive, Callable] = {} primitive_transposes: dict[core.Primitive, Callable] = {} primitive_linearizations : dict[core.Primitive, Callable] = {} def deflinear(primitive, transpose_rule): primitive_jvps[primitive] = partial(linear_jvp, primitive) primitive_transposes[primitive] = partial(linear_transpose, transpose_rule) def linear_jvp(primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) if all(type(tangent) is Zero for tangent in tangents): if primitive.multiple_results: return val_out, map(p2tz, val_out) return val_out, p2tz(val_out) else: tangents = map(instantiate_zeros, tangents) return val_out, primitive.bind(*tangents, **params) def linear_transpose(transpose_rule, cotangent, *args, **kwargs): if type(cotangent) is Zero: return [Zero(x.aval.to_tangent_aval()) if isinstance(x, UndefinedPrimal) else None for x in args] else: return transpose_rule(cotangent, **kwargs) def deflinear2(primitive, transpose_rule): primitive_jvps[primitive] = partial(linear_jvp, primitive) primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule) def linear_transpose2(transpose_rule, cotangent, *args, **kwargs): if type(cotangent) is Zero: return [Zero(x.aval.to_ct_aval()) if isinstance(x, UndefinedPrimal) else None for x in args] else: return transpose_rule(cotangent, *args, **kwargs) def defjvp(primitive, *jvprules): assert isinstance(primitive, Primitive) assert not primitive.multiple_results primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive) def standard_jvp(jvprules, primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero] return val_out, functools.reduce(add_tangents, tangents_out, p2tz(val_out)) def defjvp2(primitive, *jvprules): assert isinstance(primitive, Primitive) assert not primitive.multiple_results primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive) def standard_jvp2(jvprules, primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero) tangents_out = list(tangents_out) return val_out, functools.reduce(add_tangents, tangents_out, p2tz(val_out)) def add_tangents(x, y): if type(x) is Zero: return y elif type(y) is Zero: return x else: return add_jaxvals(x, y) def defbilinear(prim, lhs_rule, rhs_rule): assert isinstance(prim, Primitive) lhs_jvp = lambda g, x, y, **kwargs: prim.bind(g, y, **kwargs) rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, g, **kwargs) defjvp(prim, lhs_jvp, rhs_jvp) fancy_transposes[prim] = partial(fancy_bilinear_transpose, lhs_rule, rhs_rule) # TODO(mattjj,yashkatariya): remove next line if downstream doesnt need it primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule) def fancy_bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs): assert isinstance(x, GradAccum) ^ isinstance(y, GradAccum) if isinstance(x, GradAccum): if type(cotangent) is not Zero and not isinstance(x, NullAccum): x.accum(lhs_rule(cotangent, x, y, **kwargs)) else: if type(cotangent) is not Zero and not isinstance(y, NullAccum): y.accum(rhs_rule(cotangent, x, y, **kwargs)) def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs): assert is_undefined_primal(x) ^ is_undefined_primal(y) if is_undefined_primal(x): if type(cotangent) is Zero: return Zero(x.aval.to_ct_aval()), None else: out = lhs_rule(cotangent, x, y, **kwargs) return out, None else: if type(cotangent) is Zero: return None, Zero(y.aval.to_ct_aval()) else: out = rhs_rule(cotangent, x, y, **kwargs) return None, out def defjvp_zero(primitive): assert isinstance(primitive, Primitive) primitive_jvps[primitive] = partial(zero_jvp, primitive) def zero_jvp(primitive, primals, tangents, **params): r = primitive.bind(*primals, **params) return r, p2tz(r) deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) def instantiate_zeros(tangent): if type(tangent) is Zero: if hasattr(tangent.aval, 'sharding'): # TODO(dougalm, yashkatariya): Delete this context manager once we figure # out how to ensure jaxpr arguments always have the context mesh. with mesh_lib.use_abstract_mesh(tangent.aval.sharding.mesh): return zeros_like_aval(tangent.aval) return zeros_like_aval(tangent.aval) return tangent @lu.transformation_with_aux2 def traceable(f, store, in_tree, *primals_and_tangents): primals, tangents = tree_unflatten(in_tree, primals_and_tangents) tangents = [p2tz(p) if t is None else t for p, t in zip(primals, tangents)] primals_out, tangents_out = f(primals, tangents) tangents_out = [None if type(t) is Zero else t for t in tangents_out] out_flat, out_tree = tree_flatten((primals_out, tangents_out)) store.store(out_tree) return out_flat def call_transpose_fancy(primitive, cts, *args, call_jaxpr, **params): if call_jaxpr.constvars: raise NotImplementedError primals_ctrefs, specs = project_accums(args) flat_args, treedef = tree_flatten((primals_ctrefs, cts)) cell = lambda: None @partial(lu.wrap_init, debug_info=call_jaxpr.debug_info.with_unknown_names()) def transposed(*flat_args): primals_ctrefs, cts = tree_unflatten(treedef, flat_args) args = unproject_accums(specs, primals_ctrefs) backward_pass3(call_jaxpr, False, (), args, cts) cts_out = [x.freeze() if isinstance(x, ValAccum) else None for x in args] cts_out, cell.out_tree = tree_flatten(cts_out) # pyrefly: ignore[missing-attribute] return cts_out update_params = call_transpose_param_updaters.get(primitive) if update_params: params = update_params(params, [isinstance(x, GradAccum) for x in args], [type(x) is not Zero for x in cts]) out_flat = primitive.bind(*flat_args, subfuns=(transposed,), **params) for x, ct in zip(args, tree_unflatten(cell.out_tree, out_flat)): # pyrefly: ignore[missing-attribute] if isinstance(x, ValAccum): x.accum(ct) fancy_transposes[core.call_p] = partial(call_transpose_fancy, call_p) def _closed_call_transpose(ct, *args, call_jaxpr, **params): jaxpr_, consts = call_jaxpr.jaxpr, call_jaxpr.consts jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_) call_transpose_fancy(core.closed_call_p, ct, *consts, *args, call_jaxpr=jaxpr_, **params) fancy_transposes[core.closed_call_p] = _closed_call_transpose def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], instantiate: bool | Sequence[bool] ) -> tuple[core.ClosedJaxpr, list[bool]]: if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.out_avals) return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate)) @weakref_lru_cache def _jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool], instantiate: Sequence[bool]): assert len(jaxpr.in_avals) == len(nonzeros) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info.with_unknown_names()) f_jvp, out_nonzeros = f_jvp_traceable( jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_aval_qdds, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_aval_qdds, tangent_avals)) jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic( f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @lu.transformation_with_aux2 def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) tangents = [next(nonzero_tangents) if nz else p2tz(p) for p, nz in zip(primals, nonzeros)] primals_out, tangents_out = f(primals, tangents) out_nonzeros = [type(t) is not Zero for t in tangents_out] nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero] store.store(out_nonzeros) return list(primals_out) + nonzero_tangents_out def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) if jaxpr.jaxpr.debug_info.arg_names is None: new_arg_names = None else: new_arg_names = tuple(_perm(primals_in, tangents_in, jaxpr.jaxpr.debug_info.arg_names)) if jaxpr.jaxpr.debug_info.result_paths is None: new_result_paths = None else: new_result_paths = tuple(_perm(primals_out, tangents_out, jaxpr.jaxpr.debug_info.result_paths)) new_debug_info = jaxpr.jaxpr.debug_info._replace( arg_names=new_arg_names, result_paths=new_result_paths) constvars = jaxpr.jaxpr.constvars new_effects = pe._renumber_effects( (*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars), jaxpr.jaxpr.effects) new_jaxpr = jaxpr.jaxpr.replace( constvars=constvars, invars=new_invars, outvars=new_outvars, effects=new_effects, debug_info=new_debug_info) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int], lst: Sequence[Any]) -> Sequence[Any]: n = sum(primal_counts) primals, tangents = lst[:n], lst[n:] primal_groups = split_list(primals, primal_counts[:-1]) tangent_groups = split_list(tangents, tangent_counts[:-1]) return _interleave(primal_groups, tangent_groups) def _interleave(xs, ys): assert len(xs) == len(ys) return [e for pair in zip(xs, ys) for l in pair for e in l] custom_lin_p: core.Primitive = core.Primitive('custom_lin') custom_lin_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals) custom_lin_p.multiple_results = True def raise_custom_vjp_error_on_jvp(*_, **__): raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp " "function.") custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp) def _custom_lin_transpose(cts_out, *invals, num_res, bwd: lu.WrappedFun, out_avals, symbolic_zeros, in_zeros): res, _ = split_list(invals, [num_res]) if symbolic_zeros: cts_out = map(replace_internal_symbolic_zeros, cts_out) else: cts_out = map(instantiate_zeros, cts_out) cts_in = bwd.call_wrapped(*res, *cts_out) cts_in = map(replace_rule_output_symbolic_zeros, cts_in) nz_cts_in, _ = partition_list(in_zeros, cts_in) return [None] * num_res + nz_cts_in primitive_transposes[custom_lin_p] = _custom_lin_transpose def _custom_lin_pp_rule(eqn: core.JaxprEqn, context: core.JaxprPpContext, settings: core.JaxprPpSettings) -> core.pp.Doc: params = dict(eqn.params) params.pop("out_avals") params["bwd"] = params.pop("bwd").debug_info.func_name return core._pp_eqn(eqn.replace(params=params), context, settings) core.pp_eqn_rules[custom_lin_p] = _custom_lin_pp_rule class CustomVJPException(Exception): def __init__(self): # TODO(mattjj): track source provenance on AD tracers, improve error msg = ("Detected differentiation of a custom_vjp function with respect to " "a closed-over value. That isn't supported because the custom VJP " "rule only specifies how to differentiate the custom_vjp function " "with respect to explicit input parameters. Try passing the " "closed-over value into the custom_vjp function as an argument, and " "adapting the custom_vjp fwd and bwd rules.") super().__init__(msg) # TODO(mattjj): remove this vestigial dict reducing_transposes: dict[core.Primitive, Callable] = {} # TODO(mattjj): remove this old code, used by something downstream def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _): if isinstance(call_jaxpr, core.ClosedJaxpr): call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts else: consts = () all_args, in_treedef = tree_flatten((consts, args, ct)) fun = lu.hashable_partial( lu.wrap_init(backward_pass, debug_info=call_jaxpr.debug_info), call_jaxpr, False) fun, out_tree = flatten_fun_nokwargs(fun, in_treedef) update_params = call_transpose_param_updaters.get(primitive) if update_params: params = update_params(params, map(is_undefined_primal, args), [type(x) is not Zero for x in ct]) out_flat = primitive.bind(*all_args, **dict(params, subfuns=(fun,))) return tree_unflatten(out_tree(), out_flat)