hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,32 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains Triton-specific Pallas abstractions."""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CompilerParams:
|
||||
"""Compiler parameters for Triton.
|
||||
|
||||
Attributes:
|
||||
num_warps: The number of warps to use for the kernel. Each warp consists of
|
||||
32 threads.
|
||||
num_stages: The number of stages the compiler should use for software
|
||||
pipelining loops.
|
||||
"""
|
||||
num_warps: int | None = None
|
||||
num_stages: int | None = None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,209 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module registering a lowering rule for pallas_call on GPU."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import zlib
|
||||
|
||||
import jax
|
||||
from jax._src import frozen_dict
|
||||
import jax._src.core as jax_core
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib import gpu_triton as triton_kernel_call_lib
|
||||
from jax._src.lib import triton
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.triton import core as triton_core
|
||||
from jax._src.pallas.triton import lowering
|
||||
|
||||
|
||||
def normalize_grid(grid: pallas_core.StaticGrid) -> tuple[int, int, int]:
|
||||
if isinstance(grid, int):
|
||||
grid = (grid,)
|
||||
elif len(grid) > 3:
|
||||
raise ValueError("`grid` should have three or fewer dimensions.")
|
||||
return tuple(grid) + (1,) * (3 - len(grid)) # pyrefly: ignore[bad-return]
|
||||
|
||||
|
||||
def avals_to_layouts(avals):
|
||||
return [list(reversed(range(aval.ndim))) for aval in avals]
|
||||
|
||||
|
||||
def pallas_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes,
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: pallas_core.GridMapping,
|
||||
mesh: pallas_core.Mesh | None,
|
||||
compiler_params: pallas_core.CompilerParams | None,
|
||||
cost_estimate: pallas_core.CostEstimate | None,
|
||||
out_avals: tuple[jax_core.AbstractValue, ...],
|
||||
metadata: frozen_dict.FrozenDict[str, str] | None,
|
||||
name: str | None,
|
||||
):
|
||||
del interpret, out_avals, cost_estimate, name
|
||||
debug_info = jaxpr.debug_info
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"dynamic grid bounds not supported in the Triton backend"
|
||||
)
|
||||
if grid_mapping.num_index_operands:
|
||||
raise NotImplementedError(
|
||||
"scalar prefetch not implemented in the Triton backend"
|
||||
)
|
||||
if mesh is not None:
|
||||
raise NotImplementedError("mesh is not supported in the Triton backend")
|
||||
|
||||
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
|
||||
|
||||
if compiler_params is None:
|
||||
triton_params = triton_core.CompilerParams()
|
||||
else:
|
||||
assert isinstance(compiler_params, triton_core.CompilerParams)
|
||||
triton_params = compiler_params
|
||||
|
||||
num_warps = 4 if triton_params.num_warps is None else triton_params.num_warps
|
||||
num_stages = triton_params.num_stages
|
||||
if num_stages is None:
|
||||
num_stages = 1 if lowering_platform == "rocm" else 3
|
||||
|
||||
if debug:
|
||||
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
|
||||
print(jaxpr)
|
||||
print(f"The grid mapping for pallas_call {debug_info.func_src_info}:")
|
||||
print(grid_mapping)
|
||||
|
||||
try:
|
||||
gpu_device, *_ = jax.local_devices(backend="gpu")
|
||||
except RuntimeError:
|
||||
# GPU device is not available. Fall back to the minimum CC supported by Triton.
|
||||
# TODO(slebedev): Make the fallback CC configurable.
|
||||
arch_name = "8.0"
|
||||
compute_capability = 80
|
||||
else:
|
||||
arch_name = str(gpu_device.compute_capability)
|
||||
if lowering_platform == "rocm":
|
||||
compute_capability = 0
|
||||
else:
|
||||
compute_capability = int(arch_name.replace(".", ""))
|
||||
|
||||
# Sanitize the name to conform to NVPTX requirements. We do this here
|
||||
# to avoid the need to fetch the new name from PTX post compilation.
|
||||
name = mlir.sanitize_name(debug_info.func_name)
|
||||
lowering_result = lowering.lower_jaxpr_to_triton_module(
|
||||
jaxpr, grid_mapping, lowering_platform, compute_capability or None
|
||||
)
|
||||
module_op = lowering_result.module.operation
|
||||
if debug:
|
||||
print(f"\nThe Triton module for pallas_call {debug_info.func_src_info}:")
|
||||
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
|
||||
|
||||
grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
|
||||
buf = io.BytesIO()
|
||||
module_op.write_bytecode(buf)
|
||||
|
||||
serialized_metadata = None
|
||||
if metadata is not None:
|
||||
serialized_metadata = json.dumps(dict(metadata))
|
||||
|
||||
# TODO(b/394629193): Remove True once the bug is fixed.
|
||||
if True:
|
||||
# AOT Triton compilation is only available on jaxlib 0.5.1+.
|
||||
out_types = [
|
||||
ir.RankedTensorType.get(bm.array_aval.shape,
|
||||
mlir.dtype_to_ir_type(bm.array_aval.dtype))
|
||||
for bm in grid_mapping.block_mappings_output
|
||||
]
|
||||
backend_config = dict(
|
||||
name=ir.StringAttr.get(name),
|
||||
ir=ir.StringAttr.get(buf.getvalue()),
|
||||
num_stages=mlir.i32_attr(num_stages),
|
||||
num_warps=mlir.i32_attr(num_warps),
|
||||
grid_x=mlir.i32_attr(grid_x),
|
||||
grid_y=mlir.i32_attr(grid_y),
|
||||
grid_z=mlir.i32_attr(grid_z),
|
||||
debug=ir.BoolAttr.get(debug),
|
||||
)
|
||||
if serialized_metadata is not None:
|
||||
# This field is unstable and may be removed in the future.
|
||||
backend_config["serialized_metadata"] = ir.StringAttr.get(
|
||||
serialized_metadata
|
||||
)
|
||||
return mlir.custom_call(
|
||||
call_target_name="__gpu$xla.gpu.triton",
|
||||
result_types=out_types,
|
||||
operands=in_nodes,
|
||||
backend_config=backend_config,
|
||||
api_version=4,
|
||||
operand_layouts=avals_to_layouts(ctx.avals_in),
|
||||
result_layouts=avals_to_layouts(ctx.avals_out),
|
||||
operand_output_aliases=dict(input_output_aliases),
|
||||
).results
|
||||
|
||||
compilation_result = triton.compile(
|
||||
lowering_platform,
|
||||
buf.getvalue(),
|
||||
arch_name,
|
||||
num_warps=num_warps,
|
||||
num_ctas=1,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
kernel = triton_kernel_call_lib.TritonKernel(
|
||||
debug_info.func_name,
|
||||
num_warps,
|
||||
1,
|
||||
compilation_result.smem_bytes,
|
||||
(
|
||||
compilation_result.hsaco_path
|
||||
if lowering_platform == "rocm"
|
||||
else compilation_result.asm
|
||||
),
|
||||
module_op.get_asm(enable_debug_info=True, pretty_debug_info=True),
|
||||
compute_capability,
|
||||
)
|
||||
kernel_call = triton_kernel_call_lib.TritonKernelCall(
|
||||
kernel,
|
||||
grid_x,
|
||||
grid_y,
|
||||
grid_z,
|
||||
[triton_kernel_call_lib.create_array_parameter(0, 16)]
|
||||
* (len(ctx.avals_in) + len(ctx.avals_out)),
|
||||
)
|
||||
# TODO(b/392558289): Migrate to ``jax.ffi``.
|
||||
return mlir.custom_call(
|
||||
call_target_name="triton_kernel_call",
|
||||
result_types=mlir.flatten_ir_types(
|
||||
map(mlir.aval_to_ir_type, ctx.avals_out)
|
||||
),
|
||||
operands=in_nodes,
|
||||
backend_config=zlib.compress(
|
||||
kernel_call.to_proto(
|
||||
debug_info.func_name,
|
||||
(serialized_metadata or "").encode(),
|
||||
)
|
||||
),
|
||||
operand_layouts=avals_to_layouts(ctx.avals_in),
|
||||
result_layouts=avals_to_layouts(ctx.avals_out),
|
||||
operand_output_aliases=dict(input_output_aliases),
|
||||
).results
|
||||
|
||||
|
||||
pallas_core.register_lowering_rule(triton_core.CompilerParams, pallas_call_lowering, "gpu")
|
||||
@@ -0,0 +1,705 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module for GPU-specific JAX primitives."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import enum
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import jax
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import effects
|
||||
from jax._src import lax
|
||||
from jax._src import state
|
||||
from jax._src import tree_util
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith as arith_dialect
|
||||
from jax._src.lib.mlir.dialects import gpu as gpu_dialect
|
||||
from jax._src.lib.triton import dialect as tt_dialect
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas import primitives as pallas_primitives
|
||||
from jax._src.pallas.triton import lowering
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
Ref: TypeAlias = state.AbstractRef | state.TransformedRef
|
||||
|
||||
Slice = indexing.Slice
|
||||
|
||||
|
||||
def approx_tanh(x: jax.Array) -> jax.Array:
|
||||
r"""Elementwise approximate hyperbolic tangent: :math:`\mathrm{tanh}(x)`.
|
||||
|
||||
See
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-tanh.
|
||||
"""
|
||||
if x.dtype == jnp.float16:
|
||||
asm = "tanh.approx.f16 $0, $1;"
|
||||
constraint = "h"
|
||||
elif x.dtype == jnp.bfloat16:
|
||||
asm = "tanh.approx.bf16 $0, $1;"
|
||||
constraint = "h"
|
||||
elif x.dtype == jnp.float32:
|
||||
asm = "tanh.approx.f32 $0, $1;"
|
||||
constraint = "f"
|
||||
elif x.dtype == jnp.float64:
|
||||
# f64 tanh.approx is only supported on ROCm (uses __ocml_tanh_f64)
|
||||
# CUDA does not have a PTX instruction for f64 approximate tanh
|
||||
asm = "tanh.approx.f64 $0, $1;"
|
||||
constraint = "d"
|
||||
else:
|
||||
raise TypeError(f"approx_tanh does not accept {x.dtype} arrays")
|
||||
|
||||
[result] = elementwise_inline_asm(
|
||||
asm,
|
||||
args=[x],
|
||||
constraints=f"={constraint},{constraint}",
|
||||
pack=1,
|
||||
result_shape_dtypes=[jax.ShapeDtypeStruct(x.shape, x.dtype)],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def elementwise_inline_asm(
|
||||
asm: str,
|
||||
*,
|
||||
args: Sequence[jax.Array],
|
||||
constraints: str,
|
||||
pack: int,
|
||||
result_shape_dtypes: Sequence[jax.ShapeDtypeStruct],
|
||||
) -> Sequence[jax.Array]:
|
||||
"""Inline assembly applying an elementwise operation.
|
||||
|
||||
Args:
|
||||
asm: The assembly code to run.
|
||||
args: The arguments to pass to the assembly code.
|
||||
constraints: LLVM inline assembly `constraints
|
||||
<https://llvm.org/docs/LangRef.html#inline-asm-constraint-string>`_.
|
||||
pack: The number of elements from each argument expected by a single
|
||||
instance of the assembly code.
|
||||
result_shape_dtypes: The shapes and dtypes of the results produced by the
|
||||
assembly code.
|
||||
|
||||
Returns:
|
||||
The results produced by the assembly code.
|
||||
"""
|
||||
return elementwise_inline_asm_p.bind(
|
||||
*args,
|
||||
asm=asm,
|
||||
constraints=constraints,
|
||||
pack=pack,
|
||||
result_shape_dtypes=tuple(result_shape_dtypes),
|
||||
)
|
||||
|
||||
|
||||
elementwise_inline_asm_p = jax_core.Primitive("elementwise_inline_asm_p")
|
||||
elementwise_inline_asm_p.multiple_results = True
|
||||
|
||||
|
||||
@elementwise_inline_asm_p.def_abstract_eval
|
||||
def _elementwise_inline_asm_abstract_eval(
|
||||
*avals: jax_core.ShapedArray, result_shape_dtypes, **kwargs
|
||||
) -> Sequence[jax_core.ShapedArray]:
|
||||
del kwargs # Unused.
|
||||
if not all(x.shape == y.shape for x, y in zip(avals, avals[1:])):
|
||||
raise ValueError(
|
||||
"All arguments of elementwise_inline_asm must have the same shape"
|
||||
)
|
||||
return [jax_core.ShapedArray(s.shape, s.dtype) for s in result_shape_dtypes]
|
||||
|
||||
|
||||
@lowering.register_lowering(elementwise_inline_asm_p)
|
||||
def _elementwise_inline_asm_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
*args,
|
||||
asm,
|
||||
constraints,
|
||||
pack,
|
||||
result_shape_dtypes,
|
||||
):
|
||||
del result_shape_dtypes # Unused.
|
||||
|
||||
if "tanh.approx" in asm:
|
||||
if ctx.context.platform == "rocm":
|
||||
return _approx_tanh_rocm_lowering(ctx, *args)
|
||||
if ctx.avals_in[0].dtype == jnp.float64:
|
||||
raise TypeError(
|
||||
"approx_tanh does not support float64 on CUDA; it is only"
|
||||
" supported on ROCm"
|
||||
)
|
||||
|
||||
return tt_dialect.ElementwiseInlineAsmOp(
|
||||
[*map(mlir.aval_to_ir_type, ctx.avals_out)],
|
||||
asm,
|
||||
constraints=constraints,
|
||||
pure=True,
|
||||
packed_element=pack,
|
||||
args=args,
|
||||
).result
|
||||
|
||||
|
||||
def _approx_tanh_rocm_lowering(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
*args,
|
||||
):
|
||||
"""Lower approx_tanh for ROCm.
|
||||
|
||||
AMD CDNA3 (MI300X/gfx942) does not have a hardware tanh instruction.
|
||||
See: https://github.com/triton-lang/triton/pull/7780
|
||||
"""
|
||||
[arg] = args
|
||||
[out_aval] = ctx.avals_out
|
||||
in_dtype = ctx.avals_in[0].dtype
|
||||
|
||||
if in_dtype == jnp.float64:
|
||||
result_type = mlir.aval_to_ir_type(out_aval)
|
||||
result = tt_dialect.extern_elementwise(
|
||||
result_type,
|
||||
list(args),
|
||||
libname="",
|
||||
libpath="",
|
||||
symbol="__ocml_tanh_f64",
|
||||
pure=True,
|
||||
)
|
||||
return [result]
|
||||
|
||||
needs_cast = in_dtype in (jnp.float16, jnp.bfloat16)
|
||||
|
||||
if needs_cast:
|
||||
f32_type = mlir.dtype_to_ir_type(jnp.dtype(jnp.float32))
|
||||
if out_aval.shape:
|
||||
result_type = ir.RankedTensorType.get(out_aval.shape, f32_type)
|
||||
else:
|
||||
result_type = f32_type
|
||||
arg = arith_dialect.extf(result_type, arg)
|
||||
else:
|
||||
result_type = mlir.aval_to_ir_type(out_aval)
|
||||
result = tt_dialect.extern_elementwise(
|
||||
result_type,
|
||||
[arg],
|
||||
libname="libdevice",
|
||||
libpath="",
|
||||
symbol="__triton_hip_fast_tanhf",
|
||||
pure=True,
|
||||
)
|
||||
|
||||
if needs_cast:
|
||||
out_type = mlir.aval_to_ir_type(out_aval)
|
||||
result = arith_dialect.truncf(out_type, result)
|
||||
|
||||
return [result]
|
||||
|
||||
|
||||
def debug_barrier() -> None:
|
||||
"""Synchronizes all kernel executions in the grid."""
|
||||
return debug_barrier_p.bind()
|
||||
|
||||
|
||||
class BarrierEffect(jax_core.Effect):
|
||||
pass
|
||||
|
||||
barrier_effect = BarrierEffect()
|
||||
|
||||
pallas_core.kernel_local_effects.add_type(BarrierEffect)
|
||||
effects.control_flow_allowed_effects.add_type(BarrierEffect)
|
||||
|
||||
|
||||
debug_barrier_p = jax_core.Primitive("debug_barrier_p")
|
||||
debug_barrier_p.multiple_results = True
|
||||
|
||||
|
||||
@debug_barrier_p.def_effectful_abstract_eval
|
||||
def _debug_barrier_abstract_eval():
|
||||
return (), {barrier_effect}
|
||||
|
||||
|
||||
@lowering.register_lowering(debug_barrier_p)
|
||||
def _debug_barrier_lowering(ctx: lowering.LoweringRuleContext):
|
||||
del ctx # Unused.
|
||||
gpu_dialect.barrier()
|
||||
return []
|
||||
|
||||
|
||||
def load(
|
||||
ref: Ref,
|
||||
*,
|
||||
mask: jax.Array | None = None,
|
||||
other: jax.typing.ArrayLike | None = None,
|
||||
cache_modifier: str | None = None,
|
||||
eviction_policy: str | None = None,
|
||||
volatile: bool = False,
|
||||
) -> jax.Array:
|
||||
"""Loads an array from the given ref.
|
||||
|
||||
If neither ``mask`` nor ``other`` is specified, this function has the same
|
||||
semantics as ``ref[idx]`` in JAX.
|
||||
|
||||
Args:
|
||||
ref: The ref to load from.
|
||||
mask: An optional boolean mask specifying which indices to load. If mask is
|
||||
``False`` and ``other`` is not given, no assumptions can be made about the
|
||||
value in the resulting array.
|
||||
other: An optional value to use for indices where mask is ``False``.
|
||||
cache_modifier: TO BE DOCUMENTED.
|
||||
eviction_policy: TO BE DOCUMENTED.
|
||||
volatile: TO BE DOCUMENTED.
|
||||
"""
|
||||
return pallas_primitives.load(
|
||||
ref,
|
||||
None,
|
||||
mask=mask,
|
||||
other=other,
|
||||
cache_modifier=cache_modifier,
|
||||
eviction_policy=eviction_policy,
|
||||
volatile=volatile,
|
||||
)
|
||||
|
||||
|
||||
def store(
|
||||
ref: Ref,
|
||||
val: jax.Array,
|
||||
*,
|
||||
mask: jax.Array | None = None,
|
||||
eviction_policy: str | None = None,
|
||||
) -> None:
|
||||
"""Stores a value to the given ref.
|
||||
|
||||
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
|
||||
"""
|
||||
return pallas_primitives.store(
|
||||
ref,
|
||||
None,
|
||||
val,
|
||||
mask=mask,
|
||||
eviction_policy=eviction_policy,
|
||||
)
|
||||
|
||||
|
||||
class AtomicOpType(enum.Enum):
|
||||
XCHG = "xchg"
|
||||
ADD = "add"
|
||||
MAX = "max"
|
||||
MIN = "min"
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
XOR = "xor"
|
||||
|
||||
|
||||
atomic_rmw_p = jax_core.Primitive("atomic_rmw")
|
||||
|
||||
|
||||
def _atomic_rmw_discharge_rule(
|
||||
in_avals, out_avals, *args_flat, args_tree, atomic_type: AtomicOpType
|
||||
):
|
||||
del out_avals # Unused.
|
||||
ref, transforms, val, mask = args_tree.unflatten(args_flat)
|
||||
*prev_transforms, idx = transforms
|
||||
ref = state_discharge.transform_array(ref, prev_transforms)
|
||||
|
||||
if mask is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
if atomic_type == AtomicOpType.ADD:
|
||||
monoid = lambda x, y: x + y
|
||||
elif atomic_type == AtomicOpType.MAX:
|
||||
monoid = jnp.maximum
|
||||
elif atomic_type == AtomicOpType.MIN:
|
||||
monoid = jnp.minimum
|
||||
else:
|
||||
raise NotImplementedError(atomic_type)
|
||||
|
||||
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
|
||||
indices = idx.indices
|
||||
scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices]
|
||||
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
||||
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
||||
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
||||
val_indexer = tuple(
|
||||
None if scalar else slice(None) for scalar in scalar_dims
|
||||
)
|
||||
val = val[val_indexer]
|
||||
val = monoid(val, out_ones)
|
||||
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
|
||||
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
||||
out = out_ones[out_indexer]
|
||||
elif all(not isinstance(s, Slice) for s in idx.indices):
|
||||
out = ref[idx.indices]
|
||||
x_new = ref.at[idx.indices].set(monoid(out, val))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return (x_new,) + (None,) * (len(in_avals) - 1), out
|
||||
|
||||
|
||||
state_discharge.register_discharge_rule(atomic_rmw_p)(
|
||||
_atomic_rmw_discharge_rule
|
||||
)
|
||||
|
||||
|
||||
@atomic_rmw_p.def_effectful_abstract_eval
|
||||
def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType):
|
||||
ref, _, _, _ = args_tree.unflatten(avals_flat)
|
||||
if ref.dtype == jnp.dtype("float16") and atomic_type != AtomicOpType.ADD:
|
||||
raise ValueError(f"`atomic_{atomic_type.value}` does not support f16.")
|
||||
if ref.dtype in {
|
||||
jnp.dtype("bool"),
|
||||
jnp.dtype("int8"),
|
||||
jnp.dtype("int16"),
|
||||
jnp.bfloat16,
|
||||
}:
|
||||
raise ValueError(
|
||||
f"`atomic_{atomic_type.value}` does not support {ref.dtype}."
|
||||
)
|
||||
return pallas_primitives._swap_abstract_eval(*avals_flat, args_tree=args_tree)
|
||||
|
||||
|
||||
def _atomic_rmw(
|
||||
x_ref_or_view,
|
||||
idx,
|
||||
val,
|
||||
*,
|
||||
mask: Any | None = None,
|
||||
atomic_type: AtomicOpType,
|
||||
):
|
||||
x_ref, transforms = state_primitives.get_ref_and_transforms(
|
||||
x_ref_or_view, idx, "atomic_rmw"
|
||||
)
|
||||
args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask))
|
||||
return atomic_rmw_p.bind(
|
||||
*args_flat, args_tree=args_tree, atomic_type=atomic_type
|
||||
)
|
||||
|
||||
def _expand_atomic_fp_min_max(
|
||||
atomic_type: AtomicOpType,
|
||||
ptr: ir.Value,
|
||||
val: ir.Value,
|
||||
mask: ir.Value | None = None,
|
||||
semantic: tt_dialect.MemSemantic = tt_dialect.MemSemantic.ACQUIRE_RELEASE,
|
||||
sync_scope: tt_dialect.MemSyncScope = tt_dialect.MemSyncScope.GPU,
|
||||
) -> ir.Value:
|
||||
"""
|
||||
Expands floating point min/max via sequence of integer min/max. Does not handle NaNs.
|
||||
|
||||
min:
|
||||
return atomic_smin(i_ptr, i_val) if i_val >= 0 else atomic_umax(i_ptr, i_val)
|
||||
|
||||
max:
|
||||
return atomic_smax(i_ptr, i_val) if i_val >= 0 else atomic_umin(i_ptr, i_val)
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(ptr.type, ir.RankedTensorType):
|
||||
ptr_type = ir.RankedTensorType(ptr.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
|
||||
)
|
||||
else:
|
||||
result_type = tt_dialect.PointerType(ptr.type).pointee_type
|
||||
|
||||
ptr_cast = tt_dialect.bitcast(lowering._fp_bits_type(ptr.type), ptr)
|
||||
val_cast = tt_dialect.bitcast(lowering._fp_bits_type(val.type), val)
|
||||
|
||||
zero = lowering._full(val_cast.type, 0)
|
||||
pos_cmp = lowering._greater_equal(val_cast, zero, signed=True)
|
||||
neg_cmp = lowering._less_than(val_cast, zero, signed=True)
|
||||
|
||||
pos_mask = pos_cmp if mask is None else arith_dialect.andi(mask, pos_cmp)
|
||||
neg_mask = neg_cmp if mask is None else arith_dialect.andi(mask, neg_cmp)
|
||||
|
||||
pos_op, neg_op = (
|
||||
(tt_dialect.RMWOp.MAX, tt_dialect.RMWOp.UMIN)
|
||||
if atomic_type == AtomicOpType.MAX
|
||||
else (tt_dialect.RMWOp.MIN, tt_dialect.RMWOp.UMAX)
|
||||
)
|
||||
|
||||
pos_val = lowering._atomic_rmw(
|
||||
pos_op, ptr_cast, val_cast, mask=pos_mask, semantic=semantic, sync_scope=sync_scope
|
||||
)
|
||||
neg_val = lowering._atomic_rmw(
|
||||
neg_op, ptr_cast, val_cast, mask=neg_mask, semantic=semantic, sync_scope=sync_scope
|
||||
)
|
||||
result = arith_dialect.select(pos_cmp, pos_val, neg_val)
|
||||
return tt_dialect.bitcast(result_type, result)
|
||||
|
||||
@lowering.register_lowering(atomic_rmw_p)
|
||||
def _atomic_lowering_rule(
|
||||
ctx: lowering.LoweringRuleContext,
|
||||
*args_flat,
|
||||
args_tree,
|
||||
atomic_type: AtomicOpType,
|
||||
):
|
||||
block_info, *_ = ctx.block_infos
|
||||
assert block_info is not None
|
||||
ptr, indexers, val, mask = args_tree.unflatten(args_flat)
|
||||
*_, value_aval, mask_aval = args_tree.unflatten(ctx.avals_in)
|
||||
indexers = list(indexers)
|
||||
if not indexers or not isinstance(indexers[-1], indexing.NDIndexer):
|
||||
ref_aval = state.transform_type(indexers, ctx.avals_in[0])
|
||||
assert isinstance(ref_aval, state.AbstractRef)
|
||||
indexers.append(indexing.NDIndexer.make_trivial_indexer(ref_aval.shape))
|
||||
if len(indexers) != 1:
|
||||
raise NotImplementedError("Only single indexer is supported.")
|
||||
idx = indexers[0]
|
||||
ptr = lowering._compute_pointers_from_indices(ptr, block_info, idx)
|
||||
val = lowering._ensure_ir_value(val, value_aval)
|
||||
if mask is not None:
|
||||
mask = lowering._ensure_ir_value(mask, mask_aval)
|
||||
if atomic_type == AtomicOpType.XCHG:
|
||||
op = tt_dialect.RMWOp.XCHG
|
||||
elif atomic_type == AtomicOpType.ADD:
|
||||
if isinstance(val.type, ir.IntegerType):
|
||||
op = tt_dialect.RMWOp.ADD
|
||||
else:
|
||||
op = tt_dialect.RMWOp.FADD
|
||||
elif atomic_type == AtomicOpType.MIN:
|
||||
if isinstance(val.type, ir.IntegerType):
|
||||
op = (
|
||||
tt_dialect.RMWOp.MIN
|
||||
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
|
||||
else tt_dialect.RMWOp.UMIN
|
||||
)
|
||||
else:
|
||||
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
|
||||
elif atomic_type == AtomicOpType.MAX:
|
||||
if isinstance(val.type, ir.IntegerType):
|
||||
op = (
|
||||
tt_dialect.RMWOp.MAX
|
||||
if jnp.issubdtype(value_aval.dtype, jnp.signedinteger)
|
||||
else tt_dialect.RMWOp.UMAX
|
||||
)
|
||||
else:
|
||||
return _expand_atomic_fp_min_max(atomic_type, ptr, val, mask=mask)
|
||||
elif atomic_type == AtomicOpType.AND:
|
||||
op = tt_dialect.RMWOp.AND
|
||||
elif atomic_type == AtomicOpType.OR:
|
||||
op = tt_dialect.RMWOp.OR
|
||||
elif atomic_type == AtomicOpType.XOR:
|
||||
op = tt_dialect.RMWOp.XOR
|
||||
else:
|
||||
raise NotImplementedError(f"unsupported atomic operation: {atomic_type}")
|
||||
return lowering._atomic_rmw(op, ptr, val, mask=mask)
|
||||
|
||||
|
||||
def atomic_xchg(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically exchanges the given value with the value at the given index.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the aupdate.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XCHG
|
||||
)
|
||||
|
||||
|
||||
def atomic_add(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] += val``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.ADD
|
||||
)
|
||||
|
||||
|
||||
def atomic_max(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] = max(x_ref_or_view[idx], val)``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MAX
|
||||
)
|
||||
|
||||
|
||||
def atomic_min(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] = min(x_ref_or_view[idx], val)``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.MIN
|
||||
)
|
||||
|
||||
|
||||
def atomic_and(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] &= val``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.AND
|
||||
)
|
||||
|
||||
|
||||
def atomic_or(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] |= val``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.OR
|
||||
)
|
||||
|
||||
|
||||
def atomic_xor(x_ref_or_view, idx, val, *, mask: Any | None = None):
|
||||
"""Atomically computes ``x_ref_or_view[idx] ^= val``.
|
||||
|
||||
Args:
|
||||
x_ref_or_view: The ref to operate on.
|
||||
idx: The indexer to use.
|
||||
mask: TO BE DOCUMENTED.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return _atomic_rmw(
|
||||
x_ref_or_view, idx, val, mask=mask, atomic_type=AtomicOpType.XOR
|
||||
)
|
||||
|
||||
|
||||
atomic_cas_p = jax_core.Primitive("atomic_cas")
|
||||
|
||||
|
||||
@atomic_cas_p.def_effectful_abstract_eval
|
||||
def _atomic_cas_abstract_eval(ref_aval, cmp_aval, val_aval):
|
||||
if cmp_aval.dtype != val_aval.dtype or cmp_aval.shape != val_aval.shape:
|
||||
raise ValueError("cmp and val must have identical dtypes and shapes")
|
||||
if ref_aval.shape:
|
||||
raise ValueError("ref must be scalar.")
|
||||
if cmp_aval.shape:
|
||||
raise ValueError("cmp must be scalar.")
|
||||
if val_aval.shape:
|
||||
raise ValueError("val must be scalar.")
|
||||
return jax_core.ShapedArray(val_aval.shape, val_aval.dtype), {
|
||||
state.WriteEffect(0)
|
||||
}
|
||||
|
||||
|
||||
def atomic_cas(ref, cmp, val):
|
||||
"""Performs an atomic compare-and-swap of the value in the ref with the
|
||||
|
||||
given value.
|
||||
|
||||
Args:
|
||||
ref: The ref to operate on.
|
||||
cmp: The expected value to compare against.
|
||||
val: The value to swap in.
|
||||
|
||||
Returns:
|
||||
The value at the given index prior to the atomic operation.
|
||||
"""
|
||||
return atomic_cas_p.bind(ref, cmp, val)
|
||||
|
||||
|
||||
@state_discharge.register_discharge_rule(atomic_cas_p)
|
||||
def _atomic_cas_discharge_rule(in_avals, out_avals, ref, cmp, val):
|
||||
del in_avals, out_avals
|
||||
new_val = jnp.where(ref == cmp, val, ref)
|
||||
return (new_val, None, None), ref
|
||||
|
||||
|
||||
@lowering.register_lowering(atomic_cas_p)
|
||||
def _atomic_cas_lowering_rule(ctx: lowering.LoweringRuleContext, ptr, cmp, val):
|
||||
_, cmp_aval, val_aval = ctx.avals_in
|
||||
if isinstance(ptr.type, ir.RankedTensorType):
|
||||
ptr_type = ir.RankedTensorType(ptr.type)
|
||||
element_type = tt_dialect.PointerType(ptr_type.element_type)
|
||||
result_type = ir.RankedTensorType.get(
|
||||
ptr_type.shape, element_type.pointee_type, ptr_type.encoding
|
||||
)
|
||||
else:
|
||||
result_type = tt_dialect.PointerType(ptr.type).pointee_type
|
||||
return tt_dialect.atomic_cas(
|
||||
result_type,
|
||||
ptr,
|
||||
lowering._ensure_ir_value(cmp, cmp_aval),
|
||||
lowering._ensure_ir_value(val, val_aval),
|
||||
sem=tt_dialect.MemSemantic.ACQUIRE_RELEASE,
|
||||
scope=tt_dialect.MemSyncScope.GPU,
|
||||
)
|
||||
|
||||
|
||||
max_contiguous_p = jax_core.Primitive("max_contiguous")
|
||||
|
||||
max_contiguous_p.def_impl(lambda x, **_: x)
|
||||
mlir.register_lowering(max_contiguous_p, lambda _, x, **__: [x])
|
||||
|
||||
|
||||
def max_contiguous(x, values):
|
||||
"""A compiler hint that asserts the ``values`` first values of ``x`` are contiguous."""
|
||||
if not isinstance(values, (list, tuple)):
|
||||
values = (values,)
|
||||
return max_contiguous_p.bind(x, values=tuple(values))
|
||||
|
||||
|
||||
@max_contiguous_p.def_abstract_eval
|
||||
def _max_contiguous_abstract_eval(aval, **_):
|
||||
return aval
|
||||
|
||||
|
||||
@lowering.register_lowering(max_contiguous_p)
|
||||
def _max_contiguous_rule(
|
||||
ctx: lowering.LoweringRuleContext, x, values: Sequence[int]
|
||||
):
|
||||
[x_aval] = ctx.avals_in
|
||||
assert len(x_aval.shape) == len(values)
|
||||
lowering._set_attr(
|
||||
x,
|
||||
"tt.contiguity",
|
||||
ir.DenseIntElementsAttr.get(np.asarray(values, dtype=np.int32)), # pyrefly: ignore[no-matching-overload]
|
||||
)
|
||||
return x
|
||||
Reference in New Issue
Block a user