hand
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(os.path.dirname(__file__))
|
||||
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,61 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from jax._src.interpreters.ad import (
|
||||
JVPTrace as JVPTrace,
|
||||
JVPTracer as JVPTracer,
|
||||
UndefinedPrimal as UndefinedPrimal,
|
||||
Zero as Zero,
|
||||
add_jaxvals as add_jaxvals,
|
||||
add_jaxvals_p as add_jaxvals_p,
|
||||
add_tangents as add_tangents,
|
||||
defbilinear as defbilinear,
|
||||
defjvp as defjvp,
|
||||
defjvp2 as defjvp2,
|
||||
deflinear as deflinear,
|
||||
deflinear2 as deflinear2,
|
||||
get_primitive_transpose as get_primitive_transpose,
|
||||
instantiate_zeros as instantiate_zeros,
|
||||
is_undefined_primal as is_undefined_primal,
|
||||
jvp as jvp,
|
||||
linearize as linearize,
|
||||
primitive_jvps as primitive_jvps,
|
||||
primitive_transposes as primitive_transposes,
|
||||
zeros_like_aval as zeros_like_aval,
|
||||
)
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Deprecated in v0.9.0; finalized in v0.10.0.
|
||||
# TODO(jakevdp) remove entry in v0.11.0.
|
||||
"reducing_transposes": (
|
||||
(
|
||||
"jax.interpreters.ad.reducing_transposes was deprecated in v0.9.0."
|
||||
" and removed in v0.10.0. It has been unused since JAX v0.4.38."
|
||||
),
|
||||
None,
|
||||
),
|
||||
}
|
||||
|
||||
import typing
|
||||
if not typing.TYPE_CHECKING:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del typing
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.interpreters.batching import (
|
||||
axis_primitive_batchers as axis_primitive_batchers,
|
||||
bdim_at_front as bdim_at_front,
|
||||
broadcast as broadcast,
|
||||
defbroadcasting as defbroadcasting,
|
||||
defreducer as defreducer,
|
||||
defvectorized as defvectorized,
|
||||
fancy_primitive_batchers as fancy_primitive_batchers,
|
||||
not_mapped as not_mapped,
|
||||
primitive_batchers as primitive_batchers,
|
||||
register_vmappable as register_vmappable,
|
||||
unregister_vmappable as unregister_vmappable,
|
||||
)
|
||||
|
||||
|
||||
_deprecations = {
|
||||
# Deprecated in JAX v0.7.1; removed in JAX v0.10.0.
|
||||
# TODO(jakevdp):remove this for JAX v0.11.0
|
||||
"NotMapped": (
|
||||
"jax.interpreters.batching.NotMapped is deprecated.",
|
||||
None,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
import typing as _typing
|
||||
if not _typing.TYPE_CHECKING:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2021 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.interpreters.mlir import (
|
||||
AxisContext as AxisContext,
|
||||
ConstantHandler as ConstantHandler,
|
||||
DEVICE_TO_DEVICE_TYPE as DEVICE_TO_DEVICE_TYPE,
|
||||
LoweringParameters as LoweringParameters,
|
||||
LoweringResult as LoweringResult,
|
||||
LoweringRule as LoweringRule,
|
||||
LoweringRuleContext as LoweringRuleContext,
|
||||
ModuleContext as ModuleContext,
|
||||
RECV_FROM_HOST_TYPE as RECV_FROM_HOST_TYPE,
|
||||
SEND_TO_HOST_TYPE as SEND_TO_HOST_TYPE,
|
||||
ShapePolyLoweringState as ShapePolyLoweringState,
|
||||
Token as Token,
|
||||
TokenSet as TokenSet,
|
||||
Value as Value,
|
||||
call_lowering as _call_lowering, # noqa: F401
|
||||
_lowerings as _lowerings,
|
||||
_platform_specific_lowerings as _platform_specific_lowerings,
|
||||
aval_to_ir_type as aval_to_ir_type,
|
||||
aval_to_ir_types as aval_to_ir_types,
|
||||
core_call_lowering as core_call_lowering,
|
||||
dense_int_array as dense_int_array,
|
||||
dense_int_elements as dense_int_elements,
|
||||
dtype_to_ir_type as dtype_to_ir_type,
|
||||
flatten_ir_types as flatten_ir_types,
|
||||
flatten_ir_values as flatten_ir_values,
|
||||
unflatten_ir_values_like_types as unflatten_ir_values_like_types,
|
||||
i32_attr as i32_attr,
|
||||
i64_attr as i64_attr,
|
||||
ir as ir,
|
||||
ir_attribute as ir_attribute,
|
||||
ir_constant as ir_constant,
|
||||
ir_constants as ir_constants,
|
||||
ir_type_handlers as ir_type_handlers,
|
||||
jaxpr_subcomp as jaxpr_subcomp,
|
||||
lower_fun as lower_fun,
|
||||
lower_jaxpr_to_fun as lower_jaxpr_to_fun,
|
||||
lower_jaxpr_to_module as lower_jaxpr_to_module,
|
||||
make_ir_context as make_ir_context,
|
||||
merge_mlir_modules as merge_mlir_modules,
|
||||
module_to_bytecode as module_to_bytecode,
|
||||
module_to_string as module_to_string,
|
||||
register_constant_handler as register_constant_handler,
|
||||
register_lowering as register_lowering,
|
||||
shape_tensor as shape_tensor,
|
||||
token_type as token_type,
|
||||
)
|
||||
|
||||
from jax._src.mesh import Mesh as Mesh
|
||||
from jax._src.sharding_impls import (
|
||||
MeshAxisName as MeshAxisName,
|
||||
ReplicaAxisContext as ReplicaAxisContext,
|
||||
SPMDAxisContext as SPMDAxisContext,
|
||||
ShardingContext as ShardingContext,
|
||||
)
|
||||
from jax._src.effects import lowerable_effects as lowerable_effects
|
||||
|
||||
|
||||
# TODO(dsuo): Temporarily maintain symbols related to callback lowering for sake
|
||||
# of public APIs.
|
||||
from jax._src.callback import (
|
||||
emit_python_callback as emit_python_callback,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from jax._src.interpreters.partial_eval import (
|
||||
DynamicJaxprTracer as DynamicJaxprTracer,
|
||||
JaxprTracer as JaxprTracer,
|
||||
PartialVal as PartialVal,
|
||||
Val as Val,
|
||||
custom_partial_eval_rules as custom_partial_eval_rules,
|
||||
dce_jaxpr as dce_jaxpr,
|
||||
dce_jaxpr_call_rule as dce_jaxpr_call_rule,
|
||||
dce_jaxpr_closed_call_rule as dce_jaxpr_closed_call_rule,
|
||||
dce_jaxpr_consts as dce_jaxpr_consts,
|
||||
dce_rules as dce_rules,
|
||||
partial_eval_jaxpr_custom_rules as partial_eval_jaxpr_custom_rules,
|
||||
trace_to_jaxpr_dynamic as trace_to_jaxpr_dynamic,
|
||||
trace_to_jaxpr_nounits as trace_to_jaxpr_nounits,
|
||||
)
|
||||
@@ -0,0 +1,205 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.interpreters import pxla as _deprecated_pxla
|
||||
from jax._src import mesh as _deprecated_mesh
|
||||
from jax._src import op_shardings as _deprecated_op_shardings
|
||||
from jax._src import sharding_impls as _deprecated_sharding_impls
|
||||
|
||||
from jax._src.interpreters.pxla import (
|
||||
create_compile_options as create_compile_options,
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# deprecated as of JAX v0.8.2 (Dec 2025)
|
||||
"Index": (
|
||||
"jax.interpreters.pxla.Index is deprecated as of JAX v0.8.2.",
|
||||
_deprecated_pxla.Index,
|
||||
),
|
||||
"MeshAxisName": (
|
||||
(
|
||||
"jax.interpreters.pxla.MeshAxisName is deprecated as of JAX v0.8.2."
|
||||
" Use jax.sharding.Mesh axis names directly."
|
||||
),
|
||||
_deprecated_pxla.MeshAxisName,
|
||||
),
|
||||
"MeshComputation": (
|
||||
"jax.interpreters.pxla.MeshComputation is deprecated as of JAX v0.8.2.",
|
||||
_deprecated_pxla.MeshComputation,
|
||||
),
|
||||
"MeshExecutable": (
|
||||
"jax.interpreters.pxla.MeshExecutable is deprecated as of JAX v0.8.2.",
|
||||
_deprecated_pxla.MeshExecutable,
|
||||
),
|
||||
"global_aval_to_result_handler": (
|
||||
(
|
||||
"jax.interpreters.pxla.global_aval_to_result_handler is deprecated"
|
||||
" as of JAX v0.8.2."
|
||||
),
|
||||
_deprecated_pxla.global_aval_to_result_handler,
|
||||
),
|
||||
"global_avals_to_results_handler": (
|
||||
(
|
||||
"jax.interpreters.pxla.global_avals_to_results_handler is"
|
||||
" deprecated as of JAX v0.8.2."
|
||||
),
|
||||
_deprecated_pxla.global_avals_to_results_handler,
|
||||
),
|
||||
"global_result_handlers": (
|
||||
(
|
||||
"jax.interpreters.pxla.global_result_handlers is deprecated as of"
|
||||
" JAX v0.8.2."
|
||||
),
|
||||
_deprecated_pxla.global_result_handlers,
|
||||
),
|
||||
"thread_resources": (
|
||||
(
|
||||
"jax.interpreters.pxla.thread_resources is deprecated as of JAX"
|
||||
" v0.8.2. Please switch to using `with jax.set_mesh(mesh)` instead"
|
||||
" of `with mesh:` and then use `jax.sharding.get_abstract_mesh()`"
|
||||
" to get the current mesh."
|
||||
),
|
||||
_deprecated_mesh.thread_resources,
|
||||
),
|
||||
"are_hlo_shardings_equal": (
|
||||
(
|
||||
"jax.interpreters.pxla.are_hlo_shardings_equal is deprecated as of"
|
||||
" JAX v0.8.2."
|
||||
),
|
||||
_deprecated_op_shardings.are_hlo_shardings_equal,
|
||||
),
|
||||
"is_hlo_sharding_replicated": (
|
||||
(
|
||||
"jax.interpreters.pxla.is_hlo_sharding_replicated is deprecated as"
|
||||
" of JAX v0.8.2."
|
||||
),
|
||||
_deprecated_op_shardings.is_hlo_sharding_replicated,
|
||||
),
|
||||
"op_sharding_to_indices": (
|
||||
(
|
||||
"jax.interpreters.pxla.op_sharding_to_indices is deprecated as of"
|
||||
" JAX v0.8.2."
|
||||
),
|
||||
_deprecated_op_shardings.op_sharding_to_indices,
|
||||
),
|
||||
"ArrayMapping": (
|
||||
"jax.interpreters.pxla.ArrayMapping is deprecated as of JAX v0.8.2.",
|
||||
_deprecated_sharding_impls.ArrayMapping,
|
||||
),
|
||||
"_UNSPECIFIED": (
|
||||
"jax.interpreters.pxla._UNSPECIFIED is deprecated as of JAX v0.8.2.",
|
||||
_deprecated_sharding_impls.UNSPECIFIED,
|
||||
),
|
||||
"array_mapping_to_axis_resources": (
|
||||
(
|
||||
"jax.interpreters.pxla.array_mapping_to_axis_resources is"
|
||||
" deprecated as of JAX v0.8.2."
|
||||
),
|
||||
_deprecated_sharding_impls.array_mapping_to_axis_resources,
|
||||
),
|
||||
# Deprecated as of JAX v0.8.2; finalized in JAX v0.10.0; remove in v0.11.0.
|
||||
"MapTracer": (
|
||||
"jax.interpreters.pxla.MapTracer was removed in JAX v0.10.0.",
|
||||
None,
|
||||
),
|
||||
"PmapExecutable": (
|
||||
"jax.interpreters.pxla.PmapExecutable was removed in JAX v0.10.0.",
|
||||
None,
|
||||
),
|
||||
"parallel_callable": (
|
||||
(
|
||||
"jax.interpreters.pxla.parallel_callable was removed in JAX"
|
||||
" v0.10.0."
|
||||
),
|
||||
None,
|
||||
),
|
||||
"shard_args": (
|
||||
"jax.interpreters.pxla.shard_args was removed in JAX v0.10.0.",
|
||||
None,
|
||||
),
|
||||
"Chunked": (
|
||||
(
|
||||
"jax.interpreters.pxla.Chunked was removed in JAX v0.10.0."
|
||||
" Please use `jax.shard_map` instead of `jax.pmap`."
|
||||
),
|
||||
None,
|
||||
),
|
||||
"NoSharding": (
|
||||
(
|
||||
"jax.interpreters.pxla.NoSharding was removed in JAX v0.10.0."
|
||||
" Please use `jax.shard_map` instead of `jax.pmap`."
|
||||
),
|
||||
None,
|
||||
),
|
||||
"Replicated": (
|
||||
(
|
||||
"jax.interpreters.pxla.Replicated was removed in JAX v0.10.0."
|
||||
" Please use `jax.shard_map` instead of `jax.pmap`."
|
||||
),
|
||||
None,
|
||||
),
|
||||
"ShardedAxis": (
|
||||
(
|
||||
"jax.interpreters.pxla.ShardedAxis was removed in JAX v0.10.0."
|
||||
" Please use `jax.shard_map` instead of `jax.pmap`."
|
||||
),
|
||||
None,
|
||||
),
|
||||
"ShardingSpec": (
|
||||
(
|
||||
"jax.interpreters.pxla.ShardingSpec was removed in JAX v0.10.0."
|
||||
" Please use `jax.shard_map` instead of `jax.pmap`."
|
||||
),
|
||||
None,
|
||||
),
|
||||
"Unstacked": (
|
||||
(
|
||||
"jax.interpreters.pxla.Unstacked was removed in JAX v0.10.0."
|
||||
" Please use `jax.shard_map` instead of `jax.pmap`."
|
||||
),
|
||||
None,
|
||||
),
|
||||
"spec_to_indices": (
|
||||
(
|
||||
"jax.interpreters.pxla.spec_to_indices was removed in JAX"
|
||||
" v0.10.0. Please use `jax.shard_map` instead of `jax.pmap`."
|
||||
),
|
||||
None,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if _typing.TYPE_CHECKING:
|
||||
Index = _deprecated_pxla.Index
|
||||
MeshAxisName = _deprecated_pxla.MeshAxisName
|
||||
MeshComputation = _deprecated_pxla.MeshComputation
|
||||
MeshExecutable = _deprecated_pxla.MeshExecutable
|
||||
global_aval_to_result_handler = _deprecated_pxla.global_aval_to_result_handler
|
||||
global_avals_to_results_handler = _deprecated_pxla.global_avals_to_results_handler
|
||||
global_result_handlers = _deprecated_pxla.global_result_handlers
|
||||
thread_resources = _deprecated_mesh.thread_resources
|
||||
are_hlo_shardings_equal = _deprecated_op_shardings.are_hlo_shardings_equal
|
||||
is_hlo_sharding_replicated = _deprecated_op_shardings.is_hlo_sharding_replicated
|
||||
op_sharding_to_indices = _deprecated_op_shardings.op_sharding_to_indices
|
||||
ArrayMapping = _deprecated_sharding_impls.ArrayMapping
|
||||
_UNSPECIFIED = _deprecated_sharding_impls.UNSPECIFIED
|
||||
array_mapping_to_axis_resources = _deprecated_sharding_impls.array_mapping_to_axis_resources
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ["apply_primitive", "canonicalize_dtype_handlers", "Backend"]
|
||||
|
||||
from jax._src.dtypes import (
|
||||
canonicalize_value_handlers as canonicalize_dtype_handlers
|
||||
)
|
||||
|
||||
from jax._src.dispatch import (
|
||||
apply_primitive as apply_primitive,
|
||||
)
|
||||
|
||||
from jax._src.lib import xla_client as _xc
|
||||
Backend = _xc._xla.Client
|
||||
del _xc
|
||||
Reference in New Issue
Block a user