Member-only story

JAX and Flax tutorials

Derrick Mwiti
1 min readSep 15, 2022

Looking for high-performance with deep learning networks? Look no further. JAX is here to help you.

JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy’s, with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include:

  • Automatic differentiation
  • Vectorization
  • JIT compilation

We have provided various tutorials to get you started with JAX and Flax(the deep learning library for JAX). They include:

--

--

No responses yet