Skip to content

JAX

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!

Libraries

  • Neural Network Libraries
    • Flax - Centered on flexibility and clarity.
    • Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
    • Objax - Has an object oriented design similar to PyTorch.
    • Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
    • Trax - "Batteries included" deep learning library focused on providing solutions for common workloads.
    • Jraph - Lightweight graph neural network library.
    • Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
    • HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
    • Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
    • Scenic - A Jax Library for Computer Vision Research and Beyond.
  • Levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.
  • EasyLM - LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
  • NumPyro - Probabilistic programming based on the Pyro library.
  • Chex - Utilities to write and test reliable JAX code.
  • Optax - Gradient processing and optimization library.
  • RLax - Library for implementing reinforcement learning agents.
  • JAX, M.D. - Accelerated, differential molecular dynamics.
  • Coax - Turn RL papers into code, the easy way.
  • Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
  • cvxpylayers - Construct differentiable convex optimization layers.
  • TensorLy - Tensor learning made simple.
  • NetKet - Machine Learning toolbox for Quantum Physics.
  • Fortuna - AWS library for Uncertainty Quantification in Deep Learning.
  • BlackJAX - Library of samplers for JAX.

New Libraries

This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.

  • Neural Network Libraries
    • FedJAX - Federated learning in JAX, built on Optax and Haiku.
    • Equivariant MLP - Construct equivariant neural network layers.
    • jax-resnet - Implementations and checkpoints for ResNet variants in Flax.
    • Parallax - Immutable Torch Modules for JAX.
  • jax-unirep - Library implementing the UniRep model for protein machine learning applications.
  • jax-flows - Normalizing flows in JAX.
  • sklearn-jax-kernels - scikit-learn kernel matrices using JAX.
  • jax-cosmo - Differentiable cosmology library.
  • efax - Exponential Families in JAX.
  • mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
  • imax - Image augmentations and transformations.
  • FlaxVision - Flax version of TorchVision.
  • Oryx - Probabilistic programming language based on program transformations.
  • Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
  • delta PV - A photovoltaic simulator with automatic differentation.
  • jaxlie - Lie theory library for rigid body transformations and optimization.
  • BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
  • flaxmodels - Pretrained models for Jax/Flax.
  • CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing.
  • exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
  • JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
  • PIX - PIX is an image processing library in JAX, for JAX.
  • bayex - Bayesian Optimization powered by JAX.
  • JaxDF - Framework for differentiable simulators with arbitrary discretizations.
  • tree-math - Convert functions that operate on arrays into functions that operate on PyTrees.
  • jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX.
  • PGMax - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
  • EvoJAX - Hardware-Accelerated Neuroevolution
  • evosax - JAX-Based Evolution Strategies
  • SymJAX - Symbolic CPU/GPU/TPU programming.
  • mcx - Express & compile probabilistic programs for performant inference.
  • Einshape - DSL-based reshaping library for JAX and other frameworks.
  • ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
  • Diffrax - Numerical differential equation solvers in JAX.
  • tinygp - The tiniest of Gaussian process libraries in JAX.
  • gymnax - Reinforcement Learning Environments with the well-known gym API.
  • Mctx - Monte Carlo tree search algorithms in native JAX.
  • KFAC-JAX - Second Order Optimization with Approximate Curvature for NNs.
  • TF2JAX - Convert functions/graphs to JAX functions.
  • jwave - A library for differentiable acoustic simulations
  • GPJax - Gaussian processes in JAX.
  • Jumanji - A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.
  • Eqxvision - Equinox version of Torchvision.
  • JAXFit - Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).
  • econpizza - Solve macroeconomic models with hetereogeneous agents using JAX.
  • SPU - A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).
  • jax-tqdm - Add a tqdm progress bar to JAX scans and loops.
  • safejax - Serialize JAX, Flax, Haiku, or Objax model params with 🤗safetensors.
  • Kernex - Differentiable stencil decorators in JAX.
  • MaxText - A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.
  • Pax - A Jax-based machine learning framework for training large scale models.
  • Praxis - The layer library for Pax with a goal to be usable by other JAX-based ML projects.
  • purejaxrl - Vectorisable, end-to-end RL algorithms in JAX.
  • Lorax - Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
  • SCICO - Scientific computational imaging in JAX.
  • Spyx - Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.
  • BrainPy - Brain Dynamics Programming in Python.
  • OTT-JAX - Optimal transport tools in JAX.
  • QDax - Quality Diversity optimization in Jax.
  • JAX Toolbox - Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.
  • Pgx - Vectorized board game environments for RL with an AlphaZero example.
  • XLB - A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.

Models and Projects

JAX

Flax

Haiku

Trax

  • Reformer - Implementation of the Reformer (efficient transformer) architecture.

NumPyro

Videos

Papers

This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

Tutorials and Blog Posts

Books

  • Jax in Action - A hands-on guide to using JAX for deep learning and other mathematically-intensive applications.

Community

Contributing

Contributions welcome! Read the contribution guidelines first.