JAX ML Frameworks
Libraries and frameworks built on JAX for neural networks, optimization, and machine learning research. Includes JAX-based neural network layers, training utilities, optimization algorithms, and domain-specific extensions (graphs, replay buffers, Earth observation). Does NOT include general Python ML frameworks, probabilistic programming languages, or domain applications built with JAX.
There are 122 jax ml frameworks tracked. 6 score above 70 (verified tier). The highest-rated is google-deepmind/optax at 85/100 with 2,207 stars and 3,543,273 monthly downloads. 8 of the top 10 are actively maintained.
Get all 122 projects as JSON
curl "https://pt-edge.onrender.com/api/v1/datasets/quality?domain=ml-frameworks&subcategory=jax-ml-frameworks&limit=20"
Open to everyone — 100 requests/day, no key needed. Get a free key for 1,000/day.
| # | Framework | Score | Tier |
|---|---|---|---|
| 1 |
google-deepmind/optax
Optax is a gradient processing and optimization library for JAX. |
|
Verified |
| 2 |
patrick-kidger/equinox
Elegant easy-to-use neural networks + scientific computing in JAX.... |
|
Verified |
| 3 |
explosion/thinc
🔮 A refreshing functional take on deep learning, compatible with your... |
|
Verified |
| 4 |
google/grain
Library for reading and processing ML training data. |
|
Verified |
| 5 |
extropic-ai/thrml
Thermodynamic Hypergraphical Model Library in JAX |
|
Verified |
| 6 |
patrick-kidger/optimistix
Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox.... |
|
Verified |
| 7 |
thomaspinder/GPJax
Gaussian processes in JAX and Flax. |
|
Established |
| 8 |
google-research/kauldron
Modular, scalable library to train ML models |
|
Established |
| 9 |
google-deepmind/dm-haiku
JAX-based neural network library |
|
Established |
| 10 |
google-deepmind/kfac-jax
Second Order Optimization and Curvature Estimation with K-FAC in JAX. |
|
Established |
| 11 |
google/jaxopt
Hardware accelerated, batchable and differentiable optimizers in JAX. |
|
Established |
| 12 |
patrick-kidger/diffrax
Numerical differential equation solvers in JAX. Autodifferentiable and... |
|
Established |
| 13 |
MichaelTMatthews/Craftax
(Crafter + NetHack) in JAX. ICML 2024 Spotlight. |
|
Established |
| 14 |
apax-hub/apax
A flexible and performant framework for training machine learning potentials. |
|
Established |
| 15 |
tumaer/JAXFLUIDS
Differentiable Fluid Dynamics Package |
|
Established |
| 16 |
google/torchax
torchax is a PyTorch frontend for JAX. It gives JAX the ability to author... |
|
Established |
| 17 |
lockwo/distreqx
Distrax, but in equinox. Lightweight JAX library of probability... |
|
Established |
| 18 |
camail-official/discretax
Discretax is a light weight collection of state space models implemented in JAX ⚡️ |
|
Established |
| 19 |
e3nn/e3nn-jax
jax library for E3 Equivariant Neural Networks |
|
Established |
| 20 |
ekzhang/jax-js
JAX in JavaScript – ML library for the web, running on WebGPU & Wasm |
|
Established |
| 21 |
apple/axlearn
An Extensible Deep Learning Library |
|
Established |
| 22 |
Ceyron/exponax
Efficient Differentiable n-d PDE Solvers in JAX. |
|
Established |
| 23 |
instadeepai/flashbax
⚡ Flashbax: Accelerated Replay Buffers in JAX |
|
Established |
| 24 |
flaport/sax
S + Autograd + XLA :: S-parameter based frequency domain circuit simulations... |
|
Established |
| 25 |
jax-ml/bonsai
Minimal, lightweight JAX implementations of popular models. |
|
Established |
| 26 |
sotetsuk/pgx
♟️ Vectorized RL game environments in JAX |
|
Established |
| 27 |
dpiras/cosmopower-jax
Differentiable cosmological emulators: the JAX version of CosmoPower |
|
Emerging |
| 28 |
poets-ai/elegy
A High Level API for Deep Learning in JAX |
|
Emerging |
| 29 |
GalacticDynamics/diffraxtra
Extras for Diffrax: OOP and vectorization |
|
Emerging |
| 30 |
n2cholas/awesome-jax
JAX - A curated list of resources https://github.com/google/jax |
|
Emerging |
| 31 |
Dicklesworthstone/model_guided_research
Systematic investigation of 11 exotic math frameworks (Lie groups, tropical... |
|
Emerging |
| 32 |
arpastrana/jax_fdm
Auto-differentiable and hardware-accelerated force density method |
|
Emerging |
| 33 |
BirkhoffG/jax-dataloader
Pytorch-like dataloaders for JAX. |
|
Emerging |
| 34 |
google/trax
Trax — Deep Learning with Clear Code and Speed |
|
Emerging |
| 35 |
bsc-quantic/tn4ml
Tensor Networks for Machine Learning |
|
Emerging |
| 36 |
perrin-isir/xpag
a modular reinforcement learning library with JAX agents |
|
Emerging |
| 37 |
ergodicio/adept
Automatic-Differentiation-Enabled Plasma Transport in JAX |
|
Emerging |
| 38 |
Ceyron/pdequinox
Neural Emulator Architectures in JAX. |
|
Emerging |
| 39 |
matthias-wright/flaxmodels
Pretrained deep learning models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet, etc. |
|
Emerging |
| 40 |
thorben-frank/mlff
Build neural networks for machine learning force fields with JAX |
|
Emerging |
| 41 |
srush/annotated-s4
Implementation of https://srush.github.io/annotated-s4 |
|
Emerging |
| 42 |
danijar/ninjax
General Modules for JAX |
|
Emerging |
| 43 |
thebuckleylab/jpc
Flexible Inference for Predictive Coding Networks in JAX. |
|
Emerging |
| 44 |
gordicaleksa/get-started-with-JAX
The purpose of this repo is to make it easy to get started with JAX, Flax,... |
|
Emerging |
| 45 |
tfm000/copulax
JAX based probability modelling |
|
Emerging |
| 46 |
google-deepmind/jeo
Jeo: Jax model training lib for Earth Observation |
|
Emerging |
| 47 |
francois-rozet/inox
Stainless neural networks in JAX |
|
Emerging |
| 48 |
tinker495/JAxtar
JAxtar is a project with a JAX-native implementation of parallelizeable A* &... |
|
Emerging |
| 49 |
genjax-community/genjax
Probabilistic programming with programmable inference for parallel accelerators. |
|
Emerging |
| 50 |
matomatical/hijax
An introduction to vanilla JAX for deep learning research |
|
Emerging |
| 51 |
MahmudulAlam/Holographic-Reduced-Representations
Holographic Reduced Representations |
|
Emerging |
| 52 |
shyamsn97/hyper-nn
Easy Hypernetworks in Pytorch and Jax |
|
Emerging |
| 53 |
FLAIROx/Kinetix
Reinforcement learning on general 2D physics environments in JAX. ICLR 2025 Oral. |
|
Emerging |
| 54 |
instadeepai/catx
🐈⬛ Contextual bandits library for continuous action trees with smoothing in JAX |
|
Emerging |
| 55 |
jeertmans/sampling-paths
Generative Path Candidate Sampling for Faster Point-to-Point Ray Tracing |
|
Emerging |
| 56 |
google-deepmind/jraph
A Graph Neural Network Library in Jax |
|
Emerging |
| 57 |
eserie/wax-ml
A Python library for machine-learning and feedback loops on streaming data |
|
Emerging |
| 58 |
RobinKa/jaxga
Geometric Algebra package for JAX |
|
Emerging |
| 59 |
ikostrikov/jaxrl
JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with... |
|
Emerging |
| 60 |
liblaf/apple
🍎 A JAX and Warp library for differentiable physics simulation, featuring... |
|
Emerging |
| 61 |
duyongan/sunstreaker
以jax为后端的类似keras的框架 |
|
Emerging |
| 62 |
auto-differentiation/xad-py
High-Performance Automatic Differentiation for Python |
|
Emerging |
| 63 |
lockwo/awesome-jax
Curated list of JAX Resources and Packages |
|
Emerging |
| 64 |
BobMcDear/flaim
Flax Image Models - State-of-the-art pre-trained vision backbones for Flax. |
|
Emerging |
| 65 |
AakashKumarNain/TF_JAX_tutorials
All about the fundamental blocks of TF and JAX! |
|
Emerging |
| 66 |
wladekpal/golden-standard
Is Temporal Difference Learning the Gold Standard for Stitching in RL? Code... |
|
Emerging |
| 67 |
AaltoML/kalman-jax
Approximate inference for Markov Gaussian processes using iterated Kalman... |
|
Emerging |
| 68 |
mancusolab/susiepca
Scalable Ultra-Sparse Bayesian PCA |
|
Emerging |
| 69 |
mila-iqia/torch_jax_interop
Simple tools to mix and match PyTorch and Jax - Get the best of both worlds! |
|
Emerging |
| 70 |
m-wojnar/reinforced-lib
Reinforcement learning library |
|
Emerging |
| 71 |
cgarciae/treex
A Pytree Module system for Deep Learning in JAX |
|
Emerging |
| 72 |
affjljoo3581/deit3-jax
Jax/Flax implementation of DeiT and DeiT-III (ViT) |
|
Emerging |
| 73 |
ivy-llc/mech
Mechanics functions with end-to-end support for deep learning developers,... |
|
Emerging |
| 74 |
Twistient/HoloVec
Holographic vectors you can compute with. Bind structure, bundle sets,... |
|
Emerging |
| 75 |
danielkelshaw/riemax
Riemannian geometry in JAX |
|
Emerging |
| 76 |
evgenii-nikishin/omd
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning... |
|
Emerging |
| 77 |
XanaduAI/GradDFT
GradDFT is a JAX-based library enabling the differentiable design and... |
|
Experimental |
| 78 |
IvanIsCoding/GNN-for-Combinatorial-Optimization
JAX + Flax implementation of "Combinatorial Optimization with... |
|
Experimental |
| 79 |
satojkovic/vit-jax-flax
Vision Transformer from scratch (JAX/Flax). |
|
Experimental |
| 80 |
ericjang/pt-jax
Path Tracing in JAX |
|
Experimental |
| 81 |
google-deepmind/dks
Multi-framework implementation of Deep Kernel Shaping and Tailored... |
|
Experimental |
| 82 |
evgenii-nikishin/rl_with_resets
JAX implementation of deep RL agents with resets from the paper "The Primacy... |
|
Experimental |
| 83 |
HomebrewML/revlib
Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed... |
|
Experimental |
| 84 |
ml-for-gp/jaxgptoolbox
Geometry processing utilities compatible with jax for autodifferentiation. |
|
Experimental |
| 85 |
omron-sinicx/jaxmapp
JAX-based implementation for multi-agent path planning (MAPP) in continuous spaces. |
|
Experimental |
| 86 |
Anuoluwapo65/pytorch-jax-implementation
pytorch jax |
|
Experimental |
| 87 |
phlippe/jax_trainer
Lightning-like training API for JAX with Flax |
|
Experimental |
| 88 |
google-research/jestimator
Amos optimizer with JEstimator lib. |
|
Experimental |
| 89 |
davisyoshida/haiku-mup
A port of muP to JAX/Haiku |
|
Experimental |
| 90 |
NITHISHM2410/flax-pilot
A Simplistic trainer for Flax |
|
Experimental |
| 91 |
amoudgl/celo
Code for Celo: Training Versatile Learned Optimizers on a Compute Diet |
|
Experimental |
| 92 |
juliuskunze/cwvae-jax
Clockwork VAEs in JAX/Flax |
|
Experimental |
| 93 |
Auxeno/ion
A minimal neural network library for JAX |
|
Experimental |
| 94 |
camml-lab/reax
REAX — Scalable, flexible training for JAX, inspired by the simplicity of... |
|
Experimental |
| 95 |
alexOarga/haiku-geometric
A collection of graph neural networks implementations in JAX |
|
Experimental |
| 96 |
malbertosm/frp_rl
Explore the "frp_rl" repository to discover the Free Random Projection... |
|
Experimental |
| 97 |
yonesuke/jaxfss
JAX/Flax implementation of finite-size scaling |
|
Experimental |
| 98 |
cor3bit/awesome-soms
A curated list of resources for second-order stochastic optimization |
|
Experimental |
| 99 |
phydra-labs/phydrax
Modular Physics-ML Components in JAX |
|
Experimental |
| 100 |
graphcore-research/jax-experimental
JAX for Graphcore IPU (experimental) |
|
Experimental |
| 101 |
pythoncrazy/jimm
JAX Image Modeling of Models contains Computer Vision/Vision Language Model... |
|
Experimental |
| 102 |
OleksiiBevza/jaxpsmc
JAX based Preconditioned Sequential Monte Carlo framework |
|
Experimental |
| 103 |
ethanluoyc/magi
Reinforcement learning library in JAX. |
|
Experimental |
| 104 |
forynski/jax-pid-nn
High-performance JAX/Flax neural network for particle identification in... |
|
Experimental |
| 105 |
cgarciae/nnx
Neural Networks for JAX |
|
Experimental |
| 106 |
mzguntalan/neptune
[WIP] Neptune: JAX iterop-able library in Haskell. |
|
Experimental |
| 107 |
stefanosele/GPfY
Gaussian processes with spherical harmonic features in JAX |
|
Experimental |
| 108 |
jrajath94/jax-transformer-impl
JAX/XLA Transformer with MHA, MQA, GQA (Ainslie et al. 2023) — JIT, vmap, pmap |
|
Experimental |
| 109 |
mzguntalan/zephyr
Zephyr is a declarative neural network library on top of JAX allowing for... |
|
Experimental |
| 110 |
norabelrose/classroom
Preference-based reinforcement learning in PyTorch and JAX with a browser-based GUI. |
|
Experimental |
| 111 |
lweitkamp/GANs-JAX
Implementation of several Generative Adversarial Networks in JAX / Flax |
|
Experimental |
| 112 |
yardenas/jax-dreamer
Dreamer on JAX |
|
Experimental |
| 113 |
cifkao/jax-spectral
Short-time Fourier transform (STFT) for JAX |
|
Experimental |
| 114 |
MizuhoAOKI/jax_playground
A collection of hands-on examples for exploring numerical algorithms with JAX. |
|
Experimental |
| 115 |
ethanluoyc/corax
Corax: Core RL in JAX |
|
Experimental |
| 116 |
nathanwispinski/meta-rl
A short conceptual replication of "Prefrontal cortex as a meta-reinforcement... |
|
Experimental |
| 117 |
ASEM000/serket
The ✨Magical✨ JAX ML Library. |
|
Experimental |
| 118 |
wolfwdavid/jax-pinn
JAX/Flax physics-informed neural network with jax2tf export — benchmark JAX... |
|
Experimental |
| 119 |
Ceyron/trainax
Training methodologies for autoregressive neural operators/emulators in JAX. |
|
Experimental |
| 120 |
ysngshn/ivon-optax
An Optax-based JAX implementation of the IVON optimizer for large-scale VI... |
|
Experimental |
| 121 |
ScottAlexanderCameron/Jynx
A neural network library written in jax |
|
Experimental |
| 122 |
dtunai/LongConv-Jax
Jax/Flax/Linen implementation of "Simple Hardware-Efficient Long... |
|
Experimental |