hand
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,155 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Simple all-gather kernel.
|
||||
|
||||
This is meant to be a pedagogical example of how to write a custom collective
|
||||
using Pallas. It doesn't have all possible performance optimizations and doesn't
|
||||
currently handle more diverse topologies.
|
||||
|
||||
The kernel assumes a ring structure on a single mesh axis. It takes the local
|
||||
chunk, splits it in two, and sends each of the half-chunks in each direction
|
||||
(left and right) until every device has received the half chunks.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax._src import shard_map
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
P = jax.sharding.PartitionSpec
|
||||
|
||||
|
||||
def get_neighbor(
|
||||
idx: jax.Array, mesh: jax.sharding.Mesh, axis_name: str, *, direction: str
|
||||
) -> tuple[jax.Array, ...]:
|
||||
"""Helper function that computes the mesh indices of a neighbor."""
|
||||
axis_names = mesh.axis_names
|
||||
which_axis = axis_names.index(axis_name)
|
||||
mesh_index = [
|
||||
idx if i == which_axis else lax.axis_index(a)
|
||||
for i, a in enumerate(axis_names)
|
||||
]
|
||||
axis_size = lax.axis_size(axis_name)
|
||||
if direction == "right":
|
||||
next_idx = lax.rem(idx + 1, axis_size)
|
||||
else:
|
||||
left = idx - 1
|
||||
next_idx = jnp.where(left < 0, left + axis_size, left)
|
||||
mesh_index[which_axis] = next_idx
|
||||
return tuple(mesh_index)
|
||||
|
||||
|
||||
def ag_kernel(x_ref, o_ref, send_sem, recv_sem, *, axis_name: str,
|
||||
mesh: jax.sharding.Mesh):
|
||||
my_id = lax.axis_index(axis_name)
|
||||
# TODO(sharadmv): could speed this up having the first remote DMA go from
|
||||
# x_ref->o_ref immediately instead of a blocking HBM copy.
|
||||
with jax.named_scope("initial_copy"):
|
||||
pltpu.async_copy(x_ref, o_ref.at[my_id], recv_sem[0]).wait()
|
||||
|
||||
with jax.named_scope("neighbour_lookup"):
|
||||
axis_size = lax.axis_size(axis_name)
|
||||
left_neighbor = get_neighbor(my_id, mesh, axis_name, direction="left")
|
||||
right_neighbor = get_neighbor(my_id, mesh, axis_name, direction="right")
|
||||
|
||||
with jax.named_scope("main_barrier"):
|
||||
sem = pltpu.get_barrier_semaphore()
|
||||
pl.semaphore_signal(sem, 1, device_id=left_neighbor)
|
||||
pl.semaphore_signal(sem, 1, device_id=right_neighbor)
|
||||
pl.semaphore_wait(sem, 2)
|
||||
|
||||
shard_size = x_ref.shape[0]
|
||||
right_dma, left_dma = None, None
|
||||
# Main strategy for this AG: carve up our input into two slices. Send
|
||||
# each slice along each direction until they reach every device.
|
||||
for i in range(axis_size - 1):
|
||||
right_slot = my_id - i
|
||||
right_slice = pl.ds(shard_size // 2, shard_size // 2)
|
||||
slot = jnp.where(right_slot < 0, axis_size + right_slot, right_slot)
|
||||
if right_dma:
|
||||
with jax.named_scope("wait_right_dma"):
|
||||
right_dma.wait()
|
||||
right_dma = pltpu.async_remote_copy(
|
||||
o_ref.at[slot, right_slice],
|
||||
o_ref.at[slot, right_slice],
|
||||
send_sem[1],
|
||||
recv_sem[1],
|
||||
device_id=right_neighbor,
|
||||
)
|
||||
|
||||
left_slot = my_id + i
|
||||
left_slice = pl.ds(0, shard_size // 2)
|
||||
slot = lax.rem(left_slot, axis_size)
|
||||
if left_dma:
|
||||
with jax.named_scope("wait_left_dma"):
|
||||
left_dma.wait()
|
||||
left_dma = pltpu.async_remote_copy(
|
||||
o_ref.at[slot, left_slice],
|
||||
o_ref.at[slot, left_slice],
|
||||
send_sem[0],
|
||||
recv_sem[0],
|
||||
device_id=left_neighbor,
|
||||
)
|
||||
with jax.named_scope("wait_all_dma"):
|
||||
assert right_dma is not None
|
||||
assert left_dma is not None
|
||||
right_dma.wait()
|
||||
left_dma.wait()
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit, static_argnames=["mesh", "axis_name", "memory_space"]
|
||||
)
|
||||
def all_gather(x, *, mesh: jax.sharding.Mesh, axis_name: str | Sequence[str],
|
||||
memory_space: pltpu.MemorySpace = pltpu.VMEM):
|
||||
if isinstance(axis_name, str):
|
||||
axis_name = (axis_name,)
|
||||
# TODO(sharadmv): enable all gather over multiple axes
|
||||
if len(axis_name) > 1:
|
||||
raise NotImplementedError("Only one axis supported.")
|
||||
axis_name, = axis_name
|
||||
if mesh.shape[axis_name] == 1:
|
||||
# We can short-circuit here if our axis size is 1
|
||||
return x
|
||||
def ag_local(x_shard):
|
||||
axis_size = lax.axis_size(axis_name)
|
||||
out_shape = jax.ShapeDtypeStruct((axis_size, *x_shard.shape), x_shard.dtype)
|
||||
out = pl.pallas_call(
|
||||
functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh),
|
||||
out_shape=out_shape,
|
||||
compiler_params=pltpu.CompilerParams(collective_id=0),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
scratch_shapes=(
|
||||
(pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA),
|
||||
(pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.DMA),
|
||||
),
|
||||
in_specs=[pl.BlockSpec(memory_space=memory_space)],
|
||||
out_specs=pl.BlockSpec(memory_space=memory_space),
|
||||
),
|
||||
)(x_shard)
|
||||
return out.reshape((axis_size * x_shard.shape[0], *x_shard.shape[1:]))
|
||||
|
||||
return shard_map.shard_map(
|
||||
ag_local, mesh=mesh, in_specs=P(axis_name), out_specs=P(None),
|
||||
check_vma=False
|
||||
)(x)
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Kernels used for testing pallas_call."""
|
||||
|
||||
from jax.experimental import pallas as pl
|
||||
|
||||
|
||||
def double_kernel(x_ref, y_ref):
|
||||
y_ref[:] = x_ref[:] * 2
|
||||
|
||||
|
||||
def double(x):
|
||||
return pl.pallas_call(double_kernel, out_shape=x)(x)
|
||||
+1715
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,84 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Example matmul TPU kernel.
|
||||
|
||||
See discussion in https://docs.jax.dev/en/latest/pallas/tpu/matmul.html.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def matmul_kernel(x_tile_ref, y_tile_ref, o_tile_ref, acc_ref):
|
||||
@pl.when(pl.program_id(2) == 0)
|
||||
def init():
|
||||
acc_ref[...] = jnp.zeros_like(acc_ref)
|
||||
|
||||
acc_ref[...] = acc_ref[...] + jnp.dot(
|
||||
x_tile_ref[...],
|
||||
y_tile_ref[...],
|
||||
preferred_element_type=acc_ref.dtype,
|
||||
)
|
||||
# It is possible to make this conditional but in general this bundle packs
|
||||
# quite well for a simple matmul kernel
|
||||
o_tile_ref[...] = acc_ref[...].astype(o_tile_ref.dtype)
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit, static_argnames=["block_shape", "block_k", "debug", "out_dtype"]
|
||||
)
|
||||
def matmul(
|
||||
x: jax.Array,
|
||||
y: jax.Array,
|
||||
*,
|
||||
block_shape,
|
||||
block_k: int = 256,
|
||||
out_dtype: jnp.dtype | None = None,
|
||||
debug: bool = False,
|
||||
) -> jax.Array:
|
||||
if out_dtype is None:
|
||||
if x.dtype != y.dtype:
|
||||
# TODO(tlongeri): Maybe we could use a deduction similar to jnp.dot
|
||||
raise TypeError(
|
||||
f"Cannot deduce output dtype for different input dtypes: {x.dtype},"
|
||||
f" {y.dtype}"
|
||||
)
|
||||
out_dtype = x.dtype
|
||||
acc_dtype = jnp.float32
|
||||
if x.dtype in [jnp.int8, jnp.int4, jnp.uint8, jnp.uint4]:
|
||||
acc_dtype = jnp.int32
|
||||
|
||||
l, r = block_shape
|
||||
return pl.pallas_call(
|
||||
matmul_kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), out_dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[
|
||||
pl.BlockSpec((l, block_k), lambda i, _, k: (i, k)),
|
||||
pl.BlockSpec((block_k, r), lambda _, j, k: (k, j)),
|
||||
],
|
||||
out_specs=pl.BlockSpec((l, r), lambda i, j, k: (i, j)),
|
||||
grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k),
|
||||
scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)],
|
||||
),
|
||||
compiler_params=pltpu.CompilerParams(
|
||||
dimension_semantics=("parallel", "parallel", "arbitrary")),
|
||||
debug=debug,
|
||||
)(x, y)
|
||||
+15
@@ -0,0 +1,15 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental.pallas.ops.tpu.megablox.ops import gmm as gmm
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,65 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Common utilities for GMM kernels."""
|
||||
|
||||
import re
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def is_tpu() -> bool:
|
||||
return "TPU" in jax.devices()[0].device_kind
|
||||
|
||||
|
||||
def tpu_kind() -> str:
|
||||
"""Query identification string for the currently attached TPU."""
|
||||
return jax.devices()[0].device_kind
|
||||
|
||||
|
||||
# Most TPU devices follow the pattern "TPU v{version}{variant}", e.g. "TPU v5p"
|
||||
# TPU v7 has a different pattern (i.e. "TPU7x")
|
||||
_TPU_KIND_PATTERN = re.compile(r"TPU( v)?(\d+)")
|
||||
|
||||
|
||||
def tpu_generation() -> int:
|
||||
"""Generation number of the currently attached TPU."""
|
||||
if version := _TPU_KIND_PATTERN.match(tpu_kind()):
|
||||
return int(version[2])
|
||||
raise NotImplementedError("only TPU devices are supported")
|
||||
|
||||
|
||||
def supports_bfloat16_matmul() -> bool:
|
||||
"""Does the currently attached CPU support bfloat16 inputs?"""
|
||||
return not is_tpu() or tpu_generation() >= 4
|
||||
|
||||
|
||||
def assert_is_supported_dtype(dtype: jnp.dtype) -> None:
|
||||
if dtype != jnp.bfloat16 and dtype != jnp.float32:
|
||||
raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.")
|
||||
|
||||
|
||||
def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype:
|
||||
"""A type to which both input should be adapted to before dot product."""
|
||||
# bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed
|
||||
# input precision, we need to convert bf16 argument to fp32 beforehand.
|
||||
if (
|
||||
supports_bfloat16_matmul()
|
||||
and lhs.dtype == jnp.bfloat16
|
||||
and rhs.dtype == jnp.bfloat16
|
||||
):
|
||||
return jnp.bfloat16
|
||||
else:
|
||||
return jnp.float32
|
||||
@@ -0,0 +1,793 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Grouped matrix multiplication kernels for TPU written in Pallas."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import Any, Optional
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
from jax.experimental.pallas.ops.tpu.megablox import common
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
partial = functools.partial
|
||||
|
||||
|
||||
def _validate_args(
|
||||
*,
|
||||
lhs: jnp.ndarray,
|
||||
rhs: jnp.ndarray,
|
||||
group_sizes: jnp.ndarray,
|
||||
expected_rhs_dims: int = 3,
|
||||
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]:
|
||||
"""Validates the arguments for the gmm function."""
|
||||
# Validate 'lhs'.
|
||||
if lhs.ndim != 2:
|
||||
raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.ndim}-tensor.")
|
||||
common.assert_is_supported_dtype(lhs.dtype)
|
||||
|
||||
# Validate 'rhs'.
|
||||
if rhs.ndim != expected_rhs_dims:
|
||||
raise ValueError(
|
||||
f"Expected {expected_rhs_dims}-tensor for 'rhs' but got"
|
||||
f" {rhs.ndim}-tensor."
|
||||
)
|
||||
common.assert_is_supported_dtype(rhs.dtype)
|
||||
|
||||
# Validate 'group_sizes'.
|
||||
if group_sizes.dtype != jnp.int32:
|
||||
raise ValueError(
|
||||
f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}."
|
||||
)
|
||||
|
||||
return lhs, group_sizes, common.select_input_dtype(lhs, rhs)
|
||||
|
||||
|
||||
def _calculate_num_tiles(x: int, tx: int) -> int:
|
||||
tiles, rem = divmod(x, tx)
|
||||
if rem:
|
||||
raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).")
|
||||
return tiles
|
||||
|
||||
|
||||
def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]:
|
||||
tiles, rem = divmod(x, tx)
|
||||
if rem:
|
||||
tiles += 1
|
||||
return tiles, rem
|
||||
|
||||
|
||||
GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple
|
||||
|
||||
|
||||
def make_group_metadata(
|
||||
*,
|
||||
group_sizes: jnp.ndarray,
|
||||
m: int,
|
||||
tm: int,
|
||||
start_group: jnp.ndarray,
|
||||
num_nonzero_groups: int,
|
||||
visit_empty_groups: bool = True,
|
||||
) -> GroupMetadata:
|
||||
"""Create the metadata needed for grouped matmul computation.
|
||||
|
||||
Args:
|
||||
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
|
||||
m: The number of rows in lhs.
|
||||
tm: The m-dimension tile size being used.
|
||||
start_group: The group in group sizes to start computing from. This is
|
||||
particularly useful for when rhs num_groups is sharded.
|
||||
num_nonzero_groups: Number of groups in group sizes to compute on. Useful in
|
||||
combination with group_offset.
|
||||
visit_empty_groups: If True, do not squeeze tiles for empty groups out of
|
||||
the metadata. This is necessary for tgmm, where we at least need to zero
|
||||
the output for each group.
|
||||
|
||||
Returns:
|
||||
tuple of:
|
||||
group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32
|
||||
dtype. group_offsets[i] indicates the row at which group [i] starts in
|
||||
the lhs matrix and group_offsets[i-1] = m.
|
||||
group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
|
||||
jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will
|
||||
work on.
|
||||
m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and
|
||||
jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i'
|
||||
will work on.
|
||||
num_tiles: The number of m-dimension tiles to execute.
|
||||
"""
|
||||
num_groups = group_sizes.shape[0]
|
||||
end_group = start_group + num_nonzero_groups - 1
|
||||
|
||||
# Calculate the offset of each group, starting at zero. This metadata is
|
||||
# similar to row offsets in a CSR matrix. The following properties hold:
|
||||
#
|
||||
# group_offsets.shape = [num_groups + 1]
|
||||
# group_offsets[0] = 0
|
||||
# group_offsets[num_groups] = m
|
||||
#
|
||||
# The row at which group 'i' starts is group_offsets[i].
|
||||
group_ends = jnp.cumsum(group_sizes)
|
||||
group_offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends])
|
||||
|
||||
# Assign a group id to each grid index.
|
||||
#
|
||||
# If a group starts somewhere other than the start of a tile or ends somewhere
|
||||
# other than the end of a tile we need to compute that full tile. Calculate
|
||||
# the number of tiles for each group by rounding their end up to the nearest
|
||||
# 'tm' and their start down to the nearest 'tm'.
|
||||
|
||||
# (1) Round the group_ends up to the nearest multiple of 'tm'.
|
||||
#
|
||||
# NOTE: This does not change group_offsets[num_groups], which is m
|
||||
# (because we enforce m is divisible by tm).
|
||||
rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32)
|
||||
|
||||
# (2) Round the group_starts down to the nearest multiple of 'tm'.
|
||||
group_starts = jnp.concatenate(
|
||||
[jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]
|
||||
)
|
||||
rounded_group_starts = group_starts // tm * tm
|
||||
|
||||
# (3) Calculate the number of rows in each group.
|
||||
#
|
||||
# NOTE: Handle zero-sized groups as a special case. If the start for a
|
||||
# zero-sized group is not divisible by 'tm' its start will be rounded down and
|
||||
# its end will be rounded up such that its size will become 1 tile here.
|
||||
rounded_group_sizes = rounded_group_ends - rounded_group_starts
|
||||
rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes)
|
||||
|
||||
# (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles.
|
||||
#
|
||||
# An m-dimension tile is 'owned' by group 'i' if the first row of the tile
|
||||
# belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1
|
||||
# initial partial tiles if it's first row does not occur in the first row of a
|
||||
# tile. The '0-th' group never has a partial tile because it always starts at
|
||||
# the 0-th row.
|
||||
#
|
||||
# If no group has a partial tile, the total number of tiles is equal to
|
||||
# 'm // tm'. If every group has a partial except the 0-th group, the total
|
||||
# number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that
|
||||
#
|
||||
# tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1
|
||||
#
|
||||
# Where tiles_m = m // tm.
|
||||
#
|
||||
# NOTE: All group sizes are divisible by 'tm' because of the rounding in steps
|
||||
# (1) and (2) so this division is exact.
|
||||
group_tiles = rounded_group_sizes // tm
|
||||
|
||||
if visit_empty_groups:
|
||||
# Insert one tile for empty groups.
|
||||
group_tiles = jnp.where(group_sizes == 0, 1, group_tiles)
|
||||
|
||||
# Create the group ids for each grid index based on the tile counts for each
|
||||
# group.
|
||||
#
|
||||
# NOTE: This repeat(...) will pad group_ids with the final group id if
|
||||
# group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
|
||||
# such that we only execute the necessary number of tiles.
|
||||
tiles_m = _calculate_num_tiles(m, tm)
|
||||
group_ids = jnp.repeat(
|
||||
jnp.arange(num_groups, dtype=jnp.int32),
|
||||
group_tiles,
|
||||
total_repeat_length=tiles_m + num_groups - 1,
|
||||
)
|
||||
|
||||
# Assign an m-dimension tile id to each grid index.
|
||||
#
|
||||
# NOTE: Output tiles can only be re-visited consecutively. The following
|
||||
# procedure guarantees that m-dimension tile indices respect this.
|
||||
|
||||
# (1) Calculate how many times each m-dimension tile will be visited.
|
||||
#
|
||||
# Each tile is guaranteed to be visited once by the group that owns the tile.
|
||||
# The remaining possible visits occur when a group starts inside of a tile at
|
||||
# a position other than the first row. We can calculate which m-dimension tile
|
||||
# each group starts in by floor-dividing its offset with `tm` and then count
|
||||
# tile visits with a histogram.
|
||||
#
|
||||
# To avoid double counting tile visits from the group that owns the tile,
|
||||
# filter these out by assigning their tile id to `tile_m` (one beyond the max)
|
||||
# such that they're ignored by the subsequent histogram. Also filter out any
|
||||
# group which is empty.
|
||||
#
|
||||
# TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear.
|
||||
partial_tile_mask = jnp.logical_or(
|
||||
(group_offsets[:-1] % tm) == 0, group_sizes == 0
|
||||
)
|
||||
|
||||
# Explicitly enable tiles for zero sized groups, if specified. This covers
|
||||
# zero sized groups that start on a tile-aligned row and those that do not.
|
||||
if visit_empty_groups:
|
||||
partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask)
|
||||
|
||||
partial_tile_ids = jnp.where(
|
||||
partial_tile_mask, tiles_m, group_offsets[:-1] // tm
|
||||
)
|
||||
|
||||
tile_visits = (
|
||||
jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0]
|
||||
+ 1
|
||||
)
|
||||
|
||||
# Create the m-dimension tile ids for each grid index based on the visit
|
||||
# counts for each tile.
|
||||
m_tile_ids = jnp.repeat(
|
||||
jnp.arange(tiles_m, dtype=jnp.int32),
|
||||
tile_visits.astype(jnp.int32),
|
||||
total_repeat_length=tiles_m + num_groups - 1,
|
||||
)
|
||||
|
||||
# Account for sharding.
|
||||
#
|
||||
# Find the start of the groups owned by our shard and shift the group_ids and
|
||||
# m_tile_ids s.t. the metadata for our tiles are at the front of the arrays.
|
||||
#
|
||||
# TODO(tgale): Move this offset into the kernel to avoid these rolls.
|
||||
first_tile_in_shard = (group_ids < start_group).sum()
|
||||
group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0)
|
||||
m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0)
|
||||
|
||||
# Calculate the number of tiles we need to compute for our shard.
|
||||
#
|
||||
# Remove tile visits that belong to a group not in our shard.
|
||||
iota = jnp.arange(num_groups, dtype=jnp.int32)
|
||||
active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group)
|
||||
group_tiles = jnp.where(active_group_mask, group_tiles, 0)
|
||||
num_tiles = group_tiles.sum()
|
||||
return (group_offsets, group_ids, m_tile_ids), num_tiles
|
||||
|
||||
|
||||
def _get_group_size(
|
||||
*, grid_id: jnp.ndarray, group_metadata: GroupMetadata
|
||||
) -> jnp.ndarray:
|
||||
"""Calculate the number of rows in the current group."""
|
||||
group_offsets, group_ids = group_metadata[:2]
|
||||
group_id = group_ids[grid_id]
|
||||
group_start = group_offsets[group_id]
|
||||
group_end = group_offsets[group_id + 1]
|
||||
return group_end - group_start
|
||||
|
||||
|
||||
def _get_store_mask(
|
||||
*,
|
||||
grid_id: jnp.ndarray,
|
||||
group_metadata: GroupMetadata,
|
||||
tm: int,
|
||||
tn: int,
|
||||
) -> jnp.ndarray:
|
||||
"""Mask for rows that belong to the current group in the current tile."""
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata[:3]
|
||||
group_id = group_ids[grid_id]
|
||||
group_start = group_offsets[group_id]
|
||||
group_end = group_offsets[group_id + 1]
|
||||
m_id = m_tile_ids[grid_id] * tm
|
||||
iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id
|
||||
return jnp.logical_and(iota >= group_start, iota < group_end)
|
||||
|
||||
|
||||
def _zero_uninitialized_memory(
|
||||
out: jnp.ndarray,
|
||||
*,
|
||||
start_group: jnp.ndarray,
|
||||
num_nonzero_groups: int,
|
||||
group_metadata: GroupMetadata,
|
||||
) -> jnp.ndarray:
|
||||
"""Zero out uninitialized memory from output."""
|
||||
group_offsets = group_metadata[0]
|
||||
group_start = group_offsets[start_group]
|
||||
group_end = group_offsets[start_group + num_nonzero_groups]
|
||||
valid_mask = jax.lax.broadcasted_iota(jnp.int32, (out.shape[0],), 0)
|
||||
valid_mask = (valid_mask >= group_start) & (valid_mask < group_end)
|
||||
return jnp.where(valid_mask[:, None], out, 0)
|
||||
|
||||
|
||||
LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]]
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"preferred_element_type",
|
||||
"tiling",
|
||||
"transpose_rhs",
|
||||
"interpret",
|
||||
],
|
||||
)
|
||||
def gmm(
|
||||
lhs: jnp.ndarray,
|
||||
rhs: jnp.ndarray,
|
||||
group_sizes: jnp.ndarray,
|
||||
preferred_element_type: jnp.dtype = jnp.float32,
|
||||
tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128),
|
||||
group_offset: jnp.ndarray | None = None,
|
||||
existing_out: jnp.ndarray | None = None,
|
||||
transpose_rhs: bool = False,
|
||||
interpret: bool = False,
|
||||
) -> jnp.ndarray:
|
||||
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
|
||||
|
||||
Args:
|
||||
lhs: A 2d, jnp.ndarray with shape [m, k].
|
||||
rhs: A 3d, jnp.ndarray with shape [num_groups, k, n].
|
||||
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
|
||||
preferred_element_type: jnp.dtype, the element type for the output matrix.
|
||||
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
|
||||
group_offset: The group in group sizes to start computing from. This is
|
||||
particularly useful for when rhs num_groups is sharded.
|
||||
existing_out: Existing output to write to.
|
||||
transpose_rhs: True if the rhs needs to be transposed.
|
||||
interpret: Whether or not to run the kernel in interpret mode, helpful for
|
||||
testing and debugging.
|
||||
|
||||
Returns:
|
||||
A 2d, jnp.ndarray with shape [m, n].
|
||||
"""
|
||||
|
||||
if existing_out is not None:
|
||||
assert isinstance(existing_out, jax.Array)
|
||||
expected_dtype = existing_out.dtype
|
||||
if expected_dtype != preferred_element_type:
|
||||
raise ValueError(
|
||||
"Existing output dtype must match preferred_element_type."
|
||||
)
|
||||
if group_offset is None:
|
||||
group_offset = jnp.array([0], dtype=jnp.int32)
|
||||
else:
|
||||
if group_offset.shape:
|
||||
raise ValueError(
|
||||
f"group_offset must be a ()-shaped array. Got: {group_offset.shape}."
|
||||
)
|
||||
group_offset = group_offset[None]
|
||||
num_current_groups = rhs.shape[0]
|
||||
num_total_groups = group_sizes.shape[0]
|
||||
lhs, group_sizes, input_dtype = _validate_args(
|
||||
lhs=lhs, rhs=rhs, group_sizes=group_sizes
|
||||
)
|
||||
|
||||
# Gather shape information.
|
||||
m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[2])
|
||||
if transpose_rhs:
|
||||
n = rhs.shape[1]
|
||||
|
||||
# If tiling is callable, look up the problem dimensions in the LUT. If no tuned
|
||||
# tile dimensions are available throw an error.
|
||||
if callable(tiling):
|
||||
tiling = tiling(m, k, n)
|
||||
|
||||
if tiling is None:
|
||||
raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})")
|
||||
|
||||
tm, tk, tn = tiling
|
||||
tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk)
|
||||
tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn)
|
||||
del n_rem
|
||||
|
||||
# Create the metadata we need for computation.
|
||||
group_metadata, num_active_tiles = make_group_metadata(
|
||||
group_sizes=group_sizes,
|
||||
m=m,
|
||||
tm=tm,
|
||||
start_group=group_offset[0],
|
||||
num_nonzero_groups=rhs.shape[0],
|
||||
visit_empty_groups=False,
|
||||
)
|
||||
|
||||
def kernel(
|
||||
group_metadata,
|
||||
group_offset,
|
||||
lhs,
|
||||
rhs,
|
||||
existing_out,
|
||||
out,
|
||||
acc_scratch,
|
||||
):
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata
|
||||
del group_offsets, group_ids, group_offset
|
||||
|
||||
grid_id = pl.program_id(1)
|
||||
k_i = pl.program_id(2)
|
||||
|
||||
@pl.when(k_i == 0)
|
||||
def _zero_acc():
|
||||
acc_scratch[...] = jnp.zeros_like(acc_scratch)
|
||||
|
||||
if existing_out is not None:
|
||||
prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0)
|
||||
is_first_processed_group = grid_id == 0
|
||||
m_tile_changed = m_tile_ids[grid_id] != m_tile_ids[prev_grid_id]
|
||||
first_time_seeing_out = jnp.logical_or(
|
||||
is_first_processed_group, m_tile_changed
|
||||
)
|
||||
|
||||
@pl.when(first_time_seeing_out)
|
||||
def _init_out():
|
||||
out[...] = existing_out[...]
|
||||
|
||||
def mask_k_rem(x, *, dim):
|
||||
if k_rem == 0:
|
||||
return x
|
||||
|
||||
orig_dtype = x.dtype
|
||||
iota = lax.broadcasted_iota(jnp.int32, x.shape, dim)
|
||||
x = x.astype(jnp.float32)
|
||||
return jnp.where(iota < k_rem, x, 0).astype(orig_dtype)
|
||||
|
||||
def _store_accum():
|
||||
mask = _get_store_mask(
|
||||
grid_id=grid_id,
|
||||
group_metadata=group_metadata,
|
||||
tm=tm,
|
||||
tn=tn,
|
||||
)
|
||||
to_store = acc_scratch[...]
|
||||
out[...] = jax.lax.select(
|
||||
mask[...], to_store, out[...].astype(jnp.float32)
|
||||
).astype(preferred_element_type)
|
||||
|
||||
def _accum(is_last_k_tile):
|
||||
if is_last_k_tile:
|
||||
mask_k_rem_lhs = partial(mask_k_rem, dim=1)
|
||||
mask_k_rem_rhs = partial(mask_k_rem, dim=int(transpose_rhs))
|
||||
else:
|
||||
mask_k_rem_lhs = lambda x: x
|
||||
mask_k_rem_rhs = lambda x: x
|
||||
|
||||
if transpose_rhs:
|
||||
dot_general_dims = (((1,), (1,)), ((), ()))
|
||||
else:
|
||||
dot_general_dims = (((1,), (0,)), ((), ()))
|
||||
|
||||
loaded_lhs = lhs[...]
|
||||
loaded_rhs = rhs[...]
|
||||
acc_scratch[...] += lax.dot_general(
|
||||
mask_k_rem_lhs(loaded_lhs).astype(input_dtype),
|
||||
mask_k_rem_rhs(loaded_rhs).astype(input_dtype),
|
||||
preferred_element_type=jnp.float32,
|
||||
dimension_numbers=dot_general_dims,
|
||||
)
|
||||
|
||||
if is_last_k_tile:
|
||||
_store_accum()
|
||||
|
||||
lax.cond(
|
||||
k_i == tiles_k - 1,
|
||||
partial(_accum, True),
|
||||
partial(_accum, False),
|
||||
)
|
||||
|
||||
def lhs_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
|
||||
# lhs is (m, k). Load the [tm, tk] matrix for this m-tile.
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata
|
||||
del n_i, group_offsets, group_ids, group_offset
|
||||
return m_tile_ids[grid_id], k_i
|
||||
|
||||
def rhs_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
|
||||
# rhs is (num_groups, k, n). Load the [tk, tn] matrix based on the group id
|
||||
# for this m-tile.
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata
|
||||
del group_offsets, m_tile_ids
|
||||
if transpose_rhs:
|
||||
k_i, n_i = n_i, k_i
|
||||
|
||||
# NOTE: If we're working on only a shard of the rhs we need to adjust the
|
||||
# group index we load from to account for this. The group_ids are in the
|
||||
# "unsharded" domain.
|
||||
return group_ids[grid_id] - group_offset[0], k_i, n_i
|
||||
|
||||
def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
|
||||
# out is (m, n). Load the [tm, tn] matrix for this m-tile.
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata
|
||||
del k_i, group_offsets, group_ids, group_offset
|
||||
return m_tile_ids[grid_id], n_i
|
||||
|
||||
out_block_spec = pl.BlockSpec((tm, tn), out_transform_indices)
|
||||
if existing_out is None:
|
||||
in_out_block_spec: Any = None
|
||||
input_output_aliases = {}
|
||||
else:
|
||||
in_out_block_spec = out_block_spec
|
||||
input_output_aliases = {6: 0}
|
||||
|
||||
lhs_block_spec = pl.BlockSpec((tm, tk), lhs_transform_indices)
|
||||
if transpose_rhs:
|
||||
rhs_block_spec = pl.BlockSpec((None, tn, tk), rhs_transform_indices)
|
||||
else:
|
||||
rhs_block_spec = pl.BlockSpec((None, tk, tn), rhs_transform_indices)
|
||||
|
||||
lhs_bytes = lhs.size * lhs.itemsize
|
||||
rhs_bytes = (k * n) * rhs.itemsize # We don't read all of rhs
|
||||
out_bytes = (m * n) * jnp.dtype(preferred_element_type).itemsize
|
||||
max_active_tiles = group_metadata[1].size
|
||||
bytes_accessed = (
|
||||
(lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes
|
||||
)
|
||||
flops = 2 * m * k * n
|
||||
cost_estimate = pl.CostEstimate(
|
||||
flops=flops, bytes_accessed=bytes_accessed, transcendentals=0
|
||||
)
|
||||
call_gmm = pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=2,
|
||||
in_specs=[
|
||||
lhs_block_spec,
|
||||
rhs_block_spec,
|
||||
in_out_block_spec,
|
||||
],
|
||||
out_specs=out_block_spec,
|
||||
grid=(tiles_n, num_active_tiles, tiles_k),
|
||||
scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)],
|
||||
),
|
||||
input_output_aliases=input_output_aliases,
|
||||
compiler_params=pltpu.CompilerParams(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary")),
|
||||
interpret=interpret,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
|
||||
out = call_gmm(
|
||||
group_metadata,
|
||||
group_offset,
|
||||
lhs,
|
||||
rhs,
|
||||
existing_out,
|
||||
)
|
||||
if existing_out is None and num_current_groups < num_total_groups:
|
||||
out = _zero_uninitialized_memory(
|
||||
out,
|
||||
start_group=group_offset[0],
|
||||
num_nonzero_groups=rhs.shape[0],
|
||||
group_metadata=group_metadata,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"preferred_element_type",
|
||||
"tiling",
|
||||
"num_actual_groups",
|
||||
"interpret",
|
||||
],
|
||||
)
|
||||
def tgmm(
|
||||
lhs: jnp.ndarray,
|
||||
rhs: jnp.ndarray,
|
||||
group_sizes: jnp.ndarray,
|
||||
preferred_element_type: jnp.dtype = jnp.float32,
|
||||
tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128),
|
||||
group_offset: jnp.ndarray | None = None,
|
||||
num_actual_groups: int | None = None,
|
||||
existing_out: jnp.ndarray | None = None,
|
||||
interpret: bool = False,
|
||||
) -> jnp.ndarray:
|
||||
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
|
||||
|
||||
Args:
|
||||
lhs: A 2d, jnp.ndarray with shape [k, m].
|
||||
rhs: A 2d, jnp.ndarray with shape [m, n].
|
||||
group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype.
|
||||
preferred_element_type: jnp.dtype, the element type for the output matrix.
|
||||
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
|
||||
group_offset: The group in group sizes to start computing from. This is
|
||||
particularly useful for when rhs num_groups is sharded.
|
||||
num_actual_groups: For when num_groups is sharded and we should only compute
|
||||
the groups that are local, starting from group_offset.
|
||||
existing_out: Existing output to write to.
|
||||
interpret: Whether or not to run the kernel in interpret mode, helpful for
|
||||
testing and debugging.
|
||||
|
||||
Returns:
|
||||
A 3d, jnp.ndarray with shape [num_groups, k, n].
|
||||
"""
|
||||
if group_offset is None:
|
||||
group_offset = jnp.array([0], dtype=jnp.int32)
|
||||
else:
|
||||
group_offset = group_offset[None]
|
||||
lhs, group_sizes, input_dtype = _validate_args(
|
||||
lhs=lhs, rhs=rhs, group_sizes=group_sizes, expected_rhs_dims=2
|
||||
)
|
||||
|
||||
# Gather shape information.
|
||||
k, m, n = (lhs.shape[0], lhs.shape[1], rhs.shape[1])
|
||||
num_groups = group_sizes.shape[0]
|
||||
num_actual_groups = (
|
||||
num_actual_groups if num_actual_groups is not None else num_groups
|
||||
)
|
||||
|
||||
# If tiling is callable, look up the problem dimensions in the LUT. If no tuned
|
||||
# tile dimensions are available throw an error.
|
||||
if callable(tiling):
|
||||
tiling = tiling(m, k, n)
|
||||
|
||||
if tiling is None:
|
||||
raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})")
|
||||
|
||||
tm, tk, tn = tiling
|
||||
tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk)
|
||||
del k_rem
|
||||
tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn)
|
||||
del n_rem
|
||||
|
||||
# Create the metadata we need for computation.
|
||||
group_metadata, num_active_tiles = make_group_metadata(
|
||||
group_sizes=group_sizes,
|
||||
m=m,
|
||||
tm=tm,
|
||||
start_group=group_offset[0],
|
||||
num_nonzero_groups=num_actual_groups,
|
||||
visit_empty_groups=True,
|
||||
)
|
||||
|
||||
def kernel(
|
||||
group_metadata,
|
||||
group_offset,
|
||||
lhs,
|
||||
rhs,
|
||||
existing_out,
|
||||
out,
|
||||
acc_scratch,
|
||||
):
|
||||
grid_id = pl.program_id(2)
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata
|
||||
del group_offsets, group_offset, m_tile_ids
|
||||
|
||||
group = group_ids[grid_id]
|
||||
prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0)
|
||||
prev_group = group_ids[prev_grid_id]
|
||||
|
||||
group_has_changed = jnp.logical_or(grid_id == 0, prev_group != group)
|
||||
|
||||
@pl.when(group_has_changed)
|
||||
def _zero_acc():
|
||||
acc_scratch[...] = jnp.zeros_like(acc_scratch)
|
||||
|
||||
# We'll only do computation if our group has a nonzero number of rows in it.
|
||||
dont_skip = (
|
||||
_get_group_size(grid_id=grid_id, group_metadata=group_metadata) > 0
|
||||
)
|
||||
|
||||
@pl.when(dont_skip)
|
||||
def _do():
|
||||
rhs_mask = _get_store_mask(
|
||||
grid_id=grid_id,
|
||||
group_metadata=group_metadata,
|
||||
tm=tm,
|
||||
tn=tn,
|
||||
)
|
||||
lhs_mask = _get_store_mask(
|
||||
grid_id=grid_id,
|
||||
group_metadata=group_metadata,
|
||||
tm=tm,
|
||||
tn=tk,
|
||||
)
|
||||
|
||||
loaded_lhs = lhs[...]
|
||||
loaded_rhs = rhs[...]
|
||||
loaded_lhs = lax.select(
|
||||
lhs_mask[...],
|
||||
loaded_lhs.astype(jnp.float32),
|
||||
jnp.zeros_like(lhs, jnp.float32),
|
||||
).swapaxes(0, 1)
|
||||
loaded_rhs = lax.select(
|
||||
rhs_mask[...],
|
||||
loaded_rhs.astype(jnp.float32),
|
||||
jnp.zeros_like(rhs, jnp.float32),
|
||||
)
|
||||
|
||||
acc_scratch[...] += lax.dot(
|
||||
loaded_lhs.astype(input_dtype),
|
||||
loaded_rhs.astype(input_dtype),
|
||||
preferred_element_type=jnp.float32,
|
||||
)
|
||||
|
||||
is_end_of_grid = grid_id == (pl.num_programs(2) - 1)
|
||||
next_grid_id = jnp.where(is_end_of_grid, grid_id, grid_id + 1)
|
||||
next_group = group_ids[next_grid_id]
|
||||
|
||||
group_is_changing = jnp.logical_or(is_end_of_grid, group != next_group)
|
||||
|
||||
@pl.when(group_is_changing)
|
||||
def _store_accum():
|
||||
to_store = acc_scratch[...]
|
||||
if existing_out is not None:
|
||||
to_store += existing_out[...].astype(jnp.float32)
|
||||
out[...] = to_store.astype(preferred_element_type)
|
||||
|
||||
def lhs_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
|
||||
# lhs is (m, k). Load the [tm, tk] matrix for this m-tile.
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata
|
||||
del n_i, group_offsets, group_ids, group_offset
|
||||
return m_tile_ids[grid_id], k_i
|
||||
|
||||
def rhs_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
|
||||
# rhs is (m, n). Load the [tm, tn] matrix for this m-tile.
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata
|
||||
del k_i, group_offsets, group_ids, group_offset
|
||||
return m_tile_ids[grid_id], n_i
|
||||
|
||||
def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
|
||||
# out is (num_groups, k, n). Load the [tk, tn] matrix based on the group id
|
||||
# for this m-tile.
|
||||
group_offsets, group_ids, m_tile_ids = group_metadata
|
||||
del group_offsets, m_tile_ids
|
||||
|
||||
# NOTE: If we're working on only a shard of the output we need to adjust the
|
||||
# group index we load from to account for this. The group_ids are in the
|
||||
# "unsharded" domain.
|
||||
return group_ids[grid_id] - group_offset[0], k_i, n_i
|
||||
|
||||
out_block_spec = pl.BlockSpec((None, tk, tn), out_transform_indices)
|
||||
if existing_out is None:
|
||||
in_out_block_spec: Any = None
|
||||
input_output_aliases = {}
|
||||
else:
|
||||
in_out_block_spec = out_block_spec
|
||||
input_output_aliases = {6: 0}
|
||||
|
||||
lhs_block_spec = pl.BlockSpec((tm, tk), lhs_transform_indices)
|
||||
rhs_block_spec = pl.BlockSpec((tm, tn), rhs_transform_indices)
|
||||
|
||||
lhs_bytes = lhs.size * lhs.itemsize
|
||||
rhs_bytes = rhs.size * rhs.itemsize
|
||||
out_bytewidth = jnp.dtype(preferred_element_type).itemsize
|
||||
out_bytes = (num_actual_groups * k * n) * out_bytewidth
|
||||
bytes_accessed = (
|
||||
(lhs_bytes * tiles_n) + (rhs_bytes * tiles_k) + out_bytes
|
||||
)
|
||||
flops = 2 * m * k * n
|
||||
cost_estimate = pl.CostEstimate(
|
||||
flops=flops, bytes_accessed=bytes_accessed, transcendentals=0
|
||||
)
|
||||
lhs = lhs.swapaxes(0, 1)
|
||||
call_gmm = pl.pallas_call(
|
||||
kernel,
|
||||
out_shape=jax.ShapeDtypeStruct(
|
||||
(num_actual_groups, k, n), preferred_element_type
|
||||
),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=2,
|
||||
in_specs=[
|
||||
lhs_block_spec,
|
||||
rhs_block_spec,
|
||||
in_out_block_spec,
|
||||
],
|
||||
out_specs=out_block_spec,
|
||||
grid=(tiles_n, tiles_k, num_active_tiles),
|
||||
scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)],
|
||||
),
|
||||
input_output_aliases=input_output_aliases,
|
||||
compiler_params=pltpu.CompilerParams(
|
||||
dimension_semantics=("parallel", "arbitrary", "arbitrary")),
|
||||
interpret=interpret,
|
||||
cost_estimate=cost_estimate,
|
||||
)
|
||||
|
||||
out = call_gmm(
|
||||
group_metadata,
|
||||
group_offset,
|
||||
lhs,
|
||||
rhs,
|
||||
existing_out,
|
||||
)
|
||||
return out
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Grouped matrix multiplication operations with custom VJPs."""
|
||||
|
||||
import jax
|
||||
from jax.experimental.pallas.ops.tpu.megablox import gmm as backend
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
gmm = jax.custom_vjp(
|
||||
backend.gmm,
|
||||
nondiff_argnums=(3, 4, 7, 8),
|
||||
)
|
||||
|
||||
|
||||
def _gmm_fwd(
|
||||
lhs: jnp.ndarray,
|
||||
rhs: jnp.ndarray,
|
||||
group_sizes: jnp.ndarray,
|
||||
preferred_element_type: jnp.dtype = jnp.float32,
|
||||
tiling: tuple[int, int, int] = (128, 128, 128),
|
||||
group_offset: jnp.ndarray | None = None,
|
||||
existing_out: jnp.ndarray | None = None,
|
||||
transpose_rhs: bool = False,
|
||||
interpret: bool = False,
|
||||
) -> tuple[
|
||||
jnp.ndarray,
|
||||
tuple[
|
||||
jnp.ndarray,
|
||||
jnp.ndarray,
|
||||
jnp.ndarray,
|
||||
jnp.ndarray | None,
|
||||
int,
|
||||
],
|
||||
]:
|
||||
"""Forward function for GMM VJP."""
|
||||
out = backend.gmm(
|
||||
lhs,
|
||||
rhs,
|
||||
group_sizes,
|
||||
preferred_element_type,
|
||||
tiling,
|
||||
group_offset,
|
||||
existing_out,
|
||||
transpose_rhs=transpose_rhs,
|
||||
interpret=interpret,
|
||||
)
|
||||
return out, (lhs, rhs, group_sizes, group_offset, rhs.shape[0])
|
||||
|
||||
|
||||
def _gmm_bwd(
|
||||
preferred_element_type: jnp.dtype,
|
||||
tiling: tuple[int, int, int],
|
||||
transpose_rhs: bool,
|
||||
interpret: bool,
|
||||
residual: tuple[
|
||||
jnp.ndarray,
|
||||
jnp.ndarray,
|
||||
jnp.ndarray,
|
||||
jnp.ndarray | None,
|
||||
int,
|
||||
],
|
||||
grad: jnp.ndarray,
|
||||
) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]:
|
||||
"""Backward function for throughput GMM VJP."""
|
||||
del preferred_element_type
|
||||
lhs, rhs, group_sizes, group_offset, num_actual_groups = residual
|
||||
grad_lhs = backend.gmm(
|
||||
grad,
|
||||
rhs,
|
||||
group_sizes,
|
||||
lhs[0].dtype,
|
||||
tiling,
|
||||
group_offset,
|
||||
transpose_rhs=not transpose_rhs,
|
||||
interpret=interpret,
|
||||
)
|
||||
grad_rhs = backend.tgmm(
|
||||
lhs.swapaxes(0, 1),
|
||||
grad,
|
||||
group_sizes,
|
||||
rhs.dtype,
|
||||
tiling,
|
||||
group_offset,
|
||||
num_actual_groups,
|
||||
interpret=interpret,
|
||||
)
|
||||
|
||||
# NOTE: If the rhs transposition is fused into the forward pass we need to
|
||||
# return the transpose of the rhs gradient that we calculated above.
|
||||
#
|
||||
# TODO(tgale, enriqueps, apaske): Fuse this transposition into the tgmm.
|
||||
grad_rhs = grad_rhs.swapaxes(1, 2) if transpose_rhs else grad_rhs
|
||||
return grad_lhs, grad_rhs, None, None, grad
|
||||
|
||||
|
||||
gmm.defvjp(_gmm_fwd, _gmm_bwd)
|
||||
+15
@@ -0,0 +1,15 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention as paged_attention
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
+670
@@ -0,0 +1,670 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""PagedAttention TPU kernel."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import functools
|
||||
from typing import Literal
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
|
||||
|
||||
|
||||
class MultiPageAsyncCopyDescriptor:
|
||||
"""Descriptor for async copy of multiple K/V pages from HBM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pages_hbm_ref,
|
||||
scales_pages_hbm_ref,
|
||||
vmem_buffer,
|
||||
scales_vmem_buffer,
|
||||
sem,
|
||||
page_indices,
|
||||
page_indices_start_offset,
|
||||
num_pages_to_load,
|
||||
head_index,
|
||||
):
|
||||
self._vmem_buffer = vmem_buffer
|
||||
self._scales_vmem_buffer = scales_vmem_buffer
|
||||
self._num_pages_to_load = num_pages_to_load
|
||||
if head_index is not None:
|
||||
self._pages_hbm_ref = pages_hbm_ref.at[head_index]
|
||||
if scales_pages_hbm_ref is not None:
|
||||
self._scales_pages_hbm_ref = scales_pages_hbm_ref.at[head_index]
|
||||
else:
|
||||
self._scales_pages_hbm_ref = None
|
||||
else:
|
||||
self._pages_hbm_ref = pages_hbm_ref
|
||||
self._scales_pages_hbm_ref = scales_pages_hbm_ref
|
||||
self._sem = sem
|
||||
self._page_indices = page_indices
|
||||
self._page_indices_start_offset = page_indices_start_offset
|
||||
self._async_copies = [
|
||||
self._make_async_copy(i) for i in range(self._num_pages_to_load)
|
||||
]
|
||||
if (
|
||||
self._scales_pages_hbm_ref is not None
|
||||
and self._scales_vmem_buffer is not None
|
||||
):
|
||||
self._async_copies += [
|
||||
self._make_scales_async_copy(i)
|
||||
for i in range(self._num_pages_to_load)
|
||||
]
|
||||
|
||||
def _make_async_copy(self, i):
|
||||
page_index = self._page_indices[self._page_indices_start_offset + i]
|
||||
return pltpu.make_async_copy(
|
||||
self._pages_hbm_ref.at[page_index], self._vmem_buffer.at[i], self._sem
|
||||
)
|
||||
|
||||
def _make_scales_async_copy(self, i):
|
||||
page_index = self._page_indices[self._page_indices_start_offset + i]
|
||||
return pltpu.make_async_copy(
|
||||
self._scales_pages_hbm_ref.at[page_index],
|
||||
self._scales_vmem_buffer.at[i],
|
||||
self._sem,
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""Starts the async copies."""
|
||||
for async_copy in self._async_copies:
|
||||
async_copy.start()
|
||||
|
||||
def _maybe_dequantize(self, x, x_scale, dtype=jnp.bfloat16):
|
||||
if x_scale is None:
|
||||
return x.astype(dtype)
|
||||
return quantization_utils.from_int8(x, x_scale, dtype=dtype)
|
||||
|
||||
def wait_and_get_loaded(self) -> jax.Array:
|
||||
"""Wait async copies and gets the loaded buffer as a jax.Array."""
|
||||
for async_copy in self._async_copies:
|
||||
async_copy.wait()
|
||||
head_dim = self._vmem_buffer.shape[-1]
|
||||
jax_array = self._vmem_buffer[...].astype(jnp.float32)
|
||||
if self._scales_vmem_buffer is not None:
|
||||
scales_jax_array = self._scales_vmem_buffer[...].astype(jnp.float32)
|
||||
else:
|
||||
scales_jax_array = None
|
||||
jax_array = self._maybe_dequantize(jax_array, scales_jax_array)
|
||||
return jax_array.reshape(-1, head_dim)
|
||||
|
||||
|
||||
def paged_flash_attention_kernel(
|
||||
lengths_ref,
|
||||
page_indices_ref,
|
||||
buffer_index_ref,
|
||||
init_flag_ref,
|
||||
q_ref,
|
||||
k_pages_hbm_ref,
|
||||
k_scales_pages_hbm_ref,
|
||||
v_pages_hbm_ref,
|
||||
v_scales_pages_hbm_ref,
|
||||
o_ref,
|
||||
m_ref,
|
||||
l_ref,
|
||||
k_vmem_buffer,
|
||||
k_scales_vmem_buffer,
|
||||
v_vmem_buffer,
|
||||
v_scales_vmem_buffer,
|
||||
k_sems,
|
||||
v_sems,
|
||||
*,
|
||||
batch_size: int,
|
||||
pages_per_compute_block: int,
|
||||
pages_per_sequence: int,
|
||||
mask_value: float,
|
||||
attn_logits_soft_cap: float | None,
|
||||
megacore_mode: str | None,
|
||||
program_ids=(),
|
||||
):
|
||||
"""Pallas kernel for paged attention."""
|
||||
if program_ids:
|
||||
core_index, b, h, i = program_ids
|
||||
else:
|
||||
core_index, b, h, i = (
|
||||
pl.program_id(0),
|
||||
pl.program_id(1),
|
||||
pl.program_id(2),
|
||||
pl.program_id(3),
|
||||
)
|
||||
num_kv_heads, _, page_size, _ = k_pages_hbm_ref.shape
|
||||
bk = page_size * pages_per_compute_block
|
||||
num_cores = pl.num_programs(0)
|
||||
|
||||
b_step = num_cores if megacore_mode == "batch" else 1
|
||||
b_start = core_index if megacore_mode == "batch" else 0
|
||||
h_step = num_cores if megacore_mode == "kv_head" else 1
|
||||
h_start = core_index if megacore_mode == "kv_head" else 0
|
||||
|
||||
h = h * h_step + h_start
|
||||
b = b * b_step + b_start
|
||||
length = lengths_ref[b]
|
||||
|
||||
def compute_block_indices(b, h, i):
|
||||
|
||||
def advance_b():
|
||||
next_b = b + b_step
|
||||
|
||||
def advance_to_next_non_zero_length():
|
||||
next_next_b = next_b + b_step
|
||||
return lax.fori_loop(
|
||||
lax.div(next_next_b, b_step),
|
||||
lax.div(batch_size, b_step),
|
||||
lambda _, b: jnp.where(lengths_ref[b] == 0, b + b_step, b),
|
||||
next_next_b,
|
||||
)
|
||||
|
||||
return (
|
||||
lax.cond(
|
||||
jnp.logical_and(
|
||||
next_b < batch_size,
|
||||
lengths_ref[lax.clamp(0, next_b, batch_size - 1)] == 0),
|
||||
advance_to_next_non_zero_length,
|
||||
lambda: next_b,
|
||||
),
|
||||
h_start,
|
||||
0,
|
||||
)
|
||||
|
||||
def advance_h():
|
||||
next_h = h + h_step
|
||||
return lax.cond(next_h < num_kv_heads, lambda: (b, next_h, 0), advance_b)
|
||||
|
||||
return lax.cond(i * bk < lengths_ref[b], lambda: (b, h, i), advance_h)
|
||||
|
||||
def create_kv_async_copy_descriptors(b, h, i, buffer_index):
|
||||
page_offset = b * pages_per_sequence + i * pages_per_compute_block
|
||||
pages_to_load = pages_per_compute_block
|
||||
async_copy_k = MultiPageAsyncCopyDescriptor(
|
||||
k_pages_hbm_ref,
|
||||
k_scales_pages_hbm_ref,
|
||||
k_vmem_buffer.at[buffer_index],
|
||||
k_scales_vmem_buffer.at[buffer_index]
|
||||
if k_scales_vmem_buffer is not None
|
||||
else None,
|
||||
k_sems.at[buffer_index],
|
||||
page_indices_ref,
|
||||
page_offset,
|
||||
pages_to_load,
|
||||
h,
|
||||
)
|
||||
async_copy_v = MultiPageAsyncCopyDescriptor(
|
||||
v_pages_hbm_ref,
|
||||
v_scales_pages_hbm_ref,
|
||||
v_vmem_buffer.at[buffer_index],
|
||||
v_scales_vmem_buffer.at[buffer_index]
|
||||
if v_scales_vmem_buffer is not None
|
||||
else None,
|
||||
v_sems.at[buffer_index],
|
||||
page_indices_ref,
|
||||
page_offset,
|
||||
pages_to_load,
|
||||
h,
|
||||
)
|
||||
return async_copy_k, async_copy_v
|
||||
|
||||
@pl.when(i * bk < length)
|
||||
def flash_attention():
|
||||
init_flag = init_flag_ref[0]
|
||||
init_flag_ref[0] = 0
|
||||
buffer_index = buffer_index_ref[0]
|
||||
next_b, next_h, next_i = compute_block_indices(b, h, i + 1)
|
||||
|
||||
@pl.when(init_flag)
|
||||
def prefetch_first_block():
|
||||
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
|
||||
b, h, i, buffer_index
|
||||
)
|
||||
async_copy_k.start()
|
||||
async_copy_v.start()
|
||||
|
||||
@pl.when(i == 0)
|
||||
def init():
|
||||
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
|
||||
l_ref[...] = jnp.zeros_like(l_ref)
|
||||
o_ref[...] = jnp.zeros_like(o_ref)
|
||||
|
||||
@pl.when(next_b < batch_size)
|
||||
def prefetch_next_block():
|
||||
next_buffer_index = jnp.where(buffer_index == 0, 1, 0)
|
||||
async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors(
|
||||
next_b, next_h, next_i, next_buffer_index
|
||||
)
|
||||
async_copy_next_k.start()
|
||||
async_copy_next_v.start()
|
||||
buffer_index_ref[0] = next_buffer_index
|
||||
|
||||
async_copy_k, async_copy_v = create_kv_async_copy_descriptors(
|
||||
b, h, i, buffer_index
|
||||
)
|
||||
q = q_ref[...].astype(jnp.float32)
|
||||
k = async_copy_k.wait_and_get_loaded()
|
||||
qk = jnp.einsum("gd,td->gt", q, k, preferred_element_type=jnp.float32)
|
||||
if attn_logits_soft_cap is not None:
|
||||
capped_qk = jnp.tanh(qk / attn_logits_soft_cap)
|
||||
qk = capped_qk * attn_logits_soft_cap
|
||||
|
||||
mask = i * bk + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length
|
||||
qk = qk + jnp.where(mask, 0.0, mask_value)
|
||||
m_curr = qk.max(axis=-1)
|
||||
|
||||
s_curr = jnp.exp(qk - m_curr[..., None])
|
||||
m_prev, l_prev = m_ref[...], l_ref[...]
|
||||
l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
|
||||
m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
|
||||
m_next = jnp.maximum(m_prev, m_curr)
|
||||
alpha = jnp.exp(m_prev - m_next)
|
||||
beta = jnp.exp(m_curr - m_next)
|
||||
l_next = alpha * l_prev + beta * l_curr
|
||||
m_ref[...], l_ref[...] = m_next, l_next
|
||||
|
||||
v = async_copy_v.wait_and_get_loaded()
|
||||
o_curr = jnp.einsum("gt,td->gd", s_curr, v)
|
||||
|
||||
o_ref[...] = (
|
||||
(l_prev * alpha * o_ref[...] + beta * o_curr) / l_next
|
||||
).astype(o_ref.dtype)
|
||||
|
||||
|
||||
def paged_flash_attention_kernel_inline_seq_dim(
|
||||
lengths_ref,
|
||||
page_indices_ref,
|
||||
buffer_index_ref,
|
||||
init_flag_ref,
|
||||
q_ref,
|
||||
k_pages_hbm_ref,
|
||||
k_scales_pages_hbm_ref,
|
||||
v_pages_hbm_ref,
|
||||
v_scales_pages_hbm_ref,
|
||||
o_ref,
|
||||
m_ref,
|
||||
l_ref,
|
||||
k_vmem_buffer,
|
||||
k_scales_vmem_buffer,
|
||||
v_vmem_buffer,
|
||||
v_scales_vmem_buffer,
|
||||
k_sems,
|
||||
v_sems,
|
||||
*,
|
||||
batch_size: int,
|
||||
pages_per_compute_block: int,
|
||||
pages_per_sequence: int,
|
||||
mask_value: float,
|
||||
attn_logits_soft_cap: float | None,
|
||||
megacore_mode: str | None,
|
||||
):
|
||||
core_index, b, h = pl.program_id(0), pl.program_id(1), pl.program_id(2)
|
||||
|
||||
# Initialize the output HBM buffers to avoid accessing garbage memory inside
|
||||
# the kernel body below.
|
||||
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
|
||||
l_ref[...] = jnp.zeros_like(l_ref)
|
||||
o_ref[...] = jnp.zeros_like(o_ref)
|
||||
|
||||
def body(i, _):
|
||||
paged_flash_attention_kernel(
|
||||
lengths_ref,
|
||||
page_indices_ref,
|
||||
buffer_index_ref,
|
||||
init_flag_ref,
|
||||
q_ref,
|
||||
k_pages_hbm_ref,
|
||||
k_scales_pages_hbm_ref,
|
||||
v_pages_hbm_ref,
|
||||
v_scales_pages_hbm_ref,
|
||||
o_ref,
|
||||
m_ref,
|
||||
l_ref,
|
||||
k_vmem_buffer,
|
||||
k_scales_vmem_buffer,
|
||||
v_vmem_buffer,
|
||||
v_scales_vmem_buffer,
|
||||
k_sems,
|
||||
v_sems,
|
||||
batch_size=batch_size,
|
||||
pages_per_compute_block=pages_per_compute_block,
|
||||
pages_per_sequence=pages_per_sequence,
|
||||
mask_value=mask_value,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
megacore_mode=megacore_mode,
|
||||
program_ids=(core_index, b, h, i),
|
||||
)
|
||||
return ()
|
||||
|
||||
bk = pages_per_compute_block * k_pages_hbm_ref.shape[-2]
|
||||
|
||||
if megacore_mode == "batch":
|
||||
num_cores = pl.num_programs(0)
|
||||
length = lengths_ref[b * num_cores + core_index]
|
||||
else:
|
||||
length = lengths_ref[b]
|
||||
|
||||
lax.fori_loop(0, lax.div(length + bk - 1, bk), body, ())
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"pages_per_compute_block",
|
||||
"attn_logits_soft_cap",
|
||||
"mask_value",
|
||||
"megacore_mode",
|
||||
"inline_seq_dim",
|
||||
],
|
||||
)
|
||||
def paged_attention(
|
||||
q: jax.Array,
|
||||
k_pages: jax.Array | quantization_utils.QuantizedTensor,
|
||||
v_pages: jax.Array | quantization_utils.QuantizedTensor,
|
||||
lengths: jax.Array,
|
||||
page_indices: jax.Array,
|
||||
*,
|
||||
mask_value: float = DEFAULT_MASK_VALUE,
|
||||
attn_logits_soft_cap: float | None = None,
|
||||
pages_per_compute_block: int,
|
||||
megacore_mode: str | None = None,
|
||||
inline_seq_dim: bool = True,
|
||||
) -> jax.Array:
|
||||
"""Paged grouped query attention.
|
||||
|
||||
Args:
|
||||
q: A [batch_size, num_q_heads, head_dim] jax.Array.
|
||||
k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
|
||||
v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array.
|
||||
lengths: A i32[batch_size] jax.Array the length of each example.
|
||||
page_indices: A i32[batch_size, pages_per_sequence] jax.Array. Each entry
|
||||
should be in the range of [0, total_num_pages), indicating where to locate
|
||||
the page in `k_pages` or `v_pages`.
|
||||
mask_value: The value used for padding in attention. By default it is a very
|
||||
negative floating point number.
|
||||
attn_logits_soft_cap: The value used for soft capping the attention logits.
|
||||
pages_per_compute_block: how many pages to be processed in one flash
|
||||
attention block in the pallas kernel.
|
||||
megacore_mode: if set, enable megacore to parallelize the computation. Must
|
||||
be one of ['kv_head', 'batch', None]. Caveat: set this only if megacore is
|
||||
enabled, otherwise the kernel may hang. If you are not sure, leave it to
|
||||
None.
|
||||
* None: disable megacore parallelism.
|
||||
* kv_head: megacore parallelism on KV heads; requires number of KV heads
|
||||
divisible by 2.
|
||||
* batch: megacore parallelism on batch dimension; requires batch divisible
|
||||
by 2.
|
||||
inline_seq_dim: whether to fuse kernel instances along the sequence dim into
|
||||
one kernel.
|
||||
|
||||
Returns:
|
||||
The output of attention([batch_size, num_q_heads, head_dim]).
|
||||
"""
|
||||
if isinstance(k_pages, quantization_utils.QuantizedTensor):
|
||||
k_pages, k_scales_pages = k_pages.weight, k_pages.scales
|
||||
assert isinstance(k_scales_pages, jax.Array) # For typing.
|
||||
k_scales_pages = jnp.broadcast_to(
|
||||
k_scales_pages, (*k_scales_pages.shape[:-1], k_pages.shape[-1])
|
||||
)
|
||||
else:
|
||||
k_scales_pages = None
|
||||
if isinstance(v_pages, quantization_utils.QuantizedTensor):
|
||||
v_pages, v_scales_pages = v_pages.weight, v_pages.scales
|
||||
assert isinstance(v_scales_pages, jax.Array) # For typing.
|
||||
v_scales_pages = jnp.broadcast_to(
|
||||
v_scales_pages, (*v_scales_pages.shape[:-1], v_pages.shape[-1])
|
||||
)
|
||||
else:
|
||||
v_scales_pages = None
|
||||
|
||||
batch_size, num_q_heads, head_dim = q.shape
|
||||
num_kv_heads, _, page_size, head_dim_k = k_pages.shape
|
||||
batch_size_paged_indices, pages_per_sequence = page_indices.shape
|
||||
|
||||
if k_pages.shape != v_pages.shape:
|
||||
raise ValueError(
|
||||
f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and"
|
||||
f" {v_pages.shape}"
|
||||
)
|
||||
if num_q_heads % num_kv_heads != 0:
|
||||
raise ValueError(
|
||||
"Number of Q heads must be divisible by number of KV heads. Got"
|
||||
f" {num_q_heads} and {num_kv_heads}."
|
||||
)
|
||||
if head_dim_k != head_dim:
|
||||
raise ValueError(
|
||||
"head_dim of Q must be the same as that of K/V. Got"
|
||||
f" {head_dim} and {head_dim_k}."
|
||||
)
|
||||
if pages_per_sequence % pages_per_compute_block != 0:
|
||||
raise ValueError(
|
||||
"pages_per_compute_block must be divisible by pages per sequence. Got"
|
||||
f" {pages_per_compute_block} and {pages_per_sequence}."
|
||||
)
|
||||
if lengths.shape != (batch_size,):
|
||||
raise ValueError("`lengths` and `q` must have the same batch size")
|
||||
if batch_size_paged_indices != batch_size:
|
||||
raise ValueError("`page_indices` and `q` must have the same batch size")
|
||||
if lengths.dtype != jnp.int32:
|
||||
raise ValueError(
|
||||
f"The dtype of `lengths` must be int32. Got {lengths.dtype}"
|
||||
)
|
||||
|
||||
# TODO(dinghua): get the actual cores per chip once there's an official API.
|
||||
if megacore_mode == "kv_head":
|
||||
if num_kv_heads % 2 != 0:
|
||||
raise ValueError(
|
||||
"number of KV heads must be even when megacore_mode is 'kv_head'"
|
||||
)
|
||||
num_cores = 2
|
||||
elif megacore_mode == "batch":
|
||||
if batch_size % 2 != 0:
|
||||
raise ValueError("batch size must be even when megacore_mode is 'batch'")
|
||||
num_cores = 2
|
||||
elif megacore_mode is None:
|
||||
num_cores = 1
|
||||
else:
|
||||
raise ValueError("megacore_mode must be one of ['kv_head', 'batch', None]")
|
||||
|
||||
num_groups = num_q_heads // num_kv_heads
|
||||
if (num_groups) % 8 != 0:
|
||||
# Reshape q to hint XLA to pick a <1x128> layout otherwise it will pick a
|
||||
# <8x128> layout for a <1x128> memref inside the kernel and error out.
|
||||
q = q.reshape(batch_size, num_q_heads, 1, head_dim)
|
||||
if megacore_mode == "kv_head":
|
||||
q_block_spec = pl.BlockSpec(
|
||||
(None, num_groups, None, head_dim),
|
||||
lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0, 0),
|
||||
)
|
||||
elif megacore_mode == "batch":
|
||||
q_block_spec = pl.BlockSpec(
|
||||
(None, num_groups, None, head_dim),
|
||||
lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0, 0),
|
||||
)
|
||||
else:
|
||||
q_block_spec = pl.BlockSpec(
|
||||
(None, num_groups, None, head_dim),
|
||||
lambda core_index, b, h, *_: (b, h, 0, 0),
|
||||
)
|
||||
q_dtype_for_kernel_launch = jnp.float32
|
||||
else:
|
||||
if megacore_mode == "kv_head":
|
||||
q_block_spec = pl.BlockSpec(
|
||||
(None, num_groups, head_dim),
|
||||
lambda core_index, b, h, *_: (b, h * num_cores + core_index, 0),
|
||||
)
|
||||
elif megacore_mode == "batch":
|
||||
q_block_spec = pl.BlockSpec(
|
||||
(None, num_groups, head_dim),
|
||||
lambda core_index, b, h, *_: (b * num_cores + core_index, h, 0),
|
||||
)
|
||||
else:
|
||||
q_block_spec = pl.BlockSpec(
|
||||
(None, num_groups, head_dim),
|
||||
lambda core_index, b, h, *_: (b, h, 0),
|
||||
)
|
||||
q_dtype_for_kernel_launch = q.dtype
|
||||
|
||||
dimension_semantics: Sequence[Literal["parallel", "arbitrary"]]
|
||||
if inline_seq_dim:
|
||||
kernel = paged_flash_attention_kernel_inline_seq_dim
|
||||
grid = (
|
||||
num_cores,
|
||||
batch_size // num_cores if megacore_mode == "batch" else batch_size,
|
||||
num_kv_heads // num_cores
|
||||
if megacore_mode == "kv_head"
|
||||
else num_kv_heads,
|
||||
)
|
||||
dimension_semantics = ("parallel", "arbitrary", "arbitrary")
|
||||
else:
|
||||
kernel = paged_flash_attention_kernel
|
||||
grid = (
|
||||
num_cores,
|
||||
batch_size // num_cores if megacore_mode == "batch" else batch_size,
|
||||
num_kv_heads // num_cores
|
||||
if megacore_mode == "kv_head"
|
||||
else num_kv_heads,
|
||||
pages_per_sequence // pages_per_compute_block,
|
||||
)
|
||||
dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary")
|
||||
|
||||
if k_scales_pages is not None and v_scales_pages is not None:
|
||||
in_specs = [
|
||||
q_block_spec,
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
]
|
||||
scratch_shapes = (
|
||||
pltpu.VMEM(
|
||||
(
|
||||
2, # For double buffering during DMA copies.
|
||||
pages_per_compute_block,
|
||||
page_size,
|
||||
head_dim,
|
||||
),
|
||||
k_pages.dtype,
|
||||
), # k_pages buffer
|
||||
pltpu.VMEM(
|
||||
(
|
||||
2, # For double buffering during DMA copies.
|
||||
pages_per_compute_block,
|
||||
page_size,
|
||||
head_dim,
|
||||
),
|
||||
k_scales_pages.dtype,
|
||||
), # k_scales_pages buffer
|
||||
pltpu.VMEM(
|
||||
(
|
||||
2, # For double buffering during DMA copies.
|
||||
pages_per_compute_block,
|
||||
page_size,
|
||||
head_dim,
|
||||
),
|
||||
v_pages.dtype,
|
||||
), # v_pages buffer
|
||||
pltpu.VMEM(
|
||||
(
|
||||
2, # For double buffering during DMA copies.
|
||||
pages_per_compute_block,
|
||||
page_size,
|
||||
head_dim,
|
||||
),
|
||||
v_scales_pages.dtype,
|
||||
), # v_scales_pages buffer
|
||||
pltpu.SemaphoreType.DMA((2,)),
|
||||
pltpu.SemaphoreType.DMA((2,)),
|
||||
)
|
||||
else:
|
||||
in_specs = [
|
||||
q_block_spec,
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
None,
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
None,
|
||||
]
|
||||
scratch_shapes = (
|
||||
pltpu.VMEM(
|
||||
(
|
||||
2, # For double buffering during DMA copies.
|
||||
pages_per_compute_block,
|
||||
page_size,
|
||||
head_dim,
|
||||
),
|
||||
k_pages.dtype,
|
||||
), # k_pages buffer
|
||||
None,
|
||||
pltpu.VMEM(
|
||||
(
|
||||
2, # For double buffering during DMA copies.
|
||||
pages_per_compute_block,
|
||||
page_size,
|
||||
head_dim,
|
||||
),
|
||||
v_pages.dtype,
|
||||
), # v_pages buffer
|
||||
None,
|
||||
pltpu.SemaphoreType.DMA((2,)),
|
||||
pltpu.SemaphoreType.DMA((2,)),
|
||||
)
|
||||
|
||||
out, _, _ = pl.pallas_call(
|
||||
functools.partial(
|
||||
kernel,
|
||||
pages_per_sequence=pages_per_sequence,
|
||||
batch_size=batch_size,
|
||||
pages_per_compute_block=pages_per_compute_block,
|
||||
mask_value=mask_value,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
megacore_mode=megacore_mode,
|
||||
),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
# There are 4 scalars prefetched per kernel call: `lengths_ref`,
|
||||
# `page_indices_ref`, `buffer_index_ref`, `init_flag_ref`
|
||||
num_scalar_prefetch=4,
|
||||
in_specs=in_specs,
|
||||
out_specs=[
|
||||
q_block_spec,
|
||||
q_block_spec,
|
||||
q_block_spec,
|
||||
],
|
||||
grid=grid,
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
compiler_params=pltpu.CompilerParams(
|
||||
dimension_semantics=dimension_semantics
|
||||
),
|
||||
out_shape=[
|
||||
jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch),
|
||||
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
|
||||
jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32),
|
||||
],
|
||||
)(
|
||||
lengths,
|
||||
page_indices.reshape(-1),
|
||||
jnp.zeros((1,), jnp.int32), # buffer index
|
||||
jnp.ones((1,), jnp.int32), # init flag
|
||||
q.astype(q_dtype_for_kernel_launch),
|
||||
k_pages,
|
||||
k_scales_pages,
|
||||
v_pages,
|
||||
v_scales_pages,
|
||||
)
|
||||
return out.reshape(batch_size, num_q_heads, head_dim).astype(q.dtype)
|
||||
+107
@@ -0,0 +1,107 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import NamedTuple
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
|
||||
P = jax.sharding.PartitionSpec
|
||||
MAX_INT8 = 127.5
|
||||
|
||||
|
||||
class QuantizedTensor(NamedTuple):
|
||||
"""A tensor which has been quantized to int8 and its scales.
|
||||
|
||||
Attributes:
|
||||
weight: Weight
|
||||
scales: Scales
|
||||
"""
|
||||
|
||||
weight: jnp.ndarray
|
||||
scales: jnp.ndarray
|
||||
|
||||
|
||||
def to_int8(x: jnp.ndarray, h: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Converts a float array to an int8 array with a scale.
|
||||
|
||||
Args:
|
||||
x: Float array.
|
||||
h: Quantization scale.
|
||||
|
||||
Returns:
|
||||
Int8 array.
|
||||
"""
|
||||
return jnp.int8(jnp.rint(x * (MAX_INT8 / h)))
|
||||
|
||||
|
||||
def from_int8(
|
||||
x: jnp.ndarray, h: jnp.ndarray, dtype: jnp.dtype = jnp.bfloat16
|
||||
) -> jnp.ndarray:
|
||||
"""Converts an int8 array to a float array with a scale.
|
||||
|
||||
Args:
|
||||
x: Int8 array.
|
||||
h: Quantization scale.
|
||||
dtype: Float dtype to convert to.
|
||||
|
||||
Returns:
|
||||
Float array.
|
||||
"""
|
||||
return x.astype(dtype) * h / MAX_INT8
|
||||
|
||||
|
||||
def get_quantization_scales(x: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Computes the quantization scales for a float array.
|
||||
|
||||
These are the maximum values of the trailing dimension.
|
||||
|
||||
Args:
|
||||
x: Float array to quantize.
|
||||
|
||||
Returns:
|
||||
Array of the same shape as input but with the trailing dimension reduced to
|
||||
a size 1 absolute max value.
|
||||
"""
|
||||
return jnp.max(jnp.abs(x), axis=-1, keepdims=True)
|
||||
|
||||
|
||||
def quantize_to_int8(
|
||||
x: jnp.ndarray,
|
||||
) -> QuantizedTensor:
|
||||
"""Quantizes a float array to an int8 QuantizedTensor.
|
||||
|
||||
Args:
|
||||
x: Float array to quantize.
|
||||
|
||||
Returns:
|
||||
Int8 QuantizedTensor.
|
||||
"""
|
||||
x_scales = get_quantization_scales(x)
|
||||
return QuantizedTensor(weight=to_int8(x, x_scales), scales=x_scales)
|
||||
|
||||
|
||||
def unquantize_from_int8(
|
||||
x: QuantizedTensor,
|
||||
dtype: jnp.dtype = jnp.bfloat16,
|
||||
) -> jnp.ndarray:
|
||||
"""Unquantizes an int8 QuantizedTensor to a float array.
|
||||
|
||||
Args:
|
||||
x: Int8 QuantizedTensor to unquantize.
|
||||
dtype: Float dtype to unquantize to.
|
||||
|
||||
Returns:
|
||||
Float array.
|
||||
"""
|
||||
return from_int8(x.weight, x.scales, dtype)
|
||||
+82
@@ -0,0 +1,82 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""JAX reference implementation of grouped query attention."""
|
||||
|
||||
import jax
|
||||
from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils
|
||||
import jax.numpy as jnp
|
||||
|
||||
MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
||||
|
||||
|
||||
def grouped_query_attention_reference(
|
||||
queries: jax.Array, # [batch_size, num_q_heads, head_dim]
|
||||
k_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim]
|
||||
v_pages: jax.Array, # [batch_size, num_kv_heads, max_seq_len, head_dim]
|
||||
seq_lens: jax.Array, # i32[batch_size]
|
||||
soft_cap: float | None = None,
|
||||
debug: bool = False,
|
||||
) -> jax.Array: # [batch_size, num_q_heads, head_dim]
|
||||
"""Grouped query attention with a single query per request."""
|
||||
# Check input shapes
|
||||
assert k_pages.shape == v_pages.shape
|
||||
batch_size, num_q_heads, head_dim = queries.shape
|
||||
batch_size2, num_kv_heads, max_seq_len, head_dim2 = k_pages.shape
|
||||
assert batch_size2 == batch_size
|
||||
assert head_dim2 == head_dim
|
||||
|
||||
# Unquantize kv pages if necessary
|
||||
if isinstance(k_pages, quantization_utils.QuantizedTensor):
|
||||
k_pages = quantization_utils.unquantize_from_int8(
|
||||
k_pages, dtype=jnp.float32
|
||||
)
|
||||
if isinstance(v_pages, quantization_utils.QuantizedTensor):
|
||||
v_pages = quantization_utils.unquantize_from_int8(
|
||||
v_pages, dtype=jnp.float32
|
||||
)
|
||||
|
||||
# Reshape for num_groups queries per k head
|
||||
assert num_q_heads % num_kv_heads == 0
|
||||
num_groups = num_q_heads // num_kv_heads
|
||||
queries = queries.reshape(batch_size, num_kv_heads, num_groups, head_dim)
|
||||
|
||||
# Compute the dot product q*k and apply soft cap if necessary
|
||||
qk = jnp.einsum(
|
||||
"bhgd,bhtd->bhgt",
|
||||
queries.astype(jnp.float32),
|
||||
k_pages.astype(jnp.float32),
|
||||
)
|
||||
if soft_cap is not None and soft_cap != 0.0:
|
||||
qk = jnp.tanh(qk / soft_cap) * soft_cap
|
||||
assert qk.shape == (batch_size, num_kv_heads, num_groups, max_seq_len)
|
||||
if debug:
|
||||
jax.debug.print("qk: {qk}", qk=qk)
|
||||
|
||||
# Enforce causal mask (adding dimensions when necessary)
|
||||
mask = jnp.arange(max_seq_len)[None] < seq_lens[:, None]
|
||||
qk += jnp.where(mask, 0.0, MASK_VALUE)[:, None, None, :]
|
||||
if debug:
|
||||
jax.debug.print("masked: {qk}", qk=qk)
|
||||
|
||||
# Generate probability distribution using softmax
|
||||
probs = jax.nn.softmax(qk, axis=-1).astype(v_pages.dtype)
|
||||
assert probs.shape == (batch_size, num_kv_heads, num_groups, max_seq_len)
|
||||
if debug:
|
||||
jax.debug.print("softmax: {probs}", probs=probs)
|
||||
|
||||
# Attention is probability-weighted sum of v heads
|
||||
attention = jnp.einsum("bhgt,bhtd->bhgd", probs, v_pages)
|
||||
assert attention.shape == (batch_size, num_kv_heads, num_groups, head_dim)
|
||||
return attention.reshape(batch_size, num_q_heads, head_dim)
|
||||
+22
@@ -0,0 +1,22 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental.pallas.ops.tpu.ragged_paged_attention import kernel
|
||||
from jax.experimental.pallas.ops.tpu.ragged_paged_attention import tuned_block_sizes
|
||||
|
||||
dynamic_validate_inputs = kernel.dynamic_validate_inputs
|
||||
ragged_paged_attention = kernel.ragged_paged_attention
|
||||
ref_ragged_paged_attention = kernel.ref_ragged_paged_attention
|
||||
static_validate_inputs = kernel.static_validate_inputs
|
||||
get_tuned_block_sizes = tuned_block_sizes.get_tuned_block_sizes
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
+899
@@ -0,0 +1,899 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""TPU-Friendly Ragged Paged Attention kernel.
|
||||
|
||||
This kernel offers a highly optimized implementation of ragged paged attention,
|
||||
specifically designed for TPU and compatible with a wide range of model
|
||||
specifications. It supports mixed prefill and decoding, enhancing throughput
|
||||
during inference.
|
||||
"""
|
||||
import functools
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
from jax.experimental.pallas.ops.tpu.ragged_paged_attention.tuned_block_sizes import get_tuned_block_sizes
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
|
||||
|
||||
|
||||
class MultiPageAsyncCopyDescriptor:
|
||||
"""Descriptor for async copy of multiple K/V pages from HBM."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads_per_blk, head_dim]
|
||||
vmem_buf, # [num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
|
||||
sem,
|
||||
page_indices_ref, # i32[max_num_seqs, pages_per_seq]
|
||||
metadata, # [seq_idx, start_page_idx, end_page_idx]
|
||||
):
|
||||
self._vmem_buf = vmem_buf
|
||||
seq_id, start_page_idx, end_page_idx = metadata
|
||||
self._async_copies = []
|
||||
# TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
|
||||
# a bunch of if-ops. Check the performance when we have benchmarking setup.
|
||||
for i in range(vmem_buf.shape[0]):
|
||||
page_idx = start_page_idx + i
|
||||
page_idx = jax.lax.select(page_idx < end_page_idx, page_idx, 0)
|
||||
self._async_copies.append(
|
||||
pltpu.make_async_copy(
|
||||
pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]],
|
||||
vmem_buf.at[i],
|
||||
sem,
|
||||
)
|
||||
)
|
||||
|
||||
def start(self):
|
||||
"""Starts the async copies."""
|
||||
for async_copy in self._async_copies:
|
||||
async_copy.start()
|
||||
|
||||
def wait(self):
|
||||
for async_copy in self._async_copies:
|
||||
async_copy.wait()
|
||||
return self._vmem_buf
|
||||
|
||||
|
||||
def ref_ragged_paged_attention(
|
||||
queries: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
||||
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
||||
kv_lens: jax.Array, # i32[max_num_seqs]
|
||||
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
||||
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
||||
num_seqs: jax.Array, # i32[1],
|
||||
*,
|
||||
sm_scale: float = 1.0,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
mask_value: float | None = DEFAULT_MASK_VALUE,
|
||||
k_scale: float | None = None,
|
||||
v_scale: float | None = None,
|
||||
):
|
||||
static_validate_inputs(
|
||||
queries,
|
||||
kv_pages,
|
||||
kv_lens,
|
||||
page_indices,
|
||||
cu_q_lens,
|
||||
num_seqs,
|
||||
sm_scale=sm_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
mask_value=mask_value,
|
||||
)
|
||||
if mask_value is None:
|
||||
mask_value = DEFAULT_MASK_VALUE
|
||||
_, _, num_combined_kv_heads, head_dim = kv_pages.shape
|
||||
assert num_combined_kv_heads % 2 == 0
|
||||
num_kv_heads = num_combined_kv_heads // 2
|
||||
num_q_heads = queries.shape[1]
|
||||
assert num_q_heads % num_kv_heads == 0
|
||||
num_query_per_kv = num_q_heads // num_kv_heads
|
||||
outputs = []
|
||||
for i in range(num_seqs[0]):
|
||||
q_start = cu_q_lens[i]
|
||||
q_end = cu_q_lens[i + 1]
|
||||
q_len = q_end - q_start
|
||||
kv_len = kv_lens[i]
|
||||
indices = page_indices[i]
|
||||
q = queries[q_start:q_end]
|
||||
k = kv_pages[indices, :, 0::2, :].reshape(-1, num_kv_heads, head_dim)[
|
||||
:kv_len
|
||||
]
|
||||
v = kv_pages[indices, :, 1::2, :].reshape(-1, num_kv_heads, head_dim)[
|
||||
:kv_len
|
||||
]
|
||||
if k_scale is not None:
|
||||
k = k.astype(jnp.float32) * k_scale
|
||||
k = k.astype(q.dtype)
|
||||
if v_scale is not None:
|
||||
v = v.astype(jnp.float32) * v_scale
|
||||
v = v.astype(q.dtype)
|
||||
k = jnp.repeat(k, num_query_per_kv, axis=1)
|
||||
v = jnp.repeat(v, num_query_per_kv, axis=1)
|
||||
attn = jnp.einsum("qhd,khd->hqk", q, k, preferred_element_type=jnp.float32)
|
||||
attn *= sm_scale
|
||||
q_span = (kv_len - q_len) + jax.lax.broadcasted_iota(
|
||||
jnp.int32, attn.shape, 1
|
||||
)
|
||||
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
|
||||
mask = q_span < kv_span
|
||||
if sliding_window is not None:
|
||||
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * jnp.tanh(attn / soft_cap)
|
||||
attn += jnp.where(mask, mask_value, 0.0)
|
||||
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
|
||||
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
|
||||
outputs.append(out)
|
||||
|
||||
return jnp.concatenate(outputs, axis=0)
|
||||
|
||||
|
||||
# Expect to run these checks during runtime.
|
||||
def dynamic_validate_inputs(
|
||||
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
||||
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
||||
kv_lens: jax.Array, # i32[max_num_seqs]
|
||||
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
||||
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
||||
num_seqs: jax.Array, # i32[1]
|
||||
*,
|
||||
# These inputs are optional. If not specified, we will not validate them.
|
||||
sm_scale: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
mask_value: float | None = None,
|
||||
k_scale: float | None = None,
|
||||
v_scale: float | None = None,
|
||||
# Kernel tuning params.
|
||||
num_kv_pages_per_block: int | None = None,
|
||||
num_queries_per_block: int | None = None,
|
||||
vmem_limit_bytes: int | None = None,
|
||||
):
|
||||
static_validate_inputs(
|
||||
q,
|
||||
kv_pages,
|
||||
kv_lens,
|
||||
page_indices,
|
||||
cu_q_lens,
|
||||
num_seqs,
|
||||
sm_scale=sm_scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
mask_value=mask_value,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
num_kv_pages_per_block=num_kv_pages_per_block,
|
||||
num_queries_per_block=num_queries_per_block,
|
||||
vmem_limit_bytes=vmem_limit_bytes,
|
||||
)
|
||||
max_num_batched_tokens = q.shape[0]
|
||||
page_size = kv_pages.shape[1]
|
||||
max_num_seqs, pages_per_seq = page_indices.shape
|
||||
if num_seqs[0] > max_num_seqs:
|
||||
raise ValueError(f"{num_seqs[0]=} must be less or equal to {max_num_seqs=}")
|
||||
max_kv_len = jnp.max(kv_lens)
|
||||
min_pages_per_seq = pl.cdiv(max_kv_len, page_size)
|
||||
if pages_per_seq < min_pages_per_seq:
|
||||
raise ValueError(
|
||||
f"{pages_per_seq=} must be greater or equal to"
|
||||
f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}."
|
||||
)
|
||||
if cu_q_lens[num_seqs[0]] > max_num_batched_tokens:
|
||||
raise ValueError(
|
||||
f"Total q tokens {cu_q_lens[num_seqs[0]]} must be less or equal to"
|
||||
f" {max_num_batched_tokens=}."
|
||||
)
|
||||
for i in range(num_seqs[0]):
|
||||
q_len = cu_q_lens[i + 1] - cu_q_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
if q_len > kv_len:
|
||||
raise ValueError(
|
||||
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}."
|
||||
)
|
||||
|
||||
|
||||
# Expect to run these checks during compile time.
|
||||
def static_validate_inputs(
|
||||
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
||||
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
||||
kv_lens: jax.Array, # i32[max_num_seqs]
|
||||
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
||||
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
||||
num_seqs: jax.Array, # i32[1]
|
||||
*,
|
||||
# These inputs are optional. If not specified, we will not validate them.
|
||||
sm_scale: float | None = None,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
mask_value: float | None = None,
|
||||
k_scale: float | None = None,
|
||||
v_scale: float | None = None,
|
||||
# Kernel tuning params.
|
||||
num_kv_pages_per_block: int | None = None,
|
||||
num_queries_per_block: int | None = None,
|
||||
vmem_limit_bytes: int | None = None,
|
||||
):
|
||||
_, num_q_heads, head_dim = q.shape
|
||||
_, _, num_combined_kv_heads, head_dim_k = kv_pages.shape
|
||||
assert num_combined_kv_heads % 2 == 0
|
||||
assert isinstance(k_scale, float) or k_scale is None
|
||||
assert isinstance(v_scale, float) or v_scale is None
|
||||
num_kv_heads = num_combined_kv_heads // 2
|
||||
max_num_seqs, pages_per_seq = page_indices.shape
|
||||
if num_seqs.shape != (1,):
|
||||
raise ValueError(f"{num_seqs.shape=} must be (1,)")
|
||||
if head_dim_k != head_dim:
|
||||
raise ValueError(
|
||||
f"Q head_dim {head_dim} must be the same as that of K/V {head_dim_k}."
|
||||
)
|
||||
if kv_lens.shape != (max_num_seqs,):
|
||||
raise ValueError(
|
||||
f"Expected {kv_lens.shape=} to be ({max_num_seqs},) where"
|
||||
" `max_num_seqs` is `page_indices.shape[0]`."
|
||||
)
|
||||
if cu_q_lens.shape != (max_num_seqs + 1,):
|
||||
raise ValueError(
|
||||
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
|
||||
" `max_num_seqs` is `page_indices.shape[0]`."
|
||||
)
|
||||
if (
|
||||
kv_lens.dtype != jnp.int32
|
||||
or page_indices.dtype != jnp.int32
|
||||
or cu_q_lens.dtype != jnp.int32
|
||||
):
|
||||
raise ValueError(
|
||||
"The dtype of `kv_lens`, `page_indices`, and `cu_q_lens` must be"
|
||||
f" int32. Got {kv_lens.dtype=}, {page_indices.dtype=},"
|
||||
f" {cu_q_lens.dtype=}."
|
||||
)
|
||||
if num_q_heads % num_kv_heads != 0:
|
||||
raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}")
|
||||
if sliding_window is not None and sliding_window <= 0:
|
||||
raise ValueError(f"{sliding_window=} must be positive.")
|
||||
if soft_cap is not None and soft_cap == 0.0:
|
||||
raise ValueError(f"{soft_cap=} must not be 0.0.")
|
||||
if (
|
||||
num_kv_pages_per_block is not None
|
||||
and not 0 < num_kv_pages_per_block <= pages_per_seq
|
||||
):
|
||||
raise ValueError(
|
||||
f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}]."
|
||||
)
|
||||
if num_queries_per_block is not None and num_queries_per_block <= 0:
|
||||
raise ValueError(f"{num_queries_per_block=} must be positive.")
|
||||
if vmem_limit_bytes is not None and vmem_limit_bytes <= 0:
|
||||
raise ValueError(f"{vmem_limit_bytes=} must be positive.")
|
||||
del sm_scale # No constraints on sm_scale.
|
||||
del mask_value # No consstraints on mask_value.
|
||||
|
||||
|
||||
def ragged_paged_attention_kernel(
|
||||
# Prefetch
|
||||
kv_lens_ref, # [max_num_seqs]
|
||||
page_indices_ref, # [max_num_seqs, pages_per_seq]
|
||||
cu_q_lens_ref, # [max_num_seqs + 1]
|
||||
seq_buf_idx_ref,
|
||||
# TODO(jevinjiang): if OOM in SMEM, consider pack to other scalar refs.
|
||||
num_seqs_ref,
|
||||
# Input
|
||||
q_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
|
||||
kv_pages_hbm_ref, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
||||
# Output
|
||||
o_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
|
||||
# Scratch
|
||||
kv_bufs, # [2, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, head_dim]
|
||||
sems, # [2, 2]
|
||||
l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
|
||||
m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
|
||||
acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
|
||||
*,
|
||||
sm_scale: float,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
mask_value: float | None = DEFAULT_MASK_VALUE,
|
||||
k_scale: float | None = None,
|
||||
v_scale: float | None = None,
|
||||
):
|
||||
if mask_value is None:
|
||||
mask_value = DEFAULT_MASK_VALUE
|
||||
num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
|
||||
pages_per_seq = page_indices_ref.shape[-1]
|
||||
num_seqs = num_seqs_ref[0]
|
||||
_, num_kv_pages_per_blk, page_size, num_combined_kv_heads_per_blk, _ = (
|
||||
kv_bufs.shape
|
||||
)
|
||||
num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2
|
||||
num_kv_per_blk = num_kv_pages_per_blk * page_size
|
||||
num_q_heads_per_kv_head = num_q_heads_per_blk // num_kv_heads_per_blk
|
||||
heads_blk_idx, q_blk_idx = (
|
||||
pl.program_id(0),
|
||||
pl.program_id(1),
|
||||
)
|
||||
num_heads_blks = pl.num_programs(0)
|
||||
init_seq_idx = seq_buf_idx_ref[0]
|
||||
init_buf_idx = seq_buf_idx_ref[1]
|
||||
q_len_start = q_blk_idx * num_q_per_blk
|
||||
q_len_end = q_len_start + num_q_per_blk
|
||||
|
||||
def create_kv_async_copy_descriptors(
|
||||
heads_blk_idx, seq_idx, kv_blk_idx, buf_idx
|
||||
):
|
||||
start_kv_page_idx = kv_blk_idx * num_kv_pages_per_blk
|
||||
end_kv_page_idx = jnp.minimum(
|
||||
pages_per_seq, pl.cdiv(kv_lens_ref[seq_idx], page_size)
|
||||
)
|
||||
metadata = (seq_idx, start_kv_page_idx, end_kv_page_idx)
|
||||
heads_start = heads_blk_idx * num_combined_kv_heads_per_blk
|
||||
async_copy_kv = MultiPageAsyncCopyDescriptor(
|
||||
kv_pages_hbm_ref.at[
|
||||
:, :, pl.ds(heads_start, num_combined_kv_heads_per_blk), :
|
||||
],
|
||||
kv_bufs.at[buf_idx],
|
||||
sems.at[buf_idx],
|
||||
page_indices_ref,
|
||||
metadata,
|
||||
)
|
||||
return async_copy_kv
|
||||
|
||||
# TODO(jevinjiang): Add these to Mosaic:
|
||||
# 1. Support arbitrary strided load/store for int4 and int8 dtype.
|
||||
# 2. Support arbitrary strided load/store for any last dimension.
|
||||
def strided_load_kv(ref, start, step):
|
||||
packing = get_dtype_packing(ref.dtype)
|
||||
if packing == 1:
|
||||
return [ref[start::step, :]], [ref[start + 1 :: step, :]]
|
||||
assert packing in (2, 4, 8)
|
||||
assert step % packing == 0
|
||||
k_list, v_list = [], []
|
||||
b_start = start // packing
|
||||
b_step = step // packing
|
||||
b_ref = ref.bitcast(jnp.uint32)
|
||||
b = b_ref[b_start::b_step, :]
|
||||
|
||||
# TODO(chengjiyao): use the general strided loading logic for bf16 after
|
||||
# fixing the issue in mosaic's infer vector layout pass
|
||||
if ref.dtype == jnp.bfloat16:
|
||||
bk = b << 16
|
||||
bv = b & jnp.uint32(0xFFFF0000)
|
||||
k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16)
|
||||
v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16)
|
||||
k_list.append(k)
|
||||
v_list.append(v)
|
||||
else:
|
||||
bitwidth = 32 // packing
|
||||
bitcast_dst_dtype = jnp.dtype(f"uint{bitwidth}")
|
||||
for i in range(0, packing, 2):
|
||||
bk = b >> (i * bitwidth)
|
||||
k = pltpu.bitcast(bk.astype(bitcast_dst_dtype), ref.dtype)
|
||||
k_list.append(k)
|
||||
bv = b >> ((i + 1) * bitwidth)
|
||||
v = pltpu.bitcast(bv.astype(bitcast_dst_dtype), ref.dtype)
|
||||
v_list.append(v)
|
||||
|
||||
return k_list, v_list
|
||||
|
||||
def fold_on_2nd_minor(vec):
|
||||
assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32
|
||||
assert len(vec.shape) >= 2
|
||||
last_dim = vec.shape[-1]
|
||||
packing = get_dtype_packing(vec.dtype)
|
||||
if vec.shape[-2] % packing != 0:
|
||||
vec = vec.astype(jnp.float32)
|
||||
return vec.reshape(-1, last_dim)
|
||||
|
||||
@pl.when(heads_blk_idx + q_blk_idx == 0)
|
||||
def prefetch_first_kv_blk():
|
||||
async_copy_kv = create_kv_async_copy_descriptors(
|
||||
heads_blk_idx, init_seq_idx, 0, init_buf_idx
|
||||
)
|
||||
async_copy_kv.start()
|
||||
|
||||
def is_cur_q_blk_needed(q_states):
|
||||
done, cur_seq_idx, _ = q_states
|
||||
should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs],
|
||||
cur_seq_idx < num_seqs)
|
||||
return jnp.logical_and(done == 0, should_run)
|
||||
|
||||
def compute_with_cur_q_blk(q_states):
|
||||
done, cur_seq_idx, cur_buf_idx = q_states
|
||||
q_start = cu_q_lens_ref[cur_seq_idx]
|
||||
q_end = cu_q_lens_ref[cur_seq_idx + 1]
|
||||
q_len = q_end - q_start
|
||||
kv_len = kv_lens_ref[cur_seq_idx]
|
||||
|
||||
def get_next_prefetch_ids(
|
||||
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
|
||||
):
|
||||
next_kv_blk_idx = kv_blk_idx + 1
|
||||
is_last_kv_blk = next_kv_blk_idx * num_kv_per_blk >= kv_len
|
||||
next_kv_blk_idx = lax.select(
|
||||
is_last_kv_blk,
|
||||
0,
|
||||
next_kv_blk_idx,
|
||||
)
|
||||
is_cur_seq_end_in_cur_q_blk = q_end <= q_len_end
|
||||
next_seq_idx = lax.select(
|
||||
is_last_kv_blk,
|
||||
lax.select(is_cur_seq_end_in_cur_q_blk, cur_seq_idx + 1, cur_seq_idx),
|
||||
cur_seq_idx,
|
||||
)
|
||||
is_last_seq = next_seq_idx == num_seqs
|
||||
next_seq_idx = lax.select(
|
||||
is_last_seq,
|
||||
0,
|
||||
next_seq_idx,
|
||||
)
|
||||
next_heads_blk_idx = lax.select(
|
||||
is_last_seq,
|
||||
heads_blk_idx + 1,
|
||||
heads_blk_idx,
|
||||
)
|
||||
next_buf_idx = lax.select(cur_buf_idx == 0, 1, 0)
|
||||
return next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
|
||||
|
||||
def flash_attention(
|
||||
q, # [num_q_per_blk * num_q_heads_per_kv_head, head_dim]
|
||||
k, # [num_kv_per_blk, head_dim]
|
||||
v, # [num_kv_per_blk, head_dim]
|
||||
head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
|
||||
head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
|
||||
head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
|
||||
*,
|
||||
kv_blk_idx,
|
||||
):
|
||||
assert q.shape == (
|
||||
num_q_per_blk * num_q_heads_per_kv_head,
|
||||
head_dim,
|
||||
)
|
||||
assert (
|
||||
k.shape
|
||||
== v.shape
|
||||
== (
|
||||
num_kv_per_blk,
|
||||
head_dim,
|
||||
)
|
||||
)
|
||||
assert k.dtype == v.dtype
|
||||
assert (
|
||||
head_m_ref.shape
|
||||
== head_l_ref.shape
|
||||
== (
|
||||
num_q_per_blk * num_q_heads_per_kv_head,
|
||||
128,
|
||||
)
|
||||
)
|
||||
assert head_acc_ref.shape == (
|
||||
num_q_per_blk,
|
||||
num_q_heads_per_kv_head,
|
||||
head_dim,
|
||||
)
|
||||
kv_len_start = kv_blk_idx * num_kv_per_blk
|
||||
|
||||
def masked_store(ref, val, start, end, group=1):
|
||||
iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group
|
||||
pltpu.store(ref, val, mask=jnp.logical_and(iota >= start, iota < end))
|
||||
|
||||
def load_with_init(ref, init_val):
|
||||
return jnp.where(
|
||||
kv_blk_idx == 0, jnp.full_like(ref, init_val), ref[...]
|
||||
)
|
||||
|
||||
# kv lens will be contracting dim, we should mask out the NaNs.
|
||||
kv_mask = (
|
||||
lax.broadcasted_iota(jnp.int32, k.shape, 0) < kv_len - kv_len_start
|
||||
)
|
||||
k = jnp.where(kv_mask, k.astype(jnp.float32), 0).astype(k.dtype)
|
||||
v = jnp.where(kv_mask, v.astype(jnp.float32), 0).astype(v.dtype)
|
||||
|
||||
qk = (
|
||||
jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32)
|
||||
* sm_scale
|
||||
)
|
||||
store_start = jnp.maximum(q_start - q_len_start, 0)
|
||||
store_end = jnp.minimum(q_end - q_len_start, num_q_per_blk)
|
||||
|
||||
row_ids = (
|
||||
(kv_len - q_len)
|
||||
+ q_len_start
|
||||
- q_start
|
||||
+ jax.lax.broadcasted_iota(
|
||||
jnp.int32,
|
||||
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
|
||||
0,
|
||||
)
|
||||
// num_q_heads_per_kv_head
|
||||
)
|
||||
col_ids = kv_len_start + jax.lax.broadcasted_iota(
|
||||
jnp.int32,
|
||||
(num_q_per_blk * num_q_heads_per_kv_head, num_kv_per_blk),
|
||||
1,
|
||||
)
|
||||
causal_mask = row_ids < col_ids
|
||||
if sliding_window is not None:
|
||||
causal_mask = jnp.logical_or(causal_mask,
|
||||
row_ids - sliding_window >= col_ids)
|
||||
if soft_cap is not None:
|
||||
qk = soft_cap * jnp.tanh(qk / soft_cap)
|
||||
qk += jnp.where(causal_mask, mask_value, 0.0)
|
||||
m_curr = jnp.max(qk, axis=1, keepdims=True)
|
||||
s_curr = jnp.exp(qk - m_curr)
|
||||
qkv = jnp.dot(s_curr, v, preferred_element_type=jnp.float32)
|
||||
lm_store_shape = head_m_ref.shape
|
||||
m_curr = jnp.broadcast_to(m_curr, lm_store_shape)
|
||||
l_curr = jnp.broadcast_to(
|
||||
s_curr.sum(axis=1, keepdims=True), lm_store_shape
|
||||
)
|
||||
m_prev = load_with_init(head_m_ref, -jnp.inf)
|
||||
l_prev = load_with_init(head_l_ref, 0.0)
|
||||
m_next = jnp.maximum(m_prev, m_curr)
|
||||
masked_store(
|
||||
head_m_ref, m_next, store_start, store_end, num_q_heads_per_kv_head
|
||||
)
|
||||
alpha = jnp.exp(m_prev - m_next)
|
||||
beta = jnp.exp(m_curr - m_next)
|
||||
l_alpha = alpha * l_prev
|
||||
l_next = l_alpha + beta * l_curr
|
||||
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)
|
||||
masked_store(
|
||||
head_l_ref,
|
||||
l_next_safe,
|
||||
store_start,
|
||||
store_end,
|
||||
num_q_heads_per_kv_head,
|
||||
)
|
||||
|
||||
def broadcast_to_shape(arr, shape):
|
||||
if arr.shape == shape:
|
||||
return arr
|
||||
assert len(arr.shape) == len(shape)
|
||||
assert arr.shape[0] == shape[0]
|
||||
assert shape[1] % arr.shape[1] == 0
|
||||
# no-op concatenation.
|
||||
return jnp.concatenate(
|
||||
[arr for _ in range(shape[1] // arr.shape[1])], axis=1
|
||||
)
|
||||
|
||||
o_curr = load_with_init(head_acc_ref, 0.0).reshape(-1, head_dim)
|
||||
l_alpha = broadcast_to_shape(l_alpha, qkv.shape)
|
||||
beta = broadcast_to_shape(beta, qkv.shape)
|
||||
l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape)
|
||||
out = lax.div(
|
||||
l_alpha * o_curr + beta * qkv,
|
||||
l_next_safe,
|
||||
)
|
||||
masked_store(
|
||||
head_acc_ref,
|
||||
out.reshape(head_acc_ref.shape),
|
||||
store_start,
|
||||
store_end,
|
||||
)
|
||||
|
||||
def is_valid_kv_blk_in_cur_seq(kv_states):
|
||||
kv_blk_idx, _ = kv_states
|
||||
return kv_blk_idx * num_kv_per_blk < kv_len
|
||||
|
||||
def compute_with_kv_blk_in_cur_seq(kv_states):
|
||||
kv_blk_idx, cur_buf_idx = kv_states
|
||||
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx = (
|
||||
get_next_prefetch_ids(
|
||||
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
|
||||
)
|
||||
)
|
||||
|
||||
@pl.when(next_heads_blk_idx < num_heads_blks)
|
||||
def prefetch_next_kv_blk():
|
||||
# TODO(jevinjiang): reuse the same buffer if it is already prefetched!
|
||||
# TODO(jevinjiang): only fetch effective dynamic size to hold kv_len and
|
||||
# DMA to fixed size buffer!
|
||||
next_async_copy_kv = create_kv_async_copy_descriptors(
|
||||
next_heads_blk_idx, next_seq_idx, next_kv_blk_idx, next_buf_idx
|
||||
)
|
||||
next_async_copy_kv.start()
|
||||
|
||||
cur_async_copy_kv = create_kv_async_copy_descriptors(
|
||||
heads_blk_idx, cur_seq_idx, kv_blk_idx, cur_buf_idx
|
||||
)
|
||||
kv_ref = cur_async_copy_kv.wait().reshape(
|
||||
num_kv_pages_per_blk * page_size * num_combined_kv_heads_per_blk,
|
||||
head_dim,
|
||||
)
|
||||
kv_packing = get_dtype_packing(kv_ref.dtype)
|
||||
# NOTE: kv_packing is divided by 2 because k and v are packed together.
|
||||
kv_load_step = max(1, kv_packing // 2)
|
||||
for kv_head_chunk_idx in range(0, num_kv_heads_per_blk, kv_load_step):
|
||||
k_list, v_list = strided_load_kv(
|
||||
kv_ref, kv_head_chunk_idx * 2, num_combined_kv_heads_per_blk
|
||||
)
|
||||
for step_idx in range(kv_load_step):
|
||||
k = k_list[step_idx]
|
||||
v = v_list[step_idx]
|
||||
if k_scale is not None:
|
||||
# NOTE: Conversion between arbitrary data types is not supported.
|
||||
# That's why it is converted to float32 first.
|
||||
k = k.astype(jnp.float32) * k_scale
|
||||
k = k.astype(q_ref.dtype)
|
||||
if v_scale is not None:
|
||||
v = v.astype(jnp.float32) * v_scale
|
||||
v = v.astype(q_ref.dtype)
|
||||
kv_head_idx = kv_head_chunk_idx + step_idx
|
||||
q_head_idx = kv_head_idx * num_q_heads_per_kv_head
|
||||
# TODO(jevinjiang): extra handling for packed type that can start at
|
||||
# unaligned position!
|
||||
q = fold_on_2nd_minor(
|
||||
q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :]
|
||||
)
|
||||
flash_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
l_ref.at[kv_head_idx],
|
||||
m_ref.at[kv_head_idx],
|
||||
acc_ref.at[
|
||||
:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :
|
||||
],
|
||||
kv_blk_idx=kv_blk_idx,
|
||||
)
|
||||
return kv_blk_idx + 1, next_buf_idx
|
||||
|
||||
_, next_buf_idx = lax.while_loop(
|
||||
is_valid_kv_blk_in_cur_seq,
|
||||
compute_with_kv_blk_in_cur_seq,
|
||||
(0, cur_buf_idx), # (kv_blk_idx, buf_idx)
|
||||
)
|
||||
next_seq_idx = lax.select(q_end <= q_len_end, cur_seq_idx + 1, cur_seq_idx)
|
||||
done = lax.select(q_end < q_len_end, done, 1)
|
||||
return done, next_seq_idx, next_buf_idx
|
||||
|
||||
_, seq_idx, buf_idx = lax.while_loop(
|
||||
is_cur_q_blk_needed,
|
||||
compute_with_cur_q_blk,
|
||||
(0, init_seq_idx, init_buf_idx), # (done, seq_idx, buf_idx)
|
||||
)
|
||||
# Reset seq_idx for next kv_heads_blk if run out of seqs!
|
||||
seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0)
|
||||
seq_buf_idx_ref[1] = buf_idx
|
||||
o_ref[...] = acc_ref[...].astype(q_ref.dtype)
|
||||
|
||||
|
||||
def get_dtype_packing(dtype):
|
||||
bits = dtypes.itemsize_bits(dtype)
|
||||
return 32 // bits
|
||||
|
||||
|
||||
def get_min_heads_per_blk(
|
||||
num_q_heads, num_combined_kv_heads, q_dtype, kv_dtype
|
||||
):
|
||||
q_packing = get_dtype_packing(q_dtype)
|
||||
kv_packing = get_dtype_packing(kv_dtype)
|
||||
|
||||
def can_be_xla_fully_tiled(x, packing):
|
||||
if x % packing != 0:
|
||||
return False
|
||||
x //= packing
|
||||
return x in (1, 2, 4, 8) or x % 8 == 0
|
||||
|
||||
# TODO(jevinjiang): support unaligned number of heads!
|
||||
if not can_be_xla_fully_tiled(num_combined_kv_heads, kv_packing):
|
||||
raise ValueError(
|
||||
f"Not implemented: {num_combined_kv_heads=} can not be XLA fully tiled."
|
||||
)
|
||||
assert num_combined_kv_heads % 2 == 0
|
||||
num_kv_heads = num_combined_kv_heads // 2
|
||||
assert num_q_heads % num_kv_heads == 0
|
||||
ratio = num_q_heads // num_kv_heads
|
||||
# TODO(jevinjiang): we can choose smaller tiling for packed type if large
|
||||
# second minor tiling is not on.
|
||||
max_combined_kv_tiling = 8 * kv_packing
|
||||
min_combined_kv_heads = (
|
||||
max_combined_kv_tiling
|
||||
if num_combined_kv_heads % max_combined_kv_tiling == 0
|
||||
else num_combined_kv_heads
|
||||
)
|
||||
min_q_heads = min_combined_kv_heads // 2 * ratio
|
||||
if can_be_xla_fully_tiled(min_q_heads, q_packing):
|
||||
return min_q_heads, min_combined_kv_heads
|
||||
return num_q_heads, num_combined_kv_heads
|
||||
|
||||
|
||||
@functools.partial(
|
||||
jax.jit,
|
||||
static_argnames=[
|
||||
"sm_scale",
|
||||
"mask_value",
|
||||
"num_kv_pages_per_block",
|
||||
"num_queries_per_block",
|
||||
"vmem_limit_bytes",
|
||||
"sliding_window",
|
||||
"soft_cap",
|
||||
"k_scale",
|
||||
"v_scale",
|
||||
],
|
||||
)
|
||||
def ragged_paged_attention(
|
||||
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
|
||||
# TODO(jevinjiang): create a write_to_kv_cache kernel!
|
||||
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
|
||||
kv_lens: jax.Array, # i32[max_num_seqs]
|
||||
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
|
||||
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
|
||||
num_seqs: jax.Array, # i32[1]
|
||||
*,
|
||||
sm_scale: float = 1.0,
|
||||
sliding_window: int | None = None,
|
||||
soft_cap: float | None = None,
|
||||
mask_value: float | None = DEFAULT_MASK_VALUE,
|
||||
k_scale: float | None = None,
|
||||
v_scale: float | None = None,
|
||||
num_kv_pages_per_block: int | None = None,
|
||||
num_queries_per_block: int | None = None,
|
||||
vmem_limit_bytes: int | None = None,
|
||||
):
|
||||
"""Ragged paged attention that supports mixed prefill and decode.
|
||||
|
||||
Args:
|
||||
q: concatenated all sequences' queries.
|
||||
kv_pages: paged KV cache. Normally in HBM.
|
||||
kv_lens: padded kv lengths. Only the first num_seqs values are valid.
|
||||
page_indices: the first index indicates which page to use in the kv cache
|
||||
for each sequence. Only the first num_seqs values are valid.
|
||||
cu_q_lens: the cumulative sum of the effective query lengths. Similar to
|
||||
kv_lens, only the first num_seqs+1 values are valid.
|
||||
num_seqs: the dynamic number of sequences.
|
||||
sm_scale: the softmax scale which will be applied to the Q@K^T.
|
||||
sliding_window: the sliding window size for the attention.
|
||||
soft_cap: the logit soft cap for the attention.
|
||||
mask_value: mask value for causal mask.
|
||||
k_scale: the scale for the key cache.
|
||||
v_scale: the scale for the value cache.
|
||||
num_kv_pages_per_block: number of kv pages to be processed in one flash
|
||||
attention block in the pallas kernel.
|
||||
num_queries_per_block: number of kv pages to be processed in one flash
|
||||
attention block in the pallas kernel.
|
||||
vmem_limit_bytes: the vmem limit for the pallas kernel.
|
||||
|
||||
Returns:
|
||||
The output of the attention.
|
||||
"""
|
||||
static_validate_inputs(
|
||||
q,
|
||||
kv_pages,
|
||||
kv_lens,
|
||||
page_indices,
|
||||
cu_q_lens,
|
||||
num_seqs,
|
||||
sm_scale=sm_scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
mask_value=mask_value,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
num_kv_pages_per_block=num_kv_pages_per_block,
|
||||
num_queries_per_block=num_queries_per_block,
|
||||
vmem_limit_bytes=vmem_limit_bytes,
|
||||
)
|
||||
if mask_value is None:
|
||||
mask_value = DEFAULT_MASK_VALUE
|
||||
num_q_tokens, num_q_heads, head_dim = q.shape
|
||||
_, page_size, num_combined_kv_heads, _ = kv_pages.shape
|
||||
assert num_combined_kv_heads % 2 == 0
|
||||
num_kv_heads = num_combined_kv_heads // 2
|
||||
_, pages_per_seq = page_indices.shape
|
||||
num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk(
|
||||
num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype
|
||||
)
|
||||
num_q_per_blk = num_queries_per_block
|
||||
num_kv_pages_per_blk = num_kv_pages_per_block
|
||||
if num_q_per_blk is None or num_kv_pages_per_blk is None:
|
||||
num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes(
|
||||
q.dtype,
|
||||
kv_pages.dtype,
|
||||
num_q_heads_per_blk,
|
||||
num_combined_kv_heads_per_blk // 2,
|
||||
head_dim,
|
||||
page_size,
|
||||
num_q_tokens,
|
||||
pages_per_seq,
|
||||
)
|
||||
num_q_heads_per_kv_head = num_q_heads // num_kv_heads
|
||||
num_q_blks = pl.cdiv(num_q_tokens, num_q_per_blk)
|
||||
assert num_combined_kv_heads_per_blk % 2 == 0
|
||||
num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2
|
||||
assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0
|
||||
num_heads_blks = num_q_heads // num_q_heads_per_blk
|
||||
grid = (num_heads_blks, num_q_blks)
|
||||
|
||||
def q_index_map(heads_blk_idx, q_blk_idx, *_):
|
||||
return (q_blk_idx, heads_blk_idx, 0)
|
||||
|
||||
q_block_spec = pl.BlockSpec(
|
||||
(num_q_per_blk, num_q_heads_per_blk, head_dim),
|
||||
q_index_map,
|
||||
)
|
||||
in_specs = [
|
||||
q_block_spec,
|
||||
pl.BlockSpec(memory_space=pl.ANY),
|
||||
]
|
||||
out_specs = q_block_spec
|
||||
lm_scratch = pltpu.VMEM(
|
||||
# TODO(jevinjiang): use 128 instead of 1 is due to Mosaic does not support
|
||||
# unaligned slicing!
|
||||
(num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128),
|
||||
jnp.float32,
|
||||
)
|
||||
acc_scratch = pltpu.VMEM(
|
||||
(num_q_per_blk, num_q_heads_per_blk, head_dim),
|
||||
jnp.float32,
|
||||
)
|
||||
double_buf_scratch = pltpu.VMEM(
|
||||
(
|
||||
2, # For double buffering during DMA copies.
|
||||
num_kv_pages_per_blk,
|
||||
page_size,
|
||||
num_combined_kv_heads_per_blk,
|
||||
head_dim,
|
||||
),
|
||||
kv_pages.dtype,
|
||||
)
|
||||
scratch_shapes = [
|
||||
double_buf_scratch, # kv_bufs
|
||||
pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers.
|
||||
lm_scratch, # l_ref
|
||||
lm_scratch, # m_ref
|
||||
acc_scratch,
|
||||
]
|
||||
scalar_prefetches = (
|
||||
kv_lens,
|
||||
page_indices,
|
||||
cu_q_lens,
|
||||
jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx
|
||||
num_seqs,
|
||||
)
|
||||
kernel = pl.pallas_call(
|
||||
functools.partial(
|
||||
ragged_paged_attention_kernel,
|
||||
sm_scale=sm_scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
mask_value=mask_value,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=len(scalar_prefetches),
|
||||
in_specs=in_specs,
|
||||
out_specs=out_specs,
|
||||
grid=grid,
|
||||
scratch_shapes=scratch_shapes,
|
||||
),
|
||||
compiler_params=pltpu.CompilerParams(
|
||||
dimension_semantics=(
|
||||
"arbitrary",
|
||||
"arbitrary",
|
||||
),
|
||||
vmem_limit_bytes=vmem_limit_bytes,
|
||||
),
|
||||
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
|
||||
name="ragged_paged_attention_kernel",
|
||||
)
|
||||
|
||||
return kernel(*scalar_prefetches, q, kv_pages)
|
||||
+1484
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,14 @@
|
||||
# Copyright 2025 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,210 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of the Philox PRNG as a Pallas kernel."""
|
||||
from collections.abc import Sequence
|
||||
import jax
|
||||
from jax import typing
|
||||
from jax._src import prng
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.experimental.pallas.ops.tpu.random import prng_utils
|
||||
|
||||
Shape = Sequence[int]
|
||||
|
||||
BLOCK_SIZE = (256, 256)
|
||||
|
||||
# Philox constants. See original paper at:
|
||||
# "Parallel Random Numbers: As Easy as 1, 2, 3", Salmon et. al. 2011
|
||||
K_HI_32 = 0x9E3779B9
|
||||
K_LO_32 = 0xBB67AE85
|
||||
MUL_A = 0xCD9E8D57
|
||||
MUL_B = 0xD2511F53
|
||||
|
||||
|
||||
def mul32_hi_lo(x: jax.Array, y: jax.Array) -> tuple[jax.Array, jax.Array]:
|
||||
"""Multiplies 2 32-bit values and returns the hi+low bits of the result."""
|
||||
xhi = x >> 16
|
||||
yhi = y >> 16
|
||||
xlo = x & 0xffff
|
||||
ylo = y & 0xffff
|
||||
|
||||
xy_hi = xhi * yhi
|
||||
xy_lo = xlo * ylo
|
||||
cross_xy = xhi * ylo
|
||||
cross_yx = xlo * yhi
|
||||
carry = (cross_xy & 0xffff) + (cross_yx & 0xffff) + (xy_lo >> 16)
|
||||
result_hi = xy_hi + (cross_xy >> 16) + (cross_yx >> 16) + (carry >> 16)
|
||||
result_lo = (carry << 16) + (xy_lo & 0xffff)
|
||||
return result_hi, result_lo
|
||||
|
||||
|
||||
def philox_4x32(hi0, lo0, hi1, lo1, k_hi, k_lo, rounds = 10):
|
||||
"""Philox 4x32 keyed hash function."""
|
||||
k_hi_const = jnp.array(K_HI_32, dtype=jnp.uint32)
|
||||
k_lo_const = jnp.array(K_LO_32, dtype=jnp.uint32)
|
||||
mul_a = jnp.array(MUL_A, dtype=jnp.uint32)
|
||||
mul_b = jnp.array(MUL_B, dtype=jnp.uint32)
|
||||
|
||||
for i in range(rounds):
|
||||
# Compute the round.
|
||||
new_hi0, new_lo0 = mul32_hi_lo(mul_a, hi1)
|
||||
new_hi0 = new_hi0 ^ lo0 ^ k_hi
|
||||
new_hi1, new_lo1 = mul32_hi_lo(mul_b, hi0)
|
||||
new_hi1 = new_hi1 ^ lo1 ^ k_lo
|
||||
hi0, lo0, hi1, lo1 = new_hi0, new_lo0, new_hi1, new_lo1
|
||||
|
||||
# Raise the key on all iterations except for the last round.
|
||||
if i != rounds - 1:
|
||||
k_hi = k_hi + k_hi_const
|
||||
k_lo = k_lo + k_lo_const
|
||||
return hi0, lo0, hi1, lo1
|
||||
|
||||
|
||||
def philox_4x32_kernel(key,
|
||||
shape: Shape,
|
||||
unpadded_shape: Shape,
|
||||
block_size: tuple[int, int],
|
||||
offset: typing.ArrayLike = 0,
|
||||
fuse_output: bool = True):
|
||||
"""Generates random bits using the Philox keyed hash function.
|
||||
|
||||
Args:
|
||||
key: A Philox key of shape (2,).
|
||||
shape: The shape of the output. Must be divisible by `block_size`.
|
||||
unpadded_shape: If `shape` is padded, then this is the shape of the
|
||||
output tensor if it were not padded. This is important for indexing
|
||||
calculations within the kernel. If `shape` is not padded, then this
|
||||
should be equal to `shape`.
|
||||
block_size: The block size of the kernel.
|
||||
offset: An optional offset to the counts.
|
||||
fuse_output: Whether to fuse the output bits into a single value.
|
||||
|
||||
Returns:
|
||||
A tensor of random bits of shape `shape` if fuse_output=True. Otherwise,
|
||||
this will return a tensor of shape (2, *shape) with the first channel being
|
||||
the high bits and the second channel being the low bits.
|
||||
"""
|
||||
shape = tuple(shape)
|
||||
if np.prod(shape) > jnp.iinfo(jnp.uint32).max:
|
||||
raise ValueError(
|
||||
f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}")
|
||||
|
||||
if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0):
|
||||
raise ValueError(
|
||||
f"Shape dimension {shape[-2:]} must be divisible by {block_size}")
|
||||
grid_dims = shape[:-2] + (
|
||||
shape[-2] // block_size[-2], shape[-1] // block_size[1],)
|
||||
offset = jnp.array(offset, dtype=jnp.uint32)
|
||||
if offset.ndim != 0:
|
||||
raise ValueError(f"Offset must be scalar, got {offset.shape}")
|
||||
offset = jnp.reshape(offset, (1,))
|
||||
|
||||
def kernel(offset_ref, key_ref, out_ref):
|
||||
counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
|
||||
offset = prng_utils.compute_scalar_offset(
|
||||
counts_idx, unpadded_shape, block_shape)
|
||||
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
|
||||
counts_lo = counts_lo + offset.astype(jnp.uint32) + offset_ref[0]
|
||||
counts_lo = counts_lo.astype(jnp.uint32)
|
||||
# TODO(justinfu): Support hi bits on count.
|
||||
_zeros = jnp.zeros_like(counts_lo)
|
||||
k1 = jnp.reshape(key_ref[0, 0], (1, 1))
|
||||
k2 = jnp.reshape(key_ref[0, 1], (1, 1))
|
||||
o1, o2, _, _ = philox_4x32(_zeros, counts_lo, _zeros, _zeros, k1, k2)
|
||||
if fuse_output:
|
||||
out_bits = o1 ^ o2
|
||||
out_ref[...] = out_bits.reshape(out_ref.shape)
|
||||
else:
|
||||
out_ref[0, ...] = o1.reshape(out_ref[0].shape)
|
||||
out_ref[1, ...] = o2.reshape(out_ref[0].shape)
|
||||
|
||||
key = key.reshape((1, 2))
|
||||
block_shape = (1,) * (len(shape)-2) + block_size
|
||||
if fuse_output:
|
||||
out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32)
|
||||
out_spec = pl.BlockSpec(block_shape, lambda *idxs: idxs)
|
||||
else:
|
||||
out = jax.ShapeDtypeStruct((2,) + shape, dtype=jnp.uint32)
|
||||
out_spec = pl.BlockSpec((2,) + block_shape, lambda *idxs: (0, *idxs))
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[
|
||||
pl.BlockSpec(memory_space=pltpu.SMEM),
|
||||
pl.BlockSpec(memory_space=pltpu.SMEM),
|
||||
],
|
||||
out_specs=out_spec,
|
||||
grid=grid_dims,
|
||||
out_shape=out,
|
||||
)(offset, key)
|
||||
|
||||
|
||||
def philox_4x32_count(key,
|
||||
shape: Shape,
|
||||
offset: typing.ArrayLike = 0,
|
||||
fuse_output: bool = True):
|
||||
"""Convenience function to call philox_4x32_kernel with padded shapes."""
|
||||
if len(shape) == 0:
|
||||
return philox_4x32_count(
|
||||
key, (1, 1), offset=offset, fuse_output=fuse_output)[..., 0, 0]
|
||||
elif len(shape) == 1:
|
||||
return philox_4x32_count(
|
||||
key, (1, *shape), offset=offset, fuse_output=fuse_output)[..., 0, :]
|
||||
|
||||
requires_pad = (
|
||||
shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
|
||||
if requires_pad:
|
||||
padded_shape = tuple(shape[:-2]) + (
|
||||
prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]),
|
||||
prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]),
|
||||
)
|
||||
padded_result = philox_4x32_kernel(
|
||||
key, padded_shape, shape,
|
||||
block_size=BLOCK_SIZE, offset=offset,
|
||||
fuse_output=fuse_output)
|
||||
return padded_result[..., :shape[-2], :shape[-1]]
|
||||
else:
|
||||
return philox_4x32_kernel(key, shape, shape,
|
||||
block_size=BLOCK_SIZE, offset=offset,
|
||||
fuse_output=fuse_output)
|
||||
|
||||
|
||||
def philox_split(key, shape: Shape):
|
||||
"""Splits the key into two keys of the same shape."""
|
||||
bits1, bits2 = philox_4x32_count(key, shape, fuse_output=False)
|
||||
return jnp.stack([bits1, bits2], axis=bits1.ndim)
|
||||
|
||||
|
||||
def philox_random_bits(key, bit_width: int, shape: Shape):
|
||||
if bit_width != 32:
|
||||
raise ValueError("Only 32-bit PRNG supported.")
|
||||
return philox_4x32_count(key, shape, fuse_output=True)
|
||||
|
||||
|
||||
def philox_fold_in(key, data):
|
||||
assert data.ndim == 0
|
||||
return philox_4x32_count(key, (), offset=data, fuse_output=False)
|
||||
|
||||
|
||||
plphilox_prng_impl = prng.PRNGImpl(
|
||||
key_shape=(2,),
|
||||
seed=prng.threefry_seed,
|
||||
split=philox_split,
|
||||
random_bits=philox_random_bits,
|
||||
fold_in=philox_fold_in,
|
||||
name="pallas_philox4x32",
|
||||
tag="pllox")
|
||||
|
||||
prng.register_prng(plphilox_prng_impl)
|
||||
+55
@@ -0,0 +1,55 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Helper functions for PRNG kernels."""
|
||||
from collections.abc import Sequence
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
Shape = Sequence[int]
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
|
||||
def blocked_iota(block_shape: Shape,
|
||||
total_shape: Shape):
|
||||
"""Computes a sub-block of a larger shaped iota.
|
||||
|
||||
Args:
|
||||
block_shape: The output block shape of the iota.
|
||||
total_shape: The total shape of the input tensor.
|
||||
Returns:
|
||||
Result of the blocked iota.
|
||||
"""
|
||||
iota_data = jnp.zeros(block_shape, dtype=jnp.uint32)
|
||||
multiplier = 1
|
||||
for dim in range(len(block_shape)-1, -1, -1):
|
||||
block_mult = 1
|
||||
counts_lo = lax.broadcasted_iota(
|
||||
dtype=jnp.uint32, shape=block_shape, dimension=dim
|
||||
)
|
||||
iota_data += counts_lo * multiplier * block_mult
|
||||
multiplier *= total_shape[dim]
|
||||
return iota_data
|
||||
|
||||
|
||||
def compute_scalar_offset(iteration_index,
|
||||
total_size: Shape,
|
||||
block_size: Shape):
|
||||
ndims = len(iteration_index)
|
||||
dim_size = 1
|
||||
total_idx = 0
|
||||
for i in range(ndims-1, -1, -1):
|
||||
dim_idx = iteration_index[i] * block_size[i]
|
||||
total_idx += dim_idx * dim_size
|
||||
dim_size *= total_size[i]
|
||||
return total_idx
|
||||
@@ -0,0 +1,120 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of the Threefry PRNG as a Pallas kernel."""
|
||||
from collections.abc import Sequence
|
||||
import jax
|
||||
from jax._src import prng
|
||||
from jax.experimental import pallas as pl
|
||||
from jax.experimental.pallas import tpu as pltpu
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.experimental.pallas.ops.tpu.random import prng_utils
|
||||
|
||||
Shape = Sequence[int]
|
||||
|
||||
BLOCK_SIZE = (256, 256)
|
||||
|
||||
|
||||
def threefry_2x32_count(key,
|
||||
shape: Shape,
|
||||
unpadded_shape: Shape,
|
||||
block_size: tuple[int, int]):
|
||||
"""Generates random bits using the Threefry hash function.
|
||||
|
||||
This function is a fusion of prng.shaped_iota and prng.threefry_2x32 from
|
||||
the JAX core library.
|
||||
|
||||
Args:
|
||||
key: A threefry key of shape (2,).
|
||||
shape: The shape of the output. Must be divisible by `block_size`.
|
||||
unpadded_shape: If `shape` is padded, then this is the shape of the
|
||||
output tensor if it were not padded. This is important for indexing
|
||||
calculations within the kernel. If `shape` is not padded, then this
|
||||
should be equal to `shape`.
|
||||
block_size: The block size of the kernel.
|
||||
|
||||
Returns:
|
||||
A tensor of random bits of shape `shape`.
|
||||
"""
|
||||
shape = tuple(shape)
|
||||
if np.prod(shape) > jnp.iinfo(jnp.uint32).max:
|
||||
raise ValueError(
|
||||
f"Shape too large: {np.prod(shape)} > {np.iinfo(jnp.uint32).max}")
|
||||
|
||||
if (shape[-2] % block_size[-2] != 0) or (shape[-1] % block_size[-1] != 0):
|
||||
raise ValueError(
|
||||
f"Shape dimension {shape[-2:]} must be divisible by {block_size}")
|
||||
grid_dims = shape[:-2] + (
|
||||
shape[-2] // block_size[-2], shape[-1] // block_size[1],)
|
||||
|
||||
def kernel(key_ref, out_ref):
|
||||
counts_idx = tuple(pl.program_id(i) for i in range(len(grid_dims)))
|
||||
offset = prng_utils.compute_scalar_offset(
|
||||
counts_idx, unpadded_shape, block_shape)
|
||||
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
|
||||
counts_lo = counts_lo + offset.astype(jnp.uint32)
|
||||
counts_lo = counts_lo.astype(jnp.uint32)
|
||||
# TODO(justinfu): Support hi bits on count.
|
||||
counts_hi = jnp.zeros_like(counts_lo)
|
||||
k1 = jnp.reshape(key_ref[0, 0], (1, 1))
|
||||
k2 = jnp.reshape(key_ref[0, 1], (1, 1))
|
||||
o1, o2 = prng.threefry2x32_p.bind(
|
||||
k1, k2, counts_hi, counts_lo)
|
||||
out_bits = o1 ^ o2
|
||||
out_ref[...] = out_bits.reshape(out_ref.shape)
|
||||
|
||||
key = key.reshape((1, 2))
|
||||
out = jax.ShapeDtypeStruct(shape, dtype=jnp.uint32)
|
||||
block_shape = (1,) * (len(shape)-2) + block_size
|
||||
result = pl.pallas_call(
|
||||
kernel,
|
||||
in_specs=[pl.BlockSpec(memory_space=pltpu.SMEM)],
|
||||
out_specs=pl.BlockSpec(block_shape, lambda *idxs: idxs),
|
||||
grid=grid_dims,
|
||||
out_shape=out,
|
||||
)(key)
|
||||
return result
|
||||
|
||||
def plthreefry_random_bits(key, bit_width: int, shape: Shape):
|
||||
if bit_width != 32:
|
||||
raise ValueError("Only 32-bit PRNG supported.")
|
||||
if len(shape) == 0:
|
||||
return plthreefry_random_bits(key, bit_width, (1, 1))[0, 0]
|
||||
elif len(shape) == 1:
|
||||
return plthreefry_random_bits(key, bit_width, (1, *shape))[0]
|
||||
|
||||
requires_pad = (
|
||||
shape[-2] % BLOCK_SIZE[-2] != 0) or (shape[-1] % BLOCK_SIZE[-1] != 0)
|
||||
if requires_pad:
|
||||
padded_shape = tuple(shape[:-2]) + (
|
||||
prng_utils.round_up(shape[-2], BLOCK_SIZE[-2]),
|
||||
prng_utils.round_up(shape[-1], BLOCK_SIZE[-1]),
|
||||
)
|
||||
padded_result = threefry_2x32_count(
|
||||
key, padded_shape, shape, block_size=BLOCK_SIZE)
|
||||
return padded_result[..., :shape[-2], :shape[-1]]
|
||||
else:
|
||||
return threefry_2x32_count(key, shape, shape, block_size=BLOCK_SIZE)
|
||||
|
||||
|
||||
plthreefry_prng_impl = prng.PRNGImpl(
|
||||
key_shape=(2,),
|
||||
seed=prng.threefry_seed,
|
||||
split=prng.threefry_split,
|
||||
random_bits=plthreefry_random_bits,
|
||||
fold_in=prng.threefry_fold_in,
|
||||
name="pallas_threefry2x32",
|
||||
tag="plfry")
|
||||
|
||||
prng.register_prng(plthreefry_prng_impl)
|
||||
+32
@@ -0,0 +1,32 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import BlockSizes as BlockSizes
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mha_reference as make_masked_mha_reference
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_masked_mqa_reference as make_masked_mqa_reference
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha as make_splash_mha
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mha_single_device as make_splash_mha_single_device
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa as make_splash_mqa
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import make_splash_mqa_single_device as make_splash_mqa_single_device
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import QKVLayout as QKVLayout
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_kernel import SegmentIds as SegmentIds
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import CausalMask as CausalMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import FullMask as FullMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import LocalMask as LocalMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_causal_mask as make_causal_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_local_attention_mask as make_local_attention_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import make_random_mask as make_random_mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import Mask as Mask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import MultiHeadMask as MultiHeadMask
|
||||
from jax.experimental.pallas.ops.tpu.splash_attention.splash_attention_mask import NumpyMask as NumpyMask
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
+2637
File diff suppressed because it is too large
Load Diff
+560
@@ -0,0 +1,560 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Mini-mask creation library."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Mask:
|
||||
"""A base class for splash attention masks."""
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, idx) -> np.ndarray:
|
||||
raise NotImplementedError
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
'Conversion to bool is unsupported. Could be caused by using logical'
|
||||
' instead of bitwise operations on masks.'
|
||||
)
|
||||
|
||||
def __or__(self, other: Mask) -> Mask:
|
||||
if self.shape != other.shape:
|
||||
raise ValueError(
|
||||
f'Invalid shape for other: {other.shape}, expected: {self.shape}'
|
||||
)
|
||||
return LogicalOr(self, other)
|
||||
|
||||
def __and__(self, other: Mask) -> Mask:
|
||||
if self.shape != other.shape:
|
||||
raise ValueError(
|
||||
f'Invalid shape for other: {other.shape}, expected: {self.shape}'
|
||||
)
|
||||
return LogicalAnd(self, other)
|
||||
|
||||
|
||||
def make_causal_mask(shape: tuple[int, int], offset: int = 0) -> np.ndarray:
|
||||
"""Makes a causal attention mask.
|
||||
|
||||
Args:
|
||||
shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len).
|
||||
offset: Offset of q start wrt kv. A positive offset shifts the bottom
|
||||
triangle upward, a negative one shifts it downward. A negative offset
|
||||
makes the first 'offset' rows of the attention matrix all 0s which leads
|
||||
to undefined softmax.
|
||||
|
||||
Returns:
|
||||
The causal mask.
|
||||
"""
|
||||
q_seq_len, kv_seq_len = shape
|
||||
q_idx = np.arange(q_seq_len, dtype=np.int32)
|
||||
kv_idx = np.arange(kv_seq_len, dtype=np.int32)
|
||||
return (q_idx[:, None] + offset >= kv_idx[None, :]).astype(np.bool_)
|
||||
|
||||
|
||||
def make_local_attention_mask(
|
||||
shape: tuple[int, int],
|
||||
window_size: tuple[int | None, int | None],
|
||||
*,
|
||||
offset: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""Makes a local attention mask."""
|
||||
q_seq_len, kv_seq_len = shape
|
||||
q_idx = np.arange(q_seq_len, dtype=np.int32)
|
||||
kv_idx = np.arange(kv_seq_len, dtype=np.int32)
|
||||
mask = np.ones((q_seq_len, kv_seq_len), dtype=np.bool_)
|
||||
left, right = window_size
|
||||
if left is not None:
|
||||
mask = mask & (q_idx[:, None] - left + offset <= kv_idx[None, :])
|
||||
if right is not None:
|
||||
mask = mask & (q_idx[:, None] + right + offset >= kv_idx[None, :])
|
||||
return mask.astype(np.bool_)
|
||||
|
||||
|
||||
def make_chunk_attention_mask(
|
||||
shape: tuple[int, int], chunk_size: int
|
||||
) -> np.ndarray:
|
||||
"""Makes a chunked causal attention mask.
|
||||
|
||||
Args:
|
||||
shape: The desired shape of the mask (q_seq_len, kv_seq_len).
|
||||
chunk_size: The size of the attention chunks.
|
||||
|
||||
Returns:
|
||||
A boolean mask of shape `mask_shape` where True indicates attention is
|
||||
allowed according to chunked causal rules, and False otherwise.
|
||||
|
||||
Raises:
|
||||
ValueError: If chunk_window_size is None or not positive.
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
raise ValueError('chunk_size must be positive')
|
||||
|
||||
q_seq_len, kv_seq_len = shape
|
||||
q_idx = np.arange(q_seq_len, dtype=np.int32)
|
||||
kv_idx = np.arange(kv_seq_len, dtype=np.int32)
|
||||
|
||||
# chunk mask calculation
|
||||
same_chunk = (q_idx[:, None] // chunk_size) == (kv_idx[None, :] // chunk_size)
|
||||
mask = same_chunk & (q_idx[:, None] >= kv_idx[None, :])
|
||||
return mask
|
||||
|
||||
|
||||
def make_random_mask(
|
||||
shape: tuple[int, int], sparsity: float, seed: int
|
||||
) -> np.ndarray:
|
||||
"""Makes a random attention mask."""
|
||||
np.random.seed(seed)
|
||||
return np.random.binomial(n=1, p=1.0 - sparsity, size=shape).astype(np.bool_)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LogicalOr(Mask):
|
||||
left: Mask
|
||||
right: Mask
|
||||
|
||||
def __init__(self, left: Mask, right: Mask):
|
||||
if left.shape != right.shape:
|
||||
raise ValueError('Masks must have the same shape')
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return self.left.shape
|
||||
|
||||
def __getitem__(self, idx) -> np.ndarray:
|
||||
return self.left[idx] | self.right[idx]
|
||||
|
||||
def __hash__(self):
|
||||
return hash((type(self),) + (self.left, self.right))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LogicalAnd(Mask):
|
||||
left: Mask
|
||||
right: Mask
|
||||
|
||||
def __init__(self, left: Mask, right: Mask):
|
||||
if left.shape != right.shape:
|
||||
raise ValueError('Masks must have the same shape')
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return self.left.shape
|
||||
|
||||
def __getitem__(self, idx) -> np.ndarray:
|
||||
return self.left[idx] & self.right[idx]
|
||||
|
||||
def __hash__(self):
|
||||
return hash((type(self),) + (self.left, self.right))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MultiHeadMask(Mask):
|
||||
"""Lazy multihead mask, combines multiple lazy masks one per head."""
|
||||
|
||||
masks: Sequence[Mask]
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.masks:
|
||||
raise ValueError('Unsupported empty tuple of masks')
|
||||
|
||||
shape = self.masks[0].shape
|
||||
for mask in self.masks[1:]:
|
||||
if shape != mask.shape:
|
||||
raise ValueError(
|
||||
f'Unexpected mask shape, got: {mask.shape}, expected: {shape}'
|
||||
)
|
||||
|
||||
if not all(isinstance(mask, Mask) for mask in self.masks):
|
||||
raise ValueError('masks should be of type Mask')
|
||||
|
||||
if any(isinstance(mask, MultiHeadMask) for mask in self.masks):
|
||||
raise ValueError('Nesting MultiHeadMasks is not supported')
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return (len(self.masks),) + self.masks[0].shape
|
||||
|
||||
def __getitem__(self, idx) -> np.ndarray:
|
||||
if len(idx) != 3:
|
||||
raise NotImplementedError(f'Unsupported slice: {idx}')
|
||||
|
||||
head_slice = idx[0]
|
||||
if isinstance(head_slice, int):
|
||||
assert head_slice >= 0 and head_slice <= len(self.masks)
|
||||
return self.masks[head_slice][idx[1:]]
|
||||
else:
|
||||
slice_masks = [mask[idx[1:]] for mask in self.masks[head_slice]]
|
||||
return np.stack(slice_masks)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
return self.masks == other.masks
|
||||
|
||||
def __hash__(self):
|
||||
return hash((type(self),) + tuple(hash(mask) for mask in self.masks))
|
||||
|
||||
|
||||
class _ComputableMask(Mask):
|
||||
"""Superclass for all masks that can be computed inside the kernel using a callable object.
|
||||
|
||||
This subclass is designed to be used with Splash Attention.
|
||||
It allows the mask logic to be computed on-the-fly or fused into the attention
|
||||
kernel, avoiding the memory cost of materializing the full
|
||||
(sequence_length, sequence_length) boolean mask array, which can be excessive
|
||||
for long sequences.
|
||||
|
||||
Attributes:
|
||||
_shape: Shape of the 2-dim mask: (q_seq_len, kv_seq_len).
|
||||
offset: Offset of q start wrt kv. A positive offset shifts the bottom
|
||||
triangle upward, a negative one shifts it downward. A negative offset
|
||||
makes the first 'offset' rows of the attention matrix all 0s which leads
|
||||
to undefined softmax.
|
||||
q_sequence: Indices of Q sequence. q_sequence is reused across __getitem__
|
||||
calls which is important for compile-time performance.
|
||||
mask_function: Function used by the SplashAttention kernel to compute the
|
||||
mask rather than loading it.
|
||||
"""
|
||||
|
||||
_shape: tuple[int, int]
|
||||
q_sequence: np.ndarray
|
||||
mask_function: Callable[..., Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: tuple[int, int],
|
||||
mask_function: Callable[..., Any],
|
||||
shard_count: int = 1,
|
||||
):
|
||||
self._shape = shape
|
||||
self.mask_function = mask_function
|
||||
q_seq_len = self.shape[0]
|
||||
|
||||
if q_seq_len % (shard_count * shard_count) != 0:
|
||||
raise ValueError(
|
||||
f'Shard count squared ({shard_count * shard_count}) must'
|
||||
f' divide Q seq_len ({self.shape[0]}) evenly.'
|
||||
)
|
||||
|
||||
self.q_sequence = np.arange(q_seq_len, dtype=np.int32)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return self._shape
|
||||
|
||||
def __getitem__(self, idx) -> np.ndarray:
|
||||
if len(idx) != 2:
|
||||
raise NotImplementedError(f'Unsupported slice: {idx}')
|
||||
|
||||
q_slice, kv_slice = idx
|
||||
if not isinstance(q_slice, slice) or not isinstance(kv_slice, slice):
|
||||
raise NotImplementedError(f'Unsupported slice: {idx}')
|
||||
|
||||
q_slice = _fill_slice(q_slice, self.shape[0])
|
||||
kv_slice = _fill_slice(kv_slice, self.shape[1])
|
||||
|
||||
rows = self.q_sequence[q_slice]
|
||||
cols = np.arange(kv_slice.start, kv_slice.stop)
|
||||
|
||||
return self.mask_function(rows[:, None], cols[None, :])
|
||||
|
||||
def __eq__(self, other: object):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __hash__(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CausalMask(_ComputableMask):
|
||||
"""Lazy causal mask, prevents the model from attending to future tokens.
|
||||
|
||||
Attributes:
|
||||
offset: Offset of q start wrt kv. A positive offset shifts the bottom
|
||||
triangle upward, a negative one shifts it downward. A negative offset
|
||||
makes the first 'offset' rows of the attention matrix all 0s which leads
|
||||
to undefined softmax.
|
||||
"""
|
||||
|
||||
offset: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: tuple[int, int],
|
||||
offset: int = 0,
|
||||
shard_count: int = 1,
|
||||
):
|
||||
self.offset = offset
|
||||
|
||||
def causal_mask_function(q_ids, kv_ids):
|
||||
# When evaluating the mask in _process_mask we typically work with numpy
|
||||
# array views.
|
||||
# Avoid the addition when possible to avoid instantiating an actual array.
|
||||
if self.offset == 0:
|
||||
return q_ids >= kv_ids
|
||||
else:
|
||||
return q_ids + self.offset >= kv_ids
|
||||
|
||||
mask_function = causal_mask_function
|
||||
|
||||
super().__init__(
|
||||
shape=shape,
|
||||
mask_function=mask_function,
|
||||
shard_count=shard_count,
|
||||
)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
return (
|
||||
self.shape == other.shape
|
||||
and self.offset == other.offset
|
||||
and np.array_equal(self.q_sequence, other.q_sequence)
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
type(self),
|
||||
self.shape,
|
||||
self.offset,
|
||||
self.q_sequence.tobytes() if self.q_sequence is not None else None,
|
||||
))
|
||||
|
||||
|
||||
class ChunkedCausalMask(_ComputableMask):
|
||||
"""Lazy chunked causal mask.
|
||||
|
||||
Attention is causal within each chunk (0, K), (K, 2K), (2K, 3K), ... tokens
|
||||
attend to each other but not across chunks.
|
||||
Llama4 models use interleaved chunk attention along with global attention.
|
||||
|
||||
|
||||
Attributes:
|
||||
chunk_size: The size of each attention chunk.
|
||||
"""
|
||||
|
||||
chunk_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: tuple[int, int],
|
||||
chunk_size: int,
|
||||
shard_count: int = 1,
|
||||
):
|
||||
if chunk_size <= 0:
|
||||
raise ValueError('chunk_size must be positive')
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
# Define the mask function for chunk attention
|
||||
def chunked_causal_mask_function(q_ids, kv_ids):
|
||||
"""Computes the mask logic for the given slice indices."""
|
||||
# Condition 1: Same chunk
|
||||
same_chunk = (q_ids // self.chunk_size) == (kv_ids // self.chunk_size)
|
||||
|
||||
# Condition 2: Causal
|
||||
causal = q_ids >= kv_ids
|
||||
|
||||
return same_chunk & causal
|
||||
|
||||
super().__init__(
|
||||
shape=shape,
|
||||
mask_function=chunked_causal_mask_function,
|
||||
shard_count=shard_count,
|
||||
)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
return (
|
||||
self.shape == other.shape
|
||||
and self.chunk_size == other.chunk_size
|
||||
and np.array_equal(self.q_sequence, other.q_sequence)
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
type(self),
|
||||
self.shape,
|
||||
self.chunk_size,
|
||||
self.q_sequence.tobytes() if self.q_sequence is not None else None,
|
||||
))
|
||||
|
||||
|
||||
class LocalMask(_ComputableMask):
|
||||
"""Lazy local mask, prevents model from attending to tokens outside window.
|
||||
|
||||
Attributes:
|
||||
window_size: Size of the two sides of the local window (None identifies no
|
||||
limit for the given side).
|
||||
offset: Offset of q start wrt kv. A positive offset shifts the bottom
|
||||
triangle upward, a negative one shifts it downward. A negative offset
|
||||
makes the first 'offset' rows of the attention matrix all 0s which leads
|
||||
to undefined softmax.
|
||||
"""
|
||||
|
||||
window_size: tuple[int | None, int | None]
|
||||
offset: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shape: tuple[int, int],
|
||||
window_size: tuple[int | None, int | None],
|
||||
offset: int,
|
||||
shard_count: int = 1,
|
||||
):
|
||||
self.window_size = window_size
|
||||
self.offset = offset
|
||||
|
||||
def local_mask_function(q_ids, kv_ids):
|
||||
"""Computes the local attention mask for the given slice indices."""
|
||||
left_size, right_size = self.window_size
|
||||
|
||||
assert q_ids.ndim == 2
|
||||
assert kv_ids.ndim == 2
|
||||
|
||||
if left_size is None and right_size is None:
|
||||
return np.ones((q_ids.shape[0], kv_ids.shape[1]), dtype=np.bool_)
|
||||
|
||||
# Avoid the addition when possible to avoid instantiating an actual array.
|
||||
if offset != 0:
|
||||
shifted_q_ids = q_ids + self.offset
|
||||
else:
|
||||
shifted_q_ids = q_ids
|
||||
|
||||
mask = None
|
||||
if left_size is not None:
|
||||
mask = shifted_q_ids - left_size <= kv_ids
|
||||
if right_size is not None:
|
||||
if mask is None:
|
||||
mask = shifted_q_ids + right_size >= kv_ids
|
||||
else:
|
||||
mask &= shifted_q_ids + right_size >= kv_ids
|
||||
return mask
|
||||
|
||||
super().__init__(
|
||||
shape=shape,
|
||||
mask_function=local_mask_function,
|
||||
shard_count=shard_count,
|
||||
)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, type(self)):
|
||||
return False
|
||||
|
||||
return (
|
||||
self.shape == other.shape
|
||||
and self.window_size == other.window_size
|
||||
and self.offset == other.offset
|
||||
and np.array_equal(self.q_sequence, other.q_sequence)
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((
|
||||
type(self),
|
||||
self.shape,
|
||||
self.window_size,
|
||||
self.offset,
|
||||
self.q_sequence.tobytes() if self.q_sequence is not None else None,
|
||||
))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NumpyMask(Mask):
|
||||
"""A mask backed by a dense numpy array."""
|
||||
|
||||
array: np.ndarray
|
||||
|
||||
def __post_init__(self):
|
||||
if self.array.ndim != 2:
|
||||
raise ValueError('Expected a 2-dim array')
|
||||
|
||||
if self.array.dtype != np.bool_:
|
||||
raise ValueError('Mask must be a boolean array')
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return self.array.shape
|
||||
|
||||
def __getitem__(self, idx) -> np.ndarray:
|
||||
return self.array[idx]
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
return np.array_equal(self.array, other.array, equal_nan=True)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((type(self), self.array.tobytes()))
|
||||
|
||||
|
||||
def _fill_slice(inp_slice: slice, size: int) -> slice:
|
||||
assert inp_slice.step is None or inp_slice.step == 1
|
||||
start = 0 if inp_slice.start is None else inp_slice.start
|
||||
stop = size if inp_slice.stop is None else inp_slice.stop
|
||||
assert start >= 0
|
||||
assert stop <= size
|
||||
return slice(start, stop, None)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FullMask(Mask):
|
||||
"""Lazy full mask, allows all tokens to attend to all other tokens."""
|
||||
|
||||
# TODO(amagni): Transform FullMask into a _ComputableMask.
|
||||
|
||||
_shape: tuple[int, int]
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.shape, tuple):
|
||||
raise ValueError(f'Unsupported shape type: {type(self.shape)}')
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]:
|
||||
return self._shape
|
||||
|
||||
def __getitem__(self, idx) -> np.ndarray:
|
||||
if len(idx) != 2:
|
||||
raise NotImplementedError(f'Unsupported slice: {idx}')
|
||||
i, j = idx
|
||||
if not isinstance(i, slice) or not isinstance(j, slice):
|
||||
raise NotImplementedError(f'Unsupported slice: {idx}')
|
||||
i = _fill_slice(i, self.shape[0])
|
||||
j = _fill_slice(j, self.shape[1])
|
||||
return np.ones((i.stop - i.start, j.stop - j.start), dtype=np.bool_)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
|
||||
return self.shape == other.shape
|
||||
|
||||
def __hash__(self):
|
||||
return hash((type(self), self.shape))
|
||||
+1133
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user