hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,228 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""All-gather kernel implemented using Mosaic GPU."""
|
||||
|
||||
from collections.abc import Hashable
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.pallas import mosaic_gpu as plgpu
|
||||
from jax.extend import backend
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def all_gather(
|
||||
x: jax.Array,
|
||||
*,
|
||||
axis_name: Hashable,
|
||||
gather_dimension: int = 0,
|
||||
num_blocks: int | None = None,
|
||||
tile_size: int | None = None,
|
||||
vec_size: int | None = None,
|
||||
) -> jax.Array:
|
||||
"""Performs an all-gather operation using multimem instructions.
|
||||
|
||||
Args:
|
||||
x: Input array. Should be sharded across the specified axis.
|
||||
axis_name: Name of the mesh axis to all-gather across.
|
||||
gather_dimension: Axis along which to gather.
|
||||
num_blocks: Number of blocks to use. Defaults to the device core count.
|
||||
tile_size: Total tile size to split across major, gather, and minor dimensions.
|
||||
vec_size: Vector size for the layout. If None, automatically inferred from dtype.
|
||||
"""
|
||||
num_devices = lax.axis_size(axis_name)
|
||||
input_shape = x.shape
|
||||
dtype = x.dtype
|
||||
ndim = len(input_shape)
|
||||
|
||||
if num_blocks is None:
|
||||
num_blocks = backend.get_default_device().core_count
|
||||
|
||||
if gather_dimension < -ndim or gather_dimension >= ndim:
|
||||
raise ValueError(
|
||||
f"gather_dimension {gather_dimension} out of bounds for array of rank"
|
||||
f" {ndim}"
|
||||
)
|
||||
if gather_dimension < 0:
|
||||
gather_dimension += ndim
|
||||
|
||||
input_gather_dim = input_shape[gather_dimension]
|
||||
major_dims = math.prod(input_shape[:gather_dimension])
|
||||
minor_dims = math.prod(input_shape[gather_dimension+1:])
|
||||
output_gather_dim = input_gather_dim * num_devices
|
||||
output_shape = (
|
||||
*input_shape[:gather_dimension], output_gather_dim, *input_shape[gather_dimension + 1 :],
|
||||
)
|
||||
|
||||
if (output_size := math.prod(output_shape)) % 128:
|
||||
raise ValueError("Output size must be divisible by 128")
|
||||
if jnp.issubdtype(dtype, jnp.integer):
|
||||
if vec_size is None:
|
||||
vec_size = 1 # Integer types only support unvectorized operations
|
||||
elif vec_size != 1:
|
||||
raise ValueError("Integer types only support vec_size=1")
|
||||
elif vec_size is None: # vec_size inference for floating point types
|
||||
dtype_bits = jnp.finfo(dtype).bits
|
||||
max_vec_size = min(128 // dtype_bits, output_size // 128)
|
||||
if tile_size is not None:
|
||||
max_vec_size_for_tile = tile_size // 128
|
||||
max_vec_size = min(max_vec_size, max_vec_size_for_tile)
|
||||
vec_size = 32 // dtype_bits # We don't support multimem below 32-bit
|
||||
while vec_size * 2 <= max_vec_size:
|
||||
vec_size *= 2
|
||||
if math.prod(output_shape) % vec_size:
|
||||
raise ValueError(
|
||||
"The total number of elements in the output"
|
||||
f" ({math.prod(output_shape)}) must be divisible by the vec_size"
|
||||
f" ({vec_size})"
|
||||
)
|
||||
|
||||
min_transfer_elems = 128 * vec_size
|
||||
if tile_size is None:
|
||||
# TODO(apaszke): 8 is just an arbitrary unrolling factor. Tune it!
|
||||
unroll_factor = min(math.prod(input_shape) // min_transfer_elems, 8)
|
||||
tile_size = unroll_factor * min_transfer_elems
|
||||
if tile_size < min_transfer_elems:
|
||||
raise ValueError(
|
||||
f"{tile_size=} is smaller than minimum required"
|
||||
f" {min_transfer_elems} for {vec_size=}"
|
||||
)
|
||||
|
||||
minor_tile = math.gcd(tile_size, minor_dims)
|
||||
remaining_tile = tile_size // minor_tile
|
||||
gather_tile = math.gcd(remaining_tile, input_gather_dim)
|
||||
major_tile = remaining_tile // gather_tile
|
||||
|
||||
if major_dims % major_tile != 0:
|
||||
raise NotImplementedError(
|
||||
f"Major dimension size ({major_dims}) must be divisible by the"
|
||||
f" inferred major tile size ({major_tile}). Consider adjusting tile_size."
|
||||
)
|
||||
|
||||
def kernel(x_ref, y_ref, done_barrier):
|
||||
dev_idx = lax.axis_index(axis_name)
|
||||
x_ref_3d = x_ref.reshape((major_dims, input_gather_dim, minor_dims))
|
||||
y_ref_3d = y_ref.reshape((major_dims, output_gather_dim, minor_dims))
|
||||
y_ref_3d = y_ref_3d.at[:, pl.ds(dev_idx * input_gather_dim, input_gather_dim), :]
|
||||
|
||||
major_tiles = major_dims // major_tile
|
||||
gather_tiles = input_gather_dim // gather_tile
|
||||
minor_tiles = minor_dims // minor_tile
|
||||
# TODO(apaszke): Use a TMA pipeline
|
||||
@plgpu.nd_loop((major_tiles, gather_tiles, minor_tiles), collective_axes="blocks")
|
||||
def _transfer_loop(loop_info: plgpu.NDLoopInfo):
|
||||
major_tile_idx, gather_tile_idx, minor_tile_idx = loop_info.index
|
||||
idxs = (
|
||||
pl.ds(major_tile_idx * major_tile, major_tile),
|
||||
pl.ds(gather_tile_idx * gather_tile, gather_tile),
|
||||
pl.ds(minor_tile_idx * minor_tile, minor_tile)
|
||||
)
|
||||
output_data = plgpu.layout_cast(
|
||||
x_ref_3d[idxs],
|
||||
plgpu.Layout.WG_STRIDED((major_tile, gather_tile, minor_tile), vec_size=vec_size)
|
||||
)
|
||||
plgpu.multimem_store(output_data, y_ref_3d.at[idxs], axis_name)
|
||||
|
||||
# Wait for everyone to finish storing into our memory before returning.
|
||||
plgpu.semaphore_signal_multicast(done_barrier, collective_axes=axis_name)
|
||||
pl.semaphore_wait(done_barrier, num_devices, decrement=False)
|
||||
|
||||
# TODO(b/448323639): We fake modify the input to ensure that XLA:GPU copies
|
||||
# the operand into symmetric memory.
|
||||
@pl.when(dev_idx == -1)
|
||||
def _never():
|
||||
x_ref[(0,) * len(x_ref.shape)] = jnp.asarray(0, x_ref.dtype)
|
||||
|
||||
return plgpu.kernel(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct(output_shape, dtype),
|
||||
grid=(num_blocks,),
|
||||
grid_names=("blocks",),
|
||||
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
|
||||
)(x)
|
||||
|
||||
|
||||
def _run_example():
|
||||
P = jax.sharding.PartitionSpec
|
||||
shape = (4 * 4096, 4 * 4096) # This shape is global!
|
||||
dtype = jnp.bfloat16
|
||||
shards = jax.device_count()
|
||||
mesh = jax.make_mesh(
|
||||
(shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
|
||||
)
|
||||
jax.set_mesh(mesh)
|
||||
|
||||
# We measure time per-shard and so we only need bytes per shard.
|
||||
local_out_bytes = math.prod(shape) * jnp.dtype(dtype).itemsize
|
||||
total_bytes = local_out_bytes
|
||||
|
||||
a = jax.random.normal(jax.random.key(1), shape, dtype)
|
||||
a = jax.sharding.reshard(a, P("x", None))
|
||||
|
||||
@jax.jit
|
||||
@functools.partial(jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None))
|
||||
def ref_fn(x):
|
||||
return lax.all_gather(x, "x", axis=0, tiled=True)
|
||||
ref_fn(a).block_until_ready() # Warmup.
|
||||
_, ref_kernels_ms = profiler.measure(ref_fn, aggregate=False)(a)
|
||||
assert ref_kernels_ms is not None
|
||||
ref_time_us = sum(t * 1e3 for _, t in ref_kernels_ms)
|
||||
# We choose the minimum across processes to choose the runtime that didn't
|
||||
# include devices waiting for other devices.
|
||||
ref_time_us = min(multihost_utils.process_allgather(ref_time_us).tolist())
|
||||
ref_bw = total_bytes / (ref_time_us * 1e-6) / 1e9 # GB/s
|
||||
|
||||
tuning_it = itertools.product(
|
||||
(4, 8, 16, 32, 64, 132), # num_blocks
|
||||
(1024, 2048, 4096, 8192), # tile_size
|
||||
)
|
||||
best_bw = 0.0
|
||||
best_runtime = float("inf")
|
||||
for num_blocks, tile_size in tuning_it:
|
||||
@jax.jit
|
||||
@functools.partial(
|
||||
jax.shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None), check_vma=False
|
||||
)
|
||||
def kernel_fn(x):
|
||||
return all_gather(x, axis_name="x", gather_dimension=0, num_blocks=num_blocks, tile_size=tile_size)
|
||||
try:
|
||||
_, kernels_ms = profiler.measure(kernel_fn, aggregate=False)(a)
|
||||
except ValueError as e:
|
||||
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
|
||||
continue
|
||||
raise
|
||||
assert kernels_ms is not None
|
||||
runtime_us = sum(t * 1e3 for _, t in kernels_ms)
|
||||
runtime_us = min(multihost_utils.process_allgather(runtime_us).tolist())
|
||||
achieved_bw = total_bytes / (runtime_us * 1e-6) / 1e9 # GB/s
|
||||
if achieved_bw > best_bw:
|
||||
best_runtime = runtime_us
|
||||
best_bw = achieved_bw
|
||||
print(f"{num_blocks=}, {tile_size=}: {runtime_us:<7.1f}us = {achieved_bw:4.1f} GB/s")
|
||||
|
||||
print(f"Total bytes transferred: {total_bytes / 1e9:.2f} GB")
|
||||
print(f"\tBest: {best_runtime:<7.1f}us = {best_bw:4.1f} GB/s")
|
||||
print(f"\tRef: {ref_time_us:<7.1f}us = {ref_bw:4.1f} GB/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error
|
||||
jt_multiprocess.main(shard_main=_run_example)
|
||||
@@ -0,0 +1,680 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module containing fused attention forward and backward pass."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import dataclasses
|
||||
|
||||
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
|
||||
|
||||
@dataclasses.dataclass(frozen=True, slots=True)
|
||||
class BlockSizes:
|
||||
"""
|
||||
Tile sizes parameterizing the attention kernel. These block sizes
|
||||
should be tuned for the model and hardware for optimal performance.
|
||||
|
||||
Attributes:
|
||||
block_q: Block size along Q sequence length for forward kernel.
|
||||
block_k: Block size along KV sequence length for forward kernel.
|
||||
block_kv: Block size along KV sequence length for forward kernel.
|
||||
block_q_dkv: Block size along Q sequence length for dKV backward kernel.
|
||||
block_kv_dkv: Block size along KV sequence length for dKV backward kernel.
|
||||
block_q_dq: Block size along Q sequence length for dQ backward kernel.
|
||||
block_kv_dq: Block size along KV sequence length for dQ backward kernel.
|
||||
"""
|
||||
block_q: int
|
||||
block_k: int
|
||||
|
||||
block_q_dkv: int | None = None
|
||||
block_kv_dkv: int | None = None
|
||||
block_q_dq: int | None = None
|
||||
block_kv_dq: int | None = None
|
||||
|
||||
@classmethod
|
||||
def get_default(cls):
|
||||
return BlockSizes(
|
||||
block_q=128,
|
||||
block_k=128,
|
||||
block_q_dkv=32,
|
||||
block_kv_dkv=32,
|
||||
block_q_dq=32,
|
||||
block_kv_dq=32,
|
||||
)
|
||||
|
||||
@property
|
||||
def has_backward_blocks(self) -> bool:
|
||||
"""Returns True if all backward blocks are specified for the fused
|
||||
|
||||
dq and dk/dv backwards pass.
|
||||
"""
|
||||
backward_blocks = [
|
||||
self.block_q_dkv,
|
||||
self.block_kv_dkv,
|
||||
self.block_q_dq,
|
||||
self.block_kv_dq,
|
||||
]
|
||||
|
||||
return all(b is not None for b in backward_blocks)
|
||||
|
||||
|
||||
def mha_forward_kernel(
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref, # Input arrays
|
||||
segment_ids_ref: jax.Array | None, # segment_id arrays
|
||||
o_ref: Any, # Output
|
||||
*residual_refs: Any, # Residual outputs
|
||||
sm_scale: float,
|
||||
causal: bool,
|
||||
block_q: int,
|
||||
block_k: int,
|
||||
head_dim: int,
|
||||
):
|
||||
seq_len = k_ref.shape[0]
|
||||
start_q = pl.program_id(0)
|
||||
head_dim_padded = q_ref.shape[-1]
|
||||
|
||||
# o is the buffer where we accumulate the output on sram.
|
||||
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
|
||||
m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf')
|
||||
l_i = jnp.zeros(block_q, dtype=jnp.float32)
|
||||
# acc is the buffer where we accumulate the output on sram.
|
||||
o = jnp.zeros((block_q, head_dim_padded), dtype=jnp.float32)
|
||||
|
||||
# Load q: it will stay in L1 throughout. Indices form a matrix because we
|
||||
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
|
||||
# q tile has shape [block_q, head_dim_padded], head_dim_padded >= head_dim.
|
||||
curr_q_slice = pl.dslice(start_q * block_q, block_q)
|
||||
head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :]
|
||||
q = plgpu.load(q_ref, mask=head_mask, other=0.0)
|
||||
q_segment_ids = (
|
||||
None if segment_ids_ref is None else segment_ids_ref[curr_q_slice]
|
||||
)
|
||||
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
|
||||
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
|
||||
# Here we only loop over blocks of kv to process entire seq_len, the loop over
|
||||
# blocks of q is carried out by the grid.
|
||||
def body(start_k, carry):
|
||||
o_prev, m_prev, l_prev = carry
|
||||
curr_k_slice = pl.dslice(start_k * block_k, block_k)
|
||||
|
||||
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
|
||||
qk = pl.dot(q, k.T) # [block_q, block_k]
|
||||
|
||||
# Scale logits to convert from base-2 to the natural log domain.
|
||||
# This is based on the identity: e^x = 2^(x * log2(e)).
|
||||
qk_scale = math.log2(math.e)
|
||||
if sm_scale != 1.:
|
||||
qk_scale *= sm_scale
|
||||
qk *= qk_scale
|
||||
|
||||
# Avoids Triton crash.
|
||||
# if num_heads > 2:
|
||||
# qk = qk.astype(q_ref.dtype)
|
||||
# qk = qk.astype(jnp.float32)
|
||||
|
||||
if causal or segment_ids_ref is not None:
|
||||
mask = None
|
||||
if segment_ids_ref is not None:
|
||||
assert q_segment_ids is not None
|
||||
kv_segment_ids = segment_ids_ref[curr_k_slice]
|
||||
mask = segment_mask(q_segment_ids, kv_segment_ids)
|
||||
if causal:
|
||||
span_q = start_q * block_q + jnp.arange(block_q)
|
||||
span_k = start_k * block_k + jnp.arange(block_k)
|
||||
causal_mask = span_q[:, None] >= span_k[None, :]
|
||||
mask = (
|
||||
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
|
||||
)
|
||||
# Apply mask to qk.
|
||||
assert mask is not None
|
||||
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
|
||||
|
||||
m_curr = jnp.max(qk, axis=-1)
|
||||
m_next = jnp.maximum(m_prev, m_curr)
|
||||
correction = jnp.exp2(m_prev - m_next)
|
||||
l_prev_corr = correction * l_prev
|
||||
s_curr = jnp.exp2(
|
||||
qk - m_next[:, None]
|
||||
) # Use m_next instead of m_curr to avoid a correction on l_curr
|
||||
l_curr = s_curr.sum(axis=-1)
|
||||
l_next = l_prev_corr + l_curr
|
||||
o_prev_corr = correction[:, None] * o_prev
|
||||
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask)
|
||||
o_curr = pl.dot(s_curr.astype(v.dtype), v)
|
||||
|
||||
o_next = o_prev_corr + o_curr
|
||||
return o_next, m_next, l_next
|
||||
if causal:
|
||||
# Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q)
|
||||
upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k)
|
||||
else:
|
||||
upper_bound = pl.cdiv(seq_len, block_k)
|
||||
o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))
|
||||
|
||||
# We keep an unscaled version of o during the scan over seq_len. Scaling it
|
||||
# by the last l_i gives us the correct final output. See section 3.1.1 in the
|
||||
# FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691.
|
||||
o /= l_i[:, None]
|
||||
|
||||
if residual_refs:
|
||||
lse_ref = residual_refs[0]
|
||||
lse_ref[...] = m_i + jnp.log2(l_i)
|
||||
# Write output to dram.
|
||||
plgpu.store(o_ref.at[:, : o.shape[-1]], o.astype(o_ref.dtype), mask=head_mask)
|
||||
|
||||
def segment_mask(
|
||||
q_segment_ids: jax.Array,
|
||||
kv_segment_ids: jax.Array,
|
||||
):
|
||||
# [B, T, 1] or [T, 1]
|
||||
q_segment_ids = jnp.expand_dims(q_segment_ids, axis=-1)
|
||||
# [B, 1, S] or [1, S]
|
||||
if kv_segment_ids.ndim == 1:
|
||||
kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=0)
|
||||
else:
|
||||
kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=1)
|
||||
return jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
)
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"sm_scale",
|
||||
"causal",
|
||||
"block_sizes",
|
||||
"backward_pass_impl",
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
"grid",
|
||||
"interpret",
|
||||
"debug",
|
||||
"return_residuals",
|
||||
],
|
||||
)
|
||||
def mha(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
segment_ids: jnp.ndarray | None,
|
||||
sm_scale: float = 1.0,
|
||||
causal: bool = False,
|
||||
block_sizes: BlockSizes = BlockSizes.get_default(),
|
||||
backward_pass_impl: str = "triton",
|
||||
num_warps: int | None = None,
|
||||
num_stages: int = 2,
|
||||
grid: tuple[int, ...] | None = None,
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
return_residuals: bool = False,
|
||||
):
|
||||
del backward_pass_impl
|
||||
batch_size, q_seq_len, num_heads, head_dim = q.shape
|
||||
kv_seq_len = k.shape[1]
|
||||
block_q = min(block_sizes.block_q, q_seq_len)
|
||||
block_k = min(block_sizes.block_k, kv_seq_len)
|
||||
head_dim_padded = pl.next_power_of_2(head_dim)
|
||||
if (q.shape[-1] != k.shape[-1]) or (q.shape[-1] != v.shape[-1]):
|
||||
raise ValueError(
|
||||
f"This kernel expects q, k, and v to have the same head dimension, but"
|
||||
f" found {q.shape=}, {k.shape=}, {v.shape=}."
|
||||
)
|
||||
if q_seq_len % block_q != 0:
|
||||
raise ValueError(f"{q_seq_len=} must be a multiple of {block_q=}")
|
||||
if kv_seq_len % block_k != 0:
|
||||
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_k=}")
|
||||
|
||||
# Heuristics.
|
||||
grid_ = grid
|
||||
if grid_ is None:
|
||||
grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads)
|
||||
|
||||
num_warps_ = num_warps
|
||||
if num_warps_ is None:
|
||||
num_warps_ = 4 if head_dim <= 64 else 8
|
||||
kernel = functools.partial(mha_forward_kernel, sm_scale=sm_scale,
|
||||
block_q=block_q, block_k=block_k,
|
||||
head_dim=head_dim, causal=causal)
|
||||
|
||||
in_specs: list[pl.BlockSpec | None] = [
|
||||
pl.BlockSpec((None, block_q, None, head_dim_padded),
|
||||
lambda i, j, k: (j, i, k, 0)),
|
||||
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
|
||||
lambda _, j, k: (j, 0, k, 0)),
|
||||
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
|
||||
lambda _, j, k: (j, 0, k, 0)),
|
||||
]
|
||||
in_specs.append(
|
||||
None
|
||||
if segment_ids is None
|
||||
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
|
||||
)
|
||||
out_shape = [q]
|
||||
out_specs = [pl.BlockSpec((None, block_q, None, head_dim_padded),
|
||||
lambda i, j, k: (j, i, k, 0))]
|
||||
if return_residuals:
|
||||
out_shape.append(jax.ShapeDtypeStruct(
|
||||
shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32)) # lse
|
||||
out_specs.append(
|
||||
pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))) # lse
|
||||
out = pl.pallas_call(
|
||||
kernel,
|
||||
grid=grid_,
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
num_warps=num_warps_, num_stages=num_stages),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
name="mha_forward",
|
||||
)(q, k, v, segment_ids)
|
||||
return out if return_residuals else out[0]
|
||||
|
||||
|
||||
def _mha_forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
segment_ids: jax.Array | None,
|
||||
sm_scale: float,
|
||||
causal: bool,
|
||||
block_sizes: BlockSizes,
|
||||
backward_pass_impl: str,
|
||||
num_warps: int | None,
|
||||
num_stages: int,
|
||||
grid: Any,
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
return_residuals: bool,
|
||||
):
|
||||
out, lse = mha(q, k, v, segment_ids=segment_ids, sm_scale=sm_scale,
|
||||
causal=causal, block_sizes=block_sizes,
|
||||
backward_pass_impl=backward_pass_impl,
|
||||
num_warps=num_warps, num_stages=num_stages,
|
||||
grid=grid, interpret=interpret, debug=debug,
|
||||
return_residuals=True)
|
||||
residuals = (q, k, v, segment_ids, out, lse)
|
||||
ret = (out, lse) if return_residuals else out
|
||||
return ret, residuals
|
||||
|
||||
|
||||
def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref, head_dim: int):
|
||||
# load
|
||||
head_mask = (jnp.arange(out_ref.shape[-1]) < head_dim)[None, :]
|
||||
o = plgpu.load(out_ref, mask=head_mask, other=0.0)
|
||||
do = plgpu.load(dout_ref, mask=head_mask, other=0.0)
|
||||
# compute
|
||||
delta = jnp.sum(o * do, axis=1)
|
||||
# write-back
|
||||
delta_ref[...] = delta.astype(delta_ref.dtype)
|
||||
|
||||
@jax.named_scope("preprocess_backward")
|
||||
def _preprocess_backward(out, do, lse, block_q: int,
|
||||
debug: bool, interpret: bool):
|
||||
batch_size, seq_len, num_heads, head_dim = out.shape
|
||||
head_dim_padded = pl.next_power_of_2(head_dim)
|
||||
out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype)
|
||||
delta = pl.pallas_call(
|
||||
functools.partial(_preprocess_backward_kernel, head_dim=head_dim),
|
||||
grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
|
||||
in_specs=[
|
||||
pl.BlockSpec((None, block_q, None, head_dim_padded),
|
||||
lambda i, j, k: (j, i, k, 0)),
|
||||
pl.BlockSpec((None, block_q, None, head_dim_padded),
|
||||
lambda i, j, k: (j, i, k, 0)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)),
|
||||
compiler_params=plgpu.CompilerParams(num_warps=4, num_stages=3),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
name="mha_preprocess_backward",
|
||||
)(out, do)
|
||||
return delta
|
||||
|
||||
|
||||
# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence
|
||||
# length.
|
||||
# Inspired by the triton tutorial: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
|
||||
def mha_backward_kernel(
|
||||
# Inputs
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref,
|
||||
segment_ids_ref: jax.Array | None,
|
||||
out_ref,
|
||||
do_scaled_ref,
|
||||
lse_ref,
|
||||
delta_ref,
|
||||
# Outputs
|
||||
dq_ref,
|
||||
dk_ref,
|
||||
dv_ref,
|
||||
*,
|
||||
sm_scale: float,
|
||||
causal: bool,
|
||||
block_q_dkv: int,
|
||||
block_kv_dkv: int,
|
||||
block_q_dq: int,
|
||||
block_kv_dq: int,
|
||||
head_dim: int,
|
||||
):
|
||||
del out_ref # Not needed
|
||||
q_seq_len = q_ref.shape[0]
|
||||
kv_seq_len = k_ref.shape[0]
|
||||
|
||||
# Scan #1: dK and dV
|
||||
# 1. Load a block of K and V of size (block_kv_dkv, head_dim) in SMEM.
|
||||
# 2. Iterate through Q in chunks of (block_q_dkv, head_dim) to accumulate
|
||||
# dK and dV.
|
||||
start_k = pl.program_id(2)
|
||||
curr_k_slice = pl.dslice(start_k * block_kv_dkv, block_kv_dkv)
|
||||
|
||||
head_dim_padded = q_ref.shape[-1]
|
||||
dv = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32)
|
||||
dk = jnp.zeros([block_kv_dkv, head_dim_padded], dtype=jnp.float32)
|
||||
|
||||
head_mask = (jnp.arange(head_dim_padded) < head_dim)[None, :]
|
||||
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
|
||||
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
|
||||
span_k = start_k * block_kv_dkv + jnp.arange(block_kv_dkv)
|
||||
kv_segment_ids = (
|
||||
None if segment_ids_ref is None else segment_ids_ref[curr_k_slice]
|
||||
)
|
||||
|
||||
def inner_loop_dkdv(start_q, carry):
|
||||
dv, dk = carry
|
||||
curr_q_slice = pl.dslice(start_q * block_q_dkv, block_q_dkv)
|
||||
|
||||
q = plgpu.load(q_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
|
||||
qk = pl.dot(q, k.T)
|
||||
qk_scale = math.log2(math.e)
|
||||
if sm_scale != 1.:
|
||||
qk_scale *= sm_scale
|
||||
qk *= qk_scale
|
||||
|
||||
if causal or segment_ids_ref is not None:
|
||||
mask = None
|
||||
if segment_ids_ref is not None:
|
||||
assert kv_segment_ids is not None
|
||||
q_segment_ids = segment_ids_ref[curr_q_slice]
|
||||
mask = segment_mask(q_segment_ids, kv_segment_ids)
|
||||
|
||||
if causal:
|
||||
span_q = start_q * block_q_dkv + jnp.arange(block_q_dkv)
|
||||
causal_mask = span_q[:, None] >= span_k[None, :]
|
||||
mask = (
|
||||
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
|
||||
)
|
||||
assert mask is not None
|
||||
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
|
||||
|
||||
lse = lse_ref[curr_q_slice]
|
||||
di = delta_ref[curr_q_slice]
|
||||
do = plgpu.load(
|
||||
do_scaled_ref.at[curr_q_slice, :], mask=head_mask, other=0.0
|
||||
)
|
||||
|
||||
p = jnp.exp2(qk - lse[:, None])
|
||||
dv = dv + pl.dot(p.astype(do.dtype).T, do)
|
||||
dp = jnp.zeros((block_q_dkv, block_kv_dkv), dtype=jnp.float32) - di[:, None]
|
||||
dp = dp + pl.dot(do, v.T)
|
||||
ds = p * dp
|
||||
if sm_scale != 1.0:
|
||||
ds = ds * sm_scale
|
||||
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
|
||||
|
||||
return dv, dk
|
||||
|
||||
lower_bound = lax.div(start_k * block_kv_dkv, block_q_dkv) if causal else 0
|
||||
dv, dk = lax.fori_loop(
|
||||
lower_bound, pl.cdiv(q_seq_len, block_q_dkv), inner_loop_dkdv, (dv, dk)
|
||||
)
|
||||
plgpu.store(
|
||||
dv_ref.at[:, : dv.shape[-1]], dv.astype(dv_ref.dtype), mask=head_mask
|
||||
)
|
||||
plgpu.store(
|
||||
dk_ref.at[:, : dk.shape[-1]], dk.astype(dk_ref.dtype), mask=head_mask
|
||||
)
|
||||
|
||||
# Scan #2: dQ
|
||||
# 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM.
|
||||
# 2. Iterate through K and V in chunks of (block_kv_dq, head_dim) to
|
||||
# accumulate dQ.
|
||||
start_q = pl.program_id(2)
|
||||
curr_q_slice = pl.ds(start_q * block_q_dq, block_q_dq)
|
||||
span_q = start_q * block_q_dq + jnp.arange(block_q_dq)
|
||||
dq = jnp.zeros([block_q_dq, head_dim_padded], dtype=jnp.float32)
|
||||
|
||||
q = plgpu.load(q_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
|
||||
q_segment_ids = (
|
||||
None if segment_ids_ref is None else segment_ids_ref[curr_q_slice]
|
||||
)
|
||||
lse = lse_ref[curr_q_slice]
|
||||
do = plgpu.load(do_scaled_ref.at[curr_q_slice, :], mask=head_mask, other=0.0)
|
||||
di = delta_ref[curr_q_slice]
|
||||
|
||||
def inner_loop_dq(start_k, dq):
|
||||
curr_k_slice = pl.dslice(start_k * block_kv_dq, block_kv_dq)
|
||||
k = plgpu.load(k_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
|
||||
v = plgpu.load(v_ref.at[curr_k_slice, :], mask=head_mask, other=0.0)
|
||||
|
||||
qk = pl.dot(q, k.T)
|
||||
qk_scale = math.log2(math.e)
|
||||
if sm_scale != 1.:
|
||||
qk_scale *= sm_scale
|
||||
qk *= qk_scale
|
||||
|
||||
if causal or segment_ids_ref is not None:
|
||||
mask = None
|
||||
if segment_ids_ref is not None:
|
||||
assert q_segment_ids is not None
|
||||
kv_segment_ids = segment_ids_ref[curr_k_slice]
|
||||
mask = segment_mask(q_segment_ids, kv_segment_ids)
|
||||
|
||||
if causal:
|
||||
span_k = start_k * block_kv_dq + jnp.arange(block_kv_dq)
|
||||
causal_mask = span_q[:, None] >= span_k[None, :]
|
||||
mask = (
|
||||
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
|
||||
)
|
||||
assert mask is not None
|
||||
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
|
||||
|
||||
p = jnp.exp2(qk - lse[:, None])
|
||||
dp = jnp.zeros((block_q_dq, block_kv_dq), dtype=jnp.float32) - di[:, None]
|
||||
dp = dp + pl.dot(do, v.T)
|
||||
ds = p * dp
|
||||
if sm_scale != 1.0:
|
||||
ds = ds * sm_scale
|
||||
|
||||
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
|
||||
|
||||
return dq
|
||||
|
||||
if causal:
|
||||
upper_bound = pl.cdiv((start_q + 1) * block_q_dq, block_kv_dq)
|
||||
else:
|
||||
upper_bound = pl.cdiv(kv_seq_len, block_kv_dq)
|
||||
|
||||
dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
|
||||
plgpu.store(
|
||||
dq_ref.at[:, : dq.shape[-1]], dq.astype(dq_ref.dtype), mask=head_mask
|
||||
)
|
||||
|
||||
|
||||
def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
|
||||
backward_pass_impl: str, num_warps: int | None,
|
||||
num_stages: int, grid: Any, interpret: bool,
|
||||
debug: bool, return_residuals: bool, res, do):
|
||||
if return_residuals:
|
||||
raise ValueError(
|
||||
"Kernel differentiation is not supported if return_residuals is True.")
|
||||
q, k, v, segment_ids, out, lse = res
|
||||
del num_stages, grid, return_residuals
|
||||
|
||||
if backward_pass_impl == "xla":
|
||||
return jax.vjp(
|
||||
functools.partial(mha_reference, sm_scale=sm_scale, causal=causal),
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
segment_ids,
|
||||
)[1](do)
|
||||
elif backward_pass_impl == "triton":
|
||||
if not block_sizes.has_backward_blocks:
|
||||
raise ValueError("Backward block sizes must all be set.")
|
||||
|
||||
assert block_sizes.block_q_dkv is not None
|
||||
assert block_sizes.block_kv_dkv is not None
|
||||
assert block_sizes.block_q_dq is not None
|
||||
assert block_sizes.block_kv_dq is not None
|
||||
|
||||
batch_size, q_seq_len, num_heads, head_dim = q.shape
|
||||
kv_seq_len = k.shape[1]
|
||||
block_q = min(block_sizes.block_q, q_seq_len)
|
||||
block_q_dkv = min(block_sizes.block_q_dkv, q_seq_len)
|
||||
block_kv_dkv = min(block_sizes.block_kv_dkv, kv_seq_len)
|
||||
block_q_dq = min(block_sizes.block_q_dq, q_seq_len)
|
||||
block_kv_dq = min(block_sizes.block_kv_dq, kv_seq_len)
|
||||
head_dim_padded = pl.next_power_of_2(head_dim)
|
||||
|
||||
if q_seq_len // block_q_dq != kv_seq_len // block_kv_dkv:
|
||||
raise ValueError(
|
||||
"q_seq_len and kv_seq_len must be divided into the same "
|
||||
"number of blocks for the fused backward pass."
|
||||
)
|
||||
|
||||
delta = _preprocess_backward(out, do, lse, block_q, debug, interpret)
|
||||
out_shapes = [
|
||||
jax.ShapeDtypeStruct(q.shape, q.dtype),
|
||||
jax.ShapeDtypeStruct(k.shape, k.dtype),
|
||||
jax.ShapeDtypeStruct(v.shape, v.dtype),
|
||||
]
|
||||
|
||||
in_specs: list[pl.BlockSpec | None] = [
|
||||
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
|
||||
lambda i, j, _: (i, 0, j, 0)),
|
||||
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
|
||||
lambda i, j, _: (i, 0, j, 0)),
|
||||
pl.BlockSpec((None, kv_seq_len, None, head_dim_padded),
|
||||
lambda i, j, _: (i, 0, j, 0)),
|
||||
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
|
||||
lambda i, j, _: (i, 0, j, 0)),
|
||||
pl.BlockSpec((None, q_seq_len, None, head_dim_padded),
|
||||
lambda i, j, _: (i, 0, j, 0)),
|
||||
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
|
||||
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
|
||||
]
|
||||
if segment_ids is None:
|
||||
in_specs.insert(3, None)
|
||||
else:
|
||||
in_specs.insert(3, pl.BlockSpec((None, kv_seq_len),
|
||||
lambda i, j, _: (i, 0)))
|
||||
|
||||
grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_kv_dkv))
|
||||
num_warps_ = num_warps
|
||||
if num_warps_ is None:
|
||||
if (
|
||||
block_q_dkv * block_kv_dkv < 128 * 128
|
||||
or block_q_dq * block_kv_dq < 128 * 128
|
||||
):
|
||||
num_warps_ = 4
|
||||
else:
|
||||
num_warps_ = 8
|
||||
|
||||
|
||||
dq, dk, dv = pl.pallas_call(
|
||||
functools.partial(
|
||||
mha_backward_kernel,
|
||||
sm_scale=sm_scale,
|
||||
causal=causal,
|
||||
block_q_dkv=block_q_dkv,
|
||||
block_kv_dkv=block_kv_dkv,
|
||||
block_q_dq=block_q_dq,
|
||||
block_kv_dq=block_kv_dq,
|
||||
head_dim=head_dim,
|
||||
),
|
||||
out_shape=out_shapes,
|
||||
in_specs=in_specs,
|
||||
grid=grid,
|
||||
out_specs=[
|
||||
pl.BlockSpec(
|
||||
(None, block_q_dq, None, head_dim_padded),
|
||||
lambda i, j, k: (i, k, j, 0), # dq
|
||||
),
|
||||
pl.BlockSpec(
|
||||
(None, block_kv_dkv, None, head_dim_padded),
|
||||
lambda i, j, k: (i, k, j, 0), # dk
|
||||
),
|
||||
pl.BlockSpec(
|
||||
(None, block_kv_dkv, None, head_dim_padded),
|
||||
lambda i, j, k: (i, k, j, 0), # dv
|
||||
),
|
||||
],
|
||||
name="mha_backward",
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
num_warps=num_warps_, num_stages=2
|
||||
),
|
||||
)(q, k, v, segment_ids, out, do, lse, delta)
|
||||
else:
|
||||
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
|
||||
return dq.astype(q.dtype), dk, dv, None
|
||||
mha.defvjp(_mha_forward, _mha_backward)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=['sm_scale', 'causal'])
|
||||
def mha_reference(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
segment_ids: jnp.ndarray | None,
|
||||
sm_scale=1.0,
|
||||
causal: bool = False,
|
||||
):
|
||||
q_seq_len = q.shape[1]
|
||||
kv_seq_len = k.shape[1]
|
||||
logits = jnp.einsum(
|
||||
'bqhc,bkhc->bhqk', q, k, preferred_element_type=jnp.float32
|
||||
)
|
||||
mask = None
|
||||
if segment_ids is not None:
|
||||
mask = jnp.expand_dims(segment_mask(segment_ids, segment_ids), 1)
|
||||
mask = jnp.broadcast_to(mask, logits.shape)
|
||||
if causal:
|
||||
causal_mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool))
|
||||
causal_mask = jnp.broadcast_to(causal_mask, logits.shape)
|
||||
mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
|
||||
logits = logits if mask is None else jnp.where(mask, logits, float("-inf"))
|
||||
weights = jax.nn.softmax(logits * sm_scale)
|
||||
return jnp.einsum(
|
||||
'bhqk,bkhc->bqhc', weights, v, preferred_element_type=jnp.float32
|
||||
)
|
||||
@@ -0,0 +1,911 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""FlashAttention3 implementation (using Mosaic GPU as the backend)."""
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu # noqa: F401
|
||||
from jax._src.lib import cuda_versions # noqa: F401
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
import jax.experimental.pallas as pl
|
||||
import jax.experimental.pallas.mosaic_gpu as plgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
from typing import Protocol, TypeVar
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class PipelineCallback(Protocol):
|
||||
"""A callback that returns the same type as the input."""
|
||||
def __call__(self, arg: T, /) -> T: ...
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TuningConfig:
|
||||
block_q: int
|
||||
block_kv: int
|
||||
max_concurrent_steps: int
|
||||
use_schedule_barrier: bool = True
|
||||
causal: bool = False
|
||||
compute_wgs_bwd: int = 1
|
||||
|
||||
block_q_dkv: int | None = None
|
||||
block_kv_dkv: int | None = None
|
||||
block_q_dq: int | None = None
|
||||
block_kv_dq: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.block_q % 64:
|
||||
raise ValueError(f"{self.block_q=} must be a multiple of 64")
|
||||
if self.block_kv % 64:
|
||||
raise ValueError(f"{self.block_kv=} must be a multiple of 64")
|
||||
if self.max_concurrent_steps < 2:
|
||||
raise ValueError(f"{self.max_concurrent_steps=} must be at least 2")
|
||||
|
||||
backward_blocks = [self.block_q_dkv, self.block_kv_dkv, self.block_q_dq, self.block_kv_dq]
|
||||
block_is_set = [blk is not None for blk in backward_blocks]
|
||||
if any(block_is_set) and not all(block_is_set):
|
||||
raise ValueError(
|
||||
"Backward block sizes (block_q_dkv, block_kv_dkv, block_q_dq, "
|
||||
"block_kv_dq) must either all be specified or all be None."
|
||||
)
|
||||
|
||||
@property
|
||||
def has_backward_blocks(self) -> bool:
|
||||
return self.block_q_dkv is not None
|
||||
|
||||
def _attention_forward(q, k, v, config: TuningConfig, save_residuals: bool = False):
|
||||
assert cuda_versions is not None
|
||||
cuda_runtime_version = cuda_versions.cuda_runtime_get_version()
|
||||
# TODO(pobudzey): Undo when we upgrade to cuda 12.9.1.
|
||||
if config.causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091:
|
||||
raise ValueError(
|
||||
"Causal masking not supported with cuda versions between 12.8.0 and"
|
||||
" 12.9.1 due to a ptxas miscompilation."
|
||||
)
|
||||
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
|
||||
raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}")
|
||||
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
|
||||
_, kv_seq_len, num_kv_heads, _ = k.shape
|
||||
kv_shape = (batch_size, kv_seq_len, num_kv_heads, head_dim)
|
||||
if k.shape != kv_shape:
|
||||
raise ValueError(f"Expected {k.shape=} to be {kv_shape} (inferred from q)")
|
||||
if k.shape != kv_shape:
|
||||
raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)")
|
||||
if (dtype := q.dtype) != k.dtype or dtype != v.dtype:
|
||||
raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}")
|
||||
if num_q_heads % num_kv_heads:
|
||||
raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}")
|
||||
q_heads_per_kv_head = num_q_heads // num_kv_heads
|
||||
if head_dim % 64:
|
||||
raise ValueError(f"{head_dim=} must be divisible by 64")
|
||||
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
|
||||
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")
|
||||
|
||||
max_concurrent_steps = min(
|
||||
config.max_concurrent_steps, kv_seq_len // config.block_kv
|
||||
)
|
||||
block_q, block_kv = config.block_q, config.block_kv
|
||||
if kv_seq_len % block_kv:
|
||||
raise ValueError(f"{kv_seq_len=} must be a multiple of {block_kv=}")
|
||||
|
||||
def kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, scoped):
|
||||
batch = lax.axis_index("batch")
|
||||
q_head = lax.axis_index("heads")
|
||||
smem_buffers, buffer_barriers, consumed_barriers, schedule_barrier = scoped
|
||||
wg_idx = lax.axis_index("wg")
|
||||
qo_smem2, k_smem, v_smem, lse_smem2 = smem_buffers
|
||||
k_barriers, v_barriers, q_barriers = buffer_barriers
|
||||
k_consumed_barriers, v_consumed_barriers = consumed_barriers
|
||||
def perform_schedule_barrier():
|
||||
plgpu.barrier_arrive(schedule_barrier)
|
||||
plgpu.barrier_wait(schedule_barrier)
|
||||
|
||||
if config.causal:
|
||||
block_q_end = (lax.axis_index("q_seq") + 1) * (2 * block_q)
|
||||
block_max_kv_steps = pl.cdiv(block_q_end, jnp.array(block_kv, jnp.int32))
|
||||
else:
|
||||
block_max_kv_steps = kv_seq_len // block_kv
|
||||
|
||||
@pl.when(wg_idx < 2)
|
||||
def _compute_wg():
|
||||
plgpu.set_max_registers(232, action="increase")
|
||||
qo_smem = qo_smem2.at[wg_idx]
|
||||
lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None
|
||||
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
|
||||
|
||||
if config.causal:
|
||||
kv_steps = pl.cdiv(q_seq_base + block_q, jnp.array(block_kv, jnp.int32))
|
||||
else:
|
||||
kv_steps = block_max_kv_steps
|
||||
|
||||
plgpu.copy_gmem_to_smem(
|
||||
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
|
||||
qo_smem,
|
||||
q_barriers.at[wg_idx],
|
||||
)
|
||||
plgpu.barrier_wait(q_barriers.at[wg_idx])
|
||||
|
||||
m_i = plgpu.layout_cast(
|
||||
jnp.full((block_q,), -jnp.inf, dtype=jnp.float32),
|
||||
plgpu.Layout.WGMMA.reduce(1),
|
||||
)
|
||||
l_i = plgpu.layout_cast(
|
||||
jnp.full((block_q,), 0, dtype=jnp.float32),
|
||||
plgpu.Layout.WGMMA.reduce(1),
|
||||
)
|
||||
acc = plgpu.layout_cast(
|
||||
jnp.full((block_q, head_dim), 0, dtype=jnp.float32),
|
||||
plgpu.Layout.WGMMA,
|
||||
)
|
||||
|
||||
@pl.when(kv_steps > 0)
|
||||
def _():
|
||||
plgpu.barrier_wait(k_barriers.at[0])
|
||||
|
||||
pl.when(wg_idx == 1)(perform_schedule_barrier)
|
||||
def kv_loop(kv_step, carry, causal: bool = False):
|
||||
acc, m_i, l_i = carry
|
||||
slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype))
|
||||
|
||||
# QK
|
||||
def compute_qk(acc_ref):
|
||||
plgpu.wgmma(acc_ref, qo_smem, plgpu.transpose_ref(k_smem.at[slot], (1, 0)))
|
||||
perform_schedule_barrier()
|
||||
return acc_ref[...]
|
||||
qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32))
|
||||
plgpu.barrier_arrive(k_consumed_barriers.at[slot])
|
||||
|
||||
if causal:
|
||||
q_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 0, layout=plgpu.Layout.WGMMA)
|
||||
kv_ids = plgpu.broadcasted_iota(jnp.int32, (block_q, block_kv), 1, layout=plgpu.Layout.WGMMA)
|
||||
mask = (q_ids + q_seq_base) >= (kv_ids + kv_step * block_kv)
|
||||
qk = jnp.where(mask, qk, -jnp.inf)
|
||||
|
||||
# Softmax
|
||||
# We keep m scaled by log2e to use FMA instructions when computing p.
|
||||
log2e = math.log2(math.e)
|
||||
m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e)
|
||||
alpha = jnp.exp2(m_i - m_ij)
|
||||
m_i = m_ij
|
||||
p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0]))
|
||||
acc *= lax.broadcast_in_dim(alpha, acc.shape, [0])
|
||||
l_i *= alpha
|
||||
p16 = p.astype(dtype)
|
||||
|
||||
def end_softmax_barriers():
|
||||
plgpu.barrier_arrive(schedule_barrier) # Done with softmax!
|
||||
plgpu.barrier_wait(v_barriers.at[slot])
|
||||
plgpu.barrier_wait(schedule_barrier) # Wait until TensorCore is free.
|
||||
# Can't fully explain why, but empirically the ordering here influences
|
||||
# the performance of the final kernel quite significantly.
|
||||
if head_dim <= 128:
|
||||
l_i += p.sum(axis=1)
|
||||
acc, l_i, m_i, p16 = lax.optimization_barrier((acc, l_i, m_i, p16))
|
||||
end_softmax_barriers()
|
||||
else:
|
||||
end_softmax_barriers()
|
||||
l_i += p.sum(axis=1)
|
||||
|
||||
# PV
|
||||
def compute_pv(acc_ref):
|
||||
plgpu.wgmma(acc_ref, p16, v_smem.at[slot])
|
||||
|
||||
wait_step = kv_step + 1
|
||||
wait_slot = lax.rem(wait_step, jnp.array(max_concurrent_steps, kv_step.dtype))
|
||||
@pl.when(wait_step < kv_steps)
|
||||
def _wait():
|
||||
plgpu.barrier_wait(k_barriers.at[wait_slot])
|
||||
acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc))
|
||||
plgpu.barrier_arrive(v_consumed_barriers.at[slot])
|
||||
return acc, m_i, l_i
|
||||
|
||||
if not config.causal:
|
||||
acc, m_i, l_i = lax.fori_loop(0, block_max_kv_steps, kv_loop, (acc, m_i, l_i))
|
||||
else:
|
||||
def epilogue_kv_loop(kv_step, _):
|
||||
# This loop makes sure that all the pipelined KV data is processed, even
|
||||
# if one compute wg finishes early like with causal masking.
|
||||
slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype))
|
||||
plgpu.barrier_arrive(k_consumed_barriers.at[slot])
|
||||
plgpu.barrier_arrive(v_consumed_barriers.at[slot])
|
||||
perform_schedule_barrier()
|
||||
perform_schedule_barrier()
|
||||
|
||||
causal_kv_loop = functools.partial(kv_loop, causal=True)
|
||||
full_kv_steps = lax.div(q_seq_base, jnp.array(block_kv, jnp.int32))
|
||||
# With causal masking, the KV loop unrolling is split in 3 sections:
|
||||
# 1. A fast path where no causal mask is needed.
|
||||
acc, m_i, l_i = lax.fori_loop(0, full_kv_steps, kv_loop, (acc, m_i, l_i))
|
||||
# 2. Causal masking.
|
||||
acc, m_i, l_i = lax.fori_loop(full_kv_steps, kv_steps, causal_kv_loop, (acc, m_i, l_i))
|
||||
# 3. Epilogue to flush the data pipeline.
|
||||
lax.fori_loop(kv_steps, block_max_kv_steps, epilogue_kv_loop, None)
|
||||
pl.when(wg_idx == 0)(perform_schedule_barrier)
|
||||
|
||||
# TODO(apaszke): Invert and multiply to avoid expensive divisions.
|
||||
acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
|
||||
qo_smem[...] = acc.astype(dtype)
|
||||
if lse_smem is not None:
|
||||
RCP_LN2 = 1.4426950408889634
|
||||
log2 = lambda x: jnp.log(x) * RCP_LN2
|
||||
lse_smem[...] = m_i + log2(l_i)
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(
|
||||
qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
|
||||
)
|
||||
if lse_smem is not None:
|
||||
plgpu.copy_smem_to_gmem(
|
||||
lse_smem,
|
||||
lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)],
|
||||
)
|
||||
plgpu.wait_smem_to_gmem(0)
|
||||
@pl.when(wg_idx == 2)
|
||||
def _memory_wg():
|
||||
plgpu.set_max_registers(40, action="decrease")
|
||||
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
|
||||
for i in range(max_concurrent_steps):
|
||||
s = (batch, pl.ds(i * block_kv, block_kv), kv_head)
|
||||
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[i], k_barriers.at[i])
|
||||
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[i], v_barriers.at[i])
|
||||
|
||||
@pl.loop(0, block_max_kv_steps - max_concurrent_steps)
|
||||
def _kv_loop(kv_step):
|
||||
tma_step = kv_step + max_concurrent_steps
|
||||
tma_slot = lax.rem(kv_step, jnp.array(max_concurrent_steps, kv_step.dtype))
|
||||
s = (batch, pl.ds(tma_step * block_kv, block_kv), kv_head)
|
||||
plgpu.barrier_wait(k_consumed_barriers.at[tma_slot])
|
||||
plgpu.copy_gmem_to_smem(k_ref.at[s], k_smem.at[tma_slot], k_barriers.at[tma_slot])
|
||||
plgpu.barrier_wait(v_consumed_barriers.at[tma_slot])
|
||||
plgpu.copy_gmem_to_smem(v_ref.at[s], v_smem.at[tma_slot], v_barriers.at[tma_slot])
|
||||
|
||||
def entry(q_ref, k_ref, v_ref, out_ref, lse_ref):
|
||||
compute_wgs = 2
|
||||
tiling = plgpu.TilingTransform((8, 64))
|
||||
swizzle = plgpu.SwizzleTransform(128)
|
||||
qo_scratch = plgpu.SMEM(
|
||||
(compute_wgs, block_q, head_dim), jnp.float16,
|
||||
transforms=(tiling, swizzle),
|
||||
)
|
||||
k_scratch = plgpu.SMEM(
|
||||
(max_concurrent_steps, block_kv, head_dim), jnp.float16,
|
||||
transforms=(tiling, swizzle),
|
||||
)
|
||||
v_scratch = plgpu.SMEM(
|
||||
(max_concurrent_steps, block_kv, head_dim), jnp.float16,
|
||||
transforms=(tiling, swizzle),
|
||||
)
|
||||
scratch = [qo_scratch, k_scratch, v_scratch, None]
|
||||
if save_residuals:
|
||||
scratch[3] = plgpu.SMEM((compute_wgs, block_q), jnp.float32)
|
||||
pl.run_scoped(
|
||||
lambda *args: kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, args),
|
||||
scratch,
|
||||
(
|
||||
plgpu.Barrier(num_barriers=max_concurrent_steps),
|
||||
plgpu.Barrier(num_barriers=max_concurrent_steps),
|
||||
plgpu.Barrier(num_barriers=compute_wgs),
|
||||
),
|
||||
(plgpu.Barrier(num_arrivals=compute_wgs, num_barriers=max_concurrent_steps),) * 2,
|
||||
plgpu.Barrier(num_arrivals=compute_wgs),
|
||||
collective_axes="wg",
|
||||
)
|
||||
|
||||
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
|
||||
if rem:
|
||||
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
|
||||
|
||||
out_shape = [q, None]
|
||||
if save_residuals:
|
||||
# Note that we keep seq_len in the minor-most dimension so that we can do
|
||||
# 1D TMAs on chunks of `block_q`.
|
||||
out_shape[1] = jax.ShapeDtypeStruct(
|
||||
(batch_size, num_q_heads, q_seq_len), jnp.float32
|
||||
)
|
||||
|
||||
out, lse = plgpu.kernel(
|
||||
entry,
|
||||
out_shape=out_shape,
|
||||
grid=(num_q_heads, num_q_tiles, batch_size),
|
||||
grid_names=("heads", "q_seq", "batch"),
|
||||
num_threads=3,
|
||||
thread_name="wg",
|
||||
compiler_params=plgpu.CompilerParams(approx_math=True),
|
||||
)(q, k, v)
|
||||
|
||||
if save_residuals:
|
||||
assert lse is not None
|
||||
return out, (lse,)
|
||||
|
||||
return out
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(3, 4))
|
||||
@partial(jax.jit, static_argnames=["config", "save_residuals"])
|
||||
def attention(q, k, v, config: TuningConfig, save_residuals: bool = False):
|
||||
return _attention_forward(q, k, v, config, save_residuals)
|
||||
|
||||
def _attention_fwd(q, k, v, config: TuningConfig, save_residuals: bool):
|
||||
del save_residuals
|
||||
|
||||
out, (lse,) = _attention_forward(q, k, v, config, save_residuals=True)
|
||||
return out, (q, k, v, out, lse)
|
||||
|
||||
def _attention_bwd(config: TuningConfig, save_residuals: bool, res, do):
|
||||
del save_residuals
|
||||
q, k, v, out, lse = res
|
||||
|
||||
if config.causal:
|
||||
raise NotImplementedError("Causal attention not supported in the backwards pass yet.")
|
||||
|
||||
if not config.has_backward_blocks:
|
||||
raise ValueError("Need to specify backward blocks.")
|
||||
|
||||
assert config.block_q_dq is not None
|
||||
assert config.block_kv_dq is not None
|
||||
assert config.block_q_dkv is not None
|
||||
assert config.block_kv_dkv is not None
|
||||
|
||||
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
|
||||
_, kv_seq_len, num_kv_heads, _ = k.shape
|
||||
q_heads_per_kv_head = num_q_heads // num_kv_heads
|
||||
dtype = q.dtype
|
||||
compute_wgs = config.compute_wgs_bwd
|
||||
|
||||
num_q_tiles, rem = divmod(q_seq_len, config.block_q_dq * compute_wgs)
|
||||
if rem:
|
||||
raise NotImplementedError(
|
||||
f"{q_seq_len=} must be a multiple of {config.block_q_dq=} * {compute_wgs=}")
|
||||
|
||||
num_kv_tiles, rem = divmod(kv_seq_len, config.block_kv_dkv * compute_wgs)
|
||||
if rem:
|
||||
raise NotImplementedError(
|
||||
f"{kv_seq_len=} must be a multiple of {config.block_kv_dkv=} * {compute_wgs=}")
|
||||
|
||||
num_q_tiles_in_dkv, rem = divmod(q_seq_len, config.block_q_dkv)
|
||||
if rem:
|
||||
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {config.block_q_dkv=}")
|
||||
|
||||
num_kv_tiles_in_dq, rem = divmod(kv_seq_len, config.block_kv_dq)
|
||||
if rem:
|
||||
raise NotImplementedError(f"{kv_seq_len=} must be a multiple of {config.block_kv_dq=}")
|
||||
|
||||
tiling = plgpu.TilingTransform((8, 64))
|
||||
swizzle = plgpu.SwizzleTransform(128)
|
||||
|
||||
delta = jnp.einsum('bqhd,bqhd->bhq', out.astype(jnp.float32), do.astype(jnp.float32))
|
||||
del out # Not needed anymore.
|
||||
|
||||
def kernel_dq(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref, dq_ref,
|
||||
smem_buffers, buffer_barriers, block_q, block_kv):
|
||||
batch = lax.axis_index("batch")
|
||||
q_head = lax.axis_index("heads")
|
||||
wg_idx = lax.axis_index("wg")
|
||||
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
|
||||
q_smem2, do_smem2, lse_smem2, delta_smem2 = smem_buffers
|
||||
q_barriers, do_barriers, lse_barriers, delta_barriers = buffer_barriers
|
||||
def _compute_thread(pipeline_callback: PipelineCallback, /) -> None:
|
||||
q_smem, do_smem, lse_smem, delta_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx], lse_smem2.at[wg_idx], delta_smem2.at[wg_idx]
|
||||
q_seq_base = lax.axis_index("q_seq") * (compute_wgs * block_q) + wg_idx * block_q
|
||||
q_slice = (batch, pl.ds(q_seq_base, block_q), q_head)
|
||||
plgpu.copy_gmem_to_smem(q_ref.at[q_slice], q_smem, q_barriers.at[wg_idx])
|
||||
plgpu.copy_gmem_to_smem(do_ref.at[q_slice], do_smem, do_barriers.at[wg_idx])
|
||||
plgpu.copy_gmem_to_smem(
|
||||
delta_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)],
|
||||
delta_smem,
|
||||
delta_barriers.at[wg_idx],
|
||||
)
|
||||
plgpu.copy_gmem_to_smem(
|
||||
lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)],
|
||||
lse_smem,
|
||||
lse_barriers.at[wg_idx],
|
||||
)
|
||||
for buffer in buffer_barriers:
|
||||
plgpu.barrier_wait(buffer.at[wg_idx])
|
||||
|
||||
delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA.reduce(1))
|
||||
lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA.reduce(1))
|
||||
dq_acc: jax.Array = plgpu.layout_cast(
|
||||
jnp.full((block_q, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
|
||||
)
|
||||
dq, _, _ = pipeline_callback((dq_acc, lse, delta))
|
||||
q_smem[...] = dq.astype(dtype)
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(q_smem, dq_ref.at[q_slice])
|
||||
plgpu.wait_smem_to_gmem(0)
|
||||
|
||||
def kv_pipeline(_, k_smem, v_smem, k_consumed_barrier, v_consumed_barrier, carry):
|
||||
q_smem, do_smem = q_smem2.at[wg_idx], do_smem2.at[wg_idx]
|
||||
(dq_acc, lse, delta) = carry
|
||||
|
||||
def compute_s(acc_ref):
|
||||
plgpu.wgmma(acc_ref, q_smem, plgpu.transpose_ref(k_smem, (1, 0)))
|
||||
return acc_ref[...]
|
||||
|
||||
s = pl.run_scoped(compute_s, plgpu.ACC((block_q, block_kv), jnp.float32))
|
||||
s *= math.log2(math.e)
|
||||
p = jnp.exp2(s - lax.broadcast_in_dim(lse, (block_q, block_kv), [0]))
|
||||
|
||||
# dP
|
||||
def compute_dp(acc_ref):
|
||||
plgpu.wgmma(acc_ref, do_smem, plgpu.transpose_ref(v_smem, (1, 0)))
|
||||
return acc_ref[...]
|
||||
|
||||
dp = pl.run_scoped(compute_dp, plgpu.ACC((block_q, block_kv), jnp.float32))
|
||||
plgpu.barrier_arrive(v_consumed_barrier)
|
||||
|
||||
# dS
|
||||
ds = p * (dp - lax.broadcast_in_dim(delta, (block_q, block_kv), [0]))
|
||||
|
||||
# dQ
|
||||
def compute_dq(acc_ref):
|
||||
plgpu.wgmma(acc_ref, ds.astype(k_ref.dtype), k_smem)
|
||||
|
||||
dq_acc = pl.run_state(compute_dq)(plgpu.ACC.init(dq_acc))
|
||||
plgpu.barrier_arrive(k_consumed_barrier)
|
||||
|
||||
return (dq_acc, lse, delta)
|
||||
|
||||
pipeline = plgpu.emit_pipeline_warp_specialized(
|
||||
kv_pipeline,
|
||||
grid=(num_kv_tiles_in_dq,),
|
||||
max_concurrent_steps=min([config.max_concurrent_steps, num_q_tiles]),
|
||||
num_compute_wgs=compute_wgs,
|
||||
memory_registers=40,
|
||||
wg_axis="wg",
|
||||
manual_consumed_barriers=True,
|
||||
compute_context=_compute_thread,
|
||||
in_specs=[
|
||||
plgpu.BlockSpec( # k
|
||||
block_shape=(block_kv, head_dim),
|
||||
index_map=lambda i: (i, 0),
|
||||
transforms=[tiling, swizzle]),
|
||||
plgpu.BlockSpec( # v
|
||||
block_shape=(block_kv, head_dim),
|
||||
index_map=lambda i: (i, 0),
|
||||
transforms=[tiling, swizzle]),
|
||||
])
|
||||
k_ref = k_ref.at[batch, :, kv_head, :]
|
||||
v_ref = v_ref.at[batch, :, kv_head, :]
|
||||
pipeline(k_ref, v_ref)
|
||||
|
||||
def kernel_dkv(q_ref, k_ref, v_ref, do_ref, lse_ref, delta_ref,
|
||||
dk_ref, dv_ref, smem_buffers, buffer_barriers, block_q: int, block_kv: int):
|
||||
batch = lax.axis_index("batch")
|
||||
q_head = lax.axis_index("heads")
|
||||
wg_idx = lax.axis_index("wg")
|
||||
(k_smem2, v_smem2) = smem_buffers
|
||||
(k_barriers, v_barriers) = buffer_barriers
|
||||
|
||||
def _compute_thread(pipeline_callback):
|
||||
k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx]
|
||||
kv_seq_base = lax.axis_index("kv_seq") * (compute_wgs * block_kv) + wg_idx * block_kv
|
||||
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
|
||||
plgpu.copy_gmem_to_smem(
|
||||
k_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)],
|
||||
k_smem,
|
||||
k_barriers.at[wg_idx])
|
||||
plgpu.copy_gmem_to_smem(
|
||||
v_ref.at[(batch, pl.ds(kv_seq_base, block_kv), kv_head)],
|
||||
v_smem,
|
||||
v_barriers.at[wg_idx])
|
||||
plgpu.barrier_wait(k_barriers.at[wg_idx])
|
||||
plgpu.barrier_wait(v_barriers.at[wg_idx])
|
||||
dk_acc = plgpu.layout_cast(
|
||||
jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
|
||||
)
|
||||
dv_acc = plgpu.layout_cast(
|
||||
jnp.full((block_kv, head_dim), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
|
||||
)
|
||||
(dk, dv) = pipeline_callback((dv_acc, dk_acc))
|
||||
k_smem[...] = dk.astype(dtype)
|
||||
v_smem[...] = dv.astype(dtype)
|
||||
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(
|
||||
k_smem,
|
||||
dk_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)],
|
||||
commit_group=False)
|
||||
plgpu.copy_smem_to_gmem(
|
||||
v_smem,
|
||||
dv_ref.at[(batch, pl.ds(kv_seq_base, block_kv), q_head)],
|
||||
commit_group=False)
|
||||
plgpu.commit_smem_to_gmem_group()
|
||||
plgpu.wait_smem_to_gmem(0)
|
||||
|
||||
def q_pipeline(_, q_smem, do_smem, lse_smem, delta_smem, q_consumed_barrier, do_consumed_barrier, lse_consumed_barrier, delta_consumed_barrier, carry):
|
||||
k_smem, v_smem = k_smem2.at[wg_idx], v_smem2.at[wg_idx]
|
||||
dk_acc, dv_acc = carry
|
||||
|
||||
def _compute_sT(acc_ref):
|
||||
plgpu.wgmma(acc_ref, k_smem, plgpu.transpose_ref(q_smem, (1, 0)))
|
||||
return acc_ref[...]
|
||||
sT = pl.run_scoped(_compute_sT, plgpu.ACC((block_kv, block_q), jnp.float32))
|
||||
sT *= math.log2(math.e)
|
||||
|
||||
lse = plgpu.load(lse_smem, (), layout=plgpu.Layout.WGMMA.reduce(0))
|
||||
plgpu.barrier_arrive(lse_consumed_barrier)
|
||||
pT = jnp.exp2(sT - lax.broadcast_in_dim(lse, (block_kv, block_q), [1]))
|
||||
|
||||
def _compute(refs):
|
||||
# Combining two WGMMA calls in one block to avoid the unnecessary
|
||||
# synchronization from two `wgmma.wait_group` calls.
|
||||
dv_acc_ref, dpT_acc_ref = refs
|
||||
plgpu.wgmma(dv_acc_ref, pT.astype(dtype), do_smem) # dV
|
||||
plgpu.wgmma(dpT_acc_ref, v_smem, plgpu.transpose_ref(do_smem, (1, 0))) # dpT
|
||||
|
||||
zeros = plgpu.layout_cast(
|
||||
jnp.full((block_kv, block_q), 0, dtype=jnp.float32), plgpu.Layout.WGMMA,
|
||||
)
|
||||
dv_acc, dpT = pl.run_state(_compute)((plgpu.ACC.init(dv_acc), plgpu.ACC.init(zeros)))
|
||||
plgpu.barrier_arrive(do_consumed_barrier)
|
||||
|
||||
delta = plgpu.load(delta_smem, (), layout=plgpu.Layout.WGMMA.reduce(0))
|
||||
plgpu.barrier_arrive(delta_consumed_barrier)
|
||||
|
||||
dsT = pT * (dpT - lax.broadcast_in_dim(delta, (block_kv, block_q), [1])) # jax-operator-types
|
||||
|
||||
def compute_dk(acc_ref):
|
||||
plgpu.wgmma(acc_ref, dsT.astype(dtype), q_smem)
|
||||
|
||||
dk_acc = pl.run_state(compute_dk)(plgpu.ACC.init(dk_acc))
|
||||
plgpu.barrier_arrive(q_consumed_barrier)
|
||||
|
||||
return (dk_acc, dv_acc)
|
||||
|
||||
pipeline = plgpu.emit_pipeline_warp_specialized(
|
||||
q_pipeline,
|
||||
grid=(num_q_tiles_in_dkv,),
|
||||
max_concurrent_steps=min([config.max_concurrent_steps, num_kv_tiles]),
|
||||
num_compute_wgs=compute_wgs,
|
||||
memory_registers=40,
|
||||
wg_axis="wg",
|
||||
manual_consumed_barriers=True,
|
||||
compute_context=_compute_thread,
|
||||
in_specs=[
|
||||
plgpu.BlockSpec( # q
|
||||
block_shape=(block_q, head_dim),
|
||||
index_map=lambda i: (i, 0),
|
||||
transforms=[tiling, swizzle]),
|
||||
plgpu.BlockSpec( # do
|
||||
block_shape=(block_q, head_dim),
|
||||
index_map=lambda i: (i, 0),
|
||||
transforms=[tiling, swizzle]),
|
||||
plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,)),
|
||||
plgpu.BlockSpec(block_shape=(block_q,), index_map=lambda i: (i,))
|
||||
])
|
||||
q_ref = q_ref.at[batch, :, q_head, :]
|
||||
do_ref = do_ref.at[batch, :, q_head, :]
|
||||
lse_ref = lse_ref.at[batch, q_head, :]
|
||||
delta_ref = delta_ref.at[batch, q_head, :]
|
||||
pipeline(q_ref, do_ref, lse_ref, delta_ref)
|
||||
|
||||
q_scratch = plgpu.SMEM(
|
||||
(compute_wgs, config.block_q_dq, head_dim), jnp.float16,
|
||||
transforms=(tiling, swizzle),
|
||||
)
|
||||
do_scratch = q_scratch
|
||||
lse_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32)
|
||||
delta_scratch = plgpu.SMEM((compute_wgs, config.block_q_dq), jnp.float32)
|
||||
dq = plgpu.kernel(
|
||||
partial(kernel_dq, block_q=config.block_q_dq, block_kv=config.block_kv_dq),
|
||||
out_shape=q,
|
||||
scratch_shapes=[
|
||||
(q_scratch, do_scratch, lse_scratch, delta_scratch),
|
||||
(plgpu.Barrier(num_barriers=compute_wgs),) * 4
|
||||
],
|
||||
compiler_params=plgpu.CompilerParams(approx_math=True),
|
||||
grid=(num_q_heads, num_q_tiles, batch_size),
|
||||
grid_names=("heads", "q_seq", "batch"),
|
||||
num_threads=compute_wgs + 1,
|
||||
thread_name="wg",
|
||||
)(q, k, v, do, lse, delta)
|
||||
|
||||
k_scratch = plgpu.SMEM(
|
||||
(compute_wgs, config.block_kv_dkv, head_dim), jnp.float16,
|
||||
transforms=(tiling, swizzle),
|
||||
)
|
||||
v_scratch = k_scratch
|
||||
out_shape_kv = jax.ShapeDtypeStruct(
|
||||
(batch_size, kv_seq_len, num_q_heads, head_dim), dtype=jnp.float16)
|
||||
dk, dv = plgpu.kernel(
|
||||
partial(kernel_dkv, block_q=config.block_q_dkv, block_kv=config.block_kv_dkv),
|
||||
out_shape=[out_shape_kv, out_shape_kv],
|
||||
scratch_shapes=[
|
||||
(k_scratch, v_scratch),
|
||||
(plgpu.Barrier(num_barriers=compute_wgs),) * 2
|
||||
],
|
||||
compiler_params=plgpu.CompilerParams(approx_math=True),
|
||||
grid=(num_q_heads, num_kv_tiles, batch_size),
|
||||
grid_names=("heads", "kv_seq", "batch"),
|
||||
num_threads=compute_wgs + 1,
|
||||
thread_name="wg"
|
||||
)(q, k, v, do, lse, delta)
|
||||
|
||||
if q_heads_per_kv_head > 1:
|
||||
sum_shape = (*k.shape[:-1], q_heads_per_kv_head, head_dim)
|
||||
dk = dk.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dk.dtype)
|
||||
dv = dv.reshape(sum_shape).astype(jnp.float32).sum(axis=-2).astype(dv.dtype)
|
||||
|
||||
return dq, dk, dv
|
||||
|
||||
attention.defvjp(_attention_fwd, _attention_bwd)
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["config", "save_residuals"])
|
||||
def attention_with_pipeline_emitter(q, k, v, config: TuningConfig, save_residuals=False):
|
||||
if config.causal:
|
||||
raise NotImplementedError("Causal attention is not supported with the pipeline emitter yet.")
|
||||
if q.ndim != 4 or k.ndim != 4 or v.ndim != 4:
|
||||
raise ValueError(f"q, k, and v should all be 4D, got: {q.ndim=}, {k.ndim=}, {v.ndim=}")
|
||||
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
|
||||
_, kv_seq_len, num_kv_heads, _ = k.shape
|
||||
kv_shape = (batch_size, kv_seq_len, num_kv_heads, head_dim)
|
||||
if k.shape != kv_shape:
|
||||
raise ValueError(f"Expected {k.shape=} to be {kv_shape} (inferred from q)")
|
||||
if k.shape != kv_shape:
|
||||
raise ValueError(f"Expected {v.shape=} to be {kv_shape} (inferred from q)")
|
||||
if (dtype := q.dtype) != k.dtype or dtype != v.dtype:
|
||||
raise ValueError(f"q, k, and v should all have the same dtype, got: {q.dtype}, {k.dtype}, {v.dtype}")
|
||||
if num_q_heads % num_kv_heads:
|
||||
raise ValueError(f"{num_q_heads=} must be divisible by and {num_kv_heads=}")
|
||||
q_heads_per_kv_head = num_q_heads // num_kv_heads
|
||||
if head_dim % 64:
|
||||
raise ValueError(f"{head_dim=} must be divisible by 64")
|
||||
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
|
||||
raise NotImplementedError(f"Only f16 and bf16 are supported, got dtype: {dtype}")
|
||||
|
||||
max_concurrent_steps = min(
|
||||
config.max_concurrent_steps, kv_seq_len // config.block_kv
|
||||
)
|
||||
compute_wgs = 2
|
||||
block_q, block_kv = config.block_q, config.block_kv
|
||||
num_q_tiles, rem = divmod(q_seq_len, block_q * 2)
|
||||
if rem:
|
||||
raise NotImplementedError(f"{q_seq_len=} must be a multiple of {block_q * 2=}")
|
||||
|
||||
def fa3_kernel(q_ref, k_ref, v_ref, out_ref, lse_ref, smem_buffers, q_barriers, schedule_barrier):
|
||||
batch = lax.axis_index("batch")
|
||||
wg_idx = lax.axis_index("wg")
|
||||
qo_smem2, lse_smem2 = smem_buffers
|
||||
q_seq_base = lax.axis_index("q_seq") * (2 * block_q) + wg_idx * block_q
|
||||
q_head = lax.axis_index("heads")
|
||||
kv_head = lax.div(q_head, jnp.array(q_heads_per_kv_head, q_head.dtype))
|
||||
|
||||
def perform_schedule_barrier():
|
||||
if config.use_schedule_barrier:
|
||||
plgpu.barrier_arrive(schedule_barrier)
|
||||
plgpu.barrier_wait(schedule_barrier)
|
||||
|
||||
def _compute_thread(pipeline_callback: PipelineCallback, /) -> None:
|
||||
qo_smem = qo_smem2.at[wg_idx]
|
||||
lse_smem = lse_smem2.at[wg_idx] if lse_smem2 is not None else None
|
||||
m_i = jnp.full((block_q,), -jnp.inf, dtype=jnp.float32)
|
||||
l_i = jnp.full((block_q,), 0, dtype=jnp.float32)
|
||||
acc = jnp.full((block_q, head_dim), 0, dtype=jnp.float32)
|
||||
# Q is not pipelined, so we load in with a manual DMA.
|
||||
plgpu.copy_gmem_to_smem(
|
||||
q_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
|
||||
qo_smem,
|
||||
q_barriers.at[wg_idx],
|
||||
)
|
||||
plgpu.barrier_wait(q_barriers.at[wg_idx])
|
||||
pl.when(wg_idx == 1)(perform_schedule_barrier)
|
||||
final_carry = pipeline_callback((acc, m_i, l_i))
|
||||
pl.when(wg_idx == 0)(perform_schedule_barrier)
|
||||
acc, m_i, l_i = final_carry
|
||||
acc /= lax.broadcast_in_dim(l_i, (block_q, head_dim), [0])
|
||||
qo_smem[...] = acc.astype(dtype)
|
||||
if lse_smem is not None:
|
||||
RCP_LN2 = 1.4426950408889634
|
||||
log2 = lambda x: jnp.log(x) * RCP_LN2
|
||||
lse_smem[...] = m_i + log2(l_i)
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(
|
||||
qo_smem, out_ref.at[batch, pl.ds(q_seq_base, block_q), q_head],
|
||||
)
|
||||
if lse_smem is not None:
|
||||
plgpu.copy_smem_to_gmem(
|
||||
lse_smem,
|
||||
lse_ref.at[batch, q_head, pl.ds(q_seq_base, block_q)],
|
||||
)
|
||||
plgpu.wait_smem_to_gmem(0)
|
||||
|
||||
def kv_pipeline(_, k_smem, v_smem,
|
||||
k_consumed_barrier, v_consumed_barrier,
|
||||
carry):
|
||||
acc, m_i, l_i = carry
|
||||
qo_smem = qo_smem2.at[wg_idx]
|
||||
def compute_qk(acc_ref):
|
||||
plgpu.wgmma(acc_ref, qo_smem, plgpu.transpose_ref(k_smem, (1, 0)))
|
||||
perform_schedule_barrier()
|
||||
return acc_ref[...]
|
||||
qk = pl.run_scoped(compute_qk, plgpu.ACC((block_q, block_kv), jnp.float32))
|
||||
plgpu.barrier_arrive(k_consumed_barrier)
|
||||
|
||||
# Softmax
|
||||
# We keep m scaled by log2e to use FMA instructions when computing p.
|
||||
log2e = math.log2(math.e)
|
||||
m_ij = jnp.maximum(m_i, qk.max(axis=1) * log2e)
|
||||
alpha = jnp.exp2(m_i - m_ij)
|
||||
m_i = m_ij
|
||||
p = jnp.exp2(qk * log2e - lax.broadcast_in_dim(m_ij, qk.shape, [0]))
|
||||
acc *= lax.broadcast_in_dim(alpha, acc.shape, [0])
|
||||
l_i *= alpha
|
||||
p16 = p.astype(dtype)
|
||||
perform_schedule_barrier()
|
||||
l_i += p.sum(axis=1)
|
||||
# PV
|
||||
def compute_pv(acc_ref):
|
||||
plgpu.wgmma(acc_ref, p16, v_smem)
|
||||
acc = pl.run_state(compute_pv)(plgpu.ACC.init(acc))
|
||||
plgpu.barrier_arrive(v_consumed_barrier)
|
||||
return acc, m_i, l_i
|
||||
pipeline = plgpu.emit_pipeline_warp_specialized(
|
||||
kv_pipeline,
|
||||
grid=(kv_seq_len // block_kv,),
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
num_compute_wgs=compute_wgs,
|
||||
memory_registers=40,
|
||||
wg_axis="wg",
|
||||
manual_consumed_barriers=True,
|
||||
compute_context=_compute_thread,
|
||||
in_specs=[
|
||||
plgpu.BlockSpec( # k
|
||||
block_shape=(block_kv, head_dim),
|
||||
index_map=lambda i: (i, 0)),
|
||||
plgpu.BlockSpec( # v
|
||||
block_shape=(block_kv, head_dim),
|
||||
index_map=lambda i: (i, 0)),
|
||||
],
|
||||
out_specs=[],
|
||||
)
|
||||
k_ref = k_ref.at[batch, :, kv_head, :]
|
||||
v_ref = v_ref.at[batch, :, kv_head, :]
|
||||
pipeline(k_ref, v_ref)
|
||||
|
||||
out_shape = [q, None]
|
||||
if save_residuals:
|
||||
out_shape[1] = jax.ShapeDtypeStruct((batch_size, num_q_heads, q_seq_len), jnp.float32)
|
||||
|
||||
qo_scratch = plgpu.SMEM((compute_wgs, block_q, head_dim), jnp.float16)
|
||||
smem_scratch = [qo_scratch, None]
|
||||
if save_residuals:
|
||||
smem_scratch[1] = plgpu.SMEM((compute_wgs, block_q), jnp.float32)
|
||||
|
||||
out, lse = plgpu.kernel(
|
||||
fa3_kernel,
|
||||
grid=(num_q_heads, num_q_tiles, batch_size),
|
||||
grid_names=("heads", "q_seq", "batch"),
|
||||
num_threads=3,
|
||||
thread_name="wg",
|
||||
out_shape=out_shape,
|
||||
scratch_shapes=(
|
||||
tuple(smem_scratch),
|
||||
plgpu.Barrier(num_barriers=compute_wgs),
|
||||
plgpu.Barrier(num_arrivals=compute_wgs),),
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
approx_math=True, lowering_semantics=plgpu.LoweringSemantics.Warpgroup,
|
||||
),
|
||||
)(q, k, v)
|
||||
|
||||
if save_residuals:
|
||||
assert lse is not None
|
||||
return out, (lse,)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["causal", "save_residuals"])
|
||||
def attention_reference(q, k, v, causal=False, save_residuals=False):
|
||||
batch_size, q_seq_len, num_q_heads, head_dim = q.shape
|
||||
kv_seq_len, num_kv_heads = k.shape[1], k.shape[2]
|
||||
q, k, v = map(lambda x: x.astype(jnp.float32), (q, k, v))
|
||||
q_reshaped = q.reshape(
|
||||
batch_size, q_seq_len, num_kv_heads, num_q_heads // num_kv_heads, head_dim
|
||||
)
|
||||
logits = jnp.einsum("bqHhc,bkHc->bqHhk", q_reshaped, k)
|
||||
|
||||
if causal:
|
||||
mask = jnp.arange(q_seq_len)[:, None] >= jnp.arange(kv_seq_len)[None, :]
|
||||
mask = jnp.broadcast_to(mask[:, None, None, :], logits.shape)
|
||||
logits = jnp.where(mask, logits, -jnp.inf)
|
||||
|
||||
m = logits.max(axis=-1, keepdims=True)
|
||||
unnormalized = jnp.exp(logits - m)
|
||||
l = unnormalized.sum(axis=-1, keepdims=True)
|
||||
weights = unnormalized / l
|
||||
out = jnp.einsum("bqHhk,bkHc->bqHhc", weights, v).reshape(*q.shape)
|
||||
|
||||
if save_residuals:
|
||||
log2e = math.log2(math.e)
|
||||
l = l.reshape(*q.shape[:-1])
|
||||
m = m.reshape(*q.shape[:-1])
|
||||
lse = m * log2e + jnp.log2(l)
|
||||
return out, (lse.swapaxes(-1, -2),)
|
||||
else:
|
||||
return out
|
||||
|
||||
def main(unused_argv):
|
||||
num_q_heads = 16
|
||||
num_kv_heads = 16
|
||||
use_pipeline_emitter = False
|
||||
if use_pipeline_emitter:
|
||||
attention_impl = attention_with_pipeline_emitter
|
||||
schedule_barrier_opts = (True, False)
|
||||
else:
|
||||
attention_impl = attention
|
||||
schedule_barrier_opts = (True,)
|
||||
|
||||
problem_it = itertools.product(
|
||||
(1,), (4096, 32768,), (64, 128, 256,), schedule_barrier_opts, (False, True))
|
||||
for batch_size, seq_len, head_dim, use_schedule_barrier, causal in problem_it:
|
||||
assert cuda_versions is not None
|
||||
cuda_runtime_version = cuda_versions.cuda_runtime_get_version()
|
||||
# TODO(pobudzey): Undo when we upgrade to cuda 12.9.1.
|
||||
if causal and cuda_runtime_version >= 12080 and cuda_runtime_version < 12091:
|
||||
continue
|
||||
|
||||
if causal and use_pipeline_emitter:
|
||||
continue
|
||||
q_seq_len = kv_seq_len = seq_len
|
||||
print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}"
|
||||
f"{num_q_heads=:<4} {head_dim=:<6} {use_schedule_barrier=:} {causal=:} ====")
|
||||
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
|
||||
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
|
||||
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
|
||||
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
|
||||
block_q = 64
|
||||
best = None
|
||||
for block_kv in (256, 128, 64):
|
||||
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2, use_schedule_barrier=use_schedule_barrier, causal=causal)
|
||||
try:
|
||||
out, runtime_ms = profiler.measure(functools.partial(attention_impl, config=config))(q, k, v)
|
||||
if seq_len < 32768:
|
||||
out_ref = attention_reference(q, k, v, causal=causal)
|
||||
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
|
||||
except ValueError as e:
|
||||
if "exceeds available shared memory" in e.args[0]:
|
||||
continue
|
||||
raise
|
||||
assert runtime_ms is not None
|
||||
runtime_us = runtime_ms * 1e3
|
||||
matmul_flops = (
|
||||
4 * q_seq_len * kv_seq_len * head_dim * num_q_heads * batch_size
|
||||
)
|
||||
if causal:
|
||||
matmul_flops //= 2
|
||||
peak_flops = 1e15 # f16 TensorCore peak = 1000TFLOPS
|
||||
optimal_time = matmul_flops / peak_flops * 1e6 # us
|
||||
achieved_tc_util = optimal_time / runtime_us * 100
|
||||
print(
|
||||
f"block_q={block_q:<4}block_kv={block_kv:<4}: {runtime_us:<7.1f}us"
|
||||
f" = {achieved_tc_util:4.1f}% TC utilization"
|
||||
)
|
||||
if best is None or runtime_us < best[0]:
|
||||
best = (runtime_us, achieved_tc_util)
|
||||
break # Remove this for full autotuning.
|
||||
if best is not None:
|
||||
print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
import jax
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
+343
@@ -0,0 +1,343 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Matrix Multiplication kernel for Blackwell GPUs."""
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import statistics
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu # noqa: F401
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
import jax.experimental.pallas as pl
|
||||
import jax.experimental.pallas.mosaic_gpu as plgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
class MatmulDimension(enum.IntEnum):
|
||||
M = 0
|
||||
N = 1
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TuningConfig:
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
max_concurrent_steps: int
|
||||
collective: bool
|
||||
epilogue_tile_n: int = 64
|
||||
grid_minor_dim: MatmulDimension = MatmulDimension.N
|
||||
grid_tile_width: int = 1
|
||||
|
||||
|
||||
def matmul_kernel(a, b, config: TuningConfig):
|
||||
dtype = a.dtype
|
||||
if a.dtype != b.dtype:
|
||||
raise ValueError(
|
||||
f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}"
|
||||
)
|
||||
m, k = a.shape
|
||||
k2, n = b.shape
|
||||
if k != k2:
|
||||
raise ValueError(
|
||||
f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}"
|
||||
)
|
||||
collective = config.collective
|
||||
tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k)
|
||||
epilogue_tile_n = config.epilogue_tile_n
|
||||
if tile_n % epilogue_tile_n != 0:
|
||||
raise ValueError(
|
||||
f"{tile_n=} must be divisible by {epilogue_tile_n=}"
|
||||
)
|
||||
block_tile_m = tile_m
|
||||
block_tile_n = tile_n
|
||||
if collective:
|
||||
tile_m *= 2
|
||||
tile_n *= 2
|
||||
swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
|
||||
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
|
||||
transforms = (
|
||||
plgpu.TilingTransform((8, swizzle_elems)),
|
||||
plgpu.SwizzleTransform(swizzle),
|
||||
)
|
||||
out_swizzle = plgpu.find_swizzle(epilogue_tile_n * jnp.dtype(dtype).itemsize * 8)
|
||||
out_swizzle_elems = out_swizzle // jnp.dtype(dtype).itemsize
|
||||
out_transforms = (
|
||||
plgpu.TilingTransform((8, out_swizzle_elems)),
|
||||
plgpu.SwizzleTransform(out_swizzle),
|
||||
)
|
||||
if m % tile_m != 0:
|
||||
raise ValueError(f"{m=} must be divisible by {tile_m=}")
|
||||
if n % tile_n != 0:
|
||||
raise ValueError(f"{n=} must be divisible by {tile_n=}")
|
||||
if k % tile_k != 0:
|
||||
raise ValueError(f"{k=} must be divisible by {tile_k=}")
|
||||
m_iters = m // tile_m
|
||||
n_iters = n // tile_n
|
||||
k_iters = k // tile_k
|
||||
max_concurrent_steps = config.max_concurrent_steps
|
||||
|
||||
TMA_WARP = 0
|
||||
MMA_WARP = 1
|
||||
COMPUTE_WG = 0
|
||||
STORE_WG = 1
|
||||
|
||||
def kernel(a_gmem, b_gmem, out_gmem,
|
||||
a_smem, b_smem, acc_tmem, acc_smem,
|
||||
ab_tma_barrier, store_done_barrier, mma_done_barrier,
|
||||
consumed_barrier):
|
||||
wg_idx = lax.axis_index("wg")
|
||||
cluster_idx = lax.axis_index("x")
|
||||
is_lead_block = cluster_idx == 0
|
||||
|
||||
@plgpu.dynamic_scheduling_loop(grid_names=("mn_linear",), thread_axis="wg")
|
||||
def mn_loop(loop_info: plgpu.NDLoopInfo):
|
||||
(lin_idx,) = loop_info.index
|
||||
local_index = loop_info.local_index
|
||||
m_index, n_index = plgpu.planar_snake(
|
||||
lin_idx,
|
||||
(m_iters, n_iters),
|
||||
config.grid_minor_dim,
|
||||
config.grid_tile_width,
|
||||
)
|
||||
block_m_index = m_index * 2 + cluster_idx if collective else m_index
|
||||
|
||||
block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m)
|
||||
slice_m = pl.ds(m_index * tile_m, tile_m)
|
||||
slice_n = pl.ds(n_index * tile_n, tile_n)
|
||||
acc_slot = lax.rem(local_index, jnp.int32(2))
|
||||
|
||||
@pl.when(wg_idx == COMPUTE_WG)
|
||||
def _():
|
||||
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
|
||||
def _per_warp():
|
||||
warp_id = lax.axis_index("warp")
|
||||
@pl.when(warp_id == TMA_WARP)
|
||||
def _memory():
|
||||
def _loop_body(ki, _):
|
||||
slice_k = pl.ds(ki * tile_k, tile_k)
|
||||
slot = lax.rem(ki, max_concurrent_steps)
|
||||
@pl.when(jnp.logical_or(ki >= max_concurrent_steps,
|
||||
local_index > 0))
|
||||
def _():
|
||||
plgpu.barrier_wait(consumed_barrier.at[slot])
|
||||
plgpu.copy_gmem_to_smem(
|
||||
a_gmem.at[slice_m, slice_k],
|
||||
a_smem.at[slot],
|
||||
ab_tma_barrier.at[slot],
|
||||
leader_tracked=plgpu.CopyPartition.PARTITIONED(0)
|
||||
if collective
|
||||
else None,
|
||||
collective_axes="x" if collective else None,
|
||||
)
|
||||
plgpu.copy_gmem_to_smem(
|
||||
b_gmem.at[slice_k, slice_n],
|
||||
b_smem.at[slot],
|
||||
ab_tma_barrier.at[slot],
|
||||
leader_tracked=plgpu.CopyPartition.PARTITIONED(1)
|
||||
if collective
|
||||
else None,
|
||||
collective_axes="x" if collective else None,
|
||||
)
|
||||
lax.fori_loop(0, k_iters, _loop_body, None)
|
||||
|
||||
@pl.when(jnp.logical_and(warp_id == MMA_WARP, local_index > 1))
|
||||
def _wait_store():
|
||||
plgpu.barrier_wait(store_done_barrier.at[acc_slot])
|
||||
@pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block))
|
||||
def _compute():
|
||||
def _loop_body(ki, _):
|
||||
slot = lax.rem(ki, max_concurrent_steps)
|
||||
plgpu.barrier_wait(ab_tma_barrier.at[slot])
|
||||
|
||||
is_last_iter = ki >= k_iters - 1
|
||||
acc_tmem_slice = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)]
|
||||
plgpu.tcgen05_mma(
|
||||
acc_tmem_slice,
|
||||
a_smem.at[slot],
|
||||
b_smem.at[slot],
|
||||
consumed_barrier.at[slot],
|
||||
accumulate=(ki > 0),
|
||||
collective_axis="x" if collective else None,
|
||||
)
|
||||
@pl.when(is_last_iter)
|
||||
def _():
|
||||
plgpu.tcgen05_commit_arrive(
|
||||
mma_done_barrier.at[acc_slot],
|
||||
collective_axis="x" if collective else None,
|
||||
)
|
||||
|
||||
lax.fori_loop(0, k_iters, _loop_body, None)
|
||||
|
||||
@pl.when(wg_idx == STORE_WG)
|
||||
def _():
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
plgpu.barrier_wait(mma_done_barrier.at[acc_slot])
|
||||
acc_tmem_slot = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)]
|
||||
step_out_gmem = out_gmem.at[block_slice_m, slice_n]
|
||||
for ni in range(tile_n // epilogue_tile_n):
|
||||
acc_smem_ni = acc_smem.at[ni % 2]
|
||||
ni_col_slice = pl.ds(ni * epilogue_tile_n, epilogue_tile_n)
|
||||
acc_smem_ni[...] = plgpu.async_load_tmem(
|
||||
acc_tmem_slot.at[:, ni_col_slice]
|
||||
).astype(dtype)
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(acc_smem_ni, step_out_gmem.at[:, ni_col_slice])
|
||||
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
|
||||
plgpu.wait_load_tmem() # Load must complete before we continue.
|
||||
plgpu.barrier_arrive(store_done_barrier.at[acc_slot])
|
||||
|
||||
if collective:
|
||||
store_done_barrier = plgpu.ClusterBarrier(
|
||||
collective_axes=("x",),
|
||||
num_arrivals=1,
|
||||
num_barriers=2,
|
||||
orders_tensor_core=True,
|
||||
)
|
||||
else:
|
||||
store_done_barrier = plgpu.Barrier(
|
||||
num_arrivals=1, num_barriers=2, orders_tensor_core=True
|
||||
)
|
||||
f = plgpu.kernel(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
|
||||
grid=(m_iters * n_iters,),
|
||||
grid_names=("mn_linear",),
|
||||
num_threads=2,
|
||||
thread_name="wg",
|
||||
cluster_names=("x",),
|
||||
cluster=(1 + collective,),
|
||||
scratch_shapes=dict(
|
||||
a_smem=plgpu.SMEM(
|
||||
(max_concurrent_steps, block_tile_m, tile_k),
|
||||
dtype,
|
||||
transforms=transforms,
|
||||
),
|
||||
b_smem=plgpu.SMEM(
|
||||
(max_concurrent_steps, tile_k, block_tile_n),
|
||||
dtype,
|
||||
transforms=transforms,
|
||||
),
|
||||
acc_tmem=plgpu.TMEM(
|
||||
(block_tile_m, tile_n * 2), jnp.float32, collective=collective
|
||||
),
|
||||
acc_smem=plgpu.SMEM(
|
||||
(2, block_tile_m, epilogue_tile_n),
|
||||
dtype,
|
||||
transforms=out_transforms,
|
||||
),
|
||||
ab_tma_barrier=plgpu.Barrier(
|
||||
num_arrivals=2, num_barriers=max_concurrent_steps
|
||||
),
|
||||
store_done_barrier=store_done_barrier,
|
||||
mma_done_barrier=plgpu.Barrier(
|
||||
num_arrivals=1, num_barriers=2, orders_tensor_core=True
|
||||
),
|
||||
consumed_barrier=plgpu.Barrier(
|
||||
num_arrivals=1,
|
||||
num_barriers=max_concurrent_steps,
|
||||
orders_tensor_core=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
return f(a, b)
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
problem_it = [(4096, 8192, 4096)]
|
||||
for M, N, K in problem_it:
|
||||
print(f"==== {M=} {N=} {K=} ====")
|
||||
matmul_flops = 2 * M * N * K
|
||||
peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS
|
||||
a = jax.random.uniform(jax.random.key(1), (M, K), jnp.float16, -1, 1)
|
||||
b = jax.random.uniform(jax.random.key(2), (K, N), jnp.float16, -1, 1)
|
||||
tuning_it = itertools.product(
|
||||
(128,), # tile_m
|
||||
(128, 256), # tile_n
|
||||
(64,), # tile_k
|
||||
MatmulDimension, # grid_minor_dim
|
||||
(1, 4, 8, 12, 16), # grid_tile_width
|
||||
(2, 4, 6), # max_concurrent_steps
|
||||
(False, True), # collective
|
||||
(32,), # epilogue_tile_n
|
||||
)
|
||||
best_util = -float("inf")
|
||||
expected = jnp.dot(a, b, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32)
|
||||
for (tile_m, tile_n, tile_k, grid_minor_dim, grid_tile_width,
|
||||
max_concurrent_steps, collective, epilogue_tile_n) in tuning_it:
|
||||
# Only N <= 128 are supported for collective MMAs
|
||||
if collective and tile_n > 128:
|
||||
continue
|
||||
config = TuningConfig(
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
collective=collective,
|
||||
epilogue_tile_n=epilogue_tile_n,
|
||||
grid_minor_dim=grid_minor_dim,
|
||||
grid_tile_width=grid_tile_width,
|
||||
)
|
||||
if collective:
|
||||
tile_m *= 2
|
||||
tile_n *= 2
|
||||
try:
|
||||
out, runtimes_ms = profiler.measure(
|
||||
functools.partial(matmul_kernel, config=config), iterations=10
|
||||
)(a, b)
|
||||
assert runtimes_ms is not None
|
||||
runtime_ms = statistics.median(runtimes_ms)
|
||||
except ValueError as e:
|
||||
if ("exceeds available shared memory" in e.args[0] or
|
||||
"Accumulator layout mismatch:" in e.args[0]):
|
||||
# Accumulator layout mismatch triggers for tile_n=256 on some configs.
|
||||
continue
|
||||
raise
|
||||
runtime_us = runtime_ms * 1e3
|
||||
optimal_time = matmul_flops / peak_flops * 1e6 # us
|
||||
achieved_tc_util = optimal_time / runtime_us * 100
|
||||
if achieved_tc_util > best_util:
|
||||
np.testing.assert_allclose(out, expected)
|
||||
best_util = achieved_tc_util
|
||||
print(
|
||||
f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} "
|
||||
f"{grid_minor_dim=} {grid_tile_width=} "
|
||||
f"{epilogue_tile_n=} "
|
||||
f"{collective=} : "
|
||||
f"{runtime_us:<7.1f}us"
|
||||
f" = {achieved_tc_util:4.1f}% TC utilization"
|
||||
)
|
||||
print(f"\tBest utilization: {best_util:4.1f}%")
|
||||
_, runtimes_ms = profiler.measure(
|
||||
functools.partial(
|
||||
jnp.dot, precision=jax.lax.DotAlgorithmPreset.F16_F16_F32
|
||||
),
|
||||
iterations=10,
|
||||
)(a, b)
|
||||
assert runtimes_ms is not None
|
||||
runtime_ms = statistics.median(runtimes_ms)
|
||||
runtime_us = runtime_ms * 1e3
|
||||
optimal_time = matmul_flops / peak_flops * 1e6 # us
|
||||
achieved_tc_util = optimal_time / runtime_us * 100
|
||||
print(f"\tReference: {achieved_tc_util:4.1f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
+441
@@ -0,0 +1,441 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Ragged/Grouped Matrix Multiplication kernel for Blackwell GPUs."""
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu # noqa: F401
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
import jax.experimental.pallas as pl
|
||||
import jax.experimental.pallas.mosaic_gpu as plgpu
|
||||
from jax.experimental.pallas.ops.gpu import blackwell_matmul_mgpu
|
||||
from jax.experimental.pallas.ops.gpu import ragged_dot_mgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TuningConfig:
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
max_concurrent_steps: int
|
||||
collective: bool
|
||||
grid_tile_width: int
|
||||
grid_minor_dim: blackwell_matmul_mgpu.MatmulDimension
|
||||
epilogue_tile_n: int = 64
|
||||
|
||||
def __str__(self):
|
||||
return "_".join(f"{k}={v}" for k, v in dataclasses.asdict(self).items())
|
||||
|
||||
|
||||
# TODO(justinfu): Merge with blackwell_matmul_mgpu.py
|
||||
def do_matmul(a_gmem,
|
||||
b_gmem,
|
||||
out_gmem,
|
||||
grid_indices: Sequence[jax.Array],
|
||||
wg_axis: str,
|
||||
collective_axes: tuple[str, ...],
|
||||
local_index: jax.Array | int,
|
||||
config: TuningConfig,
|
||||
group_info: ragged_dot_mgpu.GroupInfo,
|
||||
a_smem, b_smem, acc_tmem, acc_smem,
|
||||
a_tma_barrier, b_tma_barrier, store_done_barrier, mma_done_barrier,
|
||||
consumed_barrier
|
||||
):
|
||||
"""Compute a non-ragged matmul for a single output block."""
|
||||
dtype = out_gmem.dtype
|
||||
m, k = a_gmem.shape
|
||||
collective = config.collective
|
||||
tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k)
|
||||
epilogue_tile_n = config.epilogue_tile_n
|
||||
max_concurrent_steps = config.max_concurrent_steps
|
||||
block_tile_m = tile_m
|
||||
if collective:
|
||||
tile_m *= 2
|
||||
tile_n *= 2
|
||||
k_iters = k // tile_k
|
||||
|
||||
if collective:
|
||||
m_index, n_index, cluster_idx = grid_indices
|
||||
block_m_index = m_index * 2 + cluster_idx
|
||||
is_lead_block = cluster_idx == 0
|
||||
else:
|
||||
m_index, n_index = grid_indices
|
||||
cluster_idx = 0
|
||||
block_m_index = m_index
|
||||
is_lead_block = True
|
||||
wg_idx = lax.axis_index(wg_axis)
|
||||
collective_axis = collective_axes[0] if collective else None
|
||||
|
||||
TMA_WARP = 0
|
||||
MMA_WARP = 1
|
||||
COMPUTE_WG = 0
|
||||
STORE_WG = 1
|
||||
|
||||
block_slice_m = pl.ds(block_m_index * block_tile_m, block_tile_m)
|
||||
slice_m = pl.ds(m_index * tile_m, tile_m)
|
||||
slice_n = pl.ds(n_index * tile_n, tile_n)
|
||||
acc_slot = lax.rem(local_index, jnp.int32(2))
|
||||
regs_layout = plgpu.Layout.TCGEN05
|
||||
|
||||
@pl.when(wg_idx == COMPUTE_WG)
|
||||
@jax.named_scope("compute_wg")
|
||||
def _():
|
||||
@pl.core_map(plgpu.WarpMesh(axis_name="warp"))
|
||||
def _per_warp():
|
||||
warp_id = lax.axis_index("warp")
|
||||
@pl.when(warp_id == TMA_WARP)
|
||||
def _memory():
|
||||
def _loop_body(ki, _):
|
||||
slice_k = pl.ds(ki * tile_k, tile_k)
|
||||
slot = lax.rem(ki, max_concurrent_steps)
|
||||
@pl.when(jnp.logical_or(ki >= max_concurrent_steps,
|
||||
local_index > 0))
|
||||
def _():
|
||||
plgpu.barrier_wait(consumed_barrier.at[slot])
|
||||
plgpu.copy_gmem_to_smem(
|
||||
a_gmem.at[slice_m, slice_k],
|
||||
a_smem.at[slot],
|
||||
a_tma_barrier.at[slot],
|
||||
leader_tracked=plgpu.CopyPartition.PARTITIONED(0)
|
||||
if collective
|
||||
else None,
|
||||
collective_axes=collective_axis,
|
||||
)
|
||||
plgpu.copy_gmem_to_smem(
|
||||
b_gmem.at[slice_k, slice_n],
|
||||
b_smem.at[slot],
|
||||
b_tma_barrier.at[slot],
|
||||
leader_tracked=plgpu.CopyPartition.PARTITIONED(1)
|
||||
if collective
|
||||
else None,
|
||||
collective_axes=collective_axis,
|
||||
)
|
||||
lax.fori_loop(0, k_iters, _loop_body, None)
|
||||
|
||||
@pl.when(jnp.logical_and(warp_id == MMA_WARP, local_index > 1))
|
||||
def _wait_store():
|
||||
plgpu.barrier_wait(store_done_barrier.at[acc_slot])
|
||||
@pl.when(jnp.logical_and(warp_id == MMA_WARP, is_lead_block))
|
||||
def _compute():
|
||||
def _loop_body(ki, _):
|
||||
slot = lax.rem(ki, max_concurrent_steps)
|
||||
plgpu.barrier_wait(a_tma_barrier.at[slot])
|
||||
plgpu.barrier_wait(b_tma_barrier.at[slot])
|
||||
|
||||
is_last_iter = ki >= k_iters - 1
|
||||
acc_tmem_slice = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)]
|
||||
plgpu.tcgen05_mma(
|
||||
acc_tmem_slice,
|
||||
a_smem.at[slot],
|
||||
b_smem.at[slot],
|
||||
consumed_barrier.at[slot],
|
||||
accumulate=(ki > 0),
|
||||
collective_axis=collective_axis,
|
||||
)
|
||||
@pl.when(is_last_iter)
|
||||
def _():
|
||||
plgpu.tcgen05_commit_arrive(
|
||||
mma_done_barrier.at[acc_slot],
|
||||
collective_axis=collective_axis,
|
||||
)
|
||||
|
||||
lax.fori_loop(0, k_iters, _loop_body, None)
|
||||
|
||||
@pl.when(wg_idx == STORE_WG)
|
||||
@jax.named_scope("store_wg")
|
||||
def _():
|
||||
plgpu.barrier_wait(mma_done_barrier.at[acc_slot])
|
||||
acc_tmem_slot = acc_tmem.at[:, pl.ds(acc_slot * tile_n, tile_n)]
|
||||
step_out_gmem = out_gmem.at[block_slice_m, slice_n]
|
||||
# group_info contains start/size info relative to the logical
|
||||
# tiling (tile_m) but because for collective matmuls we use 2 CTAs per
|
||||
# logical block, but we need to compute the start/size relative to the
|
||||
# current block.
|
||||
# For example, for the following parameters:
|
||||
# block_tile_m=64 (tile_m=128)
|
||||
# group_info.start_within_block=60
|
||||
# group_info.actual_size=37
|
||||
# The requested copy will be split across both blocks
|
||||
# Memory: | Block 0 | Block 1 |
|
||||
# |--- 64 ---|--- 64 ---|
|
||||
# Copy: |-- 37 --|
|
||||
# Where block 0 copies rows 60-64 (4 rows total) and block 1 copies
|
||||
# the remaining rows 64-97 (33 rows total).
|
||||
smem_start = group_info.start_within_block - cluster_idx * block_tile_m
|
||||
smem_start = lax.max(smem_start, jnp.int32(0))
|
||||
def _clamp(min, x, max):
|
||||
return lax.max(lax.min(x, max), min)
|
||||
block0_copy_size = _clamp(
|
||||
jnp.int32(0),
|
||||
block_tile_m - group_info.start_within_block,
|
||||
group_info.actual_size)
|
||||
block_local_size = lax.select(is_lead_block,
|
||||
# block 0 copies up to end of the first block or actual_size,
|
||||
# whichever comes first.
|
||||
block0_copy_size,
|
||||
# block 1 copies the remaining rows that block 0 did not copy.
|
||||
group_info.actual_size - block0_copy_size
|
||||
)
|
||||
for ni in range(tile_n // epilogue_tile_n):
|
||||
acc_smem[...] = plgpu.async_load_tmem(
|
||||
acc_tmem_slot.at[:, pl.ds(ni * epilogue_tile_n, epilogue_tile_n)],
|
||||
layout=regs_layout).astype(dtype)
|
||||
plgpu.commit_smem()
|
||||
cur_smem_idx = smem_start
|
||||
remaining_rows = min(block_tile_m, m)
|
||||
while remaining_rows > 0:
|
||||
const_rows_len = 1 << int(math.log2(remaining_rows))
|
||||
remaining_rows //= 2
|
||||
@pl.when(block_local_size & const_rows_len != 0)
|
||||
def _():
|
||||
o_smem_slice = acc_smem.at[pl.ds(cur_smem_idx, const_rows_len)]
|
||||
o_gref_slice = step_out_gmem.at[
|
||||
pl.ds(cur_smem_idx, const_rows_len),
|
||||
pl.ds(ni * epilogue_tile_n, epilogue_tile_n),
|
||||
]
|
||||
plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice)
|
||||
cur_smem_idx += block_local_size & const_rows_len
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
plgpu.wait_load_tmem() # Load must complete before we continue.
|
||||
plgpu.barrier_arrive(store_done_barrier.at[acc_slot])
|
||||
|
||||
|
||||
def ragged_dot_kernel(a, b, group_sizes, config: TuningConfig):
|
||||
dtype = a.dtype
|
||||
if a.dtype != b.dtype:
|
||||
raise ValueError(
|
||||
f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}"
|
||||
)
|
||||
m, k = a.shape
|
||||
num_groups, k2, n = b.shape
|
||||
if num_groups != group_sizes.shape[0]:
|
||||
raise ValueError("RHS and group_sizes have incompatible shapes.")
|
||||
if k != k2:
|
||||
raise ValueError(
|
||||
"Matmul LHS and RHS have incompatible shapes "
|
||||
f"{a.shape} vs {b.shape[1:]}"
|
||||
)
|
||||
collective = config.collective
|
||||
tile_m, tile_n, tile_k = (config.tile_m, config.tile_n, config.tile_k)
|
||||
block_tile_m = tile_m
|
||||
block_tile_n = tile_n
|
||||
if collective:
|
||||
tile_m *= 2
|
||||
tile_n *= 2
|
||||
m_iters = m // tile_m
|
||||
n_iters = n // tile_n
|
||||
|
||||
max_concurrent_steps = config.max_concurrent_steps
|
||||
epilogue_tile_n = config.epilogue_tile_n
|
||||
if tile_n % epilogue_tile_n != 0:
|
||||
raise ValueError(
|
||||
f"{tile_n=} must be divisible by {epilogue_tile_n=}"
|
||||
)
|
||||
|
||||
if m % tile_m != 0:
|
||||
raise ValueError(f"{m=} must be divisible by {tile_m=}")
|
||||
if n % tile_n != 0:
|
||||
raise ValueError(f"{n=} must be divisible by {tile_n=}")
|
||||
if k % tile_k != 0:
|
||||
raise ValueError(f"{k=} must be divisible by {tile_k=}")
|
||||
swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
|
||||
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
|
||||
transforms = (
|
||||
plgpu.TilingTransform((8, swizzle_elems)),
|
||||
plgpu.SwizzleTransform(swizzle),
|
||||
)
|
||||
|
||||
def kernel(a_gmem, b_gmem, group_sizes_gmem, out_gmem):
|
||||
linear_grid = (m_iters + num_groups - 1) * n_iters
|
||||
group_sizes_regs = [group_sizes_gmem[i] for i in range(num_groups)]
|
||||
cluster_idx = lax.axis_index("x")
|
||||
|
||||
@functools.partial(pl.run_scoped,
|
||||
a_smem=plgpu.SMEM(
|
||||
(max_concurrent_steps, block_tile_m, tile_k),
|
||||
dtype, transforms=transforms
|
||||
),
|
||||
b_smem=plgpu.SMEM(
|
||||
(max_concurrent_steps, tile_k, block_tile_n),
|
||||
dtype, transforms=transforms
|
||||
),
|
||||
# Temporary SMEM used for storing accumulator output to GMEM.
|
||||
acc_smem=plgpu.SMEM(
|
||||
(block_tile_m, epilogue_tile_n), dtype),
|
||||
# a/b_tma_barrier
|
||||
a_tma_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps),
|
||||
b_tma_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=max_concurrent_steps),
|
||||
# store_done_barrier, double-buffered
|
||||
store_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=2,
|
||||
orders_tensor_core=True),
|
||||
# mma_done_barrier, double-buffered
|
||||
mma_done_barrier=plgpu.Barrier(num_arrivals=1, num_barriers=2,
|
||||
orders_tensor_core=True),
|
||||
# consumed_barrier
|
||||
consumed_barrier=plgpu.Barrier(
|
||||
num_arrivals=1,
|
||||
num_barriers=max_concurrent_steps,
|
||||
orders_tensor_core=True,
|
||||
),
|
||||
# Accumulator TMEM (double-buffered)
|
||||
acc_tmem=plgpu.TMEM(
|
||||
(block_tile_m, tile_n * 2), jnp.float32, collective=collective),
|
||||
collective_axes=("wg",)
|
||||
)
|
||||
def _scoped(**ref_kwargs):
|
||||
@plgpu.nd_loop(grid=(linear_grid,),
|
||||
collective_axes="sm")
|
||||
def mn_loop(loop_info: plgpu.NDLoopInfo):
|
||||
linear_idx, = loop_info.index
|
||||
local_index = loop_info.local_index
|
||||
m_index, n_index = plgpu.planar_snake(
|
||||
linear_idx,
|
||||
(m_iters + num_groups - 1, n_iters),
|
||||
config.grid_minor_dim,
|
||||
config.grid_tile_width,
|
||||
)
|
||||
with jax.named_scope("create_group_info"):
|
||||
group_info = ragged_dot_mgpu.GroupInfo.create(
|
||||
group_sizes_regs, tile_m, m_index
|
||||
)
|
||||
do_matmul(
|
||||
a_gmem,
|
||||
b_gmem.at[group_info.group_id],
|
||||
out_gmem,
|
||||
grid_indices=(group_info.block, n_index, cluster_idx),
|
||||
wg_axis="wg",
|
||||
collective_axes=("x",) if collective else (),
|
||||
local_index=local_index,
|
||||
config=config,
|
||||
group_info=group_info,
|
||||
**ref_kwargs
|
||||
)
|
||||
|
||||
num_sms = jax.local_devices()[0].core_count
|
||||
compiler_params = None
|
||||
f = plgpu.kernel(
|
||||
kernel,
|
||||
compiler_params=compiler_params,
|
||||
kernel_name=f"ragged_dot_kernel_{str(config)}",
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
|
||||
grid=(num_sms//2,) if collective else (num_sms,),
|
||||
grid_names=("sm",),
|
||||
num_threads=2,
|
||||
thread_name="wg",
|
||||
cluster_names=("x",) if collective else (),
|
||||
cluster=(2,) if collective else (),
|
||||
)
|
||||
return f(a, b, group_sizes)
|
||||
|
||||
|
||||
def ragged_dot_reference(a, b, g):
|
||||
return lax.ragged_dot(a, b, g, preferred_element_type=jnp.float16)
|
||||
|
||||
|
||||
def sample_group_sizes(key: jax.Array,
|
||||
num_groups: int,
|
||||
num_elements: int,
|
||||
alpha: float = 10.0,
|
||||
):
|
||||
"""Sample group sizes.
|
||||
|
||||
Args:
|
||||
key: PRNG key.
|
||||
num_groups: Number of groups to sample.
|
||||
num_elements: Total number of elements to sample.
|
||||
alpha: Shape parameter. The lower the alpha, the more imbalanced the
|
||||
group sizes will be. As alpha approaches infinity, the group sizes
|
||||
approach a uniform distribution.
|
||||
|
||||
Returns:
|
||||
A jax.Array of shape (num_groups,) that sums to num_elements.
|
||||
"""
|
||||
probs_key, sample_key = jax.random.split(key)
|
||||
probs = jax.random.dirichlet(probs_key, jnp.ones((num_groups,)) * alpha)
|
||||
return jax.random.multinomial(
|
||||
sample_key, num_elements, probs).astype(jnp.int32)
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
M = 16 * 1024
|
||||
K = 2048
|
||||
N = 16 * 1024
|
||||
num_groups = 16
|
||||
group_sizes = sample_group_sizes(jax.random.key(0), num_groups, M, alpha=10.0)
|
||||
|
||||
print(f"==== {M=} {N=} {K=} {num_groups=}====")
|
||||
matmul_flops = 2 * M * N * K
|
||||
peak_flops = 2.25e15 # f16 TensorCore peak = 2250 TFLOPS
|
||||
a = jax.random.uniform(jax.random.key(1), (M, K), jnp.float16)
|
||||
b = jax.random.uniform(jax.random.key(2), (num_groups, K, N), jnp.float16)
|
||||
|
||||
tuning_it = itertools.product(
|
||||
(128,), # tile_m
|
||||
(128,), # tile_n
|
||||
(64,), # tile_k
|
||||
(1, 8, 12, 16), # grid_tile_width
|
||||
blackwell_matmul_mgpu.MatmulDimension, # grid_minor_dim
|
||||
(4, 6) # max_concurrent_steps
|
||||
)
|
||||
best_util = -float("inf")
|
||||
for (tile_m, tile_n, tile_k, grid_tile_width, grid_minor_dim,
|
||||
max_concurrent_steps,) in tuning_it:
|
||||
config = TuningConfig(
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
grid_tile_width=grid_tile_width,
|
||||
grid_minor_dim=grid_minor_dim,
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
collective=True,
|
||||
)
|
||||
try:
|
||||
out, runtime_ms = profiler.measure(
|
||||
functools.partial(ragged_dot_kernel, config=config),
|
||||
iterations=10
|
||||
)(a, b, group_sizes)
|
||||
runtime_ms = np.median(runtime_ms if runtime_ms else [])
|
||||
except ValueError as e:
|
||||
if ("exceeds available shared memory" in e.args[0] or
|
||||
"Accumulator layout mismatch:" in e.args[0]):
|
||||
print(e.args[0])
|
||||
continue
|
||||
raise
|
||||
expected = ragged_dot_reference(a, b, group_sizes)
|
||||
np.testing.assert_allclose(out, expected)
|
||||
|
||||
runtime_us = runtime_ms * 1e3
|
||||
optimal_time = matmul_flops / peak_flops * 1e6 # us
|
||||
achieved_tc_util = optimal_time / runtime_us * 100
|
||||
if achieved_tc_util > best_util:
|
||||
best_util = achieved_tc_util
|
||||
print(
|
||||
f"{tile_m=} {tile_n=} {tile_k=} {grid_tile_width=} {grid_minor_dim=} {max_concurrent_steps=} "
|
||||
f"{runtime_us:<7.1f}us"
|
||||
f" = {achieved_tc_util:4.1f}% TC utilization"
|
||||
)
|
||||
print(f"\tBest utilization: {best_util:4.1f}%")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
+290
@@ -0,0 +1,290 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A collective matmul kernel implemented using Mosaic GPU."""
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
import jax
|
||||
import os
|
||||
from jax import lax
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.pallas import mosaic_gpu as plgpu
|
||||
from jax.experimental.pallas.ops.gpu import hopper_matmul_mgpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
MatmulDimension = hopper_matmul_mgpu.MatmulDimension
|
||||
TuningConfig = hopper_matmul_mgpu.TuningConfig
|
||||
|
||||
|
||||
def is_nvshmem_used() -> bool:
|
||||
return (
|
||||
"XLA_FLAGS" in os.environ
|
||||
and "--xla_gpu_experimental_enable_nvshmem" in os.environ["XLA_FLAGS"]
|
||||
)
|
||||
|
||||
def all_gather_lhs_matmul(
|
||||
lhs: jax.Array,
|
||||
rhs: jax.Array,
|
||||
axis_name,
|
||||
*,
|
||||
config: hopper_matmul_mgpu.TuningConfig,
|
||||
dtype: jnp.dtype = jnp.float16,
|
||||
) -> jax.Array:
|
||||
if (
|
||||
num_devices := jax.device_count()
|
||||
) != jax.process_count() and num_devices != jax.local_device_count():
|
||||
raise ValueError(
|
||||
"Kernel requires either 1 process per single GPU or 1 process per all"
|
||||
" GPUs."
|
||||
)
|
||||
if (axis_size := lax.axis_size(axis_name)) != num_devices:
|
||||
raise ValueError("The kernel can only work over all devices in a Mesh.")
|
||||
if jnp.dtype(dtype) not in map(jnp.dtype, [jnp.float16, jnp.bfloat16]):
|
||||
raise NotImplementedError(f"Only f16 and bf16 are supported, got {dtype=}")
|
||||
if config.cluster_dimension is not None:
|
||||
raise NotImplementedError("Cluster dimension must be None for all-gather matmuls.")
|
||||
|
||||
m_shard, k = lhs.shape
|
||||
k2, n_shard = rhs.shape
|
||||
if k != k2:
|
||||
raise ValueError(
|
||||
f"lhs and rhs must have the same contraction size, got {k} and {k2}."
|
||||
)
|
||||
if (element_type := lhs.dtype) != rhs.dtype:
|
||||
raise ValueError(
|
||||
f"lhs and rhs must have the same element type, got {element_type} and"
|
||||
f" {rhs.dtype}."
|
||||
)
|
||||
tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
|
||||
max_concurrent_steps = config.max_concurrent_steps
|
||||
if max_concurrent_steps < 2:
|
||||
raise ValueError("max_concurrent_steps must be >= 2")
|
||||
cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M))
|
||||
|
||||
epi_tile_n = config.epi_tile_n or tile_n
|
||||
epi_tile_m = config.epi_tile_m or tile_m
|
||||
if tile_n % epi_tile_n != 0:
|
||||
raise ValueError(f"{tile_n=} must be divisible by {epi_tile_n=}")
|
||||
if tile_m % epi_tile_m != 0:
|
||||
raise ValueError(f"{tile_m=} must be divisible by {epi_tile_m=}")
|
||||
|
||||
num_sms = jax.devices()[0].core_count # 132 for H100 SXM GPUs.
|
||||
|
||||
def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref):
|
||||
received_sem = pl.get_global(plgpu.SemaphoreType.REGULAR)
|
||||
wg_idx = lax.axis_index("wg")
|
||||
dev_id = lax.axis_index(axis_name)
|
||||
send_dev_id = lax.rem(dev_id + axis_size - 1, jnp.int32(axis_size))
|
||||
send_scratch_ref = plgpu.remote_ref(scratch_ref, send_dev_id)
|
||||
|
||||
def send_lhs(m_idx, n_idx, k_idx, a_smem, b_smem, send_ref, should_send):
|
||||
del b_smem # Unused.
|
||||
# We only send when n_idx == 0 to avoid sending the same data
|
||||
# multiple times when revisiting lhs.
|
||||
@pl.when(should_send & jnp.bool(n_idx == 0))
|
||||
def _():
|
||||
k_slice = pl.ds(k_idx * tile_k, tile_k)
|
||||
m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m)
|
||||
plgpu.copy_smem_to_gmem(a_smem, send_ref.at[m_slice, k_slice])
|
||||
# We only delay release by 1 step, so we need to wait for the
|
||||
# previous copies.
|
||||
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
|
||||
|
||||
def device_step(lhs_source_ref, device_offset):
|
||||
# Invariant: lhs_source_ref is ready to be used
|
||||
next_scratch_slot = device_offset
|
||||
out_device_m_slice = pl.ds(
|
||||
lax.rem(device_offset + dev_id, jnp.int32(num_devices)) * m_shard,
|
||||
m_shard,
|
||||
)
|
||||
is_send_wg = wg_idx == 0
|
||||
has_send_space = next_scratch_slot < num_devices - 1
|
||||
should_send = is_send_wg & has_send_space
|
||||
|
||||
# This reuses the regular matmul kernel, only with the exception of
|
||||
# inserting send_lhs into the pipeline.
|
||||
# TODO(apaszke): This contains run_scoped inside, meaning that it will
|
||||
# synchronize all threads at each device step. If we optimize the barrier
|
||||
# below, then it might be better to move it out to make bubbles smaller.
|
||||
hopper_matmul_mgpu.kernel(
|
||||
lhs_source_ref, # Use the lhs from previous step.
|
||||
rhs_ref, # Use the same rhs for all steps.
|
||||
None, # No C.
|
||||
out_ref.at[out_device_m_slice], # Use a slice of the output.
|
||||
config=config,
|
||||
pipeline_callback=functools.partial(
|
||||
send_lhs,
|
||||
send_ref=send_scratch_ref.at[next_scratch_slot],
|
||||
should_send=should_send,
|
||||
),
|
||||
delay_release=1,
|
||||
)
|
||||
|
||||
# Wait for the next scratch to arrive --- see the device loop invariant.
|
||||
@pl.when(should_send)
|
||||
def _signal():
|
||||
# TODO(apaszke): We could do this signal a lot earlier if we better
|
||||
# control the order of sends. If we tile the grid along N, then we can
|
||||
# signal as soon as everyone moves on from the first column tile.
|
||||
# Make sure the copy is done and signal the receiving device.
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=False)
|
||||
pl.semaphore_signal(received_sem, device_id=send_dev_id)
|
||||
@pl.when(next_scratch_slot < num_devices - 1)
|
||||
def _wait():
|
||||
pl.semaphore_wait(received_sem, value=(device_offset + 1) * num_sms, decrement=False)
|
||||
|
||||
# We peel the first step to copy data directly form lhs_local_ref.
|
||||
device_step(lhs_local_ref, 0)
|
||||
@pl.loop(1, num_devices)
|
||||
def _device_loop(device_offset):
|
||||
device_step(scratch_ref.at[device_offset - 1], device_offset)
|
||||
# Make sure all copies are fully done.
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
|
||||
result, _ = plgpu.kernel(
|
||||
kernel_body,
|
||||
out_shape=[
|
||||
# The output, with its M dimension all-gathered.
|
||||
jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype),
|
||||
# The scratch buffer used for the all-gather.
|
||||
jax.ShapeDtypeStruct((num_devices - 1, m_shard, k), dtype),
|
||||
],
|
||||
grid=(num_sms,),
|
||||
grid_names=("cluster_grid",),
|
||||
num_threads=3,
|
||||
thread_name="wg",
|
||||
cluster=(1,),
|
||||
cluster_names=("cluster",),
|
||||
)(lhs, rhs)
|
||||
return result
|
||||
|
||||
|
||||
def _min_results_across_devices(kernels_ms : list[tuple[str, float]]) -> float:
|
||||
# We choose the minimum across processes to choose the runtime that didn't
|
||||
# include devices waiting for other devices.
|
||||
if is_nvshmem_used():
|
||||
time_us = sum(t * 1e3 for _, t in kernels_ms)
|
||||
return min(multihost_utils.process_allgather(time_us).tolist())
|
||||
|
||||
# profiler.measures measures all devices visible to the process, so we
|
||||
# need to select the mimimum result of each kernel across devices.
|
||||
# This code relies on the fact that with collective metadata a single kernel
|
||||
# with unique name is launched on each device.
|
||||
min_values : dict[str, float] = {}
|
||||
for kernel_name, t in kernels_ms:
|
||||
if kernel_name not in min_values or t < min_values[kernel_name]:
|
||||
min_values[kernel_name] = t
|
||||
return sum(time_ms * 1e3 for time_ms in min_values.values())
|
||||
|
||||
|
||||
def _run_example():
|
||||
P = jax.sharding.PartitionSpec
|
||||
m_shard = 1024
|
||||
n_shard = 4096
|
||||
k = 4096
|
||||
dtype = jnp.bfloat16
|
||||
shards = jax.device_count()
|
||||
mesh = jax.make_mesh(
|
||||
(shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
|
||||
)
|
||||
jax.set_mesh(mesh)
|
||||
|
||||
# We measure time per-shard and so we only need FLOPs per shard.
|
||||
matmul_flops = 2 * (shards * m_shard) * n_shard * k
|
||||
peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS
|
||||
optimal_time = matmul_flops / peak_flops * 1e6 # us
|
||||
a = jax.random.normal(jax.random.key(1), (shards * m_shard, k), dtype)
|
||||
b = jax.random.normal(jax.random.key(2), (k, shards * n_shard), dtype)
|
||||
a = jax.sharding.reshard(a, P("x", None))
|
||||
b = jax.sharding.reshard(b, P(None, "x"))
|
||||
_, ref_kernels_ms = profiler.measure(jax.jit(
|
||||
jax.shard_map(
|
||||
lambda x, y: lax.all_gather(x, "x", axis=0, tiled=True) @ y,
|
||||
out_specs=P(None, "x"),
|
||||
check_vma=False,
|
||||
)
|
||||
), aggregate=False)(a, b)
|
||||
|
||||
assert ref_kernels_ms is not None
|
||||
ref_time_us = _min_results_across_devices(ref_kernels_ms)
|
||||
ref_util = optimal_time / ref_time_us * 100
|
||||
|
||||
tuning_it = itertools.product(
|
||||
(128, 256,), # tile_m
|
||||
(64, 128), # tile_n
|
||||
(64,), # tile_k
|
||||
(4,), # max_concurrent_steps
|
||||
(MatmulDimension.M, MatmulDimension.N), # grid_minor_dim
|
||||
(4, 8, 16), # grid_tile_width
|
||||
MatmulDimension, # wg_dimension
|
||||
)
|
||||
best_util = 0.0
|
||||
best_runtime = float("inf")
|
||||
def build_kernel(**kwargs):
|
||||
return jax.jit(
|
||||
jax.shard_map(
|
||||
functools.partial(all_gather_lhs_matmul, **kwargs),
|
||||
out_specs=P(None, "x"),
|
||||
check_vma=False,
|
||||
)
|
||||
)
|
||||
|
||||
for tile_m, tile_n, tile_k, max_concurrent_steps, grid_minor_dim, grid_tile_width, wg_dimension in tuning_it:
|
||||
try:
|
||||
config = TuningConfig(
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
grid_minor_dim=grid_minor_dim,
|
||||
grid_tile_width=grid_tile_width,
|
||||
wg_dimension=wg_dimension,
|
||||
)
|
||||
_, kernels_ms = profiler.measure(
|
||||
build_kernel(axis_name="x", config=config, dtype=dtype),
|
||||
aggregate=False,
|
||||
)(a, b)
|
||||
except ValueError as e:
|
||||
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
|
||||
continue
|
||||
raise
|
||||
assert kernels_ms is not None
|
||||
runtime_us = _min_results_across_devices(kernels_ms)
|
||||
achieved_tc_util = optimal_time / runtime_us * 100
|
||||
if achieved_tc_util > best_util:
|
||||
best_runtime = runtime_us
|
||||
best_util = achieved_tc_util
|
||||
print(
|
||||
f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=}: "
|
||||
f"{runtime_us:<7.1f}us"
|
||||
f" = {achieved_tc_util:4.1f}% TC utilization"
|
||||
)
|
||||
print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization")
|
||||
print(f"\tRef: {ref_time_us:<7.1f}us = {ref_util:4.1f}% TC utilization")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if is_nvshmem_used():
|
||||
from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error
|
||||
jt_multiprocess.main(shard_main=_run_example)
|
||||
else:
|
||||
from jax._src.config import config as jax_config
|
||||
from absl import app
|
||||
jax_config.config_with_absl()
|
||||
app.run(lambda _: _run_example())
|
||||
+503
@@ -0,0 +1,503 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module containing decode attention."""
|
||||
from __future__ import annotations
|
||||
import math
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
def attn_forward_kernel(
|
||||
# inputs
|
||||
q_ref, # [num_heads, head_dim]
|
||||
k_ref, # [k_seq_len, head_dim]
|
||||
v_ref, # [k_seq_len, head_dim]
|
||||
start_idx_ref, # [] (i.e., scalar)
|
||||
kv_seq_len_ref, # [] (i.e., scalar)
|
||||
# outputs
|
||||
o_ref: Any, # [num_heads, head_dim]
|
||||
*residual_refs: Any, # Residual outputs: [num_heads,], [num_heads,]
|
||||
sm_scale: float,
|
||||
block_k: int,
|
||||
block_h: int,
|
||||
num_heads: int,
|
||||
):
|
||||
_, head_dim = q_ref.shape
|
||||
split_k_seq_len, _ = k_ref.shape
|
||||
prog_i, prog_j = pl.program_id(0), pl.program_id(1)
|
||||
q_slice = pl.ds(0, block_h)
|
||||
q_mask = (jnp.arange(block_h) < num_heads - block_h * prog_i)[:, None]
|
||||
|
||||
def _compute(start_idx, kv_seq_len, o, m_i, l_i):
|
||||
# Load q: it will stay in L1 throughout. Indices form a matrix because we
|
||||
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
|
||||
# q tile has shape [block_h, head_dim].
|
||||
q = plgpu.load(q_ref.at[q_slice, :], mask=q_mask)
|
||||
|
||||
def _dot(a, b):
|
||||
# if a.shape[0] == 1:
|
||||
# # Use matrix vector product
|
||||
# return (a.T * b).sum(axis=0, keepdims=True)
|
||||
return pl.dot(a, b)
|
||||
|
||||
mask_indices = jnp.arange(block_k)
|
||||
|
||||
# Loop over blocks of kv to process entire kv seq_len.
|
||||
# Grid loops over q blocks over num_heads.
|
||||
def body(start_k, carry):
|
||||
o_prev, m_prev, l_prev = carry
|
||||
curr_k_slice = pl.ds(start_k * block_k, block_k)
|
||||
|
||||
k = k_ref[curr_k_slice, :]
|
||||
qk = _dot(q, k.T) # [block_h, block_k]
|
||||
if sm_scale != 1.0:
|
||||
qk *= sm_scale # [block_h, block_k]
|
||||
|
||||
# apply mask if start or sequence length is specified
|
||||
if start_idx_ref is not None or kv_seq_len_ref is not None:
|
||||
indices = (prog_j * split_k_seq_len + start_k * block_k + mask_indices)
|
||||
mask = ((indices >= start_idx) & (indices < kv_seq_len))[None, :]
|
||||
qk += (~mask) * (0.7 * jnp.finfo(qk.dtype).min)
|
||||
|
||||
m_curr = qk.max(axis=-1)
|
||||
m_next = jnp.maximum(m_prev, m_curr)
|
||||
correction = jnp.exp(m_prev - m_next)
|
||||
l_prev_corr = correction * l_prev
|
||||
s_curr = jnp.exp(
|
||||
qk - m_next[:, None]
|
||||
) # Use m_next instead of m_curr to avoid a correction on l_curr
|
||||
l_curr = s_curr.sum(axis=-1)
|
||||
l_next = l_prev_corr + l_curr
|
||||
v = v_ref[curr_k_slice, :]
|
||||
o_curr = _dot(s_curr.astype(v.dtype), v)
|
||||
|
||||
# flash2 unscaled_o
|
||||
o_next = correction[:, None] * o_prev + o_curr
|
||||
return o_next, m_next, l_next
|
||||
|
||||
max_it = jnp.minimum(pl.cdiv((kv_seq_len - prog_j * split_k_seq_len),
|
||||
block_k), split_k_seq_len // block_k)
|
||||
(o, m_i, l_i) = lax.fori_loop(0, max_it, body, (o, m_i, l_i))
|
||||
return o, m_i, l_i
|
||||
|
||||
# o is the buffer where we accumulate the output on sram.
|
||||
# m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop.
|
||||
m_i = jnp.zeros(block_h, dtype=jnp.float32) + jnp.finfo(jnp.float32).min
|
||||
l_i = jnp.zeros(block_h, dtype=jnp.float32)
|
||||
o = jnp.zeros((block_h, head_dim), dtype=jnp.float32)
|
||||
|
||||
start_idx = split_k_seq_len * prog_j
|
||||
if start_idx_ref is not None:
|
||||
start_idx = jnp.maximum(start_idx, start_idx_ref[()])
|
||||
kv_seq_len = (prog_j + 1) * split_k_seq_len # lower bound on actual k_seq_len
|
||||
if kv_seq_len_ref is not None:
|
||||
kv_seq_len = jnp.minimum(kv_seq_len, kv_seq_len_ref[()])
|
||||
|
||||
if start_idx_ref is None and kv_seq_len is None:
|
||||
o, m_i, l_i = _compute(start_idx, kv_seq_len, o, m_i, l_i)
|
||||
else:
|
||||
o, m_i, l_i = jax.lax.cond(
|
||||
start_idx >= kv_seq_len, lambda: (o, m_i, l_i),
|
||||
lambda: _compute(start_idx, kv_seq_len, o, m_i, l_i))
|
||||
|
||||
# Write output to dram.
|
||||
if residual_refs:
|
||||
l_ref, m_ref = residual_refs
|
||||
vec_q_mask = q_mask.reshape(-1) if q_mask is not None else None
|
||||
plgpu.store(l_ref.at[q_slice], l_i, mask=vec_q_mask)
|
||||
plgpu.store(m_ref.at[q_slice], m_i, mask=vec_q_mask)
|
||||
o = o.astype(o_ref.dtype)
|
||||
plgpu.store(o_ref.at[q_slice, :], o, mask=q_mask)
|
||||
|
||||
|
||||
def decode_attn_unbatched(
|
||||
q, # [num_heads, head_dim]
|
||||
k, # [k_seq_len, head_dim]
|
||||
v, # [k_seq_len, head_dim]
|
||||
start_idx, # []
|
||||
kv_seq_len, # []
|
||||
sm_scale: float,
|
||||
block_h: int,
|
||||
block_k: int,
|
||||
k_splits: int,
|
||||
num_warps: int | None,
|
||||
num_stages: int,
|
||||
grid: tuple[int, ...] | None,
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
return_residuals: bool,
|
||||
normalize_output: bool
|
||||
):
|
||||
num_heads, head_dim = q.shape
|
||||
k_seq_len, _ = k.shape
|
||||
# Pad num query heads to 16 if needed, and slice output at the end.
|
||||
head_splits = pl.cdiv(num_heads, block_h)
|
||||
grid_ = grid
|
||||
if grid_ is None:
|
||||
grid_ = (head_splits, k_splits)
|
||||
|
||||
assert (
|
||||
k_seq_len % k_splits == 0
|
||||
), f"{k_seq_len=} must be divisible by {k_splits=}"
|
||||
assert k_seq_len // k_splits >= 16, (
|
||||
f"{k_seq_len=} divided by {k_splits=} must be >= 16.")
|
||||
assert block_k >= 16, "block_k must be >= 16"
|
||||
k = k.reshape(k_splits, k_seq_len // k_splits, head_dim)
|
||||
v = v.reshape(k_splits, k_seq_len // k_splits, head_dim)
|
||||
split_k_seq_len = k_seq_len // k_splits
|
||||
block_k = min(block_k, split_k_seq_len)
|
||||
assert split_k_seq_len % block_k == 0, (
|
||||
f"Sequence length ({k_seq_len=}) split by {k_splits=} must by divisible by"
|
||||
f" {block_k=}")
|
||||
num_warps_ = num_warps
|
||||
if num_warps_ is None:
|
||||
num_warps_ = 4
|
||||
kernel = functools.partial(
|
||||
attn_forward_kernel,
|
||||
sm_scale=sm_scale,
|
||||
block_k=block_k,
|
||||
block_h=block_h,
|
||||
num_heads=num_heads,
|
||||
)
|
||||
|
||||
o, l, m = pl.pallas_call(
|
||||
kernel,
|
||||
grid=grid_,
|
||||
in_specs=[
|
||||
pl.BlockSpec((block_h, head_dim), lambda i, j: (i, 0)),
|
||||
pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)),
|
||||
pl.BlockSpec((None, split_k_seq_len, head_dim), lambda i, j: (j, 0, 0)),
|
||||
]
|
||||
+ [None if start_idx is None else pl.BlockSpec((), lambda i, j: ())]
|
||||
+ [None if kv_seq_len is None else pl.BlockSpec((), lambda i, j: ())],
|
||||
out_specs=[
|
||||
pl.BlockSpec((None, block_h, head_dim), lambda i, j: (j, i, 0)), # o
|
||||
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l
|
||||
pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m
|
||||
],
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
num_warps=num_warps_, num_stages=num_stages
|
||||
),
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o
|
||||
jax.ShapeDtypeStruct(
|
||||
shape=(k_splits, num_heads), dtype=jnp.float32
|
||||
), # l
|
||||
jax.ShapeDtypeStruct(
|
||||
shape=(k_splits, num_heads), dtype=jnp.float32
|
||||
), # m
|
||||
],
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
name="mha_forward",
|
||||
)(q, k, v, start_idx, kv_seq_len)
|
||||
|
||||
# final round of flash
|
||||
m_next = m.max(axis=0)
|
||||
# TODO(b/389925439): This barrier is necessary to prevent NaNs/invalid
|
||||
# values appearing after JIT compilation.
|
||||
m_next = lax.optimization_barrier(m_next)
|
||||
correction = jnp.exp(m - m_next[None])
|
||||
o = o * correction[:, :, None].astype(o.dtype)
|
||||
l_next = (l * correction).sum(axis=0)
|
||||
eps = jnp.finfo(l_next.dtype).eps
|
||||
o = o.sum(axis=0)
|
||||
if normalize_output:
|
||||
o /= (l_next[:, None].astype(o.dtype) + eps)
|
||||
if return_residuals:
|
||||
return o, (l_next, m_next)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"sm_scale",
|
||||
"block_h",
|
||||
"block_k",
|
||||
"k_splits",
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
"grid",
|
||||
"interpret",
|
||||
"debug",
|
||||
"return_residuals",
|
||||
"normalize_output"
|
||||
],
|
||||
)
|
||||
def mqa(
|
||||
q, # [batch_size, num_heads, head_dim]
|
||||
k, # [batch_size, k_seq_len, head_dim]
|
||||
v, # [batch_size, k_seq_len, head_dim]
|
||||
start_idx=None, # [batch_size]
|
||||
kv_seq_len=None, # [batch_size]
|
||||
sm_scale: float | None = None,
|
||||
block_h: int = 16,
|
||||
block_k: int = 256,
|
||||
k_splits: int = 16,
|
||||
num_warps: int | None = None,
|
||||
num_stages: int = 2,
|
||||
grid: tuple[int, ...] | None = None,
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
return_residuals: bool = False,
|
||||
normalize_output: bool = True,
|
||||
):
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
bs = q.shape[0]
|
||||
if start_idx is not None:
|
||||
start_idx = jnp.broadcast_to(start_idx, (bs,))
|
||||
if kv_seq_len is not None:
|
||||
kv_seq_len = jnp.broadcast_to(kv_seq_len, (bs,))
|
||||
inner = functools.partial(
|
||||
decode_attn_unbatched,
|
||||
sm_scale=sm_scale,
|
||||
block_h=block_h,
|
||||
block_k=block_k,
|
||||
k_splits=k_splits,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
grid=grid,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
return_residuals=return_residuals,
|
||||
normalize_output=normalize_output,
|
||||
)
|
||||
return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len)
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"sm_scale",
|
||||
"block_h",
|
||||
"block_k",
|
||||
"k_splits",
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
"grid",
|
||||
"interpret",
|
||||
"debug",
|
||||
"return_residuals",
|
||||
"normalize_output"
|
||||
],
|
||||
)
|
||||
def gqa(
|
||||
q, # [batch_size, num_q_heads, head_dim]
|
||||
k, # [batch_size, k_seq_len, num_kv_heads, head_dim]
|
||||
v, # [batch_size, k_seq_len, num_kv_heads, head_dim]
|
||||
start_idx=None, # [batch_size]
|
||||
kv_seq_len=None, # [batch_size]
|
||||
sm_scale: float | None = None,
|
||||
block_h: int = 16,
|
||||
block_k: int = 128,
|
||||
k_splits: int = 16,
|
||||
num_warps: int | None = None,
|
||||
num_stages: int = 2,
|
||||
grid: tuple[int, ...] | None = None,
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
return_residuals: bool = False,
|
||||
normalize_output: bool = True,
|
||||
):
|
||||
if not normalize_output and not return_residuals:
|
||||
raise NotImplementedError(
|
||||
"When normalize_output is False, attention residuals must be returned."
|
||||
)
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
batch_size, q_heads, head_dim = q.shape
|
||||
_k_seq_len, kv_heads = k.shape[1], k.shape[2]
|
||||
assert kv_heads == v.shape[2]
|
||||
assert q_heads % kv_heads == 0
|
||||
if start_idx is not None:
|
||||
assert start_idx.ndim in (0, 1)
|
||||
start_idx = jnp.broadcast_to(jnp.asarray(start_idx)[..., None],
|
||||
(batch_size, kv_heads))
|
||||
if kv_seq_len is not None:
|
||||
assert kv_seq_len.ndim in (0, 1)
|
||||
kv_seq_len = jnp.broadcast_to(jnp.asarray(kv_seq_len)[..., None],
|
||||
(batch_size, kv_heads))
|
||||
q_heads_per_kv_head = q_heads // kv_heads
|
||||
q_reshaped = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim)
|
||||
k_transposed = jnp.swapaxes(
|
||||
k, 1, 2
|
||||
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
|
||||
v_transposed = jnp.swapaxes(
|
||||
v, 1, 2
|
||||
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
|
||||
inner = functools.partial(
|
||||
decode_attn_unbatched,
|
||||
sm_scale=sm_scale,
|
||||
block_h=block_h,
|
||||
block_k=block_k,
|
||||
k_splits=k_splits,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
grid=grid,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
return_residuals=return_residuals,
|
||||
normalize_output=normalize_output,
|
||||
)
|
||||
with_kv_heads = jax.vmap(inner)
|
||||
outputs = jax.vmap(with_kv_heads)(
|
||||
q_reshaped, k_transposed, v_transposed, start_idx, kv_seq_len
|
||||
)
|
||||
if return_residuals:
|
||||
o, (l, m) = outputs
|
||||
o = o.reshape(batch_size, q_heads, head_dim)
|
||||
l = l.reshape(batch_size, q_heads)
|
||||
m = m.reshape(batch_size, q_heads)
|
||||
return o, (l, m)
|
||||
else:
|
||||
o = outputs
|
||||
o = o.reshape(batch_size, q_heads, head_dim)
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["sm_scale", "return_residuals", "normalize_output"],
|
||||
)
|
||||
def mqa_reference(
|
||||
q, # [bs, num_q_heads, head_dim]
|
||||
k, # [bs, k_seq_len, head_dim]
|
||||
v, # [bs, k_seq_len, head_dim]
|
||||
start_idx=None, # [bs]
|
||||
kv_seq_len=None, # [bs]
|
||||
sm_scale=None,
|
||||
return_residuals=False,
|
||||
normalize_output=True,
|
||||
):
|
||||
original_dtype = q.dtype
|
||||
q = q.astype(jnp.float32)
|
||||
k = k.astype(jnp.float32)
|
||||
bs = q.shape[0]
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32)
|
||||
if sm_scale is not None and sm_scale != 1.0:
|
||||
logits = logits * sm_scale
|
||||
if start_idx is not None or kv_seq_len is not None:
|
||||
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
|
||||
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
|
||||
else kv_seq_len, (bs,))
|
||||
mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None])
|
||||
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
|
||||
mask = mask[:, None, :]
|
||||
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
|
||||
|
||||
m = logits.max(axis=-1)
|
||||
s = jnp.exp(logits - m[..., None])
|
||||
l = s.sum(axis=-1)
|
||||
if normalize_output:
|
||||
s = s / l[..., None]
|
||||
o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype)
|
||||
|
||||
if return_residuals:
|
||||
return o, (l, m)
|
||||
else:
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["sm_scale"])
|
||||
def mha_reference(
|
||||
q, # [bs, num_q_heads, head_dim]
|
||||
k, # [bs, k_seq_len, num_k_heads, head_dim]
|
||||
v, # [bs, k_seq_len, num_v_heads, head_dim]
|
||||
start_idx=None, # [bs]
|
||||
kv_seq_len=None, # [bs]
|
||||
sm_scale=None,
|
||||
):
|
||||
bs = q.shape[0]
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
assert q.shape[1] == k.shape[2]
|
||||
logits = jnp.einsum("bnd,bsnd->bns", q, k).astype(jnp.float32)
|
||||
if start_idx is not None or kv_seq_len is not None:
|
||||
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
|
||||
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
|
||||
else kv_seq_len, (bs,))
|
||||
mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None])
|
||||
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
|
||||
mask = mask[:, None, :]
|
||||
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
|
||||
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
|
||||
return jnp.einsum("bns,bsnd->bnd", weights, v)
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=["sm_scale", "return_residuals", "normalize_output"],
|
||||
)
|
||||
def gqa_reference(
|
||||
q, # [bs, num_q_heads, head_dim]
|
||||
k, # [bs, k_seq_len, num_k_heads, head_dim]
|
||||
v, # [bs, k_seq_len, num_v_heads, head_dim]
|
||||
start_idx=None, # [bs]
|
||||
kv_seq_len=None, # [bs]
|
||||
sm_scale=None,
|
||||
return_residuals=False,
|
||||
normalize_output=True
|
||||
):
|
||||
original_dtype = q.dtype
|
||||
q = q.astype(jnp.float32)
|
||||
k = k.astype(jnp.float32)
|
||||
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
|
||||
bs, num_q_heads, head_dim = q.shape
|
||||
num_kv_heads = k.shape[2]
|
||||
assert num_q_heads % num_kv_heads == 0
|
||||
q_reshaped = q.reshape(
|
||||
bs, num_kv_heads, num_q_heads // num_kv_heads, head_dim
|
||||
)
|
||||
k_transposed = jnp.swapaxes(
|
||||
k, 1, 2
|
||||
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
|
||||
v_transposed = jnp.swapaxes(
|
||||
v, 1, 2
|
||||
) # [batch_size, num_kv_heads, k_seq_len, head_dim]
|
||||
logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype(
|
||||
jnp.float32
|
||||
)
|
||||
if sm_scale is not None and sm_scale != 1.0:
|
||||
logits = logits * sm_scale
|
||||
if start_idx is not None or kv_seq_len is not None:
|
||||
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
|
||||
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
|
||||
else kv_seq_len, (bs,))
|
||||
mask = ((jnp.arange(k.shape[1])[None, :] >= start_idx[:, None])
|
||||
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
|
||||
mask = mask[:, None, None, :]
|
||||
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
|
||||
|
||||
m = logits.max(axis=-1)
|
||||
s = jnp.exp(logits - m[..., None])
|
||||
l = s.sum(axis=-1)
|
||||
if normalize_output:
|
||||
s = s / l[..., None]
|
||||
o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype)
|
||||
o = o.reshape(bs, num_q_heads, head_dim)
|
||||
|
||||
if return_residuals:
|
||||
l = l.reshape(bs, num_q_heads)
|
||||
m = m.reshape(bs, num_q_heads)
|
||||
return o, (l, m)
|
||||
else:
|
||||
return o
|
||||
+329
@@ -0,0 +1,329 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Matrix Multiplication kernel for Hopper GPUs."""
|
||||
import statistics
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu # noqa: F401
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
import jax.experimental.pallas as pl
|
||||
import jax.experimental.pallas.mosaic_gpu as plgpu
|
||||
from jax.extend import backend
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MatmulDimension(enum.IntEnum):
|
||||
M = 0
|
||||
N = 1
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TuningConfig:
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
max_concurrent_steps: int
|
||||
epi_tile_n: int | None = 64 # This needs to be lowered for for small N.
|
||||
epi_tile_m: int | None = 64
|
||||
grid_minor_dim: MatmulDimension = MatmulDimension.N
|
||||
grid_tile_width: int = 1
|
||||
wg_dimension: MatmulDimension = MatmulDimension.N
|
||||
cluster_dimension: None | MatmulDimension = None
|
||||
|
||||
|
||||
# pipeline_callback and delay_release are only used for collective matmuls.
|
||||
def kernel(a_gmem, b_gmem, c_gmem, out_gmem, config: TuningConfig,
|
||||
pipeline_callback=None, delay_release=0):
|
||||
dtype = a_gmem.dtype
|
||||
out_dtype = out_gmem.dtype
|
||||
assert b_gmem.dtype == dtype
|
||||
if c_gmem is not None:
|
||||
assert c_gmem.dtype == out_dtype
|
||||
m, k = a_gmem.shape
|
||||
k2, n = b_gmem.shape
|
||||
assert k == k2
|
||||
tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
|
||||
max_concurrent_steps = config.max_concurrent_steps
|
||||
swizzle = plgpu.find_swizzle(tile_k * jnp.dtype(dtype).itemsize * 8)
|
||||
swizzle_elems = swizzle // jnp.dtype(dtype).itemsize
|
||||
transforms = (
|
||||
plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle)
|
||||
)
|
||||
|
||||
cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M))
|
||||
cta_tile_n = tile_n * (1 + (config.wg_dimension == MatmulDimension.N))
|
||||
cluster_tile_m = cta_tile_m * (1 + (config.cluster_dimension == MatmulDimension.M))
|
||||
cluster_tile_n = cta_tile_n * (1 + (config.cluster_dimension == MatmulDimension.N))
|
||||
if m % cluster_tile_m != 0:
|
||||
raise ValueError(f"{m=} must be divisible by {cluster_tile_m} for the given config")
|
||||
if n % cluster_tile_n != 0:
|
||||
raise ValueError(f"{n=} must be divisible by {cluster_tile_n} for the given config")
|
||||
if k % tile_k != 0:
|
||||
raise ValueError(f"{k=} must be divisible by {tile_k=}")
|
||||
m_iters = m // cluster_tile_m
|
||||
n_iters = n // cluster_tile_n
|
||||
k_iters = k // tile_k
|
||||
|
||||
epi_tile_m = config.epi_tile_m or tile_m
|
||||
epi_tile_n = config.epi_tile_n or tile_n
|
||||
# We don't need multiple slots if there's only one epilogue tile.
|
||||
num_out_slots = min(2, (tile_m * tile_n) // (epi_tile_m * epi_tile_n))
|
||||
out_swizzle = plgpu.find_swizzle(epi_tile_n * jnp.dtype(out_dtype).itemsize * 8)
|
||||
out_swizzle_elems = out_swizzle // jnp.dtype(out_dtype).itemsize
|
||||
out_transforms = (
|
||||
plgpu.TilingTransform((8, out_swizzle_elems)),
|
||||
plgpu.SwizzleTransform(out_swizzle),
|
||||
)
|
||||
|
||||
def get_pipeline(pipeline_body, compute_context):
|
||||
return plgpu.emit_pipeline_warp_specialized(
|
||||
pipeline_body,
|
||||
grid=(k_iters,),
|
||||
memory_registers=40,
|
||||
in_specs=[
|
||||
plgpu.BlockSpec(
|
||||
(cta_tile_m, tile_k),
|
||||
lambda k: (0, k),
|
||||
transforms=transforms,
|
||||
memory_space=plgpu.SMEM,
|
||||
delay_release=delay_release,
|
||||
collective_axes=("cluster",)
|
||||
if config.cluster_dimension == MatmulDimension.N
|
||||
else (),
|
||||
),
|
||||
plgpu.BlockSpec(
|
||||
(tile_k, cta_tile_n),
|
||||
lambda k: (k, 0),
|
||||
transforms=transforms,
|
||||
memory_space=plgpu.SMEM,
|
||||
delay_release=delay_release,
|
||||
collective_axes=("cluster",)
|
||||
if config.cluster_dimension == MatmulDimension.M
|
||||
else (),
|
||||
),
|
||||
],
|
||||
wg_axis="wg",
|
||||
num_compute_wgs=2,
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
compute_context=compute_context,
|
||||
)
|
||||
|
||||
# Functions don't influence the allocations necessary to run the pipeline.
|
||||
ignore = lambda *_, **__: None
|
||||
@functools.partial(
|
||||
pl.run_scoped,
|
||||
pipeline_allocs=get_pipeline(ignore, ignore).get_allocations(a_gmem, b_gmem),
|
||||
out_smem=plgpu.SMEM(
|
||||
(2, num_out_slots, epi_tile_m, epi_tile_n),
|
||||
out_dtype,
|
||||
transforms=out_transforms,
|
||||
),
|
||||
c_barrier=None if c_gmem is None else plgpu.Barrier(num_barriers=2 * num_out_slots),
|
||||
collective_axes="wg",
|
||||
)
|
||||
def _pipeline_scope(pipeline_allocs, out_smem, c_barrier):
|
||||
wg_idx = lax.axis_index("wg")
|
||||
cta_idx = lax.axis_index("cluster")
|
||||
@plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
|
||||
def _mn_loop(loop_info: plgpu.NDLoopInfo):
|
||||
(lin_idx,) = loop_info.index
|
||||
m_cluster_idx, n_cluster_idx = plgpu.planar_snake(
|
||||
lin_idx,
|
||||
(m_iters, n_iters),
|
||||
config.grid_minor_dim,
|
||||
config.grid_tile_width,
|
||||
)
|
||||
m_idx = m_cluster_idx
|
||||
n_idx = n_cluster_idx
|
||||
if config.cluster_dimension == MatmulDimension.M:
|
||||
m_idx = m_cluster_idx * 2 + cta_idx
|
||||
elif config.cluster_dimension == MatmulDimension.N:
|
||||
n_idx = n_cluster_idx * 2 + cta_idx
|
||||
cta_m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m)
|
||||
cta_n_slice = pl.ds(n_idx * cta_tile_n, cta_tile_n)
|
||||
if config.wg_dimension == MatmulDimension.M:
|
||||
wg_m_slice = pl.ds(wg_idx * tile_m, tile_m)
|
||||
wg_n_slice = slice(None)
|
||||
else:
|
||||
wg_m_slice = slice(None)
|
||||
wg_n_slice = pl.ds(wg_idx * tile_n, tile_n)
|
||||
|
||||
def compute_context(eval_pipeline):
|
||||
@functools.partial(
|
||||
pl.run_scoped, acc_ref=plgpu.ACC((tile_m, tile_n), jnp.float32)
|
||||
)
|
||||
def _acc_scope(acc_ref):
|
||||
eval_pipeline(acc_ref)
|
||||
acc = acc_ref[...].astype(out_dtype)
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
for epi_mi in range(tile_m // epi_tile_m):
|
||||
for epi_ni in range(tile_n // epi_tile_n):
|
||||
epi_m_slice = slice(epi_mi * epi_tile_m, (epi_mi + 1) * epi_tile_m)
|
||||
epi_n_slice = slice(epi_ni * epi_tile_n, (epi_ni + 1) * epi_tile_n)
|
||||
slot = (epi_mi * (tile_n // epi_tile_n) + epi_ni) % 2
|
||||
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
|
||||
if c_gmem is None:
|
||||
out_smem[wg_idx, slot] = acc[epi_m_slice, epi_n_slice]
|
||||
else:
|
||||
# TODO: Consider using triple-buffering so to not end up issuing
|
||||
# the copy and immediately blocking on it
|
||||
plgpu.copy_gmem_to_smem(
|
||||
c_gmem.at[cta_m_slice, cta_n_slice]
|
||||
.at[wg_m_slice, wg_n_slice]
|
||||
.at[epi_m_slice, epi_n_slice],
|
||||
out_smem.at[wg_idx, slot],
|
||||
c_barrier.at[wg_idx * num_out_slots + slot],
|
||||
)
|
||||
plgpu.barrier_wait(c_barrier.at[wg_idx * num_out_slots + slot])
|
||||
out_smem[wg_idx, slot] += acc[epi_m_slice, epi_n_slice]
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(
|
||||
out_smem.at[wg_idx, slot],
|
||||
out_gmem.at[cta_m_slice, cta_n_slice]
|
||||
.at[wg_m_slice, wg_n_slice]
|
||||
.at[epi_m_slice, epi_n_slice],
|
||||
)
|
||||
|
||||
def mma_body(idxs, a_smem, b_smem, acc_ref):
|
||||
plgpu.wgmma(acc_ref, a_smem.at[wg_m_slice], b_smem.at[:, wg_n_slice])
|
||||
if pipeline_callback is not None:
|
||||
(k_idx,) = idxs
|
||||
pipeline_callback(m_idx, n_idx, k_idx, a_smem, b_smem)
|
||||
plgpu.wgmma_wait(delay_release)
|
||||
return acc_ref
|
||||
|
||||
get_pipeline(mma_body, compute_context)(
|
||||
a_gmem.at[cta_m_slice, :],
|
||||
b_gmem.at[:, cta_n_slice],
|
||||
allocations=pipeline_allocs,
|
||||
)
|
||||
# Await all transfers before we exit.
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
|
||||
|
||||
def matmul(a, b, c, config: TuningConfig):
|
||||
dtype = a.dtype
|
||||
if a.dtype != b.dtype:
|
||||
raise ValueError(
|
||||
f"Matmul LHS and RHS have incompatible dtypes {a.dtype} vs {b.dtype}"
|
||||
)
|
||||
m, k = a.shape
|
||||
k2, n = b.shape
|
||||
assert k == k2
|
||||
if k != k2:
|
||||
raise ValueError(
|
||||
f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}"
|
||||
)
|
||||
if c is None:
|
||||
out_dtype = dtype
|
||||
else:
|
||||
if c.shape != (m, n):
|
||||
raise ValueError(f"C has incompatible shape {c.shape} vs {(m, n)}")
|
||||
out_dtype = c.dtype
|
||||
tile_m, tile_n = config.tile_m, config.tile_n
|
||||
epi_tile_n = config.epi_tile_n or tile_n
|
||||
epi_tile_m = config.epi_tile_m or tile_m
|
||||
config = dataclasses.replace(config, epi_tile_n=epi_tile_n, epi_tile_m=epi_tile_m)
|
||||
|
||||
num_sms = backend.get_default_device().core_count
|
||||
cluster_size = 1 + (config.cluster_dimension is not None)
|
||||
f = plgpu.kernel(
|
||||
functools.partial(kernel, config=config),
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), out_dtype),
|
||||
grid=(num_sms // cluster_size,),
|
||||
grid_names=("cluster_grid",),
|
||||
cluster=(cluster_size,),
|
||||
cluster_names=("cluster",),
|
||||
num_threads=3,
|
||||
thread_name="wg",
|
||||
)
|
||||
return f(a, b, c)
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
problem_it = [(4096, 8192, 4096)]
|
||||
for M, N, K in problem_it:
|
||||
print(f"==== {M=} {N=} {K=} ====")
|
||||
matmul_flops = 2 * M * N * K
|
||||
peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS
|
||||
a = jax.random.uniform(jax.random.key(0), (M, K), jnp.float16)
|
||||
b = jax.random.uniform(jax.random.key(1), (K, N), jnp.float16)
|
||||
ref = a @ b
|
||||
tuning_it = itertools.product(
|
||||
(128, 256,), # tile_m
|
||||
(64, 128), # tile_n
|
||||
(64,), # tile_k
|
||||
(4,), # max_concurrent_steps
|
||||
(True,), # Tiled epilogue
|
||||
(MatmulDimension.M, MatmulDimension.N), # grid_minor_dim
|
||||
(4, 8, 16), # grid_tile_width
|
||||
MatmulDimension, # wg_dimension
|
||||
# Consider adding MatmulDimension here to try out collective TMA kernels
|
||||
(None,) # cluster_dimension
|
||||
)
|
||||
best_util = 0.0
|
||||
best_runtime = float("inf")
|
||||
for tile_m, tile_n, tile_k, max_concurrent_steps, tiled_epilogue, grid_minor_dim, grid_tile_width, wg_dimension, cluster_dimension in tuning_it:
|
||||
config = TuningConfig(
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
epi_tile_n=64 if tiled_epilogue else None,
|
||||
epi_tile_m=64 if tiled_epilogue else None,
|
||||
grid_minor_dim=grid_minor_dim,
|
||||
grid_tile_width=grid_tile_width,
|
||||
wg_dimension=wg_dimension,
|
||||
cluster_dimension=cluster_dimension,
|
||||
)
|
||||
try:
|
||||
out, runtimes_ms = profiler.measure(
|
||||
functools.partial(matmul, config=config), iterations=10,
|
||||
)(a, b, None)
|
||||
assert runtimes_ms is not None
|
||||
runtime_ms = statistics.median(runtimes_ms)
|
||||
except ValueError as e:
|
||||
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
|
||||
continue
|
||||
raise
|
||||
np.testing.assert_allclose(out, ref)
|
||||
runtime_us = runtime_ms * 1e3
|
||||
optimal_time = matmul_flops / peak_flops * 1e6 # us
|
||||
achieved_tc_util = optimal_time / runtime_us * 100
|
||||
if achieved_tc_util > best_util:
|
||||
best_runtime = runtime_us
|
||||
best_util = achieved_tc_util
|
||||
print(
|
||||
f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {tiled_epilogue=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=} {cluster_dimension=}:"
|
||||
f" {runtime_us:<7.1f}us = {achieved_tc_util:4.1f}% TC utilization"
|
||||
)
|
||||
print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
+343
@@ -0,0 +1,343 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Matrix Multiplication kernel for Hopper GPUs."""
|
||||
import statistics
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import itertools
|
||||
import jax
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax._src import test_util as jtu # noqa: F401
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
import jax.experimental.pallas as pl
|
||||
import jax.experimental.pallas.mosaic_gpu as plgpu
|
||||
from jax.extend import backend
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MatmulDimension(enum.IntEnum):
|
||||
M = 0
|
||||
N = 1
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TuningConfig:
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
max_concurrent_steps: int
|
||||
epi_tile_n: int | None = 64 # This needs to be lowered for for small N.
|
||||
epi_tile_m: int | None = 64
|
||||
grid_minor_dim: MatmulDimension = MatmulDimension.N
|
||||
grid_tile_width: int = 1
|
||||
wg_dimension: MatmulDimension = MatmulDimension.N
|
||||
cluster_dimension: None | MatmulDimension = None
|
||||
|
||||
|
||||
def mixed_matmul_kernel(
|
||||
a: jax.Array, b: jax.Array, *, out_dtype: jnp.dtype, config: TuningConfig
|
||||
) -> jax.Array:
|
||||
"""Mixed-type matrix multiplication kernel for Hopper GPUs.
|
||||
|
||||
Specifically, this kernel implements the function
|
||||
(a.as_dtype(b.dtype) @ b).astype(out_dtype).
|
||||
"""
|
||||
if a.dtype == b.dtype:
|
||||
raise ValueError(
|
||||
f"Mixed matmul LHS and RHS have the same dtype {a.dtype}. For such "
|
||||
"matrix multiplications, use the `hopper_matmul_mgpu` kernel instead."
|
||||
)
|
||||
match (a.dtype, b.dtype):
|
||||
case (jnp.int8, jnp.bfloat16):
|
||||
pass
|
||||
case (jnp.int8, jnp.float16):
|
||||
pass
|
||||
case _, _:
|
||||
# We do support more combinations, but we haven't benchmarked them
|
||||
# yet---so we raise for the time being.
|
||||
raise NotImplementedError(
|
||||
f"Unbenchmarked dtype combination: {a.dtype=} and {b.dtype=}"
|
||||
)
|
||||
m, k = a.shape
|
||||
k2, n = b.shape
|
||||
if k != k2:
|
||||
raise ValueError(
|
||||
f"Matmul LHS and RHS have incompatible shapes {a.shape} vs {b.shape}"
|
||||
)
|
||||
|
||||
tile_m, tile_n, tile_k = config.tile_m, config.tile_n, config.tile_k
|
||||
epi_tile_n = config.epi_tile_n or tile_n
|
||||
epi_tile_m = config.epi_tile_m or tile_m
|
||||
if tile_n % epi_tile_n != 0:
|
||||
raise ValueError(f"{tile_n=} must be divisible by {epi_tile_n=}")
|
||||
if tile_m % epi_tile_m != 0:
|
||||
raise ValueError(f"{tile_m=} must be divisible by {epi_tile_m=}")
|
||||
|
||||
a_bits = dtypes.itemsize_bits(a.dtype)
|
||||
b_bits = dtypes.itemsize_bits(b.dtype)
|
||||
out_bits = dtypes.itemsize_bits(out_dtype)
|
||||
|
||||
a_swizzle = plgpu.find_swizzle(tile_k * a_bits, "lhs")
|
||||
b_swizzle = plgpu.find_swizzle(tile_n * b_bits, "rhs")
|
||||
out_swizzle = plgpu.find_swizzle(epi_tile_n * out_bits, "out")
|
||||
|
||||
a_transforms = (
|
||||
plgpu.TilingTransform((8, a_swizzle * 8 // a_bits)),
|
||||
plgpu.SwizzleTransform(a_swizzle),
|
||||
)
|
||||
b_transforms = (
|
||||
plgpu.TilingTransform((8, b_swizzle * 8 // b_bits)),
|
||||
plgpu.SwizzleTransform(b_swizzle),
|
||||
)
|
||||
out_transforms = (
|
||||
plgpu.TilingTransform((8, out_swizzle * 8 // out_bits)),
|
||||
plgpu.SwizzleTransform(out_swizzle),
|
||||
)
|
||||
|
||||
max_concurrent_steps = config.max_concurrent_steps
|
||||
cta_tile_m = tile_m * (1 + (config.wg_dimension == MatmulDimension.M))
|
||||
cta_tile_n = tile_n * (1 + (config.wg_dimension == MatmulDimension.N))
|
||||
cluster_tile_m = cta_tile_m * (1 + (config.cluster_dimension == MatmulDimension.M))
|
||||
cluster_tile_n = cta_tile_n * (1 + (config.cluster_dimension == MatmulDimension.N))
|
||||
if m % cluster_tile_m != 0:
|
||||
raise ValueError(f"{m=} must be divisible by {cluster_tile_m} for the given config")
|
||||
if n % cluster_tile_n != 0:
|
||||
raise ValueError(f"{n=} must be divisible by {cluster_tile_n} for the given config")
|
||||
if k % tile_k != 0:
|
||||
raise ValueError(f"{k=} must be divisible by {tile_k=}")
|
||||
m_iters = m // cluster_tile_m
|
||||
n_iters = n // cluster_tile_n
|
||||
k_iters = k // tile_k
|
||||
|
||||
def kernel(a_gmem, b_gmem, out_gmem, out_smem):
|
||||
|
||||
def get_pipeline(pipeline_body, compute_context):
|
||||
return plgpu.emit_pipeline_warp_specialized(
|
||||
pipeline_body,
|
||||
grid=(k_iters,),
|
||||
memory_registers=40,
|
||||
in_specs=[
|
||||
plgpu.BlockSpec(
|
||||
(cta_tile_m, tile_k),
|
||||
lambda k: (0, k),
|
||||
transforms=a_transforms,
|
||||
memory_space=plgpu.SMEM,
|
||||
collective_axes=("cluster",)
|
||||
if config.cluster_dimension == MatmulDimension.N
|
||||
else (),
|
||||
),
|
||||
plgpu.BlockSpec(
|
||||
(tile_k, cta_tile_n),
|
||||
lambda k: (k, 0),
|
||||
transforms=b_transforms,
|
||||
memory_space=plgpu.SMEM,
|
||||
collective_axes=("cluster",)
|
||||
if config.cluster_dimension == MatmulDimension.M
|
||||
else (),
|
||||
),
|
||||
],
|
||||
wg_axis="wg",
|
||||
num_compute_wgs=2,
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
compute_context=compute_context,
|
||||
)
|
||||
|
||||
# Functions don't influence the allocations necessary to run the pipeline.
|
||||
ignore = lambda *_, **__: None
|
||||
@functools.partial(
|
||||
pl.run_scoped,
|
||||
pipeline_allocs=get_pipeline(ignore, ignore).get_allocations(a_gmem, b_gmem),
|
||||
collective_axes="wg",
|
||||
)
|
||||
def _pipeline_scope(pipeline_allocs):
|
||||
wg_idx = lax.axis_index("wg")
|
||||
cta_idx = lax.axis_index("cluster")
|
||||
@plgpu.nd_loop((m_iters * n_iters,), collective_axes="cluster_grid")
|
||||
def _mn_loop(loop_info: plgpu.NDLoopInfo):
|
||||
(lin_idx,) = loop_info.index
|
||||
m_cluster_idx, n_cluster_idx = plgpu.planar_snake(
|
||||
lin_idx,
|
||||
(m_iters, n_iters),
|
||||
config.grid_minor_dim,
|
||||
config.grid_tile_width,
|
||||
)
|
||||
m_idx = m_cluster_idx
|
||||
n_idx = n_cluster_idx
|
||||
if config.cluster_dimension == MatmulDimension.M:
|
||||
m_idx = m_cluster_idx * 2 + cta_idx
|
||||
elif config.cluster_dimension == MatmulDimension.N:
|
||||
n_idx = n_cluster_idx * 2 + cta_idx
|
||||
cta_m_slice = pl.ds(m_idx * cta_tile_m, cta_tile_m)
|
||||
cta_n_slice = pl.ds(n_idx * cta_tile_n, cta_tile_n)
|
||||
if config.wg_dimension == MatmulDimension.M:
|
||||
wg_m_slice = pl.ds(wg_idx * tile_m, tile_m)
|
||||
wg_n_slice = slice(None)
|
||||
else:
|
||||
wg_m_slice = slice(None)
|
||||
wg_n_slice = pl.ds(wg_idx * tile_n, tile_n)
|
||||
|
||||
def compute_context(eval_pipeline):
|
||||
@functools.partial(
|
||||
pl.run_scoped, acc_ref=plgpu.ACC((tile_m, tile_n), jnp.float32)
|
||||
)
|
||||
def _acc_scope(acc_ref):
|
||||
eval_pipeline(acc_ref)
|
||||
acc = acc_ref[...].astype(out_dtype)
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
for epi_mi in range(tile_m // epi_tile_m):
|
||||
for epi_ni in range(tile_n // epi_tile_n):
|
||||
epi_m_slice = slice(epi_mi * epi_tile_m, (epi_mi + 1) * epi_tile_m)
|
||||
epi_n_slice = slice(epi_ni * epi_tile_n, (epi_ni + 1) * epi_tile_n)
|
||||
slot = (epi_mi * (tile_n // epi_tile_n) + epi_ni) % 2
|
||||
plgpu.wait_smem_to_gmem(1, wait_read_only=True)
|
||||
out_smem[wg_idx, slot] = acc[epi_m_slice, epi_n_slice]
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(
|
||||
out_smem.at[wg_idx, slot],
|
||||
out_gmem.at[cta_m_slice, cta_n_slice]
|
||||
.at[wg_m_slice, wg_n_slice]
|
||||
.at[epi_m_slice, epi_n_slice],
|
||||
)
|
||||
|
||||
def mma_body(_, a_smem, b_smem, acc_ref):
|
||||
with jax.named_scope("smem_load"):
|
||||
a_reg = a_smem[wg_m_slice]
|
||||
with jax.named_scope("dequant"):
|
||||
a_reg = a_reg.astype(b.dtype)
|
||||
with jax.named_scope("wgmma"):
|
||||
plgpu.wgmma(acc_ref, a_reg, b_smem.at[:, wg_n_slice])
|
||||
with jax.named_scope("wgmma_wait"):
|
||||
plgpu.wgmma_wait(0)
|
||||
return acc_ref
|
||||
|
||||
get_pipeline(mma_body, compute_context)(
|
||||
a_gmem.at[cta_m_slice, :],
|
||||
b_gmem.at[:, cta_n_slice],
|
||||
allocations=pipeline_allocs,
|
||||
)
|
||||
# Await all transfers before we exit.
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
|
||||
# We don't need multiple slots if there's only one epilogue tile.
|
||||
num_out_slots = min(2, (tile_m * tile_n) // (epi_tile_m * epi_tile_n))
|
||||
num_sms = backend.get_default_device().core_count
|
||||
cluster_size = 1 + (config.cluster_dimension is not None)
|
||||
f = plgpu.kernel(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), out_dtype),
|
||||
grid=(num_sms // cluster_size,),
|
||||
grid_names=("cluster_grid",),
|
||||
cluster=(cluster_size,),
|
||||
cluster_names=("cluster",),
|
||||
num_threads=3,
|
||||
thread_name="wg",
|
||||
scratch_shapes=dict(
|
||||
out_smem=plgpu.SMEM(
|
||||
(2, num_out_slots, epi_tile_m, epi_tile_n),
|
||||
out_dtype,
|
||||
transforms=out_transforms,
|
||||
)
|
||||
),
|
||||
)
|
||||
return f(a, b)
|
||||
|
||||
|
||||
def reference(
|
||||
a: jax.Array, b: jax.Array, *, out_dtype: jnp.dtype
|
||||
) -> jax.Array:
|
||||
"""Reference implementation of a mixed-type matrix multiplication."""
|
||||
return jax.numpy.dot(a, b, preferred_element_type=jnp.float32).astype(
|
||||
out_dtype
|
||||
)
|
||||
|
||||
|
||||
def main(_) -> None:
|
||||
problem_it = [(4096, 8192, 4096)]
|
||||
for M, N, K in problem_it:
|
||||
print(f"==== {M=} {N=} {K=} ====")
|
||||
matmul_flops = 2 * M * N * K
|
||||
peak_flops = 990e12 # f16 TensorCore peak = 990 TFLOPS
|
||||
a = jax.random.randint(
|
||||
jax.random.key(0), minval=-128, maxval=127, shape=(M, K), dtype=jnp.int8
|
||||
)
|
||||
b = jax.random.uniform(jax.random.key(1), (K, N), jnp.bfloat16)
|
||||
ref = reference(a, b, out_dtype=jnp.bfloat16)
|
||||
tuning_it = itertools.product(
|
||||
(64, 128, 256,), # tile_m
|
||||
(64, 128), # tile_n
|
||||
(64, 128), # tile_k
|
||||
(4,), # max_concurrent_steps
|
||||
(True,), # Tiled epilogue
|
||||
(MatmulDimension.M, MatmulDimension.N), # grid_minor_dim
|
||||
(4, 8, 16), # grid_tile_width
|
||||
MatmulDimension, # wg_dimension
|
||||
# Consider adding MatmulDimension here to try out collective TMA kernels
|
||||
(None,) # cluster_dimension
|
||||
)
|
||||
best_util = 0.0
|
||||
best_runtime = float("inf")
|
||||
for tile_m, tile_n, tile_k, max_concurrent_steps, tiled_epilogue, grid_minor_dim, grid_tile_width, wg_dimension, cluster_dimension in tuning_it:
|
||||
config = TuningConfig(
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
epi_tile_n=64 if tiled_epilogue else None,
|
||||
epi_tile_m=64 if tiled_epilogue else None,
|
||||
grid_minor_dim=grid_minor_dim,
|
||||
grid_tile_width=grid_tile_width,
|
||||
wg_dimension=wg_dimension,
|
||||
cluster_dimension=cluster_dimension,
|
||||
)
|
||||
try:
|
||||
out, runtimes_ms = profiler.measure(
|
||||
functools.partial(
|
||||
mixed_matmul_kernel, out_dtype=jnp.bfloat16, config=config
|
||||
),
|
||||
iterations=10,
|
||||
)(a, b)
|
||||
assert runtimes_ms is not None
|
||||
runtime_ms = statistics.median(runtimes_ms)
|
||||
except ValueError as e:
|
||||
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
|
||||
continue
|
||||
raise
|
||||
np.testing.assert_allclose(out, ref)
|
||||
runtime_us = runtime_ms * 1e3
|
||||
optimal_time = matmul_flops / peak_flops * 1e6 # us
|
||||
achieved_tc_util = optimal_time / runtime_us * 100
|
||||
if achieved_tc_util > best_util:
|
||||
best_runtime = runtime_us
|
||||
best_util = achieved_tc_util
|
||||
print(
|
||||
f"{tile_m=} {tile_n=} {tile_k=} {max_concurrent_steps=} {tiled_epilogue=} {grid_minor_dim=} {grid_tile_width=} {wg_dimension=} {cluster_dimension=}:"
|
||||
f" {runtime_us:<7.1f}us = {achieved_tc_util:4.1f}% TC utilization"
|
||||
)
|
||||
print(f"\tBest: {best_runtime:<7.1f}us = {best_util:4.1f}% TC utilization")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
@@ -0,0 +1,345 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module containing fused layer norm forward and backward pass."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
|
||||
def layer_norm_forward_kernel(
|
||||
x_ref, weight_ref, bias_ref, # Input arrays
|
||||
o_ref, mean_ref=None, rstd_ref=None, # Output arrays
|
||||
*, eps: float, block_size: int):
|
||||
n_col = x_ref.shape[0]
|
||||
|
||||
def mean_body(i, acc):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
return acc + a
|
||||
|
||||
mean = lax.fori_loop(
|
||||
0,
|
||||
pl.cdiv(n_col, block_size),
|
||||
mean_body,
|
||||
init_val=jnp.zeros(block_size),
|
||||
).sum()
|
||||
mean /= n_col
|
||||
|
||||
def var_body(i, acc):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
a = jnp.where(mask, a - mean, 0.)
|
||||
return acc + a * a
|
||||
|
||||
var = lax.fori_loop(
|
||||
0,
|
||||
pl.cdiv(n_col, block_size),
|
||||
var_body,
|
||||
init_val=jnp.zeros(block_size),
|
||||
).sum()
|
||||
var /= n_col
|
||||
rstd = 1 / jnp.sqrt(var + eps)
|
||||
if mean_ref is not None:
|
||||
mean_ref[...] = mean.astype(mean_ref.dtype)
|
||||
if rstd_ref is not None:
|
||||
rstd_ref[...] = rstd.astype(rstd_ref.dtype)
|
||||
|
||||
@pl.loop(0, pl.cdiv(n_col, block_size))
|
||||
def body(i):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
weight = plgpu.load(weight_ref.at[col_idx], mask=mask)
|
||||
bias = plgpu.load(bias_ref.at[col_idx], mask=mask)
|
||||
x = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_first"
|
||||
).astype(jnp.float32)
|
||||
out = (x - mean) * rstd * weight + bias
|
||||
plgpu.store(o_ref.at[col_idx], out.astype(o_ref.dtype), mask=mask)
|
||||
|
||||
|
||||
def layer_norm_forward(
|
||||
x, weight, bias,
|
||||
num_warps: int | None = None,
|
||||
num_stages: int | None = 3,
|
||||
eps: float = 1e-5,
|
||||
backward_pass_impl: str = 'triton',
|
||||
interpret: bool = False):
|
||||
del num_stages
|
||||
del backward_pass_impl
|
||||
n = x.shape[-1]
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = [
|
||||
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype),
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
|
||||
]
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_forward",
|
||||
)
|
||||
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
out, mean, rstd = method(x, weight, bias)
|
||||
return out, (x, weight, bias, mean, rstd)
|
||||
|
||||
|
||||
def layer_norm_backward_kernel_dx(
|
||||
# Inputs
|
||||
x_ref, weight_ref, bias_ref, do_ref,
|
||||
mean_ref, rstd_ref,
|
||||
# Outputs
|
||||
dx_ref,
|
||||
*, eps: float, block_size: int):
|
||||
n_col = x_ref.shape[0]
|
||||
|
||||
def mean_body(i, acc):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
dout = plgpu.load(
|
||||
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
weight = plgpu.load(
|
||||
weight_ref.at[col_idx],
|
||||
mask=mask,
|
||||
other=0.0,
|
||||
eviction_policy="evict_last",
|
||||
).astype(jnp.float32)
|
||||
a_hat = (a - mean_ref[...]) * rstd_ref[...]
|
||||
wdout = weight * dout
|
||||
mean1_acc, mean2_acc = acc
|
||||
return mean1_acc + a_hat * wdout, mean2_acc + wdout
|
||||
mean1, mean2 = lax.fori_loop(
|
||||
0,
|
||||
pl.cdiv(n_col, block_size),
|
||||
mean_body,
|
||||
init_val=(jnp.zeros(block_size), jnp.zeros(block_size)),
|
||||
)
|
||||
mean1 = mean1.sum() / n_col
|
||||
mean2 = mean2.sum() / n_col
|
||||
|
||||
@pl.loop(0, pl.cdiv(n_col, block_size))
|
||||
def dx_body(i):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
dout = plgpu.load(
|
||||
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
weight = plgpu.load(
|
||||
weight_ref.at[col_idx],
|
||||
mask=mask,
|
||||
other=0.0,
|
||||
eviction_policy="evict_last",
|
||||
).astype(jnp.float32)
|
||||
a_hat = (a - mean_ref[...]) * rstd_ref[...]
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * mean1 + mean2)) * rstd_ref[...]
|
||||
plgpu.store(dx_ref.at[col_idx], da.astype(dx_ref.dtype), mask=mask)
|
||||
|
||||
|
||||
def layer_norm_backward_kernel_dw_db(
|
||||
# Inputs
|
||||
x_ref, weight_ref, bias_ref, do_ref,
|
||||
mean_ref, rstd_ref,
|
||||
# Outputs
|
||||
dw_ref, db_ref,
|
||||
*, eps: float, block_m: int, block_n: int):
|
||||
m, n_col = x_ref.shape
|
||||
j = pl.program_id(0)
|
||||
col_idx = j * block_n + jnp.arange(block_n)
|
||||
col_mask = col_idx < n_col
|
||||
|
||||
def body(i, acc):
|
||||
row_idx = i * block_m + jnp.arange(block_m)
|
||||
row_mask = row_idx < m
|
||||
mask = row_mask[:, None] & col_mask[None, :]
|
||||
a = plgpu.load(
|
||||
x_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
|
||||
).astype(jnp.float32)
|
||||
dout = plgpu.load(
|
||||
do_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
|
||||
).astype(jnp.float32)
|
||||
mean = plgpu.load(mean_ref.at[row_idx], mask=row_mask, other=0.0).astype(
|
||||
jnp.float32
|
||||
)
|
||||
rstd = plgpu.load(rstd_ref.at[row_idx], mask=row_mask, other=0.0).astype(
|
||||
jnp.float32
|
||||
)
|
||||
a_hat = (a - mean[:, None]) * rstd[:, None]
|
||||
dw_acc_ref, db_acc_ref = acc
|
||||
return dw_acc_ref + (dout * a_hat).sum(axis=0), db_acc_ref + dout.sum(
|
||||
axis=0
|
||||
)
|
||||
|
||||
dw_acc, db_acc = lax.fori_loop(
|
||||
0,
|
||||
pl.cdiv(m, block_m),
|
||||
body,
|
||||
init_val=(jnp.zeros(block_n), jnp.zeros(block_n)),
|
||||
)
|
||||
plgpu.store(dw_ref.at[col_idx], dw_acc.astype(dw_ref.dtype), mask=col_mask)
|
||||
plgpu.store(db_ref.at[col_idx], db_acc.astype(db_ref.dtype), mask=col_mask)
|
||||
|
||||
|
||||
def layer_norm_backward(
|
||||
num_warps: int | None,
|
||||
num_stages: int | None,
|
||||
eps: float,
|
||||
backward_pass_impl: str,
|
||||
interpret: bool,
|
||||
res, do):
|
||||
del num_stages
|
||||
x, weight, bias, mean, rstd = res
|
||||
if backward_pass_impl == 'xla':
|
||||
return jax.vjp(layer_norm_reference, x, weight, bias)[1](do)
|
||||
|
||||
*shape_prefix, n = x.shape
|
||||
reshaped_x = x.reshape((-1, n))
|
||||
reshaped_mean = mean.reshape((-1,))
|
||||
reshaped_rstd = rstd.reshape((-1,))
|
||||
reshaped_do = do.reshape((-1, n))
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
# layer_norm_backward_kernel_dx parallel over batch dims
|
||||
kernel = functools.partial(layer_norm_backward_kernel_dx, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
|
||||
grid=(),
|
||||
out_shape=out_shape_dx,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_backward_dx",
|
||||
)
|
||||
|
||||
method = jax.vmap(method, in_axes=(0, None, None, 0, 0, 0))
|
||||
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
|
||||
dx = dx.reshape((*shape_prefix, n))
|
||||
|
||||
# layer_norm_backward_kernel_dw_db reduce over batch dims
|
||||
# Triton heuristics
|
||||
if n > 10240:
|
||||
block_n = 128
|
||||
block_m = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
block_n = 16
|
||||
block_m = 16
|
||||
num_warps = 8
|
||||
kernel = functools.partial(layer_norm_backward_kernel_dw_db, eps=eps,
|
||||
block_m=block_m, block_n=block_n)
|
||||
out_shape_dwbias = [
|
||||
jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype),
|
||||
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
|
||||
]
|
||||
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
|
||||
grid=grid_,
|
||||
out_shape=out_shape_dwbias,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_backward_dw_db",
|
||||
)
|
||||
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_mean, reshaped_rstd)
|
||||
return dx, dw, dbias
|
||||
|
||||
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7])
|
||||
@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages",
|
||||
"num_stages", "eps",
|
||||
"backward_pass_impl",
|
||||
"interpret"])
|
||||
def layer_norm(
|
||||
x, weight, bias,
|
||||
num_warps: int | None = None,
|
||||
num_stages: int | None = 3,
|
||||
eps: float = 1e-5,
|
||||
backward_pass_impl: str = 'triton',
|
||||
interpret: bool = False):
|
||||
n = x.shape[-1]
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
kernel = functools.partial(layer_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
num_warps=num_warps, num_stages=num_stages),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
)
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
return method(x, weight, bias)
|
||||
layer_norm.defvjp(layer_norm_forward, layer_norm_backward)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["eps"])
|
||||
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
|
||||
def layer_norm_reference(x, weight, bias, *, eps: float = 1e-5):
|
||||
mean = jnp.mean(x, axis=1)
|
||||
mean2 = jnp.mean(jnp.square(x), axis=1)
|
||||
var = jnp.maximum(0., mean2 - jnp.square(mean))
|
||||
y = x - mean[:, None]
|
||||
mul = lax.rsqrt(var + eps)
|
||||
return y * mul[:, None] * weight[None] + bias[None]
|
||||
@@ -0,0 +1,461 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module containing decode attention."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
|
||||
|
||||
|
||||
def paged_attention_kernel(
|
||||
# inputs
|
||||
q_ref, # [block_h, head_dim]
|
||||
k_pages_ref, # [total_num_pages, page_size, head_dim]
|
||||
k_scales_pages_ref, # [total_num_pages, page_size]
|
||||
v_pages_ref, # [total_num_pages, page_size, head_dim]
|
||||
v_scales_pages_ref, # [total_num_pages, page_size]
|
||||
block_tables_ref, # [pages_per_partition]
|
||||
lengths_ref, # [1]
|
||||
# outputs
|
||||
o_ref: Any, # [block_h, head_dim]
|
||||
*residual_refs: Any, # Residual outputs: [block_h,], [block_h,]
|
||||
num_heads: int,
|
||||
pages_per_compute_block: int,
|
||||
mask_value: float,
|
||||
attn_logits_soft_cap: float | None,
|
||||
):
|
||||
partition_idx = pl.program_id(2)
|
||||
block_h, head_dim = q_ref.shape
|
||||
page_size = k_pages_ref.shape[-2]
|
||||
pages_per_partition = block_tables_ref.shape[0]
|
||||
block_k = pages_per_compute_block * page_size
|
||||
|
||||
def _compute(start_page_idx, end_page_idx, o, m_i, l_i):
|
||||
q_slice = pl.ds(0, block_h)
|
||||
q = q_ref[q_slice, :]
|
||||
|
||||
# Loop over blocks of pages to process a entire page sequence partition.
|
||||
# Grid loops over q blocks over num_heads.
|
||||
def body(start_k, carry):
|
||||
o_prev, m_prev, l_prev = carry
|
||||
|
||||
block_tables_slice = pl.ds(
|
||||
start_k * pages_per_compute_block, pages_per_compute_block
|
||||
)
|
||||
block_tables = block_tables_ref[block_tables_slice]
|
||||
k = k_pages_ref[block_tables].reshape(block_k, head_dim)
|
||||
v = v_pages_ref[block_tables].reshape(block_k, head_dim)
|
||||
if k_scales_pages_ref is not None:
|
||||
# dynamic lhs quantized dot is not currently implemented
|
||||
# so we cast rhs to the lhs dtype
|
||||
k = k.astype(q.dtype)
|
||||
uncapped_logits = pl.dot(q, k.T) # [block_h, block_k]
|
||||
if k_scales_pages_ref is not None:
|
||||
# k_scales_pages_ref are one per head
|
||||
# they're laid out across the output dimension, so scale output
|
||||
k_scale = k_scales_pages_ref[block_tables].reshape((1, block_k))
|
||||
uncapped_logits *= k_scale.astype(uncapped_logits.dtype)
|
||||
if attn_logits_soft_cap is not None:
|
||||
logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap)
|
||||
logits = logits * attn_logits_soft_cap
|
||||
else:
|
||||
logits = uncapped_logits
|
||||
|
||||
if lengths_ref is not None:
|
||||
curr_start_page_idx = (
|
||||
partition_idx * pages_per_partition
|
||||
+ start_k * pages_per_compute_block
|
||||
)
|
||||
curr_start_token_idx = curr_start_page_idx * page_size
|
||||
|
||||
mask = jnp.arange(block_k) + curr_start_token_idx < lengths_ref[0]
|
||||
mask = lax.broadcast_in_dim(mask, (block_h, block_k), (1,))
|
||||
logits = jnp.where(mask, logits, mask_value)
|
||||
|
||||
log2e = math.log2(math.e)
|
||||
m_curr = logits.max(axis=-1)
|
||||
m_next = jnp.maximum(m_prev, m_curr)
|
||||
correction = jnp.exp2((m_prev - m_next) * log2e)
|
||||
l_prev_corr = correction * l_prev
|
||||
s_curr = jnp.exp2((logits - m_next[:, None]) * log2e)
|
||||
l_curr = s_curr.sum(axis=-1)
|
||||
l_next = l_prev_corr + l_curr
|
||||
o_prev_corr = correction[:, None] * o_prev
|
||||
if v_scales_pages_ref is not None:
|
||||
# v_scales are 1 per head
|
||||
# they're laid out across the reduction dimension, so scale lhs
|
||||
v_scale = v_scales_pages_ref[block_tables].reshape((1, block_k))
|
||||
s_curr *= v_scale.astype(s_curr.dtype)
|
||||
# dynamic lhs quantized dot is not currently implemented
|
||||
# so we cast rhs to the lhs dtype
|
||||
v = v.astype(s_curr.dtype)
|
||||
o_curr = pl.dot(s_curr.astype(v.dtype), v)
|
||||
|
||||
o_next = o_prev_corr + o_curr
|
||||
return o_next, m_next, l_next
|
||||
|
||||
max_it = pl.cdiv(end_page_idx - start_page_idx, pages_per_compute_block)
|
||||
(o, m_i, l_i) = lax.fori_loop(0, max_it, body, (o, m_i, l_i))
|
||||
|
||||
return o, m_i, l_i
|
||||
|
||||
m_i = jnp.zeros(block_h, dtype=jnp.float32) + jnp.finfo(jnp.float32).min
|
||||
l_i = jnp.zeros(block_h, dtype=jnp.float32)
|
||||
o = jnp.zeros((block_h, head_dim), dtype=jnp.float32)
|
||||
|
||||
start_page_idx = partition_idx * pages_per_partition
|
||||
end_page_idx = start_page_idx + pages_per_partition
|
||||
|
||||
if lengths_ref is None:
|
||||
o, m_i, l_i = _compute(start_page_idx, end_page_idx, o, m_i, l_i)
|
||||
else:
|
||||
end_page_idx = jnp.minimum(pl.cdiv(lengths_ref[0], page_size), end_page_idx)
|
||||
|
||||
o, m_i, l_i = jax.lax.cond(
|
||||
start_page_idx >= end_page_idx,
|
||||
lambda: (o, m_i, l_i),
|
||||
lambda: _compute(start_page_idx, end_page_idx, o, m_i, l_i),
|
||||
)
|
||||
|
||||
o_ref[...] = o.astype(o_ref.dtype)
|
||||
|
||||
if residual_refs is not None:
|
||||
l_ref, m_ref = residual_refs
|
||||
l_ref[...] = l_i
|
||||
m_ref[...] = m_i
|
||||
|
||||
|
||||
def paged_attention_unbatched(
|
||||
q: jax.Array, # [num_q_heads, head_dim]
|
||||
k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim]
|
||||
v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim]
|
||||
block_tables: jax.Array, # [pages_per_sequence]
|
||||
lengths: jax.Array | None, # [1]
|
||||
k_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size]
|
||||
v_scales_pages: jax.Array | None = None, # [num_kv_heads, total_num_pages, page_size]
|
||||
*,
|
||||
block_h: int,
|
||||
pages_per_compute_block: int,
|
||||
k_splits: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
interpret: bool,
|
||||
debug: bool,
|
||||
mask_value: float,
|
||||
attn_logits_soft_cap: float | None,
|
||||
) -> jax.Array:
|
||||
num_q_heads, head_dim = q.shape
|
||||
num_kv_heads, total_num_pages, page_size, _ = k_pages.shape
|
||||
pages_per_sequence = block_tables.shape[0]
|
||||
|
||||
assert (
|
||||
pages_per_sequence % k_splits == 0
|
||||
), f"{pages_per_sequence=} must be divisible by {k_splits=}."
|
||||
|
||||
pages_per_partition = pages_per_sequence // k_splits
|
||||
pages_per_compute_block = min(pages_per_partition, pages_per_compute_block)
|
||||
|
||||
assert (
|
||||
pages_per_partition % pages_per_compute_block == 0
|
||||
), f"{pages_per_partition=} must de divisible by {pages_per_compute_block=}."
|
||||
|
||||
block_tables = block_tables.reshape(k_splits, pages_per_sequence // k_splits)
|
||||
|
||||
q_heads_per_kv_head = num_q_heads // num_kv_heads
|
||||
q_reshaped = q.reshape(num_kv_heads, q_heads_per_kv_head, head_dim)
|
||||
|
||||
if q_heads_per_kv_head % block_h:
|
||||
q_reshaped = jnp.pad(
|
||||
q_reshaped, ((0, 0), (0, -q_heads_per_kv_head % block_h), (0, 0))
|
||||
)
|
||||
|
||||
head_splits = pl.cdiv(q_heads_per_kv_head, block_h)
|
||||
grid = (num_kv_heads, head_splits, k_splits)
|
||||
kernel = functools.partial(
|
||||
paged_attention_kernel,
|
||||
num_heads=q_heads_per_kv_head,
|
||||
pages_per_compute_block=pages_per_compute_block,
|
||||
mask_value=mask_value,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
)
|
||||
# set up quantization scales
|
||||
if k_scales_pages is not None:
|
||||
assert k_scales_pages.shape == (num_kv_heads, total_num_pages, page_size)
|
||||
k_scales_spec = pl.BlockSpec((None, total_num_pages, page_size),
|
||||
lambda h, i, k: (h, 0, 0))
|
||||
else:
|
||||
k_scales_spec = None
|
||||
if v_scales_pages is not None:
|
||||
assert v_scales_pages.shape == (num_kv_heads, total_num_pages, page_size)
|
||||
v_scales_spec = pl.BlockSpec((None, total_num_pages, page_size),
|
||||
lambda h, i, k: (h, 0, 0))
|
||||
else:
|
||||
v_scales_spec = None
|
||||
|
||||
o, l, m = pl.pallas_call(
|
||||
kernel,
|
||||
grid=grid,
|
||||
in_specs=[
|
||||
pl.BlockSpec(
|
||||
(None, block_h, head_dim), lambda h, i, k: (h, i, 0)
|
||||
), # q
|
||||
pl.BlockSpec(
|
||||
(None, total_num_pages, page_size, head_dim),
|
||||
lambda h, i, k: (h, 0, 0, 0),
|
||||
), # k_pages
|
||||
k_scales_spec, # k_pages_scale
|
||||
pl.BlockSpec(
|
||||
(None, total_num_pages, page_size, head_dim),
|
||||
lambda h, i, k: (h, 0, 0, 0),
|
||||
), # v_pages
|
||||
v_scales_spec, # v_pages_scale
|
||||
pl.BlockSpec(
|
||||
(None, pages_per_partition), lambda h, i, k: (k, 0)
|
||||
), # block_tables
|
||||
]
|
||||
+ [
|
||||
None if lengths is None else pl.BlockSpec((1,), lambda h, i, k: (0,))
|
||||
], # lengths
|
||||
out_specs=[
|
||||
pl.BlockSpec(
|
||||
(None, None, block_h, head_dim), lambda h, i, k: (k, h, i, 0)
|
||||
), # q
|
||||
pl.BlockSpec((None, None, block_h), lambda h, i, k: (k, h, i)), # l
|
||||
pl.BlockSpec((None, None, block_h), lambda h, i, k: (k, h, i)), # m
|
||||
],
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(
|
||||
(k_splits, *q_reshaped.shape), dtype=q.dtype
|
||||
), # o
|
||||
jax.ShapeDtypeStruct(
|
||||
(k_splits, *q_reshaped.shape[:-1]), dtype=jnp.float32
|
||||
), # l
|
||||
jax.ShapeDtypeStruct(
|
||||
(k_splits, *q_reshaped.shape[:-1]), dtype=jnp.float32
|
||||
), # m
|
||||
],
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
num_warps=num_warps, num_stages=num_stages
|
||||
),
|
||||
name=f"paged_attention_{block_h=}_{pages_per_compute_block=}",
|
||||
)(q_reshaped, k_pages, k_scales_pages, v_pages, v_scales_pages, block_tables, lengths)
|
||||
|
||||
if q_heads_per_kv_head % block_h:
|
||||
o = o[..., :q_heads_per_kv_head, :]
|
||||
l = l[..., :q_heads_per_kv_head]
|
||||
m = m[..., :q_heads_per_kv_head]
|
||||
|
||||
# final round of flash
|
||||
m_next = m.max(axis=0)
|
||||
correction = jnp.exp(m - m_next[None])
|
||||
o = o * correction[..., None].astype(o.dtype)
|
||||
l_next = (l * correction).sum(axis=0)
|
||||
eps = jnp.finfo(l_next.dtype).eps
|
||||
o = o.sum(axis=0) / ((l_next[..., None] + eps).astype(o.dtype))
|
||||
|
||||
o = o.reshape(q.shape).astype(q.dtype)
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"block_h",
|
||||
"pages_per_compute_block",
|
||||
"k_splits",
|
||||
"num_warps",
|
||||
"num_stages",
|
||||
"interpret",
|
||||
"debug",
|
||||
"mask_value",
|
||||
"attn_logits_soft_cap",
|
||||
],
|
||||
)
|
||||
def paged_attention(
|
||||
q: jax.Array,
|
||||
k_pages: jax.Array,
|
||||
v_pages: jax.Array,
|
||||
block_tables: jax.Array,
|
||||
lengths: jax.Array | None,
|
||||
k_scales_pages: jax.Array | None = None,
|
||||
v_scales_pages: jax.Array | None = None,
|
||||
*,
|
||||
block_h: int = 16,
|
||||
pages_per_compute_block: int = 8,
|
||||
k_splits: int = 16,
|
||||
num_warps: int = 8,
|
||||
num_stages: int = 2,
|
||||
interpret: bool = False,
|
||||
debug: bool = False,
|
||||
mask_value: float = DEFAULT_MASK_VALUE,
|
||||
attn_logits_soft_cap: float | None = None,
|
||||
) -> jax.Array:
|
||||
"""Paged grouped query attention.
|
||||
|
||||
Args:
|
||||
q: A [batch_size, num_heads, head_dim] jax.Array.
|
||||
k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
|
||||
v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
|
||||
block_tables: A i32[batch_size, pages_per_sequence] jax.Array. Each entry
|
||||
should be in the range of [0, total_num_pages), indicating where to locate
|
||||
the page in `k_pages` or `v_pages`.
|
||||
lengths: A i32[batch_size] jax.Array the length of each example.
|
||||
k_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array.
|
||||
v_scales_pages: A [num_kv_heads, total_num_pages, page_size] jax.Array.
|
||||
block_h: int The block size that partitions the number of head groups.
|
||||
pages_per_compute_block: int The maximum number of blocks per compute block.
|
||||
k_splits: int Number of partitions used to parallelize key-value sequence
|
||||
pages processing.
|
||||
mask_value: The value used for padding in attention. By default it is a very
|
||||
negative floating point number.
|
||||
attn_logits_soft_cap: The value used for soft capping the attention logits.
|
||||
|
||||
Returns:
|
||||
The output of attention([batch_size, num_heads, head_dim]).
|
||||
"""
|
||||
batch_size, num_heads, head_dim = q.shape
|
||||
num_kv_heads, _, _, head_dim_k = k_pages.shape
|
||||
batch_size_paged_indices, _ = block_tables.shape
|
||||
|
||||
if k_pages.shape != v_pages.shape:
|
||||
raise ValueError(
|
||||
f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and"
|
||||
f" {v_pages.shape}"
|
||||
)
|
||||
if num_heads % num_kv_heads != 0:
|
||||
raise ValueError(
|
||||
"Number of Q heads must be divisible by number of KV heads. Got"
|
||||
f" {num_heads} and {num_kv_heads}."
|
||||
)
|
||||
if head_dim_k != head_dim:
|
||||
raise ValueError(
|
||||
"head_dim of Q must be the same as that of K/V. Got"
|
||||
f" {head_dim} and {head_dim_k}."
|
||||
)
|
||||
if batch_size_paged_indices != batch_size:
|
||||
raise ValueError("`block_tables` and `q` must have the same batch size")
|
||||
if lengths is not None:
|
||||
if lengths.shape != (batch_size,):
|
||||
raise ValueError("`lengths` and `q` must have the same batch size")
|
||||
if lengths.dtype != jnp.int32:
|
||||
raise ValueError(
|
||||
"The dtype of `lengths` must be int32. Got {lengths.dtype}"
|
||||
)
|
||||
|
||||
if block_h % 16:
|
||||
raise ValueError(f"block_h must divisible by 16, but is {block_h}.")
|
||||
|
||||
impl = functools.partial(
|
||||
paged_attention_unbatched,
|
||||
block_h=block_h,
|
||||
pages_per_compute_block=pages_per_compute_block,
|
||||
k_splits=k_splits,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
interpret=interpret,
|
||||
debug=debug,
|
||||
mask_value=mask_value,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
)
|
||||
|
||||
o = jax.vmap(impl, (0, None, None, 0, 0, None, None), 0)(
|
||||
q,
|
||||
k_pages,
|
||||
v_pages,
|
||||
block_tables,
|
||||
lengths[..., None] if lengths is not None else None,
|
||||
k_scales_pages,
|
||||
v_scales_pages,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit, static_argnames=["mask_value", "attn_logits_soft_cap"]
|
||||
)
|
||||
def paged_attention_reference(
|
||||
q: jax.Array,
|
||||
k: jax.Array,
|
||||
v: jax.Array,
|
||||
lengths: jax.Array,
|
||||
*,
|
||||
mask_value: float = DEFAULT_MASK_VALUE,
|
||||
attn_logits_soft_cap: float | None = None,
|
||||
) -> jax.Array:
|
||||
"""Grouped query attention reference implementation.
|
||||
|
||||
Args:
|
||||
q: A [batch_size, num_heads, head_dim] jax.Array.
|
||||
k: A [batch_size, kv_seq_len, num_kv_heads, head_dim] jax.Array.
|
||||
v: A [batch_size, kv_seq_len, num_kv_heads, head_dim] jax.Array.
|
||||
lengths: A i32[batch_size] jax.Array the length of each example.
|
||||
mask_value: The value used for padding in attention. By default it is a very
|
||||
negative floating point number.
|
||||
attn_logits_soft_cap: The value used for soft capping the attention logits.
|
||||
|
||||
Returns:
|
||||
The output of attention([batch_size, num_heads, head_dim]).
|
||||
"""
|
||||
batch_size, num_heads, head_dim = q.shape
|
||||
_, kv_seq_len, num_kv_heads, _ = k.shape
|
||||
|
||||
q_heads_per_kv_head = num_heads // num_kv_heads
|
||||
q_reshaped = q.reshape(
|
||||
batch_size, num_kv_heads, q_heads_per_kv_head, head_dim
|
||||
)
|
||||
k_transposed = jnp.swapaxes(
|
||||
k, 1, 2
|
||||
) # [batch_size, num_kv_heads, kv_seq_len, head_dim]
|
||||
v_transposed = jnp.swapaxes(
|
||||
v, 1, 2
|
||||
) # [batch_size, num_kv_heads, kv_seq_len, head_dim]
|
||||
|
||||
uncapped_logits = jnp.einsum(
|
||||
"bkgd,bksd->bkgs", q_reshaped, k_transposed,
|
||||
preferred_element_type=jnp.float32
|
||||
).astype(jnp.float32)
|
||||
|
||||
if attn_logits_soft_cap is not None:
|
||||
logits = jnp.tanh(uncapped_logits / attn_logits_soft_cap)
|
||||
logits = logits * attn_logits_soft_cap
|
||||
else:
|
||||
logits = uncapped_logits
|
||||
|
||||
if lengths is not None:
|
||||
mask = jnp.arange(kv_seq_len)[None, :] < lengths[:, None]
|
||||
mask = jnp.broadcast_to(mask[:, None, None, :], logits.shape)
|
||||
logits = jnp.where(mask, logits, mask_value)
|
||||
|
||||
weights = jax.nn.softmax(logits, axis=-1)
|
||||
o = jnp.einsum(
|
||||
"bkgs,bksd->bkgd", weights, v_transposed.astype(jnp.float32),
|
||||
preferred_element_type=jnp.float32
|
||||
).astype(q.dtype)
|
||||
o = o.reshape(q.shape)
|
||||
|
||||
return o
|
||||
@@ -0,0 +1,333 @@
|
||||
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Ragged dot Pallas-Mosaic-GPU implementation."""
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import random
|
||||
from jax._src import test_util as jtu # noqa: F401
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.pallas import mosaic_gpu as plgpu
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class GroupInfo:
|
||||
"""Information regarding the group being processed in a block."""
|
||||
|
||||
group_id: jax.Array
|
||||
block: jax.Array
|
||||
block_start: jax.Array
|
||||
actual_start: jax.Array
|
||||
actual_end: jax.Array
|
||||
start_within_block: jax.Array
|
||||
actual_size: jax.Array
|
||||
|
||||
@classmethod
|
||||
def create(cls, group_lengths, tile, tid):
|
||||
"""Get the group info for the current block."""
|
||||
|
||||
tile = jnp.int32(tile)
|
||||
group_boundaries = [group_lengths[i] for i in range(len(group_lengths))]
|
||||
|
||||
# We usually only have very few groups, so we unroll the loop processing
|
||||
# them. Normally we'd break out of the loop early, once we'd have found our
|
||||
# boundary, but we can't do that when unrolling, so we rely on many selects
|
||||
# to mask out the epilogue of the loop.
|
||||
group_end = group_start = block = group = end = jnp.array(
|
||||
0, dtype=jnp.int32
|
||||
)
|
||||
|
||||
for i, b in enumerate(group_boundaries):
|
||||
# Start/end are inclusive
|
||||
start = end
|
||||
end = start + b
|
||||
final = end - 1
|
||||
start_block = lax.div(start, tile)
|
||||
final_block = lax.div(final, tile)
|
||||
block_end = final_block + 1
|
||||
tid_begin = start_block + i
|
||||
tid_end = block_end + i
|
||||
# How many blocks after is our block?
|
||||
this_is_group = (tid_begin <= tid) & (tid < tid_end)
|
||||
block = lax.select(this_is_group, tid - tid_begin + start_block, block)
|
||||
group = lax.select(this_is_group, jnp.int32(i), group)
|
||||
group_start = lax.select(this_is_group, start, group_start)
|
||||
group_end = lax.select(this_is_group, end, group_end)
|
||||
|
||||
block_start = block * tile
|
||||
actual_start = jnp.maximum(group_start, block_start)
|
||||
actual_end = jnp.minimum(group_end, block_start + tile)
|
||||
start_within_block = actual_start - block_start
|
||||
actual_size = actual_end - actual_start
|
||||
return cls(
|
||||
group_id=group,
|
||||
block=block,
|
||||
block_start=block_start,
|
||||
actual_start=actual_start,
|
||||
actual_end=actual_end,
|
||||
start_within_block=start_within_block,
|
||||
actual_size=actual_size,
|
||||
)
|
||||
|
||||
|
||||
def ragged_dot(
|
||||
lhs, # (M, K)
|
||||
rhs, # (G, K, N)
|
||||
*,
|
||||
group_sizes, # (G,)
|
||||
block_m: int,
|
||||
block_n: int,
|
||||
block_k: int,
|
||||
max_concurrent_steps: int,
|
||||
grid_block_n: int,
|
||||
transpose_rhs: bool = False,
|
||||
load_group_sizes_to_register: bool = True,
|
||||
) -> jax.Array:
|
||||
if lhs.dtype != rhs.dtype:
|
||||
raise NotImplementedError(
|
||||
f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}"
|
||||
)
|
||||
m, k = lhs.shape
|
||||
g, k2, n = rhs.shape
|
||||
|
||||
if transpose_rhs:
|
||||
k2, n = n, k2
|
||||
|
||||
if group_sizes.shape[0] != g:
|
||||
raise ValueError(
|
||||
f"Expected group_sizes to have shape {g} but got {group_sizes.shape}"
|
||||
)
|
||||
|
||||
if k != k2:
|
||||
raise ValueError(f"lhs.shape={k} must match rhs.shape={k2}")
|
||||
|
||||
if k % block_k != 0:
|
||||
raise ValueError(f"k={k} must be a multiple of block_k={block_k}")
|
||||
|
||||
def body(rows_per_expert_gmem, lhs_gmem, rhs_gmem, o_gmem):
|
||||
grid_m = pl.cdiv(m, block_m) + g - 1
|
||||
grid_n = pl.cdiv(n, block_n)
|
||||
grid = (grid_m * grid_n,)
|
||||
|
||||
@plgpu.nd_loop(grid, collective_axes="sm")
|
||||
def mn_loop(loop_info: plgpu.NDLoopInfo):
|
||||
mi, ni = plgpu.planar_snake(
|
||||
loop_info.index[0],
|
||||
(grid_m, grid_n),
|
||||
1,
|
||||
grid_block_n,
|
||||
)
|
||||
group_info = GroupInfo.create(rows_per_expert_gmem, block_m, mi)
|
||||
|
||||
def acc_scope(acc_ref):
|
||||
plgpu.emit_pipeline(
|
||||
lambda _, lhs_smem, rhs_smem: plgpu.wgmma(
|
||||
acc_ref,
|
||||
lhs_smem,
|
||||
plgpu.transpose_ref(rhs_smem, (1, 0)) if transpose_rhs else rhs_smem,
|
||||
),
|
||||
grid=(k // block_k,),
|
||||
in_specs=[
|
||||
plgpu.BlockSpec(
|
||||
(block_m, block_k),
|
||||
lambda k: (group_info.block, k),
|
||||
delay_release=1,
|
||||
),
|
||||
plgpu.BlockSpec(
|
||||
(block_n, block_k) if transpose_rhs else (block_k, block_n),
|
||||
lambda k: (ni, k) if transpose_rhs else (k, ni),
|
||||
delay_release=1,
|
||||
),
|
||||
],
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
)(lhs_gmem, rhs_gmem.at[group_info.group_id])
|
||||
return acc_ref[...]
|
||||
|
||||
acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n)))
|
||||
|
||||
@functools.partial(
|
||||
pl.run_scoped,
|
||||
o_smem=plgpu.SMEM((block_m, block_n), dtype=o_gmem.dtype)
|
||||
)
|
||||
def store_scope(o_smem):
|
||||
o_smem[...] = acc.astype(o_smem.dtype)
|
||||
plgpu.commit_smem()
|
||||
|
||||
smem_start = group_info.start_within_block
|
||||
remaining_rows = min(block_m, m)
|
||||
# TMA descriptors need to be generated with static tile sizes along each
|
||||
# axis, but we do not know at compile time how many rows we will need to
|
||||
# store. We only know that the number of rows to store is bounded by
|
||||
# min(block_m, m).
|
||||
#
|
||||
# In order to work around that, we construct a logarithmic ladder of
|
||||
# TMA descriptors, where each descriptor can store 2**i rows for some
|
||||
# i between 0 and log2(min(block_m, m)). This allows storing any
|
||||
# number of rows we will need to store, so long as this number of rows
|
||||
# is between `1` and `min(block_m, m)`.
|
||||
#
|
||||
# E.g., imagine we have block_m = 8, m = 16. The loop below will be
|
||||
# unrolled into 4 iterations, where the first one will generate a TMA
|
||||
# descriptor that can store 8 rows, the second one will generate a TMA
|
||||
# descriptor that can store 4 rows, etc. all the way to 1 row.
|
||||
#
|
||||
# At run time, we finally know the actual number of rows we need to
|
||||
# store as we go through the unrolled loop iterations. Let's imagine
|
||||
# that we need to store 5 rows.
|
||||
#
|
||||
# The first unrolled iteration will check whether we can store 8 rows.
|
||||
# Since we only need to store 5 rows, we won't store anything then.
|
||||
#
|
||||
# The second unrolled iteration will check whether we can store 4 rows.
|
||||
# We're able to store 4 rows, and are left with a single remaining row.
|
||||
#
|
||||
# The fourth unrolled iteration will store the single remaining row, and
|
||||
# we end up with a storing scheme as follows for our 5 rows:
|
||||
#
|
||||
# -----------------------------------------------------------
|
||||
# 0 | |
|
||||
# 1 | |
|
||||
# 2 | Store 4 rows |
|
||||
# 3 | |
|
||||
# -----------------------------------------------------------
|
||||
# 4 | Store 1 row |
|
||||
# -----------------------------------------------------------
|
||||
while remaining_rows > 0:
|
||||
const_rows_len = 1 << int(math.log2(remaining_rows))
|
||||
remaining_rows //= 2
|
||||
|
||||
@pl.when(group_info.actual_size & const_rows_len != 0)
|
||||
def _():
|
||||
o_smem_slice = o_smem.at[pl.ds(smem_start, const_rows_len)]
|
||||
o_gref_slice = o_gmem.at[
|
||||
pl.ds(group_info.block_start + smem_start, const_rows_len),
|
||||
pl.ds(ni * block_n, block_n),
|
||||
]
|
||||
plgpu.copy_smem_to_gmem(o_smem_slice, o_gref_slice)
|
||||
|
||||
smem_start += group_info.actual_size & const_rows_len
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
|
||||
# There are 132 SMs on a H100 SXM GPU.
|
||||
num_sms = 132
|
||||
kernel = plgpu.kernel(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), lhs.dtype),
|
||||
grid=(num_sms,),
|
||||
grid_names=("sm",),
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
lowering_semantics=plgpu.LoweringSemantics.Warpgroup,
|
||||
),
|
||||
)
|
||||
return kernel(group_sizes, lhs, rhs)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
for transpose_rhs in [False, True]:
|
||||
m, k, n, num_groups = 16 * 1024, 2048, 16 * 1024, 16
|
||||
kx, ky, kz = random.split(random.key(1234), num=3)
|
||||
|
||||
lhs = jax.random.normal(kx, (m, k), jnp.float16)
|
||||
if transpose_rhs:
|
||||
rhs = jax.random.normal(ky, (num_groups, n, k), jnp.float16)
|
||||
else:
|
||||
rhs = jax.random.normal(ky, (num_groups, k, n), jnp.float16)
|
||||
group_boundaries = jax.lax.sort(
|
||||
jax.random.randint(kz, (num_groups - 1,), 0, m, jnp.int32)
|
||||
)
|
||||
group_starts = lax.concatenate(
|
||||
[jnp.array([0], dtype=jnp.int32), group_boundaries], 0
|
||||
)
|
||||
group_ends = lax.concatenate(
|
||||
[group_boundaries, jnp.array([m], dtype=jnp.int32)], 0
|
||||
)
|
||||
group_sizes = group_ends - group_starts
|
||||
assert group_sizes.shape == (num_groups,)
|
||||
|
||||
block_m = block_n = (64, 128, 192)
|
||||
block_k = (64,)
|
||||
max_concurrent_steps = (2, 4, 5, 6)
|
||||
grid_block_n = (1, 2, 4, 8, 16)
|
||||
configs = itertools.product(
|
||||
block_m, block_n, block_k, max_concurrent_steps, grid_block_n
|
||||
)
|
||||
names = (
|
||||
"block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n"
|
||||
)
|
||||
best_runtime = float("inf")
|
||||
best_kwargs: dict[str, int] = {}
|
||||
for config in configs:
|
||||
kwargs = dict(zip(names, config))
|
||||
if n % (kwargs["grid_block_n"] * kwargs["block_n"]):
|
||||
continue
|
||||
try:
|
||||
f = functools.partial(
|
||||
ragged_dot, group_sizes=group_sizes, transpose_rhs=transpose_rhs,
|
||||
**kwargs
|
||||
)
|
||||
_, runtime = profiler.measure(f)(lhs, rhs)
|
||||
except ValueError as e:
|
||||
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
|
||||
raise
|
||||
runtime = float("inf")
|
||||
# Enable this to get more detailed information.
|
||||
else:
|
||||
assert runtime is not None
|
||||
print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000))
|
||||
if runtime < best_runtime:
|
||||
best_runtime = runtime
|
||||
best_kwargs = kwargs
|
||||
if not best_kwargs:
|
||||
raise ValueError("No valid configuration found")
|
||||
|
||||
def ref_ragged_dot(lhs, rhs, group_sizes):
|
||||
if transpose_rhs:
|
||||
rhs = jnp.transpose(rhs, (0, 2, 1))
|
||||
return jax.lax.ragged_dot(lhs, rhs, group_sizes=group_sizes)
|
||||
|
||||
ref, ref_runtime = profiler.measure(ref_ragged_dot)(
|
||||
lhs, rhs, group_sizes=group_sizes
|
||||
)
|
||||
assert ref_runtime is not None
|
||||
result = ragged_dot(
|
||||
lhs, rhs, group_sizes=group_sizes, transpose_rhs=transpose_rhs,
|
||||
load_group_sizes_to_register=True,
|
||||
**best_kwargs
|
||||
)
|
||||
np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12
|
||||
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
|
||||
print(f"Transpose RHS: {transpose_rhs}")
|
||||
print(
|
||||
"Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())
|
||||
)
|
||||
print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
|
||||
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
+244
@@ -0,0 +1,244 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Reduce scatter kernel implemented using Mosaic GPU."""
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.pallas import mosaic_gpu as plgpu
|
||||
from jax.extend import backend
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def reduce_scatter(
|
||||
x: jax.Array,
|
||||
*,
|
||||
axis_name,
|
||||
scatter_dimension: int | None = 0,
|
||||
reduction: Literal["add", "min", "max", "and", "or", "xor"] = "add",
|
||||
num_blocks: int | None = None,
|
||||
tile_size: int | None = None,
|
||||
vec_size: int | None = None,
|
||||
) -> jax.Array:
|
||||
"""Performs a reduce-scatter or all-reduce operation across devices using multimem instructions.
|
||||
|
||||
Args:
|
||||
x: Input array. Should be sharded across the specified axis.
|
||||
axis_name: Name of the mesh axis to reduce-scatter across.
|
||||
scatter_dimension: Axis along which to reduce-scatter. If None, performs
|
||||
all-reduce instead. Defaults to 0.
|
||||
reduction: Reduction operation to perform. Supported: "add", "min", "max",
|
||||
"and", "or", "xor".
|
||||
vec_size: Vector size for the layout. If None, automatically inferred from dtype.
|
||||
num_blocks: Number of blocks to use. Defaults to the device core count.
|
||||
tile_size: Total tile size to split across major, scatter, and minor dimensions.
|
||||
"""
|
||||
num_devices = lax.axis_size(axis_name)
|
||||
input_shape = x.shape
|
||||
dtype = x.dtype
|
||||
ndim = len(input_shape)
|
||||
|
||||
if num_blocks is None:
|
||||
num_blocks = backend.get_default_device().core_count
|
||||
|
||||
if scatter_dimension is None:
|
||||
major_dims, scatter_dim, minor_dims = 1, math.prod(input_shape), 1
|
||||
output_scatter_dim = scatter_dim
|
||||
output_shape = input_shape
|
||||
else:
|
||||
if scatter_dimension < -ndim or scatter_dimension >= ndim:
|
||||
raise ValueError(
|
||||
f"scatter_dimension {scatter_dimension} out of bounds for array of"
|
||||
f" dimension {ndim}"
|
||||
)
|
||||
if scatter_dimension < 0:
|
||||
scatter_dimension += ndim
|
||||
|
||||
scatter_dim = input_shape[scatter_dimension]
|
||||
if scatter_dim % num_devices != 0:
|
||||
raise ValueError(
|
||||
f"Scattered dimension {scatter_dimension} of input ({scatter_dim})"
|
||||
f" must be divisible by number of devices ({num_devices})"
|
||||
)
|
||||
|
||||
major_dims = math.prod(input_shape[:scatter_dimension])
|
||||
minor_dims = math.prod(input_shape[scatter_dimension+1:])
|
||||
output_scatter_dim = scatter_dim // num_devices
|
||||
output_shape = (
|
||||
*input_shape[:scatter_dimension], output_scatter_dim, *input_shape[scatter_dimension + 1 :],
|
||||
)
|
||||
|
||||
if (output_size := math.prod(output_shape)) % 128:
|
||||
raise ValueError("Output size must be divisible by 128")
|
||||
if jnp.issubdtype(dtype, jnp.integer):
|
||||
if vec_size is None:
|
||||
vec_size = 1 # Integer types only support unvectorized reductions
|
||||
elif vec_size != 1:
|
||||
raise ValueError("Integer types only support vec_size=1")
|
||||
elif vec_size is None: # vec_size inference for floating point types
|
||||
dtype_bits = jnp.finfo(dtype).bits
|
||||
max_vec_size = min(128 // dtype_bits, output_size // 128)
|
||||
if tile_size is not None:
|
||||
max_vec_size_for_tile = tile_size // 128
|
||||
max_vec_size = min(max_vec_size, max_vec_size_for_tile)
|
||||
vec_size = 32 // dtype_bits # We don't support ld_reduce below 32-bit
|
||||
while vec_size * 2 <= max_vec_size:
|
||||
vec_size *= 2
|
||||
if math.prod(output_shape) % vec_size:
|
||||
raise ValueError(
|
||||
"The total number of elements in the output"
|
||||
f" ({math.prod(output_shape)}) must be divisible by the vec_size"
|
||||
f" ({vec_size})"
|
||||
)
|
||||
|
||||
min_transfer_elems = 128 * vec_size
|
||||
if tile_size is None:
|
||||
# TODO(apaszke): 8 is just an arbitrary unrolling factor. Tune it!
|
||||
unroll_factor = min(math.prod(output_shape) // min_transfer_elems, 8)
|
||||
tile_size = unroll_factor * min_transfer_elems
|
||||
if tile_size < min_transfer_elems:
|
||||
raise ValueError(
|
||||
f"{tile_size=} is smaller than minimum required"
|
||||
f" {min_transfer_elems} for {vec_size=}"
|
||||
)
|
||||
|
||||
minor_tile = math.gcd(tile_size, minor_dims)
|
||||
remaining_tile = tile_size // minor_tile
|
||||
scatter_tile = math.gcd(remaining_tile, output_scatter_dim)
|
||||
major_tile = remaining_tile // scatter_tile
|
||||
|
||||
if major_dims % major_tile != 0:
|
||||
raise NotImplementedError(
|
||||
f"Major dimension size ({major_dims}) must be divisible by the"
|
||||
f" inferred major tile size ({major_tile}). Consider adjusting tile_size."
|
||||
)
|
||||
|
||||
def kernel(x_ref, y_ref, done_barrier):
|
||||
dev_idx = lax.axis_index(axis_name)
|
||||
x_ref_3d = x_ref.reshape((major_dims, scatter_dim, minor_dims))
|
||||
y_ref_3d = y_ref.reshape((major_dims, output_scatter_dim, minor_dims))
|
||||
|
||||
if scatter_dimension is not None:
|
||||
dev_slice = pl.ds(dev_idx * output_scatter_dim, output_scatter_dim)
|
||||
x_ref_3d = x_ref_3d.at[:, dev_slice, :]
|
||||
|
||||
major_tiles = major_dims // major_tile
|
||||
scatter_tiles = output_scatter_dim // scatter_tile
|
||||
minor_tiles = minor_dims // minor_tile
|
||||
@plgpu.nd_loop((major_tiles, scatter_tiles, minor_tiles), collective_axes="blocks")
|
||||
def _transfer_loop(loop_info: plgpu.NDLoopInfo):
|
||||
major_tile_idx, scatter_tile_idx, minor_tile_idx = loop_info.index
|
||||
idxs = (
|
||||
pl.ds(major_tile_idx * major_tile, major_tile),
|
||||
pl.ds(scatter_tile_idx * scatter_tile, scatter_tile),
|
||||
pl.ds(minor_tile_idx * minor_tile, minor_tile)
|
||||
)
|
||||
|
||||
y_ref_3d[idxs] = plgpu.layout_cast(
|
||||
plgpu.multimem_load_reduce(
|
||||
x_ref_3d.at[idxs], collective_axes=axis_name, reduction_op=reduction
|
||||
),
|
||||
plgpu.Layout.WG_STRIDED((major_tile, scatter_tile, minor_tile), vec_size=vec_size)
|
||||
)
|
||||
|
||||
# Wait for everyone to finish reading the operands before we exit and potentially free them
|
||||
plgpu.semaphore_signal_multicast(done_barrier, collective_axes=axis_name)
|
||||
pl.semaphore_wait(done_barrier, num_devices, decrement=False)
|
||||
|
||||
return plgpu.kernel(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct(output_shape, dtype),
|
||||
grid=(num_blocks,),
|
||||
grid_names=("blocks",),
|
||||
scratch_shapes=[plgpu.SemaphoreType.REGULAR],
|
||||
)(x)
|
||||
|
||||
|
||||
def _run_example():
|
||||
P = jax.sharding.PartitionSpec
|
||||
shape = (4 * 4096, 4 * 4096) # This shape is global!
|
||||
dtype = jnp.bfloat16
|
||||
shards = jax.device_count()
|
||||
mesh = jax.make_mesh(
|
||||
(shards,), ("x",), axis_types=(jax.sharding.AxisType.Explicit,)
|
||||
)
|
||||
jax.set_mesh(mesh)
|
||||
|
||||
# We measure time per-shard and so we only need bytes per shard.
|
||||
local_in_bytes = math.prod(shape) / shards * jnp.dtype(dtype).itemsize
|
||||
# In reduce-scatter, we send (shards - 1) / shards worth of input data to the
|
||||
# switch and receive as much data as in the whole output, which is 1 / shards.
|
||||
total_bytes = local_in_bytes
|
||||
|
||||
a = jax.random.normal(jax.random.key(1), shape, dtype)
|
||||
a = jax.sharding.reshard(a, P(None, "x"))
|
||||
|
||||
@jax.jit
|
||||
@functools.partial(jax.shard_map, mesh=mesh, in_specs=P(None, "x"), out_specs=P(None, "x"))
|
||||
def ref_fn(x):
|
||||
return lax.psum_scatter(x, "x", scatter_dimension=1, tiled=True)
|
||||
ref_fn(a).block_until_ready() # Warmup.
|
||||
_, ref_kernels_ms = profiler.measure(ref_fn, aggregate=False)(a)
|
||||
assert ref_kernels_ms is not None
|
||||
ref_time_us = sum(t * 1e3 for _, t in ref_kernels_ms)
|
||||
# We choose the minimum across processes to choose the runtime that didn't
|
||||
# include devices waiting for other devices.
|
||||
ref_time_us = min(multihost_utils.process_allgather(ref_time_us).tolist())
|
||||
ref_bw = total_bytes / (ref_time_us * 1e-6) / 1e9 # GB/s
|
||||
|
||||
tuning_it = itertools.product(
|
||||
(4, 8, 16, 32, 64, 132), # num_blocks
|
||||
(1024, 2048, 4096, 8192), # tile_size
|
||||
)
|
||||
best_bw = 0.0
|
||||
best_runtime = float("inf")
|
||||
for num_blocks, tile_size in tuning_it:
|
||||
try:
|
||||
@jax.jit
|
||||
@functools.partial(
|
||||
jax.shard_map, mesh=mesh, in_specs=P(None, "x"), out_specs=P(None, "x"), check_vma=False
|
||||
)
|
||||
def kernel_fn(x):
|
||||
return reduce_scatter(x, axis_name="x", scatter_dimension=1, num_blocks=num_blocks, tile_size=tile_size)
|
||||
kernel_fn(a).block_until_ready() # Warmup.
|
||||
_, kernels_ms = profiler.measure(kernel_fn, aggregate=False)(a)
|
||||
except ValueError as e:
|
||||
if "exceeds available shared memory" in e.args[0]: # Ignore SMEM OOMs.
|
||||
continue
|
||||
raise
|
||||
assert kernels_ms is not None
|
||||
runtime_us = sum(t * 1e3 for _, t in kernels_ms)
|
||||
runtime_us = min(multihost_utils.process_allgather(runtime_us).tolist())
|
||||
achieved_bw = total_bytes / (runtime_us * 1e-6) / 1e9 # GB/s
|
||||
if achieved_bw > best_bw:
|
||||
best_runtime = runtime_us
|
||||
best_bw = achieved_bw
|
||||
print(f"{num_blocks=}, {tile_size=}: {runtime_us:<7.1f}us = {achieved_bw:4.1f} GB/s")
|
||||
|
||||
print(f"Total bytes transferred: {total_bytes / 1e9:.2f} GB")
|
||||
print(f"\tBest: {best_runtime:<7.1f}us = {best_bw:4.1f} GB/s")
|
||||
print(f"\tRef: {ref_time_us:<7.1f}us = {ref_bw:4.1f} GB/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from jax._src import test_multiprocess as jt_multiprocess # pytype: disable=import-error
|
||||
jt_multiprocess.main(shard_main=_run_example)
|
||||
@@ -0,0 +1,310 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Module containing rms forward and backward pass."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
def rms_norm_forward_kernel(
|
||||
x_ref, weight_ref, bias_ref, # Input arrays
|
||||
o_ref, rstd_ref=None, # Output arrays
|
||||
*, eps: float, block_size: int):
|
||||
n_col = x_ref.shape[0]
|
||||
|
||||
def var_body(i, acc):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
a = jnp.where(mask, a, 0.)
|
||||
return acc + a * a
|
||||
|
||||
var = lax.fori_loop(
|
||||
0, pl.cdiv(n_col, block_size), var_body, init_val=jnp.zeros(block_size)
|
||||
).sum()
|
||||
var /= n_col
|
||||
rstd = 1 / jnp.sqrt(var + eps)
|
||||
if rstd_ref is not None:
|
||||
rstd_ref[...] = rstd.astype(rstd_ref.dtype)
|
||||
|
||||
@pl.loop(0, pl.cdiv(n_col, block_size))
|
||||
def body(i):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
weight = plgpu.load(weight_ref.at[col_idx], mask=mask)
|
||||
bias = plgpu.load(bias_ref.at[col_idx], mask=mask)
|
||||
x = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_first"
|
||||
).astype(jnp.float32)
|
||||
out = x * rstd * weight + bias
|
||||
plgpu.store(o_ref.at[col_idx], out.astype(o_ref.dtype), mask=mask)
|
||||
|
||||
|
||||
def rms_norm_forward(
|
||||
x, weight, bias,
|
||||
num_warps: int | None = None,
|
||||
num_stages: int | None = 3,
|
||||
eps: float = 1e-5,
|
||||
backward_pass_impl: str = 'triton',
|
||||
interpret: bool = False):
|
||||
del num_stages
|
||||
del backward_pass_impl
|
||||
n = x.shape[-1]
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = [
|
||||
jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype),
|
||||
jax.ShapeDtypeStruct(shape=(), dtype=x.dtype)
|
||||
]
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="rms_forward",
|
||||
)
|
||||
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
out, rstd = method(x, weight, bias)
|
||||
return out, (x, weight, bias, rstd)
|
||||
|
||||
|
||||
def rms_norm_backward_kernel_dx(
|
||||
# Inputs
|
||||
x_ref, weight_ref, bias_ref, do_ref,
|
||||
rstd_ref,
|
||||
# Outputs
|
||||
dx_ref,
|
||||
*, eps: float, block_size: int):
|
||||
n_col = x_ref.shape[0]
|
||||
|
||||
def mean_body(i, c1_acc):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
dout = plgpu.load(
|
||||
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
weight = plgpu.load(
|
||||
weight_ref.at[col_idx],
|
||||
mask=mask,
|
||||
other=0.0,
|
||||
eviction_policy="evict_last",
|
||||
).astype(jnp.float32)
|
||||
a_hat = a * rstd_ref[...]
|
||||
wdout = weight * dout
|
||||
return c1_acc + a_hat * wdout
|
||||
|
||||
c1 = lax.fori_loop(
|
||||
0, pl.cdiv(n_col, block_size), mean_body, jnp.zeros(block_size)
|
||||
)
|
||||
c1 = c1.sum() / n_col
|
||||
|
||||
@pl.loop(0, pl.cdiv(n_col, block_size))
|
||||
def dx_body(i):
|
||||
col_idx = i * block_size + jnp.arange(block_size)
|
||||
mask = col_idx < n_col
|
||||
a = plgpu.load(
|
||||
x_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
dout = plgpu.load(
|
||||
do_ref.at[col_idx], mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).astype(jnp.float32)
|
||||
weight = plgpu.load(
|
||||
weight_ref.at[col_idx],
|
||||
mask=mask,
|
||||
other=0.0,
|
||||
eviction_policy="evict_last",
|
||||
).astype(jnp.float32)
|
||||
a_hat = a * rstd_ref[...]
|
||||
wdout = weight * dout
|
||||
da = (wdout - (a_hat * c1)) * rstd_ref[...]
|
||||
plgpu.store(dx_ref.at[col_idx], da.astype(dx_ref.dtype), mask=mask)
|
||||
|
||||
|
||||
def rms_norm_backward_kernel_dw_db(
|
||||
# Inputs
|
||||
x_ref, weight_ref, bias_ref, do_ref,
|
||||
rstd_ref,
|
||||
# Outputs
|
||||
dw_ref, db_ref,
|
||||
*, eps: float, block_m: int, block_n: int):
|
||||
m, n_col = x_ref.shape
|
||||
j = pl.program_id(0)
|
||||
col_idx = j * block_n + jnp.arange(block_n)
|
||||
col_mask = col_idx < n_col
|
||||
|
||||
def body(i, acc):
|
||||
row_idx = i * block_m + jnp.arange(block_m)
|
||||
row_mask = row_idx < m
|
||||
mask = row_mask[:, None] & col_mask[None, :]
|
||||
a = plgpu.load(
|
||||
x_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
|
||||
).astype(jnp.float32)
|
||||
dout = plgpu.load(
|
||||
do_ref.at[row_idx[:, None], col_idx[None]], mask=mask, other=0.0
|
||||
).astype(jnp.float32)
|
||||
rstd = plgpu.load(rstd_ref.at[row_idx], mask=row_mask, other=0.0).astype(
|
||||
jnp.float32
|
||||
)
|
||||
a_hat = a * rstd[:, None]
|
||||
dw_acc, db_acc = acc
|
||||
return (dw_acc + (dout * a_hat).sum(axis=0), db_acc + dout.sum(axis=0))
|
||||
|
||||
dw_acc, db_acc = lax.fori_loop(
|
||||
0,
|
||||
pl.cdiv(m, block_m),
|
||||
body,
|
||||
init_val=(jnp.zeros(block_n), jnp.zeros(block_n)),
|
||||
)
|
||||
plgpu.store(dw_ref.at[col_idx], dw_acc.astype(dw_ref.dtype), mask=col_mask)
|
||||
plgpu.store(db_ref.at[col_idx], db_acc.astype(db_ref.dtype), mask=col_mask)
|
||||
|
||||
|
||||
def rms_norm_backward(
|
||||
num_warps: int | None,
|
||||
num_stages: int | None,
|
||||
eps: float,
|
||||
backward_pass_impl: str,
|
||||
interpret: bool,
|
||||
res, do):
|
||||
del num_stages
|
||||
x, weight, bias, rstd = res
|
||||
if backward_pass_impl == 'xla':
|
||||
return jax.vjp(rms_norm_reference, x, weight, bias)[1](do)
|
||||
|
||||
*shape_prefix, n = x.shape
|
||||
reshaped_x = x.reshape((-1, n))
|
||||
reshaped_rstd = rstd.reshape((-1,))
|
||||
reshaped_do = do.reshape((-1, n))
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
# rms_norm_backward_kernel_dx parallel over batch dims
|
||||
kernel = functools.partial(rms_norm_backward_kernel_dx, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape_dx = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
|
||||
grid=(),
|
||||
out_shape=out_shape_dx,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_backward_dx",
|
||||
)
|
||||
|
||||
method = jax.vmap(method, in_axes=(0, None, None, 0, 0))
|
||||
dx = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
|
||||
dx = dx.reshape((*shape_prefix, n))
|
||||
|
||||
# rms_norm_backward_kernel_dw_db reduce over batch dims
|
||||
# Triton heuristics
|
||||
if n > 10240:
|
||||
block_n = 128
|
||||
block_m = 32
|
||||
num_warps = 4
|
||||
else:
|
||||
# maximize occupancy for small N
|
||||
block_n = 16
|
||||
block_m = 16
|
||||
num_warps = 8
|
||||
kernel = functools.partial(rms_norm_backward_kernel_dw_db, eps=eps,
|
||||
block_m=block_m, block_n=block_n)
|
||||
out_shape_dwbias = [
|
||||
jax.ShapeDtypeStruct(shape=weight.shape, dtype=weight.dtype),
|
||||
jax.ShapeDtypeStruct(shape=bias.shape, dtype=bias.dtype)
|
||||
]
|
||||
grid_ = (pl.cdiv(reshaped_x.shape[1], block_n),)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(num_warps=num_warps),
|
||||
grid=grid_,
|
||||
out_shape=out_shape_dwbias,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
name="ln_backward_dw_db",
|
||||
)
|
||||
dw, dbias = method(reshaped_x, weight, bias, reshaped_do, reshaped_rstd)
|
||||
return dx, dw, dbias
|
||||
|
||||
|
||||
@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7])
|
||||
@functools.partial(jax.jit, static_argnames=["num_warps", "num_stages",
|
||||
"num_stages", "eps",
|
||||
"backward_pass_impl",
|
||||
"interpret"])
|
||||
def rms_norm(
|
||||
x, weight, bias,
|
||||
num_warps: int | None = None,
|
||||
num_stages: int | None = 3,
|
||||
eps: float = 1e-5,
|
||||
backward_pass_impl: str = 'triton',
|
||||
interpret: bool = False):
|
||||
n = x.shape[-1]
|
||||
# Triton heuristics
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
max_fused_size = 65536 // x.dtype.itemsize
|
||||
block_size = min(max_fused_size, pl.next_power_of_2(n))
|
||||
block_size = min(max(block_size, 128), 4096)
|
||||
num_warps = min(max(block_size // 256, 1), 8)
|
||||
|
||||
kernel = functools.partial(rms_norm_forward_kernel, eps=eps,
|
||||
block_size=block_size)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype)
|
||||
method = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
num_warps=num_warps, num_stages=num_stages
|
||||
),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=False,
|
||||
interpret=interpret,
|
||||
)
|
||||
method = jax.vmap(jax.vmap(method, in_axes=(0, None, None)), in_axes=(0, None, None))
|
||||
return method(x, weight, bias)
|
||||
rms_norm.defvjp(rms_norm_forward, rms_norm_backward)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["eps"])
|
||||
@functools.partial(jax.vmap, in_axes=(0, None, None), out_axes=0)
|
||||
def rms_norm_reference(x, weight, bias, *, eps: float = 1e-5):
|
||||
var = jnp.mean(jnp.square(x), axis=1)
|
||||
mul = lax.rsqrt(var + eps)
|
||||
return x * mul[:, None] * weight[None] + bias[None]
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Pallas softmax kernel."""
|
||||
import functools
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import triton as plgpu
|
||||
|
||||
|
||||
def _vmappable_softmax_kernel(
|
||||
# inputs
|
||||
input_ref,
|
||||
# outputs
|
||||
probs_ref,
|
||||
*,
|
||||
# block information
|
||||
# It is assumed that block_row >= row_len
|
||||
block_row: int,
|
||||
):
|
||||
row_len = input_ref.shape[-1]
|
||||
|
||||
mask = jnp.arange(block_row) < row_len
|
||||
row = plgpu.load(
|
||||
input_ref.at[pl.ds(0, block_row)], mask=mask, other=-float("inf")
|
||||
)
|
||||
|
||||
row_max = jnp.max(row, axis=0)
|
||||
numerator = jnp.exp((row - row_max).astype(jnp.float32))
|
||||
denominator = jnp.sum(numerator, axis=0)
|
||||
|
||||
plgpu.store(
|
||||
probs_ref.at[pl.ds(0, block_row)],
|
||||
(numerator / denominator).astype(probs_ref.dtype),
|
||||
mask=mask
|
||||
)
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnames=["axis", "num_warps", "interpret",
|
||||
"debug"])
|
||||
def softmax(
|
||||
x: jax.Array, *, axis: int = -1, num_warps: int = 4,
|
||||
interpret: bool = False, debug: bool = False
|
||||
) -> jax.Array:
|
||||
"""Computes the softmax of the input array along the specified axis.
|
||||
|
||||
Args:
|
||||
x: input array
|
||||
axis: the axis along which to perform the computation
|
||||
num_warps: the number of warps to use for executing the Triton kernel
|
||||
interpret: whether to interpret the kernel using pallas
|
||||
debug: whether to use pallas in debug mode
|
||||
|
||||
Returns:
|
||||
The result of the softmax operation over the specified axis of x.
|
||||
"""
|
||||
axis = axis if axis >= 0 else len(x.shape) + axis
|
||||
if axis != len(x.shape) - 1:
|
||||
raise NotImplementedError(
|
||||
"reductions along non-trailing dimension unsupported")
|
||||
|
||||
row_len = x.shape[-1]
|
||||
|
||||
block_row = pl.next_power_of_2(row_len)
|
||||
out_shape = jax.ShapeDtypeStruct(shape=(row_len,), dtype=x.dtype)
|
||||
|
||||
kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row)
|
||||
f = pl.pallas_call(
|
||||
kernel,
|
||||
compiler_params=plgpu.CompilerParams(
|
||||
num_warps=num_warps, num_stages=1),
|
||||
grid=(),
|
||||
out_shape=out_shape,
|
||||
debug=debug,
|
||||
interpret=interpret,
|
||||
)
|
||||
|
||||
for _ in range(len(x.shape) - 1):
|
||||
f = jax.vmap(f)
|
||||
|
||||
return f(x)
|
||||
+302
@@ -0,0 +1,302 @@
|
||||
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Transposed ragged dot Pallas-Mosaic-GPU implementation."""
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax import random
|
||||
from jax._src import test_util as jtu # noqa: F401
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.mosaic.gpu import profiler
|
||||
from jax.experimental.pallas import mosaic_gpu as plgpu
|
||||
import numpy as np
|
||||
|
||||
|
||||
def transposed_ragged_dot(
|
||||
lhs, # (K, M)
|
||||
rhs, # (K, N)
|
||||
*,
|
||||
group_sizes, # (G,)
|
||||
block_m: int,
|
||||
block_n: int,
|
||||
block_k: int,
|
||||
max_concurrent_steps: int,
|
||||
grid_block_n: int,
|
||||
) -> jax.Array:
|
||||
if lhs.dtype != rhs.dtype:
|
||||
raise NotImplementedError(
|
||||
f"lhs and rhs must have the same dtype, got {lhs.dtype} and {rhs.dtype}"
|
||||
)
|
||||
k, m = lhs.shape
|
||||
k2, n = rhs.shape
|
||||
g = group_sizes.shape[0]
|
||||
|
||||
if k != k2:
|
||||
raise ValueError(f"lhs.shape={k} must match rhs.shape={k2}")
|
||||
|
||||
if m % block_m != 0:
|
||||
raise ValueError(f"m={m} must be a multiple of block_m={block_m}")
|
||||
if n % block_n != 0:
|
||||
raise ValueError(f"n={n} must be a multiple of block_n={block_n}")
|
||||
|
||||
group_sizes = group_sizes.astype(int)
|
||||
group_starts = jnp.concatenate(
|
||||
[jnp.zeros(1, dtype=int), jnp.cumsum(group_sizes)[:-1]]
|
||||
).astype(int)
|
||||
group_ends = jnp.cumsum(group_sizes)
|
||||
group_block_starts = group_starts // block_k * block_k
|
||||
group_block_ends = -(group_ends // -block_k) * block_k
|
||||
group_num_blocks = (group_block_ends - group_block_starts) // block_k
|
||||
|
||||
swizzle = plgpu.find_swizzle(block_k * jnp.dtype(lhs.dtype).itemsize * 8)
|
||||
swizzle_elems = swizzle // jnp.dtype(lhs.dtype).itemsize
|
||||
transforms = (
|
||||
plgpu.TilingTransform((8, swizzle_elems)), plgpu.SwizzleTransform(swizzle)
|
||||
)
|
||||
|
||||
def body(
|
||||
group_sizes_gmem,
|
||||
group_starts_gmem,
|
||||
group_ends_gmem,
|
||||
group_num_blocks_gmem,
|
||||
group_block_starts_gmem,
|
||||
lhs_gmem,
|
||||
rhs_gmem,
|
||||
o_gmem,
|
||||
):
|
||||
|
||||
grid_m = pl.cdiv(m, block_m)
|
||||
grid_n = pl.cdiv(n, block_n)
|
||||
|
||||
@plgpu.nd_loop((g, grid_m * grid_n), collective_axes="sm")
|
||||
def mn_loop(loop_info: plgpu.NDLoopInfo):
|
||||
g_i = loop_info.index[0]
|
||||
m_i, n_i = plgpu.planar_snake(
|
||||
loop_info.index[1],
|
||||
(grid_m, grid_n),
|
||||
1,
|
||||
grid_block_n,
|
||||
)
|
||||
|
||||
# This slice is potentially out of bounds, but we never access the
|
||||
# out of bound part in emit_pipeline.
|
||||
gmem_slice = pl.ds(group_block_starts_gmem[g_i], k)
|
||||
|
||||
def acc_scope(acc_ref):
|
||||
def block_matmul(block_idx, lhs_smem, rhs_smem):
|
||||
block_idx = block_idx[0]
|
||||
|
||||
@pl.when(block_idx == 0)
|
||||
def _():
|
||||
# Handles the first block of the group, where there might be
|
||||
# data from the previous group in the beginning of the block.
|
||||
lhs_reg = lhs_smem[...]
|
||||
start_index = lax.rem(group_starts_gmem[g_i], block_k)
|
||||
indices = plgpu.layout_cast(
|
||||
jax.lax.broadcasted_iota(jnp.int32, (block_k, block_m), 0),
|
||||
plgpu.Layout.WGMMA
|
||||
)
|
||||
lhs_mask = (indices >= start_index).astype(lhs_smem.dtype)
|
||||
|
||||
lhs_reg = lhs_reg * lhs_mask
|
||||
lhs_smem[...] = lhs_reg
|
||||
plgpu.commit_smem()
|
||||
|
||||
@pl.when(block_idx == group_num_blocks_gmem[g_i] - 1)
|
||||
def _():
|
||||
# Handles the last block of the group, where there might be
|
||||
# data from the next group in the end of the block.
|
||||
lhs_reg = lhs_smem[...]
|
||||
last_index = lax.rem(group_ends_gmem[g_i] - 1, block_k)
|
||||
indices = plgpu.layout_cast(
|
||||
jax.lax.broadcasted_iota(jnp.int32, (block_k, block_m), 0),
|
||||
plgpu.Layout.WGMMA
|
||||
)
|
||||
lhs_mask = (indices <= last_index).astype(lhs_smem.dtype)
|
||||
|
||||
lhs_reg = lhs_reg * lhs_mask
|
||||
lhs_smem[...] = lhs_reg
|
||||
plgpu.commit_smem()
|
||||
|
||||
plgpu.wgmma(acc_ref, plgpu.transpose_ref(lhs_smem, (1, 0)), rhs_smem)
|
||||
if max_concurrent_steps == 1:
|
||||
# Without delayed release, we won't have at least two separate
|
||||
# smem blocks in flight. Therefore, we cannot rely on the implicit
|
||||
# wait of wgmma to gaurantee that the data in smem is ready to be
|
||||
# overwritten by the next pipeline iteration.
|
||||
plgpu.wgmma_wait(0)
|
||||
|
||||
@pl.when(group_sizes_gmem[g_i] > 0) # Skip the group if it is empty.
|
||||
def _():
|
||||
plgpu.emit_pipeline(
|
||||
block_matmul,
|
||||
grid=(group_num_blocks_gmem[g_i],),
|
||||
in_specs=[
|
||||
plgpu.BlockSpec(
|
||||
(block_k, block_m),
|
||||
lambda k_i: (k_i, m_i),
|
||||
delay_release=1 if max_concurrent_steps > 1 else 0,
|
||||
transforms=transforms,
|
||||
),
|
||||
plgpu.BlockSpec(
|
||||
(block_k, block_n),
|
||||
lambda k_i: (k_i, n_i),
|
||||
delay_release=1 if max_concurrent_steps > 1 else 0,
|
||||
transforms=transforms,
|
||||
),
|
||||
],
|
||||
max_concurrent_steps=max_concurrent_steps,
|
||||
)(lhs_gmem.at[gmem_slice, :], rhs_gmem.at[gmem_slice, :])
|
||||
|
||||
return acc_ref[...]
|
||||
|
||||
acc = pl.run_scoped(acc_scope, plgpu.ACC((block_m, block_n)))
|
||||
|
||||
@functools.partial(
|
||||
pl.run_scoped,
|
||||
o_smem=plgpu.SMEM(
|
||||
(block_m, block_n),
|
||||
dtype=o_gmem.dtype,
|
||||
transforms=transforms,
|
||||
)
|
||||
)
|
||||
def store_scope(o_smem):
|
||||
o_smem[...] = acc.astype(o_smem.dtype)
|
||||
plgpu.commit_smem()
|
||||
plgpu.copy_smem_to_gmem(
|
||||
o_smem, o_gmem.at[
|
||||
g_i,
|
||||
pl.ds(m_i * block_m, block_m),
|
||||
pl.ds(n_i * block_n, block_n)
|
||||
]
|
||||
)
|
||||
plgpu.wait_smem_to_gmem(0, wait_read_only=True)
|
||||
|
||||
# There are 132 SMs on a H100 SXM GPU.
|
||||
num_sms = jax.devices()[0].core_count
|
||||
kernel = plgpu.kernel(
|
||||
body,
|
||||
out_shape=jax.ShapeDtypeStruct((g, m, n), lhs.dtype),
|
||||
grid=(num_sms,),
|
||||
grid_names=("sm",),
|
||||
)
|
||||
return kernel(
|
||||
group_sizes,
|
||||
group_starts,
|
||||
group_ends,
|
||||
group_num_blocks,
|
||||
group_block_starts,
|
||||
lhs,
|
||||
rhs,
|
||||
)
|
||||
|
||||
|
||||
def ref_transposed_ragged_dot(lhs, rhs, group_sizes):
|
||||
return jax.lax.ragged_dot_general(
|
||||
lhs, rhs, group_sizes,
|
||||
ragged_dot_dimension_numbers=jax.lax.RaggedDotDimensionNumbers(
|
||||
dot_dimension_numbers=(((0,), (0,)), ((), ())),
|
||||
lhs_ragged_dimensions=[0],
|
||||
rhs_group_dimensions=[],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
k, m, n, num_groups = 16 * 1024, 2048, 2048, 16
|
||||
kx, ky, kz = random.split(random.key(1234), num=3)
|
||||
|
||||
lhs = jax.random.normal(kx, (k, m), jnp.float16)
|
||||
rhs = jax.random.normal(ky, (k, n), jnp.float16)
|
||||
group_boundaries = jax.lax.sort(
|
||||
jax.random.randint(kz, (num_groups - 1,), 0, k, jnp.int32)
|
||||
)
|
||||
group_starts = lax.concatenate(
|
||||
[jnp.array([0], dtype=jnp.int32), group_boundaries], 0
|
||||
)
|
||||
group_ends = lax.concatenate(
|
||||
[group_boundaries, jnp.array([k], dtype=jnp.int32)], 0
|
||||
)
|
||||
group_sizes = group_ends - group_starts
|
||||
assert group_sizes.shape == (num_groups,)
|
||||
|
||||
block_m = block_n = [64, 128]
|
||||
block_k = [64, 128]
|
||||
max_concurrent_steps = [1, 2, 4, 5, 6]
|
||||
grid_block_n = [1, 2, 4, 8, 16]
|
||||
|
||||
configs = itertools.product(
|
||||
block_m, block_n, block_k, max_concurrent_steps, grid_block_n
|
||||
)
|
||||
names = (
|
||||
"block_m", "block_n", "block_k", "max_concurrent_steps", "grid_block_n",
|
||||
)
|
||||
best_runtime = float("inf")
|
||||
best_kwargs = {}
|
||||
for config in configs:
|
||||
kwargs = dict(zip(names, config))
|
||||
if n % kwargs["block_n"]:
|
||||
continue
|
||||
try:
|
||||
f = functools.partial(
|
||||
transposed_ragged_dot, group_sizes=group_sizes,
|
||||
**kwargs
|
||||
)
|
||||
_, runtime = profiler.measure(f)(lhs, rhs)
|
||||
except ValueError as e:
|
||||
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
|
||||
raise
|
||||
runtime = float("inf")
|
||||
# Enable this to get more detailed information.
|
||||
else:
|
||||
assert runtime is not None
|
||||
print(
|
||||
" ".join(f"{k}={v}" for k, v in kwargs.items()),
|
||||
f"{int(runtime * 1000):.1f} us",
|
||||
)
|
||||
assert runtime is not None
|
||||
assert best_runtime is not None
|
||||
if runtime < best_runtime:
|
||||
best_runtime = runtime
|
||||
best_kwargs = kwargs
|
||||
if not best_kwargs:
|
||||
raise ValueError("No valid configuration found")
|
||||
|
||||
ref, ref_runtime = profiler.measure(ref_transposed_ragged_dot)(
|
||||
lhs, rhs, group_sizes=group_sizes
|
||||
)
|
||||
result = transposed_ragged_dot(
|
||||
lhs, rhs, group_sizes=group_sizes, **best_kwargs
|
||||
)
|
||||
|
||||
assert ref_runtime is not None
|
||||
tflops = float(2 * k * m * n) / (best_runtime / 1e3) / 1e12
|
||||
ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12
|
||||
print(
|
||||
"Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())
|
||||
)
|
||||
print(f"Kernel: {best_runtime * 1000:.1f} us = {tflops:.1f} TFLOPS")
|
||||
print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS")
|
||||
np.testing.assert_allclose(result, ref, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from absl import app
|
||||
|
||||
jax.config.config_with_absl()
|
||||
app.run(main)
|
||||
Reference in New Issue
Block a user