AUTHORS
LICENSE
README.md
pyproject.toml
setup.py
jax/__init__.py
jax/ad_checkpoint.py
jax/api_util.py
jax/cloud_tpu_init.py
jax/collect_profile.py
jax/core.py
jax/custom_batching.py
jax/custom_derivatives.py
jax/custom_transpose.py
jax/debug.py
jax/distributed.py
jax/dlpack.py
jax/dtypes.py
jax/errors.py
jax/export.py
jax/flatten_util.py
jax/monitoring.py
jax/profiler.py
jax/py.typed
jax/random.py
jax/sharding.py
jax/stages.py
jax/test_util.py
jax/tree.py
jax/tree_util.py
jax/typing.py
jax/util.py
jax/version.py
jax.egg-info/PKG-INFO
jax.egg-info/SOURCES.txt
jax.egg-info/dependency_links.txt
jax.egg-info/not-zip-safe
jax.egg-info/requires.txt
jax.egg-info/top_level.txt
jax/_src/__init__.py
jax/_src/abstract_arrays.py
jax/_src/ad_checkpoint.py
jax/_src/ad_util.py
jax/_src/api.py
jax/_src/api_util.py
jax/_src/array.py
jax/_src/basearray.py
jax/_src/basearray.pyi
jax/_src/blocked_sampler.py
jax/_src/cache_key.py
jax/_src/callback.py
jax/_src/checkify.py
jax/_src/cloud_tpu_init.py
jax/_src/compilation_cache.py
jax/_src/compilation_cache_interface.py
jax/_src/compiler.py
jax/_src/compute_on.py
jax/_src/config.py
jax/_src/core.py
jax/_src/custom_api_util.py
jax/_src/custom_batching.py
jax/_src/custom_derivatives.py
jax/_src/custom_partitioning.py
jax/_src/custom_transpose.py
jax/_src/debugging.py
jax/_src/deprecations.py
jax/_src/dispatch.py
jax/_src/distributed.py
jax/_src/dlpack.py
jax/_src/dtypes.py
jax/_src/earray.py
jax/_src/effects.py
jax/_src/environment_info.py
jax/_src/errors.py
jax/_src/flatten_util.py
jax/_src/hardware_utils.py
jax/_src/jaxpr_util.py
jax/_src/lax_reference.py
jax/_src/layout.py
jax/_src/lazy_loader.py
jax/_src/linear_util.py
jax/_src/logging_config.py
jax/_src/lru_cache.py
jax/_src/mesh.py
jax/_src/mesh_utils.py
jax/_src/monitoring.py
jax/_src/op_shardings.py
jax/_src/partition_spec.py
jax/_src/path.py
jax/_src/pickle_util.py
jax/_src/pjit.py
jax/_src/pretty_printer.py
jax/_src/prng.py
jax/_src/profiler.py
jax/_src/public_test_util.py
jax/_src/random.py
jax/_src/shard_alike.py
jax/_src/sharding.py
jax/_src/sharding_impls.py
jax/_src/sharding_specs.py
jax/_src/source_info_util.py
jax/_src/sourcemap.py
jax/_src/stages.py
jax/_src/test_util.py
jax/_src/tpu_custom_call.py
jax/_src/traceback_util.py
jax/_src/tree.py
jax/_src/tree_util.py
jax/_src/typing.py
jax/_src/util.py
jax/_src/xla_bridge.py
jax/_src/xla_metadata.py
jax/_src/clusters/__init__.py
jax/_src/clusters/cloud_tpu_cluster.py
jax/_src/clusters/cluster.py
jax/_src/clusters/k8s_cluster.py
jax/_src/clusters/mpi4py_cluster.py
jax/_src/clusters/ompi_cluster.py
jax/_src/clusters/slurm_cluster.py
jax/_src/cudnn/__init__.py
jax/_src/cudnn/fused_attention_stablehlo.py
jax/_src/cudnn/fusion.py
jax/_src/debugger/__init__.py
jax/_src/debugger/cli_debugger.py
jax/_src/debugger/colab_debugger.py
jax/_src/debugger/colab_lib.py
jax/_src/debugger/core.py
jax/_src/debugger/web_debugger.py
jax/_src/export/__init__.py
jax/_src/export/_export.py
jax/_src/export/serialization.py
jax/_src/export/serialization_generated.py
jax/_src/export/shape_poly.py
jax/_src/export/shape_poly_decision.py
jax/_src/extend/__init__.py
jax/_src/extend/ffi.py
jax/_src/extend/random.py
jax/_src/image/__init__.py
jax/_src/image/scale.py
jax/_src/internal_test_util/__init__.py
jax/_src/internal_test_util/deprecation_module.py
jax/_src/internal_test_util/export_back_compat_test_util.py
jax/_src/internal_test_util/lax_test_util.py
jax/_src/internal_test_util/test_harnesses.py
jax/_src/internal_test_util/lazy_loader_module/__init__.py
jax/_src/internal_test_util/lazy_loader_module/lazy_test_submodule.py
jax/_src/interpreters/__init__.py
jax/_src/interpreters/ad.py
jax/_src/interpreters/batching.py
jax/_src/interpreters/mlir.py
jax/_src/interpreters/partial_eval.py
jax/_src/interpreters/pxla.py
jax/_src/interpreters/xla.py
jax/_src/lax/__init__.py
jax/_src/lax/ann.py
jax/_src/lax/convolution.py
jax/_src/lax/eigh.py
jax/_src/lax/fft.py
jax/_src/lax/lax.py
jax/_src/lax/linalg.py
jax/_src/lax/other.py
jax/_src/lax/parallel.py
jax/_src/lax/qdwh.py
jax/_src/lax/slicing.py
jax/_src/lax/special.py
jax/_src/lax/stack.py
jax/_src/lax/svd.py
jax/_src/lax/utils.py
jax/_src/lax/windowed_reductions.py
jax/_src/lax/control_flow/__init__.py
jax/_src/lax/control_flow/common.py
jax/_src/lax/control_flow/conditionals.py
jax/_src/lax/control_flow/for_loop.py
jax/_src/lax/control_flow/loops.py
jax/_src/lax/control_flow/solves.py
jax/_src/lib/__init__.py
jax/_src/lib/mosaic_gpu.py
jax/_src/lib/triton.py
jax/_src/lib/mlir/__init__.py
jax/_src/lib/mlir/dialects/__init__.py
jax/_src/nn/__init__.py
jax/_src/nn/functions.py
jax/_src/nn/initializers.py
jax/_src/numpy/__init__.py
jax/_src/numpy/array_api_metadata.py
jax/_src/numpy/array_methods.py
jax/_src/numpy/fft.py
jax/_src/numpy/index_tricks.py
jax/_src/numpy/lax_numpy.py
jax/_src/numpy/linalg.py
jax/_src/numpy/polynomial.py
jax/_src/numpy/reductions.py
jax/_src/numpy/setops.py
jax/_src/numpy/ufunc_api.py
jax/_src/numpy/ufuncs.py
jax/_src/numpy/util.py
jax/_src/numpy/vectorize.py
jax/_src/ops/__init__.py
jax/_src/ops/scatter.py
jax/_src/ops/special.py
jax/_src/pallas/__init__.py
jax/_src/pallas/core.py
jax/_src/pallas/pallas_call.py
jax/_src/pallas/primitives.py
jax/_src/pallas/utils.py
jax/_src/pallas/mosaic/__init__.py
jax/_src/pallas/mosaic/core.py
jax/_src/pallas/mosaic/error_handling.py
jax/_src/pallas/mosaic/lowering.py
jax/_src/pallas/mosaic/pallas_call_registration.py
jax/_src/pallas/mosaic/pipeline.py
jax/_src/pallas/mosaic/primitives.py
jax/_src/pallas/mosaic/random.py
jax/_src/pallas/mosaic/verification.py
jax/_src/pallas/mosaic_gpu/__init__.py
jax/_src/pallas/mosaic_gpu/core.py
jax/_src/pallas/mosaic_gpu/lowering.py
jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
jax/_src/pallas/mosaic_gpu/primitives.py
jax/_src/pallas/triton/__init__.py
jax/_src/pallas/triton/core.py
jax/_src/pallas/triton/lowering.py
jax/_src/pallas/triton/pallas_call_registration.py
jax/_src/pallas/triton/primitives.py
jax/_src/scipy/__init__.py
jax/_src/scipy/fft.py
jax/_src/scipy/integrate.py
jax/_src/scipy/linalg.py
jax/_src/scipy/ndimage.py
jax/_src/scipy/signal.py
jax/_src/scipy/special.py
jax/_src/scipy/cluster/__init__.py
jax/_src/scipy/cluster/vq.py
jax/_src/scipy/interpolate/__init__.py
jax/_src/scipy/optimize/__init__.py
jax/_src/scipy/optimize/_lbfgs.py
jax/_src/scipy/optimize/bfgs.py
jax/_src/scipy/optimize/line_search.py
jax/_src/scipy/optimize/minimize.py
jax/_src/scipy/sparse/__init__.py
jax/_src/scipy/sparse/linalg.py
jax/_src/scipy/spatial/__init__.py
jax/_src/scipy/spatial/transform.py
jax/_src/scipy/stats/__init__.py
jax/_src/scipy/stats/_core.py
jax/_src/scipy/stats/bernoulli.py
jax/_src/scipy/stats/beta.py
jax/_src/scipy/stats/betabinom.py
jax/_src/scipy/stats/binom.py
jax/_src/scipy/stats/cauchy.py
jax/_src/scipy/stats/chi2.py
jax/_src/scipy/stats/dirichlet.py
jax/_src/scipy/stats/expon.py
jax/_src/scipy/stats/gamma.py
jax/_src/scipy/stats/gennorm.py
jax/_src/scipy/stats/geom.py
jax/_src/scipy/stats/kde.py
jax/_src/scipy/stats/laplace.py
jax/_src/scipy/stats/logistic.py
jax/_src/scipy/stats/multinomial.py
jax/_src/scipy/stats/multivariate_normal.py
jax/_src/scipy/stats/nbinom.py
jax/_src/scipy/stats/norm.py
jax/_src/scipy/stats/pareto.py
jax/_src/scipy/stats/poisson.py
jax/_src/scipy/stats/t.py
jax/_src/scipy/stats/truncnorm.py
jax/_src/scipy/stats/uniform.py
jax/_src/scipy/stats/vonmises.py
jax/_src/scipy/stats/wrapcauchy.py
jax/_src/state/__init__.py
jax/_src/state/discharge.py
jax/_src/state/indexing.py
jax/_src/state/primitives.py
jax/_src/state/types.py
jax/_src/state/utils.py
jax/_src/third_party/__init__.py
jax/_src/third_party/scipy/__init__.py
jax/_src/third_party/scipy/betaln.py
jax/_src/third_party/scipy/interpolate.py
jax/_src/third_party/scipy/linalg.py
jax/_src/third_party/scipy/signal_helper.py
jax/_src/third_party/scipy/special.py
jax/example_libraries/__init__.py
jax/example_libraries/optimizers.py
jax/example_libraries/stax.py
jax/experimental/__init__.py
jax/experimental/attrs.py
jax/experimental/checkify.py
jax/experimental/compute_on.py
jax/experimental/custom_partitioning.py
jax/experimental/host_callback.py
jax/experimental/jet.py
jax/experimental/layout.py
jax/experimental/mesh_utils.py
jax/experimental/multihost_utils.py
jax/experimental/ode.py
jax/experimental/pjit.py
jax/experimental/profiler.py
jax/experimental/rnn.py
jax/experimental/serialize_executable.py
jax/experimental/shard_alike.py
jax/experimental/shard_map.py
jax/experimental/topologies.py
jax/experimental/x64_context.py
jax/experimental/xla_metadata.py
jax/experimental/array_api/__init__.py
jax/experimental/array_serialization/__init__.py
jax/experimental/array_serialization/serialization.py
jax/experimental/array_serialization/serialization_test.py
jax/experimental/compilation_cache/__init__.py
jax/experimental/compilation_cache/compilation_cache.py
jax/experimental/export/__init__.py
jax/experimental/jax2tf/__init__.py
jax/experimental/jax2tf/call_tf.py
jax/experimental/jax2tf/impl_no_xla.py
jax/experimental/jax2tf/jax2tf.py
jax/experimental/jax2tf/examples/__init__.py
jax/experimental/jax2tf/examples/keras_reuse_main.py
jax/experimental/jax2tf/examples/keras_reuse_main_test.py
jax/experimental/jax2tf/examples/mnist_lib.py
jax/experimental/jax2tf/examples/saved_model_lib.py
jax/experimental/jax2tf/examples/saved_model_main.py
jax/experimental/jax2tf/examples/saved_model_main_test.py
jax/experimental/jax2tf/examples/serving/__init__.py
jax/experimental/jax2tf/examples/serving/model_server_request.py
jax/experimental/jax2tf/tests/__init__.py
jax/experimental/jax2tf/tests/back_compat_tf_test.py
jax/experimental/jax2tf/tests/call_tf_test.py
jax/experimental/jax2tf/tests/control_flow_ops_test.py
jax/experimental/jax2tf/tests/converters.py
jax/experimental/jax2tf/tests/cross_compilation_check.py
jax/experimental/jax2tf/tests/jax2tf_limitations.py
jax/experimental/jax2tf/tests/jax2tf_test.py
jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py
jax/experimental/jax2tf/tests/model_harness.py
jax/experimental/jax2tf/tests/models_test_main.py
jax/experimental/jax2tf/tests/primitives_test.py
jax/experimental/jax2tf/tests/savedmodel_test.py
jax/experimental/jax2tf/tests/shape_poly_test.py
jax/experimental/jax2tf/tests/sharding_test.py
jax/experimental/jax2tf/tests/tf_test_util.py
jax/experimental/key_reuse/__init__.py
jax/experimental/key_reuse/_core.py
jax/experimental/mosaic/__init__.py
jax/experimental/mosaic/dialects.py
jax/experimental/mosaic/gpu/__init__.py
jax/experimental/mosaic/gpu/core.py
jax/experimental/mosaic/gpu/fragmented_array.py
jax/experimental/mosaic/gpu/profiler.py
jax/experimental/mosaic/gpu/utils.py
jax/experimental/mosaic/gpu/wgmma.py
jax/experimental/pallas/__init__.py
jax/experimental/pallas/gpu.py
jax/experimental/pallas/tpu.py
jax/experimental/pallas/ops/__init__.py
jax/experimental/pallas/ops/gpu/__init__.py
jax/experimental/pallas/ops/gpu/attention.py
jax/experimental/pallas/ops/gpu/decode_attention.py
jax/experimental/pallas/ops/gpu/layer_norm.py
jax/experimental/pallas/ops/gpu/rms_norm.py
jax/experimental/pallas/ops/gpu/softmax.py
jax/experimental/pallas/ops/tpu/__init__.py
jax/experimental/pallas/ops/tpu/all_gather.py
jax/experimental/pallas/ops/tpu/example_kernel.py
jax/experimental/pallas/ops/tpu/flash_attention.py
jax/experimental/pallas/ops/tpu/matmul.py
jax/experimental/pallas/ops/tpu/megablox/__init__.py
jax/experimental/pallas/ops/tpu/megablox/common.py
jax/experimental/pallas/ops/tpu/megablox/gmm.py
jax/experimental/pallas/ops/tpu/megablox/ops.py
jax/experimental/pallas/ops/tpu/paged_attention/__init__.py
jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py
jax/experimental/pallas/ops/tpu/paged_attention/quantization_utils.py
jax/experimental/pallas/ops/tpu/splash_attention/__init__.py
jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py
jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask.py
jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_mask_info.py
jax/experimental/sparse/__init__.py
jax/experimental/sparse/_base.py
jax/experimental/sparse/_lowerings.py
jax/experimental/sparse/ad.py
jax/experimental/sparse/api.py
jax/experimental/sparse/bcoo.py
jax/experimental/sparse/bcsr.py
jax/experimental/sparse/coo.py
jax/experimental/sparse/csr.py
jax/experimental/sparse/linalg.py
jax/experimental/sparse/nm.py
jax/experimental/sparse/random.py
jax/experimental/sparse/test_util.py
jax/experimental/sparse/transform.py
jax/experimental/sparse/util.py
jax/extend/__init__.py
jax/extend/backend.py
jax/extend/ffi.py
jax/extend/ifrt_programs.py
jax/extend/linear_util.py
jax/extend/random.py
jax/extend/source_info_util.py
jax/extend/core/__init__.py
jax/extend/core/primitives.py
jax/extend/mlir/__init__.py
jax/extend/mlir/ir.py
jax/extend/mlir/passmanager.py
jax/extend/mlir/dialects/__init__.py
jax/extend/mlir/dialects/arith.py
jax/extend/mlir/dialects/builtin.py
jax/extend/mlir/dialects/chlo.py
jax/extend/mlir/dialects/func.py
jax/extend/mlir/dialects/math.py
jax/extend/mlir/dialects/memref.py
jax/extend/mlir/dialects/scf.py
jax/extend/mlir/dialects/sdy.py
jax/extend/mlir/dialects/sparse_tensor.py
jax/extend/mlir/dialects/stablehlo.py
jax/extend/mlir/dialects/vector.py
jax/image/__init__.py
jax/interpreters/__init__.py
jax/interpreters/ad.py
jax/interpreters/batching.py
jax/interpreters/mlir.py
jax/interpreters/partial_eval.py
jax/interpreters/pxla.py
jax/interpreters/xla.py
jax/lax/__init__.py
jax/lax/linalg.py
jax/lib/__init__.py
jax/lib/xla_bridge.py
jax/lib/xla_client.py
jax/lib/xla_extension.py
jax/nn/__init__.py
jax/nn/initializers.py
jax/numpy/__init__.py
jax/numpy/__init__.pyi
jax/numpy/fft.py
jax/numpy/linalg.py
jax/ops/__init__.py
jax/scipy/__init__.py
jax/scipy/fft.py
jax/scipy/integrate.py
jax/scipy/linalg.py
jax/scipy/ndimage.py
jax/scipy/signal.py
jax/scipy/special.py
jax/scipy/cluster/__init__.py
jax/scipy/cluster/vq.py
jax/scipy/interpolate/__init__.py
jax/scipy/optimize/__init__.py
jax/scipy/sparse/__init__.py
jax/scipy/sparse/linalg.py
jax/scipy/spatial/__init__.py
jax/scipy/spatial/transform.py
jax/scipy/stats/__init__.py
jax/scipy/stats/bernoulli.py
jax/scipy/stats/beta.py
jax/scipy/stats/betabinom.py
jax/scipy/stats/binom.py
jax/scipy/stats/cauchy.py
jax/scipy/stats/chi2.py
jax/scipy/stats/dirichlet.py
jax/scipy/stats/expon.py
jax/scipy/stats/gamma.py
jax/scipy/stats/gennorm.py
jax/scipy/stats/geom.py
jax/scipy/stats/laplace.py
jax/scipy/stats/logistic.py
jax/scipy/stats/multinomial.py
jax/scipy/stats/multivariate_normal.py
jax/scipy/stats/nbinom.py
jax/scipy/stats/norm.py
jax/scipy/stats/pareto.py
jax/scipy/stats/poisson.py
jax/scipy/stats/t.py
jax/scipy/stats/truncnorm.py
jax/scipy/stats/uniform.py
jax/scipy/stats/vonmises.py
jax/scipy/stats/wrapcauchy.py
jax/tools/__init__.py
jax/tools/build_utils.py
jax/tools/colab_tpu.py
jax/tools/jax_to_ir.py
jax/tools/pgo_nsys_converter.py