hand
This commit is contained in:
@@ -0,0 +1,83 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
from ..dialects._ods_common import get_op_result_or_op_results
|
||||
from ..ir import Type, InsertionPoint
|
||||
|
||||
|
||||
def op_region_builder(op, op_region, terminator=None):
|
||||
def builder_wrapper(body_builder):
|
||||
# Add a block with block args having types determined by type hints on the wrapped function.
|
||||
if len(op_region.blocks) == 0:
|
||||
sig = inspect.signature(body_builder)
|
||||
types = [p.annotation for p in sig.parameters.values()]
|
||||
if not (
|
||||
len(types) == len(sig.parameters)
|
||||
and all(isinstance(t, Type) for t in types)
|
||||
):
|
||||
raise ValueError(
|
||||
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
|
||||
)
|
||||
|
||||
op_region.blocks.append(*types)
|
||||
|
||||
with InsertionPoint(op_region.blocks[0]):
|
||||
results = body_builder(*list(op_region.blocks[0].arguments))
|
||||
|
||||
with InsertionPoint(list(op_region.blocks)[-1]):
|
||||
if terminator is not None:
|
||||
res = []
|
||||
if isinstance(results, (tuple, list)):
|
||||
res.extend(results)
|
||||
elif results is not None:
|
||||
res.append(results)
|
||||
terminator(res)
|
||||
|
||||
return get_op_result_or_op_results(op)
|
||||
|
||||
return builder_wrapper
|
||||
|
||||
|
||||
def region_op(op_constructor, terminator=None):
|
||||
"""Decorator to define an MLIR Op specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the Op with types as specified **as type hints on the args of the function**.
|
||||
The block arguments will be passed positionally to the Python function.
|
||||
|
||||
If a terminator is specified then the return from the decorated function will be passed
|
||||
to the terminator as the last statement in the entry block. Note, the API for the terminator
|
||||
is a (possibly empty) list; terminator accepting single values should be wrapped in a
|
||||
`lambda args: term(args[0])`
|
||||
|
||||
The identifier (name) of the function will become:
|
||||
1. A single value result if the Op returns a single value;
|
||||
2. An OpResultList (as a list) if the Op returns multiple values;
|
||||
3. The Operation if the Op returns no results.
|
||||
|
||||
See examples in tensor.py and transform.extras.
|
||||
"""
|
||||
|
||||
def op_decorator(*args, **kwargs):
|
||||
op = op_constructor(*args, **kwargs)
|
||||
op_region = op.regions[0]
|
||||
|
||||
return op_region_builder(op, op_region, terminator)
|
||||
|
||||
@wraps(op_decorator)
|
||||
def maybe_no_args(*args, **kwargs):
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
return op_decorator()(args[0])
|
||||
else:
|
||||
return op_decorator(*args, **kwargs)
|
||||
|
||||
return maybe_no_args
|
||||
Reference in New Issue
Block a user