hand
This commit is contained in:
@@ -0,0 +1,185 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Layout utilities."""
|
||||
|
||||
from typing import assert_never
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||
from jax._src.lib.mlir import ir
|
||||
|
||||
from . import fragmented_array as fa
|
||||
from . import launch_context
|
||||
|
||||
|
||||
# TODO(b/415721295): Refine return type once minimum jaxlib version is 0.10.0
|
||||
def _to_splat_fragmented_layout_attr(
|
||||
layout: fa.WGSplatFragLayout,
|
||||
) -> ir.Attribute:
|
||||
"""Constructs a #mosaic_gpu.WGSplatFragLayout attribute from a WGSplatFragLayout."""
|
||||
shape = ir.DenseI64ArrayAttr.get(layout.shape)
|
||||
return mgpu.WGSplatFragLayoutAttr.get(shape)
|
||||
|
||||
|
||||
def _from_splat_fragmented_layout_attr(
|
||||
attr: ir.Attribute,
|
||||
) -> fa.WGSplatFragLayout:
|
||||
# TODO(b/415721295): Refine arg type once minimum jaxlib version is 0.10.0
|
||||
assert isinstance(attr, mgpu.WGSplatFragLayoutAttr)
|
||||
return fa.WGSplatFragLayout(shape=tuple(attr.shape))
|
||||
|
||||
|
||||
# TODO(b/415721295): Refine return type once minimum jaxlib version is 0.10.0
|
||||
def _to_strided_fragmented_layout_attr(
|
||||
layout: fa.WGStridedFragLayout,
|
||||
) -> ir.Attribute:
|
||||
"""Constructs a #mosaic_gpu.WGStridedFragLayout attribute from a WGStridedFragLayout."""
|
||||
shape = ir.DenseI64ArrayAttr.get(layout.shape)
|
||||
return mgpu.WGStridedFragLayoutAttr.get(shape, layout.vec_size)
|
||||
|
||||
|
||||
def _from_strided_fragmented_layout_attr(
|
||||
attr: ir.Attribute,
|
||||
) -> fa.WGStridedFragLayout:
|
||||
"""Constructs a WGStridedFragLayout from a #mosaic_gpu.WGStridedFragLayout attribute."""
|
||||
# TODO(b/415721295): Refine arg type once minimum jaxlib version is 0.10.0
|
||||
assert isinstance(attr, mgpu.WGStridedFragLayoutAttr)
|
||||
return fa.WGStridedFragLayout(
|
||||
shape=tuple(attr.shape),
|
||||
vec_size=attr.vector_size,
|
||||
)
|
||||
|
||||
|
||||
# TODO(b/415721295): Refine return type once minimum jaxlib version is 0.10.0
|
||||
def _to_tiled_layout_attr(
|
||||
layout: fa.TiledLayout,
|
||||
) -> ir.Attribute:
|
||||
"""Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
|
||||
def _int_or_replicated(d: int | fa.Replicated) -> ir.Attribute:
|
||||
if isinstance(d, fa.Replicated):
|
||||
return mgpu.ReplicatedAttr.get(d.times)
|
||||
return ir.IntegerAttr.get(i64, d)
|
||||
|
||||
def _tile_attr(tile):
|
||||
return ir.ArrayAttr.get([ir.IntegerAttr.get(i64, d) for d in tile])
|
||||
|
||||
tiling_attr = ir.ArrayAttr.get(
|
||||
[_tile_attr(tile) for tile in layout.tiling.tiles]
|
||||
)
|
||||
warp_dims_attr = ir.ArrayAttr.get(
|
||||
[_int_or_replicated(d) for d in layout.warp_dims]
|
||||
)
|
||||
lane_dims_attr = ir.ArrayAttr.get(
|
||||
[_int_or_replicated(d) for d in layout.lane_dims]
|
||||
)
|
||||
|
||||
return mgpu.TiledLayoutAttr.get(
|
||||
tiling_attr, warp_dims_attr, lane_dims_attr, layout.vector_dim
|
||||
)
|
||||
|
||||
|
||||
def _from_tiled_layout_attr(
|
||||
attr: ir.Attribute,
|
||||
) -> fa.TiledLayout:
|
||||
"""Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute."""
|
||||
# TODO(allanrenucci): Refine arg type once minimum jaxlib version is 0.10.0
|
||||
assert isinstance(attr, mgpu.TiledLayoutAttr)
|
||||
|
||||
def _from_int_or_replicated_attr(d_attr: ir.Attribute) -> int | fa.Replicated:
|
||||
if isinstance(d_attr, mgpu.ReplicatedAttr):
|
||||
return fa.Replicated(times=mgpu.ReplicatedAttr(d_attr).times)
|
||||
return ir.IntegerAttr(d_attr).value
|
||||
|
||||
tiles = tuple(
|
||||
tuple(ir.IntegerAttr(d).value for d in ir.ArrayAttr(tile))
|
||||
for tile in attr.tiling
|
||||
)
|
||||
warp_dims = tuple(_from_int_or_replicated_attr(d) for d in attr.warp_dims)
|
||||
lane_dims = tuple(_from_int_or_replicated_attr(d) for d in attr.lane_dims)
|
||||
|
||||
return fa.TiledLayout(
|
||||
tiling=fa.Tiling(tiles),
|
||||
warp_dims=warp_dims,
|
||||
lane_dims=lane_dims,
|
||||
vector_dim=attr.vector_dim,
|
||||
)
|
||||
|
||||
|
||||
def to_layout_attr(layout: fa.FragmentedLayout) -> ir.Attribute:
|
||||
"""Constructs an MLIR attribute that corresponds to the given layout."""
|
||||
match layout:
|
||||
case fa.WGSplatFragLayout():
|
||||
return _to_splat_fragmented_layout_attr(layout)
|
||||
case fa.WGStridedFragLayout():
|
||||
return _to_strided_fragmented_layout_attr(layout)
|
||||
case fa.TiledLayout():
|
||||
return _to_tiled_layout_attr(layout)
|
||||
case _:
|
||||
assert_never(layout)
|
||||
|
||||
|
||||
def from_layout_attr(attr: ir.Attribute) -> fa.FragmentedLayout:
|
||||
"""Constructs a layout from an MLIR attribute."""
|
||||
if isinstance(attr, mgpu.WGSplatFragLayoutAttr):
|
||||
return _from_splat_fragmented_layout_attr(attr)
|
||||
elif isinstance(attr, mgpu.WGStridedFragLayoutAttr):
|
||||
return _from_strided_fragmented_layout_attr(attr)
|
||||
elif isinstance(attr, mgpu.TiledLayoutAttr):
|
||||
return _from_tiled_layout_attr(attr)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported layout for conversion from MLIR attribute: {attr}"
|
||||
)
|
||||
|
||||
|
||||
def splat_is_compatible_with_tiled(
|
||||
l1: fa.WGSplatFragLayout, l2: fa.TiledLayout
|
||||
) -> bool:
|
||||
# A splat layout is compatible with a tiled layout up to replication if each
|
||||
# dimension in the shape of the splat layout is divisible by the corresponding
|
||||
# dimension in the base tile shape.
|
||||
s1, s2 = l1.shape, l2.base_tile_shape
|
||||
return all(d1 % d2 == 0 for d1, d2 in zip(s1, s2))
|
||||
|
||||
|
||||
def to_transform_attr(
|
||||
transform: launch_context.MemRefTransform | mgpu.SwizzlingMode,
|
||||
) -> ir.Attribute:
|
||||
if isinstance(transform, launch_context.TileTransform):
|
||||
return mgpu.TileTransformAttr.get(transform.tiling)
|
||||
elif isinstance(transform, launch_context.TransposeTransform):
|
||||
return mgpu.TransposeTransformAttr.get(transform.permutation)
|
||||
elif isinstance(transform, mgpu.SwizzlingMode):
|
||||
return mgpu.SwizzleTransformAttr.get(transform)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported transform {transform}")
|
||||
|
||||
|
||||
def from_transform_attr(
|
||||
transform: ir.Attribute,
|
||||
) -> launch_context.MemRefTransform | mgpu.SwizzlingMode:
|
||||
if isinstance(transform, mgpu.TileTransformAttr):
|
||||
return launch_context.TileTransform(
|
||||
tuple(mgpu.TileTransformAttr(transform).tiling)
|
||||
)
|
||||
elif isinstance(transform, mgpu.TransposeTransformAttr):
|
||||
return launch_context.TransposeTransform(
|
||||
tuple(mgpu.TransposeTransformAttr(transform).permutation)
|
||||
)
|
||||
elif isinstance(transform, mgpu.SwizzleTransformAttr):
|
||||
return mgpu.SwizzlingMode(mgpu.SwizzleTransformAttr(transform).swizzle)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported transform {transform}")
|
||||
Reference in New Issue
Block a user