Install directly from GitHub with:
pip install git+https://github.com/mariogeiger/allegro-jax
You can import the Allegro layer directly with:
from allegro_jax import AllegroLayer # For Flax
from allegro_jax import AllegroHaikuLayer # For Haiku
You can also import the whole model with:
from allegro_jax import Allegro # For Flax
from allegro_jax import AllegroHaiku # For Haiku
See test.py
for an example of usage.