1067 lines
36 KiB
Python
1067 lines
36 KiB
Python
# Copyright 2025 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Defines expressions and constraints over layouts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import abc
|
|
from collections.abc import Sequence
|
|
import dataclasses
|
|
import math
|
|
from typing import Any, assert_never, final
|
|
|
|
from . import fragmented_array as fa
|
|
from . import inference_utils
|
|
from . import launch_context as lc
|
|
from . import layouts as layouts_lib
|
|
from . import tcgen05
|
|
from . import utils
|
|
|
|
|
|
# TODO(bchetioui): consider defining an interface for variable keys that carry
|
|
# shape and memory space information.
|
|
VariableKey = Any
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Variable:
|
|
"""A variable is an abstract identifier.
|
|
|
|
`key` is supposed to be hashable.
|
|
"""
|
|
key: VariableKey
|
|
|
|
def __str__(self):
|
|
return f"V({self.key})"
|
|
|
|
|
|
class Constant(abc.ABC):
|
|
"""A constant is a known layout."""
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class RegisterLayout(Constant):
|
|
"""Wraps a known register layout."""
|
|
|
|
value: fa.FragmentedLayout
|
|
|
|
def __str__(self):
|
|
return f"C({self.value})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TMEMLayout(Constant):
|
|
"""Wraps a known TMEM layout."""
|
|
|
|
value: tcgen05.TMEMLayout
|
|
|
|
def __str__(self):
|
|
return f"C({self.value})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SMEMTiling(Constant):
|
|
"""Wraps a known SMEM Tile Transform.
|
|
|
|
If an SMEM reference may, in principle, have transforms but should not be
|
|
tiled, then `value` is `None`.
|
|
"""
|
|
|
|
value: lc.TileTransform | None
|
|
|
|
def __str__(self):
|
|
return f"C({self.value})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Reduce:
|
|
expression: Expression
|
|
axes: tuple[int, ...]
|
|
# The rank of the shape of the input to the reduction. It is necessary to
|
|
# know this in order to reduce `TiledLayout`s correctly.
|
|
rank: int
|
|
# If `True`, the axes which are reduced are left in the result as dimensions
|
|
# with size one.
|
|
keep_dims: bool = False
|
|
|
|
def __str__(self):
|
|
return (
|
|
f"Reduce([{self.axes}], {self.expression}, rank={self.rank},"
|
|
f" keep_dims={self.keep_dims})"
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Reshape:
|
|
expression: Expression
|
|
source_shape: tuple[int, ...]
|
|
target_shape: tuple[int, ...]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Transpose:
|
|
expression: Expression
|
|
|
|
def __str__(self):
|
|
return f"T({self.expression})"
|
|
|
|
|
|
Expression = (
|
|
Variable
|
|
| Constant
|
|
| Reduce
|
|
| Reshape
|
|
| Transpose
|
|
)
|
|
|
|
|
|
def reduce_reshape_expression(
|
|
reshape: Reshape, assignments: dict[Variable, Constant]
|
|
) -> Expression | Unsatisfiable:
|
|
reduced_expr = reduce_expression(reshape.expression, assignments)
|
|
match reduced_expr:
|
|
case Unsatisfiable():
|
|
return Unsatisfiable()
|
|
case RegisterLayout(value=layout):
|
|
match layout:
|
|
case fa.WGSplatFragLayout(shape=shape):
|
|
assert math.prod(shape) == math.prod(reshape.target_shape)
|
|
return RegisterLayout(
|
|
fa.WGSplatFragLayout(shape=reshape.target_shape)
|
|
)
|
|
case fa.WGStridedFragLayout(shape=shape, vec_size=vec_size):
|
|
assert math.prod(shape) == math.prod(reshape.target_shape)
|
|
return RegisterLayout(
|
|
fa.WGStridedFragLayout(
|
|
shape=reshape.target_shape, vec_size=vec_size
|
|
)
|
|
)
|
|
case fa.TiledLayout() as tiled_layout:
|
|
tile_shape = tiled_layout.base_tile_shape
|
|
if len(reshape.target_shape) < len(tile_shape):
|
|
return dataclasses.replace(reshape, expression=reduced_expr)
|
|
# Even if the new shape is not perfectly tilable, it is possible that
|
|
# we may be able to reshape the tiling itself in a way that is
|
|
# compatible with the new shape. We do not handle this case at the
|
|
# moment.
|
|
for ts, s in zip(tile_shape, reshape.source_shape[-len(tile_shape):], strict=True):
|
|
if s % ts != 0:
|
|
return dataclasses.replace(reshape, expression=reduced_expr)
|
|
|
|
# If minor tiled dimensions are modified, then reshaping is likely to
|
|
# not be a no-op since the strides between tiles will change,
|
|
# potentially mapping different elements to lanes and warps. We don't
|
|
# attempt to handle this case at the moment.
|
|
num_minor_tiled_dims = len(tile_shape) - 1
|
|
source_minor_tiled_dims = reshape.source_shape[-num_minor_tiled_dims:]
|
|
target_minor_tiled_dims = reshape.target_shape[-num_minor_tiled_dims:]
|
|
major_tiled_dim = tile_shape[0]
|
|
if (source_minor_tiled_dims != target_minor_tiled_dims or
|
|
reshape.target_shape[-len(tile_shape)] % major_tiled_dim != 0):
|
|
return dataclasses.replace(reshape, expression=reduced_expr)
|
|
# At this point, we now that only non-tiled dimensions and/or the
|
|
# majormost tiled dimensions may have changed. We also know that the
|
|
# majormost tiled dimension is still tilable in the new shape.
|
|
# Therefore, we can return the tiled layout as is.
|
|
return RegisterLayout(tiled_layout)
|
|
case _:
|
|
return dataclasses.replace(reshape, expression=reduced_expr)
|
|
|
|
|
|
def reduce_transpose_expression(
|
|
transpose: Transpose, assignments: dict[Variable, Constant]
|
|
) -> Expression | Unsatisfiable:
|
|
reduced_expr = reduce_expression(transpose.expression, assignments)
|
|
match reduced_expr:
|
|
case Unsatisfiable():
|
|
return Unsatisfiable()
|
|
case SMEMTiling(value=tile_transform):
|
|
if tile_transform is None:
|
|
return SMEMTiling(None)
|
|
tiling = tile_transform.tiling
|
|
if len(tiling) != 2:
|
|
raise NotImplementedError(
|
|
f"Only 2D tilings are supported, got {len(tiling)}"
|
|
)
|
|
return SMEMTiling(lc.TileTransform(tiling[::-1]))
|
|
case _:
|
|
return Transpose(expression=reduced_expr)
|
|
|
|
|
|
def reduce_reduce_expression(
|
|
expr: Reduce, assignments: dict[Variable, Constant]
|
|
) -> Expression | Unsatisfiable:
|
|
reduced_expr = reduce_expression(expr.expression, assignments)
|
|
|
|
def default():
|
|
"""We don't know how to reduce further."""
|
|
assert not isinstance(reduced_expr, Unsatisfiable)
|
|
return dataclasses.replace(expr, expression=reduced_expr)
|
|
|
|
match reduced_expr:
|
|
case Unsatisfiable():
|
|
return Unsatisfiable()
|
|
case RegisterLayout(value=fa.TiledLayout() as layout):
|
|
# TODO(allanrenucci): Add support for reducing tiled layouts when keep_dims=True.
|
|
if expr.keep_dims:
|
|
return default()
|
|
num_untiled_dims = expr.rank - len(layout.base_tile_shape)
|
|
reduced_tiling_axes = [
|
|
a - num_untiled_dims for a in expr.axes if a >= num_untiled_dims
|
|
]
|
|
if reduced_tiling_axes:
|
|
return RegisterLayout(layout.reduce(reduced_tiling_axes))
|
|
return RegisterLayout(layout)
|
|
case RegisterLayout(value=fa.WGStridedFragLayout() as layout):
|
|
# We only support reducing leading dimensions.
|
|
if expr.axes != tuple(range(len(expr.axes))):
|
|
return default()
|
|
shape = utils.reduce_shape(layout.shape, expr.axes, expr.keep_dims)
|
|
if math.prod(shape) % (layout.vec_size * fa.WARPGROUP_SIZE) != 0:
|
|
return default()
|
|
return RegisterLayout(fa.WGStridedFragLayout(shape, layout.vec_size))
|
|
case RegisterLayout(value=fa.WGSplatFragLayout() as layout):
|
|
shape = utils.reduce_shape(layout.shape, expr.axes, expr.keep_dims)
|
|
return RegisterLayout(fa.WGSplatFragLayout(shape))
|
|
case _:
|
|
return default()
|
|
|
|
|
|
def reduce_expression(
|
|
expr: Expression, assignments: dict[Variable, Constant]
|
|
) -> Expression | Unsatisfiable:
|
|
"""Reduces an expression as much as is possible given a set of known variable assignments."""
|
|
match expr:
|
|
case Constant():
|
|
return expr
|
|
case Variable():
|
|
return assignments.get(expr, expr)
|
|
case Reduce():
|
|
return reduce_reduce_expression(expr, assignments)
|
|
case Reshape():
|
|
return reduce_reshape_expression(expr, assignments)
|
|
case Transpose():
|
|
return reduce_transpose_expression(expr, assignments)
|
|
case _:
|
|
assert_never(expr)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Equals:
|
|
"""States that `lhs` and `rhs` are equal."""
|
|
lhs: Expression
|
|
rhs: Expression
|
|
|
|
def holds(self) -> bool | None:
|
|
if self.lhs == self.rhs:
|
|
return True
|
|
if isinstance(self.lhs, Constant) and isinstance(self.rhs, Constant):
|
|
return False
|
|
return None
|
|
|
|
def __str__(self):
|
|
return f"Equals({self.lhs} == {self.rhs})"
|
|
|
|
|
|
def _is_supported_tiled_relayout(
|
|
src: fa.TiledLayout, dst: fa.TiledLayout, bitwidth: int
|
|
) -> bool:
|
|
"""Returns whether the source->target relayout is supported for values of types with the given bitwidth."""
|
|
match src, dst:
|
|
# Transposed layouts.
|
|
case fa.WGMMA_LAYOUT, fa.WGMMA_TRANSPOSED_LAYOUT:
|
|
return True
|
|
case fa.WGMMA_TRANSPOSED_LAYOUT, fa.WGMMA_LAYOUT:
|
|
return True
|
|
case fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT:
|
|
return True
|
|
case fa.TCGEN05_TRANSPOSED_LAYOUT, fa.TCGEN05_LAYOUT:
|
|
return True
|
|
# "Conversion-optimized" layouts.
|
|
case fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT:
|
|
return fa.can_relayout_wgmma_2x_to_wgmma(bitwidth)
|
|
case fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X:
|
|
return fa.can_relayout_wgmma_4x_to_wgmma_2x(bitwidth)
|
|
case fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT:
|
|
return fa.can_relayout_wgmma_4x_to_wgmma_2x(
|
|
bitwidth
|
|
) and fa.can_relayout_wgmma_2x_to_wgmma(bitwidth)
|
|
if src == fa.tmem_native_layout(
|
|
src.vector_length
|
|
) and dst == fa.tmem_native_layout(dst.vector_length):
|
|
return True
|
|
return False
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Relayout:
|
|
"""States that `source` must be relayout-able to `target`.
|
|
|
|
Relayout-ability here is not defined as a fundamental property of layouts, but
|
|
rather a reflection of our implementation. For instance, when evaluating this
|
|
constraint, we will return `False` systematically if a relayout exists but we
|
|
do not ever plan to support it.
|
|
|
|
Modeling this constraint this way is helpful, in order to allow pruning
|
|
inefficient solutions when attempting to solve a constraint system.
|
|
|
|
We include here the bitwidth of the element type we want to associate with
|
|
this constraint, as certain relayouts are only supported for specific
|
|
bitwidths.
|
|
|
|
If `strict` is `True`, only allows relayout from splat layouts and force
|
|
layout equality otherwise.
|
|
"""
|
|
|
|
source: Expression
|
|
target: Expression
|
|
bitwidth: int
|
|
strict: bool = False
|
|
|
|
def canonicalize(self) -> Constraint:
|
|
match self:
|
|
# The only valid strict tiled and strided relayout is the identity.
|
|
case Relayout(
|
|
source=RegisterLayout(
|
|
value=fa.TiledLayout() | fa.WGStridedFragLayout()
|
|
) as cst,
|
|
target=target,
|
|
strict=True,
|
|
):
|
|
return Equals(lhs=cst, rhs=target)
|
|
case _:
|
|
return self
|
|
|
|
def holds(self) -> bool | None:
|
|
"""Returns whether the relayout constraint holds.
|
|
|
|
Returns `None` if the constraint can't be checked.
|
|
"""
|
|
source = self.source
|
|
target = self.target
|
|
|
|
# Fast path for syntactically identical expressions.
|
|
if source == target:
|
|
return True
|
|
|
|
if not isinstance(source, RegisterLayout) or not isinstance(
|
|
target, RegisterLayout
|
|
):
|
|
return None
|
|
|
|
source_layout, target_layout = source.value, target.value
|
|
match source_layout, target_layout:
|
|
case fa.WGSplatFragLayout() as splat, fa.WGStridedFragLayout() as strided:
|
|
return splat.shape == strided.shape
|
|
case fa.WGSplatFragLayout(), fa.TiledLayout():
|
|
return layouts_lib.splat_is_compatible_with_tiled(
|
|
source_layout, target_layout
|
|
)
|
|
case fa.TiledLayout(), fa.TiledLayout() if not self.strict:
|
|
return _is_supported_tiled_relayout(
|
|
source_layout, target_layout, self.bitwidth
|
|
)
|
|
case _:
|
|
return False
|
|
|
|
def __str__(self):
|
|
return f"Relayout({self.source} ⟶ {self.target})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class IsTransferable(abc.ABC):
|
|
"""States that `source` layout must be transferable across memory spaces to `target` layout."""
|
|
source: Expression
|
|
target: Expression
|
|
shape: tuple[int, ...]
|
|
|
|
def holds(self) -> bool | None:
|
|
"""Returns whether the constraint holds.
|
|
|
|
Returns `None` if the constraint can't be checked.
|
|
"""
|
|
raise NotImplementedError("Holds must be implemented by subclasses.")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class IsTransferableTmemRegisters(IsTransferable):
|
|
"""States that `source` layout must be transferable across memory spaces to `target` layout.
|
|
|
|
In this case, one of `source` and `target` must be in TMEM, and the other must
|
|
be in registers. `bitwidth` is the bitwidth of the element type.
|
|
"""
|
|
bitwidth: int
|
|
|
|
def __post_init__(self):
|
|
assert len(self.shape) == 2
|
|
assert 0 < self.bitwidth <= 32
|
|
|
|
def is_valid_tmem_transfer(
|
|
self, tmem_layout: tcgen05.TMEMLayout, reg_layout: fa.FragmentedLayout
|
|
) -> bool:
|
|
if not isinstance(reg_layout, fa.TiledLayout):
|
|
return False
|
|
packing = tmem_layout.vector_length
|
|
columns = self.shape[1]
|
|
if (
|
|
reg_layout == fa.TCGEN05_LAYOUT
|
|
and tmem_layout == tcgen05.tmem_default_layout(packing)
|
|
):
|
|
return True
|
|
if (
|
|
reg_layout == tmem_layout.as_tiled_layout()
|
|
and packing * self.bitwidth == 32
|
|
):
|
|
return True
|
|
if (
|
|
reg_layout == fa.TMEM_NATIVE_LAYOUT
|
|
and tmem_layout == tcgen05.tmem_default_layout(packing)
|
|
and ((self.bitwidth == 16 and packing == 1) or self.bitwidth == 32)
|
|
):
|
|
return True
|
|
if (
|
|
reg_layout == fa.WGMMA_LAYOUT
|
|
and tmem_layout == tcgen05.tmem_half_lane_layout(columns, packing)
|
|
):
|
|
return True
|
|
if (
|
|
reg_layout == tcgen05.fa_m64_collective_layout(columns)
|
|
and tmem_layout == tcgen05.tmem_m64_collective_layout(columns, packing)
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def holds(self) -> bool | None:
|
|
match self.source, self.target:
|
|
case RegisterLayout(value=src), TMEMLayout(value=dst):
|
|
return self.is_valid_tmem_transfer(dst, src)
|
|
case TMEMLayout(value=src), RegisterLayout(value=dst):
|
|
return self.is_valid_tmem_transfer(src, dst)
|
|
case Constant(), Constant():
|
|
raise ValueError(
|
|
f"{self.source} -> {self.target} is not a TMEM <-> Registers"
|
|
" transfer."
|
|
)
|
|
case _:
|
|
return None
|
|
|
|
def __str__(self):
|
|
return f"IsTransferableTmemRegisters({self.source} ⟶ {self.target})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class IsTransferableSmemRegisters(IsTransferable):
|
|
"""States that `source` layout must be transferable across memory spaces to `target` layout.
|
|
|
|
In this case, one of `source` and `target` must be in SMEM, and the other must
|
|
be in registers.
|
|
"""
|
|
strides: tuple[int, ...]
|
|
|
|
def _is_supported_smem_transfer(
|
|
self,
|
|
smem_layout: lc.TileTransform | None,
|
|
reg_layout: fa.FragmentedLayout,
|
|
) -> bool:
|
|
# TODO(b/447079781): This is way too restrictive. We need to make it more
|
|
# precise by:
|
|
# - Consider whether the op is annotated with optimized copies or not.
|
|
# - If copies do not have to be optimized, always return True.
|
|
# - If copies have to be optimized, determine if the transfer is optimal by
|
|
# calling fragmented_array.plan_tiled_transfer.
|
|
if inference_utils.is_mma_layout(reg_layout):
|
|
if smem_layout is None or len(smem_layout.tiling) != 2:
|
|
return False
|
|
transposed_layouts = {fa.TCGEN05_TRANSPOSED_LAYOUT, fa.WGMMA_TRANSPOSED_LAYOUT}
|
|
if list(self.strides[-2:]) != sorted(self.strides[-2:], reverse=True):
|
|
return reg_layout in transposed_layouts
|
|
return reg_layout not in transposed_layouts
|
|
return smem_layout is None
|
|
|
|
def holds(self) -> bool | None:
|
|
match self.source, self.target:
|
|
case SMEMTiling(value=src), RegisterLayout(value=dst):
|
|
return self._is_supported_smem_transfer(src, dst)
|
|
case RegisterLayout(value=src), SMEMTiling(value=dst):
|
|
return self._is_supported_smem_transfer(dst, src)
|
|
case Constant(), Constant():
|
|
raise ValueError(
|
|
f"{self.source} -> {self.target} is not a SMEM <-> Registers"
|
|
" transfer."
|
|
)
|
|
case _:
|
|
return None
|
|
|
|
def __str__(self):
|
|
return f"IsTransferableSmemRegisters({self.source} ⟶ {self.target})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NotOfType:
|
|
"""States that `expr` is not an instance of `type`."""
|
|
|
|
expr: Expression
|
|
type: type[fa.FragmentedLayout]
|
|
|
|
def holds(self) -> bool | None:
|
|
"""Whether the distinctiveness constraint holds.
|
|
|
|
Returns `None` if the constraint can't be checked.
|
|
"""
|
|
if not isinstance(self.expr, Constant):
|
|
return None
|
|
if not isinstance(self.expr, RegisterLayout):
|
|
return True
|
|
return not isinstance(self.expr.value, self.type)
|
|
|
|
def __str__(self):
|
|
return f"type({self.expr}) ≠ {self.type.__name__}"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class Divides:
|
|
"""States that the `expr` tiling is a divisor of `tiling_multiple`.
|
|
|
|
That is to say that, for each tiled dimension in `expr`, the dimension must
|
|
divide its corresponding dimension in `tiling_multiple` starting from the
|
|
tail.
|
|
|
|
If `tiling_multiple` contains more dimensions than `expr`, then
|
|
the extra dimensions in `tiling_multiple` are ignored for the purposes of the
|
|
check.
|
|
|
|
`expr` is not allowed to contain more dimensions than `tiling_multiple`, and
|
|
this constraint therefore also constrains the rank of `expr`.
|
|
"""
|
|
expr: Expression
|
|
tiling_multiple: tuple[int, ...]
|
|
|
|
def holds(self) -> bool | None:
|
|
match self.expr:
|
|
case SMEMTiling(value=None):
|
|
# If there is no tiling, then this holds trivially.
|
|
return True
|
|
case SMEMTiling(value=lc.TileTransform(tiling=t)):
|
|
tiling = t
|
|
case RegisterLayout(value=fa.TiledLayout() as layout):
|
|
tiling = layout.base_tile_shape
|
|
case TMEMLayout(value):
|
|
tiling = value.base_tile_shape
|
|
case _:
|
|
return None
|
|
|
|
if len(tiling) > len(self.tiling_multiple):
|
|
# The rank of the tiling is larger than the rank of the constraint. This
|
|
# is not allowed.
|
|
return False
|
|
|
|
for size, multiple in zip(reversed(tiling), reversed(self.tiling_multiple)):
|
|
if multiple % size:
|
|
return False
|
|
return True
|
|
|
|
def __str__(self):
|
|
return f"{self.tiling_multiple} % {self.expr} == 0"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class IsValidMmaTiling:
|
|
"""States that the `expr` SMEM tiling must be compatible with MMA requirements.
|
|
|
|
For both tcgen05.mma and wgmma, tiling is valid if it is of the form
|
|
(8, swizzle_elems), with
|
|
swizzle_elems in {s * 8 // dtype_bitwidth for s in [32, 64, 128]},
|
|
as support for unswizzled tilings is not yet supported.
|
|
|
|
If `allow_unswizzled` is True, then we additionally accept
|
|
(8, 16 * 8 // dtype_bitwidth) as a valid tiling.
|
|
"""
|
|
expr: Expression
|
|
bitwidth: int
|
|
allow_unswizzled: bool = False
|
|
|
|
def holds(self) -> bool | None:
|
|
match self.expr:
|
|
case SMEMTiling(value=None):
|
|
return False
|
|
case SMEMTiling(value=lc.TileTransform(tiling=t)):
|
|
swizzles = [16, 32, 64, 128] if self.allow_unswizzled else [32, 64, 128]
|
|
valid_tilings = {(8, s * 8 // self.bitwidth) for s in swizzles}
|
|
return t in valid_tilings
|
|
case RegisterLayout() | TMEMLayout() as c:
|
|
raise ValueError(f"Unexpected value {c} in IsValidMmaTiling constraint")
|
|
case _:
|
|
return None
|
|
|
|
def __str__(self):
|
|
return f"IsValidMMATiling({self.expr}, {self.bitwidth}, allow_unswizzled={self.allow_unswizzled})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class IsSupportedBroadcast:
|
|
"""States that `src` can be broadcasted to `dst`.
|
|
|
|
See `FragmentedArray.broadcast_in_dim` for more details.
|
|
"""
|
|
|
|
src: Expression
|
|
dst: Expression
|
|
dims: tuple[int, ...]
|
|
|
|
def holds(self) -> bool | None:
|
|
match self.src, self.dst:
|
|
case RegisterLayout(
|
|
value=fa.WGStridedFragLayout() as src_layout
|
|
), RegisterLayout(value=fa.WGStridedFragLayout() as dst_layout):
|
|
return fa.is_supported_strided_layout_broadcast(src_layout, dst_layout, self.dims)
|
|
case RegisterLayout(value=src_layout), RegisterLayout(value=dst_layout):
|
|
# This is an intentionally loose check. We rely on the presence of a
|
|
# `src = Reduce(dst)` constraint to enforce correctness.
|
|
return type(src_layout) == type(dst_layout)
|
|
case Constant() as src, Constant() as dst:
|
|
raise ValueError(
|
|
f"Unexpected values {src=} {dst=} in IsSupportedBroadcast"
|
|
" constraint"
|
|
)
|
|
case _:
|
|
return None
|
|
|
|
def __str__(self):
|
|
return (
|
|
f"IsSupportedBroadcast(src={self.src}, dst={self.dst},"
|
|
f" dims={self.dims})"
|
|
)
|
|
|
|
|
|
Constraint = (
|
|
Equals
|
|
| Relayout
|
|
| NotOfType
|
|
| IsTransferable
|
|
| IsValidMmaTiling
|
|
| Divides
|
|
| IsSupportedBroadcast
|
|
)
|
|
|
|
|
|
def reduce_constraint(
|
|
constraint: Constraint, assignments: dict[Variable, Constant]
|
|
) -> Constraint | Unsatisfiable:
|
|
"""Reduces a constraint."""
|
|
|
|
match constraint:
|
|
case Equals(lhs=lhs, rhs=rhs):
|
|
lhs_red = reduce_expression(lhs, assignments)
|
|
if isinstance(lhs_red, Unsatisfiable):
|
|
return Unsatisfiable()
|
|
rhs_red = reduce_expression(rhs, assignments)
|
|
if isinstance(rhs_red, Unsatisfiable):
|
|
return Unsatisfiable()
|
|
return Equals(lhs_red, rhs_red)
|
|
case Relayout(source=source, target=target) as relayout:
|
|
source_red = reduce_expression(source, assignments)
|
|
target_red = reduce_expression(target, assignments)
|
|
if isinstance(source_red, Unsatisfiable) or isinstance(
|
|
target_red, Unsatisfiable
|
|
):
|
|
return Unsatisfiable()
|
|
reduced = dataclasses.replace(
|
|
relayout, source=source_red, target=target_red
|
|
)
|
|
return reduced.canonicalize()
|
|
case NotOfType(expr=expr, type=ty):
|
|
expr_red = reduce_expression(expr, assignments)
|
|
if isinstance(expr_red, Unsatisfiable):
|
|
return Unsatisfiable()
|
|
return NotOfType(expr_red, ty)
|
|
case IsTransferable(source=source, target=target) as transfer:
|
|
source_red = reduce_expression(source, assignments)
|
|
target_red = reduce_expression(target, assignments)
|
|
if isinstance(source_red, Unsatisfiable) or isinstance(target_red, Unsatisfiable):
|
|
return Unsatisfiable()
|
|
return dataclasses.replace(transfer, source=source_red, target=target_red)
|
|
case IsValidMmaTiling(expr=expr) as is_valid_mma_tiling:
|
|
expr_red = reduce_expression(expr, assignments)
|
|
if isinstance(expr_red, Unsatisfiable):
|
|
return Unsatisfiable()
|
|
return dataclasses.replace(is_valid_mma_tiling, expr=expr_red)
|
|
case Divides(expr=expr, tiling_multiple=tiling_multiple):
|
|
expr_red = reduce_expression(expr, assignments)
|
|
if isinstance(expr_red, Unsatisfiable):
|
|
return Unsatisfiable()
|
|
return Divides(expr_red, tiling_multiple)
|
|
case IsSupportedBroadcast(src=src, dst=dst, dims=dims):
|
|
src_red = reduce_expression(src, assignments)
|
|
dst_red = reduce_expression(dst, assignments)
|
|
if isinstance(src_red, Unsatisfiable) or isinstance(
|
|
dst_red, Unsatisfiable
|
|
):
|
|
return Unsatisfiable()
|
|
return IsSupportedBroadcast(src_red, dst_red, dims)
|
|
case _ as never:
|
|
assert_never(never)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ConstraintSystem:
|
|
"""A constraint system contains a set of constraints and assignments.
|
|
|
|
Assignments assign constant values to variables in the system (bound
|
|
variables). Constraints describe relationships between variables that must be
|
|
upheld, and can be used to determine assignments for unknown (free) variables.
|
|
"""
|
|
|
|
assignments: dict[Variable, Constant] = dataclasses.field(
|
|
default_factory=dict
|
|
)
|
|
constraints: Sequence[Constraint] = dataclasses.field(default_factory=list)
|
|
|
|
def unknowns(self) -> list[Variable]:
|
|
"""Returns the list of free variables in the system."""
|
|
seen_variables: set[Variable] = set()
|
|
free_variables: list[Variable] = []
|
|
def extract_variables(expr: Expression) -> None:
|
|
match expr:
|
|
case Variable():
|
|
if expr not in seen_variables and expr not in self.assignments:
|
|
seen_variables.add(expr)
|
|
free_variables.append(expr)
|
|
case Constant():
|
|
...
|
|
case Reduce(expression=e):
|
|
extract_variables(e)
|
|
case Reshape(expression=e):
|
|
extract_variables(e)
|
|
case Transpose(expression=e):
|
|
extract_variables(e)
|
|
case _:
|
|
assert_never(never)
|
|
for constraint in self.constraints:
|
|
match constraint:
|
|
case Equals(lhs=lhs, rhs=rhs):
|
|
extract_variables(lhs)
|
|
extract_variables(rhs)
|
|
case Relayout(source=source, target=target):
|
|
extract_variables(source)
|
|
extract_variables(target)
|
|
case NotOfType(expr=expr):
|
|
extract_variables(expr)
|
|
case IsTransferable(source=source, target=target):
|
|
extract_variables(source)
|
|
extract_variables(target)
|
|
case IsValidMmaTiling(expr=expr):
|
|
extract_variables(expr)
|
|
case Divides(expr=expr):
|
|
extract_variables(expr)
|
|
case IsSupportedBroadcast(src=src, dst=dst):
|
|
extract_variables(src)
|
|
extract_variables(dst)
|
|
case _ as never:
|
|
assert_never(never)
|
|
return free_variables
|
|
|
|
def __and__(
|
|
self, other: ConstraintSystem | Unsatisfiable
|
|
) -> ConstraintSystem | Unsatisfiable:
|
|
if isinstance(other, Unsatisfiable):
|
|
return Unsatisfiable()
|
|
for variable, assignment in self.assignments.items():
|
|
if variable in other.assignments and assignment != other.assignments[variable]:
|
|
return Unsatisfiable()
|
|
return ConstraintSystem(
|
|
assignments=self.assignments | other.assignments,
|
|
constraints=[*self.constraints, *other.constraints],
|
|
)
|
|
|
|
def __str__(self):
|
|
r = "ConstraintSystem\n"
|
|
r += " assignments:\n"
|
|
for assignment, constant in self.assignments.items():
|
|
r += f" {assignment} ⟵ {constant}\n"
|
|
r += " constraints:\n"
|
|
for constraint in self.constraints:
|
|
r += f" {constraint}\n"
|
|
return r
|
|
|
|
|
|
@final
|
|
class Unsatisfiable:
|
|
|
|
def __and__(self, other: ConstraintSystem | Unsatisfiable) -> Unsatisfiable:
|
|
return self
|
|
|
|
|
|
def non_splat_variables(
|
|
constraints: Sequence[Constraint],
|
|
) -> set[Variable]:
|
|
"""Returns a all vars distinct from a splat."""
|
|
vs: set[Variable] = set()
|
|
for constraint in constraints:
|
|
match constraint:
|
|
case NotOfType(expr=Variable() as v, type=fa.WGSplatFragLayout):
|
|
assert isinstance(v, Variable) # make pytype happy
|
|
vs.add(v)
|
|
return vs
|
|
|
|
|
|
def _has_relayout_of_non_splat_to_splat(constraints: Sequence[Constraint]) -> bool:
|
|
"""Returns whether the constraints imply a non-splat to splat relayout.
|
|
|
|
Such relayouts are impossible and this helps shortcut the search.
|
|
|
|
If this function returns False, this doesn't necessarily mean that there are
|
|
no non-splat to splat relayouts, just that this is not known yet.
|
|
"""
|
|
non_splat = non_splat_variables(constraints)
|
|
if not non_splat:
|
|
return False
|
|
|
|
def is_constant_splat(e) -> bool:
|
|
return isinstance(e, RegisterLayout) and isinstance(
|
|
e.value, fa.WGSplatFragLayout
|
|
)
|
|
|
|
for constraint in constraints:
|
|
match constraint:
|
|
case Relayout(source=source, target=target):
|
|
if source in non_splat and is_constant_splat(target):
|
|
return True
|
|
case _:
|
|
pass
|
|
return False
|
|
|
|
|
|
def saturate_distinct_from_splat(
|
|
constraint_system: ConstraintSystem,
|
|
) -> ConstraintSystem | Unsatisfiable:
|
|
"""Adds transitive NotOfType constraints for all non-splat variables.
|
|
|
|
Given `n` variables `l0`, ... `l{n-1}`, and a set of relayouts
|
|
`{ Relayout(l{i}, l{i+1}) : 0 <= i < n }`, if we also know that
|
|
`l{0}` is not splat, then we can automatically deduce that none of
|
|
`l0`, ..., `l{n-1}` are splat either.
|
|
|
|
This helps us quickly conclude that a system is unsatisfiable in cases where
|
|
a non-splat variable is transitively relaid out into a splat layout.
|
|
"""
|
|
non_splat = non_splat_variables(constraint_system.constraints)
|
|
new_constraints: list[Constraint] = []
|
|
new_non_splat_found = bool(non_splat)
|
|
|
|
while new_non_splat_found:
|
|
new_non_splat_found = False
|
|
for constraint in constraint_system.constraints:
|
|
match constraint:
|
|
case Relayout(source=source, target=target):
|
|
if (
|
|
isinstance(target, Variable)
|
|
and source in non_splat
|
|
and target not in non_splat
|
|
):
|
|
new_non_splat_found = True
|
|
non_splat.add(target)
|
|
new_constraints.append(NotOfType(target, fa.WGSplatFragLayout))
|
|
case _:
|
|
pass
|
|
return constraint_system & ConstraintSystem(constraints=new_constraints)
|
|
|
|
|
|
def compute_transitively_equal_vars(
|
|
system: ConstraintSystem,
|
|
) -> dict[Variable, list[Variable]]:
|
|
"""Computes all transitively equal variables in a constraint system.
|
|
|
|
The output dictionary maps each variable that appears in constraints in the
|
|
constraint system to all the variables it is transitively equal to.
|
|
"""
|
|
# The equality relations between variables form a graph where variables are
|
|
# nodes and a constraint `v1 == v2` forms an edge. All variables in a
|
|
# connected component are transitively equal. We use a Union-Find data
|
|
# structure with path compression to efficiently find these connected
|
|
# components (i.e., equivalence classes).
|
|
parent: dict[Variable, Variable] = {}
|
|
def find(v: Variable) -> Variable:
|
|
if v not in parent:
|
|
parent[v] = v
|
|
if parent[v] != v:
|
|
parent[v] = find(parent[v])
|
|
return parent[v]
|
|
|
|
def union(v1: Variable, v2: Variable):
|
|
root1 = find(v1)
|
|
root2 = find(v2)
|
|
if root1 != root2:
|
|
parent[root2] = root1
|
|
|
|
all_vars: set[Variable] = set()
|
|
for constraint in system.constraints:
|
|
match constraint:
|
|
case Equals(lhs=Variable() as lhs, rhs=Variable() as rhs):
|
|
assert isinstance(lhs, Variable) # make pytype happy
|
|
assert isinstance(rhs, Variable) # make pytype happy
|
|
all_vars.add(lhs)
|
|
all_vars.add(rhs)
|
|
union(lhs, rhs)
|
|
|
|
# Group variables by their component representative.
|
|
components: dict[Variable, list[Variable]] = {}
|
|
for v in sorted(all_vars, key=str):
|
|
root = find(v)
|
|
components.setdefault(root, []).append(v)
|
|
|
|
equal_vars: dict[Variable, list[Variable]] = {}
|
|
for component_vars in components.values():
|
|
for v in component_vars:
|
|
equal_vars[v] = [other for other in component_vars if other != v]
|
|
|
|
return equal_vars
|
|
|
|
|
|
def saturate_divides_constraints_for_equal_vars(
|
|
system: ConstraintSystem,
|
|
) -> ConstraintSystem:
|
|
"""Saturates Divides constraints between all transitively equal vars."""
|
|
equal_vars = compute_transitively_equal_vars(system)
|
|
new_constraints: list[Constraint] = []
|
|
for constraint in system.constraints:
|
|
new_constraints.append(constraint)
|
|
match constraint:
|
|
case Divides(expr=expr, tiling_multiple=tiling_multiple):
|
|
if isinstance(expr, Variable):
|
|
for equal_var in equal_vars.get(expr, []):
|
|
new_constraints.append(Divides(equal_var, tiling_multiple))
|
|
case _:
|
|
pass
|
|
new_constraints = _merge_all_divides_constraints(new_constraints)
|
|
return dataclasses.replace(system, constraints=new_constraints)
|
|
|
|
|
|
def _merge_all_divides_constraints(constraints: Sequence[Constraint]) -> list[Constraint]:
|
|
"""Merges Divides constraints that can be merged."""
|
|
result: list[Constraint] = []
|
|
var_to_divides : dict[Variable, Divides] = {}
|
|
for constraint in constraints:
|
|
match constraint:
|
|
case Divides(expr=Variable() as v) as d1:
|
|
assert isinstance(v, Variable) # make pytype happy
|
|
if (d0 := var_to_divides.get(v)) is None:
|
|
var_to_divides[v] = d1
|
|
continue
|
|
var_to_divides[v] = merge_divides_constraints(d0, d1)
|
|
case _:
|
|
result.append(constraint)
|
|
result.extend(var_to_divides.values())
|
|
return result
|
|
|
|
|
|
def merge_divides_constraints(d0: Divides, d1: Divides) -> Divides:
|
|
if d0.expr != d1.expr:
|
|
raise ValueError("Divides constraints must apply to the same expression.")
|
|
# If the two tuples are of different lengths, the larger tuple will be
|
|
# truncated to the length of the smaller tuple. This preserves the semantics
|
|
# of the Divides constraints where a tiling's rank cannot exceed the size of
|
|
# tiling_multiple.
|
|
min_len = min(len(d0.tiling_multiple), len(d1.tiling_multiple))
|
|
if min_len == 0:
|
|
return Divides(d0.expr, ())
|
|
tiling_multiple = []
|
|
for t0, t1 in zip(d0.tiling_multiple[-min_len:], d1.tiling_multiple[-min_len:], strict=True):
|
|
tiling_multiple.append(math.gcd(t0, t1))
|
|
return Divides(d0.expr, tuple(tiling_multiple))
|
|
|
|
|
|
def _reduce_system_once(
|
|
constraint_system: ConstraintSystem,
|
|
) -> ConstraintSystem | Unsatisfiable | None:
|
|
"""Performs one reduction step over each constraint in a constraint system.
|
|
|
|
Returns:
|
|
- Unsatisfiable(): if the constraint system is unsatisfiable.
|
|
- A new constraint system if any constraint was reduced.
|
|
- None: if the constraint system is not known unsatisfiable, but hasn't been
|
|
reduced.
|
|
"""
|
|
assignments = constraint_system.assignments
|
|
constraints: list[Constraint] = []
|
|
changed = False
|
|
|
|
def try_assign(var: Variable, cst: Constant) -> bool:
|
|
if var in assignments and assignments[var] != cst:
|
|
return False
|
|
assignments[var] = cst
|
|
return True
|
|
|
|
for constraint in constraint_system.constraints:
|
|
match reduce_constraint(constraint, assignments):
|
|
case Unsatisfiable():
|
|
return Unsatisfiable()
|
|
case Equals(lhs=Variable() as var, rhs=Constant() as cst):
|
|
if not try_assign(var, cst):
|
|
return Unsatisfiable()
|
|
changed = True
|
|
case Equals(lhs=Constant() as cst, rhs=Variable() as var):
|
|
if not try_assign(var, cst):
|
|
return Unsatisfiable()
|
|
changed = True
|
|
case _ as new_constraint:
|
|
assert isinstance(new_constraint, Constraint) # make pytype happy
|
|
match new_constraint.holds():
|
|
case None:
|
|
constraints.append(new_constraint)
|
|
changed |= new_constraint != constraint
|
|
case False:
|
|
return Unsatisfiable()
|
|
case True:
|
|
changed = True
|
|
|
|
new_constraints = _merge_all_divides_constraints(constraints)
|
|
changed |= len(new_constraints) != len(constraints)
|
|
constraints = new_constraints
|
|
|
|
# Shortcut for a specific case of unsatisfiability. This shortcut
|
|
# drastically reduces the size of the search space.
|
|
if _has_relayout_of_non_splat_to_splat(constraints):
|
|
return Unsatisfiable()
|
|
|
|
if changed:
|
|
return ConstraintSystem(
|
|
assignments=assignments | constraint_system.assignments,
|
|
constraints=constraints,
|
|
)
|
|
return None
|
|
|
|
|
|
def reduce(
|
|
constraint_system: ConstraintSystem,
|
|
) -> ConstraintSystem | Unsatisfiable:
|
|
"""Reduces a constraint system until it can no longer be reduced.
|
|
|
|
Returns:
|
|
- Unsatisfiable(): if the constraint system is unsatisfiable.
|
|
- The maximally reduced constraint system otherwise.
|
|
"""
|
|
while True:
|
|
match _reduce_system_once(constraint_system):
|
|
case None:
|
|
break
|
|
case Unsatisfiable():
|
|
return Unsatisfiable()
|
|
case ConstraintSystem() as new_system:
|
|
constraint_system = new_system
|
|
case _ as never:
|
|
assert_never(never) # pyrefly: ignore[bad-argument-type] # pyrefly#2858
|
|
|
|
return constraint_system
|