1365 lines
57 KiB
Python
1365 lines
57 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import contextlib
|
|
import functools
|
|
import itertools as it
|
|
from functools import partial
|
|
from typing import Any
|
|
|
|
from jax._src import config
|
|
from jax._src import linear_util as lu
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.tree_util import (tree_flatten, tree_unflatten,
|
|
register_pytree_node, PyTreeDef)
|
|
from jax._src import mesh as mesh_lib
|
|
from jax._src import core
|
|
from jax._src import source_info_util
|
|
from jax._src.ad_util import (
|
|
add_jaxvals, replace_internal_symbolic_zeros,
|
|
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, SymbolicZero,
|
|
add_jaxvals_p, p2tz, p2cz) # noqa: F401
|
|
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs, debug_info
|
|
from jax._src.core import (Trace, Tracer, typeof, call_p, Primitive, Literal)
|
|
from jax._src.dtypes import dtype, float0
|
|
from jax._src.state.types import AbstractRef
|
|
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
|
|
weakref_lru_cache, partition_list, subs_list2,
|
|
foreach)
|
|
|
|
Array = Any
|
|
Ref = Any
|
|
zip = safe_zip
|
|
map = safe_map
|
|
def identity(x): return x
|
|
|
|
def _update_annotation(
|
|
f: lu.WrappedFun,
|
|
orig_type: tuple[core.AbstractValue, ...] | None,
|
|
nonzeros: list[bool]
|
|
) -> lu.WrappedFun:
|
|
if orig_type is None:
|
|
return f
|
|
tan_types = [aval.to_tangent_aval() for nz, aval in zip(nonzeros, orig_type) if nz]
|
|
return lu.annotate(f, (*orig_type, *tan_types))
|
|
|
|
def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
|
|
transform_stack=True) -> Any:
|
|
if not has_aux:
|
|
return jvpfun(jvp_subtrace(fun), instantiate, transform_stack)
|
|
else:
|
|
fun, aux = jvp_subtrace_aux(fun)
|
|
return jvpfun(fun, instantiate, transform_stack), aux
|
|
|
|
@lu.transformation2
|
|
def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
|
|
tag = core.TraceTag()
|
|
tangents = [p2tz(t) if not isinstance(t, Zero)
|
|
and isinstance(typeof(t), core.ShapedArray)
|
|
and dtype(t) == float0 else t for t in tangents]
|
|
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
|
|
else contextlib.nullcontext())
|
|
with ctx:
|
|
out_primals, out_tangents = f(tag, primals, tangents)
|
|
if type(instantiate) is bool:
|
|
instantiate = [instantiate] * len(out_tangents)
|
|
out_tangents = [instantiate_zeros(t) if inst else t for t, inst
|
|
in zip(out_tangents, instantiate)]
|
|
return out_primals, out_tangents
|
|
|
|
# The result of `f` should be a `FlatTree`
|
|
def linearize_subtrace_2(f: Callable, is_vjp: bool,
|
|
tag: core.TraceTag, nzs_in: Sequence[bool],
|
|
debug_info: core.DebugInfo, primals):
|
|
source_info = source_info_util.current()
|
|
with core.take_current_trace() as parent_trace:
|
|
tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True)
|
|
tangent_trace.tag = tag
|
|
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, is_vjp)
|
|
tracers = [LinearizeTracer(linearize_trace, p,
|
|
tangent_trace.new_arg(typeof(p).to_tangent_aval(),
|
|
source_info))
|
|
if nz else p
|
|
for p, nz in zip(primals, nzs_in)]
|
|
with core.set_current_trace(linearize_trace, check_leaks=True):
|
|
ans = f(*tracers)
|
|
out_primals, out_tangents = ans.map(linearize_trace.to_primal_tangent_pair).unzip2()
|
|
del linearize_trace, ans, tracers
|
|
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
|
|
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
|
|
out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info), out_tangents)
|
|
jaxpr, consts = tangent_trace.to_jaxpr(out_tangents, debug_info.with_unknown_names(), source_info)
|
|
which_env = [(isinstance(c, pe.DynamicJaxprTracer) and
|
|
getattr(c._trace, 'tag', None) is tag) for c in consts]
|
|
jaxpr = pe.move_envvars(jaxpr, tuple(which_env))
|
|
res, env = partition_list(which_env, consts)
|
|
residual_avals = map(typeof, res)
|
|
# Which residuals are just forwarded inputs? Check object id.
|
|
id_map = {id(p): i for i, p in enumerate(primals)}
|
|
in_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
|
|
# Which residuals are already primal outputs? Check object id.
|
|
id_map = {id(p): i for i, p in enumerate(out_primals)}
|
|
out_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
|
|
# Prune residuals not to include forwarded primal inputs or outputs.
|
|
res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None]
|
|
aux = (residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd)
|
|
return res, out_primals, aux
|
|
|
|
|
|
@lu.transformation_with_aux2
|
|
def linearize_subtrace(_f: Callable, _store: lu.Store, _is_vjp: bool,
|
|
_tag: core.TraceTag, nzs_in: Sequence[bool],
|
|
debug_info: core.DebugInfo, *primals, **params):
|
|
source_info = source_info_util.current()
|
|
with core.take_current_trace() as parent_trace:
|
|
tangent_trace = pe.DynamicJaxprTrace(debug_info, auto_dce=True)
|
|
tangent_trace.tag = _tag
|
|
linearize_trace = LinearizeTrace(parent_trace, tangent_trace, _is_vjp)
|
|
tracers = [LinearizeTracer(linearize_trace, p,
|
|
tangent_trace.new_arg(typeof(p).to_tangent_aval(),
|
|
source_info))
|
|
if nz else p
|
|
for p, nz in zip(primals, nzs_in)]
|
|
with core.set_current_trace(linearize_trace, check_leaks=True):
|
|
ans = _f(*tracers)
|
|
out_primals, out_tangents = unzip2(map(linearize_trace.to_primal_tangent_pair, ans))
|
|
del linearize_trace, ans, tracers
|
|
nzs_out = tuple(type(t) is not Zero for t in out_tangents)
|
|
out_tangents = tuple(t for t, nz in zip(out_tangents, nzs_out) if nz)
|
|
out_tangents = map(partial(tangent_trace.to_jaxpr_tracer, source_info=source_info),
|
|
out_tangents)
|
|
jaxpr, consts = tangent_trace.to_jaxpr(
|
|
out_tangents, debug_info.with_unknown_names(), source_info)
|
|
which_env = [(isinstance(c, pe.DynamicJaxprTracer) and
|
|
getattr(c._trace, 'tag', None) is _tag) for c in consts]
|
|
jaxpr = pe.move_envvars(jaxpr, tuple(which_env))
|
|
res, env = partition_list(which_env, consts)
|
|
residual_avals = map(typeof, res)
|
|
# Which residuals are just forwarded inputs? Check object id.
|
|
id_map = {id(p): i for i, p in enumerate(primals)}
|
|
in_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
|
|
# Which residuals are already primal outputs? Check object id.
|
|
id_map = {id(p): i for i, p in enumerate(out_primals)}
|
|
out_fwd: list[int | None] = [id_map.get(id(r)) for r in res]
|
|
# Prune residuals not to include forwarded primal inputs or outputs.
|
|
res = [p for p, f1, f2 in zip(res, in_fwd, out_fwd) if f1 is None and f2 is None]
|
|
_store.store((residual_avals, nzs_out, jaxpr, env, in_fwd, out_fwd))
|
|
return *res, *out_primals
|
|
|
|
# The result of `f` should be a `FlatTree`
|
|
def jvp_subtrace_2(f: Callable, tag: core.TraceTag, primals, tangents):
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = JVPTrace(parent_trace, tag)
|
|
in_tracers = [maybe_jvp_tracer(trace, x, t)
|
|
for x, t in zip(primals, tangents)]
|
|
with core.set_current_trace(trace):
|
|
ans = f(*in_tracers)
|
|
return ans.map(trace.to_primal_tangent_pair).unzip2()
|
|
|
|
@lu.transformation2
|
|
def jvp_subtrace(f: Callable, tag: core.TraceTag, primals, tangents):
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = JVPTrace(parent_trace, tag)
|
|
in_tracers = [maybe_jvp_tracer(trace, x, t)
|
|
for x, t in zip(primals, tangents)]
|
|
with core.set_current_trace(trace):
|
|
ans = f(*in_tracers)
|
|
out = unzip2(map(trace.to_primal_tangent_pair, ans))
|
|
return out
|
|
|
|
@lu.transformation_with_aux2
|
|
def jvp_subtrace_aux(f, store, tag, primals, tangents):
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = JVPTrace(parent_trace, tag)
|
|
with core.set_current_trace(trace):
|
|
ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents)))
|
|
out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans))
|
|
aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag
|
|
else x for x in aux]
|
|
store.store(aux_primals)
|
|
return out_primals, out_tangents
|
|
|
|
def linearize_jaxpr(
|
|
jaxpr: core.ClosedJaxpr,
|
|
nonzeros: Sequence[bool],
|
|
instantiate: bool | Sequence[bool] = False,
|
|
allow_fwds: bool | Sequence[bool] = True,
|
|
*,
|
|
is_vjp: bool,
|
|
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]:
|
|
if type(allow_fwds) is bool:
|
|
allow_fwds = (allow_fwds,) * (len(jaxpr.consts) + len(jaxpr.jaxpr.invars))
|
|
assert len(allow_fwds) == (len(jaxpr.consts) + len(jaxpr.jaxpr.invars))
|
|
if type(instantiate) is bool:
|
|
instantiate = (instantiate,) * len(jaxpr.jaxpr.outvars)
|
|
assert len(instantiate) == len(jaxpr.jaxpr.outvars)
|
|
return _linearize_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate),
|
|
tuple(allow_fwds), is_vjp)
|
|
|
|
@weakref_lru_cache
|
|
@source_info_util.reset_name_stack()
|
|
def _linearize_jaxpr(
|
|
jaxpr: core.ClosedJaxpr,
|
|
nonzeros: tuple[bool, ...],
|
|
instantiate: tuple[bool, ...],
|
|
allow_fwds: tuple[bool, ...],
|
|
is_vjp: bool,
|
|
) -> tuple[core.ClosedJaxpr, int, Sequence[bool], Sequence[int | None], core.ClosedJaxpr]:
|
|
dbg = jaxpr.jaxpr.debug_info
|
|
config.enable_checks.value and dbg.assert_arg_names(len(nonzeros))
|
|
primal_trace = pe.DynamicJaxprTrace(dbg)
|
|
tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True)
|
|
tag = core.TraceTag()
|
|
tangent_trace.tag = tag
|
|
lin_trace = LinearizeTrace(primal_trace, tangent_trace, is_vjp=is_vjp)
|
|
|
|
def new_arg(trace, primal_aval, nz, source_info):
|
|
primal = primal_trace.new_arg(primal_aval, source_info)
|
|
tangent_aval = primal_aval.to_tangent_aval()
|
|
tangent = tangent_trace.new_arg(tangent_aval, source_info) if nz else Zero(tangent_aval)
|
|
return LinearizeTracer(trace, primal, tangent)
|
|
|
|
source_info = source_info_util.current()
|
|
tracers = [new_arg(lin_trace, a, nz, source_info)
|
|
for (a, nz) in zip(jaxpr.in_aval_qdds, nonzeros)]
|
|
in_primals = [t.primal for t in tracers]
|
|
|
|
with core.set_current_trace(lin_trace, check_leaks=True):
|
|
ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers)
|
|
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
|
|
out_tangents = [instantiate_zeros(t) if inst else t
|
|
for t, inst in zip(out_tangents, instantiate)]
|
|
del lin_trace, ans, new_arg, tracers
|
|
|
|
# pe._check_no_returned_refs(debug_info, out_tangents)
|
|
nzs_out = [type(t) is not Zero for t in out_tangents]
|
|
out_tangents = [tangent_trace.to_jaxpr_tracer(t, source_info)
|
|
for (nz, t) in zip(nzs_out, out_tangents) if nz]
|
|
tangent_jaxpr, tangent_consts = tangent_trace.to_jaxpr(
|
|
out_tangents, dbg.with_unknown_names(), source_info)
|
|
tangent_trace.invalidate()
|
|
tangent_jaxpr, tangent_consts = _dce_consts(tangent_jaxpr, tangent_consts)
|
|
tangent_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(tangent_jaxpr))
|
|
|
|
fwd_inputs = (*jaxpr.consts, *in_primals)
|
|
id_map = {id(x):i for i, (x,a) in enumerate(zip(fwd_inputs, allow_fwds)) if a}
|
|
fwds = [id_map.get(id(c)) for c in tangent_consts]
|
|
tangent_consts = [c for c, f in zip(tangent_consts, fwds) if f is None]
|
|
del in_primals
|
|
|
|
# pe._check_no_returned_refs(debug_info, out_primals)
|
|
primals_and_residuals = *out_primals, *tangent_consts
|
|
primals_and_residuals = map(partial(primal_trace.to_jaxpr_tracer, source_info=source_info),
|
|
primals_and_residuals)
|
|
primal_jaxpr, primal_consts = primal_trace.to_jaxpr(
|
|
primals_and_residuals, dbg.with_unknown_names(),
|
|
source_info)
|
|
primal_trace.invalidate()
|
|
primal_jaxpr, primal_consts = _dce_consts(primal_jaxpr, primal_consts)
|
|
primal_jaxpr = core.ClosedJaxpr(primal_jaxpr, primal_consts)
|
|
|
|
num_residuals_out = len(tangent_consts)
|
|
return primal_jaxpr, num_residuals_out, nzs_out, fwds, tangent_jaxpr
|
|
|
|
def _dce_consts(jaxpr, consts):
|
|
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
|
|
jaxpr, [True] * len(jaxpr.outvars),
|
|
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
|
|
return jaxpr, [c for c, used in zip(consts, used_consts) if used]
|
|
|
|
def direct_linearize(traceable, primals, *, has_aux, is_vjp):
|
|
dbg = traceable.debug_info.with_unknown_names()
|
|
tag = core.TraceTag()
|
|
with core.take_current_trace() as parent_trace:
|
|
source_info = source_info_util.current()
|
|
tangent_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True)
|
|
tangents = [tangent_trace.new_arg(typeof(p).to_tangent_aval(), source_info) for p in primals]
|
|
tangents = [p2tz(t) if not isinstance(t, Zero)
|
|
and isinstance(typeof(t), core.ShapedArray)
|
|
and dtype(t) == float0 else t for t in tangents]
|
|
tangent_trace.tag = tag
|
|
lin_trace = LinearizeTrace(parent_trace, tangent_trace, is_vjp)
|
|
tracers = [LinearizeTracer(lin_trace, p, t) for p, t in zip(primals, tangents)]
|
|
tracers = [t.full_lower() for t in tracers]
|
|
with (core.set_current_trace(lin_trace),
|
|
source_info_util.transform_name_stack('jvp')):
|
|
if has_aux:
|
|
ans, aux = traceable.call_wrapped(*tracers)
|
|
aux = [x.primal if type(x) is LinearizeTracer and x._trace.tag is tag
|
|
else x for x in aux]
|
|
else:
|
|
ans = traceable.call_wrapped(*tracers)
|
|
aux = None
|
|
out_primals, out_tangents = unzip2(map(lin_trace.to_primal_tangent_pair, ans))
|
|
del lin_trace, ans, tracers
|
|
out_nzs = [type(t) is not Zero for t in out_tangents]
|
|
out_nz_tangents = [t for t, nz in zip(out_tangents, out_nzs) if nz]
|
|
out_nz_tangents = map(partial(tangent_trace.to_jaxpr_tracer,
|
|
source_info=source_info), out_nz_tangents)
|
|
jaxpr, consts = tangent_trace.to_jaxpr(out_nz_tangents, dbg, source_info)
|
|
tangent_trace.invalidate()
|
|
config.enable_checks.value and core.check_jaxpr(jaxpr)
|
|
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
|
|
jaxpr, [True] * len(jaxpr.outvars),
|
|
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
|
|
consts = [c for c, used in zip(consts, used_consts) if used]
|
|
out_tangents_pvals = [pe.PartialVal.unknown(core.typeof(t)) if nz else
|
|
pe.PartialVal.known(zeros_like_aval(t.aval))
|
|
for t, nz in zip(out_tangents, out_nzs)]
|
|
if has_aux:
|
|
return out_primals, out_tangents_pvals, jaxpr, consts, aux
|
|
else:
|
|
return out_primals, out_tangents_pvals, jaxpr, consts
|
|
|
|
def linearize(traceable: lu.WrappedFun, *primals, has_aux=False, is_vjp=False):
|
|
if config.use_direct_linearize.value:
|
|
return direct_linearize(traceable, primals, has_aux=has_aux, is_vjp=is_vjp)
|
|
if has_aux:
|
|
jvpfun, aux = jvp(traceable, has_aux=True)
|
|
else:
|
|
jvpfun = jvp(traceable)
|
|
aux = None
|
|
|
|
in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
|
|
+ tuple(pe.PartialVal.unknown(typeof(p).to_tangent_aval())
|
|
for p in primals))
|
|
_, in_tree = tree_flatten(((primals, primals), {}))
|
|
jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
|
|
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
|
|
out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
|
|
if any(not out_primal_pval.is_known() for out_primal_pval in out_primals_pvals):
|
|
raise ValueError(
|
|
"Linearization failed to produce known values for all output primals. "
|
|
"This is typically caused by attempting to differentiate a function "
|
|
"using an operation that does not support reverse-mode autodiff.")
|
|
out_primals_consts = [pval.get_known() for pval in out_primals_pvals]
|
|
if not has_aux:
|
|
assert aux is None
|
|
return out_primals_consts, out_tangents_pvals, jaxpr, consts
|
|
else:
|
|
assert aux is not None
|
|
return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
|
|
|
|
|
|
class UndefinedPrimal:
|
|
__slots__ = ['aval']
|
|
def __init__(self, aval):
|
|
self.aval = aval
|
|
def __repr__(self):
|
|
return f'UndefinedPrimal({self.aval})'
|
|
|
|
def is_undefined_primal(x):
|
|
return type(x) is UndefinedPrimal
|
|
|
|
register_pytree_node(UndefinedPrimal,
|
|
lambda z: ((), z.aval),
|
|
lambda aval, _: UndefinedPrimal(aval))
|
|
|
|
def get_primitive_transpose(p):
|
|
try:
|
|
return primitive_transposes[p]
|
|
except KeyError as err:
|
|
raise NotImplementedError(
|
|
"Transpose rule (for reverse-mode differentiation) for '{}' "
|
|
"not implemented".format(p)) from err
|
|
|
|
|
|
def backward_pass3(
|
|
jaxpr: core.Jaxpr, transform_stack: bool,
|
|
consts: Sequence[Array], primals_in: Sequence[Array | Ref | GradAccum],
|
|
cotangents_in: Sequence[Array]) -> None:
|
|
if all(type(ct) is Zero for ct in cotangents_in) and not jaxpr.effects:
|
|
return
|
|
|
|
env: dict = dict(zip((*jaxpr.constvars, *jaxpr.invars),
|
|
(*consts, *primals_in)))
|
|
|
|
def read(x: core.Atom) -> Array | GradAccum:
|
|
return x.val if isinstance(x, Literal) else env[x]
|
|
|
|
lin_eqns = []
|
|
for eqn in jaxpr.eqns:
|
|
# TODO(mattjj): shorten the lifetime of the reference accumulators, as it
|
|
# is longer than necessary.
|
|
if eqn.primitive.ref_primitive:
|
|
v, = eqn.outvars
|
|
lin_eqns.append(eqn)
|
|
if eqn.primitive is core.ref_p or eqn.primitive is core.empty_ref_p:
|
|
env[v] = RefAccum(v.aval.inner_aval) # pyrefly: ignore[missing-attribute]
|
|
elif eqn.primitive is core.freeze_p:
|
|
env[v] = ValAccum(v.aval)
|
|
elif eqn.primitive is core.accum_grad_in_ref_p:
|
|
env[v] = RefAccum(v.aval)
|
|
else:
|
|
assert False
|
|
elif any(isinstance(read(x), GradAccum) for x in eqn.invars):
|
|
for v in eqn.outvars:
|
|
env[v] = ValAccum(v.aval)
|
|
lin_eqns.append(eqn)
|
|
else:
|
|
params = eqn.primitive.get_bind_params(eqn.params)
|
|
with eqn.ctx.manager, _name_stack_ctx(eqn.source_info):
|
|
ans = eqn.primitive.bind(*map(read, eqn.invars), **params)
|
|
ans = ans if eqn.primitive.multiple_results else [ans]
|
|
foreach(env.setdefault, eqn.outvars, ans)
|
|
|
|
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
|
|
else contextlib.nullcontext())
|
|
for acc, ct in zip(map(read, jaxpr.outvars), cotangents_in):
|
|
if isinstance(acc, GradAccum):
|
|
acc.accum(ct) # jaxpr.outvars can have Literals, env can have inst zeros
|
|
with ctx:
|
|
for eqn in lin_eqns[::-1]:
|
|
with eqn.ctx.manager, _name_stack_ctx(eqn.source_info):
|
|
if eqn.primitive is core.empty_ref_p:
|
|
env.pop(eqn.outvars[0]).freeze()
|
|
elif eqn.primitive.ref_primitive:
|
|
ct = env.pop(eqn.outvars[0]).freeze()
|
|
acc = read(eqn.invars[0])
|
|
if isinstance(acc, GradAccum):
|
|
acc.accum(ct)
|
|
else:
|
|
cts_in = [env.pop(v).freeze() for v in eqn.outvars]
|
|
if not eqn.primitive.multiple_results:
|
|
cts_in, = cts_in
|
|
if eqn.primitive in fancy_transposes:
|
|
rule = fancy_transposes[eqn.primitive]
|
|
rule(cts_in, *map(read, eqn.invars), **eqn.params)
|
|
else:
|
|
rule = get_primitive_transpose(eqn.primitive)
|
|
primals = map(read, eqn.invars)
|
|
up = lambda x: UndefinedPrimal(x.aval) if isinstance(x, GradAccum) else x
|
|
if eqn.primitive.call_primitive:
|
|
# TODO(mattjj,dougalm): remove this path by revising call/map trans
|
|
cts_in_avals = [v.aval for v in eqn.outvars]
|
|
params = dict(eqn.params)
|
|
call_jaxpr = params.pop('call_jaxpr')
|
|
cts_out = rule(params, call_jaxpr, map(up, primals), cts_in, cts_in_avals)
|
|
else:
|
|
cts_out = rule(cts_in, *map(up, primals), **eqn.params)
|
|
for x, ct in zip(primals, cts_out):
|
|
if isinstance(x, GradAccum):
|
|
x.accum(ct)
|
|
|
|
def _name_stack_ctx(src_info):
|
|
stack = source_info_util.current_name_stack() + src_info.name_stack
|
|
return source_info_util.user_context(src_info.traceback, name_stack=stack)
|
|
|
|
class GradAccum:
|
|
aval: core.AbstractValue
|
|
|
|
def accum(self, x) -> None:
|
|
assert False
|
|
def freeze(self) -> Array | Zero:
|
|
assert False
|
|
|
|
class RefAccum(GradAccum):
|
|
aval: core.AbstractValue
|
|
ref: Ref | None
|
|
|
|
def __init__(self, aval, ref=None):
|
|
self.aval = aval
|
|
self.ref = ref
|
|
|
|
def accum(self, x):
|
|
assert x is not Zero
|
|
if isinstance(x, Zero) or x is None:
|
|
return
|
|
if self.ref is None:
|
|
self.ref = core.new_ref(x)
|
|
else:
|
|
ct_check(self, x)
|
|
self.ref.addupdate(x)
|
|
|
|
def freeze(self):
|
|
if self.ref is None:
|
|
return Zero(self.aval)
|
|
else:
|
|
return core.freeze(self.ref)
|
|
|
|
def inst(self):
|
|
if self.ref is None:
|
|
self.ref = core.new_ref(zeros_like_aval(self.aval))
|
|
return self
|
|
|
|
class ValAccum(GradAccum):
|
|
aval: core.AbstractValue
|
|
val: Array | Zero
|
|
|
|
def __init__(self, aval, val=None):
|
|
self.aval = aval
|
|
self.val = Zero(aval.to_ct_aval()) if val is None else val
|
|
ct_check(self, self.val)
|
|
|
|
def __repr__(self):
|
|
return f"ValAccum({self.aval})"
|
|
|
|
def accum(self, x):
|
|
if x is not None:
|
|
ct_check(self, x)
|
|
self.val = add_tangents(self.val, x)
|
|
|
|
def freeze(self):
|
|
return self.val
|
|
|
|
def ct_check(primal, ct):
|
|
if config.disable_bwd_checks.value:
|
|
return
|
|
ct_aval = ct.aval if type(ct) is Zero else typeof(ct)
|
|
ct_aval_expected = primal.aval.to_ct_aval()
|
|
if not core.typematch(ct_aval, ct_aval_expected, no_dtype_check=True):
|
|
# TODO(yashkatariya, mattjj): Add primitive name here for better error?
|
|
raise ValueError(
|
|
f"Expected cotangent type {ct_aval_expected.str_short()} for primal "
|
|
f"type {primal.aval.str_short()}, but got {ct_aval.str_short()}")
|
|
|
|
class NullAccum(GradAccum):
|
|
aval: core.AbstractValue
|
|
|
|
def __init__(self, aval): self.aval = aval
|
|
def __repr__(self): return f"NullAccum({self.aval})"
|
|
def accum(self, x): return
|
|
def freeze(self): assert False
|
|
|
|
|
|
fancy_transposes: dict[core.Primitive, Callable] = {}
|
|
|
|
def project_accums(args):
|
|
result, specs = [], []
|
|
for x in args:
|
|
if isinstance(x, ValAccum):
|
|
specs.append((ValAccum, x.aval))
|
|
elif isinstance(x, RefAccum):
|
|
result.append(x.inst().ref)
|
|
specs.append((RefAccum, x.aval))
|
|
elif isinstance(x, NullAccum):
|
|
specs.append((NullAccum, x.aval))
|
|
else:
|
|
result.append(x)
|
|
specs.append((None, typeof(x)))
|
|
return result, tuple(specs)
|
|
|
|
def unproject_accums(specs, result):
|
|
args, result_ = [], iter(result)
|
|
for k, aval in specs:
|
|
if k is ValAccum:
|
|
args.append(ValAccum(aval))
|
|
elif k is RefAccum:
|
|
args.append(RefAccum(aval, next(result_)))
|
|
elif k is NullAccum:
|
|
args.append(NullAccum(aval))
|
|
elif k is None:
|
|
args.append(next(result_))
|
|
else:
|
|
assert False
|
|
assert next(result_, None) is None
|
|
return args
|
|
|
|
def accum_typeof(x):
|
|
if isinstance(x, GradAccum):
|
|
return x.aval
|
|
else:
|
|
return typeof(x)
|
|
|
|
# TODO(mattjj): this is for for backward (get it?) compatibility. Remove, maybe.
|
|
def backward_pass(jaxpr, transform_stack: bool, consts, primals_in, cts_in):
|
|
primals_in = [ValAccum(x.aval) if isinstance(x, UndefinedPrimal) else x
|
|
for x in primals_in]
|
|
backward_pass3(jaxpr, transform_stack, consts, primals_in, cts_in)
|
|
return [x.freeze() if isinstance(x, ValAccum) else p2cz(x)
|
|
for x in primals_in]
|
|
|
|
def closed_backward_pass(jaxpr: core.ClosedJaxpr, transform_stack,
|
|
primals_in, cotangents_in):
|
|
return backward_pass(jaxpr.jaxpr, transform_stack, jaxpr.consts,
|
|
primals_in, cotangents_in)
|
|
|
|
|
|
@lu.transformation_with_aux2
|
|
def nonzero_tangent_outputs(f, store, *args, **kwargs):
|
|
results = (_, tangents_out) = f(*args, **kwargs)
|
|
store.store([type(r) is not Zero for r in tangents_out])
|
|
return results
|
|
|
|
|
|
class JVPTrace(Trace):
|
|
def __init__(self, parent_trace, tag):
|
|
super().__init__()
|
|
self.tag = tag
|
|
self.parent_trace = parent_trace
|
|
self.requires_low = False
|
|
|
|
def to_primal_tangent_pair(self, val):
|
|
if isinstance(val, JVPTracer) and val._trace.tag is self.tag:
|
|
return (val.primal, val.tangent)
|
|
else:
|
|
tangent_zero = p2tz(val)
|
|
return (val, tangent_zero)
|
|
|
|
def process_primitive(self, primitive, tracers, params, /):
|
|
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
|
|
if (all(type(t) is Zero for t in tangents_in) and
|
|
primitive is not core.ref_p and primitive is not core.empty_ref_p and
|
|
not any(isinstance(typeof(x), AbstractRef) for x in primals_in)):
|
|
avals = tuple(core.typeof(x) for x in primals_in)
|
|
return primitive.bind_with_trace(self.parent_trace, primals_in, avals, params)
|
|
jvp = primitive_jvps.get(primitive)
|
|
if not jvp:
|
|
msg = f"Differentiation rule for '{primitive}' not implemented"
|
|
raise NotImplementedError(msg)
|
|
with core.set_current_trace(self.parent_trace):
|
|
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
|
|
|
|
if primitive.multiple_results:
|
|
return [maybe_jvp_tracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
|
|
else:
|
|
return maybe_jvp_tracer(self, primal_out, tangent_out)
|
|
|
|
def cur_qdd(self, x):
|
|
p, _ = self.to_primal_tangent_pair(x)
|
|
with core.set_current_trace(self.parent_trace):
|
|
return core.cur_qdd(p)
|
|
|
|
def process_call(self, call_primitive, f, tracers, params, /):
|
|
assert call_primitive.multiple_results
|
|
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
|
|
which_nz = [ type(t) is not Zero for t in tangents]
|
|
tangents = [t if type(t) is not Zero else None for t in tangents]
|
|
args, in_tree = tree_flatten((primals, tangents))
|
|
f_jvp = jvp_subtrace(f, self.tag)
|
|
f_jvp, which_nz_out = nonzero_tangent_outputs(f_jvp)
|
|
f_jvp, out_tree = traceable(f_jvp, in_tree)
|
|
update_params = call_param_updaters.get(call_primitive)
|
|
new_params = update_params(params, which_nz) if update_params else params
|
|
fun_and_args = _update_annotation(f_jvp.with_unknown_names(), f.in_type, which_nz)
|
|
new_params = dict(new_params, subfuns=(fun_and_args,))
|
|
avals = tuple(core.typeof(x) for x in args)
|
|
result = call_primitive.bind_with_trace(self.parent_trace, args, avals, new_params)
|
|
primal_out, tangent_out = tree_unflatten(out_tree(), result)
|
|
tangent_out = [p2tz(p) if t is None else t
|
|
for p, t in zip(primal_out, tangent_out)]
|
|
return [maybe_jvp_tracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
|
|
|
|
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, /, *, symbolic_zeros):
|
|
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
|
|
if all(type(t) is Zero for t in tangents_in):
|
|
avals = tuple(core.typeof(x) for x in primals_in)
|
|
return primitive.bind_with_trace(
|
|
self.parent_trace, tuple(primals_in), avals,
|
|
dict(subfuns=(fun, jvp), symbolic_zeros=symbolic_zeros))
|
|
with core.set_current_trace(self.parent_trace):
|
|
if not symbolic_zeros:
|
|
tangents_in = map(instantiate_zeros, tangents_in)
|
|
else:
|
|
tangents_in = map(replace_internal_symbolic_zeros, tangents_in)
|
|
outs = jvp.call_wrapped(*(tuple(primals_in) + tuple(tangents_in)))
|
|
|
|
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
|
tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out)
|
|
return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out)
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, /, *, out_trees,
|
|
symbolic_zeros):
|
|
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
|
|
if all(type(t) is Zero for t in tangents_in):
|
|
avals = tuple(core.typeof(x) for x in primals_in)
|
|
return primitive.bind_with_trace(
|
|
self.parent_trace, tuple(primals_in), avals,
|
|
dict(subfuns=(fun, fwd, bwd), out_trees=out_trees,
|
|
symbolic_zeros=symbolic_zeros))
|
|
fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
|
|
fwd_in = [x for pair in fwd_in for x in pair] # flatten
|
|
with core.set_current_trace(self.parent_trace):
|
|
res_and_primals_out = fwd.call_wrapped(*fwd_in)
|
|
|
|
_, res_tree, input_fwds = out_trees()
|
|
num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds)
|
|
res_out, primals_out = split_list(res_and_primals_out, [num_res_out])
|
|
res_out_ = iter(res_out)
|
|
res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds]
|
|
assert next(res_out_, None) is None
|
|
|
|
avals_out = [core.typeof(x).to_tangent_aval() for x in primals_out]
|
|
in_zeros = [type(t) is Zero for t in tangents_in]
|
|
nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z]
|
|
with core.set_current_trace(self.parent_trace):
|
|
tangents_out = custom_lin_p.bind(
|
|
*res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
|
out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros)
|
|
return map(partial(maybe_jvp_tracer, self), primals_out, tangents_out)
|
|
|
|
|
|
def maybe_jvp_tracer(trace, primal, tangent):
|
|
if (type(tangent) is Zero or
|
|
isinstance(typeof(tangent), core.ShapedArray)
|
|
and dtype(tangent) == float0):
|
|
return primal
|
|
else:
|
|
return JVPTracer(trace, primal, tangent)
|
|
|
|
class JVPTracer(Tracer[JVPTrace]):
|
|
__slots__ = ['primal', 'tangent']
|
|
|
|
def __init__(self, trace, primal, tangent):
|
|
if config.enable_checks.value:
|
|
_primal_tangent_shapes_match(primal, tangent)
|
|
super().__init__(trace, typeof(primal))
|
|
self.primal = primal
|
|
self.tangent = tangent
|
|
|
|
def _short_repr(self):
|
|
pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x)
|
|
primal, tangent = pp(self.primal), pp(self.tangent)
|
|
return f'JVPTracer({primal=!s}, {tangent=!s})'
|
|
|
|
def cur_qdd(self):
|
|
return core.cur_qdd(self.primal)
|
|
|
|
def full_lower(self):
|
|
if type(self.tangent) is Zero:
|
|
return core.full_lower(self.primal)
|
|
else:
|
|
return self
|
|
|
|
def to_concrete_value(self):
|
|
return core.to_concrete_value(self.primal)
|
|
|
|
def get_referent(self):
|
|
return core.get_referent(self.primal)
|
|
|
|
def type_state(self):
|
|
return self.primal.type_state()
|
|
|
|
def _primal_tangent_shapes_match(primal, tangent):
|
|
if type(tangent) is not Zero:
|
|
primal_aval = typeof(primal).strip_weak_type()
|
|
tangent_aval = typeof(tangent).strip_weak_type()
|
|
if not isinstance(primal_aval, core.ShapedArray):
|
|
return # TODO(mattjj,dougalm)
|
|
assert core.definitely_equal_shape(primal_aval.shape, tangent_aval.shape), (
|
|
primal_aval.shape, tangent_aval.shape)
|
|
expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
|
|
assert expected_tangent_dtype == tangent_aval.dtype, (
|
|
expected_tangent_dtype, tangent_aval.dtype)
|
|
if (not primal_aval.sharding.mesh.empty and
|
|
not tangent_aval.sharding.mesh.empty and
|
|
(primal_aval.sharding.mesh._any_axis_explicit or
|
|
tangent_aval.sharding.mesh._any_axis_explicit)):
|
|
assert primal_aval.sharding == tangent_aval.sharding, (
|
|
primal_aval.sharding, tangent_aval.sharding)
|
|
|
|
call_param_updaters: dict[core.Primitive, Callable] = {}
|
|
call_linearize_param_updaters: dict[core.Primitive, Callable] = {}
|
|
call_transpose_param_updaters: dict[core.Primitive, Callable] = {}
|
|
|
|
# -------------------- Linearize trace --------------------
|
|
|
|
class LinearizeTrace(Trace):
|
|
parent_trace: core.Trace | None
|
|
tangent_trace: core.Trace
|
|
is_vjp: bool
|
|
requires_low: bool
|
|
_name_stack_prefix_len: int
|
|
|
|
def __init__(self, parent_trace, tangent_trace, is_vjp):
|
|
super().__init__()
|
|
if not hasattr(tangent_trace, "tag"):
|
|
raise RuntimeError("Internal: LinearizeTrace.__init__ requires tangent_trace.tag to be defined.")
|
|
self.parent_trace = parent_trace
|
|
self.tangent_trace = tangent_trace
|
|
self.is_vjp = is_vjp
|
|
self.requires_low = False
|
|
self._name_stack_prefix_len = len(source_info_util.current_name_stack())
|
|
|
|
@property
|
|
def tag(self) -> core.TraceTag:
|
|
assert hasattr(self.tangent_trace, "tag")
|
|
return self.tangent_trace.tag
|
|
|
|
def _name_stack_suffix(self):
|
|
return source_info_util.current_name_stack()[self._name_stack_prefix_len:]
|
|
|
|
def to_primal_tangent_pair(self, val):
|
|
if isinstance(val, LinearizeTracer) and val._trace.tag is self.tag:
|
|
return (val.primal, val.tangent)
|
|
else:
|
|
tangent_zero = p2tz(val)
|
|
return (val, tangent_zero)
|
|
|
|
def process_primitive(self, primitive, tracers, params, /):
|
|
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
|
|
tangent_nzs = [type(t) is not Zero for t in tangents_in]
|
|
if (all(type(t) is Zero for t in tangents_in) and
|
|
primitive is not core.ref_p and primitive is not core.empty_ref_p and
|
|
not any(isinstance(typeof(x), AbstractRef) for x in primals_in)):
|
|
avals = tuple(core.typeof(x) for x in primals_in)
|
|
return primitive.bind_with_trace(self.parent_trace, primals_in, avals, params)
|
|
fallback = partial(fallback_linearize_rule, primitive)
|
|
lin = primitive_linearizations.get(primitive, fallback)
|
|
with core.set_current_trace(self.parent_trace):
|
|
primal_out, tangent_nzs_out, residuals, linearized = lin(
|
|
self.is_vjp, tangent_nzs, *primals_in, **params)
|
|
with (core.set_current_trace(self.tangent_trace),
|
|
source_info_util.set_name_stack(self._name_stack_suffix())):
|
|
tangent_out = linearized(residuals, *tangents_in)
|
|
if primitive.multiple_results:
|
|
return [maybe_linearize_tracer(self, x, nz, t)
|
|
for x, nz, t in zip(primal_out, tangent_nzs_out, tangent_out)]
|
|
else:
|
|
return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out)
|
|
|
|
def cur_qdd(self, x):
|
|
p, _ = self.to_primal_tangent_pair(x)
|
|
with core.set_current_trace(self.parent_trace):
|
|
return core.cur_qdd(p)
|
|
|
|
def process_custom_jvp_call(self, primitive, fun: lu.WrappedFun,
|
|
jvp: lu.WrappedFun, tracers, /, *,
|
|
symbolic_zeros: bool):
|
|
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
|
|
if all(type(t) is Zero for t in tangents_in):
|
|
avals = [typeof(x) for x in primals_in]
|
|
return primitive.bind_with_trace(
|
|
self.parent_trace, tuple(primals_in), avals,
|
|
dict(subfuns=(fun, jvp), symbolic_zeros=symbolic_zeros))
|
|
|
|
@partial(lu.wrap_init, debug_info=jvp.debug_info)
|
|
def _f_jvp(primals, tangents):
|
|
outs = jvp.call_wrapped(*primals, *tangents)
|
|
primals_out, tangents_out = split_list(outs, [len(outs) // 2])
|
|
return primals_out, tangents_out
|
|
|
|
with core.set_current_trace(self.parent_trace):
|
|
instantiate_zeros = not symbolic_zeros
|
|
nonzeros_in = [type(t) is not Zero for t in tangents_in]
|
|
primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
|
|
_f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros,
|
|
primals_in, {})
|
|
|
|
with core.set_current_trace(self.tangent_trace):
|
|
tangents_out = linearized(residuals, *tangents_in)
|
|
tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out)
|
|
return [maybe_linearize_tracer(self, x, nz, t)
|
|
for x, nz, t in zip(primals_out, tangent_nzs_out, tangents_out)]
|
|
|
|
def process_custom_vjp_call(self, primitive, fun, fwd,
|
|
bwd: lu.WrappedFun, tracers, /, *,
|
|
out_trees: Callable[[], tuple[PyTreeDef, PyTreeDef, list[int | None]]],
|
|
symbolic_zeros: bool):
|
|
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
|
|
if all(type(t) is Zero for t in tangents_in):
|
|
avals = [typeof(x) for x in primals_in]
|
|
return primitive.bind_with_trace(
|
|
self.parent_trace, tuple(primals_in), avals,
|
|
dict(subfuns=(fun, fwd, bwd), out_trees=out_trees,
|
|
symbolic_zeros=symbolic_zeros))
|
|
fwd_in = [(p, type(t) is not Zero) for p, t in zip(primals_in, tangents_in)]
|
|
fwd_in_flat = [x for pair in fwd_in for x in pair] # flatten
|
|
with core.set_current_trace(self.parent_trace):
|
|
res_and_primals_out = fwd.call_wrapped(*fwd_in_flat)
|
|
|
|
_, res_tree, input_fwds = out_trees()
|
|
num_res_out = res_tree.num_leaves - sum(f is not None for f in input_fwds)
|
|
res_out, primals_out = split_list(res_and_primals_out, [num_res_out])
|
|
res_out_ = iter(res_out)
|
|
res = [next(res_out_) if f is None else primals_in[f] for f in input_fwds]
|
|
assert next(res_out_, None) is None
|
|
avals_out = [core.typeof(x).to_tangent_aval() for x in primals_out]
|
|
|
|
in_zeros = [type(t) is Zero for t in tangents_in]
|
|
nz_tangents_in = [t for z, t in zip(in_zeros, tangents_in) if not z]
|
|
with core.set_current_trace(self.tangent_trace):
|
|
tangents_out = custom_lin_p.bind(
|
|
*res, *nz_tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
|
|
out_avals=avals_out, symbolic_zeros=symbolic_zeros, in_zeros=in_zeros)
|
|
tangent_nzs_out = [type(t) is not Zero for t in tangents_out]
|
|
return map(partial(maybe_linearize_tracer, self), primals_out, tangent_nzs_out, tangents_out)
|
|
|
|
def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params, /):
|
|
assert call_primitive.multiple_results
|
|
primals, tangents = unzip2(map(self.to_primal_tangent_pair, tracers))
|
|
nzs_in = tuple(type(t) is not Zero for t in tangents)
|
|
f_primal, linearize_outs_thunk = linearize_subtrace(
|
|
f, self.is_vjp, self.tag, nzs_in, f.debug_info)
|
|
|
|
avals = [typeof(x) for x in primals]
|
|
all_primal_results = call_primitive.bind_with_trace(
|
|
self.parent_trace, primals, avals, dict(params, subfuns=(f_primal,)))
|
|
residual_avals, nzs_out, lin_jaxpr, env, in_fwd, out_fwd = linearize_outs_thunk()
|
|
num_res_out = sum(f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd))
|
|
non_fwd_res = all_primal_results[:num_res_out]
|
|
primals_out = all_primal_results[num_res_out:]
|
|
residuals = subs_list2(in_fwd, out_fwd, primals, primals_out, non_fwd_res)
|
|
update_params = call_linearize_param_updaters.get(call_primitive)
|
|
num_new_args = len(residuals) + len(env)
|
|
new_params = (update_params(params, num_new_args, nzs_in)
|
|
if update_params else params)
|
|
num_residuals = len(residual_avals)
|
|
|
|
f_tangent = _get_f_tangent(lin_jaxpr, num_residuals)
|
|
nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
|
|
new_params = dict(new_params, subfuns=(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),))
|
|
args = (*residuals, *env, *nz_tangents_in)
|
|
avals = [typeof(x) for x in args]
|
|
nz_tangents_out = call_primitive.bind_with_trace(
|
|
self.tangent_trace, args, avals, new_params)
|
|
nz_tangents_out_iter = iter(nz_tangents_out)
|
|
tangents_out = [next(nz_tangents_out_iter) if nz else p2tz(primal)
|
|
for nz, primal in zip(nzs_out, primals_out)]
|
|
return map(partial(maybe_linearize_tracer, self), primals_out, nzs_out, tangents_out)
|
|
|
|
|
|
@weakref_lru_cache
|
|
def _get_f_tangent(lin_jaxpr, num_residuals):
|
|
def _f(*args):
|
|
consts = args[:num_residuals]
|
|
nz_tangents = args[num_residuals:]
|
|
return core.eval_jaxpr(lin_jaxpr, consts, *nz_tangents)
|
|
return _f
|
|
|
|
|
|
def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
|
|
if is_nonzero:
|
|
assert not type(tangent) is Zero
|
|
return LinearizeTracer(trace, primal, tangent)
|
|
else:
|
|
assert type(tangent) is Zero
|
|
return primal
|
|
|
|
def fallback_linearize_rule(_prim: core.Primitive,
|
|
_is_vjp, _nonzeros: Sequence[bool], *primals, **params):
|
|
jvp = primitive_jvps.get(_prim)
|
|
if not jvp:
|
|
msg = f"Differentiation rule for '{_prim}' not implemented"
|
|
raise NotImplementedError(msg)
|
|
debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params)
|
|
return linearize_from_jvp(lu.wrap_init(jvp, debug_info=debug_jvp),
|
|
_prim.multiple_results, _nonzeros, False, False,
|
|
primals, params)
|
|
|
|
def linearize_from_jvp(jvp: lu.WrappedFun,
|
|
multiple_results: bool,
|
|
nonzeros: Sequence[bool],
|
|
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
|
|
primals, params):
|
|
current_name_stack = source_info_util.current_name_stack()
|
|
with core.take_current_trace() as parent_trace:
|
|
trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag())
|
|
tangent_avals = [typeof(p).to_tangent_aval() for p in primals]
|
|
|
|
# map tangents with float0 dtype to symbolic zeros
|
|
nonzeros = [nz and not (isinstance(a, core.ShapedArray) and a.dtype == float0)
|
|
for a, nz in zip(tangent_avals, nonzeros)]
|
|
|
|
def make_zero(aval):
|
|
if instantiate_input_zeros:
|
|
return zeros_like_aval(aval)
|
|
elif user_facing_symbolic_zeros:
|
|
return SymbolicZero(aval)
|
|
else:
|
|
return Zero(aval)
|
|
|
|
if user_facing_symbolic_zeros:
|
|
zero_type = SymbolicZero
|
|
else:
|
|
zero_type = Zero
|
|
|
|
with core.set_current_trace(trace):
|
|
tangent_args = [trace.new_arg(pe.PartialVal.unknown(a)) if nz else make_zero(a)
|
|
for a, nz in zip(tangent_avals, nonzeros)]
|
|
out_primals, out_tangents = jvp.call_wrapped(
|
|
tuple(primals), tuple(tangent_args), **params)
|
|
|
|
if not multiple_results:
|
|
out_primals = [out_primals]
|
|
out_tangents = [out_tangents]
|
|
|
|
out_primals = [trace.to_jaxpr_tracer(p).pval.get_known() for p in out_primals]
|
|
if any(p is None for p in out_primals):
|
|
raise ValueError(
|
|
"Linearization failed to produce known values for all output primals. "
|
|
"This is typically caused by attempting to differentiate a function "
|
|
"uses an operation that does not support reverse-mode autodiff.")
|
|
|
|
out_nzs = [type(t) is not zero_type and not trace.to_jaxpr_tracer(t).is_known()
|
|
for t in out_tangents]
|
|
out_tangent_avals = [typeof(p).to_tangent_aval() for p in out_primals]
|
|
out_nz_tracers = [trace.to_jaxpr_tracer(r)
|
|
for (r, nz) in zip(out_tangents, out_nzs) if nz]
|
|
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
|
|
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(
|
|
in_tracers, out_nz_tracers, trace.effect_handles,
|
|
jvp.debug_info.with_unknown_names())
|
|
jaxpr, used_consts, _ = pe.dce_jaxpr_consts(
|
|
jaxpr, [True] * len(jaxpr.outvars),
|
|
[False] * len(jaxpr.constvars) + [True] * len(jaxpr.invars))
|
|
out_consts = [c for used, c in zip(used_consts, out_consts) if used]
|
|
|
|
def linearized(residuals, *tangents):
|
|
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
|
|
nz_tangents_out = core.eval_jaxpr(jaxpr, residuals, *nz_tangents_in)
|
|
nz_tangents_out_iter = iter(nz_tangents_out)
|
|
all_out_tangents = [next(nz_tangents_out_iter) if nz else Zero(aval)
|
|
for (aval, nz) in zip(out_tangent_avals, out_nzs)]
|
|
if multiple_results:
|
|
return all_out_tangents
|
|
else:
|
|
out_tangent, = all_out_tangents
|
|
return out_tangent
|
|
|
|
if multiple_results:
|
|
return out_primals, out_nzs, out_consts, linearized
|
|
else:
|
|
out_primal, = out_primals
|
|
out_nz, = out_nzs
|
|
return out_primal, out_nz, out_consts, linearized
|
|
|
|
class LinearizeTracer(Tracer[LinearizeTrace]):
|
|
__slots__ = ['primal', 'tangent']
|
|
|
|
def __init__(self, trace, primal, tangent):
|
|
if config.enable_checks.value:
|
|
_primal_tangent_shapes_match(primal, tangent)
|
|
super().__init__(trace, typeof(primal))
|
|
self.primal = primal
|
|
self.tangent = tangent
|
|
|
|
def _short_repr(self):
|
|
pp = lambda x: x._short_repr() if isinstance(x, Tracer) else str(x)
|
|
primal, tangent = pp(self.primal), typeof(self.tangent).str_short(True)
|
|
return f"GradTracer({primal=!s}, typeof(tangent)={tangent!s})"
|
|
|
|
def full_lower(self):
|
|
if type(self.tangent) is Zero:
|
|
return core.full_lower(self.primal)
|
|
else:
|
|
return self
|
|
|
|
def to_concrete_value(self):
|
|
return core.to_concrete_value(self.primal)
|
|
|
|
def get_referent(self):
|
|
return core.get_referent(self.primal)
|
|
|
|
def cur_qdd(self):
|
|
return core.cur_qdd(self.primal)
|
|
|
|
|
|
# -------------------- Primitives --------------------
|
|
|
|
primitive_jvps : dict[core.Primitive, Callable] = {}
|
|
primitive_transposes: dict[core.Primitive, Callable] = {}
|
|
primitive_linearizations : dict[core.Primitive, Callable] = {}
|
|
|
|
def deflinear(primitive, transpose_rule):
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
|
|
|
|
def linear_jvp(primitive, primals, tangents, **params):
|
|
val_out = primitive.bind(*primals, **params)
|
|
if all(type(tangent) is Zero for tangent in tangents):
|
|
if primitive.multiple_results:
|
|
return val_out, map(p2tz, val_out)
|
|
return val_out, p2tz(val_out)
|
|
else:
|
|
tangents = map(instantiate_zeros, tangents)
|
|
return val_out, primitive.bind(*tangents, **params)
|
|
|
|
def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
|
|
if type(cotangent) is Zero:
|
|
return [Zero(x.aval.to_tangent_aval()) if isinstance(x, UndefinedPrimal)
|
|
else None for x in args]
|
|
else:
|
|
return transpose_rule(cotangent, **kwargs)
|
|
|
|
|
|
def deflinear2(primitive, transpose_rule):
|
|
primitive_jvps[primitive] = partial(linear_jvp, primitive)
|
|
primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)
|
|
|
|
def linear_transpose2(transpose_rule, cotangent, *args, **kwargs):
|
|
if type(cotangent) is Zero:
|
|
return [Zero(x.aval.to_ct_aval()) if isinstance(x, UndefinedPrimal)
|
|
else None for x in args]
|
|
else:
|
|
return transpose_rule(cotangent, *args, **kwargs)
|
|
|
|
|
|
def defjvp(primitive, *jvprules):
|
|
assert isinstance(primitive, Primitive)
|
|
assert not primitive.multiple_results
|
|
primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
|
|
|
|
|
|
def standard_jvp(jvprules, primitive, primals, tangents, **params):
|
|
val_out = primitive.bind(*primals, **params)
|
|
tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
if rule is not None and type(t) is not Zero]
|
|
return val_out, functools.reduce(add_tangents, tangents_out, p2tz(val_out))
|
|
|
|
def defjvp2(primitive, *jvprules):
|
|
assert isinstance(primitive, Primitive)
|
|
assert not primitive.multiple_results
|
|
primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
|
|
|
|
def standard_jvp2(jvprules, primitive, primals, tangents, **params):
|
|
val_out = primitive.bind(*primals, **params)
|
|
tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
|
|
if rule is not None and type(t) is not Zero)
|
|
tangents_out = list(tangents_out)
|
|
return val_out, functools.reduce(add_tangents, tangents_out, p2tz(val_out))
|
|
|
|
def add_tangents(x, y):
|
|
if type(x) is Zero:
|
|
return y
|
|
elif type(y) is Zero:
|
|
return x
|
|
else:
|
|
return add_jaxvals(x, y)
|
|
|
|
def defbilinear(prim, lhs_rule, rhs_rule):
|
|
assert isinstance(prim, Primitive)
|
|
lhs_jvp = lambda g, x, y, **kwargs: prim.bind(g, y, **kwargs)
|
|
rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, g, **kwargs)
|
|
defjvp(prim, lhs_jvp, rhs_jvp)
|
|
fancy_transposes[prim] = partial(fancy_bilinear_transpose, lhs_rule, rhs_rule)
|
|
# TODO(mattjj,yashkatariya): remove next line if downstream doesnt need it
|
|
primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
|
|
|
|
def fancy_bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
|
|
assert isinstance(x, GradAccum) ^ isinstance(y, GradAccum)
|
|
if isinstance(x, GradAccum):
|
|
if type(cotangent) is not Zero and not isinstance(x, NullAccum):
|
|
x.accum(lhs_rule(cotangent, x, y, **kwargs))
|
|
else:
|
|
if type(cotangent) is not Zero and not isinstance(y, NullAccum):
|
|
y.accum(rhs_rule(cotangent, x, y, **kwargs))
|
|
|
|
def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
|
|
assert is_undefined_primal(x) ^ is_undefined_primal(y)
|
|
if is_undefined_primal(x):
|
|
if type(cotangent) is Zero:
|
|
return Zero(x.aval.to_ct_aval()), None
|
|
else:
|
|
out = lhs_rule(cotangent, x, y, **kwargs)
|
|
return out, None
|
|
else:
|
|
if type(cotangent) is Zero:
|
|
return None, Zero(y.aval.to_ct_aval())
|
|
else:
|
|
out = rhs_rule(cotangent, x, y, **kwargs)
|
|
return None, out
|
|
|
|
def defjvp_zero(primitive):
|
|
assert isinstance(primitive, Primitive)
|
|
primitive_jvps[primitive] = partial(zero_jvp, primitive)
|
|
|
|
def zero_jvp(primitive, primals, tangents, **params):
|
|
r = primitive.bind(*primals, **params)
|
|
return r, p2tz(r)
|
|
|
|
deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
|
|
|
|
|
|
def instantiate_zeros(tangent):
|
|
if type(tangent) is Zero:
|
|
if hasattr(tangent.aval, 'sharding'):
|
|
# TODO(dougalm, yashkatariya): Delete this context manager once we figure
|
|
# out how to ensure jaxpr arguments always have the context mesh.
|
|
with mesh_lib.use_abstract_mesh(tangent.aval.sharding.mesh):
|
|
return zeros_like_aval(tangent.aval)
|
|
return zeros_like_aval(tangent.aval)
|
|
return tangent
|
|
|
|
@lu.transformation_with_aux2
|
|
def traceable(f, store, in_tree, *primals_and_tangents):
|
|
primals, tangents = tree_unflatten(in_tree, primals_and_tangents)
|
|
tangents = [p2tz(p) if t is None else t
|
|
for p, t in zip(primals, tangents)]
|
|
primals_out, tangents_out = f(primals, tangents)
|
|
tangents_out = [None if type(t) is Zero else t for t in tangents_out]
|
|
out_flat, out_tree = tree_flatten((primals_out, tangents_out))
|
|
store.store(out_tree)
|
|
return out_flat
|
|
|
|
def call_transpose_fancy(primitive, cts, *args, call_jaxpr, **params):
|
|
if call_jaxpr.constvars: raise NotImplementedError
|
|
primals_ctrefs, specs = project_accums(args)
|
|
flat_args, treedef = tree_flatten((primals_ctrefs, cts))
|
|
cell = lambda: None
|
|
|
|
@partial(lu.wrap_init, debug_info=call_jaxpr.debug_info.with_unknown_names())
|
|
def transposed(*flat_args):
|
|
primals_ctrefs, cts = tree_unflatten(treedef, flat_args)
|
|
args = unproject_accums(specs, primals_ctrefs)
|
|
backward_pass3(call_jaxpr, False, (), args, cts)
|
|
cts_out = [x.freeze() if isinstance(x, ValAccum) else None for x in args]
|
|
cts_out, cell.out_tree = tree_flatten(cts_out) # pyrefly: ignore[missing-attribute]
|
|
return cts_out
|
|
|
|
update_params = call_transpose_param_updaters.get(primitive)
|
|
if update_params:
|
|
params = update_params(params, [isinstance(x, GradAccum) for x in args],
|
|
[type(x) is not Zero for x in cts])
|
|
|
|
out_flat = primitive.bind(*flat_args, subfuns=(transposed,), **params)
|
|
for x, ct in zip(args, tree_unflatten(cell.out_tree, out_flat)): # pyrefly: ignore[missing-attribute]
|
|
if isinstance(x, ValAccum): x.accum(ct)
|
|
fancy_transposes[core.call_p] = partial(call_transpose_fancy, call_p)
|
|
|
|
def _closed_call_transpose(ct, *args, call_jaxpr, **params):
|
|
jaxpr_, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
|
jaxpr_ = pe.convert_constvars_jaxpr(jaxpr_)
|
|
call_transpose_fancy(core.closed_call_p, ct, *consts, *args,
|
|
call_jaxpr=jaxpr_, **params)
|
|
fancy_transposes[core.closed_call_p] = _closed_call_transpose
|
|
|
|
def jvp_jaxpr(jaxpr: core.ClosedJaxpr, nonzeros: Sequence[bool],
|
|
instantiate: bool | Sequence[bool]
|
|
) -> tuple[core.ClosedJaxpr, list[bool]]:
|
|
if type(instantiate) is bool:
|
|
instantiate = (instantiate,) * len(jaxpr.out_avals)
|
|
return _jvp_jaxpr(jaxpr, tuple(nonzeros), tuple(instantiate))
|
|
|
|
@weakref_lru_cache
|
|
def _jvp_jaxpr(jaxpr: core.ClosedJaxpr,
|
|
nonzeros: Sequence[bool], instantiate: Sequence[bool]):
|
|
assert len(jaxpr.in_avals) == len(nonzeros)
|
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr),
|
|
debug_info=jaxpr.jaxpr.debug_info.with_unknown_names())
|
|
f_jvp, out_nonzeros = f_jvp_traceable(
|
|
jvp(f, instantiate=instantiate, transform_stack=False), nonzeros)
|
|
tangent_avals = [aval.to_tangent_aval()
|
|
for aval, nz in zip(jaxpr.in_aval_qdds, nonzeros) if nz]
|
|
avals_in = list(it.chain(jaxpr.in_aval_qdds, tangent_avals))
|
|
jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(
|
|
f_jvp, avals_in)
|
|
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
|
|
|
|
@lu.transformation_with_aux2
|
|
def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents):
|
|
num_primals = len(nonzeros)
|
|
primals = list(primals_and_nztangents[:num_primals])
|
|
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
|
|
tangents = [next(nonzero_tangents) if nz else p2tz(p)
|
|
for p, nz in zip(primals, nonzeros)]
|
|
primals_out, tangents_out = f(primals, tangents)
|
|
out_nonzeros = [type(t) is not Zero for t in tangents_out]
|
|
nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero]
|
|
store.store(out_nonzeros)
|
|
return list(primals_out) + nonzero_tangents_out
|
|
|
|
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
|
|
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
|
|
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
|
|
if jaxpr.jaxpr.debug_info.arg_names is None:
|
|
new_arg_names = None
|
|
else:
|
|
new_arg_names = tuple(_perm(primals_in, tangents_in,
|
|
jaxpr.jaxpr.debug_info.arg_names))
|
|
if jaxpr.jaxpr.debug_info.result_paths is None:
|
|
new_result_paths = None
|
|
else:
|
|
new_result_paths = tuple(_perm(primals_out, tangents_out,
|
|
jaxpr.jaxpr.debug_info.result_paths))
|
|
new_debug_info = jaxpr.jaxpr.debug_info._replace(
|
|
arg_names=new_arg_names, result_paths=new_result_paths)
|
|
constvars = jaxpr.jaxpr.constvars
|
|
new_effects = pe._renumber_effects(
|
|
(*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars),
|
|
jaxpr.jaxpr.effects)
|
|
new_jaxpr = jaxpr.jaxpr.replace(
|
|
constvars=constvars, invars=new_invars, outvars=new_outvars,
|
|
effects=new_effects, debug_info=new_debug_info)
|
|
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
|
|
|
|
def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int],
|
|
lst: Sequence[Any]) -> Sequence[Any]:
|
|
n = sum(primal_counts)
|
|
primals, tangents = lst[:n], lst[n:]
|
|
primal_groups = split_list(primals, primal_counts[:-1])
|
|
tangent_groups = split_list(tangents, tangent_counts[:-1])
|
|
return _interleave(primal_groups, tangent_groups)
|
|
|
|
def _interleave(xs, ys):
|
|
assert len(xs) == len(ys)
|
|
return [e for pair in zip(xs, ys) for l in pair for e in l]
|
|
|
|
|
|
custom_lin_p: core.Primitive = core.Primitive('custom_lin')
|
|
custom_lin_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals)
|
|
custom_lin_p.multiple_results = True
|
|
|
|
def raise_custom_vjp_error_on_jvp(*_, **__):
|
|
raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
|
|
"function.")
|
|
custom_lin_p.def_impl(raise_custom_vjp_error_on_jvp)
|
|
|
|
def _custom_lin_transpose(cts_out, *invals, num_res,
|
|
bwd: lu.WrappedFun, out_avals,
|
|
symbolic_zeros, in_zeros):
|
|
res, _ = split_list(invals, [num_res])
|
|
if symbolic_zeros:
|
|
cts_out = map(replace_internal_symbolic_zeros, cts_out)
|
|
else:
|
|
cts_out = map(instantiate_zeros, cts_out)
|
|
cts_in = bwd.call_wrapped(*res, *cts_out)
|
|
cts_in = map(replace_rule_output_symbolic_zeros, cts_in)
|
|
nz_cts_in, _ = partition_list(in_zeros, cts_in)
|
|
return [None] * num_res + nz_cts_in
|
|
primitive_transposes[custom_lin_p] = _custom_lin_transpose
|
|
|
|
def _custom_lin_pp_rule(eqn: core.JaxprEqn, context: core.JaxprPpContext,
|
|
settings: core.JaxprPpSettings) -> core.pp.Doc:
|
|
params = dict(eqn.params)
|
|
params.pop("out_avals")
|
|
params["bwd"] = params.pop("bwd").debug_info.func_name
|
|
return core._pp_eqn(eqn.replace(params=params), context, settings)
|
|
core.pp_eqn_rules[custom_lin_p] = _custom_lin_pp_rule
|
|
|
|
class CustomVJPException(Exception):
|
|
def __init__(self):
|
|
# TODO(mattjj): track source provenance on AD tracers, improve error
|
|
msg = ("Detected differentiation of a custom_vjp function with respect to "
|
|
"a closed-over value. That isn't supported because the custom VJP "
|
|
"rule only specifies how to differentiate the custom_vjp function "
|
|
"with respect to explicit input parameters. Try passing the "
|
|
"closed-over value into the custom_vjp function as an argument, and "
|
|
"adapting the custom_vjp fwd and bwd rules.")
|
|
super().__init__(msg)
|
|
|
|
# TODO(mattjj): remove this vestigial dict
|
|
reducing_transposes: dict[core.Primitive, Callable] = {}
|
|
|
|
# TODO(mattjj): remove this old code, used by something downstream
|
|
def call_transpose(primitive, params, call_jaxpr: core.Jaxpr, args, ct, _):
|
|
if isinstance(call_jaxpr, core.ClosedJaxpr):
|
|
call_jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
|
|
else:
|
|
consts = ()
|
|
all_args, in_treedef = tree_flatten((consts, args, ct))
|
|
fun = lu.hashable_partial(
|
|
lu.wrap_init(backward_pass, debug_info=call_jaxpr.debug_info),
|
|
call_jaxpr, False)
|
|
fun, out_tree = flatten_fun_nokwargs(fun, in_treedef)
|
|
update_params = call_transpose_param_updaters.get(primitive)
|
|
if update_params:
|
|
params = update_params(params, map(is_undefined_primal, args),
|
|
[type(x) is not Zero for x in ct])
|
|
out_flat = primitive.bind(*all_args, **dict(params, subfuns=(fun,)))
|
|
return tree_unflatten(out_tree(), out_flat)
|