Tax: A fully sharded data parallel trainer in jax
Published:

JAX is a powerful framework that I frequently use in my projects. Its features, like the scan
function, the ease of calculating gradients, and the ability to run code on both TPUs and GPUs, are significant advantages. However, it isn’t as mature as PyTorch in some areas.
In my research, I needed a training framework that could handle medium-sized models (around 2 billion parameters), so I decided to develop a fully sharded data-parallel trainer. While the primary focus is on research rather than production code, I thought it might be helpful to share it as a library.
You can check out the library here. I hope others find it useful too! Please star and cite if you use it.