Skip to content

Spatial Decomposition Network (SDNet) for content (anatomy) and style (modality) disentanglement

Notifications You must be signed in to change notification settings

SAMPL-Weizmann/SDNet

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

60 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

About this repo

Spatial Decomposition Network (SDNet) for content (anatomy) and style (modality) disentanglement.

This repo delivers the PyTorch implementation of the SDNet model presented in this paper. The original SDNet is implemented in Keras by the first author of the paper Agis85. This version of SDNet focuses on the comparison between spatial and vectorized latent space for the anatomy encoding (many variants are included). To actually compare the different variants, the segmentation task is adopted, using the ACDC cardiac imaging dataset (as in the original paper).

Prerequisites

All coding and experiments were using the following setup:

  • PyTorch 1.5.1
  • Cuda 10.1
  • Python 3.7.5
  • Visdom - loss plots, images, etc.
  • Packages: nibabel, opencv-python, skimage

Training

To see all the available training (hyper)parameters use:

python main.py -h

Available SDNet variants:

  1. Original architecture - UNet to encode anatomy in spatial latent variable (Variant A)
    • Gumbel Softmax is used instead of the binarization module for the UNet output --> smoother Dice loss convergence and a 3% increase in the validation accuracy
  2. A VAE is used to encode the anatomy in a vector latent space (Variant B)
  3. A VAE is used to re-encode the spatial output of the UNet - VAE output is used by the segmentor and the decoder (Variant C)

SDNet architecture - Variant A (36.1M parameters)

Train the original SDNet model for 60 epochs and batch size 10 using:
python main.py --model_name sdnet --epochs 60 --batch_size 10 --data_path /path/to/ACDC/data --name visdom_experiment_name --visdom --gpu gpu_id

SDNet architecture - Variant B (8.7M parameters)

Train the 2-VAE SDNet model for 60 epochs and batch size 10 using:
python main.py --model_name sdnet2 --epochs 60 --batch_size 10 --data_path /path/to/ACDC/data --name visdom_experiment_name --visdom --gpu gpu_id

SDNet architecture - Variant C (37.2M parameters)

Train the UNet+VAE SDNet model for 60 epochs and batch size 10 using:
python main.py --model_name sdnet3 --epochs 60 --batch_size 10 --data_path /path/to/ACDC/data --name visdom_experiment_name --visdom --gpu gpu_id

Test

To test the original SDNet model using the ACDC test set samples use the following command:

python test.py --model_name sdnet --data_path /path/to/ACDC/data --load_weights checkpoints/path/to/saved_model_weights --gpu gpu_id

Note that this script will save the anatomy factors of each sample under the factors directory.

Results

The following Table reports the results of the 3 variants on the ACDC test set. Note that all models were trained only on Split 0 of the training set for this proof-of-concept experiment.

Variant A Variant B Variant C
Dice Score 0.78 0.48 0.36

The following examples are anatomy factors encoded by the SDNet variant A model:

To Do

Since this is an "in-progress" repository there are some more stuff to be added:

  • Report some preliminary results on ACDC test set
  • A script to perform modality (style) traversals of any model
  • Add SPADE decoder implementation (now only AdaIN is available)

Acknowledgements

Thank you Agis85 for the discussions and the original (Keras) implementation. Also thanks Naoto Inoue for the PyTorch implementation of the AdaIN module.

About

Spatial Decomposition Network (SDNet) for content (anatomy) and style (modality) disentanglement

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%