hand
This commit is contained in:
@@ -0,0 +1,419 @@
|
||||
# Copyright 2019 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Note: import <name> as <name> is required for names to be exported.
|
||||
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
|
||||
|
||||
from jax._src.lax.lax import (
|
||||
DotDimensionNumbers as DotDimensionNumbers,
|
||||
RaggedDotDimensionNumbers as RaggedDotDimensionNumbers,
|
||||
AccuracyMode as AccuracyMode,
|
||||
Tolerance as Tolerance,
|
||||
Precision as Precision,
|
||||
PrecisionLike as PrecisionLike,
|
||||
DotAlgorithm as DotAlgorithm,
|
||||
DotAlgorithmPreset as DotAlgorithmPreset,
|
||||
RandomAlgorithm as RandomAlgorithm,
|
||||
RoundingMethod as RoundingMethod,
|
||||
abs as abs,
|
||||
abs_p as abs_p,
|
||||
acos as acos,
|
||||
acos_p as acos_p,
|
||||
acosh as acosh,
|
||||
acosh_p as acosh_p,
|
||||
add as add,
|
||||
add_p as add_p,
|
||||
after_all as after_all,
|
||||
after_all_p as after_all_p,
|
||||
and_p as and_p,
|
||||
argmax as argmax,
|
||||
argmax_p as argmax_p,
|
||||
argmin as argmin,
|
||||
argmin_p as argmin_p,
|
||||
asin as asin,
|
||||
asin_p as asin_p,
|
||||
asinh as asinh,
|
||||
asinh_p as asinh_p,
|
||||
atan as atan,
|
||||
atan_p as atan_p,
|
||||
atan2 as atan2,
|
||||
atan2_p as atan2_p,
|
||||
atanh as atanh,
|
||||
atanh_p as atanh_p,
|
||||
batch_matmul as batch_matmul,
|
||||
bitcast_convert_type as bitcast_convert_type,
|
||||
bitcast_convert_type_p as bitcast_convert_type_p,
|
||||
bitwise_and as bitwise_and,
|
||||
bitwise_not as bitwise_not,
|
||||
bitwise_or as bitwise_or,
|
||||
bitwise_xor as bitwise_xor,
|
||||
broadcast as broadcast,
|
||||
broadcast_like as broadcast_like,
|
||||
broadcast_in_dim as broadcast_in_dim,
|
||||
broadcast_in_dim_p as broadcast_in_dim_p,
|
||||
broadcast_shapes as broadcast_shapes,
|
||||
broadcast_to_rank as broadcast_to_rank,
|
||||
broadcasted_iota as broadcasted_iota,
|
||||
cbrt as cbrt,
|
||||
cbrt_p as cbrt_p,
|
||||
ceil as ceil,
|
||||
ceil_p as ceil_p,
|
||||
clamp as clamp,
|
||||
clamp_p as clamp_p,
|
||||
clz as clz,
|
||||
clz_p as clz_p,
|
||||
collapse as collapse,
|
||||
complex as complex,
|
||||
complex_p as complex_p,
|
||||
composite as composite,
|
||||
concatenate as concatenate,
|
||||
concatenate_p as concatenate_p,
|
||||
conj as conj,
|
||||
conj_p as conj_p,
|
||||
convert_element_type as convert_element_type,
|
||||
convert_element_type_p as convert_element_type_p,
|
||||
copy_p as copy_p,
|
||||
cos as cos,
|
||||
dce_sink_p as dce_sink_p,
|
||||
dce_sink as dce_sink,
|
||||
cos_p as cos_p,
|
||||
cosh as cosh,
|
||||
cosh_p as cosh_p,
|
||||
create_token as create_token,
|
||||
create_token_p as create_token_p,
|
||||
div as div,
|
||||
div_p as div_p,
|
||||
dot as dot,
|
||||
dot_general as dot_general,
|
||||
dot_general_p as dot_general_p,
|
||||
dtype as dtype,
|
||||
eq as eq,
|
||||
eq_p as eq_p,
|
||||
eq_to_p as eq_to_p,
|
||||
exp as exp,
|
||||
exp_p as exp_p,
|
||||
exp2 as exp2,
|
||||
exp2_p as exp2_p,
|
||||
expand_dims as expand_dims,
|
||||
expm1 as expm1,
|
||||
expm1_p as expm1_p,
|
||||
floor as floor,
|
||||
floor_p as floor_p,
|
||||
full as full,
|
||||
full_like as full_like,
|
||||
ge as ge,
|
||||
ge_p as ge_p,
|
||||
gt as gt,
|
||||
gt_p as gt_p,
|
||||
imag as imag,
|
||||
imag_p as imag_p,
|
||||
integer_pow as integer_pow,
|
||||
integer_pow_p as integer_pow_p,
|
||||
iota as iota,
|
||||
iota_p as iota_p,
|
||||
is_finite as is_finite,
|
||||
is_finite_p as is_finite_p,
|
||||
le as le,
|
||||
le_p as le_p,
|
||||
le_to_p as le_to_p,
|
||||
log as log,
|
||||
log1p as log1p,
|
||||
log1p_p as log1p_p,
|
||||
log_p as log_p,
|
||||
logistic as logistic,
|
||||
logistic_p as logistic_p,
|
||||
lt as lt,
|
||||
lt_p as lt_p,
|
||||
lt_to_p as lt_to_p,
|
||||
max as max,
|
||||
max_p as max_p,
|
||||
min as min,
|
||||
min_p as min_p,
|
||||
mul as mul,
|
||||
mul_p as mul_p,
|
||||
ne as ne,
|
||||
ne_p as ne_p,
|
||||
neg as neg,
|
||||
neg_p as neg_p,
|
||||
nextafter as nextafter,
|
||||
nextafter_p as nextafter_p,
|
||||
not_p as not_p,
|
||||
optimization_barrier as optimization_barrier,
|
||||
optimization_barrier_p as optimization_barrier_p,
|
||||
or_p as or_p,
|
||||
pad as pad,
|
||||
pad_p as pad_p,
|
||||
padtype_to_pads as padtype_to_pads,
|
||||
population_count as population_count,
|
||||
population_count_p as population_count_p,
|
||||
pow as pow,
|
||||
pow_p as pow_p,
|
||||
ragged_dot as ragged_dot,
|
||||
ragged_dot_general as ragged_dot_general,
|
||||
real as real,
|
||||
real_p as real_p,
|
||||
reciprocal as reciprocal,
|
||||
reduce as reduce,
|
||||
reduce_and as reduce_and,
|
||||
reduce_and_p as reduce_and_p,
|
||||
reduce_max as reduce_max,
|
||||
reduce_max_p as reduce_max_p,
|
||||
reduce_min as reduce_min,
|
||||
reduce_min_p as reduce_min_p,
|
||||
reduce_or as reduce_or,
|
||||
reduce_or_p as reduce_or_p,
|
||||
reduce_p as reduce_p,
|
||||
reduce_precision as reduce_precision,
|
||||
reduce_precision_p as reduce_precision_p,
|
||||
reduce_prod as reduce_prod,
|
||||
reduce_prod_p as reduce_prod_p,
|
||||
reduce_sum as reduce_sum,
|
||||
reduce_sum_p as reduce_sum_p,
|
||||
reduce_xor as reduce_xor,
|
||||
reduce_xor_p as reduce_xor_p,
|
||||
rem as rem,
|
||||
rem_p as rem_p,
|
||||
reshape as reshape,
|
||||
reshape_p as reshape_p,
|
||||
rev as rev,
|
||||
rev_p as rev_p,
|
||||
rng_bit_generator as rng_bit_generator,
|
||||
rng_bit_generator_p as rng_bit_generator_p,
|
||||
rng_uniform as rng_uniform,
|
||||
rng_uniform_p as rng_uniform_p,
|
||||
round as round,
|
||||
round_p as round_p,
|
||||
rsqrt as rsqrt,
|
||||
rsqrt_p as rsqrt_p,
|
||||
select as select,
|
||||
select_n as select_n,
|
||||
select_n_p as select_n_p,
|
||||
shape_as_value as shape_as_value,
|
||||
shift_left as shift_left,
|
||||
shift_left_p as shift_left_p,
|
||||
shift_right_arithmetic as shift_right_arithmetic,
|
||||
shift_right_arithmetic_p as shift_right_arithmetic_p,
|
||||
shift_right_logical as shift_right_logical,
|
||||
shift_right_logical_p as shift_right_logical_p,
|
||||
sign as sign,
|
||||
sign_p as sign_p,
|
||||
sin as sin,
|
||||
sin_p as sin_p,
|
||||
sinh as sinh,
|
||||
sinh_p as sinh_p,
|
||||
sort as sort,
|
||||
sort_key_val as sort_key_val,
|
||||
sort_p as sort_p,
|
||||
split as split,
|
||||
split_p as split_p,
|
||||
sqrt as sqrt,
|
||||
sqrt_p as sqrt_p,
|
||||
square as square,
|
||||
square_p as square_p,
|
||||
squeeze as squeeze,
|
||||
squeeze_p as squeeze_p,
|
||||
stop_gradient as stop_gradient,
|
||||
sub as sub,
|
||||
sub_p as sub_p,
|
||||
tan as tan,
|
||||
tan_p as tan_p,
|
||||
tanh as tanh,
|
||||
tanh_p as tanh_p,
|
||||
tile as tile,
|
||||
tile_p as tile_p,
|
||||
top_k as top_k,
|
||||
top_k_p as top_k_p,
|
||||
transpose as transpose,
|
||||
transpose_p as transpose_p,
|
||||
xor_p as xor_p,
|
||||
empty as empty,
|
||||
)
|
||||
from jax._src.lax.special import (
|
||||
bessel_i0e as bessel_i0e,
|
||||
bessel_i0e_p as bessel_i0e_p,
|
||||
bessel_i1e as bessel_i1e,
|
||||
bessel_i1e_p as bessel_i1e_p,
|
||||
betainc as betainc,
|
||||
digamma as digamma,
|
||||
digamma_p as digamma_p,
|
||||
erf as erf,
|
||||
erfc as erfc,
|
||||
erfc_p as erfc_p,
|
||||
erf_inv as erf_inv,
|
||||
erf_inv_p as erf_inv_p,
|
||||
erf_p as erf_p,
|
||||
igamma as igamma,
|
||||
igammac as igammac,
|
||||
igammac_p as igammac_p,
|
||||
igamma_grad_a as igamma_grad_a,
|
||||
igamma_grad_a_p as igamma_grad_a_p,
|
||||
igamma_p as igamma_p,
|
||||
lgamma as lgamma,
|
||||
lgamma_p as lgamma_p,
|
||||
polygamma as polygamma,
|
||||
polygamma_p as polygamma_p,
|
||||
random_gamma_grad as random_gamma_grad,
|
||||
regularized_incomplete_beta_p as regularized_incomplete_beta_p,
|
||||
zeta as zeta,
|
||||
zeta_p as zeta_p,
|
||||
)
|
||||
from jax._src.lax.slicing import (
|
||||
GatherDimensionNumbers as GatherDimensionNumbers,
|
||||
GatherScatterMode as GatherScatterMode,
|
||||
ScatterDimensionNumbers as ScatterDimensionNumbers,
|
||||
dynamic_index_in_dim as dynamic_index_in_dim,
|
||||
dynamic_slice as dynamic_slice,
|
||||
dynamic_slice_in_dim as dynamic_slice_in_dim,
|
||||
dynamic_slice_p as dynamic_slice_p,
|
||||
dynamic_update_index_in_dim as dynamic_update_index_in_dim,
|
||||
dynamic_update_slice as dynamic_update_slice,
|
||||
dynamic_update_slice_in_dim as dynamic_update_slice_in_dim,
|
||||
dynamic_update_slice_p as dynamic_update_slice_p,
|
||||
gather as gather,
|
||||
gather_p as gather_p,
|
||||
index_in_dim as index_in_dim,
|
||||
index_take as index_take,
|
||||
scatter as scatter,
|
||||
scatter_apply as scatter_apply,
|
||||
scatter_add as scatter_add,
|
||||
scatter_add_p as scatter_add_p,
|
||||
scatter_max as scatter_max,
|
||||
scatter_max_p as scatter_max_p,
|
||||
scatter_min as scatter_min,
|
||||
scatter_min_p as scatter_min_p,
|
||||
scatter_mul as scatter_mul,
|
||||
scatter_mul_p as scatter_mul_p,
|
||||
scatter_p as scatter_p,
|
||||
scatter_sub as scatter_sub,
|
||||
scatter_sub_p as scatter_sub_p,
|
||||
slice as slice,
|
||||
slice_in_dim as slice_in_dim,
|
||||
slice_p as slice_p,
|
||||
)
|
||||
from jax._src.lax.convolution import (
|
||||
ConvDimensionNumbers as ConvDimensionNumbers,
|
||||
ConvGeneralDilatedDimensionNumbers as ConvGeneralDilatedDimensionNumbers,
|
||||
conv as conv,
|
||||
conv_dimension_numbers as conv_dimension_numbers,
|
||||
conv_general_dilated as conv_general_dilated,
|
||||
conv_general_dilated_p as conv_general_dilated_p,
|
||||
conv_general_permutations as conv_general_permutations,
|
||||
conv_general_shape_tuple as conv_general_shape_tuple,
|
||||
conv_shape_tuple as conv_shape_tuple,
|
||||
conv_transpose as conv_transpose,
|
||||
conv_transpose_shape_tuple as conv_transpose_shape_tuple,
|
||||
conv_with_general_padding as conv_with_general_padding,
|
||||
)
|
||||
from jax._src.lax.windowed_reductions import (
|
||||
reduce_window as reduce_window,
|
||||
reduce_window_max_p as reduce_window_max_p,
|
||||
reduce_window_min_p as reduce_window_min_p,
|
||||
reduce_window_p as reduce_window_p,
|
||||
reduce_window_shape_tuple as reduce_window_shape_tuple,
|
||||
reduce_window_sum_p as reduce_window_sum_p,
|
||||
select_and_gather_add_p as select_and_gather_add_p,
|
||||
select_and_scatter_p as select_and_scatter_p,
|
||||
select_and_scatter_add_p as select_and_scatter_add_p,
|
||||
)
|
||||
from jax._src.lax.control_flow import (
|
||||
associative_scan as associative_scan,
|
||||
cond as cond,
|
||||
cond_p as cond_p,
|
||||
cumlogsumexp as cumlogsumexp,
|
||||
cumlogsumexp_p as cumlogsumexp_p,
|
||||
cummax as cummax,
|
||||
cummax_p as cummax_p,
|
||||
cummin as cummin,
|
||||
cummin_p as cummin_p,
|
||||
cumprod as cumprod,
|
||||
cumprod_p as cumprod_p,
|
||||
cumsum as cumsum,
|
||||
cumsum_p as cumsum_p,
|
||||
custom_linear_solve as custom_linear_solve,
|
||||
custom_root as custom_root,
|
||||
fori_loop as fori_loop,
|
||||
linear_solve_p as linear_solve_p,
|
||||
map as map,
|
||||
scan as scan,
|
||||
scan_p as scan_p,
|
||||
switch as switch,
|
||||
while_loop as while_loop,
|
||||
while_p as while_p,
|
||||
platform_dependent as platform_dependent,
|
||||
)
|
||||
from jax._src.lax.fft import (
|
||||
fft as fft,
|
||||
fft_p as fft_p,
|
||||
FftType as FftType,
|
||||
)
|
||||
from jax._src.lax.parallel import (
|
||||
all_gather as all_gather,
|
||||
pcast as pcast,
|
||||
all_gather_p as all_gather_p,
|
||||
all_to_all as all_to_all,
|
||||
all_to_all_p as all_to_all_p,
|
||||
axis_index as axis_index,
|
||||
axis_index_p as axis_index_p,
|
||||
axis_size as axis_size,
|
||||
pbroadcast as pbroadcast,
|
||||
pmax as pmax,
|
||||
pmax_p as pmax_p,
|
||||
pmean as pmean,
|
||||
pmin as pmin,
|
||||
pmin_p as pmin_p,
|
||||
ppermute as ppermute,
|
||||
ppermute_p as ppermute_p,
|
||||
psend as psend,
|
||||
precv as precv,
|
||||
pshuffle as pshuffle,
|
||||
psum as psum,
|
||||
psum_p as psum_p,
|
||||
psum_scatter as psum_scatter,
|
||||
pswapaxes as pswapaxes,
|
||||
ragged_all_to_all as ragged_all_to_all,
|
||||
ragged_all_to_all_p as ragged_all_to_all_p,
|
||||
)
|
||||
from jax._src.lax.other import (
|
||||
conv_general_dilated_local as conv_general_dilated_local,
|
||||
conv_general_dilated_patches as conv_general_dilated_patches
|
||||
)
|
||||
from jax._src.lax.ann import (
|
||||
approx_max_k as approx_max_k,
|
||||
approx_min_k as approx_min_k,
|
||||
approx_top_k_p as approx_top_k_p
|
||||
)
|
||||
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
|
||||
from jax.lax import linalg as linalg
|
||||
|
||||
from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
|
||||
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
|
||||
from jax._src.dispatch import device_put_p as device_put_p
|
||||
from jax._src.lax.scaled_dot import scaled_dot as scaled_dot
|
||||
|
||||
_deprecations = {
|
||||
# Deprecated in v0.8.2; finalized in v0.10.0.
|
||||
# TODO(jakevdp) remove entry in v0.11.0.
|
||||
"pvary": (
|
||||
"jax.lax.pvary was deprecated in JAX v0.8.2 and removed in JAX v0.10.0;"
|
||||
" use `jax.lax.pcast(..., to='varying')",
|
||||
None,
|
||||
),
|
||||
}
|
||||
|
||||
import typing as _typing
|
||||
if not _typing.TYPE_CHECKING:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
del _deprecation_getattr
|
||||
del _typing
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
# Copyright 2020 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.lax.linalg import (
|
||||
cholesky as cholesky,
|
||||
cholesky_p as cholesky_p,
|
||||
cholesky_update as cholesky_update,
|
||||
cholesky_update_p as cholesky_update_p,
|
||||
EigImplementation as EigImplementation,
|
||||
eig as eig,
|
||||
eig_p as eig_p,
|
||||
eigh as eigh,
|
||||
EighImplementation as EighImplementation,
|
||||
eigh_p as eigh_p,
|
||||
hessenberg as hessenberg,
|
||||
hessenberg_p as hessenberg_p,
|
||||
householder_product as householder_product,
|
||||
householder_product_p as householder_product_p,
|
||||
lu as lu,
|
||||
lu_p as lu_p,
|
||||
lu_pivots_to_permutation as lu_pivots_to_permutation,
|
||||
lu_pivots_to_permutation_p as lu_pivots_to_permutation_p,
|
||||
ormqr as ormqr,
|
||||
ormqr_p as ormqr_p,
|
||||
qr as qr,
|
||||
qr_p as qr_p,
|
||||
schur as schur,
|
||||
schur_p as schur_p,
|
||||
svd as svd,
|
||||
svd_p as svd_p,
|
||||
SvdAlgorithm as SvdAlgorithm,
|
||||
symmetric_product as symmetric_product,
|
||||
symmetric_product_p as symmetric_product_p,
|
||||
triangular_solve as triangular_solve,
|
||||
triangular_solve_p as triangular_solve_p,
|
||||
tridiagonal as tridiagonal,
|
||||
tridiagonal_p as tridiagonal_p,
|
||||
tridiagonal_solve as tridiagonal_solve,
|
||||
tridiagonal_solve_p as tridiagonal_solve_p,
|
||||
)
|
||||
|
||||
from jax._src.tpu.linalg.qdwh import (
|
||||
qdwh as qdwh
|
||||
)
|
||||
Reference in New Issue
Block a user