Skip to content

Official PyTorch implementation of PlaneMVS ( https://arxiv.org/abs/2203.12082 ) , a 3D plane reconstruction framework leveraging multi-view geometry and slanted plane hypotheses.

License

Notifications You must be signed in to change notification settings

oppo-us-research/PlaneMVS

Repository files navigation

PlaneMVS: 3D Plane Reconstruction from Multi-view Stereo (CVPR 2022)

Project Page Paper Arxiv Video Poster Slide

This is an official PyTorch implementation of the paper:

PlaneMVS: 3D Plane Reconstruction From Multi-View Stereo

Jiachen Liu*;‡1,2, Pan Ji†1, Nitin Bansal1, Changjiang Cai†‡1, Qingan Yan1, Xiaolei Huang2, Yi Xu1

1 OPPO US Research Center, 2 Penn State University

* Work done as an intern at OPPO US Research Center
Corresponding author
Code development and maintenance

🆕 Updates

  • 06/10/2024: Official code initially released per institutional approval.
  • 06/01/2022: PlaneMVS paper released, see arXiv paper.

📋 Table of Contents

  1. ✨ Highlights
  2. ⚙️ Setup
  3. 💾 Data Preparation for Training
  4. 🍀 Our Code Structure
  5. 🏋️ Model Weights
  6. ⏳ Training
  7. 📊 Testing and Evaluation
  8. ⚖️ License
  9. 🙏 Acknowledgements
  10. 📑 Citations

✨ Highlights

  • Our framework PlaneMVS supports:

    • Semantic Dense Mapping: Simultaneously perform semantic plane detection and dense 3D reconstruction.
    • Multi-view inputs: multi-view input RGB images or monocular RGB videos to recover the 3D scale without suffering scale ambiguity compared with single-view counterparts.
  • Our implementation gives an upgraded maskrcnn_benchmark to support latest PyTorch and CUDA Compilation. The code has been tested in the following environment:

    • PyTorch 2.2.0+cu121, Python 3.10, Ubuntu 22.04.4, via a virtual environment and a docker container(See details in Setup).
  • 🗣 Why to upgrade maskrcnn-benchmark?

    • Our PlaneMVS is implemented based on maskrcnn-benchmark. The original benchmark has been deprecated. It can be compiled only in old PyTorch (e.g., 1.6.0) but still cannot be supported by new GPUs. E.g., when compiling it with PyTorch 1.6.0 on a RTX3090 GPU, errors will be seen
    Error: NVIDIA GeForce RTX 3090 with CUDA capability sm_86 is not compatible with the current PyTorch installation. The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75.
    
  • Our implementation supports multi-GPU DDP (Distributed Data Parallel) and DP (Data Parallel) training, and single-gpu inference.

⚙️ Setup

Our code has been successfully tested in the following environment:

  • PyTorch 2.2.0+cu121, Python 3.10, Ubuntu 22.04.4.

To setup that environment and replicate our experimental results, please follow those steps.

Step 1: Download our code

# Clone the repo.
git clone https://github.com/oppo-us-research/PlaneMVS.git

# We assume PROJ_ROOT as the project directory.
# E.g., PROJ_ROOT=~/PlaneMVS
cd $PROJ_ROOT

Step 2: Configure a Docker environment

  • Follow the steps to install the Docker environment:
# assume you are in the PROJ_ROOT via cd $PROJ_ROOT
cd $PROJ_ROOT

## Build Docker image
cd docker && sh build.sh
# After the docker image is successfully built, 
# now you can run a Docker container
sh run.sh

## (Optional) [Useful Tips ✍]: To exit Docker container without stopping it, press Ctrl+P followed by Ctrl+Q; 
# If you want to exit the container's interactive shell session, 
# but do not want to interrupt the processes running in it,
# press Ctrl+P followed by Ctrl+Q. 
# This operation detaches the container and 
# allows you to return to your system's shell.

## (Optional) [Useful Tips ✍]: Re-enter the Container
# get the container id
docker ps
#  Run as root, e.g., if you want to install some libraries via `pip`;
#  Here `d89c34efb04a` is the container id;
docker exec -it -u 0 d89c34efb04a bash
#  Run as regular user
docker exec -it d89c34efb04a bash

## (Optional) [Useful Tips ✍]: To save the container to a new docker image;
# After the pip installation, save the container to an image. 
# You can run the following from the outside of the docker container.
docker commit -m "some notes you specified" d89c34efb04a ccj/planemvs:1.1

Step 3: Compile the maskrcnn_benchmark

cd $PROJ_ROOT
# to compile the cuda code, and generate '*.so' lib
# i.e., `third_party/maskrcnn_main/build/lib/maskrcnn_benchmark/_C.cpython-310-x86_64-linux-gnu.so`
./compile.sh
  • Please note that here we do not want to install the library in the system-level, like /usr/local/* via something like pip install maskrcnn_benchmark. Instead, we prefer to just build the libs in the local build directory in this project.

  • After running ./compile.sh to compile the CUDA codes (including: maskrcnn_benchmark/csrc/cuda/*.cu), it will generate the *.so library at third_party/maskrcnn_main/build/lib/maskrcnn_benchmark/_C.cpython-310-x86_64-linux-gnu.so.

  • We will explicitly load the *.so, for example, as shown in the file roi_align.py as:

    from build.lib.maskrcnn_benchmark import _C

After running those three steps, you have finished the configuration and ready to run the code at your will.

💾 Data Preparation for Training

We assume you have downloaded the ScanNetv2 dataset. Or you can have a look at this by SimpleRecon about how to download ScanNetv2.

Now we will focus on how to collect the plane annotation for network training.

  • Step 1. Download plane annotation from PlaneRCNN.
    • You will need to first download the plane annotation provided by PlaneRCNN and put the annotation directory separately into each scene of the ScanNet. Please check the PlaneRCNN repo about Training data preparation.
    • Then, unzip the raw semantic label archive (scene*_2d-label.zip) for each scene.
    • The ScanNet dataset directory is expected to have the following structure:
└── scannet_data #scannet main directory
    └── scans
    │   ├── scene0000_00
    │   │   ├── annotation
    │   │   │   ├── plane_info.npy
    │   │   │   ├── planes.npy # plane parameters which are represented in the global frame;
    │   │   │   └── segmentation #  Plane segmentation for each image view;
    │   │   │       ├── 0.png
    │   │   │       ├── 1.png
    │   │   │       ├── ...
    │   │   ├── frames
    │   │   │   ├── color
    │   │   │   │   ├── 0.jpg
    │   │   │   │   ├── 1.jpg
    │   │   │   │   ├── ...
    │   │   │   ├── depth
    │   │   │   │   ├── 0.png
    │   │   │   │   ├── 1.png
    │   │   │   │   ├── ...
    │   │   │   ├── intrinsic
    │   │   │   │   ├── extrinsic_color.txt
    │   │   │   │   ├── extrinsic_depth.txt
    │   │   │   │   ├── intrinsic_color.txt
    │   │   │   │   ├── intrinsic_depth.txt
    │   │   │   └── pose
    │   │   │   │   ├── 0.txt
    │   │   │   │   ├── 1.txt
    │   │   │   │   ├── ...
    │   │   ├── label-filt
    │   │   │   ├── 0.png
    │   │   │   ├── 1.png
    │   │   │   ├── ...
    │   ...
    │   ├── scene0000_01
    │   │   ├── ...
    │   ...
    │
    └── scannetv2-labels.combined.tsv
  • Step 2: mount to ./datasets/scannet_data: Mount the prepared ScanNet dataset to the local dir ./datasets/scannet_data via

    ln -s SCANNET_DATA_DIR $PROJ_ROOT/datasets/scannet_data
  • Step 3: Stereo Data Cleaning and Generation: For stereo data generation, please run the python script data_preparation/stereo_pre_process.py.

          # this is the main dataloader we used for training on ScanNet;
          "scannet_stereo_train": { 
              #these <key: value> elements should follow the arguments in __init__() of some Dataset;
              # for example: ScannetStereoDataset defined in src/datasets/scannet_stereo.py;
              #you can see the __init__(...) has the same args as below:
              
              "data_root": pjoin(DATA_DIR, "scannet_data/scans"),
              "data_split_file": "splits/sampled_stereo_train_files.txt",
              "cleaned_seg_dir": pjoin(DATA_DIR, "scannet_data/deltas_split_stereo_cleaned_segmentation"),
              "anchor_normal_dir": pjoin(DATA_DIR, "scannet_data/anchor_normals.npy"),
              "semantic_mapping_dir": "scannet_label_mapping",
              "pixel_plane_dir": pjoin(DATA_DIR, "sampled_pixel_planar_map"),
              "use_new_mapping": False,
              "split": "train",
              "mode": "mask"
          },
    • The train/val split will also be generated and saved at data_preparation/scannnet_splits/cleaned_scannet_stereo_train/val_files.txt. Make sure to move them into splits directory to follow the path of data_split_file. We maintain a list of *.txt files for different data splits. Please see splits for more details.

NOTE 📢: We recommend to make a soft link to mount your dataset to the local dir ./datasets. Otherwise, you have to configure your own dataset paths by modifying src/config/paths_catalog.py.

🍀 Our Code Structure

Model Definition 🕸

Config Parameters 📝

The default model configuration parameters are defined at src/config/defaults.py, and the data path parameters is listed at src/config/paths_catalog.py.

The model checkpoint loader is definied at src/utils/checkpoint.py.

ScanNet Dataloader 🔃

Adding Your Own Dataset and Dataloader 💽

Adding support for training on a new dataset can be done as src/datasets/scannet.py (as parent class) and src/data/datasets/scannet_stereo.py (as child class).

Once you have created your dataset, it needs to be registered in a couple of files:

🏋️ Model Weights

ResNet50-FPN Backbone

We need to first download the resnet50-FPN maskrcnn (model id: 6358792) pre-trained models to initialize PlaneRCNN detection head, which can be found in MODEL_ZOO.md. Please download that model and put it into the path you set at MODEL.WEIGHT as specified in the yaml file configs/plane_cfgs/planestereo.yaml. For example, we set WEIGHT: "checkpoints/saved/mask_rcnn/e2e_mask_rcnn_R_50_FPN_1x.pth".

Our Model Checkpoints

We provide several model checkpoints trained on ScannNetv2, by varying the training data size, training epoch and batch size, and so on.

While the model configuration is specified by the yaml file planestereo.yaml, other parameters can be changed dynamically in their scripts in the following formats:

  • Data size: 1) setting MODEL.STEREO.USE_ALL_DATA "True" will load a large training data (#samples=67k), specified by the split txt file valid_stereo_train_files.txt. 2) setting MODEL.STEREO.USE_ALL_DATA "False" will load a relatively smaller training data (#samples=20k), specified by the split txt file sampled_stereo_train_files.txt.

  • Training, e.g., NUM_EPOCHS, BATCH_SIZE, BASE_LEARNING_RATE can be changed in the scripts (see below).

Experiments Config. Training
Time
Depth
AbsDiff↓
Depth
SqRel↓
Depth
RMSE↓
Depth
$\delta$&lt;1.25↑
Detection
AP $^{0.2m}$
Detection
AP $^{0.4m}$
AP↑ mAP↑
model ckpt @Exp1
  • single-GPU training
  • sampled data (#samples=20k)
  • BS=6,LR=0.003,Epoch=10
24 hrs 0.088 0.027 0.186 0.925 0.441 0.516 0.542 0.491
model ckpt @Exp2
  • single-GPU training
  • all data (#samples=67k)
  • BS=6,LR=0.003,Epoch=10
70 hrs 0.083 0.023 0.175 0.936 0.459 0.533 0.553 0.519
Our best 🎯 model ckpt @Exp3
  • 2-GPU training
  • all data (#samples=67k)
  • BS=16,LR=0.003,Epoch=10
53 hrs 0.081 0.022 0.170 0.939 0.477 0.551 0.571 0.541
model ckpt @Exp4
  • 2-GPU training
  • sampled data (#samples=20k)
  • BS=16,LR=0.0003,Epoch=20
20 hrs 0.089 0.027 0.186 0.925 0.443 0.522 0.545 0.481

⏳ Training

Training on ScanNet

Run the python script train_net.py for network training. We provide several bash scripts to run the training on ScanNet, e.g., run_train_exp03.sh.

# DDP training on 2-GPU: GPU 0 and GPU 1; 
scripts/run_train_exp03.sh 0,1 

Resume

Sometimes you have to stop and resume the training due to NaN loss (check this issue #33) after several training iterations.

To resume training a process, make sure there is a file in your ./checkpoints/some_exp_folder/last_checkpoint. This last_checkpoint automatically points to the most recent checkpoint, e.g., it points to the file ./checkpoints_nfs/exp01-planemvs-epo10-bs6-dgx10/model_final.pth.

Then just set IS_RESUME='true in the bash script and run it to load previous model weights, optimizer and scheduler to resume the network training.

📊 Testing and Evaluation

Now you can run test_net.py for testing and evaluation. We provide several bash scripts to run the testing on ScanNet, e.g., run_test_exp03.sh.

# single GPU inference on GPU 0; 
scripts/run_test_exp03.sh 0

# (Optional) Or if you want to save terminal output to a .txt file;
scripts/run_test_exp03.sh 0 2>&1 | tee results/tmp_results.txt

It will give you the following depth metrics and detection metrics.

ccj@5b8fddfb3476:~/code/planemvs_proj$ ./runfiles/run_test_exp03.sh 0
[***] inference device = cuda
[****] Loading checkpoint from ./checkpoints_nfs/exp03-alldata-planemvs-epo10-bs16-dgx10/model_final.pth
building dataset using Dataset class  ScannetStereoDataset for scannet_stereo_val ...
100%|████████████████████████████| 950/950 [20:30<00:00,  1.30s/it]
Mean AP: [0.477 0.551 0.566 0.571 0.571]
Mean mAP: 0.5408594040438236
Mean PlaneSeg: [1.381 0.845 0.714]
Mean Planar Depth Metrics for Whole Img: [0.092 0.035 0.081 0.022 0.17  0.107 0.939 0.991 0.998]
Mean Pixel-Planar Depth Metrics for Whole Img: [0.095 0.036 0.083 0.023 0.174 0.109 0.936 0.991 0.998]
Mean Planar Depth Metric using gt masks for Whole Img: [0.088 0.034 0.081 0.021 0.165 0.104 0.943 0.992 0.998]
Mean Pred-Planar-Area Depth Metrics: [0.072 0.032 0.075 0.017 0.143 0.094 0.95  0.994 0.999]
Mean Gt-Planar-Area Depth Metrics: [0.078 0.033 0.076 0.018 0.149 0.097 0.948 0.993 0.999]
Mean Refined Gt-Planar-Area Depth Metrics: [0.077 0.033 0.076 0.018 0.149 0.097 0.948 0.993 0.999]
Mean Gt-Planar-Area Pixel-Planar Depth Metrics: [0.084 0.034 0.079 0.02  0.156 0.101 0.943 0.993 0.999]
========== img semantics(consider every pixels) ==========
{'Pixel Accuracy': 0.7544948156524123, 'Mean Accuracy': 0.6183070794385185, 'Frequency Weighted IoU': 0.6202744571246841, 'Mean IoU': 0.5111661750637606, 'Class IoU': {0: 0.5117866472963567, 1: 0.7441619843595424, 2: 0.46830892233623095, 3: 0.7841726233489069, 4: 0.5992126295858932, 5: 0.4093309051899402, 6: 0.5224285534707567, 7: 0.29198417923629466, 8: 0.44792802188280134, 9: 0.4074677811908966, 10: 0.3752509771344273, 11: 0.5719608757330807}}
========== Mean Plane Normal Diff: ==========
0.251
========== Mean Plane Offset Diff: ==========
0.193
========== Mean Plane n/d Diff: ==========
0.333
========== Mean Plane Geometric Diff: ==========
0.603
========== Median Plane Geometric Diff: ==========
0.285
========== Weighted Mean Plane Geometric Diff: ==========
0.258

Please also check the validation metrics of depth and detection evaluation for Our best 🎯 model ckpt @Exp3 as shown in the table above.

⚖️ License

PlaneMVS is licensed under MIT licence. For the third party maskrcnn-benchmark, please refer to its MIT license.

🙏 Acknowledgements

Our work adopts codes from maskrcnn-benchmark. We sincerely thank the owners for open sourcing their project.

📑 Citations

If you find our work useful, please consider citing our paper:

@InProceedings{liu2022planemvs,
    author    = {Liu, Jiachen and Ji, Pan and Bansal, Nitin and Cai, Changjiang and Yan, Qingan and Huang, Xiaolei and Xu, Yi},
    title     = {PlaneMVS: 3D Plane Reconstruction From Multi-View Stereo},
    booktitle = {CVPR},
    month     = {June},
    year      = {2022},
    pages     = {8665-8675}
}

Please also consider our another MVS paper if you find it useful:

@InProceedings{cai2023riavmvs,
    author    = {Cai, Changjiang and Ji, Pan and Yan, Qingan and Xu, Yi},
    title     = {RIAV-MVS: Recurrent-Indexing an Asymmetric Volume for Multi-View Stereo},
    booktitle = {CVPR},
    month     = {June},
    year      = {2023},
    pages     = {919-928}
}

and a neural active reconstruction paper if you find it helpful:

@article{feng2024naruto,
  title={NARUTO: Neural Active Reconstruction from Uncertain Target Observations},
  author={Feng, Ziyue and Zhan, Huangying and Chen, Zheng and Yan, Qingan and Xu, Xiangyu and Cai, Changjiang and Li, Bing and Zhu, Qilun and Xu, Yi},
  journal={arXiv preprint arXiv:2402.18771},
  year={2024}
}

About

Official PyTorch implementation of PlaneMVS ( https://arxiv.org/abs/2203.12082 ) , a 3D plane reconstruction framework leveraging multi-view geometry and slanted plane hypotheses.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published