This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,13 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,155 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple all-gather kernel.
This is meant to be a pedagogical example of how to write a custom collective
using Pallas. It doesn't have all possible performance optimizations and doesn't
currently handle more diverse topologies.
The kernel assumes a ring structure on a single mesh axis. It takes the local
chunk, splits it in two, and sends each of the half-chunks in each direction
(left and right) until every device has received the half chunks.
"""
from __future__ import annotations
import functools
from collections.abc import Sequence
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax._src import shard_map
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
P = jax.sharding.PartitionSpec
def get_neighbor(
idx: jax.Array, mesh: jax.sharding.Mesh, axis_name: str, *, direction: str
) -> tuple[jax.Array, ...]:
"""Helper function that computes the mesh indices of a neighbor."""
axis_names = mesh.axis_names
which_axis = axis_names.index(axis_name)
mesh_index = [
idx if i == which_axis else lax.axis_index(a)
for i, a in enumerate(axis_names)
]
axis_size = lax.axis_size(axis_name)
if direction == "right":
next_idx = lax.rem(idx + 1, axis_size)
else:
left = idx - 1
next_idx = jnp.where(left < 0, left + axis_size, left)
mesh_index[which_axis] = next_idx
return tuple(mesh_index)
def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str,
mesh: jax.sharding.Mesh):
my_id = lax.axis_index(axis_name)
# TODO(sharadmv): could speed this up having the first remote DMA go from
# x_ref->o_ref immediately instead of a blocking HBM copy.
with jax.named_scope("initial_copy"):
pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait()
with jax.named_scope("neighbour_lookup"):
axis_size = lax.axis_size(axis_name)
left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left")
right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right")
with jax.named_scope("main_barrier"):
sem = pltpu.get_barrier_semaphore()
pl.semaphore_signal(sem, 1, device_id=left_neighbor)
pl.semaphore_signal(sem, 1, device_id=right_neighbor)
pl.semaphore_wait(sem, 2)
shard_size = x_ref.shape[0]
right_dma, left_dma = None, None
# Main strategy for this AG: carve up our input into two slices. Send
# each slice along each direction until they reach every device.
for i in range(axis_size - 1):
right_slot = my_id - i
right_slice = pl.ds(shard_size // 2, shard_size // 2)
slot = jnp.where(right_slot < 0, axis_size + right_slot, right_slot)
if right_dma:
with jax.named_scope("wait_right_dma"):
right_dma.wait()
right_dma = pltpu.async_remote_copy(
o_ref.at[slot, right_slice],
o_ref.at[slot, right_slice],
send_sem[1],
recv_sem[1],
device_id=right_neighbor,
)
left_slot = my_id + i
left_slice = pl.ds(0, shard_size // 2)
slot = lax.rem(left_slot, axis_size)
if left_dma:
with jax.named_scope("wait_left_dma"):
left_dma.wait()
left_dma = pltpu.async_remote_copy(
o_ref.at[slot, left_slice],
o_ref.at[slot, left_slice],
send_sem[0],
recv_sem[0],
device_id=left_neighbor,
)
with jax.named_scope("wait_all_dma"):
assert right_dma is not None
assert left_dma is not None
right_dma.wait()
left_dma.wait()
@functools.partial(
jax.jit, static_argnames=["mesh", "axis_name", "memory_space"]
)
def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str],
memory_space: pltpu.MemorySpace = pltpu.VMEM):
if isinstance(axis_name, str):
axis_name = (axis_name,)
# TODO(sharadmv): enable all gather over multiple axes
if len(axis_name) > 1:
raise NotImplementedError("Only one axis supported.")
axis_name, = axis_name
if mesh.shape[axis_name] == 1:
# We can short-circuit here if our axis size is 1
return x
def ag_local(x_shard):
axis_size = lax.axis_size(axis_name)
out_shape = jax.ShapeDtypeStruct((axis_size, *x_shard.shape), x_shard.dtype)
out = pl.pallas_call(
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh),
out_shape=out_shape,
compiler_params=pltpu.CompilerParams(collective_id=0),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
scratch_shapes=(
(pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA),
(pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA),
),
in_specs=[pl.BlockSpec(memory_space=memory_space)],
out_specs=pl.BlockSpec(memory_space=memory_space),
),
)(x_shard)
return out.reshape((axis_size * x_shard.shape[0], *x_shard.shape[1:]))
return shard_map.shard_map(
ag_local, mesh=mesh, in_specs=P(axis_name), out_specs=P(None),
check_vma=False
)(x)
@@ -0,0 +1,25 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Kernels used for testing pallas_call."""
from jax.experimental import pallas as pl
def double_kernel(x_ref, y_ref):
y_ref[:] = x_ref[:] * 2
def double(x):
return pl.pallas_call(double_kernel, out_shape=x)(x)
@@ -0,0 +1,84 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example matmul TPU kernel.
See discussion in https://docs.jax.dev/en/latest/pallas/tpu/matmul.html.
"""
import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
def matmul_kernel(x_tile_ref, y_tile_ref, o_tile_ref, acc_ref):
@pl.when(pl.program_id(2) == 0)
def init():
acc_ref[...] = jnp.zeros_like(acc_ref)
acc_ref[...] = acc_ref[...] + jnp.dot(
x_tile_ref[...],
y_tile_ref[...],
preferred_element_type=acc_ref.dtype,
)
# It is possible to make this conditional but in general this bundle packs
# quite well for a simple matmul kernel
o_tile_ref[...] = acc_ref[...].astype(o_tile_ref.dtype)
@functools.partial(
jax.jit, static_argnames=["block_shape", "block_k", "debug", "out_dtype"]
)
def matmul(
x: jax.Array,
y: jax.Array,
*,
block_shape,
block_k: int = 256,
out_dtype: jnp.dtype | None = None,
debug: bool = False,
) -> jax.Array:
if out_dtype is None:
if x.dtype != y.dtype:
# TODO(tlongeri): Maybe we could use a deduction similar to jnp.dot
raise TypeError(
f"Cannot deduce output dtype for different input dtypes: {x.dtype},"
f" {y.dtype}"
)
out_dtype = x.dtype
acc_dtype = jnp.float32
if x.dtype in [jnp.int8, jnp.int4, jnp.uint8, jnp.uint4]:
acc_dtype = jnp.int32
l, r = block_shape
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), out_dtype),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((l, block_k), lambda i, _, k: (i, k)),
pl.BlockSpec((block_k, r), lambda _, j, k: (k, j)),
],
out_specs=pl.BlockSpec((l, r), lambda i, j, k: (i, j)),
grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k),
scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)],
),
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "parallel", "arbitrary")),
debug=debug,
)(x, y)
@@ -0,0 +1,15 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.pallas.ops.tpu.megablox.ops import gmm as gmm
@@ -0,0 +1,65 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for GMM kernels."""
import re
import jax
import jax.numpy as jnp
def is_tpu() -> bool:
return "TPU" in jax.devices()[0].device_kind
def tpu_kind() -> str:
"""Query identification string for the currently attached TPU."""
return jax.devices()[0].device_kind
# Most TPU devices follow the pattern "TPU v{version}{variant}", e.g. "TPU v5p"
# TPU v7 has a different pattern (i.e. "TPU7x")
_TPU_KIND_PATTERN = re.compile(r"TPU( v)?(\d+)")
def tpu_generation() -> int:
"""Generation number of the currently attached TPU."""
if version := _TPU_KIND_PATTERN.match(tpu_kind()):
return int(version[2])
raise NotImplementedError("only TPU devices are supported")
def supports_bfloat16_matmul() -> bool:
"""Does the currently attached CPU support bfloat16 inputs?"""
return not is_tpu() or tpu_generation() >= 4
def assert_is_supported_dtype(dtype: jnp.dtype) -> None:
if dtype != jnp.bfloat16 and dtype != jnp.float32:
raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.")
def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype:
"""A type to which both input should be adapted to before dot product."""
# bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed
# input precision, we need to convert bf16 argument to fp32 beforehand.
if (
supports_bfloat16_matmul()
and lhs.dtype == jnp.bfloat16
and rhs.dtype == jnp.bfloat16
):
return jnp.bfloat16
else:
return jnp.float32
@@ -0,0 +1,793 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grouped matrix multiplication kernels for TPU written in Pallas."""
from collections.abc import Callable
import functools
from typing import Any, Optional
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu.megablox import common
import jax.numpy as jnp
partial = functools.partial
def _validate_args(
*,
lhs: jnp.ndarray,
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
expected_rhs_dims: int = 3,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]:
"""Validates the arguments for the gmm function."""
# Validate 'lhs'.
if lhs.ndim != 2:
raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.ndim}-tensor.")
common.assert_is_supported_dtype(lhs.dtype)
# Validate 'rhs'.
if rhs.ndim != expected_rhs_dims:
raise ValueError(
f"Expected {expected_rhs_dims}-tensor for 'rhs' but got"
f" {rhs.ndim}-tensor."
)
common.assert_is_supported_dtype(rhs.dtype)
# Validate 'group_sizes'.
if group_sizes.dtype != jnp.int32:
raise ValueError(
f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}."
)
return lhs, group_sizes, common.select_input_dtype(lhs, rhs)
def _calculate_num_tiles(x: int, tx: int) -> int:
tiles, rem = divmod(x, tx)
if rem:
raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).")
return tiles
def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]:
tiles, rem = divmod(x, tx)
if rem:
tiles += 1
return tiles, rem
GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple
def make_group_metadata(
*,
group_sizes: jnp.ndarray,
m: int,
tm: int,
start_group: jnp.ndarray,
num_nonzero_groups: int,
visit_empty_groups: bool = True,
) -> GroupMetadata:
"""Create the metadata needed for grouped matmul computation.
Args:
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
m: The number of rows in lhs.
tm: The m-dimension tile size being used.
start_group: The group in group sizes to start computing from. This is
particularly useful for when rhs num_groups is sharded.
num_nonzero_groups: Number of groups in group sizes to compute on. Useful in
combination with group_offset.
visit_empty_groups: If True, do not squeeze tiles for empty groups out of
the metadata. This is necessary for tgmm, where we at least need to zero
the output for each group.
Returns:
tuple of:
group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32
dtype. group_offsets[i] indicates the row at which group [i] starts in
the lhs matrix and group_offsets[i-1] = m.
group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will
work on.
m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i'
will work on.
num_tiles: The number of m-dimension tiles to execute.
"""
num_groups = group_sizes.shape[0]
end_group = start_group + num_nonzero_groups - 1
# Calculate the offset of each group, starting at zero. This metadata is
# similar to row offsets in a CSR matrix. The following properties hold:
#
# group_offsets.shape = [num_groups + 1]
# group_offsets[0] = 0
# group_offsets[num_groups] = m
#
# The row at which group 'i' starts is group_offsets[i].
group_ends = jnp.cumsum(group_sizes)
group_offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends])
# Assign a group id to each grid index.
#
# If a group starts somewhere other than the start of a tile or ends somewhere
# other than the end of a tile we need to compute that full tile. Calculate
# the number of tiles for each group by rounding their end up to the nearest
# 'tm' and their start down to the nearest 'tm'.
# (1) Round the group_ends up to the nearest multiple of 'tm'.
#
# NOTE: This does not change group_offsets[num_groups], which is m
# (because we enforce m is divisible by tm).
rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32)
# (2) Round the group_starts down to the nearest multiple of 'tm'.
group_starts = jnp.concatenate(
[jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]
)
rounded_group_starts = group_starts // tm * tm
# (3) Calculate the number of rows in each group.
#
# NOTE: Handle zero-sized groups as a special case. If the start for a
# zero-sized group is not divisible by 'tm' its start will be rounded down and
# its end will be rounded up such that its size will become 1 tile here.
rounded_group_sizes = rounded_group_ends - rounded_group_starts
rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes)
# (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles.
#
# An m-dimension tile is 'owned' by group 'i' if the first row of the tile
# belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1
# initial partial tiles if it's first row does not occur in the first row of a
# tile. The '0-th' group never has a partial tile because it always starts at
# the 0-th row.
#
# If no group has a partial tile, the total number of tiles is equal to
# 'm // tm'. If every group has a partial except the 0-th group, the total
# number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that
#
# tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1
#
# Where tiles_m = m // tm.
#
# NOTE: All group sizes are divisible by 'tm' because of the rounding in steps
# (1) and (2) so this division is exact.
group_tiles = rounded_group_sizes // tm
if visit_empty_groups:
# Insert one tile for empty groups.
group_tiles = jnp.where(group_sizes == 0, 1, group_tiles)
# Create the group ids for each grid index based on the tile counts for each
# group.
#
# NOTE: This repeat(...) will pad group_ids with the final group id if
# group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
# such that we only execute the necessary number of tiles.
tiles_m = _calculate_num_tiles(m, tm)
group_ids = jnp.repeat(
jnp.arange(num_groups, dtype=jnp.int32),
group_tiles,
total_repeat_length=tiles_m + num_groups - 1,
)
# Assign an m-dimension tile id to each grid index.
#
# NOTE: Output tiles can only be re-visited consecutively. The following
# procedure guarantees that m-dimension tile indices respect this.
# (1) Calculate how many times each m-dimension tile will be visited.
#
# Each tile is guaranteed to be visited once by the group that owns the tile.
# The remaining possible visits occur when a group starts inside of a tile at
# a position other than the first row. We can calculate which m-dimension tile
# each group starts in by floor-dividing its offset with `tm` and then count
# tile visits with a histogram.
#
# To avoid double counting tile visits from the group that owns the tile,
# filter these out by assigning their tile id to `tile_m` (one beyond the max)
# such that they're ignored by the subsequent histogram. Also filter out any
# group which is empty.
#
# TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear.
partial_tile_mask = jnp.logical_or(
(group_offsets[:-1] % tm) == 0, group_sizes == 0
)
# Explicitly enable tiles for zero sized groups, if specified. This covers
# zero sized groups that start on a tile-aligned row and those that do not.
if visit_empty_groups:
partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask)
partial_tile_ids = jnp.where(
partial_tile_mask, tiles_m, group_offsets[:-1] // tm
)
tile_visits = (
jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0]
+ 1
)
# Create the m-dimension tile ids for each grid index based on the visit
# counts for each tile.
m_tile_ids = jnp.repeat(
jnp.arange(tiles_m, dtype=jnp.int32),
tile_visits.astype(jnp.int32),
total_repeat_length=tiles_m + num_groups - 1,
)
# Account for sharding.
#
# Find the start of the groups owned by our shard and shift the group_ids and
# m_tile_ids s.t. the metadata for our tiles are at the front of the arrays.
#
# TODO(tgale): Move this offset into the kernel to avoid these rolls.
first_tile_in_shard = (group_ids < start_group).sum()
group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0)
m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0)
# Calculate the number of tiles we need to compute for our shard.
#
# Remove tile visits that belong to a group not in our shard.
iota = jnp.arange(num_groups, dtype=jnp.int32)
active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group)
group_tiles = jnp.where(active_group_mask, group_tiles, 0)
num_tiles = group_tiles.sum()
return (group_offsets, group_ids, m_tile_ids), num_tiles
def _get_group_size(
*, grid_id: jnp.ndarray, group_metadata: GroupMetadata
) -> jnp.ndarray:
"""Calculate the number of rows in the current group."""
group_offsets, group_ids = group_metadata[:2]
group_id = group_ids[grid_id]
group_start = group_offsets[group_id]
group_end = group_offsets[group_id + 1]
return group_end - group_start
def _get_store_mask(
*,
grid_id: jnp.ndarray,
group_metadata: GroupMetadata,
tm: int,
tn: int,
) -> jnp.ndarray:
"""Mask for rows that belong to the current group in the current tile."""
group_offsets, group_ids, m_tile_ids = group_metadata[:3]
group_id = group_ids[grid_id]
group_start = group_offsets[group_id]
group_end = group_offsets[group_id + 1]
m_id = m_tile_ids[grid_id] * tm
iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id
return jnp.logical_and(iota >= group_start, iota < group_end)
def _zero_uninitialized_memory(
out: jnp.ndarray,
*,
start_group: jnp.ndarray,
num_nonzero_groups: int,
group_metadata: GroupMetadata,
) -> jnp.ndarray:
"""Zero out uninitialized memory from output."""
group_offsets = group_metadata[0]
group_start = group_offsets[start_group]
group_end = group_offsets[start_group + num_nonzero_groups]
valid_mask = jax.lax.broadcasted_iota(jnp.int32, (out.shape[0],), 0)
valid_mask = (valid_mask >= group_start) & (valid_mask < group_end)
return jnp.where(valid_mask[:, None], out, 0)
LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]]
@functools.partial(
jax.jit,
static_argnames=[
"preferred_element_type",
"tiling",
"transpose_rhs",
"interpret",
],
)
def gmm(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
preferred_element_type: jnp.dtype = jnp.float32,
tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128),
group_offset: jnp.ndarray | None = None,
existing_out: jnp.ndarray | None = None,
transpose_rhs: bool = False,
interpret: bool = False,
) -> jnp.ndarray:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
Args:
lhs: A 2d, jnp.ndarray with shape [m, k].
rhs: A 3d, jnp.ndarray with shape [num_groups, k, n].
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
preferred_element_type: jnp.dtype, the element type for the output matrix.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset: The group in group sizes to start computing from. This is
particularly useful for when rhs num_groups is sharded.
existing_out: Existing output to write to.
transpose_rhs: True if the rhs needs to be transposed.
interpret: Whether or not to run the kernel in interpret mode, helpful for
testing and debugging.
Returns:
A 2d, jnp.ndarray with shape [m, n].
"""
if existing_out is not None:
assert isinstance(existing_out, jax.Array)
expected_dtype = existing_out.dtype
if expected_dtype != preferred_element_type:
raise ValueError(
"Existing output dtype must match preferred_element_type."
)
if group_offset is None:
group_offset = jnp.array([0], dtype=jnp.int32)
else:
if group_offset.shape:
raise ValueError(
f"group_offset must be a ()-shaped array. Got: {group_offset.shape}."
)
group_offset = group_offset[None]
num_current_groups = rhs.shape[0]
num_total_groups = group_sizes.shape[0]
lhs, group_sizes, input_dtype = _validate_args(
lhs=lhs, rhs=rhs, group_sizes=group_sizes
)
# Gather shape information.
m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[2])
if transpose_rhs:
n = rhs.shape[1]
# If tiling is callable, look up the problem dimensions in the LUT. If no tuned
# tile dimensions are available throw an error.
if callable(tiling):
tiling = tiling(m, k, n)
if tiling is None:
raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})")
tm, tk, tn = tiling
tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk)
tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn)
del n_rem
# Create the metadata we need for computation.
group_metadata, num_active_tiles = make_group_metadata(
group_sizes=group_sizes,
m=m,
tm=tm,
start_group=group_offset[0],
num_nonzero_groups=rhs.shape[0],
visit_empty_groups=False,
)
def kernel(
group_metadata,
group_offset,
lhs,
rhs,
existing_out,
out,
acc_scratch,
):
group_offsets, group_ids, m_tile_ids = group_metadata
del group_offsets, group_ids, group_offset
grid_id = pl.program_id(1)
k_i = pl.program_id(2)
@pl.when(k_i == 0)
def _zero_acc():
acc_scratch[...] = jnp.zeros_like(acc_scratch)
if existing_out is not None:
prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0)
is_first_processed_group = grid_id == 0
m_tile_changed = m_tile_ids[grid_id] != m_tile_ids[prev_grid_id]
first_time_seeing_out = jnp.logical_or(
is_first_processed_group, m_tile_changed
)
@pl.when(first_time_seeing_out)
def _init_out():
out[...] = existing_out[...]
def mask_k_rem(x, *, dim):
if k_rem == 0:
return x
orig_dtype = x.dtype
iota = lax.broadcasted_iota(jnp.int32, x.shape, dim)
x = x.astype(jnp.float32)
return jnp.where(iota < k_rem, x, 0).astype(orig_dtype)
def _store_accum():
mask = _get_store_mask(
grid_id=grid_id,
group_metadata=group_metadata,
tm=tm,
tn=tn,
)
to_store = acc_scratch[...]
out[...] = jax.lax.select(
mask[...], to_store, out[...].astype(jnp.float32)
).astype(preferred_element_type)
def _accum(is_last_k_tile):
if is_last_k_tile:
mask_k_rem_lhs = partial(mask_k_rem, dim=1)
mask_k_rem_rhs = partial(mask_k_rem, dim=int(transpose_rhs))
else:
mask_k_rem_lhs = lambda x: x
mask_k_rem_rhs = lambda x: x
if transpose_rhs:
dot_general_dims = (((1,), (1,)), ((), ()))
else:
dot_general_dims = (((1,), (0,)), ((), ()))
loaded_lhs = lhs[...]
loaded_rhs = rhs[...]
acc_scratch[...] += lax.dot_general(
mask_k_rem_lhs(loaded_lhs).astype(input_dtype),
mask_k_rem_rhs(loaded_rhs).astype(input_dtype),
preferred_element_type=jnp.float32,
dimension_numbers=dot_general_dims,
)
if is_last_k_tile:
_store_accum()
lax.cond(
k_i == tiles_k - 1,
partial(_accum, True),
partial(_accum, False),
)
def lhs_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
# lhs is (m, k). Load the [tm, tk] matrix for this m-tile.
group_offsets, group_ids, m_tile_ids = group_metadata
del n_i, group_offsets, group_ids, group_offset
return m_tile_ids[grid_id], k_i
def rhs_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
# rhs is (num_groups, k, n). Load the [tk, tn] matrix based on the group id
# for this m-tile.
group_offsets, group_ids, m_tile_ids = group_metadata
del group_offsets, m_tile_ids
if transpose_rhs:
k_i, n_i = n_i, k_i
# NOTE: If we're working on only a shard of the rhs we need to adjust the
# group index we load from to account for this. The group_ids are in the
# "unsharded" domain.
return group_ids[grid_id] - group_offset[0], k_i, n_i
def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
# out is (m, n). Load the [tm, tn] matrix for this m-tile.
group_offsets, group_ids, m_tile_ids = group_metadata
del k_i, group_offsets, group_ids, group_offset
return m_tile_ids[grid_id], n_i
out_block_spec = pl.BlockSpec((tm, tn), out_transform_indices)
if existing_out is None:
in_out_block_spec: Any = None
input_output_aliases = {}
else:
in_out_block_spec = out_block_spec
input_output_aliases = {6: 0}
lhs_block_spec = pl.BlockSpec((tm, tk), lhs_transform_indices)
if transpose_rhs:
rhs_block_spec = pl.BlockSpec((None, tn, tk), rhs_transform_indices)
else:
rhs_block_spec = pl.BlockSpec((None, tk, tn), rhs_transform_indices)
lhs_bytes = lhs.size * lhs.itemsize
rhs_bytes = (k * n) * rhs.itemsize # We don't read all of rhs
out_bytes = (m * n) * jnp.dtype(preferred_element_type).itemsize
max_active_tiles = group_metadata[1].size
bytes_accessed = (
(lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes
)
flops = 2 * m * k * n
cost_estimate = pl.CostEstimate(
flops=flops, bytes_accessed=bytes_accessed, transcendentals=0
)
call_gmm = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
lhs_block_spec,
rhs_block_spec,
in_out_block_spec,
],
out_specs=out_block_spec,
grid=(tiles_n, num_active_tiles, tiles_k),
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
),
input_output_aliases=input_output_aliases,
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "arbitrary", "arbitrary")),
interpret=interpret,
cost_estimate=cost_estimate,
)
out = call_gmm(
group_metadata,
group_offset,
lhs,
rhs,
existing_out,
)
if existing_out is None and num_current_groups < num_total_groups:
out = _zero_uninitialized_memory(
out,
start_group=group_offset[0],
num_nonzero_groups=rhs.shape[0],
group_metadata=group_metadata,
)
return out
@functools.partial(
jax.jit,
static_argnames=[
"preferred_element_type",
"tiling",
"num_actual_groups",
"interpret",
],
)
def tgmm(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
preferred_element_type: jnp.dtype = jnp.float32,
tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128),
group_offset: jnp.ndarray | None = None,
num_actual_groups: int | None = None,
existing_out: jnp.ndarray | None = None,
interpret: bool = False,
) -> jnp.ndarray:
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
Args:
lhs: A 2d, jnp.ndarray with shape [k, m].
rhs: A 2d, jnp.ndarray with shape [m, n].
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
preferred_element_type: jnp.dtype, the element type for the output matrix.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset: The group in group sizes to start computing from. This is
particularly useful for when rhs num_groups is sharded.
num_actual_groups: For when num_groups is sharded and we should only compute
the groups that are local, starting from group_offset.
existing_out: Existing output to write to.
interpret: Whether or not to run the kernel in interpret mode, helpful for
testing and debugging.
Returns:
A 3d, jnp.ndarray with shape [num_groups, k, n].
"""
if group_offset is None:
group_offset = jnp.array([0], dtype=jnp.int32)
else:
group_offset = group_offset[None]
lhs, group_sizes, input_dtype = _validate_args(
lhs=lhs, rhs=rhs, group_sizes=group_sizes, expected_rhs_dims=2
)
# Gather shape information.
k, m, n = (lhs.shape[0], lhs.shape[1], rhs.shape[1])
num_groups = group_sizes.shape[0]
num_actual_groups = (
num_actual_groups if num_actual_groups is not None else num_groups
)
# If tiling is callable, look up the problem dimensions in the LUT. If no tuned
# tile dimensions are available throw an error.
if callable(tiling):
tiling = tiling(m, k, n)
if tiling is None:
raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})")
tm, tk, tn = tiling
tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk)
del k_rem
tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn)
del n_rem
# Create the metadata we need for computation.
group_metadata, num_active_tiles = make_group_metadata(
group_sizes=group_sizes,
m=m,
tm=tm,
start_group=group_offset[0],
num_nonzero_groups=num_actual_groups,
visit_empty_groups=True,
)
def kernel(
group_metadata,
group_offset,
lhs,
rhs,
existing_out,
out,
acc_scratch,
):
grid_id = pl.program_id(2)
group_offsets, group_ids, m_tile_ids = group_metadata
del group_offsets, group_offset, m_tile_ids
group = group_ids[grid_id]
prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0)
prev_group = group_ids[prev_grid_id]
group_has_changed = jnp.logical_or(grid_id == 0, prev_group != group)
@pl.when(group_has_changed)
def _zero_acc():
acc_scratch[...] = jnp.zeros_like(acc_scratch)
# We'll only do computation if our group has a nonzero number of rows in it.
dont_skip = (
_get_group_size(grid_id=grid_id, group_metadata=group_metadata) > 0
)
@pl.when(dont_skip)
def _do():
rhs_mask = _get_store_mask(
grid_id=grid_id,
group_metadata=group_metadata,
tm=tm,
tn=tn,
)
lhs_mask = _get_store_mask(
grid_id=grid_id,
group_metadata=group_metadata,
tm=tm,
tn=tk,
)
loaded_lhs = lhs[...]
loaded_rhs = rhs[...]
loaded_lhs = lax.select(
lhs_mask[...],
loaded_lhs.astype(jnp.float32),
jnp.zeros_like(lhs, jnp.float32),
).swapaxes(0, 1)
loaded_rhs = lax.select(
rhs_mask[...],
loaded_rhs.astype(jnp.float32),
jnp.zeros_like(rhs, jnp.float32),
)
acc_scratch[...] += lax.dot(
loaded_lhs.astype(input_dtype),
loaded_rhs.astype(input_dtype),
preferred_element_type=jnp.float32,
)
is_end_of_grid = grid_id == (pl.num_programs(2) - 1)
next_grid_id = jnp.where(is_end_of_grid, grid_id, grid_id + 1)
next_group = group_ids[next_grid_id]
group_is_changing = jnp.logical_or(is_end_of_grid, group != next_group)
@pl.when(group_is_changing)
def _store_accum():
to_store = acc_scratch[...]
if existing_out is not None:
to_store += existing_out[...].astype(jnp.float32)
out[...] = to_store.astype(preferred_element_type)
def lhs_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
# lhs is (m, k). Load the [tm, tk] matrix for this m-tile.
group_offsets, group_ids, m_tile_ids = group_metadata
del n_i, group_offsets, group_ids, group_offset
return m_tile_ids[grid_id], k_i
def rhs_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
# rhs is (m, n). Load the [tm, tn] matrix for this m-tile.
group_offsets, group_ids, m_tile_ids = group_metadata
del k_i, group_offsets, group_ids, group_offset
return m_tile_ids[grid_id], n_i
def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
# out is (num_groups, k, n). Load the [tk, tn] matrix based on the group id
# for this m-tile.
group_offsets, group_ids, m_tile_ids = group_metadata
del group_offsets, m_tile_ids
# NOTE: If we're working on only a shard of the output we need to adjust the
# group index we load from to account for this. The group_ids are in the
# "unsharded" domain.
return group_ids[grid_id] - group_offset[0], k_i, n_i
out_block_spec = pl.BlockSpec((None, tk, tn), out_transform_indices)
if existing_out is None:
in_out_block_spec: Any = None
input_output_aliases = {}
else:
in_out_block_spec = out_block_spec
input_output_aliases = {6: 0}
lhs_block_spec = pl.BlockSpec((tm, tk), lhs_transform_indices)
rhs_block_spec = pl.BlockSpec((tm, tn), rhs_transform_indices)
lhs_bytes = lhs.size * lhs.itemsize
rhs_bytes = rhs.size * rhs.itemsize
out_bytewidth = jnp.dtype(preferred_element_type).itemsize
out_bytes = (num_actual_groups * k * n) * out_bytewidth
bytes_accessed = (
(lhs_bytes * tiles_n) + (rhs_bytes * tiles_k) + out_bytes
)
flops = 2 * m * k * n
cost_estimate = pl.CostEstimate(
flops=flops, bytes_accessed=bytes_accessed, transcendentals=0
)
lhs = lhs.swapaxes(0, 1)
call_gmm = pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(
(num_actual_groups, k, n), preferred_element_type
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
lhs_block_spec,
rhs_block_spec,
in_out_block_spec,
],
out_specs=out_block_spec,
grid=(tiles_n, tiles_k, num_active_tiles),
scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)],
),
input_output_aliases=input_output_aliases,
compiler_params=pltpu.CompilerParams(
dimension_semantics=("parallel", "arbitrary", "arbitrary")),
interpret=interpret,
cost_estimate=cost_estimate,
)
out = call_gmm(
group_metadata,
group_offset,
lhs,
rhs,
existing_out,
)
return out
@@ -0,0 +1,109 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grouped matrix multiplication operations with custom VJPs."""
import jax
from jax.experimental.pallas.ops.tpu.megablox import gmm as backend
import jax.numpy as jnp
gmm = jax.custom_vjp(
backend.gmm,
nondiff_argnums=(3, 4, 7, 8),
)
def _gmm_fwd(
lhs: jnp.ndarray,
rhs: jnp.ndarray,
group_sizes: jnp.ndarray,
preferred_element_type: jnp.dtype = jnp.float32,
tiling: tuple[int, int, int] = (128, 128, 128),
group_offset: jnp.ndarray | None = None,
existing_out: jnp.ndarray | None = None,
transpose_rhs: bool = False,
interpret: bool = False,
) -> tuple[
jnp.ndarray,
tuple[
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray | None,
int,
],
]:
"""Forward function for GMM VJP."""
out = backend.gmm(
lhs,
rhs,
group_sizes,
preferred_element_type,
tiling,
group_offset,
existing_out,
transpose_rhs=transpose_rhs,
interpret=interpret,
)
return out, (lhs, rhs, group_sizes, group_offset, rhs.shape[0])
def _gmm_bwd(
preferred_element_type: jnp.dtype,
tiling: tuple[int, int, int],
transpose_rhs: bool,
interpret: bool,
residual: tuple[
jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray | None,
int,
],
grad: jnp.ndarray,
) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]:
"""Backward function for throughput GMM VJP."""
del preferred_element_type
lhs, rhs, group_sizes, group_offset, num_actual_groups = residual
grad_lhs = backend.gmm(
grad,
rhs,
group_sizes,
lhs[0].dtype,
tiling,
group_offset,
transpose_rhs=not transpose_rhs,
interpret=interpret,
)
grad_rhs = backend.tgmm(
lhs.swapaxes(0, 1),
grad,
group_sizes,
rhs.dtype,
tiling,
group_offset,
num_actual_groups,
interpret=interpret,
)
# NOTE: If the rhs transposition is fused into the forward pass we need to
# return the transpose of the rhs gradient that we calculated above.
#
# TODO(tgale, enriqueps, apaske): Fuse this transposition into the tgmm.
grad_rhs = grad_rhs.swapaxes(1, 2) if transpose_rhs else grad_rhs
return grad_lhs, grad_rhs, None, None, grad
gmm.defvjp(_gmm_fwd, _gmm_bwd)
@@ -0,0 +1,15 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as paged_attention
@@ -0,0 +1,670 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PagedAttention TPU kernel."""
from collections.abc import Sequence
import functools
from typing import Literal
import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils
import jax.numpy as jnp
import numpy as np
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
class MultiPageAsyncCopyDescriptor:
"""Descriptor for async copy of multiple K/V pages from HBM."""
def __init__(
self,
pages_hbm_ref,
scales_pages_hbm_ref,
vmem_buffer,
scales_vmem_buffer,
sem,
page_indices,
page_indices_start_offset,
num_pages_to_load,
head_index,
):
self._vmem_buffer = vmem_buffer
self._scales_vmem_buffer = scales_vmem_buffer
self._num_pages_to_load = num_pages_to_load
if head_index is not None:
self._pages_hbm_ref = pages_hbm_ref.at[head_index]
if scales_pages_hbm_ref is not None:
self._scales_pages_hbm_ref = scales_pages_hbm_ref.at[head_index]
else:
self._scales_pages_hbm_ref = None
else:
self._pages_hbm_ref = pages_hbm_ref
self._scales_pages_hbm_ref = scales_pages_hbm_ref
self._sem = sem
self._page_indices = page_indices
self._page_indices_start_offset = page_indices_start_offset
self._async_copies = [
self._make_async_copy(i) for i in range(self._num_pages_to_load)
]
if (
self._scales_pages_hbm_ref is not None
and self._scales_vmem_buffer is not None
):
self._async_copies += [
self._make_scales_async_copy(i)
for i in range(self._num_pages_to_load)
]
def _make_async_copy(self, i):
page_index = self._page_indices[self._page_indices_start_offset + i]
return pltpu.make_async_copy(
self._pages_hbm_ref.at[page_index], self._vmem_buffer.at[i], self._sem
)
def _make_scales_async_copy(self, i):
page_index = self._page_indices[self._page_indices_start_offset + i]
return pltpu.make_async_copy(
self._scales_pages_hbm_ref.at[page_index],
self._scales_vmem_buffer.at[i],
self._sem,
)
def start(self):
"""Starts the async copies."""
for async_copy in self._async_copies:
async_copy.start()
def _maybe_dequantize(self, x, x_scale, dtype=jnp.bfloat16):
if x_scale is None:
return x.astype(dtype)
return quantization_utils.from_int8(x, x_scale, dtype=dtype)
def wait_and_get_loaded(self) -> jax.Array:
"""Wait async copies and gets the loaded buffer as a jax.Array."""
for async_copy in self._async_copies:
async_copy.wait()
head_dim = self._vmem_buffer.shape[-1]
jax_array = self._vmem_buffer[...].astype(jnp.float32)
if self._scales_vmem_buffer is not None:
scales_jax_array = self._scales_vmem_buffer[...].astype(jnp.float32)
else:
scales_jax_array = None
jax_array = self._maybe_dequantize(jax_array, scales_jax_array)
return jax_array.reshape(-1, head_dim)
def paged_flash_attention_kernel(
lengths_ref,
page_indices_ref,
buffer_index_ref,
init_flag_ref,
q_ref,
k_pages_hbm_ref,
k_scales_pages_hbm_ref,
v_pages_hbm_ref,
v_scales_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
k_scales_vmem_buffer,
v_vmem_buffer,
v_scales_vmem_buffer,
k_sems,
v_sems,
*,
batch_size: int,
pages_per_compute_block: int,
pages_per_sequence: int,
mask_value: float,
attn_logits_soft_cap: float | None,
megacore_mode: str | None,
program_ids=(),
):
"""Pallas kernel for paged attention."""
if program_ids:
core_index, b, h, i = program_ids
else:
core_index, b, h, i = (
pl.program_id(0),
pl.program_id(1),
pl.program_id(2),
pl.program_id(3),
)
num_kv_heads, _, page_size, _ = k_pages_hbm_ref.shape
bk = page_size * pages_per_compute_block
num_cores = pl.num_programs(0)
b_step = num_cores if megacore_mode == "batch" else 1
b_start = core_index if megacore_mode == "batch" else 0
h_step = num_cores if megacore_mode == "kv_head" else 1
h_start = core_index if megacore_mode == "kv_head" else 0
h = h * h_step + h_start
b = b * b_step + b_start
length = lengths_ref[b]
def compute_block_indices(b, h, i):
def advance_b():
next_b = b + b_step
def advance_to_next_non_zero_length():
next_next_b = next_b + b_step
return lax.fori_loop(
lax.div(next_next_b, b_step),
lax.div(batch_size, b_step),
lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b),
next_next_b,
)
return (
lax.cond(
jnp.logical_and(
next_b < batch_size,
lengths_ref[lax.clamp(0, next_b, batch_size - 1)] == 0),
advance_to_next_non_zero_length,
lambda: next_b,
),
h_start,
0,
)
def advance_h():
next_h = h + h_step
return lax.cond(next_h < num_kv_heads, lambda: (b, next_h, 0), advance_b)
return lax.cond(i * bk < lengths_ref[b], lambda: (b, h, i), advance_h)
def create_kv_async_copy_descriptors(b, h, i, buffer_index):
page_offset = b * pages_per_sequence + i * pages_per_compute_block
pages_to_load = pages_per_compute_block
async_copy_k = MultiPageAsyncCopyDescriptor(
k_pages_hbm_ref,
k_scales_pages_hbm_ref,
k_vmem_buffer.at[buffer_index],
k_scales_vmem_buffer.at[buffer_index]
if k_scales_vmem_buffer is not None
else None,
k_sems.at[buffer_index],
page_indices_ref,
page_offset,
pages_to_load,
h,
)
async_copy_v = MultiPageAsyncCopyDescriptor(
v_pages_hbm_ref,
v_scales_pages_hbm_ref,
v_vmem_buffer.at[buffer_index],
v_scales_vmem_buffer.at[buffer_index]
if v_scales_vmem_buffer is not None
else None,
v_sems.at[buffer_index],
page_indices_ref,
page_offset,
pages_to_load,
h,
)
return async_copy_k, async_copy_v
@pl.when(i * bk < length)
def flash_attention():
init_flag = init_flag_ref[0]
init_flag_ref[0] = 0
buffer_index = buffer_index_ref[0]
next_b, next_h, next_i = compute_block_indices(b, h, i + 1)
@pl.when(init_flag)
def prefetch_first_block():
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
b, h, i, buffer_index
)
async_copy_k.start()
async_copy_v.start()
@pl.when(i == 0)
def init():
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)
@pl.when(next_b < batch_size)
def prefetch_next_block():
next_buffer_index = jnp.where(buffer_index == 0, 1, 0)
async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors(
next_b, next_h, next_i, next_buffer_index
)
async_copy_next_k.start()
async_copy_next_v.start()
buffer_index_ref[0] = next_buffer_index
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
b, h, i, buffer_index
)
q = q_ref[...].astype(jnp.float32)
k = async_copy_k.wait_and_get_loaded()
qk = jnp.einsum("gd,td->gt", q, k, preferred_element_type=jnp.float32)
if attn_logits_soft_cap is not None:
capped_qk = jnp.tanh(qk / attn_logits_soft_cap)
qk = capped_qk * attn_logits_soft_cap
mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length
qk = qk + jnp.where(mask, 0.0, mask_value)
m_curr = qk.max(axis=-1)
s_curr = jnp.exp(qk - m_curr[..., None])
m_prev, l_prev = m_ref[...], l_ref[...]
l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
m_ref[...], l_ref[...] = m_next, l_next
v = async_copy_v.wait_and_get_loaded()
o_curr = jnp.einsum("gt,td->gd", s_curr, v)
o_ref[...] = (
(l_prev * alpha * o_ref[...] + beta * o_curr) / l_next
).astype(o_ref.dtype)
def paged_flash_attention_kernel_inline_seq_dim(
lengths_ref,
page_indices_ref,
buffer_index_ref,
init_flag_ref,
q_ref,
k_pages_hbm_ref,
k_scales_pages_hbm_ref,
v_pages_hbm_ref,
v_scales_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
k_scales_vmem_buffer,
v_vmem_buffer,
v_scales_vmem_buffer,
k_sems,
v_sems,
*,
batch_size: int,
pages_per_compute_block: int,
pages_per_sequence: int,
mask_value: float,
attn_logits_soft_cap: float | None,
megacore_mode: str | None,
):
core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2)
# Initialize the output HBM buffers to avoid accessing garbage memory inside
# the kernel body below.
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)
def body(i, _):
paged_flash_attention_kernel(
lengths_ref,
page_indices_ref,
buffer_index_ref,
init_flag_ref,
q_ref,
k_pages_hbm_ref,
k_scales_pages_hbm_ref,
v_pages_hbm_ref,
v_scales_pages_hbm_ref,
o_ref,
m_ref,
l_ref,
k_vmem_buffer,
k_scales_vmem_buffer,
v_vmem_buffer,
v_scales_vmem_buffer,
k_sems,
v_sems,
batch_size=batch_size,
pages_per_compute_block=pages_per_compute_block,
pages_per_sequence=pages_per_sequence,
mask_value=mask_value,
attn_logits_soft_cap=attn_logits_soft_cap,
megacore_mode=megacore_mode,
program_ids=(core_index, b, h, i),
)
return ()
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2]
if megacore_mode == "batch":
num_cores = pl.num_programs(0)
length = lengths_ref[b * num_cores + core_index]
else:
length = lengths_ref[b]
lax.fori_loop(0, lax.div(length + bk - 1, bk), body, ())
@functools.partial(
jax.jit,
static_argnames=[
"pages_per_compute_block",
"attn_logits_soft_cap",
"mask_value",
"megacore_mode",
"inline_seq_dim",
],
)
def paged_attention(
q: jax.Array,
k_pages: jax.Array | quantization_utils.QuantizedTensor,
v_pages: jax.Array | quantization_utils.QuantizedTensor,
lengths: jax.Array,
page_indices: jax.Array,
*,
mask_value: float = DEFAULT_MASK_VALUE,
attn_logits_soft_cap: float | None = None,
pages_per_compute_block: int,
megacore_mode: str | None = None,
inline_seq_dim: bool = True,
) -> jax.Array:
"""Paged grouped query attention.
Args:
q: A [batch_size, num_q_heads, head_dim] jax.Array.
k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
lengths: A i32[batch_size] jax.Array the length of each example.
page_indices: A i32[batch_size, pages_per_sequence] jax.Array. Each entry
should be in the range of [0, total_num_pages), indicating where to locate
the page in `k_pages` or `v_pages`.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
attn_logits_soft_cap: The value used for soft capping the attention logits.
pages_per_compute_block: how many pages to be processed in one flash
attention block in the pallas kernel.
megacore_mode: if set, enable megacore to parallelize the computation. Must
be one of ['kv_head', 'batch', None]. Caveat: set this only if megacore is
enabled, otherwise the kernel may hang. If you are not sure, leave it to
None.
* None: disable megacore parallelism.
* kv_head: megacore parallelism on KV heads; requires number of KV heads
divisible by 2.
* batch: megacore parallelism on batch dimension; requires batch divisible
by 2.
inline_seq_dim: whether to fuse kernel instances along the sequence dim into
one kernel.
Returns:
The output of attention([batch_size, num_q_heads, head_dim]).
"""
if isinstance(k_pages, quantization_utils.QuantizedTensor):
k_pages, k_scales_pages = k_pages.weight, k_pages.scales
assert isinstance(k_scales_pages, jax.Array) # For typing.
k_scales_pages = jnp.broadcast_to(
k_scales_pages, (*k_scales_pages.shape[:-1], k_pages.shape[-1])
)
else:
k_scales_pages = None
if isinstance(v_pages, quantization_utils.QuantizedTensor):
v_pages, v_scales_pages = v_pages.weight, v_pages.scales
assert isinstance(v_scales_pages, jax.Array) # For typing.
v_scales_pages = jnp.broadcast_to(
v_scales_pages, (*v_scales_pages.shape[:-1], v_pages.shape[-1])
)
else:
v_scales_pages = None
batch_size, num_q_heads, head_dim = q.shape
num_kv_heads, _, page_size, head_dim_k = k_pages.shape
batch_size_paged_indices, pages_per_sequence = page_indices.shape
if k_pages.shape != v_pages.shape:
raise ValueError(
f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and"
f" {v_pages.shape}"
)
if num_q_heads % num_kv_heads != 0:
raise ValueError(
"Number of Q heads must be divisible by number of KV heads. Got"
f" {num_q_heads} and {num_kv_heads}."
)
if head_dim_k != head_dim:
raise ValueError(
"head_dim of Q must be the same as that of K/V. Got"
f" {head_dim} and {head_dim_k}."
)
if pages_per_sequence % pages_per_compute_block != 0:
raise ValueError(
"pages_per_compute_block must be divisible by pages per sequence. Got"
f" {pages_per_compute_block} and {pages_per_sequence}."
)
if lengths.shape != (batch_size,):
raise ValueError("`lengths` and `q` must have the same batch size")
if batch_size_paged_indices != batch_size:
raise ValueError("`page_indices` and `q` must have the same batch size")
if lengths.dtype != jnp.int32:
raise ValueError(
f"The dtype of `lengths` must be int32. Got {lengths.dtype}"
)
# TODO(dinghua): get the actual cores per chip once there's an official API.
if megacore_mode == "kv_head":
if num_kv_heads % 2 != 0:
raise ValueError(
"number of KV heads must be even when megacore_mode is 'kv_head'"
)
num_cores = 2
elif megacore_mode == "batch":
if batch_size % 2 != 0:
raise ValueError("batch size must be even when megacore_mode is 'batch'")
num_cores = 2
elif megacore_mode is None:
num_cores = 1
else:
raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]")
num_groups = num_q_heads // num_kv_heads
if (num_groups) % 8 != 0:
# Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a
# <8x128> layout for a <1x128> memref inside the kernel and error out.
q = q.reshape(batch_size, num_q_heads, 1, head_dim)
if megacore_mode == "kv_head":
q_block_spec = pl.BlockSpec(
(None, num_groups, None, head_dim),
lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0),
)
elif megacore_mode == "batch":
q_block_spec = pl.BlockSpec(
(None, num_groups, None, head_dim),
lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0),
)
else:
q_block_spec = pl.BlockSpec(
(None, num_groups, None, head_dim),
lambda core_index, b, h, *_: (b, h, 0, 0),
)
q_dtype_for_kernel_launch = jnp.float32
else:
if megacore_mode == "kv_head":
q_block_spec = pl.BlockSpec(
(None, num_groups, head_dim),
lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0),
)
elif megacore_mode == "batch":
q_block_spec = pl.BlockSpec(
(None, num_groups, head_dim),
lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0),
)
else:
q_block_spec = pl.BlockSpec(
(None, num_groups, head_dim),
lambda core_index, b, h, *_: (b, h, 0),
)
q_dtype_for_kernel_launch = q.dtype
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]]
if inline_seq_dim:
kernel = paged_flash_attention_kernel_inline_seq_dim
grid = (
num_cores,
batch_size // num_cores if megacore_mode == "batch" else batch_size,
num_kv_heads // num_cores
if megacore_mode == "kv_head"
else num_kv_heads,
)
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
else:
kernel = paged_flash_attention_kernel
grid = (
num_cores,
batch_size // num_cores if megacore_mode == "batch" else batch_size,
num_kv_heads // num_cores
if megacore_mode == "kv_head"
else num_kv_heads,
pages_per_sequence // pages_per_compute_block,
)
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")
if k_scales_pages is not None and v_scales_pages is not None:
in_specs = [
q_block_spec,
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pl.ANY),
pl.BlockSpec(memory_space=pl.ANY),
]
scratch_shapes = (
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
k_pages.dtype,
), # k_pages buffer
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
k_scales_pages.dtype,
), # k_scales_pages buffer
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
v_pages.dtype,
), # v_pages buffer
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
v_scales_pages.dtype,
), # v_scales_pages buffer
pltpu.SemaphoreType.DMA((2,)),
pltpu.SemaphoreType.DMA((2,)),
)
else:
in_specs = [
q_block_spec,
pl.BlockSpec(memory_space=pl.ANY),
None,
pl.BlockSpec(memory_space=pl.ANY),
None,
]
scratch_shapes = (
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
k_pages.dtype,
), # k_pages buffer
None,
pltpu.VMEM(
(
2, # For double buffering during DMA copies.
pages_per_compute_block,
page_size,
head_dim,
),
v_pages.dtype,
), # v_pages buffer
None,
pltpu.SemaphoreType.DMA((2,)),
pltpu.SemaphoreType.DMA((2,)),
)
out, _, _ = pl.pallas_call(
functools.partial(
kernel,
pages_per_sequence=pages_per_sequence,
batch_size=batch_size,
pages_per_compute_block=pages_per_compute_block,
mask_value=mask_value,
attn_logits_soft_cap=attn_logits_soft_cap,
megacore_mode=megacore_mode,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
# There are 4 scalars prefetched per kernel call: `lengths_ref`,
# `page_indices_ref`, `buffer_index_ref`, `init_flag_ref`
num_scalar_prefetch=4,
in_specs=in_specs,
out_specs=[
q_block_spec,
q_block_spec,
q_block_spec,
],
grid=grid,
scratch_shapes=scratch_shapes,
),
compiler_params=pltpu.CompilerParams(
dimension_semantics=dimension_semantics
),
out_shape=[
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
],
)(
lengths,
page_indices.reshape(-1),
jnp.zeros((1,), jnp.int32), # buffer index
jnp.ones((1,), jnp.int32), # init flag
q.astype(q_dtype_for_kernel_launch),
k_pages,
k_scales_pages,
v_pages,
v_scales_pages,
)
return out.reshape(batch_size, num_q_heads, head_dim).astype(q.dtype)
@@ -0,0 +1,107 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple
import jax
from jax import numpy as jnp
P = jax.sharding.PartitionSpec
MAX_INT8 = 127.5
class QuantizedTensor(NamedTuple):
"""A tensor which has been quantized to int8 and its scales.
Attributes:
weight: Weight
scales: Scales
"""
weight: jnp.ndarray
scales: jnp.ndarray
def to_int8(x: jnp.ndarray, h: jnp.ndarray) -> jnp.ndarray:
"""Converts a float array to an int8 array with a scale.
Args:
x: Float array.
h: Quantization scale.
Returns:
Int8 array.
"""
return jnp.int8(jnp.rint(x * (MAX_INT8 / h)))
def from_int8(
x: jnp.ndarray, h: jnp.ndarray, dtype: jnp.dtype = jnp.bfloat16
) -> jnp.ndarray:
"""Converts an int8 array to a float array with a scale.
Args:
x: Int8 array.
h: Quantization scale.
dtype: Float dtype to convert to.
Returns:
Float array.
"""
return x.astype(dtype) * h / MAX_INT8
def get_quantization_scales(x: jnp.ndarray) -> jnp.ndarray:
"""Computes the quantization scales for a float array.
These are the maximum values of the trailing dimension.
Args:
x: Float array to quantize.
Returns:
Array of the same shape as input but with the trailing dimension reduced to
a size 1 absolute max value.
"""
return jnp.max(jnp.abs(x), axis=-1, keepdims=True)
def quantize_to_int8(
x: jnp.ndarray,
) -> QuantizedTensor:
"""Quantizes a float array to an int8 QuantizedTensor.
Args:
x: Float array to quantize.
Returns:
Int8 QuantizedTensor.
"""
x_scales = get_quantization_scales(x)
return QuantizedTensor(weight=to_int8(x, x_scales), scales=x_scales)
def unquantize_from_int8(
x: QuantizedTensor,
dtype: jnp.dtype = jnp.bfloat16,
) -> jnp.ndarray:
"""Unquantizes an int8 QuantizedTensor to a float array.
Args:
x: Int8 QuantizedTensor to unquantize.
dtype: Float dtype to unquantize to.
Returns:
Float array.
"""
return from_int8(x.weight, x.scales, dtype)
@@ -0,0 +1,82 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""JAX reference implementation of grouped query attention."""
import jax
from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils
import jax.numpy as jnp
MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
def grouped_query_attention_reference(
queries: jax.Array, # [batch_size, num_q_heads, head_dim]
k_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim]
v_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim]
seq_lens: jax.Array, # i32[batch_size]
soft_cap: float | None = None,
debug: bool = False,
) -> jax.Array: # [batch_size, num_q_heads, head_dim]
"""Grouped query attention with a single query per request."""
# Check input shapes
assert k_pages.shape == v_pages.shape
batch_size, num_q_heads, head_dim = queries.shape
batch_size2, num_kv_heads, max_seq_len, head_dim2 = k_pages.shape
assert batch_size2 == batch_size
assert head_dim2 == head_dim
# Unquantize kv pages if necessary
if isinstance(k_pages, quantization_utils.QuantizedTensor):
k_pages = quantization_utils.unquantize_from_int8(
k_pages, dtype=jnp.float32
)
if isinstance(v_pages, quantization_utils.QuantizedTensor):
v_pages = quantization_utils.unquantize_from_int8(
v_pages, dtype=jnp.float32
)
# Reshape for num_groups queries per k head
assert num_q_heads % num_kv_heads == 0
num_groups = num_q_heads // num_kv_heads
queries = queries.reshape(batch_size, num_kv_heads, num_groups, head_dim)
# Compute the dot product q*k and apply soft cap if necessary
qk = jnp.einsum(
"bhgd,bhtd->bhgt",
queries.astype(jnp.float32),
k_pages.astype(jnp.float32),
)
if soft_cap is not None and soft_cap != 0.0:
qk = jnp.tanh(qk / soft_cap) * soft_cap
assert qk.shape == (batch_size, num_kv_heads, num_groups, max_seq_len)
if debug:
jax.debug.print("qk: {qk}", qk=qk)
# Enforce causal mask (adding dimensions when necessary)
mask = jnp.arange(max_seq_len)[None] < seq_lens[:, None]
qk += jnp.where(mask, 0.0, MASK_VALUE)[:, None, None, :]
if debug:
jax.debug.print("masked: {qk}", qk=qk)
# Generate probability distribution using softmax
probs = jax.nn.softmax(qk, axis=-1).astype(v_pages.dtype)
assert probs.shape == (batch_size, num_kv_heads, num_groups, max_seq_len)
if debug:
jax.debug.print("softmax: {probs}", probs=probs)
# Attention is probability-weighted sum of v heads
attention = jnp.einsum("bhgt,bhtd->bhgd", probs, v_pages)
assert attention.shape == (batch_size, num_kv_heads, num_groups, head_dim)
return attention.reshape(batch_size, num_q_heads, head_dim)
@@ -0,0 +1,22 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.pallas.ops.tpu.ragged_paged_attention import kernel
from jax.experimental.pallas.ops.tpu.ragged_paged_attention import tuned_block_sizes
dynamic_validate_inputs = kernel.dynamic_validate_inputs
ragged_paged_attention = kernel.ragged_paged_attention
ref_ragged_paged_attention = kernel.ref_ragged_paged_attention
static_validate_inputs = kernel.static_validate_inputs
get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
@@ -0,0 +1,899 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TPU-Friendly Ragged Paged Attention kernel.
This kernel offers a highly optimized implementation of ragged paged attention,
specifically designed for TPU and compatible with a wide range of model
specifications. It supports mixed prefill and decoding, enhancing throughput
during inference.
"""
import functools
import jax
from jax import lax
from jax._src import dtypes
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import get_tuned_block_sizes
import jax.numpy as jnp
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
class MultiPageAsyncCopyDescriptor:
"""Descriptor for async copy of multiple K/V pages from HBM."""
def __init__(
self,
pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim]
vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
sem,
page_indices_ref, # i32[max_num_seqs, pages_per_seq]
metadata, # [seq_idx, start_page_idx, end_page_idx]
):
self._vmem_buf = vmem_buf
seq_id, start_page_idx, end_page_idx = metadata
self._async_copies = []
# TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
# a bunch of if-ops. Check the performance when we have benchmarking setup.
for i in range(vmem_buf.shape[0]):
page_idx = start_page_idx + i
page_idx = jax.lax.select(page_idx < end_page_idx, page_idx, 0)
self._async_copies.append(
pltpu.make_async_copy(
pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]],
vmem_buf.at[i],
sem,
)
)
def start(self):
"""Starts the async copies."""
for async_copy in self._async_copies:
async_copy.start()
def wait(self):
for async_copy in self._async_copies:
async_copy.wait()
return self._vmem_buf
def ref_ragged_paged_attention(
queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs: jax.Array, # i32[1],
*,
sm_scale: float = 1.0,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value: float | None = DEFAULT_MASK_VALUE,
k_scale: float | None = None,
v_scale: float | None = None,
):
static_validate_inputs(
queries,
kv_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs,
sm_scale=sm_scale,
k_scale=k_scale,
v_scale=v_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
mask_value=mask_value,
)
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
_, _, num_combined_kv_heads, head_dim = kv_pages.shape
assert num_combined_kv_heads % 2 == 0
num_kv_heads = num_combined_kv_heads // 2
num_q_heads = queries.shape[1]
assert num_q_heads % num_kv_heads == 0
num_query_per_kv = num_q_heads // num_kv_heads
outputs = []
for i in range(num_seqs[0]):
q_start = cu_q_lens[i]
q_end = cu_q_lens[i + 1]
q_len = q_end - q_start
kv_len = kv_lens[i]
indices = page_indices[i]
q = queries[q_start:q_end]
k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[
:kv_len
]
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[
:kv_len
]
if k_scale is not None:
k = k.astype(jnp.float32) * k_scale
k = k.astype(q.dtype)
if v_scale is not None:
v = v.astype(jnp.float32) * v_scale
v = v.astype(q.dtype)
k = jnp.repeat(k, num_query_per_kv, axis=1)
v = jnp.repeat(v, num_query_per_kv, axis=1)
attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32)
attn *= sm_scale
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
jnp.int32, attn.shape, 1
)
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
mask = q_span < kv_span
if sliding_window is not None:
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
if soft_cap is not None:
attn = soft_cap * jnp.tanh(attn / soft_cap)
attn += jnp.where(mask, mask_value, 0.0)
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
outputs.append(out)
return jnp.concatenate(outputs, axis=0)
# Expect to run these checks during runtime.
def dynamic_validate_inputs(
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs: jax.Array, # i32[1]
*,
# These inputs are optional. If not specified, we will not validate them.
sm_scale: float | None = None,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
# Kernel tuning params.
num_kv_pages_per_block: int | None = None,
num_queries_per_block: int | None = None,
vmem_limit_bytes: int | None = None,
):
static_validate_inputs(
q,
kv_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
mask_value=mask_value,
k_scale=k_scale,
v_scale=v_scale,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
vmem_limit_bytes=vmem_limit_bytes,
)
max_num_batched_tokens = q.shape[0]
page_size = kv_pages.shape[1]
max_num_seqs, pages_per_seq = page_indices.shape
if num_seqs[0] > max_num_seqs:
raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}")
max_kv_len = jnp.max(kv_lens)
min_pages_per_seq = pl.cdiv(max_kv_len, page_size)
if pages_per_seq < min_pages_per_seq:
raise ValueError(
f"{pages_per_seq=} must be greater or equal to"
f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}."
)
if cu_q_lens[num_seqs[0]] > max_num_batched_tokens:
raise ValueError(
f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to"
f" {max_num_batched_tokens=}."
)
for i in range(num_seqs[0]):
q_len = cu_q_lens[i + 1] - cu_q_lens[i]
kv_len = kv_lens[i]
if q_len > kv_len:
raise ValueError(
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}."
)
# Expect to run these checks during compile time.
def static_validate_inputs(
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs: jax.Array, # i32[1]
*,
# These inputs are optional. If not specified, we will not validate them.
sm_scale: float | None = None,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
# Kernel tuning params.
num_kv_pages_per_block: int | None = None,
num_queries_per_block: int | None = None,
vmem_limit_bytes: int | None = None,
):
_, num_q_heads, head_dim = q.shape
_, _, num_combined_kv_heads, head_dim_k = kv_pages.shape
assert num_combined_kv_heads % 2 == 0
assert isinstance(k_scale, float) or k_scale is None
assert isinstance(v_scale, float) or v_scale is None
num_kv_heads = num_combined_kv_heads // 2
max_num_seqs, pages_per_seq = page_indices.shape
if num_seqs.shape != (1,):
raise ValueError(f"{num_seqs.shape=} must be (1,)")
if head_dim_k != head_dim:
raise ValueError(
f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}."
)
if kv_lens.shape != (max_num_seqs,):
raise ValueError(
f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where"
" `max_num_seqs` is `page_indices.shape[0]`."
)
if cu_q_lens.shape != (max_num_seqs + 1,):
raise ValueError(
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
" `max_num_seqs` is `page_indices.shape[0]`."
)
if (
kv_lens.dtype != jnp.int32
or page_indices.dtype != jnp.int32
or cu_q_lens.dtype != jnp.int32
):
raise ValueError(
"The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be"
f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=},"
f" {cu_q_lens.dtype=}."
)
if num_q_heads % num_kv_heads != 0:
raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}")
if sliding_window is not None and sliding_window <= 0:
raise ValueError(f"{sliding_window=} must be positive.")
if soft_cap is not None and soft_cap == 0.0:
raise ValueError(f"{soft_cap=} must not be 0.0.")
if (
num_kv_pages_per_block is not None
and not 0 < num_kv_pages_per_block <= pages_per_seq
):
raise ValueError(
f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]."
)
if num_queries_per_block is not None and num_queries_per_block <= 0:
raise ValueError(f"{num_queries_per_block=} must be positive.")
if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
raise ValueError(f"{vmem_limit_bytes=} must be positive.")
del sm_scale # No constraints on sm_scale.
del mask_value # No consstraints on mask_value.
def ragged_paged_attention_kernel(
# Prefetch
kv_lens_ref, # [max_num_seqs]
page_indices_ref, # [max_num_seqs, pages_per_seq]
cu_q_lens_ref, # [max_num_seqs + 1]
seq_buf_idx_ref,
# TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs.
num_seqs_ref,
# Input
q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
# Output
o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
# Scratch
kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
sems, # [2, 2]
l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
*,
sm_scale: float,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value: float | None = DEFAULT_MASK_VALUE,
k_scale: float | None = None,
v_scale: float | None = None,
):
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
pages_per_seq = page_indices_ref.shape[-1]
num_seqs = num_seqs_ref[0]
_, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = (
kv_bufs.shape
)
num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2
num_kv_per_blk = num_kv_pages_per_blk * page_size
num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk
heads_blk_idx, q_blk_idx = (
pl.program_id(0),
pl.program_id(1),
)
num_heads_blks = pl.num_programs(0)
init_seq_idx = seq_buf_idx_ref[0]
init_buf_idx = seq_buf_idx_ref[1]
q_len_start = q_blk_idx * num_q_per_blk
q_len_end = q_len_start + num_q_per_blk
def create_kv_async_copy_descriptors(
heads_blk_idx, seq_idx, kv_blk_idx, buf_idx
):
start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk
end_kv_page_idx = jnp.minimum(
pages_per_seq, pl.cdiv(kv_lens_ref[seq_idx], page_size)
)
metadata = (seq_idx, start_kv_page_idx, end_kv_page_idx)
heads_start = heads_blk_idx * num_combined_kv_heads_per_blk
async_copy_kv = MultiPageAsyncCopyDescriptor(
kv_pages_hbm_ref.at[
:, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), :
],
kv_bufs.at[buf_idx],
sems.at[buf_idx],
page_indices_ref,
metadata,
)
return async_copy_kv
# TODO(jevinjiang): Add these to Mosaic:
# 1. Support arbitrary strided load/store for int4 and int8 dtype.
# 2. Support arbitrary strided load/store for any last dimension.
def strided_load_kv(ref, start, step):
packing = get_dtype_packing(ref.dtype)
if packing == 1:
return [ref[start::step, :]], [ref[start + 1 :: step, :]]
assert packing in (2, 4, 8)
assert step % packing == 0
k_list, v_list = [], []
b_start = start // packing
b_step = step // packing
b_ref = ref.bitcast(jnp.uint32)
b = b_ref[b_start::b_step, :]
# TODO(chengjiyao): use the general strided loading logic for bf16 after
# fixing the issue in mosaic's infer vector layout pass
if ref.dtype == jnp.bfloat16:
bk = b << 16
bv = b & jnp.uint32(0xFFFF0000)
k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16)
v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16)
k_list.append(k)
v_list.append(v)
else:
bitwidth = 32 // packing
bitcast_dst_dtype = jnp.dtype(f"uint{bitwidth}")
for i in range(0, packing, 2):
bk = b >> (i * bitwidth)
k = pltpu.bitcast(bk.astype(bitcast_dst_dtype), ref.dtype)
k_list.append(k)
bv = b >> ((i + 1) * bitwidth)
v = pltpu.bitcast(bv.astype(bitcast_dst_dtype), ref.dtype)
v_list.append(v)
return k_list, v_list
def fold_on_2nd_minor(vec):
assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32
assert len(vec.shape) >= 2
last_dim = vec.shape[-1]
packing = get_dtype_packing(vec.dtype)
if vec.shape[-2] % packing != 0:
vec = vec.astype(jnp.float32)
return vec.reshape(-1, last_dim)
@pl.when(heads_blk_idx + q_blk_idx == 0)
def prefetch_first_kv_blk():
async_copy_kv = create_kv_async_copy_descriptors(
heads_blk_idx, init_seq_idx, 0, init_buf_idx
)
async_copy_kv.start()
def is_cur_q_blk_needed(q_states):
done, cur_seq_idx, _ = q_states
should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs],
cur_seq_idx < num_seqs)
return jnp.logical_and(done == 0, should_run)
def compute_with_cur_q_blk(q_states):
done, cur_seq_idx, cur_buf_idx = q_states
q_start = cu_q_lens_ref[cur_seq_idx]
q_end = cu_q_lens_ref[cur_seq_idx + 1]
q_len = q_end - q_start
kv_len = kv_lens_ref[cur_seq_idx]
def get_next_prefetch_ids(
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
):
next_kv_blk_idx = kv_blk_idx + 1
is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len
next_kv_blk_idx = lax.select(
is_last_kv_blk,
0,
next_kv_blk_idx,
)
is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end
next_seq_idx = lax.select(
is_last_kv_blk,
lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx),
cur_seq_idx,
)
is_last_seq = next_seq_idx == num_seqs
next_seq_idx = lax.select(
is_last_seq,
0,
next_seq_idx,
)
next_heads_blk_idx = lax.select(
is_last_seq,
heads_blk_idx + 1,
heads_blk_idx,
)
next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0)
return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
def flash_attention(
q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim]
k, # [num_kv_per_blk, head_dim]
v, # [num_kv_per_blk, head_dim]
head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
*,
kv_blk_idx,
):
assert q.shape == (
num_q_per_blk * num_q_heads_per_kv_head,
head_dim,
)
assert (
k.shape
== v.shape
== (
num_kv_per_blk,
head_dim,
)
)
assert k.dtype == v.dtype
assert (
head_m_ref.shape
== head_l_ref.shape
== (
num_q_per_blk * num_q_heads_per_kv_head,
128,
)
)
assert head_acc_ref.shape == (
num_q_per_blk,
num_q_heads_per_kv_head,
head_dim,
)
kv_len_start = kv_blk_idx * num_kv_per_blk
def masked_store(ref, val, start, end, group=1):
iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group
pltpu.store(ref, val, mask=jnp.logical_and(iota >= start, iota < end))
def load_with_init(ref, init_val):
return jnp.where(
kv_blk_idx == 0, jnp.full_like(ref, init_val), ref[...]
)
# kv lens will be contracting dim, we should mask out the NaNs.
kv_mask = (
lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start
)
k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype)
v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype)
qk = (
jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32)
* sm_scale
)
store_start = jnp.maximum(q_start - q_len_start, 0)
store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk)
row_ids = (
(kv_len - q_len)
+ q_len_start
- q_start
+ jax.lax.broadcasted_iota(
jnp.int32,
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
0,
)
// num_q_heads_per_kv_head
)
col_ids = kv_len_start + jax.lax.broadcasted_iota(
jnp.int32,
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
1,
)
causal_mask = row_ids < col_ids
if sliding_window is not None:
causal_mask = jnp.logical_or(causal_mask,
row_ids - sliding_window >= col_ids)
if soft_cap is not None:
qk = soft_cap * jnp.tanh(qk / soft_cap)
qk += jnp.where(causal_mask, mask_value, 0.0)
m_curr = jnp.max(qk, axis=1, keepdims=True)
s_curr = jnp.exp(qk - m_curr)
qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32)
lm_store_shape = head_m_ref.shape
m_curr = jnp.broadcast_to(m_curr, lm_store_shape)
l_curr = jnp.broadcast_to(
s_curr.sum(axis=1, keepdims=True), lm_store_shape
)
m_prev = load_with_init(head_m_ref, -jnp.inf)
l_prev = load_with_init(head_l_ref, 0.0)
m_next = jnp.maximum(m_prev, m_curr)
masked_store(
head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head
)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_alpha = alpha * l_prev
l_next = l_alpha + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
masked_store(
head_l_ref,
l_next_safe,
store_start,
store_end,
num_q_heads_per_kv_head,
)
def broadcast_to_shape(arr, shape):
if arr.shape == shape:
return arr
assert len(arr.shape) == len(shape)
assert arr.shape[0] == shape[0]
assert shape[1] % arr.shape[1] == 0
# no-op concatenation.
return jnp.concatenate(
[arr for _ in range(shape[1] // arr.shape[1])], axis=1
)
o_curr = load_with_init(head_acc_ref, 0.0).reshape(-1, head_dim)
l_alpha = broadcast_to_shape(l_alpha, qkv.shape)
beta = broadcast_to_shape(beta, qkv.shape)
l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape)
out = lax.div(
l_alpha * o_curr + beta * qkv,
l_next_safe,
)
masked_store(
head_acc_ref,
out.reshape(head_acc_ref.shape),
store_start,
store_end,
)
def is_valid_kv_blk_in_cur_seq(kv_states):
kv_blk_idx, _ = kv_states
return kv_blk_idx * num_kv_per_blk < kv_len
def compute_with_kv_blk_in_cur_seq(kv_states):
kv_blk_idx, cur_buf_idx = kv_states
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = (
get_next_prefetch_ids(
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
)
)
@pl.when(next_heads_blk_idx < num_heads_blks)
def prefetch_next_kv_blk():
# TODO(jevinjiang): reuse the same buffer if it is already prefetched!
# TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and
# DMA to fixed size buffer!
next_async_copy_kv = create_kv_async_copy_descriptors(
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
)
next_async_copy_kv.start()
cur_async_copy_kv = create_kv_async_copy_descriptors(
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
)
kv_ref = cur_async_copy_kv.wait().reshape(
num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk,
head_dim,
)
kv_packing = get_dtype_packing(kv_ref.dtype)
# NOTE: kv_packing is divided by 2 because k and v are packed together.
kv_load_step = max(1, kv_packing // 2)
for kv_head_chunk_idx in range(0, num_kv_heads_per_blk, kv_load_step):
k_list, v_list = strided_load_kv(
kv_ref, kv_head_chunk_idx * 2, num_combined_kv_heads_per_blk
)
for step_idx in range(kv_load_step):
k = k_list[step_idx]
v = v_list[step_idx]
if k_scale is not None:
# NOTE: Conversion between arbitrary data types is not supported.
# That's why it is converted to float32 first.
k = k.astype(jnp.float32) * k_scale
k = k.astype(q_ref.dtype)
if v_scale is not None:
v = v.astype(jnp.float32) * v_scale
v = v.astype(q_ref.dtype)
kv_head_idx = kv_head_chunk_idx + step_idx
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
# TODO(jevinjiang): extra handling for packed type that can start at
# unaligned position!
q = fold_on_2nd_minor(
q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :]
)
flash_attention(
q,
k,
v,
l_ref.at[kv_head_idx],
m_ref.at[kv_head_idx],
acc_ref.at[
:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :
],
kv_blk_idx=kv_blk_idx,
)
return kv_blk_idx + 1, next_buf_idx
_, next_buf_idx = lax.while_loop(
is_valid_kv_blk_in_cur_seq,
compute_with_kv_blk_in_cur_seq,
(0, cur_buf_idx), # (kv_blk_idx, buf_idx)
)
next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx)
done = lax.select(q_end < q_len_end, done, 1)
return done, next_seq_idx, next_buf_idx
_, seq_idx, buf_idx = lax.while_loop(
is_cur_q_blk_needed,
compute_with_cur_q_blk,
(0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx)
)
# Reset seq_idx for next kv_heads_blk if run out of seqs!
seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0)
seq_buf_idx_ref[1] = buf_idx
o_ref[...] = acc_ref[...].astype(q_ref.dtype)
def get_dtype_packing(dtype):
bits = dtypes.itemsize_bits(dtype)
return 32 // bits
def get_min_heads_per_blk(
num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype
):
q_packing = get_dtype_packing(q_dtype)
kv_packing = get_dtype_packing(kv_dtype)
def can_be_xla_fully_tiled(x, packing):
if x % packing != 0:
return False
x //= packing
return x in (1, 2, 4, 8) or x % 8 == 0
# TODO(jevinjiang): support unaligned number of heads!
if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing):
raise ValueError(
f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled."
)
assert num_combined_kv_heads % 2 == 0
num_kv_heads = num_combined_kv_heads // 2
assert num_q_heads % num_kv_heads == 0
ratio = num_q_heads // num_kv_heads
# TODO(jevinjiang): we can choose smaller tiling for packed type if large
# second minor tiling is not on.
max_combined_kv_tiling = 8 * kv_packing
min_combined_kv_heads = (
max_combined_kv_tiling
if num_combined_kv_heads % max_combined_kv_tiling == 0
else num_combined_kv_heads
)
min_q_heads = min_combined_kv_heads // 2 * ratio
if can_be_xla_fully_tiled(min_q_heads, q_packing):
return min_q_heads, min_combined_kv_heads
return num_q_heads, num_combined_kv_heads
@functools.partial(
jax.jit,
static_argnames=[
"sm_scale",
"mask_value",
"num_kv_pages_per_block",
"num_queries_per_block",
"vmem_limit_bytes",
"sliding_window",
"soft_cap",
"k_scale",
"v_scale",
],
)
def ragged_paged_attention(
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
# TODO(jevinjiang): create a write_to_kv_cache kernel!
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
kv_lens: jax.Array, # i32[max_num_seqs]
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
num_seqs: jax.Array, # i32[1]
*,
sm_scale: float = 1.0,
sliding_window: int | None = None,
soft_cap: float | None = None,
mask_value: float | None = DEFAULT_MASK_VALUE,
k_scale: float | None = None,
v_scale: float | None = None,
num_kv_pages_per_block: int | None = None,
num_queries_per_block: int | None = None,
vmem_limit_bytes: int | None = None,
):
"""Ragged paged attention that supports mixed prefill and decode.
Args:
q: concatenated all sequences' queries.
kv_pages: paged KV cache. Normally in HBM.
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
page_indices: the first index indicates which page to use in the kv cache
for each sequence. Only the first num_seqs values are valid.
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
kv_lens, only the first num_seqs+1 values are valid.
num_seqs: the dynamic number of sequences.
sm_scale: the softmax scale which will be applied to the Q@K^T.
sliding_window: the sliding window size for the attention.
soft_cap: the logit soft cap for the attention.
mask_value: mask value for causal mask.
k_scale: the scale for the key cache.
v_scale: the scale for the value cache.
num_kv_pages_per_block: number of kv pages to be processed in one flash
attention block in the pallas kernel.
num_queries_per_block: number of kv pages to be processed in one flash
attention block in the pallas kernel.
vmem_limit_bytes: the vmem limit for the pallas kernel.
Returns:
The output of the attention.
"""
static_validate_inputs(
q,
kv_pages,
kv_lens,
page_indices,
cu_q_lens,
num_seqs,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
mask_value=mask_value,
k_scale=k_scale,
v_scale=v_scale,
num_kv_pages_per_block=num_kv_pages_per_block,
num_queries_per_block=num_queries_per_block,
vmem_limit_bytes=vmem_limit_bytes,
)
if mask_value is None:
mask_value = DEFAULT_MASK_VALUE
num_q_tokens, num_q_heads, head_dim = q.shape
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
assert num_combined_kv_heads % 2 == 0
num_kv_heads = num_combined_kv_heads // 2
_, pages_per_seq = page_indices.shape
num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk(
num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype
)
num_q_per_blk = num_queries_per_block
num_kv_pages_per_blk = num_kv_pages_per_block
if num_q_per_blk is None or num_kv_pages_per_blk is None:
num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes(
q.dtype,
kv_pages.dtype,
num_q_heads_per_blk,
num_combined_kv_heads_per_blk // 2,
head_dim,
page_size,
num_q_tokens,
pages_per_seq,
)
num_q_heads_per_kv_head = num_q_heads // num_kv_heads
num_q_blks = pl.cdiv(num_q_tokens, num_q_per_blk)
assert num_combined_kv_heads_per_blk % 2 == 0
num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2
assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0
num_heads_blks = num_q_heads // num_q_heads_per_blk
grid = (num_heads_blks, num_q_blks)
def q_index_map(heads_blk_idx, q_blk_idx, *_):
return (q_blk_idx, heads_blk_idx, 0)
q_block_spec = pl.BlockSpec(
(num_q_per_blk, num_q_heads_per_blk, head_dim),
q_index_map,
)
in_specs = [
q_block_spec,
pl.BlockSpec(memory_space=pl.ANY),
]
out_specs = q_block_spec
lm_scratch = pltpu.VMEM(
# TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support
# unaligned slicing!
(num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128),
jnp.float32,
)
acc_scratch = pltpu.VMEM(
(num_q_per_blk, num_q_heads_per_blk, head_dim),
jnp.float32,
)
double_buf_scratch = pltpu.VMEM(
(
2, # For double buffering during DMA copies.
num_kv_pages_per_blk,
page_size,
num_combined_kv_heads_per_blk,
head_dim,
),
kv_pages.dtype,
)
scratch_shapes = [
double_buf_scratch, # kv_bufs
pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers.
lm_scratch, # l_ref
lm_scratch, # m_ref
acc_scratch,
]
scalar_prefetches = (
kv_lens,
page_indices,
cu_q_lens,
jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx
num_seqs,
)
kernel = pl.pallas_call(
functools.partial(
ragged_paged_attention_kernel,
sm_scale=sm_scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
mask_value=mask_value,
k_scale=k_scale,
v_scale=v_scale,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=len(scalar_prefetches),
in_specs=in_specs,
out_specs=out_specs,
grid=grid,
scratch_shapes=scratch_shapes,
),
compiler_params=pltpu.CompilerParams(
dimension_semantics=(
"arbitrary",
"arbitrary",
),
vmem_limit_bytes=vmem_limit_bytes,
),
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
name="ragged_paged_attention_kernel",
)
return kernel(*scalar_prefetches, q, kv_pages)
@@ -0,0 +1,14 @@
# Copyright 2025 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
@@ -0,0 +1,210 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Philox PRNG as a Pallas kernel."""
from collections.abc import Sequence
import jax
from jax import typing
from jax._src import prng
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
from jax.experimental.pallas.ops.tpu.random import prng_utils
Shape = Sequence[int]
BLOCK_SIZE = (256, 256)
# Philox constants. See original paper at:
# "Parallel Random Numbers: As Easy as 1, 2, 3", Salmon et. al. 2011
K_HI_32 = 0x9E3779B9
K_LO_32 = 0xBB67AE85
MUL_A = 0xCD9E8D57
MUL_B = 0xD2511F53
def mul32_hi_lo(x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Multiplies 2 32-bit values and returns the hi+low bits of the result."""
xhi = x >> 16
yhi = y >> 16
xlo = x & 0xffff
ylo = y & 0xffff
xy_hi = xhi * yhi
xy_lo = xlo * ylo
cross_xy = xhi * ylo
cross_yx = xlo * yhi
carry = (cross_xy & 0xffff) + (cross_yx & 0xffff) + (xy_lo >> 16)
result_hi = xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16)
result_lo = (carry << 16) + (xy_lo & 0xffff)
return result_hi, result_lo
def philox_4x32(hi0, lo0, hi1, lo1, k_hi, k_lo, rounds = 10):
"""Philox 4x32 keyed hash function."""
k_hi_const = jnp.array(K_HI_32, dtype=jnp.uint32)
k_lo_const = jnp.array(K_LO_32, dtype=jnp.uint32)
mul_a = jnp.array(MUL_A, dtype=jnp.uint32)
mul_b = jnp.array(MUL_B, dtype=jnp.uint32)
for i in range(rounds):
# Compute the round.
new_hi0, new_lo0 = mul32_hi_lo(mul_a, hi1)
new_hi0 = new_hi0 ^ lo0 ^ k_hi
new_hi1, new_lo1 = mul32_hi_lo(mul_b, hi0)
new_hi1 = new_hi1 ^ lo1 ^ k_lo
hi0, lo0, hi1, lo1 = new_hi0, new_lo0, new_hi1, new_lo1
# Raise the key on all iterations except for the last round.
if i != rounds - 1:
k_hi = k_hi + k_hi_const
k_lo = k_lo + k_lo_const
return hi0, lo0, hi1, lo1
def philox_4x32_kernel(key,
shape: Shape,
unpadded_shape: Shape,
block_size: tuple[int, int],
offset: typing.ArrayLike = 0,
fuse_output: bool = True):
"""Generates random bits using the Philox keyed hash function.
Args:
key: A Philox key of shape (2,).
shape: The shape of the output. Must be divisible by `block_size`.
unpadded_shape: If `shape` is padded, then this is the shape of the
output tensor if it were not padded. This is important for indexing
calculations within the kernel. If `shape` is not padded, then this
should be equal to `shape`.
block_size: The block size of the kernel.
offset: An optional offset to the counts.
fuse_output: Whether to fuse the output bits into a single value.
Returns:
A tensor of random bits of shape `shape` if fuse_output=True. Otherwise,
this will return a tensor of shape (2, *shape) with the first channel being
the high bits and the second channel being the low bits.
"""
shape = tuple(shape)
if np.prod(shape) > jnp.iinfo(jnp.uint32).max:
raise ValueError(
f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}")
if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0):
raise ValueError(
f"Shape dimension {shape[-2:]} must be divisible by {block_size}")
grid_dims = shape[:-2] + (
shape[-2] // block_size[-2], shape[-1] // block_size[1],)
offset = jnp.array(offset, dtype=jnp.uint32)
if offset.ndim != 0:
raise ValueError(f"Offset must be scalar, got {offset.shape}")
offset = jnp.reshape(offset, (1,))
def kernel(offset_ref, key_ref, out_ref):
counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
offset = prng_utils.compute_scalar_offset(
counts_idx, unpadded_shape, block_shape)
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
counts_lo = counts_lo + offset.astype(jnp.uint32) + offset_ref[0]
counts_lo = counts_lo.astype(jnp.uint32)
# TODO(justinfu): Support hi bits on count.
_zeros = jnp.zeros_like(counts_lo)
k1 = jnp.reshape(key_ref[0, 0], (1, 1))
k2 = jnp.reshape(key_ref[0, 1], (1, 1))
o1, o2, _, _ = philox_4x32(_zeros, counts_lo, _zeros, _zeros, k1, k2)
if fuse_output:
out_bits = o1 ^ o2
out_ref[...] = out_bits.reshape(out_ref.shape)
else:
out_ref[0, ...] = o1.reshape(out_ref[0].shape)
out_ref[1, ...] = o2.reshape(out_ref[0].shape)
key = key.reshape((1, 2))
block_shape = (1,) * (len(shape)-2) + block_size
if fuse_output:
out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32)
out_spec = pl.BlockSpec(block_shape, lambda *idxs: idxs)
else:
out = jax.ShapeDtypeStruct((2,) + shape, dtype=jnp.uint32)
out_spec = pl.BlockSpec((2,) + block_shape, lambda *idxs: (0, *idxs))
return pl.pallas_call(
kernel,
in_specs=[
pl.BlockSpec(memory_space=pltpu.SMEM),
pl.BlockSpec(memory_space=pltpu.SMEM),
],
out_specs=out_spec,
grid=grid_dims,
out_shape=out,
)(offset, key)
def philox_4x32_count(key,
shape: Shape,
offset: typing.ArrayLike = 0,
fuse_output: bool = True):
"""Convenience function to call philox_4x32_kernel with padded shapes."""
if len(shape) == 0:
return philox_4x32_count(
key, (1, 1), offset=offset, fuse_output=fuse_output)[..., 0, 0]
elif len(shape) == 1:
return philox_4x32_count(
key, (1, *shape), offset=offset, fuse_output=fuse_output)[..., 0, :]
requires_pad = (
shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
if requires_pad:
padded_shape = tuple(shape[:-2]) + (
prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]),
prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]),
)
padded_result = philox_4x32_kernel(
key, padded_shape, shape,
block_size=BLOCK_SIZE, offset=offset,
fuse_output=fuse_output)
return padded_result[..., :shape[-2], :shape[-1]]
else:
return philox_4x32_kernel(key, shape, shape,
block_size=BLOCK_SIZE, offset=offset,
fuse_output=fuse_output)
def philox_split(key, shape: Shape):
"""Splits the key into two keys of the same shape."""
bits1, bits2 = philox_4x32_count(key, shape, fuse_output=False)
return jnp.stack([bits1, bits2], axis=bits1.ndim)
def philox_random_bits(key, bit_width: int, shape: Shape):
if bit_width != 32:
raise ValueError("Only 32-bit PRNG supported.")
return philox_4x32_count(key, shape, fuse_output=True)
def philox_fold_in(key, data):
assert data.ndim == 0
return philox_4x32_count(key, (), offset=data, fuse_output=False)
plphilox_prng_impl = prng.PRNGImpl(
key_shape=(2,),
seed=prng.threefry_seed,
split=philox_split,
random_bits=philox_random_bits,
fold_in=philox_fold_in,
name="pallas_philox4x32",
tag="pllox")
prng.register_prng(plphilox_prng_impl)
@@ -0,0 +1,55 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for PRNG kernels."""
from collections.abc import Sequence
from jax import lax
import jax.numpy as jnp
Shape = Sequence[int]
round_up = lambda x, y: (x + y - 1) // y * y
def blocked_iota(block_shape: Shape,
total_shape: Shape):
"""Computes a sub-block of a larger shaped iota.
Args:
block_shape: The output block shape of the iota.
total_shape: The total shape of the input tensor.
Returns:
Result of the blocked iota.
"""
iota_data = jnp.zeros(block_shape, dtype=jnp.uint32)
multiplier = 1
for dim in range(len(block_shape)-1, -1, -1):
block_mult = 1
counts_lo = lax.broadcasted_iota(
dtype=jnp.uint32, shape=block_shape, dimension=dim
)
iota_data += counts_lo * multiplier * block_mult
multiplier *= total_shape[dim]
return iota_data
def compute_scalar_offset(iteration_index,
total_size: Shape,
block_size: Shape):
ndims = len(iteration_index)
dim_size = 1
total_idx = 0
for i in range(ndims-1, -1, -1):
dim_idx = iteration_index[i] * block_size[i]
total_idx += dim_idx * dim_size
dim_size *= total_size[i]
return total_idx
@@ -0,0 +1,120 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the Threefry PRNG as a Pallas kernel."""
from collections.abc import Sequence
import jax
from jax._src import prng
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
from jax.experimental.pallas.ops.tpu.random import prng_utils
Shape = Sequence[int]
BLOCK_SIZE = (256, 256)
def threefry_2x32_count(key,
shape: Shape,
unpadded_shape: Shape,
block_size: tuple[int, int]):
"""Generates random bits using the Threefry hash function.
This function is a fusion of prng.shaped_iota and prng.threefry_2x32 from
the JAX core library.
Args:
key: A threefry key of shape (2,).
shape: The shape of the output. Must be divisible by `block_size`.
unpadded_shape: If `shape` is padded, then this is the shape of the
output tensor if it were not padded. This is important for indexing
calculations within the kernel. If `shape` is not padded, then this
should be equal to `shape`.
block_size: The block size of the kernel.
Returns:
A tensor of random bits of shape `shape`.
"""
shape = tuple(shape)
if np.prod(shape) > jnp.iinfo(jnp.uint32).max:
raise ValueError(
f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}")
if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0):
raise ValueError(
f"Shape dimension {shape[-2:]} must be divisible by {block_size}")
grid_dims = shape[:-2] + (
shape[-2] // block_size[-2], shape[-1] // block_size[1],)
def kernel(key_ref, out_ref):
counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
offset = prng_utils.compute_scalar_offset(
counts_idx, unpadded_shape, block_shape)
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
counts_lo = counts_lo + offset.astype(jnp.uint32)
counts_lo = counts_lo.astype(jnp.uint32)
# TODO(justinfu): Support hi bits on count.
counts_hi = jnp.zeros_like(counts_lo)
k1 = jnp.reshape(key_ref[0, 0], (1, 1))
k2 = jnp.reshape(key_ref[0, 1], (1, 1))
o1, o2 = prng.threefry2x32_p.bind(
k1, k2, counts_hi, counts_lo)
out_bits = o1 ^ o2
out_ref[...] = out_bits.reshape(out_ref.shape)
key = key.reshape((1, 2))
out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32)
block_shape = (1,) * (len(shape)-2) + block_size
result = pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs),
grid=grid_dims,
out_shape=out,
)(key)
return result
def plthreefry_random_bits(key, bit_width: int, shape: Shape):
if bit_width != 32:
raise ValueError("Only 32-bit PRNG supported.")
if len(shape) == 0:
return plthreefry_random_bits(key, bit_width, (1, 1))[0, 0]
elif len(shape) == 1:
return plthreefry_random_bits(key, bit_width, (1, *shape))[0]
requires_pad = (
shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
if requires_pad:
padded_shape = tuple(shape[:-2]) + (
prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]),
prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]),
)
padded_result = threefry_2x32_count(
key, padded_shape, shape, block_size=BLOCK_SIZE)
return padded_result[..., :shape[-2], :shape[-1]]
else:
return threefry_2x32_count(key, shape, shape, block_size=BLOCK_SIZE)
plthreefry_prng_impl = prng.PRNGImpl(
key_shape=(2,),
seed=prng.threefry_seed,
split=prng.threefry_split,
random_bits=plthreefry_random_bits,
fold_in=prng.threefry_fold_in,
name="pallas_threefry2x32",
tag="plfry")
prng.register_prng(plthreefry_prng_impl)
@@ -0,0 +1,32 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes as BlockSizes
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mha_reference as make_masked_mha_reference
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mqa_reference as make_masked_mqa_reference
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha as make_splash_mha
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha_single_device as make_splash_mha_single_device
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa as make_splash_mqa
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa_single_device as make_splash_mqa_single_device
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import QKVLayout as QKVLayout
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import SegmentIds as SegmentIds
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import CausalMask as CausalMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import FullMask as FullMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import LocalMask as LocalMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_causal_mask as make_causal_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_local_attention_mask as make_local_attention_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_random_mask as make_random_mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import Mask as Mask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import MultiHeadMask as MultiHeadMask
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import NumpyMask as NumpyMask
@@ -0,0 +1,560 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mini-mask creation library."""
from __future__ import annotations
from collections.abc import Callable, Sequence
import dataclasses
from typing import Any
import numpy as np
class Mask:
"""A base class for splash attention masks."""
@property
def shape(self) -> tuple[int, ...]:
raise NotImplementedError
def __getitem__(self, idx) -> np.ndarray:
raise NotImplementedError
def __bool__(self) -> bool:
raise NotImplementedError(
'Conversion to bool is unsupported. Could be caused by using logical'
' instead of bitwise operations on masks.'
)
def __or__(self, other: Mask) -> Mask:
if self.shape != other.shape:
raise ValueError(
f'Invalid shape for other: {other.shape}, expected: {self.shape}'
)
return LogicalOr(self, other)
def __and__(self, other: Mask) -> Mask:
if self.shape != other.shape:
raise ValueError(
f'Invalid shape for other: {other.shape}, expected: {self.shape}'
)
return LogicalAnd(self, other)
def make_causal_mask(shape: tuple[int, int], offset: int = 0) -> np.ndarray:
"""Makes a causal attention mask.
Args:
shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len).
offset: Offset of q start wrt kv. A positive offset shifts the bottom
triangle upward, a negative one shifts it downward. A negative offset
makes the first 'offset' rows of the attention matrix all 0s which leads
to undefined softmax.
Returns:
The causal mask.
"""
q_seq_len, kv_seq_len = shape
q_idx = np.arange(q_seq_len, dtype=np.int32)
kv_idx = np.arange(kv_seq_len, dtype=np.int32)
return (q_idx[:, None] + offset >= kv_idx[None, :]).astype(np.bool_)
def make_local_attention_mask(
shape: tuple[int, int],
window_size: tuple[int | None, int | None],
*,
offset: int = 0,
) -> np.ndarray:
"""Makes a local attention mask."""
q_seq_len, kv_seq_len = shape
q_idx = np.arange(q_seq_len, dtype=np.int32)
kv_idx = np.arange(kv_seq_len, dtype=np.int32)
mask = np.ones((q_seq_len, kv_seq_len), dtype=np.bool_)
left, right = window_size
if left is not None:
mask = mask & (q_idx[:, None] - left + offset <= kv_idx[None, :])
if right is not None:
mask = mask & (q_idx[:, None] + right + offset >= kv_idx[None, :])
return mask.astype(np.bool_)
def make_chunk_attention_mask(
shape: tuple[int, int], chunk_size: int
) -> np.ndarray:
"""Makes a chunked causal attention mask.
Args:
shape: The desired shape of the mask (q_seq_len, kv_seq_len).
chunk_size: The size of the attention chunks.
Returns:
A boolean mask of shape `mask_shape` where True indicates attention is
allowed according to chunked causal rules, and False otherwise.
Raises:
ValueError: If chunk_window_size is None or not positive.
"""
if chunk_size <= 0:
raise ValueError('chunk_size must be positive')
q_seq_len, kv_seq_len = shape
q_idx = np.arange(q_seq_len, dtype=np.int32)
kv_idx = np.arange(kv_seq_len, dtype=np.int32)
# chunk mask calculation
same_chunk = (q_idx[:, None] // chunk_size) == (kv_idx[None, :] // chunk_size)
mask = same_chunk & (q_idx[:, None] >= kv_idx[None, :])
return mask
def make_random_mask(
shape: tuple[int, int], sparsity: float, seed: int
) -> np.ndarray:
"""Makes a random attention mask."""
np.random.seed(seed)
return np.random.binomial(n=1, p=1.0 - sparsity, size=shape).astype(np.bool_)
@dataclasses.dataclass
class LogicalOr(Mask):
left: Mask
right: Mask
def __init__(self, left: Mask, right: Mask):
if left.shape != right.shape:
raise ValueError('Masks must have the same shape')
self.left = left
self.right = right
@property
def shape(self) -> tuple[int, ...]:
return self.left.shape
def __getitem__(self, idx) -> np.ndarray:
return self.left[idx] | self.right[idx]
def __hash__(self):
return hash((type(self),) + (self.left, self.right))
@dataclasses.dataclass
class LogicalAnd(Mask):
left: Mask
right: Mask
def __init__(self, left: Mask, right: Mask):
if left.shape != right.shape:
raise ValueError('Masks must have the same shape')
self.left = left
self.right = right
@property
def shape(self) -> tuple[int, ...]:
return self.left.shape
def __getitem__(self, idx) -> np.ndarray:
return self.left[idx] & self.right[idx]
def __hash__(self):
return hash((type(self),) + (self.left, self.right))
@dataclasses.dataclass
class MultiHeadMask(Mask):
"""Lazy multihead mask, combines multiple lazy masks one per head."""
masks: Sequence[Mask]
def __post_init__(self):
if not self.masks:
raise ValueError('Unsupported empty tuple of masks')
shape = self.masks[0].shape
for mask in self.masks[1:]:
if shape != mask.shape:
raise ValueError(
f'Unexpected mask shape, got: {mask.shape}, expected: {shape}'
)
if not all(isinstance(mask, Mask) for mask in self.masks):
raise ValueError('masks should be of type Mask')
if any(isinstance(mask, MultiHeadMask) for mask in self.masks):
raise ValueError('Nesting MultiHeadMasks is not supported')
@property
def shape(self) -> tuple[int, ...]:
return (len(self.masks),) + self.masks[0].shape
def __getitem__(self, idx) -> np.ndarray:
if len(idx) != 3:
raise NotImplementedError(f'Unsupported slice: {idx}')
head_slice = idx[0]
if isinstance(head_slice, int):
assert head_slice >= 0 and head_slice <= len(self.masks)
return self.masks[head_slice][idx[1:]]
else:
slice_masks = [mask[idx[1:]] for mask in self.masks[head_slice]]
return np.stack(slice_masks)
def __eq__(self, other: object):
if not isinstance(other, type(self)):
return NotImplemented
return self.masks == other.masks
def __hash__(self):
return hash((type(self),) + tuple(hash(mask) for mask in self.masks))
class _ComputableMask(Mask):
"""Superclass for all masks that can be computed inside the kernel using a callable object.
This subclass is designed to be used with Splash Attention.
It allows the mask logic to be computed on-the-fly or fused into the attention
kernel, avoiding the memory cost of materializing the full
(sequence_length, sequence_length) boolean mask array, which can be excessive
for long sequences.
Attributes:
_shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len).
offset: Offset of q start wrt kv. A positive offset shifts the bottom
triangle upward, a negative one shifts it downward. A negative offset
makes the first 'offset' rows of the attention matrix all 0s which leads
to undefined softmax.
q_sequence: Indices of Q sequence. q_sequence is reused across __getitem__
calls which is important for compile-time performance.
mask_function: Function used by the SplashAttention kernel to compute the
mask rather than loading it.
"""
_shape: tuple[int, int]
q_sequence: np.ndarray
mask_function: Callable[..., Any]
def __init__(
self,
shape: tuple[int, int],
mask_function: Callable[..., Any],
shard_count: int = 1,
):
self._shape = shape
self.mask_function = mask_function
q_seq_len = self.shape[0]
if q_seq_len % (shard_count * shard_count) != 0:
raise ValueError(
f'Shard count squared ({shard_count * shard_count}) must'
f' divide Q seq_len ({self.shape[0]}) evenly.'
)
self.q_sequence = np.arange(q_seq_len, dtype=np.int32)
@property
def shape(self) -> tuple[int, ...]:
return self._shape
def __getitem__(self, idx) -> np.ndarray:
if len(idx) != 2:
raise NotImplementedError(f'Unsupported slice: {idx}')
q_slice, kv_slice = idx
if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice):
raise NotImplementedError(f'Unsupported slice: {idx}')
q_slice = _fill_slice(q_slice, self.shape[0])
kv_slice = _fill_slice(kv_slice, self.shape[1])
rows = self.q_sequence[q_slice]
cols = np.arange(kv_slice.start, kv_slice.stop)
return self.mask_function(rows[:, None], cols[None, :])
def __eq__(self, other: object):
raise NotImplementedError()
def __hash__(self):
raise NotImplementedError()
class CausalMask(_ComputableMask):
"""Lazy causal mask, prevents the model from attending to future tokens.
Attributes:
offset: Offset of q start wrt kv. A positive offset shifts the bottom
triangle upward, a negative one shifts it downward. A negative offset
makes the first 'offset' rows of the attention matrix all 0s which leads
to undefined softmax.
"""
offset: int
def __init__(
self,
shape: tuple[int, int],
offset: int = 0,
shard_count: int = 1,
):
self.offset = offset
def causal_mask_function(q_ids, kv_ids):
# When evaluating the mask in _process_mask we typically work with numpy
# array views.
# Avoid the addition when possible to avoid instantiating an actual array.
if self.offset == 0:
return q_ids >= kv_ids
else:
return q_ids + self.offset >= kv_ids
mask_function = causal_mask_function
super().__init__(
shape=shape,
mask_function=mask_function,
shard_count=shard_count,
)
def __eq__(self, other: object):
if not isinstance(other, type(self)):
return NotImplemented
return (
self.shape == other.shape
and self.offset == other.offset
and np.array_equal(self.q_sequence, other.q_sequence)
)
def __hash__(self):
return hash((
type(self),
self.shape,
self.offset,
self.q_sequence.tobytes() if self.q_sequence is not None else None,
))
class ChunkedCausalMask(_ComputableMask):
"""Lazy chunked causal mask.
Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens
attend to each other but not across chunks.
Llama4 models use interleaved chunk attention along with global attention.
Attributes:
chunk_size: The size of each attention chunk.
"""
chunk_size: int
def __init__(
self,
shape: tuple[int, int],
chunk_size: int,
shard_count: int = 1,
):
if chunk_size <= 0:
raise ValueError('chunk_size must be positive')
self.chunk_size = chunk_size
# Define the mask function for chunk attention
def chunked_causal_mask_function(q_ids, kv_ids):
"""Computes the mask logic for the given slice indices."""
# Condition 1: Same chunk
same_chunk = (q_ids // self.chunk_size) == (kv_ids // self.chunk_size)
# Condition 2: Causal
causal = q_ids >= kv_ids
return same_chunk & causal
super().__init__(
shape=shape,
mask_function=chunked_causal_mask_function,
shard_count=shard_count,
)
def __eq__(self, other: object):
if not isinstance(other, type(self)):
return NotImplemented
return (
self.shape == other.shape
and self.chunk_size == other.chunk_size
and np.array_equal(self.q_sequence, other.q_sequence)
)
def __hash__(self):
return hash((
type(self),
self.shape,
self.chunk_size,
self.q_sequence.tobytes() if self.q_sequence is not None else None,
))
class LocalMask(_ComputableMask):
"""Lazy local mask, prevents model from attending to tokens outside window.
Attributes:
window_size: Size of the two sides of the local window (None identifies no
limit for the given side).
offset: Offset of q start wrt kv. A positive offset shifts the bottom
triangle upward, a negative one shifts it downward. A negative offset
makes the first 'offset' rows of the attention matrix all 0s which leads
to undefined softmax.
"""
window_size: tuple[int | None, int | None]
offset: int
def __init__(
self,
shape: tuple[int, int],
window_size: tuple[int | None, int | None],
offset: int,
shard_count: int = 1,
):
self.window_size = window_size
self.offset = offset
def local_mask_function(q_ids, kv_ids):
"""Computes the local attention mask for the given slice indices."""
left_size, right_size = self.window_size
assert q_ids.ndim == 2
assert kv_ids.ndim == 2
if left_size is None and right_size is None:
return np.ones((q_ids.shape[0], kv_ids.shape[1]), dtype=np.bool_)
# Avoid the addition when possible to avoid instantiating an actual array.
if offset != 0:
shifted_q_ids = q_ids + self.offset
else:
shifted_q_ids = q_ids
mask = None
if left_size is not None:
mask = shifted_q_ids - left_size <= kv_ids
if right_size is not None:
if mask is None:
mask = shifted_q_ids + right_size >= kv_ids
else:
mask &= shifted_q_ids + right_size >= kv_ids
return mask
super().__init__(
shape=shape,
mask_function=local_mask_function,
shard_count=shard_count,
)
def __eq__(self, other: object):
if not isinstance(other, type(self)):
return False
return (
self.shape == other.shape
and self.window_size == other.window_size
and self.offset == other.offset
and np.array_equal(self.q_sequence, other.q_sequence)
)
def __hash__(self):
return hash((
type(self),
self.shape,
self.window_size,
self.offset,
self.q_sequence.tobytes() if self.q_sequence is not None else None,
))
@dataclasses.dataclass
class NumpyMask(Mask):
"""A mask backed by a dense numpy array."""
array: np.ndarray
def __post_init__(self):
if self.array.ndim != 2:
raise ValueError('Expected a 2-dim array')
if self.array.dtype != np.bool_:
raise ValueError('Mask must be a boolean array')
@property
def shape(self) -> tuple[int, ...]:
return self.array.shape
def __getitem__(self, idx) -> np.ndarray:
return self.array[idx]
def __eq__(self, other: object):
if not isinstance(other, type(self)):
return NotImplemented
return np.array_equal(self.array, other.array, equal_nan=True)
def __hash__(self):
return hash((type(self), self.array.tobytes()))
def _fill_slice(inp_slice: slice, size: int) -> slice:
assert inp_slice.step is None or inp_slice.step == 1
start = 0 if inp_slice.start is None else inp_slice.start
stop = size if inp_slice.stop is None else inp_slice.stop
assert start >= 0
assert stop <= size
return slice(start, stop, None)
@dataclasses.dataclass(frozen=True)
class FullMask(Mask):
"""Lazy full mask, allows all tokens to attend to all other tokens."""
# TODO(amagni): Transform FullMask into a _ComputableMask.
_shape: tuple[int, int]
def __post_init__(self):
if not isinstance(self.shape, tuple):
raise ValueError(f'Unsupported shape type: {type(self.shape)}')
@property
def shape(self) -> tuple[int, ...]:
return self._shape
def __getitem__(self, idx) -> np.ndarray:
if len(idx) != 2:
raise NotImplementedError(f'Unsupported slice: {idx}')
i, j = idx
if not isinstance(i, slice) or not isinstance(j, slice):
raise NotImplementedError(f'Unsupported slice: {idx}')
i = _fill_slice(i, self.shape[0])
j = _fill_slice(j, self.shape[1])
return np.ones((i.stop - i.start, j.stop - j.start), dtype=np.bool_)
def __eq__(self, other: object):
if not isinstance(other, type(self)):
return NotImplemented
return self.shape == other.shape
def __hash__(self):
return hash((type(self), self.shape))