hand
This commit is contained in:
@@ -0,0 +1,526 @@
|
||||
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jaxlib.mlir import ir
|
||||
from jaxlib.mlir.dialects import arith
|
||||
from jaxlib.mlir.dialects import llvm
|
||||
from jaxlib.mlir.dialects import nvvm
|
||||
from jaxlib.mlir.dialects import vector
|
||||
import numpy as np
|
||||
|
||||
from . import fragmented_array as fa
|
||||
from . import mma_utils
|
||||
from . import utils
|
||||
|
||||
|
||||
c = utils.c
|
||||
bytewidth = utils.bytewidth
|
||||
|
||||
|
||||
@jax.tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass
|
||||
class WGMMAAccumulator:
|
||||
"""A FragmentedArray that has is synchronized with the async proxy.
|
||||
|
||||
This implies that it requires no additional synchronization when passed in
|
||||
as a WGMMA accumulator. In particular, when created from a
|
||||
FragmentedArray, the necessary synchronization is inserted at construction.
|
||||
"""
|
||||
_original_layout: fa.FragmentedLayout
|
||||
_value: fa.FragmentedArray
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
_value: fa.FragmentedArray,
|
||||
_original_layout: fa.FragmentedLayout,
|
||||
_sync: bool = True,
|
||||
):
|
||||
self._original_layout = _original_layout
|
||||
self._value = _value
|
||||
if _sync:
|
||||
self._value = wgmma_fence(_value)
|
||||
|
||||
@property
|
||||
def value(self) -> fa.FragmentedArray:
|
||||
return self._value.to_layout(self._original_layout)
|
||||
|
||||
@classmethod
|
||||
def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None):
|
||||
if m % 64 or n % 8:
|
||||
raise ValueError("WGMMA requires m and n to be multiples of 64 and 8, "
|
||||
f"got {m} and {n}")
|
||||
if is_signed is False:
|
||||
raise TypeError("PTX does not support unsigned WGMMA accumulators")
|
||||
f32 = ir.F32Type.get()
|
||||
if dtype is None:
|
||||
dtype = f32
|
||||
if isinstance(dtype, ir.IntegerType):
|
||||
zero = arith.constant(dtype, ir.IntegerAttr.get(dtype, 0))
|
||||
else:
|
||||
zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0))
|
||||
return cls.from_registers(
|
||||
fa.FragmentedArray.splat(
|
||||
zero, (m, n), fa.WGMMA_LAYOUT, is_signed=is_signed
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_registers(cls, registers):
|
||||
original_layout = registers.layout
|
||||
if registers.layout != fa.WGMMA_LAYOUT and registers.layout != fa.WGMMA_LAYOUT_ACC_32BIT:
|
||||
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
|
||||
if utils.bitwidth(registers.mlir_dtype) == 32:
|
||||
registers = registers.to_layout(fa.WGMMA_LAYOUT_ACC_32BIT)
|
||||
return cls(_value=registers, _original_layout=original_layout)
|
||||
|
||||
def tree_flatten(self):
|
||||
return (self._value,), (self._original_layout,)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux, value):
|
||||
return cls(_value=value[0], _original_layout=aux[0], _sync=False)
|
||||
|
||||
|
||||
def _supported_wgmma_types(dtype, abtype) -> bool:
|
||||
input_types_are = lambda ty: isinstance(abtype, ty)
|
||||
f16_acc_types = (ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType)
|
||||
if isinstance(dtype, ir.F32Type):
|
||||
return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, *f16_acc_types))
|
||||
elif isinstance(dtype, ir.F16Type):
|
||||
return any(input_types_are(ty) for ty in f16_acc_types)
|
||||
elif (
|
||||
isinstance(dtype, ir.IntegerType)
|
||||
and dtype.width == 32
|
||||
and dtype.is_signless
|
||||
):
|
||||
return input_types_are(ir.IntegerType)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def wgmma_m64(
|
||||
acc: np.ndarray, # of register Values
|
||||
a,
|
||||
b_descriptor: ir.Value,
|
||||
a_transpose: bool | None,
|
||||
b_transpose: bool,
|
||||
a_k_stride: int | None,
|
||||
b_k_stride: int,
|
||||
n: int,
|
||||
swizzle: int,
|
||||
element_type: ir.Type,
|
||||
):
|
||||
out_ty = ir.VectorType(acc.flat[0].type).element_type
|
||||
if not _supported_wgmma_types(out_ty, element_type):
|
||||
raise ValueError(f"Unsupported wgmma types {(out_ty, element_type)=}")
|
||||
if n % 8:
|
||||
raise ValueError
|
||||
|
||||
bf16 = ir.BF16Type.get()
|
||||
f16 = ir.F16Type.get()
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
f8e5m2 = ir.Float8E5M2Type.get()
|
||||
f8e4m3fn = ir.Float8E4M3FNType.get()
|
||||
if b_k_stride % 16:
|
||||
raise ValueError
|
||||
# Only 16-bit types support transposes
|
||||
supports_transpose = bytewidth(element_type) == 2
|
||||
if not supports_transpose and (a_transpose or b_transpose):
|
||||
raise ValueError("Only f16 WGMMA supports transposes")
|
||||
if a_in_regs := isinstance(a, fa.FragmentedArray):
|
||||
if a.mlir_dtype not in {bf16, f16, i8, f8e5m2, f8e4m3fn}:
|
||||
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
|
||||
# Column count must be equal to swizzle // bytewidth.
|
||||
elt_bytewidth = utils.bytewidth(element_type)
|
||||
swizzle_elems = swizzle // elt_bytewidth
|
||||
if a.shape != (64, swizzle_elems):
|
||||
raise ValueError("Unsupported A register array shape")
|
||||
if a.layout not in {fa.WGMMA_LAYOUT, fa.WGMMA_LAYOUT_8BIT}:
|
||||
raise ValueError("Unsupported A register array layout")
|
||||
if a_k_stride is not None or a_transpose is not None:
|
||||
raise ValueError("Unsupported WGMMA features with A in registers")
|
||||
else:
|
||||
if a_k_stride is None or a_k_stride % 16:
|
||||
raise ValueError
|
||||
if a_transpose is None:
|
||||
raise ValueError
|
||||
|
||||
if isinstance(out_ty, ir.F32Type) or out_ty == i32:
|
||||
num_acc_regs = n // 2
|
||||
out_ty_field = ir.VectorType.get((1,), out_ty)
|
||||
acc_regs = list(acc.flat)
|
||||
assert acc_regs[0].type == ir.VectorType.get((1,), out_ty)
|
||||
to_acc_vec_regs = lambda regs: np.array(regs).reshape(acc.shape)
|
||||
acc_constraint = "r" if isinstance(out_ty, ir.IntegerType) else "f"
|
||||
elif isinstance(out_ty, ir.F16Type):
|
||||
num_acc_regs = n // 4
|
||||
out_ty_field = i32
|
||||
acc_regs = [_as_i32_reg(reg) for reg in acc.flat]
|
||||
vec_ty = ir.VectorType(acc.flat[0].type)
|
||||
to_acc_vec_regs = lambda regs: np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape)
|
||||
acc_constraint = "r"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"WGMMA instruction only supports f32, f16 and s32 out (got {out_ty})")
|
||||
|
||||
if supports_transpose:
|
||||
num_imm_regs = 4
|
||||
elif out_ty == i32:
|
||||
num_imm_regs = 0
|
||||
else:
|
||||
num_imm_regs = 2
|
||||
|
||||
if a_in_regs:
|
||||
a_reg_constraints = ["r"] * 4 # 4x (b)f16x2 or s8x4 registers
|
||||
if supports_transpose:
|
||||
num_imm_regs -= 1 # transpose not supported for a in registers
|
||||
else:
|
||||
a_reg_constraints = ["l"] # descriptor
|
||||
# Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
|
||||
# Seems like it's not actually documented in LLVM IR docs.
|
||||
reg_constraints_list = (
|
||||
[f"={acc_constraint}"] * num_acc_regs # accumulator registers
|
||||
+ [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too.
|
||||
+ a_reg_constraints # a descriptor / registers
|
||||
+ ["l"] * 1 # b descriptor
|
||||
+ ["n"] * (1 + num_imm_regs) # literal constants
|
||||
)
|
||||
reg_constraints = ",".join(reg_constraints_list)
|
||||
reg_count = itertools.count()
|
||||
|
||||
def take_regs(n):
|
||||
return (f"${i}" for i in itertools.islice(reg_count, n))
|
||||
|
||||
acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}"
|
||||
for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing.
|
||||
pass
|
||||
if a_in_regs:
|
||||
a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}"
|
||||
else:
|
||||
a_regs, = take_regs(1)
|
||||
b_desc_reg, use_out_reg = take_regs(2)
|
||||
# Immediate regs (scale, ...).
|
||||
imm_regs = "".join(f", {r}" for r in take_regs(num_imm_regs))
|
||||
assert next(reg_count) == len(reg_constraints_list)
|
||||
k_instr = 32 // bytewidth(element_type)
|
||||
el_ty = str(element_type)
|
||||
if isinstance(element_type, ir.Float8E5M2Type):
|
||||
el_ty = "e5m2"
|
||||
elif isinstance(element_type, ir.Float8E4M3FNType):
|
||||
el_ty = "e4m3"
|
||||
elif isinstance(element_type, ir.IntegerType):
|
||||
# TODO(bchetioui): add u8 support in the future. Currently we always assume
|
||||
# that 8-bit integers are s8, and we would need to change the signature of
|
||||
# `wgmma` to indicate whether the input should be treated as signed or not.
|
||||
el_ty = "s8"
|
||||
|
||||
out_ty_str = str(out_ty)
|
||||
if out_ty == i32:
|
||||
out_ty_str = "s32"
|
||||
|
||||
wgmma_instr = (
|
||||
f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty_str}.{el_ty}.{el_ty} "
|
||||
f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p{imm_regs};"
|
||||
)
|
||||
ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"
|
||||
|
||||
def lc(x):
|
||||
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
||||
|
||||
use_out = scale_a = scale_b = lc(1)
|
||||
if out_ty == i32:
|
||||
imms = [use_out]
|
||||
else:
|
||||
imms = [use_out, scale_a, scale_b]
|
||||
|
||||
if supports_transpose and a_transpose is not None:
|
||||
imms += [lc(int(a_transpose)), lc(int(b_transpose))]
|
||||
elif supports_transpose:
|
||||
imms += [lc(int(b_transpose))]
|
||||
|
||||
assert len(imms) == num_imm_regs + 1 # +1 for the use_out_reg in setp.ne.b32
|
||||
|
||||
expected_dim = 10 if utils.bitwidth(out_ty) == 32 else 9
|
||||
expected_regs_per_tile = 4 if utils.bitwidth(out_ty) == 32 else 2
|
||||
if acc.ndim != expected_dim or acc.shape[0] != 1 or math.prod(acc.shape[2:]) != expected_regs_per_tile:
|
||||
raise ValueError(acc.shape)
|
||||
acc_struct_type = ir.Type.parse(
|
||||
f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>"
|
||||
)
|
||||
for i in range((swizzle // bytewidth(element_type)) // k_instr):
|
||||
# Slice out the relevant part of A or advance the A descriptor.
|
||||
if a_in_regs:
|
||||
a_slice = a[:, (i * k_instr) : ((i + 1) * k_instr)]
|
||||
a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
|
||||
else:
|
||||
if i > 0:
|
||||
assert a_k_stride is not None
|
||||
a = _llvm_add(
|
||||
a,
|
||||
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
|
||||
)
|
||||
a_args = [a]
|
||||
# Advance the B descriptor.
|
||||
if i > 0:
|
||||
b_descriptor = _llvm_add(
|
||||
b_descriptor,
|
||||
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
|
||||
)
|
||||
assert len(a_args) == len(a_reg_constraints)
|
||||
acc_struct = llvm.inline_asm(
|
||||
acc_struct_type,
|
||||
[*acc_regs, *a_args, b_descriptor, *imms],
|
||||
ptx,
|
||||
reg_constraints,
|
||||
asm_dialect=0,
|
||||
has_side_effects=True,
|
||||
)
|
||||
assert isinstance(acc_struct, ir.Value)
|
||||
acc_regs = [
|
||||
llvm.extractvalue(out_ty_field, acc_struct, [i]) for i in range(len(acc_regs))
|
||||
]
|
||||
return to_acc_vec_regs(acc_regs)
|
||||
|
||||
|
||||
def wgmma(
|
||||
acc: WGMMAAccumulator,
|
||||
a: fa.FragmentedArray | ir.Value,
|
||||
b: ir.Value,
|
||||
*,
|
||||
swizzle: int = 128,
|
||||
):
|
||||
"""Perform acc += a @ b using the WGMMA instruction.
|
||||
|
||||
`a` may be passed in registers, or as a memref. `b` must be a memref.
|
||||
|
||||
The expected (logical) memref shapes are:
|
||||
a: (m // tile_m, k // tile_k, tile_m, tile_k)
|
||||
b: (k // tile_k, n // tile_n, tile_k, tile_n).
|
||||
|
||||
While the shapes may be physically transposed, when considering the row-major
|
||||
physical shape, the tile dimensions must be the two minor dimensions and must
|
||||
have the shape (8, S) where S = swizzle // bytewidth(element_type).
|
||||
"""
|
||||
if swizzle == 16:
|
||||
raise NotImplementedError("No swizzle is not supported")
|
||||
# Step 1. Establish the shape and element type of the operation.
|
||||
if not isinstance(b.type, ir.MemRefType):
|
||||
raise ValueError(f"B must be a memref, got: {b.type}")
|
||||
bf16 = ir.BF16Type.get()
|
||||
f32 = ir.F32Type.get()
|
||||
f16 = ir.F16Type.get()
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
i8 = ir.IntegerType.get_signless(8)
|
||||
f8e5m2 = ir.Float8E5M2Type.get()
|
||||
f8e4m3fn = ir.Float8E4M3FNType.get()
|
||||
(k, n), element_type = mma_utils.tiled_memref_shape(b)
|
||||
if a_in_regs := isinstance(a, fa.FragmentedArray):
|
||||
m, k2 = a.shape
|
||||
element_type2 = a.mlir_dtype
|
||||
if element_type2 not in {f16, bf16, i8, f8e5m2, f8e4m3fn}:
|
||||
raise ValueError(
|
||||
"Only f16, bf16, i8, f8e5m2, f8e4m3fn are supported for A "
|
||||
f"in registers, got {element_type2}"
|
||||
)
|
||||
if element_type2 == i8 and swizzle == 32:
|
||||
# TODO(bchetioui): relax this when ptxas is fixed. As of ptxas 12.8,
|
||||
# optimizations eliminate MMA instructions, leading to only the first tile
|
||||
# of the result being computed correctly.
|
||||
raise NotImplementedError("swizzle=32 not supported for s8 lhs in registers")
|
||||
elif isinstance(a.type, ir.MemRefType):
|
||||
(m, k2), element_type2 = mma_utils.tiled_memref_shape(a)
|
||||
else:
|
||||
raise ValueError(f"Unsupported A type: {type(a)}")
|
||||
if k != k2:
|
||||
raise ValueError(
|
||||
"WGMMA requires A and B to have the same contraction dimension (K),"
|
||||
f" got: {k2} and {k}"
|
||||
)
|
||||
if element_type != element_type2:
|
||||
raise ValueError(
|
||||
"WGMMA requires A and B to have the same element type, got:"
|
||||
f" {element_type2} and {element_type}"
|
||||
)
|
||||
if acc._value.shape != (m, n):
|
||||
raise ValueError(
|
||||
f"Accumulator shape mismatch: expected {(m, n)}, got {acc._value.shape}"
|
||||
)
|
||||
if element_type == f32 or element_type == ir.BF16Type.get():
|
||||
if acc._value.mlir_dtype != f32:
|
||||
raise ValueError(
|
||||
f"WGMMA with element type {element_type} only supports accumulators"
|
||||
f" of type f32, but got: {acc._value.mlir_dtype}"
|
||||
)
|
||||
elif any(
|
||||
isinstance(element_type, t)
|
||||
for t in {ir.F16Type, ir.Float8E5M2Type, ir.Float8E4M3FNType}
|
||||
):
|
||||
if acc._value.mlir_dtype != f16 and acc._value.mlir_dtype != f32:
|
||||
raise ValueError(
|
||||
f"WGMMA with element type {element_type} only supports accumulators "
|
||||
f"of type f32 or f16, but got: {acc._value.mlir_dtype}"
|
||||
)
|
||||
elif element_type == i8:
|
||||
if a_in_regs and not a.is_signed: # pyrefly: ignore[missing-attribute]
|
||||
raise NotImplementedError("WGMMA with lhs of type u8")
|
||||
if acc._value.mlir_dtype != i32 or not acc._value.is_signed:
|
||||
raise ValueError(
|
||||
f"WGMMA with element type {element_type} only supports accumulators "
|
||||
f"of type s32, but got: {acc._value.mlir_dtype}"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported element type: {element_type}")
|
||||
|
||||
# Step 2. Decide on the instruction shapes we'll use. Note that with swizzles,
|
||||
# instructions must be issued in groups of the same width as the swizzle.
|
||||
m_group_elems = 64 # Hopper has a fixed M instruction shape.
|
||||
k_group_elems = swizzle // utils.bytewidth(element_type)
|
||||
if n > 256 or n % 8:
|
||||
raise ValueError(f"N must be a multiple of 8 and <= 256, got: {n}")
|
||||
n_group_elems = n # We assume only one N group below.
|
||||
if m % m_group_elems:
|
||||
raise ValueError(f"M must be a multiple of {m_group_elems}, got: {m}")
|
||||
if k % k_group_elems:
|
||||
raise ValueError(f"K must be a multiple of {k_group_elems}, got: {k}")
|
||||
m_groups = m // m_group_elems
|
||||
k_groups = k // k_group_elems
|
||||
# TODO(apaszke): Require users to bitcast input refs to tf32 before WGMMA.
|
||||
wgmma_element_type = (
|
||||
ir.FloatTF32Type.get() if element_type == ir.F32Type.get() else element_type
|
||||
)
|
||||
|
||||
# Step 3. Compute the operand descriptors.
|
||||
if a_in_regs:
|
||||
a_desc_base = a_m_group_stride = a_k_group_stride = None
|
||||
a_instr_params = dict(a_transpose=None, a_k_stride=None)
|
||||
else:
|
||||
assert isinstance(a, ir.Value)
|
||||
(
|
||||
(a_desc_base, a_k_instr_stride),
|
||||
(a_m_group_stride, a_k_group_stride),
|
||||
a_fastest,
|
||||
) = mma_utils.create_descriptor(
|
||||
a,
|
||||
swizzle=swizzle,
|
||||
large_tile=(m_group_elems, k_group_elems),
|
||||
group_size=(m_group_elems, k_group_elems),
|
||||
logical_k_major=False,
|
||||
)
|
||||
assert not a_k_instr_stride[0] # We'd need separate a/b swizzles.
|
||||
a_k_instr_stride = a_k_instr_stride[1][0]
|
||||
a_instr_params = dict(a_transpose=a_fastest != mma_utils.Dim.K,
|
||||
a_k_stride=a_k_instr_stride)
|
||||
(
|
||||
(b_desc_base, b_k_instr_stride),
|
||||
(b_n_group_stride, b_k_group_stride),
|
||||
b_fastest,
|
||||
) = mma_utils.create_descriptor(
|
||||
b,
|
||||
swizzle=swizzle,
|
||||
large_tile=(k_group_elems,) * 2, # It's not a typo that we use k for n.
|
||||
group_size=(k_group_elems, n_group_elems),
|
||||
logical_k_major=True,
|
||||
)
|
||||
assert not b_k_instr_stride[0] # We'd need separate a/b swizzles.
|
||||
b_k_instr_stride = b_k_instr_stride[1][0]
|
||||
del b_n_group_stride # We only support one N group.
|
||||
|
||||
# Step 4. Issue the instructions.
|
||||
if a_in_regs:
|
||||
assert isinstance(a, fa.FragmentedArray)
|
||||
a = wgmma_fence(a) # Make sure the registers are ready.
|
||||
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
new_acc_regs = acc._value.registers.copy()
|
||||
for mi in range(m_groups):
|
||||
for ki in range(k_groups):
|
||||
if a_in_regs:
|
||||
assert isinstance(a, fa.FragmentedArray)
|
||||
a_mk = a[
|
||||
mi * m_group_elems : (mi + 1) * m_group_elems,
|
||||
ki * k_group_elems : (ki + 1) * k_group_elems,
|
||||
]
|
||||
else:
|
||||
assert a_m_group_stride is not None and a_k_group_stride is not None
|
||||
a_group_offset = mi * a_m_group_stride + ki * a_k_group_stride
|
||||
a_mk = _llvm_add(
|
||||
a_desc_base, c(mma_utils.encode_addr(a_group_offset), i64),
|
||||
)
|
||||
b_k = _llvm_add(
|
||||
b_desc_base, c(mma_utils.encode_addr(ki * b_k_group_stride), i64)
|
||||
)
|
||||
new_acc_regs[mi : mi + 1] = wgmma_m64(
|
||||
new_acc_regs[mi : mi + 1],
|
||||
a_mk,
|
||||
b_k,
|
||||
swizzle=swizzle,
|
||||
n=n_group_elems,
|
||||
element_type=wgmma_element_type,
|
||||
b_transpose=b_fastest != mma_utils.Dim.K,
|
||||
b_k_stride=b_k_instr_stride,
|
||||
**a_instr_params,
|
||||
)
|
||||
return WGMMAAccumulator(
|
||||
_value=fa.FragmentedArray(
|
||||
_registers=new_acc_regs,
|
||||
_layout=acc._value.layout,
|
||||
_is_signed=acc._value.is_signed,
|
||||
),
|
||||
_original_layout=acc._original_layout,
|
||||
_sync=False,
|
||||
)
|
||||
|
||||
|
||||
def wgmma_fence(array: fa.FragmentedArray) -> fa.FragmentedArray:
|
||||
"""Fences the array construction from WGMMA instructions.
|
||||
|
||||
LLVM treats in-register computation as pure and can move it after the fence,
|
||||
which is explicitly disallowed by the PTX programming model. For that reason,
|
||||
we insert an LLVM optimization barrier before the fence.
|
||||
"""
|
||||
array = fa.optimization_barrier(array)
|
||||
nvvm.wgmma_fence_aligned()
|
||||
return array
|
||||
|
||||
|
||||
def _as_i32_reg(v):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return llvm.extractelement(
|
||||
vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0)
|
||||
)
|
||||
|
||||
|
||||
def _lc(x):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
||||
|
||||
|
||||
def _llvm_add(x, y):
|
||||
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
|
||||
|
||||
|
||||
def _unpack_i32(vec_ty, r):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
return vector.bitcast(
|
||||
vec_ty, vector.broadcast(ir.VectorType.get((1,), i32), r)
|
||||
)
|
||||
Reference in New Issue
Block a user