1026 lines
38 KiB
Python
1026 lines
38 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from functools import partial, update_wrapper
|
|
import inspect
|
|
import itertools as it
|
|
from typing import Any, Hashable, Callable
|
|
|
|
from jax._src import api
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src.api_util import resolve_kwargs, infer_argnums_and_argnames
|
|
from jax._src import traceback_util
|
|
from jax._src.core import typeof
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import batching
|
|
from jax._src.interpreters import partial_eval as pe
|
|
from jax._src.interpreters import remat
|
|
from jax._src.custom_derivatives import (
|
|
CustomVJPPrimal, _temporary_dtype_exception, _check_for_returned_refs)
|
|
from jax._src.errors import UnexpectedTracerError
|
|
from jax._src.state.types import AbstractRef
|
|
from jax._src import ad_util
|
|
from jax._src.util import safe_zip, safe_map, split_list, unzip2
|
|
from jax._src.tree_util import (
|
|
tree_map, tree_flatten, tree_unflatten, tree_leaves, tree_leaves_checked,
|
|
broadcast_prefix, register_static, tree_map_with_path, keystr,
|
|
tracing_registry, FlatTree)
|
|
map, unsafe_map = safe_map, map
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
PyTreeOfAvals = Any
|
|
PyTreeDef = Any
|
|
LoVal = Any
|
|
HiVal = Any
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
# Hijax extension API
|
|
|
|
Ty = core.AbstractValue
|
|
LoType = core.AbstractValue
|
|
QDD = core.QuasiDynamicData
|
|
ShapedArray = core.ShapedArray
|
|
|
|
class HiPrimitive(core.Primitive):
|
|
def __init__(self, name):
|
|
self.name = name
|
|
ad.primitive_jvps[self] = self.jvp
|
|
ad.primitive_transposes[self] = self.transpose
|
|
|
|
def is_high(self, *avals, **params) -> bool:
|
|
return True
|
|
|
|
def is_effectful(self, params) -> bool: # pyrefly: ignore[bad-override]
|
|
return False # default immutable
|
|
|
|
# type checking and forward type propagation
|
|
def abstract_eval(self, *arg_avals, **params):
|
|
assert False, "must override"
|
|
|
|
# lowering implements the primitive in terms of lojax inputs/outputs/ops
|
|
def to_lojax(self, *lotypes_wrapped_in_hitypes, **params): # pyrefly: ignore[bad-override]
|
|
assert False, f"must override for {self}"
|
|
|
|
# autodiff interface
|
|
def jvp(self, primals, tangents, **params):
|
|
assert False, "must override"
|
|
# transposition is only required if the primitive is linear in some inputs
|
|
def transpose(self, *args, **params):
|
|
assert False, "must override"
|
|
|
|
AxisName = Any
|
|
|
|
class HiType(core.AbstractValue):
|
|
is_high = True
|
|
has_qdd = False # immutable
|
|
|
|
# type equality
|
|
def __hash__(self): assert False, "must override"
|
|
def __eq__(self, other): assert False, "must override"
|
|
|
|
# lowering from hijax type to lojax types
|
|
def lo_ty(self) -> list[core.AbstractValue]:
|
|
assert False, "must override"
|
|
|
|
# define lowering from hijax value to lojax values and back (like pytrees)
|
|
def lower_val(self, hi_val: HiVal) -> list[LoVal]: # TODO(mattjj): not lovals
|
|
assert False, "must override"
|
|
def raise_val(self, *lo_vals: LoVal) -> HiVal:
|
|
assert False, "must override"
|
|
|
|
# autodiff interface
|
|
def to_tangent_aval(self) -> HiType:
|
|
assert False, "must override"
|
|
def to_ct_aval(self) -> HiType:
|
|
return self.to_tangent_aval()
|
|
# the next two are required if this type is itself a tangent type
|
|
def vspace_zero(self) -> HiVal:
|
|
assert False, "must override"
|
|
def vspace_add(self, x: HiVal, y: HiVal) -> HiVal:
|
|
assert False, "must override"
|
|
|
|
# vmap interface (also needed for scan)
|
|
def dec_rank(self, size: int | None, spec: MappingSpec) -> HiType:
|
|
assert False, "must override"
|
|
def inc_rank(self, size: int | None, spec: MappingSpec) -> HiType:
|
|
assert False, "must override"
|
|
|
|
# scan interface
|
|
def leading_axis_spec(self) -> MappingSpec:
|
|
assert False, "must override"
|
|
|
|
# shard_map interface
|
|
def shard(self, mesh, manual_axes: frozenset, check_vma: bool, spec: HiPspec
|
|
) -> HiType:
|
|
assert False, "must override"
|
|
def unshard(self, mesh, check_vma: bool, spec: HiPspec) -> HiType:
|
|
assert False, "must override"
|
|
def nospec(self, mesh, check_vma: bool, all_names: tuple[AxisName, ...]
|
|
) -> HiPspec:
|
|
assert False, "must override"
|
|
|
|
|
|
class MutableHiType(core.AbstractValue):
|
|
is_high = True
|
|
has_qdd = True # mutable and potentially type-changing
|
|
is_writer = False
|
|
type_state = core.aval_method(core.cur_qdd)
|
|
|
|
# type equality
|
|
def __hash__(self): assert False, "must override"
|
|
def __eq__(self, other): assert False, "must override"
|
|
|
|
# define lowering from (mutable) hijax type to (immutable) lojax types
|
|
def lo_ty_qdd(self, state: QDD, /) -> list[core.AbstractValue]: # pyrefly: ignore[bad-override]
|
|
assert False, "must override"
|
|
def lo_ty(self):
|
|
assert False, "mutable hitypes should use lo_ty_qdd instead"
|
|
|
|
# define lowering from hijax value to lojax values and back, depending on qdd
|
|
def new_from_loval(self, state: QDD, /, *vals: LoVal) -> HiVal:
|
|
assert False, "must override"
|
|
def read_loval(self, state: QDD, val: HiVal, /) -> list[LoVal]:
|
|
assert False, "must override"
|
|
# default implementations of newer apis
|
|
def read_loval_in(self, state, val, /):
|
|
return self.read_loval(state, val)
|
|
def read_loval_out(self, qdd, hi, /):
|
|
return FlatTree.flatten(self.read_loval(qdd, hi))
|
|
|
|
# define how to mutate/set the mutable hijax value given immutable lojax vals
|
|
def update_from_loval(self, state: QDD, val: HiVal, /, *lo_vals: LoVal) -> None:
|
|
assert False, "must override"
|
|
# default implementation of newer api
|
|
def update_from_loval2(self, state, val, lo_vals_ft, /) -> None:
|
|
self.update_from_loval(state, val, *lo_vals_ft.unflatten())
|
|
|
|
# autodiff interface
|
|
def to_tangent_aval(self) -> HiType:
|
|
assert False, "must override"
|
|
|
|
# Subclasses should override if the cotangent type is a function of primal
|
|
# type. For example, CT unreduced = reduced and vice-versa.
|
|
def to_ct_aval(self) -> HiType:
|
|
return self.to_tangent_aval()
|
|
|
|
def register_hitype(val_cls, typeof_fn) -> None:
|
|
core.pytype_aval_mappings[val_cls] = typeof_fn
|
|
dtypes.canonicalize_value_handlers[val_cls] = lambda x: x
|
|
|
|
def hijax_method(f):
|
|
return core.aval_method(f)
|
|
|
|
|
|
# Boxes
|
|
|
|
## Box API
|
|
|
|
def new_box():
|
|
(), treedef = tree_flatten(None)
|
|
return new_box_p.bind(treedef=treedef)
|
|
|
|
def box_get(box):
|
|
tys = core.cur_qdd(box)
|
|
leaf_vals = box_get_p.bind(box, avals=tuple(tys.leaf_avals))
|
|
return tree_unflatten(tys.treedef, leaf_vals)
|
|
|
|
def box_set(box, val):
|
|
leaves, treedef = tree_flatten(val)
|
|
box_set_p.bind(box, *leaves, treedef=treedef)
|
|
|
|
## Box implementation
|
|
|
|
@dataclass(frozen=True)
|
|
class BoxTypeState(QDD):
|
|
leaf_avals: tuple[core.AbstractValue, ...]
|
|
treedef: PyTreeDef
|
|
|
|
def to_tangent_qdd(self):
|
|
leaf_avals = tuple(a.to_tangent_aval() for a in self.leaf_avals)
|
|
return BoxTypeState(leaf_avals, self.treedef)
|
|
|
|
def normalize(self):
|
|
leaf_types = tuple(a.normalize() for a in self.leaf_avals)
|
|
return BoxTypeState(leaf_types, self.treedef)
|
|
|
|
class BoxTy(MutableHiType):
|
|
has_qdd = True
|
|
|
|
# forwarded to value
|
|
get = core.aval_method(box_get)
|
|
set = core.aval_method(box_set)
|
|
|
|
# aval interface: hashability and str_short
|
|
def __hash__(self): return hash(BoxTy)
|
|
def __eq__(self, other): return isinstance(other, BoxTy)
|
|
|
|
def str_short(self, short_dtypes=False, **_) -> str: # pyrefly: ignore[bad-override]
|
|
return 'BoxTy'
|
|
|
|
# mutable interface
|
|
def lo_ty_qdd(self, box_state):
|
|
return [lo_ty for t in box_state.leaf_avals for lo_ty in t.lo_ty()]
|
|
|
|
def new_from_loval(self, box_state: BoxTypeState, *lo_vals) -> Box: # pyrefly: ignore[bad-override]
|
|
lo_vals_ = iter(lo_vals)
|
|
hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # pyrefly: ignore[missing-attribute]
|
|
for hi_ty in box_state.leaf_avals]
|
|
assert next(lo_vals_, None) is None
|
|
return Box._new(tree_unflatten(box_state.treedef, hi_vals)) # will be mutated
|
|
|
|
def read_loval(self, box_state: BoxTypeState, box) -> list: # pyrefly: ignore[bad-override]
|
|
leaf_vals, treedef = tree_flatten(box_get(box))
|
|
assert treedef == box_state.treedef
|
|
return [lo_val for hi_ty, hi_val in zip(box_state.leaf_avals, leaf_vals)
|
|
for lo_val in hi_ty.lower_val(hi_val)] # pyrefly: ignore[missing-attribute]
|
|
|
|
def update_from_loval(self, box_state: BoxTypeState, box, *lo_vals) -> None: # pyrefly: ignore[bad-override]
|
|
lo_vals_ = iter(lo_vals)
|
|
hi_vals = [hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # pyrefly: ignore[missing-attribute]
|
|
for hi_ty in box_state.leaf_avals]
|
|
assert next(lo_vals_, None) is None
|
|
box_set(box, tree_unflatten(box_state.treedef, hi_vals))
|
|
|
|
def to_tangent_aval(self):
|
|
return BoxTy()
|
|
|
|
# Override isinstance checks under tracing
|
|
class _BoxMeta(type):
|
|
def __instancecheck__(self, instance):
|
|
return (super().__instancecheck__(instance) or
|
|
isinstance(instance, core.Tracer) and
|
|
isinstance(core.typeof(instance), BoxTy))
|
|
|
|
class Box(metaclass=_BoxMeta): # noqa: F811
|
|
_val = None # always clobbered by __new__, but pytype likes this
|
|
|
|
# We want `Box(x)` to bind a primitive, so we override __new__ and provide a
|
|
# raw `_new` method below.
|
|
def __new__(cls, init_val=None):
|
|
(), treedef = tree_flatten(None)
|
|
box = new_box_p.bind(treedef=treedef)
|
|
box.set(init_val)
|
|
return box
|
|
|
|
@classmethod
|
|
def _new(cls, init_val):
|
|
new = super().__new__(cls)
|
|
new._val = init_val
|
|
return new
|
|
|
|
def get(self):
|
|
return box_get(self)
|
|
|
|
def set(self, val):
|
|
box_set(self, val)
|
|
|
|
def cur_qdd(self):
|
|
return self.type_state()
|
|
|
|
@property
|
|
def ty(self):
|
|
return BoxTy()
|
|
|
|
def type_state(self):
|
|
leaves, treedef = tree_flatten(self._val)
|
|
leaf_avals = tuple(map(core.typeof, leaves))
|
|
return BoxTypeState(leaf_avals, treedef)
|
|
|
|
register_hitype(Box, lambda b: b.ty)
|
|
|
|
class BoxEffect(effects.Effect): ...
|
|
box_effect = BoxEffect()
|
|
effects.control_flow_allowed_effects.add_type(BoxEffect)
|
|
effects.custom_derivatives_allowed_effects.add_type(BoxEffect)
|
|
|
|
class NewBox(HiPrimitive):
|
|
def is_high(self, *, treedef) -> bool: return True
|
|
|
|
def abstract_eval(self, *, treedef):
|
|
leaves, treedef = tree_flatten(None)
|
|
qdd = BoxTypeState(tuple(leaves), treedef)
|
|
return core.AvalQDD(BoxTy(), qdd), {box_effect}
|
|
|
|
def to_lojax(_, *, treedef):
|
|
return Box._new(None)
|
|
|
|
def jvp(_, primals, tangents, *, treedef): # pyrefly: ignore[bad-override]
|
|
assert False # TODO
|
|
|
|
def transpose(_, *args, treedef):
|
|
assert False # TODO
|
|
new_box_p = NewBox('new_box')
|
|
|
|
class BoxSet(HiPrimitive):
|
|
multiple_results = True
|
|
|
|
def is_high(self, *leaf_avals, treedef) -> bool: return True
|
|
|
|
def abstract_eval(self, box_ty, *leaf_avals, treedef):
|
|
box_ty.mutable_qdd.update(BoxTypeState(leaf_avals, treedef))
|
|
return [], {box_effect} # TODO better typechecking...
|
|
|
|
def to_lojax(_, box, *leaves, treedef):
|
|
box._val = tree_unflatten(treedef, leaves)
|
|
return []
|
|
|
|
def jvp(_, primals, tangents, *, treedef): # pyrefly: ignore[bad-override]
|
|
box, *vals = primals
|
|
box_dot, *val_dots = tangents
|
|
if type(box_dot) is ad_util.Zero:
|
|
raise Exception("can't differentiate Box.set operation, "
|
|
"did you forget jax.lax.stop_gradient?")
|
|
box_set_p.bind(box, *vals, treedef=treedef)
|
|
box_set_p.bind(box_dot, *val_dots, treedef=treedef)
|
|
return [], []
|
|
|
|
def transpose(_, *args, treedef):
|
|
assert False # TODO
|
|
box_set_p = BoxSet('box_set')
|
|
|
|
|
|
class BoxGet(HiPrimitive):
|
|
multiple_results = True
|
|
|
|
def abstract_eval(self, box_ty, *, avals):
|
|
return avals, {box_effect}
|
|
|
|
def to_lojax(_, box, *, avals):
|
|
return tree_leaves(box._val)
|
|
|
|
def jvp(_, primals, tangents, *, avals): # pyrefly: ignore[bad-override]
|
|
(box,), (box_dot,) = primals, tangents
|
|
return (
|
|
box_get_p.bind(box, avals=avals),
|
|
box_get_p.bind(box_dot, avals=tuple(a.to_tangent_aval() for a in avals))
|
|
)
|
|
|
|
def transpose(_, *args):
|
|
assert False # TODO
|
|
box_get_p = BoxGet('box_get')
|
|
|
|
|
|
# === new-style hijax primitive implementation ===
|
|
|
|
class VJPHiPrimitive:
|
|
in_avals: tuple[PyTreeOfAvals, ...]
|
|
out_aval: PyTreeOfAvals
|
|
params: dict[str, Hashable]
|
|
effects: frozenset[effects.Effect] = frozenset()
|
|
|
|
def __init__(self):
|
|
if not hasattr(self, 'in_avals'):
|
|
raise AttributeError("subclass __init__ should set `self.in_avals`")
|
|
if not hasattr(self, 'out_aval'):
|
|
raise AttributeError("subclass __init__ should set `self.out_aval`")
|
|
if not hasattr(self, 'params'):
|
|
raise AttributeError("subclass __init__ should set `self.params`")
|
|
if (type(self).vjp_bwd is not VJPHiPrimitive.vjp_bwd and
|
|
type(self).vjp_bwd_retval is not VJPHiPrimitive.vjp_bwd_retval):
|
|
raise AttributeError(f"subclass {type(self)} should not override both "
|
|
"`vjp_bwd` and `vjp_bwd_retval`")
|
|
self.in_avals_flat, self.in_tree = tracing_registry.flatten(self.in_avals)
|
|
self.out_avals_flat, self.out_tree = tracing_registry.flatten(self.out_aval)
|
|
self.__dict__.update(self.params)
|
|
self.check(*self.in_avals)
|
|
|
|
# Operation implementation in terms of lojax primitives
|
|
def expand(self, *args):
|
|
raise NotImplementedError(f"subclass {type(self)} must implement `expand`")
|
|
|
|
# reverse-mode AD interface
|
|
def vjp_fwd(self, nzs_in, /, *args):
|
|
raise NotImplementedError(f"for grad support, subclass {type(self)} must "
|
|
"implement `vjp_fwd`")
|
|
|
|
def vjp_bwd(self, res, outgrad, /, *arg_accums):
|
|
args_grad = self.vjp_bwd_retval(res, outgrad)
|
|
maybe_accum = lambda acc, v: isinstance(acc, ad.GradAccum) and acc.accum(v)
|
|
tree_map(maybe_accum, arg_accums, args_grad)
|
|
|
|
def vjp_bwd_retval(self, res, outgrad, /):
|
|
# Classic API: returns values instead of using accumulators
|
|
raise NotImplementedError(f"for grad support, subclass {type(self)} must "
|
|
"implement `vjp_bwd` or `vjp_bwd_retval`")
|
|
|
|
# optional forward-mode AD interfaces
|
|
def jvp(self, primals, tangents):
|
|
raise NotImplementedError(f"for jvp support, subclass {type(self)} must "
|
|
"implement `jvp`")
|
|
|
|
def lin(self, nzs_in, *primals):
|
|
raise NotImplementedError(f"for linearize support, subclass {type(self)} "
|
|
"must implement `lin` and `linearized`")
|
|
|
|
def linearized(self, residuals, *tangents):
|
|
raise NotImplementedError(f"for linearize support, subclass {type(self)} "
|
|
"must implement `lin` and `linearized`")
|
|
|
|
# optional transpose rule, for primitives that are linear in some inputs
|
|
def transpose(self, out_ct, *maybe_accums):
|
|
raise NotImplementedError(f"for transpose support, subclass {type(self)} "
|
|
"must implement `transpose`")
|
|
|
|
# vmap interface
|
|
def batch(self, axis_data, args, dims):
|
|
out_dim = self.batch_dim_rule(axis_data, dims)
|
|
return VmapOf(self, axis_data, dims, out_dim)(*args), out_dim
|
|
|
|
def batch_dim_rule(self, axis_data, dims, /):
|
|
raise NotImplementedError(f"for vmap support, subclass {type(self)} must "
|
|
"implement `batch` or `batch_dim_rule`")
|
|
|
|
# optional dce control
|
|
def dce(self, used_outs):
|
|
used_outs_flat = tree_leaves_checked(self.out_tree, used_outs)
|
|
if not any(used_outs_flat):
|
|
return False, False, None
|
|
else:
|
|
return True, True, self
|
|
|
|
# optional remat control
|
|
def remat(self, _policy, *args):
|
|
return self(*args), self # full remat by default
|
|
|
|
def __call__(self, *args):
|
|
args_flat = tree_leaves_checked(self.in_tree, args)
|
|
ans_flat = call_hi_primitive_p.bind(*args_flat, _prim=self)
|
|
return tree_unflatten(self.out_tree, ans_flat)
|
|
|
|
def check(self, *arg_tys):
|
|
return # subclass can optionally override this to add checking logic
|
|
|
|
def staging(self, trace, source_info, *args):
|
|
args_flat = tree_leaves_checked(self.in_tree, args)
|
|
ans_flat = trace.default_process_primitive(
|
|
call_hi_primitive_p, args_flat, dict(_prim=self), source_info)
|
|
return tree_unflatten(self.out_tree, ans_flat)
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__}[{self.params}]"
|
|
|
|
def __hash__(self):
|
|
return hash((self.__class__.__name__, tuple(self.params.items()), self.effects))
|
|
|
|
def __eq__(self, other):
|
|
return (type(self) is type(other) and self.params == other.params
|
|
and self.effects == other.effects)
|
|
|
|
class VmapOf(VJPHiPrimitive):
|
|
prim: core.Primitive
|
|
axis_data: batching.AxisData
|
|
in_dims: Any
|
|
out_dim: Any
|
|
|
|
def __init__(self, prim, axis_data, in_dims, out_dim):
|
|
unmap = lambda a, d: core.unmapped_aval(axis_data.size, d, a,
|
|
axis_data.explicit_mesh_axis)
|
|
self.in_avals = tree_map(unmap, prim.in_avals, in_dims)
|
|
self.out_aval = tree_map(unmap, prim.out_aval, out_dim)
|
|
self.params = dict(prim=prim, axis_data=axis_data, in_dims=in_dims,
|
|
out_dim=out_dim)
|
|
super().__init__()
|
|
|
|
@property
|
|
def _vmap_params(self):
|
|
return dict(axis_size=self.axis_data.size, axis_name=self.axis_data.name,
|
|
spmd_axis_name=self.axis_data.spmd_name or self.axis_data.explicit_mesh_axis)
|
|
|
|
def expand(self, *args):
|
|
return api.vmap(self.prim.expand, in_axes=self.in_dims, out_axes=self.out_dim, # pyrefly: ignore[missing-attribute]
|
|
**self._vmap_params)(*args)
|
|
|
|
def jvp(self, primals, tangents):
|
|
# TODO probably gonna get non-pytree-prefix errors because of sym zeros...
|
|
return api.vmap(self.prim.jvp, in_axes=(self.in_dims, self.in_dims), # pyrefly: ignore[missing-attribute]
|
|
out_axes=(self.out_dim, self.out_dim),
|
|
**self._vmap_params)(primals, tangents)
|
|
|
|
def vjp_fwd(self, in_nzs, *args):
|
|
store = lambda: None
|
|
def fwd(*args):
|
|
primal_out, res, *maybe_out_nzs = self.prim.vjp_fwd(in_nzs, *args) # pyrefly: ignore[missing-attribute]
|
|
store.out_nzs = maybe_out_nzs # pyrefly: ignore[missing-attribute]
|
|
return primal_out, res
|
|
(primal_out, res), (_, res_axes) = api.vmap(
|
|
fwd, in_axes=self.in_dims, out_axes=(self.out_dim, batching.infer),
|
|
**self._vmap_params)(*args)
|
|
return primal_out, (res, Static(res_axes)), *store.out_nzs # pyrefly: ignore[missing-attribute]
|
|
|
|
def vjp_bwd_retval(self, res_, g):
|
|
# TODO probably gonna get non-pytree-prefix errors because of sym zeros...
|
|
res, res_axes = res_[0], res_[1].val
|
|
in_dims = tree_map(lambda x: batching.sum_axis if x is None else x, self.in_dims,
|
|
is_leaf=lambda x: x is None)
|
|
g = tree_map(partial(map_zero, self.axis_data), self.out_dim, g, is_leaf=lambda x: x is None)
|
|
out = api.vmap(self.prim.vjp_bwd_retval, in_axes=(res_axes, self.out_dim), # pyrefly: ignore[missing-attribute]
|
|
out_axes=in_dims, **self._vmap_params, sum_match=True)(res, g)
|
|
return tree_map(partial(unmap_zero, self.axis_data), self.in_dims, out, is_leaf=lambda x: x is None)
|
|
|
|
def batch_dim_rule(self, axis_data, in_dims):
|
|
fix = lambda d, d_: d if (d is None or d_ is None) else d - (d_ < d)
|
|
in_dims_ = tree_map(fix, in_dims, self.in_dims, is_leaf=lambda x: x is None)
|
|
out_dim = self.prim.batch_dim_rule(axis_data, in_dims_) # pyrefly: ignore[missing-attribute]
|
|
return tree_map(lambda d, d_: d + (d_ < d), out_dim, self.out_dim)
|
|
|
|
def map_zero(axis_data, d, ct):
|
|
if isinstance(ct, ad_util.Zero):
|
|
return ad_util.Zero(core.mapped_aval(axis_data.size, d, ct.aval))
|
|
return ct
|
|
|
|
def unmap_zero(axis_data, d, ct):
|
|
if isinstance(ct, ad_util.Zero):
|
|
return ad_util.Zero(core.unmapped_aval(axis_data.size, d, ct.aval,
|
|
axis_data.explicit_mesh_axis))
|
|
return ct
|
|
|
|
|
|
call_hi_primitive_p = core.Primitive("call_hi_primitive")
|
|
call_hi_primitive_p.multiple_results = True
|
|
call_hi_primitive_p.is_high = lambda *args, _prim: True
|
|
call_hi_primitive_p.is_effectful = lambda params: bool(params['_prim'].effects)
|
|
@call_hi_primitive_p.def_effectful_abstract_eval
|
|
def _call_hi_primitive_abstract_eval(*_args, _prim):
|
|
return _prim.out_avals_flat, _prim.effects
|
|
|
|
def _call_hi_primitive_typecheck(_ctx_factory, *in_atoms_flat, _prim):
|
|
in_avals = [x.aval for x in in_atoms_flat]
|
|
if not all(map(core.typematch, in_avals, _prim.in_avals_flat)):
|
|
raise TypeError(f"input type mismatch for {_prim}")
|
|
_prim.check()
|
|
return _prim.out_avals_flat, _prim.effects
|
|
core.custom_typechecks[call_hi_primitive_p] = _call_hi_primitive_typecheck
|
|
|
|
def _call_hi_primitive_staging(trace, source_info, *args_flat, _prim):
|
|
trace.frame.is_high = True
|
|
args = tree_unflatten(_prim.in_tree, args_flat)
|
|
ans = _prim.staging(trace, source_info, *args)
|
|
return tree_leaves_checked(_prim.out_tree, ans)
|
|
pe.custom_staging_rules[call_hi_primitive_p] = _call_hi_primitive_staging
|
|
|
|
def _call_hi_primitive_to_lojax(*args_flat, _prim):
|
|
args = tree_unflatten(_prim.in_tree, args_flat)
|
|
ans = _prim.expand(*args)
|
|
return tree_leaves_checked(_prim.out_tree, ans)
|
|
call_hi_primitive_p.to_lojax = _call_hi_primitive_to_lojax
|
|
|
|
def _call_hi_primitive_batcher(axis_data, args_flat, dims_flat, _prim):
|
|
args = tree_unflatten(_prim.in_tree, args_flat)
|
|
dims = tree_unflatten(_prim.in_tree, dims_flat)
|
|
ans, dims = _prim.batch(axis_data, args, dims)
|
|
ans_flat = tree_leaves_checked(_prim.out_tree, ans)
|
|
dims_flat = _prim.out_tree.flatten_up_to(dims)
|
|
return ans_flat, dims_flat
|
|
batching.fancy_primitive_batchers[call_hi_primitive_p] = _call_hi_primitive_batcher
|
|
|
|
def _call_hi_primitive_linearize(is_vjp, nz_in_flat, *args_flat, _prim):
|
|
args = tree_unflatten(_prim.in_tree, args_flat)
|
|
nzs_in = tree_unflatten(_prim.in_tree, nz_in_flat)
|
|
if is_vjp:
|
|
ans, residuals, *maybe_nzs_out = _prim.vjp_fwd(nzs_in, *args)
|
|
linearized = partial(fake_linear_op, _prim, nz_in_flat)
|
|
else:
|
|
ans, residuals, *maybe_nzs_out = _prim.lin(nzs_in, *args)
|
|
linearized = partial(flatten_user_linearized, _prim)
|
|
ans_flat = tree_leaves_checked(_prim.out_tree, ans)
|
|
nzs_out = maybe_nzs_out[0] if maybe_nzs_out else True
|
|
nzs_out_flat = broadcast_prefix(nzs_out, ans)
|
|
return ans_flat, nzs_out_flat, residuals, linearized
|
|
ad.primitive_linearizations[call_hi_primitive_p] = _call_hi_primitive_linearize
|
|
|
|
def fake_linear_op(prim, nz_in_flat, rs, *tangents):
|
|
residuals_flat, residuals_tree = tree_flatten(rs)
|
|
assert nz_in_flat == [not isinstance(t, ad_util.Zero) for t in tangents]
|
|
nz_tangents = tree_leaves(tangents)
|
|
return call_hi_primitive_linearized_p.bind(
|
|
*residuals_flat, *nz_tangents, residuals_tree=residuals_tree, _prim=prim,
|
|
nz_in_flat=tuple(nz_in_flat))
|
|
|
|
def flatten_user_linearized(prim, residuals, *tangents_flat):
|
|
tangents = tree_unflatten(prim.in_tree, tangents_flat)
|
|
tangents_out = prim.linearized(residuals, *tangents)
|
|
tangents_out_flat = tree_leaves_checked(prim.out_tree, tangents_out)
|
|
return tangents_out_flat
|
|
|
|
call_hi_primitive_linearized_p = core.Primitive("call_hi_primitive_linearized")
|
|
call_hi_primitive_linearized_p.multiple_results = True
|
|
call_hi_primitive_linearized_p.is_high = lambda *args, _prim, **_: True
|
|
@call_hi_primitive_linearized_p.def_abstract_eval
|
|
def _call_hi_primitive_linearized_abstract_eval(*_args, _prim, residuals_tree, nz_in_flat):
|
|
return [t.to_tangent_aval() for t in _prim.out_avals_flat] # TODO(dougalm): handle nonzeros
|
|
|
|
def _call_hi_primitive_linearized_transpose(cts_flat, *args, _prim,
|
|
residuals_tree, nz_in_flat):
|
|
residuals_flat, accums_flat = split_list(args, [residuals_tree.num_leaves])
|
|
residuals = tree_unflatten(residuals_tree, residuals_flat)
|
|
accums_flat_ = iter(accums_flat)
|
|
accums_flat = [next(accums_flat_) if nz else ad.NullAccum(aval.to_ct_aval())
|
|
for aval, nz in zip(_prim.in_avals_flat, nz_in_flat)]
|
|
assert next(accums_flat_, None) is None
|
|
accums = tree_unflatten(_prim.in_tree, accums_flat)
|
|
cts = tree_unflatten(_prim.out_tree, cts_flat)
|
|
none = _prim.vjp_bwd(residuals, cts, *accums)
|
|
assert none is None
|
|
ad.fancy_transposes[call_hi_primitive_linearized_p] = _call_hi_primitive_linearized_transpose
|
|
|
|
def _call_hi_primitive_jvp(primals, tangents, *, _prim):
|
|
primals = tree_unflatten(_prim.in_tree, primals)
|
|
tangents = tree_unflatten(_prim.in_tree, tangents)
|
|
out_primals, out_tangents = _prim.jvp(primals, tangents)
|
|
out_primals_flat = tree_leaves_checked(_prim.out_tree, out_primals)
|
|
out_tangents_flat = _prim.out_tree.flatten_up_to(out_tangents)
|
|
return out_primals_flat, out_tangents_flat
|
|
ad.primitive_jvps[call_hi_primitive_p] = _call_hi_primitive_jvp
|
|
|
|
def _call_hi_primitive_transpose(cts_flat, *primals_flat, _prim):
|
|
cts = tree_unflatten(_prim.out_tree, cts_flat)
|
|
primals = tree_unflatten(_prim.in_tree, primals_flat)
|
|
none = _prim.transpose(cts, *primals)
|
|
assert none is None
|
|
ad.fancy_transposes[call_hi_primitive_p] = _call_hi_primitive_transpose
|
|
|
|
def _call_hi_primitive_dce(used_outs_flat, eqn):
|
|
_prim = eqn.params['_prim']
|
|
used_out = tree_unflatten(_prim.out_tree, used_outs_flat)
|
|
used_ins, produced_outs, new_prim = _prim.dce(used_out)
|
|
if new_prim is None:
|
|
return [False] * len(eqn.invars), None
|
|
used_ins_flat = broadcast_prefix(used_ins, _prim.in_avals)
|
|
produced_outs_flat = broadcast_prefix(produced_outs, _prim.out_aval)
|
|
new_invars = [x for x, u in zip(eqn.invars, used_ins_flat) if u]
|
|
new_outvars = [v for v, u in zip(eqn.outvars, produced_outs_flat) if u]
|
|
new_eqn = eqn.replace(invars=new_invars, outvars=new_outvars,
|
|
params=dict(_prim=new_prim))
|
|
return used_ins_flat, new_eqn
|
|
pe.dce_rules[call_hi_primitive_p] = _call_hi_primitive_dce
|
|
|
|
call_hi_primitive_linearized_p.to_lojax = ad.raise_custom_vjp_error_on_jvp
|
|
batching.fancy_primitive_batchers[call_hi_primitive_linearized_p] = ad.raise_custom_vjp_error_on_jvp
|
|
|
|
def _call_hi_primitive_remat(policy, *args_flat, _prim):
|
|
args = tree_unflatten(_prim.in_tree, args_flat)
|
|
out, rem_ = _prim.remat(policy, *args)
|
|
def rem(*args_flat):
|
|
args = tree_unflatten(_prim.in_tree, args_flat)
|
|
out = rem_(*args)
|
|
return tree_leaves_checked(_prim.out_tree, out)
|
|
return tree_leaves_checked(_prim.out_tree, out), rem
|
|
remat.rules[call_hi_primitive_p] = _call_hi_primitive_remat
|
|
|
|
|
|
class CustomVJPTraced(VJPHiPrimitive):
|
|
traced: Any
|
|
fwd: Any
|
|
bwd: Any
|
|
symbolic_zeros: Any
|
|
static_argnums: Any
|
|
opt_remat: bool
|
|
|
|
def __init__(self, traced, fwd, bwd, in_avals, sym_zeros, static_argnums, opt_remat):
|
|
self.in_avals = in_avals
|
|
self.out_aval = traced.out_avals
|
|
self.params = dict(traced=traced, fwd=fwd, bwd=bwd, symbolic_zeros=sym_zeros,
|
|
static_argnums=static_argnums, opt_remat=opt_remat)
|
|
super().__init__()
|
|
|
|
def expand(self, *args):
|
|
args = [x for x in args if not isinstance(x, Static)]
|
|
return self.traced(*args)
|
|
|
|
def vjp_fwd(self, in_nzs, *args):
|
|
in_nzs = tuple(x.val if isinstance(x, Static) else x for x in in_nzs)
|
|
args_ = tuple(x.val if isinstance(x, Static) else x for x in args)
|
|
if self.symbolic_zeros:
|
|
args_ = tree_map(CustomVJPPrimal, args_, in_nzs)
|
|
out, res = self.fwd(*args_)
|
|
if config.mutable_array_checks.value:
|
|
_check_for_returned_refs(self.fwd, (out, res), "fwd", tree_leaves(args),
|
|
self.out_tree.num_leaves)
|
|
if ((tree := tracing_registry.flatten(out)[1]) != self.out_tree):
|
|
raise TypeError(_vjp_primal_fwd_tree_mismatch_err(self, tree))
|
|
tree_map_with_path(_vjp_fwd_aval_mismatch_err, self.out_aval, out)
|
|
if self.symbolic_zeros:
|
|
out_pairs_flat = tree_leaves_checked(self.out_tree, out)
|
|
out_flat, out_nzs_flat = unzip2(
|
|
(x.value, x.perturbed) if isinstance(x, CustomVJPPrimal) else
|
|
(x, True) for x in out_pairs_flat)
|
|
out_nzs = tree_unflatten(self.out_tree, out_nzs_flat)
|
|
out = tree_unflatten(self.out_tree, out_flat)
|
|
return out, res, out_nzs
|
|
else:
|
|
return out, res
|
|
|
|
def vjp_bwd_retval(self, res, out_ct):
|
|
static_args = tuple(x.val for x in self.in_avals if isinstance(x, Static))
|
|
in_avals_ = tuple(x for x in self.in_avals if not isinstance(x, Static))
|
|
leaf = lambda x: isinstance(x, ad_util.Zero)
|
|
if self.symbolic_zeros:
|
|
out_ct = tree_map(ad_util.replace_internal_symbolic_zeros, out_ct, is_leaf=leaf)
|
|
else:
|
|
out_ct = tree_map(ad_util.instantiate, out_ct, is_leaf=leaf)
|
|
in_cts = self.bwd(*static_args, res, out_ct)
|
|
if isinstance(in_cts, list):
|
|
in_cts = tuple(in_cts)
|
|
if not isinstance(in_cts, tuple):
|
|
raise TypeError(f"Custom VJP bwd rule {self.bwd} must produce a tuple "
|
|
f"but got {type(in_cts)}.")
|
|
if len(in_cts) != len(self.in_tree.children()) - len(self.static_argnums):
|
|
raise ValueError(f"Custom VJP bwd rule {self.bwd} must produce a tuple "
|
|
"of length equal to the primal args tuple, but got "
|
|
f"length {len(in_cts)}")
|
|
in_cts = broadcast_prefix(in_cts, in_avals_, is_leaf=lambda x: x is None)
|
|
in_cts = tree_unflatten(self.in_tree, map(_replace_none, self.in_avals_flat, in_cts))
|
|
tree_map_with_path(_vjp_bwd_aval_mismatch_err, self.in_avals, in_cts)
|
|
if self.symbolic_zeros:
|
|
in_cts = tree_map(ad_util.replace_rule_output_symbolic_zeros, in_cts)
|
|
return in_cts
|
|
|
|
def jvp(self, primals, tangents):
|
|
if self.symbolic_zeros: raise NotImplementedError
|
|
zero = lambda x: isinstance(x, ad_util.Zero)
|
|
tangents = tree_map(ad_util.instantiate, tangents, is_leaf=zero)
|
|
if self.opt_remat:
|
|
fwd_traced = api.jit(partial(self.vjp_fwd, (True,) * len(primals))).trace(*primals)
|
|
primals_out, residuals = OptRemat(self, fwd_traced)(*primals)
|
|
else:
|
|
primals_out, residuals, *_ = self.vjp_fwd((True,) * len(primals), *primals)
|
|
tangents_out_flat = fake_linear_op(self, [True] * len(tangents), residuals, *tangents)
|
|
tangents_out = tree_unflatten(self.out_tree, tangents_out_flat)
|
|
return primals_out, tangents_out
|
|
|
|
def batch_dim_rule(self, axis_data, in_dims):
|
|
in_dims_flat = self.in_tree.flatten_up_to(in_dims)
|
|
_, out_dims = batching.batch_jaxpr2(self.traced.jaxpr, axis_data, tuple(in_dims_flat))
|
|
return tree_unflatten(self.out_tree, out_dims)
|
|
|
|
def check(self, *_):
|
|
effs = self.traced.jaxpr.effects
|
|
disallowed = effects.custom_derivatives_allowed_effects.filter_not_in(effs)
|
|
if disallowed:
|
|
raise NotImplementedError(f'Effects not supported in `custom_jvp`: {disallowed}')
|
|
|
|
def _vjp_primal_fwd_tree_mismatch_err(self, tree):
|
|
return (f"Custom VJP fwd rule {self.fwd.__name__} for function {self.traced.fun_name} "
|
|
"must produce a pair (list or tuple of length two) where the first "
|
|
"element represents the primal output "
|
|
"(equal to the output of the custom_vjp-decorated function "
|
|
f"{self.traced.fun_name}) and the "
|
|
"second element represents residuals (i.e. values stored from the "
|
|
"forward pass for use on the backward pass), but "
|
|
f"instead the fwd rule output's first element had container/pytree "
|
|
"structure:\n"
|
|
f""" {str(tree ).replace("'", "")}\n"""
|
|
f"while the custom_vjp-decorated function {self.traced.fun_name} had output "
|
|
"container/pytree structure:\n"
|
|
f""" {str(self.out_tree).replace("'", "")}.""")
|
|
|
|
def _vjp_fwd_aval_mismatch_err(path, primal_aval, fwd_val):
|
|
if not core.typematch(ty := typeof(fwd_val), primal_aval):
|
|
raise TypeError(f"at {keystr(path)}, got fwd output type {ty.str_short()} "
|
|
f"which doesn't match primal output type {primal_aval.str_short()}")
|
|
|
|
def _vjp_bwd_aval_mismatch_err(path, primal_aval, ct):
|
|
if config.disable_bwd_checks.value:
|
|
return
|
|
if isinstance(ct, ad_util.Zero):
|
|
return
|
|
if isinstance(primal_aval, AbstractRef):
|
|
primal_aval = primal_aval.inner_aval
|
|
expected = primal_aval.to_ct_aval()
|
|
ct_aval = ct.aval if isinstance(ct, ad_util.SymbolicZero) else typeof(ct)
|
|
if (not core.typematch(expected, ct_aval) and
|
|
not _temporary_dtype_exception(expected, ct_aval) and
|
|
getattr(expected, 'dtype', None) is not dtypes.float0):
|
|
result = f"at output{keystr(path)} " if path else ""
|
|
raise ValueError(
|
|
f"{result}the bwd rule produced an output of type {ct_aval.str_short()}"
|
|
f" which doesn't match expected type {expected.str_short()}")
|
|
|
|
def _replace_none(primal_in_aval, maybe_ct):
|
|
if maybe_ct is None:
|
|
return ad_util.Zero(primal_in_aval.to_ct_aval())
|
|
else:
|
|
return maybe_ct
|
|
|
|
class custom_vjp3:
|
|
fwd: Callable | None = None
|
|
bwd: Callable | None = None
|
|
symz: bool = False
|
|
opt_remat: bool = False
|
|
|
|
def __init__(self, f, nondiff_argnums=(), nondiff_argnames=()):
|
|
self.f = f
|
|
self.static_argnums = _set_up_nondiff(f, nondiff_argnums, nondiff_argnames)
|
|
update_wrapper(self, f)
|
|
|
|
def defvjp(self, fwd, bwd, *, symbolic_zeros=False, optimize_remat=False):
|
|
self.fwd = fwd
|
|
self.bwd = bwd
|
|
self.symz = symbolic_zeros
|
|
self.opt_remat = optimize_remat
|
|
return self
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
if not self.fwd or not self.bwd:
|
|
msg = f"No VJP defined for custom_vjp function {self.f.__name__} using defvjp."
|
|
raise AttributeError(msg)
|
|
|
|
args = resolve_kwargs(self.f, args, kwargs)
|
|
if any(isinstance(args[i], core.Tracer) for i in self.static_argnums):
|
|
raise UnexpectedTracerError("custom_vjp inputs marked with nondiff_argnums "
|
|
"must be static, not Tracers")
|
|
traced = api.jit(self.f, static_argnums=(*self.static_argnums,)).trace(*args)
|
|
if any(isinstance(x, core.Tracer) for x in traced._consts):
|
|
t = next(x for x in traced._consts if isinstance(x, core.Tracer))
|
|
raise UnexpectedTracerError(
|
|
f"custom_vjp-decorated function {self.f} closed over a {type(t).__name__} "
|
|
f"of type {t.aval.str_short()}, but custom_vjp functions can't close "
|
|
f"over Tracers. Rewrite {self.f} to take it as an explicit input.")
|
|
raise Exception # TODO(mattjj):error tracer type, value type, primal name
|
|
args = tuple(Static(x) if i in self.static_argnums else x for i, x in enumerate(args))
|
|
in_avals = tree_map(typeof, args)
|
|
prim = CustomVJPTraced(traced, self.fwd, self.bwd, in_avals, self.symz,
|
|
self.static_argnums, self.opt_remat)
|
|
return prim(*args)
|
|
|
|
def def_vmap(self, rule, /): return self.f.def_vmap(rule)
|
|
def def_transpose(self, rule, /): return self.f.def_transpose(rule)
|
|
|
|
class OptRemat(VJPHiPrimitive):
|
|
orig: CustomVJPTraced
|
|
traced_fwd: Any
|
|
|
|
def __init__(self, orig, traced_fwd):
|
|
self.in_avals = orig.in_avals
|
|
self.out_aval = traced_fwd.out_avals
|
|
self.params = dict(orig=orig, traced_fwd=traced_fwd)
|
|
super().__init__()
|
|
|
|
def expand(self, *primals):
|
|
return self.traced_fwd(*primals)
|
|
|
|
def dce(self, used_outs):
|
|
used_primals, used_res = used_outs
|
|
if any(tree_leaves(used_res)):
|
|
return True, (True, True), self # if any res used, no dce at all
|
|
elif any(tree_leaves(used_primals)):
|
|
return True, (True, False), self.orig # if only primals used, undo AD
|
|
else:
|
|
return False, (False, False), None
|
|
|
|
# TODO(mattjj): jvp and transpose? does anyone rely on them?
|
|
|
|
|
|
def _set_up_nondiff(f, argnums_, argnames) -> frozenset[int]:
|
|
argnums = set(argnums_)
|
|
if argnames:
|
|
sig = inspect.signature(f) # needed for static_argnames
|
|
argnums |= set(infer_argnums_and_argnames(sig, None, argnames)[0])
|
|
return frozenset(argnums)
|
|
|
|
@register_static
|
|
@dataclass(frozen=True)
|
|
class Static:
|
|
val: Any
|
|
|
|
class MappingSpec: pass
|
|
class HiPspec:
|
|
def to_lo(self) -> HiPspec: assert False, "must override"
|
|
def to_tangent_spec(self) -> HiPspec: assert False, "must override"
|
|
def to_ct_spec(self) -> HiPspec: assert False, "must override"
|
|
|
|
# Logs
|
|
|
|
log_effect = box_effect
|
|
|
|
def log_extend(log, dct):
|
|
leaves, treedef = tree_flatten(dct)
|
|
log_extend_p.bind(log, *leaves, treedef=treedef)
|
|
|
|
def log_append(log, key, val):
|
|
log_extend(log, {key: [val]})
|
|
|
|
def log_read(log):
|
|
return log_read_p.bind(log)
|
|
|
|
class _LogMeta(type):
|
|
def __instancecheck__(self, instance):
|
|
return (super().__instancecheck__(instance) or
|
|
isinstance(instance, core.Tracer) and
|
|
isinstance(core.typeof(instance), LogTy))
|
|
|
|
class Log(metaclass=_LogMeta): # noqa: F811
|
|
_dct: dict # dict[str, list[PyTree[Array]]]
|
|
|
|
def __new__(cls):
|
|
return new_log_p.bind()
|
|
|
|
@classmethod
|
|
def _new(cls):
|
|
new = super().__new__(cls)
|
|
new._dct = {}
|
|
return new
|
|
|
|
def cur_qdd(self):
|
|
return ()
|
|
|
|
append = log_append
|
|
extend = log_extend
|
|
read = log_read
|
|
|
|
class LogTy(MutableHiType):
|
|
has_qdd = True
|
|
is_writer = True
|
|
|
|
append = core.aval_method(log_append)
|
|
extend = core.aval_method(log_extend)
|
|
read = core.aval_method(log_read)
|
|
|
|
def __hash__(self): return hash(LogTy)
|
|
def __eq__(self, other): return isinstance(other, LogTy)
|
|
def str_short(self, short_dtypes=False, **_) -> str: # pyrefly: ignore[bad-override]
|
|
return 'Log'
|
|
|
|
def to_tangent_aval(self):
|
|
return LogTy()
|
|
|
|
def read_loval_in(self, qdd, log):
|
|
() = qdd
|
|
return []
|
|
|
|
def read_loval_out(self, qdd, log):
|
|
() = qdd
|
|
return FlatTree.flatten(log._dct)
|
|
|
|
def new_from_loval(self, qdd): # pyrefly: ignore[bad-override]
|
|
() = qdd
|
|
return Log._new()
|
|
|
|
def update_from_loval2(self, qdd, log: Log, lo_ft) -> None:
|
|
() = qdd
|
|
updates = lo_ft.unflatten()
|
|
for k, v in updates.items():
|
|
log._dct.setdefault(k, []).extend(v)
|
|
|
|
register_hitype(Log, lambda _: LogTy())
|
|
|
|
class LogExtend(HiPrimitive):
|
|
multiple_results = True # no results
|
|
is_effectful = lambda *_, **__: True
|
|
|
|
def abstract_eval(self, log_ty, *val_tys, treedef):
|
|
return [], {log_effect}
|
|
|
|
def to_lojax(_, log, *vals, treedef):
|
|
updates = tree_unflatten(treedef, vals)
|
|
for k, v in updates.items():
|
|
log._dct.setdefault(k, []).extend(v)
|
|
return []
|
|
log_extend_p = LogExtend('log_extend')
|
|
|
|
class NewLog(HiPrimitive):
|
|
def is_high(self) -> bool: return True
|
|
|
|
def abstract_eval(self):
|
|
ty = LogTy()
|
|
return core.AvalQDD(ty, ()), {log_effect} # pyrefly: ignore[bad-argument-type]
|
|
|
|
def to_lojax(_):
|
|
return Log._new()
|
|
new_log_p = NewLog('new_log')
|
|
|
|
def new_log():
|
|
return new_log_p.bind()
|
|
|
|
|
|
class ReadLog(HiPrimitive):
|
|
multiple_results = True
|
|
|
|
def is_high(self, _) -> bool: return True
|
|
|
|
def abstract_eval(self, log_qdd):
|
|
raise Exception
|
|
|
|
def to_lojax(_, log):
|
|
return list(FlatTree.flatten(log._dct))
|
|
log_read_p = ReadLog('log_read')
|