JAX ML Frameworks

Libraries and frameworks built on JAX for neural networks, optimization, and machine learning research. Includes JAX-based neural network layers, training utilities, optimization algorithms, and domain-specific extensions (graphs, replay buffers, Earth observation). Does NOT include general Python ML frameworks, probabilistic programming languages, or domain applications built with JAX.

There are 122 jax ml frameworks tracked. 6 score above 70 (verified tier). The highest-rated is google-deepmind/optax at 85/100 with 2,207 stars and 3,543,273 monthly downloads. 8 of the top 10 are actively maintained.

Get all 122 projects as JSON

curl "https://pt-edge.onrender.com/api/v1/datasets/quality?domain=ml-frameworks&subcategory=jax-ml-frameworks&limit=20"

Open to everyone — 100 requests/day, no key needed. Get a free key for 1,000/day.

# Framework Score Tier
1 google-deepmind/optax

Optax is a gradient processing and optimization library for JAX.

85
Verified
2 patrick-kidger/equinox

Elegant easy-to-use neural networks + scientific computing in JAX....

84
Verified
3 explosion/thinc

🔮 A refreshing functional take on deep learning, compatible with your...

80
Verified
4 google/grain

Library for reading and processing ML training data.

79
Verified
5 extropic-ai/thrml

Thermodynamic Hypergraphical Model Library in JAX

73
Verified
6 patrick-kidger/optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox....

71
Verified
7 thomaspinder/GPJax

Gaussian processes in JAX and Flax.

69
Established
8 google-research/kauldron

Modular, scalable library to train ML models

66
Established
9 google-deepmind/dm-haiku

JAX-based neural network library

65
Established
10 google-deepmind/kfac-jax

Second Order Optimization and Curvature Estimation with K-FAC in JAX.

63
Established
11 google/jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.

63
Established
12 patrick-kidger/diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and...

63
Established
13 MichaelTMatthews/Craftax

(Crafter + NetHack) in JAX. ICML 2024 Spotlight.

62
Established
14 apax-hub/apax

A flexible and performant framework for training machine learning potentials.

59
Established
15 tumaer/JAXFLUIDS

Differentiable Fluid Dynamics Package

59
Established
16 google/torchax

torchax is a PyTorch frontend for JAX. It gives JAX the ability to author...

58
Established
17 lockwo/distreqx

Distrax, but in equinox. Lightweight JAX library of probability...

58
Established
18 camail-official/discretax

Discretax is a light weight collection of state space models implemented in JAX ⚡️

55
Established
19 e3nn/e3nn-jax

jax library for E3 Equivariant Neural Networks

55
Established
20 ekzhang/jax-js

JAX in JavaScript – ML library for the web, running on WebGPU & Wasm

53
Established
21 apple/axlearn

An Extensible Deep Learning Library

53
Established
22 Ceyron/exponax

Efficient Differentiable n-d PDE Solvers in JAX.

51
Established
23 instadeepai/flashbax

⚡ Flashbax: Accelerated Replay Buffers in JAX

51
Established
24 flaport/sax

S + Autograd + XLA :: S-parameter based frequency domain circuit simulations...

51
Established
25 jax-ml/bonsai

Minimal, lightweight JAX implementations of popular models.

50
Established
26 sotetsuk/pgx

♟️ Vectorized RL game environments in JAX

50
Established
27 dpiras/cosmopower-jax

Differentiable cosmological emulators: the JAX version of CosmoPower

49
Emerging
28 poets-ai/elegy

A High Level API for Deep Learning in JAX

48
Emerging
29 GalacticDynamics/diffraxtra

Extras for Diffrax: OOP and vectorization

48
Emerging
30 n2cholas/awesome-jax

JAX - A curated list of resources https://github.com/google/jax

48
Emerging
31 Dicklesworthstone/model_guided_research

