hand
This commit is contained in:
@@ -0,0 +1,330 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from ._func_ops_gen import *
|
||||
from ._func_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from ..ir import *
|
||||
from ._ods_common import (
|
||||
get_default_loc_context as _get_default_loc_context,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
|
||||
import inspect
|
||||
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
|
||||
RESULT_ATTRIBUTE_NAME = "res_attrs"
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class ConstantOp(ConstantOp):
|
||||
"""Specialization for the constant op class."""
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.results[0].type
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class FuncOp(FuncOp):
|
||||
"""Specialization for the func op class."""
|
||||
|
||||
def __init__(
|
||||
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
|
||||
):
|
||||
"""
|
||||
Create a FuncOp with the provided `name`, `type`, and `visibility`.
|
||||
- `name` is a string representing the function name.
|
||||
- `type` is either a FunctionType or a pair of list describing inputs and
|
||||
results.
|
||||
- `visibility` is a string matching `public`, `private`, or `nested`. None
|
||||
implies private visibility.
|
||||
- `body_builder` is an optional callback, when provided a new entry block
|
||||
is created and the callback is invoked with the new op as argument within
|
||||
an InsertionPoint context already set for the block. The callback is
|
||||
expected to insert a terminator in the block.
|
||||
"""
|
||||
sym_name = StringAttr.get(str(name))
|
||||
|
||||
# If the type is passed as a tuple, build a FunctionType on the fly.
|
||||
if isinstance(type, tuple):
|
||||
type = FunctionType.get(inputs=type[0], results=type[1])
|
||||
|
||||
type = TypeAttr.get(type)
|
||||
sym_visibility = (
|
||||
StringAttr.get(str(visibility)) if visibility is not None else None
|
||||
)
|
||||
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
if body_builder:
|
||||
entry_block = self.add_entry_block()
|
||||
with InsertionPoint(entry_block):
|
||||
body_builder(self)
|
||||
|
||||
@property
|
||||
def is_external(self):
|
||||
return len(self.regions[0].blocks) == 0
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0]
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
|
||||
|
||||
@property
|
||||
def visibility(self):
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@property
|
||||
def name(self) -> StringAttr:
|
||||
return StringAttr(self.attributes["sym_name"])
|
||||
|
||||
@property
|
||||
def entry_block(self):
|
||||
if self.is_external:
|
||||
raise IndexError("External function does not have a body")
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
|
||||
"""
|
||||
Add an entry block to the function body using the function signature to
|
||||
infer block arguments.
|
||||
Returns the newly created block
|
||||
"""
|
||||
if not self.is_external:
|
||||
raise IndexError("The function already has an entry block!")
|
||||
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
|
||||
return self.body.blocks[0]
|
||||
|
||||
@property
|
||||
def arg_attrs(self):
|
||||
if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
|
||||
return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
|
||||
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
|
||||
if isinstance(attribute, ArrayAttr):
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
|
||||
else:
|
||||
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
|
||||
attribute, context=self.context
|
||||
)
|
||||
|
||||
@property
|
||||
def arguments(self):
|
||||
return self.entry_block.arguments
|
||||
|
||||
@property
|
||||
def result_attrs(self):
|
||||
return self.attributes[RESULT_ATTRIBUTE_NAME]
|
||||
|
||||
@result_attrs.setter
|
||||
def result_attrs(self, attribute: ArrayAttr):
|
||||
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
|
||||
|
||||
@classmethod
|
||||
def from_py_func(
|
||||
FuncOp,
|
||||
*inputs: Type,
|
||||
results: Optional[Sequence[Type]] = None,
|
||||
name: Optional[str] = None,
|
||||
):
|
||||
"""Decorator to define an MLIR FuncOp specified as a python function.
|
||||
|
||||
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
||||
active for the current thread (i.e. established in a `with` block).
|
||||
|
||||
When applied as a decorator to a Python function, an entry block will
|
||||
be constructed for the FuncOp with types as specified in `*inputs`. The
|
||||
block arguments will be passed positionally to the Python function. In
|
||||
addition, if the Python function accepts keyword arguments generally or
|
||||
has a corresponding keyword argument, the following will be passed:
|
||||
* `func_op`: The `func` op being defined.
|
||||
|
||||
By default, the function name will be the Python function `__name__`. This
|
||||
can be overriden by passing the `name` argument to the decorator.
|
||||
|
||||
If `results` is not specified, then the decorator will implicitly
|
||||
insert a `ReturnOp` with the `Value`'s returned from the decorated
|
||||
function. It will also set the `FuncOp` type with the actual return
|
||||
value types. If `results` is specified, then the decorated function
|
||||
must return `None` and no implicit `ReturnOp` is added (nor are the result
|
||||
types updated). The implicit behavior is intended for simple, single-block
|
||||
cases, and users should specify result types explicitly for any complicated
|
||||
cases.
|
||||
|
||||
The decorated function can further be called from Python and will insert
|
||||
a `CallOp` at the then-current insertion point, returning either None (
|
||||
if no return values), a unary Value (for one result), or a list of Values).
|
||||
This mechanism cannot be used to emit recursive calls (by construction).
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
from . import func
|
||||
|
||||
# Introspect the callable for optional features.
|
||||
sig = inspect.signature(f)
|
||||
has_arg_func_op = False
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == param.VAR_KEYWORD:
|
||||
has_arg_func_op = True
|
||||
if param.name == "func_op" and (
|
||||
param.kind == param.POSITIONAL_OR_KEYWORD
|
||||
or param.kind == param.KEYWORD_ONLY
|
||||
):
|
||||
has_arg_func_op = True
|
||||
|
||||
# Emit the FuncOp.
|
||||
implicit_return = results is None
|
||||
symbol_name = name or f.__name__
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=[] if implicit_return else results
|
||||
)
|
||||
func_op = FuncOp(name=symbol_name, type=function_type)
|
||||
with InsertionPoint(func_op.add_entry_block()):
|
||||
func_args = func_op.entry_block.arguments
|
||||
func_kwargs = {}
|
||||
if has_arg_func_op:
|
||||
func_kwargs["func_op"] = func_op
|
||||
return_values = f(*func_args, **func_kwargs)
|
||||
if not implicit_return:
|
||||
return_types = list(results)
|
||||
assert return_values is None, (
|
||||
"Capturing a python function with explicit `results=` "
|
||||
"requires that the wrapped function returns None."
|
||||
)
|
||||
else:
|
||||
# Coerce return values, add ReturnOp and rewrite func type.
|
||||
if return_values is None:
|
||||
return_values = []
|
||||
elif isinstance(return_values, tuple):
|
||||
return_values = list(return_values)
|
||||
elif isinstance(return_values, Value):
|
||||
# Returning a single value is fine, coerce it into a list.
|
||||
return_values = [return_values]
|
||||
elif isinstance(return_values, OpView):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.operation.results
|
||||
elif isinstance(return_values, Operation):
|
||||
# Returning a single operation is fine, coerce its results a list.
|
||||
return_values = return_values.results
|
||||
else:
|
||||
return_values = list(return_values)
|
||||
func.ReturnOp(return_values)
|
||||
# Recompute the function type.
|
||||
return_types = [v.type for v in return_values]
|
||||
function_type = FunctionType.get(
|
||||
inputs=inputs, results=return_types
|
||||
)
|
||||
func_op.attributes["function_type"] = TypeAttr.get(function_type)
|
||||
|
||||
def emit_call_op(*call_args):
|
||||
call_op = func.CallOp(
|
||||
return_types, FlatSymbolRefAttr.get(symbol_name), call_args
|
||||
)
|
||||
if return_types is None:
|
||||
return None
|
||||
elif len(return_types) == 1:
|
||||
return call_op.result
|
||||
else:
|
||||
return call_op.results
|
||||
|
||||
wrapped = emit_call_op
|
||||
wrapped.__name__ = f.__name__
|
||||
wrapped.func_op = func_op
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
func = FuncOp.from_py_func
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class CallOp(CallOp):
|
||||
"""Specialization for the call op class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calleeOrResults: Union[FuncOp, List[Type]],
|
||||
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
|
||||
arguments: Optional[List] = None,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
"""Creates an call operation.
|
||||
|
||||
The constructor accepts three different forms:
|
||||
|
||||
1. A function op to be called followed by a list of arguments.
|
||||
2. A list of result types, followed by the name of the function to be
|
||||
called as string, following by a list of arguments.
|
||||
3. A list of result types, followed by the name of the function to be
|
||||
called as symbol reference attribute, followed by a list of arguments.
|
||||
|
||||
For example
|
||||
|
||||
f = func.FuncOp("foo", ...)
|
||||
func.CallOp(f, [args])
|
||||
func.CallOp([result_types], "foo", [args])
|
||||
|
||||
In all cases, the location and insertion point may be specified as keyword
|
||||
arguments if not provided by the surrounding context managers.
|
||||
"""
|
||||
|
||||
# TODO: consider supporting constructor "overloads", e.g., through a custom
|
||||
# or pybind-provided metaclass.
|
||||
if isinstance(calleeOrResults, FuncOp):
|
||||
if not isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function, expected "
|
||||
+ "the second argument to be a list of call arguments, "
|
||||
+ f"got {type(argumentsOrCallee)}"
|
||||
)
|
||||
if arguments is not None:
|
||||
raise ValueError(
|
||||
"unexpected third argument when constructing a call"
|
||||
+ "to a function"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
calleeOrResults.type.results,
|
||||
FlatSymbolRefAttr.get(
|
||||
calleeOrResults.name.value, context=_get_default_loc_context(loc)
|
||||
),
|
||||
argumentsOrCallee,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(argumentsOrCallee, list):
|
||||
raise ValueError(
|
||||
"when constructing a call to a function by name, "
|
||||
+ "expected the second argument to be a string or a "
|
||||
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
|
||||
)
|
||||
|
||||
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
|
||||
super().__init__(
|
||||
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
|
||||
)
|
||||
elif isinstance(argumentsOrCallee, str):
|
||||
super().__init__(
|
||||
calleeOrResults,
|
||||
FlatSymbolRefAttr.get(
|
||||
argumentsOrCallee, context=_get_default_loc_context(loc)
|
||||
),
|
||||
arguments,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
Reference in New Issue
Block a user