142 lines
5.2 KiB
Python
142 lines
5.2 KiB
Python
# Copyright 2026 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import partial
|
|
from typing import Callable
|
|
|
|
from jax._src import core
|
|
from jax._src import api_util
|
|
from jax._src.util import safe_map, safe_zip, unzip2, weakref_lru_cache
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.tree_util import (
|
|
FlatTree, Partial, tree_unflatten, tree_leaves_checked)
|
|
from jax._src import source_info_util
|
|
from jax._src.core import typeof
|
|
|
|
map = safe_map
|
|
zip = safe_zip
|
|
|
|
# TODO
|
|
# [ ] static_argnums and static_argnames (via FlatTree)
|
|
# [ ] allow NotAvailable sentinels
|
|
# [ ] primal-output-to-residual forwarding
|
|
|
|
def remat_transform(policy, f, *args):
|
|
dbg = api_util.debug_info("remat", f, args, {})
|
|
with core.take_current_trace() as parent_trace:
|
|
jaxpr_trace = pe.DynamicJaxprTrace(None)
|
|
trace = RematTrace(parent_trace, jaxpr_trace, core.TraceTag(), policy)
|
|
args_ft = FlatTree.flatten_static_argnums_argnames(args, {}, (), ())
|
|
in_tracers = args_ft.map(
|
|
# pyrefly: ignore[bad-argument-type]
|
|
lambda x: RematTracer(trace, x, jaxpr_trace.new_arg(typeof(x), None))) # noqa F821
|
|
with core.set_current_trace(trace):
|
|
args, kwargs = in_tracers.unflatten()
|
|
ans_pytree = f(*args, **kwargs)
|
|
dbg = dbg.set_result_paths(ans_pytree)
|
|
ans_ft = FlatTree.flatten(ans_pytree)
|
|
del ans_pytree, args, kwargs
|
|
out_ft, out_tracer_ft = ans_ft.map(trace.to_val_tracer_pair).unzip2()
|
|
jaxpr, res = jaxpr_trace.to_jaxpr(list(out_tracer_ft), dbg, source_info_util.current())
|
|
in_tree, out_tree = args_ft.tree, out_ft.tree
|
|
del trace, in_tracers, out_tracer_ft
|
|
def f_rem(res, *args):
|
|
args_flat = tree_leaves_checked(in_tree, (args, {}))
|
|
out_flat = core.eval_jaxpr(jaxpr, res, *args_flat)
|
|
return tree_unflatten(out_tree, out_flat)
|
|
return out_ft.unflatten(), Partial(f_rem, map(reduce_precision, res))
|
|
|
|
class RematTracer(core.Tracer['RematTrace']):
|
|
_trace: RematTrace
|
|
|
|
def __init__(self, trace, x, jaxpr_tracer):
|
|
super().__init__(trace, core.typeof(x))
|
|
self.val = x
|
|
self.tracer = jaxpr_tracer
|
|
|
|
class RematTrace(core.Trace):
|
|
def __init__(self, parent_trace, jaxpr_trace, tag, policy):
|
|
super().__init__()
|
|
self.parent_trace = parent_trace
|
|
self.jaxpr_trace = jaxpr_trace
|
|
self.tag = tag
|
|
self.policy = policy
|
|
self.requires_low = False
|
|
|
|
def to_val_tracer_pair(self, x):
|
|
if isinstance(x, RematTracer) and x._trace.tag is self.tag:
|
|
return x.val, x.tracer
|
|
else:
|
|
return x, x
|
|
|
|
def process_primitive(self, prim, tracers, params, /):
|
|
in_vals, in_vals2 = unzip2(map(self.to_val_tracer_pair, tracers))
|
|
if prim in rules:
|
|
with core.set_current_trace(self.parent_trace):
|
|
out_primal, rem = rules[prim](self.policy, *in_vals, **params)
|
|
with core.set_current_trace(self.jaxpr_trace):
|
|
out_primal2 = rem(*in_vals2)
|
|
else:
|
|
with core.set_current_trace(self.parent_trace):
|
|
out_primal = prim.bind(*in_vals, **params)
|
|
with core.set_current_trace(self.jaxpr_trace):
|
|
out_primal2 = prim.bind(*in_vals2, **params)
|
|
if prim.multiple_results:
|
|
return map(partial(RematTracer, self), out_primal, out_primal2)
|
|
else:
|
|
return RematTracer(self, out_primal, out_primal2)
|
|
|
|
def reduce_precision(x):
|
|
if (h := reduce_precision_handlers.get(type(t := core.typeof(x)))):
|
|
return h(t, x)
|
|
return x
|
|
|
|
rules: dict[core.Primitive, Callable] = {}
|
|
reduce_precision_handlers: dict[type, Callable] = {}
|
|
|
|
|
|
def remat_jaxpr(jaxpr, policy):
|
|
return _remat_jaxpr(jaxpr, frozenset(policy))
|
|
|
|
@weakref_lru_cache
|
|
def _remat_jaxpr(jaxpr, policy):
|
|
dbg = jaxpr.jaxpr.debug_info
|
|
fwd_trace = pe.DynamicJaxprTrace(dbg)
|
|
rem_trace = pe.DynamicJaxprTrace(dbg, auto_dce=True)
|
|
tag = core.TraceTag()
|
|
trace = RematTrace(fwd_trace, rem_trace, tag, policy)
|
|
rem_trace.tag = tag
|
|
src = source_info_util.current()
|
|
|
|
def new_arg(a):
|
|
return RematTracer(trace, fwd_trace.new_arg(a, src), rem_trace.new_arg(a, src)) # noqa: F821
|
|
|
|
tracers = map(new_arg, jaxpr.in_aval_qdds)
|
|
with core.set_current_trace(trace, check_leaks=True):
|
|
ans = core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *tracers)
|
|
out_primals, out_rem = unzip2(map(trace.to_val_tracer_pair, ans))
|
|
del trace, ans, new_arg, tracers
|
|
|
|
rem_jaxpr_, rem_consts = rem_trace.to_jaxpr(out_rem, dbg.with_unknown_names(), src)
|
|
rem_jaxpr = pe.close_jaxpr(pe.convert_constvars_jaxpr(rem_jaxpr_))
|
|
rem_trace.invalidate()
|
|
rem_consts = map(partial(fwd_trace.to_jaxpr_tracer, source_info=src), rem_consts)
|
|
fwd_jaxpr_, fwd_consts = fwd_trace.to_jaxpr(
|
|
[*out_primals, *rem_consts], dbg.with_unknown_names(), src)
|
|
fwd_trace.invalidate()
|
|
fwd_jaxpr = core.ClosedJaxpr(fwd_jaxpr_, fwd_consts)
|
|
return fwd_jaxpr, rem_jaxpr, len(rem_consts)
|