(Dilated) Vision Transformer for Gray/White Matter Segmentation
This script:
- Specifies Slurm job parameters (job name, outputs, allocated resources, etc.)
- Activates a Conda environment named
neuro
- Sets library paths
- Runs a python script (
minimal_monai_torchio_example.py
) with specified parameters
- Platform: Slurm cluster
- Compute: v100 GPUs, High Memory Nodes
- Framework: Conda with
neuro
environment - Codebase: Python, utilizing
minimal_monai_torchio_example.py
- Dataset: Located within
./data/afedorov_T1_c_atlas_data/
- Platform: Slurm cluster
- Compute: v100 GPUs, High Memory Nodes
- Environment: Managed via Conda, specifically the
neuro
environment. - Project Root:
/data/users2/bbaker/projects/MeshVit/neuro2
- Logging: Slurm logs stored in
/data/users2/bbaker/projects/MeshVit/slurm/
with error and output streams separated. - Notifications: Emails sent to
[email protected]
on all job events. - Script Execution: Using the Python script
minimal_mongo.py
in thetraining
directory.- Model:
segmenter
- Epochs: Short runs of 10 epochs for quick iteration and testing.
- Classes: 3 class segmentation.
- Logging & Results: Stored under
../vit3d_results/
.
- Model:
- Dataset Configuration: Custom subvolume and patch sizes set for dataset processing (
tsv
,sv
,ps
).
Developer Notes:
- Utilizing Conda for environment management ensures consistent dependency versions across runs.
- Slurm configuration allows leveraging GPU resources and simplifying parallel execution.
- Emphasis on monitoring via email notifications and dedicated logging to keep a tab on training progress and issues.
- Platform: Slurm cluster
- Compute Specs: v100 GPUs, High Memory Nodes
- Environment Management: Conda with
neuro
environment. - Project Directory:
/data/users2/bbaker/projects/MeshVit/neuro2
- Logging: Slurm logs stored in
/data/users2/bbaker/projects/MeshVit/slurm/
. - Notifications: Configured for
[email protected]
. - Main Script:
minimal_mongo.py
within thetraining
folder.- Model:
meshnet
- Training Duration: Quick iterations with 10 epochs.
- Segmentation Classes: Three distinct classes.
- Results Directory:
../vit3d_results/
.
- Model:
Development Insights:
- The Conda environment ensures a consistent working environment.
- The use of Slurm streamlines GPU utilization and job management.
- Regular email notifications and log separation for quick troubleshooting and monitoring.
- Dataset customization via specific subvolume and patch sizes to cater to specific data nuances.
- Framework: PyTorch
- Image Size: 3D volume with dimensions 38x38x38
- Device: CUDA (GPU acceleration)
-
Vision Transformer:
VisionTransformer3d
- Patches of size: 12
- Embedding size: 128
- Depth: 8
- Number of heads: 3
- Input channels: 1
-
Decoder:
MaskTransformer3d
- Heads: 2
- Dropout: 0.0 and 0.1
-
Segmenter: Combines the Vision Transformer and Decoder for 3D segmentation.
- Dummy 3D data is generated to feed into the
Segmenter3d
. - The trainable parameters of the model are then counted and printed.
Development Note: This script showcases the integration of a 3D Vision Transformer with a custom decoder for segmentation tasks. The current configuration is indicative and may need tweaking based on the dataset and desired results.
This module focuses on image segmentation using transformer-based architectures. Here's a breakdown of its architecture and components:
- Framework: PyTorch.
-
MaskTransformer (Class):
- Input Parameters:
n_cls
: Number of classes.patch_size
: Size of each image patch.d_encoder
: Encoder depth.n_layers
: Number of transformer layers.n_heads
: Number of attention heads.d_model
: Depth of the model.d_ff
: Depth of the feed-forward network.drop_path_rate
: Dropout rate.dropout
: Dropout value.
- Attributes:
blocks
: Sequential transformer blocks.cls_emb
: Embeddings for the classes.proj_dec
: Projection for the decoder.proj_patch
: Projection for patches.proj_classes
: Projection for the classes.decoder_norm
: Layer normalization for the decoder.mask_norm
: Layer normalization for the mask.
- Methods:
forward
: Propagates the input through the model, returning masks.get_attention_map
: Retrieves the attention map for a given layer ID.
- Input Parameters:
-
Utilities and Blocks:
Block
: Basic transformer block used in the sequence.FeedForward
: Basic feed-forward network block.init_weights
: Helper function to initialize weights.trunc_normal_
: Truncated normal initialization fromtimm
.
- The
MaskTransformer
is tailored to segment images, extracting features from the encoder output and predicting segmentation masks. - It combines traditional transformer blocks with custom projections to cater to segmentation specifics.
- Attention maps can be retrieved for specific layers, aiding in model interpretability and understanding.
This code module provides an implementation of the Visual Transformer (ViT) model that leverages the efficient attention mechanism of Linformer for image classification.
- Frameworks & Libraries: PyTorch, Linformer, ViT from
vit_pytorch
.
-
build_vit Function:
- Input Parameters:
dim
: Dimensionality of the token embeddings.seq_len
: Length of the sequence, determined by image and patch size.depth
: Number of transformer layers.heads
: Number of attention heads.k
: Context window size for Linformer.image_size
: Dimension (height/width) of the input image.patch_size
: Size of each image patch.num_classes
: Number of output classes.channels
: Number of input channels (e.g., 3 for RGB images).
- Description: Constructs the ViT model using Linformer as the transformer backbone.
- Return: The ViT model.
- Input Parameters:
-
Efficient Attention Mechanism:
- The Linformer reduces the self-attention computation from O(n^2) to O(nk) by approximating the full attention matrix with a fixed size context window.
- When executed as a main script, the module builds a default ViT model and prints its modules.
- The integration of Linformer in ViT paves the way for handling larger images or sequences without a significant increase in computation.
- The
seq_len
represents the total number of patches plus a class token, which is used for classification in the ViT paradigm. - By using efficient transformers like Linformer in computer vision models, developers can harness the power of transformers while maintaining computational feasibility.
The module provides a 3D variant of the Visual Transformer (ViT) for processing volumetric data using the Linformer's efficient attention mechanism.
- Frameworks & Libraries: PyTorch, Linformer,
einops
.
-
ViT3d Class:
- Description: A 3D version of the Visual Transformer, modified to handle volumetric data.
- Key Variables:
pos_embedding
: Positional embeddings for patches.cls_token
: Special classification token.transformer
: Underlying transformer, Linformer in this case.mlp_head
: The MLP layer for final classification.
- Methods:
forward
: Processes an input volume and returns model predictions.
-
build_vit Function:
- Input Parameters:
dim
: Dimensionality of the token embeddings.seq_len
: Length of the sequence, determined by image and patch size.depth
: Number of transformer layers.heads
: Number of attention heads.k
: Context window size for Linformer.image_size
: Dimension (height/width/depth) of the input volume.patch_size
: Size of each image patch.num_classes
: Number of output classes.channels
: Number of input channels (e.g., 1 for grayscale volumes).output_shape
: The desired shape of the model's output. Useful for segmentations or other spatial outputs.
- Description: Constructs the 3D ViT model using Linformer as its transformer backbone.
- Return: The 3D ViT model.
- Input Parameters:
-
Efficient Attention Mechanism:
- The Linformer reduces the self-attention computation by approximating the full attention matrix with a fixed size context window.
- When executed as a main script, the module creates an instance of the 3D ViT model, processes a random input volume, and prints the output shape.
- The use of Linformer in the 3D ViT model allows for efficient handling of larger volumes.
- The
seq_len
represents the total number of patches plus a class token, adapted for the 3D setting. - This architecture is ideal for 3D image tasks like medical image analysis where the data is naturally volumetric.