This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,272 @@
# Autogenerated by mlir-tblgen; don't manually edit.
from enum import IntEnum, auto, IntFlag
from jaxlib.mlir.dialects._ods_common import _cext as _ods_cext
from jaxlib.mlir.ir import register_attribute_builder
_ods_ir = _ods_cext.ir
class ContractPrecision(IntEnum):
"""Contraction precision"""
kBF16 = 0
kFP32 = 1
def __str__(self):
if self is ContractPrecision.kBF16:
return "bf16"
if self is ContractPrecision.kFP32:
return "fp32"
raise ValueError("Unknown ContractPrecision enum entry.")
@register_attribute_builder("TPU_ContractPrecision", allow_existing=True)
def _tpu_contractprecision(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class CoreType(IntEnum):
"""Core type"""
kTc = 0
kScScalarSubcore = 1
kScVectorSubcore = 2
def __str__(self):
if self is CoreType.kTc:
return "tc"
if self is CoreType.kScScalarSubcore:
return "sc_scalar_subcore"
if self is CoreType.kScVectorSubcore:
return "sc_vector_subcore"
raise ValueError("Unknown CoreType enum entry.")
@register_attribute_builder("TPU_CoreType", allow_existing=True)
def _tpu_coretype(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class DimensionSemantics(IntEnum):
"""Dimension semantics"""
parallel = 0
arbitrary = 1
core_parallel = 2
subcore_parallel = 3
def __str__(self):
if self is DimensionSemantics.parallel:
return "parallel"
if self is DimensionSemantics.arbitrary:
return "arbitrary"
if self is DimensionSemantics.core_parallel:
return "core_parallel"
if self is DimensionSemantics.subcore_parallel:
return "subcore_parallel"
raise ValueError("Unknown DimensionSemantics enum entry.")
@register_attribute_builder("TPU_DimensionSemantics", allow_existing=True)
def _tpu_dimensionsemantics(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class MemorySpace(IntEnum):
"""Memory space"""
kAny = 4294967295
kVmem = 0
kSmem = 1
kHbm = 2
kCmem = 3
kSemaphoreMem = 4
kVmemShared = 5
kHost = 6
def __str__(self):
if self is MemorySpace.kAny:
return "any"
if self is MemorySpace.kVmem:
return "vmem"
if self is MemorySpace.kSmem:
return "smem"
if self is MemorySpace.kHbm:
return "hbm"
if self is MemorySpace.kCmem:
return "cmem"
if self is MemorySpace.kSemaphoreMem:
return "semaphore_mem"
if self is MemorySpace.kVmemShared:
return "vmem_shared"
if self is MemorySpace.kHost:
return "host"
raise ValueError("Unknown MemorySpace enum entry.")
@register_attribute_builder("TPU_MemorySpace", allow_existing=True)
def _tpu_memoryspace(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class PackFormat(IntEnum):
"""Pack format"""
kCompressed = 0
kInterleaved = 1
def __str__(self):
if self is PackFormat.kCompressed:
return "compressed"
if self is PackFormat.kInterleaved:
return "interleaved"
raise ValueError("Unknown PackFormat enum entry.")
@register_attribute_builder("TPU_PackFormat", allow_existing=True)
def _tpu_packformat(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class PipelineMode(IntEnum):
"""Pipeline mode"""
kSynchronous = 1
kDoubleBuffered = 2
def __str__(self):
if self is PipelineMode.kSynchronous:
return "synchronous"
if self is PipelineMode.kDoubleBuffered:
return "double_buffered"
raise ValueError("Unknown PipelineMode enum entry.")
@register_attribute_builder("TPU_PipelineMode", allow_existing=True)
def _tpu_pipelinemode(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class ReductionKind(IntEnum):
"""Reduction kind"""
kSum = 0
kMax = 1
kMin = 2
kArgMax = 3
kArgMin = 4
kFindFirstSet = 5
def __str__(self):
if self is ReductionKind.kSum:
return "sum"
if self is ReductionKind.kMax:
return "max"
if self is ReductionKind.kMin:
return "min"
if self is ReductionKind.kArgMax:
return "arg_max"
if self is ReductionKind.kArgMin:
return "arg_min"
if self is ReductionKind.kFindFirstSet:
return "find_first_set"
raise ValueError("Unknown ReductionKind enum entry.")
@register_attribute_builder("TPU_ReductionKind", allow_existing=True)
def _tpu_reductionkind(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class RevisitMode(IntEnum):
"""Revisit mode"""
kImmediate = 0
kAny = 1
def __str__(self):
if self is RevisitMode.kImmediate:
return "immediate"
if self is RevisitMode.kAny:
return "any"
raise ValueError("Unknown RevisitMode enum entry.")
@register_attribute_builder("TPU_RevisitMode", allow_existing=True)
def _tpu_revisitmode(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class RoundingMode(IntEnum):
"""Rounding mode"""
kTowardsZero = 0
kToNearestEven = 1
def __str__(self):
if self is RoundingMode.kTowardsZero:
return "towards_zero"
if self is RoundingMode.kToNearestEven:
return "to_nearest_even"
raise ValueError("Unknown RoundingMode enum entry.")
@register_attribute_builder("TPU_RoundingMode", allow_existing=True)
def _tpu_roundingmode(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
class VectorLayoutDim(IntEnum):
"""allowed 32-bit signless integer cases: 0, 1, 2"""
tiled = 0
lanes = 1
sublanes = 2
def __str__(self):
if self is VectorLayoutDim.tiled:
return "tiled"
if self is VectorLayoutDim.lanes:
return "lanes"
if self is VectorLayoutDim.sublanes:
return "sublanes"
raise ValueError("Unknown VectorLayoutDim enum entry.")
@register_attribute_builder("TPU_VectorLayoutDim", allow_existing=True)
def _tpu_vectorlayoutdim(x, context):
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
@register_attribute_builder("tpu.TPU_ContractPrecisionEnum")
def _tpu_contractprecisionenum(x, context):
return _ods_ir.Attribute.parse(f'#tpu.contract_precision<{str(x)}>', context=context)
@register_attribute_builder("tpu.TPU_CoreTypeEnum")
def _tpu_coretypeenum(x, context):
return _ods_ir.Attribute.parse(f'#tpu.core_type<{str(x)}>', context=context)
@register_attribute_builder("tpu.TPU_DimensionSemanticsEnum")
def _tpu_dimensionsemanticsenum(x, context):
return _ods_ir.Attribute.parse(f'#tpu.dimension_semantics<{str(x)}>', context=context)
@register_attribute_builder("tpu.TPU_PackFormatEnum")
def _tpu_packformatenum(x, context):
return _ods_ir.Attribute.parse(f'#tpu.pack_format<{str(x)}>', context=context)
@register_attribute_builder("tpu.TPU_PipelineModeEnum")
def _tpu_pipelinemodeenum(x, context):
return _ods_ir.Attribute.parse(f'#tpu.pipeline_mode<{str(x)}>', context=context)
@register_attribute_builder("tpu.TPU_ReductionKindAttr")
def _tpu_reductionkindattr(x, context):
return _ods_ir.Attribute.parse(f'#tpu.reduction_kind<{str(x)}>', context=context)
@register_attribute_builder("tpu.TPU_RevisitModeEnum")
def _tpu_revisitmodeenum(x, context):
return _ods_ir.Attribute.parse(f'#tpu.revisit_mode<{str(x)}>', context=context)
@register_attribute_builder("tpu.TPU_RoundingModeEnum")
def _tpu_roundingmodeenum(x, context):
return _ods_ir.Attribute.parse(f'#tpu.rounding_mode<{str(x)}>', context=context)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,58 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python definitions for third_party.jax.jaxlib.mosaic.python.tpu
These definitions are needed internally by the tpu module.
TODO(tlongeri): Migrate definitions to tpu module
"""
import collections
import enum
from typing import Literal
TargetTuple = collections.namedtuple("TargetTuple", ["sublanes", "lanes"])
@enum.unique
class Direction(enum.Enum):
SUBLANES = "sublanes"
LANES = "lanes"
SUBELEMENTS = "subelements"
def __repr__(self):
return self.name.lower()
SUBLANES = Direction.SUBLANES
LANES = Direction.LANES
SUBELEMENTS = Direction.SUBELEMENTS
class Replicated(enum.Enum):
REPLICATED = "*"
def __repr__(self):
return "*"
__str__ = __repr__
def __bool__(self):
return False # Useful because we can then say `offset or 0`
REPLICATED = Replicated.REPLICATED
Offset = int | Literal[REPLICATED]
class ImplicitDim(enum.IntEnum):
MINOR = -1
SECOND_MINOR = -2
MINOR_AND_SECOND_MINOR = -3
def __repr__(self) -> str:
return str(int(self))
@@ -0,0 +1,53 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python bindings for the MLIR Mosaic GPU dialect.
Note: this file *must* be called `mosaic_gpu.py`, in order to match the dialect
name. Otherwise, MLIR is unable to find the module during dialect search.
"""
# ruff: noqa: F401
# ruff: noqa: F403
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_ops import *
from jaxlib.mosaic.dialect.gpu import _mosaic_gpu_gen_ops
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_enums import *
from jaxlib.mlir._mlir_libs._mosaic_gpu_ext import *
try:
from jaxlib.mlir.dialects._ods_common import _cext
except ImportError:
from mlir.dialects._ods_common import _cext
# Add the parent module to the search prefix
_cext.globals.append_dialect_search_prefix(__name__[:__name__.rfind(".")])
@_cext.register_operation(_mosaic_gpu_gen_ops._Dialect, replace=True)
class WarpMapOp(_mosaic_gpu_gen_ops.WarpMapOp): # noqa: F405
"""An extension to the automatically generated WarpMapOp bindings."""
def __init__(self, operands, *, loc=None, ip=None):
super().__init__(operands, loc=loc, ip=ip)
args_ty = [o.type for o in self.operands_]
self.regions[0].blocks.append(*args_ty) # Append the block.
@property
def body(self):
return self.regions[0].blocks[0]
def warp_map(operands, *, loc=None, ip=None) -> WarpMapOp:
return WarpMapOp(operands, loc=loc, ip=ip)
@@ -0,0 +1,92 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python bindings for the MLIR TPU dialect."""
# ruff: noqa: F401
# ruff: noqa: F403
from ._tpu_enum_gen import *
from . import _tpu_ops_gen
from ._tpu_ops_gen import *
from ._tpu_ops_gen import _Dialect, VectorLoadOp, VectorStoreOp
from jaxlib.mlir._mlir_libs._tpu_ext import *
try:
from jaxlib.mlir.dialects._ods_common import _cext
except ImportError:
from mlir.dialects._ods_common import _cext
_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
@_cext.register_operation(_Dialect, replace=True)
class TraceOp(_tpu_ops_gen.TraceOp): # noqa: F405
"""An extension to the automatically generated TraceOp bindings."""
def __init__(self, results, message, level, *, loc=None, ip=None):
super().__init__(results, message, level, loc=loc, ip=ip)
self.regions[0].blocks.append(*[]) # Append the block.
@property
def body(self):
return self.regions[0].blocks[0]
@_cext.register_operation(_Dialect, replace=True)
class RegionOp(_tpu_ops_gen.RegionOp): # noqa: F405
"""An extension to the automatically generated RegionOp bindings."""
def __init__(self, results, *, loc=None, ip=None):
super().__init__(results, loc=loc, ip=ip)
self.regions[0].blocks.append() # Append the block.
@property
def body(self):
return self.regions[0].blocks[0]
def vector_load(
result,
base,
indices,
*,
strides=None,
mask=None,
loc=None,
ip=None,
):
if strides is None:
strides = []
return VectorLoadOp(
result, base, indices, strides, mask=mask, loc=loc, ip=ip
).result
def vector_store(
value_to_store,
base,
indices,
*,
strides=None,
add=False,
mask=None,
loc=None,
ip=None,
):
if strides is None:
strides = []
return VectorStoreOp(
value_to_store, base, indices, strides, mask=mask, add=add, loc=loc, ip=ip
)