Normalizing Flow models using Pytorch, Tensorflow and XLA accellarated JAX which allows for automatically parallelising code across multiple accelerators such as GPUs and TPUs.
The repository is under active developement.
- Linear Flow : [x] PyTorch [x] TensorFlow [x] Jax
- Non-Linear Flow: [x] PyTorch [x] TensorFlow [x] Jax
- Affine Flow: [x] PyTorch [x] TensorFlow [x] Jax
- Planar Flow: [ ] PyTorch [ ] TensorFlow [ ] Jax
- Radial Flow
- Coupling and Autoregressive Flows
- RealNVPs
- GLOW