This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,13 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,228 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All-gather kernel implemented using Mosaic GPU."""
from collections.abc import Hashable
import functools
import itertools
import math
import jax
from jax import lax
from jax.experimental import multihost_utils
from jax.experimental import pallas as pl
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.pallas import mosaic_gpu as plgpu
from jax.extend import backend
import jax.numpy as jnp
def all_gather(
x: jax.Array,
*,
axis_name: Hashable,
gather_dimension: int = 0,
num_blocks: int | None = None,
tile_size: int | None = None,
vec_size: int | None = None,
) -> jax.Array:
"""Performs an all-gather operation using multimem instructions.
Args:
x: Input array. Should be sharded across the specified axis.
axis_name: Name of the mesh axis to all-gather across.
gather_dimension: Axis along which to gather.
num_blocks: Number of blocks to use. Defaults to the device core count.
tile_size: Total tile size to split across major, gather, and minor dimensions.
vec_size: Vector size for the layout. If None, automatically inferred from dtype.
"""
num_devices = lax.axis_size(axis_name)
input_shape = x.shape
dtype = x.dtype
ndim = len(input_shape)
if num_blocks is None:
num_blocks = backend.get_default_device().core_count
if gather_dimension < -ndim or gather_dimension >= ndim:
raise ValueError(
f"gather_dimension {gather_dimension} out of bounds for array of rank"
f" {ndim}"
)
if gather_dimension < 0:
gather_dimension += ndim
input_gather_dim = input_shape[gather_dimension]
major_dims = math.prod(input_shape[:gather_dimension])
minor_dims = math.prod(input_shape[gather_dimension+1:])
output_gather_dim = input_gather_dim * num_devices
output_shape = (
*input_shape[:gather_dimension], output_gather_dim, *input_shape[gather_dimension + 1 :],
)
if (output_size := math.prod(output_shape)) % 128:
raise ValueError("Output size must be divisible by 128")
if jnp.issubdtype(dtype, jnp.integer):
if vec_size is None:
vec_size = 1 # Integer types only support unvectorized operations
elif vec_size != 1:
raise ValueError("Integer types only support vec_size=1")
elif vec_size is None: # vec_size inference for floating point types
dtype_bits = jnp.finfo(dtype).bits
max_vec_size = min(128 // dtype_bits, output_size // 128)
if tile_size is not None:
max_vec_size_for_tile = tile_size // 128
max_vec_size = min(max_vec_size, max_vec_size_for_tile)
vec_size = 32 // dtype_bits # We don't support multimem below 32-bit
while vec_size * 2 <= max_vec_size:
vec_size *= 2
if math.prod(output_shape) % vec_size:
raise ValueError(
"The total number of elements in the output"
f" ({math.prod(output_shape)}) must be divisible by the vec_size"
f" ({vec_size})"
)
min_transfer_elems = 128 * vec_size
if tile_size is None:
# TODO(apaszke): 8 is just an arbitrary unrolling factor. Tune it!
unroll_factor = min(math.prod(input_shape) // min_transfer_elems, 8)
tile_size = unroll_factor * min_transfer_elems
if tile_size < min_transfer_elems:
raise ValueError(
f"{tile_size=} is smaller than minimum required"
f" {min_transfer_elems} for {vec_size=}"
)
minor_tile = math.gcd(tile_size, minor_dims)
remaining_tile = tile_size // minor_tile
gather_tile = math.gcd(remaining_tile, input_gather_dim)
major_tile = remaining_tile // gather_tile
if major_dims % major_tile != 0:
raise NotImplementedError(
f"Major dimension size ({major_dims}) must be divisible by the"
f" inferred major tile size ({major_tile}). Consider adjusting tile_size."
)
def kernel(x_ref, y_ref, done_barrier):
dev_idx = lax.axis_index(axis_name)
x_ref_3d = x_ref.reshape((major_dims, input_gather_dim, minor_dims))
y_ref_3d = y_ref.reshape((major_dims, output_gather_dim, minor_dims))
y_ref_3d = y_ref_3d.at[:, pl.ds(dev_idx * input_gather_dim, input_gather_dim), :]
major_tiles = major_dims // major_tile
gather_tiles = input_gather_dim // gather_tile
minor_tiles = minor_dims // minor_tile
# TODO(apaszke): Use a TMA pipeline
@plgpu.nd_loop((major_tiles, gather_tiles, minor_tiles), collective_axes="blocks")
def _transfer_loop(loop_info: plgpu.NDLoopInfo):
major_tile_idx, gather_tile_idx, minor_tile_idx = loop_info.index
idxs = (
pl.ds(major_tile_idx * major_tile, major_tile),
pl.ds(gather_tile_idx * gather_tile, gather_tile),
pl.ds(minor_tile_idx * minor_tile, minor_tile)
)
output_data = plgpu.layout_cast(
x_ref_3d[idxs],
plgpu.Layout.WG_STRIDED((major_tile, gather_tile, minor_tile), vec_size=vec_size)
)
plgpu.multimem_store(output_data, y_ref_3d.at[idxs], axis_name)
# Wait for everyone to finish storing into our memory before returning.
plgpu.semaphore_signal_multicast(done_barrier, collective_axes=axis_name)
pl.semaphore_wait(done_barrier, num_devices, decrement=False)
# TODO(b/448323639): We fake modify the input to ensure that XLA:GPU copies
# the operand into symmetric memory.
@pl.when(dev_idx == -1)
def _never():
x_ref[(0,) * len(x_ref.shape)] = jnp.asarray(0, x_ref.dtype)
return plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct(output_shape, dtype),
grid=(num_blocks,),
grid_names=("blocks",),
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
)(x)
def _run_example():
P = jax.sharding.PartitionSpec
shape = (4 * 4096, 4 * 4096) # This shape is global!
dtype = jnp.bfloat16
shards = jax.device_count()
mesh = jax.make_mesh(
(shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
)
jax.set_mesh(mesh)
# We measure time per-shard and so we only need bytes per shard.
local_out_bytes = math.prod(shape) * jnp.dtype(dtype).itemsize
total_bytes = local_out_bytes
a = jax.random.normal(jax.random.key(1), shape, dtype)
a = jax.sharding.reshard(a, P("x", None))
@jax.jit
@functools.partial(jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None))
def ref_fn(x):
return lax.all_gather(x, "x", axis=0, tiled=True)
ref_fn(a).block_until_ready() # Warmup.
_, ref_kernels_ms = profiler.measure(ref_fn, aggregate=False)(a)
assert ref_kernels_ms is not None
ref_time_us = sum(t * 1e3 for _, t in ref_kernels_ms)
# We choose the minimum across processes to choose the runtime that didn't
# include devices waiting for other devices.
ref_time_us = min(multihost_utils.process_allgather(ref_time_us).tolist())
ref_bw = total_bytes / (ref_time_us * 1e-6) / 1e9 # GB/s
tuning_it = itertools.product(
(4, 8, 16, 32, 64, 132), # num_blocks
(1024, 2048, 4096, 8192), # tile_size
)
best_bw = 0.0
best_runtime = float("inf")
for num_blocks, tile_size in tuning_it:
@jax.jit
@functools.partial(
jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None), check_vma=False
)
def kernel_fn(x):
return all_gather(x, axis_name="x", gather_dimension=0, num_blocks=num_blocks, tile_size=tile_size)
try:
_, kernels_ms = profiler.measure(kernel_fn, aggregate=False)(a)
except ValueError as e:
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
continue
raise
assert kernels_ms is not None
runtime_us = sum(t * 1e3 for _, t in kernels_ms)
runtime_us = min(multihost_utils.process_allgather(runtime_us).tolist())
achieved_bw = total_bytes / (runtime_us * 1e-6) / 1e9 # GB/s
if achieved_bw > best_bw:
best_runtime = runtime_us
best_bw = achieved_bw
print(f"{num_blocks=}, {tile_size=}: {runtime_us:<7.1f}us = {achieved_bw:4.1f} GB/s")
print(f"Total bytes transferred: {total_bytes / 1e9:.2f} GB")
print(f"\tBest: {best_runtime:<7.1f}us = {best_bw:4.1f} GB/s")
print(f"\tRef: {ref_time_us:<7.1f}us = {ref_bw:4.1f} GB/s")
if __name__ == "__main__":
from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error
jt_multiprocess.main(shard_main=_run_example)
@@ -0,0 +1,680 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing fused attention forward and backward pass."""
from __future__ import annotations
import functools
import math
from typing import Any
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
import jax.numpy as jnp
import numpy as np
import dataclasses
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
@dataclasses.dataclass(frozen=True, slots=True)
class BlockSizes:
"""
Tile sizes parameterizing the attention kernel. These block sizes
should be tuned for the model and hardware for optimal performance.
Attributes:
block_q: Block size along Q sequence length for forward kernel.
block_k: Block size along KV sequence length for forward kernel.
block_kv: Block size along KV sequence length for forward kernel.
block_q_dkv: Block size along Q sequence length for dKV backward kernel.
block_kv_dkv: Block size along KV sequence length for dKV backward kernel.
block_q_dq: Block size along Q sequence length for dQ backward kernel.
block_kv_dq: Block size along KV sequence length for dQ backward kernel.
"""
block_q: int
block_k: int
block_q_dkv: int | None = None
block_kv_dkv: int | None = None
block_q_dq: int | None = None
block_kv_dq: int | None = None
@classmethod
def get_default(cls):
return BlockSizes(
block_q=128,
block_k=128,
block_q_dkv=32,
block_kv_dkv=32,
block_q_dq=32,
block_kv_dq=32,
)
@property
def has_backward_blocks(self) -> bool:
"""Returns True if all backward blocks are specified for the fused
dq and dk/dv backwards pass.
"""
backward_blocks = [
self.block_q_dkv,
self.block_kv_dkv,
self.block_q_dq,
self.block_kv_dq,
]
return all(b is not None for b in backward_blocks)
def mha_forward_kernel(
q_ref,
k_ref,
v_ref, # Input arrays
segment_ids_ref: jax.Array | None, # segment_id arrays
o_ref: Any, # Output
*residual_refs: Any, # Residual outputs
sm_scale: float,
causal: bool,
block_q: int,
block_k: int,
head_dim: int,
):
seq_len = k_ref.shape[0]
start_q = pl.program_id(0)
head_dim_padded = q_ref.shape[-1]
# o is the buffer where we accumulate the output on sram.
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf')
l_i = jnp.zeros(block_q, dtype=jnp.float32)
# acc is the buffer where we accumulate the output on sram.
o = jnp.zeros((block_q, head_dim_padded), dtype=jnp.float32)
# Load q: it will stay in L1 throughout. Indices form a matrix because we
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
# q tile has shape [block_q, head_dim_padded], head_dim_padded >= head_dim.
curr_q_slice = pl.dslice(start_q * block_q, block_q)
head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :]
q = plgpu.load(q_ref, mask=head_mask, other=0.0)
q_segment_ids = (
None if segment_ids_ref is None else segment_ids_ref[curr_q_slice]
)
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
# Here we only loop over blocks of kv to process entire seq_len, the loop over
# blocks of q is carried out by the grid.
def body(start_k, carry):
o_prev, m_prev, l_prev = carry
curr_k_slice = pl.dslice(start_k * block_k, block_k)
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
qk = pl.dot(q, k.T) # [block_q, block_k]
# Scale logits to convert from base-2 to the natural log domain.
# This is based on the identity: e^x = 2^(x * log2(e)).
qk_scale = math.log2(math.e)
if sm_scale != 1.:
qk_scale *= sm_scale
qk *= qk_scale
# Avoids Triton crash.
# if num_heads > 2:
# qk = qk.astype(q_ref.dtype)
# qk = qk.astype(jnp.float32)
if causal or segment_ids_ref is not None:
mask = None
if segment_ids_ref is not None:
assert q_segment_ids is not None
kv_segment_ids = segment_ids_ref[curr_k_slice]
mask = segment_mask(q_segment_ids, kv_segment_ids)
if causal:
span_q = start_q * block_q + jnp.arange(block_q)
span_k = start_k * block_k + jnp.arange(block_k)
causal_mask = span_q[:, None] >= span_k[None, :]
mask = (
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
)
# Apply mask to qk.
assert mask is not None
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
m_curr = jnp.max(qk, axis=-1)
m_next = jnp.maximum(m_prev, m_curr)
correction = jnp.exp2(m_prev - m_next)
l_prev_corr = correction * l_prev
s_curr = jnp.exp2(
qk - m_next[:, None]
) # Use m_next instead of m_curr to avoid a correction on l_curr
l_curr = s_curr.sum(axis=-1)
l_next = l_prev_corr + l_curr
o_prev_corr = correction[:, None] * o_prev
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask)
o_curr = pl.dot(s_curr.astype(v.dtype), v)
o_next = o_prev_corr + o_curr
return o_next, m_next, l_next
if causal:
# Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q)
upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k)
else:
upper_bound = pl.cdiv(seq_len, block_k)
o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))
# We keep an unscaled version of o during the scan over seq_len. Scaling it
# by the last l_i gives us the correct final output. See section 3.1.1 in the
# FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691.
o /= l_i[:, None]
if residual_refs:
lse_ref = residual_refs[0]
lse_ref[...] = m_i + jnp.log2(l_i)
# Write output to dram.
plgpu.store(o_ref.at[:, : o.shape[-1]], o.astype(o_ref.dtype), mask=head_mask)
def segment_mask(
q_segment_ids: jax.Array,
kv_segment_ids: jax.Array,
):
# [B, T, 1] or [T, 1]
q_segment_ids = jnp.expand_dims(q_segment_ids, axis=-1)
# [B, 1, S] or [1, S]
if kv_segment_ids.ndim == 1:
kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=0)
else:
kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=1)
return jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
@functools.partial(
jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
)
@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"causal",
"block_sizes",
"backward_pass_impl",
"num_warps",
"num_stages",
"grid",
"interpret",
"debug",
"return_residuals",
],
)
def mha(
q,
k,
v,
segment_ids: jnp.ndarray | None,
sm_scale: float = 1.0,
causal: bool = False,
block_sizes: BlockSizes = BlockSizes.get_default(),
backward_pass_impl: str = "triton",
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
return_residuals: bool = False,
):
del backward_pass_impl
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_sizes.block_q, q_seq_len)
block_k = min(block_sizes.block_k, kv_seq_len)
head_dim_padded = pl.next_power_of_2(head_dim)
if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]):
raise ValueError(
f"This kernel expects q, k, and v to have the same head dimension, but"
f" found {q.shape=}, {k.shape=}, {v.shape=}."
)
if q_seq_len % block_q != 0:
raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}")
if kv_seq_len % block_k != 0:
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}")
# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads)
num_warps_ = num_warps
if num_warps_ is None:
num_warps_ = 4 if head_dim <= 64 else 8
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
block_q=block_q, block_k=block_k,
head_dim=head_dim, causal=causal)
in_specs: list[pl.BlockSpec | None] = [
pl.BlockSpec((None, block_q, None, head_dim_padded),
lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
lambda _, j, k: (j, 0, k, 0)),
]
in_specs.append(
None
if segment_ids is None
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
)
out_shape = [q]
out_specs = [pl.BlockSpec((None, block_q, None, head_dim_padded),
lambda i, j, k: (j, i, k, 0))]
if return_residuals:
out_shape.append(jax.ShapeDtypeStruct(
shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32)) # lse
out_specs.append(
pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))) # lse
out = pl.pallas_call(
kernel,
grid=grid_,
in_specs=in_specs,
out_specs=out_specs,
compiler_params=plgpu.CompilerParams(
num_warps=num_warps_, num_stages=num_stages),
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_forward",
)(q, k, v, segment_ids)
return out if return_residuals else out[0]
def _mha_forward(
q,
k,
v,
segment_ids: jax.Array | None,
sm_scale: float,
causal: bool,
block_sizes: BlockSizes,
backward_pass_impl: str,
num_warps: int | None,
num_stages: int,
grid: Any,
interpret: bool,
debug: bool,
return_residuals: bool,
):
out, lse = mha(q, k, v, segment_ids=segment_ids, sm_scale=sm_scale,
causal=causal, block_sizes=block_sizes,
backward_pass_impl=backward_pass_impl,
num_warps=num_warps, num_stages=num_stages,
grid=grid, interpret=interpret, debug=debug,
return_residuals=True)
residuals = (q, k, v, segment_ids, out, lse)
ret = (out, lse) if return_residuals else out
return ret, residuals
def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int):
# load
head_mask = (jnp.arange(out_ref.shape[-1]) < head_dim)[None, :]
o = plgpu.load(out_ref, mask=head_mask, other=0.0)
do = plgpu.load(dout_ref, mask=head_mask, other=0.0)
# compute
delta = jnp.sum(o * do, axis=1)
# write-back
delta_ref[...] = delta.astype(delta_ref.dtype)
@jax.named_scope("preprocess_backward")
def _preprocess_backward(out, do, lse, block_q: int,
debug: bool, interpret: bool):
batch_size, seq_len, num_heads, head_dim = out.shape
head_dim_padded = pl.next_power_of_2(head_dim)
out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype)
delta = pl.pallas_call(
functools.partial(_preprocess_backward_kernel, head_dim=head_dim),
grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
in_specs=[
pl.BlockSpec((None, block_q, None, head_dim_padded),
lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, block_q, None, head_dim_padded),
lambda i, j, k: (j, i, k, 0)),
],
out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=3),
out_shape=out_shape,
debug=debug,
interpret=interpret,
name="mha_preprocess_backward",
)(out, do)
return delta
# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence
# length.
# Inspired by the triton tutorial: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
def mha_backward_kernel(
# Inputs
q_ref,
k_ref,
v_ref,
segment_ids_ref: jax.Array | None,
out_ref,
do_scaled_ref,
lse_ref,
delta_ref,
# Outputs
dq_ref,
dk_ref,
dv_ref,
*,
sm_scale: float,
causal: bool,
block_q_dkv: int,
block_kv_dkv: int,
block_q_dq: int,
block_kv_dq: int,
head_dim: int,
):
del out_ref # Not needed
q_seq_len = q_ref.shape[0]
kv_seq_len = k_ref.shape[0]
# Scan #1: dK and dV
# 1. Load a block of K and V of size (block_kv_dkv, head_dim) in SMEM.
# 2. Iterate through Q in chunks of (block_q_dkv, head_dim) to accumulate
# dK and dV.
start_k = pl.program_id(2)
curr_k_slice = pl.dslice(start_k * block_kv_dkv, block_kv_dkv)
head_dim_padded = q_ref.shape[-1]
dv = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32)
dk = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32)
head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :]
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
span_k = start_k * block_kv_dkv + jnp.arange(block_kv_dkv)
kv_segment_ids = (
None if segment_ids_ref is None else segment_ids_ref[curr_k_slice]
)
def inner_loop_dkdv(start_q, carry):
dv, dk = carry
curr_q_slice = pl.dslice(start_q * block_q_dkv, block_q_dkv)
q = plgpu.load(q_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
qk = pl.dot(q, k.T)
qk_scale = math.log2(math.e)
if sm_scale != 1.:
qk_scale *= sm_scale
qk *= qk_scale
if causal or segment_ids_ref is not None:
mask = None
if segment_ids_ref is not None:
assert kv_segment_ids is not None
q_segment_ids = segment_ids_ref[curr_q_slice]
mask = segment_mask(q_segment_ids, kv_segment_ids)
if causal:
span_q = start_q * block_q_dkv + jnp.arange(block_q_dkv)
causal_mask = span_q[:, None] >= span_k[None, :]
mask = (
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
)
assert mask is not None
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
lse = lse_ref[curr_q_slice]
di = delta_ref[curr_q_slice]
do = plgpu.load(
do_scaled_ref.at[curr_q_slice, :], mask=head_mask, other=0.0
)
p = jnp.exp2(qk - lse[:, None])
dv = dv + pl.dot(p.astype(do.dtype).T, do)
dp = jnp.zeros((block_q_dkv, block_kv_dkv), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
return dv, dk
lower_bound = lax.div(start_k * block_kv_dkv, block_q_dkv) if causal else 0
dv, dk = lax.fori_loop(
lower_bound, pl.cdiv(q_seq_len, block_q_dkv), inner_loop_dkdv, (dv, dk)
)
plgpu.store(
dv_ref.at[:, : dv.shape[-1]], dv.astype(dv_ref.dtype), mask=head_mask
)
plgpu.store(
dk_ref.at[:, : dk.shape[-1]], dk.astype(dk_ref.dtype), mask=head_mask
)
# Scan #2: dQ
# 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM.
# 2. Iterate through K and V in chunks of (block_kv_dq, head_dim) to
# accumulate dQ.
start_q = pl.program_id(2)
curr_q_slice = pl.ds(start_q * block_q_dq, block_q_dq)
span_q = start_q * block_q_dq + jnp.arange(block_q_dq)
dq = jnp.zeros([block_q_dq, head_dim_padded], dtype=jnp.float32)
q = plgpu.load(q_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
q_segment_ids = (
None if segment_ids_ref is None else segment_ids_ref[curr_q_slice]
)
lse = lse_ref[curr_q_slice]
do = plgpu.load(do_scaled_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
di = delta_ref[curr_q_slice]
def inner_loop_dq(start_k, dq):
curr_k_slice = pl.dslice(start_k * block_kv_dq, block_kv_dq)
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
qk = pl.dot(q, k.T)
qk_scale = math.log2(math.e)
if sm_scale != 1.:
qk_scale *= sm_scale
qk *= qk_scale
if causal or segment_ids_ref is not None:
mask = None
if segment_ids_ref is not None:
assert q_segment_ids is not None
kv_segment_ids = segment_ids_ref[curr_k_slice]
mask = segment_mask(q_segment_ids, kv_segment_ids)
if causal:
span_k = start_k * block_kv_dq + jnp.arange(block_kv_dq)
causal_mask = span_q[:, None] >= span_k[None, :]
mask = (
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
)
assert mask is not None
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
p = jnp.exp2(qk - lse[:, None])
dp = jnp.zeros((block_q_dq, block_kv_dq), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
return dq
if causal:
upper_bound = pl.cdiv((start_q + 1) * block_q_dq, block_kv_dq)
else:
upper_bound = pl.cdiv(kv_seq_len, block_kv_dq)
dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
plgpu.store(
dq_ref.at[:, : dq.shape[-1]], dq.astype(dq_ref.dtype), mask=head_mask
)
def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
backward_pass_impl: str, num_warps: int | None,
num_stages: int, grid: Any, interpret: bool,
debug: bool, return_residuals: bool, res, do):
if return_residuals:
raise ValueError(
"Kernel differentiation is not supported if return_residuals is True.")
q, k, v, segment_ids, out, lse = res
del num_stages, grid, return_residuals
if backward_pass_impl == "xla":
return jax.vjp(
functools.partial(mha_reference, sm_scale=sm_scale, causal=causal),
q,
k,
v,
segment_ids,
)[1](do)
elif backward_pass_impl == "triton":
if not block_sizes.has_backward_blocks:
raise ValueError("Backward block sizes must all be set.")
assert block_sizes.block_q_dkv is not None
assert block_sizes.block_kv_dkv is not None
assert block_sizes.block_q_dq is not None
assert block_sizes.block_kv_dq is not None
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_sizes.block_q, q_seq_len)
block_q_dkv = min(block_sizes.block_q_dkv, q_seq_len)
block_kv_dkv = min(block_sizes.block_kv_dkv, kv_seq_len)
block_q_dq = min(block_sizes.block_q_dq, q_seq_len)
block_kv_dq = min(block_sizes.block_kv_dq, kv_seq_len)
head_dim_padded = pl.next_power_of_2(head_dim)
if q_seq_len // block_q_dq != kv_seq_len // block_kv_dkv:
raise ValueError(
"q_seq_len and kv_seq_len must be divided into the same "
"number of blocks for the fused backward pass."
)
delta = _preprocess_backward(out, do, lse, block_q, debug, interpret)
out_shapes = [
jax.ShapeDtypeStruct(q.shape, q.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]
in_specs: list[pl.BlockSpec | None] = [
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
lambda i, j, _: (i, 0, j, 0)),
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
]
if segment_ids is None:
in_specs.insert(3, None)
else:
in_specs.insert(3, pl.BlockSpec((None, kv_seq_len),
lambda i, j, _: (i, 0)))
grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_kv_dkv))
num_warps_ = num_warps
if num_warps_ is None:
if (
block_q_dkv * block_kv_dkv < 128 * 128
or block_q_dq * block_kv_dq < 128 * 128
):
num_warps_ = 4
else:
num_warps_ = 8
dq, dk, dv = pl.pallas_call(
functools.partial(
mha_backward_kernel,
sm_scale=sm_scale,
causal=causal,
block_q_dkv=block_q_dkv,
block_kv_dkv=block_kv_dkv,
block_q_dq=block_q_dq,
block_kv_dq=block_kv_dq,
head_dim=head_dim,
),
out_shape=out_shapes,
in_specs=in_specs,
grid=grid,
out_specs=[
pl.BlockSpec(
(None, block_q_dq, None, head_dim_padded),
lambda i, j, k: (i, k, j, 0), # dq
),
pl.BlockSpec(
(None, block_kv_dkv, None, head_dim_padded),
lambda i, j, k: (i, k, j, 0), # dk
),
pl.BlockSpec(
(None, block_kv_dkv, None, head_dim_padded),
lambda i, j, k: (i, k, j, 0), # dv
),
],
name="mha_backward",
debug=debug,
interpret=interpret,
compiler_params=plgpu.CompilerParams(
num_warps=num_warps_, num_stages=2
),
)(q, k, v, segment_ids, out, do, lse, delta)
else:
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
return dq.astype(q.dtype), dk, dv, None
mha.defvjp(_mha_forward, _mha_backward)
@functools.partial(jax.jit, static_argnames=['sm_scale', 'causal'])
def mha_reference(
q,
k,
v,
segment_ids: jnp.ndarray | None,
sm_scale=1.0,
causal: bool = False,
):
q_seq_len = q.shape[1]
kv_seq_len = k.shape[1]
logits = jnp.einsum(
'bqhc,bkhc->bhqk', q, k, preferred_element_type=jnp.float32
)
mask = None
if segment_ids is not None:
mask = jnp.expand_dims(segment_mask(segment_ids, segment_ids), 1)
mask = jnp.broadcast_to(mask, logits.shape)
if causal:
causal_mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool))
causal_mask = jnp.broadcast_to(causal_mask, logits.shape)
mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
logits = logits if mask is None else jnp.where(mask, logits, float("-inf"))
weights = jax.nn.softmax(logits * sm_scale)
return jnp.einsum(
'bhqk,bkhc->bqhc', weights, v, preferred_element_type=jnp.float32
)
@@ -0,0 +1,911 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FlashAttention3 implementation (using Mosaic GPU as the backend)."""
import dataclasses
import functools
import itertools
import math
import jax
from jax import lax
from jax._src import test_util as jtu # noqa: F401
from jax._src.lib import cuda_versions # noqa: F401
from jax.experimental.mosaic.gpu import profiler
import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import Protocol, TypeVar
T = TypeVar('T')
class PipelineCallback(Protocol):
"""A callback that returns the same type as the input."""
def __call__(self, arg: T, /) -> T: ...
@dataclasses.dataclass(frozen=True)
class TuningConfig:
block_q: int
block_kv: int
max_concurrent_steps: int
use_schedule_barrier: bool = True
causal: bool = False
compute_wgs_bwd: int = 1
block_q_dkv: int | None = None
block_kv_dkv: int | None = None
block_q_dq: int | None = None
block_kv_dq: int | None = None
def __post_init__(self):
if self.block_q % 64:
raise ValueError(f"{self.block_q=} must be a multiple of 64")
if self.block_kv % 64:
raise ValueError(f"{self.block_kv=} must be a multiple of 64")
if self.max_concurrent_steps < 2:
raise ValueError(f"{self.max_concurrent_steps=} must be at least 2")
backward_blocks = [self.block_q_dkv, self.block_kv_dkv, self.block_q_dq, self.block_kv_dq]
block_is_set = [blk is not None for blk in backward_blocks]
if any(block_is_set) and not all(block_is_set):
raise ValueError(
"Backward block sizes (block_q_dkv, block_kv_dkv, block_q_dq, "
"block_kv_dq) must either all be specified or all be None."
)
@property
def has_backward_blocks(self) -> bool:
return self.block_q_dkv is not None
def _attention_forward(q, k, v, config: TuningConfig, save_residuals: bool = False):
assert cuda_versions is not None
cuda_runtime_version = cuda_versions.cuda_runtime_get_version()
# TODO(pobudzey): Undo when we upgrade to cuda 12.9.1.
if config.causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091:
raise ValueError(
"Causal masking not supported with cuda versions between 12.8.0 and"
" 12.9.1 due to a ptxas miscompilation."
)
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}")
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
_, kv_seq_len, num_kv_heads, _ = k.shape
kv_shape = (batch_size, kv_seq_len, num_kv_heads, head_dim)
if k.shape != kv_shape:
raise ValueError(f"Expected {k.shape=} to be {kv_shape} (inferred from q)")
if k.shape != kv_shape:
raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)")
if (dtype := q.dtype) != k.dtype or dtype != v.dtype:
raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}")
if num_q_heads % num_kv_heads:
raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}")
q_heads_per_kv_head = num_q_heads // num_kv_heads
if head_dim % 64:
raise ValueError(f"{head_dim=} must be divisible by 64")
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")
max_concurrent_steps = min(
config.max_concurrent_steps, kv_seq_len // config.block_kv
)
block_q, block_kv = config.block_q, config.block_kv
if kv_seq_len % block_kv:
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}")
def kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped):
batch = lax.axis_index("batch")
q_head = lax.axis_index("heads")
smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped
wg_idx = lax.axis_index("wg")
qo_smem2, k_smem, v_smem, lse_smem2 = smem_buffers
k_barriers, v_barriers, q_barriers = buffer_barriers
k_consumed_barriers, v_consumed_barriers = consumed_barriers
def perform_schedule_barrier():
plgpu.barrier_arrive(schedule_barrier)
plgpu.barrier_wait(schedule_barrier)
if config.causal:
block_q_end = (lax.axis_index("q_seq") + 1) * (2 * block_q)
block_max_kv_steps = pl.cdiv(block_q_end, jnp.array(block_kv, jnp.int32))
else:
block_max_kv_steps = kv_seq_len // block_kv
@pl.when(wg_idx < 2)
def _compute_wg():
plgpu.set_max_registers(232, action="increase")
qo_smem = qo_smem2.at[wg_idx]
lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
if config.causal:
kv_steps = pl.cdiv(q_seq_base + block_q, jnp.array(block_kv, jnp.int32))
else:
kv_steps = block_max_kv_steps
plgpu.copy_gmem_to_smem(
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
qo_smem,
q_barriers.at[wg_idx],
)
plgpu.barrier_wait(q_barriers.at[wg_idx])
m_i = plgpu.layout_cast(
jnp.full((block_q,), -jnp.inf, dtype=jnp.float32),
plgpu.Layout.WGMMA.reduce(1),
)
l_i = plgpu.layout_cast(
jnp.full((block_q,), 0, dtype=jnp.float32),
plgpu.Layout.WGMMA.reduce(1),
)
acc = plgpu.layout_cast(
jnp.full((block_q, head_dim), 0, dtype=jnp.float32),
plgpu.Layout.WGMMA,
)
@pl.when(kv_steps > 0)
def _():
plgpu.barrier_wait(k_barriers.at[0])
pl.when(wg_idx == 1)(perform_schedule_barrier)
def kv_loop(kv_step, carry, causal: bool = False):
acc, m_i, l_i = carry
slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype))
# QK
def compute_qk(acc_ref):
plgpu.wgmma(acc_ref, qo_smem, plgpu.transpose_ref(k_smem.at[slot], (1, 0)))
perform_schedule_barrier()
return acc_ref[...]
qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32))
plgpu.barrier_arrive(k_consumed_barriers.at[slot])
if causal:
q_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 0, layout=plgpu.Layout.WGMMA)
kv_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 1, layout=plgpu.Layout.WGMMA)
mask = (q_ids + q_seq_base) >= (kv_ids + kv_step * block_kv)
qk = jnp.where(mask, qk, -jnp.inf)
# Softmax
# We keep m scaled by log2e to use FMA instructions when computing p.
log2e = math.log2(math.e)
m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e)
alpha = jnp.exp2(m_i - m_ij)
m_i = m_ij
p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0]))
acc *= lax.broadcast_in_dim(alpha, acc.shape, [0])
l_i *= alpha
p16 = p.astype(dtype)
def end_softmax_barriers():
plgpu.barrier_arrive(schedule_barrier) # Done with softmax!
plgpu.barrier_wait(v_barriers.at[slot])
plgpu.barrier_wait(schedule_barrier) # Wait until TensorCore is free.
# Can't fully explain why, but empirically the ordering here influences
# the performance of the final kernel quite significantly.
if head_dim <= 128:
l_i += p.sum(axis=1)
acc, l_i, m_i, p16 = lax.optimization_barrier((acc, l_i, m_i, p16))
end_softmax_barriers()
else:
end_softmax_barriers()
l_i += p.sum(axis=1)
# PV
def compute_pv(acc_ref):
plgpu.wgmma(acc_ref, p16, v_smem.at[slot])
wait_step = kv_step + 1
wait_slot = lax.rem(wait_step, jnp.array(max_concurrent_steps, kv_step.dtype))
@pl.when(wait_step < kv_steps)
def _wait():
plgpu.barrier_wait(k_barriers.at[wait_slot])
acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc))
plgpu.barrier_arrive(v_consumed_barriers.at[slot])
return acc, m_i, l_i
if not config.causal:
acc, m_i, l_i = lax.fori_loop(0, block_max_kv_steps, kv_loop, (acc, m_i, l_i))
else:
def epilogue_kv_loop(kv_step, _):
# This loop makes sure that all the pipelined KV data is processed, even
# if one compute wg finishes early like with causal masking.
slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype))
plgpu.barrier_arrive(k_consumed_barriers.at[slot])
plgpu.barrier_arrive(v_consumed_barriers.at[slot])
perform_schedule_barrier()
perform_schedule_barrier()
causal_kv_loop = functools.partial(kv_loop, causal=True)
full_kv_steps = lax.div(q_seq_base, jnp.array(block_kv, jnp.int32))
# With causal masking, the KV loop unrolling is split in 3 sections:
# 1. A fast path where no causal mask is needed.
acc, m_i, l_i = lax.fori_loop(0, full_kv_steps, kv_loop, (acc, m_i, l_i))
# 2. Causal masking.
acc, m_i, l_i = lax.fori_loop(full_kv_steps, kv_steps, causal_kv_loop, (acc, m_i, l_i))
# 3. Epilogue to flush the data pipeline.
lax.fori_loop(kv_steps, block_max_kv_steps, epilogue_kv_loop, None)
pl.when(wg_idx == 0)(perform_schedule_barrier)
# TODO(apaszke): Invert and multiply to avoid expensive divisions.
acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
qo_smem[...] = acc.astype(dtype)
if lse_smem is not None:
RCP_LN2 = 1.4426950408889634
log2 = lambda x: jnp.log(x) * RCP_LN2
lse_smem[...] = m_i + log2(l_i)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
)
if lse_smem is not None:
plgpu.copy_smem_to_gmem(
lse_smem,
lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)],
)
plgpu.wait_smem_to_gmem(0)
@pl.when(wg_idx == 2)
def _memory_wg():
plgpu.set_max_registers(40, action="decrease")
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
for i in range(max_concurrent_steps):
s = (batch, pl.ds(i * block_kv, block_kv), kv_head)
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i])
@pl.loop(0, block_max_kv_steps - max_concurrent_steps)
def _kv_loop(kv_step):
tma_step = kv_step + max_concurrent_steps
tma_slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype))
s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head)
plgpu.barrier_wait(k_consumed_barriers.at[tma_slot])
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot])
plgpu.barrier_wait(v_consumed_barriers.at[tma_slot])
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot])
def entry(q_ref, k_ref, v_ref, out_ref, lse_ref):
compute_wgs = 2
tiling = plgpu.TilingTransform((8, 64))
swizzle = plgpu.SwizzleTransform(128)
qo_scratch = plgpu.SMEM(
(compute_wgs, block_q, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
k_scratch = plgpu.SMEM(
(max_concurrent_steps, block_kv, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
v_scratch = plgpu.SMEM(
(max_concurrent_steps, block_kv, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
scratch = [qo_scratch, k_scratch, v_scratch, None]
if save_residuals:
scratch[3] = plgpu.SMEM((compute_wgs, block_q), jnp.float32)
pl.run_scoped(
lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args),
scratch,
(
plgpu.Barrier(num_barriers=max_concurrent_steps),
plgpu.Barrier(num_barriers=max_concurrent_steps),
plgpu.Barrier(num_barriers=compute_wgs),
),
(plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2,
plgpu.Barrier(num_arrivals=compute_wgs),
collective_axes="wg",
)
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
out_shape = [q, None]
if save_residuals:
# Note that we keep seq_len in the minor-most dimension so that we can do
# 1D TMAs on chunks of `block_q`.
out_shape[1] = jax.ShapeDtypeStruct(
(batch_size, num_q_heads, q_seq_len), jnp.float32
)
out, lse = plgpu.kernel(
entry,
out_shape=out_shape,
grid=(num_q_heads, num_q_tiles, batch_size),
grid_names=("heads", "q_seq", "batch"),
num_threads=3,
thread_name="wg",
compiler_params=plgpu.CompilerParams(approx_math=True),
)(q, k, v)
if save_residuals:
assert lse is not None
return out, (lse,)
return out
@partial(jax.custom_vjp, nondiff_argnums=(3, 4))
@partial(jax.jit, static_argnames=["config", "save_residuals"])
def attention(q, k, v, config: TuningConfig, save_residuals: bool = False):
return _attention_forward(q, k, v, config, save_residuals)
def _attention_fwd(q, k, v, config: TuningConfig, save_residuals: bool):
del save_residuals
out, (lse,) = _attention_forward(q, k, v, config, save_residuals=True)
return out, (q, k, v, out, lse)
def _attention_bwd(config: TuningConfig, save_residuals: bool, res, do):
del save_residuals
q, k, v, out, lse = res
if config.causal:
raise NotImplementedError("Causal attention not supported in the backwards pass yet.")
if not config.has_backward_blocks:
raise ValueError("Need to specify backward blocks.")
assert config.block_q_dq is not None
assert config.block_kv_dq is not None
assert config.block_q_dkv is not None
assert config.block_kv_dkv is not None
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
_, kv_seq_len, num_kv_heads, _ = k.shape
q_heads_per_kv_head = num_q_heads // num_kv_heads
dtype = q.dtype
compute_wgs = config.compute_wgs_bwd
num_q_tiles, rem = divmod(q_seq_len, config.block_q_dq * compute_wgs)
if rem:
raise NotImplementedError(
f"{q_seq_len=} must be a multiple of {config.block_q_dq=} * {compute_wgs=}")
num_kv_tiles, rem = divmod(kv_seq_len, config.block_kv_dkv * compute_wgs)
if rem:
raise NotImplementedError(
f"{kv_seq_len=} must be a multiple of {config.block_kv_dkv=} * {compute_wgs=}")
num_q_tiles_in_dkv, rem = divmod(q_seq_len, config.block_q_dkv)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {config.block_q_dkv=}")
num_kv_tiles_in_dq, rem = divmod(kv_seq_len, config.block_kv_dq)
if rem:
raise NotImplementedError(f"{kv_seq_len=} must be a multiple of {config.block_kv_dq=}")
tiling = plgpu.TilingTransform((8, 64))
swizzle = plgpu.SwizzleTransform(128)
delta = jnp.einsum('bqhd,bqhd->bhq', out.astype(jnp.float32), do.astype(jnp.float32))
del out # Not needed anymore.
def kernel_dq(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, dq_ref,
smem_buffers, buffer_barriers, block_q, block_kv):
batch = lax.axis_index("batch")
q_head = lax.axis_index("heads")
wg_idx = lax.axis_index("wg")
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
q_smem2, do_smem2, lse_smem2, delta_smem2 = smem_buffers
q_barriers, do_barriers, lse_barriers, delta_barriers = buffer_barriers
def _compute_thread(pipeline_callback: PipelineCallback, /) -> None:
q_smem, do_smem, lse_smem, delta_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx], lse_smem2.at[wg_idx], delta_smem2.at[wg_idx]
q_seq_base = lax.axis_index("q_seq") * (compute_wgs * block_q) + wg_idx * block_q
q_slice = (batch, pl.ds(q_seq_base, block_q), q_head)
plgpu.copy_gmem_to_smem(q_ref.at[q_slice], q_smem, q_barriers.at[wg_idx])
plgpu.copy_gmem_to_smem(do_ref.at[q_slice], do_smem, do_barriers.at[wg_idx])
plgpu.copy_gmem_to_smem(
delta_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)],
delta_smem,
delta_barriers.at[wg_idx],
)
plgpu.copy_gmem_to_smem(
lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)],
lse_smem,
lse_barriers.at[wg_idx],
)
for buffer in buffer_barriers:
plgpu.barrier_wait(buffer.at[wg_idx])
delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA.reduce(1))
lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA.reduce(1))
dq_acc: jax.Array = plgpu.layout_cast(
jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
)
dq, _, _ = pipeline_callback((dq_acc, lse, delta))
q_smem[...] = dq.astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(q_smem, dq_ref.at[q_slice])
plgpu.wait_smem_to_gmem(0)
def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry):
q_smem, do_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx]
(dq_acc, lse, delta) = carry
def compute_s(acc_ref):
plgpu.wgmma(acc_ref, q_smem, plgpu.transpose_ref(k_smem, (1, 0)))
return acc_ref[...]
s = pl.run_scoped(compute_s, plgpu.ACC((block_q, block_kv), jnp.float32))
s *= math.log2(math.e)
p = jnp.exp2(s - lax.broadcast_in_dim(lse, (block_q, block_kv), [0]))
# dP
def compute_dp(acc_ref):
plgpu.wgmma(acc_ref, do_smem, plgpu.transpose_ref(v_smem, (1, 0)))
return acc_ref[...]
dp = pl.run_scoped(compute_dp, plgpu.ACC((block_q, block_kv), jnp.float32))
plgpu.barrier_arrive(v_consumed_barrier)
# dS
ds = p * (dp - lax.broadcast_in_dim(delta, (block_q, block_kv), [0]))
# dQ
def compute_dq(acc_ref):
plgpu.wgmma(acc_ref, ds.astype(k_ref.dtype), k_smem)
dq_acc = pl.run_state(compute_dq)(plgpu.ACC.init(dq_acc))
plgpu.barrier_arrive(k_consumed_barrier)
return (dq_acc, lse, delta)
pipeline = plgpu.emit_pipeline_warp_specialized(
kv_pipeline,
grid=(num_kv_tiles_in_dq,),
max_concurrent_steps=min([config.max_concurrent_steps, num_q_tiles]),
num_compute_wgs=compute_wgs,
memory_registers=40,
wg_axis="wg",
manual_consumed_barriers=True,
compute_context=_compute_thread,
in_specs=[
plgpu.BlockSpec( # k
block_shape=(block_kv, head_dim),
index_map=lambda i: (i, 0),
transforms=[tiling, swizzle]),
plgpu.BlockSpec( # v
block_shape=(block_kv, head_dim),
index_map=lambda i: (i, 0),
transforms=[tiling, swizzle]),
])
k_ref = k_ref.at[batch, :, kv_head, :]
v_ref = v_ref.at[batch, :, kv_head, :]
pipeline(k_ref, v_ref)
def kernel_dkv(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref,
dk_ref, dv_ref, smem_buffers, buffer_barriers, block_q: int, block_kv: int):
batch = lax.axis_index("batch")
q_head = lax.axis_index("heads")
wg_idx = lax.axis_index("wg")
(k_smem2, v_smem2) = smem_buffers
(k_barriers, v_barriers) = buffer_barriers
def _compute_thread(pipeline_callback):
k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx]
kv_seq_base = lax.axis_index("kv_seq") * (compute_wgs * block_kv) + wg_idx * block_kv
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
plgpu.copy_gmem_to_smem(
k_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)],
k_smem,
k_barriers.at[wg_idx])
plgpu.copy_gmem_to_smem(
v_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)],
v_smem,
v_barriers.at[wg_idx])
plgpu.barrier_wait(k_barriers.at[wg_idx])
plgpu.barrier_wait(v_barriers.at[wg_idx])
dk_acc = plgpu.layout_cast(
jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
)
dv_acc = plgpu.layout_cast(
jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
)
(dk, dv) = pipeline_callback((dv_acc, dk_acc))
k_smem[...] = dk.astype(dtype)
v_smem[...] = dv.astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
k_smem,
dk_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)],
commit_group=False)
plgpu.copy_smem_to_gmem(
v_smem,
dv_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)],
commit_group=False)
plgpu.commit_smem_to_gmem_group()
plgpu.wait_smem_to_gmem(0)
def q_pipeline(_, q_smem, do_smem, lse_smem, delta_smem, q_consumed_barrier, do_consumed_barrier, lse_consumed_barrier, delta_consumed_barrier, carry):
k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx]
dk_acc, dv_acc = carry
def _compute_sT(acc_ref):
plgpu.wgmma(acc_ref, k_smem, plgpu.transpose_ref(q_smem, (1, 0)))
return acc_ref[...]
sT = pl.run_scoped(_compute_sT, plgpu.ACC((block_kv, block_q), jnp.float32))
sT *= math.log2(math.e)
lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA.reduce(0))
plgpu.barrier_arrive(lse_consumed_barrier)
pT = jnp.exp2(sT - lax.broadcast_in_dim(lse, (block_kv, block_q), [1]))
def _compute(refs):
# Combining two WGMMA calls in one block to avoid the unnecessary
# synchronization from two `wgmma.wait_group` calls.
dv_acc_ref, dpT_acc_ref = refs
plgpu.wgmma(dv_acc_ref, pT.astype(dtype), do_smem) # dV
plgpu.wgmma(dpT_acc_ref, v_smem, plgpu.transpose_ref(do_smem, (1, 0))) # dpT
zeros = plgpu.layout_cast(
jnp.full((block_kv, block_q), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
)
dv_acc, dpT = pl.run_state(_compute)((plgpu.ACC.init(dv_acc), plgpu.ACC.init(zeros)))
plgpu.barrier_arrive(do_consumed_barrier)
delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA.reduce(0))
plgpu.barrier_arrive(delta_consumed_barrier)
dsT = pT * (dpT - lax.broadcast_in_dim(delta, (block_kv, block_q), [1])) # jax-operator-types
def compute_dk(acc_ref):
plgpu.wgmma(acc_ref, dsT.astype(dtype), q_smem)
dk_acc = pl.run_state(compute_dk)(plgpu.ACC.init(dk_acc))
plgpu.barrier_arrive(q_consumed_barrier)
return (dk_acc, dv_acc)
pipeline = plgpu.emit_pipeline_warp_specialized(
q_pipeline,
grid=(num_q_tiles_in_dkv,),
max_concurrent_steps=min([config.max_concurrent_steps, num_kv_tiles]),
num_compute_wgs=compute_wgs,
memory_registers=40,
wg_axis="wg",
manual_consumed_barriers=True,
compute_context=_compute_thread,
in_specs=[
plgpu.BlockSpec( # q
block_shape=(block_q, head_dim),
index_map=lambda i: (i, 0),
transforms=[tiling, swizzle]),
plgpu.BlockSpec( # do
block_shape=(block_q, head_dim),
index_map=lambda i: (i, 0),
transforms=[tiling, swizzle]),
plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)),
plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,))
])
q_ref = q_ref.at[batch, :, q_head, :]
do_ref = do_ref.at[batch, :, q_head, :]
lse_ref = lse_ref.at[batch, q_head, :]
delta_ref = delta_ref.at[batch, q_head, :]
pipeline(q_ref, do_ref, lse_ref, delta_ref)
q_scratch = plgpu.SMEM(
(compute_wgs, config.block_q_dq, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
do_scratch = q_scratch
lse_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32)
delta_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32)
dq = plgpu.kernel(
partial(kernel_dq, block_q=config.block_q_dq, block_kv=config.block_kv_dq),
out_shape=q,
scratch_shapes=[
(q_scratch, do_scratch, lse_scratch, delta_scratch),
(plgpu.Barrier(num_barriers=compute_wgs),) * 4
],
compiler_params=plgpu.CompilerParams(approx_math=True),
grid=(num_q_heads, num_q_tiles, batch_size),
grid_names=("heads", "q_seq", "batch"),
num_threads=compute_wgs + 1,
thread_name="wg",
)(q, k, v, do, lse, delta)
k_scratch = plgpu.SMEM(
(compute_wgs, config.block_kv_dkv, head_dim), jnp.float16,
transforms=(tiling, swizzle),
)
v_scratch = k_scratch
out_shape_kv = jax.ShapeDtypeStruct(
(batch_size, kv_seq_len, num_q_heads, head_dim), dtype=jnp.float16)
dk, dv = plgpu.kernel(
partial(kernel_dkv, block_q=config.block_q_dkv, block_kv=config.block_kv_dkv),
out_shape=[out_shape_kv, out_shape_kv],
scratch_shapes=[
(k_scratch, v_scratch),
(plgpu.Barrier(num_barriers=compute_wgs),) * 2
],
compiler_params=plgpu.CompilerParams(approx_math=True),
grid=(num_q_heads, num_kv_tiles, batch_size),
grid_names=("heads", "kv_seq", "batch"),
num_threads=compute_wgs + 1,
thread_name="wg"
)(q, k, v, do, lse, delta)
if q_heads_per_kv_head > 1:
sum_shape = (*k.shape[:-1], q_heads_per_kv_head, head_dim)
dk = dk.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dk.dtype)
dv = dv.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dv.dtype)
return dq, dk, dv
attention.defvjp(_attention_fwd, _attention_bwd)
@functools.partial(jax.jit, static_argnames=["config", "save_residuals"])
def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False):
if config.causal:
raise NotImplementedError("Causal attention is not supported with the pipeline emitter yet.")
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}")
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
_, kv_seq_len, num_kv_heads, _ = k.shape
kv_shape = (batch_size, kv_seq_len, num_kv_heads, head_dim)
if k.shape != kv_shape:
raise ValueError(f"Expected {k.shape=} to be {kv_shape} (inferred from q)")
if k.shape != kv_shape:
raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)")
if (dtype := q.dtype) != k.dtype or dtype != v.dtype:
raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}")
if num_q_heads % num_kv_heads:
raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}")
q_heads_per_kv_head = num_q_heads // num_kv_heads
if head_dim % 64:
raise ValueError(f"{head_dim=} must be divisible by 64")
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")
max_concurrent_steps = min(
config.max_concurrent_steps, kv_seq_len // config.block_kv
)
compute_wgs = 2
block_q, block_kv = config.block_q, config.block_kv
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
if rem:
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, smem_buffers, q_barriers, schedule_barrier):
batch = lax.axis_index("batch")
wg_idx = lax.axis_index("wg")
qo_smem2, lse_smem2 = smem_buffers
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
q_head = lax.axis_index("heads")
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
def perform_schedule_barrier():
if config.use_schedule_barrier:
plgpu.barrier_arrive(schedule_barrier)
plgpu.barrier_wait(schedule_barrier)
def _compute_thread(pipeline_callback: PipelineCallback, /) -> None:
qo_smem = qo_smem2.at[wg_idx]
lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None
m_i = jnp.full((block_q,), -jnp.inf, dtype=jnp.float32)
l_i = jnp.full((block_q,), 0, dtype=jnp.float32)
acc = jnp.full((block_q, head_dim), 0, dtype=jnp.float32)
# Q is not pipelined, so we load in with a manual DMA.
plgpu.copy_gmem_to_smem(
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
qo_smem,
q_barriers.at[wg_idx],
)
plgpu.barrier_wait(q_barriers.at[wg_idx])
pl.when(wg_idx == 1)(perform_schedule_barrier)
final_carry = pipeline_callback((acc, m_i, l_i))
pl.when(wg_idx == 0)(perform_schedule_barrier)
acc, m_i, l_i = final_carry
acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
qo_smem[...] = acc.astype(dtype)
if lse_smem is not None:
RCP_LN2 = 1.4426950408889634
log2 = lambda x: jnp.log(x) * RCP_LN2
lse_smem[...] = m_i + log2(l_i)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
)
if lse_smem is not None:
plgpu.copy_smem_to_gmem(
lse_smem,
lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)],
)
plgpu.wait_smem_to_gmem(0)
def kv_pipeline(_, k_smem, v_smem,
k_consumed_barrier, v_consumed_barrier,
carry):
acc, m_i, l_i = carry
qo_smem = qo_smem2.at[wg_idx]
def compute_qk(acc_ref):
plgpu.wgmma(acc_ref, qo_smem, plgpu.transpose_ref(k_smem, (1, 0)))
perform_schedule_barrier()
return acc_ref[...]
qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32))
plgpu.barrier_arrive(k_consumed_barrier)
# Softmax
# We keep m scaled by log2e to use FMA instructions when computing p.
log2e = math.log2(math.e)
m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e)
alpha = jnp.exp2(m_i - m_ij)
m_i = m_ij
p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0]))
acc *= lax.broadcast_in_dim(alpha, acc.shape, [0])
l_i *= alpha
p16 = p.astype(dtype)
perform_schedule_barrier()
l_i += p.sum(axis=1)
# PV
def compute_pv(acc_ref):
plgpu.wgmma(acc_ref, p16, v_smem)
acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc))
plgpu.barrier_arrive(v_consumed_barrier)
return acc, m_i, l_i
pipeline = plgpu.emit_pipeline_warp_specialized(
kv_pipeline,
grid=(kv_seq_len // block_kv,),
max_concurrent_steps=max_concurrent_steps,
num_compute_wgs=compute_wgs,
memory_registers=40,
wg_axis="wg",
manual_consumed_barriers=True,
compute_context=_compute_thread,
in_specs=[
plgpu.BlockSpec( # k
block_shape=(block_kv, head_dim),
index_map=lambda i: (i, 0)),
plgpu.BlockSpec( # v
block_shape=(block_kv, head_dim),
index_map=lambda i: (i, 0)),
],
out_specs=[],
)
k_ref = k_ref.at[batch, :, kv_head, :]
v_ref = v_ref.at[batch, :, kv_head, :]
pipeline(k_ref, v_ref)
out_shape = [q, None]
if save_residuals:
out_shape[1] = jax.ShapeDtypeStruct((batch_size, num_q_heads, q_seq_len), jnp.float32)
qo_scratch = plgpu.SMEM((compute_wgs, block_q, head_dim), jnp.float16)
smem_scratch = [qo_scratch, None]
if save_residuals:
smem_scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32)
out, lse = plgpu.kernel(
fa3_kernel,
grid=(num_q_heads, num_q_tiles, batch_size),
grid_names=("heads", "q_seq", "batch"),
num_threads=3,
thread_name="wg",
out_shape=out_shape,
scratch_shapes=(
tuple(smem_scratch),
plgpu.Barrier(num_barriers=compute_wgs),
plgpu.Barrier(num_arrivals=compute_wgs),),
compiler_params=plgpu.CompilerParams(
approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup,
),
)(q, k, v)
if save_residuals:
assert lse is not None
return out, (lse,)
return out
@functools.partial(jax.jit, static_argnames=["causal", "save_residuals"])
def attention_reference(q, k, v, causal=False, save_residuals=False):
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
kv_seq_len, num_kv_heads = k.shape[1], k.shape[2]
q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v))
q_reshaped = q.reshape(
batch_size, q_seq_len, num_kv_heads, num_q_heads // num_kv_heads, head_dim
)
logits = jnp.einsum("bqHhc,bkHc->bqHhk", q_reshaped, k)
if causal:
mask = jnp.arange(q_seq_len)[:, None] >= jnp.arange(kv_seq_len)[None, :]
mask = jnp.broadcast_to(mask[:, None, None, :], logits.shape)
logits = jnp.where(mask, logits, -jnp.inf)
m = logits.max(axis=-1, keepdims=True)
unnormalized = jnp.exp(logits - m)
l = unnormalized.sum(axis=-1, keepdims=True)
weights = unnormalized / l
out = jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape)
if save_residuals:
log2e = math.log2(math.e)
l = l.reshape(*q.shape[:-1])
m = m.reshape(*q.shape[:-1])
lse = m * log2e + jnp.log2(l)
return out, (lse.swapaxes(-1, -2),)
else:
return out
def main(unused_argv):
num_q_heads = 16
num_kv_heads = 16
use_pipeline_emitter = False
if use_pipeline_emitter:
attention_impl = attention_with_pipeline_emitter
schedule_barrier_opts = (True, False)
else:
attention_impl = attention
schedule_barrier_opts = (True,)
problem_it = itertools.product(
(1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts, (False, True))
for batch_size, seq_len, head_dim, use_schedule_barrier, causal in problem_it:
assert cuda_versions is not None
cuda_runtime_version = cuda_versions.cuda_runtime_get_version()
# TODO(pobudzey): Undo when we upgrade to cuda 12.9.1.
if causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091:
continue
if causal and use_pipeline_emitter:
continue
q_seq_len = kv_seq_len = seq_len
print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}"
f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} {causal=:} ====")
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
block_q = 64
best = None
for block_kv in (256, 128, 64):
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier, causal=causal)
try:
out, runtime_ms = profiler.measure(functools.partial(attention_impl, config=config))(q, k, v)
if seq_len < 32768:
out_ref = attention_reference(q, k, v, causal=causal)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
except ValueError as e:
if "exceeds available shared memory" in e.args[0]:
continue
raise
assert runtime_ms is not None
runtime_us = runtime_ms * 1e3
matmul_flops = (
4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size
)
if causal:
matmul_flops //= 2
peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS
optimal_time = matmul_flops / peak_flops * 1e6 # us
achieved_tc_util = optimal_time / runtime_us * 100
print(
f"block_q={block_q:<4}block_kv={block_kv:<4}: {runtime_us:<7.1f}us"
f" = {achieved_tc_util:4.1f}% TC utilization"
)
if best is None or runtime_us < best[0]:
best = (runtime_us, achieved_tc_util)
break # Remove this for full autotuning.
if best is not None:
print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization")
if __name__ == "__main__":
from absl import app
import jax
jax.config.config_with_absl()
app.run(main)
@@ -0,0 +1,343 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Matrix Multiplication kernel for Blackwell GPUs."""
import dataclasses
import enum
import functools
import itertools
import statistics
import jax
from jax import lax
from jax._src import test_util as jtu # noqa: F401
from jax.experimental.mosaic.gpu import profiler
import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu
import jax.numpy as jnp
import numpy as np
class MatmulDimension(enum.IntEnum):
M = 0
N = 1
@dataclasses.dataclass(frozen=True)
class TuningConfig:
tile_m: int
tile_n: int
tile_k: int
max_concurrent_steps: int
collective: bool
epilogue_tile_n: int = 64
grid_minor_dim: MatmulDimension = MatmulDimension.N
grid_tile_width: int = 1
def matmul_kernel(a, b, config: TuningConfig):
dtype = a.dtype
if a.dtype != b.dtype:
raise ValueError(
f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}"
)
m, k = a.shape
k2, n = b.shape
if k != k2:
raise ValueError(
f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}"
)
collective = config.collective
tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k)
epilogue_tile_n = config.epilogue_tile_n
if tile_n % epilogue_tile_n != 0:
raise ValueError(
f"{tile_n=} must be divisible by {epilogue_tile_n=}"
)
block_tile_m = tile_m
block_tile_n = tile_n
if collective:
tile_m *= 2
tile_n *= 2
swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
transforms = (
plgpu.TilingTransform((8, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
)
out_swizzle = plgpu.find_swizzle(epilogue_tile_n * jnp.dtype(dtype).itemsize * 8)
out_swizzle_elems = out_swizzle // jnp.dtype(dtype).itemsize
out_transforms = (
plgpu.TilingTransform((8, out_swizzle_elems)),
plgpu.SwizzleTransform(out_swizzle),
)
if m % tile_m != 0:
raise ValueError(f"{m=} must be divisible by {tile_m=}")
if n % tile_n != 0:
raise ValueError(f"{n=} must be divisible by {tile_n=}")
if k % tile_k != 0:
raise ValueError(f"{k=} must be divisible by {tile_k=}")
m_iters = m // tile_m
n_iters = n // tile_n
k_iters = k // tile_k
max_concurrent_steps = config.max_concurrent_steps
TMA_WARP = 0
MMA_WARP = 1
COMPUTE_WG = 0
STORE_WG = 1
def kernel(a_gmem, b_gmem, out_gmem,
a_smem, b_smem, acc_tmem, acc_smem,
ab_tma_barrier, store_done_barrier, mma_done_barrier,
consumed_barrier):
wg_idx = lax.axis_index("wg")
cluster_idx = lax.axis_index("x")
is_lead_block = cluster_idx == 0
@plgpu.dynamic_scheduling_loop(grid_names=("mn_linear",), thread_axis="wg")
def mn_loop(loop_info: plgpu.NDLoopInfo):
(lin_idx,) = loop_info.index
local_index = loop_info.local_index
m_index, n_index = plgpu.planar_snake(
lin_idx,
(m_iters, n_iters),
config.grid_minor_dim,
config.grid_tile_width,
)
block_m_index = m_index * 2 + cluster_idx if collective else m_index
block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m)
slice_m = pl.ds(m_index * tile_m, tile_m)
slice_n = pl.ds(n_index * tile_n, tile_n)
acc_slot = lax.rem(local_index, jnp.int32(2))
@pl.when(wg_idx == COMPUTE_WG)
def _():
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
def _per_warp():
warp_id = lax.axis_index("warp")
@pl.when(warp_id == TMA_WARP)
def _memory():
def _loop_body(ki, _):
slice_k = pl.ds(ki * tile_k, tile_k)
slot = lax.rem(ki, max_concurrent_steps)
@pl.when(jnp.logical_or(ki >= max_concurrent_steps,
local_index > 0))
def _():
plgpu.barrier_wait(consumed_barrier.at[slot])
plgpu.copy_gmem_to_smem(
a_gmem.at[slice_m, slice_k],
a_smem.at[slot],
ab_tma_barrier.at[slot],
leader_tracked=plgpu.CopyPartition.PARTITIONED(0)
if collective
else None,
collective_axes="x" if collective else None,
)
plgpu.copy_gmem_to_smem(
b_gmem.at[slice_k, slice_n],
b_smem.at[slot],
ab_tma_barrier.at[slot],
leader_tracked=plgpu.CopyPartition.PARTITIONED(1)
if collective
else None,
collective_axes="x" if collective else None,
)
lax.fori_loop(0, k_iters, _loop_body, None)
@pl.when(jnp.logical_and(warp_id == MMA_WARP, local_index > 1))
def _wait_store():
plgpu.barrier_wait(store_done_barrier.at[acc_slot])
@pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block))
def _compute():
def _loop_body(ki, _):
slot = lax.rem(ki, max_concurrent_steps)
plgpu.barrier_wait(ab_tma_barrier.at[slot])
is_last_iter = ki >= k_iters - 1
acc_tmem_slice = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)]
plgpu.tcgen05_mma(
acc_tmem_slice,
a_smem.at[slot],
b_smem.at[slot],
consumed_barrier.at[slot],
accumulate=(ki > 0),
collective_axis="x" if collective else None,
)
@pl.when(is_last_iter)
def _():
plgpu.tcgen05_commit_arrive(
mma_done_barrier.at[acc_slot],
collective_axis="x" if collective else None,
)
lax.fori_loop(0, k_iters, _loop_body, None)
@pl.when(wg_idx == STORE_WG)
def _():
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
plgpu.barrier_wait(mma_done_barrier.at[acc_slot])
acc_tmem_slot = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)]
step_out_gmem = out_gmem.at[block_slice_m, slice_n]
for ni in range(tile_n // epilogue_tile_n):
acc_smem_ni = acc_smem.at[ni % 2]
ni_col_slice = pl.ds(ni * epilogue_tile_n, epilogue_tile_n)
acc_smem_ni[...] = plgpu.async_load_tmem(
acc_tmem_slot.at[:, ni_col_slice]
).astype(dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(acc_smem_ni, step_out_gmem.at[:, ni_col_slice])
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
plgpu.wait_load_tmem() # Load must complete before we continue.
plgpu.barrier_arrive(store_done_barrier.at[acc_slot])
if collective:
store_done_barrier = plgpu.ClusterBarrier(
collective_axes=("x",),
num_arrivals=1,
num_barriers=2,
orders_tensor_core=True,
)
else:
store_done_barrier = plgpu.Barrier(
num_arrivals=1, num_barriers=2, orders_tensor_core=True
)
f = plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
grid=(m_iters * n_iters,),
grid_names=("mn_linear",),
num_threads=2,
thread_name="wg",
cluster_names=("x",),
cluster=(1 + collective,),
scratch_shapes=dict(
a_smem=plgpu.SMEM(
(max_concurrent_steps, block_tile_m, tile_k),
dtype,
transforms=transforms,
),
b_smem=plgpu.SMEM(
(max_concurrent_steps, tile_k, block_tile_n),
dtype,
transforms=transforms,
),
acc_tmem=plgpu.TMEM(
(block_tile_m, tile_n * 2), jnp.float32, collective=collective
),
acc_smem=plgpu.SMEM(
(2, block_tile_m, epilogue_tile_n),
dtype,
transforms=out_transforms,
),
ab_tma_barrier=plgpu.Barrier(
num_arrivals=2, num_barriers=max_concurrent_steps
),
store_done_barrier=store_done_barrier,
mma_done_barrier=plgpu.Barrier(
num_arrivals=1, num_barriers=2, orders_tensor_core=True
),
consumed_barrier=plgpu.Barrier(
num_arrivals=1,
num_barriers=max_concurrent_steps,
orders_tensor_core=True,
),
),
)
return f(a, b)
def main(_) -> None:
problem_it = [(4096, 8192, 4096)]
for M, N, K in problem_it:
print(f"==== {M=} {N=} {K=} ====")
matmul_flops = 2 * M * N * K
peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS
a = jax.random.uniform(jax.random.key(1), (M, K), jnp.float16, -1, 1)
b = jax.random.uniform(jax.random.key(2), (K, N), jnp.float16, -1, 1)
tuning_it = itertools.product(
(128,), # tile_m
(128, 256), # tile_n
(64,), # tile_k
MatmulDimension, # grid_minor_dim
(1, 4, 8, 12, 16), # grid_tile_width
(2, 4, 6), # max_concurrent_steps
(False, True), # collective
(32,), # epilogue_tile_n
)
best_util = -float("inf")
expected = jnp.dot(a, b, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32)
for (tile_m, tile_n, tile_k, grid_minor_dim, grid_tile_width,
max_concurrent_steps, collective, epilogue_tile_n) in tuning_it:
# Only N <= 128 are supported for collective MMAs
if collective and tile_n > 128:
continue
config = TuningConfig(
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
max_concurrent_steps=max_concurrent_steps,
collective=collective,
epilogue_tile_n=epilogue_tile_n,
grid_minor_dim=grid_minor_dim,
grid_tile_width=grid_tile_width,
)
if collective:
tile_m *= 2
tile_n *= 2
try:
out, runtimes_ms = profiler.measure(
functools.partial(matmul_kernel, config=config), iterations=10
)(a, b)
assert runtimes_ms is not None
runtime_ms = statistics.median(runtimes_ms)
except ValueError as e:
if ("exceeds available shared memory" in e.args[0] or
"Accumulator layout mismatch:" in e.args[0]):
# Accumulator layout mismatch triggers for tile_n=256 on some configs.
continue
raise
runtime_us = runtime_ms * 1e3
optimal_time = matmul_flops / peak_flops * 1e6 # us
achieved_tc_util = optimal_time / runtime_us * 100
if achieved_tc_util > best_util:
np.testing.assert_allclose(out, expected)
best_util = achieved_tc_util
print(
f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} "
f"{grid_minor_dim=} {grid_tile_width=} "
f"{epilogue_tile_n=} "
f"{collective=} : "
f"{runtime_us:<7.1f}us"
f" = {achieved_tc_util:4.1f}% TC utilization"
)
print(f"\tBest utilization: {best_util:4.1f}%")
_, runtimes_ms = profiler.measure(
functools.partial(
jnp.dot, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32
),
iterations=10,
)(a, b)
assert runtimes_ms is not None
runtime_ms = statistics.median(runtimes_ms)
runtime_us = runtime_ms * 1e3
optimal_time = matmul_flops / peak_flops * 1e6 # us
achieved_tc_util = optimal_time / runtime_us * 100
print(f"\tReference: {achieved_tc_util:4.1f}%")
if __name__ == "__main__":
from absl import app
jax.config.config_with_absl()
app.run(main)
@@ -0,0 +1,441 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ragged/Grouped Matrix Multiplication kernel for Blackwell GPUs."""
import dataclasses
import functools
import itertools
import math
import jax
from jax import lax
from jax._src import test_util as jtu # noqa: F401
from jax.experimental.mosaic.gpu import profiler
import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu
from jax.experimental.pallas.ops.gpu import blackwell_matmul_mgpu
from jax.experimental.pallas.ops.gpu import ragged_dot_mgpu
import jax.numpy as jnp
import numpy as np
from typing import Sequence
@dataclasses.dataclass(frozen=True)
class TuningConfig:
tile_m: int
tile_n: int
tile_k: int
max_concurrent_steps: int
collective: bool
grid_tile_width: int
grid_minor_dim: blackwell_matmul_mgpu.MatmulDimension
epilogue_tile_n: int = 64
def __str__(self):
return "_".join(f"{k}={v}" for k, v in dataclasses.asdict(self).items())
# TODO(justinfu): Merge with blackwell_matmul_mgpu.py
def do_matmul(a_gmem,
b_gmem,
out_gmem,
grid_indices: Sequence[jax.Array],
wg_axis: str,
collective_axes: tuple[str, ...],
local_index: jax.Array | int,
config: TuningConfig,
group_info: ragged_dot_mgpu.GroupInfo,
a_smem, b_smem, acc_tmem, acc_smem,
a_tma_barrier, b_tma_barrier, store_done_barrier, mma_done_barrier,
consumed_barrier
):
"""Compute a non-ragged matmul for a single output block."""
dtype = out_gmem.dtype
m, k = a_gmem.shape
collective = config.collective
tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k)
epilogue_tile_n = config.epilogue_tile_n
max_concurrent_steps = config.max_concurrent_steps
block_tile_m = tile_m
if collective:
tile_m *= 2
tile_n *= 2
k_iters = k // tile_k
if collective:
m_index, n_index, cluster_idx = grid_indices
block_m_index = m_index * 2 + cluster_idx
is_lead_block = cluster_idx == 0
else:
m_index, n_index = grid_indices
cluster_idx = 0
block_m_index = m_index
is_lead_block = True
wg_idx = lax.axis_index(wg_axis)
collective_axis = collective_axes[0] if collective else None
TMA_WARP = 0
MMA_WARP = 1
COMPUTE_WG = 0
STORE_WG = 1
block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m)
slice_m = pl.ds(m_index * tile_m, tile_m)
slice_n = pl.ds(n_index * tile_n, tile_n)
acc_slot = lax.rem(local_index, jnp.int32(2))
regs_layout = plgpu.Layout.TCGEN05
@pl.when(wg_idx == COMPUTE_WG)
@jax.named_scope("compute_wg")
def _():
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
def _per_warp():
warp_id = lax.axis_index("warp")
@pl.when(warp_id == TMA_WARP)
def _memory():
def _loop_body(ki, _):
slice_k = pl.ds(ki * tile_k, tile_k)
slot = lax.rem(ki, max_concurrent_steps)
@pl.when(jnp.logical_or(ki >= max_concurrent_steps,
local_index > 0))
def _():
plgpu.barrier_wait(consumed_barrier.at[slot])
plgpu.copy_gmem_to_smem(
a_gmem.at[slice_m, slice_k],
a_smem.at[slot],
a_tma_barrier.at[slot],
leader_tracked=plgpu.CopyPartition.PARTITIONED(0)
if collective
else None,
collective_axes=collective_axis,
)
plgpu.copy_gmem_to_smem(
b_gmem.at[slice_k, slice_n],
b_smem.at[slot],
b_tma_barrier.at[slot],
leader_tracked=plgpu.CopyPartition.PARTITIONED(1)
if collective
else None,
collective_axes=collective_axis,
)
lax.fori_loop(0, k_iters, _loop_body, None)
@pl.when(jnp.logical_and(warp_id == MMA_WARP, local_index > 1))
def _wait_store():
plgpu.barrier_wait(store_done_barrier.at[acc_slot])
@pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block))
def _compute():
def _loop_body(ki, _):
slot = lax.rem(ki, max_concurrent_steps)
plgpu.barrier_wait(a_tma_barrier.at[slot])
plgpu.barrier_wait(b_tma_barrier.at[slot])
is_last_iter = ki >= k_iters - 1
acc_tmem_slice = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)]
plgpu.tcgen05_mma(
acc_tmem_slice,
a_smem.at[slot],
b_smem.at[slot],
consumed_barrier.at[slot],
accumulate=(ki > 0),
collective_axis=collective_axis,
)
@pl.when(is_last_iter)
def _():
plgpu.tcgen05_commit_arrive(
mma_done_barrier.at[acc_slot],
collective_axis=collective_axis,
)
lax.fori_loop(0, k_iters, _loop_body, None)
@pl.when(wg_idx == STORE_WG)
@jax.named_scope("store_wg")
def _():
plgpu.barrier_wait(mma_done_barrier.at[acc_slot])
acc_tmem_slot = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)]
step_out_gmem = out_gmem.at[block_slice_m, slice_n]
# group_info contains start/size info relative to the logical
# tiling (tile_m) but because for collective matmuls we use 2 CTAs per
# logical block, but we need to compute the start/size relative to the
# current block.
# For example, for the following parameters:
# block_tile_m=64 (tile_m=128)
# group_info.start_within_block=60
# group_info.actual_size=37
# The requested copy will be split across both blocks
# Memory: | Block 0 | Block 1 |
# |--- 64 ---|--- 64 ---|
# Copy: |-- 37 --|
# Where block 0 copies rows 60-64 (4 rows total) and block 1 copies
# the remaining rows 64-97 (33 rows total).
smem_start = group_info.start_within_block - cluster_idx * block_tile_m
smem_start = lax.max(smem_start, jnp.int32(0))
def _clamp(min, x, max):
return lax.max(lax.min(x, max), min)
block0_copy_size = _clamp(
jnp.int32(0),
block_tile_m - group_info.start_within_block,
group_info.actual_size)
block_local_size = lax.select(is_lead_block,
# block 0 copies up to end of the first block or actual_size,
# whichever comes first.
block0_copy_size,
# block 1 copies the remaining rows that block 0 did not copy.
group_info.actual_size - block0_copy_size
)
for ni in range(tile_n // epilogue_tile_n):
acc_smem[...] = plgpu.async_load_tmem(
acc_tmem_slot.at[:, pl.ds(ni * epilogue_tile_n, epilogue_tile_n)],
layout=regs_layout).astype(dtype)
plgpu.commit_smem()
cur_smem_idx = smem_start
remaining_rows = min(block_tile_m, m)
while remaining_rows > 0:
const_rows_len = 1 << int(math.log2(remaining_rows))
remaining_rows //= 2
@pl.when(block_local_size & const_rows_len != 0)
def _():
o_smem_slice = acc_smem.at[pl.ds(cur_smem_idx, const_rows_len)]
o_gref_slice = step_out_gmem.at[
pl.ds(cur_smem_idx, const_rows_len),
pl.ds(ni * epilogue_tile_n, epilogue_tile_n),
]
plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice)
cur_smem_idx += block_local_size & const_rows_len
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
plgpu.wait_load_tmem() # Load must complete before we continue.
plgpu.barrier_arrive(store_done_barrier.at[acc_slot])
def ragged_dot_kernel(a, b, group_sizes, config: TuningConfig):
dtype = a.dtype
if a.dtype != b.dtype:
raise ValueError(
f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}"
)
m, k = a.shape
num_groups, k2, n = b.shape
if num_groups != group_sizes.shape[0]:
raise ValueError("RHS and group_sizes have incompatible shapes.")
if k != k2:
raise ValueError(
"Matmul LHS and RHS have incompatible shapes "
f"{a.shape} vs {b.shape[1:]}"
)
collective = config.collective
tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k)
block_tile_m = tile_m
block_tile_n = tile_n
if collective:
tile_m *= 2
tile_n *= 2
m_iters = m // tile_m
n_iters = n // tile_n
max_concurrent_steps = config.max_concurrent_steps
epilogue_tile_n = config.epilogue_tile_n
if tile_n % epilogue_tile_n != 0:
raise ValueError(
f"{tile_n=} must be divisible by {epilogue_tile_n=}"
)
if m % tile_m != 0:
raise ValueError(f"{m=} must be divisible by {tile_m=}")
if n % tile_n != 0:
raise ValueError(f"{n=} must be divisible by {tile_n=}")
if k % tile_k != 0:
raise ValueError(f"{k=} must be divisible by {tile_k=}")
swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
transforms = (
plgpu.TilingTransform((8, swizzle_elems)),
plgpu.SwizzleTransform(swizzle),
)
def kernel(a_gmem, b_gmem, group_sizes_gmem, out_gmem):
linear_grid = (m_iters + num_groups - 1) * n_iters
group_sizes_regs = [group_sizes_gmem[i] for i in range(num_groups)]
cluster_idx = lax.axis_index("x")
@functools.partial(pl.run_scoped,
a_smem=plgpu.SMEM(
(max_concurrent_steps, block_tile_m, tile_k),
dtype, transforms=transforms
),
b_smem=plgpu.SMEM(
(max_concurrent_steps, tile_k, block_tile_n),
dtype, transforms=transforms
),
# Temporary SMEM used for storing accumulator output to GMEM.
acc_smem=plgpu.SMEM(
(block_tile_m, epilogue_tile_n), dtype),
# a/b_tma_barrier
a_tma_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps),
b_tma_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps),
# store_done_barrier, double-buffered
store_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=2,
orders_tensor_core=True),
# mma_done_barrier, double-buffered
mma_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=2,
orders_tensor_core=True),
# consumed_barrier
consumed_barrier=plgpu.Barrier(
num_arrivals=1,
num_barriers=max_concurrent_steps,
orders_tensor_core=True,
),
# Accumulator TMEM (double-buffered)
acc_tmem=plgpu.TMEM(
(block_tile_m, tile_n * 2), jnp.float32, collective=collective),
collective_axes=("wg",)
)
def _scoped(**ref_kwargs):
@plgpu.nd_loop(grid=(linear_grid,),
collective_axes="sm")
def mn_loop(loop_info: plgpu.NDLoopInfo):
linear_idx, = loop_info.index
local_index = loop_info.local_index
m_index, n_index = plgpu.planar_snake(
linear_idx,
(m_iters + num_groups - 1, n_iters),
config.grid_minor_dim,
config.grid_tile_width,
)
with jax.named_scope("create_group_info"):
group_info = ragged_dot_mgpu.GroupInfo.create(
group_sizes_regs, tile_m, m_index
)
do_matmul(
a_gmem,
b_gmem.at[group_info.group_id],
out_gmem,
grid_indices=(group_info.block, n_index, cluster_idx),
wg_axis="wg",
collective_axes=("x",) if collective else (),
local_index=local_index,
config=config,
group_info=group_info,
**ref_kwargs
)
num_sms = jax.local_devices()[0].core_count
compiler_params = None
f = plgpu.kernel(
kernel,
compiler_params=compiler_params,
kernel_name=f"ragged_dot_kernel_{str(config)}",
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
grid=(num_sms//2,) if collective else (num_sms,),
grid_names=("sm",),
num_threads=2,
thread_name="wg",
cluster_names=("x",) if collective else (),
cluster=(2,) if collective else (),
)
return f(a, b, group_sizes)
def ragged_dot_reference(a, b, g):
return lax.ragged_dot(a, b, g, preferred_element_type=jnp.float16)
def sample_group_sizes(key: jax.Array,
num_groups: int,
num_elements: int,
alpha: float = 10.0,
):
"""Sample group sizes.
Args:
key: PRNG key.
num_groups: Number of groups to sample.
num_elements: Total number of elements to sample.
alpha: Shape parameter. The lower the alpha, the more imbalanced the
group sizes will be. As alpha approaches infinity, the group sizes
approach a uniform distribution.
Returns:
A jax.Array of shape (num_groups,) that sums to num_elements.
"""
probs_key, sample_key = jax.random.split(key)
probs = jax.random.dirichlet(probs_key, jnp.ones((num_groups,)) * alpha)
return jax.random.multinomial(
sample_key, num_elements, probs).astype(jnp.int32)
def main(_) -> None:
M = 16 * 1024
K = 2048
N = 16 * 1024
num_groups = 16
group_sizes = sample_group_sizes(jax.random.key(0), num_groups, M, alpha=10.0)
print(f"==== {M=} {N=} {K=} {num_groups=}====")
matmul_flops = 2 * M * N * K
peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS
a = jax.random.uniform(jax.random.key(1), (M, K), jnp.float16)
b = jax.random.uniform(jax.random.key(2), (num_groups, K, N), jnp.float16)
tuning_it = itertools.product(
(128,), # tile_m
(128,), # tile_n
(64,), # tile_k
(1, 8, 12, 16), # grid_tile_width
blackwell_matmul_mgpu.MatmulDimension, # grid_minor_dim
(4, 6) # max_concurrent_steps
)
best_util = -float("inf")
for (tile_m, tile_n, tile_k, grid_tile_width, grid_minor_dim,
max_concurrent_steps,) in tuning_it:
config = TuningConfig(
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
grid_tile_width=grid_tile_width,
grid_minor_dim=grid_minor_dim,
max_concurrent_steps=max_concurrent_steps,
collective=True,
)
try:
out, runtime_ms = profiler.measure(
functools.partial(ragged_dot_kernel, config=config),
iterations=10
)(a, b, group_sizes)
runtime_ms = np.median(runtime_ms if runtime_ms else [])
except ValueError as e:
if ("exceeds available shared memory" in e.args[0] or
"Accumulator layout mismatch:" in e.args[0]):
print(e.args[0])
continue
raise
expected = ragged_dot_reference(a, b, group_sizes)
np.testing.assert_allclose(out, expected)
runtime_us = runtime_ms * 1e3
optimal_time = matmul_flops / peak_flops * 1e6 # us
achieved_tc_util = optimal_time / runtime_us * 100
if achieved_tc_util > best_util:
best_util = achieved_tc_util
print(
f"{tile_m=} {tile_n=} {tile_k=} {grid_tile_width=} {grid_minor_dim=} {max_concurrent_steps=} "
f"{runtime_us:<7.1f}us"
f" = {achieved_tc_util:4.1f}% TC utilization"
)
print(f"\tBest utilization: {best_util:4.1f}%")
if __name__ == "__main__":
from absl import app
jax.config.config_with_absl()
app.run(main)
@@ -0,0 +1,290 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collective matmul kernel implemented using Mosaic GPU."""
import functools
import itertools
import jax
import os
from jax import lax
from jax.experimental import multihost_utils
from jax.experimental import pallas as pl
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.pallas import mosaic_gpu as plgpu
from jax.experimental.pallas.ops.gpu import hopper_matmul_mgpu
import jax.numpy as jnp
MatmulDimension = hopper_matmul_mgpu.MatmulDimension
TuningConfig = hopper_matmul_mgpu.TuningConfig
def is_nvshmem_used() -> bool:
return (
"XLA_FLAGS" in os.environ
and "--xla_gpu_experimental_enable_nvshmem" in os.environ["XLA_FLAGS"]
)
def all_gather_lhs_matmul(
lhs: jax.Array,
rhs: jax.Array,
axis_name,
*,
config: hopper_matmul_mgpu.TuningConfig,
dtype: jnp.dtype = jnp.float16,
) -> jax.Array:
if (
num_devices := jax.device_count()
) != jax.process_count() and num_devices != jax.local_device_count():
raise ValueError(
"Kernel requires either 1 process per single GPU or 1 process per all"
" GPUs."
)
if (axis_size := lax.axis_size(axis_name)) != num_devices:
raise ValueError("The kernel can only work over all devices in a Mesh.")
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
raise NotImplementedError(f"Only f16 and bf16 are supported, got {dtype=}")
if config.cluster_dimension is not None:
raise NotImplementedError("Cluster dimension must be None for all-gather matmuls.")
m_shard, k = lhs.shape
k2, n_shard = rhs.shape
if k != k2:
raise ValueError(
f"lhs and rhs must have the same contraction size, got {k} and {k2}."
)
if (element_type := lhs.dtype) != rhs.dtype:
raise ValueError(
f"lhs and rhs must have the same element type, got {element_type} and"
f" {rhs.dtype}."
)
tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
max_concurrent_steps = config.max_concurrent_steps
if max_concurrent_steps < 2:
raise ValueError("max_concurrent_steps must be >= 2")
cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M))
epi_tile_n = config.epi_tile_n or tile_n
epi_tile_m = config.epi_tile_m or tile_m
if tile_n % epi_tile_n != 0:
raise ValueError(f"{tile_n=} must be divisible by {epi_tile_n=}")
if tile_m % epi_tile_m != 0:
raise ValueError(f"{tile_m=} must be divisible by {epi_tile_m=}")
num_sms = jax.devices()[0].core_count # 132 for H100 SXM GPUs.
def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref):
received_sem = pl.get_global(plgpu.SemaphoreType.REGULAR)
wg_idx = lax.axis_index("wg")
dev_id = lax.axis_index(axis_name)
send_dev_id = lax.rem(dev_id + axis_size - 1, jnp.int32(axis_size))
send_scratch_ref = plgpu.remote_ref(scratch_ref, send_dev_id)
def send_lhs(m_idx, n_idx, k_idx, a_smem, b_smem, send_ref, should_send):
del b_smem # Unused.
# We only send when n_idx == 0 to avoid sending the same data
# multiple times when revisiting lhs.
@pl.when(should_send & jnp.bool(n_idx == 0))
def _():
k_slice = pl.ds(k_idx * tile_k, tile_k)
m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m)
plgpu.copy_smem_to_gmem(a_smem, send_ref.at[m_slice, k_slice])
# We only delay release by 1 step, so we need to wait for the
# previous copies.
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
def device_step(lhs_source_ref, device_offset):
# Invariant: lhs_source_ref is ready to be used
next_scratch_slot = device_offset
out_device_m_slice = pl.ds(
lax.rem(device_offset + dev_id, jnp.int32(num_devices)) * m_shard,
m_shard,
)
is_send_wg = wg_idx == 0
has_send_space = next_scratch_slot < num_devices - 1
should_send = is_send_wg & has_send_space
# This reuses the regular matmul kernel, only with the exception of
# inserting send_lhs into the pipeline.
# TODO(apaszke): This contains run_scoped inside, meaning that it will
# synchronize all threads at each device step. If we optimize the barrier
# below, then it might be better to move it out to make bubbles smaller.
hopper_matmul_mgpu.kernel(
lhs_source_ref, # Use the lhs from previous step.
rhs_ref, # Use the same rhs for all steps.
None, # No C.
out_ref.at[out_device_m_slice], # Use a slice of the output.
config=config,
pipeline_callback=functools.partial(
send_lhs,
send_ref=send_scratch_ref.at[next_scratch_slot],
should_send=should_send,
),
delay_release=1,
)
# Wait for the next scratch to arrive --- see the device loop invariant.
@pl.when(should_send)
def _signal():
# TODO(apaszke): We could do this signal a lot earlier if we better
# control the order of sends. If we tile the grid along N, then we can
# signal as soon as everyone moves on from the first column tile.
# Make sure the copy is done and signal the receiving device.
plgpu.wait_smem_to_gmem(0, wait_read_only=False)
pl.semaphore_signal(received_sem, device_id=send_dev_id)
@pl.when(next_scratch_slot < num_devices - 1)
def _wait():
pl.semaphore_wait(received_sem, value=(device_offset + 1) * num_sms, decrement=False)
# We peel the first step to copy data directly form lhs_local_ref.
device_step(lhs_local_ref, 0)
@pl.loop(1, num_devices)
def _device_loop(device_offset):
device_step(scratch_ref.at[device_offset - 1], device_offset)
# Make sure all copies are fully done.
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
result, _ = plgpu.kernel(
kernel_body,
out_shape=[
# The output, with its M dimension all-gathered.
jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype),
# The scratch buffer used for the all-gather.
jax.ShapeDtypeStruct((num_devices - 1, m_shard, k), dtype),
],
grid=(num_sms,),
grid_names=("cluster_grid",),
num_threads=3,
thread_name="wg",
cluster=(1,),
cluster_names=("cluster",),
)(lhs, rhs)
return result
def _min_results_across_devices(kernels_ms : list[tuple[str, float]]) -> float:
# We choose the minimum across processes to choose the runtime that didn't
# include devices waiting for other devices.
if is_nvshmem_used():
time_us = sum(t * 1e3 for _, t in kernels_ms)
return min(multihost_utils.process_allgather(time_us).tolist())
# profiler.measures measures all devices visible to the process, so we
# need to select the mimimum result of each kernel across devices.
# This code relies on the fact that with collective metadata a single kernel
# with unique name is launched on each device.
min_values : dict[str, float] = {}
for kernel_name, t in kernels_ms:
if kernel_name not in min_values or t < min_values[kernel_name]:
min_values[kernel_name] = t
return sum(time_ms * 1e3 for time_ms in min_values.values())
def _run_example():
P = jax.sharding.PartitionSpec
m_shard = 1024
n_shard = 4096
k = 4096
dtype = jnp.bfloat16
shards = jax.device_count()
mesh = jax.make_mesh(
(shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
)
jax.set_mesh(mesh)
# We measure time per-shard and so we only need FLOPs per shard.
matmul_flops = 2 * (shards * m_shard) * n_shard * k
peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS
optimal_time = matmul_flops / peak_flops * 1e6 # us
a = jax.random.normal(jax.random.key(1), (shards * m_shard, k), dtype)
b = jax.random.normal(jax.random.key(2), (k, shards * n_shard), dtype)
a = jax.sharding.reshard(a, P("x", None))
b = jax.sharding.reshard(b, P(None, "x"))
_, ref_kernels_ms = profiler.measure(jax.jit(
jax.shard_map(
lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y,
out_specs=P(None, "x"),
check_vma=False,
)
), aggregate=False)(a, b)
assert ref_kernels_ms is not None
ref_time_us = _min_results_across_devices(ref_kernels_ms)
ref_util = optimal_time / ref_time_us * 100
tuning_it = itertools.product(
(128, 256,), # tile_m
(64, 128), # tile_n
(64,), # tile_k
(4,), # max_concurrent_steps
(MatmulDimension.M, MatmulDimension.N), # grid_minor_dim
(4, 8, 16), # grid_tile_width
MatmulDimension, # wg_dimension
)
best_util = 0.0
best_runtime = float("inf")
def build_kernel(**kwargs):
return jax.jit(
jax.shard_map(
functools.partial(all_gather_lhs_matmul, **kwargs),
out_specs=P(None, "x"),
check_vma=False,
)
)
for tile_m, tile_n, tile_k, max_concurrent_steps, grid_minor_dim, grid_tile_width, wg_dimension in tuning_it:
try:
config = TuningConfig(
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
max_concurrent_steps=max_concurrent_steps,
grid_minor_dim=grid_minor_dim,
grid_tile_width=grid_tile_width,
wg_dimension=wg_dimension,
)
_, kernels_ms = profiler.measure(
build_kernel(axis_name="x", config=config, dtype=dtype),
aggregate=False,
)(a, b)
except ValueError as e:
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
continue
raise
assert kernels_ms is not None
runtime_us = _min_results_across_devices(kernels_ms)
achieved_tc_util = optimal_time / runtime_us * 100
if achieved_tc_util > best_util:
best_runtime = runtime_us
best_util = achieved_tc_util
print(
f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=}: "
f"{runtime_us:<7.1f}us"
f" = {achieved_tc_util:4.1f}% TC utilization"
)
print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization")
print(f"\tRef: {ref_time_us:<7.1f}us = {ref_util:4.1f}% TC utilization")
if __name__ == "__main__":
if is_nvshmem_used():
from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error
jt_multiprocess.main(shard_main=_run_example)
else:
from jax._src.config import config as jax_config
from absl import app
jax_config.config_with_absl()
app.run(lambda _: _run_example())
@@ -0,0 +1,503 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing decode attention."""
from __future__ import annotations
import math
import functools
from typing import Any
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
import jax.numpy as jnp
def attn_forward_kernel(
# inputs
q_ref, # [num_heads, head_dim]
k_ref, # [k_seq_len, head_dim]
v_ref, # [k_seq_len, head_dim]
start_idx_ref, # [] (i.e., scalar)
kv_seq_len_ref, # [] (i.e., scalar)
# outputs
o_ref: Any, # [num_heads, head_dim]
*residual_refs: Any, # Residual outputs: [num_heads,], [num_heads,]
sm_scale: float,
block_k: int,
block_h: int,
num_heads: int,
):
_, head_dim = q_ref.shape
split_k_seq_len, _ = k_ref.shape
prog_i, prog_j = pl.program_id(0), pl.program_id(1)
q_slice = pl.ds(0, block_h)
q_mask = (jnp.arange(block_h) < num_heads - block_h * prog_i)[:, None]
def _compute(start_idx, kv_seq_len, o, m_i, l_i):
# Load q: it will stay in L1 throughout. Indices form a matrix because we
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
# q tile has shape [block_h, head_dim].
q = plgpu.load(q_ref.at[q_slice, :], mask=q_mask)
def _dot(a, b):
# if a.shape[0] == 1:
# # Use matrix vector product
# return (a.T * b).sum(axis=0, keepdims=True)
return pl.dot(a, b)
mask_indices = jnp.arange(block_k)
# Loop over blocks of kv to process entire kv seq_len.
# Grid loops over q blocks over num_heads.
def body(start_k, carry):
o_prev, m_prev, l_prev = carry
curr_k_slice = pl.ds(start_k * block_k, block_k)
k = k_ref[curr_k_slice, :]
qk = _dot(q, k.T) # [block_h, block_k]
if sm_scale != 1.0:
qk *= sm_scale # [block_h, block_k]
# apply mask if start or sequence length is specified
if start_idx_ref is not None or kv_seq_len_ref is not None:
indices = (prog_j * split_k_seq_len + start_k * block_k + mask_indices)
mask = ((indices >= start_idx) & (indices < kv_seq_len))[None, :]
qk += (~mask) * (0.7 * jnp.finfo(qk.dtype).min)
m_curr = qk.max(axis=-1)
m_next = jnp.maximum(m_prev, m_curr)
correction = jnp.exp(m_prev - m_next)
l_prev_corr = correction * l_prev
s_curr = jnp.exp(
qk - m_next[:, None]
) # Use m_next instead of m_curr to avoid a correction on l_curr
l_curr = s_curr.sum(axis=-1)
l_next = l_prev_corr + l_curr
v = v_ref[curr_k_slice, :]
o_curr = _dot(s_curr.astype(v.dtype), v)
# flash2 unscaled_o
o_next = correction[:, None] * o_prev + o_curr
return o_next, m_next, l_next
max_it = jnp.minimum(pl.cdiv((kv_seq_len - prog_j * split_k_seq_len),
block_k), split_k_seq_len // block_k)
(o, m_i, l_i) = lax.fori_loop(0, max_it, body, (o, m_i, l_i))
return o, m_i, l_i
# o is the buffer where we accumulate the output on sram.
# m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop.
m_i = jnp.zeros(block_h, dtype=jnp.float32) + jnp.finfo(jnp.float32).min
l_i = jnp.zeros(block_h, dtype=jnp.float32)
o = jnp.zeros((block_h, head_dim), dtype=jnp.float32)
start_idx = split_k_seq_len * prog_j
if start_idx_ref is not None:
start_idx = jnp.maximum(start_idx, start_idx_ref[()])
kv_seq_len = (prog_j + 1) * split_k_seq_len # lower bound on actual k_seq_len
if kv_seq_len_ref is not None:
kv_seq_len = jnp.minimum(kv_seq_len, kv_seq_len_ref[()])
if start_idx_ref is None and kv_seq_len is None:
o, m_i, l_i = _compute(start_idx, kv_seq_len, o, m_i, l_i)
else:
o, m_i, l_i = jax.lax.cond(
start_idx >= kv_seq_len, lambda: (o, m_i, l_i),
lambda: _compute(start_idx, kv_seq_len, o, m_i, l_i))
# Write output to dram.
if residual_refs:
l_ref, m_ref = residual_refs
vec_q_mask = q_mask.reshape(-1) if q_mask is not None else None
plgpu.store(l_ref.at[q_slice], l_i, mask=vec_q_mask)
plgpu.store(m_ref.at[q_slice], m_i, mask=vec_q_mask)
o = o.astype(o_ref.dtype)
plgpu.store(o_ref.at[q_slice, :], o, mask=q_mask)
def decode_attn_unbatched(
q, # [num_heads, head_dim]
k, # [k_seq_len, head_dim]
v, # [k_seq_len, head_dim]
start_idx, # []
kv_seq_len, # []
sm_scale: float,
block_h: int,
block_k: int,
k_splits: int,
num_warps: int | None,
num_stages: int,
grid: tuple[int, ...] | None,
interpret: bool,
debug: bool,
return_residuals: bool,
normalize_output: bool
):
num_heads, head_dim = q.shape
k_seq_len, _ = k.shape
# Pad num query heads to 16 if needed, and slice output at the end.
head_splits = pl.cdiv(num_heads, block_h)
grid_ = grid
if grid_ is None:
grid_ = (head_splits, k_splits)
assert (
k_seq_len % k_splits == 0
), f"{k_seq_len=} must be divisible by {k_splits=}"
assert k_seq_len // k_splits >= 16, (
f"{k_seq_len=} divided by {k_splits=} must be >= 16.")
assert block_k >= 16, "block_k must be >= 16"
k = k.reshape(k_splits, k_seq_len // k_splits, head_dim)
v = v.reshape(k_splits, k_seq_len // k_splits, head_dim)
split_k_seq_len = k_seq_len // k_splits
block_k = min(block_k, split_k_seq_len)
assert split_k_seq_len % block_k == 0, (
f"Sequence length ({k_seq_len=}) split by {k_splits=} must by divisible by"
f" {block_k=}")
num_warps_ = num_warps
if num_warps_ is None:
num_warps_ = 4
kernel = functools.partial(
attn_forward_kernel,
sm_scale=sm_scale,
block_k=block_k,
block_h=block_h,
num_heads=num_heads,
)
o, l, m = pl.pallas_call(
kernel,
grid=grid_,
in_specs=[
pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)),
pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)),
pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)),
]
+ [None if start_idx is None else pl.BlockSpec((), lambda i, j: ())]
+ [None if kv_seq_len is None else pl.BlockSpec((), lambda i, j: ())],
out_specs=[
pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m
],
compiler_params=plgpu.CompilerParams(
num_warps=num_warps_, num_stages=num_stages
),
out_shape=[
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o
jax.ShapeDtypeStruct(
shape=(k_splits, num_heads), dtype=jnp.float32
), # l
jax.ShapeDtypeStruct(
shape=(k_splits, num_heads), dtype=jnp.float32
), # m
],
debug=debug,
interpret=interpret,
name="mha_forward",
)(q, k, v, start_idx, kv_seq_len)
# final round of flash
m_next = m.max(axis=0)
# TODO(b/389925439): This barrier is necessary to prevent NaNs/invalid
# values appearing after JIT compilation.
m_next = lax.optimization_barrier(m_next)
correction = jnp.exp(m - m_next[None])
o = o * correction[:, :, None].astype(o.dtype)
l_next = (l * correction).sum(axis=0)
eps = jnp.finfo(l_next.dtype).eps
o = o.sum(axis=0)
if normalize_output:
o /= (l_next[:, None].astype(o.dtype) + eps)
if return_residuals:
return o, (l_next, m_next)
else:
return o
@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"block_h",
"block_k",
"k_splits",
"num_warps",
"num_stages",
"grid",
"interpret",
"debug",
"return_residuals",
"normalize_output"
],
)
def mqa(
q, # [batch_size, num_heads, head_dim]
k, # [batch_size, k_seq_len, head_dim]
v, # [batch_size, k_seq_len, head_dim]
start_idx=None, # [batch_size]
kv_seq_len=None, # [batch_size]
sm_scale: float | None = None,
block_h: int = 16,
block_k: int = 256,
k_splits: int = 16,
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
return_residuals: bool = False,
normalize_output: bool = True,
):
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
bs = q.shape[0]
if start_idx is not None:
start_idx = jnp.broadcast_to(start_idx, (bs,))
if kv_seq_len is not None:
kv_seq_len = jnp.broadcast_to(kv_seq_len, (bs,))
inner = functools.partial(
decode_attn_unbatched,
sm_scale=sm_scale,
block_h=block_h,
block_k=block_k,
k_splits=k_splits,
num_warps=num_warps,
num_stages=num_stages,
grid=grid,
interpret=interpret,
debug=debug,
return_residuals=return_residuals,
normalize_output=normalize_output,
)
return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len)
@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"block_h",
"block_k",
"k_splits",
"num_warps",
"num_stages",
"grid",
"interpret",
"debug",
"return_residuals",
"normalize_output"
],
)
def gqa(
q, # [batch_size, num_q_heads, head_dim]
k, # [batch_size, k_seq_len, num_kv_heads, head_dim]
v, # [batch_size, k_seq_len, num_kv_heads, head_dim]
start_idx=None, # [batch_size]
kv_seq_len=None, # [batch_size]
sm_scale: float | None = None,
block_h: int = 16,
block_k: int = 128,
k_splits: int = 16,
num_warps: int | None = None,
num_stages: int = 2,
grid: tuple[int, ...] | None = None,
interpret: bool = False,
debug: bool = False,
return_residuals: bool = False,
normalize_output: bool = True,
):
if not normalize_output and not return_residuals:
raise NotImplementedError(
"When normalize_output is False, attention residuals must be returned."
)
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
batch_size, q_heads, head_dim = q.shape
_k_seq_len, kv_heads = k.shape[1], k.shape[2]
assert kv_heads == v.shape[2]
assert q_heads % kv_heads == 0
if start_idx is not None:
assert start_idx.ndim in (0, 1)
start_idx = jnp.broadcast_to(jnp.asarray(start_idx)[..., None],
(batch_size, kv_heads))
if kv_seq_len is not None:
assert kv_seq_len.ndim in (0, 1)
kv_seq_len = jnp.broadcast_to(jnp.asarray(kv_seq_len)[..., None],
(batch_size, kv_heads))
q_heads_per_kv_head = q_heads // kv_heads
q_reshaped = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim)
k_transposed = jnp.swapaxes(
k, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
v_transposed = jnp.swapaxes(
v, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
inner = functools.partial(
decode_attn_unbatched,
sm_scale=sm_scale,
block_h=block_h,
block_k=block_k,
k_splits=k_splits,
num_warps=num_warps,
num_stages=num_stages,
grid=grid,
interpret=interpret,
debug=debug,
return_residuals=return_residuals,
normalize_output=normalize_output,
)
with_kv_heads = jax.vmap(inner)
outputs = jax.vmap(with_kv_heads)(
q_reshaped, k_transposed, v_transposed, start_idx, kv_seq_len
)
if return_residuals:
o, (l, m) = outputs
o = o.reshape(batch_size, q_heads, head_dim)
l = l.reshape(batch_size, q_heads)
m = m.reshape(batch_size, q_heads)
return o, (l, m)
else:
o = outputs
o = o.reshape(batch_size, q_heads, head_dim)
return o
@functools.partial(
jax.jit,
static_argnames=["sm_scale", "return_residuals", "normalize_output"],
)
def mqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, head_dim]
v, # [bs, k_seq_len, head_dim]
start_idx=None, # [bs]
kv_seq_len=None, # [bs]
sm_scale=None,
return_residuals=False,
normalize_output=True,
):
original_dtype = q.dtype
q = q.astype(jnp.float32)
k = k.astype(jnp.float32)
bs = q.shape[0]
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32)
if sm_scale is not None and sm_scale != 1.0:
logits = logits * sm_scale
if start_idx is not None or kv_seq_len is not None:
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
else kv_seq_len, (bs,))
mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None])
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
mask = mask[:, None, :]
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
m = logits.max(axis=-1)
s = jnp.exp(logits - m[..., None])
l = s.sum(axis=-1)
if normalize_output:
s = s / l[..., None]
o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype)
if return_residuals:
return o, (l, m)
else:
return o
@functools.partial(jax.jit, static_argnames=["sm_scale"])
def mha_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, num_k_heads, head_dim]
v, # [bs, k_seq_len, num_v_heads, head_dim]
start_idx=None, # [bs]
kv_seq_len=None, # [bs]
sm_scale=None,
):
bs = q.shape[0]
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
assert q.shape[1] == k.shape[2]
logits = jnp.einsum("bnd,bsnd->bns", q, k).astype(jnp.float32)
if start_idx is not None or kv_seq_len is not None:
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
else kv_seq_len, (bs,))
mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None])
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
mask = mask[:, None, :]
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
return jnp.einsum("bns,bsnd->bnd", weights, v)
@functools.partial(
jax.jit,
static_argnames=["sm_scale", "return_residuals", "normalize_output"],
)
def gqa_reference(
q, # [bs, num_q_heads, head_dim]
k, # [bs, k_seq_len, num_k_heads, head_dim]
v, # [bs, k_seq_len, num_v_heads, head_dim]
start_idx=None, # [bs]
kv_seq_len=None, # [bs]
sm_scale=None,
return_residuals=False,
normalize_output=True
):
original_dtype = q.dtype
q = q.astype(jnp.float32)
k = k.astype(jnp.float32)
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
bs, num_q_heads, head_dim = q.shape
num_kv_heads = k.shape[2]
assert num_q_heads % num_kv_heads == 0
q_reshaped = q.reshape(
bs, num_kv_heads, num_q_heads // num_kv_heads, head_dim
)
k_transposed = jnp.swapaxes(
k, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
v_transposed = jnp.swapaxes(
v, 1, 2
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype(
jnp.float32
)
if sm_scale is not None and sm_scale != 1.0:
logits = logits * sm_scale
if start_idx is not None or kv_seq_len is not None:
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
else kv_seq_len, (bs,))
mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None])
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
mask = mask[:, None, None, :]
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
m = logits.max(axis=-1)
s = jnp.exp(logits - m[..., None])
l = s.sum(axis=-1)
if normalize_output:
s = s / l[..., None]
o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype)
o = o.reshape(bs, num_q_heads, head_dim)
if return_residuals:
l = l.reshape(bs, num_q_heads)
m = m.reshape(bs, num_q_heads)
return o, (l, m)
else:
return o
@@ -0,0 +1,329 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Matrix Multiplication kernel for Hopper GPUs."""
import statistics
import dataclasses
import enum
import functools
import itertools
import jax
from jax import lax
from jax._src import test_util as jtu # noqa: F401
from jax.experimental.mosaic.gpu import profiler
import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu
from jax.extend import backend
import jax.numpy as jnp
import numpy as np
class MatmulDimension(enum.IntEnum):
M = 0
N = 1
def __str__(self):
return self.name
def __repr__(self):
return self.name
@dataclasses.dataclass(frozen=True)
class TuningConfig:
tile_m: int
tile_n: int
tile_k: int
max_concurrent_steps: int
epi_tile_n: int | None = 64 # This needs to be lowered for for small N.
epi_tile_m: int | None = 64
grid_minor_dim: MatmulDimension = MatmulDimension.N
grid_tile_width: int = 1
wg_dimension: MatmulDimension = MatmulDimension.N
cluster_dimension: None | MatmulDimension = None
# pipeline_callback and delay_release are only used for collective matmuls.
def kernel(a_gmem, b_gmem, c_gmem, out_gmem, config: TuningConfig,
pipeline_callback=None, delay_release=0):
dtype = a_gmem.dtype
out_dtype = out_gmem.dtype
assert b_gmem.dtype == dtype
if c_gmem is not None:
assert c_gmem.dtype == out_dtype
m, k = a_gmem.shape
k2, n = b_gmem.shape
assert k == k2
tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
max_concurrent_steps = config.max_concurrent_steps
swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
transforms = (
plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle)
)
cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M))
cta_tile_n = tile_n * (1 + (config.wg_dimension == MatmulDimension.N))
cluster_tile_m = cta_tile_m * (1 + (config.cluster_dimension == MatmulDimension.M))
cluster_tile_n = cta_tile_n * (1 + (config.cluster_dimension == MatmulDimension.N))
if m % cluster_tile_m != 0:
raise ValueError(f"{m=} must be divisible by {cluster_tile_m} for the given config")
if n % cluster_tile_n != 0:
raise ValueError(f"{n=} must be divisible by {cluster_tile_n} for the given config")
if k % tile_k != 0:
raise ValueError(f"{k=} must be divisible by {tile_k=}")
m_iters = m // cluster_tile_m
n_iters = n // cluster_tile_n
k_iters = k // tile_k
epi_tile_m = config.epi_tile_m or tile_m
epi_tile_n = config.epi_tile_n or tile_n
# We don't need multiple slots if there's only one epilogue tile.
num_out_slots = min(2, (tile_m * tile_n) // (epi_tile_m * epi_tile_n))
out_swizzle = plgpu.find_swizzle(epi_tile_n * jnp.dtype(out_dtype).itemsize * 8)
out_swizzle_elems = out_swizzle // jnp.dtype(out_dtype).itemsize
out_transforms = (
plgpu.TilingTransform((8, out_swizzle_elems)),
plgpu.SwizzleTransform(out_swizzle),
)
def get_pipeline(pipeline_body, compute_context):
return plgpu.emit_pipeline_warp_specialized(
pipeline_body,
grid=(k_iters,),
memory_registers=40,
in_specs=[
plgpu.BlockSpec(
(cta_tile_m, tile_k),
lambda k: (0, k),
transforms=transforms,
memory_space=plgpu.SMEM,
delay_release=delay_release,
collective_axes=("cluster",)
if config.cluster_dimension == MatmulDimension.N
else (),
),
plgpu.BlockSpec(
(tile_k, cta_tile_n),
lambda k: (k, 0),
transforms=transforms,
memory_space=plgpu.SMEM,
delay_release=delay_release,
collective_axes=("cluster",)
if config.cluster_dimension == MatmulDimension.M
else (),
),
],
wg_axis="wg",
num_compute_wgs=2,
max_concurrent_steps=max_concurrent_steps,
compute_context=compute_context,
)
# Functions don't influence the allocations necessary to run the pipeline.
ignore = lambda *_, **__: None
@functools.partial(
pl.run_scoped,
pipeline_allocs=get_pipeline(ignore, ignore).get_allocations(a_gmem, b_gmem),
out_smem=plgpu.SMEM(
(2, num_out_slots, epi_tile_m, epi_tile_n),
out_dtype,
transforms=out_transforms,
),
c_barrier=None if c_gmem is None else plgpu.Barrier(num_barriers=2 * num_out_slots),
collective_axes="wg",
)
def _pipeline_scope(pipeline_allocs, out_smem, c_barrier):
wg_idx = lax.axis_index("wg")
cta_idx = lax.axis_index("cluster")
@plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
def _mn_loop(loop_info: plgpu.NDLoopInfo):
(lin_idx,) = loop_info.index
m_cluster_idx, n_cluster_idx = plgpu.planar_snake(
lin_idx,
(m_iters, n_iters),
config.grid_minor_dim,
config.grid_tile_width,
)
m_idx = m_cluster_idx
n_idx = n_cluster_idx
if config.cluster_dimension == MatmulDimension.M:
m_idx = m_cluster_idx * 2 + cta_idx
elif config.cluster_dimension == MatmulDimension.N:
n_idx = n_cluster_idx * 2 + cta_idx
cta_m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m)
cta_n_slice = pl.ds(n_idx * cta_tile_n, cta_tile_n)
if config.wg_dimension == MatmulDimension.M:
wg_m_slice = pl.ds(wg_idx * tile_m, tile_m)
wg_n_slice = slice(None)
else:
wg_m_slice = slice(None)
wg_n_slice = pl.ds(wg_idx * tile_n, tile_n)
def compute_context(eval_pipeline):
@functools.partial(
pl.run_scoped, acc_ref=plgpu.ACC((tile_m, tile_n), jnp.float32)
)
def _acc_scope(acc_ref):
eval_pipeline(acc_ref)
acc = acc_ref[...].astype(out_dtype)
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
for epi_mi in range(tile_m // epi_tile_m):
for epi_ni in range(tile_n // epi_tile_n):
epi_m_slice = slice(epi_mi * epi_tile_m, (epi_mi + 1) * epi_tile_m)
epi_n_slice = slice(epi_ni * epi_tile_n, (epi_ni + 1) * epi_tile_n)
slot = (epi_mi * (tile_n // epi_tile_n) + epi_ni) % 2
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
if c_gmem is None:
out_smem[wg_idx, slot] = acc[epi_m_slice, epi_n_slice]
else:
# TODO: Consider using triple-buffering so to not end up issuing
# the copy and immediately blocking on it
plgpu.copy_gmem_to_smem(
c_gmem.at[cta_m_slice, cta_n_slice]
.at[wg_m_slice, wg_n_slice]
.at[epi_m_slice, epi_n_slice],
out_smem.at[wg_idx, slot],
c_barrier.at[wg_idx * num_out_slots + slot],
)
plgpu.barrier_wait(c_barrier.at[wg_idx * num_out_slots + slot])
out_smem[wg_idx, slot] += acc[epi_m_slice, epi_n_slice]
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
out_smem.at[wg_idx, slot],
out_gmem.at[cta_m_slice, cta_n_slice]
.at[wg_m_slice, wg_n_slice]
.at[epi_m_slice, epi_n_slice],
)
def mma_body(idxs, a_smem, b_smem, acc_ref):
plgpu.wgmma(acc_ref, a_smem.at[wg_m_slice], b_smem.at[:, wg_n_slice])
if pipeline_callback is not None:
(k_idx,) = idxs
pipeline_callback(m_idx, n_idx, k_idx, a_smem, b_smem)
plgpu.wgmma_wait(delay_release)
return acc_ref
get_pipeline(mma_body, compute_context)(
a_gmem.at[cta_m_slice, :],
b_gmem.at[:, cta_n_slice],
allocations=pipeline_allocs,
)
# Await all transfers before we exit.
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
def matmul(a, b, c, config: TuningConfig):
dtype = a.dtype
if a.dtype != b.dtype:
raise ValueError(
f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}"
)
m, k = a.shape
k2, n = b.shape
assert k == k2
if k != k2:
raise ValueError(
f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}"
)
if c is None:
out_dtype = dtype
else:
if c.shape != (m, n):
raise ValueError(f"C has incompatible shape {c.shape} vs {(m, n)}")
out_dtype = c.dtype
tile_m, tile_n = config.tile_m, config.tile_n
epi_tile_n = config.epi_tile_n or tile_n
epi_tile_m = config.epi_tile_m or tile_m
config = dataclasses.replace(config, epi_tile_n=epi_tile_n, epi_tile_m=epi_tile_m)
num_sms = backend.get_default_device().core_count
cluster_size = 1 + (config.cluster_dimension is not None)
f = plgpu.kernel(
functools.partial(kernel, config=config),
out_shape=jax.ShapeDtypeStruct((m, n), out_dtype),
grid=(num_sms // cluster_size,),
grid_names=("cluster_grid",),
cluster=(cluster_size,),
cluster_names=("cluster",),
num_threads=3,
thread_name="wg",
)
return f(a, b, c)
def main(_) -> None:
problem_it = [(4096, 8192, 4096)]
for M, N, K in problem_it:
print(f"==== {M=} {N=} {K=} ====")
matmul_flops = 2 * M * N * K
peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS
a = jax.random.uniform(jax.random.key(0), (M, K), jnp.float16)
b = jax.random.uniform(jax.random.key(1), (K, N), jnp.float16)
ref = a @ b
tuning_it = itertools.product(
(128, 256,), # tile_m
(64, 128), # tile_n
(64,), # tile_k
(4,), # max_concurrent_steps
(True,), # Tiled epilogue
(MatmulDimension.M, MatmulDimension.N), # grid_minor_dim
(4, 8, 16), # grid_tile_width
MatmulDimension, # wg_dimension
# Consider adding MatmulDimension here to try out collective TMA kernels
(None,) # cluster_dimension
)
best_util = 0.0
best_runtime = float("inf")
for tile_m, tile_n, tile_k, max_concurrent_steps, tiled_epilogue, grid_minor_dim, grid_tile_width, wg_dimension, cluster_dimension in tuning_it:
config = TuningConfig(
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
max_concurrent_steps=max_concurrent_steps,
epi_tile_n=64 if tiled_epilogue else None,
epi_tile_m=64 if tiled_epilogue else None,
grid_minor_dim=grid_minor_dim,
grid_tile_width=grid_tile_width,
wg_dimension=wg_dimension,
cluster_dimension=cluster_dimension,
)
try:
out, runtimes_ms = profiler.measure(
functools.partial(matmul, config=config), iterations=10,
)(a, b, None)
assert runtimes_ms is not None
runtime_ms = statistics.median(runtimes_ms)
except ValueError as e:
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
continue
raise
np.testing.assert_allclose(out, ref)
runtime_us = runtime_ms * 1e3
optimal_time = matmul_flops / peak_flops * 1e6 # us
achieved_tc_util = optimal_time / runtime_us * 100
if achieved_tc_util > best_util:
best_runtime = runtime_us
best_util = achieved_tc_util
print(
f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {tiled_epilogue=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=} {cluster_dimension=}:"
f" {runtime_us:<7.1f}us = {achieved_tc_util:4.1f}% TC utilization"
)
print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization")
if __name__ == "__main__":
from absl import app
jax.config.config_with_absl()
app.run(main)
@@ -0,0 +1,343 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Matrix Multiplication kernel for Hopper GPUs."""
import statistics
import dataclasses
import enum
import functools
import itertools
import jax
from jax._src import dtypes
from jax import lax
from jax._src import test_util as jtu # noqa: F401
from jax.experimental.mosaic.gpu import profiler
import jax.experimental.pallas as pl
import jax.experimental.pallas.mosaic_gpu as plgpu
from jax.extend import backend
import jax.numpy as jnp
import numpy as np
class MatmulDimension(enum.IntEnum):
M = 0
N = 1
def __str__(self):
return self.name
def __repr__(self):
return self.name
@dataclasses.dataclass(frozen=True)
class TuningConfig:
tile_m: int
tile_n: int
tile_k: int
max_concurrent_steps: int
epi_tile_n: int | None = 64 # This needs to be lowered for for small N.
epi_tile_m: int | None = 64
grid_minor_dim: MatmulDimension = MatmulDimension.N
grid_tile_width: int = 1
wg_dimension: MatmulDimension = MatmulDimension.N
cluster_dimension: None | MatmulDimension = None
def mixed_matmul_kernel(
a: jax.Array, b: jax.Array, *, out_dtype: jnp.dtype, config: TuningConfig
) -> jax.Array:
"""Mixed-type matrix multiplication kernel for Hopper GPUs.
Specifically, this kernel implements the function
(a.as_dtype(b.dtype) @ b).astype(out_dtype).
"""
if a.dtype == b.dtype:
raise ValueError(
f"Mixed matmul LHS and RHS have the same dtype {a.dtype}. For such "
"matrix multiplications, use the `hopper_matmul_mgpu` kernel instead."
)
match (a.dtype, b.dtype):
case (jnp.int8, jnp.bfloat16):
pass
case (jnp.int8, jnp.float16):
pass
case _, _:
# We do support more combinations, but we haven't benchmarked them
# yet---so we raise for the time being.
raise NotImplementedError(
f"Unbenchmarked dtype combination: {a.dtype=} and {b.dtype=}"
)
m, k = a.shape
k2, n = b.shape
if k != k2:
raise ValueError(
f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}"
)
tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
epi_tile_n = config.epi_tile_n or tile_n
epi_tile_m = config.epi_tile_m or tile_m
if tile_n % epi_tile_n != 0:
raise ValueError(f"{tile_n=} must be divisible by {epi_tile_n=}")
if tile_m % epi_tile_m != 0:
raise ValueError(f"{tile_m=} must be divisible by {epi_tile_m=}")
a_bits = dtypes.itemsize_bits(a.dtype)
b_bits = dtypes.itemsize_bits(b.dtype)
out_bits = dtypes.itemsize_bits(out_dtype)
a_swizzle = plgpu.find_swizzle(tile_k * a_bits, "lhs")
b_swizzle = plgpu.find_swizzle(tile_n * b_bits, "rhs")
out_swizzle = plgpu.find_swizzle(epi_tile_n * out_bits, "out")
a_transforms = (
plgpu.TilingTransform((8, a_swizzle * 8 // a_bits)),
plgpu.SwizzleTransform(a_swizzle),
)
b_transforms = (
plgpu.TilingTransform((8, b_swizzle * 8 // b_bits)),
plgpu.SwizzleTransform(b_swizzle),
)
out_transforms = (
plgpu.TilingTransform((8, out_swizzle * 8 // out_bits)),
plgpu.SwizzleTransform(out_swizzle),
)
max_concurrent_steps = config.max_concurrent_steps
cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M))
cta_tile_n = tile_n * (1 + (config.wg_dimension == MatmulDimension.N))
cluster_tile_m = cta_tile_m * (1 + (config.cluster_dimension == MatmulDimension.M))
cluster_tile_n = cta_tile_n * (1 + (config.cluster_dimension == MatmulDimension.N))
if m % cluster_tile_m != 0:
raise ValueError(f"{m=} must be divisible by {cluster_tile_m} for the given config")
if n % cluster_tile_n != 0:
raise ValueError(f"{n=} must be divisible by {cluster_tile_n} for the given config")
if k % tile_k != 0:
raise ValueError(f"{k=} must be divisible by {tile_k=}")
m_iters = m // cluster_tile_m
n_iters = n // cluster_tile_n
k_iters = k // tile_k
def kernel(a_gmem, b_gmem, out_gmem, out_smem):
def get_pipeline(pipeline_body, compute_context):
return plgpu.emit_pipeline_warp_specialized(
pipeline_body,
grid=(k_iters,),
memory_registers=40,
in_specs=[
plgpu.BlockSpec(
(cta_tile_m, tile_k),
lambda k: (0, k),
transforms=a_transforms,
memory_space=plgpu.SMEM,
collective_axes=("cluster",)
if config.cluster_dimension == MatmulDimension.N
else (),
),
plgpu.BlockSpec(
(tile_k, cta_tile_n),
lambda k: (k, 0),
transforms=b_transforms,
memory_space=plgpu.SMEM,
collective_axes=("cluster",)
if config.cluster_dimension == MatmulDimension.M
else (),
),
],
wg_axis="wg",
num_compute_wgs=2,
max_concurrent_steps=max_concurrent_steps,
compute_context=compute_context,
)
# Functions don't influence the allocations necessary to run the pipeline.
ignore = lambda *_, **__: None
@functools.partial(
pl.run_scoped,
pipeline_allocs=get_pipeline(ignore, ignore).get_allocations(a_gmem, b_gmem),
collective_axes="wg",
)
def _pipeline_scope(pipeline_allocs):
wg_idx = lax.axis_index("wg")
cta_idx = lax.axis_index("cluster")
@plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
def _mn_loop(loop_info: plgpu.NDLoopInfo):
(lin_idx,) = loop_info.index
m_cluster_idx, n_cluster_idx = plgpu.planar_snake(
lin_idx,
(m_iters, n_iters),
config.grid_minor_dim,
config.grid_tile_width,
)
m_idx = m_cluster_idx
n_idx = n_cluster_idx
if config.cluster_dimension == MatmulDimension.M:
m_idx = m_cluster_idx * 2 + cta_idx
elif config.cluster_dimension == MatmulDimension.N:
n_idx = n_cluster_idx * 2 + cta_idx
cta_m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m)
cta_n_slice = pl.ds(n_idx * cta_tile_n, cta_tile_n)
if config.wg_dimension == MatmulDimension.M:
wg_m_slice = pl.ds(wg_idx * tile_m, tile_m)
wg_n_slice = slice(None)
else:
wg_m_slice = slice(None)
wg_n_slice = pl.ds(wg_idx * tile_n, tile_n)
def compute_context(eval_pipeline):
@functools.partial(
pl.run_scoped, acc_ref=plgpu.ACC((tile_m, tile_n), jnp.float32)
)
def _acc_scope(acc_ref):
eval_pipeline(acc_ref)
acc = acc_ref[...].astype(out_dtype)
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
for epi_mi in range(tile_m // epi_tile_m):
for epi_ni in range(tile_n // epi_tile_n):
epi_m_slice = slice(epi_mi * epi_tile_m, (epi_mi + 1) * epi_tile_m)
epi_n_slice = slice(epi_ni * epi_tile_n, (epi_ni + 1) * epi_tile_n)
slot = (epi_mi * (tile_n // epi_tile_n) + epi_ni) % 2
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
out_smem[wg_idx, slot] = acc[epi_m_slice, epi_n_slice]
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
out_smem.at[wg_idx, slot],
out_gmem.at[cta_m_slice, cta_n_slice]
.at[wg_m_slice, wg_n_slice]
.at[epi_m_slice, epi_n_slice],
)
def mma_body(_, a_smem, b_smem, acc_ref):
with jax.named_scope("smem_load"):
a_reg = a_smem[wg_m_slice]
with jax.named_scope("dequant"):
a_reg = a_reg.astype(b.dtype)
with jax.named_scope("wgmma"):
plgpu.wgmma(acc_ref, a_reg, b_smem.at[:, wg_n_slice])
with jax.named_scope("wgmma_wait"):
plgpu.wgmma_wait(0)
return acc_ref
get_pipeline(mma_body, compute_context)(
a_gmem.at[cta_m_slice, :],
b_gmem.at[:, cta_n_slice],
allocations=pipeline_allocs,
)
# Await all transfers before we exit.
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
# We don't need multiple slots if there's only one epilogue tile.
num_out_slots = min(2, (tile_m * tile_n) // (epi_tile_m * epi_tile_n))
num_sms = backend.get_default_device().core_count
cluster_size = 1 + (config.cluster_dimension is not None)
f = plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), out_dtype),
grid=(num_sms // cluster_size,),
grid_names=("cluster_grid",),
cluster=(cluster_size,),
cluster_names=("cluster",),
num_threads=3,
thread_name="wg",
scratch_shapes=dict(
out_smem=plgpu.SMEM(
(2, num_out_slots, epi_tile_m, epi_tile_n),
out_dtype,
transforms=out_transforms,
)
),
)
return f(a, b)
def reference(
a: jax.Array, b: jax.Array, *, out_dtype: jnp.dtype
) -> jax.Array:
"""Reference implementation of a mixed-type matrix multiplication."""
return jax.numpy.dot(a, b, preferred_element_type=jnp.float32).astype(
out_dtype
)
def main(_) -> None:
problem_it = [(4096, 8192, 4096)]
for M, N, K in problem_it:
print(f"==== {M=} {N=} {K=} ====")
matmul_flops = 2 * M * N * K
peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS
a = jax.random.randint(
jax.random.key(0), minval=-128, maxval=127, shape=(M, K), dtype=jnp.int8
)
b = jax.random.uniform(jax.random.key(1), (K, N), jnp.bfloat16)
ref = reference(a, b, out_dtype=jnp.bfloat16)
tuning_it = itertools.product(
(64, 128, 256,), # tile_m
(64, 128), # tile_n
(64, 128), # tile_k
(4,), # max_concurrent_steps
(True,), # Tiled epilogue
(MatmulDimension.M, MatmulDimension.N), # grid_minor_dim
(4, 8, 16), # grid_tile_width
MatmulDimension, # wg_dimension
# Consider adding MatmulDimension here to try out collective TMA kernels
(None,) # cluster_dimension
)
best_util = 0.0
best_runtime = float("inf")
for tile_m, tile_n, tile_k, max_concurrent_steps, tiled_epilogue, grid_minor_dim, grid_tile_width, wg_dimension, cluster_dimension in tuning_it:
config = TuningConfig(
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
max_concurrent_steps=max_concurrent_steps,
epi_tile_n=64 if tiled_epilogue else None,
epi_tile_m=64 if tiled_epilogue else None,
grid_minor_dim=grid_minor_dim,
grid_tile_width=grid_tile_width,
wg_dimension=wg_dimension,
cluster_dimension=cluster_dimension,
)
try:
out, runtimes_ms = profiler.measure(
functools.partial(
mixed_matmul_kernel, out_dtype=jnp.bfloat16, config=config
),
iterations=10,
)(a, b)
assert runtimes_ms is not None
runtime_ms = statistics.median(runtimes_ms)
except ValueError as e:
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
continue
raise
np.testing.assert_allclose(out, ref)
runtime_us = runtime_ms * 1e3
optimal_time = matmul_flops / peak_flops * 1e6 # us
achieved_tc_util = optimal_time / runtime_us * 100
if achieved_tc_util > best_util:
best_runtime = runtime_us
best_util = achieved_tc_util
print(
f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {tiled_epilogue=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=} {cluster_dimension=}:"
f" {runtime_us:<7.1f}us = {achieved_tc_util:4.1f}% TC utilization"
)
print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization")
if __name__ == "__main__":
from absl import app
jax.config.config_with_absl()
app.run(main)
@@ -0,0 +1,345 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing fused layer norm forward and backward pass."""
from __future__ import annotations
import functools
import jax
from jax import lax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
def layer_norm_forward_kernel(
x_ref, weight_ref, bias_ref, # Input arrays
o_ref, mean_ref=None, rstd_ref=None, # Output arrays
*, eps: float, block_size: int):
n_col = x_ref.shape[0]
def mean_body(i, acc):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
return acc + a
mean = lax.fori_loop(
0,
pl.cdiv(n_col, block_size),
mean_body,
init_val=jnp.zeros(block_size),
).sum()
mean /= n_col
def var_body(i, acc):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
a = jnp.where(mask, a - mean, 0.)
return acc + a * a
var = lax.fori_loop(
0,
pl.cdiv(n_col, block_size),
var_body,
init_val=jnp.zeros(block_size),
).sum()
var /= n_col
rstd = 1 / jnp.sqrt(var + eps)
if mean_ref is not None:
mean_ref[...] = mean.astype(mean_ref.dtype)
if rstd_ref is not None:
rstd_ref[...] = rstd.astype(rstd_ref.dtype)
@pl.loop(0, pl.cdiv(n_col, block_size))
def body(i):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
weight = plgpu.load(weight_ref.at[col_idx], mask=mask)
bias = plgpu.load(bias_ref.at[col_idx], mask=mask)
x = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_first"
).astype(jnp.float32)
out = (x - mean) * rstd * weight + bias
plgpu.store(o_ref.at[col_idx], out.astype(o_ref.dtype), mask=mask)
def layer_norm_forward(
x, weight, bias,
num_warps: int | None = None,
num_stages: int | None = 3,
eps: float = 1e-5,
backward_pass_impl: str = 'triton',
interpret: bool = False):
del num_stages
del backward_pass_impl
n = x.shape[-1]
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = [
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype),
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
]
method = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape,
debug=False,
interpret=interpret,
name="ln_forward",
)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
out, mean, rstd = method(x, weight, bias)
return out, (x, weight, bias, mean, rstd)
def layer_norm_backward_kernel_dx(
# Inputs
x_ref, weight_ref, bias_ref, do_ref,
mean_ref, rstd_ref,
# Outputs
dx_ref,
*, eps: float, block_size: int):
n_col = x_ref.shape[0]
def mean_body(i, acc):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
dout = plgpu.load(
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
weight = plgpu.load(
weight_ref.at[col_idx],
mask=mask,
other=0.0,
eviction_policy="evict_last",
).astype(jnp.float32)
a_hat = (a - mean_ref[...]) * rstd_ref[...]
wdout = weight * dout
mean1_acc, mean2_acc = acc
return mean1_acc + a_hat * wdout, mean2_acc + wdout
mean1, mean2 = lax.fori_loop(
0,
pl.cdiv(n_col, block_size),
mean_body,
init_val=(jnp.zeros(block_size), jnp.zeros(block_size)),
)
mean1 = mean1.sum() / n_col
mean2 = mean2.sum() / n_col
@pl.loop(0, pl.cdiv(n_col, block_size))
def dx_body(i):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
dout = plgpu.load(
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
weight = plgpu.load(
weight_ref.at[col_idx],
mask=mask,
other=0.0,
eviction_policy="evict_last",
).astype(jnp.float32)
a_hat = (a - mean_ref[...]) * rstd_ref[...]
wdout = weight * dout
da = (wdout - (a_hat * mean1 + mean2)) * rstd_ref[...]
plgpu.store(dx_ref.at[col_idx], da.astype(dx_ref.dtype), mask=mask)
def layer_norm_backward_kernel_dw_db(
# Inputs
x_ref, weight_ref, bias_ref, do_ref,
mean_ref, rstd_ref,
# Outputs
dw_ref, db_ref,
*, eps: float, block_m: int, block_n: int):
m, n_col = x_ref.shape
j = pl.program_id(0)
col_idx = j * block_n + jnp.arange(block_n)
col_mask = col_idx < n_col
def body(i, acc):
row_idx = i * block_m + jnp.arange(block_m)
row_mask = row_idx < m
mask = row_mask[:, None] & col_mask[None, :]
a = plgpu.load(
x_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
).astype(jnp.float32)
dout = plgpu.load(
do_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
).astype(jnp.float32)
mean = plgpu.load(mean_ref.at[row_idx], mask=row_mask, other=0.0).astype(
jnp.float32
)
rstd = plgpu.load(rstd_ref.at[row_idx], mask=row_mask, other=0.0).astype(
jnp.float32
)
a_hat = (a - mean[:, None]) * rstd[:, None]
dw_acc_ref, db_acc_ref = acc
return dw_acc_ref + (dout * a_hat).sum(axis=0), db_acc_ref + dout.sum(
axis=0
)
dw_acc, db_acc = lax.fori_loop(
0,
pl.cdiv(m, block_m),
body,
init_val=(jnp.zeros(block_n), jnp.zeros(block_n)),
)
plgpu.store(dw_ref.at[col_idx], dw_acc.astype(dw_ref.dtype), mask=col_mask)
plgpu.store(db_ref.at[col_idx], db_acc.astype(db_ref.dtype), mask=col_mask)
def layer_norm_backward(
num_warps: int | None,
num_stages: int | None,
eps: float,
backward_pass_impl: str,
interpret: bool,
res, do):
del num_stages
x, weight, bias, mean, rstd = res
if backward_pass_impl == 'xla':
return jax.vjp(layer_norm_reference, x, weight, bias)[1](do)
*shape_prefix, n = x.shape
reshaped_x = x.reshape((-1, n))
reshaped_mean = mean.reshape((-1,))
reshaped_rstd = rstd.reshape((-1,))
reshaped_do = do.reshape((-1, n))
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
# layer_norm_backward_kernel_dx parallel over batch dims
kernel = functools.partial(layer_norm_backward_kernel_dx, eps=eps,
block_size=block_size)
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape_dx,
debug=False,
interpret=interpret,
name="ln_backward_dx",
)
method = jax.vmap(method, in_axes=(0, None, None, 0, 0, 0))
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
dx = dx.reshape((*shape_prefix, n))
# layer_norm_backward_kernel_dw_db reduce over batch dims
# Triton heuristics
if n > 10240:
block_n = 128
block_m = 32
num_warps = 4
else:
# maximize occupancy for small N
block_n = 16
block_m = 16
num_warps = 8
kernel = functools.partial(layer_norm_backward_kernel_dw_db, eps=eps,
block_m=block_m, block_n=block_n)
out_shape_dwbias = [
jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype),
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
]
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=grid_,
out_shape=out_shape_dwbias,
debug=False,
interpret=interpret,
name="ln_backward_dw_db",
)
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
return dx, dw, dbias
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7])
@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages",
"num_stages", "eps",
"backward_pass_impl",
"interpret"])
def layer_norm(
x, weight, bias,
num_warps: int | None = None,
num_stages: int | None = 3,
eps: float = 1e-5,
backward_pass_impl: str = 'triton',
interpret: bool = False):
n = x.shape[-1]
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(
num_warps=num_warps, num_stages=num_stages),
grid=(),
out_shape=out_shape,
debug=False,
interpret=interpret,
)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
return method(x, weight, bias)
layer_norm.defvjp(layer_norm_forward, layer_norm_backward)
@functools.partial(jax.jit, static_argnames=["eps"])
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
def layer_norm_reference(x, weight, bias, *, eps: float = 1e-5):
mean = jnp.mean(x, axis=1)
mean2 = jnp.mean(jnp.square(x), axis=1)
var = jnp.maximum(0., mean2 - jnp.square(mean))
y = x - mean[:, None]
mul = lax.rsqrt(var + eps)
return y * mul[:, None] * weight[None] + bias[None]
@@ -0,0 +1,461 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing decode attention."""
from __future__ import annotations
import functools
import math
from typing import Any
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
import jax.numpy as jnp
import numpy as np
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
def paged_attention_kernel(
# inputs
q_ref, # [block_h, head_dim]
k_pages_ref, # [total_num_pages, page_size, head_dim]
k_scales_pages_ref, # [total_num_pages, page_size]
v_pages_ref, # [total_num_pages, page_size, head_dim]
v_scales_pages_ref, # [total_num_pages, page_size]
block_tables_ref, # [pages_per_partition]
lengths_ref, # [1]
# outputs
o_ref: Any, # [block_h, head_dim]
*residual_refs: Any, # Residual outputs: [block_h,], [block_h,]
num_heads: int,
pages_per_compute_block: int,
mask_value: float,
attn_logits_soft_cap: float | None,
):
partition_idx = pl.program_id(2)
block_h, head_dim = q_ref.shape
page_size = k_pages_ref.shape[-2]
pages_per_partition = block_tables_ref.shape[0]
block_k = pages_per_compute_block * page_size
def _compute(start_page_idx, end_page_idx, o, m_i, l_i):
q_slice = pl.ds(0, block_h)
q = q_ref[q_slice, :]
# Loop over blocks of pages to process a entire page sequence partition.
# Grid loops over q blocks over num_heads.
def body(start_k, carry):
o_prev, m_prev, l_prev = carry
block_tables_slice = pl.ds(
start_k * pages_per_compute_block, pages_per_compute_block
)
block_tables = block_tables_ref[block_tables_slice]
k = k_pages_ref[block_tables].reshape(block_k, head_dim)
v = v_pages_ref[block_tables].reshape(block_k, head_dim)
if k_scales_pages_ref is not None:
# dynamic lhs quantized dot is not currently implemented
# so we cast rhs to the lhs dtype
k = k.astype(q.dtype)
uncapped_logits = pl.dot(q, k.T) # [block_h, block_k]
if k_scales_pages_ref is not None:
# k_scales_pages_ref are one per head
# they're laid out across the output dimension, so scale output
k_scale = k_scales_pages_ref[block_tables].reshape((1, block_k))
uncapped_logits *= k_scale.astype(uncapped_logits.dtype)
if attn_logits_soft_cap is not None:
logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap)
logits = logits * attn_logits_soft_cap
else:
logits = uncapped_logits
if lengths_ref is not None:
curr_start_page_idx = (
partition_idx * pages_per_partition
+ start_k * pages_per_compute_block
)
curr_start_token_idx = curr_start_page_idx * page_size
mask = jnp.arange(block_k) + curr_start_token_idx < lengths_ref[0]
mask = lax.broadcast_in_dim(mask, (block_h, block_k), (1,))
logits = jnp.where(mask, logits, mask_value)
log2e = math.log2(math.e)
m_curr = logits.max(axis=-1)
m_next = jnp.maximum(m_prev, m_curr)
correction = jnp.exp2((m_prev - m_next) * log2e)
l_prev_corr = correction * l_prev
s_curr = jnp.exp2((logits - m_next[:, None]) * log2e)
l_curr = s_curr.sum(axis=-1)
l_next = l_prev_corr + l_curr
o_prev_corr = correction[:, None] * o_prev
if v_scales_pages_ref is not None:
# v_scales are 1 per head
# they're laid out across the reduction dimension, so scale lhs
v_scale = v_scales_pages_ref[block_tables].reshape((1, block_k))
s_curr *= v_scale.astype(s_curr.dtype)
# dynamic lhs quantized dot is not currently implemented
# so we cast rhs to the lhs dtype
v = v.astype(s_curr.dtype)
o_curr = pl.dot(s_curr.astype(v.dtype), v)
o_next = o_prev_corr + o_curr
return o_next, m_next, l_next
max_it = pl.cdiv(end_page_idx - start_page_idx, pages_per_compute_block)
(o, m_i, l_i) = lax.fori_loop(0, max_it, body, (o, m_i, l_i))
return o, m_i, l_i
m_i = jnp.zeros(block_h, dtype=jnp.float32) + jnp.finfo(jnp.float32).min
l_i = jnp.zeros(block_h, dtype=jnp.float32)
o = jnp.zeros((block_h, head_dim), dtype=jnp.float32)
start_page_idx = partition_idx * pages_per_partition
end_page_idx = start_page_idx + pages_per_partition
if lengths_ref is None:
o, m_i, l_i = _compute(start_page_idx, end_page_idx, o, m_i, l_i)
else:
end_page_idx = jnp.minimum(pl.cdiv(lengths_ref[0], page_size), end_page_idx)
o, m_i, l_i = jax.lax.cond(
start_page_idx >= end_page_idx,
lambda: (o, m_i, l_i),
lambda: _compute(start_page_idx, end_page_idx, o, m_i, l_i),
)
o_ref[...] = o.astype(o_ref.dtype)
if residual_refs is not None:
l_ref, m_ref = residual_refs
l_ref[...] = l_i
m_ref[...] = m_i
def paged_attention_unbatched(
q: jax.Array, # [num_q_heads, head_dim]
k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim]
v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim]
block_tables: jax.Array, # [pages_per_sequence]
lengths: jax.Array | None, # [1]
k_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size]
v_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size]
*,
block_h: int,
pages_per_compute_block: int,
k_splits: int,
num_warps: int,
num_stages: int,
interpret: bool,
debug: bool,
mask_value: float,
attn_logits_soft_cap: float | None,
) -> jax.Array:
num_q_heads, head_dim = q.shape
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
pages_per_sequence = block_tables.shape[0]
assert (
pages_per_sequence % k_splits == 0
), f"{pages_per_sequence=} must be divisible by {k_splits=}."
pages_per_partition = pages_per_sequence // k_splits
pages_per_compute_block = min(pages_per_partition, pages_per_compute_block)
assert (
pages_per_partition % pages_per_compute_block == 0
), f"{pages_per_partition=} must de divisible by {pages_per_compute_block=}."
block_tables = block_tables.reshape(k_splits, pages_per_sequence // k_splits)
q_heads_per_kv_head = num_q_heads // num_kv_heads
q_reshaped = q.reshape(num_kv_heads, q_heads_per_kv_head, head_dim)
if q_heads_per_kv_head % block_h:
q_reshaped = jnp.pad(
q_reshaped, ((0, 0), (0, -q_heads_per_kv_head % block_h), (0, 0))
)
head_splits = pl.cdiv(q_heads_per_kv_head, block_h)
grid = (num_kv_heads, head_splits, k_splits)
kernel = functools.partial(
paged_attention_kernel,
num_heads=q_heads_per_kv_head,
pages_per_compute_block=pages_per_compute_block,
mask_value=mask_value,
attn_logits_soft_cap=attn_logits_soft_cap,
)
# set up quantization scales
if k_scales_pages is not None:
assert k_scales_pages.shape == (num_kv_heads, total_num_pages, page_size)
k_scales_spec = pl.BlockSpec((None, total_num_pages, page_size),
lambda h, i, k: (h, 0, 0))
else:
k_scales_spec = None
if v_scales_pages is not None:
assert v_scales_pages.shape == (num_kv_heads, total_num_pages, page_size)
v_scales_spec = pl.BlockSpec((None, total_num_pages, page_size),
lambda h, i, k: (h, 0, 0))
else:
v_scales_spec = None
o, l, m = pl.pallas_call(
kernel,
grid=grid,
in_specs=[
pl.BlockSpec(
(None, block_h, head_dim), lambda h, i, k: (h, i, 0)
), # q
pl.BlockSpec(
(None, total_num_pages, page_size, head_dim),
lambda h, i, k: (h, 0, 0, 0),
), # k_pages
k_scales_spec, # k_pages_scale
pl.BlockSpec(
(None, total_num_pages, page_size, head_dim),
lambda h, i, k: (h, 0, 0, 0),
), # v_pages
v_scales_spec, # v_pages_scale
pl.BlockSpec(
(None, pages_per_partition), lambda h, i, k: (k, 0)
), # block_tables
]
+ [
None if lengths is None else pl.BlockSpec((1,), lambda h, i, k: (0,))
], # lengths
out_specs=[
pl.BlockSpec(
(None, None, block_h, head_dim), lambda h, i, k: (k, h, i, 0)
), # q
pl.BlockSpec((None, None, block_h), lambda h, i, k: (k, h, i)), # l
pl.BlockSpec((None, None, block_h), lambda h, i, k: (k, h, i)), # m
],
out_shape=[
jax.ShapeDtypeStruct(
(k_splits, *q_reshaped.shape), dtype=q.dtype
), # o
jax.ShapeDtypeStruct(
(k_splits, *q_reshaped.shape[:-1]), dtype=jnp.float32
), # l
jax.ShapeDtypeStruct(
(k_splits, *q_reshaped.shape[:-1]), dtype=jnp.float32
), # m
],
debug=debug,
interpret=interpret,
compiler_params=plgpu.CompilerParams(
num_warps=num_warps, num_stages=num_stages
),
name=f"paged_attention_{block_h=}_{pages_per_compute_block=}",
)(q_reshaped, k_pages, k_scales_pages, v_pages, v_scales_pages, block_tables, lengths)
if q_heads_per_kv_head % block_h:
o = o[..., :q_heads_per_kv_head, :]
l = l[..., :q_heads_per_kv_head]
m = m[..., :q_heads_per_kv_head]
# final round of flash
m_next = m.max(axis=0)
correction = jnp.exp(m - m_next[None])
o = o * correction[..., None].astype(o.dtype)
l_next = (l * correction).sum(axis=0)
eps = jnp.finfo(l_next.dtype).eps
o = o.sum(axis=0) / ((l_next[..., None] + eps).astype(o.dtype))
o = o.reshape(q.shape).astype(q.dtype)
return o
@functools.partial(
jax.jit,
static_argnames=[
"block_h",
"pages_per_compute_block",
"k_splits",
"num_warps",
"num_stages",
"interpret",
"debug",
"mask_value",
"attn_logits_soft_cap",
],
)
def paged_attention(
q: jax.Array,
k_pages: jax.Array,
v_pages: jax.Array,
block_tables: jax.Array,
lengths: jax.Array | None,
k_scales_pages: jax.Array | None = None,
v_scales_pages: jax.Array | None = None,
*,
block_h: int = 16,
pages_per_compute_block: int = 8,
k_splits: int = 16,
num_warps: int = 8,
num_stages: int = 2,
interpret: bool = False,
debug: bool = False,
mask_value: float = DEFAULT_MASK_VALUE,
attn_logits_soft_cap: float | None = None,
) -> jax.Array:
"""Paged grouped query attention.
Args:
q: A [batch_size, num_heads, head_dim] jax.Array.
k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
block_tables: A i32[batch_size, pages_per_sequence] jax.Array. Each entry
should be in the range of [0, total_num_pages), indicating where to locate
the page in `k_pages` or `v_pages`.
lengths: A i32[batch_size] jax.Array the length of each example.
k_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array.
v_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array.
block_h: int The block size that partitions the number of head groups.
pages_per_compute_block: int The maximum number of blocks per compute block.
k_splits: int Number of partitions used to parallelize key-value sequence
pages processing.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
attn_logits_soft_cap: The value used for soft capping the attention logits.
Returns:
The output of attention([batch_size, num_heads, head_dim]).
"""
batch_size, num_heads, head_dim = q.shape
num_kv_heads, _, _, head_dim_k = k_pages.shape
batch_size_paged_indices, _ = block_tables.shape
if k_pages.shape != v_pages.shape:
raise ValueError(
f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and"
f" {v_pages.shape}"
)
if num_heads % num_kv_heads != 0:
raise ValueError(
"Number of Q heads must be divisible by number of KV heads. Got"
f" {num_heads} and {num_kv_heads}."
)
if head_dim_k != head_dim:
raise ValueError(
"head_dim of Q must be the same as that of K/V. Got"
f" {head_dim} and {head_dim_k}."
)
if batch_size_paged_indices != batch_size:
raise ValueError("`block_tables` and `q` must have the same batch size")
if lengths is not None:
if lengths.shape != (batch_size,):
raise ValueError("`lengths` and `q` must have the same batch size")
if lengths.dtype != jnp.int32:
raise ValueError(
"The dtype of `lengths` must be int32. Got {lengths.dtype}"
)
if block_h % 16:
raise ValueError(f"block_h must divisible by 16, but is {block_h}.")
impl = functools.partial(
paged_attention_unbatched,
block_h=block_h,
pages_per_compute_block=pages_per_compute_block,
k_splits=k_splits,
num_warps=num_warps,
num_stages=num_stages,
interpret=interpret,
debug=debug,
mask_value=mask_value,
attn_logits_soft_cap=attn_logits_soft_cap,
)
o = jax.vmap(impl, (0, None, None, 0, 0, None, None), 0)(
q,
k_pages,
v_pages,
block_tables,
lengths[..., None] if lengths is not None else None,
k_scales_pages,
v_scales_pages,
)
return o
@functools.partial(
jax.jit, static_argnames=["mask_value", "attn_logits_soft_cap"]
)
def paged_attention_reference(
q: jax.Array,
k: jax.Array,
v: jax.Array,
lengths: jax.Array,
*,
mask_value: float = DEFAULT_MASK_VALUE,
attn_logits_soft_cap: float | None = None,
) -> jax.Array:
"""Grouped query attention reference implementation.
Args:
q: A [batch_size, num_heads, head_dim] jax.Array.
k: A [batch_size, kv_seq_len, num_kv_heads, head_dim] jax.Array.
v: A [batch_size, kv_seq_len, num_kv_heads, head_dim] jax.Array.
lengths: A i32[batch_size] jax.Array the length of each example.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
attn_logits_soft_cap: The value used for soft capping the attention logits.
Returns:
The output of attention([batch_size, num_heads, head_dim]).
"""
batch_size, num_heads, head_dim = q.shape
_, kv_seq_len, num_kv_heads, _ = k.shape
q_heads_per_kv_head = num_heads // num_kv_heads
q_reshaped = q.reshape(
batch_size, num_kv_heads, q_heads_per_kv_head, head_dim
)
k_transposed = jnp.swapaxes(
k, 1, 2
) # [batch_size, num_kv_heads, kv_seq_len, head_dim]
v_transposed = jnp.swapaxes(
v, 1, 2
) # [batch_size, num_kv_heads, kv_seq_len, head_dim]
uncapped_logits = jnp.einsum(
"bkgd,bksd->bkgs", q_reshaped, k_transposed,
preferred_element_type=jnp.float32
).astype(jnp.float32)
if attn_logits_soft_cap is not None:
logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap)
logits = logits * attn_logits_soft_cap
else:
logits = uncapped_logits
if lengths is not None:
mask = jnp.arange(kv_seq_len)[None, :] < lengths[:, None]
mask = jnp.broadcast_to(mask[:, None, None, :], logits.shape)
logits = jnp.where(mask, logits, mask_value)
weights = jax.nn.softmax(logits, axis=-1)
o = jnp.einsum(
"bkgs,bksd->bkgd", weights, v_transposed.astype(jnp.float32),
preferred_element_type=jnp.float32
).astype(q.dtype)
o = o.reshape(q.shape)
return o
@@ -0,0 +1,333 @@
# Copyright 2025 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ragged dot Pallas-Mosaic-GPU implementation."""
import dataclasses
import functools
import itertools
import math
import jax
from jax import lax
from jax import numpy as jnp
from jax import random
from jax._src import test_util as jtu # noqa: F401
from jax.experimental import pallas as pl
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.pallas import mosaic_gpu as plgpu
import numpy as np
@dataclasses.dataclass(frozen=True)
class GroupInfo:
"""Information regarding the group being processed in a block."""
group_id: jax.Array
block: jax.Array
block_start: jax.Array
actual_start: jax.Array
actual_end: jax.Array
start_within_block: jax.Array
actual_size: jax.Array
@classmethod
def create(cls, group_lengths, tile, tid):
"""Get the group info for the current block."""
tile = jnp.int32(tile)
group_boundaries = [group_lengths[i] for i in range(len(group_lengths))]
# We usually only have very few groups, so we unroll the loop processing
# them. Normally we'd break out of the loop early, once we'd have found our
# boundary, but we can't do that when unrolling, so we rely on many selects
# to mask out the epilogue of the loop.
group_end = group_start = block = group = end = jnp.array(
0, dtype=jnp.int32
)
for i, b in enumerate(group_boundaries):
# Start/end are inclusive
start = end
end = start + b
final = end - 1
start_block = lax.div(start, tile)
final_block = lax.div(final, tile)
block_end = final_block + 1
tid_begin = start_block + i
tid_end = block_end + i
# How many blocks after is our block?
this_is_group = (tid_begin <= tid) & (tid < tid_end)
block = lax.select(this_is_group, tid - tid_begin + start_block, block)
group = lax.select(this_is_group, jnp.int32(i), group)
group_start = lax.select(this_is_group, start, group_start)
group_end = lax.select(this_is_group, end, group_end)
block_start = block * tile
actual_start = jnp.maximum(group_start, block_start)
actual_end = jnp.minimum(group_end, block_start + tile)
start_within_block = actual_start - block_start
actual_size = actual_end - actual_start
return cls(
group_id=group,
block=block,
block_start=block_start,
actual_start=actual_start,
actual_end=actual_end,
start_within_block=start_within_block,
actual_size=actual_size,
)
def ragged_dot(
lhs, # (M, K)
rhs, # (G, K, N)
*,
group_sizes, # (G,)
block_m: int,
block_n: int,
block_k: int,
max_concurrent_steps: int,
grid_block_n: int,
transpose_rhs: bool = False,
load_group_sizes_to_register: bool = True,
) -> jax.Array:
if lhs.dtype != rhs.dtype:
raise NotImplementedError(
f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}"
)
m, k = lhs.shape
g, k2, n = rhs.shape
if transpose_rhs:
k2, n = n, k2
if group_sizes.shape[0] != g:
raise ValueError(
f"Expected group_sizes to have shape {g} but got {group_sizes.shape}"
)
if k != k2:
raise ValueError(f"lhs.shape={k} must match rhs.shape={k2}")
if k % block_k != 0:
raise ValueError(f"k={k} must be a multiple of block_k={block_k}")
def body(rows_per_expert_gmem, lhs_gmem, rhs_gmem, o_gmem):
grid_m = pl.cdiv(m, block_m) + g - 1
grid_n = pl.cdiv(n, block_n)
grid = (grid_m * grid_n,)
@plgpu.nd_loop(grid, collective_axes="sm")
def mn_loop(loop_info: plgpu.NDLoopInfo):
mi, ni = plgpu.planar_snake(
loop_info.index[0],
(grid_m, grid_n),
1,
grid_block_n,
)
group_info = GroupInfo.create(rows_per_expert_gmem, block_m, mi)
def acc_scope(acc_ref):
plgpu.emit_pipeline(
lambda _, lhs_smem, rhs_smem: plgpu.wgmma(
acc_ref,
lhs_smem,
plgpu.transpose_ref(rhs_smem, (1, 0)) if transpose_rhs else rhs_smem,
),
grid=(k // block_k,),
in_specs=[
plgpu.BlockSpec(
(block_m, block_k),
lambda k: (group_info.block, k),
delay_release=1,
),
plgpu.BlockSpec(
(block_n, block_k) if transpose_rhs else (block_k, block_n),
lambda k: (ni, k) if transpose_rhs else (k, ni),
delay_release=1,
),
],
max_concurrent_steps=max_concurrent_steps,
)(lhs_gmem, rhs_gmem.at[group_info.group_id])
return acc_ref[...]
acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n)))
@functools.partial(
pl.run_scoped,
o_smem=plgpu.SMEM((block_m, block_n), dtype=o_gmem.dtype)
)
def store_scope(o_smem):
o_smem[...] = acc.astype(o_smem.dtype)
plgpu.commit_smem()
smem_start = group_info.start_within_block
remaining_rows = min(block_m, m)
# TMA descriptors need to be generated with static tile sizes along each
# axis, but we do not know at compile time how many rows we will need to
# store. We only know that the number of rows to store is bounded by
# min(block_m, m).
#
# In order to work around that, we construct a logarithmic ladder of
# TMA descriptors, where each descriptor can store 2**i rows for some
# i between 0 and log2(min(block_m, m)). This allows storing any
# number of rows we will need to store, so long as this number of rows
# is between `1` and `min(block_m, m)`.
#
# E.g., imagine we have block_m = 8, m = 16. The loop below will be
# unrolled into 4 iterations, where the first one will generate a TMA
# descriptor that can store 8 rows, the second one will generate a TMA
# descriptor that can store 4 rows, etc. all the way to 1 row.
#
# At run time, we finally know the actual number of rows we need to
# store as we go through the unrolled loop iterations. Let's imagine
# that we need to store 5 rows.
#
# The first unrolled iteration will check whether we can store 8 rows.
# Since we only need to store 5 rows, we won't store anything then.
#
# The second unrolled iteration will check whether we can store 4 rows.
# We're able to store 4 rows, and are left with a single remaining row.
#
# The fourth unrolled iteration will store the single remaining row, and
# we end up with a storing scheme as follows for our 5 rows:
#
# -----------------------------------------------------------
# 0 | |
# 1 | |
# 2 | Store 4 rows |
# 3 | |
# -----------------------------------------------------------
# 4 | Store 1 row |
# -----------------------------------------------------------
while remaining_rows > 0:
const_rows_len = 1 << int(math.log2(remaining_rows))
remaining_rows //= 2
@pl.when(group_info.actual_size & const_rows_len != 0)
def _():
o_smem_slice = o_smem.at[pl.ds(smem_start, const_rows_len)]
o_gref_slice = o_gmem.at[
pl.ds(group_info.block_start + smem_start, const_rows_len),
pl.ds(ni * block_n, block_n),
]
plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice)
smem_start += group_info.actual_size & const_rows_len
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
# There are 132 SMs on a H100 SXM GPU.
num_sms = 132
kernel = plgpu.kernel(
body,
out_shape=jax.ShapeDtypeStruct((m, n), lhs.dtype),
grid=(num_sms,),
grid_names=("sm",),
compiler_params=plgpu.CompilerParams(
lowering_semantics=plgpu.LoweringSemantics.Warpgroup,
),
)
return kernel(group_sizes, lhs, rhs)
def main(unused_argv):
for transpose_rhs in [False, True]:
m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16
kx, ky, kz = random.split(random.key(1234), num=3)
lhs = jax.random.normal(kx, (m, k), jnp.float16)
if transpose_rhs:
rhs = jax.random.normal(ky, (num_groups, n, k), jnp.float16)
else:
rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16)
group_boundaries = jax.lax.sort(
jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32)
)
group_starts = lax.concatenate(
[jnp.array([0], dtype=jnp.int32), group_boundaries], 0
)
group_ends = lax.concatenate(
[group_boundaries, jnp.array([m], dtype=jnp.int32)], 0
)
group_sizes = group_ends - group_starts
assert group_sizes.shape == (num_groups,)
block_m = block_n = (64, 128, 192)
block_k = (64,)
max_concurrent_steps = (2, 4, 5, 6)
grid_block_n = (1, 2, 4, 8, 16)
configs = itertools.product(
block_m, block_n, block_k, max_concurrent_steps, grid_block_n
)
names = (
"block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n"
)
best_runtime = float("inf")
best_kwargs: dict[str, int] = {}
for config in configs:
kwargs = dict(zip(names, config))
if n % (kwargs["grid_block_n"] * kwargs["block_n"]):
continue
try:
f = functools.partial(
ragged_dot, group_sizes=group_sizes, transpose_rhs=transpose_rhs,
**kwargs
)
_, runtime = profiler.measure(f)(lhs, rhs)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
runtime = float("inf")
# Enable this to get more detailed information.
else:
assert runtime is not None
print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
if runtime < best_runtime:
best_runtime = runtime
best_kwargs = kwargs
if not best_kwargs:
raise ValueError("No valid configuration found")
def ref_ragged_dot(lhs, rhs, group_sizes):
if transpose_rhs:
rhs = jnp.transpose(rhs, (0, 2, 1))
return jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes)
ref, ref_runtime = profiler.measure(ref_ragged_dot)(
lhs, rhs, group_sizes=group_sizes
)
assert ref_runtime is not None
result = ragged_dot(
lhs, rhs, group_sizes=group_sizes, transpose_rhs=transpose_rhs,
load_group_sizes_to_register=True,
**best_kwargs
)
np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3)
tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
print(f"Transpose RHS: {transpose_rhs}")
print(
"Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())
)
print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")
if __name__ == "__main__":
from absl import app
jax.config.config_with_absl()
app.run(main)
@@ -0,0 +1,244 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reduce scatter kernel implemented using Mosaic GPU."""
import functools
import itertools
import math
from typing import Literal
import jax
from jax import lax
from jax.experimental import multihost_utils
from jax.experimental import pallas as pl
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.pallas import mosaic_gpu as plgpu
from jax.extend import backend
import jax.numpy as jnp
def reduce_scatter(
x: jax.Array,
*,
axis_name,
scatter_dimension: int | None = 0,
reduction: Literal["add", "min", "max", "and", "or", "xor"] = "add",
num_blocks: int | None = None,
tile_size: int | None = None,
vec_size: int | None = None,
) -> jax.Array:
"""Performs a reduce-scatter or all-reduce operation across devices using multimem instructions.
Args:
x: Input array. Should be sharded across the specified axis.
axis_name: Name of the mesh axis to reduce-scatter across.
scatter_dimension: Axis along which to reduce-scatter. If None, performs
all-reduce instead. Defaults to 0.
reduction: Reduction operation to perform. Supported: "add", "min", "max",
"and", "or", "xor".
vec_size: Vector size for the layout. If None, automatically inferred from dtype.
num_blocks: Number of blocks to use. Defaults to the device core count.
tile_size: Total tile size to split across major, scatter, and minor dimensions.
"""
num_devices = lax.axis_size(axis_name)
input_shape = x.shape
dtype = x.dtype
ndim = len(input_shape)
if num_blocks is None:
num_blocks = backend.get_default_device().core_count
if scatter_dimension is None:
major_dims, scatter_dim, minor_dims = 1, math.prod(input_shape), 1
output_scatter_dim = scatter_dim
output_shape = input_shape
else:
if scatter_dimension < -ndim or scatter_dimension >= ndim:
raise ValueError(
f"scatter_dimension {scatter_dimension} out of bounds for array of"
f" dimension {ndim}"
)
if scatter_dimension < 0:
scatter_dimension += ndim
scatter_dim = input_shape[scatter_dimension]
if scatter_dim % num_devices != 0:
raise ValueError(
f"Scattered dimension {scatter_dimension} of input ({scatter_dim})"
f" must be divisible by number of devices ({num_devices})"
)
major_dims = math.prod(input_shape[:scatter_dimension])
minor_dims = math.prod(input_shape[scatter_dimension+1:])
output_scatter_dim = scatter_dim // num_devices
output_shape = (
*input_shape[:scatter_dimension], output_scatter_dim, *input_shape[scatter_dimension + 1 :],
)
if (output_size := math.prod(output_shape)) % 128:
raise ValueError("Output size must be divisible by 128")
if jnp.issubdtype(dtype, jnp.integer):
if vec_size is None:
vec_size = 1 # Integer types only support unvectorized reductions
elif vec_size != 1:
raise ValueError("Integer types only support vec_size=1")
elif vec_size is None: # vec_size inference for floating point types
dtype_bits = jnp.finfo(dtype).bits
max_vec_size = min(128 // dtype_bits, output_size // 128)
if tile_size is not None:
max_vec_size_for_tile = tile_size // 128
max_vec_size = min(max_vec_size, max_vec_size_for_tile)
vec_size = 32 // dtype_bits # We don't support ld_reduce below 32-bit
while vec_size * 2 <= max_vec_size:
vec_size *= 2
if math.prod(output_shape) % vec_size:
raise ValueError(
"The total number of elements in the output"
f" ({math.prod(output_shape)}) must be divisible by the vec_size"
f" ({vec_size})"
)
min_transfer_elems = 128 * vec_size
if tile_size is None:
# TODO(apaszke): 8 is just an arbitrary unrolling factor. Tune it!
unroll_factor = min(math.prod(output_shape) // min_transfer_elems, 8)
tile_size = unroll_factor * min_transfer_elems
if tile_size < min_transfer_elems:
raise ValueError(
f"{tile_size=} is smaller than minimum required"
f" {min_transfer_elems} for {vec_size=}"
)
minor_tile = math.gcd(tile_size, minor_dims)
remaining_tile = tile_size // minor_tile
scatter_tile = math.gcd(remaining_tile, output_scatter_dim)
major_tile = remaining_tile // scatter_tile
if major_dims % major_tile != 0:
raise NotImplementedError(
f"Major dimension size ({major_dims}) must be divisible by the"
f" inferred major tile size ({major_tile}). Consider adjusting tile_size."
)
def kernel(x_ref, y_ref, done_barrier):
dev_idx = lax.axis_index(axis_name)
x_ref_3d = x_ref.reshape((major_dims, scatter_dim, minor_dims))
y_ref_3d = y_ref.reshape((major_dims, output_scatter_dim, minor_dims))
if scatter_dimension is not None:
dev_slice = pl.ds(dev_idx * output_scatter_dim, output_scatter_dim)
x_ref_3d = x_ref_3d.at[:, dev_slice, :]
major_tiles = major_dims // major_tile
scatter_tiles = output_scatter_dim // scatter_tile
minor_tiles = minor_dims // minor_tile
@plgpu.nd_loop((major_tiles, scatter_tiles, minor_tiles), collective_axes="blocks")
def _transfer_loop(loop_info: plgpu.NDLoopInfo):
major_tile_idx, scatter_tile_idx, minor_tile_idx = loop_info.index
idxs = (
pl.ds(major_tile_idx * major_tile, major_tile),
pl.ds(scatter_tile_idx * scatter_tile, scatter_tile),
pl.ds(minor_tile_idx * minor_tile, minor_tile)
)
y_ref_3d[idxs] = plgpu.layout_cast(
plgpu.multimem_load_reduce(
x_ref_3d.at[idxs], collective_axes=axis_name, reduction_op=reduction
),
plgpu.Layout.WG_STRIDED((major_tile, scatter_tile, minor_tile), vec_size=vec_size)
)
# Wait for everyone to finish reading the operands before we exit and potentially free them
plgpu.semaphore_signal_multicast(done_barrier, collective_axes=axis_name)
pl.semaphore_wait(done_barrier, num_devices, decrement=False)
return plgpu.kernel(
kernel,
out_shape=jax.ShapeDtypeStruct(output_shape, dtype),
grid=(num_blocks,),
grid_names=("blocks",),
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
)(x)
def _run_example():
P = jax.sharding.PartitionSpec
shape = (4 * 4096, 4 * 4096) # This shape is global!
dtype = jnp.bfloat16
shards = jax.device_count()
mesh = jax.make_mesh(
(shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
)
jax.set_mesh(mesh)
# We measure time per-shard and so we only need bytes per shard.
local_in_bytes = math.prod(shape) / shards * jnp.dtype(dtype).itemsize
# In reduce-scatter, we send (shards - 1) / shards worth of input data to the
# switch and receive as much data as in the whole output, which is 1 / shards.
total_bytes = local_in_bytes
a = jax.random.normal(jax.random.key(1), shape, dtype)
a = jax.sharding.reshard(a, P(None, "x"))
@jax.jit
@functools.partial(jax.shard_map, mesh=mesh, in_specs=P(None, "x"), out_specs=P(None, "x"))
def ref_fn(x):
return lax.psum_scatter(x, "x", scatter_dimension=1, tiled=True)
ref_fn(a).block_until_ready() # Warmup.
_, ref_kernels_ms = profiler.measure(ref_fn, aggregate=False)(a)
assert ref_kernels_ms is not None
ref_time_us = sum(t * 1e3 for _, t in ref_kernels_ms)
# We choose the minimum across processes to choose the runtime that didn't
# include devices waiting for other devices.
ref_time_us = min(multihost_utils.process_allgather(ref_time_us).tolist())
ref_bw = total_bytes / (ref_time_us * 1e-6) / 1e9 # GB/s
tuning_it = itertools.product(
(4, 8, 16, 32, 64, 132), # num_blocks
(1024, 2048, 4096, 8192), # tile_size
)
best_bw = 0.0
best_runtime = float("inf")
for num_blocks, tile_size in tuning_it:
try:
@jax.jit
@functools.partial(
jax.shard_map, mesh=mesh, in_specs=P(None, "x"), out_specs=P(None, "x"), check_vma=False
)
def kernel_fn(x):
return reduce_scatter(x, axis_name="x", scatter_dimension=1, num_blocks=num_blocks, tile_size=tile_size)
kernel_fn(a).block_until_ready() # Warmup.
_, kernels_ms = profiler.measure(kernel_fn, aggregate=False)(a)
except ValueError as e:
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
continue
raise
assert kernels_ms is not None
runtime_us = sum(t * 1e3 for _, t in kernels_ms)
runtime_us = min(multihost_utils.process_allgather(runtime_us).tolist())
achieved_bw = total_bytes / (runtime_us * 1e-6) / 1e9 # GB/s
if achieved_bw > best_bw:
best_runtime = runtime_us
best_bw = achieved_bw
print(f"{num_blocks=}, {tile_size=}: {runtime_us:<7.1f}us = {achieved_bw:4.1f} GB/s")
print(f"Total bytes transferred: {total_bytes / 1e9:.2f} GB")
print(f"\tBest: {best_runtime:<7.1f}us = {best_bw:4.1f} GB/s")
print(f"\tRef: {ref_time_us:<7.1f}us = {ref_bw:4.1f} GB/s")
if __name__ == "__main__":
from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error
jt_multiprocess.main(shard_main=_run_example)
@@ -0,0 +1,310 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module containing rms forward and backward pass."""
from __future__ import annotations
import functools
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
import jax.numpy as jnp
def rms_norm_forward_kernel(
x_ref, weight_ref, bias_ref, # Input arrays
o_ref, rstd_ref=None, # Output arrays
*, eps: float, block_size: int):
n_col = x_ref.shape[0]
def var_body(i, acc):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
a = jnp.where(mask, a, 0.)
return acc + a * a
var = lax.fori_loop(
0, pl.cdiv(n_col, block_size), var_body, init_val=jnp.zeros(block_size)
).sum()
var /= n_col
rstd = 1 / jnp.sqrt(var + eps)
if rstd_ref is not None:
rstd_ref[...] = rstd.astype(rstd_ref.dtype)
@pl.loop(0, pl.cdiv(n_col, block_size))
def body(i):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
weight = plgpu.load(weight_ref.at[col_idx], mask=mask)
bias = plgpu.load(bias_ref.at[col_idx], mask=mask)
x = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_first"
).astype(jnp.float32)
out = x * rstd * weight + bias
plgpu.store(o_ref.at[col_idx], out.astype(o_ref.dtype), mask=mask)
def rms_norm_forward(
x, weight, bias,
num_warps: int | None = None,
num_stages: int | None = 3,
eps: float = 1e-5,
backward_pass_impl: str = 'triton',
interpret: bool = False):
del num_stages
del backward_pass_impl
n = x.shape[-1]
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = [
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
]
method = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape,
debug=False,
interpret=interpret,
name="rms_forward",
)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
out, rstd = method(x, weight, bias)
return out, (x, weight, bias, rstd)
def rms_norm_backward_kernel_dx(
# Inputs
x_ref, weight_ref, bias_ref, do_ref,
rstd_ref,
# Outputs
dx_ref,
*, eps: float, block_size: int):
n_col = x_ref.shape[0]
def mean_body(i, c1_acc):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
dout = plgpu.load(
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
weight = plgpu.load(
weight_ref.at[col_idx],
mask=mask,
other=0.0,
eviction_policy="evict_last",
).astype(jnp.float32)
a_hat = a * rstd_ref[...]
wdout = weight * dout
return c1_acc + a_hat * wdout
c1 = lax.fori_loop(
0, pl.cdiv(n_col, block_size), mean_body, jnp.zeros(block_size)
)
c1 = c1.sum() / n_col
@pl.loop(0, pl.cdiv(n_col, block_size))
def dx_body(i):
col_idx = i * block_size + jnp.arange(block_size)
mask = col_idx < n_col
a = plgpu.load(
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
dout = plgpu.load(
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
).astype(jnp.float32)
weight = plgpu.load(
weight_ref.at[col_idx],
mask=mask,
other=0.0,
eviction_policy="evict_last",
).astype(jnp.float32)
a_hat = a * rstd_ref[...]
wdout = weight * dout
da = (wdout - (a_hat * c1)) * rstd_ref[...]
plgpu.store(dx_ref.at[col_idx], da.astype(dx_ref.dtype), mask=mask)
def rms_norm_backward_kernel_dw_db(
# Inputs
x_ref, weight_ref, bias_ref, do_ref,
rstd_ref,
# Outputs
dw_ref, db_ref,
*, eps: float, block_m: int, block_n: int):
m, n_col = x_ref.shape
j = pl.program_id(0)
col_idx = j * block_n + jnp.arange(block_n)
col_mask = col_idx < n_col
def body(i, acc):
row_idx = i * block_m + jnp.arange(block_m)
row_mask = row_idx < m
mask = row_mask[:, None] & col_mask[None, :]
a = plgpu.load(
x_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
).astype(jnp.float32)
dout = plgpu.load(
do_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
).astype(jnp.float32)
rstd = plgpu.load(rstd_ref.at[row_idx], mask=row_mask, other=0.0).astype(
jnp.float32
)
a_hat = a * rstd[:, None]
dw_acc, db_acc = acc
return (dw_acc + (dout * a_hat).sum(axis=0), db_acc + dout.sum(axis=0))
dw_acc, db_acc = lax.fori_loop(
0,
pl.cdiv(m, block_m),
body,
init_val=(jnp.zeros(block_n), jnp.zeros(block_n)),
)
plgpu.store(dw_ref.at[col_idx], dw_acc.astype(dw_ref.dtype), mask=col_mask)
plgpu.store(db_ref.at[col_idx], db_acc.astype(db_ref.dtype), mask=col_mask)
def rms_norm_backward(
num_warps: int | None,
num_stages: int | None,
eps: float,
backward_pass_impl: str,
interpret: bool,
res, do):
del num_stages
x, weight, bias, rstd = res
if backward_pass_impl == 'xla':
return jax.vjp(rms_norm_reference, x, weight, bias)[1](do)
*shape_prefix, n = x.shape
reshaped_x = x.reshape((-1, n))
reshaped_rstd = rstd.reshape((-1,))
reshaped_do = do.reshape((-1, n))
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
# rms_norm_backward_kernel_dx parallel over batch dims
kernel = functools.partial(rms_norm_backward_kernel_dx, eps=eps,
block_size=block_size)
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=(),
out_shape=out_shape_dx,
debug=False,
interpret=interpret,
name="ln_backward_dx",
)
method = jax.vmap(method, in_axes=(0, None, None, 0, 0))
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
dx = dx.reshape((*shape_prefix, n))
# rms_norm_backward_kernel_dw_db reduce over batch dims
# Triton heuristics
if n > 10240:
block_n = 128
block_m = 32
num_warps = 4
else:
# maximize occupancy for small N
block_n = 16
block_m = 16
num_warps = 8
kernel = functools.partial(rms_norm_backward_kernel_dw_db, eps=eps,
block_m=block_m, block_n=block_n)
out_shape_dwbias = [
jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype),
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
]
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
grid=grid_,
out_shape=out_shape_dwbias,
debug=False,
interpret=interpret,
name="ln_backward_dw_db",
)
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
return dx, dw, dbias
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7])
@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages",
"num_stages", "eps",
"backward_pass_impl",
"interpret"])
def rms_norm(
x, weight, bias,
num_warps: int | None = None,
num_stages: int | None = 3,
eps: float = 1e-5,
backward_pass_impl: str = 'triton',
interpret: bool = False):
n = x.shape[-1]
# Triton heuristics
# Less than 64KB per feature: enqueue fused kernel
max_fused_size = 65536 // x.dtype.itemsize
block_size = min(max_fused_size, pl.next_power_of_2(n))
block_size = min(max(block_size, 128), 4096)
num_warps = min(max(block_size // 256, 1), 8)
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
block_size=block_size)
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
method = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(
num_warps=num_warps, num_stages=num_stages
),
grid=(),
out_shape=out_shape,
debug=False,
interpret=interpret,
)
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
return method(x, weight, bias)
rms_norm.defvjp(rms_norm_forward, rms_norm_backward)
@functools.partial(jax.jit, static_argnames=["eps"])
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
def rms_norm_reference(x, weight, bias, *, eps: float = 1e-5):
var = jnp.mean(jnp.square(x), axis=1)
mul = lax.rsqrt(var + eps)
return x * mul[:, None] * weight[None] + bias[None]
@@ -0,0 +1,94 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pallas softmax kernel."""
import functools
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
def _vmappable_softmax_kernel(
# inputs
input_ref,
# outputs
probs_ref,
*,
# block information
# It is assumed that block_row >= row_len
block_row: int,
):
row_len = input_ref.shape[-1]
mask = jnp.arange(block_row) < row_len
row = plgpu.load(
input_ref.at[pl.ds(0, block_row)], mask=mask, other=-float("inf")
)
row_max = jnp.max(row, axis=0)
numerator = jnp.exp((row - row_max).astype(jnp.float32))
denominator = jnp.sum(numerator, axis=0)
plgpu.store(
probs_ref.at[pl.ds(0, block_row)],
(numerator / denominator).astype(probs_ref.dtype),
mask=mask
)
@functools.partial(jax.jit, static_argnames=["axis", "num_warps", "interpret",
"debug"])
def softmax(
x: jax.Array, *, axis: int = -1, num_warps: int = 4,
interpret: bool = False, debug: bool = False
) -> jax.Array:
"""Computes the softmax of the input array along the specified axis.
Args:
x: input array
axis: the axis along which to perform the computation
num_warps: the number of warps to use for executing the Triton kernel
interpret: whether to interpret the kernel using pallas
debug: whether to use pallas in debug mode
Returns:
The result of the softmax operation over the specified axis of x.
"""
axis = axis if axis >= 0 else len(x.shape) + axis
if axis != len(x.shape) - 1:
raise NotImplementedError(
"reductions along non-trailing dimension unsupported")
row_len = x.shape[-1]
block_row = pl.next_power_of_2(row_len)
out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype)
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
f = pl.pallas_call(
kernel,
compiler_params=plgpu.CompilerParams(
num_warps=num_warps, num_stages=1),
grid=(),
out_shape=out_shape,
debug=debug,
interpret=interpret,
)
for _ in range(len(x.shape) - 1):
f = jax.vmap(f)
return f(x)
@@ -0,0 +1,302 @@
# Copyright 2025 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transposed ragged dot Pallas-Mosaic-GPU implementation."""
import functools
import itertools
import jax
from jax import lax
from jax import numpy as jnp
from jax import random
from jax._src import test_util as jtu # noqa: F401
from jax.experimental import pallas as pl
from jax.experimental.mosaic.gpu import profiler
from jax.experimental.pallas import mosaic_gpu as plgpu
import numpy as np
def transposed_ragged_dot(
lhs, # (K, M)
rhs, # (K, N)
*,
group_sizes, # (G,)
block_m: int,
block_n: int,
block_k: int,
max_concurrent_steps: int,
grid_block_n: int,
) -> jax.Array:
if lhs.dtype != rhs.dtype:
raise NotImplementedError(
f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}"
)
k, m = lhs.shape
k2, n = rhs.shape
g = group_sizes.shape[0]
if k != k2:
raise ValueError(f"lhs.shape={k} must match rhs.shape={k2}")
if m % block_m != 0:
raise ValueError(f"m={m} must be a multiple of block_m={block_m}")
if n % block_n != 0:
raise ValueError(f"n={n} must be a multiple of block_n={block_n}")
group_sizes = group_sizes.astype(int)
group_starts = jnp.concatenate(
[jnp.zeros(1, dtype=int), jnp.cumsum(group_sizes)[:-1]]
).astype(int)
group_ends = jnp.cumsum(group_sizes)
group_block_starts = group_starts // block_k * block_k
group_block_ends = -(group_ends // -block_k) * block_k
group_num_blocks = (group_block_ends - group_block_starts) // block_k
swizzle = plgpu.find_swizzle(block_k * jnp.dtype(lhs.dtype).itemsize * 8)
swizzle_elems = swizzle // jnp.dtype(lhs.dtype).itemsize
transforms = (
plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle)
)
def body(
group_sizes_gmem,
group_starts_gmem,
group_ends_gmem,
group_num_blocks_gmem,
group_block_starts_gmem,
lhs_gmem,
rhs_gmem,
o_gmem,
):
grid_m = pl.cdiv(m, block_m)
grid_n = pl.cdiv(n, block_n)
@plgpu.nd_loop((g, grid_m * grid_n), collective_axes="sm")
def mn_loop(loop_info: plgpu.NDLoopInfo):
g_i = loop_info.index[0]
m_i, n_i = plgpu.planar_snake(
loop_info.index[1],
(grid_m, grid_n),
1,
grid_block_n,
)
# This slice is potentially out of bounds, but we never access the
# out of bound part in emit_pipeline.
gmem_slice = pl.ds(group_block_starts_gmem[g_i], k)
def acc_scope(acc_ref):
def block_matmul(block_idx, lhs_smem, rhs_smem):
block_idx = block_idx[0]
@pl.when(block_idx == 0)
def _():
# Handles the first block of the group, where there might be
# data from the previous group in the beginning of the block.
lhs_reg = lhs_smem[...]
start_index = lax.rem(group_starts_gmem[g_i], block_k)
indices = plgpu.layout_cast(
jax.lax.broadcasted_iota(jnp.int32, (block_k, block_m), 0),
plgpu.Layout.WGMMA
)
lhs_mask = (indices >= start_index).astype(lhs_smem.dtype)
lhs_reg = lhs_reg * lhs_mask
lhs_smem[...] = lhs_reg
plgpu.commit_smem()
@pl.when(block_idx == group_num_blocks_gmem[g_i] - 1)
def _():
# Handles the last block of the group, where there might be
# data from the next group in the end of the block.
lhs_reg = lhs_smem[...]
last_index = lax.rem(group_ends_gmem[g_i] - 1, block_k)
indices = plgpu.layout_cast(
jax.lax.broadcasted_iota(jnp.int32, (block_k, block_m), 0),
plgpu.Layout.WGMMA
)
lhs_mask = (indices <= last_index).astype(lhs_smem.dtype)
lhs_reg = lhs_reg * lhs_mask
lhs_smem[...] = lhs_reg
plgpu.commit_smem()
plgpu.wgmma(acc_ref, plgpu.transpose_ref(lhs_smem, (1, 0)), rhs_smem)
if max_concurrent_steps == 1:
# Without delayed release, we won't have at least two separate
# smem blocks in flight. Therefore, we cannot rely on the implicit
# wait of wgmma to gaurantee that the data in smem is ready to be
# overwritten by the next pipeline iteration.
plgpu.wgmma_wait(0)
@pl.when(group_sizes_gmem[g_i] > 0) # Skip the group if it is empty.
def _():
plgpu.emit_pipeline(
block_matmul,
grid=(group_num_blocks_gmem[g_i],),
in_specs=[
plgpu.BlockSpec(
(block_k, block_m),
lambda k_i: (k_i, m_i),
delay_release=1 if max_concurrent_steps > 1 else 0,
transforms=transforms,
),
plgpu.BlockSpec(
(block_k, block_n),
lambda k_i: (k_i, n_i),
delay_release=1 if max_concurrent_steps > 1 else 0,
transforms=transforms,
),
],
max_concurrent_steps=max_concurrent_steps,
)(lhs_gmem.at[gmem_slice, :], rhs_gmem.at[gmem_slice, :])
return acc_ref[...]
acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n)))
@functools.partial(
pl.run_scoped,
o_smem=plgpu.SMEM(
(block_m, block_n),
dtype=o_gmem.dtype,
transforms=transforms,
)
)
def store_scope(o_smem):
o_smem[...] = acc.astype(o_smem.dtype)
plgpu.commit_smem()
plgpu.copy_smem_to_gmem(
o_smem, o_gmem.at[
g_i,
pl.ds(m_i * block_m, block_m),
pl.ds(n_i * block_n, block_n)
]
)
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
# There are 132 SMs on a H100 SXM GPU.
num_sms = jax.devices()[0].core_count
kernel = plgpu.kernel(
body,
out_shape=jax.ShapeDtypeStruct((g, m, n), lhs.dtype),
grid=(num_sms,),
grid_names=("sm",),
)
return kernel(
group_sizes,
group_starts,
group_ends,
group_num_blocks,
group_block_starts,
lhs,
rhs,
)
def ref_transposed_ragged_dot(lhs, rhs, group_sizes):
return jax.lax.ragged_dot_general(
lhs, rhs, group_sizes,
ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers(
dot_dimension_numbers=(((0,), (0,)), ((), ())),
lhs_ragged_dimensions=[0],
rhs_group_dimensions=[],
)
)
def main(unused_argv):
k, m, n, num_groups = 16 * 1024, 2048, 2048, 16
kx, ky, kz = random.split(random.key(1234), num=3)
lhs = jax.random.normal(kx, (k, m), jnp.float16)
rhs = jax.random.normal(ky, (k, n), jnp.float16)
group_boundaries = jax.lax.sort(
jax.random.randint(kz, (num_groups - 1,), 0, k, jnp.int32)
)
group_starts = lax.concatenate(
[jnp.array([0], dtype=jnp.int32), group_boundaries], 0
)
group_ends = lax.concatenate(
[group_boundaries, jnp.array([k], dtype=jnp.int32)], 0
)
group_sizes = group_ends - group_starts
assert group_sizes.shape == (num_groups,)
block_m = block_n = [64, 128]
block_k = [64, 128]
max_concurrent_steps = [1, 2, 4, 5, 6]
grid_block_n = [1, 2, 4, 8, 16]
configs = itertools.product(
block_m, block_n, block_k, max_concurrent_steps, grid_block_n
)
names = (
"block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n",
)
best_runtime = float("inf")
best_kwargs = {}
for config in configs:
kwargs = dict(zip(names, config))
if n % kwargs["block_n"]:
continue
try:
f = functools.partial(
transposed_ragged_dot, group_sizes=group_sizes,
**kwargs
)
_, runtime = profiler.measure(f)(lhs, rhs)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
runtime = float("inf")
# Enable this to get more detailed information.
else:
assert runtime is not None
print(
" ".join(f"{k}={v}" for k, v in kwargs.items()),
f"{int(runtime * 1000):.1f} us",
)
assert runtime is not None
assert best_runtime is not None
if runtime < best_runtime:
best_runtime = runtime
best_kwargs = kwargs
if not best_kwargs:
raise ValueError("No valid configuration found")
ref, ref_runtime = profiler.measure(ref_transposed_ragged_dot)(
lhs, rhs, group_sizes=group_sizes
)
result = transposed_ragged_dot(
lhs, rhs, group_sizes=group_sizes, **best_kwargs
)
assert ref_runtime is not None
tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
print(
"Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())
)
print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")
np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
from absl import app
jax.config.config_with_absl()
app.run(main)