hand
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
import os
|
||||
|
||||
_this_dir = os.path.dirname(__file__)
|
||||
|
||||
|
||||
def get_lib_dirs() -> Sequence[str]:
|
||||
"""Gets the lib directory for linking to shared libraries.
|
||||
|
||||
On some platforms, the package may need to be built specially to export
|
||||
development libraries.
|
||||
"""
|
||||
return [_this_dir]
|
||||
|
||||
|
||||
def get_include_dirs() -> Sequence[str]:
|
||||
"""Gets the include directory for compiling against exported C libraries.
|
||||
|
||||
Depending on how the package was build, development C libraries may or may
|
||||
not be present.
|
||||
"""
|
||||
return [os.path.join(_this_dir, "include")]
|
||||
|
||||
|
||||
# Perform Python level site initialization. This involves:
|
||||
# 1. Attempting to load initializer modules, specific to the distribution.
|
||||
# 2. Defining the concrete mlir.ir.Context that does site specific
|
||||
# initialization.
|
||||
# 3. Registering container classes with their respective protocols.
|
||||
#
|
||||
# Aside from just being far more convenient to do this at the Python level,
|
||||
# it is actually quite hard/impossible to have such __init__ hooks, given
|
||||
# the pybind memory model (i.e. there is not a Python reference to the object
|
||||
# in the scope of the base class __init__).
|
||||
#
|
||||
# For #1, we:
|
||||
# a. Probe for modules named '_mlirRegisterEverything' and
|
||||
# '_site_initialize_{i}', where 'i' is a number starting at zero and
|
||||
# proceeding so long as a module with the name is found.
|
||||
# b. If the module has a 'register_dialects' attribute, it will be called
|
||||
# immediately with a DialectRegistry to populate.
|
||||
# c. If the module has a 'context_init_hook', it will be added to a list
|
||||
# of callbacks that are invoked as the last step of Context
|
||||
# initialization (and passed the Context under construction).
|
||||
# d. If the module has a 'disable_multithreading' attribute, it will be
|
||||
# taken as a boolean. If it is True for any initializer, then the
|
||||
# default behavior of enabling multithreading on the context
|
||||
# will be suppressed. This complies with the original behavior of all
|
||||
# contexts being created with multithreading enabled while allowing
|
||||
# this behavior to be changed if needed (i.e. if a context_init_hook
|
||||
# explicitly sets up multithreading).
|
||||
#
|
||||
# This facility allows downstreams to customize Context creation to their
|
||||
# needs.
|
||||
|
||||
_dialect_registry = None
|
||||
_load_on_create_dialects = None
|
||||
|
||||
|
||||
def get_dialect_registry():
|
||||
global _dialect_registry
|
||||
|
||||
if _dialect_registry is None:
|
||||
from ._mlir import ir
|
||||
|
||||
_dialect_registry = ir.DialectRegistry()
|
||||
|
||||
return _dialect_registry
|
||||
|
||||
|
||||
def append_load_on_create_dialect(dialect: str):
|
||||
global _load_on_create_dialects
|
||||
if _load_on_create_dialects is None:
|
||||
_load_on_create_dialects = [dialect]
|
||||
else:
|
||||
_load_on_create_dialects.append(dialect)
|
||||
|
||||
|
||||
def get_load_on_create_dialects():
|
||||
global _load_on_create_dialects
|
||||
if _load_on_create_dialects is None:
|
||||
_load_on_create_dialects = []
|
||||
return _load_on_create_dialects
|
||||
|
||||
|
||||
def _site_initialize():
|
||||
import importlib
|
||||
import itertools
|
||||
import logging
|
||||
from ._mlir import ir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
post_init_hooks = []
|
||||
disable_multithreading = False
|
||||
# This flag disables eagerly loading all dialects. Eagerly loading is often
|
||||
# not the desired behavior (see
|
||||
# https://github.com/llvm/llvm-project/issues/56037), and the logic is that
|
||||
# if any module has this attribute set, then we don't load all (e.g., it's
|
||||
# being used in a solution where the loading is controlled).
|
||||
disable_load_all_available_dialects = False
|
||||
|
||||
def process_initializer_module(module_name):
|
||||
nonlocal disable_multithreading
|
||||
nonlocal disable_load_all_available_dialects
|
||||
try:
|
||||
m = importlib.import_module(f".{module_name}", __name__)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
except ImportError:
|
||||
message = (
|
||||
f"Error importing mlir initializer {module_name}. This may "
|
||||
"happen in unclean incremental builds but is likely a real bug if "
|
||||
"encountered otherwise and the MLIR Python API may not function."
|
||||
)
|
||||
logger.warning(message, exc_info=True)
|
||||
return False
|
||||
|
||||
logger.debug("Initializing MLIR with module: %s", module_name)
|
||||
if hasattr(m, "register_dialects"):
|
||||
logger.debug("Registering dialects from initializer %r", m)
|
||||
m.register_dialects(get_dialect_registry())
|
||||
if hasattr(m, "context_init_hook"):
|
||||
logger.debug("Adding context init hook from %r", m)
|
||||
post_init_hooks.append(m.context_init_hook)
|
||||
if hasattr(m, "disable_multithreading"):
|
||||
if bool(m.disable_multithreading):
|
||||
logger.debug("Disabling multi-threading for context")
|
||||
disable_multithreading = True
|
||||
if hasattr(m, "disable_load_all_available_dialects"):
|
||||
disable_load_all_available_dialects = True
|
||||
return True
|
||||
|
||||
# If _mlirRegisterEverything is built, then include it as an initializer
|
||||
# module.
|
||||
init_module = None
|
||||
if process_initializer_module("_mlirRegisterEverything"):
|
||||
init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
|
||||
|
||||
# Load all _site_initialize_{i} modules, where 'i' is a number starting
|
||||
# at 0.
|
||||
for i in itertools.count():
|
||||
module_name = f"_site_initialize_{i}"
|
||||
if not process_initializer_module(module_name):
|
||||
break
|
||||
|
||||
ir._Context = ir.Context
|
||||
|
||||
class Context(ir._Context):
|
||||
def __init__(
|
||||
self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.append_dialect_registry(get_dialect_registry())
|
||||
for hook in post_init_hooks:
|
||||
hook(self)
|
||||
if disable_multithreading and thread_pool is not None:
|
||||
raise ValueError(
|
||||
"Context constructor has given thread_pool argument, "
|
||||
"but disable_multithreading flag is True. "
|
||||
"Please, set thread_pool argument to None or "
|
||||
"set disable_multithreading flag to False."
|
||||
)
|
||||
if not disable_multithreading:
|
||||
if thread_pool is None:
|
||||
self.enable_multithreading(True)
|
||||
else:
|
||||
self.set_thread_pool(thread_pool)
|
||||
if load_on_create_dialects is not None:
|
||||
logger.debug(
|
||||
"Loading all dialects from load_on_create_dialects arg %r",
|
||||
load_on_create_dialects,
|
||||
)
|
||||
for dialect in load_on_create_dialects:
|
||||
# This triggers loading the dialect into the context.
|
||||
_ = self.dialects[dialect]
|
||||
else:
|
||||
if disable_load_all_available_dialects:
|
||||
dialects = get_load_on_create_dialects()
|
||||
if dialects:
|
||||
logger.debug(
|
||||
"Loading all dialects from global load_on_create_dialects %r",
|
||||
dialects,
|
||||
)
|
||||
for dialect in dialects:
|
||||
# This triggers loading the dialect into the context.
|
||||
_ = self.dialects[dialect]
|
||||
else:
|
||||
logger.debug("Loading all available dialects")
|
||||
self.load_all_available_dialects()
|
||||
if init_module:
|
||||
logger.debug(
|
||||
"Registering translations from initializer %r", init_module
|
||||
)
|
||||
init_module.register_llvm_translations(self)
|
||||
|
||||
ir.Context = Context
|
||||
|
||||
# Register containers as Sequences, so they can be used with `match`.
|
||||
|
||||
Sequence.register(ir.BlockArgumentList)
|
||||
Sequence.register(ir.BlockList)
|
||||
Sequence.register(ir.BlockSuccessors)
|
||||
Sequence.register(ir.BlockPredecessors)
|
||||
Sequence.register(ir.OperationList)
|
||||
Sequence.register(ir.OpOperandList)
|
||||
Sequence.register(ir.OpOperands)
|
||||
Sequence.register(ir.OpResultList)
|
||||
Sequence.register(ir.OpSuccessors)
|
||||
Sequence.register(ir.RegionSequence)
|
||||
Mapping.register(ir.OpAttributeMap)
|
||||
|
||||
|
||||
_site_initialize()
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,78 @@
|
||||
"""chlo main python extension"""
|
||||
|
||||
from typing import Self
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
class ComparisonDirectionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonDirection attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ComparisonTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RaggedDotDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], lhs_ragged_dimensions: Sequence[int], rhs_group_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a RaggedDotDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def lhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def lhs_ragged_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_group_dimensions(self) -> list[int]: ...
|
||||
|
||||
class PrecisionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Precision attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
Binary file not shown.
@@ -0,0 +1,22 @@
|
||||
"""Registers upstream MLIR dialects used by JAX."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
def register_dialects(arg: ir.DialectRegistry, /) -> None: ...
|
||||
|
||||
def enter_multi_threaded_execution(arg: ir.Context, /) -> None: ...
|
||||
|
||||
def exit_multi_threaded_execution(arg: ir.Context, /) -> None: ...
|
||||
|
||||
def inlined_func_call(callee: ir.Operation, args: Sequence[ir.Value], block: ir.Block, loc: ir.Location | None = None) -> list[ir.Value]:
|
||||
"""
|
||||
Makes an inlined call to a function containing a single block with a single return op.
|
||||
"""
|
||||
|
||||
class TracebackToLocationCache:
|
||||
def __init__(self, code_to_filename: Callable, frame_limit: int, context: ir.Context | None = None) -> None: ...
|
||||
|
||||
def get(self, traceback: Traceback, /) -> ir.Location: ...
|
||||
Binary file not shown.
Binary file not shown.
+66
@@ -0,0 +1,66 @@
|
||||
"""MLIR Python Native Extension"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
|
||||
from . import (
|
||||
ir as ir,
|
||||
passmanager as passmanager,
|
||||
rewrite as rewrite
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
U = TypeVar("U")
|
||||
|
||||
class _Globals:
|
||||
@property
|
||||
def dialect_search_modules(self) -> list[str]: ...
|
||||
|
||||
@dialect_search_modules.setter
|
||||
def dialect_search_modules(self, arg: Sequence[str], /) -> None: ...
|
||||
|
||||
def append_dialect_search_prefix(self, module_name: str) -> None: ...
|
||||
|
||||
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
|
||||
|
||||
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: object, *, replace: bool = False) -> None:
|
||||
"""Testing hook for directly registering a dialect"""
|
||||
|
||||
def _register_operation_impl(self, operation_name: str, operation_class: object, *, replace: bool = False) -> None:
|
||||
"""Testing hook for directly registering an operation"""
|
||||
|
||||
def loc_tracebacks_enabled(self) -> bool: ...
|
||||
|
||||
def set_loc_tracebacks_enabled(self, arg: bool, /) -> None: ...
|
||||
|
||||
def loc_tracebacks_frame_limit(self) -> int: ...
|
||||
|
||||
def set_loc_tracebacks_frame_limit(self, arg: int | None) -> None: ...
|
||||
|
||||
def register_traceback_file_inclusion(self, arg: str, /) -> None: ...
|
||||
|
||||
def register_traceback_file_exclusion(self, arg: str, /) -> None: ...
|
||||
|
||||
globals: _Globals = ...
|
||||
|
||||
def register_dialect(dialect_class: type) -> type:
|
||||
"""Class decorator for registering a custom Dialect wrapper"""
|
||||
|
||||
def register_operation(dialect_class: type, *, replace: bool = False) -> Callable[[type[T]], type[T]]:
|
||||
"""
|
||||
Produce a class decorator for registering an Operation class as part of a dialect
|
||||
"""
|
||||
|
||||
def register_op_adaptor(op_class: type, *, replace: bool = False) -> Callable[[type[T]], type[T]]:
|
||||
"""
|
||||
Produce a class decorator for registering an OpAdaptor class for an operation.
|
||||
"""
|
||||
|
||||
def register_type_caster(typeid: _ir.TypeID, *, replace: bool = False) -> Callable[[Callable[[T], U]], Callable[[T], U]]:
|
||||
"""Register a type caster for casting MLIR types to custom user types."""
|
||||
|
||||
def register_value_caster(typeid: _ir.TypeID, *, replace: bool = False) -> Callable[[Callable[[T], U]], Callable[[T], U]]:
|
||||
"""Register a value caster for casting MLIR values to custom user values."""
|
||||
File diff suppressed because it is too large
Load Diff
+79
@@ -0,0 +1,79 @@
|
||||
"""MLIR Pass Management Bindings"""
|
||||
|
||||
from collections.abc import Callable
|
||||
import enum
|
||||
from typing import overload
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class PassDisplayMode(enum.Enum):
|
||||
LIST = 0
|
||||
|
||||
PIPELINE = 1
|
||||
|
||||
class ExternalPass:
|
||||
def signal_pass_failure(self) -> None: ...
|
||||
|
||||
class PassManager:
|
||||
def __init__(self, anchor_op: str = 'any', context: ir.Context | None = None) -> None:
|
||||
"""Create a new PassManager for the current (or provided) Context."""
|
||||
|
||||
@property
|
||||
def _CAPIPtr(self) -> object: ...
|
||||
|
||||
def _CAPICreate(self) -> object: ...
|
||||
|
||||
def _testing_release(self) -> None:
|
||||
"""Releases (leaks) the backing pass manager (testing)"""
|
||||
|
||||
def enable_ir_printing(self, print_before_all: bool = False, print_after_all: bool = True, print_module_scope: bool = False, print_after_change: bool = False, print_after_failure: bool = False, large_elements_limit: int | None = None, large_resource_limit: int | None = None, enable_debug_info: bool = False, print_generic_op_form: bool = False, tree_printing_dir_path: str | None = None) -> None:
|
||||
"""Enable IR printing, default as mlir-print-ir-after-all."""
|
||||
|
||||
def enable_verifier(self, enable: bool) -> None:
|
||||
"""Enable / disable verify-each."""
|
||||
|
||||
def enable_timing(self) -> None:
|
||||
"""Enable pass timing."""
|
||||
|
||||
def enable_statistics(self, displayMode: PassDisplayMode = PassDisplayMode.PIPELINE) -> None:
|
||||
"""Enable pass statistics."""
|
||||
|
||||
@staticmethod
|
||||
def parse(pipeline: str, context: ir.Context | None = None) -> PassManager:
|
||||
"""
|
||||
Parse a textual pass-pipeline and return a top-level PassManager that can be applied on a Module. Throw a ValueError if the pipeline can't be parsed
|
||||
"""
|
||||
|
||||
@overload
|
||||
def add(self, pipeline: str) -> None:
|
||||
"""
|
||||
Add textual pipeline elements to the pass manager. Throws a ValueError if the pipeline can't be parsed.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def add(self, run: Callable, name: str | None = None, argument: str | None = '', description: str | None = '', op_name: str | None = '') -> None:
|
||||
"""
|
||||
Add a python-defined pass to the current pipeline of the pass manager.
|
||||
|
||||
Args:
|
||||
run: A callable with signature ``(op: ir.Operation, pass_: ExternalPass) -> None``.
|
||||
Called when the pass executes. It receives the operation to be processed and
|
||||
the current ``ExternalPass`` instance.
|
||||
Use ``pass_.signal_pass_failure()`` to signal failure.
|
||||
name: The name of the pass. Defaults to ``run.__name__``.
|
||||
argument: The command-line argument for the pass. Defaults to empty.
|
||||
description: The description of the pass. Defaults to empty.
|
||||
op_name: The name of the operation this pass operates on.
|
||||
It will be a generic operation pass if not specified.
|
||||
"""
|
||||
|
||||
def run(self, operation: ir._OperationBase) -> None:
|
||||
"""
|
||||
Run the pass manager on the provided operation, raising an MLIRError on failure.
|
||||
"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Print the textual representation for this PassManager, suitable to be passed to `parse` for round-tripping.
|
||||
"""
|
||||
+241
@@ -0,0 +1,241 @@
|
||||
"""MLIR Rewrite Bindings"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import enum
|
||||
from typing import overload
|
||||
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class GreedyRewriteStrictness(enum.Enum):
|
||||
ANY_OP = 0
|
||||
|
||||
EXISTING_AND_NEW_OPS = 1
|
||||
|
||||
EXISTING_OPS = 2
|
||||
|
||||
class GreedySimplifyRegionLevel(enum.Enum):
|
||||
DISABLED = 0
|
||||
|
||||
NORMAL = 1
|
||||
|
||||
AGGRESSIVE = 2
|
||||
|
||||
class DialectConversionFoldingMode(enum.Enum):
|
||||
NEVER = 0
|
||||
|
||||
BEFORE_PATTERNS = 1
|
||||
|
||||
AFTER_PATTERNS = 2
|
||||
|
||||
class PatternRewriter:
|
||||
@property
|
||||
def ip(self) -> ir.InsertionPoint:
|
||||
"""The current insertion point of the PatternRewriter."""
|
||||
|
||||
@overload
|
||||
def replace_op(self, op: ir._OperationBase, new_op: ir._OperationBase) -> None:
|
||||
"""Replace an operation with a new operation."""
|
||||
|
||||
@overload
|
||||
def replace_op(self, op: ir._OperationBase, values: Sequence[ir.Value]) -> None:
|
||||
"""Replace an operation with a list of values."""
|
||||
|
||||
def erase_op(self, op: ir._OperationBase) -> None:
|
||||
"""Erase an operation."""
|
||||
|
||||
class RewritePatternSet:
|
||||
def __init__(self, context: _ir.Context | None = None) -> None: ...
|
||||
|
||||
def add(self, root: object, fn: Callable, benefit: int = 1) -> None:
|
||||
"""
|
||||
Add a new rewrite pattern on the specified root operation, using
|
||||
the provided callable for matching and rewriting, and assign it
|
||||
the given benefit.
|
||||
|
||||
Args:
|
||||
root: The root operation to which this pattern applies. This may
|
||||
be either an OpView subclass or an operation name.
|
||||
fn: The callable to use for matching and rewriting, which takes
|
||||
an operation and a pattern rewriter. The match is considered
|
||||
successful iff the callable returns a falsy value.
|
||||
benefit: The benefit of the pattern, defaulting to 1.
|
||||
"""
|
||||
|
||||
def add_conversion(self, root: object, fn: Callable, type_converter: TypeConverter, benefit: int = 1) -> None:
|
||||
"""
|
||||
Add a new conversion pattern on the specified root operation,
|
||||
using the provided callable for matching and rewriting,
|
||||
and assign it the given benefit.
|
||||
|
||||
Args:
|
||||
root: The root operation to which this pattern applies.
|
||||
This may be either an OpView subclass or an operation name.
|
||||
fn: The callable to use for matching and rewriting, which takes an
|
||||
operation, its adaptor, the type converter and a pattern
|
||||
rewriter. The match is considered successful iff the callable
|
||||
returns a falsy value.
|
||||
type_converter: The type converter to convert types in the IR.
|
||||
benefit: The benefit of the pattern, defaulting to 1.
|
||||
"""
|
||||
|
||||
def freeze(self) -> FrozenRewritePatternSet:
|
||||
"""Freeze the pattern set into a frozen one."""
|
||||
|
||||
class ConversionPatternRewriter(PatternRewriter):
|
||||
def convert_region_types(self, arg0: ir.Region, arg1: TypeConverter, /) -> None: ...
|
||||
|
||||
class ConversionTarget:
|
||||
def __init__(self, context: _ir.Context | None = None) -> None: ...
|
||||
|
||||
def add_legal_op(self, *ops) -> None:
|
||||
"""Mark the given operations as legal."""
|
||||
|
||||
def add_illegal_op(self, *ops) -> None:
|
||||
"""Mark the given operations as illegal."""
|
||||
|
||||
def add_legal_dialect(self, *dialects) -> None:
|
||||
"""Mark the given dialects as legal."""
|
||||
|
||||
def add_illegal_dialect(self, *dialects) -> None:
|
||||
"""Mark the given dialect as illegal."""
|
||||
|
||||
class TypeConverter:
|
||||
def __init__(self) -> None:
|
||||
"""Create a new TypeConverter."""
|
||||
|
||||
def add_conversion(self, convert: Callable) -> None:
|
||||
"""Register a type conversion function."""
|
||||
|
||||
def convert_type(self, type: ir.Type) -> ir.Type | None:
|
||||
"""Convert the given type. Returns None if conversion fails."""
|
||||
|
||||
class PDLResultList:
|
||||
@overload
|
||||
def append(self, arg: ir.Value, /) -> None: ...
|
||||
|
||||
@overload
|
||||
def append(self, arg: ir.Operation, /) -> None: ...
|
||||
|
||||
@overload
|
||||
def append(self, arg: ir.Type, /) -> None: ...
|
||||
|
||||
@overload
|
||||
def append(self, arg: ir.Attribute, /) -> None: ...
|
||||
|
||||
class PDLModule:
|
||||
@overload
|
||||
def __init__(self, module: ir.Module) -> None:
|
||||
"""Create a PDL module from the given module."""
|
||||
|
||||
@overload
|
||||
def __init__(self, module: ir.Module) -> None: ...
|
||||
|
||||
def freeze(self) -> FrozenRewritePatternSet: ...
|
||||
|
||||
def register_rewrite_function(self, arg0: str, arg1: Callable, /) -> None: ...
|
||||
|
||||
def register_constraint_function(self, arg0: str, arg1: Callable, /) -> None: ...
|
||||
|
||||
class GreedyRewriteConfig:
|
||||
def __init__(self) -> None:
|
||||
"""Create a greedy rewrite driver config with defaults"""
|
||||
|
||||
@property
|
||||
def max_iterations(self) -> int:
|
||||
"""Maximum number of iterations"""
|
||||
|
||||
@max_iterations.setter
|
||||
def max_iterations(self, arg: int, /) -> None: ...
|
||||
|
||||
@property
|
||||
def max_num_rewrites(self) -> int:
|
||||
"""Maximum number of rewrites per iteration"""
|
||||
|
||||
@max_num_rewrites.setter
|
||||
def max_num_rewrites(self, arg: int, /) -> None: ...
|
||||
|
||||
@property
|
||||
def use_top_down_traversal(self) -> bool:
|
||||
"""Whether to use top-down traversal"""
|
||||
|
||||
@use_top_down_traversal.setter
|
||||
def use_top_down_traversal(self, arg: bool, /) -> None: ...
|
||||
|
||||
@property
|
||||
def enable_folding(self) -> bool:
|
||||
"""Enable or disable folding"""
|
||||
|
||||
@enable_folding.setter
|
||||
def enable_folding(self, arg: bool, /) -> None: ...
|
||||
|
||||
@property
|
||||
def strictness(self) -> GreedyRewriteStrictness:
|
||||
"""Rewrite strictness level"""
|
||||
|
||||
@strictness.setter
|
||||
def strictness(self, arg: GreedyRewriteStrictness, /) -> None: ...
|
||||
|
||||
@property
|
||||
def region_simplification_level(self) -> GreedySimplifyRegionLevel:
|
||||
"""Region simplification level"""
|
||||
|
||||
@region_simplification_level.setter
|
||||
def region_simplification_level(self, arg: GreedySimplifyRegionLevel, /) -> None: ...
|
||||
|
||||
@property
|
||||
def enable_constant_cse(self) -> bool:
|
||||
"""Enable or disable constant CSE"""
|
||||
|
||||
@enable_constant_cse.setter
|
||||
def enable_constant_cse(self, arg: bool, /) -> None: ...
|
||||
|
||||
class ConversionConfig:
|
||||
def __init__(self) -> None:
|
||||
"""Create a conversion config with defaults"""
|
||||
|
||||
@property
|
||||
def folding_mode(self) -> DialectConversionFoldingMode:
|
||||
"""folding behavior during dialect conversion"""
|
||||
|
||||
@folding_mode.setter
|
||||
def folding_mode(self, arg: DialectConversionFoldingMode, /) -> None: ...
|
||||
|
||||
@property
|
||||
def build_materializations(self) -> bool:
|
||||
"""
|
||||
Whether the dialect conversion attempts to build source/target materializations
|
||||
"""
|
||||
|
||||
@build_materializations.setter
|
||||
def build_materializations(self, arg: bool, /) -> None: ...
|
||||
|
||||
class FrozenRewritePatternSet:
|
||||
@property
|
||||
def _CAPIPtr(self) -> object: ...
|
||||
|
||||
def _CAPICreate(self) -> object: ...
|
||||
|
||||
@overload
|
||||
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet, config: GreedyRewriteConfig | None = None) -> None:
|
||||
"""
|
||||
Applys the given patterns to the given module greedily while folding results.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet, config: GreedyRewriteConfig | None = None) -> None:
|
||||
"""
|
||||
Applys the given patterns to the given op greedily while folding results.
|
||||
"""
|
||||
|
||||
def walk_and_apply_patterns(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None:
|
||||
"""
|
||||
Applies the given patterns to the given op by a fast walk-based driver.
|
||||
"""
|
||||
|
||||
def apply_partial_conversion(op: ir._OperationBase, target: ConversionTarget, set: FrozenRewritePatternSet, config: ConversionConfig | None = None) -> None:
|
||||
"""Applies a partial conversion on the given operation."""
|
||||
|
||||
def apply_full_conversion(op: ir._OperationBase, target: ConversionTarget, set: FrozenRewritePatternSet, config: ConversionConfig | None = None) -> None:
|
||||
"""Applies a full conversion on the given operation."""
|
||||
+61
@@ -0,0 +1,61 @@
|
||||
"""MLIR GPU Dialect"""
|
||||
|
||||
from typing import ClassVar, Final
|
||||
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class AsyncTokenType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(context: _ir.Context | None = None) -> AsyncTokenType:
|
||||
"""Gets an instance of AsyncTokenType in the same context"""
|
||||
|
||||
class ObjectAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
attr_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(target: ir.Attribute, format: int, object: bytes, properties: ir.DictAttr | None = None, kernels: ir.Attribute | None = None, context: _ir.Context | None = None) -> ObjectAttr:
|
||||
"""Gets a gpu.object from parameters."""
|
||||
|
||||
@property
|
||||
def target(self) -> ir.Attribute: ...
|
||||
|
||||
@property
|
||||
def format(self) -> int: ...
|
||||
|
||||
@property
|
||||
def object(self) -> bytes: ...
|
||||
|
||||
@property
|
||||
def properties(self) -> ir.DictAttr | None: ...
|
||||
|
||||
@property
|
||||
def kernels(self) -> ir.Attribute | None: ...
|
||||
BIN
Binary file not shown.
+207
@@ -0,0 +1,207 @@
|
||||
"""MLIR LLVM Dialect"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import ClassVar, Final
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
class StructType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get_literal(elements: Sequence[ir.Type], *, packed: bool = False, loc: _ir.Location | None = None, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
@staticmethod
|
||||
def get_literal_unchecked(elements: Sequence[ir.Type], *, packed: bool = False, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
@staticmethod
|
||||
def get_identified(name: str, *, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
@staticmethod
|
||||
def get_opaque(name: str, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
def set_body(self, elements: Sequence[ir.Type], *, packed: bool = False) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def new_identified(name: str, elements: Sequence[ir.Type], *, packed: bool = False, context: _ir.Context | None = None) -> StructType: ...
|
||||
|
||||
@property
|
||||
def name(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def body(self) -> object: ...
|
||||
|
||||
@property
|
||||
def packed(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def opaque(self) -> bool: ...
|
||||
|
||||
class ArrayType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(element_type: ir.Type, num_elements: int) -> ArrayType: ...
|
||||
|
||||
@property
|
||||
def element_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def num_elements(self) -> int: ...
|
||||
|
||||
class PointerType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(address_space: int | None = None, *, context: _ir.Context | None = None) -> PointerType: ...
|
||||
|
||||
@property
|
||||
def address_space(self) -> int: ...
|
||||
|
||||
class FunctionType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(result_type: ir.Type, argument_types: Sequence[ir.Type], *, is_var_arg: bool = False) -> FunctionType: ...
|
||||
|
||||
@property
|
||||
def return_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def num_inputs(self) -> int: ...
|
||||
|
||||
@property
|
||||
def inputs(self) -> list: ...
|
||||
|
||||
@property
|
||||
def is_var_arg(self) -> bool: ...
|
||||
|
||||
class MDStringAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def get(value: str, *, context: _ir.Context | None = None) -> MDStringAttr: ...
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class MDConstantAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def get(value: ir.Attribute, *, context: _ir.Context | None = None) -> MDConstantAttr: ...
|
||||
|
||||
@property
|
||||
def value(self) -> ir.Attribute: ...
|
||||
|
||||
class MDFuncAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def get(name: str, *, context: _ir.Context | None = None) -> MDFuncAttr: ...
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
class MDNodeAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@staticmethod
|
||||
def get(operands: Sequence[ir.Attribute], *, context: _ir.Context | None = None) -> MDNodeAttr: ...
|
||||
|
||||
@property
|
||||
def num_operands(self) -> int: ...
|
||||
|
||||
def __getitem__(self, arg: int, /) -> ir.Attribute: ...
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
def translate_module_to_llvmir(module: ir.Operation) -> str: ...
|
||||
BIN
Binary file not shown.
+25
@@ -0,0 +1,25 @@
|
||||
"""MLIR NVGPU dialect."""
|
||||
|
||||
from typing import ClassVar, Final
|
||||
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class TensorMapDescriptorType(ir.Type):
|
||||
def __init__(self, cast_from_type: ir.Type) -> None: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
type_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(tensor_type: ir.Type, swizzle: int, l2promo: int, oob_fill: int, interleave: int, context: _ir.Context | None = None) -> TensorMapDescriptorType:
|
||||
"""Gets an instance of TensorMapDescriptorType in the same context"""
|
||||
BIN
Binary file not shown.
+97
@@ -0,0 +1,97 @@
|
||||
"""MLIR SparseTensor dialect."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import enum
|
||||
from typing import ClassVar, Final
|
||||
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
|
||||
class LevelFormat(enum.IntFlag):
|
||||
_boundary_: enum.FlagBoundary = enum.FlagBoundary.KEEP
|
||||
|
||||
_flag_mask_: int = 3997696
|
||||
|
||||
_singles_mask_: int = 3997696
|
||||
|
||||
_all_bits_: int = 8388607
|
||||
|
||||
_inverted_: None = None
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
def __repr__(self, /):
|
||||
"""Return repr(self)."""
|
||||
|
||||
dense = 65536
|
||||
|
||||
n_out_of_m = 2097152
|
||||
|
||||
compressed = 262144
|
||||
|
||||
singleton = 524288
|
||||
|
||||
loose_compressed = 1048576
|
||||
|
||||
class LevelProperty(enum.Enum):
|
||||
non_ordered = 2
|
||||
|
||||
non_unique = 1
|
||||
|
||||
soa = 4
|
||||
|
||||
class EncodingAttr(ir.Attribute):
|
||||
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
|
||||
|
||||
@property
|
||||
def type(self) -> ir.Type: ...
|
||||
|
||||
static_typeid: ClassVar[Final[ir.TypeID]] = ...
|
||||
"""(arg: object, /) -> ir.TypeID"""
|
||||
|
||||
@property
|
||||
def typeid(self) -> ir.TypeID: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
attr_name: ClassVar[Final[str]] = ...
|
||||
"""(arg: object, /) -> str"""
|
||||
|
||||
@staticmethod
|
||||
def get(lvl_types: Sequence[int], dim_to_lvl: ir.AffineMap | None, lvl_to_dim: ir.AffineMap | None, pos_width: int, crd_width: int, explicit_val: ir.Attribute | None = None, implicit_val: ir.Attribute | None = None, context: _ir.Context | None = None) -> EncodingAttr:
|
||||
"""Gets a sparse_tensor.encoding from parameters."""
|
||||
|
||||
@staticmethod
|
||||
def build_level_type(lvl_fmt: LevelFormat, properties: Sequence[LevelProperty] = [], n: int = 0, m: int = 0) -> int:
|
||||
"""Builds a sparse_tensor.encoding.level_type from parameters."""
|
||||
|
||||
@property
|
||||
def lvl_types(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def dim_to_lvl(self) -> ir.AffineMap | None: ...
|
||||
|
||||
@property
|
||||
def lvl_to_dim(self) -> ir.AffineMap | None: ...
|
||||
|
||||
@property
|
||||
def pos_width(self) -> int: ...
|
||||
|
||||
@property
|
||||
def crd_width(self) -> int: ...
|
||||
|
||||
@property
|
||||
def explicit_val(self) -> ir.Attribute | None: ...
|
||||
|
||||
@property
|
||||
def implicit_val(self) -> ir.Attribute | None: ...
|
||||
|
||||
@property
|
||||
def structured_n(self) -> int: ...
|
||||
|
||||
@property
|
||||
def structured_m(self) -> int: ...
|
||||
|
||||
@property
|
||||
def lvl_formats_enum(self) -> list[LevelFormat]: ...
|
||||
BIN
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,309 @@
|
||||
"""mlir-hlo main python extension"""
|
||||
|
||||
from typing import Self
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_mhlo_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
def register_mhlo_passes() -> None: ...
|
||||
|
||||
class TokenType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Token type."""
|
||||
|
||||
class ScatterDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, update_window_dims: Sequence[int], inserted_window_dims: Sequence[int], input_batching_dims: Sequence[int], scatter_indices_batching_dims: Sequence[int], scattered_dims_to_operand_dims: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a ScatterDimensionNumbers with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def update_window_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def inserted_window_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def input_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def scatter_indices_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def scattered_dims_to_operand_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def index_vector_dim(self) -> int: ...
|
||||
|
||||
class GatherDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, offset_dims: Sequence[int], collapsed_slice_dims: Sequence[int], operand_batching_dims: Sequence[int], start_indices_batching_dims: Sequence[int], start_index_map: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a GatherDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def offset_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def collapsed_slice_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def start_indices_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def start_index_map(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def index_vector_dim(self) -> int: ...
|
||||
|
||||
class DotDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a DotDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def lhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
class ConvDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, input_batch_dimension: int, input_feature_dimension: int, input_spatial_dimensions: Sequence[int], kernel_input_feature_dimension: int, kernel_output_feature_dimension: int, kernel_spatial_dimensions: Sequence[int], output_batch_dimension: int, output_feature_dimension: int, output_spatial_dimensions: Sequence[int], ctx: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a ConvDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_batch_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def input_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def input_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def kernel_input_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def kernel_output_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def kernel_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def output_batch_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def output_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def output_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
class OutputOperandAlias(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, output_tuple_indices: Sequence[int], operand_index: int, operand_tuple_indices: Sequence[int], ctx: ir.Context | None = None) -> Self:
|
||||
"""Creates a OutputOperandAlias attribute with the given tuple index."""
|
||||
|
||||
@property
|
||||
def output_tuple_indices(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_index(self) -> int: ...
|
||||
|
||||
@property
|
||||
def operand_tuple_indices(self) -> list[int]: ...
|
||||
|
||||
class ComparisonDirectionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonDirection attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ComparisonTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class PrecisionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Precision attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class FftTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a FftType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class DequantizeModeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a DequantizeMode attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class TransposeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Transpose attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class FusionKindAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a FusionKind attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RngDistributionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a RngDistribution attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RngAlgorithmAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a RngAlgorithm attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ChannelHandle(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, handle: int, type: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ChannelHandle attribute."""
|
||||
|
||||
@property
|
||||
def handle(self) -> int: ...
|
||||
|
||||
@property
|
||||
def channel_type(self) -> int: ...
|
||||
|
||||
class TypeExtensions(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, bounds: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a TypeExtensions with the given bounds."""
|
||||
|
||||
@property
|
||||
def bounds(self) -> list[int]: ...
|
||||
Binary file not shown.
BIN
Binary file not shown.
+295
@@ -0,0 +1,295 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
import enum
|
||||
from mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
|
||||
def register_inliner_extensions(arg: ir.Context, /) -> None: ...
|
||||
|
||||
class BarrierType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
orders_tensor_core: bool = False, context: ir.Context | None = None
|
||||
) -> BarrierType:
|
||||
"""Creates a BarrierType."""
|
||||
|
||||
@property
|
||||
def orders_tensor_core(self) -> bool: ...
|
||||
|
||||
class TileTransformAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
tiling: Sequence[int], context: ir.Context | None = None
|
||||
) -> TileTransformAttr:
|
||||
"""Creates a TileTransformAttr with the given tiling."""
|
||||
|
||||
@property
|
||||
def tiling(self) -> ir.DenseI32ArrayAttr: ...
|
||||
|
||||
class TransposeTransformAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
permutation: Sequence[int], context: ir.Context | None = None
|
||||
) -> TransposeTransformAttr:
|
||||
"""Creates a TransposeTransformAttr with the given permutation."""
|
||||
|
||||
@property
|
||||
def permutation(self) -> ir.DenseI32ArrayAttr: ...
|
||||
|
||||
class SwizzleTransformAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
swizzle: int, context: ir.Context | None = None
|
||||
) -> SwizzleTransformAttr:
|
||||
"""Creates a SwizzleTransformAttr with the given swizzle."""
|
||||
|
||||
@property
|
||||
def swizzle(self) -> int: ...
|
||||
|
||||
class WGSplatFragLayoutAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
shape: ir.DenseI64ArrayAttr, context: ir.Context | None = None
|
||||
) -> WGSplatFragLayoutAttr:
|
||||
"""Creates a WGSplatFragLayoutAttr with the given shape."""
|
||||
|
||||
@property
|
||||
def shape(self) -> ir.DenseI64ArrayAttr: ...
|
||||
|
||||
class WGStridedFragLayoutAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
shape: ir.DenseI64ArrayAttr,
|
||||
vector_size: int,
|
||||
context: ir.Context | None = None,
|
||||
) -> WGStridedFragLayoutAttr:
|
||||
"""Creates a WGStridedFragLayoutAttr."""
|
||||
|
||||
@property
|
||||
def shape(self) -> ir.DenseI64ArrayAttr: ...
|
||||
@property
|
||||
def vector_size(self) -> int: ...
|
||||
|
||||
class ReplicatedAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(times: int, context: ir.Context | None = None) -> ReplicatedAttr:
|
||||
"""Creates a ReplicatedAttr."""
|
||||
|
||||
@property
|
||||
def times(self) -> int: ...
|
||||
|
||||
class TiledLayoutAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
tiling: ir.ArrayAttr,
|
||||
warp_dims: ir.ArrayAttr,
|
||||
lane_dims: ir.ArrayAttr,
|
||||
vector_dim: int,
|
||||
context: ir.Context | None = None,
|
||||
) -> TiledLayoutAttr:
|
||||
"""Creates a TiledLayoutAttr."""
|
||||
|
||||
@property
|
||||
def tiling(self) -> ir.ArrayAttr: ...
|
||||
@property
|
||||
def warp_dims(self) -> ir.ArrayAttr: ...
|
||||
@property
|
||||
def lane_dims(self) -> ir.ArrayAttr: ...
|
||||
@property
|
||||
def vector_dim(self) -> int: ...
|
||||
|
||||
class CopyPartitionAttrInterface(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
class CopyReplicatedAttr(CopyPartitionAttrInterface):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(context: ir.Context | None = None) -> CopyReplicatedAttr:
|
||||
"""Creates a CopyReplicatedAttr."""
|
||||
|
||||
class CopyPartitionedAttr(CopyPartitionAttrInterface):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(axis: int, context: ir.Context | None = None) -> CopyPartitionedAttr:
|
||||
"""Creates a CopyPartitionedAttr."""
|
||||
|
||||
@property
|
||||
def axis(self) -> int: ...
|
||||
|
||||
class Tiling:
|
||||
def __init__(self, tiles: Iterable) -> None: ...
|
||||
def tile_shape(self, shape: Sequence[int]) -> tuple: ...
|
||||
def untile_shape(self, shape: Sequence[int]) -> tuple: ...
|
||||
def tile_strides(self, strides: Sequence[int]) -> tuple: ...
|
||||
def tile_indices(self, indices: Sequence[int]) -> tuple: ...
|
||||
def untile_indices(self, indices: Sequence[int]) -> tuple: ...
|
||||
def tile_nested_shape_strides(
|
||||
self, shape: Sequence[Sequence[int]], strides: Sequence[Sequence[int]]
|
||||
) -> tuple: ...
|
||||
def tile_dimension(self, dim: int) -> tuple: ...
|
||||
def remove_dimension(self, dim: int) -> Tiling: ...
|
||||
def canonicalize(self) -> Tiling: ...
|
||||
@property
|
||||
def tiles(self) -> tuple: ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __eq__(self, other: object) -> bool: ...
|
||||
def __hash__(self) -> int: ...
|
||||
|
||||
class Replicated:
|
||||
def __init__(self, times: int) -> None: ...
|
||||
@property
|
||||
def times(self) -> int: ...
|
||||
@times.setter
|
||||
def times(self, arg: int, /) -> None: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __eq__(self, arg: object, /) -> bool: ...
|
||||
|
||||
class TiledLayout:
|
||||
def __init__(
|
||||
self,
|
||||
tiling: Tiling,
|
||||
warp_dims: Iterable,
|
||||
lane_dims: Iterable,
|
||||
vector_dim: int,
|
||||
_check_canonical: bool = ...,
|
||||
) -> None: ...
|
||||
@property
|
||||
def warp_dims(self) -> tuple: ...
|
||||
@property
|
||||
def lane_dims(self) -> tuple: ...
|
||||
@property
|
||||
def partitioned_warp_dims(self) -> tuple: ...
|
||||
@property
|
||||
def partitioned_lane_dims(self) -> tuple: ...
|
||||
@property
|
||||
def vector_length(self) -> int: ...
|
||||
@property
|
||||
def vector_dim(self) -> int: ...
|
||||
@property
|
||||
def tiling(self) -> Tiling: ...
|
||||
@property
|
||||
def tiled_tiling_shape(self) -> tuple: ...
|
||||
@property
|
||||
def tiled_tiling_rank(self) -> int: ...
|
||||
def warp_indices(self) -> tuple: ...
|
||||
def lane_indices(self) -> tuple: ...
|
||||
def canonicalize(self) -> TiledLayout: ...
|
||||
def registers_shape(self, shape: Sequence[int]) -> tuple: ...
|
||||
def registers_element_type(self, t: ir.Type) -> ir.Type: ...
|
||||
def shape_from_registers_shape(self, shape: Sequence[int]) -> tuple: ...
|
||||
@property
|
||||
def base_tile_shape(self) -> tuple: ...
|
||||
def remove_dimension(self, dim: int) -> TiledLayout: ...
|
||||
def reduce(self, axes: Iterable) -> TiledLayout: ...
|
||||
def thread_idxs(self, arg: Sequence[int], /) -> list: ...
|
||||
def __str__(self) -> str: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __eq__(self, other: object | None) -> bool: ...
|
||||
|
||||
class Rounding(enum.Enum):
|
||||
UP = 0
|
||||
|
||||
DOWN = 1
|
||||
|
||||
class TileTransform:
|
||||
def __init__(
|
||||
self, tiling: Sequence[int], rounding: Rounding | None = ...
|
||||
) -> None: ...
|
||||
def apply(self, arg: object, /) -> ir.Value: ...
|
||||
def transform_index(self, arg: Iterable, /) -> tuple: ...
|
||||
def transform_shape(self, arg: Sequence[int], /) -> tuple: ...
|
||||
def transform_strides(self, arg: Sequence[int], /) -> tuple: ...
|
||||
|
||||
class TrivialTransferPlan:
|
||||
def __init__(self) -> None: ...
|
||||
@property
|
||||
def tile_index_transforms(self) -> tuple: ...
|
||||
def select(self, arg: Iterable, /) -> ir.Value: ...
|
||||
def select_if_group(
|
||||
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
|
||||
) -> ir.Value: ...
|
||||
|
||||
class StaggeredTransferPlan:
|
||||
def __init__(
|
||||
self, stagger: int, dim: int, size: int, group_pred: object
|
||||
) -> None: ...
|
||||
@property
|
||||
def stagger(self) -> int: ...
|
||||
@property
|
||||
def dim(self) -> int: ...
|
||||
@property
|
||||
def size(self) -> int: ...
|
||||
@property
|
||||
def group_pred(self) -> ir.Value: ...
|
||||
@property
|
||||
def tile_index_transforms(self) -> tuple: ...
|
||||
def select(self, arg: Iterable, /) -> ir.Value: ...
|
||||
def select_if_group(
|
||||
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
|
||||
) -> ir.Value: ...
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,210 @@
|
||||
"""SDY main Python extension"""
|
||||
|
||||
from typing import Self
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
class MeshAxisAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, size: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a MeshAxisAttr with the given axis name and size."""
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def size(self) -> int: ...
|
||||
|
||||
class MeshAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, mesh_axes: Sequence[ir.Attribute], device_ids: Sequence[int] = [], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a MeshAttr with the given mesh axes."""
|
||||
|
||||
@property
|
||||
def device_ids(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def axes(self) -> list[ir.Attribute]: ...
|
||||
|
||||
class SubAxisInfoAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, pre_size: int, size: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a SubAxisInfoAttr with the given pre-size and size."""
|
||||
|
||||
@property
|
||||
def pre_size(self) -> int: ...
|
||||
|
||||
@property
|
||||
def size(self) -> int: ...
|
||||
|
||||
class AxisRefAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, sub_axis_info: ir.Attribute | None = None, context: ir.Context | None = None) -> Self:
|
||||
"""Creates an AxisRefAttr with the given name and SubAxisInfoAttr."""
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def sub_axis_info(self) -> ir.Attribute | None: ...
|
||||
|
||||
class DimensionShardingAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, axes: Sequence[ir.Attribute], is_closed: bool, priority: int | None = None, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a DimensionShardingAttr with the given axes, whether it's closed, and priority.
|
||||
"""
|
||||
|
||||
@property
|
||||
def axes(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def priority(self) -> int | None: ...
|
||||
|
||||
class TensorShardingAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, mesh_or_ref: str | ir.Attribute, dimension_shardings: Sequence[ir.Attribute], replicated_axes: Sequence[ir.Attribute] = [], unreduced_axes: Sequence[ir.Attribute] = [], context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a TensorShardingAttr with either an inlined mesh or mesh name, dimension shardings, and replicated axes.
|
||||
"""
|
||||
|
||||
@property
|
||||
def mesh_or_ref(self) -> ir.Attribute: ...
|
||||
|
||||
@property
|
||||
def dimension_shardings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def replicated_axes(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def unreduced_axes(self) -> list[ir.Attribute]: ...
|
||||
|
||||
class TensorShardingPerValueAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, shardings: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a TensorShardingPerValueAttr with the tensor shardings."""
|
||||
|
||||
@property
|
||||
def shardings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
class DimMappingAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, factor_indices: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a DimMappingAttr with the factor indices."""
|
||||
|
||||
@property
|
||||
def factor_indices(self) -> list[int]: ...
|
||||
|
||||
class TensorMappingAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, dim_mappings: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a TensorMappingAttr with the dim mappings."""
|
||||
|
||||
@property
|
||||
def dim_mappings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def rank(self) -> int: ...
|
||||
|
||||
class OpShardingRuleAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, factor_sizes: Sequence[int], operand_mappings: Sequence[ir.Attribute], result_mappings: Sequence[ir.Attribute], reduction_factors: Sequence[int] = [], need_replication_factors: Sequence[int] = [], permutation_factors: Sequence[int] = [], blocked_propagation_factors: Sequence[int] = [], is_custom: bool = False, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a OpShardingRuleAttr with the factor sizes and mappings for operands and results.
|
||||
"""
|
||||
|
||||
@property
|
||||
def is_custom(self) -> bool: ...
|
||||
|
||||
@property
|
||||
def factor_sizes(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_mappings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def result_mappings(self) -> list[ir.Attribute]: ...
|
||||
|
||||
@property
|
||||
def reduction_factors(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def need_replication_factors(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def permutation_factors(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def blocked_propagation_factors(self) -> list[int]: ...
|
||||
|
||||
class ManualAxesAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, manual_axes: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ManualAxesAttr with the given manual axes."""
|
||||
|
||||
def __getitem__(self, arg: int, /) -> str: ...
|
||||
|
||||
def __len__(self) -> int: ...
|
||||
Binary file not shown.
@@ -0,0 +1,23 @@
|
||||
"""MPMD main Python extension"""
|
||||
|
||||
from typing import Self
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
class UserOriginAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, transpose_count: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a UserOriginAttr with the given user name and transpose count."""
|
||||
|
||||
@property
|
||||
def user_name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def transpose_count(self) -> int: ...
|
||||
Binary file not shown.
@@ -0,0 +1,407 @@
|
||||
"""stablehlo main python extension"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import enum
|
||||
from typing import overload, Self
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
def register_interpreter_dialect(context: ir.Context, load: bool = True) -> None: ...
|
||||
|
||||
def register_stablehlo_passes() -> None: ...
|
||||
|
||||
class TokenType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Token type."""
|
||||
|
||||
class FutureType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, types: Sequence[ir.Type], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Future type."""
|
||||
|
||||
@property
|
||||
def types(self) -> list[ir.Type]: ...
|
||||
|
||||
class ScatterDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, update_window_dims: Sequence[int], inserted_window_dims: Sequence[int], input_batching_dims: Sequence[int], scatter_indices_batching_dims: Sequence[int], scattered_dims_to_operand_dims: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a ScatterDimensionNumbers with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def update_window_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def inserted_window_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def input_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def scatter_indices_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def scattered_dims_to_operand_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def index_vector_dim(self) -> int: ...
|
||||
|
||||
class GatherDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, offset_dims: Sequence[int], collapsed_slice_dims: Sequence[int], operand_batching_dims: Sequence[int], start_indices_batching_dims: Sequence[int], start_index_map: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a GatherDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def offset_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def collapsed_slice_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def start_indices_batching_dims(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def start_index_map(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def index_vector_dim(self) -> int: ...
|
||||
|
||||
class DotAlgorithm(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, lhs_precision_type: ir.Type, rhs_precision_type: ir.Type, accumulation_type: ir.Type, lhs_component_count: int, rhs_component_count: int, num_primitive_operations: int, allow_imprecise_accumulation: bool, ctx: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a DotAlgorithm attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lhs_precision_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def rhs_precision_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def accumulation_type(self) -> ir.Type: ...
|
||||
|
||||
@property
|
||||
def lhs_component_count(self) -> int: ...
|
||||
|
||||
@property
|
||||
def rhs_component_count(self) -> int: ...
|
||||
|
||||
@property
|
||||
def num_primitive_operations(self) -> int: ...
|
||||
|
||||
@property
|
||||
def allow_imprecise_accumulation(self) -> bool: ...
|
||||
|
||||
class DotDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a DotDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_batching_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def lhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def rhs_contracting_dimensions(self) -> list[int]: ...
|
||||
|
||||
class ConvDimensionNumbers(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, input_batch_dimension: int, input_feature_dimension: int, input_spatial_dimensions: Sequence[int], kernel_input_feature_dimension: int, kernel_output_feature_dimension: int, kernel_spatial_dimensions: Sequence[int], output_batch_dimension: int, output_feature_dimension: int, output_spatial_dimensions: Sequence[int], ctx: ir.Context | None = None) -> Self:
|
||||
"""
|
||||
Creates a ConvDimensionNumbers attribute with the given dimension configuration.
|
||||
"""
|
||||
|
||||
@property
|
||||
def input_batch_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def input_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def input_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def kernel_input_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def kernel_output_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def kernel_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def output_batch_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def output_feature_dimension(self) -> int: ...
|
||||
|
||||
@property
|
||||
def output_spatial_dimensions(self) -> list[int]: ...
|
||||
|
||||
class OutputOperandAlias(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, output_tuple_indices: Sequence[int], operand_index: int, operand_tuple_indices: Sequence[int], ctx: ir.Context | None = None) -> Self:
|
||||
"""Creates a OutputOperandAlias attribute with the given tuple index."""
|
||||
|
||||
@property
|
||||
def output_tuple_indices(self) -> list[int]: ...
|
||||
|
||||
@property
|
||||
def operand_index(self) -> int: ...
|
||||
|
||||
@property
|
||||
def operand_tuple_indices(self) -> list[int]: ...
|
||||
|
||||
class ComparisonDirectionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonDirection attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ComparisonTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ComparisonType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class PrecisionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Precision attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class FftTypeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a FftType attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class TransposeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a Transpose attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RngDistributionAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a RngDistribution attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class RngAlgorithmAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a RngAlgorithm attribute with the given value."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
class ChannelHandle(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, handle: int, type: int, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ChannelHandle attribute."""
|
||||
|
||||
@property
|
||||
def handle(self) -> int: ...
|
||||
|
||||
@property
|
||||
def channel_type(self) -> int: ...
|
||||
|
||||
class TypeExtensions(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, bounds: Sequence[int], context: ir.Context | None = None) -> Self:
|
||||
"""Creates a TypeExtensions with the given bounds."""
|
||||
|
||||
@property
|
||||
def bounds(self) -> list[int]: ...
|
||||
|
||||
class ResultAccuracyAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, atol: float, rtol: float, ulps: int, mode: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ResultAccuracyAttr with the given values."""
|
||||
|
||||
@property
|
||||
def atol(self) -> float: ...
|
||||
|
||||
@property
|
||||
def rtol(self) -> float: ...
|
||||
|
||||
@property
|
||||
def ulps(self) -> int: ...
|
||||
|
||||
@property
|
||||
def mode(self) -> str: ...
|
||||
|
||||
class ResultAccuracyModeAttr(ir.Attribute):
|
||||
@staticmethod
|
||||
def isinstance(other_attribute: ir.Attribute) -> bool: ...
|
||||
|
||||
def __repr__(self) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def get(cls, value: str, context: ir.Context | None = None) -> Self:
|
||||
"""Creates a ResultAccuracyModeAttr with the given values."""
|
||||
|
||||
@property
|
||||
def value(self) -> str: ...
|
||||
|
||||
def get_api_version() -> int: ...
|
||||
|
||||
def get_smaller_version(version1: str, version2: str) -> str: ...
|
||||
|
||||
def get_current_version() -> str: ...
|
||||
|
||||
def get_minimum_version() -> str: ...
|
||||
|
||||
@overload
|
||||
def serialize_portable_artifact_str(module_str: str, target_version: str) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def serialize_portable_artifact_str(module_str: bytes, target_version: str) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def deserialize_portable_artifact_str(artifact_str: str) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def deserialize_portable_artifact_str(artifact_str: bytes) -> bytes: ...
|
||||
|
||||
class StablehloCompatibilityRequirement(enum.Enum):
|
||||
NONE = 0
|
||||
|
||||
WEEK_4 = 1
|
||||
|
||||
WEEK_12 = 2
|
||||
|
||||
MAX = 3
|
||||
|
||||
def get_version_from_compatibility_requirement(requirement: StablehloCompatibilityRequirement) -> str: ...
|
||||
|
||||
def serialize_portable_artifact(module: ir.Module, target: str, allow_other_dialects: bool = False) -> bytes: ...
|
||||
|
||||
@overload
|
||||
def deserialize_portable_artifact(context: ir.Context, artifact: str) -> ir.Module: ...
|
||||
|
||||
@overload
|
||||
def deserialize_portable_artifact(context: ir.Context, artifact: bytes) -> ir.Module: ...
|
||||
|
||||
def eval_module(module: ir.Module, args: Sequence[ir.Attribute], probe_instrumentation_dir: str = '') -> list[ir.Attribute]: ...
|
||||
Binary file not shown.
@@ -0,0 +1,32 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
|
||||
def private_has_communication(arg: ir.Operation, /) -> tuple[bool, bool]: ...
|
||||
def private_set_arg_attr(
|
||||
arg0: ir.Operation, arg1: int, arg2: str, arg3: ir.Attribute, /
|
||||
) -> None: ...
|
||||
|
||||
class Float8EXMYType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get(
|
||||
exmy_type: ir.Type | None = None, ctx: ir.Context | None = None
|
||||
) -> Float8EXMYType: ...
|
||||
@property
|
||||
def underlying_type(self) -> ir.Type: ...
|
||||
Binary file not shown.
@@ -0,0 +1,36 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jaxlib.mlir import ir
|
||||
|
||||
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
|
||||
|
||||
class PointerType(ir.Type):
|
||||
@staticmethod
|
||||
def isinstance(other_type: ir.Type) -> bool: ...
|
||||
def __repr__(self) -> str: ...
|
||||
@staticmethod
|
||||
def get_static_typeid() -> ir.TypeID: ...
|
||||
@staticmethod
|
||||
def get(pointee_type: ir.Type, address_space: int) -> PointerType:
|
||||
"""Creates a PointerType type."""
|
||||
|
||||
@property
|
||||
def pointee_type(self) -> ir.Type: ...
|
||||
@property
|
||||
def address_space(self) -> int: ...
|
||||
|
||||
def infer_reduce_op_encoding(
|
||||
arg0: ir.Attribute, arg1: int, /
|
||||
) -> ir.Attribute | None: ...
|
||||
Binary file not shown.
Reference in New Issue
Block a user