hand
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.sourcemap import SourceMap as SourceMap
|
||||
from jax._src.sourcemap import MappingsGenerator as MappingsGenerator
|
||||
from jax.experimental.source_mapper.common import Pass as Pass
|
||||
from jax.experimental.source_mapper.common import register_pass as register_pass
|
||||
from jax.experimental.source_mapper.common import all_passes as all_passes
|
||||
from jax.experimental.source_mapper.common import filter_passes as filter_passes
|
||||
from jax.experimental.source_mapper.common import compile_with_env as compile_with_env
|
||||
from jax.experimental.source_mapper.common import SourceMapDump as SourceMapDump
|
||||
from jax.experimental.source_mapper.generate_map import generate_sourcemaps as generate_sourcemaps
|
||||
from jax.experimental.source_mapper.mlir import create_mlir_sourcemap as create_mlir_sourcemap
|
||||
|
||||
# We import the jaxpr and hlo passes to register them.
|
||||
import jax.experimental.source_mapper.jaxpr # noqa: F401
|
||||
from jax.experimental.source_mapper.jaxpr import canonicalize_filename as canonicalize_filename
|
||||
import jax.experimental.source_mapper.hlo # noqa: F401
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,92 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common utilities for generating source maps."""
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Sequence
|
||||
|
||||
from absl import flags
|
||||
import jax
|
||||
from jax._src import sourcemap
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SourceMapDump:
|
||||
"""A container for a source map and the paired generated code."""
|
||||
source_map: sourcemap.SourceMap
|
||||
generated_code: str
|
||||
pass_name: str
|
||||
|
||||
|
||||
class CompileFn(Protocol):
|
||||
|
||||
def __call__(self, work_dir, fn, f_args, f_kwargs, /, **kwargs) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class GenerateDumpFn(Protocol):
|
||||
|
||||
def __call__(self, compile_result: Any, /, **kwargs) -> SourceMapDump:
|
||||
...
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pass:
|
||||
name: str
|
||||
compile_fn: CompileFn
|
||||
generate_dump: GenerateDumpFn
|
||||
|
||||
|
||||
_pass_registry = {}
|
||||
|
||||
|
||||
def register_pass(pass_: Pass):
|
||||
if pass_.name in _pass_registry:
|
||||
raise ValueError(f"Pass {pass_.name} already registered")
|
||||
_pass_registry[pass_.name] = pass_
|
||||
|
||||
|
||||
def all_passes() -> Sequence[Pass]:
|
||||
return list(_pass_registry.values())
|
||||
|
||||
|
||||
def filter_passes(regex: str) -> Sequence[Pass]:
|
||||
"""Gets all registered passes whose display name matches the given regex."""
|
||||
return [
|
||||
pass_
|
||||
for pass_ in _pass_registry.values()
|
||||
if re.match(regex, pass_.name)
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def flag_env(**kwargs):
|
||||
"""A context manager for setting and restoring flags."""
|
||||
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs}
|
||||
for kwarg, new_value in kwargs.items():
|
||||
setattr(flags.FLAGS, kwarg, new_value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for kwarg, old_value in old_flags.items():
|
||||
setattr(flags.FLAGS, kwarg, old_value)
|
||||
|
||||
|
||||
def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags):
|
||||
with flag_env(**env_flags):
|
||||
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower(
|
||||
*f_args, **f_kwargs
|
||||
).compile(compiler_flags)
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Generates source maps for JAX functions."""
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Protocol
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jax.experimental.source_mapper import common
|
||||
|
||||
|
||||
class SourceMapGeneratorFn(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> Sequence[common.SourceMapDump]:
|
||||
...
|
||||
|
||||
|
||||
def generate_sourcemaps(
|
||||
f,
|
||||
passes: Sequence[common.Pass],
|
||||
**pass_kwargs
|
||||
) -> SourceMapGeneratorFn:
|
||||
"""Generates a SourceMapBundle for the specified compiler passes.
|
||||
|
||||
Args:
|
||||
f: The function to compile.
|
||||
passes: Which compiler passes to generate sourcemaps for.
|
||||
**pass_kwargs: Keyword arguments for individual passes.
|
||||
"""
|
||||
def wrapper(*args, **kwargs) -> Sequence[common.SourceMapDump]:
|
||||
pass_results: list[common.SourceMapDump] = []
|
||||
compile_cache = {}
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
for pass_to_eval in passes:
|
||||
if pass_to_eval.compile_fn not in compile_cache:
|
||||
dirname = pass_to_eval.name.replace(":", "__")
|
||||
pass_work_dir = os.path.join(work_dir, dirname)
|
||||
os.makedirs(pass_work_dir, exist_ok=False)
|
||||
compile_result = pass_to_eval.compile_fn(
|
||||
pass_work_dir, f, args, kwargs, **pass_kwargs
|
||||
)
|
||||
compile_cache[pass_to_eval.compile_fn] = compile_result
|
||||
compile_result = compile_cache[pass_to_eval.compile_fn]
|
||||
pass_results.append(pass_to_eval.generate_dump(compile_result,
|
||||
**pass_kwargs))
|
||||
return pass_results
|
||||
return wrapper
|
||||
@@ -0,0 +1,227 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Source mapping generator for HLO dialects."""
|
||||
import enum
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import sourcemap
|
||||
|
||||
from jax.experimental.source_mapper import common
|
||||
from jax.experimental.source_mapper import mlir
|
||||
|
||||
|
||||
class HloPass(enum.Enum):
|
||||
STABLE_HLO = "hlo:stable-hlo"
|
||||
ORIGINAL = "hlo:original"
|
||||
OPTIMIZED = "hlo:optimized"
|
||||
|
||||
|
||||
METADATA_REGEX = re.compile(
|
||||
r"metadata={.*op_name=\"(?P<scope>.*)\""
|
||||
r" source_file=\"(?P<src_file>.*)\""
|
||||
r" source_line=(?P<src_line>[0-9]+).*?}"
|
||||
)
|
||||
|
||||
|
||||
# TODO(justinfu): Remove when new format is the default.
|
||||
def _parse_hlo_old_format(lines: list[str]) -> sourcemap.SourceMap:
|
||||
mappings = sourcemap.MappingsGenerator()
|
||||
used_source_files = []
|
||||
for line in lines:
|
||||
mappings.new_group()
|
||||
match = METADATA_REGEX.search(line)
|
||||
if match:
|
||||
match_dict = match.groupdict()
|
||||
_ = match_dict["scope"] # Unused
|
||||
src_file = match_dict["src_file"]
|
||||
src_line = int(match_dict["src_line"])
|
||||
if src_file not in used_source_files:
|
||||
used_source_files.append(src_file)
|
||||
src_file_idx = used_source_files.index(src_file)
|
||||
src_line -= 1 # Segments are zero-indexed
|
||||
first_col = line.index(line.strip()[0])
|
||||
mappings.new_segment(first_col, src_file_idx, src_line, 0)
|
||||
mappings.new_group()
|
||||
|
||||
return sourcemap.SourceMap(
|
||||
version=3,
|
||||
sources=used_source_files,
|
||||
sources_content=[],
|
||||
mappings=mappings.mappings(),
|
||||
names=[],
|
||||
)
|
||||
|
||||
|
||||
def _parse_hlo_new_format(lines: list[str]) -> sourcemap.SourceMap:
|
||||
file_names = {}
|
||||
file_locations = {}
|
||||
stack_frames = {}
|
||||
current_section = None
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line in ["FileNames", "FunctionNames", "FileLocations", "StackFrames"]:
|
||||
current_section = line
|
||||
continue
|
||||
|
||||
if current_section == "FileNames":
|
||||
match = re.match(r"(\d+)\s+\"(.*)\"", line)
|
||||
if match:
|
||||
file_names[int(match.group(1))] = match.group(2)
|
||||
elif current_section == "FileLocations":
|
||||
# Format: 1 {file_name_id=1 function_name_id=1 line=153 end_line=153 column=2 end_column=31}
|
||||
match = re.match(r"(\d+)\s+{(.*)}", line)
|
||||
if match:
|
||||
loc_id = int(match.group(1))
|
||||
attrs = match.group(2)
|
||||
loc_data = {}
|
||||
for part in attrs.split():
|
||||
if "=" in part:
|
||||
k, v = part.split("=")
|
||||
if k not in ["file_name_id", "function_name_id", "line",
|
||||
"end_line", "column", "end_column"]:
|
||||
raise ValueError(f"Unknown attribute for FileLocations: {k}")
|
||||
loc_data[k] = int(v)
|
||||
file_locations[loc_id] = loc_data
|
||||
elif current_section == "StackFrames":
|
||||
# Format: 1 {file_location_id=1 parent_frame_id=1}
|
||||
match = re.match(r"(\d+)\s+{(.*)}", line)
|
||||
if match:
|
||||
frame_id = int(match.group(1))
|
||||
attrs = match.group(2)
|
||||
frame_data = {}
|
||||
for part in attrs.split():
|
||||
if "=" in part:
|
||||
k, v = part.split("=")
|
||||
if k not in ["file_location_id", "parent_frame_id"]:
|
||||
raise ValueError(f"Unknown attribute for StackFrames: {k}")
|
||||
frame_data[k] = int(v)
|
||||
stack_frames[frame_id] = frame_data
|
||||
|
||||
mappings = sourcemap.MappingsGenerator()
|
||||
used_source_files = []
|
||||
|
||||
for line in lines:
|
||||
mappings.new_group()
|
||||
if "metadata={" in line:
|
||||
match = re.search(r"stack_frame_id=(\d+)", line)
|
||||
if match:
|
||||
stack_frame_id = int(match.group(1))
|
||||
if stack_frame_id in stack_frames:
|
||||
frame = stack_frames[stack_frame_id]
|
||||
file_loc = file_locations.get(frame["file_location_id"])
|
||||
if file_loc:
|
||||
file_name = file_names.get(file_loc["file_name_id"])
|
||||
if file_name:
|
||||
if file_name not in used_source_files:
|
||||
used_source_files.append(file_name)
|
||||
src_file_idx = used_source_files.index(file_name)
|
||||
src_line = file_loc["line"] - 1
|
||||
first_col = line.index(line.strip()[0])
|
||||
mappings.new_segment(first_col, src_file_idx, src_line, 0)
|
||||
else:
|
||||
raise ValueError(f"Could not find mapping for {file_loc=}")
|
||||
else:
|
||||
raise ValueError(f"Could not find mapping for {stack_frame_id=}")
|
||||
mappings.new_group()
|
||||
return sourcemap.SourceMap(
|
||||
version=3,
|
||||
sources=used_source_files,
|
||||
sources_content=[],
|
||||
mappings=mappings.mappings(),
|
||||
names=[],
|
||||
)
|
||||
|
||||
|
||||
def parse_hlo_dump(text: str) -> sourcemap.SourceMap:
|
||||
lines = text.split("\n")
|
||||
if "FileNames" in text:
|
||||
return _parse_hlo_new_format(lines)
|
||||
return _parse_hlo_old_format(lines)
|
||||
|
||||
|
||||
def trace_and_lower(work_dir, f, f_args, f_kwargs, **_):
|
||||
lowered = jax.jit(lambda *args: f(*args, **f_kwargs)).lower(*f_args)
|
||||
return (lowered, work_dir)
|
||||
|
||||
|
||||
def stable_hlo_generate_dump(args: tuple[Any, str],
|
||||
**_) -> common.SourceMapDump:
|
||||
lowered, work_dir = args
|
||||
del work_dir
|
||||
hlo_text = lowered.as_text(debug_info=True)
|
||||
source_map = mlir.create_mlir_sourcemap(hlo_text)
|
||||
return common.SourceMapDump(
|
||||
source_map=source_map,
|
||||
generated_code=hlo_text,
|
||||
pass_name=HloPass.STABLE_HLO.value,
|
||||
)
|
||||
|
||||
|
||||
common.register_pass(
|
||||
common.Pass(
|
||||
name=HloPass.STABLE_HLO.value,
|
||||
compile_fn=trace_and_lower,
|
||||
generate_dump=stable_hlo_generate_dump,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def original_hlo_generate_dump(args: tuple[Any, str],
|
||||
**_) -> common.SourceMapDump:
|
||||
lowered, work_dir = args
|
||||
del work_dir
|
||||
hlo_text = lowered.as_text(dialect="hlo", debug_info=True)
|
||||
source_map = parse_hlo_dump(hlo_text)
|
||||
return common.SourceMapDump(
|
||||
source_map=source_map,
|
||||
generated_code=hlo_text,
|
||||
pass_name=HloPass.ORIGINAL.value,
|
||||
)
|
||||
|
||||
|
||||
common.register_pass(
|
||||
common.Pass(
|
||||
name=HloPass.ORIGINAL.value,
|
||||
compile_fn=trace_and_lower,
|
||||
generate_dump=original_hlo_generate_dump,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def optimized_generate_dump(args: tuple[Any, str],
|
||||
xla_compiler_flags: dict[str, Any] | None = None,
|
||||
**_) -> common.SourceMapDump:
|
||||
lowered, work_dir = args
|
||||
compilation_args = {"xla_dump_to": work_dir, **(xla_compiler_flags or {})}
|
||||
hlo_text = lowered.compile(compilation_args).as_text()
|
||||
source_map = parse_hlo_dump(hlo_text)
|
||||
return common.SourceMapDump(
|
||||
source_map=source_map,
|
||||
generated_code=hlo_text,
|
||||
pass_name=HloPass.OPTIMIZED.value,
|
||||
)
|
||||
|
||||
|
||||
common.register_pass(
|
||||
common.Pass(
|
||||
name=HloPass.OPTIMIZED.value,
|
||||
compile_fn=trace_and_lower,
|
||||
generate_dump=optimized_generate_dump,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Source mapping generator for Jaxprs."""
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import sourcemap
|
||||
from jax.experimental.source_mapper import common
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
def compile_jaxpr(work_dir, f, f_args, f_kwargs, **_):
|
||||
del work_dir
|
||||
return jax.make_jaxpr(f)(*f_args, **f_kwargs)
|
||||
|
||||
|
||||
def canonicalize_filename(file_name: str):
|
||||
pattern = config.hlo_source_file_canonicalization_regex.value
|
||||
if pattern:
|
||||
file_name = re.sub(pattern, '', file_name)
|
||||
return file_name
|
||||
|
||||
|
||||
def make_jaxpr_dump(jaxpr: core.Jaxpr, **_) -> common.SourceMapDump:
|
||||
pprint_mappings: list[list[tuple[int, int, Any]]] = []
|
||||
pprint_str = jaxpr.pretty_print(source_map=pprint_mappings)
|
||||
used_source_files = []
|
||||
mappings = sourcemap.MappingsGenerator()
|
||||
for pprint_map_line in pprint_mappings:
|
||||
mappings.new_group()
|
||||
for pprint_segment in pprint_map_line:
|
||||
start_col, end_col, frame = pprint_segment
|
||||
del end_col
|
||||
file_name = canonicalize_filename(frame.file_name)
|
||||
if file_name not in used_source_files:
|
||||
used_source_files.append(file_name)
|
||||
file_idx = used_source_files.index(file_name)
|
||||
src_line = frame.start_line - 1 # Zero-indexed
|
||||
src_col = frame.start_column
|
||||
# A segment is a tuple of the form:
|
||||
# (generated_col, src_file_idx, src_line, src_col)
|
||||
mappings.new_segment(start_col, file_idx, src_line, src_col)
|
||||
mappings.new_group()
|
||||
source_map = sourcemap.SourceMap(
|
||||
version=3,
|
||||
sources=used_source_files,
|
||||
sources_content=[],
|
||||
mappings=mappings.mappings(),
|
||||
names=[],
|
||||
)
|
||||
return common.SourceMapDump(
|
||||
source_map=source_map,
|
||||
generated_code=pprint_str,
|
||||
pass_name='jaxpr',
|
||||
)
|
||||
|
||||
|
||||
common.register_pass(
|
||||
common.Pass(
|
||||
name='jaxpr', compile_fn=compile_jaxpr, generate_dump=make_jaxpr_dump
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,141 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for generating source mappings for MLIR dialects."""
|
||||
import collections
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from jax._src import sourcemap
|
||||
|
||||
|
||||
# TODO(justinfu): Make a proper parser for MLIR dumps.
|
||||
LOC_REGEX = re.compile(r"loc\(#loc(?P<id>[0-9]+)\)")
|
||||
|
||||
SRC_REGEX = re.compile(
|
||||
r"#loc(?P<id>[0-9]+) ="
|
||||
r" loc\(\"(?P<file>.*)\":(?P<line>[0-9]+):(?P<col>[0-9]+)"
|
||||
r"( to (?P<endlineno>[0-9]+)?:(?P<endcolno>[0-9]+))?\)"
|
||||
)
|
||||
|
||||
SCOPED_REGEX = re.compile(
|
||||
r"#loc(?P<id>[0-9]+) = loc\(\"(?P<scope>.*)\"\(#loc(?P<tgt_id>[0-9]+)\)\)"
|
||||
)
|
||||
|
||||
CALLSITE_REGEX = re.compile(
|
||||
r"#loc(?P<id>[0-9]+) = loc\(callsite\(#loc(?P<callee>[0-9]+) at"
|
||||
r" #loc(?P<caller>[0-9]+)\)\)"
|
||||
)
|
||||
|
||||
Location = collections.namedtuple("Location", ["file", "line", "col"])
|
||||
Redirect = collections.namedtuple("Redirect", ["tgt_id"])
|
||||
|
||||
|
||||
def create_mlir_sourcemap(mlir_dump: str) -> sourcemap.SourceMap:
|
||||
mappings = sourcemap.MappingsGenerator()
|
||||
dump_lines: list[str] = mlir_dump.split("\n")
|
||||
|
||||
segment_dict, sources = parse_mlir_locations(dump_lines)
|
||||
used_sources = []
|
||||
used_sources_filenames = []
|
||||
for line in dump_lines:
|
||||
mappings.new_group()
|
||||
match = LOC_REGEX.search(line)
|
||||
if match:
|
||||
loc_id = int(match.group("id"))
|
||||
if loc_id not in segment_dict:
|
||||
# TODO(justinfu): This happens on fusion locations - need to implement.
|
||||
continue
|
||||
segment = list(segment_dict[loc_id])
|
||||
first_col = line.index(line.strip()[0])
|
||||
segment[0] = first_col
|
||||
# Remap the sourcefile index to only sourcefiles that are used.
|
||||
# This is optional but makes the mapping file smaller by pruning
|
||||
# unused sourcefiles.
|
||||
source_idx = segment[1]
|
||||
if source_idx not in used_sources:
|
||||
used_sources.append(source_idx)
|
||||
used_sources_filenames.append(sources[source_idx])
|
||||
segment[1] = used_sources.index(source_idx)
|
||||
mappings.new_segment(*segment)
|
||||
mappings.new_group()
|
||||
|
||||
return sourcemap.SourceMap(
|
||||
version=3,
|
||||
sources=used_sources_filenames,
|
||||
sources_content=[''] * len(used_sources_filenames),
|
||||
mappings=mappings.mappings(),
|
||||
names=[],
|
||||
)
|
||||
|
||||
|
||||
def parse_mlir_locations(
|
||||
mlir_dump: list[str],
|
||||
) -> tuple[dict[int, sourcemap.Segment], list[str]]:
|
||||
locations: dict[int, Location | Redirect] = {}
|
||||
source_files = []
|
||||
for line in mlir_dump:
|
||||
if line.startswith("#loc"):
|
||||
src_match = SRC_REGEX.match(line)
|
||||
if src_match:
|
||||
match_dict = src_match.groupdict()
|
||||
filename = match_dict["file"]
|
||||
locations[int(match_dict["id"])] = Location(
|
||||
file=filename,
|
||||
line=int(match_dict["line"]),
|
||||
col=int(match_dict["col"]),
|
||||
)
|
||||
if filename not in source_files:
|
||||
source_files.append(filename)
|
||||
continue
|
||||
scoped_match = SCOPED_REGEX.match(line)
|
||||
if scoped_match:
|
||||
match_dict = scoped_match.groupdict()
|
||||
locations[int(match_dict["id"])] = Redirect(
|
||||
tgt_id=int(match_dict["tgt_id"])
|
||||
)
|
||||
continue
|
||||
callsite_match = CALLSITE_REGEX.match(line)
|
||||
if callsite_match:
|
||||
match_dict = callsite_match.groupdict()
|
||||
locations[int(match_dict["id"])] = Redirect(
|
||||
tgt_id=int(match_dict["callee"])
|
||||
)
|
||||
continue
|
||||
if "loc(unknown)" in line:
|
||||
continue
|
||||
# Resolve redirects
|
||||
while True:
|
||||
new_locations: dict[int, Location | Redirect] = {}
|
||||
updated = False
|
||||
for loc_id, loc in locations.items():
|
||||
if isinstance(loc, Redirect):
|
||||
new_locations[loc_id] = locations[loc.tgt_id]
|
||||
updated = True
|
||||
else:
|
||||
new_locations[loc_id] = loc
|
||||
locations = new_locations
|
||||
if not updated:
|
||||
break
|
||||
segment_dict: dict[int, sourcemap.Segment] = {}
|
||||
for id_, loc in locations.items():
|
||||
# A segment is a tuple of the form:
|
||||
# (generated_col, src_file_idx, src_line, src_col)
|
||||
loc = cast(Location, loc)
|
||||
segment_dict[id_] = (
|
||||
0,
|
||||
source_files.index(loc.file),
|
||||
loc.line - 1, # Zero-indexed, so offset by 1.
|
||||
loc.col,
|
||||
)
|
||||
return segment_dict, source_files
|
||||
Reference in New Issue
Block a user