hand
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.pallas.fuser.block_spec import get_fusion_values as get_fusion_values
|
||||
from jax._src.pallas.fuser.block_spec import make_scalar_prefetch_handler as make_scalar_prefetch_handler
|
||||
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
|
||||
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
|
||||
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
|
||||
from jax._src.pallas.fuser.custom_fusion_lib import custom_fusion as custom_fusion
|
||||
from jax._src.pallas.fuser.fusible import fusible as fusible
|
||||
from jax._src.pallas.fuser.fusion import Fusion as Fusion
|
||||
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,82 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helpers for evaluating functions under certain constraints."""
|
||||
import dataclasses
|
||||
from typing import Any
|
||||
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.pallas.fuser import fuser_utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomEvaluateSettings:
|
||||
allow_transpose: bool = True
|
||||
|
||||
|
||||
def evaluate(f, *, allow_transpose: bool = True):
|
||||
def wrapped(*args, **kwargs):
|
||||
jaxpr, consts, _, out_tree = fuser_utils.make_jaxpr(f, *args, **kwargs)
|
||||
settings = CustomEvaluateSettings(allow_transpose=allow_transpose)
|
||||
flat_args = tree_util.tree_leaves(args)
|
||||
out_flat = _custom_evaluate_jaxpr(settings, jaxpr, consts, *flat_args)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# Disallow most higher-order primitives for now.
|
||||
disallowed_primitives = {lax.scan_p, lax.while_p, lax.cond_p}
|
||||
|
||||
|
||||
def _custom_evaluate_jaxpr(
|
||||
settings: CustomEvaluateSettings, jaxpr: core.Jaxpr, consts, *args
|
||||
):
|
||||
def read(v: core.Atom) -> Any:
|
||||
return v.val if isinstance(v, core.Literal) else env[v]
|
||||
|
||||
def write(v: core.Var, val: Any) -> None:
|
||||
env[v] = val
|
||||
|
||||
env: dict[core.Var, Any] = {}
|
||||
util.safe_map(write, jaxpr.constvars, consts)
|
||||
util.safe_map(write, jaxpr.invars, args)
|
||||
lu = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
|
||||
if eqn.primitive in disallowed_primitives:
|
||||
raise NotImplementedError(f'Primitive {eqn.primitive} not supported.')
|
||||
if not settings.allow_transpose and eqn.primitive is lax.transpose_p:
|
||||
raise ValueError('Transpose not allowed.')
|
||||
name_stack = (
|
||||
source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
)
|
||||
traceback = eqn.source_info.traceback
|
||||
with source_info_util.user_context(
|
||||
traceback, name_stack=name_stack
|
||||
), eqn.ctx.manager:
|
||||
ans = eqn.primitive.bind(
|
||||
*util.safe_map(read, eqn.invars), **bind_params
|
||||
)
|
||||
if eqn.primitive.multiple_results:
|
||||
util.safe_map(write, eqn.outvars, ans)
|
||||
else:
|
||||
write(eqn.outvars[0], ans)
|
||||
core.clean_up_dead_vars(eqn, env, lu)
|
||||
return util.safe_map(read, jaxpr.outvars)
|
||||
@@ -0,0 +1,264 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Any, Protocol
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas.mosaic import lowering as mosaic_lowering
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.fuser import block_spec as block_spec_lib
|
||||
|
||||
|
||||
custom_fusion_p = core.Primitive('custom_fusion')
|
||||
custom_fusion_p.multiple_results = True
|
||||
|
||||
CustomPullBlockSpecRuleFn = Callable[[tuple[block_spec_lib.BlockIndexTransform, ...]],
|
||||
Sequence[block_spec_lib.BlockIndexTransform]]
|
||||
|
||||
CustomPushBlockSpecRuleFn = Callable[[tuple[pallas_core.BlockSpec, ...]],
|
||||
tuple[pallas_core.BlockSpec, ...]]
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CustomEvalContext:
|
||||
out_block_specs: tuple[pallas_core.BlockSpec, ...]
|
||||
out_block_indices: tuple[Any, ...]
|
||||
|
||||
class CustomEvalRuleFn(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: CustomEvalContext,
|
||||
*args: Any,
|
||||
) -> Sequence[Any]:
|
||||
...
|
||||
|
||||
|
||||
@custom_api_util.register_custom_decorator_type
|
||||
class custom_fusion:
|
||||
fun: Callable[..., Any]
|
||||
|
||||
eval_rule: CustomEvalRuleFn | None = None
|
||||
|
||||
pull_block_spec_rule: CustomPullBlockSpecRuleFn | None = None
|
||||
|
||||
# Optional if this custom_fusion is only used as an input fusion.
|
||||
push_block_spec_rule: CustomPushBlockSpecRuleFn | None = None
|
||||
|
||||
# Optional alternative implementation to use instead of `fun` for when this
|
||||
# custom fusion is run inside a Pallas kernel.
|
||||
pallas_impl: Callable[..., Any] | None = None
|
||||
|
||||
def __init__(self, fun: Callable[..., Any]):
|
||||
functools.update_wrapper(self, fun)
|
||||
self.fun = fun
|
||||
|
||||
def def_pallas_impl(self, pallas_impl):
|
||||
self.pallas_impl = pallas_impl
|
||||
return pallas_impl
|
||||
|
||||
def def_pull_block_spec(
|
||||
self, pull_block_spec_rule: CustomPullBlockSpecRuleFn):
|
||||
self.pull_block_spec_rule = pull_block_spec_rule
|
||||
return pull_block_spec_rule
|
||||
|
||||
def def_push_block_spec(
|
||||
self, push_block_spec_rule: CustomPushBlockSpecRuleFn):
|
||||
self.push_block_spec_rule = push_block_spec_rule
|
||||
return push_block_spec_rule
|
||||
|
||||
def def_eval_rule(self, eval_rule: CustomEvalRuleFn):
|
||||
self.eval_rule = eval_rule
|
||||
return eval_rule
|
||||
|
||||
@functools.partial(api_boundary,
|
||||
repro_api_name="jax.pallas.custom_fusion.__call__")
|
||||
def __call__(self, *args, **kwargs):
|
||||
debug_fun = api_util.debug_info("custom_fusion fun", self.fun, args, kwargs)
|
||||
|
||||
# TODO(jburnim): Better error messages here.
|
||||
assert self.eval_rule is not None
|
||||
assert self.pull_block_spec_rule is not None
|
||||
|
||||
try:
|
||||
args = api_util.resolve_kwargs(self.fun, args, kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"The input arguments to the custom_fusion-decorated function "
|
||||
f"{debug_fun.func_name} could not be resolved to positional-only "
|
||||
f"arguments. Binding failed with the error:\n{e}"
|
||||
) from e
|
||||
|
||||
# flatten and get jaxpr
|
||||
args_flat, in_tree = tree_util.tree_flatten(args)
|
||||
in_avals = [core.typeof(x) for x in args_flat]
|
||||
flat_fun, out_tree = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(self.fun, debug_info=debug_fun.with_unknown_names()),
|
||||
in_tree)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
|
||||
|
||||
# if a Pallas implementation was provided, get its jaxpr
|
||||
if self.pallas_impl is not None:
|
||||
debug_pallas_impl = api_util.debug_info(
|
||||
"custom_fusion pallas_impl", self.pallas_impl, args, kwargs)
|
||||
|
||||
flat_pallas_impl, pallas_out_tree = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(self.pallas_impl, debug_info=debug_pallas_impl),
|
||||
in_tree)
|
||||
# TODO(jburnim): Error if out_tree() and kernel_out_tree() are different?
|
||||
del pallas_out_tree
|
||||
pallas_jaxpr, _, pallas_consts = (
|
||||
pe.trace_to_jaxpr_dynamic(flat_pallas_impl, in_avals))
|
||||
else:
|
||||
pallas_jaxpr = None
|
||||
pallas_consts = []
|
||||
|
||||
# debug_info for rules
|
||||
out_flat = custom_fusion_p.bind(
|
||||
*consts,
|
||||
*pallas_consts,
|
||||
*args_flat,
|
||||
jaxpr=jaxpr,
|
||||
num_consts=len(consts),
|
||||
eval_rule=self.eval_rule,
|
||||
pull_block_spec_rule=self.pull_block_spec_rule,
|
||||
push_block_spec_rule=self.push_block_spec_rule,
|
||||
pallas_jaxpr=pallas_jaxpr,
|
||||
pallas_num_consts=len(pallas_consts),
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree(),
|
||||
kernel_out_tree=out_tree())
|
||||
|
||||
return tree_util.tree_unflatten(out_tree(), out_flat)
|
||||
|
||||
|
||||
@custom_fusion_p.def_impl
|
||||
def _custom_fusion_impl(
|
||||
*args,
|
||||
jaxpr: core.Jaxpr,
|
||||
num_consts: int,
|
||||
pallas_num_consts: int,
|
||||
**_):
|
||||
consts, _, args = util.split_list(args, [num_consts, pallas_num_consts])
|
||||
return core.eval_jaxpr(jaxpr, consts, *args)
|
||||
|
||||
mlir.register_lowering(custom_fusion_p, mlir.lower_fun(
|
||||
_custom_fusion_impl, multiple_results=True))
|
||||
|
||||
|
||||
@custom_fusion_p.def_effectful_abstract_eval
|
||||
def _custom_fusion_effectful_abstract_eval(
|
||||
*args,
|
||||
jaxpr: core.Jaxpr,
|
||||
pallas_jaxpr: core.Jaxpr | None,
|
||||
**_):
|
||||
del args
|
||||
# TODO(jburnim): Error if pallas_jaxpr has different number of outputs, or
|
||||
# different shapes and types of outputs?
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError(
|
||||
"custom_fusion-decorated function {jaxpr.debug_info.func_src_info} "
|
||||
"has effects, which is not yet supported: {jaxpr.effects}")
|
||||
if pallas_jaxpr is not None and pallas_jaxpr.effects:
|
||||
raise NotImplementedError(
|
||||
"custom_fusion-decorated function {jaxpr.debug_info.func_src_info} "
|
||||
"has a pallas_impl with effects, which is not yet supported: "
|
||||
f"{pallas_jaxpr.effects}")
|
||||
return jaxpr.out_avals, jaxpr.effects
|
||||
|
||||
|
||||
@block_spec_lib.register_eval_rule(custom_fusion_p)
|
||||
def _custom_fusion_eval_rule(
|
||||
ctx: block_spec_lib.KernelEvalContext,
|
||||
*args,
|
||||
eval_rule: CustomEvalRuleFn,
|
||||
num_consts: int,
|
||||
pallas_num_consts: int,
|
||||
**_):
|
||||
args = args[num_consts + pallas_num_consts:]
|
||||
return eval_rule(CustomEvalContext(
|
||||
out_block_specs=ctx.out_block_specs,
|
||||
out_block_indices=ctx.get_out_block_indices(),
|
||||
), *args)
|
||||
|
||||
|
||||
# TODO(jburnim): Lowering rules for SC and Mosaic GPU.
|
||||
|
||||
@mosaic_lowering.register_lowering_rule(custom_fusion_p)
|
||||
def _custom_fusion_mosaic_lowering_rule(
|
||||
ctx: mosaic_lowering.LoweringRuleContext,
|
||||
*args,
|
||||
jaxpr: core.Jaxpr,
|
||||
num_consts: int,
|
||||
pallas_jaxpr: core.Jaxpr | None,
|
||||
pallas_num_consts: int,
|
||||
**_):
|
||||
consts, pallas_consts, args = util.split_list(
|
||||
args, [num_consts, pallas_num_consts])
|
||||
if pallas_jaxpr is None:
|
||||
pallas_jaxpr = jaxpr
|
||||
pallas_consts = consts
|
||||
lowering_context = ctx.lowering_context.replace(block_shapes=ctx.block_shapes)
|
||||
return mosaic_lowering.jaxpr_subcomp(
|
||||
lowering_context, pallas_jaxpr, *pallas_consts, *args)
|
||||
|
||||
|
||||
@block_spec_lib.register_pull_block_spec_rule(custom_fusion_p)
|
||||
def _custom_fusion_pull_block_spec_rule(
|
||||
ctx : block_spec_lib.PullRuleContext,
|
||||
out_block_transforms : tuple[block_spec_lib.BlockIndexTransform, ...],
|
||||
*,
|
||||
pull_block_spec_rule : CustomPullBlockSpecRuleFn,
|
||||
**_,
|
||||
) -> Sequence[block_spec_lib.BlockIndexTransform]:
|
||||
del ctx
|
||||
return pull_block_spec_rule(out_block_transforms)
|
||||
|
||||
|
||||
@block_spec_lib.register_push_block_spec_rule(custom_fusion_p)
|
||||
def _custom_fusion_push_block_spec_rule(
|
||||
ctx : block_spec_lib.PushRuleContext,
|
||||
*block_specs : pallas_core.BlockSpec,
|
||||
push_block_spec_rule : CustomPushBlockSpecRuleFn,
|
||||
**_
|
||||
) -> tuple[pallas_core.BlockSpec, ...]:
|
||||
del ctx
|
||||
# TODO(jburnim): Better error message if push_block_spec_rule is None.
|
||||
return push_block_spec_rule(block_specs)
|
||||
|
||||
|
||||
@block_spec_lib.register_usage_rule(custom_fusion_p)
|
||||
def _custom_fusion_usage_rule(
|
||||
ctx : block_spec_lib.UsageRuleContext,
|
||||
used_out: Sequence[set[block_spec_lib.Usage]],
|
||||
*,
|
||||
jaxpr: core.Jaxpr,
|
||||
**_
|
||||
) -> Sequence[set[block_spec_lib.Usage]]:
|
||||
del ctx
|
||||
# TODO(jburnim): Error if jaxpr.jaxpr gives different usage than pallas_jaxpr?
|
||||
read_usage_env = block_spec_lib.compute_usage(jaxpr, used_out)
|
||||
return util.safe_map(read_usage_env, jaxpr.invars)
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Basic utils for fuser internals."""
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
|
||||
|
||||
def make_jaxpr(f, *args, **kwargs):
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
flat_avals = [core.shaped_abstractify(x) for x in flat_args]
|
||||
debug_info = api_util.debug_info('make_jaxpr', f, args, kwargs)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f, debug_info=debug_info), in_tree
|
||||
)
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree_thunk()
|
||||
return jaxpr, consts, in_tree, out_tree
|
||||
@@ -0,0 +1,141 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fusible primitive."""
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas.fuser import fusion as fusion_lib
|
||||
|
||||
fusible_p = jax_core.Primitive('fusible')
|
||||
fusible_p.multiple_results = True
|
||||
|
||||
def _fusible_is_high(*_, jaxpr, **params):
|
||||
del params
|
||||
return jaxpr.is_high
|
||||
|
||||
fusible_p.is_high = _fusible_is_high
|
||||
|
||||
|
||||
def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
|
||||
return fusion_lib.Fusion(
|
||||
func=lambda: x,
|
||||
in_type=((), {}),
|
||||
out_type=jax.typeof(x),
|
||||
)
|
||||
|
||||
|
||||
@partial(api_boundary, repro_api_name="fuser.fusible")
|
||||
def fusible(f=None, *, output_fusion_prefix: Any = True):
|
||||
def decorator(f):
|
||||
def wrapper(*args):
|
||||
def wrapped(*args):
|
||||
in_fusions = tree_util.tree_map(_make_trivial_fusion, args)
|
||||
output_fusions = tree_util.tree_unflatten(
|
||||
tree_util.tree_structure(output_fusion_prefix),
|
||||
[None] * len(tree_util.tree_leaves(output_fusion_prefix)),
|
||||
)
|
||||
return f(*in_fusions, output_fusions)
|
||||
|
||||
flat_args, in_tree = tree_util.tree_flatten(args)
|
||||
debug_info = api_util.debug_info('fusible', wrapped, args, {})
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
|
||||
)
|
||||
flat_avals = [jax_core.typeof(x) for x in flat_args]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
out_tree = out_tree_thunk()
|
||||
out = fusible_p.bind(
|
||||
*consts,
|
||||
*flat_args,
|
||||
jaxpr=jaxpr,
|
||||
num_consts=len(consts),
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree,
|
||||
func=f,
|
||||
output_fusion_prefix=output_fusion_prefix,
|
||||
)
|
||||
return tree_util.tree_unflatten(out_tree, out)
|
||||
|
||||
return wrapper
|
||||
|
||||
if f is not None:
|
||||
return decorator(f)
|
||||
return decorator
|
||||
|
||||
|
||||
@fusible_p.def_impl
|
||||
def _(*consts_and_args, jaxpr, num_consts, **_):
|
||||
consts, args = util.split_list(consts_and_args, [num_consts])
|
||||
return jax_core.eval_jaxpr(jaxpr, consts, *args)
|
||||
|
||||
|
||||
mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl))
|
||||
|
||||
|
||||
@fusible_p.def_effectful_abstract_eval
|
||||
def _(*args, jaxpr, **kwargs):
|
||||
del args, kwargs
|
||||
return [v.aval for v in jaxpr.outvars], jaxpr.effects
|
||||
|
||||
|
||||
def _fusible_trivial_batching_rule(axis_data, args, dims, **kwargs):
|
||||
if axis_data.size != 1:
|
||||
raise NotImplementedError('fusible does not support non-trivial batching')
|
||||
|
||||
unbatched_args = tuple(
|
||||
a if (d is batching.not_mapped or d is None) else a[d]
|
||||
for a, d in zip(args, dims, strict=True)
|
||||
)
|
||||
out_unbatched = fusible_p.bind(*unbatched_args, **kwargs)
|
||||
out = tuple(o[None] for o in out_unbatched)
|
||||
|
||||
return out, (0,) * len(out)
|
||||
|
||||
batching.fancy_primitive_batchers[fusible_p] = _fusible_trivial_batching_rule
|
||||
|
||||
|
||||
def _fusible_to_lojax(*hi_args, jaxpr, num_consts, **_):
|
||||
const_in_avals = jaxpr.in_aval_qdds[:num_consts]
|
||||
num_lo_consts = sum(len(aval.lo_ty()) for aval in const_in_avals)
|
||||
|
||||
lo_args = [
|
||||
lo_val
|
||||
for aval, x in util.safe_zip(jaxpr.in_aval_qdds, hi_args)
|
||||
for lo_val in (aval.read_loval(x) if aval.has_qdd else aval.lower_val(x))
|
||||
]
|
||||
|
||||
closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, lo_args[:num_lo_consts])
|
||||
|
||||
lo_jaxpr = pe.lower_jaxpr2(closed_jaxpr)
|
||||
all_outs = fusible_p.bind(*lo_args, jaxpr=lo_jaxpr.jaxpr, num_consts=num_lo_consts)
|
||||
|
||||
out_mut, lo_outs = util.split_list(all_outs, [pe.num_himuts_out(jaxpr.final_aval_qdds)])
|
||||
for a, x, us in zip(jaxpr.final_aval_qdds, hi_args, out_mut):
|
||||
if a.has_qdd:
|
||||
a.aval.update_from_loval(a.qdd, x, *us)
|
||||
return pe.raise_lo_outs(jaxpr.out_avals, lo_outs)
|
||||
|
||||
|
||||
fusible_p.to_lojax = _fusible_to_lojax
|
||||
@@ -0,0 +1,567 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom fusible dtypes."""
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools as it
|
||||
from typing import Any, TypeVar
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax.control_flow import conditionals
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import pallas_call
|
||||
from jax._src.pallas import primitives as pallas_primitives
|
||||
from jax._src.pallas.fuser import block_spec
|
||||
from jax._src.pallas.fuser.fusible import fusible_p
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.util import foreach
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_physicalize_rules: dict[core.Primitive, Callable[..., Any]] = {}
|
||||
|
||||
pack_dtype_p = core.Primitive("pack_dtype")
|
||||
|
||||
|
||||
@pack_dtype_p.def_abstract_eval
|
||||
def pack_dtype_abstract_eval(*xs, dtype):
|
||||
if dtypes.issubdtype(dtype, FusibleElementDType):
|
||||
return dtype.abstract_pack(*xs)
|
||||
raise ValueError("Attempted to pack non-fusion dtype: {dtype}")
|
||||
|
||||
|
||||
def pack(*xs, dtype):
|
||||
return pack_dtype_p.bind(*xs, dtype=dtype)
|
||||
|
||||
|
||||
unpack_dtype_p = core.Primitive("unpack_dtype")
|
||||
unpack_dtype_p.multiple_results = True
|
||||
|
||||
|
||||
@unpack_dtype_p.def_abstract_eval
|
||||
def unpack_dtype_abstract_eval(x):
|
||||
if dtypes.issubdtype(x.dtype, FusibleElementDType):
|
||||
return x.dtype.abstract_unpack(x)
|
||||
elif isinstance(x.dtype, state.AbstractRef):
|
||||
raise NotImplementedError()
|
||||
raise ValueError("Attempted to unpack non-fusion dtype: {dtype}")
|
||||
|
||||
|
||||
def unpack(x):
|
||||
return tuple(unpack_dtype_p.bind(x))
|
||||
|
||||
|
||||
class FusibleElementDType(dtypes.extended):
|
||||
"""Scalar dtype for fusible dtypes."""
|
||||
|
||||
|
||||
class FusibleTyRules:
|
||||
allow_conversion: bool = False
|
||||
|
||||
|
||||
class FusionDType(dtypes.ExtendedDType, util.StrictABC):
|
||||
"""Base class for fusible extended dtypes."""
|
||||
|
||||
_op_registry = {}
|
||||
_rules = FusibleTyRules
|
||||
type = FusibleElementDType
|
||||
|
||||
@abc.abstractmethod
|
||||
def abstract_unpack(self, x) -> Sequence[Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def abstract_pack(self, *xs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def register_op(cls, primitive):
|
||||
def _register_fn(fn):
|
||||
cls._op_registry[primitive] = fn
|
||||
|
||||
return _register_fn
|
||||
|
||||
@classmethod
|
||||
def get_op_rule(cls, primitive):
|
||||
return cls._op_registry.get(primitive)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return str(self)
|
||||
|
||||
@abc.abstractmethod
|
||||
def pull_block_spec_one_step(self, aval_out, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def unpack_push_block_spec(self, aval_in, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def unpack_pull_block_spec(self, aval_in, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def physicalize(f):
|
||||
"""Runs a function that contains fusible extended dtypes."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if kwargs:
|
||||
raise NotImplementedError()
|
||||
flattened_args, treedef = jax.tree.flatten(args)
|
||||
debug_info = api_util.debug_info("physicalize", f, args, kwargs)
|
||||
wrapped_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(f, debug_info=debug_info), treedef
|
||||
)
|
||||
avals = [core.typeof(a) for a in flattened_args]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
||||
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
|
||||
out_flat = core.eval_jaxpr(
|
||||
new_jaxpr.jaxpr, new_jaxpr.consts, *flattened_args
|
||||
)
|
||||
return tree_util.tree_unflatten(out_tree_thunk(), out_flat)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@util.weakref_lru_cache
|
||||
def physicalize_closed_jaxpr(jaxpr: core.ClosedJaxpr) -> core.ClosedJaxpr:
|
||||
"""Replaces all extended dtypes with physical types in a jaxpr."""
|
||||
fun = functools.partial(physicalize_interp, jaxpr.jaxpr, jaxpr.consts)
|
||||
in_avals = [_physical_aval(aval) for aval in jaxpr.in_avals]
|
||||
flat_avals, treedef = tree_util.tree_flatten(in_avals)
|
||||
debug_info = api_util.debug_info("physicalize_closed_jaxpr", fun, (), {})
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(fun, debug_info=debug_info.with_unknown_names()), treedef
|
||||
)
|
||||
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals)
|
||||
assert len(new_jaxpr.constvars) == len(consts), "Mismatched consts"
|
||||
return core.ClosedJaxpr(new_jaxpr, consts)
|
||||
|
||||
|
||||
def _physical_aval(aval):
|
||||
if isinstance(aval, core.ShapedArray):
|
||||
if isinstance(aval.dtype, FusionDType):
|
||||
return aval.dtype.abstract_unpack(aval)
|
||||
return core.ShapedArray(aval.shape, aval.dtype)
|
||||
if isinstance(aval, state.AbstractRef):
|
||||
if _is_fusion_type(aval):
|
||||
unpacked = aval.dtype.abstract_unpack(aval.inner_aval)
|
||||
return tuple(aval.update(inner_aval=u) for u in unpacked)
|
||||
return aval
|
||||
return aval
|
||||
|
||||
|
||||
def physicalize_jaxpr(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
||||
"""Replaces all extended dtypes with physical types in a jaxpr."""
|
||||
|
||||
def _flat_jaxpr_eval(consts, args):
|
||||
return physicalize_interp(jaxpr, consts, *args)
|
||||
|
||||
in_avals = [_physical_aval(v.aval) for v in jaxpr.invars]
|
||||
const_avals = [_physical_aval(v.aval) for v in jaxpr.constvars]
|
||||
flat_avals, treedef = jax.tree.flatten((const_avals, in_avals))
|
||||
debug_info = api_util.debug_info(
|
||||
"physicalize_jaxpr", _flat_jaxpr_eval, (const_avals, in_avals), {})
|
||||
wrapped_fun, _ = api_util.flatten_fun_nokwargs(
|
||||
lu.wrap_init(_flat_jaxpr_eval, debug_info=debug_info), treedef
|
||||
)
|
||||
new_jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, flat_avals)
|
||||
assert not consts
|
||||
new_jaxpr = pe.convert_invars_to_constvars(
|
||||
new_jaxpr, len(tree_util.tree_leaves(const_avals))
|
||||
)
|
||||
return new_jaxpr
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Context:
|
||||
avals_in: Sequence[Any]
|
||||
avals_out: Sequence[Any]
|
||||
|
||||
|
||||
def physicalize_interp(
|
||||
jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value
|
||||
):
|
||||
"""Physicalizes a jaxpr by replacing fusible dtypes with physical types."""
|
||||
# TODO: Merge into JAX core.
|
||||
env: dict[core.Var, Any] = {}
|
||||
|
||||
def read_env(var: core.Atom):
|
||||
if isinstance(var, core.Literal):
|
||||
return var.val
|
||||
return env[var]
|
||||
|
||||
def write_env(var: core.Var, val: Any):
|
||||
env[var] = val
|
||||
|
||||
foreach(write_env, jaxpr.constvars, consts)
|
||||
assert len(jaxpr.invars) == len(
|
||||
args
|
||||
), f"Length mismatch: {jaxpr.invars} != {args}"
|
||||
foreach(write_env, jaxpr.invars, args)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
invals = list(map(read_env, eqn.invars))
|
||||
avals_in = tuple(x.aval for x in eqn.invars)
|
||||
name_stack = (
|
||||
source_info_util.current_name_stack() + eqn.source_info.name_stack
|
||||
)
|
||||
with (
|
||||
source_info_util.user_context(
|
||||
eqn.source_info.traceback, name_stack=name_stack
|
||||
),
|
||||
eqn.ctx.manager,
|
||||
):
|
||||
# need to check types and then invoke the correct rule.
|
||||
ctx = Context(
|
||||
avals_in=avals_in, avals_out=[var.aval for var in eqn.outvars]
|
||||
)
|
||||
custom_rule = _phys_find_rule(eqn.primitive, avals_in)
|
||||
if custom_rule:
|
||||
outvals = custom_rule(ctx, *invals, **eqn.params)
|
||||
else:
|
||||
bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
outvals = eqn.primitive.bind(*invals, **bind_params)
|
||||
|
||||
if eqn.primitive.multiple_results:
|
||||
assert len(outvals) == len(eqn.outvars), eqn
|
||||
foreach(write_env, eqn.outvars, outvals)
|
||||
else:
|
||||
write_env(eqn.outvars[0], outvals)
|
||||
|
||||
return map(read_env, jaxpr.outvars)
|
||||
|
||||
|
||||
def _is_fusion_type(aval: core.AbstractValue):
|
||||
"""Returns whether an aval is an array containing fusion types."""
|
||||
return (
|
||||
isinstance(aval, (core.ShapedArray, state.AbstractRef))
|
||||
and hasattr(aval, 'dtype')
|
||||
and isinstance(aval.dtype, FusionDType)
|
||||
)
|
||||
|
||||
|
||||
def _phys_find_rule(primitive, avals: Sequence[core.AbstractValue]):
|
||||
"""Finds the physicalization rule for a primitive."""
|
||||
if primitive in _physicalize_rules:
|
||||
return _physicalize_rules[primitive]
|
||||
|
||||
# pyrefly: ignore[missing-attribute]
|
||||
fusion_types = {aval.dtype for aval in avals if _is_fusion_type(aval)}
|
||||
if len(fusion_types) == 0:
|
||||
return None
|
||||
elif len(fusion_types) > 1:
|
||||
raise ValueError(f"Multiple fusion types for primitive: {fusion_types}")
|
||||
fusion_type = fusion_types.pop()
|
||||
if primitive not in fusion_type._op_registry:
|
||||
raise ValueError(
|
||||
f"No implementation found for primitive {primitive} "
|
||||
f"for custom type {fusion_type}"
|
||||
)
|
||||
return fusion_type.get_op_rule(primitive)
|
||||
|
||||
|
||||
def _assert_no_fusion_types(avals: Sequence[core.AbstractValue]):
|
||||
if any(_is_fusion_type(aval) for aval in avals):
|
||||
raise NotImplementedError(f"Fusion type found in avals: {avals}")
|
||||
|
||||
|
||||
def _pallas_call_physicalize_rule(
|
||||
ctx: Context, *args, jaxpr, grid_mapping: pallas_core.GridMapping, **kwargs
|
||||
):
|
||||
_assert_no_fusion_types(ctx.avals_in)
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
with grid_mapping.trace_env():
|
||||
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, ()))
|
||||
if diff := len(new_jaxpr.jaxpr.invars) - len(jaxpr.invars):
|
||||
num_scratch_avals = len(grid_mapping.scratch_avals) + diff
|
||||
new_scratch_avals = tuple(v.aval for v in
|
||||
new_jaxpr.jaxpr.invars[-num_scratch_avals:])
|
||||
grid_mapping = grid_mapping.replace(
|
||||
scratch_avals=new_scratch_avals
|
||||
)
|
||||
return pallas_call.pallas_call_p.bind(
|
||||
*args, jaxpr=new_jaxpr.jaxpr, grid_mapping=grid_mapping, **kwargs
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[pallas_call.pallas_call_p] = _pallas_call_physicalize_rule
|
||||
|
||||
|
||||
def _cond_physicalize_rule(ctx: Context, *args, branches, **kwargs):
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
physicalized_branches = tuple(
|
||||
physicalize_closed_jaxpr(branch) for branch in branches
|
||||
)
|
||||
flat_args = jax.tree.leaves(args)
|
||||
return conditionals.cond_p.bind(
|
||||
*flat_args, branches=physicalized_branches, **kwargs
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[conditionals.cond_p] = _cond_physicalize_rule
|
||||
|
||||
|
||||
@lu.transformation2
|
||||
def _physicalize_transform(f, *args):
|
||||
vals, zeros = args[::2], args[1::2]
|
||||
assert len(vals) == len(zeros)
|
||||
wrapper = lambda *inner_vals: f(
|
||||
*it.chain.from_iterable(zip(inner_vals, zeros))
|
||||
)
|
||||
return physicalize(wrapper)(*vals)
|
||||
|
||||
|
||||
@lu.transformation2
|
||||
def _physicalize_transform_bwd(f, const_avals, *args):
|
||||
return [custom_derivatives.Zero(a) for a in const_avals] + list(
|
||||
physicalize(f)(*args)
|
||||
)
|
||||
|
||||
|
||||
def _custom_vjp_call_physicalize_rule(
|
||||
ctx: Context, *args, call_jaxpr, num_consts, fwd_jaxpr_thunk, bwd, **kwargs
|
||||
):
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
new_jaxpr = physicalize_closed_jaxpr(call_jaxpr)
|
||||
fun = lu.wrap_init(core.jaxpr_as_fun(new_jaxpr),
|
||||
debug_info=call_jaxpr.jaxpr.debug_info)
|
||||
fwd = custom_derivatives.lift_fwd(num_consts, fwd_jaxpr_thunk)
|
||||
fwd_physicalized = _physicalize_transform(fwd)
|
||||
const_avals, _ = util.split_list(new_jaxpr.in_avals, [num_consts])
|
||||
bwd_physicalized = _physicalize_transform_bwd(bwd, const_avals)
|
||||
kwargs['subfuns'] = (fun, fwd_physicalized, bwd_physicalized)
|
||||
return custom_derivatives.custom_vjp_call_p.bind(*args, **kwargs)
|
||||
|
||||
_physicalize_rules[custom_derivatives.custom_vjp_call_p] = _custom_vjp_call_physicalize_rule
|
||||
|
||||
|
||||
def _run_state_rule(ctx: Context, *args, jaxpr, which_linear, is_initialized):
|
||||
_assert_no_fusion_types(ctx.avals_in)
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
jaxpr = physicalize_jaxpr(jaxpr)
|
||||
return state_discharge.run_state_p.bind(
|
||||
*args,
|
||||
jaxpr=jaxpr,
|
||||
which_linear=which_linear,
|
||||
is_initialized=is_initialized,
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[state_discharge.run_state_p] = _run_state_rule
|
||||
|
||||
|
||||
def _core_map_rule(ctx: Context, *args, jaxpr, **params):
|
||||
_assert_no_fusion_types(ctx.avals_in)
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
assert not jaxpr.invars
|
||||
with core.extend_axis_env_nd(params["mesh"].shape.items()):
|
||||
jaxpr = physicalize_jaxpr(jaxpr)
|
||||
return pallas_core.core_map_p.bind(*args, jaxpr=jaxpr, **params)
|
||||
|
||||
|
||||
_physicalize_rules[pallas_core.core_map_p] = _core_map_rule
|
||||
|
||||
|
||||
def _run_scoped_rule(ctx: Context, *args, jaxpr, **params):
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
jaxpr = physicalize_jaxpr(jaxpr)
|
||||
flat_args = tree_util.tree_leaves(args)
|
||||
assert len(flat_args) == len(
|
||||
jaxpr.constvars
|
||||
), f"Length mismatch: {len(flat_args)=} != {len(jaxpr.constvars)=}"
|
||||
return pallas_primitives.run_scoped_p.bind(*flat_args, jaxpr=jaxpr, **params)
|
||||
|
||||
|
||||
_physicalize_rules[pallas_primitives.run_scoped_p] = _run_scoped_rule
|
||||
|
||||
|
||||
def _scan_rule(ctx: Context, *args, jaxpr, **params):
|
||||
_assert_no_fusion_types(ctx.avals_in)
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
jaxpr = physicalize_closed_jaxpr(jaxpr)
|
||||
return jax.lax.scan_p.bind(*args, jaxpr=jaxpr, **params)
|
||||
|
||||
|
||||
_physicalize_rules[jax.lax.scan_p] = _scan_rule
|
||||
|
||||
|
||||
def _while_rule(
|
||||
ctx: Context, *args, body_jaxpr, cond_jaxpr, body_nconsts,
|
||||
cond_nconsts, **params
|
||||
):
|
||||
_assert_no_fusion_types(ctx.avals_out)
|
||||
cond_avals = [v.aval for v in cond_jaxpr.jaxpr.invars]
|
||||
_, cond_in_avals = util.split_list(cond_avals, [cond_nconsts])
|
||||
_assert_no_fusion_types(cond_in_avals)
|
||||
new_cond_jaxpr = physicalize_closed_jaxpr(cond_jaxpr)
|
||||
new_num_cond_consts = (
|
||||
cond_nconsts
|
||||
+ len(new_cond_jaxpr.jaxpr.invars)
|
||||
- len(cond_jaxpr.jaxpr.invars)
|
||||
)
|
||||
|
||||
body_avals = [v.aval for v in body_jaxpr.jaxpr.invars]
|
||||
_, body_in_avals = util.split_list(body_avals, [body_nconsts])
|
||||
_assert_no_fusion_types(body_in_avals)
|
||||
new_body_jaxpr = physicalize_closed_jaxpr(body_jaxpr)
|
||||
new_num_body_consts = (
|
||||
body_nconsts
|
||||
+ len(new_body_jaxpr.jaxpr.invars)
|
||||
- len(body_jaxpr.jaxpr.invars)
|
||||
)
|
||||
flat_args = tree_util.tree_leaves(args)
|
||||
cond_consts, body_consts, flat_args = util.split_list(
|
||||
flat_args, [new_num_cond_consts, new_num_body_consts]
|
||||
)
|
||||
assert len(flat_args) + len(body_consts) == len(
|
||||
new_body_jaxpr.jaxpr.invars), (
|
||||
f"Length mismatch: {len(flat_args) + len(body_consts)} !="
|
||||
f" {len(new_body_jaxpr.jaxpr.invars)=}"
|
||||
)
|
||||
assert len(flat_args) + len(cond_consts) == len(
|
||||
new_cond_jaxpr.jaxpr.invars), (
|
||||
f"Length mismatch: {len(flat_args) + len(cond_consts)} !="
|
||||
f" {len(new_cond_jaxpr.jaxpr.invars)=}"
|
||||
)
|
||||
return jax.lax.while_p.bind(
|
||||
*(cond_consts + body_consts + flat_args),
|
||||
body_jaxpr=new_body_jaxpr,
|
||||
cond_jaxpr=new_cond_jaxpr,
|
||||
body_nconsts=new_num_body_consts,
|
||||
cond_nconsts=new_num_cond_consts,
|
||||
**params,
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[jax.lax.while_p] = _while_rule
|
||||
|
||||
|
||||
def _pack_rule(_, *args, dtype):
|
||||
del dtype
|
||||
return args
|
||||
|
||||
|
||||
_physicalize_rules[pack_dtype_p] = _pack_rule
|
||||
|
||||
|
||||
def _unpack_rule(_, arg):
|
||||
return arg
|
||||
|
||||
|
||||
_physicalize_rules[unpack_dtype_p] = _unpack_rule
|
||||
|
||||
|
||||
def _swap_rule(ctx: Context, ref, val, *args, tree):
|
||||
ref_aval, *_ = ctx.avals_in
|
||||
if not _is_fusion_type(ref_aval):
|
||||
return state_primitives.swap_p.bind(ref, val, *args, tree=tree)
|
||||
return ref_aval.dtype.swap(ref, val, *args, tree=tree)
|
||||
|
||||
|
||||
_physicalize_rules[state_primitives.swap_p] = _swap_rule
|
||||
|
||||
|
||||
def _get_rule(ctx: Context, ref, *args, tree):
|
||||
ref_aval, *_ = ctx.avals_in
|
||||
if not _is_fusion_type(ref_aval):
|
||||
return state_primitives.get_p.bind(ref, *args, tree=tree)
|
||||
return ref_aval.dtype.get(ref, *args, tree=tree)
|
||||
|
||||
|
||||
_physicalize_rules[state_primitives.get_p] = _get_rule
|
||||
|
||||
|
||||
@block_spec.register_eval_rule(pack_dtype_p)
|
||||
def _pack_dtype_eval_rule(ctx: block_spec.KernelEvalContext, *args, dtype):
|
||||
return dtype.pack_eval_rule(ctx, *args)
|
||||
|
||||
|
||||
@block_spec.register_pull_block_spec_rule(pack_dtype_p)
|
||||
def _pack_dtype_pull_rule(
|
||||
ctx: block_spec.PullRuleContext,
|
||||
block_spec: pallas_core.BlockSpec,
|
||||
*,
|
||||
dtype: FusionDType,
|
||||
):
|
||||
aval_out = ctx.avals_out[0]
|
||||
return dtype.pull_block_spec_one_step(aval_out, block_spec)
|
||||
|
||||
|
||||
@block_spec.register_push_block_spec_rule(unpack_dtype_p)
|
||||
def _unpack_dtype_push_rule(
|
||||
ctx: block_spec.PushRuleContext,
|
||||
block_spec: pallas_core.BlockSpec,
|
||||
):
|
||||
aval_in = ctx.avals_in[0]
|
||||
assert isinstance(aval_in, core.ShapedArray)
|
||||
assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype
|
||||
return aval_in.dtype.unpack_push_block_spec(aval_in, block_spec)
|
||||
|
||||
|
||||
@block_spec.register_pull_block_spec_rule(unpack_dtype_p)
|
||||
def _unpack_dtype_pull_rule(
|
||||
ctx: block_spec.PushRuleContext,
|
||||
block_specs: pallas_core.BlockSpec,
|
||||
):
|
||||
aval_in = ctx.avals_in[0]
|
||||
assert isinstance(aval_in, core.ShapedArray)
|
||||
assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype
|
||||
return aval_in.dtype.unpack_pull_block_spec(aval_in, *block_specs) # pyrefly: ignore[not-iterable]
|
||||
|
||||
|
||||
@block_spec.register_eval_rule(unpack_dtype_p)
|
||||
def _unpack_dtype_eval_rule(ctx: block_spec.KernelEvalContext, *args):
|
||||
assert ctx.avals_in is not None
|
||||
aval_in = ctx.avals_in[0]
|
||||
assert isinstance(aval_in, core.ShapedArray)
|
||||
assert isinstance(aval_in.dtype, FusionDType), aval_in.dtype
|
||||
return aval_in.dtype.unpack_eval_rule(ctx, *args) # pyrefly: ignore[missing-attribute]
|
||||
|
||||
|
||||
def _fusible_physicalize_rule(
|
||||
_, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func
|
||||
):
|
||||
consts, _ = util.split_list(consts_and_args, [num_consts])
|
||||
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
|
||||
return fusible_p.bind(
|
||||
*consts_and_args,
|
||||
jaxpr=new_jaxpr.jaxpr,
|
||||
num_consts=num_consts,
|
||||
in_tree=in_tree,
|
||||
out_tree=out_tree,
|
||||
func=func,
|
||||
)
|
||||
|
||||
|
||||
_physicalize_rules[fusible_p] = _fusible_physicalize_rule
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fusion classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Generic, ParamSpec, TypeVar
|
||||
from collections.abc import Callable
|
||||
|
||||
import jax
|
||||
from jax._src import util
|
||||
|
||||
safe_map = util.safe_map
|
||||
|
||||
A = ParamSpec("A")
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Fusion(Generic[A, K]):
|
||||
|
||||
func: Callable[A, K]
|
||||
in_type: tuple[tuple[Any, ...], dict[str, Any]]
|
||||
out_type: Any
|
||||
|
||||
def __call__(self, *args: A.args, **kwargs: A.kwargs) -> K:
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return jax.tree.map(lambda x: x.shape, self.out_type)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return jax.tree.map(lambda x: x.dtype, self.out_type)
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.out_type
|
||||
|
||||
@property
|
||||
def in_shape(self):
|
||||
return jax.tree.map(lambda x: x.shape, self.in_type)
|
||||
|
||||
@property
|
||||
def in_dtype(self):
|
||||
return jax.tree.map(lambda x: x.dtype, self.in_type)
|
||||
@@ -0,0 +1,318 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Fuses a function."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
from typing import Any
|
||||
import jax
|
||||
from jax._src import api_util
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas.fuser import fusible_dtype
|
||||
from jax._src.pallas.fuser import fusion as fusion_lib
|
||||
from jax._src.pallas.fuser.fusible import fusible_p
|
||||
|
||||
|
||||
@functools.partial(api_boundary, repro_api_name="fuser.fuse")
|
||||
def fuse(f=None, *, resolve_fusion_dtypes: bool = True, debug: bool = False):
|
||||
"""Fuses a function into a single fusible.
|
||||
|
||||
Args:
|
||||
f: The function to fuse.
|
||||
resolve_fusion_dtypes: (experimental) whether or not to resolve fusion
|
||||
dtypes (which don't correspond to physical dtypes)
|
||||
debug: Whether to print debug information.
|
||||
|
||||
There should be a single call to a `fusible` inside the body of `f`. `fuse`
|
||||
returns a transformed function that will fuse the surrounding computation into
|
||||
the fusible and invoke it.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
def wrapper(*args, **kwargs):
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
debug_info = api_util.debug_info("fuse", f, args, kwargs)
|
||||
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(f, debug_info=debug_info), in_tree
|
||||
)
|
||||
flat_avals = [jax_core.typeof(x) for x in flat_args]
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
|
||||
if debug:
|
||||
print("Jaxpr before fusion:")
|
||||
print(jaxpr)
|
||||
out_tree = out_tree_thunk()
|
||||
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
if resolve_fusion_dtypes:
|
||||
wrapper = fusible_dtype.physicalize(wrapper)
|
||||
return wrapper
|
||||
|
||||
if f is not None:
|
||||
return decorator(f)
|
||||
return decorator
|
||||
|
||||
|
||||
_fusible: dict[jax_core.Primitive, Any] = {}
|
||||
|
||||
|
||||
def _construct_fusion_jaxpr(
|
||||
candidate_values, jaxpr: jax_core.Jaxpr, outvars, *invars, **kwargs
|
||||
):
|
||||
flat_outvars, out_tree = tree_util.tree_flatten(outvars)
|
||||
flat_invars, in_tree = tree_util.tree_flatten((invars, kwargs))
|
||||
new_jaxpr_no_dce = jaxpr.replace(
|
||||
outvars=flat_outvars,
|
||||
constvars=jaxpr.constvars + jaxpr.invars,
|
||||
invars=flat_invars,
|
||||
debug_info=jaxpr.debug_info.with_unknown_names()
|
||||
)
|
||||
new_jaxpr, used_consts, used_invars = pe.dce_jaxpr_consts(
|
||||
new_jaxpr_no_dce,
|
||||
[True] * len(new_jaxpr_no_dce.outvars),
|
||||
instantiate=[False] * len(new_jaxpr_no_dce.constvars)
|
||||
+ [True] * len(new_jaxpr_no_dce.invars),
|
||||
)
|
||||
assert all(used_invars), new_jaxpr_no_dce
|
||||
new_values = tuple(
|
||||
c for used, c in zip(used_consts, candidate_values, strict=True) if used
|
||||
)
|
||||
kernel_in_tree = tree_util.tree_structure((invars, kwargs))
|
||||
flat_in_type = [x.aval for x in flat_invars]
|
||||
in_type = tree_util.tree_unflatten(kernel_in_tree, flat_in_type)
|
||||
out_type = tree_util.tree_unflatten(
|
||||
out_tree,
|
||||
[x.aval for x in flat_outvars],
|
||||
)
|
||||
return new_jaxpr, new_values, in_type, out_type, out_tree
|
||||
|
||||
|
||||
def construct_input_fusion(
|
||||
candidate_values, jaxpr: jax_core.Jaxpr, outvars
|
||||
) -> fusion_lib.Fusion:
|
||||
new_jaxpr, new_values, in_type, out_type, out_tree = _construct_fusion_jaxpr(
|
||||
candidate_values, jaxpr, outvars,
|
||||
)
|
||||
|
||||
def _fn():
|
||||
out_flat = jax_core.eval_jaxpr(new_jaxpr, new_values)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
return fusion_lib.Fusion(_fn, in_type, out_type)
|
||||
|
||||
|
||||
def _find_downstream(
|
||||
jaxpr: jax_core.Jaxpr, in_used: Sequence[bool]
|
||||
) -> tuple[bool, ...]:
|
||||
# TODO(sharadmv): We use partial_eval to query downstream dependencies which
|
||||
# is not an officially sanctioned way to do so, since PE is really used for
|
||||
# AD. In the future, we should have a special Jaxpr API that queries this.
|
||||
_, _, out_used, *_ = pe.partial_eval_jaxpr_custom(
|
||||
jaxpr,
|
||||
in_unknowns=in_used,
|
||||
in_inst=in_used,
|
||||
ensure_out_unknowns=False,
|
||||
ensure_out_inst=False,
|
||||
saveable=lambda *_, **__: False,
|
||||
)
|
||||
return tuple(out_used)
|
||||
|
||||
|
||||
def _construct_output_permutation(
|
||||
used: list[tuple[bool, ...]],
|
||||
) -> list[int]:
|
||||
order = []
|
||||
for u in used:
|
||||
true_vals = [i for i in range(len(u)) if u[i]]
|
||||
order.extend(true_vals)
|
||||
return [order.index(i) for i in range(len(order))]
|
||||
|
||||
|
||||
def _construct_output_fusions(
|
||||
candidate_values,
|
||||
jaxpr,
|
||||
out_tree,
|
||||
fusion_eqn_index,
|
||||
fusion_eqn_outvars, # Flat list of vars output by the fusible eqn
|
||||
fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs
|
||||
output_fusion_prefix, # Pytree defining output groups
|
||||
):
|
||||
# 1. Create jaxpr_out: represents computation *after* the fusible
|
||||
# Inputs: fusion_eqn_outvars
|
||||
# Outputs: jaxpr.outvars
|
||||
jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr(
|
||||
candidate_values,
|
||||
jaxpr.replace(
|
||||
eqns=jaxpr.eqns[:fusion_eqn_index]
|
||||
+ jaxpr.eqns[fusion_eqn_index + 1 :]
|
||||
),
|
||||
tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs
|
||||
tree_util.tree_unflatten(
|
||||
fusion_eqn_out_tree, fusion_eqn_outvars
|
||||
), # Fusible outputs as inputs
|
||||
)
|
||||
|
||||
# 2. Group fusible outputs based on the mask
|
||||
unflat_fusible_outvars = jax.tree.unflatten(
|
||||
fusion_eqn_out_tree, fusion_eqn_outvars
|
||||
)
|
||||
partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to(
|
||||
unflat_fusible_outvars
|
||||
)
|
||||
|
||||
# 3. Calculate dependencies and check disjointedness
|
||||
downstream_outputs_used_masks = [] # List of bool tuples, one per group
|
||||
already_used_final_outputs = set() # Indices of final outputs already claimed
|
||||
for outvars_group in partial_flat:
|
||||
# Identify vars in this group
|
||||
used_fusible_outvars = set(jax.tree.leaves(outvars_group))
|
||||
# Create mask for jaxpr_out inputs corresponding to this group
|
||||
in_used_mask = [
|
||||
True if v in used_fusible_outvars else False for v in jaxpr_out.invars
|
||||
]
|
||||
# Trace dependencies through jaxpr_out to find which final outputs are affected
|
||||
downstream_used_mask = _find_downstream(
|
||||
jaxpr_out, in_used_mask
|
||||
) # Mask for jaxpr_out.outvars (== jaxpr.outvars)
|
||||
|
||||
# Check for overlap in final output usage across groups
|
||||
for i, used in enumerate(downstream_used_mask):
|
||||
if used:
|
||||
if i in already_used_final_outputs:
|
||||
raise ValueError(
|
||||
"Outputs must be disjoint in order to use separate output fusions"
|
||||
)
|
||||
already_used_final_outputs.add(i)
|
||||
downstream_outputs_used_masks.append(downstream_used_mask)
|
||||
|
||||
# 4. Construct output permutation needed to restore original output order
|
||||
output_permutation = _construct_output_permutation(
|
||||
downstream_outputs_used_masks
|
||||
)
|
||||
|
||||
# Construct fusions for each group by DCEing the jaxpr_out
|
||||
output_fusions: list[fusion_lib.Fusion | None] = []
|
||||
for i, outvars_group in enumerate(partial_flat):
|
||||
flat_group_vars, _ = tree_util.tree_flatten(outvars_group)
|
||||
downstream_used_mask = downstream_outputs_used_masks[i]
|
||||
|
||||
used_jaxpr_invars = [False] * len(all_values) + [
|
||||
v in flat_group_vars for v in jaxpr_out.invars
|
||||
]
|
||||
jaxpr_out_for_group, used_consts, _ = pe.dce_jaxpr_consts(
|
||||
jaxpr_out, downstream_used_mask, instantiate=used_jaxpr_invars
|
||||
)
|
||||
values_for_jaxpr = tuple(
|
||||
c for used, c in zip(used_consts, all_values, strict=True) if used
|
||||
)
|
||||
|
||||
if (
|
||||
not jaxpr_out_for_group.eqns
|
||||
and jaxpr_out_for_group.outvars == jaxpr_out_for_group.invars
|
||||
):
|
||||
output_fusions.append(None)
|
||||
continue
|
||||
|
||||
def _fn(jaxpr, vals, *args, **kwargs):
|
||||
flat_args, _ = tree_util.tree_flatten((args, kwargs))
|
||||
out_flat = jax_core.eval_jaxpr(jaxpr, vals, *flat_args)
|
||||
return tuple(out_flat)
|
||||
|
||||
fn = functools.partial(_fn, jaxpr_out_for_group, values_for_jaxpr)
|
||||
in_type = jax.tree.map(lambda x: x.aval, outvars_group)
|
||||
out_type = tuple(v.aval for v in jaxpr_out_for_group.outvars)
|
||||
fusion = fusion_lib.Fusion(
|
||||
fn,
|
||||
(in_type, {}),
|
||||
out_type,
|
||||
)
|
||||
output_fusions.append(fusion)
|
||||
|
||||
return (
|
||||
tree_util.tree_unflatten(
|
||||
tree_util.tree_structure(output_fusion_prefix), output_fusions
|
||||
),
|
||||
output_permutation,
|
||||
)
|
||||
|
||||
|
||||
def fuse_jaxpr(
|
||||
jaxpr: jax_core.Jaxpr, out_tree: tree_util.PyTreeDef, consts, *args
|
||||
):
|
||||
fusion_eqn_index = None
|
||||
|
||||
# Collect input fusions
|
||||
for i, eqn in enumerate(jaxpr.eqns):
|
||||
if eqn.primitive is fusible_p:
|
||||
fusion_eqn_index = i
|
||||
break
|
||||
if fusion_eqn_index is None:
|
||||
raise ValueError("No fusible eqn found")
|
||||
fusion_eqn = jaxpr.eqns[fusion_eqn_index]
|
||||
|
||||
# Now let's check if we need to do any fusion at all, e.g. do the outputs of
|
||||
# the jaxpr have any dependence on the fusion at all?
|
||||
candidate_values = [*consts, *args]
|
||||
independent_jaxpr, _, out_used, *_ = pe.partial_eval_jaxpr_custom(
|
||||
jaxpr.replace(
|
||||
eqns=(jaxpr.eqns[:fusion_eqn_index]
|
||||
+ jaxpr.eqns[fusion_eqn_index + 1 :]),
|
||||
constvars=jaxpr.constvars + jaxpr.invars,
|
||||
invars=fusion_eqn.outvars,
|
||||
debug_info=jaxpr.debug_info.with_unknown_names()),
|
||||
in_unknowns=[True] * len(fusion_eqn.outvars),
|
||||
in_inst=[True] * len(fusion_eqn.outvars),
|
||||
ensure_out_unknowns=False,
|
||||
ensure_out_inst=False,
|
||||
saveable=lambda *_, **__: False)
|
||||
if not any(out_used):
|
||||
# Short circuit if there is no need to run the fusible at all.
|
||||
return jax_core.eval_jaxpr(independent_jaxpr, candidate_values)
|
||||
|
||||
# Construct fusions for non-constant inputs to the fusible.
|
||||
in_fusions_flat = [
|
||||
construct_input_fusion(
|
||||
candidate_values,
|
||||
jaxpr.replace(
|
||||
eqns=jaxpr.eqns[:fusion_eqn_index],
|
||||
),
|
||||
var,
|
||||
)
|
||||
for var in fusion_eqn.invars[fusion_eqn.params["num_consts"] :]
|
||||
]
|
||||
in_fusions = tree_util.tree_unflatten(
|
||||
fusion_eqn.params["in_tree"], in_fusions_flat
|
||||
)
|
||||
output_fusions, output_permutation = _construct_output_fusions(
|
||||
candidate_values,
|
||||
jaxpr,
|
||||
out_tree,
|
||||
fusion_eqn_index,
|
||||
fusion_eqn.outvars,
|
||||
fusion_eqn.params["out_tree"],
|
||||
fusion_eqn.params["output_fusion_prefix"],
|
||||
)
|
||||
out = fusion_eqn.params["func"](*in_fusions, output_fusions)
|
||||
flat_out = jax.tree.leaves(out)
|
||||
permuted_out = [flat_out[i] for i in output_permutation]
|
||||
assert len(permuted_out) == len(jaxpr.outvars), (
|
||||
len(permuted_out),
|
||||
len(jaxpr.outvars),
|
||||
)
|
||||
return permuted_out
|
||||
Reference in New Issue
Block a user