Skip to content

A 3D Gaussian Splatting framework with various derived algorithms and an interactive web viewer

Notifications You must be signed in to change notification settings

yzslab/gaussian-splatting-lightning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Gaussian Splatting PyTorch Lightning Implementation

Known issues

  • Multi-GPU training can only be enabled after densification

Features

1. Installation

1.1. Clone repository

# clone repository
git clone --recursive https://github.com/yzslab/gaussian-splatting-lightning.git
cd gaussian-splatting-lightning
  • If you forgot the --recursive options, you can run below git commands after cloning:

     git submodule sync --recursive
     git submodule update --init --recursive --force

1.2. Create virtual environment

# create virtual environment
conda create -yn gspl python=3.9 pip
conda activate gspl

1.3. Install PyTorch

  • Tested on PyTorch==2.0.1

  • You must install the one match to the version of your nvcc (nvcc --version)

  • For CUDA 11.8

    pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118

1.4. Install requirements

pip install -r requirements.txt

1.5. Install optional packages

  • If you want to use nerfstudio-project/gsplat

    pip install git+https://github.com/yzslab/gsplat.git

    This command will install my modified version, which is required by LightGaussian and Mip-Splatting. If you do not need them, you can also install vanilla gsplat v0.1.12.

  • If you need SegAnyGaussian

    • gsplat (see command above)
    • pip install hdbscan scikit-learn==1.3.2 git+https://github.com/facebookresearch/segment-anything.git
    • facebookresearch/pytorch3d
    • Download ViT-H SAM model, place it to the root dir of this repo.: wget -O sam_vit_h_4b8939.pth https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

2. Training

2.1. Basic command

python main.py fit \
    --data.path DATASET_PATH \
    -n EXPERIMENT_NAME

It can detect some dataset type automatically. You can also specify type with option --data.type. Possible values are: colmap, blender, nsvf, nerfies, matrixcity, phototourism.

[NOTE] By default, only checkpoint files will be produced on training end. If you need ply file in vanilla 3DGS's format (can be loaded by SIBR_viewer or some WebGL/GPU based viewer):

  • [Option 1]: Convert checkpoint file to ply: python utils/ckpt2ply.py TRAINING_OUTPUT_PATH, e.g.:
    • python utils/ckpt2ply.py outputs/lego
    • python utils/ckpt2ply.py outputs/lego/checkpoints/epoch=300-step=30000.ckpt
  • [Option 2]: Start training with option: --model.save_ply true

2.2. Some useful options

  • Run training with web viewer
python main.py fit \
    --viewer \
    ...
  • It is recommended to use config file configs/blender.yaml when training on blender dataset.
python main.py fit \
    --config configs/blender.yaml \
    ...
# the requirements of mask
#   * must be single channel
#   * zero(black) represent the masked pixel (won't be used to supervise learning)
#   * the filename of the mask file must be image filename + '.png', 
#     e.g.: the mask of '001.jpg' is '001.jpg.png'
--data.params.colmap.mask_dir MASK_DIR_PATH
  • Use downsampled images (colmap dataset only)

You can use utils/image_downsample.py to downsample your images, e.g. 4x downsample: python utils/image_downsample.py PATH_TO_DIRECTORY_THAT_STORE_IMAGES --factor 4

# it will load images from `images_4` directory
--data.params.colmap.down_sample_factor 4
  • Load large dataset without OOM
--data.params.train_max_num_images_to_cache 1024

Make sure that command which nvcc can produce output, or gsplat will be disabled automatically.

python main.py fit \
    --config configs/gsplat.yaml \
    ...

2.4. Multi-GPU training

[NOTE] Multi-GPU training can only be enabled after densification. You can start a single GPU training at the beginning, and save a checkpoint after densification finishing. Then resume from this checkpoint and enable multi-GPU training.

You will get improved PSNR and SSIM with more GPUs: image

# Single GPU at the beginning
python main.py fit \
    --config ... \
    --data.path DATASET_PATH \
    --model.gaussian.optimization.densify_until_iter 15000 \
    --max_steps 15000
