A unified scientific machine learning framework built on JAX/Flax NNX
-
Updated
May 21, 2026 - Python
A unified scientific machine learning framework built on JAX/Flax NNX
Titans: Learning to Memorize at Test Time
nanoGPT for Diffusion Language Models
Unified benchmarking and profiling framework for the JAX scientific ML ecosystem. Timing, GPU/energy monitoring, FLOPS counting, roofline analysis, statistical testing, regression detection, and CI integration.
Flax NNX implementation of common metrics.
End-to-end differentiable bioinformatics pipelines built on JAX/Flax NNX
A Differentiable Data Pipeline Framework for JAX
A research-focused modular generative modeling library built on JAX/Flax NNX
Molecular active learning with JAX
KGE-JAXed: A simple knowledge graph embedding library created in JAX
JAX/Flax NNX port of Karpathy's nanochat, optimized for multi-host Cloud TPU pods.
A 2-in-1 notebook-based tutorial and implementation of Manifold-Constrained Hyper-Connections using JAX
JaxNN: Foundation Models in JAX/Flax
Recax is an experimental JAX/Flax NNX recommender systems framework
Add a description, image, and links to the flax-nnx topic page so that developers can more easily learn about it.
To associate your repository with the flax-nnx topic, visit your repo's landing page and select "manage topics."