Attempt at reproducing "exponential-map-sum-of-radial-flow" results (EMSRE Table 1) of Normalizing Flows on Tori and Spheres by Rezende et al., done in JAX.
The target density on
To train the flow with hyperparameters
$ python3 main.py --N=1 --K=12
Comparing the authors' results (average) and ours (single run as is), with the convention of theirs / ours:
Model | KL | ESS |
---|---|---|
EMSRE(N=1, K=12) | 0.82 / 0.78 | 42 % / 42 % |
EMSRE(N=6, K=5) | 0.19 / 0.19 | 75 % / 82 % |
Scenario N=24, K=1 was also attempted, though unsuccessfully in 20,000 iterations claimed.