# Copyright 2025 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from jax._src import dtypes from jax._src.core import ShapedArray from jax._src.lib import _jax import numpy as np # TypedInt, TypedFloat, and TypedComplex are subclasses of int, float, and # complex that carry a JAX dtype. Canonicalization forms these types from int, # float, and complex. Repeated canonicalization, including under different # jax_enable_x64 modes, preserves the dtype. # Precomputed weak scalar avals _weak_int32_aval = ShapedArray((), np.dtype(np.int32), weak_type=True) _weak_int64_aval = ShapedArray((), np.dtype(np.int64), weak_type=True) _weak_float32_aval = ShapedArray((), np.dtype(np.float32), weak_type=True) _weak_float64_aval = ShapedArray((), np.dtype(np.float64), weak_type=True) _weak_complex64_aval = ShapedArray((), np.dtype(np.complex64), weak_type=True) _weak_complex128_aval = ShapedArray((), np.dtype(np.complex128), weak_type=True) class TypedInt(int): dtype: np.dtype aval: ShapedArray def __new__(cls, value: int, dtype: np.dtype): v = super().__new__(cls, value) v.dtype = dtype if dtype == np.dtype(np.int32): v.aval = _weak_int32_aval elif dtype == np.dtype(np.int64): v.aval = _weak_int64_aval else: v.aval = ShapedArray((), dtype, weak_type=True) return v def __repr__(self): return f'TypedInt({int(self)}, dtype={self.dtype.name})' def __getnewargs__(self): return (int(self), self.dtype) class TypedFloat(float): dtype: np.dtype aval: ShapedArray def __new__(cls, value: float, dtype: np.dtype): v = super().__new__(cls, value) v.dtype = dtype if dtype == np.dtype(np.float32): v.aval = _weak_float32_aval elif dtype == np.dtype(np.float64): v.aval = _weak_float64_aval else: v.aval = ShapedArray((), dtype, weak_type=True) return v def __repr__(self): return f'TypedFloat({float(self)}, dtype={self.dtype.name})' def __str__(self): return str(float(self)) def __getnewargs__(self): return (float(self), self.dtype) class TypedComplex(complex): dtype: np.dtype aval: ShapedArray def __new__(cls, value: complex, dtype: np.dtype): v = super().__new__(cls, value) v.dtype = dtype if dtype == np.dtype(np.complex64): v.aval = _weak_complex64_aval elif dtype == np.dtype(np.complex128): v.aval = _weak_complex128_aval else: v.aval = ShapedArray((), dtype, weak_type=True) return v def __repr__(self): return f'TypedComplex({complex(self)}, dtype={self.dtype.name})' def __getnewargs__(self): return (complex(self), self.dtype) typed_scalar_types: set[type] = {TypedInt, TypedFloat, TypedComplex} class TypedNdArray(np.ndarray): """A TypedNdArray is a host-side array used by JAX during tracing. TypedNdArray is a subclass of np.ndarray that carries additional JAX type information: * its type is not canonicalized by JAX, irrespective of the jax_enable_x64 mode * it can be weakly typed. """ __slots__ = ('_aval', '_weak_type') def __new__(cls, val: np.ndarray, aval: ShapedArray | None = None): obj = np.asarray(val).view(cls) if aval is not None: obj._aval = aval return obj def __array_finalize__(self, obj): self._aval = None self._weak_type = (obj.aval.weak_type if isinstance(obj, TypedNdArray) else False) @property def aval(self) -> ShapedArray: result = self._aval if result is None: # It is possible that multiple threads might race to reach here. However # this seems safe since they will all set the same value. result = ShapedArray(self.shape, self.dtype, weak_type=self._weak_type) self._aval = result return result @property def weak_type(self) -> bool: return self.aval.weak_type @property def val(self) -> np.ndarray: return np.asarray(self) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): inputs = tuple( np.asarray(x) if isinstance(x, TypedNdArray) else x for x in inputs ) if 'out' in kwargs: kwargs['out'] = tuple( np.asarray(x) if isinstance(x, TypedNdArray) else x for x in kwargs['out'] ) return getattr(ufunc, method)(*inputs, **kwargs) def __repr__(self): prefix = 'TypedNdArray(' if self.aval.weak_type: dtype_str = f'dtype={self.dtype.name}, weak_type=True)' else: dtype_str = f'dtype={self.dtype.name})' line_width = np.get_printoptions()['linewidth'] if self.size == 0: s = f'[], shape={self.shape}' else: s = np.array2string( np.asarray(self), prefix=prefix, suffix=',', separator=', ', max_line_width=line_width, ) last_line_len = len(s) - s.rfind('\n') + 1 sep = ' ' if last_line_len + len(dtype_str) + 1 > line_width: sep = ' ' * len(prefix) return f'{prefix}{s},{sep}{dtype_str}' def __reduce__(self): return (TypedNdArray, (np.asarray(self), self.aval.weak_type)) def __getnewargs__(self): return (np.asarray(self), self.aval.weak_type) _jax.set_typed_ndarray_type(TypedNdArray) dtypes.register_type_whose_dtype_should_not_be_canonicalized(TypedNdArray) _jax.set_typed_int_type(TypedInt) _jax.set_typed_float_type(TypedFloat) _jax.set_typed_complex_type(TypedComplex) for _typ in typed_scalar_types: dtypes.register_weak_scalar_type(_typ) dtypes.register_type_whose_dtype_should_not_be_canonicalized(_typ)