This commit is contained in:
2026-05-06 19:47:31 +07:00
parent 94d8682530
commit 12dbb7731b
9963 changed files with 2747894 additions and 0 deletions
@@ -0,0 +1,18 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from jax._src import traceback_util
traceback_util.register_exclusion(os.path.dirname(__file__))
@@ -0,0 +1,61 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from __future__ import annotations
from jax._src.interpreters.ad import (
JVPTrace as JVPTrace,
JVPTracer as JVPTracer,
UndefinedPrimal as UndefinedPrimal,
Zero as Zero,
add_jaxvals as add_jaxvals,
add_jaxvals_p as add_jaxvals_p,
add_tangents as add_tangents,
defbilinear as defbilinear,
defjvp as defjvp,
defjvp2 as defjvp2,
deflinear as deflinear,
deflinear2 as deflinear2,
get_primitive_transpose as get_primitive_transpose,
instantiate_zeros as instantiate_zeros,
is_undefined_primal as is_undefined_primal,
jvp as jvp,
linearize as linearize,
primitive_jvps as primitive_jvps,
primitive_transposes as primitive_transposes,
zeros_like_aval as zeros_like_aval,
)
_deprecations = {
# Deprecated in v0.9.0; finalized in v0.10.0.
# TODO(jakevdp) remove entry in v0.11.0.
"reducing_transposes": (
(
"jax.interpreters.ad.reducing_transposes was deprecated in v0.9.0."
" and removed in v0.10.0. It has been unused since JAX v0.4.38."
),
None,
),
}
import typing
if not typing.TYPE_CHECKING:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
@@ -0,0 +1,48 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.interpreters.batching import (
axis_primitive_batchers as axis_primitive_batchers,
bdim_at_front as bdim_at_front,
broadcast as broadcast,
defbroadcasting as defbroadcasting,
defreducer as defreducer,
defvectorized as defvectorized,
fancy_primitive_batchers as fancy_primitive_batchers,
not_mapped as not_mapped,
primitive_batchers as primitive_batchers,
register_vmappable as register_vmappable,
unregister_vmappable as unregister_vmappable,
)
_deprecations = {
# Deprecated in JAX v0.7.1; removed in JAX v0.10.0.
# TODO(jakevdp):remove this for JAX v0.11.0
"NotMapped": (
"jax.interpreters.batching.NotMapped is deprecated.",
None,
),
}
import typing as _typing
if not _typing.TYPE_CHECKING:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
@@ -0,0 +1,77 @@
# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.interpreters.mlir import (
AxisContext as AxisContext,
ConstantHandler as ConstantHandler,
DEVICE_TO_DEVICE_TYPE as DEVICE_TO_DEVICE_TYPE,
LoweringParameters as LoweringParameters,
LoweringResult as LoweringResult,
LoweringRule as LoweringRule,
LoweringRuleContext as LoweringRuleContext,
ModuleContext as ModuleContext,
RECV_FROM_HOST_TYPE as RECV_FROM_HOST_TYPE,
SEND_TO_HOST_TYPE as SEND_TO_HOST_TYPE,
ShapePolyLoweringState as ShapePolyLoweringState,
Token as Token,
TokenSet as TokenSet,
Value as Value,
call_lowering as _call_lowering, # noqa: F401
_lowerings as _lowerings,
_platform_specific_lowerings as _platform_specific_lowerings,
aval_to_ir_type as aval_to_ir_type,
aval_to_ir_types as aval_to_ir_types,
core_call_lowering as core_call_lowering,
dense_int_array as dense_int_array,
dense_int_elements as dense_int_elements,
dtype_to_ir_type as dtype_to_ir_type,
flatten_ir_types as flatten_ir_types,
flatten_ir_values as flatten_ir_values,
unflatten_ir_values_like_types as unflatten_ir_values_like_types,
i32_attr as i32_attr,
i64_attr as i64_attr,
ir as ir,
ir_attribute as ir_attribute,
ir_constant as ir_constant,
ir_constants as ir_constants,
ir_type_handlers as ir_type_handlers,
jaxpr_subcomp as jaxpr_subcomp,
lower_fun as lower_fun,
lower_jaxpr_to_fun as lower_jaxpr_to_fun,
lower_jaxpr_to_module as lower_jaxpr_to_module,
make_ir_context as make_ir_context,
merge_mlir_modules as merge_mlir_modules,
module_to_bytecode as module_to_bytecode,
module_to_string as module_to_string,
register_constant_handler as register_constant_handler,
register_lowering as register_lowering,
shape_tensor as shape_tensor,
token_type as token_type,
)
from jax._src.mesh import Mesh as Mesh
from jax._src.sharding_impls import (
MeshAxisName as MeshAxisName,
ReplicaAxisContext as ReplicaAxisContext,
SPMDAxisContext as SPMDAxisContext,
ShardingContext as ShardingContext,
)
from jax._src.effects import lowerable_effects as lowerable_effects
# TODO(dsuo): Temporarily maintain symbols related to callback lowering for sake
# of public APIs.
from jax._src.callback import (
emit_python_callback as emit_python_callback,
)
@@ -0,0 +1,30 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.interpreters.partial_eval import (
DynamicJaxprTracer as DynamicJaxprTracer,
JaxprTracer as JaxprTracer,
PartialVal as PartialVal,
Val as Val,
custom_partial_eval_rules as custom_partial_eval_rules,
dce_jaxpr as dce_jaxpr,
dce_jaxpr_call_rule as dce_jaxpr_call_rule,
dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule,
dce_jaxpr_consts as dce_jaxpr_consts,
dce_rules as dce_rules,
partial_eval_jaxpr_custom_rules as partial_eval_jaxpr_custom_rules,
trace_to_jaxpr_dynamic as trace_to_jaxpr_dynamic,
trace_to_jaxpr_nounits as trace_to_jaxpr_nounits,
)
@@ -0,0 +1,205 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.interpreters import pxla as _deprecated_pxla
from jax._src import mesh as _deprecated_mesh
from jax._src import op_shardings as _deprecated_op_shardings
from jax._src import sharding_impls as _deprecated_sharding_impls
from jax._src.interpreters.pxla import (
create_compile_options as create_compile_options,
)
_deprecations = {
# deprecated as of JAX v0.8.2 (Dec 2025)
"Index": (
"jax.interpreters.pxla.Index is deprecated as of JAX v0.8.2.",
_deprecated_pxla.Index,
),
"MeshAxisName": (
(
"jax.interpreters.pxla.MeshAxisName is deprecated as of JAX v0.8.2."
" Use jax.sharding.Mesh axis names directly."
),
_deprecated_pxla.MeshAxisName,
),
"MeshComputation": (
"jax.interpreters.pxla.MeshComputation is deprecated as of JAX v0.8.2.",
_deprecated_pxla.MeshComputation,
),
"MeshExecutable": (
"jax.interpreters.pxla.MeshExecutable is deprecated as of JAX v0.8.2.",
_deprecated_pxla.MeshExecutable,
),
"global_aval_to_result_handler": (
(
"jax.interpreters.pxla.global_aval_to_result_handler is deprecated"
" as of JAX v0.8.2."
),
_deprecated_pxla.global_aval_to_result_handler,
),
"global_avals_to_results_handler": (
(
"jax.interpreters.pxla.global_avals_to_results_handler is"
" deprecated as of JAX v0.8.2."
),
_deprecated_pxla.global_avals_to_results_handler,
),
"global_result_handlers": (
(
"jax.interpreters.pxla.global_result_handlers is deprecated as of"
" JAX v0.8.2."
),
_deprecated_pxla.global_result_handlers,
),
"thread_resources": (
(
"jax.interpreters.pxla.thread_resources is deprecated as of JAX"
" v0.8.2. Please switch to using `with jax.set_mesh(mesh)` instead"
" of `with mesh:` and then use `jax.sharding.get_abstract_mesh()`"
" to get the current mesh."
),
_deprecated_mesh.thread_resources,
),
"are_hlo_shardings_equal": (
(
"jax.interpreters.pxla.are_hlo_shardings_equal is deprecated as of"
" JAX v0.8.2."
),
_deprecated_op_shardings.are_hlo_shardings_equal,
),
"is_hlo_sharding_replicated": (
(
"jax.interpreters.pxla.is_hlo_sharding_replicated is deprecated as"
" of JAX v0.8.2."
),
_deprecated_op_shardings.is_hlo_sharding_replicated,
),
"op_sharding_to_indices": (
(
"jax.interpreters.pxla.op_sharding_to_indices is deprecated as of"
" JAX v0.8.2."
),
_deprecated_op_shardings.op_sharding_to_indices,
),
"ArrayMapping": (
"jax.interpreters.pxla.ArrayMapping is deprecated as of JAX v0.8.2.",
_deprecated_sharding_impls.ArrayMapping,
),
"_UNSPECIFIED": (
"jax.interpreters.pxla._UNSPECIFIED is deprecated as of JAX v0.8.2.",
_deprecated_sharding_impls.UNSPECIFIED,
),
"array_mapping_to_axis_resources": (
(
"jax.interpreters.pxla.array_mapping_to_axis_resources is"
" deprecated as of JAX v0.8.2."
),
_deprecated_sharding_impls.array_mapping_to_axis_resources,
),
# Deprecated as of JAX v0.8.2; finalized in JAX v0.10.0; remove in v0.11.0.
"MapTracer": (
"jax.interpreters.pxla.MapTracer was removed in JAX v0.10.0.",
None,
),
"PmapExecutable": (
"jax.interpreters.pxla.PmapExecutable was removed in JAX v0.10.0.",
None,
),
"parallel_callable": (
(
"jax.interpreters.pxla.parallel_callable was removed in JAX"
" v0.10.0."
),
None,
),
"shard_args": (
"jax.interpreters.pxla.shard_args was removed in JAX v0.10.0.",
None,
),
"Chunked": (
(
"jax.interpreters.pxla.Chunked was removed in JAX v0.10.0."
" Please use `jax.shard_map` instead of `jax.pmap`."
),
None,
),
"NoSharding": (
(
"jax.interpreters.pxla.NoSharding was removed in JAX v0.10.0."
" Please use `jax.shard_map` instead of `jax.pmap`."
),
None,
),
"Replicated": (
(
"jax.interpreters.pxla.Replicated was removed in JAX v0.10.0."
" Please use `jax.shard_map` instead of `jax.pmap`."
),
None,
),
"ShardedAxis": (
(
"jax.interpreters.pxla.ShardedAxis was removed in JAX v0.10.0."
" Please use `jax.shard_map` instead of `jax.pmap`."
),
None,
),
"ShardingSpec": (
(
"jax.interpreters.pxla.ShardingSpec was removed in JAX v0.10.0."
" Please use `jax.shard_map` instead of `jax.pmap`."
),
None,
),
"Unstacked": (
(
"jax.interpreters.pxla.Unstacked was removed in JAX v0.10.0."
" Please use `jax.shard_map` instead of `jax.pmap`."
),
None,
),
"spec_to_indices": (
(
"jax.interpreters.pxla.spec_to_indices was removed in JAX"
" v0.10.0. Please use `jax.shard_map` instead of `jax.pmap`."
),
None,
),
}
import typing as _typing
if _typing.TYPE_CHECKING:
Index = _deprecated_pxla.Index
MeshAxisName = _deprecated_pxla.MeshAxisName
MeshComputation = _deprecated_pxla.MeshComputation
MeshExecutable = _deprecated_pxla.MeshExecutable
global_aval_to_result_handler = _deprecated_pxla.global_aval_to_result_handler
global_avals_to_results_handler = _deprecated_pxla.global_avals_to_results_handler
global_result_handlers = _deprecated_pxla.global_result_handlers
thread_resources = _deprecated_mesh.thread_resources
are_hlo_shardings_equal = _deprecated_op_shardings.are_hlo_shardings_equal
is_hlo_sharding_replicated = _deprecated_op_shardings.is_hlo_sharding_replicated
op_sharding_to_indices = _deprecated_op_shardings.op_sharding_to_indices
ArrayMapping = _deprecated_sharding_impls.ArrayMapping
_UNSPECIFIED = _deprecated_sharding_impls.UNSPECIFIED
array_mapping_to_axis_resources = _deprecated_sharding_impls.array_mapping_to_axis_resources
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
@@ -0,0 +1,27 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ["apply_primitive", "canonicalize_dtype_handlers", "Backend"]
from jax._src.dtypes import (
canonicalize_value_handlers as canonicalize_dtype_handlers
)
from jax._src.dispatch import (
apply_primitive as apply_primitive,
)
from jax._src.lib import xla_client as _xc
Backend = _xc._xla.Client
del _xc