hand
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,218 @@
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
import os
|
||||
|
||||
_this_dir = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def get_lib_dirs() -> Sequence[str]:
|
||||
"""Gets the lib directory for linking to shared libraries.
|
||||
|
||||
On some platforms, the package may need to be built specially to export
|
||||
development libraries.
|
||||
"""
|
||||
return [_this_dir]
|
||||
|
||||
|
||||
def get_include_dirs() -> Sequence[str]:
|
||||
"""Gets the include directory for compiling against exported C libraries.
|
||||
|
||||
Depending on how the package was build, development C libraries may or may
|
||||
not be present.
|
||||
"""
|
||||
return [os.path.join(_this_dir, "include")]
|
||||
|
||||
|
||||
# Perform Python level site initialization. This involves:
|
||||
# 1. Attempting to load initializer modules, specific to the distribution.
|
||||
# 2. Defining the concrete mlir.ir.Context that does site specific
|
||||
# initialization.
|
||||
# 3. Registering container classes with their respective protocols.
|
||||
#
|
||||
# Aside from just being far more convenient to do this at the Python level,
|
||||
# it is actually quite hard/impossible to have such __init__ hooks, given
|
||||
# the pybind memory model (i.e. there is not a Python reference to the object
|
||||
# in the scope of the base class __init__).
|
||||
#
|
||||
# For #1, we:
|
||||
# a. Probe for modules named '_mlirRegisterEverything' and
|
||||
# '_site_initialize_{i}', where 'i' is a number starting at zero and
|
||||
# proceeding so long as a module with the name is found.
|
||||
# b. If the module has a 'register_dialects' attribute, it will be called
|
||||
# immediately with a DialectRegistry to populate.
|
||||
# c. If the module has a 'context_init_hook', it will be added to a list
|
||||
# of callbacks that are invoked as the last step of Context
|
||||
# initialization (and passed the Context under construction).
|
||||
# d. If the module has a 'disable_multithreading' attribute, it will be
|
||||
# taken as a boolean. If it is True for any initializer, then the
|
||||
# default behavior of enabling multithreading on the context
|
||||
# will be suppressed. This complies with the original behavior of all
|
||||
# contexts being created with multithreading enabled while allowing
|
||||
# this behavior to be changed if needed (i.e. if a context_init_hook
|
||||
# explicitly sets up multithreading).
|
||||
#
|
||||
# This facility allows downstreams to customize Context creation to their
|
||||
# needs.
|
||||
|
||||
_dialect_registry = None
|
||||
_load_on_create_dialects = None
|
||||
|
||||
|
||||
def get_dialect_registry():
|
||||
global _dialect_registry
|
||||
|
||||
if _dialect_registry is None:
|
||||
from ._mlir import ir
|
||||
|
||||
_dialect_registry = ir.DialectRegistry()
|
||||
|
||||
return _dialect_registry
|
||||
|
||||
|
||||
def append_load_on_create_dialect(dialect: str):
|
||||
global _load_on_create_dialects
|
||||
if _load_on_create_dialects is None:
|
||||
_load_on_create_dialects = [dialect]
|
||||
else:
|
||||
_load_on_create_dialects.append(dialect)
|
||||
|
||||
|
||||
def get_load_on_create_dialects():
|
||||
global _load_on_create_dialects
|
||||
if _load_on_create_dialects is None:
|
||||
_load_on_create_dialects = []
|
||||
return _load_on_create_dialects
|
||||
|
||||
|
||||
def _site_initialize():
|
||||
import importlib
|
||||
import itertools
|
||||
import logging
|
||||
from ._mlir import ir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
post_init_hooks = []
|
||||
disable_multithreading = False
|
||||
# This flag disables eagerly loading all dialects. Eagerly loading is often
|
||||
# not the desired behavior (see
|
||||
# https://github.com/llvm/llvm-project/issues/56037), and the logic is that
|
||||
# if any module has this attribute set, then we don't load all (e.g., it's
|
||||
# being used in a solution where the loading is controlled).
|
||||
disable_load_all_available_dialects = False
|
||||
|
||||
def process_initializer_module(module_name):
|
||||
nonlocal disable_multithreading
|
||||
nonlocal disable_load_all_available_dialects
|
||||
try:
|
||||
m = importlib.import_module(f".{module_name}", __name__)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
except ImportError:
|
||||
message = (
|
||||
f"Error importing mlir initializer {module_name}. This may "
|
||||
"happen in unclean incremental builds but is likely a real bug if "
|
||||
"encountered otherwise and the MLIR Python API may not function."
|
||||
)
|
||||
logger.warning(message, exc_info=True)
|
||||
return False
|
||||
|
||||
logger.debug("Initializing MLIR with module: %s", module_name)
|
||||
if hasattr(m, "register_dialects"):
|
||||
logger.debug("Registering dialects from initializer %r", m)
|
||||
m.register_dialects(get_dialect_registry())
|
||||
if hasattr(m, "context_init_hook"):
|
||||
logger.debug("Adding context init hook from %r", m)
|
||||
post_init_hooks.append(m.context_init_hook)
|
||||
if hasattr(m, "disable_multithreading"):
|
||||
if bool(m.disable_multithreading):
|
||||
logger.debug("Disabling multi-threading for context")
|
||||
disable_multithreading = True
|
||||
if hasattr(m, "disable_load_all_available_dialects"):
|
||||
disable_load_all_available_dialects = True
|
||||
return True
|
||||
|
||||
# If _mlirRegisterEverything is built, then include it as an initializer
|
||||
# module.
|
||||
init_module = None
|
||||
if process_initializer_module("_mlirRegisterEverything"):
|
||||
init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
|
||||
|
||||
# Load all _site_initialize_{i} modules, where 'i' is a number starting
|
||||
# at 0.
|
||||
for i in itertools.count():
|
||||
module_name = f"_site_initialize_{i}"
|
||||
if not process_initializer_module(module_name):
|
||||
break
|
||||
|
||||
ir._Context = ir.Context
|
||||
|
||||
class Context(ir._Context):
|
||||
def __init__(
|
||||
self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.append_dialect_registry(get_dialect_registry())
|
||||
for hook in post_init_hooks:
|
||||
hook(self)
|
||||
if disable_multithreading and thread_pool is not None:
|
||||
raise ValueError(
|
||||
"Context constructor has given thread_pool argument, "
|
||||
"but disable_multithreading flag is True. "
|
||||
"Please, set thread_pool argument to None or "
|
||||
"set disable_multithreading flag to False."
|
||||
)
|
||||
if not disable_multithreading:
|
||||
if thread_pool is None:
|
||||
self.enable_multithreading(True)
|
||||
else:
|
||||
self.set_thread_pool(thread_pool)
|
||||
if load_on_create_dialects is not None:
|
||||
logger.debug(
|
||||
"Loading all dialects from load_on_create_dialects arg %r",
|
||||
load_on_create_dialects,
|
||||
)
|
||||
for dialect in load_on_create_dialects:
|
||||
# This triggers loading the dialect into the context.
|
||||
_ = self.dialects[dialect]
|
||||
else:
|
||||
if disable_load_all_available_dialects:
|
||||
dialects = get_load_on_create_dialects()
|
||||
if dialects:
|
||||
logger.debug(
|
||||
"Loading all dialects from global load_on_create_dialects %r",
|
||||
dialects,
|
||||
)
|
||||
for dialect in dialects:
|
||||
# This triggers loading the dialect into the context.
|
||||
_ = self.dialects[dialect]
|
||||
else:
|
||||
logger.debug("Loading all available dialects")
|
||||
self.load_all_available_dialects()
|
||||
if init_module:
|
||||
logger.debug(
|
||||
"Registering translations from initializer %r", init_module
|
||||
)
|
||||
init_module.register_llvm_translations(self)
|
||||
|
||||
ir.Context = Context
|
||||
|
||||
# Register containers as Sequences, so they can be used with `match`.
|
||||
|
||||
Sequence.register(ir.BlockArgumentList)
|
||||
Sequence.register(ir.BlockList)
|
||||
Sequence.register(ir.BlockSuccessors)
|
||||
Sequence.register(ir.BlockPredecessors)
|
||||
Sequence.register(ir.OperationList)
|
||||
Sequence.register(ir.OpOperandList)
|
||||
Sequence.register(ir.OpOperands)
|
||||
Sequence.register(ir.OpResultList)
|
||||
Sequence.register(ir.OpSuccessors)
|
||||
Sequence.register(ir.RegionSequence)
|
||||
Mapping.register(ir.OpAttributeMap)
|
||||
|
||||
|
||||
_site_initialize()
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,78 @@
|
||||
"""chlo main python extension"""
|
||||
|
||||
from typing import Self
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
class ComparisonDirectionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonDirection attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ComparisonTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RaggedDotDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], lhs_ragged_dimensions: Sequence[int], rhs_group_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a RaggedDotDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def lhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def lhs_ragged_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_group_dimensions(self) -> list[int]: ...
|
||||
|
||||
class PrecisionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Precision attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
Binary file not shown.
@@ -0,0 +1,22 @@
|
||||
"""Registers upstream MLIR dialects used by JAX."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
def register_dialects(arg: ir.DialectRegistry, /) -> None: ...
|
||||
|
||||
def enter_multi_threaded_execution(arg: ir.Context, /) -> None: ...
|
||||
|
||||
def exit_multi_threaded_execution(arg: ir.Context, /) -> None: ...
|
||||
|
||||
def inlined_func_call(callee: ir.Operation, args: Sequence[ir.Value], block: ir.Block, loc: ir.Location | None = None) -> list[ir.Value]:
|
||||
"""
|
||||
Makes an inlined call to a function containing a single block with a single return op.
|
||||
"""
|
||||
|
||||
class TracebackToLocationCache:
|
||||
def __init__(self, code_to_filename: Callable, frame_limit: int, context: ir.Context | None = None) -> None: ...
|
||||
|
||||
def get(self, traceback: Traceback, /) -> ir.Location: ...
|
||||
Binary file not shown.
Binary file not shown.
+66
@@ -0,0 +1,66 @@
|
||||
"""MLIR Python Native Extension"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
|
||||
from . import (
|
||||
ir as ir,
|
||||
passmanager as passmanager,
|
||||
rewrite as rewrite
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
U = TypeVar("U")
|
||||
|
||||
class _Globals:
|
||||
@property
|
||||
def dialect_search_modules(self) -> list[str]: ...
|
||||
|
||||
@dialect_search_modules.setter
|
||||
def dialect_search_modules(self, arg: Sequence[str], /) -> None: ...
|
||||
|
||||
def append_dialect_search_prefix(self, module_name: str) -> None: ...
|
||||
|
||||
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
|
||||
|
||||
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: object, *, replace: bool = False) -> None:
|
||||
"""Testing hook for directly registering a dialect"""
|
||||
|
||||
def _register_operation_impl(self, operation_name: str, operation_class: object, *, replace: bool = False) -> None:
|
||||
"""Testing hook for directly registering an operation"""
|
||||
|
||||
def loc_tracebacks_enabled(self) -> bool: ...
|
||||
|
||||
def set_loc_tracebacks_enabled(self, arg: bool, /) -> None: ...
|
||||
|
||||
def loc_tracebacks_frame_limit(self) -> int: ...
|
||||
|
||||
def set_loc_tracebacks_frame_limit(self, arg: int | None) -> None: ...
|
||||
|
||||
def register_traceback_file_inclusion(self, arg: str, /) -> None: ...
|
||||
|
||||
def register_traceback_file_exclusion(self, arg: str, /) -> None: ...
|
||||
|
||||
globals: _Globals = ...
|
||||
|
||||
def register_dialect(dialect_class: type) -> type:
|
||||
"""Class decorator for registering a custom Dialect wrapper"""
|
||||
|
||||
def register_operation(dialect_class: type, *, replace: bool = False) -> Callable[[type[T]], type[T]]:
|
||||
"""
|
||||
Produce a class decorator for registering an Operation class as part of a dialect
|
||||
"""
|
||||
|
||||
def register_op_adaptor(op_class: type, *, replace: bool = False) -> Callable[[type[T]], type[T]]:
|
||||
"""
|
||||
Produce a class decorator for registering an OpAdaptor class for an operation.
|
||||
"""
|
||||
|
||||
def register_type_caster(typeid: _ir.TypeID, *, replace: bool = False) -> Callable[[Callable[[T], U]], Callable[[T], U]]:
|
||||
"""Register a type caster for casting MLIR types to custom user types."""
|
||||
|
||||
def register_value_caster(typeid: _ir.TypeID, *, replace: bool = False) -> Callable[[Callable[[T], U]], Callable[[T], U]]:
|
||||
"""Register a value caster for casting MLIR values to custom user values."""
|
||||
File diff suppressed because it is too large
Load Diff
+79
@@ -0,0 +1,79 @@
|
||||
"""MLIR Pass Management Bindings"""
|
||||
|
||||
from collections.abc import Callable
|
||||
import enum
|
||||
from typing import overload
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class PassDisplayMode(enum.Enum):
|
||||
LIST = 0
|
||||
|
||||
PIPELINE = 1
|
||||
|
||||
class ExternalPass:
|
||||
def signal_pass_failure(self) -> None: ...
|
||||
|
||||
class PassManager:
|
||||
def __init__(self, anchor_op: str = 'any', context: ir.Context | None = None) -> None:
|
||||
"""Create a new PassManager for the current (or provided) Context."""
|
||||
|
||||
@property
|
||||
def _CAPIPtr(self) -> object: ...
|
||||
|
||||
def _CAPICreate(self) -> object: ...
|
||||
|
||||
def _testing_release(self) -> None:
|
||||
"""Releases (leaks) the backing pass manager (testing)"""
|
||||
|
||||
def enable_ir_printing(self, print_before_all: bool = False, print_after_all: bool = True, print_module_scope: bool = False, print_after_change: bool = False, print_after_failure: bool = False, large_elements_limit: int | None = None, large_resource_limit: int | None = None, enable_debug_info: bool = False, print_generic_op_form: bool = False, tree_printing_dir_path: str | None = None) -> None:
|
||||
"""Enable IR printing, default as mlir-print-ir-after-all."""
|
||||
|
||||
def enable_verifier(self, enable: bool) -> None:
|
||||
"""Enable / disable verify-each."""
|
||||
|
||||
def enable_timing(self) -> None:
|
||||
"""Enable pass timing."""
|
||||
|
||||
def enable_statistics(self, displayMode: PassDisplayMode = PassDisplayMode.PIPELINE) -> None:
|
||||
"""Enable pass statistics."""
|
||||
|
||||
@staticmethod
|
||||
def parse(pipeline: str, context: ir.Context | None = None) -> PassManager:
|
||||
"""
|
||||
Parse a textual pass-pipeline and return a top-level PassManager that can be applied on a Module. Throw a ValueError if the pipeline can't be parsed
|
||||
"""
|
||||
|
||||
@overload
|
||||
def add(self, pipeline: str) -> None:
|
||||
"""
|
||||
Add textual pipeline elements to the pass manager. Throws a ValueError if the pipeline can't be parsed.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def add(self, run: Callable, name: str | None = None, argument: str | None = '', description: str | None = '', op_name: str | None = '') -> None:
|
||||
"""
|
||||
Add a python-defined pass to the current pipeline of the pass manager.
|
||||
|
||||
Args:
|
||||
run: A callable with signature ``(op: ir.Operation, pass_: ExternalPass) -> None``.
|
||||
Called when the pass executes. It receives the operation to be processed and
|
||||
the current ``ExternalPass`` instance.
|
||||
Use ``pass_.signal_pass_failure()`` to signal failure.
|
||||
name: The name of the pass. Defaults to ``run.__name__``.
|
||||
argument: The command-line argument for the pass. Defaults to empty.
|
||||
description: The description of the pass. Defaults to empty.
|
||||
op_name: The name of the operation this pass operates on.
|
||||
It will be a generic operation pass if not specified.
|
||||
"""
|
||||
|
||||
def run(self, operation: ir._OperationBase) -> None:
|
||||
"""
|
||||
Run the pass manager on the provided operation, raising an MLIRError on failure.
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Print the textual representation for this PassManager, suitable to be passed to `parse` for round-tripping.
|
||||
"""
|
||||
+241
@@ -0,0 +1,241 @@
|
||||
"""MLIR Rewrite Bindings"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import enum
|
||||
from typing import overload
|
||||
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class GreedyRewriteStrictness(enum.Enum):
|
||||
ANY_OP = 0
|
||||
|
||||
EXISTING_AND_NEW_OPS = 1
|
||||
|
||||
EXISTING_OPS = 2
|
||||
|
||||
class GreedySimplifyRegionLevel(enum.Enum):
|
||||
DISABLED = 0
|
||||
|
||||
NORMAL = 1
|
||||
|
||||
AGGRESSIVE = 2
|
||||
|
||||
class DialectConversionFoldingMode(enum.Enum):
|
||||
NEVER = 0
|
||||
|
||||
BEFORE_PATTERNS = 1
|
||||
|
||||
AFTER_PATTERNS = 2
|
||||
|
||||
class PatternRewriter:
|
||||
@property
|
||||
def ip(self) -> ir.InsertionPoint:
|
||||
"""The current insertion point of the PatternRewriter."""
|
||||
|
||||
@overload
|
||||
def replace_op(self, op: ir._OperationBase, new_op: ir._OperationBase) -> None:
|
||||
"""Replace an operation with a new operation."""
|
||||
|
||||
@overload
|
||||
def replace_op(self, op: ir._OperationBase, values: Sequence[ir.Value]) -> None:
|
||||
"""Replace an operation with a list of values."""
|
||||
|
||||
def erase_op(self, op: ir._OperationBase) -> None:
|
||||
"""Erase an operation."""
|
||||
|
||||
class RewritePatternSet:
|
||||
def __init__(self, context: _ir.Context | None = None) -> None: ...
|
||||
|
||||
def add(self, root: object, fn: Callable, benefit: int = 1) -> None:
|
||||
"""
|
||||
Add a new rewrite pattern on the specified root operation, using
|
||||
the provided callable for matching and rewriting, and assign it
|
||||
the given benefit.
|
||||
|
||||
Args:
|
||||
root: The root operation to which this pattern applies. This may
|
||||
be either an OpView subclass or an operation name.
|
||||
fn: The callable to use for matching and rewriting, which takes
|
||||
an operation and a pattern rewriter. The match is considered
|
||||
successful iff the callable returns a falsy value.
|
||||
benefit: The benefit of the pattern, defaulting to 1.
|
||||
"""
|
||||
|
||||
def add_conversion(self, root: object, fn: Callable, type_converter: TypeConverter, benefit: int = 1) -> None:
|
||||
"""
|
||||
Add a new conversion pattern on the specified root operation,
|
||||
using the provided callable for matching and rewriting,
|
||||
and assign it the given benefit.
|
||||
|
||||
Args:
|
||||
root: The root operation to which this pattern applies.
|
||||
This may be either an OpView subclass or an operation name.
|
||||
fn: The callable to use for matching and rewriting, which takes an
|
||||
operation, its adaptor, the type converter and a pattern
|
||||
rewriter. The match is considered successful iff the callable
|
||||
returns a falsy value.
|
||||
type_converter: The type converter to convert types in the IR.
|
||||
benefit: The benefit of the pattern, defaulting to 1.
|
||||
"""
|
||||
|
||||
def freeze(self) -> FrozenRewritePatternSet:
|
||||
"""Freeze the pattern set into a frozen one."""
|
||||
|
||||
class ConversionPatternRewriter(PatternRewriter):
|
||||
def convert_region_types(self, arg0: ir.Region, arg1: TypeConverter, /) -> None: ...
|
||||
|
||||
class ConversionTarget:
|
||||
def __init__(self, context: _ir.Context | None = None) -> None: ...
|
||||
|
||||
def add_legal_op(self, *ops) -> None:
|
||||
"""Mark the given operations as legal."""
|
||||
|
||||
def add_illegal_op(self, *ops) -> None:
|
||||
"""Mark the given operations as illegal."""
|
||||
|
||||
def add_legal_dialect(self, *dialects) -> None:
|
||||
"""Mark the given dialects as legal."""
|
||||
|
||||
def add_illegal_dialect(self, *dialects) -> None:
|
||||
"""Mark the given dialect as illegal."""
|
||||
|
||||
class TypeConverter:
|
||||
def __init__(self) -> None:
|
||||
"""Create a new TypeConverter."""
|
||||
|
||||
def add_conversion(self, convert: Callable) -> None:
|
||||
"""Register a type conversion function."""
|
||||
|
||||
def convert_type(self, type: ir.Type) -> ir.Type | None:
|
||||
"""Convert the given type. Returns None if conversion fails."""
|
||||
|
||||
class PDLResultList:
|
||||
@overload
|
||||
def append(self, arg: ir.Value, /) -> None: ...
|
||||
|
||||
@overload
|
||||
def append(self, arg: ir.Operation, /) -> None: ...
|
||||
|
||||
@overload
|
||||
def append(self, arg: ir.Type, /) -> None: ...
|
||||
|
||||
@overload
|
||||
def append(self, arg: ir.Attribute, /) -> None: ...
|
||||
|
||||
class PDLModule:
|
||||
@overload
|
||||
def __init__(self, module: ir.Module) -> None:
|
||||
"""Create a PDL module from the given module."""
|
||||
|
||||
@overload
|
||||
def __init__(self, module: ir.Module) -> None: ...
|
||||
|
||||
def freeze(self) -> FrozenRewritePatternSet: ...
|
||||
|
||||
def register_rewrite_function(self, arg0: str, arg1: Callable, /) -> None: ...
|
||||
|
||||
def register_constraint_function(self, arg0: str, arg1: Callable, /) -> None: ...
|
||||
|
||||
class GreedyRewriteConfig:
|
||||
def __init__(self) -> None:
|
||||
"""Create a greedy rewrite driver config with defaults"""
|
||||
|
||||
@property
|
||||
def max_iterations(self) -> int:
|
||||
"""Maximum number of iterations"""
|
||||
|
||||
@max_iterations.setter
|
||||
def max_iterations(self, arg: int, /) -> None: ...
|
||||
|
||||
@property
|
||||
def max_num_rewrites(self) -> int:
|
||||
"""Maximum number of rewrites per iteration"""
|
||||
|
||||
@max_num_rewrites.setter
|
||||
def max_num_rewrites(self, arg: int, /) -> None: ...
|
||||
|
||||
@property
|
||||
def use_top_down_traversal(self) -> bool:
|
||||
"""Whether to use top-down traversal"""
|
||||
|
||||
@use_top_down_traversal.setter
|
||||
def use_top_down_traversal(self, arg: bool, /) -> None: ...
|
||||
|
||||
@property
|
||||
def enable_folding(self) -> bool:
|
||||
"""Enable or disable folding"""
|
||||
|
||||
@enable_folding.setter
|
||||
def enable_folding(self, arg: bool, /) -> None: ...
|
||||
|
||||
@property
|
||||
def strictness(self) -> GreedyRewriteStrictness:
|
||||
"""Rewrite strictness level"""
|
||||
|
||||
@strictness.setter
|
||||
def strictness(self, arg: GreedyRewriteStrictness, /) -> None: ...
|
||||
|
||||
@property
|
||||
def region_simplification_level(self) -> GreedySimplifyRegionLevel:
|
||||
"""Region simplification level"""
|
||||
|
||||
@region_simplification_level.setter
|
||||
def region_simplification_level(self, arg: GreedySimplifyRegionLevel, /) -> None: ...
|
||||
|
||||
@property
|
||||
def enable_constant_cse(self) -> bool:
|
||||
"""Enable or disable constant CSE"""
|
||||
|
||||
@enable_constant_cse.setter
|
||||
def enable_constant_cse(self, arg: bool, /) -> None: ...
|
||||
|
||||
class ConversionConfig:
|
||||
def __init__(self) -> None:
|
||||
"""Create a conversion config with defaults"""
|
||||
|
||||
@property
|
||||
def folding_mode(self) -> DialectConversionFoldingMode:
|
||||
"""folding behavior during dialect conversion"""
|
||||
|
||||
@folding_mode.setter
|
||||
def folding_mode(self, arg: DialectConversionFoldingMode, /) -> None: ...
|
||||
|
||||
@property
|
||||
def build_materializations(self) -> bool:
|
||||
"""
|
||||
Whether the dialect conversion attempts to build source/target materializations
|
||||
"""
|
||||
|
||||
@build_materializations.setter
|
||||
def build_materializations(self, arg: bool, /) -> None: ...
|
||||
|
||||
class FrozenRewritePatternSet:
|
||||
@property
|
||||
def _CAPIPtr(self) -> object: ...
|
||||
|
||||
def _CAPICreate(self) -> object: ...
|
||||
|
||||
@overload
|
||||
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet, config: GreedyRewriteConfig | None = None) -> None:
|
||||
"""
|
||||
Applys the given patterns to the given module greedily while folding results.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet, config: GreedyRewriteConfig | None = None) -> None:
|
||||
"""
|
||||
Applys the given patterns to the given op greedily while folding results.
|
||||
"""
|
||||
|
||||
def walk_and_apply_patterns(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None:
|
||||
"""
|
||||
Applies the given patterns to the given op by a fast walk-based driver.
|
||||
"""
|
||||
|
||||
def apply_partial_conversion(op: ir._OperationBase, target: ConversionTarget, set: FrozenRewritePatternSet, config: ConversionConfig | None = None) -> None:
|
||||
"""Applies a partial conversion on the given operation."""
|
||||
|
||||
def apply_full_conversion(op: ir._OperationBase, target: ConversionTarget, set: FrozenRewritePatternSet, config: ConversionConfig | None = None) -> None:
|
||||
"""Applies a full conversion on the given operation."""
|
||||
+61
@@ -0,0 +1,61 @@
|
||||
"""MLIR GPU Dialect"""
|
||||
|
||||
from typing import ClassVar, Final
|
||||
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class AsyncTokenType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(context: _ir.Context | None = None) -> AsyncTokenType:
|
||||
"""Gets an instance of AsyncTokenType in the same context"""
|
||||
|
||||
class ObjectAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
attr_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(target: ir.Attribute, format: int, object: bytes, properties: ir.DictAttr | None = None, kernels: ir.Attribute | None = None, context: _ir.Context | None = None) -> ObjectAttr:
|
||||
"""Gets a gpu.object from parameters."""
|
||||
|
||||
@property
|
||||
def target(self) -> ir.Attribute: ...
|
||||
|
||||
@property
|
||||
def format(self) -> int: ...
|
||||
|
||||
@property
|
||||
def object(self) -> bytes: ...
|
||||
|
||||
@property
|
||||
def properties(self) -> ir.DictAttr | None: ...
|
||||
|
||||
@property
|
||||
def kernels(self) -> ir.Attribute | None: ...
|
||||
BIN
Binary file not shown.
+207
@@ -0,0 +1,207 @@
|
||||
"""MLIR LLVM Dialect"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import ClassVar, Final
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
class StructType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get_literal(elements: Sequence[ir.Type], *, packed: bool = False, loc: _ir.Location | None = None, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
@staticmethod
|
||||
def get_literal_unchecked(elements: Sequence[ir.Type], *, packed: bool = False, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
@staticmethod
|
||||
def get_identified(name: str, *, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
@staticmethod
|
||||
def get_opaque(name: str, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
def set_body(self, elements: Sequence[ir.Type], *, packed: bool = False) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def new_identified(name: str, elements: Sequence[ir.Type], *, packed: bool = False, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
@property
|
||||
def name(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def body(self) -> object: ...
|
||||
|
||||
@property
|
||||
def packed(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def opaque(self) -> bool: ...
|
||||
|
||||
class ArrayType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(element_type: ir.Type, num_elements: int) -> ArrayType: ...
|
||||
|
||||
@property
|
||||
def element_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def num_elements(self) -> int: ...
|
||||
|
||||
class PointerType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(address_space: int | None = None, *, context: _ir.Context | None = None) -> PointerType: ...
|
||||
|
||||
@property
|
||||
def address_space(self) -> int: ...
|
||||
|
||||
class FunctionType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(result_type: ir.Type, argument_types: Sequence[ir.Type], *, is_var_arg: bool = False) -> FunctionType: ...
|
||||
|
||||
@property
|
||||
def return_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def num_inputs(self) -> int: ...
|
||||
|
||||
@property
|
||||
def inputs(self) -> list: ...
|
||||
|
||||
@property
|
||||
def is_var_arg(self) -> bool: ...
|
||||
|
||||
class MDStringAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def get(value: str, *, context: _ir.Context | None = None) -> MDStringAttr: ...
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class MDConstantAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def get(value: ir.Attribute, *, context: _ir.Context | None = None) -> MDConstantAttr: ...
|
||||
|
||||
@property
|
||||
def value(self) -> ir.Attribute: ...
|
||||
|
||||
class MDFuncAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def get(name: str, *, context: _ir.Context | None = None) -> MDFuncAttr: ...
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
class MDNodeAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def get(operands: Sequence[ir.Attribute], *, context: _ir.Context | None = None) -> MDNodeAttr: ...
|
||||
|
||||
@property
|
||||
def num_operands(self) -> int: ...
|
||||
|
||||
def __getitem__(self, arg: int, /) -> ir.Attribute: ...
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
def translate_module_to_llvmir(module: ir.Operation) -> str: ...
|
||||
BIN
Binary file not shown.
+25
@@ -0,0 +1,25 @@
|
||||
"""MLIR NVGPU dialect."""
|
||||
|
||||
from typing import ClassVar, Final
|
||||
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class TensorMapDescriptorType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(tensor_type: ir.Type, swizzle: int, l2promo: int, oob_fill: int, interleave: int, context: _ir.Context | None = None) -> TensorMapDescriptorType:
|
||||
"""Gets an instance of TensorMapDescriptorType in the same context"""
|
||||
BIN
Binary file not shown.
+97
@@ -0,0 +1,97 @@
|
||||
"""MLIR SparseTensor dialect."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import enum
|
||||
from typing import ClassVar, Final
|
||||
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class LevelFormat(enum.IntFlag):
|
||||
_boundary_: enum.FlagBoundary = enum.FlagBoundary.KEEP
|
||||
|
||||
_flag_mask_: int = 3997696
|
||||
|
||||
_singles_mask_: int = 3997696
|
||||
|
||||
_all_bits_: int = 8388607
|
||||
|
||||
_inverted_: None = None
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
def __repr__(self, /):
|
||||
"""Return repr(self)."""
|
||||
|
||||
dense = 65536
|
||||
|
||||
n_out_of_m = 2097152
|
||||
|
||||
compressed = 262144
|
||||
|
||||
singleton = 524288
|
||||
|
||||
loose_compressed = 1048576
|
||||
|
||||
class LevelProperty(enum.Enum):
|
||||
non_ordered = 2
|
||||
|
||||
non_unique = 1
|
||||
|
||||
soa = 4
|
||||
|
||||
class EncodingAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
attr_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(lvl_types: Sequence[int], dim_to_lvl: ir.AffineMap | None, lvl_to_dim: ir.AffineMap | None, pos_width: int, crd_width: int, explicit_val: ir.Attribute | None = None, implicit_val: ir.Attribute | None = None, context: _ir.Context | None = None) -> EncodingAttr:
|
||||
"""Gets a sparse_tensor.encoding from parameters."""
|
||||
|
||||
@staticmethod
|
||||
def build_level_type(lvl_fmt: LevelFormat, properties: Sequence[LevelProperty] = [], n: int = 0, m: int = 0) -> int:
|
||||
"""Builds a sparse_tensor.encoding.level_type from parameters."""
|
||||
|
||||
@property
|
||||
def lvl_types(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def dim_to_lvl(self) -> ir.AffineMap | None: ...
|
||||
|
||||
@property
|
||||
def lvl_to_dim(self) -> ir.AffineMap | None: ...
|
||||
|
||||
@property
|
||||
def pos_width(self) -> int: ...
|
||||
|
||||
@property
|
||||
def crd_width(self) -> int: ...
|
||||
|
||||
@property
|
||||
def explicit_val(self) -> ir.Attribute | None: ...
|
||||
|
||||
@property
|
||||
def implicit_val(self) -> ir.Attribute | None: ...
|
||||
|
||||
@property
|
||||
def structured_n(self) -> int: ...
|
||||
|
||||
@property
|
||||
def structured_m(self) -> int: ...
|
||||
|
||||
@property
|
||||
def lvl_formats_enum(self) -> list[LevelFormat]: ...
|
||||
BIN
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,309 @@
|
||||
"""mlir-hlo main python extension"""
|
||||
|
||||
from typing import Self
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_mhlo_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
def register_mhlo_passes() -> None: ...
|
||||
|
||||
class TokenType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Token type."""
|
||||
|
||||
class ScatterDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, update_window_dims: Sequence[int], inserted_window_dims: Sequence[int], input_batching_dims: Sequence[int], scatter_indices_batching_dims: Sequence[int], scattered_dims_to_operand_dims: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a ScatterDimensionNumbers with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def update_window_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def inserted_window_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def input_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def scatter_indices_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def scattered_dims_to_operand_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def index_vector_dim(self) -> int: ...
|
||||
|
||||
class GatherDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, offset_dims: Sequence[int], collapsed_slice_dims: Sequence[int], operand_batching_dims: Sequence[int], start_indices_batching_dims: Sequence[int], start_index_map: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a GatherDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def offset_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def collapsed_slice_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def start_indices_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def start_index_map(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def index_vector_dim(self) -> int: ...
|
||||
|
||||
class DotDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a DotDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def lhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
class ConvDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, input_batch_dimension: int, input_feature_dimension: int, input_spatial_dimensions: Sequence[int], kernel_input_feature_dimension: int, kernel_output_feature_dimension: int, kernel_spatial_dimensions: Sequence[int], output_batch_dimension: int, output_feature_dimension: int, output_spatial_dimensions: Sequence[int], ctx: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a ConvDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_batch_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def input_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def input_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def kernel_input_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def kernel_output_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def kernel_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def output_batch_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def output_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def output_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
class OutputOperandAlias(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, output_tuple_indices: Sequence[int], operand_index: int, operand_tuple_indices: Sequence[int], ctx: ir.Context | None = None) -> Self:
|
||||
"""Creates a OutputOperandAlias attribute with the given tuple index."""
|
||||
|
||||
@property
|
||||
def output_tuple_indices(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_index(self) -> int: ...
|
||||
|
||||
@property
|
||||
def operand_tuple_indices(self) -> list[int]: ...
|
||||
|
||||
class ComparisonDirectionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonDirection attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ComparisonTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class PrecisionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Precision attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class FftTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a FftType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class DequantizeModeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a DequantizeMode attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class TransposeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Transpose attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class FusionKindAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a FusionKind attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RngDistributionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a RngDistribution attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RngAlgorithmAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a RngAlgorithm attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ChannelHandle(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, handle: int, type: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ChannelHandle attribute."""
|
||||
|
||||
@property
|
||||
def handle(self) -> int: ...
|
||||
|
||||
@property
|
||||
def channel_type(self) -> int: ...
|
||||
|
||||
class TypeExtensions(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, bounds: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a TypeExtensions with the given bounds."""
|
||||
|
||||
@property
|
||||
def bounds(self) -> list[int]: ...
|
||||
Binary file not shown.
BIN
Binary file not shown.
+295
@@ -0,0 +1,295 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
import enum
|
||||
from mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
|
||||
def register_inliner_extensions(arg: ir.Context, /) -> None: ...
|
||||
|
||||
class BarrierType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
orders_tensor_core: bool = False, context: ir.Context | None = None
|
||||
) -> BarrierType:
|
||||
"""Creates a BarrierType."""
|
||||
|
||||
@property
|
||||
def orders_tensor_core(self) -> bool: ...
|
||||
|
||||
class TileTransformAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
tiling: Sequence[int], context: ir.Context | None = None
|
||||
) -> TileTransformAttr:
|
||||
"""Creates a TileTransformAttr with the given tiling."""
|
||||
|
||||
@property
|
||||
def tiling(self) -> ir.DenseI32ArrayAttr: ...
|
||||
|
||||
class TransposeTransformAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
permutation: Sequence[int], context: ir.Context | None = None
|
||||
) -> TransposeTransformAttr:
|
||||
"""Creates a TransposeTransformAttr with the given permutation."""
|
||||
|
||||
@property
|
||||
def permutation(self) -> ir.DenseI32ArrayAttr: ...
|
||||
|
||||
class SwizzleTransformAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
swizzle: int, context: ir.Context | None = None
|
||||
) -> SwizzleTransformAttr:
|
||||
"""Creates a SwizzleTransformAttr with the given swizzle."""
|
||||
|
||||
@property
|
||||
def swizzle(self) -> int: ...
|
||||
|
||||
class WGSplatFragLayoutAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
shape: ir.DenseI64ArrayAttr, context: ir.Context | None = None
|
||||
) -> WGSplatFragLayoutAttr:
|
||||
"""Creates a WGSplatFragLayoutAttr with the given shape."""
|
||||
|
||||
@property
|
||||
def shape(self) -> ir.DenseI64ArrayAttr: ...
|
||||
|
||||
class WGStridedFragLayoutAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
shape: ir.DenseI64ArrayAttr,
|
||||
vector_size: int,
|
||||
context: ir.Context | None = None,
|
||||
) -> WGStridedFragLayoutAttr:
|
||||
"""Creates a WGStridedFragLayoutAttr."""
|
||||
|
||||
@property
|
||||
def shape(self) -> ir.DenseI64ArrayAttr: ...
|
||||
@property
|
||||
def vector_size(self) -> int: ...
|
||||
|
||||
class ReplicatedAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(times: int, context: ir.Context | None = None) -> ReplicatedAttr:
|
||||
"""Creates a ReplicatedAttr."""
|
||||
|
||||
@property
|
||||
def times(self) -> int: ...
|
||||
|
||||
class TiledLayoutAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
tiling: ir.ArrayAttr,
|
||||
warp_dims: ir.ArrayAttr,
|
||||
lane_dims: ir.ArrayAttr,
|
||||
vector_dim: int,
|
||||
context: ir.Context | None = None,
|
||||
) -> TiledLayoutAttr:
|
||||
"""Creates a TiledLayoutAttr."""
|
||||
|
||||
@property
|
||||
def tiling(self) -> ir.ArrayAttr: ...
|
||||
@property
|
||||
def warp_dims(self) -> ir.ArrayAttr: ...
|
||||
@property
|
||||
def lane_dims(self) -> ir.ArrayAttr: ...
|
||||
@property
|
||||
def vector_dim(self) -> int: ...
|
||||
|
||||
class CopyPartitionAttrInterface(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
class CopyReplicatedAttr(CopyPartitionAttrInterface):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(context: ir.Context | None = None) -> CopyReplicatedAttr:
|
||||
"""Creates a CopyReplicatedAttr."""
|
||||
|
||||
class CopyPartitionedAttr(CopyPartitionAttrInterface):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(axis: int, context: ir.Context | None = None) -> CopyPartitionedAttr:
|
||||
"""Creates a CopyPartitionedAttr."""
|
||||
|
||||
@property
|
||||
def axis(self) -> int: ...
|
||||
|
||||
class Tiling:
|
||||
def __init__(self, tiles: Iterable) -> None: ...
|
||||
def tile_shape(self, shape: Sequence[int]) -> tuple: ...
|
||||
def untile_shape(self, shape: Sequence[int]) -> tuple: ...
|
||||
def tile_strides(self, strides: Sequence[int]) -> tuple: ...
|
||||
def tile_indices(self, indices: Sequence[int]) -> tuple: ...
|
||||
def untile_indices(self, indices: Sequence[int]) -> tuple: ...
|
||||
def tile_nested_shape_strides(
|
||||
self, shape: Sequence[Sequence[int]], strides: Sequence[Sequence[int]]
|
||||
) -> tuple: ...
|
||||
def tile_dimension(self, dim: int) -> tuple: ...
|
||||
def remove_dimension(self, dim: int) -> Tiling: ...
|
||||
def canonicalize(self) -> Tiling: ...
|
||||
@property
|
||||
def tiles(self) -> tuple: ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __eq__(self, other: object) -> bool: ...
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
class Replicated:
|
||||
def __init__(self, times: int) -> None: ...
|
||||
@property
|
||||
def times(self) -> int: ...
|
||||
@times.setter
|
||||
def times(self, arg: int, /) -> None: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __eq__(self, arg: object, /) -> bool: ...
|
||||
|
||||
class TiledLayout:
|
||||
def __init__(
|
||||
self,
|
||||
tiling: Tiling,
|
||||
warp_dims: Iterable,
|
||||
lane_dims: Iterable,
|
||||
vector_dim: int,
|
||||
_check_canonical: bool = ...,
|
||||
) -> None: ...
|
||||
@property
|
||||
def warp_dims(self) -> tuple: ...
|
||||
@property
|
||||
def lane_dims(self) -> tuple: ...
|
||||
@property
|
||||
def partitioned_warp_dims(self) -> tuple: ...
|
||||
@property
|
||||
def partitioned_lane_dims(self) -> tuple: ...
|
||||
@property
|
||||
def vector_length(self) -> int: ...
|
||||
@property
|
||||
def vector_dim(self) -> int: ...
|
||||
@property
|
||||
def tiling(self) -> Tiling: ...
|
||||
@property
|
||||
def tiled_tiling_shape(self) -> tuple: ...
|
||||
@property
|
||||
def tiled_tiling_rank(self) -> int: ...
|
||||
def warp_indices(self) -> tuple: ...
|
||||
def lane_indices(self) -> tuple: ...
|
||||
def canonicalize(self) -> TiledLayout: ...
|
||||
def registers_shape(self, shape: Sequence[int]) -> tuple: ...
|
||||
def registers_element_type(self, t: ir.Type) -> ir.Type: ...
|
||||
def shape_from_registers_shape(self, shape: Sequence[int]) -> tuple: ...
|
||||
@property
|
||||
def base_tile_shape(self) -> tuple: ...
|
||||
def remove_dimension(self, dim: int) -> TiledLayout: ...
|
||||
def reduce(self, axes: Iterable) -> TiledLayout: ...
|
||||
def thread_idxs(self, arg: Sequence[int], /) -> list: ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __eq__(self, other: object | None) -> bool: ...
|
||||
|
||||
class Rounding(enum.Enum):
|
||||
UP = 0
|
||||
|
||||
DOWN = 1
|
||||
|
||||
class TileTransform:
|
||||
def __init__(
|
||||
self, tiling: Sequence[int], rounding: Rounding | None = ...
|
||||
) -> None: ...
|
||||
def apply(self, arg: object, /) -> ir.Value: ...
|
||||
def transform_index(self, arg: Iterable, /) -> tuple: ...
|
||||
def transform_shape(self, arg: Sequence[int], /) -> tuple: ...
|
||||
def transform_strides(self, arg: Sequence[int], /) -> tuple: ...
|
||||
|
||||
class TrivialTransferPlan:
|
||||
def __init__(self) -> None: ...
|
||||
@property
|
||||
def tile_index_transforms(self) -> tuple: ...
|
||||
def select(self, arg: Iterable, /) -> ir.Value: ...
|
||||
def select_if_group(
|
||||
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
|
||||
) -> ir.Value: ...
|
||||
|
||||
class StaggeredTransferPlan:
|
||||
def __init__(
|
||||
self, stagger: int, dim: int, size: int, group_pred: object
|
||||
) -> None: ...
|
||||
@property
|
||||
def stagger(self) -> int: ...
|
||||
@property
|
||||
def dim(self) -> int: ...
|
||||
@property
|
||||
def size(self) -> int: ...
|
||||
@property
|
||||
def group_pred(self) -> ir.Value: ...
|
||||
@property
|
||||
def tile_index_transforms(self) -> tuple: ...
|
||||
def select(self, arg: Iterable, /) -> ir.Value: ...
|
||||
def select_if_group(
|
||||
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
|
||||
) -> ir.Value: ...
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,210 @@
|
||||
"""SDY main Python extension"""
|
||||
|
||||
from typing import Self
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
class MeshAxisAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, size: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a MeshAxisAttr with the given axis name and size."""
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def size(self) -> int: ...
|
||||
|
||||
class MeshAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, mesh_axes: Sequence[ir.Attribute], device_ids: Sequence[int] = [], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a MeshAttr with the given mesh axes."""
|
||||
|
||||
@property
|
||||
def device_ids(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def axes(self) -> list[ir.Attribute]: ...
|
||||
|
||||
class SubAxisInfoAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, pre_size: int, size: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a SubAxisInfoAttr with the given pre-size and size."""
|
||||
|
||||
@property
|
||||
def pre_size(self) -> int: ...
|
||||
|
||||
@property
|
||||
def size(self) -> int: ...
|
||||
|
||||
class AxisRefAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, sub_axis_info: ir.Attribute | None = None, context: ir.Context | None = None) -> Self:
|
||||
"""Creates an AxisRefAttr with the given name and SubAxisInfoAttr."""
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def sub_axis_info(self) -> ir.Attribute | None: ...
|
||||
|
||||
class DimensionShardingAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, axes: Sequence[ir.Attribute], is_closed: bool, priority: int | None = None, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a DimensionShardingAttr with the given axes, whether it's closed, and priority.
|
||||
"""
|
||||
|
||||
@property
|
||||
def axes(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def priority(self) -> int | None: ...
|
||||
|
||||
class TensorShardingAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, mesh_or_ref: str | ir.Attribute, dimension_shardings: Sequence[ir.Attribute], replicated_axes: Sequence[ir.Attribute] = [], unreduced_axes: Sequence[ir.Attribute] = [], context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a TensorShardingAttr with either an inlined mesh or mesh name, dimension shardings, and replicated axes.
|
||||
"""
|
||||
|
||||
@property
|
||||
def mesh_or_ref(self) -> ir.Attribute: ...
|
||||
|
||||
@property
|
||||
def dimension_shardings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def replicated_axes(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def unreduced_axes(self) -> list[ir.Attribute]: ...
|
||||
|
||||
class TensorShardingPerValueAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, shardings: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a TensorShardingPerValueAttr with the tensor shardings."""
|
||||
|
||||
@property
|
||||
def shardings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
class DimMappingAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, factor_indices: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a DimMappingAttr with the factor indices."""
|
||||
|
||||
@property
|
||||
def factor_indices(self) -> list[int]: ...
|
||||
|
||||
class TensorMappingAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, dim_mappings: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a TensorMappingAttr with the dim mappings."""
|
||||
|
||||
@property
|
||||
def dim_mappings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def rank(self) -> int: ...
|
||||
|
||||
class OpShardingRuleAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, factor_sizes: Sequence[int], operand_mappings: Sequence[ir.Attribute], result_mappings: Sequence[ir.Attribute], reduction_factors: Sequence[int] = [], need_replication_factors: Sequence[int] = [], permutation_factors: Sequence[int] = [], blocked_propagation_factors: Sequence[int] = [], is_custom: bool = False, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a OpShardingRuleAttr with the factor sizes and mappings for operands and results.
|
||||
"""
|
||||
|
||||
@property
|
||||
def is_custom(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def factor_sizes(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_mappings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def result_mappings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def reduction_factors(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def need_replication_factors(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def permutation_factors(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def blocked_propagation_factors(self) -> list[int]: ...
|
||||
|
||||
class ManualAxesAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, manual_axes: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ManualAxesAttr with the given manual axes."""
|
||||
|
||||
def __getitem__(self, arg: int, /) -> str: ...
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
Binary file not shown.
@@ -0,0 +1,23 @@
|
||||
"""MPMD main Python extension"""
|
||||
|
||||
from typing import Self
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
class UserOriginAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, transpose_count: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a UserOriginAttr with the given user name and transpose count."""
|
||||
|
||||
@property
|
||||
def user_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def transpose_count(self) -> int: ...
|
||||
Binary file not shown.
@@ -0,0 +1,407 @@
|
||||
"""stablehlo main python extension"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import enum
|
||||
from typing import overload, Self
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
def register_interpreter_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
def register_stablehlo_passes() -> None: ...
|
||||
|
||||
class TokenType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Token type."""
|
||||
|
||||
class FutureType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, types: Sequence[ir.Type], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Future type."""
|
||||
|
||||
@property
|
||||
def types(self) -> list[ir.Type]: ...
|
||||
|
||||
class ScatterDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, update_window_dims: Sequence[int], inserted_window_dims: Sequence[int], input_batching_dims: Sequence[int], scatter_indices_batching_dims: Sequence[int], scattered_dims_to_operand_dims: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a ScatterDimensionNumbers with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def update_window_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def inserted_window_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def input_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def scatter_indices_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def scattered_dims_to_operand_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def index_vector_dim(self) -> int: ...
|
||||
|
||||
class GatherDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, offset_dims: Sequence[int], collapsed_slice_dims: Sequence[int], operand_batching_dims: Sequence[int], start_indices_batching_dims: Sequence[int], start_index_map: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a GatherDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def offset_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def collapsed_slice_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def start_indices_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def start_index_map(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def index_vector_dim(self) -> int: ...
|
||||
|
||||
class DotAlgorithm(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, lhs_precision_type: ir.Type, rhs_precision_type: ir.Type, accumulation_type: ir.Type, lhs_component_count: int, rhs_component_count: int, num_primitive_operations: int, allow_imprecise_accumulation: bool, ctx: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a DotAlgorithm attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lhs_precision_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def rhs_precision_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def accumulation_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def lhs_component_count(self) -> int: ...
|
||||
|
||||
@property
|
||||
def rhs_component_count(self) -> int: ...
|
||||
|
||||
@property
|
||||
def num_primitive_operations(self) -> int: ...
|
||||
|
||||
@property
|
||||
def allow_imprecise_accumulation(self) -> bool: ...
|
||||
|
||||
class DotDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a DotDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def lhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
class ConvDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, input_batch_dimension: int, input_feature_dimension: int, input_spatial_dimensions: Sequence[int], kernel_input_feature_dimension: int, kernel_output_feature_dimension: int, kernel_spatial_dimensions: Sequence[int], output_batch_dimension: int, output_feature_dimension: int, output_spatial_dimensions: Sequence[int], ctx: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a ConvDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_batch_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def input_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def input_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def kernel_input_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def kernel_output_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def kernel_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def output_batch_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def output_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def output_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
class OutputOperandAlias(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, output_tuple_indices: Sequence[int], operand_index: int, operand_tuple_indices: Sequence[int], ctx: ir.Context | None = None) -> Self:
|
||||
"""Creates a OutputOperandAlias attribute with the given tuple index."""
|
||||
|
||||
@property
|
||||
def output_tuple_indices(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_index(self) -> int: ...
|
||||
|
||||
@property
|
||||
def operand_tuple_indices(self) -> list[int]: ...
|
||||
|
||||
class ComparisonDirectionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonDirection attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ComparisonTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class PrecisionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Precision attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class FftTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a FftType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class TransposeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Transpose attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RngDistributionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a RngDistribution attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RngAlgorithmAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a RngAlgorithm attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ChannelHandle(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, handle: int, type: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ChannelHandle attribute."""
|
||||
|
||||
@property
|
||||
def handle(self) -> int: ...
|
||||
|
||||
@property
|
||||
def channel_type(self) -> int: ...
|
||||
|
||||
class TypeExtensions(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, bounds: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a TypeExtensions with the given bounds."""
|
||||
|
||||
@property
|
||||
def bounds(self) -> list[int]: ...
|
||||
|
||||
class ResultAccuracyAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, atol: float, rtol: float, ulps: int, mode: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ResultAccuracyAttr with the given values."""
|
||||
|
||||
@property
|
||||
def atol(self) -> float: ...
|
||||
|
||||
@property
|
||||
def rtol(self) -> float: ...
|
||||
|
||||
@property
|
||||
def ulps(self) -> int: ...
|
||||
|
||||
@property
|
||||
def mode(self) -> str: ...
|
||||
|
||||
class ResultAccuracyModeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ResultAccuracyModeAttr with the given values."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
def get_api_version() -> int: ...
|
||||
|
||||
def get_smaller_version(version1: str, version2: str) -> str: ...
|
||||
|
||||
def get_current_version() -> str: ...
|
||||
|
||||
def get_minimum_version() -> str: ...
|
||||
|
||||
@overload
|
||||
def serialize_portable_artifact_str(module_str: str, target_version: str) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def serialize_portable_artifact_str(module_str: bytes, target_version: str) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def deserialize_portable_artifact_str(artifact_str: str) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def deserialize_portable_artifact_str(artifact_str: bytes) -> bytes: ...
|
||||
|
||||
class StablehloCompatibilityRequirement(enum.Enum):
|
||||
NONE = 0
|
||||
|
||||
WEEK_4 = 1
|
||||
|
||||
WEEK_12 = 2
|
||||
|
||||
MAX = 3
|
||||
|
||||
def get_version_from_compatibility_requirement(requirement: StablehloCompatibilityRequirement) -> str: ...
|
||||
|
||||
def serialize_portable_artifact(module: ir.Module, target: str, allow_other_dialects: bool = False) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def deserialize_portable_artifact(context: ir.Context, artifact: str) -> ir.Module: ...
|
||||
|
||||
@overload
|
||||
def deserialize_portable_artifact(context: ir.Context, artifact: bytes) -> ir.Module: ...
|
||||
|
||||
def eval_module(module: ir.Module, args: Sequence[ir.Attribute], probe_instrumentation_dir: str = '') -> list[ir.Attribute]: ...
|
||||
Binary file not shown.
@@ -0,0 +1,32 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
|
||||
def private_has_communication(arg: ir.Operation, /) -> tuple[bool, bool]: ...
|
||||
def private_set_arg_attr(
|
||||
arg0: ir.Operation, arg1: int, arg2: str, arg3: ir.Attribute, /
|
||||
) -> None: ...
|
||||
|
||||
class Float8EXMYType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
exmy_type: ir.Type | None = None, ctx: ir.Context | None = None
|
||||
) -> Float8EXMYType: ...
|
||||
@property
|
||||
def underlying_type(self) -> ir.Type: ...
|
||||
Binary file not shown.
@@ -0,0 +1,36 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
|
||||
|
||||
class PointerType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(pointee_type: ir.Type, address_space: int) -> PointerType:
|
||||
"""Creates a PointerType type."""
|
||||
|
||||
@property
|
||||
def pointee_type(self) -> ir.Type: ...
|
||||
@property
|
||||
def address_space(self) -> int: ...
|
||||
|
||||
def infer_reduce_op_encoding(
|
||||
arg0: ir.Attribute, arg1: int, /
|
||||
) -> ir.Attribute | None: ...
|
||||
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,285 @@
|
||||
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from enum import IntEnum, auto, IntFlag
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ..ir import register_attribute_builder
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
||||
class CmpFPredicate(IntEnum):
|
||||
"""allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15"""
|
||||
|
||||
AlwaysFalse = 0
|
||||
OEQ = 1
|
||||
OGT = 2
|
||||
OGE = 3
|
||||
OLT = 4
|
||||
OLE = 5
|
||||
ONE = 6
|
||||
ORD = 7
|
||||
UEQ = 8
|
||||
UGT = 9
|
||||
UGE = 10
|
||||
ULT = 11
|
||||
ULE = 12
|
||||
UNE = 13
|
||||
UNO = 14
|
||||
AlwaysTrue = 15
|
||||
|
||||
def __str__(self):
|
||||
if self is CmpFPredicate.AlwaysFalse:
|
||||
return "false"
|
||||
if self is CmpFPredicate.OEQ:
|
||||
return "oeq"
|
||||
if self is CmpFPredicate.OGT:
|
||||
return "ogt"
|
||||
if self is CmpFPredicate.OGE:
|
||||
return "oge"
|
||||
if self is CmpFPredicate.OLT:
|
||||
return "olt"
|
||||
if self is CmpFPredicate.OLE:
|
||||
return "ole"
|
||||
if self is CmpFPredicate.ONE:
|
||||
return "one"
|
||||
if self is CmpFPredicate.ORD:
|
||||
return "ord"
|
||||
if self is CmpFPredicate.UEQ:
|
||||
return "ueq"
|
||||
if self is CmpFPredicate.UGT:
|
||||
return "ugt"
|
||||
if self is CmpFPredicate.UGE:
|
||||
return "uge"
|
||||
if self is CmpFPredicate.ULT:
|
||||
return "ult"
|
||||
if self is CmpFPredicate.ULE:
|
||||
return "ule"
|
||||
if self is CmpFPredicate.UNE:
|
||||
return "une"
|
||||
if self is CmpFPredicate.UNO:
|
||||
return "uno"
|
||||
if self is CmpFPredicate.AlwaysTrue:
|
||||
return "true"
|
||||
raise ValueError("Unknown CmpFPredicate enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("Arith_CmpFPredicateAttr", allow_existing=True)
|
||||
def _arith_cmpfpredicateattr(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
|
||||
|
||||
class CmpIPredicate(IntEnum):
|
||||
"""allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9"""
|
||||
|
||||
eq = 0
|
||||
ne = 1
|
||||
slt = 2
|
||||
sle = 3
|
||||
sgt = 4
|
||||
sge = 5
|
||||
ult = 6
|
||||
ule = 7
|
||||
ugt = 8
|
||||
uge = 9
|
||||
|
||||
def __str__(self):
|
||||
if self is CmpIPredicate.eq:
|
||||
return "eq"
|
||||
if self is CmpIPredicate.ne:
|
||||
return "ne"
|
||||
if self is CmpIPredicate.slt:
|
||||
return "slt"
|
||||
if self is CmpIPredicate.sle:
|
||||
return "sle"
|
||||
if self is CmpIPredicate.sgt:
|
||||
return "sgt"
|
||||
if self is CmpIPredicate.sge:
|
||||
return "sge"
|
||||
if self is CmpIPredicate.ult:
|
||||
return "ult"
|
||||
if self is CmpIPredicate.ule:
|
||||
return "ule"
|
||||
if self is CmpIPredicate.ugt:
|
||||
return "ugt"
|
||||
if self is CmpIPredicate.uge:
|
||||
return "uge"
|
||||
raise ValueError("Unknown CmpIPredicate enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("Arith_CmpIPredicateAttr", allow_existing=True)
|
||||
def _arith_cmpipredicateattr(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
|
||||
|
||||
class IntegerOverflowFlags(IntFlag):
|
||||
"""Integer overflow arith flags"""
|
||||
|
||||
none = 0
|
||||
nsw = 1
|
||||
nuw = 2
|
||||
|
||||
def __iter__(self):
|
||||
return iter([case for case in type(self) if (self & case) is case and self is not case])
|
||||
def __len__(self):
|
||||
return bin(self).count("1")
|
||||
|
||||
def __str__(self):
|
||||
if len(self) > 1:
|
||||
return ", ".join(map(str, self))
|
||||
if self is IntegerOverflowFlags.none:
|
||||
return "none"
|
||||
if self is IntegerOverflowFlags.nsw:
|
||||
return "nsw"
|
||||
if self is IntegerOverflowFlags.nuw:
|
||||
return "nuw"
|
||||
raise ValueError("Unknown IntegerOverflowFlags enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("Arith_IntegerOverflowFlags", allow_existing=True)
|
||||
def _arith_integeroverflowflags(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class RoundingMode(IntEnum):
|
||||
"""Floating point rounding mode"""
|
||||
|
||||
to_nearest_even = 0
|
||||
downward = 1
|
||||
upward = 2
|
||||
toward_zero = 3
|
||||
to_nearest_away = 4
|
||||
|
||||
def __str__(self):
|
||||
if self is RoundingMode.to_nearest_even:
|
||||
return "to_nearest_even"
|
||||
if self is RoundingMode.downward:
|
||||
return "downward"
|
||||
if self is RoundingMode.upward:
|
||||
return "upward"
|
||||
if self is RoundingMode.toward_zero:
|
||||
return "toward_zero"
|
||||
if self is RoundingMode.to_nearest_away:
|
||||
return "to_nearest_away"
|
||||
raise ValueError("Unknown RoundingMode enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("Arith_RoundingModeAttr", allow_existing=True)
|
||||
def _arith_roundingmodeattr(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class AtomicRMWKind(IntEnum):
|
||||
"""allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15"""
|
||||
|
||||
addf = 0
|
||||
addi = 1
|
||||
andi = 2
|
||||
assign = 3
|
||||
maximumf = 4
|
||||
maxnumf = 5
|
||||
maxs = 6
|
||||
maxu = 7
|
||||
minimumf = 8
|
||||
minnumf = 9
|
||||
mins = 10
|
||||
minu = 11
|
||||
mulf = 12
|
||||
muli = 13
|
||||
ori = 14
|
||||
xori = 15
|
||||
|
||||
def __str__(self):
|
||||
if self is AtomicRMWKind.addf:
|
||||
return "addf"
|
||||
if self is AtomicRMWKind.addi:
|
||||
return "addi"
|
||||
if self is AtomicRMWKind.andi:
|
||||
return "andi"
|
||||
if self is AtomicRMWKind.assign:
|
||||
return "assign"
|
||||
if self is AtomicRMWKind.maximumf:
|
||||
return "maximumf"
|
||||
if self is AtomicRMWKind.maxnumf:
|
||||
return "maxnumf"
|
||||
if self is AtomicRMWKind.maxs:
|
||||
return "maxs"
|
||||
if self is AtomicRMWKind.maxu:
|
||||
return "maxu"
|
||||
if self is AtomicRMWKind.minimumf:
|
||||
return "minimumf"
|
||||
if self is AtomicRMWKind.minnumf:
|
||||
return "minnumf"
|
||||
if self is AtomicRMWKind.mins:
|
||||
return "mins"
|
||||
if self is AtomicRMWKind.minu:
|
||||
return "minu"
|
||||
if self is AtomicRMWKind.mulf:
|
||||
return "mulf"
|
||||
if self is AtomicRMWKind.muli:
|
||||
return "muli"
|
||||
if self is AtomicRMWKind.ori:
|
||||
return "ori"
|
||||
if self is AtomicRMWKind.xori:
|
||||
return "xori"
|
||||
raise ValueError("Unknown AtomicRMWKind enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("AtomicRMWKindAttr", allow_existing=True)
|
||||
def _atomicrmwkindattr(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
|
||||
|
||||
class FastMathFlags(IntFlag):
|
||||
"""Floating point fast math flags"""
|
||||
|
||||
none = 0
|
||||
reassoc = 1
|
||||
nnan = 2
|
||||
ninf = 4
|
||||
nsz = 8
|
||||
arcp = 16
|
||||
contract = 32
|
||||
afn = 64
|
||||
fast = 127
|
||||
|
||||
def __iter__(self):
|
||||
return iter([case for case in type(self) if (self & case) is case and self is not case])
|
||||
def __len__(self):
|
||||
return bin(self).count("1")
|
||||
|
||||
def __str__(self):
|
||||
if len(self) > 1:
|
||||
return ",".join(map(str, self))
|
||||
if self is FastMathFlags.none:
|
||||
return "none"
|
||||
if self is FastMathFlags.reassoc:
|
||||
return "reassoc"
|
||||
if self is FastMathFlags.nnan:
|
||||
return "nnan"
|
||||
if self is FastMathFlags.ninf:
|
||||
return "ninf"
|
||||
if self is FastMathFlags.nsz:
|
||||
return "nsz"
|
||||
if self is FastMathFlags.arcp:
|
||||
return "arcp"
|
||||
if self is FastMathFlags.contract:
|
||||
return "contract"
|
||||
if self is FastMathFlags.afn:
|
||||
return "afn"
|
||||
if self is FastMathFlags.fast:
|
||||
return "fast"
|
||||
raise ValueError("Unknown FastMathFlags enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("FastMathFlags", allow_existing=True)
|
||||
def _fastmathflags(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
@register_attribute_builder("arith.Arith_FastMathAttr")
|
||||
def _arith_fastmathattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#arith.fastmath<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("arith.Arith_IntegerOverflowAttr")
|
||||
def _arith_integeroverflowattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#arith.overflow<{str(x)}>', context=context)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+199
@@ -0,0 +1,199 @@
|
||||
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ._ods_common import (
|
||||
equally_sized_accessor as _ods_equally_sized_accessor,
|
||||
get_default_loc_context as _ods_get_default_loc_context,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
segmented_accessor as _ods_segmented_accessor,
|
||||
)
|
||||
_ods_ir = _ods_cext.ir
|
||||
_ods_cext.globals.register_traceback_file_exclusion(__file__)
|
||||
|
||||
import builtins
|
||||
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
|
||||
import sys as _sys
|
||||
if _sys.version_info >= (3, 12):
|
||||
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
|
||||
else:
|
||||
try:
|
||||
from typing_extensions import Buffer as _Buffer
|
||||
except ImportError:
|
||||
_Buffer = _Any
|
||||
|
||||
|
||||
@_ods_cext.register_dialect
|
||||
class _Dialect(_ods_ir.Dialect):
|
||||
DIALECT_NAMESPACE = "builtin"
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class ModuleOp(_ods_ir.OpView):
|
||||
r"""
|
||||
A `module` represents a top-level container operation. It contains a single
|
||||
[graph region](../LangRef.md#control-flow-and-ssacfg-regions) containing a single block
|
||||
which can contain any operations and does not have a terminator. Operations
|
||||
within this region cannot implicitly capture values defined outside the module,
|
||||
i.e. Modules are [IsolatedFromAbove](../Traits#isolatedfromabove). Modules have
|
||||
an optional [symbol name](../SymbolsAndSymbolTables.md) which can be used to refer
|
||||
to them in operations.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
module {
|
||||
func.func @foo()
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "builtin.module"
|
||||
|
||||
_ODS_REGIONS = (1, True)
|
||||
|
||||
def __init__(self, *, sym_name: _Optional[_Union[str, _ods_ir.StringAttr]] = None, sym_visibility: _Optional[_Union[str, _ods_ir.StringAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
if sym_name is not None: attributes["sym_name"] = (sym_name if (
|
||||
isinstance(sym_name, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('SymbolNameAttr')) else
|
||||
_ods_ir.AttrBuilder.get('SymbolNameAttr')(sym_name, context=_ods_context))
|
||||
if sym_visibility is not None: attributes["sym_visibility"] = (sym_visibility if (
|
||||
isinstance(sym_visibility, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('StrAttr')) else
|
||||
_ods_ir.AttrBuilder.get('StrAttr')(sym_visibility, context=_ods_context))
|
||||
results = []
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def sym_name(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "sym_name" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["sym_name"]
|
||||
|
||||
@sym_name.setter
|
||||
def sym_name(self, value: _Optional[_ods_ir.StringAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["sym_name"] = value
|
||||
elif "sym_name" in self.operation.attributes:
|
||||
del self.operation.attributes["sym_name"]
|
||||
|
||||
@sym_name.deleter
|
||||
def sym_name(self):
|
||||
del self.operation.attributes["sym_name"]
|
||||
|
||||
@builtins.property
|
||||
def sym_visibility(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "sym_visibility" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["sym_visibility"]
|
||||
|
||||
@sym_visibility.setter
|
||||
def sym_visibility(self, value: _Optional[_ods_ir.StringAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["sym_visibility"] = value
|
||||
elif "sym_visibility" in self.operation.attributes:
|
||||
del self.operation.attributes["sym_visibility"]
|
||||
|
||||
@sym_visibility.deleter
|
||||
def sym_visibility(self):
|
||||
del self.operation.attributes["sym_visibility"]
|
||||
|
||||
@builtins.property
|
||||
def bodyRegion(self) -> _ods_ir.Region:
|
||||
return self.regions[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(ModuleOp)
|
||||
class ModuleOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "builtin.module"
|
||||
|
||||
@builtins.property
|
||||
def sym_name(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "sym_name" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["sym_name"]
|
||||
|
||||
@builtins.property
|
||||
def sym_visibility(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "sym_visibility" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
def module(*, sym_name: _Optional[_Union[str, _ods_ir.StringAttr]] = None, sym_visibility: _Optional[_Union[str, _ods_ir.StringAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> ModuleOp:
|
||||
return ModuleOp(sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class UnrealizedConversionCastOp(_ods_ir.OpView):
|
||||
r"""
|
||||
An `unrealized_conversion_cast` operation represents an unrealized
|
||||
conversion from one set of types to another, that is used to enable the
|
||||
inter-mixing of different type systems. This operation should not be
|
||||
attributed any special representational or execution semantics, and is
|
||||
generally only intended to be used to satisfy the temporary intermixing of
|
||||
type systems during the conversion of one type system to another.
|
||||
|
||||
This operation may produce results of arity 1-N, and accept as input
|
||||
operands of arity 0-N.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// An unrealized 0-1 conversion. These types of conversions are useful in
|
||||
// cases where a type is removed from the type system, but not all uses have
|
||||
// been converted. For example, imagine we have a tuple type that is
|
||||
// expanded to its element types. If only some uses of an empty tuple type
|
||||
// instance are converted we still need an instance of the tuple type, but
|
||||
// have no inputs to the unrealized conversion.
|
||||
%result = unrealized_conversion_cast to !bar.tuple_type<>
|
||||
|
||||
// An unrealized 1-1 conversion.
|
||||
%result1 = unrealized_conversion_cast %operand : !foo.type to !bar.lowered_type
|
||||
|
||||
// An unrealized 1-N conversion.
|
||||
%results2:2 = unrealized_conversion_cast %tuple_operand : !foo.tuple_type<!foo.type, !foo.type> to !foo.type, !foo.type
|
||||
|
||||
// An unrealized N-1 conversion.
|
||||
%result3 = unrealized_conversion_cast %operand, %operand : !foo.type, !foo.type to !bar.tuple_type<!foo.type, !foo.type>
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "builtin.unrealized_conversion_cast"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, outputs: _Sequence[_ods_ir.Type], inputs: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(inputs))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
results = []
|
||||
results.extend(outputs)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def inputs(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def outputs(self) -> _ods_ir.OpResultList:
|
||||
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
|
||||
return self.operation.results[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@_ods_cext.register_op_adaptor(UnrealizedConversionCastOp)
|
||||
class UnrealizedConversionCastOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "builtin.unrealized_conversion_cast"
|
||||
|
||||
@builtins.property
|
||||
def inputs(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
def unrealized_conversion_cast(outputs: _Sequence[_ods_ir.Type], inputs: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, UnrealizedConversionCastOp]:
|
||||
op = UnrealizedConversionCastOp(outputs=outputs, inputs=inputs, loc=loc, ip=ip); results = op.results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
@@ -0,0 +1,400 @@
|
||||
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ._ods_common import (
|
||||
equally_sized_accessor as _ods_equally_sized_accessor,
|
||||
get_default_loc_context as _ods_get_default_loc_context,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
segmented_accessor as _ods_segmented_accessor,
|
||||
)
|
||||
_ods_ir = _ods_cext.ir
|
||||
_ods_cext.globals.register_traceback_file_exclusion(__file__)
|
||||
|
||||
import builtins
|
||||
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
|
||||
import sys as _sys
|
||||
if _sys.version_info >= (3, 12):
|
||||
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
|
||||
else:
|
||||
try:
|
||||
from typing_extensions import Buffer as _Buffer
|
||||
except ImportError:
|
||||
_Buffer = _Any
|
||||
|
||||
|
||||
@_ods_cext.register_dialect
|
||||
class _Dialect(_ods_ir.Dialect):
|
||||
DIALECT_NAMESPACE = "cf"
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class AssertOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Assert operation at runtime with single boolean operand and an error
|
||||
message attribute.
|
||||
If the argument is `true` this operation has no effect. Otherwise, the
|
||||
program execution will abort. The provided error message may be used by a
|
||||
runtime to propagate the error to the user.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
cf.assert %b, "Expected ... to be true"
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "cf.assert"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, arg: _ods_ir.Value[_ods_ir.IntegerType], msg: _Union[str, _ods_ir.StringAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(arg)
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["msg"] = (msg if (
|
||||
isinstance(msg, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('StrAttr')) else
|
||||
_ods_ir.AttrBuilder.get('StrAttr')(msg, context=_ods_context))
|
||||
results = []
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def arg(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
|
||||
return self.operation.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def msg(self) -> _ods_ir.StringAttr:
|
||||
return self.operation.attributes["msg"]
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value: _ods_ir.StringAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["msg"] = value
|
||||
|
||||
@_ods_cext.register_op_adaptor(AssertOp)
|
||||
class AssertOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "cf.assert"
|
||||
|
||||
@builtins.property
|
||||
def arg(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
|
||||
return self.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def msg(self) -> _ods_ir.StringAttr:
|
||||
return self.attributes["msg"]
|
||||
|
||||
def assert_(arg: _ods_ir.Value[_ods_ir.IntegerType], msg: _Union[str, _ods_ir.StringAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> AssertOp:
|
||||
return AssertOp(arg=arg, msg=msg, loc=loc, ip=ip)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class BranchOp(_ods_ir.OpView):
|
||||
r"""
|
||||
The `cf.br` operation represents a direct branch operation to a given
|
||||
block. The operands of this operation are forwarded to the successor block,
|
||||
and the number and type of the operands must match the arguments of the
|
||||
target block.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
^bb2:
|
||||
%2 = call @someFn()
|
||||
cf.br ^bb3(%2 : tensor<*xf32>)
|
||||
^bb3(%3: tensor<*xf32>):
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "cf.br"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, destOperands: _Sequence[_ods_ir.Value], dest: _ods_ir.Block, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(destOperands))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
results = []
|
||||
_ods_successors = []
|
||||
_ods_successors.append(dest)
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def destOperands(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@_ods_cext.register_op_adaptor(BranchOp)
|
||||
class BranchOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "cf.br"
|
||||
|
||||
@builtins.property
|
||||
def destOperands(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
def br(dest_operands: _Sequence[_ods_ir.Value], dest: _ods_ir.Block, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> BranchOp:
|
||||
return BranchOp(destOperands=dest_operands, dest=dest, loc=loc, ip=ip)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class CondBranchOp(_ods_ir.OpView):
|
||||
r"""
|
||||
The `cf.cond_br` terminator operation represents a conditional branch on a
|
||||
boolean (1-bit integer) value. If the bit is set, then the first destination
|
||||
is jumped to; if it is false, the second destination is chosen. The count
|
||||
and types of operands must align with the arguments in the corresponding
|
||||
target blocks.
|
||||
|
||||
The MLIR conditional branch operation is not allowed to target the entry
|
||||
block for a region. The two destinations of the conditional branch operation
|
||||
are allowed to be the same.
|
||||
|
||||
The following example illustrates a function with a conditional branch
|
||||
operation that targets the same block.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
func.func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
|
||||
// Both targets are the same, operands differ
|
||||
cf.cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
|
||||
|
||||
^bb1(%x : i32) :
|
||||
return %x : i32
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "cf.cond_br"
|
||||
|
||||
_ODS_OPERAND_SEGMENTS = [1,-1,-1,]
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, condition: _ods_ir.Value[_ods_ir.IntegerType], trueDestOperands: _Sequence[_ods_ir.Value], falseDestOperands: _Sequence[_ods_ir.Value], trueDest: _ods_ir.Block, falseDest: _ods_ir.Block, *, branch_weights: _Optional[_Union[_Sequence[int], _ods_ir.DenseI32ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(condition)
|
||||
operands.append(_get_op_results_or_values(trueDestOperands))
|
||||
operands.append(_get_op_results_or_values(falseDestOperands))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
if branch_weights is not None: attributes["branch_weights"] = (branch_weights if (
|
||||
isinstance(branch_weights, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('DenseI32ArrayAttr')) else
|
||||
_ods_ir.AttrBuilder.get('DenseI32ArrayAttr')(branch_weights, context=_ods_context))
|
||||
results = []
|
||||
_ods_successors = []
|
||||
_ods_successors.append(trueDest)
|
||||
_ods_successors.append(falseDest)
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def condition(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operation.operands,
|
||||
self.operation.attributes["operandSegmentSizes"], 0)
|
||||
return operand_range[0]
|
||||
|
||||
@builtins.property
|
||||
def trueDestOperands(self) -> _ods_ir.OpOperandList:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operation.operands,
|
||||
self.operation.attributes["operandSegmentSizes"], 1)
|
||||
return operand_range
|
||||
|
||||
@builtins.property
|
||||
def falseDestOperands(self) -> _ods_ir.OpOperandList:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operation.operands,
|
||||
self.operation.attributes["operandSegmentSizes"], 2)
|
||||
return operand_range
|
||||
|
||||
@builtins.property
|
||||
def branch_weights(self) -> _Optional[_ods_ir.DenseI32ArrayAttr]:
|
||||
if "branch_weights" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["branch_weights"]
|
||||
|
||||
@branch_weights.setter
|
||||
def branch_weights(self, value: _Optional[_ods_ir.DenseI32ArrayAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["branch_weights"] = value
|
||||
elif "branch_weights" in self.operation.attributes:
|
||||
del self.operation.attributes["branch_weights"]
|
||||
|
||||
@branch_weights.deleter
|
||||
def branch_weights(self):
|
||||
del self.operation.attributes["branch_weights"]
|
||||
|
||||
@_ods_cext.register_op_adaptor(CondBranchOp)
|
||||
class CondBranchOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "cf.cond_br"
|
||||
|
||||
@builtins.property
|
||||
def condition(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operands,
|
||||
self.attributes["operandSegmentSizes"], 0)
|
||||
return operand_range[0]
|
||||
|
||||
@builtins.property
|
||||
def trueDestOperands(self) -> _ods_ir.OpOperandList:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operands,
|
||||
self.attributes["operandSegmentSizes"], 1)
|
||||
return operand_range
|
||||
|
||||
@builtins.property
|
||||
def falseDestOperands(self) -> _ods_ir.OpOperandList:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operands,
|
||||
self.attributes["operandSegmentSizes"], 2)
|
||||
return operand_range
|
||||
|
||||
@builtins.property
|
||||
def branch_weights(self) -> _Optional[_ods_ir.DenseI32ArrayAttr]:
|
||||
if "branch_weights" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["branch_weights"]
|
||||
|
||||
def cond_br(condition: _ods_ir.Value[_ods_ir.IntegerType], true_dest_operands: _Sequence[_ods_ir.Value], false_dest_operands: _Sequence[_ods_ir.Value], true_dest: _ods_ir.Block, false_dest: _ods_ir.Block, *, branch_weights: _Optional[_Union[_Sequence[int], _ods_ir.DenseI32ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> CondBranchOp:
|
||||
return CondBranchOp(condition=condition, trueDestOperands=true_dest_operands, falseDestOperands=false_dest_operands, trueDest=true_dest, falseDest=false_dest, branch_weights=branch_weights, loc=loc, ip=ip)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class SwitchOp(_ods_ir.OpView):
|
||||
r"""
|
||||
The `cf.switch` terminator operation represents a switch on a signless integer
|
||||
value. If the flag matches one of the specified cases, then the
|
||||
corresponding destination is jumped to. If the flag does not match any of
|
||||
the cases, the default destination is jumped to. The count and types of
|
||||
operands must align with the arguments in the corresponding target blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
cf.switch %flag : i32, [
|
||||
default: ^bb1(%a : i32),
|
||||
42: ^bb1(%b : i32),
|
||||
43: ^bb3(%c : i32)
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "cf.switch"
|
||||
|
||||
_ODS_OPERAND_SEGMENTS = [1,-1,-1,]
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, flag: _ods_ir.Value[_ods_ir.IntegerType], defaultOperands: _Sequence[_ods_ir.Value], caseOperands: _Sequence[_ods_ir.Value], case_operand_segments: _Union[_Sequence[int], _ods_ir.DenseI32ArrayAttr], defaultDestination: _ods_ir.Block, caseDestinations: _Sequence[_ods_ir.Block], *, case_values: _Optional[_Union[_Any, _ods_ir.DenseIntElementsAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(flag)
|
||||
operands.append(_get_op_results_or_values(defaultOperands))
|
||||
operands.append(_get_op_results_or_values(caseOperands))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
if case_values is not None: attributes["case_values"] = (case_values if (
|
||||
isinstance(case_values, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('AnyIntElementsAttr')) else
|
||||
_ods_ir.AttrBuilder.get('AnyIntElementsAttr')(case_values, context=_ods_context))
|
||||
attributes["case_operand_segments"] = (case_operand_segments if (
|
||||
isinstance(case_operand_segments, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('DenseI32ArrayAttr')) else
|
||||
_ods_ir.AttrBuilder.get('DenseI32ArrayAttr')(case_operand_segments, context=_ods_context))
|
||||
results = []
|
||||
_ods_successors = []
|
||||
_ods_successors.append(defaultDestination)
|
||||
_ods_successors.extend(caseDestinations)
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def flag(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operation.operands,
|
||||
self.operation.attributes["operandSegmentSizes"], 0)
|
||||
return operand_range[0]
|
||||
|
||||
@builtins.property
|
||||
def defaultOperands(self) -> _ods_ir.OpOperandList:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operation.operands,
|
||||
self.operation.attributes["operandSegmentSizes"], 1)
|
||||
return operand_range
|
||||
|
||||
@builtins.property
|
||||
def caseOperands(self) -> _ods_ir.OpOperandList:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operation.operands,
|
||||
self.operation.attributes["operandSegmentSizes"], 2)
|
||||
return operand_range
|
||||
|
||||
@builtins.property
|
||||
def case_values(self) -> _Optional[_ods_ir.DenseIntElementsAttr]:
|
||||
if "case_values" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["case_values"]
|
||||
|
||||
@case_values.setter
|
||||
def case_values(self, value: _Optional[_ods_ir.DenseIntElementsAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["case_values"] = value
|
||||
elif "case_values" in self.operation.attributes:
|
||||
del self.operation.attributes["case_values"]
|
||||
|
||||
@case_values.deleter
|
||||
def case_values(self):
|
||||
del self.operation.attributes["case_values"]
|
||||
|
||||
@builtins.property
|
||||
def case_operand_segments(self) -> _ods_ir.DenseI32ArrayAttr:
|
||||
return self.operation.attributes["case_operand_segments"]
|
||||
|
||||
@case_operand_segments.setter
|
||||
def case_operand_segments(self, value: _ods_ir.DenseI32ArrayAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["case_operand_segments"] = value
|
||||
|
||||
@_ods_cext.register_op_adaptor(SwitchOp)
|
||||
class SwitchOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "cf.switch"
|
||||
|
||||
@builtins.property
|
||||
def flag(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operands,
|
||||
self.attributes["operandSegmentSizes"], 0)
|
||||
return operand_range[0]
|
||||
|
||||
@builtins.property
|
||||
def defaultOperands(self) -> _ods_ir.OpOperandList:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operands,
|
||||
self.attributes["operandSegmentSizes"], 1)
|
||||
return operand_range
|
||||
|
||||
@builtins.property
|
||||
def caseOperands(self) -> _ods_ir.OpOperandList:
|
||||
operand_range = _ods_segmented_accessor(
|
||||
self.operands,
|
||||
self.attributes["operandSegmentSizes"], 2)
|
||||
return operand_range
|
||||
|
||||
@builtins.property
|
||||
def case_values(self) -> _Optional[_ods_ir.DenseIntElementsAttr]:
|
||||
if "case_values" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["case_values"]
|
||||
|
||||
@builtins.property
|
||||
def case_operand_segments(self) -> _ods_ir.DenseI32ArrayAttr:
|
||||
return self.attributes["case_operand_segments"]
|
||||
|
||||
def switch(flag: _ods_ir.Value[_ods_ir.IntegerType], default_operands: _Sequence[_ods_ir.Value], case_operands: _Sequence[_ods_ir.Value], case_operand_segments: _Union[_Sequence[int], _ods_ir.DenseI32ArrayAttr], default_destination: _ods_ir.Block, case_destinations: _Sequence[_ods_ir.Block], *, case_values: _Optional[_Union[_Any, _ods_ir.DenseIntElementsAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> SwitchOp:
|
||||
return SwitchOp(flag=flag, defaultOperands=default_operands, caseOperands=case_operands, case_operand_segments=case_operand_segments, defaultDestination=default_destination, caseDestinations=case_destinations, case_values=case_values, loc=loc, ip=ip)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,600 @@
|
||||
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ._ods_common import (
|
||||
equally_sized_accessor as _ods_equally_sized_accessor,
|
||||
get_default_loc_context as _ods_get_default_loc_context,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
segmented_accessor as _ods_segmented_accessor,
|
||||
)
|
||||
_ods_ir = _ods_cext.ir
|
||||
_ods_cext.globals.register_traceback_file_exclusion(__file__)
|
||||
|
||||
import builtins
|
||||
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
|
||||
import sys as _sys
|
||||
if _sys.version_info >= (3, 12):
|
||||
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
|
||||
else:
|
||||
try:
|
||||
from typing_extensions import Buffer as _Buffer
|
||||
except ImportError:
|
||||
_Buffer = _Any
|
||||
|
||||
|
||||
@_ods_cext.register_dialect
|
||||
class _Dialect(_ods_ir.Dialect):
|
||||
DIALECT_NAMESPACE = "func"
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class CallIndirectOp(_ods_ir.OpView):
|
||||
r"""
|
||||
The `func.call_indirect` operation represents an indirect call to a value
|
||||
of function type. The operands and result types of the call must match the
|
||||
specified function type.
|
||||
|
||||
Function values can be created with the
|
||||
[`func.constant` operation](#funcconstant-constantop).
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%func = func.constant @my_func : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
|
||||
%result = func.call_indirect %func(%0, %1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "func.call_indirect"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, results_: _Sequence[_ods_ir.Type], callee: _ods_ir.Value, callee_operands: _Sequence[_ods_ir.Value], *, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(callee)
|
||||
operands.extend(_get_op_results_or_values(callee_operands))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
if arg_attrs is not None: attributes["arg_attrs"] = (arg_attrs if (
|
||||
isinstance(arg_attrs, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
|
||||
_ods_ir.AttrBuilder.get('DictArrayAttr')(arg_attrs, context=_ods_context))
|
||||
if res_attrs is not None: attributes["res_attrs"] = (res_attrs if (
|
||||
isinstance(res_attrs, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
|
||||
_ods_ir.AttrBuilder.get('DictArrayAttr')(res_attrs, context=_ods_context))
|
||||
results = []
|
||||
results.extend(results_)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def callee(self) -> _ods_ir.Value:
|
||||
return self.operation.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def callee_operands(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 2 + 1
|
||||
return self.operation.operands[1:1 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "arg_attrs" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["arg_attrs"]
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["arg_attrs"] = value
|
||||
elif "arg_attrs" in self.operation.attributes:
|
||||
del self.operation.attributes["arg_attrs"]
|
||||
|
||||
@arg_attrs.deleter
|
||||
def arg_attrs(self):
|
||||
del self.operation.attributes["arg_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "res_attrs" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["res_attrs"]
|
||||
|
||||
@res_attrs.setter
|
||||
def res_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["res_attrs"] = value
|
||||
elif "res_attrs" in self.operation.attributes:
|
||||
del self.operation.attributes["res_attrs"]
|
||||
|
||||
@res_attrs.deleter
|
||||
def res_attrs(self):
|
||||
del self.operation.attributes["res_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def results_(self) -> _ods_ir.OpResultList:
|
||||
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
|
||||
return self.operation.results[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@_ods_cext.register_op_adaptor(CallIndirectOp)
|
||||
class CallIndirectOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "func.call_indirect"
|
||||
|
||||
@builtins.property
|
||||
def callee(self) -> _ods_ir.Value:
|
||||
return self.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def callee_operands(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 2 + 1
|
||||
return self.operands[1:1 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "arg_attrs" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["arg_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "res_attrs" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["res_attrs"]
|
||||
|
||||
def call_indirect(results_: _Sequence[_ods_ir.Type], callee: _ods_ir.Value, callee_operands: _Sequence[_ods_ir.Value], *, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, CallIndirectOp]:
|
||||
op = CallIndirectOp(results_=results_, callee=callee, callee_operands=callee_operands, arg_attrs=arg_attrs, res_attrs=res_attrs, loc=loc, ip=ip); results = op.results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class CallOp(_ods_ir.OpView):
|
||||
r"""
|
||||
The `func.call` operation represents a direct call to a function that is
|
||||
within the same symbol scope as the call. The operands and result types of
|
||||
the call must match the specified function type. The callee is encoded as a
|
||||
symbol reference attribute named "callee".
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = func.call @my_add(%0, %1) : (f32, f32) -> f32
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "func.call"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, result: _Sequence[_ods_ir.Type], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], operands_: _Sequence[_ods_ir.Value], *, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, no_inline: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(operands_))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["callee"] = (callee if (
|
||||
isinstance(callee, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
|
||||
_ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(callee, context=_ods_context))
|
||||
if arg_attrs is not None: attributes["arg_attrs"] = (arg_attrs if (
|
||||
isinstance(arg_attrs, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
|
||||
_ods_ir.AttrBuilder.get('DictArrayAttr')(arg_attrs, context=_ods_context))
|
||||
if res_attrs is not None: attributes["res_attrs"] = (res_attrs if (
|
||||
isinstance(res_attrs, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
|
||||
_ods_ir.AttrBuilder.get('DictArrayAttr')(res_attrs, context=_ods_context))
|
||||
if bool(no_inline): attributes["no_inline"] = _ods_ir.UnitAttr.get(
|
||||
_ods_get_default_loc_context(loc))
|
||||
results = []
|
||||
results.extend(result)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def operands_(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
|
||||
return self.operation.attributes["callee"]
|
||||
|
||||
@callee.setter
|
||||
def callee(self, value: _ods_ir.FlatSymbolRefAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["callee"] = value
|
||||
|
||||
@builtins.property
|
||||
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "arg_attrs" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["arg_attrs"]
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["arg_attrs"] = value
|
||||
elif "arg_attrs" in self.operation.attributes:
|
||||
del self.operation.attributes["arg_attrs"]
|
||||
|
||||
@arg_attrs.deleter
|
||||
def arg_attrs(self):
|
||||
del self.operation.attributes["arg_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "res_attrs" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["res_attrs"]
|
||||
|
||||
@res_attrs.setter
|
||||
def res_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["res_attrs"] = value
|
||||
elif "res_attrs" in self.operation.attributes:
|
||||
del self.operation.attributes["res_attrs"]
|
||||
|
||||
@res_attrs.deleter
|
||||
def res_attrs(self):
|
||||
del self.operation.attributes["res_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def no_inline(self) -> bool:
|
||||
return "no_inline" in self.operation.attributes
|
||||
|
||||
@no_inline.setter
|
||||
def no_inline(self, value):
|
||||
if bool(value):
|
||||
self.operation.attributes["no_inline"] = _ods_ir.UnitAttr.get()
|
||||
elif "no_inline" in self.operation.attributes:
|
||||
del self.operation.attributes["no_inline"]
|
||||
|
||||
@no_inline.deleter
|
||||
def no_inline(self):
|
||||
del self.operation.attributes["no_inline"]
|
||||
|
||||
@_ods_cext.register_op_adaptor(CallOp)
|
||||
class CallOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "func.call"
|
||||
|
||||
@builtins.property
|
||||
def operands_(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
|
||||
return self.attributes["callee"]
|
||||
|
||||
@builtins.property
|
||||
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "arg_attrs" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["arg_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "res_attrs" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["res_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def no_inline(self) -> bool:
|
||||
return "no_inline" in self.attributes
|
||||
|
||||
def call(result: _Sequence[_ods_ir.Type], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], operands_: _Sequence[_ods_ir.Value], *, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, no_inline: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, CallOp]:
|
||||
op = CallOp(result=result, callee=callee, operands_=operands_, arg_attrs=arg_attrs, res_attrs=res_attrs, no_inline=no_inline, loc=loc, ip=ip); results = op.results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class ConstantOp(_ods_ir.OpView):
|
||||
r"""
|
||||
The `func.constant` operation produces an SSA value from a symbol reference
|
||||
to a `func.func` operation
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Reference to function @myfn.
|
||||
%2 = func.constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
|
||||
|
||||
// Equivalent generic forms
|
||||
%2 = "func.constant"() { value = @myfn } : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>)
|
||||
```
|
||||
|
||||
MLIR does not allow direct references to functions in SSA operands because
|
||||
the compiler is multithreaded, and disallowing SSA values to directly
|
||||
reference a function simplifies this
|
||||
([rationale](../Rationale/Rationale.md#multithreading-the-compiler)).
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "func.constant"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, result: _ods_ir.Type, value: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["value"] = (value if (
|
||||
isinstance(value, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
|
||||
_ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(value, context=_ods_context))
|
||||
results = []
|
||||
results.append(result)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def value(self) -> _ods_ir.FlatSymbolRefAttr:
|
||||
return self.operation.attributes["value"]
|
||||
|
||||
@value.setter
|
||||
def value(self, value: _ods_ir.FlatSymbolRefAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["value"] = value
|
||||
|
||||
@_ods_cext.register_op_adaptor(ConstantOp)
|
||||
class ConstantOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "func.constant"
|
||||
|
||||
@builtins.property
|
||||
def value(self) -> _ods_ir.FlatSymbolRefAttr:
|
||||
return self.attributes["value"]
|
||||
|
||||
def constant(result: _ods_ir.Type, value: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
|
||||
return ConstantOp(result=result, value=value, loc=loc, ip=ip).result
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class FuncOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Operations within the function cannot implicitly capture values defined
|
||||
outside of the function, i.e. Functions are `IsolatedFromAbove`. All
|
||||
external references must use function arguments or attributes that establish
|
||||
a symbolic connection (e.g. symbols referenced by name via a string
|
||||
attribute like SymbolRefAttr). An external function declaration (used when
|
||||
referring to a function declared in some other module) has no body. While
|
||||
the MLIR textual form provides a nice inline syntax for function arguments,
|
||||
they are internally represented as “block arguments” to the first block in
|
||||
the region.
|
||||
|
||||
Only dialect attribute names may be specified in the attribute dictionaries
|
||||
for function arguments, results, or the function itself.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// External function definitions.
|
||||
func.func private @abort()
|
||||
func.func private @scribble(i32, i64, memref<? x 128 x f32, #layout_map0>) -> f64
|
||||
|
||||
// A function that returns its argument twice:
|
||||
func.func @count(%x: i64) -> (i64, i64)
|
||||
attributes {fruit = "banana"} {
|
||||
return %x, %x: i64, i64
|
||||
}
|
||||
|
||||
// A function with an argument attribute
|
||||
func.func private @example_fn_arg(%x: i32 {swift.self = unit})
|
||||
|
||||
// A function with a result attribute
|
||||
func.func private @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64})
|
||||
|
||||
// A function with an attribute
|
||||
func.func private @example_fn_attr() attributes {dialectName.attrName = false}
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "func.func"
|
||||
|
||||
_ODS_REGIONS = (1, True)
|
||||
|
||||
def __init__(self, sym_name: _Union[str, _ods_ir.StringAttr], function_type: _Union[_Any, _ods_ir.TypeAttr], *, sym_visibility: _Optional[_Union[str, _ods_ir.StringAttr]] = None, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, no_inline: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["sym_name"] = (sym_name if (
|
||||
isinstance(sym_name, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('SymbolNameAttr')) else
|
||||
_ods_ir.AttrBuilder.get('SymbolNameAttr')(sym_name, context=_ods_context))
|
||||
attributes["function_type"] = (function_type if (
|
||||
isinstance(function_type, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('anonymous_452')) else
|
||||
_ods_ir.AttrBuilder.get('anonymous_452')(function_type, context=_ods_context))
|
||||
if sym_visibility is not None: attributes["sym_visibility"] = (sym_visibility if (
|
||||
isinstance(sym_visibility, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('StrAttr')) else
|
||||
_ods_ir.AttrBuilder.get('StrAttr')(sym_visibility, context=_ods_context))
|
||||
if arg_attrs is not None: attributes["arg_attrs"] = (arg_attrs if (
|
||||
isinstance(arg_attrs, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
|
||||
_ods_ir.AttrBuilder.get('DictArrayAttr')(arg_attrs, context=_ods_context))
|
||||
if res_attrs is not None: attributes["res_attrs"] = (res_attrs if (
|
||||
isinstance(res_attrs, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
|
||||
_ods_ir.AttrBuilder.get('DictArrayAttr')(res_attrs, context=_ods_context))
|
||||
if bool(no_inline): attributes["no_inline"] = _ods_ir.UnitAttr.get(
|
||||
_ods_get_default_loc_context(loc))
|
||||
results = []
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def sym_name(self) -> _ods_ir.StringAttr:
|
||||
return self.operation.attributes["sym_name"]
|
||||
|
||||
@sym_name.setter
|
||||
def sym_name(self, value: _ods_ir.StringAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["sym_name"] = value
|
||||
|
||||
@builtins.property
|
||||
def function_type(self) -> _ods_ir.TypeAttr:
|
||||
return self.operation.attributes["function_type"]
|
||||
|
||||
@function_type.setter
|
||||
def function_type(self, value: _ods_ir.TypeAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["function_type"] = value
|
||||
|
||||
@builtins.property
|
||||
def sym_visibility(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "sym_visibility" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["sym_visibility"]
|
||||
|
||||
@sym_visibility.setter
|
||||
def sym_visibility(self, value: _Optional[_ods_ir.StringAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["sym_visibility"] = value
|
||||
elif "sym_visibility" in self.operation.attributes:
|
||||
del self.operation.attributes["sym_visibility"]
|
||||
|
||||
@sym_visibility.deleter
|
||||
def sym_visibility(self):
|
||||
del self.operation.attributes["sym_visibility"]
|
||||
|
||||
@builtins.property
|
||||
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "arg_attrs" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["arg_attrs"]
|
||||
|
||||
@arg_attrs.setter
|
||||
def arg_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["arg_attrs"] = value
|
||||
elif "arg_attrs" in self.operation.attributes:
|
||||
del self.operation.attributes["arg_attrs"]
|
||||
|
||||
@arg_attrs.deleter
|
||||
def arg_attrs(self):
|
||||
del self.operation.attributes["arg_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "res_attrs" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["res_attrs"]
|
||||
|
||||
@res_attrs.setter
|
||||
def res_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["res_attrs"] = value
|
||||
elif "res_attrs" in self.operation.attributes:
|
||||
del self.operation.attributes["res_attrs"]
|
||||
|
||||
@res_attrs.deleter
|
||||
def res_attrs(self):
|
||||
del self.operation.attributes["res_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def no_inline(self) -> bool:
|
||||
return "no_inline" in self.operation.attributes
|
||||
|
||||
@no_inline.setter
|
||||
def no_inline(self, value):
|
||||
if bool(value):
|
||||
self.operation.attributes["no_inline"] = _ods_ir.UnitAttr.get()
|
||||
elif "no_inline" in self.operation.attributes:
|
||||
del self.operation.attributes["no_inline"]
|
||||
|
||||
@no_inline.deleter
|
||||
def no_inline(self):
|
||||
del self.operation.attributes["no_inline"]
|
||||
|
||||
@builtins.property
|
||||
def body(self) -> _ods_ir.Region:
|
||||
return self.regions[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(FuncOp)
|
||||
class FuncOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "func.func"
|
||||
|
||||
@builtins.property
|
||||
def sym_name(self) -> _ods_ir.StringAttr:
|
||||
return self.attributes["sym_name"]
|
||||
|
||||
@builtins.property
|
||||
def function_type(self) -> _ods_ir.TypeAttr:
|
||||
return self.attributes["function_type"]
|
||||
|
||||
@builtins.property
|
||||
def sym_visibility(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "sym_visibility" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["sym_visibility"]
|
||||
|
||||
@builtins.property
|
||||
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "arg_attrs" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["arg_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
|
||||
if "res_attrs" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["res_attrs"]
|
||||
|
||||
@builtins.property
|
||||
def no_inline(self) -> bool:
|
||||
return "no_inline" in self.attributes
|
||||
|
||||
def func(sym_name: _Union[str, _ods_ir.StringAttr], function_type: _Union[_Any, _ods_ir.TypeAttr], *, sym_visibility: _Optional[_Union[str, _ods_ir.StringAttr]] = None, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, no_inline: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> FuncOp:
|
||||
return FuncOp(sym_name=sym_name, function_type=function_type, sym_visibility=sym_visibility, arg_attrs=arg_attrs, res_attrs=res_attrs, no_inline=no_inline, loc=loc, ip=ip)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class ReturnOp(_ods_ir.OpView):
|
||||
r"""
|
||||
The `func.return` operation represents a return operation within a function.
|
||||
The operation takes variable number of operands and produces no results.
|
||||
The operand number and types must match the signature of the function
|
||||
that contains the operation.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
func.func @foo() -> (i32, f8) {
|
||||
...
|
||||
return %0, %1 : i32, f8
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "func.return"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, operands_: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(operands_))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
results = []
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def operands_(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@_ods_cext.register_op_adaptor(ReturnOp)
|
||||
class ReturnOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "func.return"
|
||||
|
||||
@builtins.property
|
||||
def operands_(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
def return_(operands_: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> ReturnOp:
|
||||
return ReturnOp(operands_=operands_, loc=loc, ip=ip)
|
||||
@@ -0,0 +1,440 @@
|
||||
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from enum import IntEnum, auto, IntFlag
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ..ir import register_attribute_builder
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
||||
class AddressSpace(IntEnum):
|
||||
"""GPU address space"""
|
||||
|
||||
Global = 1
|
||||
Workgroup = 2
|
||||
Private = 3
|
||||
Constant = 4
|
||||
|
||||
def __str__(self):
|
||||
if self is AddressSpace.Global:
|
||||
return "global"
|
||||
if self is AddressSpace.Workgroup:
|
||||
return "workgroup"
|
||||
if self is AddressSpace.Private:
|
||||
return "private"
|
||||
if self is AddressSpace.Constant:
|
||||
return "constant"
|
||||
raise ValueError("Unknown AddressSpace enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_AddressSpaceEnum", allow_existing=True)
|
||||
def _gpu_addressspaceenum(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class AllReduceOperation(IntEnum):
|
||||
"""built-in reduction operations supported by gpu.allreduce."""
|
||||
|
||||
ADD = 0
|
||||
MUL = 1
|
||||
MINUI = 2
|
||||
MINSI = 3
|
||||
MINNUMF = 4
|
||||
MAXUI = 5
|
||||
MAXSI = 6
|
||||
MAXNUMF = 7
|
||||
AND = 8
|
||||
OR = 9
|
||||
XOR = 10
|
||||
MINIMUMF = 11
|
||||
MAXIMUMF = 12
|
||||
|
||||
def __str__(self):
|
||||
if self is AllReduceOperation.ADD:
|
||||
return "add"
|
||||
if self is AllReduceOperation.MUL:
|
||||
return "mul"
|
||||
if self is AllReduceOperation.MINUI:
|
||||
return "minui"
|
||||
if self is AllReduceOperation.MINSI:
|
||||
return "minsi"
|
||||
if self is AllReduceOperation.MINNUMF:
|
||||
return "minnumf"
|
||||
if self is AllReduceOperation.MAXUI:
|
||||
return "maxui"
|
||||
if self is AllReduceOperation.MAXSI:
|
||||
return "maxsi"
|
||||
if self is AllReduceOperation.MAXNUMF:
|
||||
return "maxnumf"
|
||||
if self is AllReduceOperation.AND:
|
||||
return "and"
|
||||
if self is AllReduceOperation.OR:
|
||||
return "or"
|
||||
if self is AllReduceOperation.XOR:
|
||||
return "xor"
|
||||
if self is AllReduceOperation.MINIMUMF:
|
||||
return "minimumf"
|
||||
if self is AllReduceOperation.MAXIMUMF:
|
||||
return "maximumf"
|
||||
raise ValueError("Unknown AllReduceOperation enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_AllReduceOperation", allow_existing=True)
|
||||
def _gpu_allreduceoperation(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class BroadcastType(IntEnum):
|
||||
"""a lane to broadcast from"""
|
||||
|
||||
first_active_lane = 0
|
||||
specific_lane = 1
|
||||
|
||||
def __str__(self):
|
||||
if self is BroadcastType.first_active_lane:
|
||||
return "first_active_lane"
|
||||
if self is BroadcastType.specific_lane:
|
||||
return "specific_lane"
|
||||
raise ValueError("Unknown BroadcastType enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_BroadcastType", allow_existing=True)
|
||||
def _gpu_broadcasttype(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class CompilationTarget(IntEnum):
|
||||
"""GPU compilation format"""
|
||||
|
||||
Offload = 1
|
||||
Assembly = 2
|
||||
Binary = 3
|
||||
Fatbin = 4
|
||||
|
||||
def __str__(self):
|
||||
if self is CompilationTarget.Offload:
|
||||
return "offload"
|
||||
if self is CompilationTarget.Assembly:
|
||||
return "assembly"
|
||||
if self is CompilationTarget.Binary:
|
||||
return "bin"
|
||||
if self is CompilationTarget.Fatbin:
|
||||
return "fatbin"
|
||||
raise ValueError("Unknown CompilationTarget enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_CompilationTargetEnum", allow_existing=True)
|
||||
def _gpu_compilationtargetenum(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class Dimension(IntEnum):
|
||||
"""a dimension, either 'x', 'y', or 'z'"""
|
||||
|
||||
x = 0
|
||||
y = 1
|
||||
z = 2
|
||||
|
||||
def __str__(self):
|
||||
if self is Dimension.x:
|
||||
return "x"
|
||||
if self is Dimension.y:
|
||||
return "y"
|
||||
if self is Dimension.z:
|
||||
return "z"
|
||||
raise ValueError("Unknown Dimension enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_Dimension", allow_existing=True)
|
||||
def _gpu_dimension(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class DimensionKind(IntEnum):
|
||||
"""the possible kinds of launch dimension"""
|
||||
|
||||
Other = 0
|
||||
Block = 1
|
||||
Grid = 2
|
||||
Cluster = 3
|
||||
|
||||
def __str__(self):
|
||||
if self is DimensionKind.Other:
|
||||
return "other"
|
||||
if self is DimensionKind.Block:
|
||||
return "block"
|
||||
if self is DimensionKind.Grid:
|
||||
return "grid"
|
||||
if self is DimensionKind.Cluster:
|
||||
return "cluster"
|
||||
raise ValueError("Unknown DimensionKind enum entry.")
|
||||
|
||||
|
||||
|
||||
class Prune2To4SpMatFlag(IntEnum):
|
||||
"""pruning strategy for 2:4 sparse matrix"""
|
||||
|
||||
NONE = 0
|
||||
PRUNE_ONLY = 1
|
||||
PRUNE_AND_CHECK = 2
|
||||
|
||||
def __str__(self):
|
||||
if self is Prune2To4SpMatFlag.NONE:
|
||||
return "NONE"
|
||||
if self is Prune2To4SpMatFlag.PRUNE_ONLY:
|
||||
return "PRUNE_ONLY"
|
||||
if self is Prune2To4SpMatFlag.PRUNE_AND_CHECK:
|
||||
return "PRUNE_AND_CHECK"
|
||||
raise ValueError("Unknown Prune2To4SpMatFlag enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_Prune2To4SpMatFlag", allow_existing=True)
|
||||
def _gpu_prune2to4spmatflag(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class ShuffleMode(IntEnum):
|
||||
"""Indexing modes supported by gpu.shuffle."""
|
||||
|
||||
XOR = 0
|
||||
UP = 2
|
||||
DOWN = 1
|
||||
IDX = 3
|
||||
|
||||
def __str__(self):
|
||||
if self is ShuffleMode.XOR:
|
||||
return "xor"
|
||||
if self is ShuffleMode.UP:
|
||||
return "up"
|
||||
if self is ShuffleMode.DOWN:
|
||||
return "down"
|
||||
if self is ShuffleMode.IDX:
|
||||
return "idx"
|
||||
raise ValueError("Unknown ShuffleMode enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_ShuffleMode", allow_existing=True)
|
||||
def _gpu_shufflemode(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class SpGEMMWorkEstimationOrComputeKind(IntEnum):
|
||||
"""choose whether spgemm_work_estimation_or_compute does work estimation or compute"""
|
||||
|
||||
WORK_ESTIMATION = 0
|
||||
COMPUTE = 1
|
||||
|
||||
def __str__(self):
|
||||
if self is SpGEMMWorkEstimationOrComputeKind.WORK_ESTIMATION:
|
||||
return "WORK_ESTIMATION"
|
||||
if self is SpGEMMWorkEstimationOrComputeKind.COMPUTE:
|
||||
return "COMPUTE"
|
||||
raise ValueError("Unknown SpGEMMWorkEstimationOrComputeKind enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_SpGEMMWorkEstimationOrComputeKind", allow_existing=True)
|
||||
def _gpu_spgemmworkestimationorcomputekind(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class TransposeMode(IntEnum):
|
||||
"""transpose mode of sparse matrix supported by sparse tensor ops"""
|
||||
|
||||
NON_TRANSPOSE = 0
|
||||
TRANSPOSE = 1
|
||||
CONJUGATE_TRANSPOSE = 2
|
||||
|
||||
def __str__(self):
|
||||
if self is TransposeMode.NON_TRANSPOSE:
|
||||
return "NON_TRANSPOSE"
|
||||
if self is TransposeMode.TRANSPOSE:
|
||||
return "TRANSPOSE"
|
||||
if self is TransposeMode.CONJUGATE_TRANSPOSE:
|
||||
return "CONJUGATE_TRANSPOSE"
|
||||
raise ValueError("Unknown TransposeMode enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("GPU_TransposeMode", allow_existing=True)
|
||||
def _gpu_transposemode(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class MMAElementwiseOp(IntEnum):
|
||||
"""elementwise operation to apply to mma matrix"""
|
||||
|
||||
ADDF = 0
|
||||
MULF = 1
|
||||
SUBF = 2
|
||||
MAXF = 3
|
||||
MINF = 4
|
||||
DIVF = 5
|
||||
ADDI = 6
|
||||
MULI = 7
|
||||
SUBI = 8
|
||||
DIVS = 9
|
||||
DIVU = 10
|
||||
NEGATEF = 11
|
||||
NEGATES = 12
|
||||
EXTF = 13
|
||||
TRUNCF = 14
|
||||
|
||||
def __str__(self):
|
||||
if self is MMAElementwiseOp.ADDF:
|
||||
return "addf"
|
||||
if self is MMAElementwiseOp.MULF:
|
||||
return "mulf"
|
||||
if self is MMAElementwiseOp.SUBF:
|
||||
return "subf"
|
||||
if self is MMAElementwiseOp.MAXF:
|
||||
return "maxf"
|
||||
if self is MMAElementwiseOp.MINF:
|
||||
return "minf"
|
||||
if self is MMAElementwiseOp.DIVF:
|
||||
return "divf"
|
||||
if self is MMAElementwiseOp.ADDI:
|
||||
return "addi"
|
||||
if self is MMAElementwiseOp.MULI:
|
||||
return "muli"
|
||||
if self is MMAElementwiseOp.SUBI:
|
||||
return "subi"
|
||||
if self is MMAElementwiseOp.DIVS:
|
||||
return "divs"
|
||||
if self is MMAElementwiseOp.DIVU:
|
||||
return "divu"
|
||||
if self is MMAElementwiseOp.NEGATEF:
|
||||
return "negatef"
|
||||
if self is MMAElementwiseOp.NEGATES:
|
||||
return "negates"
|
||||
if self is MMAElementwiseOp.EXTF:
|
||||
return "extf"
|
||||
if self is MMAElementwiseOp.TRUNCF:
|
||||
return "truncf"
|
||||
raise ValueError("Unknown MMAElementwiseOp enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("MMAElementWise", allow_existing=True)
|
||||
def _mmaelementwise(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class MappingId(IntEnum):
|
||||
"""Mapping ids for loop mapping"""
|
||||
|
||||
DimX = 0
|
||||
DimY = 1
|
||||
DimZ = 2
|
||||
LinearDim0 = 3
|
||||
LinearDim1 = 4
|
||||
LinearDim2 = 5
|
||||
LinearDim3 = 6
|
||||
LinearDim4 = 7
|
||||
LinearDim5 = 8
|
||||
LinearDim6 = 9
|
||||
LinearDim7 = 10
|
||||
LinearDim8 = 11
|
||||
LinearDim9 = 12
|
||||
|
||||
def __str__(self):
|
||||
if self is MappingId.DimX:
|
||||
return "x"
|
||||
if self is MappingId.DimY:
|
||||
return "y"
|
||||
if self is MappingId.DimZ:
|
||||
return "z"
|
||||
if self is MappingId.LinearDim0:
|
||||
return "linear_dim_0"
|
||||
if self is MappingId.LinearDim1:
|
||||
return "linear_dim_1"
|
||||
if self is MappingId.LinearDim2:
|
||||
return "linear_dim_2"
|
||||
if self is MappingId.LinearDim3:
|
||||
return "linear_dim_3"
|
||||
if self is MappingId.LinearDim4:
|
||||
return "linear_dim_4"
|
||||
if self is MappingId.LinearDim5:
|
||||
return "linear_dim_5"
|
||||
if self is MappingId.LinearDim6:
|
||||
return "linear_dim_6"
|
||||
if self is MappingId.LinearDim7:
|
||||
return "linear_dim_7"
|
||||
if self is MappingId.LinearDim8:
|
||||
return "linear_dim_8"
|
||||
if self is MappingId.LinearDim9:
|
||||
return "linear_dim_9"
|
||||
raise ValueError("Unknown MappingId enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("MappingIdEnum", allow_existing=True)
|
||||
def _mappingidenum(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
|
||||
|
||||
class Processor(IntEnum):
|
||||
"""processor for loop mapping"""
|
||||
|
||||
BlockX = 0
|
||||
BlockY = 1
|
||||
BlockZ = 2
|
||||
ThreadX = 3
|
||||
ThreadY = 4
|
||||
ThreadZ = 5
|
||||
Sequential = 6
|
||||
|
||||
def __str__(self):
|
||||
if self is Processor.BlockX:
|
||||
return "block_x"
|
||||
if self is Processor.BlockY:
|
||||
return "block_y"
|
||||
if self is Processor.BlockZ:
|
||||
return "block_z"
|
||||
if self is Processor.ThreadX:
|
||||
return "thread_x"
|
||||
if self is Processor.ThreadY:
|
||||
return "thread_y"
|
||||
if self is Processor.ThreadZ:
|
||||
return "thread_z"
|
||||
if self is Processor.Sequential:
|
||||
return "sequential"
|
||||
raise ValueError("Unknown Processor enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("ProcessorEnum", allow_existing=True)
|
||||
def _processorenum(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
|
||||
|
||||
@register_attribute_builder("gpu.GPU_AddressSpaceAttr")
|
||||
def _gpu_addressspaceattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu.address_space<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("gpu.GPU_AllReduceOperationAttr")
|
||||
def _gpu_allreduceoperationattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu<all_reduce_op {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("gpu.GPU_BroadcastTypeAttr")
|
||||
def _gpu_broadcasttypeattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu<broadcast {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("gpu.GPU_DimensionAttr")
|
||||
def _gpu_dimensionattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu<dim {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("gpu.GPU_Prune2To4SpMatFlagAttr")
|
||||
def _gpu_prune2to4spmatflagattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu<prune_2to4_spmat_flag {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("gpu.GPU_ShuffleModeAttr")
|
||||
def _gpu_shufflemodeattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu<shuffle_mode {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("gpu.GPU_SpGEMMWorkEstimationOrComputeKindAttr")
|
||||
def _gpu_spgemmworkestimationorcomputekindattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu<spgemm_work_estimation_or_compute_kind {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("gpu.GPU_TransposeModeAttr")
|
||||
def _gpu_transposemodeattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu<mat_transpose_mode {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("gpu.MMAElementWiseAttr")
|
||||
def _mmaelementwiseattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#gpu<mma_element_wise {str(x)}>', context=context)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
+3120
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,951 @@
|
||||
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ._ods_common import (
|
||||
equally_sized_accessor as _ods_equally_sized_accessor,
|
||||
get_default_loc_context as _ods_get_default_loc_context,
|
||||
get_op_results_or_values as _get_op_results_or_values,
|
||||
segmented_accessor as _ods_segmented_accessor,
|
||||
)
|
||||
_ods_ir = _ods_cext.ir
|
||||
_ods_cext.globals.register_traceback_file_exclusion(__file__)
|
||||
|
||||
import builtins
|
||||
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
|
||||
import sys as _sys
|
||||
if _sys.version_info >= (3, 12):
|
||||
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
|
||||
else:
|
||||
try:
|
||||
from typing_extensions import Buffer as _Buffer
|
||||
except ImportError:
|
||||
_Buffer = _Any
|
||||
|
||||
|
||||
@_ods_cext.register_dialect
|
||||
class _Dialect(_ods_ir.Dialect):
|
||||
DIALECT_NAMESPACE = "mpmd"
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class AssignOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Assigns a local tensor to a mesh as fully replicated within that mesh.
|
||||
|
||||
This is a temporary op that is introduced when lowering jax ops, to move
|
||||
from local types to mesh types. These ops will be eliminated during import,
|
||||
when the inputs and results of the func op become mesh tensors.
|
||||
|
||||
The mesh name of the result type should correspond to a mesh in the
|
||||
topology, and its global type should be identical to the operand type.
|
||||
|
||||
The origin of the assign op is the origin of mesh, e.g. named_computation,
|
||||
mesh inference, etc.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.assign"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, result: _ods_ir.Type, tensor: _ods_ir.Value, *, origin: _Optional[_Union[str, _ods_ir.StringAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(tensor)
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
if origin is not None: attributes["origin"] = (origin if (
|
||||
isinstance(origin, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('StrAttr')) else
|
||||
_ods_ir.AttrBuilder.get('StrAttr')(origin, context=_ods_context))
|
||||
results = []
|
||||
results.append(result)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operation.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "origin" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["origin"]
|
||||
|
||||
@origin.setter
|
||||
def origin(self, value: _Optional[_ods_ir.StringAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["origin"] = value
|
||||
elif "origin" in self.operation.attributes:
|
||||
del self.operation.attributes["origin"]
|
||||
|
||||
@origin.deleter
|
||||
def origin(self):
|
||||
del self.operation.attributes["origin"]
|
||||
|
||||
@builtins.property
|
||||
def result(self) -> _ods_ir.OpResult:
|
||||
return self.operation.results[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(AssignOp)
|
||||
class AssignOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.assign"
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "origin" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["origin"]
|
||||
|
||||
def assign(result: _ods_ir.Type, tensor: _ods_ir.Value, *, origin: _Optional[_Union[str, _ods_ir.StringAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
|
||||
return AssignOp(result=result, tensor=tensor, origin=origin, loc=loc, ip=ip).result
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class BroadcastOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Allows for a tensor to be transferred (or replicated) in any mesh where it's
|
||||
used. Whenever transferred, the origin of the transfer is the current
|
||||
location of the operand.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.broadcast"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, tensor: _ods_ir.Value, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(tensor)
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
if results is None: results = [operands[0].type] * 1
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operation.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def result(self) -> _ods_ir.OpResult:
|
||||
return self.operation.results[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(BroadcastOp)
|
||||
class BroadcastOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.broadcast"
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operands[0]
|
||||
|
||||
def broadcast(tensor: _ods_ir.Value, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
|
||||
return BroadcastOp(tensor=tensor, results=results, loc=loc, ip=ip).result
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class CallOp(_ods_ir.OpView):
|
||||
r"""
|
||||
A function call operation. Useful to wrap the body of loops in function
|
||||
declarations to reduce code size, for example.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.call"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, result: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(tensors))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["callee"] = (callee if (
|
||||
isinstance(callee, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
|
||||
_ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(callee, context=_ods_context))
|
||||
results = []
|
||||
results.extend(result)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
|
||||
return self.operation.attributes["callee"]
|
||||
|
||||
@callee.setter
|
||||
def callee(self, value: _ods_ir.FlatSymbolRefAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["callee"] = value
|
||||
|
||||
@_ods_cext.register_op_adaptor(CallOp)
|
||||
class CallOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.call"
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
|
||||
return self.attributes["callee"]
|
||||
|
||||
def call(result: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, CallOp]:
|
||||
op = CallOp(result=result, tensors=tensors, callee=callee, loc=loc, ip=ip); results = op.results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class ForOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Returns the result of executing a body function for a fixed number of
|
||||
iterations, with the iteration index available in the body.
|
||||
|
||||
An optional unroll factor, that must divide the number of iterations,
|
||||
can be specified to unroll the body of the op by that factor, i.e. for
|
||||
unroll factor N, the body is replicated to create N copies and the number of
|
||||
iterations is reduced by a factor of 1/N. Each copy except the first uses
|
||||
the results of the previous copy instead of the block arguments, and the
|
||||
iteration index is multiplied by the unroll factor and incremented after
|
||||
every copy.
|
||||
|
||||
A for operator can accept and return any types, but the TypeID of these
|
||||
must be the same -- e.g. all tensor types or all MPMD mesh types etc. This
|
||||
allows us to use the op at various levels, sharing implementation and
|
||||
transformations.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.for"
|
||||
|
||||
_ODS_REGIONS = (1, True)
|
||||
|
||||
def __init__(self, results_: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], iterations: _Union[int, _ods_ir.IntegerAttr], *, unroll_factor: _Optional[_Union[int, _ods_ir.IntegerAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(tensors))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["iterations"] = (iterations if (
|
||||
isinstance(iterations, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('UI32Attr')) else
|
||||
_ods_ir.AttrBuilder.get('UI32Attr')(iterations, context=_ods_context))
|
||||
if unroll_factor is not None: attributes["unroll_factor"] = (unroll_factor if (
|
||||
isinstance(unroll_factor, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('UI32Attr')) else
|
||||
_ods_ir.AttrBuilder.get('UI32Attr')(unroll_factor, context=_ods_context))
|
||||
results = []
|
||||
results.extend(results_)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def iterations(self) -> _ods_ir.IntegerAttr:
|
||||
return self.operation.attributes["iterations"]
|
||||
|
||||
@iterations.setter
|
||||
def iterations(self, value: _ods_ir.IntegerAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["iterations"] = value
|
||||
|
||||
@builtins.property
|
||||
def unroll_factor(self) -> _Optional[_ods_ir.IntegerAttr]:
|
||||
if "unroll_factor" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["unroll_factor"]
|
||||
|
||||
@unroll_factor.setter
|
||||
def unroll_factor(self, value: _Optional[_ods_ir.IntegerAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["unroll_factor"] = value
|
||||
elif "unroll_factor" in self.operation.attributes:
|
||||
del self.operation.attributes["unroll_factor"]
|
||||
|
||||
@unroll_factor.deleter
|
||||
def unroll_factor(self):
|
||||
del self.operation.attributes["unroll_factor"]
|
||||
|
||||
@builtins.property
|
||||
def results_(self) -> _ods_ir.OpResultList:
|
||||
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
|
||||
return self.operation.results[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def region(self) -> _ods_ir.Region:
|
||||
return self.regions[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(ForOp)
|
||||
class ForOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.for"
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def iterations(self) -> _ods_ir.IntegerAttr:
|
||||
return self.attributes["iterations"]
|
||||
|
||||
@builtins.property
|
||||
def unroll_factor(self) -> _Optional[_ods_ir.IntegerAttr]:
|
||||
if "unroll_factor" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["unroll_factor"]
|
||||
|
||||
def for_(results_: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], iterations: _Union[int, _ods_ir.IntegerAttr], *, unroll_factor: _Optional[_Union[int, _ods_ir.IntegerAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, ForOp]:
|
||||
op = ForOp(results_=results_, tensors=tensors, iterations=iterations, unroll_factor=unroll_factor, loc=loc, ip=ip); results = op.results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class FragmentCallOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Represents a call to a function that holds an MPMD fragment body, i.e. a
|
||||
computation assigned to a specific mesh in an MPMD topology, that is
|
||||
intended to be executed as an individual SPMD program fragment.
|
||||
|
||||
The mesh name of the fragment should correspond to a mesh in the topology of
|
||||
the enclosing function, and that mesh shape should match that of the callee.
|
||||
|
||||
The origin specifies the user named computations that contributed to this
|
||||
fragment call e.g. through merging.
|
||||
|
||||
The function input and result types of the callee must be the local tensor
|
||||
types of the corresponding mesh tensors of this op's operands and results
|
||||
respectively.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%2 = mpmd.fragment_call<mesh="m1", origin=[]> @my_fragment(%0, %1) :
|
||||
(mesh_tensor<...>, mesh_tensor<...>) -> mesh_tensor<...>
|
||||
```
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.fragment_call"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, result: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.ArrayAttr], mesh_name: _Union[str, _ods_ir.StringAttr], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(tensors))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["origin"] = (origin if (
|
||||
isinstance(origin, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('anonymous_847')) else
|
||||
_ods_ir.AttrBuilder.get('anonymous_847')(origin, context=_ods_context))
|
||||
attributes["mesh_name"] = (mesh_name if (
|
||||
isinstance(mesh_name, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('StrAttr')) else
|
||||
_ods_ir.AttrBuilder.get('StrAttr')(mesh_name, context=_ods_context))
|
||||
attributes["callee"] = (callee if (
|
||||
isinstance(callee, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
|
||||
_ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(callee, context=_ods_context))
|
||||
results = []
|
||||
results.extend(result)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _ods_ir.ArrayAttr:
|
||||
return self.operation.attributes["origin"]
|
||||
|
||||
@origin.setter
|
||||
def origin(self, value: _ods_ir.ArrayAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["origin"] = value
|
||||
|
||||
@builtins.property
|
||||
def mesh_name(self) -> _ods_ir.StringAttr:
|
||||
return self.operation.attributes["mesh_name"]
|
||||
|
||||
@mesh_name.setter
|
||||
def mesh_name(self, value: _ods_ir.StringAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["mesh_name"] = value
|
||||
|
||||
@builtins.property
|
||||
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
|
||||
return self.operation.attributes["callee"]
|
||||
|
||||
@callee.setter
|
||||
def callee(self, value: _ods_ir.FlatSymbolRefAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["callee"] = value
|
||||
|
||||
@_ods_cext.register_op_adaptor(FragmentCallOp)
|
||||
class FragmentCallOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.fragment_call"
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _ods_ir.ArrayAttr:
|
||||
return self.attributes["origin"]
|
||||
|
||||
@builtins.property
|
||||
def mesh_name(self) -> _ods_ir.StringAttr:
|
||||
return self.attributes["mesh_name"]
|
||||
|
||||
@builtins.property
|
||||
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
|
||||
return self.attributes["callee"]
|
||||
|
||||
def fragment_call(result: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.ArrayAttr], mesh_name: _Union[str, _ods_ir.StringAttr], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, FragmentCallOp]:
|
||||
op = FragmentCallOp(result=result, tensors=tensors, origin=origin, mesh_name=mesh_name, callee=callee, loc=loc, ip=ip); results = op.results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class FragmentOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Assigns a computation, i.e. a block of operations, to a specific mesh in an
|
||||
MPMD topology, that is intended to be executed as an individual SPMD program
|
||||
fragment.
|
||||
|
||||
The fragment takes and returns only mesh tensors that are assigned to the
|
||||
same mesh as the fragment.
|
||||
|
||||
The mesh name of the fragment should correspond to a mesh in the topology.
|
||||
|
||||
The fragment includes a list of origins, i.e., metadata with information re
|
||||
the original named_computations that formed this fragment, and a staged_id
|
||||
defined _iff_ it is a user defined fragment, i.e., it has a non-empty list
|
||||
of origins. The optional in_shardings specifies the sharding of the
|
||||
block arguments of a fragment, which correspond to the operands.
|
||||
The optional out_shardings specifies the shardings of the results.
|
||||
|
||||
The fragment's region shouldn't have any free variables, and the type of
|
||||
each block arguments and returned values in the region is the global tensor
|
||||
type of the corresponding mesh tensor.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.fragment"
|
||||
|
||||
_ODS_REGIONS = (1, True)
|
||||
|
||||
def __init__(self, results_: _Sequence[_ods_ir.Type], inputs: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.ArrayAttr], mesh_name: _Union[str, _ods_ir.StringAttr], *, stage_id: _Optional[_Union[int, _ods_ir.IntegerAttr]] = None, in_shardings: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, out_shardings: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(inputs))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["origin"] = (origin if (
|
||||
isinstance(origin, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('anonymous_847')) else
|
||||
_ods_ir.AttrBuilder.get('anonymous_847')(origin, context=_ods_context))
|
||||
attributes["mesh_name"] = (mesh_name if (
|
||||
isinstance(mesh_name, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('StrAttr')) else
|
||||
_ods_ir.AttrBuilder.get('StrAttr')(mesh_name, context=_ods_context))
|
||||
if stage_id is not None: attributes["stage_id"] = (stage_id if (
|
||||
isinstance(stage_id, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('I64Attr')) else
|
||||
_ods_ir.AttrBuilder.get('I64Attr')(stage_id, context=_ods_context))
|
||||
if in_shardings is not None: attributes["in_shardings"] = (in_shardings if (
|
||||
isinstance(in_shardings, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
|
||||
_ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(in_shardings, context=_ods_context))
|
||||
if out_shardings is not None: attributes["out_shardings"] = (out_shardings if (
|
||||
isinstance(out_shardings, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
|
||||
_ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(out_shardings, context=_ods_context))
|
||||
results = []
|
||||
results.extend(results_)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def inputs(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _ods_ir.ArrayAttr:
|
||||
return self.operation.attributes["origin"]
|
||||
|
||||
@origin.setter
|
||||
def origin(self, value: _ods_ir.ArrayAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["origin"] = value
|
||||
|
||||
@builtins.property
|
||||
def mesh_name(self) -> _ods_ir.StringAttr:
|
||||
return self.operation.attributes["mesh_name"]
|
||||
|
||||
@mesh_name.setter
|
||||
def mesh_name(self, value: _ods_ir.StringAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["mesh_name"] = value
|
||||
|
||||
@builtins.property
|
||||
def stage_id(self) -> _Optional[_ods_ir.IntegerAttr]:
|
||||
if "stage_id" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["stage_id"]
|
||||
|
||||
@stage_id.setter
|
||||
def stage_id(self, value: _Optional[_ods_ir.IntegerAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["stage_id"] = value
|
||||
elif "stage_id" in self.operation.attributes:
|
||||
del self.operation.attributes["stage_id"]
|
||||
|
||||
@stage_id.deleter
|
||||
def stage_id(self):
|
||||
del self.operation.attributes["stage_id"]
|
||||
|
||||
@builtins.property
|
||||
def in_shardings(self) -> _Optional[_ods_ir.Attribute]:
|
||||
if "in_shardings" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["in_shardings"]
|
||||
|
||||
@in_shardings.setter
|
||||
def in_shardings(self, value: _Optional[_ods_ir.Attribute]):
|
||||
if value is not None:
|
||||
self.operation.attributes["in_shardings"] = value
|
||||
elif "in_shardings" in self.operation.attributes:
|
||||
del self.operation.attributes["in_shardings"]
|
||||
|
||||
@in_shardings.deleter
|
||||
def in_shardings(self):
|
||||
del self.operation.attributes["in_shardings"]
|
||||
|
||||
@builtins.property
|
||||
def out_shardings(self) -> _Optional[_ods_ir.Attribute]:
|
||||
if "out_shardings" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["out_shardings"]
|
||||
|
||||
@out_shardings.setter
|
||||
def out_shardings(self, value: _Optional[_ods_ir.Attribute]):
|
||||
if value is not None:
|
||||
self.operation.attributes["out_shardings"] = value
|
||||
elif "out_shardings" in self.operation.attributes:
|
||||
del self.operation.attributes["out_shardings"]
|
||||
|
||||
@out_shardings.deleter
|
||||
def out_shardings(self):
|
||||
del self.operation.attributes["out_shardings"]
|
||||
|
||||
@builtins.property
|
||||
def results_(self) -> _ods_ir.OpResultList:
|
||||
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
|
||||
return self.operation.results[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def region(self) -> _ods_ir.Region:
|
||||
return self.regions[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(FragmentOp)
|
||||
class FragmentOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.fragment"
|
||||
|
||||
@builtins.property
|
||||
def inputs(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _ods_ir.ArrayAttr:
|
||||
return self.attributes["origin"]
|
||||
|
||||
@builtins.property
|
||||
def mesh_name(self) -> _ods_ir.StringAttr:
|
||||
return self.attributes["mesh_name"]
|
||||
|
||||
@builtins.property
|
||||
def stage_id(self) -> _Optional[_ods_ir.IntegerAttr]:
|
||||
if "stage_id" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["stage_id"]
|
||||
|
||||
@builtins.property
|
||||
def in_shardings(self) -> _Optional[_ods_ir.Attribute]:
|
||||
if "in_shardings" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["in_shardings"]
|
||||
|
||||
@builtins.property
|
||||
def out_shardings(self) -> _Optional[_ods_ir.Attribute]:
|
||||
if "out_shardings" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["out_shardings"]
|
||||
|
||||
def fragment(results_: _Sequence[_ods_ir.Type], inputs: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.ArrayAttr], mesh_name: _Union[str, _ods_ir.StringAttr], *, stage_id: _Optional[_Union[int, _ods_ir.IntegerAttr]] = None, in_shardings: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, out_shardings: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, FragmentOp]:
|
||||
op = FragmentOp(results_=results_, inputs=inputs, origin=origin, mesh_name=mesh_name, stage_id=stage_id, in_shardings=in_shardings, out_shardings=out_shardings, loc=loc, ip=ip); results = op.results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class NamedComputationOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Groups a computation, i.e. a block of operations, and gives it a name and
|
||||
a transpose count via the UserOrigin attribute. This NamedComputation can be
|
||||
used to assign a mesh to the computation in MPMD or for optimizations.
|
||||
|
||||
The transpose count (default=0) denotes whether the named computation has
|
||||
been produced by a certain number of JAX AD transpose transformations.
|
||||
|
||||
The op's region shouldn't have any free variables, and the type of
|
||||
each block arguments and returned values in the region must be the same as
|
||||
the type of the inputs and the return type of the op.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.named_computation"
|
||||
|
||||
_ODS_REGIONS = (1, True)
|
||||
|
||||
def __init__(self, results_: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.Attribute], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(tensors))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["origin"] = (origin if (
|
||||
isinstance(origin, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('Mpmd_UserOrigin')) else
|
||||
_ods_ir.AttrBuilder.get('Mpmd_UserOrigin')(origin, context=_ods_context))
|
||||
results = []
|
||||
results.extend(results_)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _ods_ir.Attribute:
|
||||
return self.operation.attributes["origin"]
|
||||
|
||||
@origin.setter
|
||||
def origin(self, value: _ods_ir.Attribute):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["origin"] = value
|
||||
|
||||
@builtins.property
|
||||
def results_(self) -> _ods_ir.OpResultList:
|
||||
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
|
||||
return self.operation.results[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def region(self) -> _ods_ir.Region:
|
||||
return self.regions[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(NamedComputationOp)
|
||||
class NamedComputationOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.named_computation"
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _ods_ir.Attribute:
|
||||
return self.attributes["origin"]
|
||||
|
||||
def named_computation(results_: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.Attribute], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, NamedComputationOp]:
|
||||
op = NamedComputationOp(results_=results_, tensors=tensors, origin=origin, loc=loc, ip=ip); results = op.results
|
||||
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class NamedTensorOp(_ods_ir.OpView):
|
||||
r"""
|
||||
An identity op that associates the result of the tensor with a given name.
|
||||
This NamedTensor can be used to assign a mesh to the tensor in MPMD.
|
||||
|
||||
NOTE: this is different than TagOp in that TagOp is used for naming a tensor
|
||||
and can be used to partition that tensor. NamedTensorOp is for MPMD programs
|
||||
for tensors that may be explicitly assigned to meshes.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.named_tensor"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, tensor: _ods_ir.Value, name: _Union[str, _ods_ir.StringAttr], *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(tensor)
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
attributes["name"] = (name if (
|
||||
isinstance(name, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('StrAttr')) else
|
||||
_ods_ir.AttrBuilder.get('StrAttr')(name, context=_ods_context))
|
||||
if results is None: results = [operands[0].type] * 1
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operation.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def name(self) -> _ods_ir.StringAttr:
|
||||
return self.operation.attributes["name"]
|
||||
|
||||
@name.setter
|
||||
def name(self, value: _ods_ir.StringAttr):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["name"] = value
|
||||
|
||||
@builtins.property
|
||||
def result(self) -> _ods_ir.OpResult:
|
||||
return self.operation.results[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(NamedTensorOp)
|
||||
class NamedTensorOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.named_tensor"
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def name(self) -> _ods_ir.StringAttr:
|
||||
return self.attributes["name"]
|
||||
|
||||
def named_tensor(tensor: _ods_ir.Value, name: _Union[str, _ods_ir.StringAttr], *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
|
||||
return NamedTensorOp(tensor=tensor, name=name, results=results, loc=loc, ip=ip).result
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class ReduceOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Allows for a tensor to be reduced across different meshes, and then
|
||||
broadcast to wherever it needs to be used.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.reduce"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, tensors: _Sequence[_ods_ir.Value], *, reduction: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(tensors))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
if reduction is not None: attributes["reduction"] = (reduction if (
|
||||
isinstance(reduction, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('Mpmd_Reduction')) else
|
||||
_ods_ir.AttrBuilder.get('Mpmd_Reduction')(reduction, context=_ods_context))
|
||||
if results is None: results = [operands[0].type] * 1
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def reduction(self) -> _ods_ir.Attribute:
|
||||
return self.operation.attributes["reduction"]
|
||||
|
||||
@reduction.setter
|
||||
def reduction(self, value: _ods_ir.Attribute):
|
||||
if value is None:
|
||||
raise ValueError("'None' not allowed as value for mandatory attributes")
|
||||
self.operation.attributes["reduction"] = value
|
||||
|
||||
@builtins.property
|
||||
def result(self) -> _ods_ir.OpResult:
|
||||
return self.operation.results[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(ReduceOp)
|
||||
class ReduceOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.reduce"
|
||||
|
||||
@builtins.property
|
||||
def tensors(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@builtins.property
|
||||
def reduction(self) -> _ods_ir.Attribute:
|
||||
return self.attributes["reduction"]
|
||||
|
||||
def reduce(tensors: _Sequence[_ods_ir.Value], *, reduction: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
|
||||
return ReduceOp(tensors=tensors, reduction=reduction, results=results, loc=loc, ip=ip).result
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class ReturnOp(_ods_ir.OpView):
|
||||
OPERATION_NAME = "mpmd.return"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, results_: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.extend(_get_op_results_or_values(results_))
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
results = []
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def results_(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
|
||||
return self.operation.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
@_ods_cext.register_op_adaptor(ReturnOp)
|
||||
class ReturnOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.return"
|
||||
|
||||
@builtins.property
|
||||
def results_(self) -> _ods_ir.OpOperandList:
|
||||
_ods_variadic_group_length = len(self.operands) - 1 + 1
|
||||
return self.operands[0:0 + _ods_variadic_group_length]
|
||||
|
||||
def return_(results_: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> ReturnOp:
|
||||
return ReturnOp(results_=results_, loc=loc, ip=ip)
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class TransferOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Transfers a distributed tensor from one mesh to another.
|
||||
|
||||
The mesh names of the operand and result types should correspond to meshes
|
||||
in the topology, and their global types should be identical.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.transfer"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, result: _ods_ir.Type, tensor: _ods_ir.Value, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(tensor)
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
results = []
|
||||
results.append(result)
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operation.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def result(self) -> _ods_ir.OpResult:
|
||||
return self.operation.results[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(TransferOp)
|
||||
class TransferOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.transfer"
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operands[0]
|
||||
|
||||
def transfer(result: _ods_ir.Type, tensor: _ods_ir.Value, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
|
||||
return TransferOp(result=result, tensor=tensor, loc=loc, ip=ip).result
|
||||
|
||||
@_ods_cext.register_operation(_Dialect)
|
||||
class UnassignOp(_ods_ir.OpView):
|
||||
r"""
|
||||
Unassigns a fully replicated tensor from a mesh.
|
||||
|
||||
This is a temporary op that is introduced when lowering jax ops, to move
|
||||
from local types to mesh types. These ops will be eliminated during import,
|
||||
when the inputs and results of the func op become mesh tensors.
|
||||
|
||||
The mesh name of the operand type should correspond to a mesh in the
|
||||
topology, and its global type should be identical to the result type.
|
||||
"""
|
||||
|
||||
OPERATION_NAME = "mpmd.unassign"
|
||||
|
||||
_ODS_REGIONS = (0, True)
|
||||
|
||||
def __init__(self, tensor: _ods_ir.Value, *, origin: _Optional[_Union[str, _ods_ir.StringAttr]] = None, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
|
||||
operands = []
|
||||
attributes = {}
|
||||
regions = None
|
||||
operands.append(tensor)
|
||||
_ods_context = _ods_get_default_loc_context(loc)
|
||||
if origin is not None: attributes["origin"] = (origin if (
|
||||
isinstance(origin, _ods_ir.Attribute) or
|
||||
not _ods_ir.AttrBuilder.contains('StrAttr')) else
|
||||
_ods_ir.AttrBuilder.get('StrAttr')(origin, context=_ods_context))
|
||||
_ods_successors = None
|
||||
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operation.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "origin" not in self.operation.attributes:
|
||||
return None
|
||||
return self.operation.attributes["origin"]
|
||||
|
||||
@origin.setter
|
||||
def origin(self, value: _Optional[_ods_ir.StringAttr]):
|
||||
if value is not None:
|
||||
self.operation.attributes["origin"] = value
|
||||
elif "origin" in self.operation.attributes:
|
||||
del self.operation.attributes["origin"]
|
||||
|
||||
@origin.deleter
|
||||
def origin(self):
|
||||
del self.operation.attributes["origin"]
|
||||
|
||||
@builtins.property
|
||||
def result(self) -> _ods_ir.OpResult:
|
||||
return self.operation.results[0]
|
||||
|
||||
@_ods_cext.register_op_adaptor(UnassignOp)
|
||||
class UnassignOpAdaptor(_ods_ir.OpAdaptor):
|
||||
OPERATION_NAME = "mpmd.unassign"
|
||||
|
||||
@builtins.property
|
||||
def tensor(self) -> _ods_ir.Value:
|
||||
return self.operands[0]
|
||||
|
||||
@builtins.property
|
||||
def origin(self) -> _Optional[_ods_ir.StringAttr]:
|
||||
if "origin" not in self.attributes:
|
||||
return None
|
||||
return self.attributes["origin"]
|
||||
|
||||
def unassign(tensor: _ods_ir.Value, *, origin: _Optional[_Union[str, _ods_ir.StringAttr]] = None, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
|
||||
return UnassignOp(tensor=tensor, origin=origin, results=results, loc=loc, ip=ip).result
|
||||
@@ -0,0 +1,147 @@
|
||||
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from enum import IntEnum, auto, IntFlag
|
||||
from ._ods_common import _cext as _ods_cext
|
||||
from ..ir import register_attribute_builder
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
||||
class RcpRoundingMode(IntEnum):
|
||||
"""Rounding mode of rcp"""
|
||||
|
||||
APPROX = 0
|
||||
RN = 1
|
||||
RZ = 2
|
||||
RM = 3
|
||||
RP = 4
|
||||
|
||||
def __str__(self):
|
||||
if self is RcpRoundingMode.APPROX:
|
||||
return "approx"
|
||||
if self is RcpRoundingMode.RN:
|
||||
return "rn"
|
||||
if self is RcpRoundingMode.RZ:
|
||||
return "rz"
|
||||
if self is RcpRoundingMode.RM:
|
||||
return "rm"
|
||||
if self is RcpRoundingMode.RP:
|
||||
return "rp"
|
||||
raise ValueError("Unknown RcpRoundingMode enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("RcpRoundingMode", allow_existing=True)
|
||||
def _rcproundingmode(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class TensorMapInterleaveKind(IntEnum):
|
||||
"""Tensor map interleave layout type"""
|
||||
|
||||
INTERLEAVE_NONE = 0
|
||||
INTERLEAVE_16B = 1
|
||||
INTERLEAVE_32B = 2
|
||||
|
||||
def __str__(self):
|
||||
if self is TensorMapInterleaveKind.INTERLEAVE_NONE:
|
||||
return "none"
|
||||
if self is TensorMapInterleaveKind.INTERLEAVE_16B:
|
||||
return "interleave_16b"
|
||||
if self is TensorMapInterleaveKind.INTERLEAVE_32B:
|
||||
return "interleave_32b"
|
||||
raise ValueError("Unknown TensorMapInterleaveKind enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TensorMapInterleaveKind", allow_existing=True)
|
||||
def _tensormapinterleavekind(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class TensorMapL2PromoKind(IntEnum):
|
||||
"""Tensor map L2 promotion type"""
|
||||
|
||||
L2PROMO_NONE = 0
|
||||
L2PROMO_64B = 1
|
||||
L2PROMO_128B = 2
|
||||
L2PROMO_256B = 3
|
||||
|
||||
def __str__(self):
|
||||
if self is TensorMapL2PromoKind.L2PROMO_NONE:
|
||||
return "none"
|
||||
if self is TensorMapL2PromoKind.L2PROMO_64B:
|
||||
return "l2promo_64b"
|
||||
if self is TensorMapL2PromoKind.L2PROMO_128B:
|
||||
return "l2promo_128b"
|
||||
if self is TensorMapL2PromoKind.L2PROMO_256B:
|
||||
return "l2promo_256b"
|
||||
raise ValueError("Unknown TensorMapL2PromoKind enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TensorMapL2PromoKind", allow_existing=True)
|
||||
def _tensormapl2promokind(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class TensorMapOOBKind(IntEnum):
|
||||
"""Tensor map out-of-bounds fill type"""
|
||||
|
||||
OOB_ZERO = 0
|
||||
OOB_NAN = 1
|
||||
|
||||
def __str__(self):
|
||||
if self is TensorMapOOBKind.OOB_ZERO:
|
||||
return "zero"
|
||||
if self is TensorMapOOBKind.OOB_NAN:
|
||||
return "nan"
|
||||
raise ValueError("Unknown TensorMapOOBKind enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TensorMapOOBKind", allow_existing=True)
|
||||
def _tensormapoobkind(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class TensorMapSwizzleKind(IntEnum):
|
||||
"""Tensor map swizzling mode of shared memory banks"""
|
||||
|
||||
SWIZZLE_NONE = 0
|
||||
SWIZZLE_32B = 1
|
||||
SWIZZLE_64B = 2
|
||||
SWIZZLE_128B = 3
|
||||
|
||||
def __str__(self):
|
||||
if self is TensorMapSwizzleKind.SWIZZLE_NONE:
|
||||
return "none"
|
||||
if self is TensorMapSwizzleKind.SWIZZLE_32B:
|
||||
return "swizzle_32b"
|
||||
if self is TensorMapSwizzleKind.SWIZZLE_64B:
|
||||
return "swizzle_64b"
|
||||
if self is TensorMapSwizzleKind.SWIZZLE_128B:
|
||||
return "swizzle_128b"
|
||||
raise ValueError("Unknown TensorMapSwizzleKind enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TensorMapSwizzleKind", allow_existing=True)
|
||||
def _tensormapswizzlekind(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
@register_attribute_builder("nvgpu.RcpRoundingModeAttr")
|
||||
def _rcproundingmodeattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#nvgpu<rcp_rounding_mode {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("nvgpu.TensorMapInterleaveAttr")
|
||||
def _tensormapinterleaveattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#nvgpu<interleave {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("nvgpu.TensorMapL2PromoAttr")
|
||||
def _tensormapl2promoattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#nvgpu<l2promo {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("nvgpu.TensorMapOOBAttr")
|
||||
def _tensormapoobattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#nvgpu<oob {str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("nvgpu.TensorMapSwizzleAttr")
|
||||
def _tensormapswizzleattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#nvgpu<swizzle {str(x)}>', context=context)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,307 @@
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import (
|
||||
List as _List,
|
||||
Optional as _Optional,
|
||||
Sequence as _Sequence,
|
||||
Tuple as _Tuple,
|
||||
Type as _Type,
|
||||
Union as _Union,
|
||||
)
|
||||
|
||||
from .._mlir_libs import _mlir as _cext
|
||||
from ..ir import (
|
||||
ArrayAttr,
|
||||
Attribute,
|
||||
BoolAttr,
|
||||
DenseI64ArrayAttr,
|
||||
IntegerAttr,
|
||||
IntegerType,
|
||||
OpView,
|
||||
Operation,
|
||||
ShapedType,
|
||||
Value,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"equally_sized_accessor",
|
||||
"get_default_loc_context",
|
||||
"get_op_result_or_value",
|
||||
"get_op_results_or_values",
|
||||
"get_op_result_or_op_results",
|
||||
"segmented_accessor",
|
||||
]
|
||||
|
||||
|
||||
def segmented_accessor(elements, raw_segments, idx):
|
||||
"""
|
||||
Returns a slice of elements corresponding to the idx-th segment.
|
||||
|
||||
elements: a sliceable container (operands or results).
|
||||
raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
|
||||
sizes of the segments.
|
||||
idx: index of the segment.
|
||||
"""
|
||||
segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
|
||||
start = sum(segments[i] for i in range(idx))
|
||||
end = start + segments[idx]
|
||||
return elements[start:end]
|
||||
|
||||
|
||||
def equally_sized_accessor(
|
||||
elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic
|
||||
):
|
||||
"""
|
||||
Returns a starting position and a number of elements per variadic group
|
||||
assuming equally-sized groups and the given numbers of preceding groups.
|
||||
|
||||
elements: a sequential container.
|
||||
n_simple: the number of non-variadic groups in the container.
|
||||
n_variadic: the number of variadic groups in the container.
|
||||
n_preceding_simple: the number of non-variadic groups preceding the current
|
||||
group.
|
||||
n_preceding_variadic: the number of variadic groups preceding the current
|
||||
group.
|
||||
"""
|
||||
|
||||
total_variadic_length = len(elements) - n_simple
|
||||
# This should be enforced by the C++-side trait verifier.
|
||||
assert total_variadic_length % n_variadic == 0
|
||||
|
||||
elements_per_group = total_variadic_length // n_variadic
|
||||
start = n_preceding_simple + n_preceding_variadic * elements_per_group
|
||||
return start, elements_per_group
|
||||
|
||||
|
||||
def get_default_loc_context(location=None):
|
||||
"""
|
||||
Returns a context in which the defaulted location is created. If the location
|
||||
is None, takes the current location from the stack.
|
||||
"""
|
||||
if location is None:
|
||||
if _cext.ir.Location.current:
|
||||
return _cext.ir.Location.current.context
|
||||
return None
|
||||
return location.context
|
||||
|
||||
|
||||
def get_op_result_or_value(
|
||||
arg: _Union[
|
||||
_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
|
||||
]
|
||||
) -> _cext.ir.Value:
|
||||
"""Returns the given value or the single result of the given op.
|
||||
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
arguments instead of requiring the caller to extract results for every op.
|
||||
Raises ValueError if provided with an op that doesn't have a single result.
|
||||
"""
|
||||
if isinstance(arg, _cext.ir.OpView):
|
||||
return arg.operation.result
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.result
|
||||
elif isinstance(arg, _cext.ir.OpResultList):
|
||||
return arg[0]
|
||||
else:
|
||||
assert isinstance(arg, _cext.ir.Value), f"expects Value, got {type(arg)}"
|
||||
return arg
|
||||
|
||||
|
||||
def get_op_results_or_values(
|
||||
arg: _Union[
|
||||
_cext.ir.OpView,
|
||||
_cext.ir.Operation,
|
||||
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
|
||||
]
|
||||
) -> _Union[
|
||||
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
|
||||
_cext.ir.OpResultList,
|
||||
]:
|
||||
"""Returns the given sequence of values or the results of the given op.
|
||||
|
||||
This is useful to implement op constructors so that they can take other ops as
|
||||
lists of arguments instead of requiring the caller to extract results for
|
||||
every op.
|
||||
"""
|
||||
if isinstance(arg, _cext.ir.OpView):
|
||||
return arg.operation.results
|
||||
elif isinstance(arg, _cext.ir.Operation):
|
||||
return arg.results
|
||||
else:
|
||||
return arg
|
||||
|
||||
|
||||
def get_op_result_or_op_results(
|
||||
op: _Union[_cext.ir.OpView, _cext.ir.Operation],
|
||||
) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
|
||||
results = op.results
|
||||
num_results = len(results)
|
||||
if num_results == 1:
|
||||
return results[0]
|
||||
elif num_results > 1:
|
||||
return results
|
||||
elif isinstance(op, _cext.ir.OpView):
|
||||
return op.operation
|
||||
else:
|
||||
return op
|
||||
|
||||
|
||||
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
|
||||
ResultValueT = _Union[ResultValueTypeTuple]
|
||||
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
|
||||
|
||||
StaticIntLike = _Union[int, IntegerAttr]
|
||||
ValueLike = _Union[Operation, OpView, Value]
|
||||
MixedInt = _Union[StaticIntLike, ValueLike]
|
||||
|
||||
IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
|
||||
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
|
||||
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
|
||||
|
||||
MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
|
||||
|
||||
DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
|
||||
|
||||
|
||||
def _dispatch_dynamic_index_list(
|
||||
indices: _Union[DynamicIndexList, ArrayAttr],
|
||||
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
|
||||
"""Dispatches a list of indices to the appropriate form.
|
||||
|
||||
This is similar to the custom `DynamicIndexList` directive upstream:
|
||||
provided indices may be in the form of dynamic SSA values or static values,
|
||||
and they may be scalable (i.e., as a singleton list) or not. This function
|
||||
dispatches each index into its respective form. It also extracts the SSA
|
||||
values and static indices from various similar structures, respectively.
|
||||
"""
|
||||
dynamic_indices = []
|
||||
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
|
||||
scalable_indices = [False] * len(indices)
|
||||
|
||||
# ArrayAttr: Extract index values.
|
||||
if isinstance(indices, ArrayAttr):
|
||||
indices = [idx for idx in indices]
|
||||
|
||||
def process_nonscalable_index(i, index):
|
||||
"""Processes any form of non-scalable index.
|
||||
|
||||
Returns False if the given index was scalable and thus remains
|
||||
unprocessed; True otherwise.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
static_indices[i] = index
|
||||
elif isinstance(index, IntegerAttr):
|
||||
static_indices[i] = index.value # pytype: disable=attribute-error
|
||||
elif isinstance(index, (Operation, Value, OpView)):
|
||||
dynamic_indices.append(index)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process each index at a time.
|
||||
for i, index in enumerate(indices):
|
||||
if not process_nonscalable_index(i, index):
|
||||
# If it wasn't processed, it must be a scalable index, which is
|
||||
# provided as a _Sequence of one value, so extract and process that.
|
||||
scalable_indices[i] = True
|
||||
assert len(index) == 1
|
||||
ret = process_nonscalable_index(i, index[0])
|
||||
assert ret
|
||||
|
||||
return dynamic_indices, static_indices, scalable_indices
|
||||
|
||||
|
||||
# Dispatches `MixedValues` that all represents integers in various forms into
|
||||
# the following three categories:
|
||||
# - `dynamic_values`: a list of `Value`s, potentially from op results;
|
||||
# - `packed_values`: a value handle, potentially from an op result, associated
|
||||
# to one or more payload operations of integer type;
|
||||
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
|
||||
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
|
||||
# The input is in the form for `packed_values`, only that result is set and the
|
||||
# other two are empty. Otherwise, the input can be a mix of the other two forms,
|
||||
# and for each dynamic value, a special value is added to the `static_values`.
|
||||
def _dispatch_mixed_values(
|
||||
values: MixedValues,
|
||||
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
|
||||
dynamic_values = []
|
||||
packed_values = None
|
||||
static_values = None
|
||||
if isinstance(values, ArrayAttr):
|
||||
static_values = values
|
||||
elif isinstance(values, (Operation, Value, OpView)):
|
||||
packed_values = values
|
||||
else:
|
||||
static_values = []
|
||||
for size in values or []:
|
||||
if isinstance(size, int):
|
||||
static_values.append(size)
|
||||
else:
|
||||
static_values.append(ShapedType.get_dynamic_size())
|
||||
dynamic_values.append(size)
|
||||
static_values = DenseI64ArrayAttr.get(static_values)
|
||||
|
||||
return (dynamic_values, packed_values, static_values)
|
||||
|
||||
|
||||
def _get_value_or_attribute_value(
|
||||
value_or_attr: _Union[any, Attribute, ArrayAttr]
|
||||
) -> any:
|
||||
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
|
||||
return value_or_attr.value
|
||||
if isinstance(value_or_attr, ArrayAttr):
|
||||
return _get_value_list(value_or_attr)
|
||||
return value_or_attr
|
||||
|
||||
|
||||
def _get_value_list(
|
||||
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
|
||||
) -> _Sequence[any]:
|
||||
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
|
||||
|
||||
|
||||
def _get_int_array_attr(
|
||||
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
|
||||
) -> ArrayAttr:
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Turn into a Python list of Python ints.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# Make an ArrayAttr of IntegerAttrs out of it.
|
||||
return ArrayAttr.get(
|
||||
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
|
||||
)
|
||||
|
||||
|
||||
def _get_int_array_array_attr(
|
||||
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
|
||||
) -> ArrayAttr:
|
||||
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
|
||||
|
||||
The input has to be a collection of a collection of integers, where any
|
||||
Python _Sequence and ArrayAttr are admissible collections and Python ints and
|
||||
any IntegerAttr are admissible integers. Both levels of collections are
|
||||
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
|
||||
If the input is None, an empty ArrayAttr is returned.
|
||||
"""
|
||||
if values is None:
|
||||
return None
|
||||
|
||||
# Make sure the outer level is a list.
|
||||
values = _get_value_list(values)
|
||||
|
||||
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
|
||||
# Sequences. Make sure the nested values are all lists.
|
||||
values = [_get_value_list(nested) for nested in values]
|
||||
|
||||
# Turn each nested list into an ArrayAttr.
|
||||
values = [_get_int_array_attr(nested) for nested in values]
|
||||
|
||||
# Turn the outer list into an ArrayAttr.
|
||||
return ArrayAttr.get(values)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user