hand
This commit is contained in:
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,83 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
from ..dialects._ods_common import get_op_result_or_op_results
|
||||
from ..ir import Type, InsertionPoint
|
||||
|
||||
|
||||
def op_region_builder(op, op_region, terminator=None):
|
||||
def builder_wrapper(body_builder):
|
||||
# Add a block with block args having types determined by type hints on the wrapped function.
|
||||
if len(op_region.blocks) == 0:
|
||||
sig = inspect.signature(body_builder)
|
||||
types = [p.annotation for p in sig.parameters.values()]
|
||||
if not (
|
||||
len(types) == len(sig.parameters)
|
||||
and all(isinstance(t, Type) for t in types)
|
||||
):
|
||||
raise ValueError(
|
||||
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
|
||||
)
|
||||
|
||||
op_region.blocks.append(*types)
|
||||
|
||||
with InsertionPoint(op_region.blocks[0]):
|
||||
results = body_builder(*list(op_region.blocks[0].arguments))
|
||||
|
||||
with InsertionPoint(list(op_region.blocks)[-1]):
|
||||
if terminator is not None:
|
||||
res = []
|
||||
if isinstance(results, (tuple, list)):
|
||||
res.extend(results)
|
||||
elif results is not None:
|
||||
res.append(results)
|
||||
terminator(res)
|
||||
|
||||
return get_op_result_or_op_results(op)
|
||||
|
||||
return builder_wrapper
|
||||
|
||||
|
||||
def region_op(op_constructor, terminator=None):
|
||||
"""Decorator to define an MLIR Op specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the Op with types as specified **as type hints on the args of the function**.
|
||||
The block arguments will be passed positionally to the Python function.
|
||||
|
||||
If a terminator is specified then the return from the decorated function will be passed
|
||||
to the terminator as the last statement in the entry block. Note, the API for the terminator
|
||||
is a (possibly empty) list; terminator accepting single values should be wrapped in a
|
||||
`lambda args: term(args[0])`
|
||||
|
||||
The identifier (name) of the function will become:
|
||||
1. A single value result if the Op returns a single value;
|
||||
2. An OpResultList (as a list) if the Op returns multiple values;
|
||||
3. The Operation if the Op returns no results.
|
||||
|
||||
See examples in tensor.py and transform.extras.
|
||||
"""
|
||||
|
||||
def op_decorator(*args, **kwargs):
|
||||
op = op_constructor(*args, **kwargs)
|
||||
op_region = op.regions[0]
|
||||
|
||||
return op_region_builder(op, op_region, terminator)
|
||||
|
||||
@wraps(op_decorator)
|
||||
def maybe_no_args(*args, **kwargs):
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
return op_decorator()(args[0])
|
||||
else:
|
||||
return op_decorator(*args, **kwargs)
|
||||
|
||||
return maybe_no_args
|
||||
@@ -0,0 +1,179 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, List
|
||||
|
||||
from ..ir import (
|
||||
Attribute,
|
||||
BF16Type,
|
||||
ComplexType,
|
||||
F16Type,
|
||||
F32Type,
|
||||
F64Type,
|
||||
Float4E2M1FNType,
|
||||
Float6E2M3FNType,
|
||||
Float6E3M2FNType,
|
||||
Float8E3M4Type,
|
||||
Float8E4M3B11FNUZType,
|
||||
Float8E4M3FNType,
|
||||
Float8E4M3Type,
|
||||
Float8E5M2Type,
|
||||
Float8E8M0FNUType,
|
||||
FloatTF32Type,
|
||||
FunctionType,
|
||||
IndexType,
|
||||
IntegerType,
|
||||
MemRefType,
|
||||
NoneType,
|
||||
OpaqueType,
|
||||
RankedTensorType,
|
||||
StridedLayoutAttr,
|
||||
StringAttr,
|
||||
TupleType,
|
||||
Type,
|
||||
UnrankedMemRefType,
|
||||
UnrankedTensorType,
|
||||
VectorType,
|
||||
)
|
||||
|
||||
index = lambda: IndexType.get()
|
||||
|
||||
|
||||
def i(width):
|
||||
return IntegerType.get_signless(width)
|
||||
|
||||
|
||||
def si(width):
|
||||
return IntegerType.get_signed(width)
|
||||
|
||||
|
||||
def ui(width):
|
||||
return IntegerType.get_unsigned(width)
|
||||
|
||||
|
||||
bool = lambda: i(1)
|
||||
i8 = lambda: i(8)
|
||||
i16 = lambda: i(16)
|
||||
i32 = lambda: i(32)
|
||||
i64 = lambda: i(64)
|
||||
|
||||
si8 = lambda: si(8)
|
||||
si16 = lambda: si(16)
|
||||
si32 = lambda: si(32)
|
||||
si64 = lambda: si(64)
|
||||
|
||||
ui8 = lambda: ui(8)
|
||||
ui16 = lambda: ui(16)
|
||||
ui32 = lambda: ui(32)
|
||||
ui64 = lambda: ui(64)
|
||||
|
||||
f16 = lambda: F16Type.get()
|
||||
f32 = lambda: F32Type.get()
|
||||
tf32 = lambda: FloatTF32Type.get()
|
||||
f64 = lambda: F64Type.get()
|
||||
bf16 = lambda: BF16Type.get()
|
||||
|
||||
f8E5M2 = lambda: Float8E5M2Type.get()
|
||||
f8E4M3 = lambda: Float8E4M3Type.get()
|
||||
f8E4M3FN = lambda: Float8E4M3FNType.get()
|
||||
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
|
||||
f8E3M4 = lambda: Float8E3M4Type.get()
|
||||
f4E2M1FN = lambda: Float4E2M1FNType.get()
|
||||
f6E2M3FN = lambda: Float6E2M3FNType.get()
|
||||
f6E3M2FN = lambda: Float6E3M2FNType.get()
|
||||
f8E8M0FNU = lambda: Float8E8M0FNUType.get()
|
||||
|
||||
none = lambda: NoneType.get()
|
||||
|
||||
|
||||
def complex(type):
|
||||
return ComplexType.get(type)
|
||||
|
||||
|
||||
def opaque(dialect_namespace, type_data):
|
||||
return OpaqueType.get(dialect_namespace, type_data)
|
||||
|
||||
|
||||
def _shaped(*shape, element_type: Type = None, type_constructor=None):
|
||||
if type_constructor is None:
|
||||
raise ValueError("shaped is an abstract base class - cannot be constructed.")
|
||||
if (element_type is None and shape and not isinstance(shape[-1], Type)) or (
|
||||
shape and isinstance(shape[-1], Type) and element_type is not None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type."
|
||||
)
|
||||
if element_type is not None:
|
||||
type = element_type
|
||||
sizes = shape
|
||||
else:
|
||||
type = shape[-1]
|
||||
sizes = shape[:-1]
|
||||
if sizes:
|
||||
return type_constructor(sizes, type)
|
||||
else:
|
||||
return type_constructor(type)
|
||||
|
||||
|
||||
def vector(
|
||||
*shape,
|
||||
element_type: Type = None,
|
||||
scalable: Optional[List[bool]] = None,
|
||||
scalable_dims: Optional[List[int]] = None,
|
||||
):
|
||||
return _shaped(
|
||||
*shape,
|
||||
element_type=element_type,
|
||||
type_constructor=partial(
|
||||
VectorType.get, scalable=scalable, scalable_dims=scalable_dims
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None):
|
||||
if encoding is not None:
|
||||
encoding = StringAttr.get(encoding)
|
||||
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
|
||||
if encoding is not None:
|
||||
raise ValueError("UnrankedTensorType does not support encoding.")
|
||||
return _shaped(
|
||||
*shape, element_type=element_type, type_constructor=UnrankedTensorType.get
|
||||
)
|
||||
return _shaped(
|
||||
*shape,
|
||||
element_type=element_type,
|
||||
type_constructor=partial(RankedTensorType.get, encoding=encoding),
|
||||
)
|
||||
|
||||
|
||||
def memref(
|
||||
*shape,
|
||||
element_type: Type = None,
|
||||
memory_space: Optional[int] = None,
|
||||
layout: Optional[StridedLayoutAttr] = None,
|
||||
):
|
||||
if memory_space is not None:
|
||||
memory_space = Attribute.parse(str(memory_space))
|
||||
if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)):
|
||||
return _shaped(
|
||||
*shape,
|
||||
element_type=element_type,
|
||||
type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space),
|
||||
)
|
||||
return _shaped(
|
||||
*shape,
|
||||
element_type=element_type,
|
||||
type_constructor=partial(
|
||||
MemRefType.get, memory_space=memory_space, layout=layout
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def tuple(*elements):
|
||||
return TupleType.get_tuple(elements)
|
||||
|
||||
|
||||
def function(*, inputs, results):
|
||||
return FunctionType.get(inputs, results)
|
||||
Reference in New Issue
Block a user