This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,21 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.debugger.core import breakpoint as breakpoint
from jax._src.debugger import cli_debugger
from jax._src.debugger import colab_debugger
from jax._src.debugger import web_debugger
del cli_debugger # For registration only
del colab_debugger # For registration only
del web_debugger # For registration only
@@ -0,0 +1,170 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import cmd
import pprint
import sys
import traceback
from typing import Any, IO
from jax._src.debugger import core as debugger_core
DebuggerFrame = debugger_core.DebuggerFrame
class CliDebugger(cmd.Cmd):
"""A text-based debugger."""
prompt = '(jdb) '
def __init__(self, frames: list[DebuggerFrame], thread_id,
stdin: IO[str] | None = None, stdout: IO[str] | None = None,
completekey: str = "tab"):
super().__init__(stdin=stdin, stdout=stdout, completekey=completekey)
self.use_rawinput = stdin is None
self.frames = frames
self.frame_index = 0
self.thread_id = thread_id
self.intro = 'Entering jdb:'
def current_frame(self):
return self.frames[self.frame_index]
def evaluate(self, expr):
env = {}
curr_frame = self.frames[self.frame_index]
env.update(curr_frame.globals)
env.update(curr_frame.locals)
return eval(expr, {}, env)
def default(self, line):
"""Evaluates an expression."""
try:
print(repr(self.evaluate(line)), file=self.stdout)
except:
self._error_message()
def print_backtrace(self):
backtrace = []
backtrace.append('Traceback:')
for frame in self.frames[::-1]:
backtrace.append(f' File "{frame.filename}", line {frame.lineno}')
if frame.offset is None:
backtrace.append(' <no source>')
else:
line = frame.source[frame.offset]
backtrace.append(f' {line.strip()}')
print("\n".join(backtrace), file=self.stdout)
def print_context(self, num_lines=2):
curr_frame = self.frames[self.frame_index]
context = []
context.append(f'> {curr_frame.filename}({curr_frame.lineno})')
for i, line in enumerate(curr_frame.source):
assert curr_frame.offset is not None
if (curr_frame.offset - 1 - num_lines <= i <=
curr_frame.offset + num_lines):
if i == curr_frame.offset:
context.append(f'-> {line}')
else:
context.append(f' {line}')
print("\n".join(context), file=self.stdout)
def _error_message(self):
exc_info = sys.exc_info()[:2]
msg = traceback.format_exception_only(*exc_info)[-1].strip()
print('***', msg, file=self.stdout)
def do_p(self, arg):
"""p expression
Evaluates and prints the value of an expression
"""
try:
print(repr(self.evaluate(arg)), file=self.stdout)
except:
self._error_message()
def do_pp(self, arg):
"""pp expression
Evaluates and pretty-prints the value of an expression
"""
try:
print(pprint.pformat(self.evaluate(arg)), file=self.stdout)
except:
self._error_message()
def do_up(self, arg, /):
"""u(p)
Move up a stack frame.
"""
del arg # unused
if self.frame_index == len(self.frames) - 1:
print('At topmost frame.', file=self.stdout)
else:
self.frame_index += 1
self.print_context()
do_u = do_up
def do_down(self, arg, /):
"""d(own)
Move down a stack frame.
"""
del arg # unused
if self.frame_index == 0:
print('At bottommost frame.', file=self.stdout)
else:
self.frame_index -= 1
self.print_context()
do_d = do_down
def do_list(self, _):
"""l(ist)
List source code for the current file.
"""
self.print_context(num_lines=5)
do_l = do_list
def do_continue(self, _):
"""c(ont(inue))
Continue the program's execution.
"""
return True
do_c = do_cont = do_continue
def do_quit(self, _):
"""q(uit)\n(exit)
Quit the debugger. The program is given an exit command.
"""
sys.exit(0)
do_q = do_EOF = do_exit = do_quit
def do_where(self, _):
"""w(here)
Prints a stack trace with the most recent frame on the bottom.
'bt' is an alias for this command.
"""
self.print_backtrace()
do_w = do_bt = do_where
def run(self):
while True:
try:
self.cmdloop()
break
except KeyboardInterrupt:
print('--KeyboardInterrupt--', file=sys.stdout)
def run_debugger(frames: list[DebuggerFrame], thread_id: int | None,
**kwargs: Any):
CliDebugger(frames, thread_id, **kwargs).run()
debugger_core.register_debugger("cli", run_debugger, -1)
@@ -0,0 +1,256 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for Colab-specific debugger."""
from __future__ import annotations
import html
import inspect
import traceback
import uuid
from jax._src.debugger import colab_lib
from jax._src.debugger import core as debugger_core
from jax._src.debugger import cli_debugger
# pytype: disable=import-error
if colab_lib.IS_COLAB_ENABLED:
from google.colab import output
try:
import pygments
import pygments.lexers
import pygments.formatters
IS_PYGMENTS_ENABLED = True
except ImportError:
IS_PYGMENTS_ENABLED = False
# pytype: enable=import-error
class CodeViewer(colab_lib.DynamicDOMElement):
"""A mutable DOM element that displays code as HTML."""
def __init__(self, code_: str, highlights: list[int], linenostart: int = 1):
self._code = code_
self._highlights = highlights
self._view = colab_lib.dynamic(colab_lib.div())
self._linenostart = linenostart
def render(self):
self.update_code(
self._code, self._highlights, linenostart=self._linenostart)
def clear(self):
self._view.clear()
def append(self, child):
raise NotImplementedError
def update(self, elem):
self._view.update(elem)
def _highlight_code(self, code: str, highlights, linenostart: int):
is_dark_mode = output.eval_js(
'document.documentElement.matches("[theme=dark]");')
code_style = "monokai" if is_dark_mode else "default"
hl_color = "#4e56b7" if is_dark_mode else "#fff7c1"
if IS_PYGMENTS_ENABLED:
lexer = pygments.lexers.get_lexer_by_name("python")
formatter = pygments.formatters.HtmlFormatter(
full=False,
hl_lines=highlights,
linenos=True,
linenostart=linenostart,
style=code_style)
if hl_color:
formatter.style.highlight_color = hl_color
css_ = formatter.get_style_defs()
code = pygments.highlight(code, lexer, formatter)
else:
return "";
return code, css_
def update_code(self, code_, highlights, *, linenostart: int = 1):
"""Updates the code viewer to use new code."""
self._code = code_
self._view.clear()
code_, css_ = self._highlight_code(self._code, highlights, linenostart)
uuid_ = uuid.uuid4()
code_div = colab_lib.div(
colab_lib.css(css_),
code_,
id=f"code-{uuid_}",
style=colab_lib.style({
"max-height": "500px",
"overflow-y": "scroll",
"background-color": "var(--colab-border-color)",
"padding": "5px 5px 5px 5px",
}))
if highlights:
percent_scroll = highlights[0] / len(self._code.split("\n"))
else:
percent_scroll = 0.
self.update(code_div)
# Scroll to where the line is
output.eval_js("""
console.log("{id}")
var elem = document.getElementById("{id}")
var maxScrollPosition = elem.scrollHeight - elem.clientHeight;
elem.scrollTop = maxScrollPosition * {percent_scroll}
""".format(id=f"code-{uuid_}", percent_scroll=percent_scroll))
class FramePreview(colab_lib.DynamicDOMElement):
"""Displays information about a stack frame."""
def __init__(self, frame):
super().__init__()
self._header = colab_lib.dynamic(
colab_lib.div(colab_lib.pre(colab_lib.code(""))))
self._code_view = CodeViewer("", highlights=[])
self.frame = frame
self._file_cache = {}
def clear(self):
self._header.clear()
self._code_view.clear()
def append(self, child):
raise NotImplementedError
def update(self, elem):
raise NotImplementedError
def update_frame(self, frame):
"""Updates the frame viewer to use a new frame."""
self.frame = frame
lineno = self.frame.lineno or None
filename = self.frame.filename.strip()
if inspect.getmodulename(filename):
if filename not in self._file_cache:
try:
with open(filename) as fp:
self._file_cache[filename] = fp.read()
except FileNotFoundError:
pass
if filename in self._file_cache:
source = self._file_cache[filename]
highlight = lineno
linenostart = 1
else:
source = "\n".join(frame.source)
highlight = min(frame.offset + 1, len(frame.source) - 1)
linenostart = lineno - frame.offset
self._header.clear()
self._header.update(
colab_lib.div(
colab_lib.pre(colab_lib.code(f"{html.escape(filename)}({lineno})")),
style=colab_lib.style({
"padding": "5px 5px 5px 5px",
"background-color": "var(--colab-highlighted-surface-color)",
})))
self._code_view.update_code(source, [highlight], linenostart=linenostart)
def render(self):
self.update_frame(self.frame)
class DebuggerView(colab_lib.DynamicDOMElement):
"""Main view for the Colab debugger."""
def __init__(self, frame, *, log_color=""):
super().__init__()
self._interaction_log = colab_lib.dynamic(colab_lib.div())
self._frame_preview = FramePreview(frame)
self._header = colab_lib.dynamic(
colab_lib.div(
colab_lib.span("Breakpoint"),
style=colab_lib.style({
"background-color": "var(--colab-secondary-surface-color)",
"color": "var(--colab-primary-text-color)",
"padding": "5px 5px 5px 5px",
"font-weight": "bold",
})))
def render(self):
self._header.render()
self._frame_preview.render()
self._interaction_log.render()
def append(self, child):
raise NotImplementedError
def update(self, elem):
raise NotImplementedError
def clear(self):
self._header.clear()
self._interaction_log.clear()
self._frame_preview.clear()
def update_frame(self, frame):
self._frame_preview.update_frame(frame)
def write(self, text):
self._interaction_log.append(colab_lib.pre(text))
def read(self):
raise NotImplementedError()
def readline(self):
with output.use_tags(["stdin"]):
user_input = input() + "\n"
output.clear(output_tags=["stdin"])
return user_input
def isatty(self):
return True
def flush(self):
pass
class ColabDebugger(cli_debugger.CliDebugger):
"""A JAX debugger for a Colab environment."""
def __init__(self,
frames: list[debugger_core.DebuggerFrame],
thread_id: int):
super().__init__(frames, thread_id)
self._debugger_view = DebuggerView(self.current_frame())
self.stdout = self.stdin = self._debugger_view # pyrefly: ignore[bad-assignment]
def do_up(self, arg, /):
super().do_up(arg)
self._debugger_view.update_frame(self.current_frame())
return False
def do_down(self, arg, /):
super().do_down(arg)
self._debugger_view.update_frame(self.current_frame())
return False
def run(self):
self._debugger_view.render()
while True:
if not self.cmdloop():
return
def _run_debugger(frames, thread_id, **kwargs):
try:
ColabDebugger(frames, thread_id, **kwargs).run()
except Exception:
traceback.print_exc()
if colab_lib.IS_COLAB_ENABLED:
debugger_core.register_debugger("colab", _run_debugger, 1)
@@ -0,0 +1,167 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for building interfaces in Colab."""
from __future__ import annotations
import abc
import dataclasses
import functools
import sys
import uuid
from typing import Any, Union
IS_COLAB_ENABLED = "google.colab" in sys.modules
if IS_COLAB_ENABLED:
# pytype: disable=import-error
from google.colab import output
from IPython import display
# pytype: enable=import-error
class DOMElement(metaclass=abc.ABCMeta):
@abc.abstractmethod
def render(self):
pass
Element = Union[DOMElement, str]
class DynamicDOMElement(DOMElement):
"""A DOM element that can be mutated."""
@abc.abstractmethod
def render(self):
pass
@abc.abstractmethod
def append(self, child: DOMElement):
pass
@abc.abstractmethod
def update(self, elem: DOMElement):
pass
@abc.abstractmethod
def clear(self):
pass
@dataclasses.dataclass
class DynamicDiv(DynamicDOMElement):
"""A `div` that can be edited."""
_uuid: str = dataclasses.field(init=False)
_root_elem: DOMElement = dataclasses.field(init=False)
elem: DOMElement | str
def __post_init__(self):
self._uuid = str(uuid.uuid4())
self._rendered = False
self._root_elem = div(id=self.tag)
@property
def tag(self):
return f"tag-{self._uuid}"
def render(self):
if self._rendered:
raise ValueError("Can't call `render` twice.")
self._root_elem.render()
self._rendered = True
if isinstance(self.elem, str):
raise TypeError("Cannot render when self.elem is a string.")
self.append(self.elem)
def append(self, child: DOMElement):
if not self._rendered:
self.render()
with output.use_tags([self.tag]):
with output.redirect_to_element(f"#{self.tag}"):
child.render()
def update(self, elem: DOMElement):
self.clear()
self.elem = elem
self.render()
def clear(self):
output.clear(output_tags=[self.tag])
self._rendered = False
@dataclasses.dataclass
class StaticDOMElement(DOMElement):
"""An immutable DOM element."""
_uuid: str = dataclasses.field(init=False)
name: str
children: list[str | DOMElement]
attrs: dict[str, str]
def html(self):
attr_str = ""
if self.attrs:
attr_str = " " + (" ".join(
[f"{key}=\"{value}\"" for key, value in self.attrs.items()]))
children = []
children = "\n".join([str(c) for c in self.children])
return f"<{self.name}{attr_str}>{children}</{self.name}>"
def render(self):
display.display(display.HTML(self.html()))
def attr(self, key: str) -> str:
return self.attrs[key]
def __str__(self):
return self.html()
def __repr__(self):
return self.html()
def append(self, child: DOMElement) -> DOMElement:
return dataclasses.replace(self, children=[*self.children, child])
def replace(self, **kwargs) -> DOMElement:
return dataclasses.replace(self, **kwargs)
def _style_dict_to_str(style_dict: dict[str, Any]) -> str:
return " ".join([f"{k}: {v};" for k, v in style_dict.items()])
def dynamic(elem: StaticDOMElement) -> DynamicDiv:
return DynamicDiv(elem)
def _make_elem(tag: str, *children: Element, **attrs) -> StaticDOMElement:
"""Helper function for making DOM elements."""
return StaticDOMElement(tag, list(children), attrs)
code = functools.partial(_make_elem, "code")
div = functools.partial(_make_elem, "div")
li = functools.partial(_make_elem, "li")
ol = functools.partial(_make_elem, "ol")
pre = functools.partial(_make_elem, "pre")
progress = functools.partial(_make_elem, "progress")
span = functools.partial(_make_elem, "span")
def css(text: str) -> StaticDOMElement:
return StaticDOMElement("style", [text], {})
def style(*args, **kwargs):
return _style_dict_to_str(dict(*args, **kwargs))
@@ -0,0 +1,230 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Hashable
import dataclasses
import inspect
import threading
from typing import Any, Protocol
from jax._src import callback
from jax._src import core
from jax._src import debugging
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src.lax import lax
@tree_util.register_pytree_node_class
class _DictWrapper:
keys: list[Hashable]
values: list[Any]
def __init__(self, keys, values):
self._keys = keys
self._values = values
def to_dict(self):
return dict(zip(self._keys, self._values))
def tree_flatten(self):
return self._values, self._keys
@classmethod
def tree_unflatten(cls, keys, values):
return _DictWrapper(keys, values)
class _CantFlatten:
__repr__ = lambda _: "<cant_flatten>"
cant_flatten = _CantFlatten()
def _safe_flatten_dict(dct: dict[Any, Any]
) -> tuple[list[Any], tree_util.PyTreeDef]:
# We avoid comparison between keys by just using the original order
keys, values = [], []
for key, value in dct.items():
try:
tree_util.tree_leaves(value)
except:
# If flattening fails, we substitute a sentinel object.
value = cant_flatten
keys.append(key)
values.append(value)
return tree_util.tree_flatten(_DictWrapper(keys, values))
@tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class DebuggerFrame:
"""Encapsulates Python frame information."""
filename: str
locals: dict[str, Any]
globals: dict[str, Any]
code_context: str
source: list[str]
lineno: int
offset: int | None
def tree_flatten(self):
flat_locals, locals_tree = _safe_flatten_dict(self.locals)
flat_globals, globals_tree = _safe_flatten_dict(self.globals)
flat_vars = flat_locals + flat_globals
is_valid = [isinstance(l, core.Tracer) for l in flat_vars]
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
return valid_vars, (is_valid, invalid_vars, locals_tree, globals_tree,
len(flat_locals), self.filename, self.code_context,
self.source, self.lineno, self.offset)
@classmethod
def tree_unflatten(cls, info, valid_vars):
(is_valid, invalid_vars, locals_tree, globals_tree, num_locals, filename,
code_context, source, lineno, offset) = info
flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars)
flat_locals, flat_globals = util.split_list(flat_vars, [num_locals])
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals).to_dict()
globals_ = tree_util.tree_unflatten(globals_tree, flat_globals).to_dict()
return DebuggerFrame(filename, locals_, globals_, code_context, source,
lineno, offset)
@classmethod
def from_frameinfo(cls, frame_info) -> DebuggerFrame:
try:
_, start = inspect.getsourcelines(frame_info.frame)
source = inspect.getsource(frame_info.frame).split("\n")
# Line numbers begin at 1 but offsets begin at 0. `inspect.getsource` will
# return a partial view of the file and a `start` indicating the line
# number that the source code starts at. However, it's possible that
# `start` is 0, indicating that we are at the beginning of the file. In
# this case, `offset` is just the `lineno - 1`. If `start` is nonzero,
# then we subtract it off from the `lineno` and don't need to subtract 1
# since both start and lineno are 1-indexed.
offset = frame_info.lineno - max(start, 1)
if offset >= len(source):
# Sometimes we don't get a valid source/offset pair. This seems to
# happen sometimes when code uses eval(). If that happens, give up.
source = []
offset = None
except OSError:
source = []
offset = None
return DebuggerFrame(
filename=frame_info.filename,
locals=frame_info.frame.f_locals,
globals={},
code_context=frame_info.code_context,
source=source,
lineno=frame_info.lineno,
offset=offset)
class Debugger(Protocol):
def __call__(self, frames: list[DebuggerFrame], thread_id: int | None,
**kwargs: Any) -> None:
...
_debugger_registry: dict[str, tuple[int, Debugger]] = {}
def get_debugger(backend: str | None = None) -> Debugger:
if backend is not None and backend in _debugger_registry:
return _debugger_registry[backend][1]
debuggers = sorted(_debugger_registry.values(), key=lambda x: -x[0])
if not debuggers:
raise ValueError("No debuggers registered!")
return debuggers[0][1]
def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
if name in _debugger_registry:
raise ValueError(f"Debugger with name \"{name}\" already registered.")
_debugger_registry[name] = (priority, debugger)
debug_lock = threading.Lock()
def breakpoint(*, backend: str | None = None, filter_frames: bool = True,
num_frames: int | None = None, ordered: bool = False,
token = None, **kwargs):
"""Enters a breakpoint at a point in a program.
Args:
backend: The debugger backend to use. By default, picks the highest priority
debugger and in the absence of other registered debuggers, falls back to
the CLI debugger.
filter_frames: Whether or not to filter out JAX-internal stack frames from
the traceback. Since some libraries, like Flax, also make use of JAX's
stack frame filtering system, this option can also affect whether stack
frames from libraries are filtered.
num_frames: The number of frames above the current stack frame to make
available for inspection in the interactive debugger.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this ``jax.debug.breakpoint``
with respect to other ordered ``jax.debug.breakpoint`` and ``jax.debug.print``
calls.
token: A keyword only argument; an alternative to ``ordered``. If used then a JAX
array (or pytree of JAX arrays) should be passed, and the breakpoint will be run
once its value is computed.
This is returned unchanged, and should be passed back to the computation.
If the return value is unused in the later computation, then the whole computation
will be pruned and this breakpoint will not be run.
Returns:
If `token` is passed, then its value is returned unchanged. Otherwise, returns
`None`.
"""
if token is not None:
if ordered:
raise ValueError("`ordered` and `token` are mutually exclusive arguments.")
frame_infos = inspect.stack()
# Throw out first frame corresponding to this function
frame_infos = frame_infos[1:]
# Filter out internal frames
if filter_frames:
frames = [
DebuggerFrame.from_frameinfo(frame_info)
for frame_info in frame_infos
if traceback_util.include_frame(frame_info.frame)
]
else:
frames = [
DebuggerFrame.from_frameinfo(frame_info)
for frame_info in frame_infos
]
if num_frames is not None:
frames = frames[:num_frames]
flat_args, frames_tree = tree_util.tree_flatten(frames)
def _breakpoint_callback(*flat_args):
frames = tree_util.tree_unflatten(frames_tree, flat_args)
thread_id = None
if threading.current_thread() is not threading.main_thread():
thread_id = threading.get_ident()
debugger = get_debugger(backend=backend)
# Lock here because this could be called from multiple threads at the same
# time.
with debug_lock:
debugger(frames, thread_id, **kwargs)
if token is None:
debugging.debug_callback(_breakpoint_callback, *flat_args, ordered=ordered)
else:
def _breakpoint_callback_wrapper(x, *flat_args):
_breakpoint_callback(*flat_args)
return x
token, flat_args = lax.stop_gradient((token, flat_args))
return callback.pure_callback(_breakpoint_callback_wrapper, token, token, *flat_args)
@@ -0,0 +1,106 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import atexit
import functools
import importlib.util
import os
from typing import Any
import weakref
from jax._src.debugger import cli_debugger
from jax._src.debugger import core as debugger_core
@functools.cache
def _web_pdb_version() -> tuple[int, ...]:
import web_pdb # pytype: disable=import-error
return tuple(map(int, web_pdb.__version__.split(".")))
_web_consoles: dict[tuple[str, int], Any] = {}
@atexit.register
def _close_debuggers():
for console in _web_consoles.values():
console.close()
_web_consoles.clear()
class WebDebugger(cli_debugger.CliDebugger):
"""A web-based debugger."""
prompt = '(jdb) '
use_rawinput: bool = False
def __init__(self, frames: list[debugger_core.DebuggerFrame], thread_id,
completekey: str = "tab", host: str = "", port: int = 5555):
if (host, port) not in _web_consoles:
import web_pdb # pytype: disable=import-error
_web_consoles[host, port] = web_pdb.WebConsole(host, port, self)
# Clobber the debugger in the web console
_web_console = _web_consoles[host, port]
_web_console._debugger = weakref.proxy(self)
super().__init__(frames, thread_id, stdin=_web_console, stdout=_web_console,
completekey=completekey)
def get_current_frame_data(self):
# Constructs the info needed for the web console to display info
current_frame = self.current_frame()
filename = current_frame.filename
lines = current_frame.source
current_line = None
if current_frame.offset is not None:
current_line = current_frame.offset + 1
if _web_pdb_version() < (1, 4, 4):
return {
'filename': filename,
'listing': '\n'.join(lines),
'curr_line': current_line,
'total_lines': len(lines),
'breaklist': [],
}
return {
'dirname': os.path.dirname(os.path.abspath(filename)) + os.path.sep,
'filename': os.path.basename(filename),
'file_listing': '\n'.join(lines),
'current_line': current_line,
'breakpoints': [],
'globals': self.get_globals(),
'locals': self.get_locals(),
}
def get_globals(self):
current_frame = self.current_frame()
return "\n".join(
f"{key} = {value}"
for key, value in sorted(current_frame.globals.items()))
def get_locals(self):
current_frame = self.current_frame()
return "\n".join(
f"{key} = {value}"
for key, value in sorted(current_frame.locals.items()))
def run(self):
return self.cmdloop()
def run_debugger(frames: list[debugger_core.DebuggerFrame],
thread_id: int | None, **kwargs: Any):
WebDebugger(frames, thread_id, **kwargs).run()
if importlib.util.find_spec("web_pdb") is not None:
debugger_core.register_debugger("web", run_debugger, -2)