hand
This commit is contained in:
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,272 @@
|
||||
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
|
||||
from enum import IntEnum, auto, IntFlag
|
||||
from jaxlib.mlir.dialects._ods_common import _cext as _ods_cext
|
||||
from jaxlib.mlir.ir import register_attribute_builder
|
||||
_ods_ir = _ods_cext.ir
|
||||
|
||||
class ContractPrecision(IntEnum):
|
||||
"""Contraction precision"""
|
||||
|
||||
kBF16 = 0
|
||||
kFP32 = 1
|
||||
|
||||
def __str__(self):
|
||||
if self is ContractPrecision.kBF16:
|
||||
return "bf16"
|
||||
if self is ContractPrecision.kFP32:
|
||||
return "fp32"
|
||||
raise ValueError("Unknown ContractPrecision enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_ContractPrecision", allow_existing=True)
|
||||
def _tpu_contractprecision(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class CoreType(IntEnum):
|
||||
"""Core type"""
|
||||
|
||||
kTc = 0
|
||||
kScScalarSubcore = 1
|
||||
kScVectorSubcore = 2
|
||||
|
||||
def __str__(self):
|
||||
if self is CoreType.kTc:
|
||||
return "tc"
|
||||
if self is CoreType.kScScalarSubcore:
|
||||
return "sc_scalar_subcore"
|
||||
if self is CoreType.kScVectorSubcore:
|
||||
return "sc_vector_subcore"
|
||||
raise ValueError("Unknown CoreType enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_CoreType", allow_existing=True)
|
||||
def _tpu_coretype(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class DimensionSemantics(IntEnum):
|
||||
"""Dimension semantics"""
|
||||
|
||||
parallel = 0
|
||||
arbitrary = 1
|
||||
core_parallel = 2
|
||||
subcore_parallel = 3
|
||||
|
||||
def __str__(self):
|
||||
if self is DimensionSemantics.parallel:
|
||||
return "parallel"
|
||||
if self is DimensionSemantics.arbitrary:
|
||||
return "arbitrary"
|
||||
if self is DimensionSemantics.core_parallel:
|
||||
return "core_parallel"
|
||||
if self is DimensionSemantics.subcore_parallel:
|
||||
return "subcore_parallel"
|
||||
raise ValueError("Unknown DimensionSemantics enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_DimensionSemantics", allow_existing=True)
|
||||
def _tpu_dimensionsemantics(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class MemorySpace(IntEnum):
|
||||
"""Memory space"""
|
||||
|
||||
kAny = 4294967295
|
||||
kVmem = 0
|
||||
kSmem = 1
|
||||
kHbm = 2
|
||||
kCmem = 3
|
||||
kSemaphoreMem = 4
|
||||
kVmemShared = 5
|
||||
kHost = 6
|
||||
|
||||
def __str__(self):
|
||||
if self is MemorySpace.kAny:
|
||||
return "any"
|
||||
if self is MemorySpace.kVmem:
|
||||
return "vmem"
|
||||
if self is MemorySpace.kSmem:
|
||||
return "smem"
|
||||
if self is MemorySpace.kHbm:
|
||||
return "hbm"
|
||||
if self is MemorySpace.kCmem:
|
||||
return "cmem"
|
||||
if self is MemorySpace.kSemaphoreMem:
|
||||
return "semaphore_mem"
|
||||
if self is MemorySpace.kVmemShared:
|
||||
return "vmem_shared"
|
||||
if self is MemorySpace.kHost:
|
||||
return "host"
|
||||
raise ValueError("Unknown MemorySpace enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_MemorySpace", allow_existing=True)
|
||||
def _tpu_memoryspace(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class PackFormat(IntEnum):
|
||||
"""Pack format"""
|
||||
|
||||
kCompressed = 0
|
||||
kInterleaved = 1
|
||||
|
||||
def __str__(self):
|
||||
if self is PackFormat.kCompressed:
|
||||
return "compressed"
|
||||
if self is PackFormat.kInterleaved:
|
||||
return "interleaved"
|
||||
raise ValueError("Unknown PackFormat enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_PackFormat", allow_existing=True)
|
||||
def _tpu_packformat(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class PipelineMode(IntEnum):
|
||||
"""Pipeline mode"""
|
||||
|
||||
kSynchronous = 1
|
||||
kDoubleBuffered = 2
|
||||
|
||||
def __str__(self):
|
||||
if self is PipelineMode.kSynchronous:
|
||||
return "synchronous"
|
||||
if self is PipelineMode.kDoubleBuffered:
|
||||
return "double_buffered"
|
||||
raise ValueError("Unknown PipelineMode enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_PipelineMode", allow_existing=True)
|
||||
def _tpu_pipelinemode(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class ReductionKind(IntEnum):
|
||||
"""Reduction kind"""
|
||||
|
||||
kSum = 0
|
||||
kMax = 1
|
||||
kMin = 2
|
||||
kArgMax = 3
|
||||
kArgMin = 4
|
||||
kFindFirstSet = 5
|
||||
|
||||
def __str__(self):
|
||||
if self is ReductionKind.kSum:
|
||||
return "sum"
|
||||
if self is ReductionKind.kMax:
|
||||
return "max"
|
||||
if self is ReductionKind.kMin:
|
||||
return "min"
|
||||
if self is ReductionKind.kArgMax:
|
||||
return "arg_max"
|
||||
if self is ReductionKind.kArgMin:
|
||||
return "arg_min"
|
||||
if self is ReductionKind.kFindFirstSet:
|
||||
return "find_first_set"
|
||||
raise ValueError("Unknown ReductionKind enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_ReductionKind", allow_existing=True)
|
||||
def _tpu_reductionkind(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class RevisitMode(IntEnum):
|
||||
"""Revisit mode"""
|
||||
|
||||
kImmediate = 0
|
||||
kAny = 1
|
||||
|
||||
def __str__(self):
|
||||
if self is RevisitMode.kImmediate:
|
||||
return "immediate"
|
||||
if self is RevisitMode.kAny:
|
||||
return "any"
|
||||
raise ValueError("Unknown RevisitMode enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_RevisitMode", allow_existing=True)
|
||||
def _tpu_revisitmode(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class RoundingMode(IntEnum):
|
||||
"""Rounding mode"""
|
||||
|
||||
kTowardsZero = 0
|
||||
kToNearestEven = 1
|
||||
|
||||
def __str__(self):
|
||||
if self is RoundingMode.kTowardsZero:
|
||||
return "towards_zero"
|
||||
if self is RoundingMode.kToNearestEven:
|
||||
return "to_nearest_even"
|
||||
raise ValueError("Unknown RoundingMode enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_RoundingMode", allow_existing=True)
|
||||
def _tpu_roundingmode(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
class VectorLayoutDim(IntEnum):
|
||||
"""allowed 32-bit signless integer cases: 0, 1, 2"""
|
||||
|
||||
tiled = 0
|
||||
lanes = 1
|
||||
sublanes = 2
|
||||
|
||||
def __str__(self):
|
||||
if self is VectorLayoutDim.tiled:
|
||||
return "tiled"
|
||||
if self is VectorLayoutDim.lanes:
|
||||
return "lanes"
|
||||
if self is VectorLayoutDim.sublanes:
|
||||
return "sublanes"
|
||||
raise ValueError("Unknown VectorLayoutDim enum entry.")
|
||||
|
||||
|
||||
|
||||
@register_attribute_builder("TPU_VectorLayoutDim", allow_existing=True)
|
||||
def _tpu_vectorlayoutdim(x, context):
|
||||
return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
@register_attribute_builder("tpu.TPU_ContractPrecisionEnum")
|
||||
def _tpu_contractprecisionenum(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#tpu.contract_precision<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("tpu.TPU_CoreTypeEnum")
|
||||
def _tpu_coretypeenum(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#tpu.core_type<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("tpu.TPU_DimensionSemanticsEnum")
|
||||
def _tpu_dimensionsemanticsenum(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#tpu.dimension_semantics<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("tpu.TPU_PackFormatEnum")
|
||||
def _tpu_packformatenum(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#tpu.pack_format<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("tpu.TPU_PipelineModeEnum")
|
||||
def _tpu_pipelinemodeenum(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#tpu.pipeline_mode<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("tpu.TPU_ReductionKindAttr")
|
||||
def _tpu_reductionkindattr(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#tpu.reduction_kind<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("tpu.TPU_RevisitModeEnum")
|
||||
def _tpu_revisitmodeenum(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#tpu.revisit_mode<{str(x)}>', context=context)
|
||||
|
||||
@register_attribute_builder("tpu.TPU_RoundingModeEnum")
|
||||
def _tpu_roundingmodeenum(x, context):
|
||||
return _ods_ir.Attribute.parse(f'#tpu.rounding_mode<{str(x)}>', context=context)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,58 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Python definitions for third_party.jax.jaxlib.mosaic.python.tpu
|
||||
|
||||
These definitions are needed internally by the tpu module.
|
||||
TODO(tlongeri): Migrate definitions to tpu module
|
||||
"""
|
||||
import collections
|
||||
import enum
|
||||
from typing import Literal
|
||||
|
||||
TargetTuple = collections.namedtuple("TargetTuple", ["sublanes", "lanes"])
|
||||
|
||||
@enum.unique
|
||||
class Direction(enum.Enum):
|
||||
SUBLANES = "sublanes"
|
||||
LANES = "lanes"
|
||||
SUBELEMENTS = "subelements"
|
||||
|
||||
def __repr__(self):
|
||||
return self.name.lower()
|
||||
SUBLANES = Direction.SUBLANES
|
||||
LANES = Direction.LANES
|
||||
SUBELEMENTS = Direction.SUBELEMENTS
|
||||
|
||||
|
||||
class Replicated(enum.Enum):
|
||||
REPLICATED = "*"
|
||||
|
||||
def __repr__(self):
|
||||
return "*"
|
||||
__str__ = __repr__
|
||||
|
||||
def __bool__(self):
|
||||
return False # Useful because we can then say `offset or 0`
|
||||
REPLICATED = Replicated.REPLICATED
|
||||
Offset = int | Literal[REPLICATED]
|
||||
|
||||
|
||||
class ImplicitDim(enum.IntEnum):
|
||||
MINOR = -1
|
||||
SECOND_MINOR = -2
|
||||
MINOR_AND_SECOND_MINOR = -3
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(int(self))
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Python bindings for the MLIR Mosaic GPU dialect.
|
||||
|
||||
Note: this file *must* be called `mosaic_gpu.py`, in order to match the dialect
|
||||
name. Otherwise, MLIR is unable to find the module during dialect search.
|
||||
"""
|
||||
|
||||
# ruff: noqa: F401
|
||||
# ruff: noqa: F403
|
||||
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_ops import *
|
||||
from jaxlib.mosaic.dialect.gpu import _mosaic_gpu_gen_ops
|
||||
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_enums import *
|
||||
from jaxlib.mlir._mlir_libs._mosaic_gpu_ext import *
|
||||
|
||||
try:
|
||||
from jaxlib.mlir.dialects._ods_common import _cext
|
||||
except ImportError:
|
||||
from mlir.dialects._ods_common import _cext
|
||||
|
||||
|
||||
# Add the parent module to the search prefix
|
||||
_cext.globals.append_dialect_search_prefix(__name__[:__name__.rfind(".")])
|
||||
|
||||
|
||||
@_cext.register_operation(_mosaic_gpu_gen_ops._Dialect, replace=True)
|
||||
class WarpMapOp(_mosaic_gpu_gen_ops.WarpMapOp): # noqa: F405
|
||||
"""An extension to the automatically generated WarpMapOp bindings."""
|
||||
|
||||
def __init__(self, operands, *, loc=None, ip=None):
|
||||
super().__init__(operands, loc=loc, ip=ip)
|
||||
args_ty = [o.type for o in self.operands_]
|
||||
self.regions[0].blocks.append(*args_ty) # Append the block.
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
def warp_map(operands, *, loc=None, ip=None) -> WarpMapOp:
|
||||
return WarpMapOp(operands, loc=loc, ip=ip)
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Python bindings for the MLIR TPU dialect."""
|
||||
|
||||
# ruff: noqa: F401
|
||||
# ruff: noqa: F403
|
||||
|
||||
from ._tpu_enum_gen import *
|
||||
from . import _tpu_ops_gen
|
||||
from ._tpu_ops_gen import *
|
||||
from ._tpu_ops_gen import _Dialect, VectorLoadOp, VectorStoreOp
|
||||
from jaxlib.mlir._mlir_libs._tpu_ext import *
|
||||
try:
|
||||
from jaxlib.mlir.dialects._ods_common import _cext
|
||||
except ImportError:
|
||||
from mlir.dialects._ods_common import _cext
|
||||
|
||||
|
||||
_cext.globals.append_dialect_search_prefix("jax.jaxlib.mosaic.python")
|
||||
|
||||
|
||||
@_cext.register_operation(_Dialect, replace=True)
|
||||
class TraceOp(_tpu_ops_gen.TraceOp): # noqa: F405
|
||||
"""An extension to the automatically generated TraceOp bindings."""
|
||||
|
||||
def __init__(self, results, message, level, *, loc=None, ip=None):
|
||||
super().__init__(results, message, level, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append(*[]) # Append the block.
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
@_cext.register_operation(_Dialect, replace=True)
|
||||
class RegionOp(_tpu_ops_gen.RegionOp): # noqa: F405
|
||||
"""An extension to the automatically generated RegionOp bindings."""
|
||||
|
||||
def __init__(self, results, *, loc=None, ip=None):
|
||||
super().__init__(results, loc=loc, ip=ip)
|
||||
self.regions[0].blocks.append() # Append the block.
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
return self.regions[0].blocks[0]
|
||||
|
||||
|
||||
def vector_load(
|
||||
result,
|
||||
base,
|
||||
indices,
|
||||
*,
|
||||
strides=None,
|
||||
mask=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if strides is None:
|
||||
strides = []
|
||||
return VectorLoadOp(
|
||||
result, base, indices, strides, mask=mask, loc=loc, ip=ip
|
||||
).result
|
||||
|
||||
|
||||
def vector_store(
|
||||
value_to_store,
|
||||
base,
|
||||
indices,
|
||||
*,
|
||||
strides=None,
|
||||
add=False,
|
||||
mask=None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if strides is None:
|
||||
strides = []
|
||||
return VectorStoreOp(
|
||||
value_to_store, base, indices, strides, mask=mask, add=add, loc=loc, ip=ip
|
||||
)
|
||||
Reference in New Issue
Block a user