hand
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common utilities for generating source maps."""
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import re
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Sequence
|
||||
|
||||
from absl import flags
|
||||
import jax
|
||||
from jax._src import sourcemap
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SourceMapDump:
|
||||
"""A container for a source map and the paired generated code."""
|
||||
source_map: sourcemap.SourceMap
|
||||
generated_code: str
|
||||
pass_name: str
|
||||
|
||||
|
||||
class CompileFn(Protocol):
|
||||
|
||||
def __call__(self, work_dir, fn, f_args, f_kwargs, /, **kwargs) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class GenerateDumpFn(Protocol):
|
||||
|
||||
def __call__(self, compile_result: Any, /, **kwargs) -> SourceMapDump:
|
||||
...
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pass:
|
||||
name: str
|
||||
compile_fn: CompileFn
|
||||
generate_dump: GenerateDumpFn
|
||||
|
||||
|
||||
_pass_registry = {}
|
||||
|
||||
|
||||
def register_pass(pass_: Pass):
|
||||
if pass_.name in _pass_registry:
|
||||
raise ValueError(f"Pass {pass_.name} already registered")
|
||||
_pass_registry[pass_.name] = pass_
|
||||
|
||||
|
||||
def all_passes() -> Sequence[Pass]:
|
||||
return list(_pass_registry.values())
|
||||
|
||||
|
||||
def filter_passes(regex: str) -> Sequence[Pass]:
|
||||
"""Gets all registered passes whose display name matches the given regex."""
|
||||
return [
|
||||
pass_
|
||||
for pass_ in _pass_registry.values()
|
||||
if re.match(regex, pass_.name)
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def flag_env(**kwargs):
|
||||
"""A context manager for setting and restoring flags."""
|
||||
old_flags = {kwarg: getattr(flags.FLAGS, kwarg) for kwarg in kwargs}
|
||||
for kwarg, new_value in kwargs.items():
|
||||
setattr(flags.FLAGS, kwarg, new_value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for kwarg, old_value in old_flags.items():
|
||||
setattr(flags.FLAGS, kwarg, old_value)
|
||||
|
||||
|
||||
def compile_with_env(f, f_args, f_kwargs, env_flags, compiler_flags):
|
||||
with flag_env(**env_flags):
|
||||
jax.jit(lambda *args, **kwargs: f(*args, **kwargs)).lower(
|
||||
*f_args, **f_kwargs
|
||||
).compile(compiler_flags)
|
||||
Reference in New Issue
Block a user