hand
This commit is contained in:
@@ -0,0 +1,112 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Experimental GPU backend for Pallas targeting H100.
|
||||
|
||||
These APIs are highly unstable and can change weekly. Use at your own risk.
|
||||
"""
|
||||
|
||||
from jax._src.pallas.mosaic_gpu.core import Barrier as Barrier
|
||||
from jax._src.pallas.mosaic_gpu.core import BlockSpec as BlockSpec
|
||||
from jax._src.pallas.mosaic_gpu.core import ClusterBarrier as ClusterBarrier
|
||||
from jax._src.pallas.mosaic_gpu.core import CompilerParams as CompilerParams
|
||||
from jax._src.pallas.mosaic_gpu.core import kernel as kernel
|
||||
from jax._src.pallas.mosaic_gpu.core import Layout as Layout
|
||||
from jax._src.pallas.mosaic_gpu.core import layout_cast as layout_cast
|
||||
from jax._src.pallas.mosaic_gpu.core import MemorySpace as MemorySpace
|
||||
from jax._src.pallas.mosaic_gpu.core import Mesh as Mesh
|
||||
from jax._src.pallas.mosaic_gpu.core import multicast_ref as multicast_ref
|
||||
from jax._src.pallas.mosaic_gpu.core import PeerMemRef as PeerMemRef
|
||||
from jax._src.pallas.mosaic_gpu.core import RefUnion as RefUnion
|
||||
from jax._src.pallas.mosaic_gpu.core import remote_ref as remote_ref
|
||||
from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType
|
||||
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TMEMLayout as TMEMLayout
|
||||
from jax._src.pallas.mosaic_gpu.core import transform_ref as transform_ref
|
||||
from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref
|
||||
from jax._src.pallas.mosaic_gpu.core import TransposeTransform as TransposeTransform
|
||||
from jax._src.pallas.mosaic_gpu.core import TryClusterCancelResult as TryClusterCancelResult
|
||||
from jax._src.pallas.mosaic_gpu.core import unswizzle_ref as unswizzle_ref
|
||||
from jax._src.pallas.mosaic_gpu.core import untile_ref as untile_ref
|
||||
from jax._src.pallas.mosaic_gpu.core import WarpMesh as WarpMesh
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC # noqa: F401
|
||||
from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as WGMMAAccumulatorRef
|
||||
from jax._src.pallas.mosaic_gpu.helpers import dynamic_scheduling_loop as dynamic_scheduling_loop
|
||||
from jax._src.pallas.mosaic_gpu.helpers import find_swizzle as find_swizzle
|
||||
from jax._src.pallas.mosaic_gpu.helpers import format_tcgen05_sparse_metadata as format_tcgen05_sparse_metadata
|
||||
from jax._src.pallas.mosaic_gpu.helpers import nd_loop as nd_loop
|
||||
from jax._src.pallas.mosaic_gpu.helpers import NDLoopInfo as NDLoopInfo
|
||||
from jax._src.pallas.mosaic_gpu.helpers import planar_snake as planar_snake
|
||||
from jax._src.pallas.mosaic_gpu.helpers import warp_map as warp_map
|
||||
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline as emit_pipeline
|
||||
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized
|
||||
from jax._src.pallas.mosaic_gpu.pipeline import PipelinePipeline as PipelinePipeline
|
||||
from jax._src.pallas.mosaic_gpu.primitives import async_copy_scales_to_tmem as async_copy_scales_to_tmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_tmem as async_copy_smem_to_tmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import async_copy_sparse_metadata_to_tmem as async_copy_sparse_metadata_to_tmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import async_load_tmem as async_load_tmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import async_prefetch as async_prefetch
|
||||
from jax._src.pallas.mosaic_gpu.primitives import async_store_tmem as async_store_tmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import atomic_add as atomic_add
|
||||
from jax._src.pallas.mosaic_gpu.primitives import atomic_and as atomic_and
|
||||
from jax._src.pallas.mosaic_gpu.primitives import atomic_max as atomic_max
|
||||
from jax._src.pallas.mosaic_gpu.primitives import atomic_min as atomic_min
|
||||
from jax._src.pallas.mosaic_gpu.primitives import atomic_or as atomic_or
|
||||
from jax._src.pallas.mosaic_gpu.primitives import atomic_xor as atomic_xor
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_arrive as barrier_arrive
|
||||
from jax._src.pallas.mosaic_gpu.primitives import barrier_wait as barrier_wait
|
||||
from jax._src.pallas.mosaic_gpu.primitives import broadcasted_iota as broadcasted_iota
|
||||
from jax._src.pallas.mosaic_gpu.primitives import commit_smem as commit_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import commit_smem_to_gmem_group as commit_smem_to_gmem_group
|
||||
from jax._src.pallas.mosaic_gpu.primitives import commit_tmem as commit_tmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_gmem_to_smem as copy_gmem_to_smem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import copy_smem_to_gmem as copy_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import inline_mgpu as inline_mgpu
|
||||
from jax._src.pallas.mosaic_gpu.primitives import load as load
|
||||
from jax._src.pallas.mosaic_gpu.primitives import multimem_load_reduce as multimem_load_reduce
|
||||
from jax._src.pallas.mosaic_gpu.primitives import multimem_store as multimem_store
|
||||
from jax._src.pallas.mosaic_gpu.primitives import print_layout as print_layout
|
||||
from jax._src.pallas.mosaic_gpu.primitives import query_cluster_cancel as query_cluster_cancel
|
||||
from jax._src.pallas.mosaic_gpu.primitives import RefType as RefType
|
||||
from jax._src.pallas.mosaic_gpu.primitives import semaphore_signal_multicast as semaphore_signal_multicast
|
||||
from jax._src.pallas.mosaic_gpu.primitives import semaphore_signal_parallel as semaphore_signal_parallel
|
||||
from jax._src.pallas.mosaic_gpu.primitives import SemaphoreSignal as SemaphoreSignal
|
||||
from jax._src.pallas.mosaic_gpu.primitives import set_max_registers as set_max_registers
|
||||
from jax._src.pallas.mosaic_gpu.primitives import ShapeDtypeStruct as ShapeDtypeStruct
|
||||
from jax._src.pallas.mosaic_gpu.primitives import tcgen05_commit_arrive as tcgen05_commit_arrive
|
||||
from jax._src.pallas.mosaic_gpu.primitives import tcgen05_mma as tcgen05_mma
|
||||
from jax._src.pallas.mosaic_gpu.primitives import try_cluster_cancel as try_cluster_cancel
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_load_tmem as wait_load_tmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem as wait_smem_to_gmem
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma as wgmma
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_load as wgmma_accumulator_load
|
||||
from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait as wgmma_wait
|
||||
from jax._src.pallas.mosaic_gpu.torch import as_torch_kernel as as_torch_kernel
|
||||
from jax._src.state.types import Transform as Transform
|
||||
from jax.experimental.mosaic.gpu.core import LoweringSemantics as LoweringSemantics
|
||||
from jax.experimental.mosaic.gpu.fragmented_array import Replicated as Replicated
|
||||
from jax.experimental.mosaic.gpu.fragmented_array import Tiling as Tiling
|
||||
from jax.experimental.mosaic.gpu.launch_context import CopyPartition as CopyPartition
|
||||
from jax.experimental.mosaic.gpu.launch_context import OOBFillMode as OOBFillMode
|
||||
|
||||
|
||||
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.GMEM`.
|
||||
GMEM = MemorySpace.GMEM
|
||||
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.SMEM`.
|
||||
SMEM = MemorySpace.SMEM
|
||||
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.TMEM`.
|
||||
TMEM = MemorySpace.TMEM
|
||||
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.MemorySpace.REGS`.
|
||||
REGS = MemorySpace.REGS
|
||||
Reference in New Issue
Block a user