This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,218 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Any, Mapping, Sequence
import os
_this_dir = os.path.dirname(__file__)
def get_lib_dirs() -> Sequence[str]:
"""Gets the lib directory for linking to shared libraries.
On some platforms, the package may need to be built specially to export
development libraries.
"""
return [_this_dir]
def get_include_dirs() -> Sequence[str]:
"""Gets the include directory for compiling against exported C libraries.
Depending on how the package was build, development C libraries may or may
not be present.
"""
return [os.path.join(_this_dir, "include")]
# Perform Python level site initialization. This involves:
# 1. Attempting to load initializer modules, specific to the distribution.
# 2. Defining the concrete mlir.ir.Context that does site specific
# initialization.
# 3. Registering container classes with their respective protocols.
#
# Aside from just being far more convenient to do this at the Python level,
# it is actually quite hard/impossible to have such __init__ hooks, given
# the pybind memory model (i.e. there is not a Python reference to the object
# in the scope of the base class __init__).
#
# For #1, we:
# a. Probe for modules named '_mlirRegisterEverything' and
# '_site_initialize_{i}', where 'i' is a number starting at zero and
# proceeding so long as a module with the name is found.
# b. If the module has a 'register_dialects' attribute, it will be called
# immediately with a DialectRegistry to populate.
# c. If the module has a 'context_init_hook', it will be added to a list
# of callbacks that are invoked as the last step of Context
# initialization (and passed the Context under construction).
# d. If the module has a 'disable_multithreading' attribute, it will be
# taken as a boolean. If it is True for any initializer, then the
# default behavior of enabling multithreading on the context
# will be suppressed. This complies with the original behavior of all
# contexts being created with multithreading enabled while allowing
# this behavior to be changed if needed (i.e. if a context_init_hook
# explicitly sets up multithreading).
#
# This facility allows downstreams to customize Context creation to their
# needs.
_dialect_registry = None
_load_on_create_dialects = None
def get_dialect_registry():
global _dialect_registry
if _dialect_registry is None:
from ._mlir import ir
_dialect_registry = ir.DialectRegistry()
return _dialect_registry
def append_load_on_create_dialect(dialect: str):
global _load_on_create_dialects
if _load_on_create_dialects is None:
_load_on_create_dialects = [dialect]
else:
_load_on_create_dialects.append(dialect)
def get_load_on_create_dialects():
global _load_on_create_dialects
if _load_on_create_dialects is None:
_load_on_create_dialects = []
return _load_on_create_dialects
def _site_initialize():
import importlib
import itertools
import logging
from ._mlir import ir
logger = logging.getLogger(__name__)
post_init_hooks = []
disable_multithreading = False
# This flag disables eagerly loading all dialects. Eagerly loading is often
# not the desired behavior (see
# https://github.com/llvm/llvm-project/issues/56037), and the logic is that
# if any module has this attribute set, then we don't load all (e.g., it's
# being used in a solution where the loading is controlled).
disable_load_all_available_dialects = False
def process_initializer_module(module_name):
nonlocal disable_multithreading
nonlocal disable_load_all_available_dialects
try:
m = importlib.import_module(f".{module_name}", __name__)
except ModuleNotFoundError:
return False
except ImportError:
message = (
f"Error importing mlir initializer {module_name}. This may "
"happen in unclean incremental builds but is likely a real bug if "
"encountered otherwise and the MLIR Python API may not function."
)
logger.warning(message, exc_info=True)
return False
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
m.register_dialects(get_dialect_registry())
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
if hasattr(m, "disable_multithreading"):
if bool(m.disable_multithreading):
logger.debug("Disabling multi-threading for context")
disable_multithreading = True
if hasattr(m, "disable_load_all_available_dialects"):
disable_load_all_available_dialects = True
return True
# If _mlirRegisterEverything is built, then include it as an initializer
# module.
init_module = None
if process_initializer_module("_mlirRegisterEverything"):
init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
# Load all _site_initialize_{i} modules, where 'i' is a number starting
# at 0.
for i in itertools.count():
module_name = f"_site_initialize_{i}"
if not process_initializer_module(module_name):
break
ir._Context = ir.Context
class Context(ir._Context):
def __init__(
self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if disable_multithreading and thread_pool is not None:
raise ValueError(
"Context constructor has given thread_pool argument, "
"but disable_multithreading flag is True. "
"Please, set thread_pool argument to None or "
"set disable_multithreading flag to False."
)
if not disable_multithreading:
if thread_pool is None:
self.enable_multithreading(True)
else:
self.set_thread_pool(thread_pool)
if load_on_create_dialects is not None:
logger.debug(
"Loading all dialects from load_on_create_dialects arg %r",
load_on_create_dialects,
)
for dialect in load_on_create_dialects:
# This triggers loading the dialect into the context.
_ = self.dialects[dialect]
else:
if disable_load_all_available_dialects:
dialects = get_load_on_create_dialects()
if dialects:
logger.debug(
"Loading all dialects from global load_on_create_dialects %r",
dialects,
)
for dialect in dialects:
# This triggers loading the dialect into the context.
_ = self.dialects[dialect]
else:
logger.debug("Loading all available dialects")
self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module
)
init_module.register_llvm_translations(self)
ir.Context = Context
# Register containers as Sequences, so they can be used with `match`.
Sequence.register(ir.BlockArgumentList)
Sequence.register(ir.BlockList)
Sequence.register(ir.BlockSuccessors)
Sequence.register(ir.BlockPredecessors)
Sequence.register(ir.OperationList)
Sequence.register(ir.OpOperandList)
Sequence.register(ir.OpOperands)
Sequence.register(ir.OpResultList)
Sequence.register(ir.OpSuccessors)
Sequence.register(ir.RegionSequence)
Mapping.register(ir.OpAttributeMap)
_site_initialize()
@@ -0,0 +1,78 @@
"""chlo main python extension"""
from typing import Self
from collections.abc import Sequence
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
class ComparisonDirectionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonDirection attribute with the given value."""
@property
def value(self) -> str: ...
class ComparisonTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonType attribute with the given value."""
@property
def value(self) -> str: ...
class RaggedDotDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], lhs_ragged_dimensions: Sequence[int], rhs_group_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
"""
Creates a RaggedDotDimensionNumbers attribute with the given dimension configuration.
"""
@property
def lhs_batching_dimensions(self) -> list[int]: ...
@property
def rhs_batching_dimensions(self) -> list[int]: ...
@property
def lhs_contracting_dimensions(self) -> list[int]: ...
@property
def rhs_contracting_dimensions(self) -> list[int]: ...
@property
def lhs_ragged_dimensions(self) -> list[int]: ...
@property
def rhs_group_dimensions(self) -> list[int]: ...
class PrecisionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Precision attribute with the given value."""
@property
def value(self) -> str: ...
Binary file not shown.
@@ -0,0 +1,22 @@
"""Registers upstream MLIR dialects used by JAX."""
from collections.abc import Callable, Sequence
from jaxlib.mlir import ir
def register_dialects(arg: ir.DialectRegistry, /) -> None: ...
def enter_multi_threaded_execution(arg: ir.Context, /) -> None: ...
def exit_multi_threaded_execution(arg: ir.Context, /) -> None: ...
def inlined_func_call(callee: ir.Operation, args: Sequence[ir.Value], block: ir.Block, loc: ir.Location | None = None) -> list[ir.Value]:
"""
Makes an inlined call to a function containing a single block with a single return op.
"""
class TracebackToLocationCache:
def __init__(self, code_to_filename: Callable, frame_limit: int, context: ir.Context | None = None) -> None: ...
def get(self, traceback: Traceback, /) -> ir.Location: ...
Binary file not shown.
@@ -0,0 +1,66 @@
"""MLIR Python Native Extension"""
from collections.abc import Callable, Sequence
from typing import TypeVar
from . import (
ir as ir,
passmanager as passmanager,
rewrite as rewrite
)
T = TypeVar("T")
U = TypeVar("U")
class _Globals:
@property
def dialect_search_modules(self) -> list[str]: ...
@dialect_search_modules.setter
def dialect_search_modules(self, arg: Sequence[str], /) -> None: ...
def append_dialect_search_prefix(self, module_name: str) -> None: ...
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: object, *, replace: bool = False) -> None:
"""Testing hook for directly registering a dialect"""
def _register_operation_impl(self, operation_name: str, operation_class: object, *, replace: bool = False) -> None:
"""Testing hook for directly registering an operation"""
def loc_tracebacks_enabled(self) -> bool: ...
def set_loc_tracebacks_enabled(self, arg: bool, /) -> None: ...
def loc_tracebacks_frame_limit(self) -> int: ...
def set_loc_tracebacks_frame_limit(self, arg: int | None) -> None: ...
def register_traceback_file_inclusion(self, arg: str, /) -> None: ...
def register_traceback_file_exclusion(self, arg: str, /) -> None: ...
globals: _Globals = ...
def register_dialect(dialect_class: type) -> type:
"""Class decorator for registering a custom Dialect wrapper"""
def register_operation(dialect_class: type, *, replace: bool = False) -> Callable[[type[T]], type[T]]:
"""
Produce a class decorator for registering an Operation class as part of a dialect
"""
def register_op_adaptor(op_class: type, *, replace: bool = False) -> Callable[[type[T]], type[T]]:
"""
Produce a class decorator for registering an OpAdaptor class for an operation.
"""
def register_type_caster(typeid: _ir.TypeID, *, replace: bool = False) -> Callable[[Callable[[T], U]], Callable[[T], U]]:
"""Register a type caster for casting MLIR types to custom user types."""
def register_value_caster(typeid: _ir.TypeID, *, replace: bool = False) -> Callable[[Callable[[T], U]], Callable[[T], U]]:
"""Register a value caster for casting MLIR values to custom user values."""
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,79 @@
"""MLIR Pass Management Bindings"""
from collections.abc import Callable
import enum
from typing import overload
from jaxlib.mlir import ir
class PassDisplayMode(enum.Enum):
LIST = 0
PIPELINE = 1
class ExternalPass:
def signal_pass_failure(self) -> None: ...
class PassManager:
def __init__(self, anchor_op: str = 'any', context: ir.Context | None = None) -> None:
"""Create a new PassManager for the current (or provided) Context."""
@property
def _CAPIPtr(self) -> object: ...
def _CAPICreate(self) -> object: ...
def _testing_release(self) -> None:
"""Releases (leaks) the backing pass manager (testing)"""
def enable_ir_printing(self, print_before_all: bool = False, print_after_all: bool = True, print_module_scope: bool = False, print_after_change: bool = False, print_after_failure: bool = False, large_elements_limit: int | None = None, large_resource_limit: int | None = None, enable_debug_info: bool = False, print_generic_op_form: bool = False, tree_printing_dir_path: str | None = None) -> None:
"""Enable IR printing, default as mlir-print-ir-after-all."""
def enable_verifier(self, enable: bool) -> None:
"""Enable / disable verify-each."""
def enable_timing(self) -> None:
"""Enable pass timing."""
def enable_statistics(self, displayMode: PassDisplayMode = PassDisplayMode.PIPELINE) -> None:
"""Enable pass statistics."""
@staticmethod
def parse(pipeline: str, context: ir.Context | None = None) -> PassManager:
"""
Parse a textual pass-pipeline and return a top-level PassManager that can be applied on a Module. Throw a ValueError if the pipeline can't be parsed
"""
@overload
def add(self, pipeline: str) -> None:
"""
Add textual pipeline elements to the pass manager. Throws a ValueError if the pipeline can't be parsed.
"""
@overload
def add(self, run: Callable, name: str | None = None, argument: str | None = '', description: str | None = '', op_name: str | None = '') -> None:
"""
Add a python-defined pass to the current pipeline of the pass manager.
Args:
run: A callable with signature ``(op: ir.Operation, pass_: ExternalPass) -> None``.
Called when the pass executes. It receives the operation to be processed and
the current ``ExternalPass`` instance.
Use ``pass_.signal_pass_failure()`` to signal failure.
name: The name of the pass. Defaults to ``run.__name__``.
argument: The command-line argument for the pass. Defaults to empty.
description: The description of the pass. Defaults to empty.
op_name: The name of the operation this pass operates on.
It will be a generic operation pass if not specified.
"""
def run(self, operation: ir._OperationBase) -> None:
"""
Run the pass manager on the provided operation, raising an MLIRError on failure.
"""
def __str__(self) -> str:
"""
Print the textual representation for this PassManager, suitable to be passed to `parse` for round-tripping.
"""
@@ -0,0 +1,241 @@
"""MLIR Rewrite Bindings"""
from collections.abc import Callable, Sequence
import enum
from typing import overload
from jaxlib.mlir import ir
class GreedyRewriteStrictness(enum.Enum):
ANY_OP = 0
EXISTING_AND_NEW_OPS = 1
EXISTING_OPS = 2
class GreedySimplifyRegionLevel(enum.Enum):
DISABLED = 0
NORMAL = 1
AGGRESSIVE = 2
class DialectConversionFoldingMode(enum.Enum):
NEVER = 0
BEFORE_PATTERNS = 1
AFTER_PATTERNS = 2
class PatternRewriter:
@property
def ip(self) -> ir.InsertionPoint:
"""The current insertion point of the PatternRewriter."""
@overload
def replace_op(self, op: ir._OperationBase, new_op: ir._OperationBase) -> None:
"""Replace an operation with a new operation."""
@overload
def replace_op(self, op: ir._OperationBase, values: Sequence[ir.Value]) -> None:
"""Replace an operation with a list of values."""
def erase_op(self, op: ir._OperationBase) -> None:
"""Erase an operation."""
class RewritePatternSet:
def __init__(self, context: _ir.Context | None = None) -> None: ...
def add(self, root: object, fn: Callable, benefit: int = 1) -> None:
"""
Add a new rewrite pattern on the specified root operation, using
the provided callable for matching and rewriting, and assign it
the given benefit.
Args:
root: The root operation to which this pattern applies. This may
be either an OpView subclass or an operation name.
fn: The callable to use for matching and rewriting, which takes
an operation and a pattern rewriter. The match is considered
successful iff the callable returns a falsy value.
benefit: The benefit of the pattern, defaulting to 1.
"""
def add_conversion(self, root: object, fn: Callable, type_converter: TypeConverter, benefit: int = 1) -> None:
"""
Add a new conversion pattern on the specified root operation,
using the provided callable for matching and rewriting,
and assign it the given benefit.
Args:
root: The root operation to which this pattern applies.
This may be either an OpView subclass or an operation name.
fn: The callable to use for matching and rewriting, which takes an
operation, its adaptor, the type converter and a pattern
rewriter. The match is considered successful iff the callable
returns a falsy value.
type_converter: The type converter to convert types in the IR.
benefit: The benefit of the pattern, defaulting to 1.
"""
def freeze(self) -> FrozenRewritePatternSet:
"""Freeze the pattern set into a frozen one."""
class ConversionPatternRewriter(PatternRewriter):
def convert_region_types(self, arg0: ir.Region, arg1: TypeConverter, /) -> None: ...
class ConversionTarget:
def __init__(self, context: _ir.Context | None = None) -> None: ...
def add_legal_op(self, *ops) -> None:
"""Mark the given operations as legal."""
def add_illegal_op(self, *ops) -> None:
"""Mark the given operations as illegal."""
def add_legal_dialect(self, *dialects) -> None:
"""Mark the given dialects as legal."""
def add_illegal_dialect(self, *dialects) -> None:
"""Mark the given dialect as illegal."""
class TypeConverter:
def __init__(self) -> None:
"""Create a new TypeConverter."""
def add_conversion(self, convert: Callable) -> None:
"""Register a type conversion function."""
def convert_type(self, type: ir.Type) -> ir.Type | None:
"""Convert the given type. Returns None if conversion fails."""
class PDLResultList:
@overload
def append(self, arg: ir.Value, /) -> None: ...
@overload
def append(self, arg: ir.Operation, /) -> None: ...
@overload
def append(self, arg: ir.Type, /) -> None: ...
@overload
def append(self, arg: ir.Attribute, /) -> None: ...
class PDLModule:
@overload
def __init__(self, module: ir.Module) -> None:
"""Create a PDL module from the given module."""
@overload
def __init__(self, module: ir.Module) -> None: ...
def freeze(self) -> FrozenRewritePatternSet: ...
def register_rewrite_function(self, arg0: str, arg1: Callable, /) -> None: ...
def register_constraint_function(self, arg0: str, arg1: Callable, /) -> None: ...
class GreedyRewriteConfig:
def __init__(self) -> None:
"""Create a greedy rewrite driver config with defaults"""
@property
def max_iterations(self) -> int:
"""Maximum number of iterations"""
@max_iterations.setter
def max_iterations(self, arg: int, /) -> None: ...
@property
def max_num_rewrites(self) -> int:
"""Maximum number of rewrites per iteration"""
@max_num_rewrites.setter
def max_num_rewrites(self, arg: int, /) -> None: ...
@property
def use_top_down_traversal(self) -> bool:
"""Whether to use top-down traversal"""
@use_top_down_traversal.setter
def use_top_down_traversal(self, arg: bool, /) -> None: ...
@property
def enable_folding(self) -> bool:
"""Enable or disable folding"""
@enable_folding.setter
def enable_folding(self, arg: bool, /) -> None: ...
@property
def strictness(self) -> GreedyRewriteStrictness:
"""Rewrite strictness level"""
@strictness.setter
def strictness(self, arg: GreedyRewriteStrictness, /) -> None: ...
@property
def region_simplification_level(self) -> GreedySimplifyRegionLevel:
"""Region simplification level"""
@region_simplification_level.setter
def region_simplification_level(self, arg: GreedySimplifyRegionLevel, /) -> None: ...
@property
def enable_constant_cse(self) -> bool:
"""Enable or disable constant CSE"""
@enable_constant_cse.setter
def enable_constant_cse(self, arg: bool, /) -> None: ...
class ConversionConfig:
def __init__(self) -> None:
"""Create a conversion config with defaults"""
@property
def folding_mode(self) -> DialectConversionFoldingMode:
"""folding behavior during dialect conversion"""
@folding_mode.setter
def folding_mode(self, arg: DialectConversionFoldingMode, /) -> None: ...
@property
def build_materializations(self) -> bool:
"""
Whether the dialect conversion attempts to build source/target materializations
"""
@build_materializations.setter
def build_materializations(self, arg: bool, /) -> None: ...
class FrozenRewritePatternSet:
@property
def _CAPIPtr(self) -> object: ...
def _CAPICreate(self) -> object: ...
@overload
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet, config: GreedyRewriteConfig | None = None) -> None:
"""
Applys the given patterns to the given module greedily while folding results.
"""
@overload
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet, config: GreedyRewriteConfig | None = None) -> None:
"""
Applys the given patterns to the given op greedily while folding results.
"""
def walk_and_apply_patterns(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None:
"""
Applies the given patterns to the given op by a fast walk-based driver.
"""
def apply_partial_conversion(op: ir._OperationBase, target: ConversionTarget, set: FrozenRewritePatternSet, config: ConversionConfig | None = None) -> None:
"""Applies a partial conversion on the given operation."""
def apply_full_conversion(op: ir._OperationBase, target: ConversionTarget, set: FrozenRewritePatternSet, config: ConversionConfig | None = None) -> None:
"""Applies a full conversion on the given operation."""
@@ -0,0 +1,61 @@
"""MLIR GPU Dialect"""
from typing import ClassVar, Final
from jaxlib.mlir import ir
class AsyncTokenType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(context: _ir.Context | None = None) -> AsyncTokenType:
"""Gets an instance of AsyncTokenType in the same context"""
class ObjectAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
attr_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(target: ir.Attribute, format: int, object: bytes, properties: ir.DictAttr | None = None, kernels: ir.Attribute | None = None, context: _ir.Context | None = None) -> ObjectAttr:
"""Gets a gpu.object from parameters."""
@property
def target(self) -> ir.Attribute: ...
@property
def format(self) -> int: ...
@property
def object(self) -> bytes: ...
@property
def properties(self) -> ir.DictAttr | None: ...
@property
def kernels(self) -> ir.Attribute | None: ...
@@ -0,0 +1,207 @@
"""MLIR LLVM Dialect"""
from collections.abc import Sequence
from typing import ClassVar, Final
from jaxlib.mlir import ir
class StructType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get_literal(elements: Sequence[ir.Type], *, packed: bool = False, loc: _ir.Location | None = None, context: _ir.Context | None = None) -> StructType: ...
@staticmethod
def get_literal_unchecked(elements: Sequence[ir.Type], *, packed: bool = False, context: _ir.Context | None = None) -> StructType: ...
@staticmethod
def get_identified(name: str, *, context: _ir.Context | None = None) -> StructType: ...
@staticmethod
def get_opaque(name: str, context: _ir.Context | None = None) -> StructType: ...
def set_body(self, elements: Sequence[ir.Type], *, packed: bool = False) -> None: ...
@staticmethod
def new_identified(name: str, elements: Sequence[ir.Type], *, packed: bool = False, context: _ir.Context | None = None) -> StructType: ...
@property
def name(self) -> str | None: ...
@property
def body(self) -> object: ...
@property
def packed(self) -> bool: ...
@property
def opaque(self) -> bool: ...
class ArrayType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(element_type: ir.Type, num_elements: int) -> ArrayType: ...
@property
def element_type(self) -> ir.Type: ...
@property
def num_elements(self) -> int: ...
class PointerType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(address_space: int | None = None, *, context: _ir.Context | None = None) -> PointerType: ...
@property
def address_space(self) -> int: ...
class FunctionType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(result_type: ir.Type, argument_types: Sequence[ir.Type], *, is_var_arg: bool = False) -> FunctionType: ...
@property
def return_type(self) -> ir.Type: ...
@property
def num_inputs(self) -> int: ...
@property
def inputs(self) -> list: ...
@property
def is_var_arg(self) -> bool: ...
class MDStringAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
@staticmethod
def get(value: str, *, context: _ir.Context | None = None) -> MDStringAttr: ...
@property
def value(self) -> str: ...
class MDConstantAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
@staticmethod
def get(value: ir.Attribute, *, context: _ir.Context | None = None) -> MDConstantAttr: ...
@property
def value(self) -> ir.Attribute: ...
class MDFuncAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
@staticmethod
def get(name: str, *, context: _ir.Context | None = None) -> MDFuncAttr: ...
@property
def name(self) -> str: ...
class MDNodeAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
@staticmethod
def get(operands: Sequence[ir.Attribute], *, context: _ir.Context | None = None) -> MDNodeAttr: ...
@property
def num_operands(self) -> int: ...
def __getitem__(self, arg: int, /) -> ir.Attribute: ...
def __len__(self) -> int: ...
def translate_module_to_llvmir(module: ir.Operation) -> str: ...
@@ -0,0 +1,25 @@
"""MLIR NVGPU dialect."""
from typing import ClassVar, Final
from jaxlib.mlir import ir
class TensorMapDescriptorType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(tensor_type: ir.Type, swizzle: int, l2promo: int, oob_fill: int, interleave: int, context: _ir.Context | None = None) -> TensorMapDescriptorType:
"""Gets an instance of TensorMapDescriptorType in the same context"""
@@ -0,0 +1,97 @@
"""MLIR SparseTensor dialect."""
from collections.abc import Sequence
import enum
from typing import ClassVar, Final
from jaxlib.mlir import ir
class LevelFormat(enum.IntFlag):
_boundary_: enum.FlagBoundary = enum.FlagBoundary.KEEP
_flag_mask_: int = 3997696
_singles_mask_: int = 3997696
_all_bits_: int = 8388607
_inverted_: None = None
__str__ = __repr__
def __repr__(self, /):
"""Return repr(self)."""
dense = 65536
n_out_of_m = 2097152
compressed = 262144
singleton = 524288
loose_compressed = 1048576
class LevelProperty(enum.Enum):
non_ordered = 2
non_unique = 1
soa = 4
class EncodingAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
attr_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(lvl_types: Sequence[int], dim_to_lvl: ir.AffineMap | None, lvl_to_dim: ir.AffineMap | None, pos_width: int, crd_width: int, explicit_val: ir.Attribute | None = None, implicit_val: ir.Attribute | None = None, context: _ir.Context | None = None) -> EncodingAttr:
"""Gets a sparse_tensor.encoding from parameters."""
@staticmethod
def build_level_type(lvl_fmt: LevelFormat, properties: Sequence[LevelProperty] = [], n: int = 0, m: int = 0) -> int:
"""Builds a sparse_tensor.encoding.level_type from parameters."""
@property
def lvl_types(self) -> list[int]: ...
@property
def dim_to_lvl(self) -> ir.AffineMap | None: ...
@property
def lvl_to_dim(self) -> ir.AffineMap | None: ...
@property
def pos_width(self) -> int: ...
@property
def crd_width(self) -> int: ...
@property
def explicit_val(self) -> ir.Attribute | None: ...
@property
def implicit_val(self) -> ir.Attribute | None: ...
@property
def structured_n(self) -> int: ...
@property
def structured_m(self) -> int: ...
@property
def lvl_formats_enum(self) -> list[LevelFormat]: ...
@@ -0,0 +1,309 @@
"""mlir-hlo main python extension"""
from typing import Self
from collections.abc import Sequence
from jaxlib.mlir import ir
def register_mhlo_dialect(context: ir.Context, load: bool = True) -> None: ...
def register_mhlo_passes() -> None: ...
class TokenType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, context: ir.Context | None = None) -> Self:
"""Creates a Token type."""
class ScatterDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, update_window_dims: Sequence[int], inserted_window_dims: Sequence[int], input_batching_dims: Sequence[int], scatter_indices_batching_dims: Sequence[int], scattered_dims_to_operand_dims: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
"""
Creates a ScatterDimensionNumbers with the given dimension configuration.
"""
@property
def update_window_dims(self) -> list[int]: ...
@property
def inserted_window_dims(self) -> list[int]: ...
@property
def input_batching_dims(self) -> list[int]: ...
@property
def scatter_indices_batching_dims(self) -> list[int]: ...
@property
def scattered_dims_to_operand_dims(self) -> list[int]: ...
@property
def index_vector_dim(self) -> int: ...
class GatherDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, offset_dims: Sequence[int], collapsed_slice_dims: Sequence[int], operand_batching_dims: Sequence[int], start_indices_batching_dims: Sequence[int], start_index_map: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
"""
Creates a GatherDimensionNumbers attribute with the given dimension configuration.
"""
@property
def offset_dims(self) -> list[int]: ...
@property
def collapsed_slice_dims(self) -> list[int]: ...
@property
def operand_batching_dims(self) -> list[int]: ...
@property
def start_indices_batching_dims(self) -> list[int]: ...
@property
def start_index_map(self) -> list[int]: ...
@property
def index_vector_dim(self) -> int: ...
class DotDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
"""
Creates a DotDimensionNumbers attribute with the given dimension configuration.
"""
@property
def lhs_batching_dimensions(self) -> list[int]: ...
@property
def rhs_batching_dimensions(self) -> list[int]: ...
@property
def lhs_contracting_dimensions(self) -> list[int]: ...
@property
def rhs_contracting_dimensions(self) -> list[int]: ...
class ConvDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, input_batch_dimension: int, input_feature_dimension: int, input_spatial_dimensions: Sequence[int], kernel_input_feature_dimension: int, kernel_output_feature_dimension: int, kernel_spatial_dimensions: Sequence[int], output_batch_dimension: int, output_feature_dimension: int, output_spatial_dimensions: Sequence[int], ctx: ir.Context | None = None) -> Self:
"""
Creates a ConvDimensionNumbers attribute with the given dimension configuration.
"""
@property
def input_batch_dimension(self) -> int: ...
@property
def input_feature_dimension(self) -> int: ...
@property
def input_spatial_dimensions(self) -> list[int]: ...
@property
def kernel_input_feature_dimension(self) -> int: ...
@property
def kernel_output_feature_dimension(self) -> int: ...
@property
def kernel_spatial_dimensions(self) -> list[int]: ...
@property
def output_batch_dimension(self) -> int: ...
@property
def output_feature_dimension(self) -> int: ...
@property
def output_spatial_dimensions(self) -> list[int]: ...
class OutputOperandAlias(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, output_tuple_indices: Sequence[int], operand_index: int, operand_tuple_indices: Sequence[int], ctx: ir.Context | None = None) -> Self:
"""Creates a OutputOperandAlias attribute with the given tuple index."""
@property
def output_tuple_indices(self) -> list[int]: ...
@property
def operand_index(self) -> int: ...
@property
def operand_tuple_indices(self) -> list[int]: ...
class ComparisonDirectionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonDirection attribute with the given value."""
@property
def value(self) -> str: ...
class ComparisonTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonType attribute with the given value."""
@property
def value(self) -> str: ...
class PrecisionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Precision attribute with the given value."""
@property
def value(self) -> str: ...
class FftTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a FftType attribute with the given value."""
@property
def value(self) -> str: ...
class DequantizeModeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a DequantizeMode attribute with the given value."""
@property
def value(self) -> str: ...
class TransposeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Transpose attribute with the given value."""
@property
def value(self) -> str: ...
class FusionKindAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a FusionKind attribute with the given value."""
@property
def value(self) -> str: ...
class RngDistributionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a RngDistribution attribute with the given value."""
@property
def value(self) -> str: ...
class RngAlgorithmAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a RngAlgorithm attribute with the given value."""
@property
def value(self) -> str: ...
class ChannelHandle(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, handle: int, type: int, context: ir.Context | None = None) -> Self:
"""Creates a ChannelHandle attribute."""
@property
def handle(self) -> int: ...
@property
def channel_type(self) -> int: ...
class TypeExtensions(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, bounds: Sequence[int], context: ir.Context | None = None) -> Self:
"""Creates a TypeExtensions with the given bounds."""
@property
def bounds(self) -> list[int]: ...
Binary file not shown.
@@ -0,0 +1,295 @@
# Copyright 2026 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Sequence
import enum
from mlir import ir
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
def register_inliner_extensions(arg: ir.Context, /) -> None: ...
class BarrierType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
orders_tensor_core: bool = False, context: ir.Context | None = None
) -> BarrierType:
"""Creates a BarrierType."""
@property
def orders_tensor_core(self) -> bool: ...
class TileTransformAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
tiling: Sequence[int], context: ir.Context | None = None
) -> TileTransformAttr:
"""Creates a TileTransformAttr with the given tiling."""
@property
def tiling(self) -> ir.DenseI32ArrayAttr: ...
class TransposeTransformAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
permutation: Sequence[int], context: ir.Context | None = None
) -> TransposeTransformAttr:
"""Creates a TransposeTransformAttr with the given permutation."""
@property
def permutation(self) -> ir.DenseI32ArrayAttr: ...
class SwizzleTransformAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
swizzle: int, context: ir.Context | None = None
) -> SwizzleTransformAttr:
"""Creates a SwizzleTransformAttr with the given swizzle."""
@property
def swizzle(self) -> int: ...
class WGSplatFragLayoutAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
shape: ir.DenseI64ArrayAttr, context: ir.Context | None = None
) -> WGSplatFragLayoutAttr:
"""Creates a WGSplatFragLayoutAttr with the given shape."""
@property
def shape(self) -> ir.DenseI64ArrayAttr: ...
class WGStridedFragLayoutAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
shape: ir.DenseI64ArrayAttr,
vector_size: int,
context: ir.Context | None = None,
) -> WGStridedFragLayoutAttr:
"""Creates a WGStridedFragLayoutAttr."""
@property
def shape(self) -> ir.DenseI64ArrayAttr: ...
@property
def vector_size(self) -> int: ...
class ReplicatedAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(times: int, context: ir.Context | None = None) -> ReplicatedAttr:
"""Creates a ReplicatedAttr."""
@property
def times(self) -> int: ...
class TiledLayoutAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
tiling: ir.ArrayAttr,
warp_dims: ir.ArrayAttr,
lane_dims: ir.ArrayAttr,
vector_dim: int,
context: ir.Context | None = None,
) -> TiledLayoutAttr:
"""Creates a TiledLayoutAttr."""
@property
def tiling(self) -> ir.ArrayAttr: ...
@property
def warp_dims(self) -> ir.ArrayAttr: ...
@property
def lane_dims(self) -> ir.ArrayAttr: ...
@property
def vector_dim(self) -> int: ...
class CopyPartitionAttrInterface(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
class CopyReplicatedAttr(CopyPartitionAttrInterface):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(context: ir.Context | None = None) -> CopyReplicatedAttr:
"""Creates a CopyReplicatedAttr."""
class CopyPartitionedAttr(CopyPartitionAttrInterface):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(axis: int, context: ir.Context | None = None) -> CopyPartitionedAttr:
"""Creates a CopyPartitionedAttr."""
@property
def axis(self) -> int: ...
class Tiling:
def __init__(self, tiles: Iterable) -> None: ...
def tile_shape(self, shape: Sequence[int]) -> tuple: ...
def untile_shape(self, shape: Sequence[int]) -> tuple: ...
def tile_strides(self, strides: Sequence[int]) -> tuple: ...
def tile_indices(self, indices: Sequence[int]) -> tuple: ...
def untile_indices(self, indices: Sequence[int]) -> tuple: ...
def tile_nested_shape_strides(
self, shape: Sequence[Sequence[int]], strides: Sequence[Sequence[int]]
) -> tuple: ...
def tile_dimension(self, dim: int) -> tuple: ...
def remove_dimension(self, dim: int) -> Tiling: ...
def canonicalize(self) -> Tiling: ...
@property
def tiles(self) -> tuple: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
class Replicated:
def __init__(self, times: int) -> None: ...
@property
def times(self) -> int: ...
@times.setter
def times(self, arg: int, /) -> None: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __eq__(self, arg: object, /) -> bool: ...
class TiledLayout:
def __init__(
self,
tiling: Tiling,
warp_dims: Iterable,
lane_dims: Iterable,
vector_dim: int,
_check_canonical: bool = ...,
) -> None: ...
@property
def warp_dims(self) -> tuple: ...
@property
def lane_dims(self) -> tuple: ...
@property
def partitioned_warp_dims(self) -> tuple: ...
@property
def partitioned_lane_dims(self) -> tuple: ...
@property
def vector_length(self) -> int: ...
@property
def vector_dim(self) -> int: ...
@property
def tiling(self) -> Tiling: ...
@property
def tiled_tiling_shape(self) -> tuple: ...
@property
def tiled_tiling_rank(self) -> int: ...
def warp_indices(self) -> tuple: ...
def lane_indices(self) -> tuple: ...
def canonicalize(self) -> TiledLayout: ...
def registers_shape(self, shape: Sequence[int]) -> tuple: ...
def registers_element_type(self, t: ir.Type) -> ir.Type: ...
def shape_from_registers_shape(self, shape: Sequence[int]) -> tuple: ...
@property
def base_tile_shape(self) -> tuple: ...
def remove_dimension(self, dim: int) -> TiledLayout: ...
def reduce(self, axes: Iterable) -> TiledLayout: ...
def thread_idxs(self, arg: Sequence[int], /) -> list: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __eq__(self, other: object | None) -> bool: ...
class Rounding(enum.Enum):
UP = 0
DOWN = 1
class TileTransform:
def __init__(
self, tiling: Sequence[int], rounding: Rounding | None = ...
) -> None: ...
def apply(self, arg: object, /) -> ir.Value: ...
def transform_index(self, arg: Iterable, /) -> tuple: ...
def transform_shape(self, arg: Sequence[int], /) -> tuple: ...
def transform_strides(self, arg: Sequence[int], /) -> tuple: ...
class TrivialTransferPlan:
def __init__(self) -> None: ...
@property
def tile_index_transforms(self) -> tuple: ...
def select(self, arg: Iterable, /) -> ir.Value: ...
def select_if_group(
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
) -> ir.Value: ...
class StaggeredTransferPlan:
def __init__(
self, stagger: int, dim: int, size: int, group_pred: object
) -> None: ...
@property
def stagger(self) -> int: ...
@property
def dim(self) -> int: ...
@property
def size(self) -> int: ...
@property
def group_pred(self) -> ir.Value: ...
@property
def tile_index_transforms(self) -> tuple: ...
def select(self, arg: Iterable, /) -> ir.Value: ...
def select_if_group(
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
) -> ir.Value: ...
@@ -0,0 +1,210 @@
"""SDY main Python extension"""
from typing import Self
from collections.abc import Sequence
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
class MeshAxisAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, name: str, size: int, context: ir.Context | None = None) -> Self:
"""Creates a MeshAxisAttr with the given axis name and size."""
@property
def name(self) -> str: ...
@property
def size(self) -> int: ...
class MeshAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, mesh_axes: Sequence[ir.Attribute], device_ids: Sequence[int] = [], context: ir.Context | None = None) -> Self:
"""Creates a MeshAttr with the given mesh axes."""
@property
def device_ids(self) -> list[int]: ...
@property
def axes(self) -> list[ir.Attribute]: ...
class SubAxisInfoAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, pre_size: int, size: int, context: ir.Context | None = None) -> Self:
"""Creates a SubAxisInfoAttr with the given pre-size and size."""
@property
def pre_size(self) -> int: ...
@property
def size(self) -> int: ...
class AxisRefAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, name: str, sub_axis_info: ir.Attribute | None = None, context: ir.Context | None = None) -> Self:
"""Creates an AxisRefAttr with the given name and SubAxisInfoAttr."""
@property
def name(self) -> str: ...
@property
def sub_axis_info(self) -> ir.Attribute | None: ...
class DimensionShardingAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, axes: Sequence[ir.Attribute], is_closed: bool, priority: int | None = None, context: ir.Context | None = None) -> Self:
"""
Creates a DimensionShardingAttr with the given axes, whether it's closed, and priority.
"""
@property
def axes(self) -> list[ir.Attribute]: ...
@property
def is_closed(self) -> bool: ...
@property
def priority(self) -> int | None: ...
class TensorShardingAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, mesh_or_ref: str | ir.Attribute, dimension_shardings: Sequence[ir.Attribute], replicated_axes: Sequence[ir.Attribute] = [], unreduced_axes: Sequence[ir.Attribute] = [], context: ir.Context | None = None) -> Self:
"""
Creates a TensorShardingAttr with either an inlined mesh or mesh name, dimension shardings, and replicated axes.
"""
@property
def mesh_or_ref(self) -> ir.Attribute: ...
@property
def dimension_shardings(self) -> list[ir.Attribute]: ...
@property
def replicated_axes(self) -> list[ir.Attribute]: ...
@property
def unreduced_axes(self) -> list[ir.Attribute]: ...
class TensorShardingPerValueAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, shardings: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
"""Creates a TensorShardingPerValueAttr with the tensor shardings."""
@property
def shardings(self) -> list[ir.Attribute]: ...
class DimMappingAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, factor_indices: Sequence[int], context: ir.Context | None = None) -> Self:
"""Creates a DimMappingAttr with the factor indices."""
@property
def factor_indices(self) -> list[int]: ...
class TensorMappingAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, dim_mappings: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
"""Creates a TensorMappingAttr with the dim mappings."""
@property
def dim_mappings(self) -> list[ir.Attribute]: ...
@property
def rank(self) -> int: ...
class OpShardingRuleAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, factor_sizes: Sequence[int], operand_mappings: Sequence[ir.Attribute], result_mappings: Sequence[ir.Attribute], reduction_factors: Sequence[int] = [], need_replication_factors: Sequence[int] = [], permutation_factors: Sequence[int] = [], blocked_propagation_factors: Sequence[int] = [], is_custom: bool = False, context: ir.Context | None = None) -> Self:
"""
Creates a OpShardingRuleAttr with the factor sizes and mappings for operands and results.
"""
@property
def is_custom(self) -> bool: ...
@property
def factor_sizes(self) -> list[int]: ...
@property
def operand_mappings(self) -> list[ir.Attribute]: ...
@property
def result_mappings(self) -> list[ir.Attribute]: ...
@property
def reduction_factors(self) -> list[int]: ...
@property
def need_replication_factors(self) -> list[int]: ...
@property
def permutation_factors(self) -> list[int]: ...
@property
def blocked_propagation_factors(self) -> list[int]: ...
class ManualAxesAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, manual_axes: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
"""Creates a ManualAxesAttr with the given manual axes."""
def __getitem__(self, arg: int, /) -> str: ...
def __len__(self) -> int: ...
Binary file not shown.
@@ -0,0 +1,23 @@
"""MPMD main Python extension"""
from typing import Self
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
class UserOriginAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, name: str, transpose_count: int, context: ir.Context | None = None) -> Self:
"""Creates a UserOriginAttr with the given user name and transpose count."""
@property
def user_name(self) -> str: ...
@property
def transpose_count(self) -> int: ...
Binary file not shown.
@@ -0,0 +1,407 @@
"""stablehlo main python extension"""
from collections.abc import Sequence
import enum
from typing import overload, Self
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
def register_interpreter_dialect(context: ir.Context, load: bool = True) -> None: ...
def register_stablehlo_passes() -> None: ...
class TokenType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, context: ir.Context | None = None) -> Self:
"""Creates a Token type."""
class FutureType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, types: Sequence[ir.Type], context: ir.Context | None = None) -> Self:
"""Creates a Future type."""
@property
def types(self) -> list[ir.Type]: ...
class ScatterDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, update_window_dims: Sequence[int], inserted_window_dims: Sequence[int], input_batching_dims: Sequence[int], scatter_indices_batching_dims: Sequence[int], scattered_dims_to_operand_dims: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
"""
Creates a ScatterDimensionNumbers with the given dimension configuration.
"""
@property
def update_window_dims(self) -> list[int]: ...
@property
def inserted_window_dims(self) -> list[int]: ...
@property
def input_batching_dims(self) -> list[int]: ...
@property
def scatter_indices_batching_dims(self) -> list[int]: ...
@property
def scattered_dims_to_operand_dims(self) -> list[int]: ...
@property
def index_vector_dim(self) -> int: ...
class GatherDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, offset_dims: Sequence[int], collapsed_slice_dims: Sequence[int], operand_batching_dims: Sequence[int], start_indices_batching_dims: Sequence[int], start_index_map: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
"""
Creates a GatherDimensionNumbers attribute with the given dimension configuration.
"""
@property
def offset_dims(self) -> list[int]: ...
@property
def collapsed_slice_dims(self) -> list[int]: ...
@property
def operand_batching_dims(self) -> list[int]: ...
@property
def start_indices_batching_dims(self) -> list[int]: ...
@property
def start_index_map(self) -> list[int]: ...
@property
def index_vector_dim(self) -> int: ...
class DotAlgorithm(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, lhs_precision_type: ir.Type, rhs_precision_type: ir.Type, accumulation_type: ir.Type, lhs_component_count: int, rhs_component_count: int, num_primitive_operations: int, allow_imprecise_accumulation: bool, ctx: ir.Context | None = None) -> Self:
"""
Creates a DotAlgorithm attribute with the given dimension configuration.
"""
@property
def lhs_precision_type(self) -> ir.Type: ...
@property
def rhs_precision_type(self) -> ir.Type: ...
@property
def accumulation_type(self) -> ir.Type: ...
@property
def lhs_component_count(self) -> int: ...
@property
def rhs_component_count(self) -> int: ...
@property
def num_primitive_operations(self) -> int: ...
@property
def allow_imprecise_accumulation(self) -> bool: ...
class DotDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
"""
Creates a DotDimensionNumbers attribute with the given dimension configuration.
"""
@property
def lhs_batching_dimensions(self) -> list[int]: ...
@property
def rhs_batching_dimensions(self) -> list[int]: ...
@property
def lhs_contracting_dimensions(self) -> list[int]: ...
@property
def rhs_contracting_dimensions(self) -> list[int]: ...
class ConvDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, input_batch_dimension: int, input_feature_dimension: int, input_spatial_dimensions: Sequence[int], kernel_input_feature_dimension: int, kernel_output_feature_dimension: int, kernel_spatial_dimensions: Sequence[int], output_batch_dimension: int, output_feature_dimension: int, output_spatial_dimensions: Sequence[int], ctx: ir.Context | None = None) -> Self:
"""
Creates a ConvDimensionNumbers attribute with the given dimension configuration.
"""
@property
def input_batch_dimension(self) -> int: ...
@property
def input_feature_dimension(self) -> int: ...
@property
def input_spatial_dimensions(self) -> list[int]: ...
@property
def kernel_input_feature_dimension(self) -> int: ...
@property
def kernel_output_feature_dimension(self) -> int: ...
@property
def kernel_spatial_dimensions(self) -> list[int]: ...
@property
def output_batch_dimension(self) -> int: ...
@property
def output_feature_dimension(self) -> int: ...
@property
def output_spatial_dimensions(self) -> list[int]: ...
class OutputOperandAlias(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, output_tuple_indices: Sequence[int], operand_index: int, operand_tuple_indices: Sequence[int], ctx: ir.Context | None = None) -> Self:
"""Creates a OutputOperandAlias attribute with the given tuple index."""
@property
def output_tuple_indices(self) -> list[int]: ...
@property
def operand_index(self) -> int: ...
@property
def operand_tuple_indices(self) -> list[int]: ...
class ComparisonDirectionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonDirection attribute with the given value."""
@property
def value(self) -> str: ...
class ComparisonTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonType attribute with the given value."""
@property
def value(self) -> str: ...
class PrecisionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Precision attribute with the given value."""
@property
def value(self) -> str: ...
class FftTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a FftType attribute with the given value."""
@property
def value(self) -> str: ...
class TransposeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Transpose attribute with the given value."""
@property
def value(self) -> str: ...
class RngDistributionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a RngDistribution attribute with the given value."""
@property
def value(self) -> str: ...
class RngAlgorithmAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a RngAlgorithm attribute with the given value."""
@property
def value(self) -> str: ...
class ChannelHandle(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, handle: int, type: int, context: ir.Context | None = None) -> Self:
"""Creates a ChannelHandle attribute."""
@property
def handle(self) -> int: ...
@property
def channel_type(self) -> int: ...
class TypeExtensions(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, bounds: Sequence[int], context: ir.Context | None = None) -> Self:
"""Creates a TypeExtensions with the given bounds."""
@property
def bounds(self) -> list[int]: ...
class ResultAccuracyAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, atol: float, rtol: float, ulps: int, mode: str, context: ir.Context | None = None) -> Self:
"""Creates a ResultAccuracyAttr with the given values."""
@property
def atol(self) -> float: ...
@property
def rtol(self) -> float: ...
@property
def ulps(self) -> int: ...
@property
def mode(self) -> str: ...
class ResultAccuracyModeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ResultAccuracyModeAttr with the given values."""
@property
def value(self) -> str: ...
def get_api_version() -> int: ...
def get_smaller_version(version1: str, version2: str) -> str: ...
def get_current_version() -> str: ...
def get_minimum_version() -> str: ...
@overload
def serialize_portable_artifact_str(module_str: str, target_version: str) -> bytes: ...
@overload
def serialize_portable_artifact_str(module_str: bytes, target_version: str) -> bytes: ...
@overload
def deserialize_portable_artifact_str(artifact_str: str) -> bytes: ...
@overload
def deserialize_portable_artifact_str(artifact_str: bytes) -> bytes: ...
class StablehloCompatibilityRequirement(enum.Enum):
NONE = 0
WEEK_4 = 1
WEEK_12 = 2
MAX = 3
def get_version_from_compatibility_requirement(requirement: StablehloCompatibilityRequirement) -> str: ...
def serialize_portable_artifact(module: ir.Module, target: str, allow_other_dialects: bool = False) -> bytes: ...
@overload
def deserialize_portable_artifact(context: ir.Context, artifact: str) -> ir.Module: ...
@overload
def deserialize_portable_artifact(context: ir.Context, artifact: bytes) -> ir.Module: ...
def eval_module(module: ir.Module, args: Sequence[ir.Attribute], probe_instrumentation_dir: str = '') -> list[ir.Attribute]: ...
@@ -0,0 +1,32 @@
# Copyright 2026 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mlir import ir
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
def private_has_communication(arg: ir.Operation, /) -> tuple[bool, bool]: ...
def private_set_arg_attr(
arg0: ir.Operation, arg1: int, arg2: str, arg3: ir.Attribute, /
) -> None: ...
class Float8EXMYType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get(
exmy_type: ir.Type | None = None, ctx: ir.Context | None = None
) -> Float8EXMYType: ...
@property
def underlying_type(self) -> ir.Type: ...
Binary file not shown.
@@ -0,0 +1,36 @@
# Copyright 2026 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
class PointerType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(pointee_type: ir.Type, address_space: int) -> PointerType:
"""Creates a PointerType type."""
@property
def pointee_type(self) -> ir.Type: ...
@property
def address_space(self) -> int: ...
def infer_reduce_op_encoding(
arg0: ir.Attribute, arg1: int, /
) -> ir.Attribute | None: ...
@@ -0,0 +1,285 @@
# Autogenerated by mlir-tblgen; don't manually edit.
from enum import IntEnum, auto, IntFlag
from ._ods_common import _cext as _ods_cext
from ..ir import register_attribute_builder
_ods_ir = _ods_cext.ir
class CmpFPredicate(IntEnum):
"""allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15"""
AlwaysFalse = 0
OEQ = 1
OGT = 2
OGE = 3
OLT = 4
OLE = 5
ONE = 6
ORD = 7
UEQ = 8
UGT = 9
UGE = 10
ULT = 11
ULE = 12
UNE = 13
UNO = 14
AlwaysTrue = 15
def __str__(self):
if self is CmpFPredicate.AlwaysFalse:
return "false"
if self is CmpFPredicate.OEQ:
return "oeq"
if self is CmpFPredicate.OGT:
return "ogt"
if self is CmpFPredicate.OGE:
return "oge"
if self is CmpFPredicate.OLT:
return "olt"
if self is CmpFPredicate.OLE:
return "ole"
if self is CmpFPredicate.ONE:
return "one"
if self is CmpFPredicate.ORD:
return "ord"
if self is CmpFPredicate.UEQ:
return "ueq"
if self is CmpFPredicate.UGT:
return "ugt"
if self is CmpFPredicate.UGE:
return "uge"
if self is CmpFPredicate.ULT:
return "ult"
if self is CmpFPredicate.ULE:
return "ule"
if self is CmpFPredicate.UNE:
return "une"
if self is CmpFPredicate.UNO:
return "uno"
if self is CmpFPredicate.AlwaysTrue:
return "true"
raise ValueError("Unknown CmpFPredicate enum entry.")
@register_attribute_builder("Arith_CmpFPredicateAttr", allow_existing=True)
def _arith_cmpfpredicateattr(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
class CmpIPredicate(IntEnum):
"""allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9"""
eq = 0
ne = 1
slt = 2
sle = 3
sgt = 4
sge = 5
ult = 6
ule = 7
ugt = 8
uge = 9
def __str__(self):
if self is CmpIPredicate.eq:
return "eq"
if self is CmpIPredicate.ne:
return "ne"
if self is CmpIPredicate.slt:
return "slt"
if self is CmpIPredicate.sle:
return "sle"
if self is CmpIPredicate.sgt:
return "sgt"
if self is CmpIPredicate.sge:
return "sge"
if self is CmpIPredicate.ult:
return "ult"
if self is CmpIPredicate.ule:
return "ule"
if self is CmpIPredicate.ugt:
return "ugt"
if self is CmpIPredicate.uge:
return "uge"
raise ValueError("Unknown CmpIPredicate enum entry.")
@register_attribute_builder("Arith_CmpIPredicateAttr", allow_existing=True)
def _arith_cmpipredicateattr(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
class IntegerOverflowFlags(IntFlag):
"""Integer overflow arith flags"""
none = 0
nsw = 1
nuw = 2
def __iter__(self):
return iter([case for case in type(self) if (self & case) is case and self is not case])
def __len__(self):
return bin(self).count("1")
def __str__(self):
if len(self) > 1:
return ", ".join(map(str, self))
if self is IntegerOverflowFlags.none:
return "none"
if self is IntegerOverflowFlags.nsw:
return "nsw"
if self is IntegerOverflowFlags.nuw:
return "nuw"
raise ValueError("Unknown IntegerOverflowFlags enum entry.")
@register_attribute_builder("Arith_IntegerOverflowFlags", allow_existing=True)
def _arith_integeroverflowflags(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class RoundingMode(IntEnum):
"""Floating point rounding mode"""
to_nearest_even = 0
downward = 1
upward = 2
toward_zero = 3
to_nearest_away = 4
def __str__(self):
if self is RoundingMode.to_nearest_even:
return "to_nearest_even"
if self is RoundingMode.downward:
return "downward"
if self is RoundingMode.upward:
return "upward"
if self is RoundingMode.toward_zero:
return "toward_zero"
if self is RoundingMode.to_nearest_away:
return "to_nearest_away"
raise ValueError("Unknown RoundingMode enum entry.")
@register_attribute_builder("Arith_RoundingModeAttr", allow_existing=True)
def _arith_roundingmodeattr(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class AtomicRMWKind(IntEnum):
"""allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15"""
addf = 0
addi = 1
andi = 2
assign = 3
maximumf = 4
maxnumf = 5
maxs = 6
maxu = 7
minimumf = 8
minnumf = 9
mins = 10
minu = 11
mulf = 12
muli = 13
ori = 14
xori = 15
def __str__(self):
if self is AtomicRMWKind.addf:
return "addf"
if self is AtomicRMWKind.addi:
return "addi"
if self is AtomicRMWKind.andi:
return "andi"
if self is AtomicRMWKind.assign:
return "assign"
if self is AtomicRMWKind.maximumf:
return "maximumf"
if self is AtomicRMWKind.maxnumf:
return "maxnumf"
if self is AtomicRMWKind.maxs:
return "maxs"
if self is AtomicRMWKind.maxu:
return "maxu"
if self is AtomicRMWKind.minimumf:
return "minimumf"
if self is AtomicRMWKind.minnumf:
return "minnumf"
if self is AtomicRMWKind.mins:
return "mins"
if self is AtomicRMWKind.minu:
return "minu"
if self is AtomicRMWKind.mulf:
return "mulf"
if self is AtomicRMWKind.muli:
return "muli"
if self is AtomicRMWKind.ori:
return "ori"
if self is AtomicRMWKind.xori:
return "xori"
raise ValueError("Unknown AtomicRMWKind enum entry.")
@register_attribute_builder("AtomicRMWKindAttr", allow_existing=True)
def _atomicrmwkindattr(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
class FastMathFlags(IntFlag):
"""Floating point fast math flags"""
none = 0
reassoc = 1
nnan = 2
ninf = 4
nsz = 8
arcp = 16
contract = 32
afn = 64
fast = 127
def __iter__(self):
return iter([case for case in type(self) if (self & case) is case and self is not case])
def __len__(self):
return bin(self).count("1")
def __str__(self):
if len(self) > 1:
return ",".join(map(str, self))
if self is FastMathFlags.none:
return "none"
if self is FastMathFlags.reassoc:
return "reassoc"
if self is FastMathFlags.nnan:
return "nnan"
if self is FastMathFlags.ninf:
return "ninf"
if self is FastMathFlags.nsz:
return "nsz"
if self is FastMathFlags.arcp:
return "arcp"
if self is FastMathFlags.contract:
return "contract"
if self is FastMathFlags.afn:
return "afn"
if self is FastMathFlags.fast:
return "fast"
raise ValueError("Unknown FastMathFlags enum entry.")
@register_attribute_builder("FastMathFlags", allow_existing=True)
def _fastmathflags(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
@register_attribute_builder("arith.Arith_FastMathAttr")
def _arith_fastmathattr(x, context):
return _ods_ir.Attribute.parse(f'#arith.fastmath<{str(x)}>', context=context)
@register_attribute_builder("arith.Arith_IntegerOverflowAttr")
def _arith_integeroverflowattr(x, context):
return _ods_ir.Attribute.parse(f'#arith.overflow<{str(x)}>', context=context)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,199 @@
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
import sys as _sys
if _sys.version_info >= (3, 12):
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
else:
try:
from typing_extensions import Buffer as _Buffer
except ImportError:
_Buffer = _Any
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "builtin"
@_ods_cext.register_operation(_Dialect)
class ModuleOp(_ods_ir.OpView):
r"""
A `module` represents a top-level container operation. It contains a single
[graph region](../LangRef.md#control-flow-and-ssacfg-regions) containing a single block
which can contain any operations and does not have a terminator. Operations
within this region cannot implicitly capture values defined outside the module,
i.e. Modules are [IsolatedFromAbove](../Traits#isolatedfromabove). Modules have
an optional [symbol name](../SymbolsAndSymbolTables.md) which can be used to refer
to them in operations.
Example:
```mlir
module {
func.func @foo()
}
```
"""
OPERATION_NAME = "builtin.module"
_ODS_REGIONS = (1, True)
def __init__(self, *, sym_name: _Optional[_Union[str, _ods_ir.StringAttr]] = None, sym_visibility: _Optional[_Union[str, _ods_ir.StringAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
_ods_context = _ods_get_default_loc_context(loc)
if sym_name is not None: attributes["sym_name"] = (sym_name if (
isinstance(sym_name, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('SymbolNameAttr')) else
_ods_ir.AttrBuilder.get('SymbolNameAttr')(sym_name, context=_ods_context))
if sym_visibility is not None: attributes["sym_visibility"] = (sym_visibility if (
isinstance(sym_visibility, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('StrAttr')) else
_ods_ir.AttrBuilder.get('StrAttr')(sym_visibility, context=_ods_context))
results = []
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def sym_name(self) -> _Optional[_ods_ir.StringAttr]:
if "sym_name" not in self.operation.attributes:
return None
return self.operation.attributes["sym_name"]
@sym_name.setter
def sym_name(self, value: _Optional[_ods_ir.StringAttr]):
if value is not None:
self.operation.attributes["sym_name"] = value
elif "sym_name" in self.operation.attributes:
del self.operation.attributes["sym_name"]
@sym_name.deleter
def sym_name(self):
del self.operation.attributes["sym_name"]
@builtins.property
def sym_visibility(self) -> _Optional[_ods_ir.StringAttr]:
if "sym_visibility" not in self.operation.attributes:
return None
return self.operation.attributes["sym_visibility"]
@sym_visibility.setter
def sym_visibility(self, value: _Optional[_ods_ir.StringAttr]):
if value is not None:
self.operation.attributes["sym_visibility"] = value
elif "sym_visibility" in self.operation.attributes:
del self.operation.attributes["sym_visibility"]
@sym_visibility.deleter
def sym_visibility(self):
del self.operation.attributes["sym_visibility"]
@builtins.property
def bodyRegion(self) -> _ods_ir.Region:
return self.regions[0]
@_ods_cext.register_op_adaptor(ModuleOp)
class ModuleOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "builtin.module"
@builtins.property
def sym_name(self) -> _Optional[_ods_ir.StringAttr]:
if "sym_name" not in self.attributes:
return None
return self.attributes["sym_name"]
@builtins.property
def sym_visibility(self) -> _Optional[_ods_ir.StringAttr]:
if "sym_visibility" not in self.attributes:
return None
return self.attributes["sym_visibility"]
def module(*, sym_name: _Optional[_Union[str, _ods_ir.StringAttr]] = None, sym_visibility: _Optional[_Union[str, _ods_ir.StringAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> ModuleOp:
return ModuleOp(sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect)
class UnrealizedConversionCastOp(_ods_ir.OpView):
r"""
An `unrealized_conversion_cast` operation represents an unrealized
conversion from one set of types to another, that is used to enable the
inter-mixing of different type systems. This operation should not be
attributed any special representational or execution semantics, and is
generally only intended to be used to satisfy the temporary intermixing of
type systems during the conversion of one type system to another.
This operation may produce results of arity 1-N, and accept as input
operands of arity 0-N.
Example:
```mlir
// An unrealized 0-1 conversion. These types of conversions are useful in
// cases where a type is removed from the type system, but not all uses have
// been converted. For example, imagine we have a tuple type that is
// expanded to its element types. If only some uses of an empty tuple type
// instance are converted we still need an instance of the tuple type, but
// have no inputs to the unrealized conversion.
%result = unrealized_conversion_cast to !bar.tuple_type<>
// An unrealized 1-1 conversion.
%result1 = unrealized_conversion_cast %operand : !foo.type to !bar.lowered_type
// An unrealized 1-N conversion.
%results2:2 = unrealized_conversion_cast %tuple_operand : !foo.tuple_type<!foo.type, !foo.type> to !foo.type, !foo.type
// An unrealized N-1 conversion.
%result3 = unrealized_conversion_cast %operand, %operand : !foo.type, !foo.type to !bar.tuple_type<!foo.type, !foo.type>
```
"""
OPERATION_NAME = "builtin.unrealized_conversion_cast"
_ODS_REGIONS = (0, True)
def __init__(self, outputs: _Sequence[_ods_ir.Type], inputs: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(inputs))
_ods_context = _ods_get_default_loc_context(loc)
results = []
results.extend(outputs)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def inputs(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def outputs(self) -> _ods_ir.OpResultList:
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
return self.operation.results[0:0 + _ods_variadic_group_length]
@_ods_cext.register_op_adaptor(UnrealizedConversionCastOp)
class UnrealizedConversionCastOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "builtin.unrealized_conversion_cast"
@builtins.property
def inputs(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
def unrealized_conversion_cast(outputs: _Sequence[_ods_ir.Type], inputs: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, UnrealizedConversionCastOp]:
op = UnrealizedConversionCastOp(outputs=outputs, inputs=inputs, loc=loc, ip=ip); results = op.results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@@ -0,0 +1,400 @@
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
import sys as _sys
if _sys.version_info >= (3, 12):
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
else:
try:
from typing_extensions import Buffer as _Buffer
except ImportError:
_Buffer = _Any
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "cf"
@_ods_cext.register_operation(_Dialect)
class AssertOp(_ods_ir.OpView):
r"""
Assert operation at runtime with single boolean operand and an error
message attribute.
If the argument is `true` this operation has no effect. Otherwise, the
program execution will abort. The provided error message may be used by a
runtime to propagate the error to the user.
Example:
```mlir
cf.assert %b, "Expected ... to be true"
```
"""
OPERATION_NAME = "cf.assert"
_ODS_REGIONS = (0, True)
def __init__(self, arg: _ods_ir.Value[_ods_ir.IntegerType], msg: _Union[str, _ods_ir.StringAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(arg)
_ods_context = _ods_get_default_loc_context(loc)
attributes["msg"] = (msg if (
isinstance(msg, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('StrAttr')) else
_ods_ir.AttrBuilder.get('StrAttr')(msg, context=_ods_context))
results = []
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def arg(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
return self.operation.operands[0]
@builtins.property
def msg(self) -> _ods_ir.StringAttr:
return self.operation.attributes["msg"]
@msg.setter
def msg(self, value: _ods_ir.StringAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["msg"] = value
@_ods_cext.register_op_adaptor(AssertOp)
class AssertOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "cf.assert"
@builtins.property
def arg(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
return self.operands[0]
@builtins.property
def msg(self) -> _ods_ir.StringAttr:
return self.attributes["msg"]
def assert_(arg: _ods_ir.Value[_ods_ir.IntegerType], msg: _Union[str, _ods_ir.StringAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> AssertOp:
return AssertOp(arg=arg, msg=msg, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect)
class BranchOp(_ods_ir.OpView):
r"""
The `cf.br` operation represents a direct branch operation to a given
block. The operands of this operation are forwarded to the successor block,
and the number and type of the operands must match the arguments of the
target block.
Example:
```mlir
^bb2:
%2 = call @someFn()
cf.br ^bb3(%2 : tensor<*xf32>)
^bb3(%3: tensor<*xf32>):
```
"""
OPERATION_NAME = "cf.br"
_ODS_REGIONS = (0, True)
def __init__(self, destOperands: _Sequence[_ods_ir.Value], dest: _ods_ir.Block, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(destOperands))
_ods_context = _ods_get_default_loc_context(loc)
results = []
_ods_successors = []
_ods_successors.append(dest)
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def destOperands(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@_ods_cext.register_op_adaptor(BranchOp)
class BranchOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "cf.br"
@builtins.property
def destOperands(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
def br(dest_operands: _Sequence[_ods_ir.Value], dest: _ods_ir.Block, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> BranchOp:
return BranchOp(destOperands=dest_operands, dest=dest, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect)
class CondBranchOp(_ods_ir.OpView):
r"""
The `cf.cond_br` terminator operation represents a conditional branch on a
boolean (1-bit integer) value. If the bit is set, then the first destination
is jumped to; if it is false, the second destination is chosen. The count
and types of operands must align with the arguments in the corresponding
target blocks.
The MLIR conditional branch operation is not allowed to target the entry
block for a region. The two destinations of the conditional branch operation
are allowed to be the same.
The following example illustrates a function with a conditional branch
operation that targets the same block.
Example:
```mlir
func.func @select(%a: i32, %b: i32, %flag: i1) -> i32 {
// Both targets are the same, operands differ
cf.cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32)
^bb1(%x : i32) :
return %x : i32
}
```
"""
OPERATION_NAME = "cf.cond_br"
_ODS_OPERAND_SEGMENTS = [1,-1,-1,]
_ODS_REGIONS = (0, True)
def __init__(self, condition: _ods_ir.Value[_ods_ir.IntegerType], trueDestOperands: _Sequence[_ods_ir.Value], falseDestOperands: _Sequence[_ods_ir.Value], trueDest: _ods_ir.Block, falseDest: _ods_ir.Block, *, branch_weights: _Optional[_Union[_Sequence[int], _ods_ir.DenseI32ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(condition)
operands.append(_get_op_results_or_values(trueDestOperands))
operands.append(_get_op_results_or_values(falseDestOperands))
_ods_context = _ods_get_default_loc_context(loc)
if branch_weights is not None: attributes["branch_weights"] = (branch_weights if (
isinstance(branch_weights, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DenseI32ArrayAttr')) else
_ods_ir.AttrBuilder.get('DenseI32ArrayAttr')(branch_weights, context=_ods_context))
results = []
_ods_successors = []
_ods_successors.append(trueDest)
_ods_successors.append(falseDest)
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def condition(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
operand_range = _ods_segmented_accessor(
self.operation.operands,
self.operation.attributes["operandSegmentSizes"], 0)
return operand_range[0]
@builtins.property
def trueDestOperands(self) -> _ods_ir.OpOperandList:
operand_range = _ods_segmented_accessor(
self.operation.operands,
self.operation.attributes["operandSegmentSizes"], 1)
return operand_range
@builtins.property
def falseDestOperands(self) -> _ods_ir.OpOperandList:
operand_range = _ods_segmented_accessor(
self.operation.operands,
self.operation.attributes["operandSegmentSizes"], 2)
return operand_range
@builtins.property
def branch_weights(self) -> _Optional[_ods_ir.DenseI32ArrayAttr]:
if "branch_weights" not in self.operation.attributes:
return None
return self.operation.attributes["branch_weights"]
@branch_weights.setter
def branch_weights(self, value: _Optional[_ods_ir.DenseI32ArrayAttr]):
if value is not None:
self.operation.attributes["branch_weights"] = value
elif "branch_weights" in self.operation.attributes:
del self.operation.attributes["branch_weights"]
@branch_weights.deleter
def branch_weights(self):
del self.operation.attributes["branch_weights"]
@_ods_cext.register_op_adaptor(CondBranchOp)
class CondBranchOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "cf.cond_br"
@builtins.property
def condition(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
operand_range = _ods_segmented_accessor(
self.operands,
self.attributes["operandSegmentSizes"], 0)
return operand_range[0]
@builtins.property
def trueDestOperands(self) -> _ods_ir.OpOperandList:
operand_range = _ods_segmented_accessor(
self.operands,
self.attributes["operandSegmentSizes"], 1)
return operand_range
@builtins.property
def falseDestOperands(self) -> _ods_ir.OpOperandList:
operand_range = _ods_segmented_accessor(
self.operands,
self.attributes["operandSegmentSizes"], 2)
return operand_range
@builtins.property
def branch_weights(self) -> _Optional[_ods_ir.DenseI32ArrayAttr]:
if "branch_weights" not in self.attributes:
return None
return self.attributes["branch_weights"]
def cond_br(condition: _ods_ir.Value[_ods_ir.IntegerType], true_dest_operands: _Sequence[_ods_ir.Value], false_dest_operands: _Sequence[_ods_ir.Value], true_dest: _ods_ir.Block, false_dest: _ods_ir.Block, *, branch_weights: _Optional[_Union[_Sequence[int], _ods_ir.DenseI32ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> CondBranchOp:
return CondBranchOp(condition=condition, trueDestOperands=true_dest_operands, falseDestOperands=false_dest_operands, trueDest=true_dest, falseDest=false_dest, branch_weights=branch_weights, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect)
class SwitchOp(_ods_ir.OpView):
r"""
The `cf.switch` terminator operation represents a switch on a signless integer
value. If the flag matches one of the specified cases, then the
corresponding destination is jumped to. If the flag does not match any of
the cases, the default destination is jumped to. The count and types of
operands must align with the arguments in the corresponding target blocks.
Example:
```mlir
cf.switch %flag : i32, [
default: ^bb1(%a : i32),
42: ^bb1(%b : i32),
43: ^bb3(%c : i32)
]
```
"""
OPERATION_NAME = "cf.switch"
_ODS_OPERAND_SEGMENTS = [1,-1,-1,]
_ODS_REGIONS = (0, True)
def __init__(self, flag: _ods_ir.Value[_ods_ir.IntegerType], defaultOperands: _Sequence[_ods_ir.Value], caseOperands: _Sequence[_ods_ir.Value], case_operand_segments: _Union[_Sequence[int], _ods_ir.DenseI32ArrayAttr], defaultDestination: _ods_ir.Block, caseDestinations: _Sequence[_ods_ir.Block], *, case_values: _Optional[_Union[_Any, _ods_ir.DenseIntElementsAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(flag)
operands.append(_get_op_results_or_values(defaultOperands))
operands.append(_get_op_results_or_values(caseOperands))
_ods_context = _ods_get_default_loc_context(loc)
if case_values is not None: attributes["case_values"] = (case_values if (
isinstance(case_values, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('AnyIntElementsAttr')) else
_ods_ir.AttrBuilder.get('AnyIntElementsAttr')(case_values, context=_ods_context))
attributes["case_operand_segments"] = (case_operand_segments if (
isinstance(case_operand_segments, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DenseI32ArrayAttr')) else
_ods_ir.AttrBuilder.get('DenseI32ArrayAttr')(case_operand_segments, context=_ods_context))
results = []
_ods_successors = []
_ods_successors.append(defaultDestination)
_ods_successors.extend(caseDestinations)
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def flag(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
operand_range = _ods_segmented_accessor(
self.operation.operands,
self.operation.attributes["operandSegmentSizes"], 0)
return operand_range[0]
@builtins.property
def defaultOperands(self) -> _ods_ir.OpOperandList:
operand_range = _ods_segmented_accessor(
self.operation.operands,
self.operation.attributes["operandSegmentSizes"], 1)
return operand_range
@builtins.property
def caseOperands(self) -> _ods_ir.OpOperandList:
operand_range = _ods_segmented_accessor(
self.operation.operands,
self.operation.attributes["operandSegmentSizes"], 2)
return operand_range
@builtins.property
def case_values(self) -> _Optional[_ods_ir.DenseIntElementsAttr]:
if "case_values" not in self.operation.attributes:
return None
return self.operation.attributes["case_values"]
@case_values.setter
def case_values(self, value: _Optional[_ods_ir.DenseIntElementsAttr]):
if value is not None:
self.operation.attributes["case_values"] = value
elif "case_values" in self.operation.attributes:
del self.operation.attributes["case_values"]
@case_values.deleter
def case_values(self):
del self.operation.attributes["case_values"]
@builtins.property
def case_operand_segments(self) -> _ods_ir.DenseI32ArrayAttr:
return self.operation.attributes["case_operand_segments"]
@case_operand_segments.setter
def case_operand_segments(self, value: _ods_ir.DenseI32ArrayAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["case_operand_segments"] = value
@_ods_cext.register_op_adaptor(SwitchOp)
class SwitchOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "cf.switch"
@builtins.property
def flag(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
operand_range = _ods_segmented_accessor(
self.operands,
self.attributes["operandSegmentSizes"], 0)
return operand_range[0]
@builtins.property
def defaultOperands(self) -> _ods_ir.OpOperandList:
operand_range = _ods_segmented_accessor(
self.operands,
self.attributes["operandSegmentSizes"], 1)
return operand_range
@builtins.property
def caseOperands(self) -> _ods_ir.OpOperandList:
operand_range = _ods_segmented_accessor(
self.operands,
self.attributes["operandSegmentSizes"], 2)
return operand_range
@builtins.property
def case_values(self) -> _Optional[_ods_ir.DenseIntElementsAttr]:
if "case_values" not in self.attributes:
return None
return self.attributes["case_values"]
@builtins.property
def case_operand_segments(self) -> _ods_ir.DenseI32ArrayAttr:
return self.attributes["case_operand_segments"]
def switch(flag: _ods_ir.Value[_ods_ir.IntegerType], default_operands: _Sequence[_ods_ir.Value], case_operands: _Sequence[_ods_ir.Value], case_operand_segments: _Union[_Sequence[int], _ods_ir.DenseI32ArrayAttr], default_destination: _ods_ir.Block, case_destinations: _Sequence[_ods_ir.Block], *, case_values: _Optional[_Union[_Any, _ods_ir.DenseIntElementsAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> SwitchOp:
return SwitchOp(flag=flag, defaultOperands=default_operands, caseOperands=case_operands, case_operand_segments=case_operand_segments, defaultDestination=default_destination, caseDestinations=case_destinations, case_values=case_values, loc=loc, ip=ip)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,600 @@
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
import sys as _sys
if _sys.version_info >= (3, 12):
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
else:
try:
from typing_extensions import Buffer as _Buffer
except ImportError:
_Buffer = _Any
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "func"
@_ods_cext.register_operation(_Dialect)
class CallIndirectOp(_ods_ir.OpView):
r"""
The `func.call_indirect` operation represents an indirect call to a value
of function type. The operands and result types of the call must match the
specified function type.
Function values can be created with the
[`func.constant` operation](#funcconstant-constantop).
Example:
```mlir
%func = func.constant @my_func : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
%result = func.call_indirect %func(%0, %1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
```
"""
OPERATION_NAME = "func.call_indirect"
_ODS_REGIONS = (0, True)
def __init__(self, results_: _Sequence[_ods_ir.Type], callee: _ods_ir.Value, callee_operands: _Sequence[_ods_ir.Value], *, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(callee)
operands.extend(_get_op_results_or_values(callee_operands))
_ods_context = _ods_get_default_loc_context(loc)
if arg_attrs is not None: attributes["arg_attrs"] = (arg_attrs if (
isinstance(arg_attrs, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
_ods_ir.AttrBuilder.get('DictArrayAttr')(arg_attrs, context=_ods_context))
if res_attrs is not None: attributes["res_attrs"] = (res_attrs if (
isinstance(res_attrs, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
_ods_ir.AttrBuilder.get('DictArrayAttr')(res_attrs, context=_ods_context))
results = []
results.extend(results_)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def callee(self) -> _ods_ir.Value:
return self.operation.operands[0]
@builtins.property
def callee_operands(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 2 + 1
return self.operation.operands[1:1 + _ods_variadic_group_length]
@builtins.property
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "arg_attrs" not in self.operation.attributes:
return None
return self.operation.attributes["arg_attrs"]
@arg_attrs.setter
def arg_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
if value is not None:
self.operation.attributes["arg_attrs"] = value
elif "arg_attrs" in self.operation.attributes:
del self.operation.attributes["arg_attrs"]
@arg_attrs.deleter
def arg_attrs(self):
del self.operation.attributes["arg_attrs"]
@builtins.property
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "res_attrs" not in self.operation.attributes:
return None
return self.operation.attributes["res_attrs"]
@res_attrs.setter
def res_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
if value is not None:
self.operation.attributes["res_attrs"] = value
elif "res_attrs" in self.operation.attributes:
del self.operation.attributes["res_attrs"]
@res_attrs.deleter
def res_attrs(self):
del self.operation.attributes["res_attrs"]
@builtins.property
def results_(self) -> _ods_ir.OpResultList:
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
return self.operation.results[0:0 + _ods_variadic_group_length]
@_ods_cext.register_op_adaptor(CallIndirectOp)
class CallIndirectOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "func.call_indirect"
@builtins.property
def callee(self) -> _ods_ir.Value:
return self.operands[0]
@builtins.property
def callee_operands(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 2 + 1
return self.operands[1:1 + _ods_variadic_group_length]
@builtins.property
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "arg_attrs" not in self.attributes:
return None
return self.attributes["arg_attrs"]
@builtins.property
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "res_attrs" not in self.attributes:
return None
return self.attributes["res_attrs"]
def call_indirect(results_: _Sequence[_ods_ir.Type], callee: _ods_ir.Value, callee_operands: _Sequence[_ods_ir.Value], *, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, CallIndirectOp]:
op = CallIndirectOp(results_=results_, callee=callee, callee_operands=callee_operands, arg_attrs=arg_attrs, res_attrs=res_attrs, loc=loc, ip=ip); results = op.results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@_ods_cext.register_operation(_Dialect)
class CallOp(_ods_ir.OpView):
r"""
The `func.call` operation represents a direct call to a function that is
within the same symbol scope as the call. The operands and result types of
the call must match the specified function type. The callee is encoded as a
symbol reference attribute named "callee".
Example:
```mlir
%2 = func.call @my_add(%0, %1) : (f32, f32) -> f32
```
"""
OPERATION_NAME = "func.call"
_ODS_REGIONS = (0, True)
def __init__(self, result: _Sequence[_ods_ir.Type], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], operands_: _Sequence[_ods_ir.Value], *, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, no_inline: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(operands_))
_ods_context = _ods_get_default_loc_context(loc)
attributes["callee"] = (callee if (
isinstance(callee, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
_ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(callee, context=_ods_context))
if arg_attrs is not None: attributes["arg_attrs"] = (arg_attrs if (
isinstance(arg_attrs, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
_ods_ir.AttrBuilder.get('DictArrayAttr')(arg_attrs, context=_ods_context))
if res_attrs is not None: attributes["res_attrs"] = (res_attrs if (
isinstance(res_attrs, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
_ods_ir.AttrBuilder.get('DictArrayAttr')(res_attrs, context=_ods_context))
if bool(no_inline): attributes["no_inline"] = _ods_ir.UnitAttr.get(
_ods_get_default_loc_context(loc))
results = []
results.extend(result)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def operands_(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
return self.operation.attributes["callee"]
@callee.setter
def callee(self, value: _ods_ir.FlatSymbolRefAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["callee"] = value
@builtins.property
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "arg_attrs" not in self.operation.attributes:
return None
return self.operation.attributes["arg_attrs"]
@arg_attrs.setter
def arg_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
if value is not None:
self.operation.attributes["arg_attrs"] = value
elif "arg_attrs" in self.operation.attributes:
del self.operation.attributes["arg_attrs"]
@arg_attrs.deleter
def arg_attrs(self):
del self.operation.attributes["arg_attrs"]
@builtins.property
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "res_attrs" not in self.operation.attributes:
return None
return self.operation.attributes["res_attrs"]
@res_attrs.setter
def res_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
if value is not None:
self.operation.attributes["res_attrs"] = value
elif "res_attrs" in self.operation.attributes:
del self.operation.attributes["res_attrs"]
@res_attrs.deleter
def res_attrs(self):
del self.operation.attributes["res_attrs"]
@builtins.property
def no_inline(self) -> bool:
return "no_inline" in self.operation.attributes
@no_inline.setter
def no_inline(self, value):
if bool(value):
self.operation.attributes["no_inline"] = _ods_ir.UnitAttr.get()
elif "no_inline" in self.operation.attributes:
del self.operation.attributes["no_inline"]
@no_inline.deleter
def no_inline(self):
del self.operation.attributes["no_inline"]
@_ods_cext.register_op_adaptor(CallOp)
class CallOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "func.call"
@builtins.property
def operands_(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
return self.attributes["callee"]
@builtins.property
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "arg_attrs" not in self.attributes:
return None
return self.attributes["arg_attrs"]
@builtins.property
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "res_attrs" not in self.attributes:
return None
return self.attributes["res_attrs"]
@builtins.property
def no_inline(self) -> bool:
return "no_inline" in self.attributes
def call(result: _Sequence[_ods_ir.Type], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], operands_: _Sequence[_ods_ir.Value], *, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, no_inline: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, CallOp]:
op = CallOp(result=result, callee=callee, operands_=operands_, arg_attrs=arg_attrs, res_attrs=res_attrs, no_inline=no_inline, loc=loc, ip=ip); results = op.results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@_ods_cext.register_operation(_Dialect)
class ConstantOp(_ods_ir.OpView):
r"""
The `func.constant` operation produces an SSA value from a symbol reference
to a `func.func` operation
Example:
```mlir
// Reference to function @myfn.
%2 = func.constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32>
// Equivalent generic forms
%2 = "func.constant"() { value = @myfn } : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>)
```
MLIR does not allow direct references to functions in SSA operands because
the compiler is multithreaded, and disallowing SSA values to directly
reference a function simplifies this
([rationale](../Rationale/Rationale.md#multithreading-the-compiler)).
"""
OPERATION_NAME = "func.constant"
_ODS_REGIONS = (0, True)
def __init__(self, result: _ods_ir.Type, value: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
_ods_context = _ods_get_default_loc_context(loc)
attributes["value"] = (value if (
isinstance(value, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
_ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(value, context=_ods_context))
results = []
results.append(result)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def value(self) -> _ods_ir.FlatSymbolRefAttr:
return self.operation.attributes["value"]
@value.setter
def value(self, value: _ods_ir.FlatSymbolRefAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["value"] = value
@_ods_cext.register_op_adaptor(ConstantOp)
class ConstantOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "func.constant"
@builtins.property
def value(self) -> _ods_ir.FlatSymbolRefAttr:
return self.attributes["value"]
def constant(result: _ods_ir.Type, value: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
return ConstantOp(result=result, value=value, loc=loc, ip=ip).result
@_ods_cext.register_operation(_Dialect)
class FuncOp(_ods_ir.OpView):
r"""
Operations within the function cannot implicitly capture values defined
outside of the function, i.e. Functions are `IsolatedFromAbove`. All
external references must use function arguments or attributes that establish
a symbolic connection (e.g. symbols referenced by name via a string
attribute like SymbolRefAttr). An external function declaration (used when
referring to a function declared in some other module) has no body. While
the MLIR textual form provides a nice inline syntax for function arguments,
they are internally represented as block arguments to the first block in
the region.
Only dialect attribute names may be specified in the attribute dictionaries
for function arguments, results, or the function itself.
Example:
```mlir
// External function definitions.
func.func private @abort()
func.func private @scribble(i32, i64, memref<? x 128 x f32, #layout_map0>) -> f64
// A function that returns its argument twice:
func.func @count(%x: i64) -> (i64, i64)
attributes {fruit = "banana"} {
return %x, %x: i64, i64
}
// A function with an argument attribute
func.func private @example_fn_arg(%x: i32 {swift.self = unit})
// A function with a result attribute
func.func private @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64})
// A function with an attribute
func.func private @example_fn_attr() attributes {dialectName.attrName = false}
```
"""
OPERATION_NAME = "func.func"
_ODS_REGIONS = (1, True)
def __init__(self, sym_name: _Union[str, _ods_ir.StringAttr], function_type: _Union[_Any, _ods_ir.TypeAttr], *, sym_visibility: _Optional[_Union[str, _ods_ir.StringAttr]] = None, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, no_inline: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
_ods_context = _ods_get_default_loc_context(loc)
attributes["sym_name"] = (sym_name if (
isinstance(sym_name, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('SymbolNameAttr')) else
_ods_ir.AttrBuilder.get('SymbolNameAttr')(sym_name, context=_ods_context))
attributes["function_type"] = (function_type if (
isinstance(function_type, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('anonymous_452')) else
_ods_ir.AttrBuilder.get('anonymous_452')(function_type, context=_ods_context))
if sym_visibility is not None: attributes["sym_visibility"] = (sym_visibility if (
isinstance(sym_visibility, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('StrAttr')) else
_ods_ir.AttrBuilder.get('StrAttr')(sym_visibility, context=_ods_context))
if arg_attrs is not None: attributes["arg_attrs"] = (arg_attrs if (
isinstance(arg_attrs, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
_ods_ir.AttrBuilder.get('DictArrayAttr')(arg_attrs, context=_ods_context))
if res_attrs is not None: attributes["res_attrs"] = (res_attrs if (
isinstance(res_attrs, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DictArrayAttr')) else
_ods_ir.AttrBuilder.get('DictArrayAttr')(res_attrs, context=_ods_context))
if bool(no_inline): attributes["no_inline"] = _ods_ir.UnitAttr.get(
_ods_get_default_loc_context(loc))
results = []
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def sym_name(self) -> _ods_ir.StringAttr:
return self.operation.attributes["sym_name"]
@sym_name.setter
def sym_name(self, value: _ods_ir.StringAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["sym_name"] = value
@builtins.property
def function_type(self) -> _ods_ir.TypeAttr:
return self.operation.attributes["function_type"]
@function_type.setter
def function_type(self, value: _ods_ir.TypeAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["function_type"] = value
@builtins.property
def sym_visibility(self) -> _Optional[_ods_ir.StringAttr]:
if "sym_visibility" not in self.operation.attributes:
return None
return self.operation.attributes["sym_visibility"]
@sym_visibility.setter
def sym_visibility(self, value: _Optional[_ods_ir.StringAttr]):
if value is not None:
self.operation.attributes["sym_visibility"] = value
elif "sym_visibility" in self.operation.attributes:
del self.operation.attributes["sym_visibility"]
@sym_visibility.deleter
def sym_visibility(self):
del self.operation.attributes["sym_visibility"]
@builtins.property
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "arg_attrs" not in self.operation.attributes:
return None
return self.operation.attributes["arg_attrs"]
@arg_attrs.setter
def arg_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
if value is not None:
self.operation.attributes["arg_attrs"] = value
elif "arg_attrs" in self.operation.attributes:
del self.operation.attributes["arg_attrs"]
@arg_attrs.deleter
def arg_attrs(self):
del self.operation.attributes["arg_attrs"]
@builtins.property
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "res_attrs" not in self.operation.attributes:
return None
return self.operation.attributes["res_attrs"]
@res_attrs.setter
def res_attrs(self, value: _Optional[_ods_ir.ArrayAttr]):
if value is not None:
self.operation.attributes["res_attrs"] = value
elif "res_attrs" in self.operation.attributes:
del self.operation.attributes["res_attrs"]
@res_attrs.deleter
def res_attrs(self):
del self.operation.attributes["res_attrs"]
@builtins.property
def no_inline(self) -> bool:
return "no_inline" in self.operation.attributes
@no_inline.setter
def no_inline(self, value):
if bool(value):
self.operation.attributes["no_inline"] = _ods_ir.UnitAttr.get()
elif "no_inline" in self.operation.attributes:
del self.operation.attributes["no_inline"]
@no_inline.deleter
def no_inline(self):
del self.operation.attributes["no_inline"]
@builtins.property
def body(self) -> _ods_ir.Region:
return self.regions[0]
@_ods_cext.register_op_adaptor(FuncOp)
class FuncOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "func.func"
@builtins.property
def sym_name(self) -> _ods_ir.StringAttr:
return self.attributes["sym_name"]
@builtins.property
def function_type(self) -> _ods_ir.TypeAttr:
return self.attributes["function_type"]
@builtins.property
def sym_visibility(self) -> _Optional[_ods_ir.StringAttr]:
if "sym_visibility" not in self.attributes:
return None
return self.attributes["sym_visibility"]
@builtins.property
def arg_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "arg_attrs" not in self.attributes:
return None
return self.attributes["arg_attrs"]
@builtins.property
def res_attrs(self) -> _Optional[_ods_ir.ArrayAttr]:
if "res_attrs" not in self.attributes:
return None
return self.attributes["res_attrs"]
@builtins.property
def no_inline(self) -> bool:
return "no_inline" in self.attributes
def func(sym_name: _Union[str, _ods_ir.StringAttr], function_type: _Union[_Any, _ods_ir.TypeAttr], *, sym_visibility: _Optional[_Union[str, _ods_ir.StringAttr]] = None, arg_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, res_attrs: _Optional[_Union[_Any, _ods_ir.ArrayAttr]] = None, no_inline: _Optional[bool] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> FuncOp:
return FuncOp(sym_name=sym_name, function_type=function_type, sym_visibility=sym_visibility, arg_attrs=arg_attrs, res_attrs=res_attrs, no_inline=no_inline, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect)
class ReturnOp(_ods_ir.OpView):
r"""
The `func.return` operation represents a return operation within a function.
The operation takes variable number of operands and produces no results.
The operand number and types must match the signature of the function
that contains the operation.
Example:
```mlir
func.func @foo() -> (i32, f8) {
...
return %0, %1 : i32, f8
}
```
"""
OPERATION_NAME = "func.return"
_ODS_REGIONS = (0, True)
def __init__(self, operands_: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(operands_))
_ods_context = _ods_get_default_loc_context(loc)
results = []
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def operands_(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@_ods_cext.register_op_adaptor(ReturnOp)
class ReturnOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "func.return"
@builtins.property
def operands_(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
def return_(operands_: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> ReturnOp:
return ReturnOp(operands_=operands_, loc=loc, ip=ip)
@@ -0,0 +1,440 @@
# Autogenerated by mlir-tblgen; don't manually edit.
from enum import IntEnum, auto, IntFlag
from ._ods_common import _cext as _ods_cext
from ..ir import register_attribute_builder
_ods_ir = _ods_cext.ir
class AddressSpace(IntEnum):
"""GPU address space"""
Global = 1
Workgroup = 2
Private = 3
Constant = 4
def __str__(self):
if self is AddressSpace.Global:
return "global"
if self is AddressSpace.Workgroup:
return "workgroup"
if self is AddressSpace.Private:
return "private"
if self is AddressSpace.Constant:
return "constant"
raise ValueError("Unknown AddressSpace enum entry.")
@register_attribute_builder("GPU_AddressSpaceEnum", allow_existing=True)
def _gpu_addressspaceenum(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class AllReduceOperation(IntEnum):
"""built-in reduction operations supported by gpu.allreduce."""
ADD = 0
MUL = 1
MINUI = 2
MINSI = 3
MINNUMF = 4
MAXUI = 5
MAXSI = 6
MAXNUMF = 7
AND = 8
OR = 9
XOR = 10
MINIMUMF = 11
MAXIMUMF = 12
def __str__(self):
if self is AllReduceOperation.ADD:
return "add"
if self is AllReduceOperation.MUL:
return "mul"
if self is AllReduceOperation.MINUI:
return "minui"
if self is AllReduceOperation.MINSI:
return "minsi"
if self is AllReduceOperation.MINNUMF:
return "minnumf"
if self is AllReduceOperation.MAXUI:
return "maxui"
if self is AllReduceOperation.MAXSI:
return "maxsi"
if self is AllReduceOperation.MAXNUMF:
return "maxnumf"
if self is AllReduceOperation.AND:
return "and"
if self is AllReduceOperation.OR:
return "or"
if self is AllReduceOperation.XOR:
return "xor"
if self is AllReduceOperation.MINIMUMF:
return "minimumf"
if self is AllReduceOperation.MAXIMUMF:
return "maximumf"
raise ValueError("Unknown AllReduceOperation enum entry.")
@register_attribute_builder("GPU_AllReduceOperation", allow_existing=True)
def _gpu_allreduceoperation(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class BroadcastType(IntEnum):
"""a lane to broadcast from"""
first_active_lane = 0
specific_lane = 1
def __str__(self):
if self is BroadcastType.first_active_lane:
return "first_active_lane"
if self is BroadcastType.specific_lane:
return "specific_lane"
raise ValueError("Unknown BroadcastType enum entry.")
@register_attribute_builder("GPU_BroadcastType", allow_existing=True)
def _gpu_broadcasttype(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class CompilationTarget(IntEnum):
"""GPU compilation format"""
Offload = 1
Assembly = 2
Binary = 3
Fatbin = 4
def __str__(self):
if self is CompilationTarget.Offload:
return "offload"
if self is CompilationTarget.Assembly:
return "assembly"
if self is CompilationTarget.Binary:
return "bin"
if self is CompilationTarget.Fatbin:
return "fatbin"
raise ValueError("Unknown CompilationTarget enum entry.")
@register_attribute_builder("GPU_CompilationTargetEnum", allow_existing=True)
def _gpu_compilationtargetenum(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class Dimension(IntEnum):
"""a dimension, either 'x', 'y', or 'z'"""
x = 0
y = 1
z = 2
def __str__(self):
if self is Dimension.x:
return "x"
if self is Dimension.y:
return "y"
if self is Dimension.z:
return "z"
raise ValueError("Unknown Dimension enum entry.")
@register_attribute_builder("GPU_Dimension", allow_existing=True)
def _gpu_dimension(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class DimensionKind(IntEnum):
"""the possible kinds of launch dimension"""
Other = 0
Block = 1
Grid = 2
Cluster = 3
def __str__(self):
if self is DimensionKind.Other:
return "other"
if self is DimensionKind.Block:
return "block"
if self is DimensionKind.Grid:
return "grid"
if self is DimensionKind.Cluster:
return "cluster"
raise ValueError("Unknown DimensionKind enum entry.")
class Prune2To4SpMatFlag(IntEnum):
"""pruning strategy for 2:4 sparse matrix"""
NONE = 0
PRUNE_ONLY = 1
PRUNE_AND_CHECK = 2
def __str__(self):
if self is Prune2To4SpMatFlag.NONE:
return "NONE"
if self is Prune2To4SpMatFlag.PRUNE_ONLY:
return "PRUNE_ONLY"
if self is Prune2To4SpMatFlag.PRUNE_AND_CHECK:
return "PRUNE_AND_CHECK"
raise ValueError("Unknown Prune2To4SpMatFlag enum entry.")
@register_attribute_builder("GPU_Prune2To4SpMatFlag", allow_existing=True)
def _gpu_prune2to4spmatflag(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class ShuffleMode(IntEnum):
"""Indexing modes supported by gpu.shuffle."""
XOR = 0
UP = 2
DOWN = 1
IDX = 3
def __str__(self):
if self is ShuffleMode.XOR:
return "xor"
if self is ShuffleMode.UP:
return "up"
if self is ShuffleMode.DOWN:
return "down"
if self is ShuffleMode.IDX:
return "idx"
raise ValueError("Unknown ShuffleMode enum entry.")
@register_attribute_builder("GPU_ShuffleMode", allow_existing=True)
def _gpu_shufflemode(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class SpGEMMWorkEstimationOrComputeKind(IntEnum):
"""choose whether spgemm_work_estimation_or_compute does work estimation or compute"""
WORK_ESTIMATION = 0
COMPUTE = 1
def __str__(self):
if self is SpGEMMWorkEstimationOrComputeKind.WORK_ESTIMATION:
return "WORK_ESTIMATION"
if self is SpGEMMWorkEstimationOrComputeKind.COMPUTE:
return "COMPUTE"
raise ValueError("Unknown SpGEMMWorkEstimationOrComputeKind enum entry.")
@register_attribute_builder("GPU_SpGEMMWorkEstimationOrComputeKind", allow_existing=True)
def _gpu_spgemmworkestimationorcomputekind(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class TransposeMode(IntEnum):
"""transpose mode of sparse matrix supported by sparse tensor ops"""
NON_TRANSPOSE = 0
TRANSPOSE = 1
CONJUGATE_TRANSPOSE = 2
def __str__(self):
if self is TransposeMode.NON_TRANSPOSE:
return "NON_TRANSPOSE"
if self is TransposeMode.TRANSPOSE:
return "TRANSPOSE"
if self is TransposeMode.CONJUGATE_TRANSPOSE:
return "CONJUGATE_TRANSPOSE"
raise ValueError("Unknown TransposeMode enum entry.")
@register_attribute_builder("GPU_TransposeMode", allow_existing=True)
def _gpu_transposemode(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class MMAElementwiseOp(IntEnum):
"""elementwise operation to apply to mma matrix"""
ADDF = 0
MULF = 1
SUBF = 2
MAXF = 3
MINF = 4
DIVF = 5
ADDI = 6
MULI = 7
SUBI = 8
DIVS = 9
DIVU = 10
NEGATEF = 11
NEGATES = 12
EXTF = 13
TRUNCF = 14
def __str__(self):
if self is MMAElementwiseOp.ADDF:
return "addf"
if self is MMAElementwiseOp.MULF:
return "mulf"
if self is MMAElementwiseOp.SUBF:
return "subf"
if self is MMAElementwiseOp.MAXF:
return "maxf"
if self is MMAElementwiseOp.MINF:
return "minf"
if self is MMAElementwiseOp.DIVF:
return "divf"
if self is MMAElementwiseOp.ADDI:
return "addi"
if self is MMAElementwiseOp.MULI:
return "muli"
if self is MMAElementwiseOp.SUBI:
return "subi"
if self is MMAElementwiseOp.DIVS:
return "divs"
if self is MMAElementwiseOp.DIVU:
return "divu"
if self is MMAElementwiseOp.NEGATEF:
return "negatef"
if self is MMAElementwiseOp.NEGATES:
return "negates"
if self is MMAElementwiseOp.EXTF:
return "extf"
if self is MMAElementwiseOp.TRUNCF:
return "truncf"
raise ValueError("Unknown MMAElementwiseOp enum entry.")
@register_attribute_builder("MMAElementWise", allow_existing=True)
def _mmaelementwise(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class MappingId(IntEnum):
"""Mapping ids for loop mapping"""
DimX = 0
DimY = 1
DimZ = 2
LinearDim0 = 3
LinearDim1 = 4
LinearDim2 = 5
LinearDim3 = 6
LinearDim4 = 7
LinearDim5 = 8
LinearDim6 = 9
LinearDim7 = 10
LinearDim8 = 11
LinearDim9 = 12
def __str__(self):
if self is MappingId.DimX:
return "x"
if self is MappingId.DimY:
return "y"
if self is MappingId.DimZ:
return "z"
if self is MappingId.LinearDim0:
return "linear_dim_0"
if self is MappingId.LinearDim1:
return "linear_dim_1"
if self is MappingId.LinearDim2:
return "linear_dim_2"
if self is MappingId.LinearDim3:
return "linear_dim_3"
if self is MappingId.LinearDim4:
return "linear_dim_4"
if self is MappingId.LinearDim5:
return "linear_dim_5"
if self is MappingId.LinearDim6:
return "linear_dim_6"
if self is MappingId.LinearDim7:
return "linear_dim_7"
if self is MappingId.LinearDim8:
return "linear_dim_8"
if self is MappingId.LinearDim9:
return "linear_dim_9"
raise ValueError("Unknown MappingId enum entry.")
@register_attribute_builder("MappingIdEnum", allow_existing=True)
def _mappingidenum(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
class Processor(IntEnum):
"""processor for loop mapping"""
BlockX = 0
BlockY = 1
BlockZ = 2
ThreadX = 3
ThreadY = 4
ThreadZ = 5
Sequential = 6
def __str__(self):
if self is Processor.BlockX:
return "block_x"
if self is Processor.BlockY:
return "block_y"
if self is Processor.BlockZ:
return "block_z"
if self is Processor.ThreadX:
return "thread_x"
if self is Processor.ThreadY:
return "thread_y"
if self is Processor.ThreadZ:
return "thread_z"
if self is Processor.Sequential:
return "sequential"
raise ValueError("Unknown Processor enum entry.")
@register_attribute_builder("ProcessorEnum", allow_existing=True)
def _processorenum(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
@register_attribute_builder("gpu.GPU_AddressSpaceAttr")
def _gpu_addressspaceattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu.address_space<{str(x)}>', context=context)
@register_attribute_builder("gpu.GPU_AllReduceOperationAttr")
def _gpu_allreduceoperationattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu<all_reduce_op {str(x)}>', context=context)
@register_attribute_builder("gpu.GPU_BroadcastTypeAttr")
def _gpu_broadcasttypeattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu<broadcast {str(x)}>', context=context)
@register_attribute_builder("gpu.GPU_DimensionAttr")
def _gpu_dimensionattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu<dim {str(x)}>', context=context)
@register_attribute_builder("gpu.GPU_Prune2To4SpMatFlagAttr")
def _gpu_prune2to4spmatflagattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu<prune_2to4_spmat_flag {str(x)}>', context=context)
@register_attribute_builder("gpu.GPU_ShuffleModeAttr")
def _gpu_shufflemodeattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu<shuffle_mode {str(x)}>', context=context)
@register_attribute_builder("gpu.GPU_SpGEMMWorkEstimationOrComputeKindAttr")
def _gpu_spgemmworkestimationorcomputekindattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu<spgemm_work_estimation_or_compute_kind {str(x)}>', context=context)
@register_attribute_builder("gpu.GPU_TransposeModeAttr")
def _gpu_transposemodeattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu<mat_transpose_mode {str(x)}>', context=context)
@register_attribute_builder("gpu.MMAElementWiseAttr")
def _mmaelementwiseattr(x, context):
return _ods_ir.Attribute.parse(f'#gpu<mma_element_wise {str(x)}>', context=context)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,951 @@
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
import sys as _sys
if _sys.version_info >= (3, 12):
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
else:
try:
from typing_extensions import Buffer as _Buffer
except ImportError:
_Buffer = _Any
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "mpmd"
@_ods_cext.register_operation(_Dialect)
class AssignOp(_ods_ir.OpView):
r"""
Assigns a local tensor to a mesh as fully replicated within that mesh.
This is a temporary op that is introduced when lowering jax ops, to move
from local types to mesh types. These ops will be eliminated during import,
when the inputs and results of the func op become mesh tensors.
The mesh name of the result type should correspond to a mesh in the
topology, and its global type should be identical to the operand type.
The origin of the assign op is the origin of mesh, e.g. named_computation,
mesh inference, etc.
"""
OPERATION_NAME = "mpmd.assign"
_ODS_REGIONS = (0, True)
def __init__(self, result: _ods_ir.Type, tensor: _ods_ir.Value, *, origin: _Optional[_Union[str, _ods_ir.StringAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(tensor)
_ods_context = _ods_get_default_loc_context(loc)
if origin is not None: attributes["origin"] = (origin if (
isinstance(origin, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('StrAttr')) else
_ods_ir.AttrBuilder.get('StrAttr')(origin, context=_ods_context))
results = []
results.append(result)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operation.operands[0]
@builtins.property
def origin(self) -> _Optional[_ods_ir.StringAttr]:
if "origin" not in self.operation.attributes:
return None
return self.operation.attributes["origin"]
@origin.setter
def origin(self, value: _Optional[_ods_ir.StringAttr]):
if value is not None:
self.operation.attributes["origin"] = value
elif "origin" in self.operation.attributes:
del self.operation.attributes["origin"]
@origin.deleter
def origin(self):
del self.operation.attributes["origin"]
@builtins.property
def result(self) -> _ods_ir.OpResult:
return self.operation.results[0]
@_ods_cext.register_op_adaptor(AssignOp)
class AssignOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.assign"
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operands[0]
@builtins.property
def origin(self) -> _Optional[_ods_ir.StringAttr]:
if "origin" not in self.attributes:
return None
return self.attributes["origin"]
def assign(result: _ods_ir.Type, tensor: _ods_ir.Value, *, origin: _Optional[_Union[str, _ods_ir.StringAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
return AssignOp(result=result, tensor=tensor, origin=origin, loc=loc, ip=ip).result
@_ods_cext.register_operation(_Dialect)
class BroadcastOp(_ods_ir.OpView):
r"""
Allows for a tensor to be transferred (or replicated) in any mesh where it's
used. Whenever transferred, the origin of the transfer is the current
location of the operand.
"""
OPERATION_NAME = "mpmd.broadcast"
_ODS_REGIONS = (0, True)
def __init__(self, tensor: _ods_ir.Value, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(tensor)
_ods_context = _ods_get_default_loc_context(loc)
if results is None: results = [operands[0].type] * 1
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operation.operands[0]
@builtins.property
def result(self) -> _ods_ir.OpResult:
return self.operation.results[0]
@_ods_cext.register_op_adaptor(BroadcastOp)
class BroadcastOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.broadcast"
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operands[0]
def broadcast(tensor: _ods_ir.Value, *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
return BroadcastOp(tensor=tensor, results=results, loc=loc, ip=ip).result
@_ods_cext.register_operation(_Dialect)
class CallOp(_ods_ir.OpView):
r"""
A function call operation. Useful to wrap the body of loops in function
declarations to reduce code size, for example.
"""
OPERATION_NAME = "mpmd.call"
_ODS_REGIONS = (0, True)
def __init__(self, result: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(tensors))
_ods_context = _ods_get_default_loc_context(loc)
attributes["callee"] = (callee if (
isinstance(callee, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
_ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(callee, context=_ods_context))
results = []
results.extend(result)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
return self.operation.attributes["callee"]
@callee.setter
def callee(self, value: _ods_ir.FlatSymbolRefAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["callee"] = value
@_ods_cext.register_op_adaptor(CallOp)
class CallOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.call"
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
return self.attributes["callee"]
def call(result: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, CallOp]:
op = CallOp(result=result, tensors=tensors, callee=callee, loc=loc, ip=ip); results = op.results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@_ods_cext.register_operation(_Dialect)
class ForOp(_ods_ir.OpView):
r"""
Returns the result of executing a body function for a fixed number of
iterations, with the iteration index available in the body.
An optional unroll factor, that must divide the number of iterations,
can be specified to unroll the body of the op by that factor, i.e. for
unroll factor N, the body is replicated to create N copies and the number of
iterations is reduced by a factor of 1/N. Each copy except the first uses
the results of the previous copy instead of the block arguments, and the
iteration index is multiplied by the unroll factor and incremented after
every copy.
A for operator can accept and return any types, but the TypeID of these
must be the same -- e.g. all tensor types or all MPMD mesh types etc. This
allows us to use the op at various levels, sharing implementation and
transformations.
"""
OPERATION_NAME = "mpmd.for"
_ODS_REGIONS = (1, True)
def __init__(self, results_: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], iterations: _Union[int, _ods_ir.IntegerAttr], *, unroll_factor: _Optional[_Union[int, _ods_ir.IntegerAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(tensors))
_ods_context = _ods_get_default_loc_context(loc)
attributes["iterations"] = (iterations if (
isinstance(iterations, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('UI32Attr')) else
_ods_ir.AttrBuilder.get('UI32Attr')(iterations, context=_ods_context))
if unroll_factor is not None: attributes["unroll_factor"] = (unroll_factor if (
isinstance(unroll_factor, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('UI32Attr')) else
_ods_ir.AttrBuilder.get('UI32Attr')(unroll_factor, context=_ods_context))
results = []
results.extend(results_)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def iterations(self) -> _ods_ir.IntegerAttr:
return self.operation.attributes["iterations"]
@iterations.setter
def iterations(self, value: _ods_ir.IntegerAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["iterations"] = value
@builtins.property
def unroll_factor(self) -> _Optional[_ods_ir.IntegerAttr]:
if "unroll_factor" not in self.operation.attributes:
return None
return self.operation.attributes["unroll_factor"]
@unroll_factor.setter
def unroll_factor(self, value: _Optional[_ods_ir.IntegerAttr]):
if value is not None:
self.operation.attributes["unroll_factor"] = value
elif "unroll_factor" in self.operation.attributes:
del self.operation.attributes["unroll_factor"]
@unroll_factor.deleter
def unroll_factor(self):
del self.operation.attributes["unroll_factor"]
@builtins.property
def results_(self) -> _ods_ir.OpResultList:
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
return self.operation.results[0:0 + _ods_variadic_group_length]
@builtins.property
def region(self) -> _ods_ir.Region:
return self.regions[0]
@_ods_cext.register_op_adaptor(ForOp)
class ForOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.for"
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def iterations(self) -> _ods_ir.IntegerAttr:
return self.attributes["iterations"]
@builtins.property
def unroll_factor(self) -> _Optional[_ods_ir.IntegerAttr]:
if "unroll_factor" not in self.attributes:
return None
return self.attributes["unroll_factor"]
def for_(results_: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], iterations: _Union[int, _ods_ir.IntegerAttr], *, unroll_factor: _Optional[_Union[int, _ods_ir.IntegerAttr]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, ForOp]:
op = ForOp(results_=results_, tensors=tensors, iterations=iterations, unroll_factor=unroll_factor, loc=loc, ip=ip); results = op.results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@_ods_cext.register_operation(_Dialect)
class FragmentCallOp(_ods_ir.OpView):
r"""
Represents a call to a function that holds an MPMD fragment body, i.e. a
computation assigned to a specific mesh in an MPMD topology, that is
intended to be executed as an individual SPMD program fragment.
The mesh name of the fragment should correspond to a mesh in the topology of
the enclosing function, and that mesh shape should match that of the callee.
The origin specifies the user named computations that contributed to this
fragment call e.g. through merging.
The function input and result types of the callee must be the local tensor
types of the corresponding mesh tensors of this op's operands and results
respectively.
Example:
```mlir
%2 = mpmd.fragment_call<mesh="m1", origin=[]> @my_fragment(%0, %1) :
(mesh_tensor<...>, mesh_tensor<...>) -> mesh_tensor<...>
```
"""
OPERATION_NAME = "mpmd.fragment_call"
_ODS_REGIONS = (0, True)
def __init__(self, result: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.ArrayAttr], mesh_name: _Union[str, _ods_ir.StringAttr], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(tensors))
_ods_context = _ods_get_default_loc_context(loc)
attributes["origin"] = (origin if (
isinstance(origin, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('anonymous_847')) else
_ods_ir.AttrBuilder.get('anonymous_847')(origin, context=_ods_context))
attributes["mesh_name"] = (mesh_name if (
isinstance(mesh_name, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('StrAttr')) else
_ods_ir.AttrBuilder.get('StrAttr')(mesh_name, context=_ods_context))
attributes["callee"] = (callee if (
isinstance(callee, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('FlatSymbolRefAttr')) else
_ods_ir.AttrBuilder.get('FlatSymbolRefAttr')(callee, context=_ods_context))
results = []
results.extend(result)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def origin(self) -> _ods_ir.ArrayAttr:
return self.operation.attributes["origin"]
@origin.setter
def origin(self, value: _ods_ir.ArrayAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["origin"] = value
@builtins.property
def mesh_name(self) -> _ods_ir.StringAttr:
return self.operation.attributes["mesh_name"]
@mesh_name.setter
def mesh_name(self, value: _ods_ir.StringAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["mesh_name"] = value
@builtins.property
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
return self.operation.attributes["callee"]
@callee.setter
def callee(self, value: _ods_ir.FlatSymbolRefAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["callee"] = value
@_ods_cext.register_op_adaptor(FragmentCallOp)
class FragmentCallOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.fragment_call"
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def origin(self) -> _ods_ir.ArrayAttr:
return self.attributes["origin"]
@builtins.property
def mesh_name(self) -> _ods_ir.StringAttr:
return self.attributes["mesh_name"]
@builtins.property
def callee(self) -> _ods_ir.FlatSymbolRefAttr:
return self.attributes["callee"]
def fragment_call(result: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.ArrayAttr], mesh_name: _Union[str, _ods_ir.StringAttr], callee: _Union[str, _ods_ir.FlatSymbolRefAttr], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, FragmentCallOp]:
op = FragmentCallOp(result=result, tensors=tensors, origin=origin, mesh_name=mesh_name, callee=callee, loc=loc, ip=ip); results = op.results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@_ods_cext.register_operation(_Dialect)
class FragmentOp(_ods_ir.OpView):
r"""
Assigns a computation, i.e. a block of operations, to a specific mesh in an
MPMD topology, that is intended to be executed as an individual SPMD program
fragment.
The fragment takes and returns only mesh tensors that are assigned to the
same mesh as the fragment.
The mesh name of the fragment should correspond to a mesh in the topology.
The fragment includes a list of origins, i.e., metadata with information re
the original named_computations that formed this fragment, and a staged_id
defined _iff_ it is a user defined fragment, i.e., it has a non-empty list
of origins. The optional in_shardings specifies the sharding of the
block arguments of a fragment, which correspond to the operands.
The optional out_shardings specifies the shardings of the results.
The fragment's region shouldn't have any free variables, and the type of
each block arguments and returned values in the region is the global tensor
type of the corresponding mesh tensor.
"""
OPERATION_NAME = "mpmd.fragment"
_ODS_REGIONS = (1, True)
def __init__(self, results_: _Sequence[_ods_ir.Type], inputs: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.ArrayAttr], mesh_name: _Union[str, _ods_ir.StringAttr], *, stage_id: _Optional[_Union[int, _ods_ir.IntegerAttr]] = None, in_shardings: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, out_shardings: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(inputs))
_ods_context = _ods_get_default_loc_context(loc)
attributes["origin"] = (origin if (
isinstance(origin, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('anonymous_847')) else
_ods_ir.AttrBuilder.get('anonymous_847')(origin, context=_ods_context))
attributes["mesh_name"] = (mesh_name if (
isinstance(mesh_name, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('StrAttr')) else
_ods_ir.AttrBuilder.get('StrAttr')(mesh_name, context=_ods_context))
if stage_id is not None: attributes["stage_id"] = (stage_id if (
isinstance(stage_id, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('I64Attr')) else
_ods_ir.AttrBuilder.get('I64Attr')(stage_id, context=_ods_context))
if in_shardings is not None: attributes["in_shardings"] = (in_shardings if (
isinstance(in_shardings, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
_ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(in_shardings, context=_ods_context))
if out_shardings is not None: attributes["out_shardings"] = (out_shardings if (
isinstance(out_shardings, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
_ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(out_shardings, context=_ods_context))
results = []
results.extend(results_)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def inputs(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def origin(self) -> _ods_ir.ArrayAttr:
return self.operation.attributes["origin"]
@origin.setter
def origin(self, value: _ods_ir.ArrayAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["origin"] = value
@builtins.property
def mesh_name(self) -> _ods_ir.StringAttr:
return self.operation.attributes["mesh_name"]
@mesh_name.setter
def mesh_name(self, value: _ods_ir.StringAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["mesh_name"] = value
@builtins.property
def stage_id(self) -> _Optional[_ods_ir.IntegerAttr]:
if "stage_id" not in self.operation.attributes:
return None
return self.operation.attributes["stage_id"]
@stage_id.setter
def stage_id(self, value: _Optional[_ods_ir.IntegerAttr]):
if value is not None:
self.operation.attributes["stage_id"] = value
elif "stage_id" in self.operation.attributes:
del self.operation.attributes["stage_id"]
@stage_id.deleter
def stage_id(self):
del self.operation.attributes["stage_id"]
@builtins.property
def in_shardings(self) -> _Optional[_ods_ir.Attribute]:
if "in_shardings" not in self.operation.attributes:
return None
return self.operation.attributes["in_shardings"]
@in_shardings.setter
def in_shardings(self, value: _Optional[_ods_ir.Attribute]):
if value is not None:
self.operation.attributes["in_shardings"] = value
elif "in_shardings" in self.operation.attributes:
del self.operation.attributes["in_shardings"]
@in_shardings.deleter
def in_shardings(self):
del self.operation.attributes["in_shardings"]
@builtins.property
def out_shardings(self) -> _Optional[_ods_ir.Attribute]:
if "out_shardings" not in self.operation.attributes:
return None
return self.operation.attributes["out_shardings"]
@out_shardings.setter
def out_shardings(self, value: _Optional[_ods_ir.Attribute]):
if value is not None:
self.operation.attributes["out_shardings"] = value
elif "out_shardings" in self.operation.attributes:
del self.operation.attributes["out_shardings"]
@out_shardings.deleter
def out_shardings(self):
del self.operation.attributes["out_shardings"]
@builtins.property
def results_(self) -> _ods_ir.OpResultList:
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
return self.operation.results[0:0 + _ods_variadic_group_length]
@builtins.property
def region(self) -> _ods_ir.Region:
return self.regions[0]
@_ods_cext.register_op_adaptor(FragmentOp)
class FragmentOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.fragment"
@builtins.property
def inputs(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def origin(self) -> _ods_ir.ArrayAttr:
return self.attributes["origin"]
@builtins.property
def mesh_name(self) -> _ods_ir.StringAttr:
return self.attributes["mesh_name"]
@builtins.property
def stage_id(self) -> _Optional[_ods_ir.IntegerAttr]:
if "stage_id" not in self.attributes:
return None
return self.attributes["stage_id"]
@builtins.property
def in_shardings(self) -> _Optional[_ods_ir.Attribute]:
if "in_shardings" not in self.attributes:
return None
return self.attributes["in_shardings"]
@builtins.property
def out_shardings(self) -> _Optional[_ods_ir.Attribute]:
if "out_shardings" not in self.attributes:
return None
return self.attributes["out_shardings"]
def fragment(results_: _Sequence[_ods_ir.Type], inputs: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.ArrayAttr], mesh_name: _Union[str, _ods_ir.StringAttr], *, stage_id: _Optional[_Union[int, _ods_ir.IntegerAttr]] = None, in_shardings: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, out_shardings: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, FragmentOp]:
op = FragmentOp(results_=results_, inputs=inputs, origin=origin, mesh_name=mesh_name, stage_id=stage_id, in_shardings=in_shardings, out_shardings=out_shardings, loc=loc, ip=ip); results = op.results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@_ods_cext.register_operation(_Dialect)
class NamedComputationOp(_ods_ir.OpView):
r"""
Groups a computation, i.e. a block of operations, and gives it a name and
a transpose count via the UserOrigin attribute. This NamedComputation can be
used to assign a mesh to the computation in MPMD or for optimizations.
The transpose count (default=0) denotes whether the named computation has
been produced by a certain number of JAX AD transpose transformations.
The op's region shouldn't have any free variables, and the type of
each block arguments and returned values in the region must be the same as
the type of the inputs and the return type of the op.
"""
OPERATION_NAME = "mpmd.named_computation"
_ODS_REGIONS = (1, True)
def __init__(self, results_: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.Attribute], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(tensors))
_ods_context = _ods_get_default_loc_context(loc)
attributes["origin"] = (origin if (
isinstance(origin, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('Mpmd_UserOrigin')) else
_ods_ir.AttrBuilder.get('Mpmd_UserOrigin')(origin, context=_ods_context))
results = []
results.extend(results_)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def origin(self) -> _ods_ir.Attribute:
return self.operation.attributes["origin"]
@origin.setter
def origin(self, value: _ods_ir.Attribute):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["origin"] = value
@builtins.property
def results_(self) -> _ods_ir.OpResultList:
_ods_variadic_group_length = len(self.operation.results) - 1 + 1
return self.operation.results[0:0 + _ods_variadic_group_length]
@builtins.property
def region(self) -> _ods_ir.Region:
return self.regions[0]
@_ods_cext.register_op_adaptor(NamedComputationOp)
class NamedComputationOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.named_computation"
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def origin(self) -> _ods_ir.Attribute:
return self.attributes["origin"]
def named_computation(results_: _Sequence[_ods_ir.Type], tensors: _Sequence[_ods_ir.Value], origin: _Union[_Any, _ods_ir.Attribute], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, NamedComputationOp]:
op = NamedComputationOp(results_=results_, tensors=tensors, origin=origin, loc=loc, ip=ip); results = op.results
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
@_ods_cext.register_operation(_Dialect)
class NamedTensorOp(_ods_ir.OpView):
r"""
An identity op that associates the result of the tensor with a given name.
This NamedTensor can be used to assign a mesh to the tensor in MPMD.
NOTE: this is different than TagOp in that TagOp is used for naming a tensor
and can be used to partition that tensor. NamedTensorOp is for MPMD programs
for tensors that may be explicitly assigned to meshes.
"""
OPERATION_NAME = "mpmd.named_tensor"
_ODS_REGIONS = (0, True)
def __init__(self, tensor: _ods_ir.Value, name: _Union[str, _ods_ir.StringAttr], *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(tensor)
_ods_context = _ods_get_default_loc_context(loc)
attributes["name"] = (name if (
isinstance(name, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('StrAttr')) else
_ods_ir.AttrBuilder.get('StrAttr')(name, context=_ods_context))
if results is None: results = [operands[0].type] * 1
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operation.operands[0]
@builtins.property
def name(self) -> _ods_ir.StringAttr:
return self.operation.attributes["name"]
@name.setter
def name(self, value: _ods_ir.StringAttr):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["name"] = value
@builtins.property
def result(self) -> _ods_ir.OpResult:
return self.operation.results[0]
@_ods_cext.register_op_adaptor(NamedTensorOp)
class NamedTensorOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.named_tensor"
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operands[0]
@builtins.property
def name(self) -> _ods_ir.StringAttr:
return self.attributes["name"]
def named_tensor(tensor: _ods_ir.Value, name: _Union[str, _ods_ir.StringAttr], *, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
return NamedTensorOp(tensor=tensor, name=name, results=results, loc=loc, ip=ip).result
@_ods_cext.register_operation(_Dialect)
class ReduceOp(_ods_ir.OpView):
r"""
Allows for a tensor to be reduced across different meshes, and then
broadcast to wherever it needs to be used.
"""
OPERATION_NAME = "mpmd.reduce"
_ODS_REGIONS = (0, True)
def __init__(self, tensors: _Sequence[_ods_ir.Value], *, reduction: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(tensors))
_ods_context = _ods_get_default_loc_context(loc)
if reduction is not None: attributes["reduction"] = (reduction if (
isinstance(reduction, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('Mpmd_Reduction')) else
_ods_ir.AttrBuilder.get('Mpmd_Reduction')(reduction, context=_ods_context))
if results is None: results = [operands[0].type] * 1
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def reduction(self) -> _ods_ir.Attribute:
return self.operation.attributes["reduction"]
@reduction.setter
def reduction(self, value: _ods_ir.Attribute):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["reduction"] = value
@builtins.property
def result(self) -> _ods_ir.OpResult:
return self.operation.results[0]
@_ods_cext.register_op_adaptor(ReduceOp)
class ReduceOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.reduce"
@builtins.property
def tensors(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
@builtins.property
def reduction(self) -> _ods_ir.Attribute:
return self.attributes["reduction"]
def reduce(tensors: _Sequence[_ods_ir.Value], *, reduction: _Optional[_Union[_Any, _ods_ir.Attribute]] = None, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
return ReduceOp(tensors=tensors, reduction=reduction, results=results, loc=loc, ip=ip).result
@_ods_cext.register_operation(_Dialect)
class ReturnOp(_ods_ir.OpView):
OPERATION_NAME = "mpmd.return"
_ODS_REGIONS = (0, True)
def __init__(self, results_: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.extend(_get_op_results_or_values(results_))
_ods_context = _ods_get_default_loc_context(loc)
results = []
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def results_(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
return self.operation.operands[0:0 + _ods_variadic_group_length]
@_ods_cext.register_op_adaptor(ReturnOp)
class ReturnOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.return"
@builtins.property
def results_(self) -> _ods_ir.OpOperandList:
_ods_variadic_group_length = len(self.operands) - 1 + 1
return self.operands[0:0 + _ods_variadic_group_length]
def return_(results_: _Sequence[_ods_ir.Value], *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> ReturnOp:
return ReturnOp(results_=results_, loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect)
class TransferOp(_ods_ir.OpView):
r"""
Transfers a distributed tensor from one mesh to another.
The mesh names of the operand and result types should correspond to meshes
in the topology, and their global types should be identical.
"""
OPERATION_NAME = "mpmd.transfer"
_ODS_REGIONS = (0, True)
def __init__(self, result: _ods_ir.Type, tensor: _ods_ir.Value, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(tensor)
_ods_context = _ods_get_default_loc_context(loc)
results = []
results.append(result)
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operation.operands[0]
@builtins.property
def result(self) -> _ods_ir.OpResult:
return self.operation.results[0]
@_ods_cext.register_op_adaptor(TransferOp)
class TransferOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.transfer"
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operands[0]
def transfer(result: _ods_ir.Type, tensor: _ods_ir.Value, *, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
return TransferOp(result=result, tensor=tensor, loc=loc, ip=ip).result
@_ods_cext.register_operation(_Dialect)
class UnassignOp(_ods_ir.OpView):
r"""
Unassigns a fully replicated tensor from a mesh.
This is a temporary op that is introduced when lowering jax ops, to move
from local types to mesh types. These ops will be eliminated during import,
when the inputs and results of the func op become mesh tensors.
The mesh name of the operand type should correspond to a mesh in the
topology, and its global type should be identical to the result type.
"""
OPERATION_NAME = "mpmd.unassign"
_ODS_REGIONS = (0, True)
def __init__(self, tensor: _ods_ir.Value, *, origin: _Optional[_Union[str, _ods_ir.StringAttr]] = None, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None):
operands = []
attributes = {}
regions = None
operands.append(tensor)
_ods_context = _ods_get_default_loc_context(loc)
if origin is not None: attributes["origin"] = (origin if (
isinstance(origin, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('StrAttr')) else
_ods_ir.AttrBuilder.get('StrAttr')(origin, context=_ods_context))
_ods_successors = None
super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operation.operands[0]
@builtins.property
def origin(self) -> _Optional[_ods_ir.StringAttr]:
if "origin" not in self.operation.attributes:
return None
return self.operation.attributes["origin"]
@origin.setter
def origin(self, value: _Optional[_ods_ir.StringAttr]):
if value is not None:
self.operation.attributes["origin"] = value
elif "origin" in self.operation.attributes:
del self.operation.attributes["origin"]
@origin.deleter
def origin(self):
del self.operation.attributes["origin"]
@builtins.property
def result(self) -> _ods_ir.OpResult:
return self.operation.results[0]
@_ods_cext.register_op_adaptor(UnassignOp)
class UnassignOpAdaptor(_ods_ir.OpAdaptor):
OPERATION_NAME = "mpmd.unassign"
@builtins.property
def tensor(self) -> _ods_ir.Value:
return self.operands[0]
@builtins.property
def origin(self) -> _Optional[_ods_ir.StringAttr]:
if "origin" not in self.attributes:
return None
return self.attributes["origin"]
def unassign(tensor: _ods_ir.Value, *, origin: _Optional[_Union[str, _ods_ir.StringAttr]] = None, results: _Optional[_Sequence[_ods_ir.Type]] = None, loc: _Optional[_ods_ir.Location] = None, ip: _Optional[_ods_ir.InsertionPoint] = None) -> _ods_ir.OpResult:
return UnassignOp(tensor=tensor, origin=origin, results=results, loc=loc, ip=ip).result
@@ -0,0 +1,147 @@
# Autogenerated by mlir-tblgen; don't manually edit.
from enum import IntEnum, auto, IntFlag
from ._ods_common import _cext as _ods_cext
from ..ir import register_attribute_builder
_ods_ir = _ods_cext.ir
class RcpRoundingMode(IntEnum):
"""Rounding mode of rcp"""
APPROX = 0
RN = 1
RZ = 2
RM = 3
RP = 4
def __str__(self):
if self is RcpRoundingMode.APPROX:
return "approx"
if self is RcpRoundingMode.RN:
return "rn"
if self is RcpRoundingMode.RZ:
return "rz"
if self is RcpRoundingMode.RM:
return "rm"
if self is RcpRoundingMode.RP:
return "rp"
raise ValueError("Unknown RcpRoundingMode enum entry.")
@register_attribute_builder("RcpRoundingMode", allow_existing=True)
def _rcproundingmode(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class TensorMapInterleaveKind(IntEnum):
"""Tensor map interleave layout type"""
INTERLEAVE_NONE = 0
INTERLEAVE_16B = 1
INTERLEAVE_32B = 2
def __str__(self):
if self is TensorMapInterleaveKind.INTERLEAVE_NONE:
return "none"
if self is TensorMapInterleaveKind.INTERLEAVE_16B:
return "interleave_16b"
if self is TensorMapInterleaveKind.INTERLEAVE_32B:
return "interleave_32b"
raise ValueError("Unknown TensorMapInterleaveKind enum entry.")
@register_attribute_builder("TensorMapInterleaveKind", allow_existing=True)
def _tensormapinterleavekind(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class TensorMapL2PromoKind(IntEnum):
"""Tensor map L2 promotion type"""
L2PROMO_NONE = 0
L2PROMO_64B = 1
L2PROMO_128B = 2
L2PROMO_256B = 3
def __str__(self):
if self is TensorMapL2PromoKind.L2PROMO_NONE:
return "none"
if self is TensorMapL2PromoKind.L2PROMO_64B:
return "l2promo_64b"
if self is TensorMapL2PromoKind.L2PROMO_128B:
return "l2promo_128b"
if self is TensorMapL2PromoKind.L2PROMO_256B:
return "l2promo_256b"
raise ValueError("Unknown TensorMapL2PromoKind enum entry.")
@register_attribute_builder("TensorMapL2PromoKind", allow_existing=True)
def _tensormapl2promokind(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class TensorMapOOBKind(IntEnum):
"""Tensor map out-of-bounds fill type"""
OOB_ZERO = 0
OOB_NAN = 1
def __str__(self):
if self is TensorMapOOBKind.OOB_ZERO:
return "zero"
if self is TensorMapOOBKind.OOB_NAN:
return "nan"
raise ValueError("Unknown TensorMapOOBKind enum entry.")
@register_attribute_builder("TensorMapOOBKind", allow_existing=True)
def _tensormapoobkind(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class TensorMapSwizzleKind(IntEnum):
"""Tensor map swizzling mode of shared memory banks"""
SWIZZLE_NONE = 0
SWIZZLE_32B = 1
SWIZZLE_64B = 2
SWIZZLE_128B = 3
def __str__(self):
if self is TensorMapSwizzleKind.SWIZZLE_NONE:
return "none"
if self is TensorMapSwizzleKind.SWIZZLE_32B:
return "swizzle_32b"
if self is TensorMapSwizzleKind.SWIZZLE_64B:
return "swizzle_64b"
if self is TensorMapSwizzleKind.SWIZZLE_128B:
return "swizzle_128b"
raise ValueError("Unknown TensorMapSwizzleKind enum entry.")
@register_attribute_builder("TensorMapSwizzleKind", allow_existing=True)
def _tensormapswizzlekind(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
@register_attribute_builder("nvgpu.RcpRoundingModeAttr")
def _rcproundingmodeattr(x, context):
return _ods_ir.Attribute.parse(f'#nvgpu<rcp_rounding_mode {str(x)}>', context=context)
@register_attribute_builder("nvgpu.TensorMapInterleaveAttr")
def _tensormapinterleaveattr(x, context):
return _ods_ir.Attribute.parse(f'#nvgpu<interleave {str(x)}>', context=context)
@register_attribute_builder("nvgpu.TensorMapL2PromoAttr")
def _tensormapl2promoattr(x, context):
return _ods_ir.Attribute.parse(f'#nvgpu<l2promo {str(x)}>', context=context)
@register_attribute_builder("nvgpu.TensorMapOOBAttr")
def _tensormapoobattr(x, context):
return _ods_ir.Attribute.parse(f'#nvgpu<oob {str(x)}>', context=context)
@register_attribute_builder("nvgpu.TensorMapSwizzleAttr")
def _tensormapswizzleattr(x, context):
return _ods_ir.Attribute.parse(f'#nvgpu<swizzle {str(x)}>', context=context)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,307 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import (
List as _List,
Optional as _Optional,
Sequence as _Sequence,
Tuple as _Tuple,
Type as _Type,
Union as _Union,
)
from .._mlir_libs import _mlir as _cext
from ..ir import (
ArrayAttr,
Attribute,
BoolAttr,
DenseI64ArrayAttr,
IntegerAttr,
IntegerType,
OpView,
Operation,
ShapedType,
Value,
)
__all__ = [
"equally_sized_accessor",
"get_default_loc_context",
"get_op_result_or_value",
"get_op_results_or_values",
"get_op_result_or_op_results",
"segmented_accessor",
]
def segmented_accessor(elements, raw_segments, idx):
"""
Returns a slice of elements corresponding to the idx-th segment.
elements: a sliceable container (operands or results).
raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing
sizes of the segments.
idx: index of the segment.
"""
segments = _cext.ir.DenseI32ArrayAttr(raw_segments)
start = sum(segments[i] for i in range(idx))
end = start + segments[idx]
return elements[start:end]
def equally_sized_accessor(
elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic
):
"""
Returns a starting position and a number of elements per variadic group
assuming equally-sized groups and the given numbers of preceding groups.
elements: a sequential container.
n_simple: the number of non-variadic groups in the container.
n_variadic: the number of variadic groups in the container.
n_preceding_simple: the number of non-variadic groups preceding the current
group.
n_preceding_variadic: the number of variadic groups preceding the current
group.
"""
total_variadic_length = len(elements) - n_simple
# This should be enforced by the C++-side trait verifier.
assert total_variadic_length % n_variadic == 0
elements_per_group = total_variadic_length // n_variadic
start = n_preceding_simple + n_preceding_variadic * elements_per_group
return start, elements_per_group
def get_default_loc_context(location=None):
"""
Returns a context in which the defaulted location is created. If the location
is None, takes the current location from the stack.
"""
if location is None:
if _cext.ir.Location.current:
return _cext.ir.Location.current.context
return None
return location.context
def get_op_result_or_value(
arg: _Union[
_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList
]
) -> _cext.ir.Value:
"""Returns the given value or the single result of the given op.
This is useful to implement op constructors so that they can take other ops as
arguments instead of requiring the caller to extract results for every op.
Raises ValueError if provided with an op that doesn't have a single result.
"""
if isinstance(arg, _cext.ir.OpView):
return arg.operation.result
elif isinstance(arg, _cext.ir.Operation):
return arg.result
elif isinstance(arg, _cext.ir.OpResultList):
return arg[0]
else:
assert isinstance(arg, _cext.ir.Value), f"expects Value, got {type(arg)}"
return arg
def get_op_results_or_values(
arg: _Union[
_cext.ir.OpView,
_cext.ir.Operation,
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
]
) -> _Union[
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
_cext.ir.OpResultList,
]:
"""Returns the given sequence of values or the results of the given op.
This is useful to implement op constructors so that they can take other ops as
lists of arguments instead of requiring the caller to extract results for
every op.
"""
if isinstance(arg, _cext.ir.OpView):
return arg.operation.results
elif isinstance(arg, _cext.ir.Operation):
return arg.results
else:
return arg
def get_op_result_or_op_results(
op: _Union[_cext.ir.OpView, _cext.ir.Operation],
) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]:
results = op.results
num_results = len(results)
if num_results == 1:
return results[0]
elif num_results > 1:
return results
elif isinstance(op, _cext.ir.OpView):
return op.operation
else:
return op
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
StaticIntLike = _Union[int, IntegerAttr]
ValueLike = _Union[Operation, OpView, Value]
MixedInt = _Union[StaticIntLike, ValueLike]
IntOrAttrList = _Sequence[_Union[IntegerAttr, int]]
OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]]
BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]]
OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]]
MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike]
DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]]
def _dispatch_dynamic_index_list(
indices: _Union[DynamicIndexList, ArrayAttr],
) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]:
"""Dispatches a list of indices to the appropriate form.
This is similar to the custom `DynamicIndexList` directive upstream:
provided indices may be in the form of dynamic SSA values or static values,
and they may be scalable (i.e., as a singleton list) or not. This function
dispatches each index into its respective form. It also extracts the SSA
values and static indices from various similar structures, respectively.
"""
dynamic_indices = []
static_indices = [ShapedType.get_dynamic_size()] * len(indices)
scalable_indices = [False] * len(indices)
# ArrayAttr: Extract index values.
if isinstance(indices, ArrayAttr):
indices = [idx for idx in indices]
def process_nonscalable_index(i, index):
"""Processes any form of non-scalable index.
Returns False if the given index was scalable and thus remains
unprocessed; True otherwise.
"""
if isinstance(index, int):
static_indices[i] = index
elif isinstance(index, IntegerAttr):
static_indices[i] = index.value # pytype: disable=attribute-error
elif isinstance(index, (Operation, Value, OpView)):
dynamic_indices.append(index)
else:
return False
return True
# Process each index at a time.
for i, index in enumerate(indices):
if not process_nonscalable_index(i, index):
# If it wasn't processed, it must be a scalable index, which is
# provided as a _Sequence of one value, so extract and process that.
scalable_indices[i] = True
assert len(index) == 1
ret = process_nonscalable_index(i, index[0])
assert ret
return dynamic_indices, static_indices, scalable_indices
# Dispatches `MixedValues` that all represents integers in various forms into
# the following three categories:
# - `dynamic_values`: a list of `Value`s, potentially from op results;
# - `packed_values`: a value handle, potentially from an op result, associated
# to one or more payload operations of integer type;
# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python
# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`.
# The input is in the form for `packed_values`, only that result is set and the
# other two are empty. Otherwise, the input can be a mix of the other two forms,
# and for each dynamic value, a special value is added to the `static_values`.
def _dispatch_mixed_values(
values: MixedValues,
) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]:
dynamic_values = []
packed_values = None
static_values = None
if isinstance(values, ArrayAttr):
static_values = values
elif isinstance(values, (Operation, Value, OpView)):
packed_values = values
else:
static_values = []
for size in values or []:
if isinstance(size, int):
static_values.append(size)
else:
static_values.append(ShapedType.get_dynamic_size())
dynamic_values.append(size)
static_values = DenseI64ArrayAttr.get(static_values)
return (dynamic_values, packed_values, static_values)
def _get_value_or_attribute_value(
value_or_attr: _Union[any, Attribute, ArrayAttr]
) -> any:
if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"):
return value_or_attr.value
if isinstance(value_or_attr, ArrayAttr):
return _get_value_list(value_or_attr)
return value_or_attr
def _get_value_list(
sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr]
) -> _Sequence[any]:
return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr]
def _get_int_array_attr(
values: _Optional[_Union[ArrayAttr, IntOrAttrList]]
) -> ArrayAttr:
if values is None:
return None
# Turn into a Python list of Python ints.
values = _get_value_list(values)
# Make an ArrayAttr of IntegerAttrs out of it.
return ArrayAttr.get(
[IntegerAttr.get(IntegerType.get_signless(64), v) for v in values]
)
def _get_int_array_array_attr(
values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]]
) -> ArrayAttr:
"""Creates an ArrayAttr of ArrayAttrs of IntegerAttrs.
The input has to be a collection of a collection of integers, where any
Python _Sequence and ArrayAttr are admissible collections and Python ints and
any IntegerAttr are admissible integers. Both levels of collections are
turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s.
If the input is None, an empty ArrayAttr is returned.
"""
if values is None:
return None
# Make sure the outer level is a list.
values = _get_value_list(values)
# The inner level is now either invalid or a mixed sequence of ArrayAttrs and
# Sequences. Make sure the nested values are all lists.
values = [_get_value_list(nested) for nested in values]
# Turn each nested list into an ArrayAttr.
values = [_get_int_array_attr(nested) for nested in values]
# Turn the outer list into an ArrayAttr.
return ArrayAttr.get(values)

Some files were not shown because too many files have changed in this diff Show More