# Then resume, and enable multi-GPU
python main.py fit \
    --config ... \
    --trainer configs/ddp.yaml \
    --data.path DATASET_PATH \
    --max_steps 30000 \
    --ckpt_path last  # find latest checkpoint automatically, or provide a path to checkpoint file

deform-gs-new.mp4

python main.py fit \
    --config configs/deformable_blender.yaml \
    --data.path ...
python main.py fit \
    --config configs/mip_splatting_gsplat.yaml \
    --data.path ...
  • Prune & finetune only currently

  • Train & densify & prune

    ... fit \
        --config configs/light_gaussian/train_densify_prune-gsplat.yaml \
        --data.path ...
  • Prune & finetune (make sure to use the same hparams as the input model used)

    ... fit \
        --config configs/light_gaussian/prune_finetune-gsplat.yaml \
        --data.path ... \
        ... \
        --ckpt_path YOUR_CHECKPOINT_PATH

2.8. AbsGS / EfficientGS

... fit \
    --config configs/gsplat-absgrad.yaml \
    --data.path ...
  • Install diff-surfel-rasterization first

    pip install git+https://github.com/hbb1/diff-surfel-rasterization.git@3a9357f6a4b80ba319560be7965ed6a88ec951c6
  • Then start training

    ... fit \
        --config configs/vanilla_2dgs.yaml \
        --data.path ...
  • First, train a 3DGS scene using gsplat

    python main.py fit \
        --config configs/gsplat.yaml \
        --data.path data/Truck \
        -n Truck -v gsplat  # trained model will save to `outputs/Truck/gsplat`
  • Then generate SAM masks and their scales

    • Masks

      python utils/get_sam_masks.py data/Truck/images

      You can specify the path to SAM checkpoint via argument -c PATH_TO_SAM_CKPT

    • Scales

      python utils/get_sam_mask_scales.py outputs/Truck/gsplat

    Both the masks and scales will be saved in data/Truck/semantics, the structure of data/Truck will like this:

    ├── images  # The images of your dataset
        ├── 000001.jpg
        ├── 000002.jpg
        ...
    ├── semantic  # Generate by `get_sam_masks.py` and `get_sam_mask_scales.py`
        ├── masks
            ├── 000001.jpg.pt
            ├── 000002.jpg.pt
            ...
        └── scales
            ├── 000001.jpg.pt
            ├── 000002.jpg.pt
            ...
    ├── sparse  # colmap sparse database
        ...
  • Train SegAnyGS

    python seganygs.py fit \
        --config configs/segany_splatting.yaml \
        --data.path data/Truck \
        --model.initialize_from outputs/Truck/gsplat \
        -n Truck -v seganygs  # save to `outputs/Truck/seganygs`

    The value of --model.initialize_from is the path to the trained 3DGS model

  • Start the web viewer to perform segmentation or cluster

    python viewer.py outputs/Truck/seganygs

    SegAnyGS-WebViewer.mp4

2.11. Reconstruct a large scale scene with the partitioning strategy like VastGaussian

Baseline Partitioning
image image
image image

There is no single script to finish the whole pipeline. Please refer to below contents about how to reconstruct a large scale scene.

2.12. Appearance Model

With appearance model, the reconstruction quality can be improved when your images have various appearance, such as different exposure, white balance, contrast and even day and night.

This model assign an extra feature vector $\boldsymbol{\ell}^{(g)}$ to each 3D Gaussian and an appearance embedding vector $\boldsymbol{\ell}^{(a)}$ to each appearance group. Both of them will be used as the input of a lightweight MLP to calculate the color.

$$ \mathbf{C} = f \left ( \boldsymbol{\ell}^{(g)}, \boldsymbol{\ell}^{(a)} \right ) $$

Please refer to internal/renderers/gsplat_appearance_embedding_renderer.py for more details.

Baseline New Model
Train-head-baseline.mp4
Train-head.mp4
Day-and-Night.mp4
  • First generate appearance groups (Colmap or PhotoTourism dataset only)
python utils/generate_image_apperance_groups.py PATH_TO_DATASET_DIR \
    --image \
    --name appearance_image_dedicated  # the name will be used later

