Skip to content

[ACM MM 2024] SAM-MIL: A Spatial Contextual Aware Multiple Instance Learning Approach for Whole Slide Image Classification

License

Notifications You must be signed in to change notification settings

FangHeng/SAM-MIL

Repository files navigation

SAM-MIL (Updating)

Official repository of "SAM-MIL: A Spatial Contextual Aware Multiple Instance Learning Approach for Whole Slide Image Classification", ACM Multimedia 2024. [arXiv] [PDF]

TODO

  • Add the code for the training and testing of the model.
  • Add the code for the preprocessing of the datasets.
  • Add the code for the visualization of the results.
  • Improving the visualization code.
  • Add the extracted features of the datasets(Camelyon16, TCGA-NSCLC).
  • Improving README document.
  • Improving the code structure.

Preparation

We used the Segment Anything Model from the official repository to implement visual segmentation of WSIs.

  • The weights used in the experiment can be found here.
  • We used sam_vit_h as the backbone model for the segmentation task.

Preprocessing

The preprocessing code can be found in the WSI_preprocess folder.

  • We adjusted the preprocessing steps in the CLAM repository. We made modifications to the original data preprocessing.

  • PLIP model and weight can be found in this link.

1. Create patches and SAM segmentations for the WSIs.

Camelyon16:

python 01_create_patches_and_sam_segment.py --source '/path/to/your/WSI/folder' --save_dir 'path/to/save/patches' --patch_size 512 --step_size 512 --preset 'bwh_biopsy.csv' --seg --patch --stitch --use_sam --sam_checkpoint 'path/to/sam_weights.pth'

TCGA-NSCLC:

python 01_create_patches_and_sam_segment.py --source '/path/to/your/WSI/folder' --save_dir '/path/to/save/patches' --patch_size 512 --step_size 512 --preset 'tcga.csv' --seg --patch --stitch --use_sam --sam_checkpoint '/path/to/sam_weights.pth'

2. Extract the features from the patches and SAM segmentations.

Camelyon16:

python 02_extract_features_and_group_feature.py --data_h5_dir '/path/to/patches' --data_slide_dir '/path/to/WSIs' --data_segment_dir '/path/to/segments' --csv_path '/path/to/process_list_autogen.csv' --feat_dir '/path/to/save/features' --use_sam --patch_size 512 --batch_size 512 --target_patch_size 224 --slide_ext .tif

TCGA-NSCLC:

python 02_extract_features_and_group_feature.py --data_h5_dir '/path/to/patches' --data_slide_dir '/path/to/WSIs' --data_segment_dir '/path/to/segments' --csv_path '/path/to/process_list_autogen.csv' --feat_dir '/path/to/save/features' --use_sam --patch_size 512 --batch_size 512 --target_patch_size 224 --slide_ext .svs

2-1. (Optional) Generate feature from original extracted features.

If you have already extracted the features from the patches, you can use the following code to generate the features for the model input.

From features(.h5) to our model input:

python extract_features_from_h5.py --data_feat_h5_dir '/path/to/h5/features' --data_slide_dir '/path/to/WSIs' --data_segment_dir '/path/to/segments' --csv_path '/path/to/process_list_autogen.csv' --feat_dir '/path/to/save/features' --patch_size 512 --slide_ext .tif/.svs

From features(.pt) to our model input:

python extract_features_from_pt.py --data_feat_pt_dir '/path/to/pt/features' --data_slide_dir '/path/to/WSIs' --data_segment_dir '/path/to/segments' --csv_path '/path/to/process_list_autogen.csv' --feat_dir '/path/to/save/features' --patch_size 512 --slide_ext .tif/.svs

3. Generate SAM info:

python 03_extract_sam_info.py --data_feat_h5_dir '/path/to/h5/features' --data_slide_dir '/path/to/WSIs' --data_segment_dir '/path/to/segments' --csv_path '/path/to/process_list_autogen.csv' --output_dir '/path/to/save/sam_info' --data_group_dir '/path/to/seg_files' --slide_ext .tif/.svs

Folder Structure

.DATASET_ROOT//
    ├── pt_files    // The extracted features in .pt format
        ├── slide1.pt
        ├── slide2.pt
        └── ...
    ├── h5_files    // (Optional) The extracted features in .h5 format
        ├── slide1.h5
        ├── slide2.h5
        └── ...
    ├── sam_info    // The SAM information inputs
        ├── slide1.h5
        ├── slide2.h5
        └── ...
    └── labels.csv  // The labels of the slides

Feature Extraction

Some code snippets about PLIP feature are shown below:

extract_features_fp.py:

model = PLIP()
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
mean, std = (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
eval_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)])

models/plip.py

from transformers import CLIPVisionModelWithProjection

class PLIP(torch.nn.Module):
    def __init__(self):
        super(PLIPM,self).__init__()
        self.model = model = CLIPVisionModelWithProjection.from_pretrained("vinid/plip")
    def forward(self, input):
        return self.model(batch_input).image_embeds

Training

The arguments for training can be found in options.py.

Train the model:

python main.py --project=your_project --datasets=camelyon16/tcga --dataset_root=/path/to/your/dataset --model_path=/path/to/save/model --cv_fold=5 --title=your_title --model=sam --sam_mask --mask_non_group_feat --mask_by_seg_area --baseline=attn --mrh_sche --seed=2021 --mask_ratio=0.9 --select_mask --num_group=5 --group_alpha=0.5 --consistency_alpha=1000 --num_workers=0 --persistence --wandb

Citing SAM-MIL

If you find SAM-MIL useful in your research, please consider citing the following paper:

@article{fang2024sam,
  title={SAM-MIL: A Spatial Contextual Aware Multiple Instance Learning Approach for Whole Slide Image Classification},
  author={Fang, Heng and Huang, Sheng and Tang, Wenhao and Huangfu, Luwen and Liu, Bo},
  journal={arXiv preprint arXiv:2407.17689},
  year={2024}
}

About

[ACM MM 2024] SAM-MIL: A Spatial Contextual Aware Multiple Instance Learning Approach for Whole Slide Image Classification

Resources

License

Stars

Watchers

Forks