hand
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from jax._src.debugger.core import breakpoint as breakpoint
|
||||
from jax._src.debugger import cli_debugger
|
||||
from jax._src.debugger import colab_debugger
|
||||
from jax._src.debugger import web_debugger
|
||||
|
||||
del cli_debugger # For registration only
|
||||
del colab_debugger # For registration only
|
||||
del web_debugger # For registration only
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,170 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import cmd
|
||||
import pprint
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, IO
|
||||
|
||||
from jax._src.debugger import core as debugger_core
|
||||
|
||||
DebuggerFrame = debugger_core.DebuggerFrame
|
||||
|
||||
class CliDebugger(cmd.Cmd):
|
||||
"""A text-based debugger."""
|
||||
prompt = '(jdb) '
|
||||
|
||||
def __init__(self, frames: list[DebuggerFrame], thread_id,
|
||||
stdin: IO[str] | None = None, stdout: IO[str] | None = None,
|
||||
completekey: str = "tab"):
|
||||
super().__init__(stdin=stdin, stdout=stdout, completekey=completekey)
|
||||
self.use_rawinput = stdin is None
|
||||
self.frames = frames
|
||||
self.frame_index = 0
|
||||
self.thread_id = thread_id
|
||||
self.intro = 'Entering jdb:'
|
||||
|
||||
def current_frame(self):
|
||||
return self.frames[self.frame_index]
|
||||
|
||||
def evaluate(self, expr):
|
||||
env = {}
|
||||
curr_frame = self.frames[self.frame_index]
|
||||
env.update(curr_frame.globals)
|
||||
env.update(curr_frame.locals)
|
||||
return eval(expr, {}, env)
|
||||
|
||||
def default(self, line):
|
||||
"""Evaluates an expression."""
|
||||
try:
|
||||
print(repr(self.evaluate(line)), file=self.stdout)
|
||||
except:
|
||||
self._error_message()
|
||||
|
||||
def print_backtrace(self):
|
||||
backtrace = []
|
||||
backtrace.append('Traceback:')
|
||||
for frame in self.frames[::-1]:
|
||||
backtrace.append(f' File "{frame.filename}", line {frame.lineno}')
|
||||
if frame.offset is None:
|
||||
backtrace.append(' <no source>')
|
||||
else:
|
||||
line = frame.source[frame.offset]
|
||||
backtrace.append(f' {line.strip()}')
|
||||
print("\n".join(backtrace), file=self.stdout)
|
||||
|
||||
def print_context(self, num_lines=2):
|
||||
curr_frame = self.frames[self.frame_index]
|
||||
context = []
|
||||
context.append(f'> {curr_frame.filename}({curr_frame.lineno})')
|
||||
for i, line in enumerate(curr_frame.source):
|
||||
assert curr_frame.offset is not None
|
||||
if (curr_frame.offset - 1 - num_lines <= i <=
|
||||
curr_frame.offset + num_lines):
|
||||
if i == curr_frame.offset:
|
||||
context.append(f'-> {line}')
|
||||
else:
|
||||
context.append(f' {line}')
|
||||
print("\n".join(context), file=self.stdout)
|
||||
|
||||
def _error_message(self):
|
||||
exc_info = sys.exc_info()[:2]
|
||||
msg = traceback.format_exception_only(*exc_info)[-1].strip()
|
||||
print('***', msg, file=self.stdout)
|
||||
|
||||
def do_p(self, arg):
|
||||
"""p expression
|
||||
Evaluates and prints the value of an expression
|
||||
"""
|
||||
try:
|
||||
print(repr(self.evaluate(arg)), file=self.stdout)
|
||||
except:
|
||||
self._error_message()
|
||||
|
||||
def do_pp(self, arg):
|
||||
"""pp expression
|
||||
Evaluates and pretty-prints the value of an expression
|
||||
"""
|
||||
try:
|
||||
print(pprint.pformat(self.evaluate(arg)), file=self.stdout)
|
||||
except:
|
||||
self._error_message()
|
||||
|
||||
def do_up(self, arg, /):
|
||||
"""u(p)
|
||||
Move up a stack frame.
|
||||
"""
|
||||
del arg # unused
|
||||
if self.frame_index == len(self.frames) - 1:
|
||||
print('At topmost frame.', file=self.stdout)
|
||||
else:
|
||||
self.frame_index += 1
|
||||
self.print_context()
|
||||
do_u = do_up
|
||||
|
||||
def do_down(self, arg, /):
|
||||
"""d(own)
|
||||
Move down a stack frame.
|
||||
"""
|
||||
del arg # unused
|
||||
if self.frame_index == 0:
|
||||
print('At bottommost frame.', file=self.stdout)
|
||||
else:
|
||||
self.frame_index -= 1
|
||||
self.print_context()
|
||||
do_d = do_down
|
||||
|
||||
def do_list(self, _):
|
||||
"""l(ist)
|
||||
List source code for the current file.
|
||||
"""
|
||||
self.print_context(num_lines=5)
|
||||
do_l = do_list
|
||||
|
||||
def do_continue(self, _):
|
||||
"""c(ont(inue))
|
||||
Continue the program's execution.
|
||||
"""
|
||||
return True
|
||||
do_c = do_cont = do_continue
|
||||
|
||||
def do_quit(self, _):
|
||||
"""q(uit)\n(exit)
|
||||
Quit the debugger. The program is given an exit command.
|
||||
"""
|
||||
sys.exit(0)
|
||||
do_q = do_EOF = do_exit = do_quit
|
||||
|
||||
def do_where(self, _):
|
||||
"""w(here)
|
||||
Prints a stack trace with the most recent frame on the bottom.
|
||||
'bt' is an alias for this command.
|
||||
"""
|
||||
self.print_backtrace()
|
||||
do_w = do_bt = do_where
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
try:
|
||||
self.cmdloop()
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
print('--KeyboardInterrupt--', file=sys.stdout)
|
||||
|
||||
def run_debugger(frames: list[DebuggerFrame], thread_id: int | None,
|
||||
**kwargs: Any):
|
||||
CliDebugger(frames, thread_id, **kwargs).run()
|
||||
debugger_core.register_debugger("cli", run_debugger, -1)
|
||||
@@ -0,0 +1,256 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for Colab-specific debugger."""
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import traceback
|
||||
|
||||
import uuid
|
||||
|
||||
from jax._src.debugger import colab_lib
|
||||
from jax._src.debugger import core as debugger_core
|
||||
from jax._src.debugger import cli_debugger
|
||||
# pytype: disable=import-error
|
||||
if colab_lib.IS_COLAB_ENABLED:
|
||||
from google.colab import output
|
||||
try:
|
||||
import pygments
|
||||
import pygments.lexers
|
||||
import pygments.formatters
|
||||
IS_PYGMENTS_ENABLED = True
|
||||
except ImportError:
|
||||
IS_PYGMENTS_ENABLED = False
|
||||
# pytype: enable=import-error
|
||||
|
||||
|
||||
class CodeViewer(colab_lib.DynamicDOMElement):
|
||||
"""A mutable DOM element that displays code as HTML."""
|
||||
|
||||
def __init__(self, code_: str, highlights: list[int], linenostart: int = 1):
|
||||
self._code = code_
|
||||
self._highlights = highlights
|
||||
self._view = colab_lib.dynamic(colab_lib.div())
|
||||
self._linenostart = linenostart
|
||||
|
||||
def render(self):
|
||||
self.update_code(
|
||||
self._code, self._highlights, linenostart=self._linenostart)
|
||||
|
||||
def clear(self):
|
||||
self._view.clear()
|
||||
|
||||
def append(self, child):
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, elem):
|
||||
self._view.update(elem)
|
||||
|
||||
def _highlight_code(self, code: str, highlights, linenostart: int):
|
||||
is_dark_mode = output.eval_js(
|
||||
'document.documentElement.matches("[theme=dark]");')
|
||||
code_style = "monokai" if is_dark_mode else "default"
|
||||
hl_color = "#4e56b7" if is_dark_mode else "#fff7c1"
|
||||
if IS_PYGMENTS_ENABLED:
|
||||
lexer = pygments.lexers.get_lexer_by_name("python")
|
||||
formatter = pygments.formatters.HtmlFormatter(
|
||||
full=False,
|
||||
hl_lines=highlights,
|
||||
linenos=True,
|
||||
linenostart=linenostart,
|
||||
style=code_style)
|
||||
if hl_color:
|
||||
formatter.style.highlight_color = hl_color
|
||||
css_ = formatter.get_style_defs()
|
||||
code = pygments.highlight(code, lexer, formatter)
|
||||
else:
|
||||
return "";
|
||||
return code, css_
|
||||
|
||||
def update_code(self, code_, highlights, *, linenostart: int = 1):
|
||||
"""Updates the code viewer to use new code."""
|
||||
self._code = code_
|
||||
self._view.clear()
|
||||
code_, css_ = self._highlight_code(self._code, highlights, linenostart)
|
||||
uuid_ = uuid.uuid4()
|
||||
code_div = colab_lib.div(
|
||||
colab_lib.css(css_),
|
||||
code_,
|
||||
id=f"code-{uuid_}",
|
||||
style=colab_lib.style({
|
||||
"max-height": "500px",
|
||||
"overflow-y": "scroll",
|
||||
"background-color": "var(--colab-border-color)",
|
||||
"padding": "5px 5px 5px 5px",
|
||||
}))
|
||||
if highlights:
|
||||
percent_scroll = highlights[0] / len(self._code.split("\n"))
|
||||
else:
|
||||
percent_scroll = 0.
|
||||
self.update(code_div)
|
||||
# Scroll to where the line is
|
||||
output.eval_js("""
|
||||
console.log("{id}")
|
||||
var elem = document.getElementById("{id}")
|
||||
var maxScrollPosition = elem.scrollHeight - elem.clientHeight;
|
||||
elem.scrollTop = maxScrollPosition * {percent_scroll}
|
||||
""".format(id=f"code-{uuid_}", percent_scroll=percent_scroll))
|
||||
|
||||
|
||||
class FramePreview(colab_lib.DynamicDOMElement):
|
||||
"""Displays information about a stack frame."""
|
||||
|
||||
def __init__(self, frame):
|
||||
super().__init__()
|
||||
self._header = colab_lib.dynamic(
|
||||
colab_lib.div(colab_lib.pre(colab_lib.code(""))))
|
||||
self._code_view = CodeViewer("", highlights=[])
|
||||
self.frame = frame
|
||||
self._file_cache = {}
|
||||
|
||||
def clear(self):
|
||||
self._header.clear()
|
||||
self._code_view.clear()
|
||||
|
||||
def append(self, child):
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, elem):
|
||||
raise NotImplementedError
|
||||
|
||||
def update_frame(self, frame):
|
||||
"""Updates the frame viewer to use a new frame."""
|
||||
self.frame = frame
|
||||
lineno = self.frame.lineno or None
|
||||
filename = self.frame.filename.strip()
|
||||
if inspect.getmodulename(filename):
|
||||
if filename not in self._file_cache:
|
||||
try:
|
||||
with open(filename) as fp:
|
||||
self._file_cache[filename] = fp.read()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
if filename in self._file_cache:
|
||||
source = self._file_cache[filename]
|
||||
highlight = lineno
|
||||
linenostart = 1
|
||||
else:
|
||||
source = "\n".join(frame.source)
|
||||
highlight = min(frame.offset + 1, len(frame.source) - 1)
|
||||
linenostart = lineno - frame.offset
|
||||
self._header.clear()
|
||||
self._header.update(
|
||||
colab_lib.div(
|
||||
colab_lib.pre(colab_lib.code(f"{html.escape(filename)}({lineno})")),
|
||||
style=colab_lib.style({
|
||||
"padding": "5px 5px 5px 5px",
|
||||
"background-color": "var(--colab-highlighted-surface-color)",
|
||||
})))
|
||||
self._code_view.update_code(source, [highlight], linenostart=linenostart)
|
||||
|
||||
def render(self):
|
||||
self.update_frame(self.frame)
|
||||
|
||||
|
||||
class DebuggerView(colab_lib.DynamicDOMElement):
|
||||
"""Main view for the Colab debugger."""
|
||||
|
||||
def __init__(self, frame, *, log_color=""):
|
||||
super().__init__()
|
||||
self._interaction_log = colab_lib.dynamic(colab_lib.div())
|
||||
self._frame_preview = FramePreview(frame)
|
||||
self._header = colab_lib.dynamic(
|
||||
colab_lib.div(
|
||||
colab_lib.span("Breakpoint"),
|
||||
style=colab_lib.style({
|
||||
"background-color": "var(--colab-secondary-surface-color)",
|
||||
"color": "var(--colab-primary-text-color)",
|
||||
"padding": "5px 5px 5px 5px",
|
||||
"font-weight": "bold",
|
||||
})))
|
||||
|
||||
def render(self):
|
||||
self._header.render()
|
||||
self._frame_preview.render()
|
||||
self._interaction_log.render()
|
||||
|
||||
def append(self, child):
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, elem):
|
||||
raise NotImplementedError
|
||||
|
||||
def clear(self):
|
||||
self._header.clear()
|
||||
self._interaction_log.clear()
|
||||
self._frame_preview.clear()
|
||||
|
||||
def update_frame(self, frame):
|
||||
self._frame_preview.update_frame(frame)
|
||||
|
||||
def write(self, text):
|
||||
self._interaction_log.append(colab_lib.pre(text))
|
||||
|
||||
def read(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def readline(self):
|
||||
with output.use_tags(["stdin"]):
|
||||
user_input = input() + "\n"
|
||||
output.clear(output_tags=["stdin"])
|
||||
return user_input
|
||||
|
||||
def isatty(self):
|
||||
return True
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
class ColabDebugger(cli_debugger.CliDebugger):
|
||||
"""A JAX debugger for a Colab environment."""
|
||||
|
||||
def __init__(self,
|
||||
frames: list[debugger_core.DebuggerFrame],
|
||||
thread_id: int):
|
||||
super().__init__(frames, thread_id)
|
||||
self._debugger_view = DebuggerView(self.current_frame())
|
||||
self.stdout = self.stdin = self._debugger_view # pyrefly: ignore[bad-assignment]
|
||||
|
||||
def do_up(self, arg, /):
|
||||
super().do_up(arg)
|
||||
self._debugger_view.update_frame(self.current_frame())
|
||||
return False
|
||||
|
||||
def do_down(self, arg, /):
|
||||
super().do_down(arg)
|
||||
self._debugger_view.update_frame(self.current_frame())
|
||||
return False
|
||||
|
||||
def run(self):
|
||||
self._debugger_view.render()
|
||||
while True:
|
||||
if not self.cmdloop():
|
||||
return
|
||||
|
||||
|
||||
def _run_debugger(frames, thread_id, **kwargs):
|
||||
try:
|
||||
ColabDebugger(frames, thread_id, **kwargs).run()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if colab_lib.IS_COLAB_ENABLED:
|
||||
debugger_core.register_debugger("colab", _run_debugger, 1)
|
||||
@@ -0,0 +1,167 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for building interfaces in Colab."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
import functools
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from typing import Any, Union
|
||||
|
||||
IS_COLAB_ENABLED = "google.colab" in sys.modules
|
||||
if IS_COLAB_ENABLED:
|
||||
# pytype: disable=import-error
|
||||
from google.colab import output
|
||||
from IPython import display
|
||||
# pytype: enable=import-error
|
||||
|
||||
|
||||
class DOMElement(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
pass
|
||||
|
||||
|
||||
Element = Union[DOMElement, str]
|
||||
|
||||
|
||||
class DynamicDOMElement(DOMElement):
|
||||
"""A DOM element that can be mutated."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def render(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def append(self, child: DOMElement):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, elem: DOMElement):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DynamicDiv(DynamicDOMElement):
|
||||
"""A `div` that can be edited."""
|
||||
_uuid: str = dataclasses.field(init=False)
|
||||
_root_elem: DOMElement = dataclasses.field(init=False)
|
||||
elem: DOMElement | str
|
||||
|
||||
def __post_init__(self):
|
||||
self._uuid = str(uuid.uuid4())
|
||||
self._rendered = False
|
||||
self._root_elem = div(id=self.tag)
|
||||
|
||||
@property
|
||||
def tag(self):
|
||||
return f"tag-{self._uuid}"
|
||||
|
||||
def render(self):
|
||||
if self._rendered:
|
||||
raise ValueError("Can't call `render` twice.")
|
||||
self._root_elem.render()
|
||||
self._rendered = True
|
||||
if isinstance(self.elem, str):
|
||||
raise TypeError("Cannot render when self.elem is a string.")
|
||||
self.append(self.elem)
|
||||
|
||||
def append(self, child: DOMElement):
|
||||
if not self._rendered:
|
||||
self.render()
|
||||
with output.use_tags([self.tag]):
|
||||
with output.redirect_to_element(f"#{self.tag}"):
|
||||
child.render()
|
||||
|
||||
def update(self, elem: DOMElement):
|
||||
self.clear()
|
||||
self.elem = elem
|
||||
self.render()
|
||||
|
||||
def clear(self):
|
||||
output.clear(output_tags=[self.tag])
|
||||
self._rendered = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StaticDOMElement(DOMElement):
|
||||
"""An immutable DOM element."""
|
||||
_uuid: str = dataclasses.field(init=False)
|
||||
name: str
|
||||
children: list[str | DOMElement]
|
||||
attrs: dict[str, str]
|
||||
|
||||
def html(self):
|
||||
attr_str = ""
|
||||
if self.attrs:
|
||||
attr_str = " " + (" ".join(
|
||||
[f"{key}=\"{value}\"" for key, value in self.attrs.items()]))
|
||||
children = []
|
||||
children = "\n".join([str(c) for c in self.children])
|
||||
return f"<{self.name}{attr_str}>{children}</{self.name}>"
|
||||
|
||||
def render(self):
|
||||
display.display(display.HTML(self.html()))
|
||||
|
||||
def attr(self, key: str) -> str:
|
||||
return self.attrs[key]
|
||||
|
||||
def __str__(self):
|
||||
return self.html()
|
||||
|
||||
def __repr__(self):
|
||||
return self.html()
|
||||
|
||||
def append(self, child: DOMElement) -> DOMElement:
|
||||
return dataclasses.replace(self, children=[*self.children, child])
|
||||
|
||||
def replace(self, **kwargs) -> DOMElement:
|
||||
return dataclasses.replace(self, **kwargs)
|
||||
|
||||
|
||||
def _style_dict_to_str(style_dict: dict[str, Any]) -> str:
|
||||
return " ".join([f"{k}: {v};" for k, v in style_dict.items()])
|
||||
|
||||
|
||||
def dynamic(elem: StaticDOMElement) -> DynamicDiv:
|
||||
return DynamicDiv(elem)
|
||||
|
||||
|
||||
def _make_elem(tag: str, *children: Element, **attrs) -> StaticDOMElement:
|
||||
"""Helper function for making DOM elements."""
|
||||
return StaticDOMElement(tag, list(children), attrs)
|
||||
|
||||
|
||||
code = functools.partial(_make_elem, "code")
|
||||
div = functools.partial(_make_elem, "div")
|
||||
li = functools.partial(_make_elem, "li")
|
||||
ol = functools.partial(_make_elem, "ol")
|
||||
pre = functools.partial(_make_elem, "pre")
|
||||
progress = functools.partial(_make_elem, "progress")
|
||||
span = functools.partial(_make_elem, "span")
|
||||
|
||||
|
||||
def css(text: str) -> StaticDOMElement:
|
||||
return StaticDOMElement("style", [text], {})
|
||||
|
||||
|
||||
def style(*args, **kwargs):
|
||||
return _style_dict_to_str(dict(*args, **kwargs))
|
||||
@@ -0,0 +1,230 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable
|
||||
import dataclasses
|
||||
import inspect
|
||||
import threading
|
||||
from typing import Any, Protocol
|
||||
|
||||
from jax._src import callback
|
||||
from jax._src import core
|
||||
from jax._src import debugging
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.lax import lax
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
class _DictWrapper:
|
||||
keys: list[Hashable]
|
||||
values: list[Any]
|
||||
|
||||
def __init__(self, keys, values):
|
||||
self._keys = keys
|
||||
self._values = values
|
||||
|
||||
def to_dict(self):
|
||||
return dict(zip(self._keys, self._values))
|
||||
|
||||
def tree_flatten(self):
|
||||
return self._values, self._keys
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, keys, values):
|
||||
return _DictWrapper(keys, values)
|
||||
|
||||
|
||||
class _CantFlatten:
|
||||
__repr__ = lambda _: "<cant_flatten>"
|
||||
cant_flatten = _CantFlatten()
|
||||
|
||||
def _safe_flatten_dict(dct: dict[Any, Any]
|
||||
) -> tuple[list[Any], tree_util.PyTreeDef]:
|
||||
# We avoid comparison between keys by just using the original order
|
||||
keys, values = [], []
|
||||
for key, value in dct.items():
|
||||
try:
|
||||
tree_util.tree_leaves(value)
|
||||
except:
|
||||
# If flattening fails, we substitute a sentinel object.
|
||||
value = cant_flatten
|
||||
keys.append(key)
|
||||
values.append(value)
|
||||
return tree_util.tree_flatten(_DictWrapper(keys, values))
|
||||
|
||||
|
||||
@tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DebuggerFrame:
|
||||
"""Encapsulates Python frame information."""
|
||||
filename: str
|
||||
locals: dict[str, Any]
|
||||
globals: dict[str, Any]
|
||||
code_context: str
|
||||
source: list[str]
|
||||
lineno: int
|
||||
offset: int | None
|
||||
|
||||
def tree_flatten(self):
|
||||
flat_locals, locals_tree = _safe_flatten_dict(self.locals)
|
||||
flat_globals, globals_tree = _safe_flatten_dict(self.globals)
|
||||
flat_vars = flat_locals + flat_globals
|
||||
is_valid = [isinstance(l, core.Tracer) for l in flat_vars]
|
||||
invalid_vars, valid_vars = util.partition_list(is_valid, flat_vars)
|
||||
return valid_vars, (is_valid, invalid_vars, locals_tree, globals_tree,
|
||||
len(flat_locals), self.filename, self.code_context,
|
||||
self.source, self.lineno, self.offset)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, info, valid_vars):
|
||||
(is_valid, invalid_vars, locals_tree, globals_tree, num_locals, filename,
|
||||
code_context, source, lineno, offset) = info
|
||||
flat_vars = util.merge_lists(is_valid, invalid_vars, valid_vars)
|
||||
flat_locals, flat_globals = util.split_list(flat_vars, [num_locals])
|
||||
locals_ = tree_util.tree_unflatten(locals_tree, flat_locals).to_dict()
|
||||
globals_ = tree_util.tree_unflatten(globals_tree, flat_globals).to_dict()
|
||||
return DebuggerFrame(filename, locals_, globals_, code_context, source,
|
||||
lineno, offset)
|
||||
|
||||
@classmethod
|
||||
def from_frameinfo(cls, frame_info) -> DebuggerFrame:
|
||||
try:
|
||||
_, start = inspect.getsourcelines(frame_info.frame)
|
||||
source = inspect.getsource(frame_info.frame).split("\n")
|
||||
# Line numbers begin at 1 but offsets begin at 0. `inspect.getsource` will
|
||||
# return a partial view of the file and a `start` indicating the line
|
||||
# number that the source code starts at. However, it's possible that
|
||||
# `start` is 0, indicating that we are at the beginning of the file. In
|
||||
# this case, `offset` is just the `lineno - 1`. If `start` is nonzero,
|
||||
# then we subtract it off from the `lineno` and don't need to subtract 1
|
||||
# since both start and lineno are 1-indexed.
|
||||
offset = frame_info.lineno - max(start, 1)
|
||||
if offset >= len(source):
|
||||
# Sometimes we don't get a valid source/offset pair. This seems to
|
||||
# happen sometimes when code uses eval(). If that happens, give up.
|
||||
source = []
|
||||
offset = None
|
||||
except OSError:
|
||||
source = []
|
||||
offset = None
|
||||
return DebuggerFrame(
|
||||
filename=frame_info.filename,
|
||||
locals=frame_info.frame.f_locals,
|
||||
globals={},
|
||||
code_context=frame_info.code_context,
|
||||
source=source,
|
||||
lineno=frame_info.lineno,
|
||||
offset=offset)
|
||||
|
||||
|
||||
class Debugger(Protocol):
|
||||
|
||||
def __call__(self, frames: list[DebuggerFrame], thread_id: int | None,
|
||||
**kwargs: Any) -> None:
|
||||
...
|
||||
_debugger_registry: dict[str, tuple[int, Debugger]] = {}
|
||||
|
||||
|
||||
def get_debugger(backend: str | None = None) -> Debugger:
|
||||
if backend is not None and backend in _debugger_registry:
|
||||
return _debugger_registry[backend][1]
|
||||
debuggers = sorted(_debugger_registry.values(), key=lambda x: -x[0])
|
||||
if not debuggers:
|
||||
raise ValueError("No debuggers registered!")
|
||||
return debuggers[0][1]
|
||||
|
||||
|
||||
def register_debugger(name: str, debugger: Debugger, priority: int) -> None:
|
||||
if name in _debugger_registry:
|
||||
raise ValueError(f"Debugger with name \"{name}\" already registered.")
|
||||
_debugger_registry[name] = (priority, debugger)
|
||||
|
||||
|
||||
debug_lock = threading.Lock()
|
||||
|
||||
|
||||
def breakpoint(*, backend: str | None = None, filter_frames: bool = True,
|
||||
num_frames: int | None = None, ordered: bool = False,
|
||||
token = None, **kwargs):
|
||||
"""Enters a breakpoint at a point in a program.
|
||||
|
||||
Args:
|
||||
backend: The debugger backend to use. By default, picks the highest priority
|
||||
debugger and in the absence of other registered debuggers, falls back to
|
||||
the CLI debugger.
|
||||
filter_frames: Whether or not to filter out JAX-internal stack frames from
|
||||
the traceback. Since some libraries, like Flax, also make use of JAX's
|
||||
stack frame filtering system, this option can also affect whether stack
|
||||
frames from libraries are filtered.
|
||||
num_frames: The number of frames above the current stack frame to make
|
||||
available for inspection in the interactive debugger.
|
||||
ordered: A keyword only argument used to indicate whether or not the
|
||||
staged out computation will enforce ordering of this ``jax.debug.breakpoint``
|
||||
with respect to other ordered ``jax.debug.breakpoint`` and ``jax.debug.print``
|
||||
calls.
|
||||
token: A keyword only argument; an alternative to ``ordered``. If used then a JAX
|
||||
array (or pytree of JAX arrays) should be passed, and the breakpoint will be run
|
||||
once its value is computed.
|
||||
This is returned unchanged, and should be passed back to the computation.
|
||||
If the return value is unused in the later computation, then the whole computation
|
||||
will be pruned and this breakpoint will not be run.
|
||||
|
||||
Returns:
|
||||
If `token` is passed, then its value is returned unchanged. Otherwise, returns
|
||||
`None`.
|
||||
"""
|
||||
if token is not None:
|
||||
if ordered:
|
||||
raise ValueError("`ordered` and `token` are mutually exclusive arguments.")
|
||||
frame_infos = inspect.stack()
|
||||
# Throw out first frame corresponding to this function
|
||||
frame_infos = frame_infos[1:]
|
||||
# Filter out internal frames
|
||||
if filter_frames:
|
||||
frames = [
|
||||
DebuggerFrame.from_frameinfo(frame_info)
|
||||
for frame_info in frame_infos
|
||||
if traceback_util.include_frame(frame_info.frame)
|
||||
]
|
||||
else:
|
||||
frames = [
|
||||
DebuggerFrame.from_frameinfo(frame_info)
|
||||
for frame_info in frame_infos
|
||||
]
|
||||
if num_frames is not None:
|
||||
frames = frames[:num_frames]
|
||||
flat_args, frames_tree = tree_util.tree_flatten(frames)
|
||||
|
||||
def _breakpoint_callback(*flat_args):
|
||||
frames = tree_util.tree_unflatten(frames_tree, flat_args)
|
||||
thread_id = None
|
||||
if threading.current_thread() is not threading.main_thread():
|
||||
thread_id = threading.get_ident()
|
||||
debugger = get_debugger(backend=backend)
|
||||
# Lock here because this could be called from multiple threads at the same
|
||||
# time.
|
||||
with debug_lock:
|
||||
debugger(frames, thread_id, **kwargs)
|
||||
|
||||
if token is None:
|
||||
debugging.debug_callback(_breakpoint_callback, *flat_args, ordered=ordered)
|
||||
else:
|
||||
def _breakpoint_callback_wrapper(x, *flat_args):
|
||||
_breakpoint_callback(*flat_args)
|
||||
return x
|
||||
token, flat_args = lax.stop_gradient((token, flat_args))
|
||||
return callback.pure_callback(_breakpoint_callback_wrapper, token, token, *flat_args)
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import functools
|
||||
import importlib.util
|
||||
import os
|
||||
from typing import Any
|
||||
import weakref
|
||||
|
||||
from jax._src.debugger import cli_debugger
|
||||
from jax._src.debugger import core as debugger_core
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _web_pdb_version() -> tuple[int, ...]:
|
||||
import web_pdb # pytype: disable=import-error
|
||||
return tuple(map(int, web_pdb.__version__.split(".")))
|
||||
|
||||
|
||||
_web_consoles: dict[tuple[str, int], Any] = {}
|
||||
|
||||
|
||||
@atexit.register
|
||||
def _close_debuggers():
|
||||
for console in _web_consoles.values():
|
||||
console.close()
|
||||
_web_consoles.clear()
|
||||
|
||||
|
||||
class WebDebugger(cli_debugger.CliDebugger):
|
||||
"""A web-based debugger."""
|
||||
prompt = '(jdb) '
|
||||
use_rawinput: bool = False
|
||||
|
||||
def __init__(self, frames: list[debugger_core.DebuggerFrame], thread_id,
|
||||
completekey: str = "tab", host: str = "", port: int = 5555):
|
||||
if (host, port) not in _web_consoles:
|
||||
import web_pdb # pytype: disable=import-error
|
||||
_web_consoles[host, port] = web_pdb.WebConsole(host, port, self)
|
||||
# Clobber the debugger in the web console
|
||||
_web_console = _web_consoles[host, port]
|
||||
_web_console._debugger = weakref.proxy(self)
|
||||
super().__init__(frames, thread_id, stdin=_web_console, stdout=_web_console,
|
||||
completekey=completekey)
|
||||
|
||||
def get_current_frame_data(self):
|
||||
# Constructs the info needed for the web console to display info
|
||||
current_frame = self.current_frame()
|
||||
filename = current_frame.filename
|
||||
lines = current_frame.source
|
||||
current_line = None
|
||||
if current_frame.offset is not None:
|
||||
current_line = current_frame.offset + 1
|
||||
if _web_pdb_version() < (1, 4, 4):
|
||||
return {
|
||||
'filename': filename,
|
||||
'listing': '\n'.join(lines),
|
||||
'curr_line': current_line,
|
||||
'total_lines': len(lines),
|
||||
'breaklist': [],
|
||||
}
|
||||
return {
|
||||
'dirname': os.path.dirname(os.path.abspath(filename)) + os.path.sep,
|
||||
'filename': os.path.basename(filename),
|
||||
'file_listing': '\n'.join(lines),
|
||||
'current_line': current_line,
|
||||
'breakpoints': [],
|
||||
'globals': self.get_globals(),
|
||||
'locals': self.get_locals(),
|
||||
}
|
||||
|
||||
def get_globals(self):
|
||||
current_frame = self.current_frame()
|
||||
return "\n".join(
|
||||
f"{key} = {value}"
|
||||
for key, value in sorted(current_frame.globals.items()))
|
||||
|
||||
def get_locals(self):
|
||||
current_frame = self.current_frame()
|
||||
return "\n".join(
|
||||
f"{key} = {value}"
|
||||
for key, value in sorted(current_frame.locals.items()))
|
||||
|
||||
def run(self):
|
||||
return self.cmdloop()
|
||||
|
||||
def run_debugger(frames: list[debugger_core.DebuggerFrame],
|
||||
thread_id: int | None, **kwargs: Any):
|
||||
WebDebugger(frames, thread_id, **kwargs).run()
|
||||
|
||||
|
||||
if importlib.util.find_spec("web_pdb") is not None:
|
||||
debugger_core.register_debugger("web", run_debugger, -2)
|
||||
Reference in New Issue
Block a user