hand
This commit is contained in:
@@ -0,0 +1,49 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Modules for JAX extensions.
|
||||
|
||||
The :mod:`jax.extend` module provides modules for access to JAX
|
||||
internal machinery. See
|
||||
`JEP #15856 <https://docs.jax.dev/en/latest/jep/15856-jex.html>`_.
|
||||
|
||||
This module is not the only means by which JAX aims to be
|
||||
extensible. For example, the main JAX API offers mechanisms for
|
||||
`customizing derivatives
|
||||
<https://docs.jax.dev/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html>`_,
|
||||
`registering custom pytree definitions
|
||||
<https://docs.jax.dev/en/latest/custom_pytrees.html#pytrees-custom-pytree-nodes>`_,
|
||||
and more.
|
||||
|
||||
API policy
|
||||
----------
|
||||
|
||||
Unlike the
|
||||
`public API <https://docs.jax.dev/en/latest/api_compatibility.html>`_,
|
||||
this module offers **no compatibility guarantee** across releases.
|
||||
Breaking changes will be announced via the
|
||||
`JAX project changelog <https://docs.jax.dev/en/latest/changelog.html>`_.
|
||||
"""
|
||||
|
||||
from jax.extend import (
|
||||
backend as backend,
|
||||
core as core,
|
||||
linear_util as linear_util,
|
||||
lowering as lowering,
|
||||
mlir as mlir,
|
||||
pallas as pallas,
|
||||
random as random,
|
||||
sharding as sharding,
|
||||
source_info_util as source_info_util,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,43 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.api import (
|
||||
clear_backends as clear_backends,
|
||||
)
|
||||
from jax._src.compiler import (
|
||||
get_compile_options as get_compile_options,
|
||||
)
|
||||
from jax._src.xla_bridge import (
|
||||
backends as backends,
|
||||
backend_xla_version as backend_xla_version,
|
||||
get_backend as get_backend,
|
||||
register_backend_factory as register_backend_factory,
|
||||
)
|
||||
from jax._src.interpreters.pxla import (
|
||||
clear_in_memory_compilation_cache as clear_in_memory_compilation_cache,
|
||||
get_default_device as get_default_device,
|
||||
)
|
||||
from jax._src import (
|
||||
util as _util
|
||||
)
|
||||
register_backend_cache = _util.register_cache
|
||||
|
||||
from jax._src.lib import (
|
||||
ifrt_proxy as ifrt_proxy
|
||||
)
|
||||
|
||||
del _util
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.abstract_arrays import (
|
||||
array_types as array_types
|
||||
)
|
||||
|
||||
from jax._src.core import (
|
||||
AbstractToken as AbstractToken,
|
||||
CallPrimitive as CallPrimitive,
|
||||
ClosedJaxpr as ClosedJaxpr,
|
||||
DebugInfo as DebugInfo,
|
||||
DropVar as DropVar,
|
||||
Effect as Effect,
|
||||
Effects as Effects,
|
||||
InconclusiveDimensionOperation as InconclusiveDimensionOperation,
|
||||
Jaxpr as Jaxpr,
|
||||
JaxprEqn as JaxprEqn,
|
||||
JaxprTypeError as JaxprTypeError,
|
||||
Literal as Literal,
|
||||
Primitive as Primitive,
|
||||
Token as Token,
|
||||
TraceTag as TraceTag,
|
||||
Var as Var,
|
||||
call_impl as call_impl,
|
||||
check_jaxpr as check_jaxpr,
|
||||
concrete_or_error as concrete_or_error,
|
||||
find_top_trace as find_top_trace,
|
||||
gensym as gensym,
|
||||
get_opaque_trace_state as get_opaque_trace_state,
|
||||
jaxpr_as_fun as jaxpr_as_fun,
|
||||
jaxprs_in_params as jaxprs_in_params,
|
||||
mapped_aval as mapped_aval,
|
||||
new_jaxpr_eqn as new_jaxpr_eqn,
|
||||
no_effects as no_effects,
|
||||
nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, # noqa: F401
|
||||
primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
|
||||
set_current_trace as set_current_trace,
|
||||
subjaxprs as subjaxprs,
|
||||
take_current_trace as take_current_trace,
|
||||
unmapped_aval as unmapped_aval,
|
||||
unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, # noqa: F401
|
||||
unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, # noqa: F401
|
||||
unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, # noqa: F401
|
||||
valid_jaxtype as valid_jaxtype,
|
||||
)
|
||||
|
||||
from . import primitives as primitives
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,245 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.ad_checkpoint import (
|
||||
name_p as name_p,
|
||||
remat_p as remat_p,
|
||||
)
|
||||
|
||||
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
|
||||
|
||||
from jax._src.core import (
|
||||
call_p as call_p,
|
||||
closed_call_p as closed_call_p
|
||||
)
|
||||
|
||||
from jax._src.custom_derivatives import (
|
||||
custom_jvp_call_p as custom_jvp_call_p,
|
||||
custom_vjp_call_p as custom_vjp_call_p,
|
||||
)
|
||||
|
||||
from jax._src.dispatch import device_put_p as device_put_p
|
||||
|
||||
from jax._src.interpreters.ad import (
|
||||
add_jaxvals_p as add_jaxvals_p,
|
||||
custom_lin_p as custom_lin_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.lax import (
|
||||
abs_p as abs_p,
|
||||
acos_p as acos_p,
|
||||
acosh_p as acosh_p,
|
||||
add_p as add_p,
|
||||
after_all_p as after_all_p,
|
||||
and_p as and_p,
|
||||
argmax_p as argmax_p,
|
||||
argmin_p as argmin_p,
|
||||
asin_p as asin_p,
|
||||
asinh_p as asinh_p,
|
||||
atan_p as atan_p,
|
||||
atan2_p as atan2_p,
|
||||
atanh_p as atanh_p,
|
||||
bitcast_convert_type_p as bitcast_convert_type_p,
|
||||
broadcast_in_dim_p as broadcast_in_dim_p,
|
||||
cbrt_p as cbrt_p,
|
||||
ceil_p as ceil_p,
|
||||
clamp_p as clamp_p,
|
||||
clz_p as clz_p,
|
||||
complex_p as complex_p,
|
||||
concatenate_p as concatenate_p,
|
||||
conj_p as conj_p,
|
||||
convert_element_type_p as convert_element_type_p,
|
||||
copy_p as copy_p,
|
||||
cos_p as cos_p,
|
||||
cosh_p as cosh_p,
|
||||
create_token_p as create_token_p,
|
||||
div_p as div_p,
|
||||
dot_general_p as dot_general_p,
|
||||
eq_p as eq_p,
|
||||
eq_to_p as eq_to_p,
|
||||
exp_p as exp_p,
|
||||
exp2_p as exp2_p,
|
||||
expm1_p as expm1_p,
|
||||
floor_p as floor_p,
|
||||
ge_p as ge_p,
|
||||
gt_p as gt_p,
|
||||
imag_p as imag_p,
|
||||
integer_pow_p as integer_pow_p,
|
||||
iota_p as iota_p,
|
||||
is_finite_p as is_finite_p,
|
||||
le_p as le_p,
|
||||
le_to_p as le_to_p,
|
||||
log1p_p as log1p_p,
|
||||
log_p as log_p,
|
||||
logistic_p as logistic_p,
|
||||
lt_p as lt_p,
|
||||
lt_to_p as lt_to_p,
|
||||
max_p as max_p,
|
||||
min_p as min_p,
|
||||
mul_p as mul_p,
|
||||
ne_p as ne_p,
|
||||
neg_p as neg_p,
|
||||
nextafter_p as nextafter_p,
|
||||
not_p as not_p,
|
||||
or_p as or_p,
|
||||
pad_p as pad_p,
|
||||
population_count_p as population_count_p,
|
||||
pow_p as pow_p,
|
||||
real_p as real_p,
|
||||
reduce_and_p as reduce_and_p,
|
||||
reduce_max_p as reduce_max_p,
|
||||
reduce_min_p as reduce_min_p,
|
||||
reduce_or_p as reduce_or_p,
|
||||
reduce_p as reduce_p,
|
||||
reduce_precision_p as reduce_precision_p,
|
||||
reduce_prod_p as reduce_prod_p,
|
||||
reduce_sum_p as reduce_sum_p,
|
||||
reduce_xor_p as reduce_xor_p,
|
||||
rem_p as rem_p,
|
||||
reshape_p as reshape_p,
|
||||
rev_p as rev_p,
|
||||
rng_bit_generator_p as rng_bit_generator_p,
|
||||
rng_uniform_p as rng_uniform_p,
|
||||
round_p as round_p,
|
||||
rsqrt_p as rsqrt_p,
|
||||
select_n_p as select_n_p,
|
||||
shift_left_p as shift_left_p,
|
||||
shift_right_arithmetic_p as shift_right_arithmetic_p,
|
||||
shift_right_logical_p as shift_right_logical_p,
|
||||
sign_p as sign_p,
|
||||
sin_p as sin_p,
|
||||
sinh_p as sinh_p,
|
||||
sort_p as sort_p,
|
||||
sqrt_p as sqrt_p,
|
||||
square_p as square_p,
|
||||
squeeze_p as squeeze_p,
|
||||
sub_p as sub_p,
|
||||
tan_p as tan_p,
|
||||
tanh_p as tanh_p,
|
||||
top_k_p as top_k_p,
|
||||
transpose_p as transpose_p,
|
||||
xor_p as xor_p,
|
||||
empty2_p as empty2_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.special import (
|
||||
bessel_i0e_p as bessel_i0e_p,
|
||||
bessel_i1e_p as bessel_i1e_p,
|
||||
digamma_p as digamma_p,
|
||||
erfc_p as erfc_p,
|
||||
erf_inv_p as erf_inv_p,
|
||||
erf_p as erf_p,
|
||||
igammac_p as igammac_p,
|
||||
igamma_grad_a_p as igamma_grad_a_p,
|
||||
igamma_p as igamma_p,
|
||||
lgamma_p as lgamma_p,
|
||||
polygamma_p as polygamma_p,
|
||||
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
|
||||
zeta_p as zeta_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.slicing import (
|
||||
dynamic_slice_p as dynamic_slice_p,
|
||||
dynamic_update_slice_p as dynamic_update_slice_p,
|
||||
gather_p as gather_p,
|
||||
scatter_add_p as scatter_add_p,
|
||||
scatter_max_p as scatter_max_p,
|
||||
scatter_min_p as scatter_min_p,
|
||||
scatter_mul_p as scatter_mul_p,
|
||||
scatter_p as scatter_p,
|
||||
slice_p as slice_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.convolution import (
|
||||
conv_general_dilated_p as conv_general_dilated_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.windowed_reductions import (
|
||||
reduce_window_max_p as reduce_window_max_p,
|
||||
reduce_window_min_p as reduce_window_min_p,
|
||||
reduce_window_p as reduce_window_p,
|
||||
reduce_window_sum_p as reduce_window_sum_p,
|
||||
select_and_gather_add_p as select_and_gather_add_p,
|
||||
select_and_scatter_p as select_and_scatter_p,
|
||||
select_and_scatter_add_p as select_and_scatter_add_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.control_flow import (
|
||||
cond_p as cond_p,
|
||||
cumlogsumexp_p as cumlogsumexp_p,
|
||||
cummax_p as cummax_p,
|
||||
cummin_p as cummin_p,
|
||||
cumprod_p as cumprod_p,
|
||||
cumsum_p as cumsum_p,
|
||||
linear_solve_p as linear_solve_p,
|
||||
scan_p as scan_p,
|
||||
while_p as while_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.fft import (
|
||||
fft_p as fft_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.parallel import (
|
||||
all_gather_p as all_gather_p,
|
||||
all_to_all_p as all_to_all_p,
|
||||
axis_index_p as axis_index_p,
|
||||
pmax_p as pmax_p,
|
||||
pmin_p as pmin_p,
|
||||
ppermute_p as ppermute_p,
|
||||
psum_p as psum_p,
|
||||
ragged_all_to_all_p as ragged_all_to_all_p,
|
||||
)
|
||||
|
||||
from jax._src.lax.ann import (
|
||||
approx_top_k_p as approx_top_k_p
|
||||
)
|
||||
|
||||
from jax._src.lax.linalg import (
|
||||
cholesky_p as cholesky_p,
|
||||
eig_p as eig_p,
|
||||
eigh_p as eigh_p,
|
||||
hessenberg_p as hessenberg_p,
|
||||
lu_p as lu_p,
|
||||
householder_product_p as householder_product_p,
|
||||
qr_p as qr_p,
|
||||
svd_p as svd_p,
|
||||
triangular_solve_p as triangular_solve_p,
|
||||
tridiagonal_p as tridiagonal_p,
|
||||
tridiagonal_solve_p as tridiagonal_solve_p,
|
||||
schur_p as schur_p,
|
||||
)
|
||||
|
||||
from jax._src.pjit import (
|
||||
jit_p as jit_p,
|
||||
sharding_constraint_p as sharding_constraint_p,
|
||||
)
|
||||
|
||||
from jax._src.prng import (
|
||||
random_bits_p as random_bits_p,
|
||||
random_fold_in_p as random_fold_in_p,
|
||||
random_seed_p as random_seed_p,
|
||||
random_split_p as random_split_p,
|
||||
threefry2x32_p as threefry2x32_p,
|
||||
)
|
||||
|
||||
from jax._src.random import random_gamma_p as random_gamma_p
|
||||
|
||||
from jax._src.state.primitives import (
|
||||
get_p as get_p,
|
||||
swap_p as swap_p,
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.lib import _jax
|
||||
|
||||
ifrt_programs = _jax.ifrt_programs
|
||||
|
||||
del _jax
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from jax._src.linear_util import (
|
||||
StoreException as StoreException,
|
||||
WrappedFun as WrappedFun,
|
||||
cache as cache,
|
||||
merge_linear_aux as merge_linear_aux,
|
||||
transformation as transformation,
|
||||
transformation_with_aux as transformation_with_aux,
|
||||
transformation2 as transformation2,
|
||||
transformation_with_aux2 as transformation_with_aux2,
|
||||
# TODO(b/396086979): remove this once we pass debug_info everywhere.
|
||||
wrap_init as _wrap_init,
|
||||
_missing_debug_info as _missing_debug_info,
|
||||
)
|
||||
|
||||
# Version of wrap_init that does not require a DebugInfo object.
|
||||
# This usage is deprecated, use api_util.debug_info() to construct a proper
|
||||
# DebugInfo object.
|
||||
def wrap_init(f: Callable, params=None, *, debug_info=None) -> WrappedFun:
|
||||
debug_info = debug_info or _missing_debug_info("linear_util.wrap_init")
|
||||
return _wrap_init(f, params, debug_info=debug_info)
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.interpreters.mlir import (
|
||||
JaxIrContext as JaxIrContext,
|
||||
LoweringRuleContext as LoweringRuleContext,
|
||||
upstream_dialects as upstream_dialects,
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.lib import (
|
||||
_jax as _jax
|
||||
)
|
||||
from jax._src.interpreters.mlir import (
|
||||
lower_with_sharding_in_types as lower_with_sharding_in_types,
|
||||
)
|
||||
|
||||
deserialize_portable_artifact = _jax.mlir.deserialize_portable_artifact
|
||||
serialize_portable_artifact = _jax.mlir.serialize_portable_artifact
|
||||
refine_polymorphic_shapes = _jax.mlir.refine_polymorphic_shapes
|
||||
hlo_to_stablehlo = _jax.mlir.hlo_to_stablehlo
|
||||
|
||||
del _jax
|
||||
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.arith import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.builtin import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.chlo import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.func import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.math import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.memref import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.mpmd import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.scf import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.sdy import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.sparse_tensor import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.stablehlo import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.dialects.vector import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.ir import *
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
from jaxlib.mlir.passmanager import *
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright 2026 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.pallas.core import (
|
||||
GridMapping as GridMapping,
|
||||
register_lowering_rule as register_lowering_rule,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.extend.random import (
|
||||
define_prng_impl as define_prng_impl,
|
||||
)
|
||||
|
||||
from jax._src.prng import (
|
||||
random_seed as random_seed,
|
||||
seed_with_impl as seed_with_impl,
|
||||
threefry2x32_p as threefry2x32_p,
|
||||
threefry_2x32 as threefry_2x32,
|
||||
threefry_prng_impl as threefry_prng_impl,
|
||||
rbg_prng_impl as rbg_prng_impl,
|
||||
unsafe_rbg_prng_impl as unsafe_rbg_prng_impl,
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO(yashkatariya): Remove this after NamedSharding supports more complicated
|
||||
# shardings like sub-axes, strided shardings, etc.
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.sharding_impls import GSPMDSharding as GSPMDSharding
|
||||
|
||||
|
||||
def get_op_sharding_from_serialized_proto(
|
||||
sharding: bytes) -> xla_client.OpSharding:
|
||||
proto = xla_client.OpSharding()
|
||||
proto.ParseFromString(sharding)
|
||||
return proto
|
||||
|
||||
|
||||
def get_hlo_sharding_from_serialized_proto(
|
||||
sharding: bytes) -> xla_client.HloSharding:
|
||||
return xla_client.HloSharding.from_proto(
|
||||
get_op_sharding_from_serialized_proto(sharding))
|
||||
|
||||
|
||||
def get_serialized_proto_from_hlo_sharding(
|
||||
sharding: xla_client.HloSharding) -> bytes:
|
||||
return sharding.to_proto().SerializeToString()
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.source_info_util import (
|
||||
NameStack as NameStack,
|
||||
SourceInfo as SourceInfo,
|
||||
current as current,
|
||||
current_name_stack as current_name_stack,
|
||||
extend_name_stack as extend_name_stack,
|
||||
new_name_stack as new_name_stack,
|
||||
new_source_info as new_source_info,
|
||||
register_exclusion as register_exclusion,
|
||||
reset_name_stack as reset_name_stack,
|
||||
set_name_stack as set_name_stack,
|
||||
summarize as summarize,
|
||||
transform_name_stack as transform_name_stack,
|
||||
user_context as user_context,
|
||||
)
|
||||
Reference in New Issue
Block a user