186 lines
6.5 KiB
Python
186 lines
6.5 KiB
Python
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Layout utilities."""
|
|
|
|
from typing import assert_never
|
|
|
|
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
|
from jax._src.lib.mlir import ir
|
|
|
|
from . import fragmented_array as fa
|
|
from . import launch_context
|
|
|
|
|
|
# TODO(b/415721295): Refine return type once minimum jaxlib version is 0.10.0
|
|
def _to_splat_fragmented_layout_attr(
|
|
layout: fa.WGSplatFragLayout,
|
|
) -> ir.Attribute:
|
|
"""Constructs a #mosaic_gpu.WGSplatFragLayout attribute from a WGSplatFragLayout."""
|
|
shape = ir.DenseI64ArrayAttr.get(layout.shape)
|
|
return mgpu.WGSplatFragLayoutAttr.get(shape)
|
|
|
|
|
|
def _from_splat_fragmented_layout_attr(
|
|
attr: ir.Attribute,
|
|
) -> fa.WGSplatFragLayout:
|
|
# TODO(b/415721295): Refine arg type once minimum jaxlib version is 0.10.0
|
|
assert isinstance(attr, mgpu.WGSplatFragLayoutAttr)
|
|
return fa.WGSplatFragLayout(shape=tuple(attr.shape))
|
|
|
|
|
|
# TODO(b/415721295): Refine return type once minimum jaxlib version is 0.10.0
|
|
def _to_strided_fragmented_layout_attr(
|
|
layout: fa.WGStridedFragLayout,
|
|
) -> ir.Attribute:
|
|
"""Constructs a #mosaic_gpu.WGStridedFragLayout attribute from a WGStridedFragLayout."""
|
|
shape = ir.DenseI64ArrayAttr.get(layout.shape)
|
|
return mgpu.WGStridedFragLayoutAttr.get(shape, layout.vec_size)
|
|
|
|
|
|
def _from_strided_fragmented_layout_attr(
|
|
attr: ir.Attribute,
|
|
) -> fa.WGStridedFragLayout:
|
|
"""Constructs a WGStridedFragLayout from a #mosaic_gpu.WGStridedFragLayout attribute."""
|
|
# TODO(b/415721295): Refine arg type once minimum jaxlib version is 0.10.0
|
|
assert isinstance(attr, mgpu.WGStridedFragLayoutAttr)
|
|
return fa.WGStridedFragLayout(
|
|
shape=tuple(attr.shape),
|
|
vec_size=attr.vector_size,
|
|
)
|
|
|
|
|
|
# TODO(b/415721295): Refine return type once minimum jaxlib version is 0.10.0
|
|
def _to_tiled_layout_attr(
|
|
layout: fa.TiledLayout,
|
|
) -> ir.Attribute:
|
|
"""Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
|
|
i64 = ir.IntegerType.get_signless(64)
|
|
|
|
def _int_or_replicated(d: int | fa.Replicated) -> ir.Attribute:
|
|
if isinstance(d, fa.Replicated):
|
|
return mgpu.ReplicatedAttr.get(d.times)
|
|
return ir.IntegerAttr.get(i64, d)
|
|
|
|
def _tile_attr(tile):
|
|
return ir.ArrayAttr.get([ir.IntegerAttr.get(i64, d) for d in tile])
|
|
|
|
tiling_attr = ir.ArrayAttr.get(
|
|
[_tile_attr(tile) for tile in layout.tiling.tiles]
|
|
)
|
|
warp_dims_attr = ir.ArrayAttr.get(
|
|
[_int_or_replicated(d) for d in layout.warp_dims]
|
|
)
|
|
lane_dims_attr = ir.ArrayAttr.get(
|
|
[_int_or_replicated(d) for d in layout.lane_dims]
|
|
)
|
|
|
|
return mgpu.TiledLayoutAttr.get(
|
|
tiling_attr, warp_dims_attr, lane_dims_attr, layout.vector_dim
|
|
)
|
|
|
|
|
|
def _from_tiled_layout_attr(
|
|
attr: ir.Attribute,
|
|
) -> fa.TiledLayout:
|
|
"""Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute."""
|
|
# TODO(allanrenucci): Refine arg type once minimum jaxlib version is 0.10.0
|
|
assert isinstance(attr, mgpu.TiledLayoutAttr)
|
|
|
|
def _from_int_or_replicated_attr(d_attr: ir.Attribute) -> int | fa.Replicated:
|
|
if isinstance(d_attr, mgpu.ReplicatedAttr):
|
|
return fa.Replicated(times=mgpu.ReplicatedAttr(d_attr).times)
|
|
return ir.IntegerAttr(d_attr).value
|
|
|
|
tiles = tuple(
|
|
tuple(ir.IntegerAttr(d).value for d in ir.ArrayAttr(tile))
|
|
for tile in attr.tiling
|
|
)
|
|
warp_dims = tuple(_from_int_or_replicated_attr(d) for d in attr.warp_dims)
|
|
lane_dims = tuple(_from_int_or_replicated_attr(d) for d in attr.lane_dims)
|
|
|
|
return fa.TiledLayout(
|
|
tiling=fa.Tiling(tiles),
|
|
warp_dims=warp_dims,
|
|
lane_dims=lane_dims,
|
|
vector_dim=attr.vector_dim,
|
|
)
|
|
|
|
|
|
def to_layout_attr(layout: fa.FragmentedLayout) -> ir.Attribute:
|
|
"""Constructs an MLIR attribute that corresponds to the given layout."""
|
|
match layout:
|
|
case fa.WGSplatFragLayout():
|
|
return _to_splat_fragmented_layout_attr(layout)
|
|
case fa.WGStridedFragLayout():
|
|
return _to_strided_fragmented_layout_attr(layout)
|
|
case fa.TiledLayout():
|
|
return _to_tiled_layout_attr(layout)
|
|
case _:
|
|
assert_never(layout)
|
|
|
|
|
|
def from_layout_attr(attr: ir.Attribute) -> fa.FragmentedLayout:
|
|
"""Constructs a layout from an MLIR attribute."""
|
|
if isinstance(attr, mgpu.WGSplatFragLayoutAttr):
|
|
return _from_splat_fragmented_layout_attr(attr)
|
|
elif isinstance(attr, mgpu.WGStridedFragLayoutAttr):
|
|
return _from_strided_fragmented_layout_attr(attr)
|
|
elif isinstance(attr, mgpu.TiledLayoutAttr):
|
|
return _from_tiled_layout_attr(attr)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Unsupported layout for conversion from MLIR attribute: {attr}"
|
|
)
|
|
|
|
|
|
def splat_is_compatible_with_tiled(
|
|
l1: fa.WGSplatFragLayout, l2: fa.TiledLayout
|
|
) -> bool:
|
|
# A splat layout is compatible with a tiled layout up to replication if each
|
|
# dimension in the shape of the splat layout is divisible by the corresponding
|
|
# dimension in the base tile shape.
|
|
s1, s2 = l1.shape, l2.base_tile_shape
|
|
return all(d1 % d2 == 0 for d1, d2 in zip(s1, s2))
|
|
|
|
|
|
def to_transform_attr(
|
|
transform: launch_context.MemRefTransform | mgpu.SwizzlingMode,
|
|
) -> ir.Attribute:
|
|
if isinstance(transform, launch_context.TileTransform):
|
|
return mgpu.TileTransformAttr.get(transform.tiling)
|
|
elif isinstance(transform, launch_context.TransposeTransform):
|
|
return mgpu.TransposeTransformAttr.get(transform.permutation)
|
|
elif isinstance(transform, mgpu.SwizzlingMode):
|
|
return mgpu.SwizzleTransformAttr.get(transform)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported transform {transform}")
|
|
|
|
|
|
def from_transform_attr(
|
|
transform: ir.Attribute,
|
|
) -> launch_context.MemRefTransform | mgpu.SwizzlingMode:
|
|
if isinstance(transform, mgpu.TileTransformAttr):
|
|
return launch_context.TileTransform(
|
|
tuple(mgpu.TileTransformAttr(transform).tiling)
|
|
)
|
|
elif isinstance(transform, mgpu.TransposeTransformAttr):
|
|
return launch_context.TransposeTransform(
|
|
tuple(mgpu.TransposeTransformAttr(transform).permutation)
|
|
)
|
|
elif isinstance(transform, mgpu.SwizzleTransformAttr):
|
|
return mgpu.SwizzlingMode(mgpu.SwizzleTransformAttr(transform).swizzle)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported transform {transform}")
|