This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,218 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Any, Mapping, Sequence
import os
_this_dir = os.path.dirname(__file__)
def get_lib_dirs() -> Sequence[str]:
"""Gets the lib directory for linking to shared libraries.
On some platforms, the package may need to be built specially to export
development libraries.
"""
return [_this_dir]
def get_include_dirs() -> Sequence[str]:
"""Gets the include directory for compiling against exported C libraries.
Depending on how the package was build, development C libraries may or may
not be present.
"""
return [os.path.join(_this_dir, "include")]
# Perform Python level site initialization. This involves:
# 1. Attempting to load initializer modules, specific to the distribution.
# 2. Defining the concrete mlir.ir.Context that does site specific
# initialization.
# 3. Registering container classes with their respective protocols.
#
# Aside from just being far more convenient to do this at the Python level,
# it is actually quite hard/impossible to have such __init__ hooks, given
# the pybind memory model (i.e. there is not a Python reference to the object
# in the scope of the base class __init__).
#
# For #1, we:
# a. Probe for modules named '_mlirRegisterEverything' and
# '_site_initialize_{i}', where 'i' is a number starting at zero and
# proceeding so long as a module with the name is found.
# b. If the module has a 'register_dialects' attribute, it will be called
# immediately with a DialectRegistry to populate.
# c. If the module has a 'context_init_hook', it will be added to a list
# of callbacks that are invoked as the last step of Context
# initialization (and passed the Context under construction).
# d. If the module has a 'disable_multithreading' attribute, it will be
# taken as a boolean. If it is True for any initializer, then the
# default behavior of enabling multithreading on the context
# will be suppressed. This complies with the original behavior of all
# contexts being created with multithreading enabled while allowing
# this behavior to be changed if needed (i.e. if a context_init_hook
# explicitly sets up multithreading).
#
# This facility allows downstreams to customize Context creation to their
# needs.
_dialect_registry = None
_load_on_create_dialects = None
def get_dialect_registry():
global _dialect_registry
if _dialect_registry is None:
from ._mlir import ir
_dialect_registry = ir.DialectRegistry()
return _dialect_registry
def append_load_on_create_dialect(dialect: str):
global _load_on_create_dialects
if _load_on_create_dialects is None:
_load_on_create_dialects = [dialect]
else:
_load_on_create_dialects.append(dialect)
def get_load_on_create_dialects():
global _load_on_create_dialects
if _load_on_create_dialects is None:
_load_on_create_dialects = []
return _load_on_create_dialects
def _site_initialize():
import importlib
import itertools
import logging
from ._mlir import ir
logger = logging.getLogger(__name__)
post_init_hooks = []
disable_multithreading = False
# This flag disables eagerly loading all dialects. Eagerly loading is often
# not the desired behavior (see
# https://github.com/llvm/llvm-project/issues/56037), and the logic is that
# if any module has this attribute set, then we don't load all (e.g., it's
# being used in a solution where the loading is controlled).
disable_load_all_available_dialects = False
def process_initializer_module(module_name):
nonlocal disable_multithreading
nonlocal disable_load_all_available_dialects
try:
m = importlib.import_module(f".{module_name}", __name__)
except ModuleNotFoundError:
return False
except ImportError:
message = (
f"Error importing mlir initializer {module_name}. This may "
"happen in unclean incremental builds but is likely a real bug if "
"encountered otherwise and the MLIR Python API may not function."
)
logger.warning(message, exc_info=True)
return False
logger.debug("Initializing MLIR with module: %s", module_name)
if hasattr(m, "register_dialects"):
logger.debug("Registering dialects from initializer %r", m)
m.register_dialects(get_dialect_registry())
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
if hasattr(m, "disable_multithreading"):
if bool(m.disable_multithreading):
logger.debug("Disabling multi-threading for context")
disable_multithreading = True
if hasattr(m, "disable_load_all_available_dialects"):
disable_load_all_available_dialects = True
return True
# If _mlirRegisterEverything is built, then include it as an initializer
# module.
init_module = None
if process_initializer_module("_mlirRegisterEverything"):
init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
# Load all _site_initialize_{i} modules, where 'i' is a number starting
# at 0.
for i in itertools.count():
module_name = f"_site_initialize_{i}"
if not process_initializer_module(module_name):
break
ir._Context = ir.Context
class Context(ir._Context):
def __init__(
self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
):
super().__init__(*args, **kwargs)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if disable_multithreading and thread_pool is not None:
raise ValueError(
"Context constructor has given thread_pool argument, "
"but disable_multithreading flag is True. "
"Please, set thread_pool argument to None or "
"set disable_multithreading flag to False."
)
if not disable_multithreading:
if thread_pool is None:
self.enable_multithreading(True)
else:
self.set_thread_pool(thread_pool)
if load_on_create_dialects is not None:
logger.debug(
"Loading all dialects from load_on_create_dialects arg %r",
load_on_create_dialects,
)
for dialect in load_on_create_dialects:
# This triggers loading the dialect into the context.
_ = self.dialects[dialect]
else:
if disable_load_all_available_dialects:
dialects = get_load_on_create_dialects()
if dialects:
logger.debug(
"Loading all dialects from global load_on_create_dialects %r",
dialects,
)
for dialect in dialects:
# This triggers loading the dialect into the context.
_ = self.dialects[dialect]
else:
logger.debug("Loading all available dialects")
self.load_all_available_dialects()
if init_module:
logger.debug(
"Registering translations from initializer %r", init_module
)
init_module.register_llvm_translations(self)
ir.Context = Context
# Register containers as Sequences, so they can be used with `match`.
Sequence.register(ir.BlockArgumentList)
Sequence.register(ir.BlockList)
Sequence.register(ir.BlockSuccessors)
Sequence.register(ir.BlockPredecessors)
Sequence.register(ir.OperationList)
Sequence.register(ir.OpOperandList)
Sequence.register(ir.OpOperands)
Sequence.register(ir.OpResultList)
Sequence.register(ir.OpSuccessors)
Sequence.register(ir.RegionSequence)
Mapping.register(ir.OpAttributeMap)
_site_initialize()
@@ -0,0 +1,78 @@
"""chlo main python extension"""
from typing import Self
from collections.abc import Sequence
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
class ComparisonDirectionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonDirection attribute with the given value."""
@property
def value(self) -> str: ...
class ComparisonTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonType attribute with the given value."""
@property
def value(self) -> str: ...
class RaggedDotDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], lhs_ragged_dimensions: Sequence[int], rhs_group_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
"""
Creates a RaggedDotDimensionNumbers attribute with the given dimension configuration.
"""
@property
def lhs_batching_dimensions(self) -> list[int]: ...
@property
def rhs_batching_dimensions(self) -> list[int]: ...
@property
def lhs_contracting_dimensions(self) -> list[int]: ...
@property
def rhs_contracting_dimensions(self) -> list[int]: ...
@property
def lhs_ragged_dimensions(self) -> list[int]: ...
@property
def rhs_group_dimensions(self) -> list[int]: ...
class PrecisionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Precision attribute with the given value."""
@property
def value(self) -> str: ...
Binary file not shown.
@@ -0,0 +1,22 @@
"""Registers upstream MLIR dialects used by JAX."""
from collections.abc import Callable, Sequence
from jaxlib.mlir import ir
def register_dialects(arg: ir.DialectRegistry, /) -> None: ...
def enter_multi_threaded_execution(arg: ir.Context, /) -> None: ...
def exit_multi_threaded_execution(arg: ir.Context, /) -> None: ...
def inlined_func_call(callee: ir.Operation, args: Sequence[ir.Value], block: ir.Block, loc: ir.Location | None = None) -> list[ir.Value]:
"""
Makes an inlined call to a function containing a single block with a single return op.
"""
class TracebackToLocationCache:
def __init__(self, code_to_filename: Callable, frame_limit: int, context: ir.Context | None = None) -> None: ...
def get(self, traceback: Traceback, /) -> ir.Location: ...
Binary file not shown.
@@ -0,0 +1,66 @@
"""MLIR Python Native Extension"""
from collections.abc import Callable, Sequence
from typing import TypeVar
from . import (
ir as ir,
passmanager as passmanager,
rewrite as rewrite
)
T = TypeVar("T")
U = TypeVar("U")
class _Globals:
@property
def dialect_search_modules(self) -> list[str]: ...
@dialect_search_modules.setter
def dialect_search_modules(self, arg: Sequence[str], /) -> None: ...
def append_dialect_search_prefix(self, module_name: str) -> None: ...
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
def _register_dialect_impl(self, dialect_namespace: str, dialect_class: object, *, replace: bool = False) -> None:
"""Testing hook for directly registering a dialect"""
def _register_operation_impl(self, operation_name: str, operation_class: object, *, replace: bool = False) -> None:
"""Testing hook for directly registering an operation"""
def loc_tracebacks_enabled(self) -> bool: ...
def set_loc_tracebacks_enabled(self, arg: bool, /) -> None: ...
def loc_tracebacks_frame_limit(self) -> int: ...
def set_loc_tracebacks_frame_limit(self, arg: int | None) -> None: ...
def register_traceback_file_inclusion(self, arg: str, /) -> None: ...
def register_traceback_file_exclusion(self, arg: str, /) -> None: ...
globals: _Globals = ...
def register_dialect(dialect_class: type) -> type:
"""Class decorator for registering a custom Dialect wrapper"""
def register_operation(dialect_class: type, *, replace: bool = False) -> Callable[[type[T]], type[T]]:
"""
Produce a class decorator for registering an Operation class as part of a dialect
"""
def register_op_adaptor(op_class: type, *, replace: bool = False) -> Callable[[type[T]], type[T]]:
"""
Produce a class decorator for registering an OpAdaptor class for an operation.
"""
def register_type_caster(typeid: _ir.TypeID, *, replace: bool = False) -> Callable[[Callable[[T], U]], Callable[[T], U]]:
"""Register a type caster for casting MLIR types to custom user types."""
def register_value_caster(typeid: _ir.TypeID, *, replace: bool = False) -> Callable[[Callable[[T], U]], Callable[[T], U]]:
"""Register a value caster for casting MLIR values to custom user values."""
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,79 @@
"""MLIR Pass Management Bindings"""
from collections.abc import Callable
import enum
from typing import overload
from jaxlib.mlir import ir
class PassDisplayMode(enum.Enum):
LIST = 0
PIPELINE = 1
class ExternalPass:
def signal_pass_failure(self) -> None: ...
class PassManager:
def __init__(self, anchor_op: str = 'any', context: ir.Context | None = None) -> None:
"""Create a new PassManager for the current (or provided) Context."""
@property
def _CAPIPtr(self) -> object: ...
def _CAPICreate(self) -> object: ...
def _testing_release(self) -> None:
"""Releases (leaks) the backing pass manager (testing)"""
def enable_ir_printing(self, print_before_all: bool = False, print_after_all: bool = True, print_module_scope: bool = False, print_after_change: bool = False, print_after_failure: bool = False, large_elements_limit: int | None = None, large_resource_limit: int | None = None, enable_debug_info: bool = False, print_generic_op_form: bool = False, tree_printing_dir_path: str | None = None) -> None:
"""Enable IR printing, default as mlir-print-ir-after-all."""
def enable_verifier(self, enable: bool) -> None:
"""Enable / disable verify-each."""
def enable_timing(self) -> None:
"""Enable pass timing."""
def enable_statistics(self, displayMode: PassDisplayMode = PassDisplayMode.PIPELINE) -> None:
"""Enable pass statistics."""
@staticmethod
def parse(pipeline: str, context: ir.Context | None = None) -> PassManager:
"""
Parse a textual pass-pipeline and return a top-level PassManager that can be applied on a Module. Throw a ValueError if the pipeline can't be parsed
"""
@overload
def add(self, pipeline: str) -> None:
"""
Add textual pipeline elements to the pass manager. Throws a ValueError if the pipeline can't be parsed.
"""
@overload
def add(self, run: Callable, name: str | None = None, argument: str | None = '', description: str | None = '', op_name: str | None = '') -> None:
"""
Add a python-defined pass to the current pipeline of the pass manager.
Args:
run: A callable with signature ``(op: ir.Operation, pass_: ExternalPass) -> None``.
Called when the pass executes. It receives the operation to be processed and
the current ``ExternalPass`` instance.
Use ``pass_.signal_pass_failure()`` to signal failure.
name: The name of the pass. Defaults to ``run.__name__``.
argument: The command-line argument for the pass. Defaults to empty.
description: The description of the pass. Defaults to empty.
op_name: The name of the operation this pass operates on.
It will be a generic operation pass if not specified.
"""
def run(self, operation: ir._OperationBase) -> None:
"""
Run the pass manager on the provided operation, raising an MLIRError on failure.
"""
def __str__(self) -> str:
"""
Print the textual representation for this PassManager, suitable to be passed to `parse` for round-tripping.
"""
@@ -0,0 +1,241 @@
"""MLIR Rewrite Bindings"""
from collections.abc import Callable, Sequence
import enum
from typing import overload
from jaxlib.mlir import ir
class GreedyRewriteStrictness(enum.Enum):
ANY_OP = 0
EXISTING_AND_NEW_OPS = 1
EXISTING_OPS = 2
class GreedySimplifyRegionLevel(enum.Enum):
DISABLED = 0
NORMAL = 1
AGGRESSIVE = 2
class DialectConversionFoldingMode(enum.Enum):
NEVER = 0
BEFORE_PATTERNS = 1
AFTER_PATTERNS = 2
class PatternRewriter:
@property
def ip(self) -> ir.InsertionPoint:
"""The current insertion point of the PatternRewriter."""
@overload
def replace_op(self, op: ir._OperationBase, new_op: ir._OperationBase) -> None:
"""Replace an operation with a new operation."""
@overload
def replace_op(self, op: ir._OperationBase, values: Sequence[ir.Value]) -> None:
"""Replace an operation with a list of values."""
def erase_op(self, op: ir._OperationBase) -> None:
"""Erase an operation."""
class RewritePatternSet:
def __init__(self, context: _ir.Context | None = None) -> None: ...
def add(self, root: object, fn: Callable, benefit: int = 1) -> None:
"""
Add a new rewrite pattern on the specified root operation, using
the provided callable for matching and rewriting, and assign it
the given benefit.
Args:
root: The root operation to which this pattern applies. This may
be either an OpView subclass or an operation name.
fn: The callable to use for matching and rewriting, which takes
an operation and a pattern rewriter. The match is considered
successful iff the callable returns a falsy value.
benefit: The benefit of the pattern, defaulting to 1.
"""
def add_conversion(self, root: object, fn: Callable, type_converter: TypeConverter, benefit: int = 1) -> None:
"""
Add a new conversion pattern on the specified root operation,
using the provided callable for matching and rewriting,
and assign it the given benefit.
Args:
root: The root operation to which this pattern applies.
This may be either an OpView subclass or an operation name.
fn: The callable to use for matching and rewriting, which takes an
operation, its adaptor, the type converter and a pattern
rewriter. The match is considered successful iff the callable
returns a falsy value.
type_converter: The type converter to convert types in the IR.
benefit: The benefit of the pattern, defaulting to 1.
"""
def freeze(self) -> FrozenRewritePatternSet:
"""Freeze the pattern set into a frozen one."""
class ConversionPatternRewriter(PatternRewriter):
def convert_region_types(self, arg0: ir.Region, arg1: TypeConverter, /) -> None: ...
class ConversionTarget:
def __init__(self, context: _ir.Context | None = None) -> None: ...
def add_legal_op(self, *ops) -> None:
"""Mark the given operations as legal."""
def add_illegal_op(self, *ops) -> None:
"""Mark the given operations as illegal."""
def add_legal_dialect(self, *dialects) -> None:
"""Mark the given dialects as legal."""
def add_illegal_dialect(self, *dialects) -> None:
"""Mark the given dialect as illegal."""
class TypeConverter:
def __init__(self) -> None:
"""Create a new TypeConverter."""
def add_conversion(self, convert: Callable) -> None:
"""Register a type conversion function."""
def convert_type(self, type: ir.Type) -> ir.Type | None:
"""Convert the given type. Returns None if conversion fails."""
class PDLResultList:
@overload
def append(self, arg: ir.Value, /) -> None: ...
@overload
def append(self, arg: ir.Operation, /) -> None: ...
@overload
def append(self, arg: ir.Type, /) -> None: ...
@overload
def append(self, arg: ir.Attribute, /) -> None: ...
class PDLModule:
@overload
def __init__(self, module: ir.Module) -> None:
"""Create a PDL module from the given module."""
@overload
def __init__(self, module: ir.Module) -> None: ...
def freeze(self) -> FrozenRewritePatternSet: ...
def register_rewrite_function(self, arg0: str, arg1: Callable, /) -> None: ...
def register_constraint_function(self, arg0: str, arg1: Callable, /) -> None: ...
class GreedyRewriteConfig:
def __init__(self) -> None:
"""Create a greedy rewrite driver config with defaults"""
@property
def max_iterations(self) -> int:
"""Maximum number of iterations"""
@max_iterations.setter
def max_iterations(self, arg: int, /) -> None: ...
@property
def max_num_rewrites(self) -> int:
"""Maximum number of rewrites per iteration"""
@max_num_rewrites.setter
def max_num_rewrites(self, arg: int, /) -> None: ...
@property
def use_top_down_traversal(self) -> bool:
"""Whether to use top-down traversal"""
@use_top_down_traversal.setter
def use_top_down_traversal(self, arg: bool, /) -> None: ...
@property
def enable_folding(self) -> bool:
"""Enable or disable folding"""
@enable_folding.setter
def enable_folding(self, arg: bool, /) -> None: ...
@property
def strictness(self) -> GreedyRewriteStrictness:
"""Rewrite strictness level"""
@strictness.setter
def strictness(self, arg: GreedyRewriteStrictness, /) -> None: ...
@property
def region_simplification_level(self) -> GreedySimplifyRegionLevel:
"""Region simplification level"""
@region_simplification_level.setter
def region_simplification_level(self, arg: GreedySimplifyRegionLevel, /) -> None: ...
@property
def enable_constant_cse(self) -> bool:
"""Enable or disable constant CSE"""
@enable_constant_cse.setter
def enable_constant_cse(self, arg: bool, /) -> None: ...
class ConversionConfig:
def __init__(self) -> None:
"""Create a conversion config with defaults"""
@property
def folding_mode(self) -> DialectConversionFoldingMode:
"""folding behavior during dialect conversion"""
@folding_mode.setter
def folding_mode(self, arg: DialectConversionFoldingMode, /) -> None: ...
@property
def build_materializations(self) -> bool:
"""
Whether the dialect conversion attempts to build source/target materializations
"""
@build_materializations.setter
def build_materializations(self, arg: bool, /) -> None: ...
class FrozenRewritePatternSet:
@property
def _CAPIPtr(self) -> object: ...
def _CAPICreate(self) -> object: ...
@overload
def apply_patterns_and_fold_greedily(module: ir.Module, set: FrozenRewritePatternSet, config: GreedyRewriteConfig | None = None) -> None:
"""
Applys the given patterns to the given module greedily while folding results.
"""
@overload
def apply_patterns_and_fold_greedily(op: ir._OperationBase, set: FrozenRewritePatternSet, config: GreedyRewriteConfig | None = None) -> None:
"""
Applys the given patterns to the given op greedily while folding results.
"""
def walk_and_apply_patterns(op: ir._OperationBase, set: FrozenRewritePatternSet) -> None:
"""
Applies the given patterns to the given op by a fast walk-based driver.
"""
def apply_partial_conversion(op: ir._OperationBase, target: ConversionTarget, set: FrozenRewritePatternSet, config: ConversionConfig | None = None) -> None:
"""Applies a partial conversion on the given operation."""
def apply_full_conversion(op: ir._OperationBase, target: ConversionTarget, set: FrozenRewritePatternSet, config: ConversionConfig | None = None) -> None:
"""Applies a full conversion on the given operation."""
@@ -0,0 +1,61 @@
"""MLIR GPU Dialect"""
from typing import ClassVar, Final
from jaxlib.mlir import ir
class AsyncTokenType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(context: _ir.Context | None = None) -> AsyncTokenType:
"""Gets an instance of AsyncTokenType in the same context"""
class ObjectAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
attr_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(target: ir.Attribute, format: int, object: bytes, properties: ir.DictAttr | None = None, kernels: ir.Attribute | None = None, context: _ir.Context | None = None) -> ObjectAttr:
"""Gets a gpu.object from parameters."""
@property
def target(self) -> ir.Attribute: ...
@property
def format(self) -> int: ...
@property
def object(self) -> bytes: ...
@property
def properties(self) -> ir.DictAttr | None: ...
@property
def kernels(self) -> ir.Attribute | None: ...
@@ -0,0 +1,207 @@
"""MLIR LLVM Dialect"""
from collections.abc import Sequence
from typing import ClassVar, Final
from jaxlib.mlir import ir
class StructType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get_literal(elements: Sequence[ir.Type], *, packed: bool = False, loc: _ir.Location | None = None, context: _ir.Context | None = None) -> StructType: ...
@staticmethod
def get_literal_unchecked(elements: Sequence[ir.Type], *, packed: bool = False, context: _ir.Context | None = None) -> StructType: ...
@staticmethod
def get_identified(name: str, *, context: _ir.Context | None = None) -> StructType: ...
@staticmethod
def get_opaque(name: str, context: _ir.Context | None = None) -> StructType: ...
def set_body(self, elements: Sequence[ir.Type], *, packed: bool = False) -> None: ...
@staticmethod
def new_identified(name: str, elements: Sequence[ir.Type], *, packed: bool = False, context: _ir.Context | None = None) -> StructType: ...
@property
def name(self) -> str | None: ...
@property
def body(self) -> object: ...
@property
def packed(self) -> bool: ...
@property
def opaque(self) -> bool: ...
class ArrayType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(element_type: ir.Type, num_elements: int) -> ArrayType: ...
@property
def element_type(self) -> ir.Type: ...
@property
def num_elements(self) -> int: ...
class PointerType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(address_space: int | None = None, *, context: _ir.Context | None = None) -> PointerType: ...
@property
def address_space(self) -> int: ...
class FunctionType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(result_type: ir.Type, argument_types: Sequence[ir.Type], *, is_var_arg: bool = False) -> FunctionType: ...
@property
def return_type(self) -> ir.Type: ...
@property
def num_inputs(self) -> int: ...
@property
def inputs(self) -> list: ...
@property
def is_var_arg(self) -> bool: ...
class MDStringAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
@staticmethod
def get(value: str, *, context: _ir.Context | None = None) -> MDStringAttr: ...
@property
def value(self) -> str: ...
class MDConstantAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
@staticmethod
def get(value: ir.Attribute, *, context: _ir.Context | None = None) -> MDConstantAttr: ...
@property
def value(self) -> ir.Attribute: ...
class MDFuncAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
@staticmethod
def get(name: str, *, context: _ir.Context | None = None) -> MDFuncAttr: ...
@property
def name(self) -> str: ...
class MDNodeAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
@staticmethod
def get(operands: Sequence[ir.Attribute], *, context: _ir.Context | None = None) -> MDNodeAttr: ...
@property
def num_operands(self) -> int: ...
def __getitem__(self, arg: int, /) -> ir.Attribute: ...
def __len__(self) -> int: ...
def translate_module_to_llvmir(module: ir.Operation) -> str: ...
@@ -0,0 +1,25 @@
"""MLIR NVGPU dialect."""
from typing import ClassVar, Final
from jaxlib.mlir import ir
class TensorMapDescriptorType(ir.Type):
def __init__(self, cast_from_type: ir.Type) -> None: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
type_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(tensor_type: ir.Type, swizzle: int, l2promo: int, oob_fill: int, interleave: int, context: _ir.Context | None = None) -> TensorMapDescriptorType:
"""Gets an instance of TensorMapDescriptorType in the same context"""
@@ -0,0 +1,97 @@
"""MLIR SparseTensor dialect."""
from collections.abc import Sequence
import enum
from typing import ClassVar, Final
from jaxlib.mlir import ir
class LevelFormat(enum.IntFlag):
_boundary_: enum.FlagBoundary = enum.FlagBoundary.KEEP
_flag_mask_: int = 3997696
_singles_mask_: int = 3997696
_all_bits_: int = 8388607
_inverted_: None = None
__str__ = __repr__
def __repr__(self, /):
"""Return repr(self)."""
dense = 65536
n_out_of_m = 2097152
compressed = 262144
singleton = 524288
loose_compressed = 1048576
class LevelProperty(enum.Enum):
non_ordered = 2
non_unique = 1
soa = 4
class EncodingAttr(ir.Attribute):
def __init__(self, cast_from_attr: ir.Attribute) -> None: ...
@property
def type(self) -> ir.Type: ...
static_typeid: ClassVar[Final[ir.TypeID]] = ...
"""(arg: object, /) -> ir.TypeID"""
@property
def typeid(self) -> ir.TypeID: ...
def __repr__(self) -> str: ...
attr_name: ClassVar[Final[str]] = ...
"""(arg: object, /) -> str"""
@staticmethod
def get(lvl_types: Sequence[int], dim_to_lvl: ir.AffineMap | None, lvl_to_dim: ir.AffineMap | None, pos_width: int, crd_width: int, explicit_val: ir.Attribute | None = None, implicit_val: ir.Attribute | None = None, context: _ir.Context | None = None) -> EncodingAttr:
"""Gets a sparse_tensor.encoding from parameters."""
@staticmethod
def build_level_type(lvl_fmt: LevelFormat, properties: Sequence[LevelProperty] = [], n: int = 0, m: int = 0) -> int:
"""Builds a sparse_tensor.encoding.level_type from parameters."""
@property
def lvl_types(self) -> list[int]: ...
@property
def dim_to_lvl(self) -> ir.AffineMap | None: ...
@property
def lvl_to_dim(self) -> ir.AffineMap | None: ...
@property
def pos_width(self) -> int: ...
@property
def crd_width(self) -> int: ...
@property
def explicit_val(self) -> ir.Attribute | None: ...
@property
def implicit_val(self) -> ir.Attribute | None: ...
@property
def structured_n(self) -> int: ...
@property
def structured_m(self) -> int: ...
@property
def lvl_formats_enum(self) -> list[LevelFormat]: ...
@@ -0,0 +1,309 @@
"""mlir-hlo main python extension"""
from typing import Self
from collections.abc import Sequence
from jaxlib.mlir import ir
def register_mhlo_dialect(context: ir.Context, load: bool = True) -> None: ...
def register_mhlo_passes() -> None: ...
class TokenType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, context: ir.Context | None = None) -> Self:
"""Creates a Token type."""
class ScatterDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, update_window_dims: Sequence[int], inserted_window_dims: Sequence[int], input_batching_dims: Sequence[int], scatter_indices_batching_dims: Sequence[int], scattered_dims_to_operand_dims: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
"""
Creates a ScatterDimensionNumbers with the given dimension configuration.
"""
@property
def update_window_dims(self) -> list[int]: ...
@property
def inserted_window_dims(self) -> list[int]: ...
@property
def input_batching_dims(self) -> list[int]: ...
@property
def scatter_indices_batching_dims(self) -> list[int]: ...
@property
def scattered_dims_to_operand_dims(self) -> list[int]: ...
@property
def index_vector_dim(self) -> int: ...
class GatherDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, offset_dims: Sequence[int], collapsed_slice_dims: Sequence[int], operand_batching_dims: Sequence[int], start_indices_batching_dims: Sequence[int], start_index_map: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
"""
Creates a GatherDimensionNumbers attribute with the given dimension configuration.
"""
@property
def offset_dims(self) -> list[int]: ...
@property
def collapsed_slice_dims(self) -> list[int]: ...
@property
def operand_batching_dims(self) -> list[int]: ...
@property
def start_indices_batching_dims(self) -> list[int]: ...
@property
def start_index_map(self) -> list[int]: ...
@property
def index_vector_dim(self) -> int: ...
class DotDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
"""
Creates a DotDimensionNumbers attribute with the given dimension configuration.
"""
@property
def lhs_batching_dimensions(self) -> list[int]: ...
@property
def rhs_batching_dimensions(self) -> list[int]: ...
@property
def lhs_contracting_dimensions(self) -> list[int]: ...
@property
def rhs_contracting_dimensions(self) -> list[int]: ...
class ConvDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, input_batch_dimension: int, input_feature_dimension: int, input_spatial_dimensions: Sequence[int], kernel_input_feature_dimension: int, kernel_output_feature_dimension: int, kernel_spatial_dimensions: Sequence[int], output_batch_dimension: int, output_feature_dimension: int, output_spatial_dimensions: Sequence[int], ctx: ir.Context | None = None) -> Self:
"""
Creates a ConvDimensionNumbers attribute with the given dimension configuration.
"""
@property
def input_batch_dimension(self) -> int: ...
@property
def input_feature_dimension(self) -> int: ...
@property
def input_spatial_dimensions(self) -> list[int]: ...
@property
def kernel_input_feature_dimension(self) -> int: ...
@property
def kernel_output_feature_dimension(self) -> int: ...
@property
def kernel_spatial_dimensions(self) -> list[int]: ...
@property
def output_batch_dimension(self) -> int: ...
@property
def output_feature_dimension(self) -> int: ...
@property
def output_spatial_dimensions(self) -> list[int]: ...
class OutputOperandAlias(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, output_tuple_indices: Sequence[int], operand_index: int, operand_tuple_indices: Sequence[int], ctx: ir.Context | None = None) -> Self:
"""Creates a OutputOperandAlias attribute with the given tuple index."""
@property
def output_tuple_indices(self) -> list[int]: ...
@property
def operand_index(self) -> int: ...
@property
def operand_tuple_indices(self) -> list[int]: ...
class ComparisonDirectionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonDirection attribute with the given value."""
@property
def value(self) -> str: ...
class ComparisonTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonType attribute with the given value."""
@property
def value(self) -> str: ...
class PrecisionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Precision attribute with the given value."""
@property
def value(self) -> str: ...
class FftTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a FftType attribute with the given value."""
@property
def value(self) -> str: ...
class DequantizeModeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a DequantizeMode attribute with the given value."""
@property
def value(self) -> str: ...
class TransposeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Transpose attribute with the given value."""
@property
def value(self) -> str: ...
class FusionKindAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a FusionKind attribute with the given value."""
@property
def value(self) -> str: ...
class RngDistributionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a RngDistribution attribute with the given value."""
@property
def value(self) -> str: ...
class RngAlgorithmAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a RngAlgorithm attribute with the given value."""
@property
def value(self) -> str: ...
class ChannelHandle(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, handle: int, type: int, context: ir.Context | None = None) -> Self:
"""Creates a ChannelHandle attribute."""
@property
def handle(self) -> int: ...
@property
def channel_type(self) -> int: ...
class TypeExtensions(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, bounds: Sequence[int], context: ir.Context | None = None) -> Self:
"""Creates a TypeExtensions with the given bounds."""
@property
def bounds(self) -> list[int]: ...
Binary file not shown.
@@ -0,0 +1,295 @@
# Copyright 2026 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Sequence
import enum
from mlir import ir
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
def register_inliner_extensions(arg: ir.Context, /) -> None: ...
class BarrierType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
orders_tensor_core: bool = False, context: ir.Context | None = None
) -> BarrierType:
"""Creates a BarrierType."""
@property
def orders_tensor_core(self) -> bool: ...
class TileTransformAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
tiling: Sequence[int], context: ir.Context | None = None
) -> TileTransformAttr:
"""Creates a TileTransformAttr with the given tiling."""
@property
def tiling(self) -> ir.DenseI32ArrayAttr: ...
class TransposeTransformAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
permutation: Sequence[int], context: ir.Context | None = None
) -> TransposeTransformAttr:
"""Creates a TransposeTransformAttr with the given permutation."""
@property
def permutation(self) -> ir.DenseI32ArrayAttr: ...
class SwizzleTransformAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
swizzle: int, context: ir.Context | None = None
) -> SwizzleTransformAttr:
"""Creates a SwizzleTransformAttr with the given swizzle."""
@property
def swizzle(self) -> int: ...
class WGSplatFragLayoutAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
shape: ir.DenseI64ArrayAttr, context: ir.Context | None = None
) -> WGSplatFragLayoutAttr:
"""Creates a WGSplatFragLayoutAttr with the given shape."""
@property
def shape(self) -> ir.DenseI64ArrayAttr: ...
class WGStridedFragLayoutAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
shape: ir.DenseI64ArrayAttr,
vector_size: int,
context: ir.Context | None = None,
) -> WGStridedFragLayoutAttr:
"""Creates a WGStridedFragLayoutAttr."""
@property
def shape(self) -> ir.DenseI64ArrayAttr: ...
@property
def vector_size(self) -> int: ...
class ReplicatedAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(times: int, context: ir.Context | None = None) -> ReplicatedAttr:
"""Creates a ReplicatedAttr."""
@property
def times(self) -> int: ...
class TiledLayoutAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(
tiling: ir.ArrayAttr,
warp_dims: ir.ArrayAttr,
lane_dims: ir.ArrayAttr,
vector_dim: int,
context: ir.Context | None = None,
) -> TiledLayoutAttr:
"""Creates a TiledLayoutAttr."""
@property
def tiling(self) -> ir.ArrayAttr: ...
@property
def warp_dims(self) -> ir.ArrayAttr: ...
@property
def lane_dims(self) -> ir.ArrayAttr: ...
@property
def vector_dim(self) -> int: ...
class CopyPartitionAttrInterface(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
class CopyReplicatedAttr(CopyPartitionAttrInterface):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(context: ir.Context | None = None) -> CopyReplicatedAttr:
"""Creates a CopyReplicatedAttr."""
class CopyPartitionedAttr(CopyPartitionAttrInterface):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(axis: int, context: ir.Context | None = None) -> CopyPartitionedAttr:
"""Creates a CopyPartitionedAttr."""
@property
def axis(self) -> int: ...
class Tiling:
def __init__(self, tiles: Iterable) -> None: ...
def tile_shape(self, shape: Sequence[int]) -> tuple: ...
def untile_shape(self, shape: Sequence[int]) -> tuple: ...
def tile_strides(self, strides: Sequence[int]) -> tuple: ...
def tile_indices(self, indices: Sequence[int]) -> tuple: ...
def untile_indices(self, indices: Sequence[int]) -> tuple: ...
def tile_nested_shape_strides(
self, shape: Sequence[Sequence[int]], strides: Sequence[Sequence[int]]
) -> tuple: ...
def tile_dimension(self, dim: int) -> tuple: ...
def remove_dimension(self, dim: int) -> Tiling: ...
def canonicalize(self) -> Tiling: ...
@property
def tiles(self) -> tuple: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
class Replicated:
def __init__(self, times: int) -> None: ...
@property
def times(self) -> int: ...
@times.setter
def times(self, arg: int, /) -> None: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __eq__(self, arg: object, /) -> bool: ...
class TiledLayout:
def __init__(
self,
tiling: Tiling,
warp_dims: Iterable,
lane_dims: Iterable,
vector_dim: int,
_check_canonical: bool = ...,
) -> None: ...
@property
def warp_dims(self) -> tuple: ...
@property
def lane_dims(self) -> tuple: ...
@property
def partitioned_warp_dims(self) -> tuple: ...
@property
def partitioned_lane_dims(self) -> tuple: ...
@property
def vector_length(self) -> int: ...
@property
def vector_dim(self) -> int: ...
@property
def tiling(self) -> Tiling: ...
@property
def tiled_tiling_shape(self) -> tuple: ...
@property
def tiled_tiling_rank(self) -> int: ...
def warp_indices(self) -> tuple: ...
def lane_indices(self) -> tuple: ...
def canonicalize(self) -> TiledLayout: ...
def registers_shape(self, shape: Sequence[int]) -> tuple: ...
def registers_element_type(self, t: ir.Type) -> ir.Type: ...
def shape_from_registers_shape(self, shape: Sequence[int]) -> tuple: ...
@property
def base_tile_shape(self) -> tuple: ...
def remove_dimension(self, dim: int) -> TiledLayout: ...
def reduce(self, axes: Iterable) -> TiledLayout: ...
def thread_idxs(self, arg: Sequence[int], /) -> list: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
def __eq__(self, other: object | None) -> bool: ...
class Rounding(enum.Enum):
UP = 0
DOWN = 1
class TileTransform:
def __init__(
self, tiling: Sequence[int], rounding: Rounding | None = ...
) -> None: ...
def apply(self, arg: object, /) -> ir.Value: ...
def transform_index(self, arg: Iterable, /) -> tuple: ...
def transform_shape(self, arg: Sequence[int], /) -> tuple: ...
def transform_strides(self, arg: Sequence[int], /) -> tuple: ...
class TrivialTransferPlan:
def __init__(self) -> None: ...
@property
def tile_index_transforms(self) -> tuple: ...
def select(self, arg: Iterable, /) -> ir.Value: ...
def select_if_group(
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
) -> ir.Value: ...
class StaggeredTransferPlan:
def __init__(
self, stagger: int, dim: int, size: int, group_pred: object
) -> None: ...
@property
def stagger(self) -> int: ...
@property
def dim(self) -> int: ...
@property
def size(self) -> int: ...
@property
def group_pred(self) -> ir.Value: ...
@property
def tile_index_transforms(self) -> tuple: ...
def select(self, arg: Iterable, /) -> ir.Value: ...
def select_if_group(
self, arg0: int, arg1: ir.Value, arg2: ir.Value, /
) -> ir.Value: ...
@@ -0,0 +1,210 @@
"""SDY main Python extension"""
from typing import Self
from collections.abc import Sequence
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
class MeshAxisAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, name: str, size: int, context: ir.Context | None = None) -> Self:
"""Creates a MeshAxisAttr with the given axis name and size."""
@property
def name(self) -> str: ...
@property
def size(self) -> int: ...
class MeshAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, mesh_axes: Sequence[ir.Attribute], device_ids: Sequence[int] = [], context: ir.Context | None = None) -> Self:
"""Creates a MeshAttr with the given mesh axes."""
@property
def device_ids(self) -> list[int]: ...
@property
def axes(self) -> list[ir.Attribute]: ...
class SubAxisInfoAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, pre_size: int, size: int, context: ir.Context | None = None) -> Self:
"""Creates a SubAxisInfoAttr with the given pre-size and size."""
@property
def pre_size(self) -> int: ...
@property
def size(self) -> int: ...
class AxisRefAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, name: str, sub_axis_info: ir.Attribute | None = None, context: ir.Context | None = None) -> Self:
"""Creates an AxisRefAttr with the given name and SubAxisInfoAttr."""
@property
def name(self) -> str: ...
@property
def sub_axis_info(self) -> ir.Attribute | None: ...
class DimensionShardingAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, axes: Sequence[ir.Attribute], is_closed: bool, priority: int | None = None, context: ir.Context | None = None) -> Self:
"""
Creates a DimensionShardingAttr with the given axes, whether it's closed, and priority.
"""
@property
def axes(self) -> list[ir.Attribute]: ...
@property
def is_closed(self) -> bool: ...
@property
def priority(self) -> int | None: ...
class TensorShardingAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, mesh_or_ref: str | ir.Attribute, dimension_shardings: Sequence[ir.Attribute], replicated_axes: Sequence[ir.Attribute] = [], unreduced_axes: Sequence[ir.Attribute] = [], context: ir.Context | None = None) -> Self:
"""
Creates a TensorShardingAttr with either an inlined mesh or mesh name, dimension shardings, and replicated axes.
"""
@property
def mesh_or_ref(self) -> ir.Attribute: ...
@property
def dimension_shardings(self) -> list[ir.Attribute]: ...
@property
def replicated_axes(self) -> list[ir.Attribute]: ...
@property
def unreduced_axes(self) -> list[ir.Attribute]: ...
class TensorShardingPerValueAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, shardings: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
"""Creates a TensorShardingPerValueAttr with the tensor shardings."""
@property
def shardings(self) -> list[ir.Attribute]: ...
class DimMappingAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, factor_indices: Sequence[int], context: ir.Context | None = None) -> Self:
"""Creates a DimMappingAttr with the factor indices."""
@property
def factor_indices(self) -> list[int]: ...
class TensorMappingAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, dim_mappings: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
"""Creates a TensorMappingAttr with the dim mappings."""
@property
def dim_mappings(self) -> list[ir.Attribute]: ...
@property
def rank(self) -> int: ...
class OpShardingRuleAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, factor_sizes: Sequence[int], operand_mappings: Sequence[ir.Attribute], result_mappings: Sequence[ir.Attribute], reduction_factors: Sequence[int] = [], need_replication_factors: Sequence[int] = [], permutation_factors: Sequence[int] = [], blocked_propagation_factors: Sequence[int] = [], is_custom: bool = False, context: ir.Context | None = None) -> Self:
"""
Creates a OpShardingRuleAttr with the factor sizes and mappings for operands and results.
"""
@property
def is_custom(self) -> bool: ...
@property
def factor_sizes(self) -> list[int]: ...
@property
def operand_mappings(self) -> list[ir.Attribute]: ...
@property
def result_mappings(self) -> list[ir.Attribute]: ...
@property
def reduction_factors(self) -> list[int]: ...
@property
def need_replication_factors(self) -> list[int]: ...
@property
def permutation_factors(self) -> list[int]: ...
@property
def blocked_propagation_factors(self) -> list[int]: ...
class ManualAxesAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, manual_axes: Sequence[ir.Attribute], context: ir.Context | None = None) -> Self:
"""Creates a ManualAxesAttr with the given manual axes."""
def __getitem__(self, arg: int, /) -> str: ...
def __len__(self) -> int: ...
Binary file not shown.
@@ -0,0 +1,23 @@
"""MPMD main Python extension"""
from typing import Self
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
class UserOriginAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, name: str, transpose_count: int, context: ir.Context | None = None) -> Self:
"""Creates a UserOriginAttr with the given user name and transpose count."""
@property
def user_name(self) -> str: ...
@property
def transpose_count(self) -> int: ...
Binary file not shown.
@@ -0,0 +1,407 @@
"""stablehlo main python extension"""
from collections.abc import Sequence
import enum
from typing import overload, Self
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = True) -> None: ...
def register_interpreter_dialect(context: ir.Context, load: bool = True) -> None: ...
def register_stablehlo_passes() -> None: ...
class TokenType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, context: ir.Context | None = None) -> Self:
"""Creates a Token type."""
class FutureType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, types: Sequence[ir.Type], context: ir.Context | None = None) -> Self:
"""Creates a Future type."""
@property
def types(self) -> list[ir.Type]: ...
class ScatterDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, update_window_dims: Sequence[int], inserted_window_dims: Sequence[int], input_batching_dims: Sequence[int], scatter_indices_batching_dims: Sequence[int], scattered_dims_to_operand_dims: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
"""
Creates a ScatterDimensionNumbers with the given dimension configuration.
"""
@property
def update_window_dims(self) -> list[int]: ...
@property
def inserted_window_dims(self) -> list[int]: ...
@property
def input_batching_dims(self) -> list[int]: ...
@property
def scatter_indices_batching_dims(self) -> list[int]: ...
@property
def scattered_dims_to_operand_dims(self) -> list[int]: ...
@property
def index_vector_dim(self) -> int: ...
class GatherDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, offset_dims: Sequence[int], collapsed_slice_dims: Sequence[int], operand_batching_dims: Sequence[int], start_indices_batching_dims: Sequence[int], start_index_map: Sequence[int], index_vector_dim: int, context: ir.Context | None = None) -> Self:
"""
Creates a GatherDimensionNumbers attribute with the given dimension configuration.
"""
@property
def offset_dims(self) -> list[int]: ...
@property
def collapsed_slice_dims(self) -> list[int]: ...
@property
def operand_batching_dims(self) -> list[int]: ...
@property
def start_indices_batching_dims(self) -> list[int]: ...
@property
def start_index_map(self) -> list[int]: ...
@property
def index_vector_dim(self) -> int: ...
class DotAlgorithm(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, lhs_precision_type: ir.Type, rhs_precision_type: ir.Type, accumulation_type: ir.Type, lhs_component_count: int, rhs_component_count: int, num_primitive_operations: int, allow_imprecise_accumulation: bool, ctx: ir.Context | None = None) -> Self:
"""
Creates a DotAlgorithm attribute with the given dimension configuration.
"""
@property
def lhs_precision_type(self) -> ir.Type: ...
@property
def rhs_precision_type(self) -> ir.Type: ...
@property
def accumulation_type(self) -> ir.Type: ...
@property
def lhs_component_count(self) -> int: ...
@property
def rhs_component_count(self) -> int: ...
@property
def num_primitive_operations(self) -> int: ...
@property
def allow_imprecise_accumulation(self) -> bool: ...
class DotDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, lhs_batching_dimensions: Sequence[int], rhs_batching_dimensions: Sequence[int], lhs_contracting_dimensions: Sequence[int], rhs_contracting_dimensions: Sequence[int], context: ir.Context | None = None) -> Self:
"""
Creates a DotDimensionNumbers attribute with the given dimension configuration.
"""
@property
def lhs_batching_dimensions(self) -> list[int]: ...
@property
def rhs_batching_dimensions(self) -> list[int]: ...
@property
def lhs_contracting_dimensions(self) -> list[int]: ...
@property
def rhs_contracting_dimensions(self) -> list[int]: ...
class ConvDimensionNumbers(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, input_batch_dimension: int, input_feature_dimension: int, input_spatial_dimensions: Sequence[int], kernel_input_feature_dimension: int, kernel_output_feature_dimension: int, kernel_spatial_dimensions: Sequence[int], output_batch_dimension: int, output_feature_dimension: int, output_spatial_dimensions: Sequence[int], ctx: ir.Context | None = None) -> Self:
"""
Creates a ConvDimensionNumbers attribute with the given dimension configuration.
"""
@property
def input_batch_dimension(self) -> int: ...
@property
def input_feature_dimension(self) -> int: ...
@property
def input_spatial_dimensions(self) -> list[int]: ...
@property
def kernel_input_feature_dimension(self) -> int: ...
@property
def kernel_output_feature_dimension(self) -> int: ...
@property
def kernel_spatial_dimensions(self) -> list[int]: ...
@property
def output_batch_dimension(self) -> int: ...
@property
def output_feature_dimension(self) -> int: ...
@property
def output_spatial_dimensions(self) -> list[int]: ...
class OutputOperandAlias(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, output_tuple_indices: Sequence[int], operand_index: int, operand_tuple_indices: Sequence[int], ctx: ir.Context | None = None) -> Self:
"""Creates a OutputOperandAlias attribute with the given tuple index."""
@property
def output_tuple_indices(self) -> list[int]: ...
@property
def operand_index(self) -> int: ...
@property
def operand_tuple_indices(self) -> list[int]: ...
class ComparisonDirectionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonDirection attribute with the given value."""
@property
def value(self) -> str: ...
class ComparisonTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ComparisonType attribute with the given value."""
@property
def value(self) -> str: ...
class PrecisionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Precision attribute with the given value."""
@property
def value(self) -> str: ...
class FftTypeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a FftType attribute with the given value."""
@property
def value(self) -> str: ...
class TransposeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a Transpose attribute with the given value."""
@property
def value(self) -> str: ...
class RngDistributionAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a RngDistribution attribute with the given value."""
@property
def value(self) -> str: ...
class RngAlgorithmAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a RngAlgorithm attribute with the given value."""
@property
def value(self) -> str: ...
class ChannelHandle(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, handle: int, type: int, context: ir.Context | None = None) -> Self:
"""Creates a ChannelHandle attribute."""
@property
def handle(self) -> int: ...
@property
def channel_type(self) -> int: ...
class TypeExtensions(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, bounds: Sequence[int], context: ir.Context | None = None) -> Self:
"""Creates a TypeExtensions with the given bounds."""
@property
def bounds(self) -> list[int]: ...
class ResultAccuracyAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, atol: float, rtol: float, ulps: int, mode: str, context: ir.Context | None = None) -> Self:
"""Creates a ResultAccuracyAttr with the given values."""
@property
def atol(self) -> float: ...
@property
def rtol(self) -> float: ...
@property
def ulps(self) -> int: ...
@property
def mode(self) -> str: ...
class ResultAccuracyModeAttr(ir.Attribute):
@staticmethod
def isinstance(other_attribute: ir.Attribute) -> bool: ...
def __repr__(self) -> str: ...
@classmethod
def get(cls, value: str, context: ir.Context | None = None) -> Self:
"""Creates a ResultAccuracyModeAttr with the given values."""
@property
def value(self) -> str: ...
def get_api_version() -> int: ...
def get_smaller_version(version1: str, version2: str) -> str: ...
def get_current_version() -> str: ...
def get_minimum_version() -> str: ...
@overload
def serialize_portable_artifact_str(module_str: str, target_version: str) -> bytes: ...
@overload
def serialize_portable_artifact_str(module_str: bytes, target_version: str) -> bytes: ...
@overload
def deserialize_portable_artifact_str(artifact_str: str) -> bytes: ...
@overload
def deserialize_portable_artifact_str(artifact_str: bytes) -> bytes: ...
class StablehloCompatibilityRequirement(enum.Enum):
NONE = 0
WEEK_4 = 1
WEEK_12 = 2
MAX = 3
def get_version_from_compatibility_requirement(requirement: StablehloCompatibilityRequirement) -> str: ...
def serialize_portable_artifact(module: ir.Module, target: str, allow_other_dialects: bool = False) -> bytes: ...
@overload
def deserialize_portable_artifact(context: ir.Context, artifact: str) -> ir.Module: ...
@overload
def deserialize_portable_artifact(context: ir.Context, artifact: bytes) -> ir.Module: ...
def eval_module(module: ir.Module, args: Sequence[ir.Attribute], probe_instrumentation_dir: str = '') -> list[ir.Attribute]: ...
@@ -0,0 +1,32 @@
# Copyright 2026 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mlir import ir
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
def private_has_communication(arg: ir.Operation, /) -> tuple[bool, bool]: ...
def private_set_arg_attr(
arg0: ir.Operation, arg1: int, arg2: str, arg3: ir.Attribute, /
) -> None: ...
class Float8EXMYType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get(
exmy_type: ir.Type | None = None, ctx: ir.Context | None = None
) -> Float8EXMYType: ...
@property
def underlying_type(self) -> ir.Type: ...
Binary file not shown.
@@ -0,0 +1,36 @@
# Copyright 2026 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jaxlib.mlir import ir
def register_dialect(context: ir.Context, load: bool = ...) -> None: ...
class PointerType(ir.Type):
@staticmethod
def isinstance(other_type: ir.Type) -> bool: ...
def __repr__(self) -> str: ...
@staticmethod
def get_static_typeid() -> ir.TypeID: ...
@staticmethod
def get(pointee_type: ir.Type, address_space: int) -> PointerType:
"""Creates a PointerType type."""
@property
def pointee_type(self) -> ir.Type: ...
@property
def address_space(self) -> int: ...
def infer_reduce_op_encoding(
arg0: ir.Attribute, arg1: int, /
) -> ir.Attribute | None: ...