Member-only story
JAX and Flax tutorials
1 min readSep 15, 2022
Looking for high-performance with deep learning networks? Look no further. JAX is here to help you.
JAX is a Python library offering high performance in machine learning with XLA and Just In Time (JIT) compilation. Its API is similar to NumPy’s, with a few differences. JAX ships with functionalities that aim to improve and increase speed in machine learning research. These functionalities include:
- Automatic differentiation
- Vectorization
- JIT compilation
We have provided various tutorials to get you started with JAX and Flax(the deep learning library for JAX). They include:
- What is JAX?
- Flax vs. TensorFlow
- JAX loss functions
- Activation functions in JAX and Flax
- Optimizers in JAX and Flax
- How to load datasets in JAX using TensorFlow
- Building Convolutional Neural Networks in JAX and Flax
- Distributed training in JAX
- Using TensorBoard in JAX and Flax
- LSTM in JAX & Flax
- Elegy (High-level API for deep learning in JAX & Flax)
- Transfer learning with JAX and Flax