This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,29 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax.experimental.roofline.roofline import (
RooflineRuleContext as RooflineRuleContext,
)
from jax.experimental.roofline.roofline import RooflineShape as RooflineShape
from jax.experimental.roofline.roofline import RooflineResult as RooflineResult
from jax.experimental.roofline.roofline import roofline as roofline
from jax.experimental.roofline.roofline import register_roofline as register_roofline
from jax.experimental.roofline.roofline import (
register_standard_roofline as register_standard_roofline,
)
from jax.experimental.roofline.roofline import roofline_and_grad as roofline_and_grad
import jax.experimental.roofline.rooflines as rooflines
del rooflines
@@ -0,0 +1,386 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol
from collections.abc import Callable, Sequence
import numpy as np
from absl import logging
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax._src import api
from jax._src import core
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.api import make_jaxpr
from jax._src.interpreters.partial_eval import dce_jaxpr
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
from jax._src.util import foreach
from jax._src.shard_map import shard_map, shard_map_p
ShapeDtypeStructTree = Any
Specs = Any
ValidRooflineDtype = np.dtype | prng.KeyTy
map = util.safe_map
@dataclass(frozen=True, slots=True, kw_only=True)
class RooflineRuleContext:
name_stack: source_info_util.NameStack
primitive: core.Primitive
avals_in: Sequence[core.AbstractValue]
avals_out: Sequence[core.AbstractValue]
jaxpr_eqn_ctx: core.JaxprEqnContext
mesh: Mesh | AbstractMesh | None
pin_lhs_in_vmem: bool
pin_rhs_in_vmem: bool
@dataclass(frozen=True, slots=True, kw_only=True)
class RooflineShape:
shape: tuple[int, ...]
dtype: ValidRooflineDtype
@classmethod
def from_aval(cls, aval: core.AbstractValue) -> RooflineShape:
if not isinstance(aval, core.ShapedArray):
raise TypeError(f"Expected ShapedArray, got {type(aval)}.")
if not isinstance(aval.dtype, ValidRooflineDtype):
raise TypeError(
f"Expected numpy or prng.KeyTy dtype, got {type(aval.dtype)}."
)
return cls(shape=aval.shape, dtype=aval.dtype)
@property
def size(self) -> int:
return int(np.prod(self.shape))
@property
def bytes(self) -> int:
return int(self.size * self.dtype.itemsize)
@classmethod
def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int:
return sum(cls.from_aval(aval).bytes for aval in avals)
@dataclass(frozen=True, slots=True, kw_only=True)
class RooflineResult:
flops: int = 0
unfused_flops: int = 0
ici_bytes: dict[str, int] = field(default_factory=dict)
ici_latency: dict[str, int] = field(default_factory=dict)
hbm_bytes: int = 0
peak_hbm_bytes: int = 0
unfused_hbm_bytes: int = 0
@classmethod
def zeros(cls) -> RooflineResult:
return cls()
def __add__(self, other: RooflineResult) -> RooflineResult:
def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]:
return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)}
return RooflineResult(
flops=self.flops + other.flops,
unfused_flops=self.unfused_flops + other.unfused_flops,
ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes),
ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency),
hbm_bytes=self.hbm_bytes + other.hbm_bytes,
peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes),
unfused_hbm_bytes=self.unfused_hbm_bytes + other.unfused_hbm_bytes,
)
def __mul__(self, constant: int | float) -> RooflineResult:
return RooflineResult(
flops=int(self.flops * constant),
unfused_flops=int(self.unfused_flops * constant),
ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()},
ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()},
hbm_bytes=int(self.hbm_bytes * constant),
peak_hbm_bytes=int(self.peak_hbm_bytes * constant),
unfused_hbm_bytes=int(self.unfused_hbm_bytes * constant),
)
def __rmul__(self, constant: int | float) -> RooflineResult:
return self.__mul__(constant)
class _RooflineRule(Protocol):
def __call__(
self, ctx: RooflineRuleContext, *args: RooflineShape, **kw
) -> RooflineResult: ...
_rooflines: dict[core.Primitive, _RooflineRule] = {}
def _roofline_interpreter(
f_name: str,
jaxpr: core.Jaxpr,
mesh: Mesh | AbstractMesh | None,
*,
pin_lhs_in_vmem: bool = False,
pin_rhs_in_vmem: bool = False,
) -> RooflineResult:
name_stack = source_info_util.new_name_stack(util.wrap_name("roofline", f_name))
result = RooflineResult.zeros()
env: dict[core.Var, RooflineShape] = {}
def write(v: core.Var, node: RooflineShape):
assert node is not None
env[v] = node
def read(v: core.Atom) -> RooflineShape:
if type(v) is core.Literal:
return RooflineShape.from_aval(core.typeof(v.val))
else:
assert isinstance(v, core.Var)
return env[v]
def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return core.typeof(v.val)
else:
return v.aval
def sum_bytes(shapes: Sequence[RooflineShape]) -> int:
return sum(shape.bytes for shape in shapes)
jaxpr = jaxpr.jaxpr if isinstance(jaxpr, core.ClosedJaxpr) else jaxpr
make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x))
foreach(
write,
jaxpr.constvars,
map(make_roofline_shape, jaxpr.constvars),
)
foreach(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
last_used = core.last_used(jaxpr)
current_hbm_bytes = sum_bytes(list(env.values()))
peak_hbm_bytes = current_hbm_bytes
for eqn in jaxpr.eqns:
source_info = eqn.source_info.replace(
name_stack=name_stack + eqn.source_info.name_stack
)
with source_info_util.user_context(
eqn.source_info.traceback, name_stack=source_info.name_stack
):
if "jaxpr" in eqn.params:
result += _roofline_interpreter(
util.wrap_name(eqn.primitive.name, f_name),
eqn.params["jaxpr"],
mesh,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
)
elif "call_jaxpr" in eqn.params:
# Used for custom_jvp_call_p. Recursively calculates roofline result for
# all primitives in the custom function.
result += _roofline_interpreter(
util.wrap_name(eqn.primitive.name, f_name),
eqn.params['call_jaxpr'],
mesh,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
)
elif eqn.primitive not in _rooflines:
msg = f"No roofline rule for {eqn.primitive}, skipping..."
for attr in dir(eqn):
if not attr.startswith("_"):
msg += f"\n{attr}: {getattr(eqn, attr)}"
logging.warning(msg)
else:
rule = _rooflines[eqn.primitive]
result += rule(
RooflineRuleContext(
name_stack=source_info.name_stack,
primitive=eqn.primitive,
avals_in=map(aval, eqn.invars),
avals_out=map(aval, eqn.outvars),
jaxpr_eqn_ctx=eqn.ctx,
mesh=mesh,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
),
*map(read, eqn.invars),
**eqn.params,
)
# Add bytes for the newly-created output variables.
outvar_shapes = map(make_roofline_shape, eqn.outvars)
current_hbm_bytes += sum_bytes(outvar_shapes)
foreach(write, eqn.outvars, outvar_shapes)
# Remove bytes for the no-longer-needed input variables.
removed_shapes = [
env[v] for v in eqn.invars
if not isinstance(v, core.Literal) and last_used[v] is eqn
]
current_hbm_bytes -= sum_bytes(removed_shapes)
core.clean_up_dead_vars(eqn, env, last_used)
peak_hbm_bytes = max(peak_hbm_bytes, current_hbm_bytes)
result += RooflineResult(peak_hbm_bytes=peak_hbm_bytes)
return result
def _f_with_vjp(f: Callable):
@util.wraps(f)
def wrapped(*args):
primals, f_vjp = api.vjp(f, *args)
return f_vjp(tree_map(jnp.bfloat16, primals))
return wrapped
def roofline(
f: Callable,
mesh: Mesh | AbstractMesh | None = None,
in_specs: Specs | None = None,
out_specs: Specs | None = None,
*,
pin_lhs_in_vmem: bool = False,
pin_rhs_in_vmem: bool = False,
vjp: bool = False,
print_jaxpr: bool = False,
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]:
@util.wraps(f)
@traceback_util.api_boundary
def wrapped(*args):
wrapped_f = f
if in_specs is not None and out_specs is not None and mesh is not None:
wrapped_f = shard_map(wrapped_f, mesh=mesh, in_specs=in_specs,
out_specs=out_specs)
if vjp:
wrapped_f = _f_with_vjp(wrapped_f)
jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args)
def make_sharded_shape_dtype_struct(
shape: api.ShapeDtypeStruct, out_spec: Specs
) -> api.ShapeDtypeStruct:
assert mesh is not None
return api.ShapeDtypeStruct(
shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec)
)
if out_specs is not None and mesh is not None:
out_specs_flat = broadcast_prefix(out_specs, out_shapes)
flat_out_shapes, treedef = tree_flatten(out_shapes)
flat_out_shapes = map(
make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat
)
out_shapes = tree_unflatten(treedef, flat_out_shapes)
used_outputs = (True,) * len(jaxpr.jaxpr.outvars)
jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs)
shard_map_eqns = [
e for e in jaxpr.eqns if e.primitive == shard_map_p
]
if shard_map_eqns:
try:
jaxpr = shard_map_eqns[-1].params["jaxpr"]
except KeyError:
raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.")
if print_jaxpr:
print(jaxpr)
return out_shapes, _roofline_interpreter(
util.fun_qual_name(f),
jaxpr,
mesh,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
)
return wrapped
def register_roofline(prim: core.Primitive):
def register(rule: _RooflineRule):
_rooflines[prim] = rule
return rule
return register
def register_standard_roofline(prim: core.Primitive):
def standard_rule(ctx: RooflineRuleContext, *args, **kwargs):
return RooflineResult.zeros()
_rooflines[prim] = standard_rule
def roofline_and_grad(
f: Callable,
mesh: Mesh | AbstractMesh,
in_specs: Specs,
out_specs: Specs,
*,
pin_lhs_in_vmem: bool = False,
pin_rhs_in_vmem: bool = False,
print_jaxpr: bool = False,
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]:
@util.wraps(f)
@traceback_util.api_boundary
def wrapped(*args):
primal_shapes, fwd_result = roofline(
f,
mesh,
in_specs,
out_specs,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
print_jaxpr=print_jaxpr,
)(*args)
return (
primal_shapes,
fwd_result,
roofline(
f,
mesh,
in_specs,
out_specs,
pin_lhs_in_vmem=pin_lhs_in_vmem,
pin_rhs_in_vmem=pin_rhs_in_vmem,
vjp=True,
print_jaxpr=print_jaxpr,
)(
*tree_map(
lambda x: api.ShapeDtypeStruct(
x.shape,
jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16,
sharding=x.sharding,
),
args,
)
)[1],
)
return wrapped
@@ -0,0 +1,822 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from dataclasses import replace
import itertools as it
from collections.abc import Sequence
import numpy as np
from jax._src import api
from jax._src import ad_checkpoint
from jax._src import ad_util
from jax._src import core, util
from jax._src import dispatch
from jax._src import ops
from jax._src import pjit
from jax._src import prng
from jax._src import random
from jax._src import shard_map
from jax._src import callback
from jax._src import debugging
from jax._src.lax import (
ann,
control_flow,
convolution,
fft,
lax,
linalg,
parallel as lax_parallel,
slicing,
special,
windowed_reductions,
)
from jax.experimental import roofline
# One FMA (Fused Multiply Add) takes 2 flops to compute.
_FMA_FLOPS_FACTOR = 2
for prim in it.chain(
ad_checkpoint.__dict__.values(),
ad_util.__dict__.values(),
ann.__dict__.values(),
callback.__dict__.values(),
control_flow.__dict__.values(),
convolution.__dict__.values(),
dispatch.__dict__.values(),
fft.__dict__.values(),
lax.__dict__.values(),
linalg.__dict__.values(),
ops.__dict__.values(),
[pjit.sharding_constraint_p],
prng.__dict__.values(),
random.__dict__.values(),
shard_map.__dict__.values(),
slicing.__dict__.values(),
special.__dict__.values(),
windowed_reductions.__dict__.values(),
):
if isinstance(prim, core.Primitive):
roofline.register_standard_roofline(prim)
def _unary_p_roofline(
ctx: roofline.RooflineRuleContext,
*args,
**kw,
) -> roofline.RooflineResult:
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
return roofline.RooflineResult(
unfused_flops=x.size,
unfused_hbm_bytes=(
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
),
)
roofline.register_roofline(lax.abs_p)(_unary_p_roofline)
roofline.register_roofline(lax.acos_p)(_unary_p_roofline)
roofline.register_roofline(lax.asin_p)(_unary_p_roofline)
roofline.register_roofline(lax.atan_p)(_unary_p_roofline)
roofline.register_roofline(lax.cbrt_p)(_unary_p_roofline)
roofline.register_roofline(lax.ceil_p)(_unary_p_roofline)
roofline.register_roofline(lax.conj_p)(_unary_p_roofline)
roofline.register_roofline(lax.cos_p)(_unary_p_roofline)
roofline.register_roofline(lax.cosh_p)(_unary_p_roofline)
roofline.register_roofline(lax.exp_p)(_unary_p_roofline)
roofline.register_roofline(lax.expm1_p)(_unary_p_roofline)
roofline.register_roofline(lax.floor_p)(_unary_p_roofline)
roofline.register_roofline(lax.imag_p)(_unary_p_roofline)
roofline.register_roofline(lax.integer_pow_p)(_unary_p_roofline)
roofline.register_roofline(lax.is_finite_p)(_unary_p_roofline)
roofline.register_roofline(lax.log_p)(_unary_p_roofline)
roofline.register_roofline(lax.log1p_p)(_unary_p_roofline)
roofline.register_roofline(lax.logistic_p)(_unary_p_roofline)
roofline.register_roofline(lax.neg_p)(_unary_p_roofline)
roofline.register_roofline(lax.not_p)(_unary_p_roofline)
roofline.register_roofline(lax.real_p)(_unary_p_roofline)
roofline.register_roofline(lax.round_p)(_unary_p_roofline)
roofline.register_roofline(lax.rsqrt_p)(_unary_p_roofline)
roofline.register_roofline(lax.sign_p)(_unary_p_roofline)
roofline.register_roofline(lax.sin_p)(_unary_p_roofline)
roofline.register_roofline(lax.sinh_p)(_unary_p_roofline)
roofline.register_roofline(lax.sqrt_p)(_unary_p_roofline)
roofline.register_roofline(lax.square_p)(_unary_p_roofline)
roofline.register_roofline(lax.tan_p)(_unary_p_roofline)
roofline.register_roofline(special.bessel_i0e_p)(_unary_p_roofline)
roofline.register_roofline(special.bessel_i1e_p)(_unary_p_roofline)
roofline.register_roofline(special.digamma_p)(_unary_p_roofline)
roofline.register_roofline(special.erf_inv_p)(_unary_p_roofline)
roofline.register_roofline(special.erf_p)(_unary_p_roofline)
roofline.register_roofline(special.erfc_p)(_unary_p_roofline)
roofline.register_roofline(special.lgamma_p)(_unary_p_roofline)
roofline.register_standard_roofline(core.pvary_p)
def _binary_p_roofline(
ctx: roofline.RooflineRuleContext,
*args,
**kw,
) -> roofline.RooflineResult:
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
broadcasted_shape = [
max(l, r) for l, r in it.zip_longest(lhs.shape, rhs.shape, fillvalue=1)
]
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
return roofline.RooflineResult(
unfused_flops=int(np.prod(broadcasted_shape)),
unfused_hbm_bytes=(
lhs.dtype.itemsize * lhs.size
+ rhs.dtype.itemsize * rhs.size
+ out.dtype.itemsize * out.size
),
)
roofline.register_roofline(lax.add_p)(_binary_p_roofline)
roofline.register_roofline(lax.sub_p)(_binary_p_roofline)
roofline.register_roofline(lax.mul_p)(_binary_p_roofline)
roofline.register_roofline(lax.div_p)(_binary_p_roofline)
roofline.register_roofline(lax.rem_p)(_binary_p_roofline)
roofline.register_roofline(lax.and_p)(_binary_p_roofline)
roofline.register_roofline(lax.or_p)(_binary_p_roofline)
roofline.register_roofline(lax.xor_p)(_binary_p_roofline)
roofline.register_roofline(lax.gt_p)(_binary_p_roofline)
roofline.register_roofline(lax.lt_p)(_binary_p_roofline)
roofline.register_roofline(lax.ge_p)(_binary_p_roofline)
roofline.register_roofline(lax.le_p)(_binary_p_roofline)
roofline.register_roofline(lax.eq_p)(_binary_p_roofline)
roofline.register_roofline(lax.ne_p)(_binary_p_roofline)
roofline.register_roofline(lax.min_p)(_binary_p_roofline)
roofline.register_roofline(lax.max_p)(_binary_p_roofline)
def _cumulative_p_roofline(
ctx: roofline.RooflineRuleContext,
*args,
axis: int,
**kw,
) -> roofline.RooflineResult:
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
return roofline.RooflineResult(
# `cum{max, min, prod, sum}` only calculate values for one axis.
unfused_flops=x.shape[axis],
unfused_hbm_bytes=(
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
),
)
roofline.register_roofline(control_flow.cummax_p)(_cumulative_p_roofline)
roofline.register_roofline(control_flow.cummin_p)(_cumulative_p_roofline)
roofline.register_roofline(control_flow.cumprod_p)(_cumulative_p_roofline)
roofline.register_roofline(control_flow.cumsum_p)(_cumulative_p_roofline)
@roofline.register_roofline(control_flow.cumlogsumexp_p)
def _cumlogsumexp_p_roofline(
ctx: roofline.RooflineRuleContext,
*args,
axis: int,
**kw,
) -> roofline.RooflineResult:
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
return roofline.RooflineResult(
# Similar to `cum{max, min, prod, sum}`, `cumlogsumexp` only calculates
# values for one axis. But for `x.shape[axis] = S`, it computes (for a
# naive implementation):
# S `exp` ops.
# S-1 `add` ops.
# 1 log op.
# Thus, the total number of flops is 2 * S.
unfused_flops=x.shape[axis] * 2,
unfused_hbm_bytes=(
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
),
)
@roofline.register_roofline(lax.dot_general_p)
def _dot_general_roofline(
ctx: roofline.RooflineRuleContext,
*args,
dimension_numbers: lax.DotDimensionNumbers,
**kw,
) -> roofline.RooflineResult:
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
(lhs_contract, _), (lhs_batch, _) = dimension_numbers
flops = (
_FMA_FLOPS_FACTOR
* lhs.size
* rhs.size
/ np.prod([lhs.shape[i] for i in lhs_contract])
/ np.prod([lhs.shape[i] for i in lhs_batch])
)
hbm_bytes = 0
if not ctx.pin_lhs_in_vmem:
hbm_bytes += lhs.bytes
hbm_bytes += out.bytes
if not ctx.pin_rhs_in_vmem:
hbm_bytes += rhs.bytes
return roofline.RooflineResult(
flops=int(flops),
unfused_flops=int(flops),
hbm_bytes=hbm_bytes,
unfused_hbm_bytes=hbm_bytes,
)
def _get_spatial_valid_position_count_for_one_dim(
window_dim_stride: int,
base_dilation: int,
window_dilation: int,
kernel_limit: int,
input_limit: int,
output_limit: int,
padding: tuple[int, int],
) -> int:
"""Gets the valid position count for conv for a single spatial dimension.
Args:
window_dim_stride: The stride of the window along this dimension.
base_dilation: The base dilation factor along this dimension.
window_dilation: The window dilation factor along this dimension.
kernel_limit: The size of the kernel along this dimension.
input_limit: The size of the input along this dimension.
output_limit: The size of the output along this dimension.
padding: The padding applied to the input along this dimension.
"""
padding_low = padding[0]
padding_high = padding[1]
# These two conditions will create an N^2 iteration pattern with only N
# valid elements. This is a performance optimization and produces the same
# result as the whole loop.
if (
input_limit == output_limit
and kernel_limit == output_limit
and input_limit == base_dilation
and window_dilation == 1
and max(1, input_limit - 1) == window_dim_stride
and padding_low == 0
and padding_high == 0
):
return input_limit
if (
input_limit == 1
and kernel_limit == output_limit
and window_dilation == 1
and base_dilation == 1
and window_dim_stride == 1
and padding_low == output_limit - 1
and padding_high == output_limit - 1
):
return output_limit
valid_position_count = 0
# Loop over each point in the kernel
for kernel_idx in range(kernel_limit):
# Skip loop for trivial stride and base_dilation
if window_dim_stride == 1 and base_dilation == 1:
undilated_index_base = padding_low - kernel_idx * window_dilation
upper_limit = min(
input_limit + undilated_index_base,
output_limit,
)
lower_limit = max(0, undilated_index_base)
valid_position_count += max(upper_limit - lower_limit, 0)
continue
# Loop over each point in the output
for output_idx in range(output_limit):
# Calculate lhs (input) index without taking base dilation into account
undilated_index = (
output_idx * window_dim_stride
- padding_low
+ kernel_idx * window_dilation
)
# Calculate the actual lhs (input) index after dilation
lhs_spatial_index = int(undilated_index / base_dilation)
# Skip if the lhs (input) index is to be dilated.
if undilated_index != lhs_spatial_index * base_dilation:
continue
# Skip if input index is not in bound.
if lhs_spatial_index < 0 or lhs_spatial_index >= input_limit:
continue
valid_position_count += 1
return valid_position_count
def _get_spatial_valid_position_count(
dnums: convolution.ConvDimensionNumbers,
lhs: roofline.RooflineShape,
rhs: roofline.RooflineShape,
out: roofline.RooflineShape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
lhs_dilation: Sequence[int],
rhs_dilation: Sequence[int],
) -> int:
"""Gets the number of valid spatial positions for conv_general_dilated.
Args:
dnums: The dimension numbers for the convolution.
lhs: The shape of the left-hand side of the convolution.
rhs: The shape of the right-hand side of the convolution.
out: The shape of the output of the convolution.
window_strides: The stride of the window along each spatial dimension.
padding: The padding applied to the input along each spatial dimension.
lhs_dilation: The dilation factor for the left-hand side along each spatial
dimension.
rhs_dilation: The dilation factor for the right-hand side along each spatial
dimension.
"""
input_spatial_dims, kernel_spatial_dims, out_spatial_dims = (
dnums.lhs_spec[2:],
dnums.rhs_spec[2:],
dnums.out_spec[2:],
)
valid_position_counts = 1
# Loop over each spatial dimension and determine how many valid positions
# there are for each dimension.
for d in range(len(input_spatial_dims)):
valid_position_counts *= _get_spatial_valid_position_count_for_one_dim(
window_dim_stride=window_strides[d],
base_dilation=lhs_dilation[d],
window_dilation=rhs_dilation[d],
kernel_limit=rhs.shape[kernel_spatial_dims[d]],
input_limit=lhs.shape[input_spatial_dims[d]],
output_limit=out.shape[out_spatial_dims[d]],
padding=padding[d],
)
return valid_position_counts
def _calculate_conv_flops(
lhs: roofline.RooflineShape,
rhs: roofline.RooflineShape,
out: roofline.RooflineShape,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
lhs_dilation: Sequence[int],
rhs_dilation: Sequence[int],
dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers,
batch_group_count: int,
) -> int:
"""Calculates roofline unfused flops for Jax's conv_general_dilated primitive.
See `jax.lax.conv_general_dilated` for details on the arguments.
"""
dnums = convolution.conv_dimension_numbers(
lhs.shape, rhs.shape, dimension_numbers
)
spatial_valid_position_counts = _get_spatial_valid_position_count(
dnums, lhs, rhs, out, window_strides, padding, lhs_dilation, rhs_dilation
)
batch = lhs.shape[dnums.lhs_spec[0]]
num_output_features = out.shape[dnums.out_spec[1]]
num_input_features = rhs.shape[dnums.rhs_spec[1]]
num_output_batch = batch / batch_group_count
non_spatial_dims_factor = (
num_input_features * num_output_features * num_output_batch
)
fma_count = non_spatial_dims_factor * spatial_valid_position_counts
flops = fma_count * _FMA_FLOPS_FACTOR
return int(flops)
@roofline.register_roofline(convolution.conv_general_dilated_p)
def _conv_general_dilated_roofline(
ctx: roofline.RooflineRuleContext,
*args,
window_strides: Sequence[int],
padding: Sequence[tuple[int, int]],
lhs_dilation: Sequence[int],
rhs_dilation: Sequence[int],
dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers,
batch_group_count: int,
**kw,
) -> roofline.RooflineResult:
"""Roofline for Jax's conv_general_dilated primitive.
See `jax.lax.conv_general_dilated` for details on the arguments.
"""
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
return roofline.RooflineResult(
unfused_flops=_calculate_conv_flops(
lhs,
rhs,
out,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dimension_numbers,
batch_group_count,
),
unfused_hbm_bytes=(
lhs.dtype.itemsize * lhs.size
+ rhs.dtype.itemsize * rhs.size
+ out.dtype.itemsize * out.size
),
)
def _return_zeros_if_one_sized_axis(
ctx: roofline.RooflineRuleContext, axes: tuple[str, ...]
) -> roofline.RooflineResult | None:
assert ctx.mesh
axes_size = np.prod([ctx.mesh.shape[axis] for axis in axes])
if axes_size > 1:
return None
return roofline.RooflineResult(
ici_bytes={axis: 0 for axis in axes},
ici_latency={axis: 0 for axis in axes},
)
def _ring_collective_roofline(
ctx: roofline.RooflineRuleContext,
*args,
axes: tuple[str, ...],
is_reduce: bool = True,
**kw,
) -> roofline.RooflineResult:
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axes):
return zeros_result
assert ctx.mesh
mesh = ctx.mesh.shape
current_shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in)
if is_reduce:
current_shard_size /= np.prod([mesh[axis] for axis in axes])
# We model the slowest color as the bottleneck.
sorted_axes = sorted(axes, key=lambda x: mesh[x], reverse=True)
num_axes = len(sorted_axes)
ici_bytes = 0
# Phase split.
current_shard_size //= num_axes
for axis in sorted_axes:
axis_size = mesh[axis]
# Do phase.
ici_bytes += current_shard_size * (axis_size - 1)
# Increase shard size.
current_shard_size *= axis_size
# Bottleneck is the longest axis.
ici_latency = mesh[sorted_axes[0]] * num_axes
return roofline.RooflineResult(
ici_bytes={axis: int(ici_bytes) for axis in sorted_axes},
ici_latency={axis: int(ici_latency) for axis in sorted_axes},
)
roofline.register_roofline(lax_parallel.reduce_scatter_p)(
lambda *args, axis_name, **kw: _ring_collective_roofline(*args, axes=axis_name, **kw)
)
roofline.register_roofline(lax_parallel.all_gather_p)(
lambda *args, axis_name, **kw: _ring_collective_roofline(
*args, axes=axis_name, is_reduce=False, **kw
)
)
def _calculate_gather_flops(
mode: slicing.GatherScatterMode,
indices_size: int,
output_size: int,
) -> int:
"""Calculates roofline unfused flops for Jax's gather primitive."""
if mode == slicing.GatherScatterMode.FILL_OR_DROP:
# With FILL_OR_DROP, we have 4 steps to check whether to fill (or drop):
# 1. Check if the index is within upper bound.
# 2. Check if the index is within lower bound.
# 3. Call `and` on #1 and #2 to check the index is "in bounds".
# 4. `reduce` the result to a single boolean per window.
# Each of the steps is a single elementwise op on the indices.
index_check_flops = indices_size * 4
# Once we know whether to fill or drop (per window), there are 2 steps to
# mask the output:
# 1. Broadcast the per-window boolean to the output shape.
# 2. Choose whether to fill (from `operand`) if in-bounds, or drop if
# out-of-bounds.
# Broadcasting is free, but choosing whether to fill or drop involves an
# elementwise op the size of the output.
output_mask_flops = output_size
return index_check_flops + output_mask_flops
return 0
@roofline.register_roofline(slicing.gather_p)
def _gather_roofline(
ctx: roofline.RooflineRuleContext,
*args,
mode: slicing.GatherScatterMode,
**kw,
) -> roofline.RooflineResult:
_, indices = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
# Gather doesn't read the whole input buffer, it's equivalent to a copy the
# size of the output shape and a read of the gather indices.
unfused_hbm_bytes = (
out.dtype.itemsize * out.size * 2 + indices.dtype.itemsize * indices.size
)
return roofline.RooflineResult(
unfused_flops=_calculate_gather_flops(mode, indices.size, out.size),
unfused_hbm_bytes=unfused_hbm_bytes,
)
def _scatter_roofline(
ctx: roofline.RooflineRuleContext,
*args,
**kw,
) -> roofline.RooflineResult:
"""Roofline for Jax's `scatter*` primitives.
The `scatter` functionality itself is a simple data read and write, which
contributes 0 flops.
But, the jaxpr for each `scatter*` function (aside from `jax.lax.scatter`)
contains an `update_jaxpr` that gets applied to the operand & scattered
updates (e.g. `add` for `scatter_add`, or arbitrary unary function for
`scatter_apply`), which *does* contribute flops. This `update_jaxpr` gets
applied to every element of the scattered updates.
Thus,
flops = [# flops for `update_jaxpr` per element] * [# elements in `updates`].
To calculate # flops for `update_jaxpr`, we convert the `update_jaxpr` back to
a callable, and then call `roofline` on that callable. `update_jaxpr` does not
contain any information about input shapes or dtypes; it expects scalars. It
will therefore give us a # flops-per-element result, which we multiply by
the size of the updates to get the total flops.
"""
(_, indices, updates) = (
roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in
)
update_jaxpr = kw.get('update_jaxpr')
flops = 0
if update_jaxpr:
update_fn = lambda *inputs: core.eval_jaxpr(update_jaxpr, [], *inputs)
# Create dummy scalar inputs.
dummy_inputs = [
api.ShapeDtypeStruct((), updates.dtype) for _ in update_jaxpr.invars
]
# Calculate the flops for the `update_jaxpr` on scalar inputs.
_, roofline_result = roofline.roofline(update_fn)(*dummy_inputs)
# Multiply by the size of the updates to get the total flops.
flops = roofline_result.unfused_flops * updates.size
return roofline.RooflineResult(
unfused_flops=flops,
# Scatter accesses the equivalent of 3N update shapes (input, output, and
# updates), and the scatter indices.
unfused_hbm_bytes=(
3 * updates.dtype.itemsize * updates.size
+ indices.dtype.itemsize * indices.size
),
)
roofline.register_roofline(slicing.scatter_add_p)(_scatter_roofline)
roofline.register_roofline(slicing.scatter_max_p)(_scatter_roofline)
roofline.register_roofline(slicing.scatter_min_p)(_scatter_roofline)
roofline.register_roofline(slicing.scatter_mul_p)(_scatter_roofline)
roofline.register_roofline(slicing.scatter_sub_p)(_scatter_roofline)
# Also registers `jax.lax.scatter_apply`, which uses the `scatter_p` primitive.
roofline.register_roofline(slicing.scatter_p)(_scatter_roofline)
def _scalar_collective_roofline(
ctx: roofline.RooflineRuleContext,
*args,
axes: tuple[str, ...],
**kw,
) -> roofline.RooflineResult:
shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in]
ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes])
return _ring_collective_roofline(ctx, *args, axes=axes, is_reduce=False, **kw)
roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline)
roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline)
@roofline.register_roofline(lax_parallel.psum_invariant_p)
def _psum2_roofline(
ctx: roofline.RooflineRuleContext,
*args,
axes: tuple[str, ...],
**kw,
) -> roofline.RooflineResult:
ring_roofline = _ring_collective_roofline(ctx, *args, axes=axes, **kw)
def double_dict(d: dict[str, int]) -> dict[str, int]:
return {k: v * 2 for k, v in d.items()}
return roofline.RooflineResult(
ici_bytes=double_dict(ring_roofline.ici_bytes),
ici_latency=double_dict(ring_roofline.ici_latency),
)
@roofline.register_roofline(lax_parallel.all_to_all_p)
def _all_to_all_roofline(
ctx: roofline.RooflineRuleContext,
*args,
axis_name: tuple[str, ...],
**kw,
) -> roofline.RooflineResult:
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name):
return zeros_result
assert ctx.mesh
mesh = ctx.mesh.shape
size = roofline.RooflineShape.total_bytes(ctx.avals_in) * np.prod([
mesh[axis] for axis in axis_name
])
smallest_axis = sorted(axis_name, key=lambda x: mesh[x])[0]
num_axes = len(axis_name)
bisection_bw = mesh[smallest_axis] ** (num_axes - 1)
if mesh[smallest_axis] > 2:
# Times 2 because of wraparound.
bisection_bw *= 2
# Half the data needs to cross the bisection on average.
ici_bytes = size / 2 / bisection_bw
# The latency is the max number of hops across the mesh.
ici_latency = sum(mesh[axis] / 2 for axis in axis_name)
return roofline.RooflineResult(
ici_bytes={axis: int(ici_bytes) for axis in axis_name},
ici_latency={axis: int(ici_latency) for axis in axis_name},
)
@roofline.register_roofline(lax_parallel.ppermute_p)
def _ppermute_roofline(
ctx: roofline.RooflineRuleContext,
*args,
axis_name: tuple[str, ...],
perm: tuple[tuple[int, int], ...],
**kw,
) -> roofline.RooflineResult:
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name):
return zeros_result
assert ctx.mesh
mesh = ctx.mesh.shape
mesh_dims: list[int] = [mesh.get(axis, 1) for axis in axis_name]
shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in)
ici_contention: dict[tuple[tuple[int, ...], ...], float] = defaultdict(float)
ici_latency = 0
for src, dst in perm:
if src == dst:
continue
# Perms are linearized.
src_coords = tuple(int(i) for i in np.unravel_index(src, mesh_dims))
dst_coords = tuple(int(i) for i in np.unravel_index(dst, mesh_dims))
ici_latency_for_perm = 0
# For each dimension.
for i in range(len(axis_name)):
dim_size = mesh_dims[i]
src_pos = src_coords[i]
dst_pos = dst_coords[i]
if src_pos != dst_pos:
# Calculate distance with wraparound.
clockwise_dist = (dst_pos - src_pos) % dim_size
counter_dist = (src_pos - dst_pos) % dim_size
direction = 1 if clockwise_dist <= counter_dist else -1
curr_pos = src_pos
while curr_pos != dst_pos:
curr_coords = util.tuple_update(src_coords, i, curr_pos)
next_pos = (curr_pos + direction) % dim_size
next_coords = util.tuple_update(curr_coords, i, next_pos)
ici_contention[tuple(sorted([curr_coords, next_coords]))] += 1
curr_pos = next_pos
distance = min(clockwise_dist, counter_dist)
ici_latency_for_perm += distance
ici_latency = max(ici_latency, ici_latency_for_perm)
ici_bytes = shard_size * max(ici_contention.values(), default=0)
return roofline.RooflineResult(
ici_bytes={axis: int(ici_bytes) for axis in axis_name},
ici_latency={axis: int(ici_latency) for axis in axis_name},
)
@roofline.register_roofline(lax.reduce_sum_p)
def _reduce_sum_p_roofline(
ctx: roofline.RooflineRuleContext,
*args,
axes: tuple[int, ...],
**kw,
) -> roofline.RooflineResult:
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
domain_size = np.prod([x.shape[i] for i in axes])
other_axes = set(range(len(x.shape))) - set(axes)
result_size = np.prod([x.shape[i] for i in other_axes])
return roofline.RooflineResult(
# To add n values, we do n - 1 add operations, and we have to do that
# for every element in the result.
unfused_flops=int((domain_size - 1) * result_size),
# Size of input, plus output. (We assume that the output is also used
# as accumulator.)
unfused_hbm_bytes=int(x.dtype.itemsize * (x.size + result_size)),
)
@roofline.register_roofline(lax.select_n_p)
def _select_n_p_roofline(
ctx: roofline.RooflineRuleContext,
*args,
**kw,
) -> roofline.RooflineResult:
(x, *_) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
return roofline.RooflineResult(
unfused_flops=out.size,
unfused_hbm_bytes=(
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
),
)
@roofline.register_roofline(callback.pure_callback_p)
@roofline.register_roofline(callback.io_callback_p)
def _callback_with_output_roofline(
ctx: roofline.RooflineRuleContext,
*args,
**kw,
) -> roofline.RooflineResult:
avals_in = ctx.avals_in
avals_out = ctx.avals_out
# HBM bytes for transferring inputs to host and results back to device.
hbm_bytes = roofline.RooflineShape.total_bytes(
avals_in
) + roofline.RooflineShape.total_bytes(avals_out)
# We don't have access to the `callback_func`, so we assume it contributes 0
# flops.
return roofline.RooflineResult(unfused_hbm_bytes=hbm_bytes)
@roofline.register_roofline(debugging.debug_callback_p)
def _debug_callback_roofline(
ctx: roofline.RooflineRuleContext,
*args,
**kw,
) -> roofline.RooflineResult:
avals_in = ctx.avals_in
# `debug_callback` does not return values to the JAX program, so only input
# HBM bytes are considered.
hbm_bytes = roofline.RooflineShape.total_bytes(avals_in)
# We don't have access to the `callback_func`, so we assume it contributes 0
# flops.
return roofline.RooflineResult(unfused_hbm_bytes=hbm_bytes)