hand
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from jax.experimental.roofline.roofline import (
|
||||
RooflineRuleContext as RooflineRuleContext,
|
||||
)
|
||||
from jax.experimental.roofline.roofline import RooflineShape as RooflineShape
|
||||
from jax.experimental.roofline.roofline import RooflineResult as RooflineResult
|
||||
from jax.experimental.roofline.roofline import roofline as roofline
|
||||
from jax.experimental.roofline.roofline import register_roofline as register_roofline
|
||||
from jax.experimental.roofline.roofline import (
|
||||
register_standard_roofline as register_standard_roofline,
|
||||
)
|
||||
from jax.experimental.roofline.roofline import roofline_and_grad as roofline_and_grad
|
||||
|
||||
|
||||
import jax.experimental.roofline.rooflines as rooflines
|
||||
|
||||
del rooflines
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,386 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Callable, Sequence
|
||||
import numpy as np
|
||||
from absl import logging
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import prng
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api import make_jaxpr
|
||||
from jax._src.interpreters.partial_eval import dce_jaxpr
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
|
||||
from jax._src.util import foreach
|
||||
from jax._src.shard_map import shard_map, shard_map_p
|
||||
|
||||
|
||||
ShapeDtypeStructTree = Any
|
||||
Specs = Any
|
||||
ValidRooflineDtype = np.dtype | prng.KeyTy
|
||||
|
||||
map = util.safe_map
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True, kw_only=True)
|
||||
class RooflineRuleContext:
|
||||
name_stack: source_info_util.NameStack
|
||||
primitive: core.Primitive
|
||||
avals_in: Sequence[core.AbstractValue]
|
||||
avals_out: Sequence[core.AbstractValue]
|
||||
jaxpr_eqn_ctx: core.JaxprEqnContext
|
||||
mesh: Mesh | AbstractMesh | None
|
||||
pin_lhs_in_vmem: bool
|
||||
pin_rhs_in_vmem: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True, kw_only=True)
|
||||
class RooflineShape:
|
||||
shape: tuple[int, ...]
|
||||
dtype: ValidRooflineDtype
|
||||
|
||||
@classmethod
|
||||
def from_aval(cls, aval: core.AbstractValue) -> RooflineShape:
|
||||
if not isinstance(aval, core.ShapedArray):
|
||||
raise TypeError(f"Expected ShapedArray, got {type(aval)}.")
|
||||
if not isinstance(aval.dtype, ValidRooflineDtype):
|
||||
raise TypeError(
|
||||
f"Expected numpy or prng.KeyTy dtype, got {type(aval.dtype)}."
|
||||
)
|
||||
return cls(shape=aval.shape, dtype=aval.dtype)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return int(np.prod(self.shape))
|
||||
|
||||
@property
|
||||
def bytes(self) -> int:
|
||||
return int(self.size * self.dtype.itemsize)
|
||||
|
||||
@classmethod
|
||||
def total_bytes(cls, avals: Sequence[core.AbstractValue]) -> int:
|
||||
return sum(cls.from_aval(aval).bytes for aval in avals)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True, kw_only=True)
|
||||
class RooflineResult:
|
||||
flops: int = 0
|
||||
unfused_flops: int = 0
|
||||
ici_bytes: dict[str, int] = field(default_factory=dict)
|
||||
ici_latency: dict[str, int] = field(default_factory=dict)
|
||||
hbm_bytes: int = 0
|
||||
peak_hbm_bytes: int = 0
|
||||
unfused_hbm_bytes: int = 0
|
||||
|
||||
@classmethod
|
||||
def zeros(cls) -> RooflineResult:
|
||||
return cls()
|
||||
|
||||
def __add__(self, other: RooflineResult) -> RooflineResult:
|
||||
def merge_ici_dicts(d1: dict[str, int], d2: dict[str, int]) -> dict[str, int]:
|
||||
return {k: d1.get(k, 0) + d2.get(k, 0) for k in set(d1) | set(d2)}
|
||||
|
||||
return RooflineResult(
|
||||
flops=self.flops + other.flops,
|
||||
unfused_flops=self.unfused_flops + other.unfused_flops,
|
||||
ici_bytes=merge_ici_dicts(self.ici_bytes, other.ici_bytes),
|
||||
ici_latency=merge_ici_dicts(self.ici_latency, other.ici_latency),
|
||||
hbm_bytes=self.hbm_bytes + other.hbm_bytes,
|
||||
peak_hbm_bytes=max(self.peak_hbm_bytes, other.peak_hbm_bytes),
|
||||
unfused_hbm_bytes=self.unfused_hbm_bytes + other.unfused_hbm_bytes,
|
||||
)
|
||||
|
||||
def __mul__(self, constant: int | float) -> RooflineResult:
|
||||
return RooflineResult(
|
||||
flops=int(self.flops * constant),
|
||||
unfused_flops=int(self.unfused_flops * constant),
|
||||
ici_bytes={k: int(v * constant) for k, v in self.ici_bytes.items()},
|
||||
ici_latency={k: int(v * constant) for k, v in self.ici_latency.items()},
|
||||
hbm_bytes=int(self.hbm_bytes * constant),
|
||||
peak_hbm_bytes=int(self.peak_hbm_bytes * constant),
|
||||
unfused_hbm_bytes=int(self.unfused_hbm_bytes * constant),
|
||||
)
|
||||
|
||||
def __rmul__(self, constant: int | float) -> RooflineResult:
|
||||
return self.__mul__(constant)
|
||||
|
||||
|
||||
class _RooflineRule(Protocol):
|
||||
def __call__(
|
||||
self, ctx: RooflineRuleContext, *args: RooflineShape, **kw
|
||||
) -> RooflineResult: ...
|
||||
|
||||
|
||||
_rooflines: dict[core.Primitive, _RooflineRule] = {}
|
||||
|
||||
|
||||
def _roofline_interpreter(
|
||||
f_name: str,
|
||||
jaxpr: core.Jaxpr,
|
||||
mesh: Mesh | AbstractMesh | None,
|
||||
*,
|
||||
pin_lhs_in_vmem: bool = False,
|
||||
pin_rhs_in_vmem: bool = False,
|
||||
) -> RooflineResult:
|
||||
name_stack = source_info_util.new_name_stack(util.wrap_name("roofline", f_name))
|
||||
|
||||
result = RooflineResult.zeros()
|
||||
|
||||
env: dict[core.Var, RooflineShape] = {}
|
||||
|
||||
def write(v: core.Var, node: RooflineShape):
|
||||
assert node is not None
|
||||
env[v] = node
|
||||
|
||||
def read(v: core.Atom) -> RooflineShape:
|
||||
if type(v) is core.Literal:
|
||||
return RooflineShape.from_aval(core.typeof(v.val))
|
||||
else:
|
||||
assert isinstance(v, core.Var)
|
||||
return env[v]
|
||||
|
||||
def aval(v: core.Atom) -> core.AbstractValue:
|
||||
if type(v) is core.Literal:
|
||||
return core.typeof(v.val)
|
||||
else:
|
||||
return v.aval
|
||||
|
||||
def sum_bytes(shapes: Sequence[RooflineShape]) -> int:
|
||||
return sum(shape.bytes for shape in shapes)
|
||||
|
||||
jaxpr = jaxpr.jaxpr if isinstance(jaxpr, core.ClosedJaxpr) else jaxpr
|
||||
make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x))
|
||||
foreach(
|
||||
write,
|
||||
jaxpr.constvars,
|
||||
map(make_roofline_shape, jaxpr.constvars),
|
||||
)
|
||||
foreach(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
|
||||
last_used = core.last_used(jaxpr)
|
||||
|
||||
current_hbm_bytes = sum_bytes(list(env.values()))
|
||||
peak_hbm_bytes = current_hbm_bytes
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
source_info = eqn.source_info.replace(
|
||||
name_stack=name_stack + eqn.source_info.name_stack
|
||||
)
|
||||
with source_info_util.user_context(
|
||||
eqn.source_info.traceback, name_stack=source_info.name_stack
|
||||
):
|
||||
if "jaxpr" in eqn.params:
|
||||
result += _roofline_interpreter(
|
||||
util.wrap_name(eqn.primitive.name, f_name),
|
||||
eqn.params["jaxpr"],
|
||||
mesh,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
)
|
||||
elif "call_jaxpr" in eqn.params:
|
||||
# Used for custom_jvp_call_p. Recursively calculates roofline result for
|
||||
# all primitives in the custom function.
|
||||
result += _roofline_interpreter(
|
||||
util.wrap_name(eqn.primitive.name, f_name),
|
||||
eqn.params['call_jaxpr'],
|
||||
mesh,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
)
|
||||
elif eqn.primitive not in _rooflines:
|
||||
msg = f"No roofline rule for {eqn.primitive}, skipping..."
|
||||
for attr in dir(eqn):
|
||||
if not attr.startswith("_"):
|
||||
msg += f"\n{attr}: {getattr(eqn, attr)}"
|
||||
logging.warning(msg)
|
||||
else:
|
||||
rule = _rooflines[eqn.primitive]
|
||||
result += rule(
|
||||
RooflineRuleContext(
|
||||
name_stack=source_info.name_stack,
|
||||
primitive=eqn.primitive,
|
||||
avals_in=map(aval, eqn.invars),
|
||||
avals_out=map(aval, eqn.outvars),
|
||||
jaxpr_eqn_ctx=eqn.ctx,
|
||||
mesh=mesh,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
),
|
||||
*map(read, eqn.invars),
|
||||
**eqn.params,
|
||||
)
|
||||
|
||||
# Add bytes for the newly-created output variables.
|
||||
outvar_shapes = map(make_roofline_shape, eqn.outvars)
|
||||
current_hbm_bytes += sum_bytes(outvar_shapes)
|
||||
foreach(write, eqn.outvars, outvar_shapes)
|
||||
|
||||
# Remove bytes for the no-longer-needed input variables.
|
||||
removed_shapes = [
|
||||
env[v] for v in eqn.invars
|
||||
if not isinstance(v, core.Literal) and last_used[v] is eqn
|
||||
]
|
||||
current_hbm_bytes -= sum_bytes(removed_shapes)
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
|
||||
peak_hbm_bytes = max(peak_hbm_bytes, current_hbm_bytes)
|
||||
|
||||
result += RooflineResult(peak_hbm_bytes=peak_hbm_bytes)
|
||||
return result
|
||||
|
||||
|
||||
def _f_with_vjp(f: Callable):
|
||||
@util.wraps(f)
|
||||
def wrapped(*args):
|
||||
primals, f_vjp = api.vjp(f, *args)
|
||||
return f_vjp(tree_map(jnp.bfloat16, primals))
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def roofline(
|
||||
f: Callable,
|
||||
mesh: Mesh | AbstractMesh | None = None,
|
||||
in_specs: Specs | None = None,
|
||||
out_specs: Specs | None = None,
|
||||
*,
|
||||
pin_lhs_in_vmem: bool = False,
|
||||
pin_rhs_in_vmem: bool = False,
|
||||
vjp: bool = False,
|
||||
print_jaxpr: bool = False,
|
||||
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult]]:
|
||||
@util.wraps(f)
|
||||
@traceback_util.api_boundary
|
||||
def wrapped(*args):
|
||||
wrapped_f = f
|
||||
if in_specs is not None and out_specs is not None and mesh is not None:
|
||||
wrapped_f = shard_map(wrapped_f, mesh=mesh, in_specs=in_specs,
|
||||
out_specs=out_specs)
|
||||
if vjp:
|
||||
wrapped_f = _f_with_vjp(wrapped_f)
|
||||
|
||||
jaxpr, out_shapes = make_jaxpr(wrapped_f, return_shape=True)(*args)
|
||||
|
||||
def make_sharded_shape_dtype_struct(
|
||||
shape: api.ShapeDtypeStruct, out_spec: Specs
|
||||
) -> api.ShapeDtypeStruct:
|
||||
assert mesh is not None
|
||||
return api.ShapeDtypeStruct(
|
||||
shape.shape, shape.dtype, sharding=NamedSharding(mesh, out_spec)
|
||||
)
|
||||
|
||||
if out_specs is not None and mesh is not None:
|
||||
out_specs_flat = broadcast_prefix(out_specs, out_shapes)
|
||||
flat_out_shapes, treedef = tree_flatten(out_shapes)
|
||||
flat_out_shapes = map(
|
||||
make_sharded_shape_dtype_struct, flat_out_shapes, out_specs_flat
|
||||
)
|
||||
out_shapes = tree_unflatten(treedef, flat_out_shapes)
|
||||
|
||||
used_outputs = (True,) * len(jaxpr.jaxpr.outvars)
|
||||
jaxpr, _ = dce_jaxpr(jaxpr.jaxpr, used_outputs)
|
||||
shard_map_eqns = [
|
||||
e for e in jaxpr.eqns if e.primitive == shard_map_p
|
||||
]
|
||||
if shard_map_eqns:
|
||||
try:
|
||||
jaxpr = shard_map_eqns[-1].params["jaxpr"]
|
||||
except KeyError:
|
||||
raise ValueError(f"Missing shard_map jaxpr in {jaxpr}.")
|
||||
|
||||
if print_jaxpr:
|
||||
print(jaxpr)
|
||||
|
||||
return out_shapes, _roofline_interpreter(
|
||||
util.fun_qual_name(f),
|
||||
jaxpr,
|
||||
mesh,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def register_roofline(prim: core.Primitive):
|
||||
def register(rule: _RooflineRule):
|
||||
_rooflines[prim] = rule
|
||||
return rule
|
||||
|
||||
return register
|
||||
|
||||
|
||||
def register_standard_roofline(prim: core.Primitive):
|
||||
def standard_rule(ctx: RooflineRuleContext, *args, **kwargs):
|
||||
return RooflineResult.zeros()
|
||||
|
||||
_rooflines[prim] = standard_rule
|
||||
|
||||
|
||||
def roofline_and_grad(
|
||||
f: Callable,
|
||||
mesh: Mesh | AbstractMesh,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
*,
|
||||
pin_lhs_in_vmem: bool = False,
|
||||
pin_rhs_in_vmem: bool = False,
|
||||
print_jaxpr: bool = False,
|
||||
) -> Callable[..., tuple[ShapeDtypeStructTree, RooflineResult, RooflineResult]]:
|
||||
@util.wraps(f)
|
||||
@traceback_util.api_boundary
|
||||
def wrapped(*args):
|
||||
primal_shapes, fwd_result = roofline(
|
||||
f,
|
||||
mesh,
|
||||
in_specs,
|
||||
out_specs,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
print_jaxpr=print_jaxpr,
|
||||
)(*args)
|
||||
|
||||
return (
|
||||
primal_shapes,
|
||||
fwd_result,
|
||||
roofline(
|
||||
f,
|
||||
mesh,
|
||||
in_specs,
|
||||
out_specs,
|
||||
pin_lhs_in_vmem=pin_lhs_in_vmem,
|
||||
pin_rhs_in_vmem=pin_rhs_in_vmem,
|
||||
vjp=True,
|
||||
print_jaxpr=print_jaxpr,
|
||||
)(
|
||||
*tree_map(
|
||||
lambda x: api.ShapeDtypeStruct(
|
||||
x.shape,
|
||||
jnp.int32 if x.dtype == jnp.int32 else jnp.bfloat16,
|
||||
sharding=x.sharding,
|
||||
),
|
||||
args,
|
||||
)
|
||||
)[1],
|
||||
)
|
||||
|
||||
return wrapped
|
||||
@@ -0,0 +1,822 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import defaultdict
|
||||
from dataclasses import replace
|
||||
import itertools as it
|
||||
from collections.abc import Sequence
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import ad_util
|
||||
from jax._src import core, util
|
||||
from jax._src import dispatch
|
||||
from jax._src import ops
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import random
|
||||
from jax._src import shard_map
|
||||
from jax._src import callback
|
||||
from jax._src import debugging
|
||||
from jax._src.lax import (
|
||||
ann,
|
||||
control_flow,
|
||||
convolution,
|
||||
fft,
|
||||
lax,
|
||||
linalg,
|
||||
parallel as lax_parallel,
|
||||
slicing,
|
||||
special,
|
||||
windowed_reductions,
|
||||
)
|
||||
from jax.experimental import roofline
|
||||
|
||||
# One FMA (Fused Multiply Add) takes 2 flops to compute.
|
||||
_FMA_FLOPS_FACTOR = 2
|
||||
|
||||
for prim in it.chain(
|
||||
ad_checkpoint.__dict__.values(),
|
||||
ad_util.__dict__.values(),
|
||||
ann.__dict__.values(),
|
||||
callback.__dict__.values(),
|
||||
control_flow.__dict__.values(),
|
||||
convolution.__dict__.values(),
|
||||
dispatch.__dict__.values(),
|
||||
fft.__dict__.values(),
|
||||
lax.__dict__.values(),
|
||||
linalg.__dict__.values(),
|
||||
ops.__dict__.values(),
|
||||
[pjit.sharding_constraint_p],
|
||||
prng.__dict__.values(),
|
||||
random.__dict__.values(),
|
||||
shard_map.__dict__.values(),
|
||||
slicing.__dict__.values(),
|
||||
special.__dict__.values(),
|
||||
windowed_reductions.__dict__.values(),
|
||||
):
|
||||
if isinstance(prim, core.Primitive):
|
||||
roofline.register_standard_roofline(prim)
|
||||
|
||||
|
||||
def _unary_p_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
return roofline.RooflineResult(
|
||||
unfused_flops=x.size,
|
||||
unfused_hbm_bytes=(
|
||||
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
|
||||
),
|
||||
)
|
||||
|
||||
roofline.register_roofline(lax.abs_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.acos_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.asin_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.atan_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.cbrt_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.ceil_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.conj_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.cos_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.cosh_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.exp_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.expm1_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.floor_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.imag_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.integer_pow_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.is_finite_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.log_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.log1p_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.logistic_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.neg_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.not_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.real_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.round_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.rsqrt_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.sign_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.sin_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.sinh_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.sqrt_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.square_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(lax.tan_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.bessel_i0e_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.bessel_i1e_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.digamma_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.erf_inv_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.erf_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.erfc_p)(_unary_p_roofline)
|
||||
roofline.register_roofline(special.lgamma_p)(_unary_p_roofline)
|
||||
|
||||
roofline.register_standard_roofline(core.pvary_p)
|
||||
|
||||
def _binary_p_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
broadcasted_shape = [
|
||||
max(l, r) for l, r in it.zip_longest(lhs.shape, rhs.shape, fillvalue=1)
|
||||
]
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
return roofline.RooflineResult(
|
||||
unfused_flops=int(np.prod(broadcasted_shape)),
|
||||
unfused_hbm_bytes=(
|
||||
lhs.dtype.itemsize * lhs.size
|
||||
+ rhs.dtype.itemsize * rhs.size
|
||||
+ out.dtype.itemsize * out.size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
roofline.register_roofline(lax.add_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.sub_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.mul_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.div_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.rem_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.and_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.or_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.xor_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.gt_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.lt_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.ge_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.le_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.eq_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.ne_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.min_p)(_binary_p_roofline)
|
||||
roofline.register_roofline(lax.max_p)(_binary_p_roofline)
|
||||
|
||||
def _cumulative_p_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axis: int,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
return roofline.RooflineResult(
|
||||
# `cum{max, min, prod, sum}` only calculate values for one axis.
|
||||
unfused_flops=x.shape[axis],
|
||||
unfused_hbm_bytes=(
|
||||
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
|
||||
),
|
||||
)
|
||||
|
||||
roofline.register_roofline(control_flow.cummax_p)(_cumulative_p_roofline)
|
||||
roofline.register_roofline(control_flow.cummin_p)(_cumulative_p_roofline)
|
||||
roofline.register_roofline(control_flow.cumprod_p)(_cumulative_p_roofline)
|
||||
roofline.register_roofline(control_flow.cumsum_p)(_cumulative_p_roofline)
|
||||
|
||||
@roofline.register_roofline(control_flow.cumlogsumexp_p)
|
||||
def _cumlogsumexp_p_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axis: int,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
return roofline.RooflineResult(
|
||||
# Similar to `cum{max, min, prod, sum}`, `cumlogsumexp` only calculates
|
||||
# values for one axis. But for `x.shape[axis] = S`, it computes (for a
|
||||
# naive implementation):
|
||||
# S `exp` ops.
|
||||
# S-1 `add` ops.
|
||||
# 1 log op.
|
||||
# Thus, the total number of flops is 2 * S.
|
||||
unfused_flops=x.shape[axis] * 2,
|
||||
unfused_hbm_bytes=(
|
||||
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@roofline.register_roofline(lax.dot_general_p)
|
||||
def _dot_general_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
dimension_numbers: lax.DotDimensionNumbers,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
(lhs_contract, _), (lhs_batch, _) = dimension_numbers
|
||||
|
||||
flops = (
|
||||
_FMA_FLOPS_FACTOR
|
||||
* lhs.size
|
||||
* rhs.size
|
||||
/ np.prod([lhs.shape[i] for i in lhs_contract])
|
||||
/ np.prod([lhs.shape[i] for i in lhs_batch])
|
||||
)
|
||||
|
||||
hbm_bytes = 0
|
||||
if not ctx.pin_lhs_in_vmem:
|
||||
hbm_bytes += lhs.bytes
|
||||
hbm_bytes += out.bytes
|
||||
if not ctx.pin_rhs_in_vmem:
|
||||
hbm_bytes += rhs.bytes
|
||||
|
||||
return roofline.RooflineResult(
|
||||
flops=int(flops),
|
||||
unfused_flops=int(flops),
|
||||
hbm_bytes=hbm_bytes,
|
||||
unfused_hbm_bytes=hbm_bytes,
|
||||
)
|
||||
|
||||
|
||||
def _get_spatial_valid_position_count_for_one_dim(
|
||||
window_dim_stride: int,
|
||||
base_dilation: int,
|
||||
window_dilation: int,
|
||||
kernel_limit: int,
|
||||
input_limit: int,
|
||||
output_limit: int,
|
||||
padding: tuple[int, int],
|
||||
) -> int:
|
||||
"""Gets the valid position count for conv for a single spatial dimension.
|
||||
|
||||
Args:
|
||||
window_dim_stride: The stride of the window along this dimension.
|
||||
base_dilation: The base dilation factor along this dimension.
|
||||
window_dilation: The window dilation factor along this dimension.
|
||||
kernel_limit: The size of the kernel along this dimension.
|
||||
input_limit: The size of the input along this dimension.
|
||||
output_limit: The size of the output along this dimension.
|
||||
padding: The padding applied to the input along this dimension.
|
||||
"""
|
||||
padding_low = padding[0]
|
||||
padding_high = padding[1]
|
||||
|
||||
# These two conditions will create an N^2 iteration pattern with only N
|
||||
# valid elements. This is a performance optimization and produces the same
|
||||
# result as the whole loop.
|
||||
if (
|
||||
input_limit == output_limit
|
||||
and kernel_limit == output_limit
|
||||
and input_limit == base_dilation
|
||||
and window_dilation == 1
|
||||
and max(1, input_limit - 1) == window_dim_stride
|
||||
and padding_low == 0
|
||||
and padding_high == 0
|
||||
):
|
||||
return input_limit
|
||||
|
||||
if (
|
||||
input_limit == 1
|
||||
and kernel_limit == output_limit
|
||||
and window_dilation == 1
|
||||
and base_dilation == 1
|
||||
and window_dim_stride == 1
|
||||
and padding_low == output_limit - 1
|
||||
and padding_high == output_limit - 1
|
||||
):
|
||||
return output_limit
|
||||
|
||||
valid_position_count = 0
|
||||
# Loop over each point in the kernel
|
||||
for kernel_idx in range(kernel_limit):
|
||||
|
||||
# Skip loop for trivial stride and base_dilation
|
||||
if window_dim_stride == 1 and base_dilation == 1:
|
||||
undilated_index_base = padding_low - kernel_idx * window_dilation
|
||||
upper_limit = min(
|
||||
input_limit + undilated_index_base,
|
||||
output_limit,
|
||||
)
|
||||
lower_limit = max(0, undilated_index_base)
|
||||
|
||||
valid_position_count += max(upper_limit - lower_limit, 0)
|
||||
continue
|
||||
|
||||
# Loop over each point in the output
|
||||
for output_idx in range(output_limit):
|
||||
# Calculate lhs (input) index without taking base dilation into account
|
||||
undilated_index = (
|
||||
output_idx * window_dim_stride
|
||||
- padding_low
|
||||
+ kernel_idx * window_dilation
|
||||
)
|
||||
# Calculate the actual lhs (input) index after dilation
|
||||
lhs_spatial_index = int(undilated_index / base_dilation)
|
||||
|
||||
# Skip if the lhs (input) index is to be dilated.
|
||||
if undilated_index != lhs_spatial_index * base_dilation:
|
||||
continue
|
||||
# Skip if input index is not in bound.
|
||||
if lhs_spatial_index < 0 or lhs_spatial_index >= input_limit:
|
||||
continue
|
||||
|
||||
valid_position_count += 1
|
||||
return valid_position_count
|
||||
|
||||
|
||||
def _get_spatial_valid_position_count(
|
||||
dnums: convolution.ConvDimensionNumbers,
|
||||
lhs: roofline.RooflineShape,
|
||||
rhs: roofline.RooflineShape,
|
||||
out: roofline.RooflineShape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
lhs_dilation: Sequence[int],
|
||||
rhs_dilation: Sequence[int],
|
||||
) -> int:
|
||||
"""Gets the number of valid spatial positions for conv_general_dilated.
|
||||
|
||||
Args:
|
||||
dnums: The dimension numbers for the convolution.
|
||||
lhs: The shape of the left-hand side of the convolution.
|
||||
rhs: The shape of the right-hand side of the convolution.
|
||||
out: The shape of the output of the convolution.
|
||||
window_strides: The stride of the window along each spatial dimension.
|
||||
padding: The padding applied to the input along each spatial dimension.
|
||||
lhs_dilation: The dilation factor for the left-hand side along each spatial
|
||||
dimension.
|
||||
rhs_dilation: The dilation factor for the right-hand side along each spatial
|
||||
dimension.
|
||||
"""
|
||||
input_spatial_dims, kernel_spatial_dims, out_spatial_dims = (
|
||||
dnums.lhs_spec[2:],
|
||||
dnums.rhs_spec[2:],
|
||||
dnums.out_spec[2:],
|
||||
)
|
||||
|
||||
valid_position_counts = 1
|
||||
# Loop over each spatial dimension and determine how many valid positions
|
||||
# there are for each dimension.
|
||||
for d in range(len(input_spatial_dims)):
|
||||
valid_position_counts *= _get_spatial_valid_position_count_for_one_dim(
|
||||
window_dim_stride=window_strides[d],
|
||||
base_dilation=lhs_dilation[d],
|
||||
window_dilation=rhs_dilation[d],
|
||||
kernel_limit=rhs.shape[kernel_spatial_dims[d]],
|
||||
input_limit=lhs.shape[input_spatial_dims[d]],
|
||||
output_limit=out.shape[out_spatial_dims[d]],
|
||||
padding=padding[d],
|
||||
)
|
||||
|
||||
return valid_position_counts
|
||||
|
||||
|
||||
def _calculate_conv_flops(
|
||||
lhs: roofline.RooflineShape,
|
||||
rhs: roofline.RooflineShape,
|
||||
out: roofline.RooflineShape,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
lhs_dilation: Sequence[int],
|
||||
rhs_dilation: Sequence[int],
|
||||
dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers,
|
||||
batch_group_count: int,
|
||||
) -> int:
|
||||
"""Calculates roofline unfused flops for Jax's conv_general_dilated primitive.
|
||||
|
||||
See `jax.lax.conv_general_dilated` for details on the arguments.
|
||||
"""
|
||||
dnums = convolution.conv_dimension_numbers(
|
||||
lhs.shape, rhs.shape, dimension_numbers
|
||||
)
|
||||
|
||||
spatial_valid_position_counts = _get_spatial_valid_position_count(
|
||||
dnums, lhs, rhs, out, window_strides, padding, lhs_dilation, rhs_dilation
|
||||
)
|
||||
|
||||
batch = lhs.shape[dnums.lhs_spec[0]]
|
||||
num_output_features = out.shape[dnums.out_spec[1]]
|
||||
num_input_features = rhs.shape[dnums.rhs_spec[1]]
|
||||
num_output_batch = batch / batch_group_count
|
||||
|
||||
non_spatial_dims_factor = (
|
||||
num_input_features * num_output_features * num_output_batch
|
||||
)
|
||||
|
||||
fma_count = non_spatial_dims_factor * spatial_valid_position_counts
|
||||
flops = fma_count * _FMA_FLOPS_FACTOR
|
||||
return int(flops)
|
||||
|
||||
|
||||
@roofline.register_roofline(convolution.conv_general_dilated_p)
|
||||
def _conv_general_dilated_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
window_strides: Sequence[int],
|
||||
padding: Sequence[tuple[int, int]],
|
||||
lhs_dilation: Sequence[int],
|
||||
rhs_dilation: Sequence[int],
|
||||
dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers,
|
||||
batch_group_count: int,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
"""Roofline for Jax's conv_general_dilated primitive.
|
||||
|
||||
See `jax.lax.conv_general_dilated` for details on the arguments.
|
||||
"""
|
||||
lhs, rhs = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
|
||||
return roofline.RooflineResult(
|
||||
unfused_flops=_calculate_conv_flops(
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
window_strides,
|
||||
padding,
|
||||
lhs_dilation,
|
||||
rhs_dilation,
|
||||
dimension_numbers,
|
||||
batch_group_count,
|
||||
),
|
||||
unfused_hbm_bytes=(
|
||||
lhs.dtype.itemsize * lhs.size
|
||||
+ rhs.dtype.itemsize * rhs.size
|
||||
+ out.dtype.itemsize * out.size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _return_zeros_if_one_sized_axis(
|
||||
ctx: roofline.RooflineRuleContext, axes: tuple[str, ...]
|
||||
) -> roofline.RooflineResult | None:
|
||||
assert ctx.mesh
|
||||
axes_size = np.prod([ctx.mesh.shape[axis] for axis in axes])
|
||||
if axes_size > 1:
|
||||
return None
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes={axis: 0 for axis in axes},
|
||||
ici_latency={axis: 0 for axis in axes},
|
||||
)
|
||||
|
||||
|
||||
def _ring_collective_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axes: tuple[str, ...],
|
||||
is_reduce: bool = True,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axes):
|
||||
return zeros_result
|
||||
|
||||
assert ctx.mesh
|
||||
mesh = ctx.mesh.shape
|
||||
current_shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in)
|
||||
if is_reduce:
|
||||
current_shard_size /= np.prod([mesh[axis] for axis in axes])
|
||||
|
||||
# We model the slowest color as the bottleneck.
|
||||
sorted_axes = sorted(axes, key=lambda x: mesh[x], reverse=True)
|
||||
num_axes = len(sorted_axes)
|
||||
|
||||
ici_bytes = 0
|
||||
# Phase split.
|
||||
current_shard_size //= num_axes
|
||||
for axis in sorted_axes:
|
||||
axis_size = mesh[axis]
|
||||
# Do phase.
|
||||
ici_bytes += current_shard_size * (axis_size - 1)
|
||||
# Increase shard size.
|
||||
current_shard_size *= axis_size
|
||||
|
||||
# Bottleneck is the longest axis.
|
||||
ici_latency = mesh[sorted_axes[0]] * num_axes
|
||||
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes={axis: int(ici_bytes) for axis in sorted_axes},
|
||||
ici_latency={axis: int(ici_latency) for axis in sorted_axes},
|
||||
)
|
||||
|
||||
|
||||
roofline.register_roofline(lax_parallel.reduce_scatter_p)(
|
||||
lambda *args, axis_name, **kw: _ring_collective_roofline(*args, axes=axis_name, **kw)
|
||||
)
|
||||
roofline.register_roofline(lax_parallel.all_gather_p)(
|
||||
lambda *args, axis_name, **kw: _ring_collective_roofline(
|
||||
*args, axes=axis_name, is_reduce=False, **kw
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _calculate_gather_flops(
|
||||
mode: slicing.GatherScatterMode,
|
||||
indices_size: int,
|
||||
output_size: int,
|
||||
) -> int:
|
||||
"""Calculates roofline unfused flops for Jax's gather primitive."""
|
||||
|
||||
if mode == slicing.GatherScatterMode.FILL_OR_DROP:
|
||||
# With FILL_OR_DROP, we have 4 steps to check whether to fill (or drop):
|
||||
# 1. Check if the index is within upper bound.
|
||||
# 2. Check if the index is within lower bound.
|
||||
# 3. Call `and` on #1 and #2 to check the index is "in bounds".
|
||||
# 4. `reduce` the result to a single boolean per window.
|
||||
# Each of the steps is a single elementwise op on the indices.
|
||||
index_check_flops = indices_size * 4
|
||||
|
||||
# Once we know whether to fill or drop (per window), there are 2 steps to
|
||||
# mask the output:
|
||||
# 1. Broadcast the per-window boolean to the output shape.
|
||||
# 2. Choose whether to fill (from `operand`) if in-bounds, or drop if
|
||||
# out-of-bounds.
|
||||
# Broadcasting is free, but choosing whether to fill or drop involves an
|
||||
# elementwise op the size of the output.
|
||||
output_mask_flops = output_size
|
||||
return index_check_flops + output_mask_flops
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@roofline.register_roofline(slicing.gather_p)
|
||||
def _gather_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
mode: slicing.GatherScatterMode,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
_, indices = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
|
||||
# Gather doesn't read the whole input buffer, it's equivalent to a copy the
|
||||
# size of the output shape and a read of the gather indices.
|
||||
unfused_hbm_bytes = (
|
||||
out.dtype.itemsize * out.size * 2 + indices.dtype.itemsize * indices.size
|
||||
)
|
||||
|
||||
return roofline.RooflineResult(
|
||||
unfused_flops=_calculate_gather_flops(mode, indices.size, out.size),
|
||||
unfused_hbm_bytes=unfused_hbm_bytes,
|
||||
)
|
||||
|
||||
|
||||
def _scatter_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
"""Roofline for Jax's `scatter*` primitives.
|
||||
|
||||
The `scatter` functionality itself is a simple data read and write, which
|
||||
contributes 0 flops.
|
||||
|
||||
But, the jaxpr for each `scatter*` function (aside from `jax.lax.scatter`)
|
||||
contains an `update_jaxpr` that gets applied to the operand & scattered
|
||||
updates (e.g. `add` for `scatter_add`, or arbitrary unary function for
|
||||
`scatter_apply`), which *does* contribute flops. This `update_jaxpr` gets
|
||||
applied to every element of the scattered updates.
|
||||
|
||||
Thus,
|
||||
flops = [# flops for `update_jaxpr` per element] * [# elements in `updates`].
|
||||
|
||||
To calculate # flops for `update_jaxpr`, we convert the `update_jaxpr` back to
|
||||
a callable, and then call `roofline` on that callable. `update_jaxpr` does not
|
||||
contain any information about input shapes or dtypes; it expects scalars. It
|
||||
will therefore give us a # flops-per-element result, which we multiply by
|
||||
the size of the updates to get the total flops.
|
||||
"""
|
||||
(_, indices, updates) = (
|
||||
roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in
|
||||
)
|
||||
|
||||
update_jaxpr = kw.get('update_jaxpr')
|
||||
|
||||
flops = 0
|
||||
if update_jaxpr:
|
||||
update_fn = lambda *inputs: core.eval_jaxpr(update_jaxpr, [], *inputs)
|
||||
# Create dummy scalar inputs.
|
||||
dummy_inputs = [
|
||||
api.ShapeDtypeStruct((), updates.dtype) for _ in update_jaxpr.invars
|
||||
]
|
||||
# Calculate the flops for the `update_jaxpr` on scalar inputs.
|
||||
_, roofline_result = roofline.roofline(update_fn)(*dummy_inputs)
|
||||
# Multiply by the size of the updates to get the total flops.
|
||||
flops = roofline_result.unfused_flops * updates.size
|
||||
|
||||
return roofline.RooflineResult(
|
||||
unfused_flops=flops,
|
||||
# Scatter accesses the equivalent of 3N update shapes (input, output, and
|
||||
# updates), and the scatter indices.
|
||||
unfused_hbm_bytes=(
|
||||
3 * updates.dtype.itemsize * updates.size
|
||||
+ indices.dtype.itemsize * indices.size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
roofline.register_roofline(slicing.scatter_add_p)(_scatter_roofline)
|
||||
roofline.register_roofline(slicing.scatter_max_p)(_scatter_roofline)
|
||||
roofline.register_roofline(slicing.scatter_min_p)(_scatter_roofline)
|
||||
roofline.register_roofline(slicing.scatter_mul_p)(_scatter_roofline)
|
||||
roofline.register_roofline(slicing.scatter_sub_p)(_scatter_roofline)
|
||||
# Also registers `jax.lax.scatter_apply`, which uses the `scatter_p` primitive.
|
||||
roofline.register_roofline(slicing.scatter_p)(_scatter_roofline)
|
||||
|
||||
def _scalar_collective_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axes: tuple[str, ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
shapes = [roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in]
|
||||
ctx = replace(ctx, avals_in=[core.ShapedArray((1,), shape.dtype) for shape in shapes])
|
||||
return _ring_collective_roofline(ctx, *args, axes=axes, is_reduce=False, **kw)
|
||||
|
||||
|
||||
roofline.register_roofline(lax_parallel.pmin_p)(_scalar_collective_roofline)
|
||||
roofline.register_roofline(lax_parallel.pmax_p)(_scalar_collective_roofline)
|
||||
|
||||
|
||||
@roofline.register_roofline(lax_parallel.psum_invariant_p)
|
||||
def _psum2_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axes: tuple[str, ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
ring_roofline = _ring_collective_roofline(ctx, *args, axes=axes, **kw)
|
||||
|
||||
def double_dict(d: dict[str, int]) -> dict[str, int]:
|
||||
return {k: v * 2 for k, v in d.items()}
|
||||
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes=double_dict(ring_roofline.ici_bytes),
|
||||
ici_latency=double_dict(ring_roofline.ici_latency),
|
||||
)
|
||||
|
||||
|
||||
@roofline.register_roofline(lax_parallel.all_to_all_p)
|
||||
def _all_to_all_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axis_name: tuple[str, ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name):
|
||||
return zeros_result
|
||||
|
||||
assert ctx.mesh
|
||||
mesh = ctx.mesh.shape
|
||||
size = roofline.RooflineShape.total_bytes(ctx.avals_in) * np.prod([
|
||||
mesh[axis] for axis in axis_name
|
||||
])
|
||||
|
||||
smallest_axis = sorted(axis_name, key=lambda x: mesh[x])[0]
|
||||
num_axes = len(axis_name)
|
||||
bisection_bw = mesh[smallest_axis] ** (num_axes - 1)
|
||||
if mesh[smallest_axis] > 2:
|
||||
# Times 2 because of wraparound.
|
||||
bisection_bw *= 2
|
||||
|
||||
# Half the data needs to cross the bisection on average.
|
||||
ici_bytes = size / 2 / bisection_bw
|
||||
|
||||
# The latency is the max number of hops across the mesh.
|
||||
ici_latency = sum(mesh[axis] / 2 for axis in axis_name)
|
||||
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes={axis: int(ici_bytes) for axis in axis_name},
|
||||
ici_latency={axis: int(ici_latency) for axis in axis_name},
|
||||
)
|
||||
|
||||
|
||||
@roofline.register_roofline(lax_parallel.ppermute_p)
|
||||
def _ppermute_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axis_name: tuple[str, ...],
|
||||
perm: tuple[tuple[int, int], ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
if zeros_result := _return_zeros_if_one_sized_axis(ctx, axis_name):
|
||||
return zeros_result
|
||||
|
||||
assert ctx.mesh
|
||||
mesh = ctx.mesh.shape
|
||||
mesh_dims: list[int] = [mesh.get(axis, 1) for axis in axis_name]
|
||||
shard_size = roofline.RooflineShape.total_bytes(ctx.avals_in)
|
||||
|
||||
ici_contention: dict[tuple[tuple[int, ...], ...], float] = defaultdict(float)
|
||||
ici_latency = 0
|
||||
|
||||
for src, dst in perm:
|
||||
if src == dst:
|
||||
continue
|
||||
# Perms are linearized.
|
||||
src_coords = tuple(int(i) for i in np.unravel_index(src, mesh_dims))
|
||||
dst_coords = tuple(int(i) for i in np.unravel_index(dst, mesh_dims))
|
||||
|
||||
ici_latency_for_perm = 0
|
||||
|
||||
# For each dimension.
|
||||
for i in range(len(axis_name)):
|
||||
dim_size = mesh_dims[i]
|
||||
src_pos = src_coords[i]
|
||||
dst_pos = dst_coords[i]
|
||||
|
||||
if src_pos != dst_pos:
|
||||
# Calculate distance with wraparound.
|
||||
clockwise_dist = (dst_pos - src_pos) % dim_size
|
||||
counter_dist = (src_pos - dst_pos) % dim_size
|
||||
direction = 1 if clockwise_dist <= counter_dist else -1
|
||||
|
||||
curr_pos = src_pos
|
||||
while curr_pos != dst_pos:
|
||||
curr_coords = util.tuple_update(src_coords, i, curr_pos)
|
||||
next_pos = (curr_pos + direction) % dim_size
|
||||
next_coords = util.tuple_update(curr_coords, i, next_pos)
|
||||
ici_contention[tuple(sorted([curr_coords, next_coords]))] += 1
|
||||
curr_pos = next_pos
|
||||
|
||||
distance = min(clockwise_dist, counter_dist)
|
||||
ici_latency_for_perm += distance
|
||||
|
||||
ici_latency = max(ici_latency, ici_latency_for_perm)
|
||||
|
||||
ici_bytes = shard_size * max(ici_contention.values(), default=0)
|
||||
return roofline.RooflineResult(
|
||||
ici_bytes={axis: int(ici_bytes) for axis in axis_name},
|
||||
ici_latency={axis: int(ici_latency) for axis in axis_name},
|
||||
)
|
||||
|
||||
|
||||
@roofline.register_roofline(lax.reduce_sum_p)
|
||||
def _reduce_sum_p_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
axes: tuple[int, ...],
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
(x,) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
domain_size = np.prod([x.shape[i] for i in axes])
|
||||
other_axes = set(range(len(x.shape))) - set(axes)
|
||||
result_size = np.prod([x.shape[i] for i in other_axes])
|
||||
|
||||
return roofline.RooflineResult(
|
||||
# To add n values, we do n - 1 add operations, and we have to do that
|
||||
# for every element in the result.
|
||||
unfused_flops=int((domain_size - 1) * result_size),
|
||||
# Size of input, plus output. (We assume that the output is also used
|
||||
# as accumulator.)
|
||||
unfused_hbm_bytes=int(x.dtype.itemsize * (x.size + result_size)),
|
||||
)
|
||||
|
||||
@roofline.register_roofline(lax.select_n_p)
|
||||
def _select_n_p_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
(x, *_) = (roofline.RooflineShape.from_aval(aval) for aval in ctx.avals_in)
|
||||
out = roofline.RooflineShape.from_aval(ctx.avals_out[0])
|
||||
|
||||
return roofline.RooflineResult(
|
||||
unfused_flops=out.size,
|
||||
unfused_hbm_bytes=(
|
||||
x.dtype.itemsize * x.size + out.dtype.itemsize * out.size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@roofline.register_roofline(callback.pure_callback_p)
|
||||
@roofline.register_roofline(callback.io_callback_p)
|
||||
def _callback_with_output_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
avals_in = ctx.avals_in
|
||||
avals_out = ctx.avals_out
|
||||
# HBM bytes for transferring inputs to host and results back to device.
|
||||
hbm_bytes = roofline.RooflineShape.total_bytes(
|
||||
avals_in
|
||||
) + roofline.RooflineShape.total_bytes(avals_out)
|
||||
# We don't have access to the `callback_func`, so we assume it contributes 0
|
||||
# flops.
|
||||
return roofline.RooflineResult(unfused_hbm_bytes=hbm_bytes)
|
||||
|
||||
|
||||
@roofline.register_roofline(debugging.debug_callback_p)
|
||||
def _debug_callback_roofline(
|
||||
ctx: roofline.RooflineRuleContext,
|
||||
*args,
|
||||
**kw,
|
||||
) -> roofline.RooflineResult:
|
||||
avals_in = ctx.avals_in
|
||||
# `debug_callback` does not return values to the JAX program, so only input
|
||||
# HBM bytes are considered.
|
||||
hbm_bytes = roofline.RooflineShape.total_bytes(avals_in)
|
||||
# We don't have access to the `callback_func`, so we assume it contributes 0
|
||||
# flops.
|
||||
return roofline.RooflineResult(unfused_hbm_bytes=hbm_bytes)
|
||||
Reference in New Issue
Block a user