hand
This commit is contained in:
@@ -0,0 +1,87 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import tree_util
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
|
||||
def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs):
|
||||
del unused_kwargs
|
||||
return jax_core.jaxpr_as_fun(jaxpr)(*args)
|
||||
|
||||
|
||||
def _custom_abstract_eval(*args, jaxpr, **unused_kwargs):
|
||||
del unused_kwargs
|
||||
del args
|
||||
return jaxpr.out_avals
|
||||
|
||||
|
||||
cudnn_fusion_p = jax_core.Primitive("cudnn_fusion")
|
||||
cudnn_fusion_p.multiple_results = True
|
||||
cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval)
|
||||
cudnn_fusion_p.def_impl(_cudnn_fusion_impl)
|
||||
|
||||
|
||||
def call_cudnn_fusion(f, *args, **kwargs):
|
||||
"""Creates a new cudnn_fusion corresponding to calling
|
||||
the given function f with args and kwargs."""
|
||||
jaxpr, out_shapes = api.make_jaxpr(
|
||||
functools.partial(f, **kwargs), return_shape=True
|
||||
)(*args)
|
||||
flat_args = tree_util.tree_leaves(args)
|
||||
out_tree = tree_util.tree_structure(out_shapes)
|
||||
out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
|
||||
def _cudnn_fusion_stablehlo_lowering(ctx, *args, name, jaxpr):
|
||||
"""Make cudnn_fusion which calls the implementation function.
|
||||
Currently this leaks a CallOp since we're using the `core_call_lowering`
|
||||
function, but this should get cleaned up by DCE easily.
|
||||
"""
|
||||
impl = mlir.core_call_lowering(
|
||||
ctx, *args, name=name + ".impl", call_jaxpr=jaxpr
|
||||
)
|
||||
call_op = impl[0].owner
|
||||
called_fn = call_op.attributes["callee"]
|
||||
cudnn_fusion = hlo.CustomCallOp(
|
||||
[r.type for r in call_op.results],
|
||||
call_op.operands,
|
||||
call_target_name="__cudnn$fusion",
|
||||
called_computations=ir.ArrayAttr.get([called_fn]),
|
||||
)
|
||||
return cudnn_fusion.results
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda"
|
||||
)
|
||||
|
||||
|
||||
def cudnn_fusion(f):
|
||||
"""Makes a function become a cuDNN kernel. Relies on XLA's handling of
|
||||
custom fusions with __cudnn$fusion backend. Currently limited to GEMM
|
||||
fusions. For example - batch matmul with mixed types and addition:
|
||||
|
||||
@cudnn_fusion
|
||||
def fn(x, y, z):
|
||||
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
|
||||
"""
|
||||
return functools.partial(call_cudnn_fusion, f)
|
||||
Reference in New Issue
Block a user