1416 lines
48 KiB
Python
1416 lines
48 KiB
Python
# Copyright 2023 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Pallas-specific JAX primitives."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Callable, Hashable, Sequence
|
|
import enum
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import string
|
|
from typing import Any
|
|
|
|
from jax._src import ad_util
|
|
from jax._src import api_util
|
|
from jax._src import config
|
|
from jax._src import core as jax_core
|
|
from jax._src import debugging
|
|
from jax._src import dtypes
|
|
from jax._src import effects
|
|
from jax._src import linear_util as lu
|
|
from jax._src import numpy as jnp
|
|
from jax._src import pretty_printer as pp
|
|
from jax._src import source_info_util
|
|
from jax._src import state
|
|
from jax._src import tree_util
|
|
from jax._src import typing as jax_typing
|
|
from jax._src import util
|
|
from jax._src.interpreters import ad
|
|
from jax._src.interpreters import partial_eval as pe
|
|
import jax._src.lax as lax
|
|
from jax._src.pallas import core as pallas_core
|
|
from jax._src.pallas import utils as pallas_utils
|
|
from jax._src.state import discharge as state_discharge
|
|
from jax._src.state import indexing
|
|
from jax._src.state import primitives as sp
|
|
from jax._src.state import types as state_types
|
|
from jax.interpreters import mlir
|
|
|
|
|
|
Slice = indexing.Slice
|
|
NDIndexer = indexing.NDIndexer
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
program_id_p = jax_core.Primitive("program_id")
|
|
|
|
def program_id(axis: int) -> jax_typing.Array:
|
|
"""Returns the kernel execution position along the given axis of the grid.
|
|
|
|
For example, with a 2D ``grid`` in the kernel execution corresponding to the
|
|
grid coordinates ``(1, 2)``,
|
|
``program_id(axis=0)`` returns ``1`` and ``program_id(axis=1)`` returns ``2``.
|
|
|
|
The returned value is an array of shape ``()`` and dtype ``int32``.
|
|
|
|
Args:
|
|
axis: the axis of the grid along which to count the program.
|
|
"""
|
|
return program_id_p.bind(axis=axis)
|
|
|
|
def program_id_bind_with_trace(trace, _, avals, params):
|
|
axis = params.pop("axis")
|
|
grid_env = pallas_core.current_grid_env()
|
|
if grid_env:
|
|
return grid_env[axis].index
|
|
frame = pallas_core.axis_frame()
|
|
# Query the size of the axis to make sure it's a valid axis (and error
|
|
# otherwise).
|
|
_ = frame.size(axis)
|
|
return jax_core.Primitive.bind_with_trace(program_id_p, trace, (), avals,
|
|
dict(axis=axis))
|
|
# TODO(dougalm): figure out how put the grid_env contest on the relevant trace
|
|
program_id_p.def_bind_with_trace(program_id_bind_with_trace)
|
|
|
|
@program_id_p.def_abstract_eval
|
|
def _program_id_abstract_eval(**_):
|
|
return jax_core.ShapedArray((), jnp.int32)
|
|
|
|
num_programs_p = jax_core.Primitive("num_programs")
|
|
|
|
def num_programs(axis: int) -> int | jax_typing.Array:
|
|
"""Returns the size of the grid along the given axis."""
|
|
return num_programs_p.bind(axis=axis)
|
|
|
|
def _num_programs_bind_with_trace(trace, _, avals, params):
|
|
axis = params.pop("axis")
|
|
# We might be using a local grid env
|
|
grid_env = pallas_core.current_grid_env()
|
|
if grid_env:
|
|
return grid_env[axis].size
|
|
# Otherwise, we look up the size of the grid in the axis env
|
|
frame = pallas_core.axis_frame()
|
|
size = frame.size(axis)
|
|
if size is pallas_core.dynamic_grid_dim:
|
|
return jax_core.Primitive.bind_with_trace(num_programs_p, trace, (), avals,
|
|
dict(axis=axis))
|
|
return size
|
|
num_programs_p.def_bind_with_trace(_num_programs_bind_with_trace)
|
|
|
|
@num_programs_p.def_abstract_eval
|
|
def _num_programs_abstract_eval(**_):
|
|
return jax_core.ShapedArray((), jnp.int32)
|
|
|
|
multiple_of_p = jax_core.Primitive("multiple_of")
|
|
|
|
multiple_of_p.def_impl(lambda x, **_: x)
|
|
mlir.register_lowering(multiple_of_p, lambda _, x, **__: [x])
|
|
|
|
def multiple_of(x: jax_typing.Array, values: Sequence[int] | int) -> jax_typing.Array:
|
|
"""A compiler hint that asserts a value is a static multiple of another.
|
|
|
|
Note that misusing this function, such as asserting ``x`` is a multiple of
|
|
``N`` when it is not, can result in undefined behavior.
|
|
|
|
Args:
|
|
x: The input array.
|
|
values: A set of static divisors that ``x`` is a multiple of.
|
|
|
|
Returns:
|
|
A copy of ``x``.
|
|
"""
|
|
values = (values,) if isinstance(values, int) else tuple(values)
|
|
return multiple_of_p.bind(x, values=values)
|
|
|
|
@multiple_of_p.def_abstract_eval
|
|
def _multiple_of_abstract_eval(aval, **_):
|
|
return aval
|
|
|
|
load_p = jax_core.Primitive('masked_load')
|
|
|
|
|
|
@load_p.def_effectful_abstract_eval
|
|
def _load_abstract_eval(*avals_flat, args_tree, **_):
|
|
ref_aval, transforms, mask_aval, _ = args_tree.unflatten(avals_flat)
|
|
assert transforms is not None
|
|
transformed_ref = pallas_core.TransformedRef(ref_aval, transforms)
|
|
if mask_aval is not None:
|
|
try:
|
|
# pyrefly: ignore[no-matching-overload]
|
|
jnp.broadcast_shapes(transformed_ref.shape, mask_aval.shape)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Cannot broadcast mask shape {mask_aval.shape} to load shape"
|
|
f" {transformed_ref.shape}"
|
|
)
|
|
return (
|
|
jax_core.ShapedArray(transformed_ref.shape, transformed_ref.dtype),
|
|
{state.ReadEffect(0)},
|
|
)
|
|
|
|
|
|
def _load_pp_rule(eqn, context, settings):
|
|
# Pretty prints `a = load x i` as `x[i] <- a`
|
|
y, = eqn.outvars
|
|
x, transforms, mask, other = tree_util.tree_unflatten(
|
|
eqn.params["args_tree"], eqn.invars
|
|
)
|
|
# TODO(sharadmv): pretty print mask and other
|
|
annotation = (source_info_util.summarize(eqn.source_info)
|
|
if settings.source_info else None)
|
|
lhs = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
|
result = [lhs, pp.text(" <- ", annotation=annotation),
|
|
sp.pp_ref_transforms(context, x, transforms)]
|
|
if mask is not None:
|
|
result += [
|
|
pp.text(" "),
|
|
pp.text("mask="),
|
|
pp.text(jax_core.pp_var(mask, context)),
|
|
]
|
|
if other is not None:
|
|
result += [
|
|
pp.text(" "),
|
|
pp.text("other="),
|
|
pp.text(jax_core.pp_var(other, context)),
|
|
]
|
|
return pp.concat(result)
|
|
jax_core.pp_eqn_rules[load_p] = _load_pp_rule
|
|
|
|
|
|
def _load_jvp(primals, tangents, args_tree, **params):
|
|
ref_primal, transforms, mask, other_primal = args_tree.unflatten(primals)
|
|
ref_tangent, _, _, other_tangent = args_tree.unflatten(tangents)
|
|
if other_tangent is not None:
|
|
other_tangent = ad_util.instantiate(other_tangent)
|
|
return (
|
|
load_p.bind(
|
|
*tree_util.tree_leaves((ref_primal, transforms, mask, other_primal)),
|
|
args_tree=args_tree,
|
|
**params,
|
|
),
|
|
load_p.bind(
|
|
*tree_util.tree_leaves(
|
|
(ref_tangent, transforms, mask, other_tangent)
|
|
),
|
|
args_tree=args_tree,
|
|
**params,
|
|
),
|
|
)
|
|
|
|
|
|
ad.primitive_jvps[load_p] = _load_jvp
|
|
|
|
def uninitialized_value(shape, dtype):
|
|
if jnp.issubdtype(dtype, jnp.floating):
|
|
return jnp.full(shape, jnp.nan, dtype)
|
|
# Note: Currently semaphore is i16[], meaning this case needs to be
|
|
# handled before the general case for integers.
|
|
# TODO(justinfu): Handle semaphores with a custom extended dtype.
|
|
elif jnp.issubdtype(dtype, pallas_core.SEMAPHORE_INTERPRET_DTYPE):
|
|
return jnp.full(shape, 0, dtype)
|
|
elif jnp.issubdtype(dtype, jnp.integer):
|
|
return jnp.full(shape, jnp.iinfo(dtype).min, dtype)
|
|
elif jnp.issubdtype(dtype, jnp.bool):
|
|
return jnp.full(shape, False, dtype)
|
|
elif jnp.issubdtype(dtype, pallas_core.semaphore_dtype):
|
|
return jnp.full(shape, 0, dtype)
|
|
raise NotImplementedError(dtype)
|
|
|
|
def _pad_values_to_avoid_dynamic_slice_oob_shift(value,
|
|
slice_sizes, unpad=False):
|
|
"""
|
|
DynamicSlice and DynamicUpdateSlice adjust the start index in cases where the
|
|
requested slice overruns the bounds of the array. This pads the array with
|
|
uninitialised values such that the requested slice will never overrun.
|
|
|
|
For example, if arr is [1.,2.,3.,4.] and a slice of size 4, start index 2 is
|
|
requested then the result will be [3.,4.,NaN,NaN] after padding, rather than
|
|
[1.,2.,3.,4.] from the unpadded array
|
|
|
|
unpad=True performs the inverse operation
|
|
"""
|
|
|
|
padding_config = tuple((0, slice_size, 0) for slice_size in slice_sizes)
|
|
if unpad:
|
|
padding_config = tuple((-low, -high, -interior)
|
|
for (low, high, interior) in padding_config)
|
|
padding_value = uninitialized_value(shape=(), dtype=value.dtype)
|
|
value = lax.pad(value,
|
|
padding_config=padding_config,
|
|
padding_value=padding_value)
|
|
return value
|
|
|
|
_unpad_values_to_avoid_dynamic_slice_oob_shift = functools.partial(
|
|
_pad_values_to_avoid_dynamic_slice_oob_shift, unpad=True
|
|
)
|
|
|
|
|
|
@state_discharge.register_discharge_rule(load_p)
|
|
def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
|
del out_avals # Unused.
|
|
ref, transforms, mask, other = args_tree.unflatten(args_flat)
|
|
transforms = list(transforms)
|
|
if not transforms or not isinstance(transforms[-1], indexing.NDIndexer):
|
|
ref_aval = state.transform_type(transforms, in_avals[0])
|
|
assert isinstance(ref_aval, state.AbstractRef)
|
|
transforms.append(indexing.NDIndexer.make_trivial_indexer(ref_aval.shape))
|
|
*prev_transforms, idx = transforms
|
|
assert isinstance(idx, NDIndexer)
|
|
ref = state_discharge.transform_array(ref, prev_transforms)
|
|
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): # pyrefly: ignore[missing-attribute]
|
|
# TODO(ayx): support strided load/store in interpret mode.
|
|
for s in idx.indices:
|
|
if isinstance(s, Slice) and s.stride > 1:
|
|
raise NotImplementedError("Unimplemented stride support.")
|
|
indices = idx.indices
|
|
scalar_dims = [not isinstance(s, Slice) and not s.shape for s in indices] # pyrefly: ignore[missing-attribute]
|
|
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
|
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
|
# fixes an inconsistency with lax.dynamic_slice where if the slice goes out
|
|
# of bounds, it will instead move the start_index backwards so the slice
|
|
# will fit in memory.
|
|
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
|
|
idx_dtype = dtypes.default_int_dtype()
|
|
out_ones = lax.dynamic_slice(
|
|
ref,
|
|
[jnp.astype(s, idx_dtype) for s in slice_starts],
|
|
slice_sizes=slice_sizes,
|
|
)
|
|
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
|
out = out_ones[out_indexer]
|
|
elif all(not isinstance(s, Slice) for s in idx.indices):
|
|
out = ref[idx.indices]
|
|
else:
|
|
raise NotImplementedError
|
|
if mask is not None and other is not None:
|
|
out = jnp.where(mask, out, other)
|
|
return (None,) * len(in_avals), out
|
|
|
|
|
|
swap_p = jax_core.Primitive('masked_swap')
|
|
|
|
|
|
@swap_p.def_effectful_abstract_eval
|
|
def _swap_abstract_eval(*avals_flat, args_tree, **_):
|
|
ref, transforms, val, mask = args_tree.unflatten(avals_flat)
|
|
assert transforms is not None
|
|
transformed_ref = pallas_core.TransformedRef(ref, transforms)
|
|
expected_output_shape = transformed_ref.shape
|
|
expected_output_dtype = transformed_ref.dtype
|
|
if expected_output_shape != val.shape:
|
|
raise ValueError(
|
|
f"Invalid shape for `swap`. Ref shape: {ref.shape}. "
|
|
f"Value shape: {val.shape}. Transforms: {transforms}. "
|
|
)
|
|
if expected_output_dtype != val.dtype:
|
|
raise ValueError(
|
|
f"Invalid dtype for `swap`. Ref dtype: {expected_output_dtype}. "
|
|
f"Value dtype: {val.dtype}. "
|
|
)
|
|
return (
|
|
jax_core.ShapedArray(expected_output_shape, expected_output_dtype),
|
|
{state.WriteEffect(0)},
|
|
)
|
|
|
|
|
|
def _swap_pp_rule(eqn, context, settings):
|
|
# Pretty prints `a = swap x v i` as `a, x[i] <- x[i], v`
|
|
# or:
|
|
# Pretty prints `_ = swap x v i` as `x[i] <- v`
|
|
y, = eqn.outvars
|
|
x, transforms, val, mask = eqn.params["args_tree"].unflatten(eqn.invars)
|
|
x_i = sp.pp_ref_transforms(context, x, transforms)
|
|
annotation = (source_info_util.summarize(eqn.source_info)
|
|
if settings.source_info else None)
|
|
if isinstance(y, jax_core.DropVar):
|
|
return pp.concat([
|
|
x_i,
|
|
pp.text(" <- ", annotation=annotation),
|
|
pp.text(jax_core.pp_var(val, context))])
|
|
y = jax_core.pp_vars([y], context, print_shapes=settings.print_shapes)
|
|
result = [
|
|
y,
|
|
pp.text(", "),
|
|
x_i,
|
|
pp.text(" <- ", annotation=annotation),
|
|
x_i,
|
|
pp.text(", "),
|
|
pp.text(jax_core.pp_var(val, context)),
|
|
]
|
|
if mask is not None:
|
|
result += [
|
|
pp.text(" "),
|
|
pp.text("mask="),
|
|
pp.text(jax_core.pp_var(mask, context)),
|
|
]
|
|
return pp.concat(result)
|
|
jax_core.pp_eqn_rules[swap_p] = _swap_pp_rule
|
|
|
|
|
|
def _swap_jvp(primals, tangents, *, args_tree, **params):
|
|
ref_primal, transforms, val_primal, mask = args_tree.unflatten(primals)
|
|
ref_tangent, _, val_tangent, _ = args_tree.unflatten(tangents)
|
|
val_tangent = ad_util.instantiate(val_tangent)
|
|
return (
|
|
swap_p.bind(
|
|
*tree_util.tree_leaves((ref_primal, transforms, val_primal, mask)),
|
|
args_tree=args_tree,
|
|
**params,
|
|
),
|
|
swap_p.bind(
|
|
*tree_util.tree_leaves((ref_tangent, transforms, val_tangent, mask)),
|
|
args_tree=args_tree,
|
|
**params,
|
|
),
|
|
)
|
|
|
|
|
|
ad.primitive_jvps[swap_p] = _swap_jvp
|
|
|
|
|
|
@state_discharge.register_discharge_rule(swap_p)
|
|
def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
|
del out_avals # Unused.
|
|
ref, transforms, val, mask = args_tree.unflatten(args_flat)
|
|
transforms = list(transforms)
|
|
if not transforms or not isinstance(transforms[-1], indexing.NDIndexer):
|
|
ref_aval = state.transform_type(transforms, in_avals[0])
|
|
assert isinstance(ref_aval, state.AbstractRef)
|
|
transforms.append(indexing.NDIndexer.make_trivial_indexer(ref_aval.shape))
|
|
*prev_transforms, idx = transforms
|
|
assert isinstance(idx, NDIndexer)
|
|
ref = state_discharge.transform_array(ref, prev_transforms)
|
|
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): # pyrefly: ignore[missing-attribute]
|
|
# TODO(ayx): support strided load/store in interpret mode.
|
|
for s in idx.indices:
|
|
if isinstance(s, Slice) and s.stride > 1:
|
|
raise NotImplementedError("Unimplemented stride support.")
|
|
indices = idx.indices
|
|
scalar_dims = [
|
|
i
|
|
for i, s in enumerate(indices)
|
|
if not isinstance(s, Slice) and not s.shape # pyrefly: ignore[missing-attribute]
|
|
]
|
|
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
|
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
|
# Fixes an inconsistency with lax.dynamic_update_slice where if the slice
|
|
# goes out of bounds, it will instead move the start_index backwards so the
|
|
# slice will fit in memory.
|
|
ref = _pad_values_to_avoid_dynamic_slice_oob_shift(ref, slice_sizes)
|
|
out = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
|
out = jnp.squeeze(out, scalar_dims)
|
|
if mask is not None:
|
|
out_ = out
|
|
out = jnp.where(mask, out, val)
|
|
val = jnp.where(mask, val, out_)
|
|
val = jnp.expand_dims(val, scalar_dims)
|
|
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
|
|
x_new = _unpad_values_to_avoid_dynamic_slice_oob_shift(x_new, slice_sizes)
|
|
elif all(not isinstance(s, Slice) for s in idx.indices):
|
|
out = ref[idx.indices]
|
|
if mask is not None:
|
|
out_ = out
|
|
out = jnp.where(mask, out, val)
|
|
val = jnp.where(mask, val, out_)
|
|
x_new = ref.at[idx.indices].set(val)
|
|
else:
|
|
raise NotImplementedError
|
|
return (x_new,) + (None,) * (len(in_avals) - 1), out
|
|
|
|
|
|
def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None,
|
|
eviction_policy=None, volatile=False) -> jax_typing.Array:
|
|
"""Returns an array loaded from the given index.
|
|
|
|
If neither ``mask`` nor ``other`` is specified, this function has the same
|
|
semantics as ``x_ref_or_view[idx]`` in JAX.
|
|
|
|
Args:
|
|
x_ref_or_view: The ref to load from.
|
|
idx: The indexer to use.
|
|
mask: An optional boolean mask specifying which indices to load.
|
|
If mask is ``False`` and ``other`` is not given, no assumptions can
|
|
be made about the value in the resulting array.
|
|
other: An optional value to use for indices where mask is ``False``.
|
|
cache_modifier: TO BE DOCUMENTED.
|
|
eviction_policy: TO BE DOCUMENTED.
|
|
volatile: TO BE DOCUMENTED.
|
|
"""
|
|
x_ref, transforms = sp.get_ref_and_transforms(x_ref_or_view, idx, "load")
|
|
args_flat, args_tree = tree_util.tree_flatten(
|
|
(x_ref, transforms, mask, other)
|
|
)
|
|
return load_p.bind(
|
|
*args_flat,
|
|
args_tree=args_tree,
|
|
cache_modifier=cache_modifier,
|
|
eviction_policy=eviction_policy,
|
|
is_volatile=volatile,
|
|
)
|
|
|
|
def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None,
|
|
_function_name="swap") -> jax_typing.Array:
|
|
"""Swaps the value at the given index and returns the old value.
|
|
|
|
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
|
|
|
|
Returns:
|
|
The value stored in the ref prior to the swap.
|
|
"""
|
|
x_ref, transforms = sp.get_ref_and_transforms(
|
|
x_ref_or_view, idx, _function_name
|
|
)
|
|
args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask))
|
|
return swap_p.bind(
|
|
*args_flat, args_tree=args_tree, eviction_policy=eviction_policy
|
|
)
|
|
|
|
def store(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None) -> None:
|
|
"""Stores a value at the given index.
|
|
|
|
See :func:`~jax.experimental.pallas.load` for the meaning of the arguments.
|
|
"""
|
|
_ = swap(x_ref_or_view, idx, val, mask=mask, eviction_policy=eviction_policy,
|
|
_function_name="store")
|
|
|
|
|
|
def _handle_small(dtype: jax_typing.DTypeLike):
|
|
"""Ugly workaround to support types that don't allow automatic promotion."""
|
|
if dtype == jnp.int4:
|
|
return jnp.int8
|
|
if dtype == jnp.float8_e4m3b11fnuz:
|
|
return jnp.bfloat16
|
|
return dtype
|
|
|
|
|
|
def dot(a, b, trans_a: bool = False, trans_b: bool = False,
|
|
allow_tf32: bool | None = None, precision=None):
|
|
"""Computes the dot product of two arrays.
|
|
|
|
The inputs can optionally be transposed before computing the
|
|
product. Depending on the hardware, this can be cheaper than
|
|
computing the transpose beforehand.
|
|
|
|
Args:
|
|
a: The left-hand size of the dot product, of shape ``(..., N)``.
|
|
b: The right-hand size of the dot product, of shape ``(...N, M)``.
|
|
trans_a: Whether to transpose ``a`` before the product.
|
|
trans_b: Whether to transpose ``b`` before the product.
|
|
allow_tf32: Whether to use tf32 precision.
|
|
Mutually exclusive with ``precision``.
|
|
precision: Specifies the precision of the dot product.
|
|
|
|
See Also:
|
|
:func:`jax.numpy.dot`
|
|
"""
|
|
if (a.ndim != 2) or (b.ndim != 2):
|
|
raise ValueError("`a` and `b` must be 2D arrays.")
|
|
lhs_contract_dim = 0 if trans_a else 1
|
|
rhs_contract_dim = 0 if not trans_b else 1
|
|
if allow_tf32 is not None:
|
|
if precision is not None:
|
|
raise ValueError("Only one of allow_tf32 and precision can be specified")
|
|
precision = lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST
|
|
|
|
dtype = jnp.promote_types(_handle_small(a.dtype), _handle_small(b.dtype))
|
|
if jnp.issubdtype(dtype, jnp.integer):
|
|
out_dtype = jnp.int32
|
|
elif dtype == jnp.float64:
|
|
out_dtype = jnp.float64
|
|
else:
|
|
out_dtype = jnp.float32
|
|
return lax.dot_general(
|
|
a,
|
|
b,
|
|
dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
|
|
precision=precision,
|
|
preferred_element_type=out_dtype,
|
|
)
|
|
|
|
reciprocal_p = jax_core.Primitive("reciprocal")
|
|
|
|
|
|
def reciprocal(x, *, approx=False, full_range=True):
|
|
"""Computes the reciprocal of an array.
|
|
|
|
Args:
|
|
x: The array to compute the reciprocal of.
|
|
approx: Whether to use an approximate reciprocal.
|
|
full_range: Whether to use the full range of the input. If False, compilers
|
|
may produce non-IEEE compliant results for edge cases, but may be faster.
|
|
On TPU, setting it to `False` may produce incorrect results when `x` or
|
|
output is ±inf or NaN; or when `x` is ±1/flt_min or ±0.
|
|
|
|
Returns:
|
|
The reciprocal of the array.
|
|
"""
|
|
return reciprocal_p.bind(x, approx=approx, full_range=full_range)
|
|
|
|
|
|
@reciprocal_p.def_abstract_eval
|
|
def _reciprocal_abstract_eval(x, *, approx, full_range):
|
|
del approx, full_range
|
|
return x
|
|
|
|
|
|
def _reciprocal_lowering_rule(
|
|
ctx: mlir.LoweringRuleContext,
|
|
x,
|
|
*,
|
|
approx=False,
|
|
full_range=True,
|
|
):
|
|
del full_range
|
|
|
|
def _reciprocal(x, *, approx=False):
|
|
if approx:
|
|
return jnp.reciprocal(x.astype(jnp.bfloat16)).astype(jnp.float32)
|
|
return jnp.reciprocal(x)
|
|
|
|
return mlir.lower_fun(_reciprocal, multiple_results=False)(
|
|
ctx, x, approx=approx
|
|
)
|
|
|
|
|
|
mlir.register_lowering(reciprocal_p, _reciprocal_lowering_rule)
|
|
|
|
|
|
def debug_print(fmt: str, *args: jax_typing.ArrayLike):
|
|
"""Prints values from inside a Pallas kernel.
|
|
|
|
Args:
|
|
fmt: A format string to be included in the output. The restrictions on the
|
|
format string depend on the backend:
|
|
|
|
* On GPU, when using Triton, ``fmt`` must not contain any placeholders
|
|
(``{...}``), since it is always printed before any of the values.
|
|
* On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must
|
|
contain a placeholder for each value to be printed. Format specs and
|
|
conversions are not supported. If a single value is provided, the value
|
|
may be an array. Otherwise, all values must be scalars.
|
|
* On TPU, if all inputs are scalars: If ``fmt`` contains placeholders,
|
|
all values must be 32-bit integers. If there are no placeholders, the
|
|
values are printed after the format string.
|
|
* On TPU, if the input is a single vector, the vector is printed after
|
|
the format string. The format string must end with a single placeholder
|
|
``{}``.
|
|
*args: The values to print.
|
|
"""
|
|
return debugging.debug_print(fmt, *args, skip_format_check=True)
|
|
|
|
|
|
def check_debug_print_format(
|
|
fmt: str, *args: jax_typing.ArrayLike
|
|
):
|
|
n_placeholders = 0
|
|
for _, field, spec, conversion in string.Formatter().parse(fmt):
|
|
if field is not None:
|
|
n_placeholders += 1
|
|
if spec or conversion:
|
|
raise ValueError(
|
|
"The format string should not contain any format specs or conversions"
|
|
)
|
|
if field:
|
|
raise ValueError(
|
|
"The format string should not reference arguments by position or name"
|
|
)
|
|
|
|
if len(args) != n_placeholders:
|
|
raise TypeError(
|
|
f"The format string expects {n_placeholders} "
|
|
f"argument{'' if n_placeholders == 1 else 's'}, but got {len(args)}"
|
|
)
|
|
|
|
|
|
# All of those shenanigans are because we can't make TransformedRef a PyTree,
|
|
# because they should appear as atomic JAX values to the users.
|
|
# TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU
|
|
# inferred by the compiler.
|
|
@lu.transformation2
|
|
def wrap_with_transforms(f, transforms, *args):
|
|
new_args = tuple(
|
|
state_types.TransformedRef(a, t) if t else a
|
|
for a, t in zip(args, transforms)
|
|
)
|
|
return f(*new_args)
|
|
|
|
|
|
run_scoped_p = jax_core.Primitive("run_scoped")
|
|
run_scoped_p.multiple_results = True
|
|
|
|
def _run_scoped_is_high(*avals, jaxpr, **params):
|
|
del avals, params
|
|
return jaxpr.is_high
|
|
run_scoped_p.is_high = _run_scoped_is_high
|
|
|
|
def _run_scoped_to_lojax(*args, jaxpr, **params):
|
|
closed_hi_jaxpr = jax_core.ClosedJaxpr(jaxpr, args)
|
|
closed_lo_jaxpr = pe.lower_jaxpr2(closed_hi_jaxpr)
|
|
consts = closed_lo_jaxpr.consts
|
|
return run_scoped_p.bind(*consts, jaxpr=closed_lo_jaxpr.jaxpr, **params)
|
|
run_scoped_p.to_lojax = _run_scoped_to_lojax
|
|
|
|
def run_scoped(
|
|
f: Callable[..., Any],
|
|
*types: Any,
|
|
collective_axes: Hashable | tuple[Hashable, ...] = (),
|
|
**kw_types: Any,
|
|
) -> Any:
|
|
"""Calls the function with allocated references and returns the result.
|
|
|
|
The positional and keyword arguments describe which reference types
|
|
to allocate for each argument. Each backend has its own set of reference
|
|
types in addition to :class:`jax.experimental.pallas.MemoryRef`.
|
|
|
|
When ``collective_axes`` is specified, the same allocation will be returned for
|
|
all programs that only differ in their program ids along the collective axes.
|
|
It is an error not to call the same ``run_scoped`` in all programs along that
|
|
axis.
|
|
"""
|
|
if not isinstance(collective_axes, tuple):
|
|
collective_axes = (collective_axes,)
|
|
flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
|
|
flat_fun, out_tree_thunk = api_util.flatten_fun(
|
|
lu.wrap_init(f,
|
|
debug_info=api_util.debug_info("pallas run_scoped",
|
|
f, types, kw_types)),
|
|
in_tree)
|
|
# We allow ref avals to be transformed references.
|
|
ref_avals = [t.get_ref_aval() for t in flat_types]
|
|
avals = [
|
|
t.ref if isinstance(t, state_types.TransformedRef) else t
|
|
for t in ref_avals
|
|
]
|
|
ref_transforms = tuple(
|
|
t.transforms if isinstance(t, state_types.TransformedRef) else ()
|
|
for t in ref_avals
|
|
)
|
|
flat_fun = wrap_with_transforms(flat_fun, ref_transforms)
|
|
# Turn the function into a jaxpr. The body of run_scoped may have
|
|
# effects (IO) on constvars (i.e. variables inherited from the
|
|
# parent scope). Jax can't reason about effects to references that
|
|
# are not in the invars of an operation so we just put them all
|
|
# there.
|
|
with config.mutable_array_checks(False):
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, avals)
|
|
out = run_scoped_p.bind(*consts, jaxpr=jaxpr, collective_axes=collective_axes)
|
|
return tree_util.tree_unflatten(out_tree_thunk(), out)
|
|
|
|
|
|
@run_scoped_p.def_effectful_abstract_eval
|
|
def _run_scoped_abstract_eval(*args, jaxpr, collective_axes):
|
|
del args, collective_axes
|
|
# jaxpr will have effects for its inputs (Refs that are allocated) and for
|
|
# constvars (closed over Refs). The effects for the allocated Refs are local
|
|
# to the jaxpr and shouldn't propagate out.
|
|
nonlocal_effects = {
|
|
eff
|
|
for eff in jaxpr.effects
|
|
if not (
|
|
isinstance(eff, effects.JaxprInputEffect)
|
|
and eff.input_index >= len(jaxpr.constvars)
|
|
)
|
|
}
|
|
return [v.aval for v in jaxpr.outvars], nonlocal_effects
|
|
|
|
|
|
def _run_scoped_discharge_rule(
|
|
should_discharge,
|
|
in_avals,
|
|
out_avals,
|
|
*args_flat,
|
|
jaxpr,
|
|
collective_axes):
|
|
del out_avals
|
|
if collective_axes:
|
|
raise NotImplementedError(
|
|
"run_scoped discharge does not support collective_axes yet."
|
|
)
|
|
num_consts = len(args_flat)
|
|
# discharge_state only discharges invars, not consts, so in order to
|
|
# discharge the requested refs we need to move them to the invar set.
|
|
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
|
|
num_return_values = len(jaxpr_noconst.outvars)
|
|
discharged_body, new_consts = state_discharge.discharge_state(
|
|
jaxpr_noconst,
|
|
[],
|
|
should_discharge=should_discharge + [False] * len(jaxpr.invars),
|
|
)
|
|
if new_consts:
|
|
raise NotImplementedError(
|
|
"Cannot handle new consts created by state discharge.")
|
|
|
|
# Lowering expects that the jaxpr.consts to be the eqn.invals.
|
|
discharged_body = pe.convert_invars_to_constvars(discharged_body, num_consts)
|
|
|
|
# Run_scoped discharged the external variables but the scoped ones
|
|
# are not discharged.
|
|
out = run_scoped_p.bind(
|
|
*args_flat, jaxpr=discharged_body, collective_axes=collective_axes
|
|
)
|
|
# Order of outputs:
|
|
# (1) return values, (2) closed refs, (3) scoped refs.
|
|
return_values = out[:num_return_values]
|
|
ref_outputs = out[num_return_values:]
|
|
# We update all ref values with their updated values from the discharged
|
|
# body. For other values we leave them in place.
|
|
updates = [
|
|
ref_outputs.pop(0) if should and isinstance(aval, state.AbstractRef)
|
|
else None for should, aval in zip(should_discharge, in_avals)]
|
|
assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}'
|
|
return updates, return_values
|
|
|
|
|
|
state_discharge.register_partial_discharge_rule(run_scoped_p)(
|
|
_run_scoped_discharge_rule)
|
|
|
|
|
|
@functools.partial(mlir.register_lowering, run_scoped_p)
|
|
def _run_scoped_lowering_rule(ctx, *args, jaxpr, collective_axes):
|
|
if collective_axes:
|
|
raise ValueError(
|
|
"run_scoped lowering outside of Pallas does not support"
|
|
" collective_axes."
|
|
)
|
|
jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr)
|
|
num_return_values = len(jaxpr_noconst.outvars)
|
|
discharged_body, new_consts = state_discharge.discharge_state(
|
|
jaxpr_noconst, [], should_discharge=True)
|
|
if new_consts: raise NotImplementedError(
|
|
"Cannot handle new consts created by state discharge.")
|
|
|
|
def _lower_fun(*lower_fun_args):
|
|
# Create inputs filled with uninitialized values to the body.
|
|
num_consts = len(lower_fun_args)
|
|
body_avals = [v.aval for v in discharged_body.invars[num_consts:]]
|
|
init_vals = [
|
|
# pyrefly: ignore[missing-attribute]
|
|
uninitialized_value(aval.shape, aval.dtype) for aval in body_avals
|
|
]
|
|
out = jax_core.eval_jaxpr(discharged_body, [], *lower_fun_args, *init_vals)
|
|
return out[:num_return_values]
|
|
|
|
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)
|
|
|
|
|
|
get_global_p = jax_core.Primitive("get_global")
|
|
get_global_p.multiple_results = False
|
|
get_global_p.ref_primitive = True
|
|
get_global_p.ref_allocating = True
|
|
|
|
def get_global(what: pallas_core.ScratchShape) -> jax_typing.Array:
|
|
"""Returns a global reference that persists across all kernel invocations.
|
|
|
|
Each call to ``get_global`` returns a different and unique reference, but one that
|
|
is stable across invocations of the kernel body.
|
|
|
|
Args:
|
|
what: The reference type to allocate. Each backend has its own set of
|
|
reference types (e.g., :class:`jax.experimental.pallas.mosaic_gpu.SemaphoreType` for GPU).
|
|
|
|
Example::
|
|
|
|
sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR)
|
|
pl.semaphore_signal(sem_ref)
|
|
pl.semaphore_wait(sem_ref)
|
|
"""
|
|
ref_aval = what.get_ref_aval()
|
|
return get_global_p.bind(what=ref_aval)
|
|
|
|
|
|
@get_global_p.def_abstract_eval
|
|
def _get_global_abstract_eval(*, what):
|
|
return what
|
|
|
|
|
|
def _get_global_discharge_rule(in_avals, out_avals, *, what):
|
|
del in_avals, out_avals, what
|
|
raise NotImplementedError(
|
|
"get_global discharge is not supported in interpret mode."
|
|
)
|
|
|
|
|
|
state_discharge.register_discharge_rule(get_global_p)(
|
|
_get_global_discharge_rule
|
|
)
|
|
|
|
|
|
def _get_ref_and_transforms(ref):
|
|
if isinstance(ref, state.TransformedRef):
|
|
return ref.ref, ref.transforms
|
|
return ref, ()
|
|
|
|
|
|
class DeviceIdType(enum.Enum):
|
|
MESH = "mesh"
|
|
LOGICAL = "logical"
|
|
|
|
|
|
def check_sem_avals(
|
|
sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None
|
|
):
|
|
if allowed_semaphore_types is None:
|
|
allowed_semaphore_types = {
|
|
pallas_core.semaphore,
|
|
pallas_core.barrier_semaphore,
|
|
# For interpret mode.
|
|
pallas_core.SEMAPHORE_INTERPRET_DTYPE,
|
|
}
|
|
if not isinstance(sem_aval, state.AbstractRef):
|
|
raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}")
|
|
sem_shape = sem_aval.shape
|
|
if sem_transforms_avals:
|
|
sem_shape = sem_transforms_avals[-1].get_indexer_shape()
|
|
if sem_shape:
|
|
raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}")
|
|
sem_dtype = sem_aval.dtype
|
|
if not any(
|
|
jnp.issubdtype(sem_dtype, sem_type)
|
|
for sem_type in allowed_semaphore_types
|
|
):
|
|
raise ValueError(
|
|
f"Must {name} semaphores of the following types:"
|
|
f" {allowed_semaphore_types}. Got {sem_dtype}."
|
|
)
|
|
|
|
|
|
def _transform_semaphore(ref_value, transforms, ref_aval):
|
|
"""Helper function for indexing into a semaphore during state_discharge."""
|
|
if ref_value.shape == ref_aval.shape:
|
|
return state_discharge.transform_array(ref_value, transforms)
|
|
elif len(ref_value.shape) == 0:
|
|
return ref_value
|
|
else:
|
|
raise ValueError(
|
|
f"Semaphore value shape {ref_value.shape} does not match aval shape"
|
|
f" {ref_aval.shape}"
|
|
)
|
|
|
|
|
|
semaphore_read_p = jax_core.Primitive("semaphore_read")
|
|
semaphore_read_p.multiple_results = False
|
|
|
|
|
|
def semaphore_read(sem_or_view) -> jax_typing.Array:
|
|
"""Reads the value of a semaphore.
|
|
|
|
Args:
|
|
sem_or_view: A Ref (or view) representing a semaphore.
|
|
|
|
Returns:
|
|
A scalar Array containing the value of the semaphore.
|
|
"""
|
|
ref, transforms = _get_ref_and_transforms(sem_or_view)
|
|
args = [ref, transforms]
|
|
flat_args, args_tree = tree_util.tree_flatten(args)
|
|
return semaphore_read_p.bind(*flat_args, args_tree=args_tree)
|
|
|
|
@semaphore_read_p.def_abstract_eval
|
|
def _semaphore_read_abstract_eval(
|
|
*avals,
|
|
args_tree,
|
|
):
|
|
del avals, args_tree
|
|
return jax_core.ShapedArray((), jnp.dtype("int32"))
|
|
|
|
def _semaphore_read_discharge_rule(in_avals,
|
|
out_avals,
|
|
*flat_args,
|
|
args_tree):
|
|
del out_avals
|
|
[ref, transforms] = args_tree.unflatten(flat_args)
|
|
sem_value = _transform_semaphore(ref, transforms, in_avals[0])
|
|
sem_value = sem_value.astype(jnp.int32)
|
|
return (None,) * len(in_avals), sem_value
|
|
state_discharge.register_discharge_rule(semaphore_read_p)(
|
|
_semaphore_read_discharge_rule
|
|
)
|
|
|
|
|
|
DeviceId = (
|
|
int
|
|
| jax_typing.Array
|
|
| None
|
|
| tuple[int | jax_typing.Array, ...]
|
|
| dict[Any, int | jax_typing.Array]
|
|
)
|
|
|
|
class SemaphoreEffect(effects.Effect):
|
|
pass
|
|
sem_effect = SemaphoreEffect()
|
|
effects.control_flow_allowed_effects.add_type(SemaphoreEffect)
|
|
effects.custom_derivatives_allowed_effects.add_type(SemaphoreEffect)
|
|
pallas_core.kernel_local_effects.add_type(SemaphoreEffect)
|
|
|
|
|
|
semaphore_signal_p = jax_core.Primitive('semaphore_signal')
|
|
semaphore_signal_p.multiple_results = True
|
|
|
|
|
|
def semaphore_signal(
|
|
sem_or_view,
|
|
inc: int | jax_typing.Array = 1,
|
|
*,
|
|
device_id: DeviceId = None,
|
|
device_id_type: DeviceIdType = DeviceIdType.MESH,
|
|
core_index: int | jax_typing.Array | None = None,
|
|
):
|
|
"""Increments the value of a semaphore.
|
|
|
|
This operation can also be performed remotely if ``device_id`` is specified,
|
|
in which ``sem_or_view`` refers to a Ref located on another device.
|
|
Note that it is assumed that ``sem_or_view`` is already allocated
|
|
(e.g. through the proper use of barriers), or else this operation could
|
|
result in undefined behavior.
|
|
|
|
Args:
|
|
sem_or_view: A Ref (or view) representing a semaphore.
|
|
inc: The value to increment by.
|
|
device_id (optional): Specifies which device to signal.
|
|
If not specified, ``sem_or_view`` is assumed to be local.
|
|
device_id_type (optional): The format in which
|
|
``device_id`` should be specified.
|
|
core_index (optional): If on a multi-core device,
|
|
specifies which core to signal.
|
|
"""
|
|
ref, transforms = _get_ref_and_transforms(sem_or_view)
|
|
inc = jnp.asarray(inc, dtype=jnp.int32)
|
|
args = [ref, transforms, inc, device_id, core_index]
|
|
flat_args, args_tree = tree_util.tree_flatten(args)
|
|
semaphore_signal_p.bind(
|
|
*flat_args,
|
|
args_tree=args_tree,
|
|
device_id_type=device_id_type,
|
|
)
|
|
|
|
|
|
@semaphore_signal_p.def_effectful_abstract_eval
|
|
def _semaphore_signal_abstract_eval(
|
|
*avals,
|
|
args_tree,
|
|
device_id_type: DeviceIdType,
|
|
):
|
|
(
|
|
sem_aval,
|
|
sem_transforms_avals,
|
|
value_aval,
|
|
device_id_aval,
|
|
core_index_aval,
|
|
) = tree_util.tree_unflatten(args_tree, avals)
|
|
check_sem_avals(sem_aval, sem_transforms_avals, "signal")
|
|
if value_aval.dtype != jnp.dtype("int32"):
|
|
raise ValueError(f"Must signal an int32 value, but got {value_aval.dtype}")
|
|
effs: set[effects.Effect] = {sem_effect}
|
|
if device_id_aval is not None:
|
|
device_id_flat_avals = tree_util.tree_leaves(device_id_aval)
|
|
for aval in device_id_flat_avals:
|
|
if aval.dtype != jnp.dtype("int32"):
|
|
raise ValueError(
|
|
f"`device_id`s must be an int32 value, but got {aval.dtype}"
|
|
)
|
|
if device_id_type is DeviceIdType.MESH and isinstance(device_id_aval, dict):
|
|
for k in device_id_aval:
|
|
if not isinstance(k, tuple):
|
|
k = (k,)
|
|
for k_ in k:
|
|
effs.add(jax_core.NamedAxisEffect(k_))
|
|
else:
|
|
effs.add(pallas_core.comms_effect)
|
|
return [], effs
|
|
|
|
def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn,
|
|
context: jax_core.JaxprPpContext,
|
|
settings: jax_core.JaxprPpSettings):
|
|
del settings
|
|
invars = eqn.invars
|
|
tree = eqn.params["args_tree"]
|
|
(
|
|
sem,
|
|
sem_transforms,
|
|
value,
|
|
device_ids,
|
|
_,
|
|
) = tree_util.tree_unflatten(tree, invars)
|
|
out = pp.concat([
|
|
pp.text("semaphore_signal"),
|
|
pp.text(" "),
|
|
sp.pp_ref_transforms(context, sem, sem_transforms),
|
|
pp.text(" "),
|
|
pp.text(jax_core.pp_var(value, context)),
|
|
])
|
|
if device_ids is not None:
|
|
flat_device_ids = tree_util.tree_leaves(device_ids)
|
|
if not flat_device_ids:
|
|
return out
|
|
device_ids_pp = [pp.text(jax_core.pp_var(flat_device_ids[0], context))]
|
|
for device_id in flat_device_ids[1:]:
|
|
device_ids_pp.append(pp.text(" "))
|
|
device_ids_pp.append(pp.text(jax_core.pp_var(device_id, context)))
|
|
out = pp.concat([out, pp.concat(device_ids_pp)])
|
|
return out
|
|
jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn
|
|
|
|
|
|
def _semaphore_signal_discharge_rule(in_avals,
|
|
out_avals,
|
|
*flat_args,
|
|
args_tree,
|
|
device_id_type):
|
|
del out_avals, device_id_type
|
|
[ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args)
|
|
if device_id is not None:
|
|
raise NotImplementedError("Remote signal not implemented.")
|
|
if core_index is not None:
|
|
raise NotImplementedError("Multiple core support not implemented.")
|
|
sem_value = _transform_semaphore(ref, transforms, in_avals[0])
|
|
inc = inc.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE)
|
|
_, new_sem_value = state_discharge.transform_swap_array(
|
|
ref, transforms, sem_value + inc
|
|
)
|
|
return (new_sem_value,) + (None,) * (len(in_avals) - 1), ()
|
|
state_discharge.register_discharge_rule(semaphore_signal_p)(
|
|
_semaphore_signal_discharge_rule
|
|
)
|
|
|
|
|
|
semaphore_wait_p = jax_core.Primitive('semaphore_wait')
|
|
semaphore_wait_p.multiple_results = True
|
|
|
|
|
|
def semaphore_wait(
|
|
sem_or_view, value: int | jax_typing.Array = 1, *, decrement: bool = True
|
|
):
|
|
"""Blocks execution of the current thread until a semaphore reaches a value.
|
|
|
|
Args:
|
|
sem_or_view: A Ref (or view) representing a semaphore.
|
|
value: The target value that the semaphore should reach before unblocking.
|
|
decrement: Whether to decrement the value of the semaphore after
|
|
a successful wait.
|
|
"""
|
|
ref, transforms = _get_ref_and_transforms(sem_or_view)
|
|
value = jnp.asarray(value, dtype=jnp.int32)
|
|
args = [ref, transforms, value, decrement]
|
|
flat_args, args_tree = tree_util.tree_flatten(args)
|
|
semaphore_wait_p.bind(*flat_args, args_tree=args_tree)
|
|
|
|
@semaphore_wait_p.def_effectful_abstract_eval
|
|
def _semaphore_wait_abstract_eval(*avals, args_tree):
|
|
sem_aval, sem_transforms_avals, value_aval, _ = tree_util.tree_unflatten(
|
|
args_tree, avals
|
|
)
|
|
check_sem_avals(sem_aval, sem_transforms_avals, "wait")
|
|
if value_aval.dtype != jnp.dtype("int32"):
|
|
raise ValueError("Must wait an int32 value.")
|
|
return [], {sem_effect}
|
|
|
|
def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn,
|
|
context: jax_core.JaxprPpContext,
|
|
settings: jax_core.JaxprPpSettings):
|
|
del settings
|
|
invars = eqn.invars
|
|
tree = eqn.params["args_tree"]
|
|
(
|
|
sem,
|
|
sem_transforms,
|
|
value,
|
|
decrement,
|
|
) = tree_util.tree_unflatten(tree, invars)
|
|
parts = [
|
|
pp.text("semaphore_wait"),
|
|
]
|
|
if decrement:
|
|
parts.append(pp.text("[dec]"))
|
|
parts += [
|
|
pp.text(" "),
|
|
sp.pp_ref_transforms(context, sem, sem_transforms),
|
|
pp.text(" "),
|
|
pp.text(jax_core.pp_var(value, context)),
|
|
]
|
|
return pp.concat(parts)
|
|
jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn
|
|
|
|
def _semaphore_wait_discharge_rule(in_avals,
|
|
out_avals,
|
|
*flat_args,
|
|
args_tree):
|
|
del out_avals
|
|
[ref, transforms, value, decrement] = args_tree.unflatten(flat_args)
|
|
sem_value = _transform_semaphore(ref, transforms, in_avals[0])
|
|
value = value.astype(pallas_core.SEMAPHORE_INTERPRET_DTYPE)
|
|
if decrement:
|
|
_, new_sem_value = state_discharge.transform_swap_array(
|
|
ref, transforms, sem_value - value
|
|
)
|
|
else:
|
|
new_sem_value = sem_value
|
|
return (new_sem_value,) + (None,) * (len(in_avals) - 1), ()
|
|
state_discharge.register_discharge_rule(semaphore_wait_p)(
|
|
_semaphore_wait_discharge_rule
|
|
)
|
|
|
|
|
|
def _device_id_dict_to_mesh(mesh_context: pallas_utils.MeshInfo | None, device_id_dict, get_axis_index):
|
|
if mesh_context is None:
|
|
mesh_axis_sizes = {}
|
|
else:
|
|
mesh_axis_sizes = dict(
|
|
zip(mesh_context.axis_names, mesh_context.mesh_shape)
|
|
)
|
|
physical_axis_dict = {}
|
|
# Handle joint axes (i.e., one logical axis over >1 physical axes)
|
|
for axis_name, idx in device_id_dict.items():
|
|
if isinstance(axis_name, tuple) and any(
|
|
a in mesh_axis_sizes for a in axis_name
|
|
):
|
|
if not all(a in mesh_axis_sizes for a in axis_name):
|
|
raise NotImplementedError(
|
|
f"{axis_name} mixes JAX mesh and Pallas mesh grid axes"
|
|
)
|
|
axes_dimensions = [mesh_axis_sizes[name] for name in axis_name]
|
|
for axis_index, axis_name in enumerate(axis_name):
|
|
axis_size = mesh_axis_sizes[axis_name]
|
|
inner_mesh_size = math.prod(axes_dimensions[axis_index + 1 :])
|
|
|
|
# Fast path for power of 2s
|
|
if inner_mesh_size & (inner_mesh_size - 1) == 0:
|
|
shift_len = (inner_mesh_size & -inner_mesh_size).bit_length() - 1
|
|
partial_device_idx = idx >> shift_len
|
|
else:
|
|
partial_device_idx = idx // inner_mesh_size
|
|
|
|
if axis_size & (axis_size - 1) == 0:
|
|
device_idx = partial_device_idx & jnp.asarray(
|
|
axis_size - 1, dtype=partial_device_idx.dtype
|
|
)
|
|
else:
|
|
device_idx = lax.rem(partial_device_idx, axis_size)
|
|
physical_axis_dict[axis_name] = device_idx
|
|
else:
|
|
physical_axis_dict[axis_name] = idx
|
|
device_id = []
|
|
for axis_name in mesh_axis_sizes:
|
|
if axis_name in physical_axis_dict:
|
|
device_id.append(physical_axis_dict[axis_name])
|
|
else:
|
|
device_id.append(get_axis_index(axis_name))
|
|
non_mesh_axes = {
|
|
k: v
|
|
for k, v in physical_axis_dict.items()
|
|
if k not in mesh_axis_sizes
|
|
}
|
|
return tuple(device_id), non_mesh_axes
|
|
|
|
|
|
def device_id_to_logical(
|
|
mesh_context: pallas_utils.MeshInfo | None,
|
|
device_id: Any,
|
|
device_id_type: DeviceIdType,
|
|
get_axis_index: Callable[[Any], Any],
|
|
) -> tuple[Any | None, dict[Any, Any]]:
|
|
"""Normalizes a device id into a logical device id and axes that don't correspond to JAX mesh axes.
|
|
|
|
The indexing implied by the returned axis dict should be handled by the
|
|
caller. If there are no cross-device operations, then the returned logical
|
|
device id will be None.
|
|
"""
|
|
non_mesh_axes = {}
|
|
if isinstance(device_id, dict):
|
|
if device_id_type is not DeviceIdType.MESH:
|
|
raise ValueError(
|
|
"`device_id_type` must be MESH if `device_id` is a dict,"
|
|
f" got: {device_id_type = }."
|
|
)
|
|
device_id, non_mesh_axes = _device_id_dict_to_mesh(mesh_context, device_id, get_axis_index)
|
|
if device_id_type is DeviceIdType.MESH:
|
|
# Mesh means we are passed the mesh coordinates for the device
|
|
device_ids = tree_util.tree_leaves(device_id)
|
|
mesh_strides: tuple[int, ...]
|
|
if mesh_context is None:
|
|
mesh_strides = ()
|
|
else:
|
|
mesh_strides = mesh_context.mesh_strides
|
|
if len(device_ids) != len(mesh_strides):
|
|
raise ValueError(
|
|
"Number of device ids must match the number of mesh axes, but got"
|
|
f" {len(device_ids)} ids for a {len(mesh_strides)}D mesh."
|
|
)
|
|
|
|
if not device_ids:
|
|
# If there are no device ids, then it is purely local communication.
|
|
return None, non_mesh_axes
|
|
return sum(a * b for a, b in zip(device_ids, mesh_strides)), non_mesh_axes
|
|
elif device_id_type is DeviceIdType.LOGICAL:
|
|
return device_id, non_mesh_axes
|
|
raise NotImplementedError(f"Unsupported device id type: {device_id_type}")
|
|
|
|
|
|
delay_p = jax_core.Primitive("delay")
|
|
delay_p.multiple_results = True
|
|
|
|
|
|
class DelayEffect(effects.Effect):
|
|
pass
|
|
delay_effect = DelayEffect()
|
|
effects.control_flow_allowed_effects.add_type(DelayEffect)
|
|
pallas_core.kernel_local_effects.add_type(DelayEffect)
|
|
|
|
|
|
@delay_p.def_effectful_abstract_eval
|
|
def _delay_abstract_eval(nanos):
|
|
del nanos
|
|
return [], {delay_effect}
|
|
|
|
|
|
def delay(nanos: int | jax_typing.Array) -> None:
|
|
"""Sleeps for the given number of nanoseconds."""
|
|
delay_p.bind(nanos)
|
|
|
|
|
|
jaxpr_call_p = jax_core.Primitive("jaxpr_call")
|
|
jaxpr_call_p.multiple_results = True
|
|
|
|
|
|
@jaxpr_call_p.def_effectful_abstract_eval
|
|
def _jaxpr_call_abstract_eval(*args, jaxpr: jax_core.Jaxpr, **params):
|
|
del args, params # Unused.
|
|
# Filter out input effects, since they are only relevant in the context
|
|
# of this ``jaxpr_call``.
|
|
out_effects = {
|
|
e for e in jaxpr.effects if not isinstance(e, effects.JaxprInputEffect)
|
|
}
|
|
return jaxpr.out_avals, out_effects
|
|
|
|
|
|
def _jaxpr_call_pp_eqn(
|
|
eqn: jax_core.JaxprEqn,
|
|
context: jax_core.JaxprPpContext,
|
|
settings: jax_core.JaxprPpSettings,
|
|
):
|
|
flat_args = eqn.invars
|
|
ref_treedefs = eqn.params["ref_treedefs"]
|
|
flat_refs, _ = util.split_list(
|
|
flat_args, [sum(treedef.num_leaves for treedef in ref_treedefs)]
|
|
)
|
|
flat_refs = util.split_list(
|
|
flat_refs,
|
|
[treedef.num_leaves for treedef in ref_treedefs[: len(ref_treedefs) - 1]],
|
|
)
|
|
trailer = []
|
|
for treedef, flat_ref in zip(ref_treedefs, flat_refs):
|
|
ref = treedef.unflatten(flat_ref)
|
|
transforms = []
|
|
if isinstance(ref, tuple):
|
|
ref, transforms = ref
|
|
trailer.append(pp.text(" "))
|
|
trailer.append(sp.pp_ref_transforms(context, ref, transforms))
|
|
return pp.concat([
|
|
pp.text("jaxpr_call"),
|
|
pp.text("["),
|
|
jax_core.pp_kv_pair("jaxpr", eqn.params["jaxpr"], context, settings),
|
|
pp.text("]"),
|
|
pp.concat(trailer),
|
|
])
|
|
|
|
|
|
jax_core.pp_eqn_rules[jaxpr_call_p] = _jaxpr_call_pp_eqn
|
|
|
|
|
|
@state_discharge.register_partial_discharge_rule(jaxpr_call_p)
|
|
def _jaxpr_call_discharge(
|
|
flat_should_discharge,
|
|
in_avals,
|
|
out_avals,
|
|
*flat_args,
|
|
jaxpr,
|
|
ref_treedefs,
|
|
program_ids_treedef,
|
|
):
|
|
del in_avals, out_avals # Unused.
|
|
flat_should_discharge = util.split_list(
|
|
flat_should_discharge,
|
|
[treedef.num_leaves for treedef in ref_treedefs[: len(ref_treedefs) - 1]],
|
|
)
|
|
should_discharge = [*map(any, flat_should_discharge)]
|
|
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
|
|
jaxpr, (), should_discharge=should_discharge
|
|
)
|
|
assert not discharged_consts
|
|
outs = jaxpr_call_p.bind(
|
|
*flat_args,
|
|
jaxpr=discharged_jaxpr,
|
|
ref_treedefs=tuple(ref_treedefs),
|
|
program_ids_treedef=program_ids_treedef,
|
|
)
|
|
discharged_outs_it = iter(outs[len(jaxpr.outvars) :])
|
|
new_in_vals = (
|
|
tuple(
|
|
itertools.chain.from_iterable(
|
|
[next(discharged_outs_it) if discharged else None]
|
|
* ref_treedefs[idx].num_leaves
|
|
for idx, discharged in enumerate(should_discharge)
|
|
)
|
|
)
|
|
+ (None,) * program_ids_treedef.num_leaves
|
|
)
|
|
return new_in_vals, outs[: len(jaxpr.outvars)]
|
|
|
|
|
|
def _jaxpr_call(
|
|
jaxpr: jax_core.Jaxpr,
|
|
*refs: state_types.AbstractRef | state_types.TransformedRef,
|
|
program_ids: Sequence[jax_typing.Array | None],
|
|
) -> Sequence[jax_typing.Array]:
|
|
"""Internal primitive for calling a kernel jaxpr inside ``emit_pipeline``.
|
|
|
|
This is *not* a general purpose primitive. In particular, it assumes that
|
|
the transformed references have been indexed.
|
|
|
|
Args:
|
|
jaxpr: The jaxpr to call.
|
|
*refs: The references to pass into the jaxpr.
|
|
program_ids: The loop-bound program IDs to pass into the jaxpr, or None if
|
|
the program ID corresponds to a parallel dimension.
|
|
|
|
Returns:
|
|
The outputs of the jaxpr.
|
|
"""
|
|
assert not jaxpr.outvars
|
|
flat_refs = []
|
|
ref_treedefs = []
|
|
ref: Any
|
|
for ref in refs:
|
|
if isinstance(ref, state_types.TransformedRef):
|
|
if not isinstance(ref.transforms[-1], indexing.NDIndexer):
|
|
raise ValueError(
|
|
"TransformedRef must have been indexed before passing into"
|
|
f" jaxpr_call. Got {ref}."
|
|
)
|
|
ref = (ref.ref, ref.transforms)
|
|
flat_ref, treedef = tree_util.tree_flatten(ref)
|
|
flat_refs.extend(flat_ref)
|
|
ref_treedefs.append(treedef)
|
|
flat_program_ids, program_ids_treedef = tree_util.tree_flatten(program_ids)
|
|
return jaxpr_call_p.bind(
|
|
*flat_refs,
|
|
*flat_program_ids,
|
|
jaxpr=jaxpr,
|
|
ref_treedefs=tuple(ref_treedefs),
|
|
program_ids_treedef=program_ids_treedef,
|
|
)
|