This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,23 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_values
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
from jax._src.pallas.fuser.custom_fusion_lib import custom_fusion as custom_fusion
from jax._src.pallas.fuser.fusible import fusible as fusible
from jax._src.pallas.fuser.fusion import Fusion as Fusion
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,82 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for evaluating functions under certain constraints."""
import dataclasses
from typing import Any
from jax import lax
from jax._src import core
from jax._src import source_info_util
from jax._src import tree_util
from jax._src import util
from jax._src.pallas.fuser import fuser_utils
@dataclasses.dataclass
class CustomEvaluateSettings:
allow_transpose: bool = True
def evaluate(f, *, allow_transpose: bool = True):
def wrapped(*args, **kwargs):
jaxpr, consts, _, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
settings = CustomEvaluateSettings(allow_transpose=allow_transpose)
flat_args = tree_util.tree_leaves(args)
out_flat = _custom_evaluate_jaxpr(settings, jaxpr, consts, *flat_args)
return tree_util.tree_unflatten(out_tree, out_flat)
return wrapped
# Disallow most higher-order primitives for now.
disallowed_primitives = {lax.scan_p, lax.while_p, lax.cond_p}
def _custom_evaluate_jaxpr(
settings: CustomEvaluateSettings, jaxpr: core.Jaxpr, consts, *args
):
def read(v: core.Atom) -> Any:
return v.val if isinstance(v, core.Literal) else env[v]
def write(v: core.Var, val: Any) -> None:
env[v] = val
env: dict[core.Var, Any] = {}
util.safe_map(write, jaxpr.constvars, consts)
util.safe_map(write, jaxpr.invars, args)
lu = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
bind_params = eqn.primitive.get_bind_params(eqn.params)
if eqn.primitive in disallowed_primitives:
raise NotImplementedError(f'Primitive {eqn.primitive} not supported.')
if not settings.allow_transpose and eqn.primitive is lax.transpose_p:
raise ValueError('Transpose not allowed.')
name_stack = (
source_info_util.current_name_stack() + eqn.source_info.name_stack
)
traceback = eqn.source_info.traceback
with source_info_util.user_context(
traceback, name_stack=name_stack
), eqn.ctx.manager:
ans = eqn.primitive.bind(
*util.safe_map(read, eqn.invars), **bind_params
)
if eqn.primitive.multiple_results:
util.safe_map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
core.clean_up_dead_vars(eqn, env, lu)
return util.safe_map(read, jaxpr.outvars)
@@ -0,0 +1,264 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Callable, Sequence
import dataclasses
import functools
from typing import Any, Protocol
from jax._src import api_util
from jax._src import core
from jax._src import custom_api_util
from jax._src import linear_util as lu
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas.mosaic import lowering as mosaic_lowering
from jax._src.pallas import core as pallas_core
from jax._src.pallas.fuser import block_spec as block_spec_lib
custom_fusion_p = core.Primitive('custom_fusion')
custom_fusion_p.multiple_results = True
CustomPullBlockSpecRuleFn = Callable[[tuple[block_spec_lib.BlockIndexTransform, ...]],
Sequence[block_spec_lib.BlockIndexTransform]]
CustomPushBlockSpecRuleFn = Callable[[tuple[pallas_core.BlockSpec, ...]],
tuple[pallas_core.BlockSpec, ...]]
@dataclasses.dataclass(frozen=True)
class CustomEvalContext:
out_block_specs: tuple[pallas_core.BlockSpec, ...]
out_block_indices: tuple[Any, ...]
class CustomEvalRuleFn(Protocol):
def __call__(
self,
ctx: CustomEvalContext,
*args: Any,
) -> Sequence[Any]:
...
@custom_api_util.register_custom_decorator_type
class custom_fusion:
fun: Callable[..., Any]
eval_rule: CustomEvalRuleFn | None = None
pull_block_spec_rule: CustomPullBlockSpecRuleFn | None = None
# Optional if this custom_fusion is only used as an input fusion.
push_block_spec_rule: CustomPushBlockSpecRuleFn | None = None
# Optional alternative implementation to use instead of `fun` for when this
# custom fusion is run inside a Pallas kernel.
pallas_impl: Callable[..., Any] | None = None
def __init__(self, fun: Callable[..., Any]):
functools.update_wrapper(self, fun)
self.fun = fun
def def_pallas_impl(self, pallas_impl):
self.pallas_impl = pallas_impl
return pallas_impl
def def_pull_block_spec(
self, pull_block_spec_rule: CustomPullBlockSpecRuleFn):
self.pull_block_spec_rule = pull_block_spec_rule
return pull_block_spec_rule
def def_push_block_spec(
self, push_block_spec_rule: CustomPushBlockSpecRuleFn):
self.push_block_spec_rule = push_block_spec_rule
return push_block_spec_rule
def def_eval_rule(self, eval_rule: CustomEvalRuleFn):
self.eval_rule = eval_rule
return eval_rule
@functools.partial(api_boundary,
repro_api_name="jax.pallas.custom_fusion.__call__")
def __call__(self, *args, **kwargs):
debug_fun = api_util.debug_info("custom_fusion fun", self.fun, args, kwargs)
# TODO(jburnim): Better error messages here.
assert self.eval_rule is not None
assert self.pull_block_spec_rule is not None
try:
args = api_util.resolve_kwargs(self.fun, args, kwargs)
except TypeError as e:
raise TypeError(
"The input arguments to the custom_fusion-decorated function "
f"{debug_fun.func_name} could not be resolved to positional-only "
f"arguments. Binding failed with the error:\n{e}"
) from e
# flatten and get jaxpr
args_flat, in_tree = tree_util.tree_flatten(args)
in_avals = [core.typeof(x) for x in args_flat]
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(self.fun, debug_info=debug_fun.with_unknown_names()),
in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
# if a Pallas implementation was provided, get its jaxpr
if self.pallas_impl is not None:
debug_pallas_impl = api_util.debug_info(
"custom_fusion pallas_impl", self.pallas_impl, args, kwargs)
flat_pallas_impl, pallas_out_tree = api_util.flatten_fun_nokwargs(
lu.wrap_init(self.pallas_impl, debug_info=debug_pallas_impl),
in_tree)
# TODO(jburnim): Error if out_tree() and kernel_out_tree() are different?
del pallas_out_tree
pallas_jaxpr, _, pallas_consts = (
pe.trace_to_jaxpr_dynamic(flat_pallas_impl, in_avals))
else:
pallas_jaxpr = None
pallas_consts = []
# debug_info for rules
out_flat = custom_fusion_p.bind(
*consts,
*pallas_consts,
*args_flat,
jaxpr=jaxpr,
num_consts=len(consts),
eval_rule=self.eval_rule,
pull_block_spec_rule=self.pull_block_spec_rule,
push_block_spec_rule=self.push_block_spec_rule,
pallas_jaxpr=pallas_jaxpr,
pallas_num_consts=len(pallas_consts),
in_tree=in_tree,
out_tree=out_tree(),
kernel_out_tree=out_tree())
return tree_util.tree_unflatten(out_tree(), out_flat)
@custom_fusion_p.def_impl
def _custom_fusion_impl(
*args,
jaxpr: core.Jaxpr,
num_consts: int,
pallas_num_consts: int,
**_):
consts, _, args = util.split_list(args, [num_consts, pallas_num_consts])
return core.eval_jaxpr(jaxpr, consts, *args)
mlir.register_lowering(custom_fusion_p, mlir.lower_fun(
_custom_fusion_impl, multiple_results=True))
@custom_fusion_p.def_effectful_abstract_eval
def _custom_fusion_effectful_abstract_eval(
*args,
jaxpr: core.Jaxpr,
pallas_jaxpr: core.Jaxpr | None,
**_):
del args
# TODO(jburnim): Error if pallas_jaxpr has different number of outputs, or
# different shapes and types of outputs?
if jaxpr.effects:
raise NotImplementedError(
"custom_fusion-decorated function {jaxpr.debug_info.func_src_info} "
"has effects, which is not yet supported: {jaxpr.effects}")
if pallas_jaxpr is not None and pallas_jaxpr.effects:
raise NotImplementedError(
"custom_fusion-decorated function {jaxpr.debug_info.func_src_info} "
"has a pallas_impl with effects, which is not yet supported: "
f"{pallas_jaxpr.effects}")
return jaxpr.out_avals, jaxpr.effects
@block_spec_lib.register_eval_rule(custom_fusion_p)
def _custom_fusion_eval_rule(
ctx: block_spec_lib.KernelEvalContext,
*args,
eval_rule: CustomEvalRuleFn,
num_consts: int,
pallas_num_consts: int,
**_):
args = args[num_consts + pallas_num_consts:]
return eval_rule(CustomEvalContext(
out_block_specs=ctx.out_block_specs,
out_block_indices=ctx.get_out_block_indices(),
), *args)
# TODO(jburnim): Lowering rules for SC and Mosaic GPU.
@mosaic_lowering.register_lowering_rule(custom_fusion_p)
def _custom_fusion_mosaic_lowering_rule(
ctx: mosaic_lowering.LoweringRuleContext,
*args,
jaxpr: core.Jaxpr,
num_consts: int,
pallas_jaxpr: core.Jaxpr | None,
pallas_num_consts: int,
**_):
consts, pallas_consts, args = util.split_list(
args, [num_consts, pallas_num_consts])
if pallas_jaxpr is None:
pallas_jaxpr = jaxpr
pallas_consts = consts
lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
return mosaic_lowering.jaxpr_subcomp(
lowering_context, pallas_jaxpr, *pallas_consts, *args)
@block_spec_lib.register_pull_block_spec_rule(custom_fusion_p)
def _custom_fusion_pull_block_spec_rule(
ctx : block_spec_lib.PullRuleContext,
out_block_transforms : tuple[block_spec_lib.BlockIndexTransform, ...],
*,
pull_block_spec_rule : CustomPullBlockSpecRuleFn,
**_,
) -> Sequence[block_spec_lib.BlockIndexTransform]:
del ctx
return pull_block_spec_rule(out_block_transforms)
@block_spec_lib.register_push_block_spec_rule(custom_fusion_p)
def _custom_fusion_push_block_spec_rule(
ctx : block_spec_lib.PushRuleContext,
*block_specs : pallas_core.BlockSpec,
push_block_spec_rule : CustomPushBlockSpecRuleFn,
**_
) -> tuple[pallas_core.BlockSpec, ...]:
del ctx
# TODO(jburnim): Better error message if push_block_spec_rule is None.
return push_block_spec_rule(block_specs)
@block_spec_lib.register_usage_rule(custom_fusion_p)
def _custom_fusion_usage_rule(
ctx : block_spec_lib.UsageRuleContext,
used_out: Sequence[set[block_spec_lib.Usage]],
*,
jaxpr: core.Jaxpr,
**_
) -> Sequence[set[block_spec_lib.Usage]]:
del ctx
# TODO(jburnim): Error if jaxpr.jaxpr gives different usage than pallas_jaxpr?
read_usage_env = block_spec_lib.compute_usage(jaxpr, used_out)
return util.safe_map(read_usage_env, jaxpr.invars)
@@ -0,0 +1,32 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic utils for fuser internals."""
from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src import tree_util
from jax._src.interpreters import partial_eval as pe
def make_jaxpr(f, *args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
flat_avals = [core.shaped_abstractify(x) for x in flat_args]
debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f, debug_info=debug_info), in_tree
)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree_thunk()
return jaxpr, consts, in_tree, out_tree
@@ -0,0 +1,141 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fusible primitive."""
from functools import partial
from typing import Any
import jax
from jax._src import api_util
from jax._src import core as jax_core
from jax._src.interpreters import batching
from jax._src import linear_util as lu
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas.fuser import fusion as fusion_lib
fusible_p = jax_core.Primitive('fusible')
fusible_p.multiple_results = True
def _fusible_is_high(*_, jaxpr, **params):
del params
return jaxpr.is_high
fusible_p.is_high = _fusible_is_high
def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
return fusion_lib.Fusion(
func=lambda: x,
in_type=((), {}),
out_type=jax.typeof(x),
)
@partial(api_boundary, repro_api_name="fuser.fusible")
def fusible(f=None, *, output_fusion_prefix: Any = True):
def decorator(f):
def wrapper(*args):
def wrapped(*args):
in_fusions = tree_util.tree_map(_make_trivial_fusion, args)
output_fusions = tree_util.tree_unflatten(
tree_util.tree_structure(output_fusion_prefix),
[None] * len(tree_util.tree_leaves(output_fusion_prefix)),
)
return f(*in_fusions, output_fusions)
flat_args, in_tree = tree_util.tree_flatten(args)
debug_info = api_util.debug_info('fusible', wrapped, args, {})
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
)
flat_avals = [jax_core.typeof(x) for x in flat_args]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
out_tree = out_tree_thunk()
out = fusible_p.bind(
*consts,
*flat_args,
jaxpr=jaxpr,
num_consts=len(consts),
in_tree=in_tree,
out_tree=out_tree,
func=f,
output_fusion_prefix=output_fusion_prefix,
)
return tree_util.tree_unflatten(out_tree, out)
return wrapper
if f is not None:
return decorator(f)
return decorator
@fusible_p.def_impl
def _(*consts_and_args, jaxpr, num_consts, **_):
consts, args = util.split_list(consts_and_args, [num_consts])
return jax_core.eval_jaxpr(jaxpr, consts, *args)
mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl))
@fusible_p.def_effectful_abstract_eval
def _(*args, jaxpr, **kwargs):
del args, kwargs
return [v.aval for v in jaxpr.outvars], jaxpr.effects
def _fusible_trivial_batching_rule(axis_data, args, dims, **kwargs):
if axis_data.size != 1:
raise NotImplementedError('fusible does not support non-trivial batching')
unbatched_args = tuple(
a if (d is batching.not_mapped or d is None) else a[d]
for a, d in zip(args, dims, strict=True)
)
out_unbatched = fusible_p.bind(*unbatched_args, **kwargs)
out = tuple(o[None] for o in out_unbatched)
return out, (0,) * len(out)
batching.fancy_primitive_batchers[fusible_p] = _fusible_trivial_batching_rule
def _fusible_to_lojax(*hi_args, jaxpr, num_consts, **_):
const_in_avals = jaxpr.in_aval_qdds[:num_consts]
num_lo_consts = sum(len(aval.lo_ty()) for aval in const_in_avals)
lo_args = [
lo_val
for aval, x in util.safe_zip(jaxpr.in_aval_qdds, hi_args)
for lo_val in (aval.read_loval(x) if aval.has_qdd else aval.lower_val(x))
]
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, lo_args[:num_lo_consts])
lo_jaxpr = pe.lower_jaxpr2(closed_jaxpr)
all_outs = fusible_p.bind(*lo_args, jaxpr=lo_jaxpr.jaxpr, num_consts=num_lo_consts)
out_mut, lo_outs = util.split_list(all_outs, [pe.num_himuts_out(jaxpr.final_aval_qdds)])
for a, x, us in zip(jaxpr.final_aval_qdds, hi_args, out_mut):
if a.has_qdd:
a.aval.update_from_loval(a.qdd, x, *us)
return pe.raise_lo_outs(jaxpr.out_avals, lo_outs)
fusible_p.to_lojax = _fusible_to_lojax
@@ -0,0 +1,567 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom fusible dtypes."""
import abc
import dataclasses
import functools
import itertools as it
from typing import Any, TypeVar
from collections.abc import Callable, Sequence
import jax
from jax._src import api_util
from jax._src import core
from jax._src import custom_derivatives
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import state
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import partial_eval as pe
from jax._src.lax.control_flow import conditionals
from jax._src.pallas import core as pallas_core
from jax._src.pallas import pallas_call
from jax._src.pallas import primitives as pallas_primitives
from jax._src.pallas.fuser import block_spec
from jax._src.pallas.fuser.fusible import fusible_p
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as state_primitives
from jax._src.util import foreach
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
T = TypeVar("T")
_physicalize_rules: dict[core.Primitive, Callable[..., Any]] = {}
pack_dtype_p = core.Primitive("pack_dtype")
@pack_dtype_p.def_abstract_eval
def pack_dtype_abstract_eval(*xs, dtype):
if dtypes.issubdtype(dtype, FusibleElementDType):
return dtype.abstract_pack(*xs)
raise ValueError("Attempted to pack non-fusion dtype: {dtype}")
def pack(*xs, dtype):
return pack_dtype_p.bind(*xs, dtype=dtype)
unpack_dtype_p = core.Primitive("unpack_dtype")
unpack_dtype_p.multiple_results = True
@unpack_dtype_p.def_abstract_eval
def unpack_dtype_abstract_eval(x):
if dtypes.issubdtype(x.dtype, FusibleElementDType):
return x.dtype.abstract_unpack(x)
elif isinstance(x.dtype, state.AbstractRef):
raise NotImplementedError()
raise ValueError("Attempted to unpack non-fusion dtype: {dtype}")
def unpack(x):
return tuple(unpack_dtype_p.bind(x))
class FusibleElementDType(dtypes.extended):
"""Scalar dtype for fusible dtypes."""
class FusibleTyRules:
allow_conversion: bool = False
class FusionDType(dtypes.ExtendedDType, util.StrictABC):
"""Base class for fusible extended dtypes."""
_op_registry = {}
_rules = FusibleTyRules
type = FusibleElementDType
@abc.abstractmethod
def abstract_unpack(self, x) -> Sequence[Any]:
raise NotImplementedError()
@abc.abstractmethod
def abstract_pack(self, *xs):
raise NotImplementedError()
@classmethod
def register_op(cls, primitive):
def _register_fn(fn):
cls._op_registry[primitive] = fn
return _register_fn
@classmethod
def get_op_rule(cls, primitive):
return cls._op_registry.get(primitive)
@property
def name(self):
return str(self)
@abc.abstractmethod
def pull_block_spec_one_step(self, aval_out, *args, **kwargs):
raise NotImplementedError()
@abc.abstractmethod
def unpack_push_block_spec(self, aval_in, *args, **kwargs):
raise NotImplementedError()
@abc.abstractmethod
def unpack_pull_block_spec(self, aval_in, *args, **kwargs):
raise NotImplementedError()
def physicalize(f):
"""Runs a function that contains fusible extended dtypes."""
def wrapper(*args, **kwargs):
if kwargs:
raise NotImplementedError()
flattened_args, treedef = jax.tree.flatten(args)
debug_info = api_util.debug_info("physicalize", f, args, kwargs)
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
lu.wrap_init(f, debug_info=debug_info), treedef
)
avals = [core.typeof(a) for a in flattened_args]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
out_flat = core.eval_jaxpr(
new_jaxpr.jaxpr, new_jaxpr.consts, *flattened_args
)
return tree_util.tree_unflatten(out_tree_thunk(), out_flat)
return wrapper
@util.weakref_lru_cache
def physicalize_closed_jaxpr(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
"""Replaces all extended dtypes with physical types in a jaxpr."""
fun = functools.partial(physicalize_interp, jaxpr.jaxpr, jaxpr.consts)
in_avals = [_physical_aval(aval) for aval in jaxpr.in_avals]
flat_avals, treedef = tree_util.tree_flatten(in_avals)
debug_info = api_util.debug_info("physicalize_closed_jaxpr", fun, (), {})
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(fun, debug_info=debug_info.with_unknown_names()), treedef
)
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals)
assert len(new_jaxpr.constvars) == len(consts), "Mismatched consts"
return core.ClosedJaxpr(new_jaxpr, consts)
def _physical_aval(aval):
if isinstance(aval, core.ShapedArray):
if isinstance(aval.dtype, FusionDType):
return aval.dtype.abstract_unpack(aval)
return core.ShapedArray(aval.shape, aval.dtype)
if isinstance(aval, state.AbstractRef):
if _is_fusion_type(aval):
unpacked = aval.dtype.abstract_unpack(aval.inner_aval)
return tuple(aval.update(inner_aval=u) for u in unpacked)
return aval
return aval
def physicalize_jaxpr(jaxpr: core.Jaxpr) -> core.Jaxpr:
"""Replaces all extended dtypes with physical types in a jaxpr."""
def _flat_jaxpr_eval(consts, args):
return physicalize_interp(jaxpr, consts, *args)
in_avals = [_physical_aval(v.aval) for v in jaxpr.invars]
const_avals = [_physical_aval(v.aval) for v in jaxpr.constvars]
flat_avals, treedef = jax.tree.flatten((const_avals, in_avals))
debug_info = api_util.debug_info(
"physicalize_jaxpr", _flat_jaxpr_eval, (const_avals, in_avals), {})
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
lu.wrap_init(_flat_jaxpr_eval, debug_info=debug_info), treedef
)
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals)
assert not consts
new_jaxpr = pe.convert_invars_to_constvars(
new_jaxpr, len(tree_util.tree_leaves(const_avals))
)
return new_jaxpr
@dataclasses.dataclass
class Context:
avals_in: Sequence[Any]
avals_out: Sequence[Any]
def physicalize_interp(
jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value
):
"""Physicalizes a jaxpr by replacing fusible dtypes with physical types."""
# TODO: Merge into JAX core.
env: dict[core.Var, Any] = {}
def read_env(var: core.Atom):
if isinstance(var, core.Literal):
return var.val
return env[var]
def write_env(var: core.Var, val: Any):
env[var] = val
foreach(write_env, jaxpr.constvars, consts)
assert len(jaxpr.invars) == len(
args
), f"Length mismatch: {jaxpr.invars} != {args}"
foreach(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = list(map(read_env, eqn.invars))
avals_in = tuple(x.aval for x in eqn.invars)
name_stack = (
source_info_util.current_name_stack() + eqn.source_info.name_stack
)
with (
source_info_util.user_context(
eqn.source_info.traceback, name_stack=name_stack
),
eqn.ctx.manager,
):
# need to check types and then invoke the correct rule.
ctx = Context(
avals_in=avals_in, avals_out=[var.aval for var in eqn.outvars]
)
custom_rule = _phys_find_rule(eqn.primitive, avals_in)
if custom_rule:
outvals = custom_rule(ctx, *invals, **eqn.params)
else:
bind_params = eqn.primitive.get_bind_params(eqn.params)
outvals = eqn.primitive.bind(*invals, **bind_params)
if eqn.primitive.multiple_results:
assert len(outvals) == len(eqn.outvars), eqn
foreach(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
return map(read_env, jaxpr.outvars)
def _is_fusion_type(aval: core.AbstractValue):
"""Returns whether an aval is an array containing fusion types."""
return (
isinstance(aval, (core.ShapedArray, state.AbstractRef))
and hasattr(aval, 'dtype')
and isinstance(aval.dtype, FusionDType)
)
def _phys_find_rule(primitive, avals: Sequence[core.AbstractValue]):
"""Finds the physicalization rule for a primitive."""
if primitive in _physicalize_rules:
return _physicalize_rules[primitive]
# pyrefly: ignore[missing-attribute]
fusion_types = {aval.dtype for aval in avals if _is_fusion_type(aval)}
if len(fusion_types) == 0:
return None
elif len(fusion_types) > 1:
raise ValueError(f"Multiple fusion types for primitive: {fusion_types}")
fusion_type = fusion_types.pop()
if primitive not in fusion_type._op_registry:
raise ValueError(
f"No implementation found for primitive {primitive} "
f"for custom type {fusion_type}"
)
return fusion_type.get_op_rule(primitive)
def _assert_no_fusion_types(avals: Sequence[core.AbstractValue]):
if any(_is_fusion_type(aval) for aval in avals):
raise NotImplementedError(f"Fusion type found in avals: {avals}")
def _pallas_call_physicalize_rule(
ctx: Context, *args, jaxpr, grid_mapping: pallas_core.GridMapping, **kwargs
):
_assert_no_fusion_types(ctx.avals_in)
_assert_no_fusion_types(ctx.avals_out)
with grid_mapping.trace_env():
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, ()))
if diff := len(new_jaxpr.jaxpr.invars) - len(jaxpr.invars):
num_scratch_avals = len(grid_mapping.scratch_avals) + diff
new_scratch_avals = tuple(v.aval for v in
new_jaxpr.jaxpr.invars[-num_scratch_avals:])
grid_mapping = grid_mapping.replace(
scratch_avals=new_scratch_avals
)
return pallas_call.pallas_call_p.bind(
*args, jaxpr=new_jaxpr.jaxpr, grid_mapping=grid_mapping, **kwargs
)
_physicalize_rules[pallas_call.pallas_call_p] = _pallas_call_physicalize_rule
def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs):
_assert_no_fusion_types(ctx.avals_out)
physicalized_branches = tuple(
physicalize_closed_jaxpr(branch) for branch in branches
)
flat_args = jax.tree.leaves(args)
return conditionals.cond_p.bind(
*flat_args, branches=physicalized_branches, **kwargs
)
_physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule
@lu.transformation2
def _physicalize_transform(f, *args):
vals, zeros = args[::2], args[1::2]
assert len(vals) == len(zeros)
wrapper = lambda *inner_vals: f(
*it.chain.from_iterable(zip(inner_vals, zeros))
)
return physicalize(wrapper)(*vals)
@lu.transformation2
def _physicalize_transform_bwd(f, const_avals, *args):
return [custom_derivatives.Zero(a) for a in const_avals] + list(
physicalize(f)(*args)
)
def _custom_vjp_call_physicalize_rule(
ctx: Context, *args, call_jaxpr, num_consts, fwd_jaxpr_thunk, bwd, **kwargs
):
_assert_no_fusion_types(ctx.avals_out)
new_jaxpr = physicalize_closed_jaxpr(call_jaxpr)
fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr),
debug_info=call_jaxpr.jaxpr.debug_info)
fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk)
fwd_physicalized = _physicalize_transform(fwd)
const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts])
bwd_physicalized = _physicalize_transform_bwd(bwd, const_avals)
kwargs['subfuns'] = (fun, fwd_physicalized, bwd_physicalized)
return custom_derivatives.custom_vjp_call_p.bind(*args, **kwargs)
_physicalize_rules[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call_physicalize_rule
def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized):
_assert_no_fusion_types(ctx.avals_in)
_assert_no_fusion_types(ctx.avals_out)
jaxpr = physicalize_jaxpr(jaxpr)
return state_discharge.run_state_p.bind(
*args,
jaxpr=jaxpr,
which_linear=which_linear,
is_initialized=is_initialized,
)
_physicalize_rules[state_discharge.run_state_p] = _run_state_rule
def _core_map_rule(ctx: Context, *args, jaxpr, **params):
_assert_no_fusion_types(ctx.avals_in)
_assert_no_fusion_types(ctx.avals_out)
assert not jaxpr.invars
with core.extend_axis_env_nd(params["mesh"].shape.items()):
jaxpr = physicalize_jaxpr(jaxpr)
return pallas_core.core_map_p.bind(*args, jaxpr=jaxpr, **params)
_physicalize_rules[pallas_core.core_map_p] = _core_map_rule
def _run_scoped_rule(ctx: Context, *args, jaxpr, **params):
_assert_no_fusion_types(ctx.avals_out)
jaxpr = physicalize_jaxpr(jaxpr)
flat_args = tree_util.tree_leaves(args)
assert len(flat_args) == len(
jaxpr.constvars
), f"Length mismatch: {len(flat_args)=} != {len(jaxpr.constvars)=}"
return pallas_primitives.run_scoped_p.bind(*flat_args, jaxpr=jaxpr, **params)
_physicalize_rules[pallas_primitives.run_scoped_p] = _run_scoped_rule
def _scan_rule(ctx: Context, *args, jaxpr, **params):
_assert_no_fusion_types(ctx.avals_in)
_assert_no_fusion_types(ctx.avals_out)
jaxpr = physicalize_closed_jaxpr(jaxpr)
return jax.lax.scan_p.bind(*args, jaxpr=jaxpr, **params)
_physicalize_rules[jax.lax.scan_p] = _scan_rule
def _while_rule(
ctx: Context, *args, body_jaxpr, cond_jaxpr, body_nconsts,
cond_nconsts, **params
):
_assert_no_fusion_types(ctx.avals_out)
cond_avals = [v.aval for v in cond_jaxpr.jaxpr.invars]
_, cond_in_avals = util.split_list(cond_avals, [cond_nconsts])
_assert_no_fusion_types(cond_in_avals)
new_cond_jaxpr = physicalize_closed_jaxpr(cond_jaxpr)
new_num_cond_consts = (
cond_nconsts
+ len(new_cond_jaxpr.jaxpr.invars)
- len(cond_jaxpr.jaxpr.invars)
)
body_avals = [v.aval for v in body_jaxpr.jaxpr.invars]
_, body_in_avals = util.split_list(body_avals, [body_nconsts])
_assert_no_fusion_types(body_in_avals)
new_body_jaxpr = physicalize_closed_jaxpr(body_jaxpr)
new_num_body_consts = (
body_nconsts
+ len(new_body_jaxpr.jaxpr.invars)
- len(body_jaxpr.jaxpr.invars)
)
flat_args = tree_util.tree_leaves(args)
cond_consts, body_consts, flat_args = util.split_list(
flat_args, [new_num_cond_consts, new_num_body_consts]
)
assert len(flat_args) + len(body_consts) == len(
new_body_jaxpr.jaxpr.invars), (
f"Length mismatch: {len(flat_args) + len(body_consts)} !="
f" {len(new_body_jaxpr.jaxpr.invars)=}"
)
assert len(flat_args) + len(cond_consts) == len(
new_cond_jaxpr.jaxpr.invars), (
f"Length mismatch: {len(flat_args) + len(cond_consts)} !="
f" {len(new_cond_jaxpr.jaxpr.invars)=}"
)
return jax.lax.while_p.bind(
*(cond_consts + body_consts + flat_args),
body_jaxpr=new_body_jaxpr,
cond_jaxpr=new_cond_jaxpr,
body_nconsts=new_num_body_consts,
cond_nconsts=new_num_cond_consts,
**params,
)
_physicalize_rules[jax.lax.while_p] = _while_rule
def _pack_rule(_, *args, dtype):
del dtype
return args
_physicalize_rules[pack_dtype_p] = _pack_rule
def _unpack_rule(_, arg):
return arg
_physicalize_rules[unpack_dtype_p] = _unpack_rule
def _swap_rule(ctx: Context, ref, val, *args, tree):
ref_aval, *_ = ctx.avals_in
if not _is_fusion_type(ref_aval):
return state_primitives.swap_p.bind(ref, val, *args, tree=tree)
return ref_aval.dtype.swap(ref, val, *args, tree=tree)
_physicalize_rules[state_primitives.swap_p] = _swap_rule
def _get_rule(ctx: Context, ref, *args, tree):
ref_aval, *_ = ctx.avals_in
if not _is_fusion_type(ref_aval):
return state_primitives.get_p.bind(ref, *args, tree=tree)
return ref_aval.dtype.get(ref, *args, tree=tree)
_physicalize_rules[state_primitives.get_p] = _get_rule
@block_spec.register_eval_rule(pack_dtype_p)
def _pack_dtype_eval_rule(ctx: block_spec.KernelEvalContext, *args, dtype):
return dtype.pack_eval_rule(ctx, *args)
@block_spec.register_pull_block_spec_rule(pack_dtype_p)
def _pack_dtype_pull_rule(
ctx: block_spec.PullRuleContext,
block_spec: pallas_core.BlockSpec,
*,
dtype: FusionDType,
):
aval_out = ctx.avals_out[0]
return dtype.pull_block_spec_one_step(aval_out, block_spec)
@block_spec.register_push_block_spec_rule(unpack_dtype_p)
def _unpack_dtype_push_rule(
ctx: block_spec.PushRuleContext,
block_spec: pallas_core.BlockSpec,
):
aval_in = ctx.avals_in[0]
assert isinstance(aval_in, core.ShapedArray)
assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype
return aval_in.dtype.unpack_push_block_spec(aval_in, block_spec)
@block_spec.register_pull_block_spec_rule(unpack_dtype_p)
def _unpack_dtype_pull_rule(
ctx: block_spec.PushRuleContext,
block_specs: pallas_core.BlockSpec,
):
aval_in = ctx.avals_in[0]
assert isinstance(aval_in, core.ShapedArray)
assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype
return aval_in.dtype.unpack_pull_block_spec(aval_in, *block_specs) # pyrefly: ignore[not-iterable]
@block_spec.register_eval_rule(unpack_dtype_p)
def _unpack_dtype_eval_rule(ctx: block_spec.KernelEvalContext, *args):
assert ctx.avals_in is not None
aval_in = ctx.avals_in[0]
assert isinstance(aval_in, core.ShapedArray)
assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype
return aval_in.dtype.unpack_eval_rule(ctx, *args) # pyrefly: ignore[missing-attribute]
def _fusible_physicalize_rule(
_, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func
):
consts, _ = util.split_list(consts_and_args, [num_consts])
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
return fusible_p.bind(
*consts_and_args,
jaxpr=new_jaxpr.jaxpr,
num_consts=num_consts,
in_tree=in_tree,
out_tree=out_tree,
func=func,
)
_physicalize_rules[fusible_p] = _fusible_physicalize_rule
@@ -0,0 +1,60 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fusion classes."""
from __future__ import annotations
import dataclasses
from typing import Any, Generic, ParamSpec, TypeVar
from collections.abc import Callable
import jax
from jax._src import util
safe_map = util.safe_map
A = ParamSpec("A")
K = TypeVar("K")
@dataclasses.dataclass
class Fusion(Generic[A, K]):
func: Callable[A, K]
in_type: tuple[tuple[Any, ...], dict[str, Any]]
out_type: Any
def __call__(self, *args: A.args, **kwargs: A.kwargs) -> K:
return self.func(*args, **kwargs)
@property
def shape(self):
return jax.tree.map(lambda x: x.shape, self.out_type)
@property
def dtype(self):
return jax.tree.map(lambda x: x.dtype, self.out_type)
@property
def type(self):
return self.out_type
@property
def in_shape(self):
return jax.tree.map(lambda x: x.shape, self.in_type)
@property
def in_dtype(self):
return jax.tree.map(lambda x: x.dtype, self.in_type)
@@ -0,0 +1,318 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fuses a function."""
from collections.abc import Sequence
import functools
from typing import Any
import jax
from jax._src import api_util
from jax._src import core as jax_core
from jax._src import linear_util as lu
from jax._src.traceback_util import api_boundary
from jax._src import tree_util
from jax._src.interpreters import partial_eval as pe
from jax._src.pallas.fuser import fusible_dtype
from jax._src.pallas.fuser import fusion as fusion_lib
from jax._src.pallas.fuser.fusible import fusible_p
@functools.partial(api_boundary, repro_api_name="fuser.fuse")
def fuse(f=None, *, resolve_fusion_dtypes: bool = True, debug: bool = False):
"""Fuses a function into a single fusible.
Args:
f: The function to fuse.
resolve_fusion_dtypes: (experimental) whether or not to resolve fusion
dtypes (which don't correspond to physical dtypes)
debug: Whether to print debug information.
There should be a single call to a `fusible` inside the body of `f`. `fuse`
returns a transformed function that will fuse the surrounding computation into
the fusible and invoke it.
"""
def decorator(f):
def wrapper(*args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
debug_info = api_util.debug_info("fuse", f, args, kwargs)
flat_fun, out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(f, debug_info=debug_info), in_tree
)
flat_avals = [jax_core.typeof(x) for x in flat_args]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
if debug:
print("Jaxpr before fusion:")
print(jaxpr)
out_tree = out_tree_thunk()
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
return tree_util.tree_unflatten(out_tree, out_flat)
if resolve_fusion_dtypes:
wrapper = fusible_dtype.physicalize(wrapper)
return wrapper
if f is not None:
return decorator(f)
return decorator
_fusible: dict[jax_core.Primitive, Any] = {}
def _construct_fusion_jaxpr(
candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs
):
flat_outvars, out_tree = tree_util.tree_flatten(outvars)
flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs))
new_jaxpr_no_dce = jaxpr.replace(
outvars=flat_outvars,
constvars=jaxpr.constvars + jaxpr.invars,
invars=flat_invars,
debug_info=jaxpr.debug_info.with_unknown_names()
)
new_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
new_jaxpr_no_dce,
[True] * len(new_jaxpr_no_dce.outvars),
instantiate=[False] * len(new_jaxpr_no_dce.constvars)
+ [True] * len(new_jaxpr_no_dce.invars),
)
assert all(used_invars), new_jaxpr_no_dce
new_values = tuple(
c for used, c in zip(used_consts, candidate_values, strict=True) if used
)
kernel_in_tree = tree_util.tree_structure((invars, kwargs))
flat_in_type = [x.aval for x in flat_invars]
in_type = tree_util.tree_unflatten(kernel_in_tree, flat_in_type)
out_type = tree_util.tree_unflatten(
out_tree,
[x.aval for x in flat_outvars],
)
return new_jaxpr, new_values, in_type, out_type, out_tree
def construct_input_fusion(
candidate_values, jaxpr: jax_core.Jaxpr, outvars
) -> fusion_lib.Fusion:
new_jaxpr, new_values, in_type, out_type, out_tree = _construct_fusion_jaxpr(
candidate_values, jaxpr, outvars,
)
def _fn():
out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values)
return tree_util.tree_unflatten(out_tree, out_flat)
return fusion_lib.Fusion(_fn, in_type, out_type)
def _find_downstream(
jaxpr: jax_core.Jaxpr, in_used: Sequence[bool]
) -> tuple[bool, ...]:
# TODO(sharadmv): We use partial_eval to query downstream dependencies which
# is not an officially sanctioned way to do so, since PE is really used for
# AD. In the future, we should have a special Jaxpr API that queries this.
_, _, out_used, *_ = pe.partial_eval_jaxpr_custom(
jaxpr,
in_unknowns=in_used,
in_inst=in_used,
ensure_out_unknowns=False,
ensure_out_inst=False,
saveable=lambda *_, **__: False,
)
return tuple(out_used)
def _construct_output_permutation(
used: list[tuple[bool, ...]],
) -> list[int]:
order = []
for u in used:
true_vals = [i for i in range(len(u)) if u[i]]
order.extend(true_vals)
return [order.index(i) for i in range(len(order))]
def _construct_output_fusions(
candidate_values,
jaxpr,
out_tree,
fusion_eqn_index,
fusion_eqn_outvars, # Flat list of vars output by the fusible eqn
fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs
output_fusion_prefix, # Pytree defining output groups
):
# 1. Create jaxpr_out: represents computation *after* the fusible
# Inputs: fusion_eqn_outvars
# Outputs: jaxpr.outvars
jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr(
candidate_values,
jaxpr.replace(
eqns=jaxpr.eqns[:fusion_eqn_index]
+ jaxpr.eqns[fusion_eqn_index + 1 :]
),
tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs
tree_util.tree_unflatten(
fusion_eqn_out_tree, fusion_eqn_outvars
), # Fusible outputs as inputs
)
# 2. Group fusible outputs based on the mask
unflat_fusible_outvars = jax.tree.unflatten(
fusion_eqn_out_tree, fusion_eqn_outvars
)
partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to(
unflat_fusible_outvars
)
# 3. Calculate dependencies and check disjointedness
downstream_outputs_used_masks = [] # List of bool tuples, one per group
already_used_final_outputs = set() # Indices of final outputs already claimed
for outvars_group in partial_flat:
# Identify vars in this group
used_fusible_outvars = set(jax.tree.leaves(outvars_group))
# Create mask for jaxpr_out inputs corresponding to this group
in_used_mask = [
True if v in used_fusible_outvars else False for v in jaxpr_out.invars
]
# Trace dependencies through jaxpr_out to find which final outputs are affected
downstream_used_mask = _find_downstream(
jaxpr_out, in_used_mask
) # Mask for jaxpr_out.outvars (== jaxpr.outvars)
# Check for overlap in final output usage across groups
for i, used in enumerate(downstream_used_mask):
if used:
if i in already_used_final_outputs:
raise ValueError(
"Outputs must be disjoint in order to use separate output fusions"
)
already_used_final_outputs.add(i)
downstream_outputs_used_masks.append(downstream_used_mask)
# 4. Construct output permutation needed to restore original output order
output_permutation = _construct_output_permutation(
downstream_outputs_used_masks
)
# Construct fusions for each group by DCEing the jaxpr_out
output_fusions: list[fusion_lib.Fusion | None] = []
for i, outvars_group in enumerate(partial_flat):
flat_group_vars, _ = tree_util.tree_flatten(outvars_group)
downstream_used_mask = downstream_outputs_used_masks[i]
used_jaxpr_invars = [False] * len(all_values) + [
v in flat_group_vars for v in jaxpr_out.invars
]
jaxpr_out_for_group, used_consts, _ = pe.dce_jaxpr_consts(
jaxpr_out, downstream_used_mask, instantiate=used_jaxpr_invars
)
values_for_jaxpr = tuple(
c for used, c in zip(used_consts, all_values, strict=True) if used
)
if (
not jaxpr_out_for_group.eqns
and jaxpr_out_for_group.outvars == jaxpr_out_for_group.invars
):
output_fusions.append(None)
continue
def _fn(jaxpr, vals, *args, **kwargs):
flat_args, _ = tree_util.tree_flatten((args, kwargs))
out_flat = jax_core.eval_jaxpr(jaxpr, vals, *flat_args)
return tuple(out_flat)
fn = functools.partial(_fn, jaxpr_out_for_group, values_for_jaxpr)
in_type = jax.tree.map(lambda x: x.aval, outvars_group)
out_type = tuple(v.aval for v in jaxpr_out_for_group.outvars)
fusion = fusion_lib.Fusion(
fn,
(in_type, {}),
out_type,
)
output_fusions.append(fusion)
return (
tree_util.tree_unflatten(
tree_util.tree_structure(output_fusion_prefix), output_fusions
),
output_permutation,
)
def fuse_jaxpr(
jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args
):
fusion_eqn_index = None
# Collect input fusions
for i, eqn in enumerate(jaxpr.eqns):
if eqn.primitive is fusible_p:
fusion_eqn_index = i
break
if fusion_eqn_index is None:
raise ValueError("No fusible eqn found")
fusion_eqn = jaxpr.eqns[fusion_eqn_index]
# Now let's check if we need to do any fusion at all, e.g. do the outputs of
# the jaxpr have any dependence on the fusion at all?
candidate_values = [*consts, *args]
independent_jaxpr, _, out_used, *_ = pe.partial_eval_jaxpr_custom(
jaxpr.replace(
eqns=(jaxpr.eqns[:fusion_eqn_index]
+ jaxpr.eqns[fusion_eqn_index + 1 :]),
constvars=jaxpr.constvars + jaxpr.invars,
invars=fusion_eqn.outvars,
debug_info=jaxpr.debug_info.with_unknown_names()),
in_unknowns=[True] * len(fusion_eqn.outvars),
in_inst=[True] * len(fusion_eqn.outvars),
ensure_out_unknowns=False,
ensure_out_inst=False,
saveable=lambda *_, **__: False)
if not any(out_used):
# Short circuit if there is no need to run the fusible at all.
return jax_core.eval_jaxpr(independent_jaxpr, candidate_values)
# Construct fusions for non-constant inputs to the fusible.
in_fusions_flat = [
construct_input_fusion(
candidate_values,
jaxpr.replace(
eqns=jaxpr.eqns[:fusion_eqn_index],
),
var,
)
for var in fusion_eqn.invars[fusion_eqn.params["num_consts"] :]
]
in_fusions = tree_util.tree_unflatten(
fusion_eqn.params["in_tree"], in_fusions_flat
)
output_fusions, output_permutation = _construct_output_fusions(
candidate_values,
jaxpr,
out_tree,
fusion_eqn_index,
fusion_eqn.outvars,
fusion_eqn.params["out_tree"],
fusion_eqn.params["output_fusion_prefix"],
)
out = fusion_eqn.params["func"](*in_fusions, output_fusions)
flat_out = jax.tree.leaves(out)
permuted_out = [flat_out[i] for i in output_permutation]
assert len(permuted_out) == len(jaxpr.outvars), (
len(permuted_out),
len(jaxpr.outvars),
)
return permuted_out