This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,29 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.sourcemap import SourceMap as SourceMap
from jax._src.sourcemap import MappingsGenerator as MappingsGenerator
from jax.experimental.source_mapper.common import Pass as Pass
from jax.experimental.source_mapper.common import register_pass as register_pass
from jax.experimental.source_mapper.common import all_passes as all_passes
from jax.experimental.source_mapper.common import filter_passes as filter_passes
from jax.experimental.source_mapper.common import compile_with_env as compile_with_env
from jax.experimental.source_mapper.common import SourceMapDump as SourceMapDump
from jax.experimental.source_mapper.generate_map import generate_sourcemaps as generate_sourcemaps
from jax.experimental.source_mapper.mlir import create_mlir_sourcemap as create_mlir_sourcemap
# We import the jaxpr and hlo passes to register them.
import jax.experimental.source_mapper.jaxpr # noqa: F401
from jax.experimental.source_mapper.jaxpr import canonicalize_filename as canonicalize_filename
import jax.experimental.source_mapper.hlo # noqa: F401
@@ -0,0 +1,92 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for generating source maps."""
import contextlib
import dataclasses
import re
from typing import Any, Protocol
from collections.abc import Sequence
from absl import flags
import jax
from jax._src import sourcemap
@dataclasses.dataclass(frozen=True)
class SourceMapDump:
"""A container for a source map and the paired generated code."""
source_map: sourcemap.SourceMap
generated_code: str
pass_name: str
class CompileFn(Protocol):
def __call__(self, work_dir, fn, f_args, f_kwargs, /, **kwargs) -> Any:
...
class GenerateDumpFn(Protocol):
def __call__(self, compile_result: Any, /, **kwargs) -> SourceMapDump:
...
@dataclasses.dataclass(frozen=True)
class Pass:
name: str
compile_fn: CompileFn
generate_dump: GenerateDumpFn
_pass_registry = {}
def register_pass(pass_: Pass):
if pass_.name in _pass_registry:
raise ValueError(f"Pass {pass_.name} already registered")
_pass_registry[pass_.name] = pass_
def all_passes() -> Sequence[Pass]:
return list(_pass_registry.values())
def filter_passes(regex: str) -> Sequence[Pass]:
"""Gets all registered passes whose display name matches the given regex."""
return [
pass_
for pass_ in _pass_registry.values()
if re.match(regex, pass_.name)
]
@contextlib.contextmanager
def flag_env(**kwargs):
"""A context manager for setting and restoring flags."""
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs}
for kwarg, new_value in kwargs.items():
setattr(flags.FLAGS, kwarg, new_value)
try:
yield
finally:
for kwarg, old_value in old_flags.items():
setattr(flags.FLAGS, kwarg, old_value)
def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags):
with flag_env(**env_flags):
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower(
*f_args, **f_kwargs
).compile(compiler_flags)
@@ -0,0 +1,57 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates source maps for JAX functions."""
import os
import tempfile
from typing import Protocol
from collections.abc import Sequence
from jax.experimental.source_mapper import common
class SourceMapGeneratorFn(Protocol):
def __call__(self, *args, **kwargs) -> Sequence[common.SourceMapDump]:
...
def generate_sourcemaps(
f,
passes: Sequence[common.Pass],
**pass_kwargs
) -> SourceMapGeneratorFn:
"""Generates a SourceMapBundle for the specified compiler passes.
Args:
f: The function to compile.
passes: Which compiler passes to generate sourcemaps for.
**pass_kwargs: Keyword arguments for individual passes.
"""
def wrapper(*args, **kwargs) -> Sequence[common.SourceMapDump]:
pass_results: list[common.SourceMapDump] = []
compile_cache = {}
with tempfile.TemporaryDirectory() as work_dir:
for pass_to_eval in passes:
if pass_to_eval.compile_fn not in compile_cache:
dirname = pass_to_eval.name.replace(":", "__")
pass_work_dir = os.path.join(work_dir, dirname)
os.makedirs(pass_work_dir, exist_ok=False)
compile_result = pass_to_eval.compile_fn(
pass_work_dir, f, args, kwargs, **pass_kwargs
)
compile_cache[pass_to_eval.compile_fn] = compile_result
compile_result = compile_cache[pass_to_eval.compile_fn]
pass_results.append(pass_to_eval.generate_dump(compile_result,
**pass_kwargs))
return pass_results
return wrapper
@@ -0,0 +1,227 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Source mapping generator for HLO dialects."""
import enum
import re
from typing import Any
import jax
from jax._src import sourcemap
from jax.experimental.source_mapper import common
from jax.experimental.source_mapper import mlir
class HloPass(enum.Enum):
STABLE_HLO = "hlo:stable-hlo"
ORIGINAL = "hlo:original"
OPTIMIZED = "hlo:optimized"
METADATA_REGEX = re.compile(
r"metadata={.*op_name=\"(?P<scope>.*)\""
r" source_file=\"(?P<src_file>.*)\""
r" source_line=(?P<src_line>[0-9]+).*?}"
)
# TODO(justinfu): Remove when new format is the default.
def _parse_hlo_old_format(lines: list[str]) -> sourcemap.SourceMap:
mappings = sourcemap.MappingsGenerator()
used_source_files = []
for line in lines:
mappings.new_group()
match = METADATA_REGEX.search(line)
if match:
match_dict = match.groupdict()
_ = match_dict["scope"] # Unused
src_file = match_dict["src_file"]
src_line = int(match_dict["src_line"])
if src_file not in used_source_files:
used_source_files.append(src_file)
src_file_idx = used_source_files.index(src_file)
src_line -= 1 # Segments are zero-indexed
first_col = line.index(line.strip()[0])
mappings.new_segment(first_col, src_file_idx, src_line, 0)
mappings.new_group()
return sourcemap.SourceMap(
version=3,
sources=used_source_files,
sources_content=[],
mappings=mappings.mappings(),
names=[],
)
def _parse_hlo_new_format(lines: list[str]) -> sourcemap.SourceMap:
file_names = {}
file_locations = {}
stack_frames = {}
current_section = None
for line in lines:
line = line.strip()
if not line:
continue
if line in ["FileNames", "FunctionNames", "FileLocations", "StackFrames"]:
current_section = line
continue
if current_section == "FileNames":
match = re.match(r"(\d+)\s+\"(.*)\"", line)
if match:
file_names[int(match.group(1))] = match.group(2)
elif current_section == "FileLocations":
# Format: 1 {file_name_id=1 function_name_id=1 line=153 end_line=153 column=2 end_column=31}
match = re.match(r"(\d+)\s+{(.*)}", line)
if match:
loc_id = int(match.group(1))
attrs = match.group(2)
loc_data = {}
for part in attrs.split():
if "=" in part:
k, v = part.split("=")
if k not in ["file_name_id", "function_name_id", "line",
"end_line", "column", "end_column"]:
raise ValueError(f"Unknown attribute for FileLocations: {k}")
loc_data[k] = int(v)
file_locations[loc_id] = loc_data
elif current_section == "StackFrames":
# Format: 1 {file_location_id=1 parent_frame_id=1}
match = re.match(r"(\d+)\s+{(.*)}", line)
if match:
frame_id = int(match.group(1))
attrs = match.group(2)
frame_data = {}
for part in attrs.split():
if "=" in part:
k, v = part.split("=")
if k not in ["file_location_id", "parent_frame_id"]:
raise ValueError(f"Unknown attribute for StackFrames: {k}")
frame_data[k] = int(v)
stack_frames[frame_id] = frame_data
mappings = sourcemap.MappingsGenerator()
used_source_files = []
for line in lines:
mappings.new_group()
if "metadata={" in line:
match = re.search(r"stack_frame_id=(\d+)", line)
if match:
stack_frame_id = int(match.group(1))
if stack_frame_id in stack_frames:
frame = stack_frames[stack_frame_id]
file_loc = file_locations.get(frame["file_location_id"])
if file_loc:
file_name = file_names.get(file_loc["file_name_id"])
if file_name:
if file_name not in used_source_files:
used_source_files.append(file_name)
src_file_idx = used_source_files.index(file_name)
src_line = file_loc["line"] - 1
first_col = line.index(line.strip()[0])
mappings.new_segment(first_col, src_file_idx, src_line, 0)
else:
raise ValueError(f"Could not find mapping for {file_loc=}")
else:
raise ValueError(f"Could not find mapping for {stack_frame_id=}")
mappings.new_group()
return sourcemap.SourceMap(
version=3,
sources=used_source_files,
sources_content=[],
mappings=mappings.mappings(),
names=[],
)
def parse_hlo_dump(text: str) -> sourcemap.SourceMap:
lines = text.split("\n")
if "FileNames" in text:
return _parse_hlo_new_format(lines)
return _parse_hlo_old_format(lines)
def trace_and_lower(work_dir, f, f_args, f_kwargs, **_):
lowered = jax.jit(lambda *args: f(*args, **f_kwargs)).lower(*f_args)
return (lowered, work_dir)
def stable_hlo_generate_dump(args: tuple[Any, str],
**_) -> common.SourceMapDump:
lowered, work_dir = args
del work_dir
hlo_text = lowered.as_text(debug_info=True)
source_map = mlir.create_mlir_sourcemap(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.STABLE_HLO.value,
)
common.register_pass(
common.Pass(
name=HloPass.STABLE_HLO.value,
compile_fn=trace_and_lower,
generate_dump=stable_hlo_generate_dump,
)
)
def original_hlo_generate_dump(args: tuple[Any, str],
**_) -> common.SourceMapDump:
lowered, work_dir = args
del work_dir
hlo_text = lowered.as_text(dialect="hlo", debug_info=True)
source_map = parse_hlo_dump(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.ORIGINAL.value,
)
common.register_pass(
common.Pass(
name=HloPass.ORIGINAL.value,
compile_fn=trace_and_lower,
generate_dump=original_hlo_generate_dump,
)
)
def optimized_generate_dump(args: tuple[Any, str],
xla_compiler_flags: dict[str, Any] | None = None,
**_) -> common.SourceMapDump:
lowered, work_dir = args
compilation_args = {"xla_dump_to": work_dir, **(xla_compiler_flags or {})}
hlo_text = lowered.compile(compilation_args).as_text()
source_map = parse_hlo_dump(hlo_text)
return common.SourceMapDump(
source_map=source_map,
generated_code=hlo_text,
pass_name=HloPass.OPTIMIZED.value,
)
common.register_pass(
common.Pass(
name=HloPass.OPTIMIZED.value,
compile_fn=trace_and_lower,
generate_dump=optimized_generate_dump,
)
)
@@ -0,0 +1,78 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Source mapping generator for Jaxprs."""
import re
from typing import Any
import jax
from jax._src import config
from jax._src import core
from jax._src import source_info_util
from jax._src import sourcemap
from jax.experimental.source_mapper import common
source_info_util.register_exclusion(__file__)
def compile_jaxpr(work_dir, f, f_args, f_kwargs, **_):
del work_dir
return jax.make_jaxpr(f)(*f_args, **f_kwargs)
def canonicalize_filename(file_name: str):
pattern = config.hlo_source_file_canonicalization_regex.value
if pattern:
file_name = re.sub(pattern, '', file_name)
return file_name
def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump:
pprint_mappings: list[list[tuple[int, int, Any]]] = []
pprint_str = jaxpr.pretty_print(source_map=pprint_mappings)
used_source_files = []
mappings = sourcemap.MappingsGenerator()
for pprint_map_line in pprint_mappings:
mappings.new_group()
for pprint_segment in pprint_map_line:
start_col, end_col, frame = pprint_segment
del end_col
file_name = canonicalize_filename(frame.file_name)
if file_name not in used_source_files:
used_source_files.append(file_name)
file_idx = used_source_files.index(file_name)
src_line = frame.start_line - 1 # Zero-indexed
src_col = frame.start_column
# A segment is a tuple of the form:
# (generated_col, src_file_idx, src_line, src_col)
mappings.new_segment(start_col, file_idx, src_line, src_col)
mappings.new_group()
source_map = sourcemap.SourceMap(
version=3,
sources=used_source_files,
sources_content=[],
mappings=mappings.mappings(),
names=[],
)
return common.SourceMapDump(
source_map=source_map,
generated_code=pprint_str,
pass_name='jaxpr',
)
common.register_pass(
common.Pass(
name='jaxpr', compile_fn=compile_jaxpr, generate_dump=make_jaxpr_dump
)
)
@@ -0,0 +1,141 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generating source mappings for MLIR dialects."""
import collections
import re
from typing import cast
from jax._src import sourcemap
# TODO(justinfu): Make a proper parser for MLIR dumps.
LOC_REGEX = re.compile(r"loc\(#loc(?P<id>[0-9]+)\)")
SRC_REGEX = re.compile(
r"#loc(?P<id>[0-9]+) ="
r" loc\(\"(?P<file>.*)\":(?P<line>[0-9]+):(?P<col>[0-9]+)"
r"( to (?P<endlineno>[0-9]+)?:(?P<endcolno>[0-9]+))?\)"
)
SCOPED_REGEX = re.compile(
r"#loc(?P<id>[0-9]+) = loc\(\"(?P<scope>.*)\"\(#loc(?P<tgt_id>[0-9]+)\)\)"
)
CALLSITE_REGEX = re.compile(
r"#loc(?P<id>[0-9]+) = loc\(callsite\(#loc(?P<callee>[0-9]+) at"
r" #loc(?P<caller>[0-9]+)\)\)"
)
Location = collections.namedtuple("Location", ["file", "line", "col"])
Redirect = collections.namedtuple("Redirect", ["tgt_id"])
def create_mlir_sourcemap(mlir_dump: str) -> sourcemap.SourceMap:
mappings = sourcemap.MappingsGenerator()
dump_lines: list[str] = mlir_dump.split("\n")
segment_dict, sources = parse_mlir_locations(dump_lines)
used_sources = []
used_sources_filenames = []
for line in dump_lines:
mappings.new_group()
match = LOC_REGEX.search(line)
if match:
loc_id = int(match.group("id"))
if loc_id not in segment_dict:
# TODO(justinfu): This happens on fusion locations - need to implement.
continue
segment = list(segment_dict[loc_id])
first_col = line.index(line.strip()[0])
segment[0] = first_col
# Remap the sourcefile index to only sourcefiles that are used.
# This is optional but makes the mapping file smaller by pruning
# unused sourcefiles.
source_idx = segment[1]
if source_idx not in used_sources:
used_sources.append(source_idx)
used_sources_filenames.append(sources[source_idx])
segment[1] = used_sources.index(source_idx)
mappings.new_segment(*segment)
mappings.new_group()
return sourcemap.SourceMap(
version=3,
sources=used_sources_filenames,
sources_content=[''] * len(used_sources_filenames),
mappings=mappings.mappings(),
names=[],
)
def parse_mlir_locations(
mlir_dump: list[str],
) -> tuple[dict[int, sourcemap.Segment], list[str]]:
locations: dict[int, Location | Redirect] = {}
source_files = []
for line in mlir_dump:
if line.startswith("#loc"):
src_match = SRC_REGEX.match(line)
if src_match:
match_dict = src_match.groupdict()
filename = match_dict["file"]
locations[int(match_dict["id"])] = Location(
file=filename,
line=int(match_dict["line"]),
col=int(match_dict["col"]),
)
if filename not in source_files:
source_files.append(filename)
continue
scoped_match = SCOPED_REGEX.match(line)
if scoped_match:
match_dict = scoped_match.groupdict()
locations[int(match_dict["id"])] = Redirect(
tgt_id=int(match_dict["tgt_id"])
)
continue
callsite_match = CALLSITE_REGEX.match(line)
if callsite_match:
match_dict = callsite_match.groupdict()
locations[int(match_dict["id"])] = Redirect(
tgt_id=int(match_dict["callee"])
)
continue
if "loc(unknown)" in line:
continue
# Resolve redirects
while True:
new_locations: dict[int, Location | Redirect] = {}
updated = False
for loc_id, loc in locations.items():
if isinstance(loc, Redirect):
new_locations[loc_id] = locations[loc.tgt_id]
updated = True
else:
new_locations[loc_id] = loc
locations = new_locations
if not updated:
break
segment_dict: dict[int, sourcemap.Segment] = {}
for id_, loc in locations.items():
# A segment is a tuple of the form:
# (generated_col, src_file_idx, src_line, src_col)
loc = cast(Location, loc)
segment_dict[id_] = (
0,
source_files.index(loc.file),
loc.line - 1, # Zero-indexed, so offset by 1.
loc.col,
)
return segment_dict, source_files