The images in a group will share a common appearance embedding. The command above will assign each image a group, which means that will not share any appearance embedding between images.

  • Then start training
python main.py fit \
    --config configs/appearance_embedding_renderer/view_dependent.yaml \
    --data.path PATH_TO_DATASET_DIR \
    --data.params.colmap.appearance_groups appearance_image_dedicated  # value here should be the same as the one provided to `--name` above

If you are using PhotoTourism dataset, please replace --data.params.colmap. with --data.params.phototourism., and specify the dataset type with --data.type phototourism.

2.13. 3DGS-MCMC

  • Install submodules/mcmc_relocation first
pip install submodules/mcmc_relocation
  • Then training
... fit \
    --config configs/gsplat-mcmc.yaml \
    --model.density.cap_max MAX_NUM_GAUSSIANS \
    ...

MAX_NUM_GAUSSIANS is the maximum number of Gaussians that will be used.

Refer to ubc-vision/3dgs-mcmc, internal/density_controllers/mcmc_density_controller.py and internal/metrics/mcmc_metrics.py for more details.

3. Evaluation

Evaluate on validation set

python main.py validate \
    --config outputs/lego/config.yaml

On test set

python main.py test \
    --config outputs/lego/config.yaml

On train set

python main.py validate \
    --config outputs/lego/config.yaml \
    --val_train

Save images that rendered during evaluation/test

python main.py <validate or test> \
    --config outputs/lego/config.yaml \
    --save_val

Then you can find the images in outputs/lego/<val or test>.

4. Web Viewer

Transform Camera Path Edit
transform.mp4
animation.mp4
edit.mp4

4.1 Basic usage

python viewer.py TRAINING_OUTPUT_PATH
# e.g.: 
#   python viewer.py outputs/lego/
#   python viewer.py outputs/lego/checkpoints/epoch=300-step=30000.ckpt
#   python viewer.py outputs/lego/baseline/point_cloud/iteration_30000/point_cloud.ply  # only works with VanillaRenderer

4.2 Load multiple models and enable transform options

python viewer.py \
    outputs/garden \
    outputs/lego \
    outputs/Synthetic_NSVF/Palace/point_cloud/iteration_30000/point_cloud.ply \
    --enable_transform

4.3 Load model trained by other implementations

[NOTE] The commands in this section only design for third-party outputs

python viewer.py \
    Deformable-3D-Gaussians/outputs/lego \
    --vanilla_deformable \
    --reorient disable  # change to enable when loading real world scene
python viewer.py \
    4DGaussians/outputs/lego \
    --vanilla_gs4d
# Install `diff-surfel-rasterization` first
pip install git+https://github.com/hbb1/diff-surfel-rasterization.git@3a9357f6a4b80ba319560be7965ed6a88ec951c6
# Then start viewer
python viewer.py \
    2d-gaussian-splatting/outputs/Truck \
    --vanilla_gs2d
python viewer.py \
    SegAnyGAussians/outputs/Truck \
    --vanilla_seganygs
python viewer.py \
    mip-splatting/outputs/bicycle \
    --vanilla_mip

5. F.A.Q.

Q: The viewer shows my scene in unexpected orientation, how to rotate the camera, like the U and O key in the SIBR_viewer?

A: Check the Orientation Control on the right panel, rotate the camera frustum in the scene to the orientation you want, then click Apply Up Direction.

reorient-camera-up.mp4


Besides: You can also click the 'Reset up direction' button. Then the viewer will use your current orientation as the reference.

  • First use mouse to rotate your camera to the orientation you want
  • Then click the 'Reset up direction' button

Q: The web viewer is slow (or low fps, far from real-time).

A: This is expected because of the overhead of the image transfer over network. You can get around 10fps in 1080P resolution, which is enough for you to view the reconstruction quality.

License

This repository is licensed under MIT license. Except some thirdparty dependencies (e.g. files in submodules directory), files and codes copied from other repositories, which are separately licensed.

MIT License

Copyright (c) 2023 yzslab

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

About

A 3D Gaussian Splatting framework with various derived algorithms and an interactive web viewer

Resources

Stars

Watchers

Forks

Packages

No packages published