Systematic investigation of 11 exotic math frameworks (Lie groups, tropical...

46
Emerging
32 arpastrana/jax_fdm

Auto-differentiable and hardware-accelerated force density method

46
Emerging
33 BirkhoffG/jax-dataloader

Pytorch-like dataloaders for JAX.

46
Emerging
34 google/trax

Trax — Deep Learning with Clear Code and Speed

46
Emerging
35 bsc-quantic/tn4ml

Tensor Networks for Machine Learning

46
Emerging
36 perrin-isir/xpag

a modular reinforcement learning library with JAX agents

45
Emerging
37 ergodicio/adept

Automatic-Differentiation-Enabled Plasma Transport in JAX

45
Emerging
38 Ceyron/pdequinox

Neural Emulator Architectures in JAX.

43
Emerging
39 matthias-wright/flaxmodels

Pretrained deep learning models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet, etc.

43
Emerging
40 thorben-frank/mlff

Build neural networks for machine learning force fields with JAX

42
Emerging
41 srush/annotated-s4

Implementation of https://srush.github.io/annotated-s4

41
Emerging
42 danijar/ninjax

General Modules for JAX

41
Emerging
43 thebuckleylab/jpc

Flexible Inference for Predictive Coding Networks in JAX.

41
Emerging
44 gordicaleksa/get-started-with-JAX

The purpose of this repo is to make it easy to get started with JAX, Flax,...

41
Emerging
45 tfm000/copulax

JAX based probability modelling

40
Emerging
46 google-deepmind/jeo

Jeo: Jax model training lib for Earth Observation

40
Emerging
47 francois-rozet/inox

Stainless neural networks in JAX

40
Emerging
48 tinker495/JAxtar

JAxtar is a project with a JAX-native implementation of parallelizeable A* &...

40
Emerging
49 genjax-community/genjax

Probabilistic programming with programmable inference for parallel accelerators.

40
Emerging
50 matomatical/hijax

An introduction to vanilla JAX for deep learning research

39
Emerging
51 MahmudulAlam/Holographic-Reduced-Representations

Holographic Reduced Representations

39
Emerging
52 shyamsn97/hyper-nn

Easy Hypernetworks in Pytorch and Jax

39
Emerging
53 FLAIROx/Kinetix

Reinforcement learning on general 2D physics environments in JAX. ICLR 2025 Oral.

38
Emerging
54 instadeepai/catx

🐈‍⬛ Contextual bandits library for continuous action trees with smoothing in JAX

37
Emerging
55 jeertmans/sampling-paths

Generative Path Candidate Sampling for Faster Point-to-Point Ray Tracing

37
Emerging
56 google-deepmind/jraph

A Graph Neural Network Library in Jax

37
Emerging
57 eserie/wax-ml

A Python library for machine-learning and feedback loops on streaming data

37
Emerging
58 RobinKa/jaxga

Geometric Algebra package for JAX

37
Emerging
59 ikostrikov/jaxrl

JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with...

37
Emerging
60 liblaf/apple

🍎 A JAX and Warp library for differentiable physics simulation, featuring...

36
Emerging
61 duyongan/sunstreaker

以jax为后端的类似keras的框架

36
Emerging
62 auto-differentiation/xad-py

High-Performance Automatic Differentiation for Python

34
Emerging
63 lockwo/awesome-jax

Curated list of JAX Resources and Packages

34
Emerging
64 BobMcDear/flaim

Flax Image Models - State-of-the-art pre-trained vision backbones for Flax.

33
Emerging
65 AakashKumarNain/TF_JAX_tutorials

All about the fundamental blocks of TF and JAX!

33
Emerging
66 wladekpal/golden-standard

Is Temporal Difference Learning the Gold Standard for Stitching in RL? Code...

32
Emerging
67 AaltoML/kalman-jax

Approximate inference for Markov Gaussian processes using iterated Kalman...

32
Emerging
68 mancusolab/susiepca

Scalable Ultra-Sparse Bayesian PCA

32
Emerging
69 mila-iqia/torch_jax_interop

Simple tools to mix and match PyTorch and Jax - Get the best of both worlds!

32
Emerging
70 m-wojnar/reinforced-lib

Reinforcement learning library

31
Emerging
71 cgarciae/treex

A Pytree Module system for Deep Learning in JAX

31
Emerging
72 affjljoo3581/deit3-jax

Jax/Flax implementation of DeiT and DeiT-III (ViT)

31
Emerging
73 ivy-llc/mech

Mechanics functions with end-to-end support for deep learning developers,...

31
Emerging
74 Twistient/HoloVec

Holographic vectors you can compute with. Bind structure, bundle sets,...

31
Emerging
75 danielkelshaw/riemax

Riemannian geometry in JAX

30
Emerging
76 evgenii-nikishin/omd

JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning...

30
Emerging
77 XanaduAI/GradDFT

GradDFT is a JAX-based library enabling the differentiable design and...

29
Experimental
78 IvanIsCoding/GNN-for-Combinatorial-Optimization

JAX + Flax implementation of "Combinatorial Optimization with...

29
Experimental
79 satojkovic/vit-jax-flax

Vision Transformer from scratch (JAX/Flax).

28
Experimental
80 ericjang/pt-jax

Path Tracing in JAX

28
Experimental
81 google-deepmind/dks

Multi-framework implementation of Deep Kernel Shaping and Tailored...

28
Experimental
82 evgenii-nikishin/rl_with_resets

JAX implementation of deep RL agents with resets from the paper "The Primacy...

27
Experimental
83 HomebrewML/revlib

Simple and efficient RevNet-Library for PyTorch with XLA and DeepSpeed...

27
Experimental
84 ml-for-gp/jaxgptoolbox

Geometry processing utilities compatible with jax for autodifferentiation.

27
Experimental
85 omron-sinicx/jaxmapp

JAX-based implementation for multi-agent path planning (MAPP) in continuous spaces.

27
Experimental
86 Anuoluwapo65/pytorch-jax-implementation

pytorch jax

26
Experimental
87 phlippe/jax_trainer

Lightning-like training API for JAX with Flax

26
Experimental
88 google-research/jestimator

Amos optimizer with JEstimator lib.

26
Experimental
89 davisyoshida/haiku-mup

A port of muP to JAX/Haiku

26
Experimental
90 NITHISHM2410/flax-pilot

A Simplistic trainer for Flax

26
Experimental
91 amoudgl/celo

Code for Celo: Training Versatile Learned Optimizers on a Compute Diet

25
Experimental
92 juliuskunze/cwvae-jax

Clockwork VAEs in JAX/Flax

25
Experimental
93 Auxeno/ion

A minimal neural network library for JAX

25
Experimental
94 camml-lab/reax

REAX — Scalable, flexible training for JAX, inspired by the simplicity of...

25
Experimental
95 alexOarga/haiku-geometric

A collection of graph neural networks implementations in JAX

24
Experimental
96 malbertosm/frp_rl

Explore the "frp_rl" repository to discover the Free Random Projection...

24
Experimental
97 yonesuke/jaxfss

JAX/Flax implementation of finite-size scaling

24
Experimental
98 cor3bit/awesome-soms

A curated list of resources for second-order stochastic optimization

24
Experimental
99 phydra-labs/phydrax

Modular Physics-ML Components in JAX

24
Experimental
100 graphcore-research/jax-experimental

JAX for Graphcore IPU (experimental)

23
Experimental
101 pythoncrazy/jimm

JAX Image Modeling of Models contains Computer Vision/Vision Language Model...

23
Experimental
102 OleksiiBevza/jaxpsmc

JAX based Preconditioned Sequential Monte Carlo framework

23
Experimental
103 ethanluoyc/magi

Reinforcement learning library in JAX.

23
Experimental
104 forynski/jax-pid-nn

High-performance JAX/Flax neural network for particle identification in...

22
Experimental
105 cgarciae/nnx

Neural Networks for JAX

22
Experimental
106 mzguntalan/neptune

[WIP] Neptune: JAX iterop-able library in Haskell.

22
Experimental
107 stefanosele/GPfY

Gaussian processes with spherical harmonic features in JAX

22
Experimental
108 jrajath94/jax-transformer-impl

JAX/XLA Transformer with MHA, MQA, GQA (Ainslie et al. 2023) — JIT, vmap, pmap

22
Experimental
109 mzguntalan/zephyr

Zephyr is a declarative neural network library on top of JAX allowing for...

21
Experimental
110 norabelrose/classroom

Preference-based reinforcement learning in PyTorch and JAX with a browser-based GUI.

21
Experimental
111 lweitkamp/GANs-JAX

Implementation of several Generative Adversarial Networks in JAX / Flax

20
Experimental
112 yardenas/jax-dreamer

Dreamer on JAX

20
Experimental
113 cifkao/jax-spectral

Short-time Fourier transform (STFT) for JAX

20
Experimental
114 MizuhoAOKI/jax_playground

A collection of hands-on examples for exploring numerical algorithms with JAX.

19
Experimental
115 ethanluoyc/corax

Corax: Core RL in JAX

16
Experimental
116 nathanwispinski/meta-rl

A short conceptual replication of "Prefrontal cortex as a meta-reinforcement...

15
Experimental
117 ASEM000/serket

The ✨Magical✨ JAX ML Library.

15
Experimental
118 wolfwdavid/jax-pinn

JAX/Flax physics-informed neural network with jax2tf export — benchmark JAX...

14
Experimental
119 Ceyron/trainax

Training methodologies for autoregressive neural operators/emulators in JAX.

14
Experimental
120 ysngshn/ivon-optax

An Optax-based JAX implementation of the IVON optimizer for large-scale VI...

14
Experimental
121 ScottAlexanderCameron/Jynx

A neural network library written in jax

14
Experimental
122 dtunai/LongConv-Jax

Jax/Flax/Linen implementation of "Simple Hardware-Efficient Long...

12
Experimental