Awesome JAX
      
    
    
    
      JAX brings automatic
      differentiation and the
      XLA compiler together through
      a NumPy-like API for high performance
      machine learning research on accelerators like GPUs and TPUs.
      
    
    
      This is a curated list of awesome JAX libraries, projects, and other
      resources. Contributions are welcome!
    
    Contents
    
    
    Libraries
    
      - 
        Neural Network Libraries
        
          - 
            Flax - Centered on
            flexibility and clarity.
            
           
          - 
            Haiku - Focused
            on simplicity, created by the authors of Sonnet at DeepMind.
            
           
          - 
            Objax - Has an object
            oriented design similar to PyTorch.
            
           
          - 
            Elegy - A
            framework-agnostic Trainer interface for the Jax ecosystem. Supports
            Flax, Haiku, and Optax.
            
           
          - 
            Trax - “Batteries
            included” deep learning library focused on providing solutions for
            common workloads.
            
           
          - 
            Jraph - Lightweight
            graph neural network library.
            
           
          - 
            Neural Tangents
            - High-level API for specifying neural networks of both finite and
            infinite width.
            
           
          - 
            HuggingFace
            - Ecosystem of pretrained Transformers for a wide range of natural
            language tasks (Flax).
            
           
          - 
            Equinox -
            Callable PyTrees and filtered JIT/grad transformations => neural
            networks in JAX.
            
           
        
       
      - 
        NumPyro -
        Probabilistic programming based on the Pyro library.
        
       
      - 
        Chex - Utilities to write
        and test reliable JAX code.
        
       
      - 
        Optax - Gradient
        processing and optimization library.
        
       
      - 
        RLax - Library for
        implementing reinforcement learning agents.
        
       
      - 
        JAX, M.D. - Accelerated,
        differential molecular dynamics.
        
       
      - 
        Coax - Turn RL papers
        into code, the easy way.
        
       
      - 
        SymJAX - Symbolic
        CPU/GPU/TPU programming.
        
       
      - 
        mcx - Express & compile
        probabilistic programs for performant inference.
        
       
      - 
        Distrax -
        Reimplementation of TensorFlow Probability, containing probability
        distributions and bijectors.
        
       
      - 
        cvxpylayers -
        Construct differentiable convex optimization layers.
        
       
      - 
        TensorLy - Tensor
        learning made simple.
        
       
      - 
        NetKet - Machine Learning
        toolbox for Quantum Physics.
        
       
    
    
    New Libraries
    
      This section contains libraries that are well-made and useful, but have
      not necessarily been battle-tested by a large userbase yet.
    
    
      - 
        Neural Network Libraries
        
          - 
            FedJAX - Federated
            learning in JAX, built on Optax and Haiku.
            
           
          - 
            Equivariant MLP
            - Construct equivariant neural network layers.
            
           
          - 
            jax-resnet -
            Implementations and checkpoints for ResNet variants in Flax.
            
           
        
       
      - 
        jax-unirep - Library
        implementing the
        UniRep model
        for protein machine learning applications.
        
       
      - 
        jax-flows -
        Normalizing flows in JAX.
        
       
      - 
        sklearn-jax-kernels
        - 
scikit-learn kernel matrices using JAX.
        
       
      - 
        jax-cosmo
        - Differentiable cosmology library.
        
       
      - 
        efax - Exponential
        Families in JAX.
        
       
      - 
        mpi4jax - Combine
        MPI operations with your Jax code on CPUs and GPUs.
        
       
      - 
        imax - Image augmentations
        and transformations.
        
       
      - 
        FlaxVision - Flax
        version of TorchVision.
        
       
      - 
        Oryx
        - Probabilistic programming language based on program transformations.
      
 
      - 
        Optimal Transport Tools
        - Toolbox that bundles utilities to solve optimal transport problems.
      
 
      - 
        delta PV - A
        photovoltaic simulator with automatic differentation.
        
       
      - 
        jaxlie - Lie theory
        library for rigid body transformations and optimization.
        
       
      - 
        BRAX - Differentiable
        physics engine to simulate environments along with learning algorithms
        to train agents for these environments.
        
       
      - 
        flaxmodels -
        Pretrained models for Jax/Flax.
        
       
      - 
        CR.Sparse -
        XLA accelerated algorithms for sparse representations and compressive
        sensing.
        
       
      - 
        exojax -
        Automatic differentiable spectrum modeling of exoplanets/brown dwarfs
        compatible to JAX.
        
       
      - 
        JAXopt - Hardware
        accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
        
       
      - 
        PIX - PIX is an image
        processing library in JAX, for JAX.
        
       
    
    
    Models and Projects
    JAX
    
    Flax
    
    Haiku
    
    Trax
    
      - 
        Reformer
        - Implementation of the Reformer (efficient transformer) architecture.
      
 
    
    
    Videos
    
      - 
        NeurIPS 2020: JAX Ecosystem Meetup
        - JAX, its use at DeepMind, and discussion between engineers,
        scientists, and JAX core team.
      
 
      - 
        Introduction to JAX - Simple
        neural network from scratch in JAX.
      
 
      - 
        JAX: Accelerated Machine Learning Research | SciPy 2020 |
          VanderPlas
        - JAX’s core design, how it’s powering new research, and how you can
        start using it.
      
 
      - 
        Bayesian Programming with JAX + NumPyro — Andy Kitchen
        - Introduction to Bayesian modelling using NumPyro.
      
 
      - 
        JAX: Accelerated machine-learning research via composable function
          transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne
        - JAX intro presentation in
        Program Transformations for Machine Learning
        workshop.
      
 
      - 
        JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James
          Bradbury
        - Presentation of TPU host access with demo.
      
 
      - 
        Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and
          Beyond | NeurIPS 2020
        - Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with
        Colab notebooks avaliable in
        Deep Implicit Layers.
      
 
      - 
        Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey
        - A four part YouTube tutorial series with Colab notebooks that starts
        with Jax fundamentals and moves up to training with a data parallel
        approach on a v3-32 TPU Pod slice.
      
 
      - 
        JAX, Flax & Transformers 🤗
        - 3 days of talks around JAX / Flax, Transformers, large-scale language
        modeling and other great topics.
      
 
    
    
    Papers
    
      This section contains papers focused on JAX (e.g. JAX-based library
      whitepapers, research on JAX, etc). Papers implemented in JAX are listed
      in the Models/Projects section.
    
    
    
      - 
        Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary.
          MLSys 2018.
        - White paper describing an early version of JAX, detailing how
        computation is traced and compiled.
      
 
      - 
        JAX, M.D.: A Framework for Differentiable Physics.
          Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020.
        - Introduces JAX, M.D., a differentiable physics library which includes
        simulation environments, interaction potentials, neural networks, and
        more.
      
 
      - 
        Enabling Fast Differentially Private SGD via Just-in-Time
            Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath.
          arXiv 2020.
        - Uses JAX’s JIT and VMAP to achieve faster differentially private than
        existing libraries.
        
      
 
    
    
    Tutorials and Blog Posts
    
    
    
    
    Contributing
    
      Contributions welcome! Read the
      contribution guidelines first.