Skip to content
/ mps-rnn Public

Code for 'From Tensor Network Quantum States to Tensorial Recurrent Neural Networks'.

License

Notifications You must be signed in to change notification settings

cqsl/mps-rnn

Repository files navigation

From Tensor Network Quantum States to Tensorial Recurrent Neural Networks

Paper link: arXiv:2206.12363 | Phys. Rev. Research 5, L032001 (2023)

Installation

The code requires Python >= 3.9. For reference, we use Python 3.10.12. We recommend creating a fresh virtual environment before installing. Use pip install -r requirements.txt to install the dependencies.

We recommend additionally installing CUDA, cuDNN, and a CUDA-accelerated jaxlib. The recent versions of jaxlib only support CUDA 12 and cuDNN 8.9:

pip install jaxlib==0.4.28+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Alternatively, you may install jax and jaxlib 0.4.25, which support CUDA 11 and cuDNN 8.6:

pip install jax jaxlib==0.4.25+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

The DMRG code in dmrg/ requires Julia >= 1.6. For reference, we use Julia 1.10.3. You need to activate the environment dmrg/Project.toml when running it. It includes MKL, which provides acceleration on Intel CPUs.

Usage

vmc.py trains a network. It will automatically read checkpoints when doing the hierarchical initialization (HI). args_parser.py contains all the configurations.

reproduce_hi.sh contains commands for the whole procedure of HI. plot_hi.py is invoked at the end to plot the energy during training.

run_tests.sh contains unit tests for the ansatzes.