Skip to content

Official PyTorch implementation of the Vectorized Conditional Neural Field.

License

Notifications You must be signed in to change notification settings

jhagnberger/vcnef

Repository files navigation

Vectorized Conditional Neural Field (VCNeF)

Jan Hagnberger, Marimuthu Kalimuthu, Daniel Musekamp, Mathias Niepert

ICML Conference ICLR Conference

This repository contains the official PyTorch implementation of the VCNeF model from the ICML'24 paper,
"Vectorized Conditional Neural Fields: A Framework for Solving Time-dependent Parametric Partial Differential Equations".

Requirements

The VCNeF model requires and is tested with the following packages.

Please also see the requirements.txt file which contains all packages to run the provided examples.

Usage

The following example shows how to use the VCNeF model.

import torch
from vcnef.vcnef_1d import VCNeFModel as VCNeF1DModel
from vcnef.vcnef_2d import VCNeFModel as VCNeF2DModel
from vcnef.vcnef_3d import VCNeFModel as VCNeF3DModel

model = VCNeF2DModel(num_channels=4,
                     condition_on_pde_param=True,
                     pde_param_dim=2,
                     d_model=256,
                     n_heads=8,
                     n_transformer_blocks=1,
                     n_modulation_blocks=6)

# Random data with shape b, s_x, s_y, c
x = torch.rand(4, 64, 64, 4)
grid = torch.rand(4, 64, 64, 2)
pde_param = torch.rand(4, 2)
t = torch.arange(1, 21).repeat(4, 1) / 20

y_hat = model(x, grid, pde_param, t)

Files

Below is a listing of the directory structure of VCNeF.

examples.py: Contains lightweight examples of how to use VCNeF.
examples_pde_bench.py: Contains examples of how to use VCNeF with PDEBench data and the PDEBench training loop.
📂 vcnef: Contains the code for the VCNeF model.
📂 utils: Contains utils for the PDEBench example.

Dataset for PDEBench Example

To use the PDEBench example examples_pde_bench.py, you have to download the PDEBench datasets. An overview of the avaiable data and how to download it can be found in the PDEBench repository. To use the downloaded datasets in the example, you have to adapt the path in base_path and the file name(s) in file_names.

VCNeF Architecture

The following illustation shows the architecture of the VCNeF model for solving 2D time-dependent PDEs (e.g., Navier-Stokes equations).

VCNeF Architecrture

Acknowledgements

The code of VCNeF is based on the code of Linear Transformers and PDEBench. We would like to thank the authors of Linear Transformers and PDEBench for their work, which made our method possible.

License

MIT licensed, except where otherwise stated. Please see LICENSE file.

Citation

If you find our project useful, please consider citing it.

@InProceedings{vcnef-hagnberger:2024,
  title = 	 {{V}ectorized {C}onditional {Ne}ural {F}ields: A Framework for Solving Time-dependent Parametric Partial Differential Equations},
  author =       {Hagnberger, Jan and Kalimuthu, Marimuthu and Musekamp, Daniel and Niepert, Mathias},
  booktitle = 	 {Proceedings of the 41st International Conference on Machine Learning},
  pages = 	 {17189--17223},
  year = 	 {2024},
  editor = 	 {Salakhutdinov, Ruslan and Kolter, Zico and Heller, Katherine and Weller, Adrian and Oliver, Nuria and Scarlett, Jonathan and Berkenkamp, Felix},
  volume = 	 {235},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {21--27 Jul},
  publisher =    {PMLR},
  pdf = 	 {https://raw.githubusercontent.com/mlresearch/v235/main/assets/hagnberger24a/hagnberger24a.pdf},
  url = 	 {https://proceedings.mlr.press/v235/hagnberger24a.html},
  abstract = 	 {Transformer models are increasingly used for solving Partial Differential Equations (PDEs). Several adaptations have been proposed, all of which suffer from the typical problems of Transformers, such as quadratic memory and time complexity. Furthermore, all prevalent architectures for PDE solving lack at least one of several desirable properties of an ideal surrogate model, such as (i) generalization to PDE parameters not seen during training, (ii) spatial and temporal zero-shot super-resolution, (iii) continuous temporal extrapolation, (iv) support for 1D, 2D, and 3D PDEs, and (v) efficient inference for longer temporal rollouts. To address these limitations, we propose <em>Vectorized Conditional Neural Fields</em> (VCNeFs), which represent the solution of time-dependent PDEs as neural fields. Contrary to prior methods, however, VCNeFs compute, for a set of multiple spatio-temporal query points, their solutions in parallel and model their dependencies through attention mechanisms. Moreover, VCNeF can condition the neural field on both the initial conditions and the parameters of the PDEs. An extensive set of experiments demonstrates that VCNeFs are competitive with and often outperform existing ML-based surrogate models.}
}