README.md
setup.cfg
setup.py
jax/__init__.py
jax/abstract_arrays.py
jax/ad_util.py
jax/api.py
jax/api_util.py
jax/config.py
jax/core.py
jax/custom_derivatives.py
jax/dlpack.py
jax/dtypes.py
jax/errors.py
jax/flatten_util.py
jax/jaxpr_util.py
jax/lax_reference.py
jax/lazy.py
jax/linear_util.py
jax/profiler.py
jax/py.typed
jax/random.py
jax/test_util.py
jax/tree_util.py
jax/util.py
jax/version.py
jax.egg-info/PKG-INFO
jax.egg-info/SOURCES.txt
jax.egg-info/dependency_links.txt
jax.egg-info/requires.txt
jax.egg-info/top_level.txt
jax/_src/__init__.py
jax/_src/dlpack.py
jax/_src/errors.py
jax/_src/pprint_util.py
jax/_src/profiler.py
jax/_src/random.py
jax/_src/source_info_util.py
jax/_src/traceback_util.py
jax/_src/util.py
jax/_src/image/__init__.py
jax/_src/image/scale.py
jax/_src/lax/__init__.py
jax/_src/lax/control_flow.py
jax/_src/lax/fft.py
jax/_src/lax/lax.py
jax/_src/lax/linalg.py
jax/_src/lax/other.py
jax/_src/lax/parallel.py
jax/_src/nn/__init__.py
jax/_src/nn/functions.py
jax/_src/nn/initializers.py
jax/_src/numpy/__init__.py
jax/_src/numpy/fft.py
jax/_src/numpy/lax_numpy.py
jax/_src/numpy/linalg.py
jax/_src/numpy/polynomial.py
jax/_src/numpy/util.py
jax/_src/numpy/vectorize.py
jax/_src/ops/__init__.py
jax/_src/ops/scatter.py
jax/_src/scipy/__init__.py
jax/_src/scipy/linalg.py
jax/_src/scipy/ndimage.py
jax/_src/scipy/signal.py
jax/_src/scipy/special.py
jax/_src/scipy/optimize/__init__.py
jax/_src/scipy/optimize/bfgs.py
jax/_src/scipy/optimize/line_search.py
jax/_src/scipy/optimize/minimize.py
jax/_src/scipy/sparse/__init__.py
jax/_src/scipy/sparse/linalg.py
jax/_src/scipy/stats/__init__.py
jax/_src/scipy/stats/bernoulli.py
jax/_src/scipy/stats/beta.py
jax/_src/scipy/stats/betabinom.py
jax/_src/scipy/stats/cauchy.py
jax/_src/scipy/stats/chi2.py
jax/_src/scipy/stats/dirichlet.py
jax/_src/scipy/stats/expon.py
jax/_src/scipy/stats/gamma.py
jax/_src/scipy/stats/geom.py
jax/_src/scipy/stats/laplace.py
jax/_src/scipy/stats/logistic.py
jax/_src/scipy/stats/multivariate_normal.py
jax/_src/scipy/stats/norm.py
jax/_src/scipy/stats/pareto.py
jax/_src/scipy/stats/poisson.py
jax/_src/scipy/stats/t.py
jax/_src/scipy/stats/uniform.py
jax/_src/third_party/__init__.py
jax/_src/third_party/numpy/__init__.py
jax/_src/third_party/numpy/linalg.py
jax/experimental/__init__.py
jax/experimental/callback.py
jax/experimental/djax.py
jax/experimental/doubledouble.py
jax/experimental/host_callback.py
jax/experimental/jet.py
jax/experimental/loops.py
jax/experimental/maps.py
jax/experimental/ode.py
jax/experimental/optimizers.py
jax/experimental/pjit.py
jax/experimental/stax.py
jax/experimental/x64_context.py
jax/experimental/jax2tf/__init__.py
jax/experimental/jax2tf/call_tf.py
jax/experimental/jax2tf/jax2tf.py
jax/experimental/jax2tf/tests/__init__.py
jax/experimental/jax2tf/tests/call_tf_test.py
jax/experimental/jax2tf/tests/control_flow_ops_test.py
jax/experimental/jax2tf/tests/jax2tf_limitations.py
jax/experimental/jax2tf/tests/jax2tf_test.py
jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py
jax/experimental/jax2tf/tests/primitive_harness.py
jax/experimental/jax2tf/tests/primitives_test.py
jax/experimental/jax2tf/tests/savedmodel_test.py
jax/experimental/jax2tf/tests/shape_poly_test.py
jax/experimental/jax2tf/tests/sharding_test.py
jax/experimental/jax2tf/tests/stax_test.py
jax/experimental/jax2tf/tests/tf_test_util.py
jax/image/__init__.py
jax/interpreters/__init__.py
jax/interpreters/ad.py
jax/interpreters/batching.py
jax/interpreters/invertible_ad.py
jax/interpreters/masking.py
jax/interpreters/partial_eval.py
jax/interpreters/pxla.py
jax/interpreters/sharded_jit.py
jax/interpreters/xla.py
jax/lax/__init__.py
jax/lax/linalg.py
jax/lib/__init__.py
jax/lib/xla_bridge.py
jax/nn/__init__.py
jax/nn/initializers.py
jax/numpy/__init__.py
jax/numpy/fft.py
jax/numpy/linalg.py
jax/ops/__init__.py
jax/scipy/__init__.py
jax/scipy/linalg.py
jax/scipy/ndimage.py
jax/scipy/signal.py
jax/scipy/special.py
jax/scipy/optimize/__init__.py
jax/scipy/sparse/__init__.py
jax/scipy/sparse/linalg.py
jax/scipy/stats/__init__.py
jax/scipy/stats/bernoulli.py
jax/scipy/stats/beta.py
jax/scipy/stats/betabinom.py
jax/scipy/stats/cauchy.py
jax/scipy/stats/chi2.py
jax/scipy/stats/dirichlet.py
jax/scipy/stats/expon.py
jax/scipy/stats/gamma.py
jax/scipy/stats/geom.py
jax/scipy/stats/laplace.py
jax/scipy/stats/logistic.py
jax/scipy/stats/multivariate_normal.py
jax/scipy/stats/norm.py
jax/scipy/stats/pareto.py
jax/scipy/stats/poisson.py
jax/scipy/stats/t.py
jax/scipy/stats/uniform.py
jax/tools/__init__.py
jax/tools/colab_tpu.py
jax/tools/jax_to_hlo.py