This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,83 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import inspect
from functools import wraps
from ..dialects._ods_common import get_op_result_or_op_results
from ..ir import Type, InsertionPoint
def op_region_builder(op, op_region, terminator=None):
def builder_wrapper(body_builder):
# Add a block with block args having types determined by type hints on the wrapped function.
if len(op_region.blocks) == 0:
sig = inspect.signature(body_builder)
types = [p.annotation for p in sig.parameters.values()]
if not (
len(types) == len(sig.parameters)
and all(isinstance(t, Type) for t in types)
):
raise ValueError(
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
)
op_region.blocks.append(*types)
with InsertionPoint(op_region.blocks[0]):
results = body_builder(*list(op_region.blocks[0].arguments))
with InsertionPoint(list(op_region.blocks)[-1]):
if terminator is not None:
res = []
if isinstance(results, (tuple, list)):
res.extend(results)
elif results is not None:
res.append(results)
terminator(res)
return get_op_result_or_op_results(op)
return builder_wrapper
def region_op(op_constructor, terminator=None):
"""Decorator to define an MLIR Op specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
When applied as a decorator to a Python function, an entry block will
be constructed for the Op with types as specified **as type hints on the args of the function**.
The block arguments will be passed positionally to the Python function.
If a terminator is specified then the return from the decorated function will be passed
to the terminator as the last statement in the entry block. Note, the API for the terminator
is a (possibly empty) list; terminator accepting single values should be wrapped in a
`lambda args: term(args[0])`
The identifier (name) of the function will become:
1. A single value result if the Op returns a single value;
2. An OpResultList (as a list) if the Op returns multiple values;
3. The Operation if the Op returns no results.
See examples in tensor.py and transform.extras.
"""
def op_decorator(*args, **kwargs):
op = op_constructor(*args, **kwargs)
op_region = op.regions[0]
return op_region_builder(op, op_region, terminator)
@wraps(op_decorator)
def maybe_no_args(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return op_decorator()(args[0])
else:
return op_decorator(*args, **kwargs)
return maybe_no_args
@@ -0,0 +1,179 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from functools import partial
from typing import Optional, List
from ..ir import (
Attribute,
BF16Type,
ComplexType,
F16Type,
F32Type,
F64Type,
Float4E2M1FNType,
Float6E2M3FNType,
Float6E3M2FNType,
Float8E3M4Type,
Float8E4M3B11FNUZType,
Float8E4M3FNType,
Float8E4M3Type,
Float8E5M2Type,
Float8E8M0FNUType,
FloatTF32Type,
FunctionType,
IndexType,
IntegerType,
MemRefType,
NoneType,
OpaqueType,
RankedTensorType,
StridedLayoutAttr,
StringAttr,
TupleType,
Type,
UnrankedMemRefType,
UnrankedTensorType,
VectorType,
)
index = lambda: IndexType.get()
def i(width):
return IntegerType.get_signless(width)
def si(width):
return IntegerType.get_signed(width)
def ui(width):
return IntegerType.get_unsigned(width)
bool = lambda: i(1)
i8 = lambda: i(8)
i16 = lambda: i(16)
i32 = lambda: i(32)
i64 = lambda: i(64)
si8 = lambda: si(8)
si16 = lambda: si(16)
si32 = lambda: si(32)
si64 = lambda: si(64)
ui8 = lambda: ui(8)
ui16 = lambda: ui(16)
ui32 = lambda: ui(32)
ui64 = lambda: ui(64)
f16 = lambda: F16Type.get()
f32 = lambda: F32Type.get()
tf32 = lambda: FloatTF32Type.get()
f64 = lambda: F64Type.get()
bf16 = lambda: BF16Type.get()
f8E5M2 = lambda: Float8E5M2Type.get()
f8E4M3 = lambda: Float8E4M3Type.get()
f8E4M3FN = lambda: Float8E4M3FNType.get()
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
f8E3M4 = lambda: Float8E3M4Type.get()
f4E2M1FN = lambda: Float4E2M1FNType.get()
f6E2M3FN = lambda: Float6E2M3FNType.get()
f6E3M2FN = lambda: Float6E3M2FNType.get()
f8E8M0FNU = lambda: Float8E8M0FNUType.get()
none = lambda: NoneType.get()
def complex(type):
return ComplexType.get(type)
def opaque(dialect_namespace, type_data):
return OpaqueType.get(dialect_namespace, type_data)
def _shaped(*shape, element_type: Type = None, type_constructor=None):
if type_constructor is None:
raise ValueError("shaped is an abstract base class - cannot be constructed.")
if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
shape and isinstance(shape[-1], Type) and element_type is not None
):
raise ValueError(
f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
)
if element_type is not None:
type = element_type
sizes = shape
else:
type = shape[-1]
sizes = shape[:-1]
if sizes:
return type_constructor(sizes, type)
else:
return type_constructor(type)
def vector(
*shape,
element_type: Type = None,
scalable: Optional[List[bool]] = None,
scalable_dims: Optional[List[int]] = None,
):
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(
VectorType.get, scalable=scalable, scalable_dims=scalable_dims
),
)
def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
if encoding is not None:
encoding = StringAttr.get(encoding)
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
if encoding is not None:
raise ValueError("UnrankedTensorType does not support encoding.")
return _shaped(
*shape, element_type=element_type, type_constructor=UnrankedTensorType.get
)
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(RankedTensorType.get, encoding=encoding),
)
def memref(
*shape,
element_type: Type = None,
memory_space: Optional[int] = None,
layout: Optional[StridedLayoutAttr] = None,
):
if memory_space is not None:
memory_space = Attribute.parse(str(memory_space))
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
)
return _shaped(
*shape,
element_type=element_type,
type_constructor=partial(
MemRefType.get, memory_space=memory_space, layout=layout
),
)
def tuple(*elements):
return TupleType.get_tuple(elements)
def function(*, inputs, results):
return FunctionType.get(inputs, results)