hand
This commit is contained in:
@@ -0,0 +1,940 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for JAX debugging primitives and related functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import copy
|
||||
from functools import partial
|
||||
import importlib.util
|
||||
import logging
|
||||
import string
|
||||
import sys
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import callback as cb
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import lax
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import shard_map
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, PartitionSpec as P, parse_flatten_op_sharding)
|
||||
from jax._src.state import discharge as state_discharge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DebugEffect(effects.Effect):
|
||||
__str__ = lambda self: "Debug"
|
||||
debug_effect = DebugEffect()
|
||||
|
||||
class OrderedDebugEffect(effects.Effect):
|
||||
__str__ = lambda self: "OrderedDebug"
|
||||
|
||||
ordered_debug_effect = OrderedDebugEffect()
|
||||
effects.ordered_effects.add_type(OrderedDebugEffect)
|
||||
effects.lowerable_effects.add_type(DebugEffect)
|
||||
effects.lowerable_effects.add_type(OrderedDebugEffect)
|
||||
effects.control_flow_allowed_effects.add_type(DebugEffect)
|
||||
effects.control_flow_allowed_effects.add_type(OrderedDebugEffect)
|
||||
effects.remat_allowed_effects.add_type(DebugEffect)
|
||||
effects.remat_allowed_effects.add_type(OrderedDebugEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(DebugEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(OrderedDebugEffect)
|
||||
effects.partial_eval_kept_effects.add_type(DebugEffect)
|
||||
effects.partial_eval_kept_effects.add_type(OrderedDebugEffect)
|
||||
|
||||
# `debug_callback_p` is the main primitive for staging out Python callbacks.
|
||||
debug_callback_p = core.Primitive('debug_callback')
|
||||
debug_callback_p.multiple_results = True
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
|
||||
@debug_callback_p.def_impl
|
||||
def debug_callback_impl(*args, callback: Callable[..., Any],
|
||||
effect: DebugEffect, partitioned: bool):
|
||||
del effect, partitioned
|
||||
try:
|
||||
cpu_device, *_ = xla_bridge.local_devices(backend="cpu")
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
"jax.debug.callback failed to find a local CPU device to place the"
|
||||
" inputs on. Make sure \"cpu\" is listed in --jax_platforms or the"
|
||||
" JAX_PLATFORMS environment variable."
|
||||
) from e
|
||||
args = api.device_put(args, cpu_device)
|
||||
with (config.default_device(cpu_device),
|
||||
sharding_impls._internal_use_concrete_mesh(mesh_lib.empty_concrete_mesh),
|
||||
mesh_lib.use_abstract_mesh(mesh_lib.empty_abstract_mesh)):
|
||||
try:
|
||||
callback(*args)
|
||||
except BaseException:
|
||||
logger.exception("jax.debug.callback failed")
|
||||
raise
|
||||
return ()
|
||||
|
||||
@debug_callback_p.def_effectful_abstract_eval
|
||||
def debug_callback_abstract_eval(*flat_avals, callback: Callable[..., Any],
|
||||
effect: DebugEffect, partitioned: bool):
|
||||
del flat_avals, callback, partitioned
|
||||
return [], {effect}
|
||||
|
||||
|
||||
def debug_batching_rule(args, dims, *, primitive, **params):
|
||||
"""Unrolls the debug callback across the mapped axis."""
|
||||
axis_size = next(x.shape[i] for x, i in zip(args, dims)
|
||||
if i is not None)
|
||||
# TODO(sharadmv): implement in terms of rolled loop unstead of unrolled.
|
||||
def get_arg_at_dim(i, dim, arg):
|
||||
if dim is batching.not_mapped:
|
||||
# Broadcast unmapped argument
|
||||
return arg
|
||||
return lax.index_in_dim(arg, i, axis=dim, keepdims=False)
|
||||
outs = []
|
||||
for i in range(axis_size):
|
||||
args_idx = map(partial(get_arg_at_dim, i), dims, args)
|
||||
outs.append(primitive.bind(*args_idx, **params))
|
||||
outs = [jnp.stack(xs) for xs in zip(*outs)]
|
||||
return outs, (0,) * len(outs)
|
||||
|
||||
|
||||
batching.primitive_batchers[debug_callback_p] = partial(
|
||||
debug_batching_rule, primitive=debug_callback_p
|
||||
)
|
||||
|
||||
def debug_callback_jvp_rule(primals, tangents, **params):
|
||||
return debug_callback_p.bind(*primals, **params), []
|
||||
ad.primitive_jvps[debug_callback_p] = debug_callback_jvp_rule
|
||||
|
||||
def debug_callback_transpose_rule(_, *flat_args, callback: Callable[..., Any],
|
||||
effect: DebugEffect, partitioned):
|
||||
del callback, effect, partitioned
|
||||
return [None for _ in flat_args]
|
||||
ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule
|
||||
|
||||
def _debug_callback_partial_auto(axis_context, *args, **params):
|
||||
partial_auto = list(set(axis_context.mesh.axis_names) - axis_context.manual_axes)
|
||||
def f():
|
||||
idx = lax.axis_index(*partial_auto)
|
||||
return lax.cond(idx == 0,
|
||||
lambda: debug_callback_p.bind(*args, **params),
|
||||
lambda: [])
|
||||
return shard_map.shard_map(f, in_specs=(), out_specs=[])()
|
||||
|
||||
def debug_callback_lowering(ctx, *args, effect, partitioned, callback, **params):
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
# We're a shard_map, which might be partial-manual or full-manual.
|
||||
partial_auto = set(axis_context.mesh.axis_names) - axis_context.manual_axes
|
||||
if partial_auto:
|
||||
# If we have partial manual / partial auto sharding, we gather and
|
||||
# conditionally run the callback.
|
||||
lower = partial(
|
||||
_debug_callback_partial_auto,
|
||||
axis_context,
|
||||
effect=effect,
|
||||
partitioned=partitioned,
|
||||
callback=callback,
|
||||
**params,
|
||||
)
|
||||
return mlir.lower_fun(lower)(ctx, *args)
|
||||
elif set(axis_context.manual_axes) == set(axis_context.mesh.axis_names):
|
||||
# If we have fully manual sharding during lowering, that means the JAX
|
||||
# program has per-device semantics, so we run the callback on each device.
|
||||
if config.use_shardy_partitioner.value:
|
||||
sharding = cb._get_sdy_array_list_for_callbacks(ctx.avals_out)
|
||||
else:
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MANUAL
|
||||
else:
|
||||
assert False # Unreachable
|
||||
elif isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
# If we have fully automatic sharding during lowering, that means the JAX
|
||||
# program has bulk array semantics, so we run the callback with a MAXIMAL
|
||||
# sharding and hence execute it only once on the full logical value).
|
||||
if config.use_shardy_partitioner.value:
|
||||
sharding = sharding_impls.SdyArrayList([
|
||||
sharding_impls.SdyArray(
|
||||
mesh_shape=(), dim_shardings=[], logical_device_ids=(0,))])
|
||||
else:
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
sharding.tile_assignment_dimensions = [1]
|
||||
sharding.tile_assignment_devices = [0]
|
||||
else:
|
||||
# When there's no SPMD partitioning going on, don't annotate a sharding.
|
||||
sharding = None
|
||||
|
||||
def _callback(*flat_args):
|
||||
debug_callback_p.impl(
|
||||
*flat_args,
|
||||
effect=effect,
|
||||
partitioned=partitioned,
|
||||
callback=callback,
|
||||
**params,
|
||||
)
|
||||
return ()
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token = ctx.tokens_in.get(effect)
|
||||
result, token, _ = cb.emit_python_callback(
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out,
|
||||
has_side_effect=True, returns_token=True, partitioned=partitioned)
|
||||
ctx.set_tokens_out(mlir.TokenSet({effect: token}))
|
||||
else:
|
||||
result, _, _ = cb.emit_python_callback(
|
||||
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out,
|
||||
has_side_effect=True, returns_token=True, partitioned=partitioned,
|
||||
sharding=sharding)
|
||||
return result
|
||||
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
|
||||
platform="cpu")
|
||||
mlir.register_lowering(
|
||||
debug_callback_p, debug_callback_lowering, platform="gpu")
|
||||
# Debug callbacks use channel IDs on TPU, which require non-caching.
|
||||
mlir.register_lowering(
|
||||
debug_callback_p, debug_callback_lowering, platform="tpu",
|
||||
cacheable=False)
|
||||
|
||||
|
||||
def _debug_partial_eval_custom(saveable, unks_in, inst_in, eqn, primitive):
|
||||
# The default behavior for effectful primitives is to not stage them if
|
||||
# possible. For debug callback, we actually want it to be staged to
|
||||
# provide more information to the user. This rule bypasses partial_eval's
|
||||
# regular behavior to do that. Specifically, we will stage the callback
|
||||
# if:
|
||||
# 1) the policy says debug_callbacks are not saveable
|
||||
# 2) the policy says debug_callbacks are saveable BUT all of the input
|
||||
# values are instantiated.
|
||||
# The purpose is to call back with as much information as possible while
|
||||
# avoiding unnecessarily staging out other values.
|
||||
if any(unks_in):
|
||||
# The usual case (if we have any unknowns, we need to stage it out)
|
||||
res = [v for v, inst in zip(eqn.invars, inst_in) if not inst]
|
||||
return None, eqn, [], [], res
|
||||
if saveable(primitive, *[v.aval for v in eqn.invars], **eqn.params):
|
||||
# The policy is telling us we can save the debug callback.
|
||||
if all(inst_in):
|
||||
# If all of the inputs are instantiated, we also stage out the
|
||||
# debug_callback.
|
||||
return eqn, eqn, [], [], []
|
||||
else:
|
||||
# If any are not instantiated, we don't do any extra staging to avoid
|
||||
# affecting the computation.
|
||||
return eqn, None, [], [], []
|
||||
# If we can't save the debug callback (thanks to the policy) we listen to
|
||||
# the policy and stage out the debug callback.
|
||||
return eqn, eqn, [], [], []
|
||||
|
||||
|
||||
pe.partial_eval_jaxpr_custom_rules[debug_callback_p] = partial(
|
||||
_debug_partial_eval_custom, primitive=debug_callback_p
|
||||
)
|
||||
|
||||
@state_discharge.register_discharge_rule(debug_callback_p)
|
||||
def _debug_callback_state_discharge_rule(
|
||||
in_avals, out_avals, *args, effect, partitioned, callback, **params
|
||||
):
|
||||
del in_avals, out_avals # Unused.
|
||||
out = debug_callback_p.bind(
|
||||
*args, effect=effect, partitioned=partitioned, callback=callback, **params
|
||||
)
|
||||
return args, out
|
||||
|
||||
|
||||
def _split_callback_args(args, kwargs):
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
static_args, dyn_args = {}, []
|
||||
for i, a in enumerate(flat_args):
|
||||
try:
|
||||
core.shaped_abstractify(a)
|
||||
dyn_args.append(a)
|
||||
except (AssertionError, TypeError):
|
||||
static_args[i] = a
|
||||
return in_tree, dyn_args, static_args
|
||||
|
||||
|
||||
def merge_callback_args(in_tree, dyn_args, static_args):
|
||||
static_args_dict = dict(static_args)
|
||||
all_args = [None] * (len(static_args) + len(dyn_args))
|
||||
di = iter(dyn_args)
|
||||
for i in range(len(all_args)):
|
||||
if i in static_args_dict:
|
||||
all_args[i] = static_args_dict[i]
|
||||
else:
|
||||
all_args[i] = next(di)
|
||||
assert next(di, None) is None
|
||||
return tree_util.tree_unflatten(in_tree, all_args)
|
||||
|
||||
|
||||
def _make_flat_callback(in_tree, callback, static_args):
|
||||
def _flat_callback(*dyn_args):
|
||||
args, kwargs = merge_callback_args(in_tree, dyn_args, static_args)
|
||||
callback(*args, **kwargs)
|
||||
return ()
|
||||
return _flat_callback
|
||||
|
||||
|
||||
debug_print_p = core.Primitive("debug_print")
|
||||
debug_print_p.multiple_results = True
|
||||
|
||||
|
||||
@debug_print_p.def_impl
|
||||
def debug_print_impl(
|
||||
*args: Any,
|
||||
fmt: str,
|
||||
ordered,
|
||||
partitioned,
|
||||
in_tree,
|
||||
static_args,
|
||||
np_printoptions,
|
||||
has_placeholders,
|
||||
logging_record,
|
||||
):
|
||||
callback = partial(
|
||||
_format_print_callback, fmt, dict(np_printoptions), has_placeholders,
|
||||
logging_record,
|
||||
)
|
||||
callback = _make_flat_callback(in_tree, callback, static_args)
|
||||
effect = ordered_debug_effect if ordered else debug_effect
|
||||
debug_callback_impl(
|
||||
*args, callback=callback, effect=effect, partitioned=partitioned
|
||||
)
|
||||
return ()
|
||||
|
||||
|
||||
@debug_print_p.def_effectful_abstract_eval
|
||||
def debug_print_abstract_eval(*avals: Any, fmt: str, ordered, **kwargs):
|
||||
del avals, fmt, kwargs # Unused.
|
||||
effect = ordered_debug_effect if ordered else debug_effect
|
||||
return [], {effect}
|
||||
|
||||
|
||||
batching.primitive_batchers[debug_print_p] = partial(
|
||||
debug_batching_rule, primitive=debug_print_p
|
||||
)
|
||||
|
||||
|
||||
def debug_print_jvp_rule(primals, tangents, **params):
|
||||
return debug_print_p.bind(*primals, **params), []
|
||||
|
||||
|
||||
ad.primitive_jvps[debug_print_p] = debug_print_jvp_rule
|
||||
|
||||
|
||||
def debug_print_transpose_rule(_, *args, **kwargs):
|
||||
del kwargs
|
||||
return [None for _ in args]
|
||||
|
||||
|
||||
ad.primitive_transposes[debug_print_p] = debug_print_transpose_rule
|
||||
|
||||
|
||||
def debug_print_lowering_rule(
|
||||
ctx,
|
||||
*dyn_args,
|
||||
fmt,
|
||||
ordered,
|
||||
partitioned,
|
||||
in_tree,
|
||||
static_args,
|
||||
np_printoptions,
|
||||
has_placeholders,
|
||||
logging_record,
|
||||
):
|
||||
callback = partial(
|
||||
_format_print_callback,
|
||||
fmt,
|
||||
dict(np_printoptions),
|
||||
has_placeholders,
|
||||
logging_record,
|
||||
)
|
||||
callback = _make_flat_callback(in_tree, callback, static_args)
|
||||
effect = ordered_debug_effect if ordered else debug_effect
|
||||
return debug_callback_lowering(
|
||||
ctx, *dyn_args, effect=effect, partitioned=partitioned, callback=callback
|
||||
)
|
||||
|
||||
|
||||
mlir.register_lowering(debug_print_p, debug_print_lowering_rule, platform="cpu")
|
||||
mlir.register_lowering(debug_print_p, debug_print_lowering_rule, platform="gpu")
|
||||
mlir.register_lowering(
|
||||
debug_print_p, debug_print_lowering_rule, platform="tpu", cacheable=False
|
||||
)
|
||||
|
||||
pe.partial_eval_jaxpr_custom_rules[debug_print_p] = partial(
|
||||
_debug_partial_eval_custom, primitive=debug_print_p
|
||||
)
|
||||
|
||||
|
||||
@state_discharge.register_discharge_rule(debug_print_p)
|
||||
def _debug_print_state_discharge_rule(in_avals, out_avals, *args, **kwargs):
|
||||
del in_avals, out_avals # Unused.
|
||||
out = debug_print_p.bind(*args, **kwargs)
|
||||
return args, out
|
||||
|
||||
|
||||
def debug_callback(
|
||||
callback: Callable[..., None],
|
||||
*args: Any,
|
||||
ordered: bool = False,
|
||||
partitioned: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Calls a stageable Python callback.
|
||||
|
||||
For more explanation, see `External Callbacks`_.
|
||||
|
||||
``jax.debug.callback`` enables you to pass in a Python function that can be called
|
||||
inside of a staged JAX program. A ``jax.debug.callback`` follows existing JAX
|
||||
transformation *pure* operational semantics, which are therefore unaware of
|
||||
side-effects. This means the effect could be dropped, duplicated, or
|
||||
potentially reordered in the presence of higher-order primitives and
|
||||
transformations.
|
||||
|
||||
We want this behavior because we'd like ``jax.debug.callback`` to be "innocuous",
|
||||
i.e. we want these primitives to change the JAX computation as little as
|
||||
possible while revealing as much about them as possible, such as which parts
|
||||
of the computation are duplicated or dropped.
|
||||
|
||||
Args:
|
||||
callback: A Python callable returning None.
|
||||
*args: The positional arguments to the callback.
|
||||
ordered: A keyword only argument used to indicate whether or not the
|
||||
staged out computation will enforce ordering of this callback w.r.t.
|
||||
other ordered callbacks.
|
||||
partitioned: If True, then print local shards only; this option avoids an
|
||||
all-gather of the operands. If False, print with logical operands; this
|
||||
option requires an all-gather of operands first.
|
||||
**kwargs: The keyword arguments to the callback.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
See Also:
|
||||
- :func:`jax.experimental.io_callback`: callback designed for impure functions.
|
||||
- :func:`jax.pure_callback`: callback designed for pure functions.
|
||||
- :func:`jax.debug.print`: callback designed for printing.
|
||||
|
||||
.. _External Callbacks: https://docs.jax.dev/en/latest/notebooks/external_callbacks.html
|
||||
"""
|
||||
if not callable(callback):
|
||||
raise TypeError("first argument to jax.debug.callback must be callable, "
|
||||
f"but got an object of type {type(callback)}")
|
||||
in_tree, dyn_args, static_args = _split_callback_args(args, kwargs)
|
||||
|
||||
def _flat_callback(*dyn_args):
|
||||
all_args = [None] * (len(static_args) + len(dyn_args))
|
||||
di = iter(dyn_args)
|
||||
for i in range(len(all_args)):
|
||||
if i in static_args:
|
||||
all_args[i] = static_args[i]
|
||||
else:
|
||||
all_args[i] = next(di)
|
||||
assert next(di, None) is None
|
||||
args, kwargs = tree_util.tree_unflatten(in_tree, all_args)
|
||||
callback(*args, **kwargs)
|
||||
return ()
|
||||
|
||||
effect = ordered_debug_effect if ordered else debug_effect
|
||||
debug_callback_p.bind(
|
||||
*dyn_args, callback=_flat_callback, effect=effect, partitioned=partitioned
|
||||
)
|
||||
|
||||
|
||||
class _DebugPrintFormatChecker(string.Formatter):
|
||||
|
||||
def format_field(self, value, format_spec):
|
||||
del value, format_spec
|
||||
return "" # No formatting is done.
|
||||
|
||||
def check_unused_args(self, used_args, args, kwargs):
|
||||
unused_args = [arg for i, arg in enumerate(args) if i not in used_args]
|
||||
unused_kwargs = [k for k in kwargs if k not in used_args]
|
||||
if unused_args:
|
||||
raise ValueError(
|
||||
f"Unused positional arguments to `jax.debug.print`: {unused_args}")
|
||||
if unused_kwargs:
|
||||
raise ValueError(
|
||||
f"Unused keyword arguments to `jax.debug.print`: {unused_kwargs}. "
|
||||
"You may be passing an f-string (i.e, `f\"{x}\"`) into "
|
||||
"`jax.debug.print` and instead should pass in a regular string.")
|
||||
|
||||
formatter = _DebugPrintFormatChecker()
|
||||
|
||||
|
||||
def _format_print_callback(
|
||||
fmt: str, np_printoptions, has_placeholders, logging_record, *args, **kwargs
|
||||
):
|
||||
if has_placeholders:
|
||||
with np.printoptions(**np_printoptions):
|
||||
msg = fmt.format(*args, **kwargs)
|
||||
else:
|
||||
assert not kwargs, "Format without placeholders should not have kwargs."
|
||||
msg = " ".join((fmt, *(str(a) for a in args)))
|
||||
if logging_record:
|
||||
logging_record = copy.copy(logging_record)
|
||||
logging_record.msg = msg
|
||||
logger.handle(logging_record)
|
||||
else:
|
||||
sys.stdout.write(msg + "\n")
|
||||
|
||||
|
||||
def _make_logging_record(level):
|
||||
si = source_info_util.current()
|
||||
user_frame = source_info_util.user_frame(si.traceback)
|
||||
|
||||
file_name = "(unknown file)"
|
||||
line_no = 0
|
||||
if user_frame:
|
||||
file_name = user_frame.file_name
|
||||
line_no = user_frame.start_line
|
||||
args = ()
|
||||
return logger.makeRecord(
|
||||
logger.name, level, file_name, line_no, "", args, None
|
||||
)
|
||||
|
||||
|
||||
def debug_print(
|
||||
fmt: str,
|
||||
*args,
|
||||
ordered: bool = False,
|
||||
partitioned: bool = False,
|
||||
skip_format_check: bool = False,
|
||||
_use_logging: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Prints values and works in staged out JAX functions.
|
||||
|
||||
This function does *not* work with f-strings because formatting is delayed.
|
||||
So instead of ``jax.debug.print(f"hello {bar}")``, write
|
||||
``jax.debug.print("hello {bar}", bar=bar)``.
|
||||
|
||||
This function is a thin convenience wrapper around :func:`jax.debug.callback`.
|
||||
The implementation is essentially::
|
||||
|
||||
def debug_print(fmt: str, *args, **kwargs):
|
||||
jax.debug.callback(
|
||||
lambda *args, **kwargs: print(fmt.format(*args, **kwargs)),
|
||||
*args, **kwargs)
|
||||
|
||||
It may be useful to call :func:`jax.debug.callback` directly instead of this
|
||||
convenience wrapper. For example, to get debug printing in logs, you might
|
||||
use :func:`jax.debug.callback` together with ``logging.log``.
|
||||
|
||||
Args:
|
||||
fmt: A format string, e.g. ``"hello {x}"``, that will be used to format
|
||||
input arguments, like ``str.format``. See the Python docs on `string
|
||||
formatting <https://docs.python.org/3/library/stdtypes.html#str.format>`_
|
||||
and `format string syntax
|
||||
<https://docs.python.org/3/library/string.html#formatstrings>`_.
|
||||
*args: A list of positional arguments to be formatted, as if passed to
|
||||
``fmt.format``.
|
||||
ordered: A keyword only argument used to indicate whether or not the staged
|
||||
out computation will enforce ordering of this ``jax.debug.print`` w.r.t.
|
||||
other ordered ``jax.debug.print`` calls.
|
||||
partitioned: If True, then print local shards only; this option avoids an
|
||||
all-gather of the operands. If False, print with logical operands; this
|
||||
option requires an all-gather of operands first.
|
||||
skip_format_check: If True, the format string is not checked. This is useful
|
||||
when using the function from inside a Pallas TPU kernel, where scalars
|
||||
args will be printed after the format string.
|
||||
**kwargs: Additional keyword arguments to be formatted, as if passed to
|
||||
``fmt.format``.
|
||||
"""
|
||||
if not skip_format_check:
|
||||
# Check that we provide the correct arguments to be formatted.
|
||||
formatter.format(fmt, *args, **kwargs)
|
||||
has_placeholders = False
|
||||
if fmt:
|
||||
_, field_name, *_ = next(iter(string.Formatter().parse(fmt)))
|
||||
has_placeholders = field_name is not None
|
||||
in_tree, dyn_args, static_args = _split_callback_args(args, kwargs)
|
||||
static_args = tuple(static_args.items())
|
||||
np_printoptions = tuple(np.get_printoptions().items())
|
||||
|
||||
debug_print_p.bind(
|
||||
*dyn_args,
|
||||
fmt=fmt,
|
||||
ordered=ordered,
|
||||
partitioned=partitioned,
|
||||
in_tree=in_tree,
|
||||
static_args=static_args,
|
||||
np_printoptions=np_printoptions,
|
||||
has_placeholders=has_placeholders,
|
||||
logging_record=(_make_logging_record(logging.INFO) if _use_logging
|
||||
else None),
|
||||
)
|
||||
|
||||
|
||||
debug_log = partial(debug_print, _use_logging=True)
|
||||
|
||||
# Sharding visualization
|
||||
|
||||
inspect_sharding_p = core.Primitive("inspect_sharding")
|
||||
inspect_sharding_p.multiple_results = True
|
||||
dispatch.prim_requires_devices_during_lowering.add(inspect_sharding_p)
|
||||
|
||||
def _inspect_sharding_impl(value, *, callback):
|
||||
callback(value.sharding)
|
||||
return []
|
||||
inspect_sharding_p.def_impl(_inspect_sharding_impl)
|
||||
|
||||
def _inspect_sharding_abstract_eval(aval, **_):
|
||||
del aval
|
||||
# Effectful abstract avoids DCE
|
||||
return [], {debug_effect}
|
||||
inspect_sharding_p.def_effectful_abstract_eval(_inspect_sharding_abstract_eval)
|
||||
|
||||
def _inspect_sharding_batching_rule(args, _, *, callback):
|
||||
value, = args
|
||||
inspect_sharding_p.bind(value, callback=callback)
|
||||
return [], []
|
||||
batching.primitive_batchers[inspect_sharding_p] = (
|
||||
_inspect_sharding_batching_rule)
|
||||
|
||||
def _inspect_sharding_jvp_rule(primals, _, **params):
|
||||
return inspect_sharding_p.bind(*primals, **params), []
|
||||
ad.primitive_jvps[inspect_sharding_p] = _inspect_sharding_jvp_rule
|
||||
|
||||
_INSPECT_SHARDING_CALL_NAME = "InspectSharding"
|
||||
|
||||
def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
callback):
|
||||
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
axis_context = ctx.module_context.axis_context
|
||||
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
devices = axis_context.device_assignment
|
||||
if devices is None:
|
||||
raise AssertionError(
|
||||
'Please file a bug at https://github.com/jax-ml/jax/issues')
|
||||
am = axis_context.abstract_mesh
|
||||
if am is not None:
|
||||
mesh = mesh_lib.Mesh(np.array(devices).reshape(am.axis_sizes),
|
||||
am.axis_names)
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
mesh = axis_context.mesh
|
||||
devices = axis_context.mesh._flat_devices_tuple
|
||||
else:
|
||||
raise NotImplementedError(type(axis_context))
|
||||
assert devices is not None
|
||||
|
||||
# If we have a nontrivial parallel computation, we need to wait until the SPMD
|
||||
# partitioner calls back with the `HloSharding.
|
||||
def _hlo_sharding_callback(hlo_sharding: xc.HloSharding):
|
||||
if mesh.empty:
|
||||
return callback(
|
||||
sharding_impls.GSPMDSharding(devices, hlo_sharding))
|
||||
pspec = (P() if hlo_sharding.is_manual() else
|
||||
parse_flatten_op_sharding(hlo_sharding, mesh)[0])
|
||||
return callback(NamedSharding(mesh, pspec))
|
||||
|
||||
if len(devices) == 1:
|
||||
# If we only have one device in our computation, we can construct a
|
||||
# replicated HloSharding and call it right now.
|
||||
_hlo_sharding_callback(sharding_impls.replicated_hlo_sharding)
|
||||
return []
|
||||
|
||||
key = xc.encode_inspect_sharding_callback(_hlo_sharding_callback)
|
||||
# We need to make sure `_hlo_sharding_callback` is still alive when the SPMD
|
||||
# partitioner runs so we keep it alive by attaching it to the executable. #
|
||||
ctx.module_context.add_keepalive(_hlo_sharding_callback)
|
||||
|
||||
hlo.CustomCallOp([value.type], [value],
|
||||
call_target_name=ir.StringAttr.get(
|
||||
_INSPECT_SHARDING_CALL_NAME),
|
||||
has_side_effect=ir.BoolAttr.get(True),
|
||||
api_version=mlir.i32_attr(1),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
backend_config=ir.StringAttr.get(key),
|
||||
operand_layouts=None,
|
||||
result_layouts=None)
|
||||
return []
|
||||
mlir.register_lowering(inspect_sharding_p, _inspect_sharding_lowering_rule)
|
||||
|
||||
def _slice_to_chunk_idx(size: int, slc: slice) -> int:
|
||||
if slc.stop == slc.start == None:
|
||||
return 0
|
||||
slice_size = slc.stop - slc.start
|
||||
assert slc.start % slice_size == 0
|
||||
assert size % slice_size == 0
|
||||
return slc.start // slice_size
|
||||
|
||||
def _raise_to_slice(slc: slice | int):
|
||||
if isinstance(slc, int):
|
||||
return slice(slc, slc + 1)
|
||||
return slc
|
||||
|
||||
Color = Union[tuple[float, float, float], str]
|
||||
ColorMap = Callable[[float], tuple[float, float, float, float]]
|
||||
|
||||
def _canonicalize_color(color: Color) -> str:
|
||||
if isinstance(color, str):
|
||||
return color
|
||||
r, g, b = (int(a * 255) for a in color)
|
||||
return f"#{r:02X}{g:02X}{b:02X}"
|
||||
|
||||
def _get_text_color(color: str) -> str:
|
||||
r, g, b = map(lambda x: int(x, 16), (color[1:3], color[3:5], color[5:7]))
|
||||
if (r * 0.299 + g * 0.587 + b * 0.114) > 186:
|
||||
return "#000000"
|
||||
return "#ffffff"
|
||||
|
||||
def make_color_iter(color_map, num_rows, num_cols):
|
||||
num_colors = num_rows * num_cols
|
||||
color_values = np.linspace(0, 1, num_colors)
|
||||
idx = 0
|
||||
for _ in range(num_colors):
|
||||
yield color_map(color_values[idx])
|
||||
idx = (idx + num_colors // 2 + bool(num_colors % 2 == 0)) % num_colors
|
||||
|
||||
def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
|
||||
use_color: bool = True, scale: float = 1.,
|
||||
min_width: int = 9, max_width: int = 80,
|
||||
color_map: ColorMap | None = None):
|
||||
"""Visualizes a ``Sharding`` using ``rich``."""
|
||||
if not importlib.util.find_spec("rich"):
|
||||
raise ValueError("`visualize_sharding` requires `rich` to be installed.")
|
||||
|
||||
# These imports are local so that they don't affect JAX import times.
|
||||
# pytype: disable=import-error
|
||||
import rich.align
|
||||
import rich.console
|
||||
import rich.box
|
||||
import rich.padding
|
||||
import rich.style
|
||||
import rich.table
|
||||
# pytype: enable=import-error
|
||||
|
||||
if len(shape) > 2 or len(shape) < 1:
|
||||
raise ValueError(
|
||||
"`visualize_sharding` only works for shapes with 1 and 2 dimensions.")
|
||||
console = rich.console.Console(width=max_width)
|
||||
use_color = use_color and console.color_system is not None
|
||||
if use_color and not color_map:
|
||||
try:
|
||||
import matplotlib as mpl # pytype: disable=import-error
|
||||
color_map = mpl.colormaps["tab20b"]
|
||||
except ModuleNotFoundError:
|
||||
use_color = False
|
||||
|
||||
base_height = int(10 * scale)
|
||||
aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0]
|
||||
base_width = int(base_height * aspect_ratio)
|
||||
height_to_width_ratio = 2.5
|
||||
|
||||
# Grab the device kind from the first device
|
||||
device_kind = next(iter(sharding.device_set)).platform.upper()
|
||||
|
||||
device_indices_map = sharding.devices_indices_map(tuple(shape))
|
||||
slices: dict[tuple[int, ...], set[int]] = {}
|
||||
heights: dict[tuple[int, ...], float | None] = {}
|
||||
widths: dict[tuple[int, ...], float] = {}
|
||||
|
||||
for i, (dev, slcs) in enumerate(device_indices_map.items()):
|
||||
assert slcs is not None
|
||||
slcs = tuple(map(_raise_to_slice, slcs))
|
||||
chunk_idxs = tuple(map(_slice_to_chunk_idx, shape, slcs))
|
||||
if slcs is None:
|
||||
raise NotImplementedError
|
||||
if len(slcs) == 2:
|
||||
vert, horiz = slcs
|
||||
vert_size = ((vert.stop - vert.start ) if vert.stop is not None
|
||||
else shape[0])
|
||||
horiz_size = ((horiz.stop - horiz.start) if horiz.stop is not None
|
||||
else shape[1])
|
||||
chunk_height = vert_size / shape[0]
|
||||
chunk_width = horiz_size / shape[1]
|
||||
heights[chunk_idxs] = chunk_height
|
||||
widths[chunk_idxs] = chunk_width
|
||||
else:
|
||||
# In the 1D case, we set the height to 1.
|
||||
horiz, = slcs
|
||||
vert = slice(0, 1, None)
|
||||
horiz_size = (
|
||||
(horiz.stop - horiz.start) if horiz.stop is not None else shape[0])
|
||||
chunk_idxs = (0, *chunk_idxs)
|
||||
heights[chunk_idxs] = None
|
||||
widths[chunk_idxs] = horiz_size / shape[0]
|
||||
slices.setdefault(chunk_idxs, set()).add(dev.id)
|
||||
num_rows = max(a[0] for a in slices.keys()) + 1
|
||||
if len(list(slices.keys())[0]) == 1:
|
||||
num_cols = 1
|
||||
else:
|
||||
num_cols = max(a[1] for a in slices.keys()) + 1
|
||||
|
||||
color_iter = make_color_iter(color_map, num_rows, num_cols)
|
||||
table = rich.table.Table(show_header=False, show_lines=not use_color,
|
||||
padding=0,
|
||||
highlight=not use_color, pad_edge=False,
|
||||
box=rich.box.SQUARE if not use_color else None)
|
||||
for i in range(num_rows):
|
||||
col = []
|
||||
for j in range(num_cols):
|
||||
entry = f"{device_kind} "+",".join([str(s) for s in sorted(slices[i, j])])
|
||||
width, maybe_height = widths[i, j], heights[i, j]
|
||||
width = int(width * base_width * height_to_width_ratio)
|
||||
if maybe_height is None:
|
||||
height = 1
|
||||
else:
|
||||
height = int(maybe_height * base_height)
|
||||
width = min(max(width, min_width), max_width)
|
||||
left_padding, remainder = divmod(width - len(entry) - 2, 2)
|
||||
right_padding = left_padding + remainder
|
||||
top_padding, remainder = divmod(height - 2, 2)
|
||||
bottom_padding = top_padding + remainder
|
||||
if use_color:
|
||||
color = _canonicalize_color(next(color_iter)[:3])
|
||||
text_color = _get_text_color(color)
|
||||
top_padding += 1
|
||||
bottom_padding += 1
|
||||
left_padding += 1
|
||||
right_padding += 1
|
||||
else:
|
||||
color = None
|
||||
text_color = None
|
||||
padding = (
|
||||
max(top_padding, 0),
|
||||
max(right_padding, 0),
|
||||
max(bottom_padding, 0),
|
||||
max(left_padding, 0),
|
||||
)
|
||||
col.append(
|
||||
rich.padding.Padding(
|
||||
rich.align.Align(entry, "center", vertical="middle"), padding,
|
||||
style=rich.style.Style(bgcolor=color,
|
||||
color=text_color)))
|
||||
table.add_row(*col)
|
||||
console.print(table, end='\n\n')
|
||||
|
||||
def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]):
|
||||
"""Enables inspecting array sharding inside JIT-ted functions.
|
||||
|
||||
This function, when provided with a Pytree of arrays, calls back with each of
|
||||
their shardings and works in ``jax.jit``-ted computations, enabling inspecting
|
||||
the chosen intermediate shardings.
|
||||
|
||||
The policy for when ``callback`` is called is *as early as possible* when the
|
||||
sharding information is available. This means if ``inspect_array_callback`` is
|
||||
called without any transformations, the callback will happen immediately
|
||||
since we have the array and its sharding readily available. Inside of a
|
||||
``jax.jit``, the callback will happen at lowering time, meaning you can
|
||||
trigger the callback using the AOT API (``jit(f).lower(...)``). When inside of
|
||||
a ``jax.jit``, the callback happens *at compile time* since the sharding is
|
||||
determined by XLA. You can trigger the callback by using JAX's AOT API
|
||||
(``jax.jit(f).lower(...).compile()``). In all cases, the callback will be
|
||||
triggered by running the function, since running a function entails lowering
|
||||
and compiling it first. However, once the function is compiled and cached,
|
||||
the callback will no longer occur.
|
||||
|
||||
This function is experimental and its behavior may change in the future.
|
||||
|
||||
Args:
|
||||
value: A Pytree of JAX arrays.
|
||||
callback: A callable that takes in a ``Sharding`` and doesn't return a value.
|
||||
|
||||
In the following example, we print out the sharding of an intermediate value
|
||||
in a ``jax.jit``-ted computation:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax.sharding import Mesh, PartitionSpec
|
||||
>>>
|
||||
>>> x = jnp.arange(8, dtype=jnp.float32)
|
||||
>>> def f_(x):
|
||||
... x = jnp.sin(x)
|
||||
... jax.debug.inspect_array_sharding(x, callback=print)
|
||||
... return jnp.square(x)
|
||||
>>> f = jax.jit(f_, in_shardings=PartitionSpec('dev'),
|
||||
... out_shardings=PartitionSpec('dev'))
|
||||
>>> with jax.set_mesh(Mesh(jax.devices(), ('dev',))):
|
||||
... f.lower(x).compile() # doctest: +SKIP
|
||||
...
|
||||
NamedSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),))
|
||||
"""
|
||||
def _inspect(val):
|
||||
inspect_sharding_p.bind(val, callback=callback)
|
||||
tree_util.tree_map(_inspect, value)
|
||||
|
||||
def visualize_array_sharding(arr, **kwargs):
|
||||
"""Visualizes an array's sharding."""
|
||||
def _visualize(sharding):
|
||||
return visualize_sharding(arr.shape, sharding, **kwargs)
|
||||
inspect_array_sharding(arr, callback=_visualize)
|
||||
|
||||
|
||||
# TODO(mattjj): working around an apparent XLA or PjRt bug, remove eventually
|
||||
def _debug_callback_eager_rule(
|
||||
mesh,
|
||||
*args,
|
||||
callback: Callable[..., Any],
|
||||
effect: DebugEffect,
|
||||
partitioned: bool,
|
||||
):
|
||||
del effect
|
||||
with core.eval_context():
|
||||
all_blocks = zip(*map(list, args))
|
||||
for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks):
|
||||
callback(*blocks)
|
||||
return []
|
||||
shard_map.eager_rules[debug_callback_p] = _debug_callback_eager_rule
|
||||
|
||||
|
||||
def _debug_print_eager_rule(
|
||||
mesh,
|
||||
*args,
|
||||
fmt: str,
|
||||
ordered,
|
||||
partitioned,
|
||||
in_tree,
|
||||
static_args,
|
||||
np_printoptions,
|
||||
has_placeholders,
|
||||
logging_record,
|
||||
):
|
||||
del ordered, partitioned
|
||||
callback = partial(
|
||||
_format_print_callback, fmt, dict(np_printoptions), has_placeholders,
|
||||
logging_record,
|
||||
)
|
||||
callback = _make_flat_callback(in_tree, callback, static_args)
|
||||
with core.eval_context():
|
||||
all_blocks = zip(*map(list, args))
|
||||
for (idx, device), blocks in zip(np.ndenumerate(mesh.devices), all_blocks):
|
||||
callback(*blocks)
|
||||
return []
|
||||
|
||||
|
||||
shard_map.eager_rules[debug_print_p] = _debug_print_eager_rule
|
||||
Reference in New Issue
Block a user