HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: bibentry

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: arXiv.org perpetual non-exclusive license
arXiv:2304.12620v7 [cs.CV] 29 Dec 2023

Medical SAM Adapter: Adapting Segment Anything Model for Medical Image Segmentation

Junde Wu1,2,7, Wei Ji3, Yuanpei Liu8, Huazhu Fu4, Min Xu5,7, Yanwu Xu6, Yueming Jin2
Abstract

The Segment Anything Model (SAM) has recently gained popularity in the field of image segmentation due to its impressive capabilities in various segmentation tasks and its prompt-based interface. However, recent studies and individual experiments have shown that SAM underperforms in medical image segmentation, since the lack of the medical specific knowledge. This raises the question of how to enhance SAM’s segmentation capability for medical images. In this paper, instead of fine-tuning the SAM model, we propose the Medical SAM Adapter (Med-SA), which incorporates domain-specific medical knowledge into the segmentation model using a light yet effective adaptation technique. In Med-SA, we propose Space-Depth Transpose (SD-Trans) to adapt 2D SAM to 3D medical images and Hyper-Prompting Adapter (HyP-Adpt) to achieve prompt-conditioned adaptation. We conduct comprehensive evaluation experiments on 17 medical image segmentation tasks across various image modalities. Med-SA outperforms several state-of-the-art (SOTA) medical image segmentation methods, while updating only 2% of the parameters. Our code is released at https://github.com/KidsWithTokens/Medical-SAM-Adapter.

Introduction

Very recently, the Segmentation Anything Model (SAM) (Kirillov et al. 2023) has gained significant attention as a powerful and versatile vision segmentation model. It can generate diverse and detailed segmentation masks based on user prompts. Despite its strong performance over natural images, many recent studies also show (Deng et al. 2023; Roy et al. 2023; He et al. 2023) that it reaches subpar performance on medical image segmentation. Making medical image segmentation interactive, such as employing techniques like SAM, holds immense clinical value. An interactive system can prioritize areas of interest as indicated by the clinicians, providing them with a more immersive and personalized experience. For instance, in a single fundus image, there are often overlapping and intricately intertwined structures such as vessels, optic disc, optic cup, and macula. Interactive segmentation can greatly assist clinicians in efficiently distinguishing target tissues from these complex structures. Considering the difficulty in acquiring large-scale annotated datasets, it becomes crucial to adopt a foundational interactive model like SAM for clinical utilization.

SAM’s limited performance on medical images is due to its lack of medical-specific knowledge, including challenges like low image contrast, ambiguous tissue boundaries, and tiny lesion regions. The state-of-the-art (SOTA) approach to address this issue is fully fine-tuning the vanilla SAM model specifically on medical data(Ma and Wang 2023), which is quite costly in terms of both computation and memory footprint. Additionally, it is doubtful whether full fine-tuning is necessary, as previous studies have shown pre-trained visual models have strong transferability to medical images (Raghu et al. 2019; Xie and Richmond 2018).

In this paper, we attempt to adapt the well-trained SAM to the medical image segmentation with minimum effort. Technically, we choose to fine-tune the pre-trained SAM using a parameter-efficient fine-tuning (PEFT) technique called Adaption (Hu et al. 2021). Adaption has been a popular and widely-used technology in natural language processing (NLP) to fine-tune the fundamental pre-trained model for various downstream tasks. The main idea of Adaption is to insert Adapter modules with partial parameters into the original model and only update a small number of additional Adapter parameters while keeping the large pre-trained model frozen.

However, directly applying the Adaption technique to the medical scenario is not that straightforward. The first challenge arises from the image modality. Unlike natural images, many medical images are 3D, such as CT and MRI scans. It is unclear how to adapt the 2D SAM model for 3D medical image segmentation. Secondly, while Adaption has been successful in NLP, there is limited research on applying it to visual models, especially interactive visual models like SAM. In interactive visual models, user-provided visual prompts play a crucial role in the final prediction. How to incorporate Adaption with these important visual prompts remains unexplored.

To overcome these challenges, we propose a novel adaptation framework called Medical SAM Adapter (Med-SA). In Med-SA, we introduce the Space-Depth Transpose (SD-Trans) technique to achieve 2D to 3D adaptation. In SD-Trans, we transpose the spatial dimension of input embedding to the depth dimension, allowing the same self-attention blocks can process different dimensional information given different inputs. Then we propose Hyper-Prompting Adapter (HyP-Adpt) to enable prompt-conditioned adaptation, in which we use the visual prompt to generate a series of weights that can be applied to the adaptation embedding efficiently, facilitating wide and deep prompt-adaptation interactions.

We conduct comprehensive evaluation experiments cover 17 medical image segmentation tasks across various image modalities, including CT, MRI, ultrasound images, fundus images, and dermoscopic images. The results demonstrate that Med-SA outperforms both SAM and fully fine-tuned SAM (MedSAM)(Ma and Wang 2023) with a significant performance gap. Med-SA also surpasses several SOTA methods that are tailor-designed for medical image segmentation, such as nnUNet, TransUNet, UNetr, and Swin-UNetr. More importantly, Med-SA achieves this superior performance by updating only 2% extra parameters of the total SAM parameters.

  • We present the Adaption approach for general medical image segmentation. Our framework, Med-SA, is a simple yet powerful extension of the SAM architecture, substantially enhancing its capabilities for medical applications while updating a mere 2% of the total parameters.

  • We propose SD-Trans to enable the segmentation of high-dimensional (3D) medical data, addressing the challenge posed by medical image modalities.

  • We propose HyP-Adpt to facilitate prompt-conditioned adaption, acknowledging the importance of user-provided prompts in the medical domain.

  • Our extensive experiments on 17 medical image segmentation tasks with various image modalities, clearly establish Med-SA’s superiority over SAM and previous state-of-the-art methods. On the widely-used abdominal multi-organ segmentation BTCV benchmark, Med-SA outperforms Swin-UNetr by 2.9%, vanilla SAM by 34.8%, and fully-finetuned SAM (MedSAM) by 9.4%.

Related Work

Interactive Segmentation

Interactive segmentation has a rich history, initially regarded as an optimization technique by researchers (Grady 2006; Gulshan et al. 2010; Kim, Lee, and Lee 2010; Rother, Kolmogorov, and Blake 2004). The pioneering work of DIOS (Xu et al. 2016) revolutionized interactive segmentation by integrating deep learning and incorporating positive and negative clicks as distance maps. Subsequent studies (Li, Chen, and Koltun 2018; Liew et al. 2019) focused on addressing uncertainty by predicting multiple potential results and enabling either a selection network or the user to choose among them. CDNet (Chen et al. 2021b) further enhanced interactive segmentation by incorporating self-attention to generate more consistent predictions. RITM (Sofiiuk, Petrov, and Konushin 2022) and AccuracyNet (Forte et al. 2020) introduced the use of previous masks as inputs to enhance the robustness and accuracy of predictions. Recently, SAM (Roy et al. 2023) demonstrated the significant impact of interactive segmentation on zero-shot segmentation and emphasized its potential importance in visual foundation models. However, limited attention has been given to interactive medical image segmentation, despite its critical role in clinical practice. For instance, a single fundus image may require the segmentation of multiple targets, such as vessels, optic disc, optic cup, and macula, depending on different requirements and use cases. Our Med-SA provides an excellent starting point for interactive medical image segmentation and aims to inspire future research in this field.

Parameter-Efficient Fine-Turning

PEFT has proven to be an efficient strategy for fine-tuning a large, fundamental model for a specific usage (Zaken, Ravfogel, and Goldberg 2021). Compared to full fine-tuning, it keeps most of the parameters frozen and learns significantly fewer parameters, often less than 5% of the total. This enables efficient learning with faster updates. Studies have also shown that PEFT approaches work better than full fine-tuning as they avoid catastrophic forgetting and generalize better to out-of-domain scenarios, especially in low-data regimes (Zaken, Ravfogel, and Goldberg 2021). Among all PEFT strategies, Adaption(Hu et al. 2021) stands out as an effective tool for fine-tuning large fundamental vision models for downstream tasks, not only in NLP but also in computer vision. Recent studies have shown that Adaption can be easily adopted in various downstream computer vision tasks(He et al. 2022; Chen et al. 2022). Therefore, we believe Adaption is the most fitting technique for carrying SAM to the medical domain. We anticipate that this simple, clean yet powerful Med-SA, will unlock greater possibilities for the development of foundational medical models.

Method

Preliminary: SAM architecture

To begin with, we provide an overview of the SAM architecture. SAM comprises three main components: an image encoder, a prompt encoder, and a mask decoder. The image encoder is based on a standard Vision Transformer (ViT) pre-trained by MAE. Specifically, we use the ViT-H/16 variant, which employs 14×14 windowed attention and four equally-spaced global attention blocks, as shown in 1 (a). The output of the image encoder is a 16× downsampled embedding of the input image. The prompt encoder can be either sparse (points, boxes) or dense (masks). In this paper, we focus only on the sparse encoder, which represents points and boxes as positional encodings summed with learned embeddings for each prompt type. The mask decoder is a Transformer decoder block modified to include a dynamic mask prediction head. The decoder uses two-way cross-attention to learn the interaction between the prompt and image embeddings. After that, SAM upsamples the image embedding, and an MLP maps the output token to a dynamic linear classifier, which predicts the target mask of the given image.

Med-SA architecture

Refer to caption
Figure 1: Med-SA architecture. We use (b) as the encoder with standard Adapter to process 2D medical images, and (c) incorporating SD-Trans to process 3D images. Then we use (d) as the decoder with HyP-Adpt to incorporate the prompts.
Refer to caption
Figure 2: HyP-Adpt architecture. We utilize Prompt Embedding to generate the weights that are applied to the Adapter Embedding.

Our objective is to enhance the medical capability of the SAM architecture for medical image segmentation tasks through fine-tuning. Rather than fully adjusting all parameters, we maintain the pre-trained SAM parameters frozen, devise an Adapter module and integrate it to designated positions. The Adapter serves as a bottleneck model, consisting of a down-projection, ReLU activation, and up-projection sequentially, as illustrated in 1 (b). The down-projection compresses the given embedding into a lower dimension using a simple MLP layer, while the up-projection expands the compressed embedding back to its original dimension using another MLP layer.

In the SAM encoder, we utilize two adapters for each ViT block. For a standard ViT block (depicted in 1(a)), the first Adapter is positioned after the multi-head attention and before the residual connection (as depicted in 1 (b)). The second Adapter is placed in the residual path of the MLP layer following the multi-head attention. Immediately after the second Adapter, we have scaled the embedding with a scale factor s𝑠sitalic_s following (Chen et al. 2022).

In the SAM decoder, we incorporate three adapters for each ViT block. The first Adapter is employed to integrate the prompt embedding, and to achieve this, we introduce a novel structure called the Hyper-Prompting Adapter (HyP-Adpt), which is further elaborated in HyP-Adpt. The second Adapter in the decoder is deployed in exactly the same way as in the encoder, to adapt the MLP-enhanced embedding. The third Adapter is deployed after the residual connection of the image embedding-to-prompt cross-attention. Another residual connection and layer normalization are connected after the adaption to output the final results.

SD-Trans

Adapting SAM to medical image segmentation poses a challenge due to the dimensional disparity between 2D images and the prevalent 3D modalities like MRI and CT scans. In clinical usage, understanding the correlation between slices is crucial for accurate decision-making. While SAM can be applied to each slice of a volume to obtain the final segmentation, it fails to consider the close volumetric correlation inherent in 3D medical image segmentation, as highlighted in previous studies (Hatamizadeh et al. 2022b, a; Xing et al. 2023). To address this limitation, we propose SD-Trans, inspired by image-to-video adaptation (Liu et al. 2019). The specific structure is depicted in 1 (c).

As shown in the image, in each block, we bifurcate the attention operation into two branches: the space branch and the depth branch. For a given 3D sample with depth D𝐷Ditalic_D, we input D×N×L𝐷𝑁𝐿D\times N\times Litalic_D × italic_N × italic_L into the multi-head attention of the space branch, where N𝑁Nitalic_N represents the number of embeddings, and L𝐿Litalic_L denotes the embedding length. Here, D𝐷Ditalic_D corresponds to the number of operations, allowing the interaction to be applied over N×L𝑁𝐿N\times Litalic_N × italic_L, capturing and abstracting spatial correlations as embeddings. In the depth branch, we transpose the input matrix to obtain N×D×L𝑁𝐷𝐿N\times D\times Litalic_N × italic_D × italic_L and subsequently feed it into the same multi-head attention. Although employing the same attention mechanism, the interaction now occurs over D×L𝐷𝐿D\times Litalic_D × italic_L, enabling the learning and abstraction of depth correlations. Finally, we transpose the results from the depth branch back to their original shape and add them to the results of the space branch, incorporating the depth information.

Refer to caption
Figure 3: Visual comparison of Med-SA and SAM on abdominal multi-organ segmentation. We use Check mark to represent SAM correctly found the organ and Cross to represent it lost.
Table 1: The comparison of Med-SA with SOTA segmentation methods over BTCV dataset evaluated by Dice Score. Best results are denoted as bold.
Model Param(M)
Turnable
Param(M)
Spleen R.Kid L.Kid Gall. Eso. Liver Stom. Aorta IVC Veins Panc. AG Avg
TransUNet 37 37 0.952 0.927 0.929 0.662 0.757 0.969 0.889 0.920 0.833 0.791 0.775 0.637 0.838
EnsDiff 32 32 0.938 0.931 0.924 0.772 0.771 0.967 0.910 0.869 0.851 0.802 0.771 0.745 0.854
SegDiff 32 32 0.954 0.932 0.926 0.738 0.763 0.953 0.927 0.846 0.833 0.796 0.782 0.723 0.847
UNetr 104 104 0.968 0.924 0.941 0.750 0.766 0.971 0.913 0.890 0.847 0.788 0.767 0.741 0.856
Swin-UNetr 138 138 0.971 0.936 0.943 0.794 0.773 0.975 0.921 0.892 0.853 0.812 0.794 0.765 0.869
nnUNet 16 16 0.942 0.894 0.910 0.704 0.723 0.948 0.824 0.877 0.782 0.720 0.680 0.616 0.802
SAM 1 points 636 0 0.518 0.686 0.791 0.543 0.584 0.461 0.562 0.612 0.402 0.553 0.511 0.354 0.548
SAM 3 points 636 0 0.622 0.710 0.812 0.614 0.605 0.513 0.673 0.645 0.483 0.628 0.564 0.395 0.631
SAM BBox 0.75 636 0 0.415 0.621 0.678 0.580 0.595 0.469 0.521 0.612 0.539 0.655 0.588 0.327 0.550
SAM BBox 0.5 636 0 0.346 0.585 0.592 0.375 0.426 0.377 0.451 0.536 0.392 0.576 0.426 0.202 0.440
MedSAM 1 point 636 636 0.751 0.814 0.885 0.766 0.721 0.901 0.855 0.872 0.746 0.771 0.760 0.705 0.803
MedSAM 3 points 636 636 0.758 0.831 0.889 0.782 0.733 0.917 0.858 0.876 0.755 0.776 0.763 0.716 0.820
MedSAM BBox 0.75 636 636 0.746 0.842 0.873 0.772 0.745 0.897 0.860 0.889 0.743 0.745 0.739 0.701 0.804
MedSAM BBox 0.5 636 636 0.621 0.736 0.801 0.721 0.715 0.811 0.714 0.770 0.622 0.618 0.630 0.545 0.692
Med-SA 1 point 636 13 0.978 0.935 0.966 0.823 0.818 0.981 0.931 0.915 0.877 0.811 0.767 0.809 0.883
Med-SA 3 points 636 13 0.980 0.936 0.968 0.826 0.821 0.986 0.934 0.917 0.878 0.813 0.771 0.818 0.887
Med-SA BBox 0.5 636 13 0.954 0.910 0.952 0.810 0.807 0.975 0.928 0.912 0.868 0.809 0.769 0.813 0.876
Med-SA BBox 0.75 636 13 0.985 0.947 0.975 0.842 0.808 0.983 0.942 0.939 0.899 0.852 0.790 0.823 0.898

HyP-Adpt

While adaptation techniques have been applied to visual models in a few previous works, the application of adaptation to interactive visual models remains largely unexplored. The interactive behavior between the source task and the downstream task can exhibit significant differences. Therefore, it becomes crucial to incorporate the visual prompt, which plays a key role in the interactive model, into the adapter. In this regard, we propose a solution called HyP-Adpt, aimed at achieving prompt-conditioned adaptation.

The idea behind HyP-Adpt is inspired by HyperNetworks (Ha, Dai, and Le 2016), which employ one network to generate weights for another network for the knowledge conditioning. We adopt the high-level concept of HyperNetworks but redesign it to efficiently apply it at the feature level. Specifically, we utilize only projection and reshaping operations to generate a sequence of weight maps from the prompt embedding. These weight maps are then directly applied (matrix product) to the adapter embedding. This approach enables wide and deep feature-level interaction while also significantly reducing the number of parameters required, as compared to generating an entire network.

Specifically, we conduct the hyper-prompting over the reduced embedding of the Adapter edownsuperscript𝑒𝑑𝑜𝑤𝑛e^{down}italic_e start_POSTSUPERSCRIPT italic_d italic_o italic_w italic_n end_POSTSUPERSCRIPT. In the mean time, the prompt information (click location, click attribution, or bounding box location) is concatenated and reduced as prompt embedding epromptsuperscript𝑒𝑝𝑟𝑜𝑚𝑝𝑡e^{prompt}italic_e start_POSTSUPERSCRIPT italic_p italic_r italic_o italic_m italic_p italic_t end_POSTSUPERSCRIPT. Then we use epromptsuperscript𝑒𝑝𝑟𝑜𝑚𝑝𝑡e^{prompt}italic_e start_POSTSUPERSCRIPT italic_p italic_r italic_o italic_m italic_p italic_t end_POSTSUPERSCRIPT to generate the a sequence of weights, taking one of it to illustrate, it can be represented as:

W=Re(M(eprompt)),𝑊𝑅𝑒𝑀superscript𝑒𝑝𝑟𝑜𝑚𝑝𝑡W=Re(M(e^{prompt})),italic_W = italic_R italic_e ( italic_M ( italic_e start_POSTSUPERSCRIPT italic_p italic_r italic_o italic_m italic_p italic_t end_POSTSUPERSCRIPT ) ) , (1)

where Re𝑅𝑒Reitalic_R italic_e denotes reshape, and M𝑀Mitalic_M denotes the MLP layer to project epromptN×Lsuperscript𝑒𝑝𝑟𝑜𝑚𝑝𝑡superscript𝑁𝐿e^{prompt}\in\mathcal{R}^{N\times L}italic_e start_POSTSUPERSCRIPT italic_p italic_r italic_o italic_m italic_p italic_t end_POSTSUPERSCRIPT ∈ caligraphic_R start_POSTSUPERSCRIPT italic_N × italic_L end_POSTSUPERSCRIPT to epromptN×(Lin*Lout)superscript𝑒𝑝𝑟𝑜𝑚𝑝𝑡superscript𝑁superscript𝐿𝑖𝑛superscript𝐿𝑜𝑢𝑡e^{prompt}\in\mathcal{R}^{N\times(L^{in}*L^{out})}italic_e start_POSTSUPERSCRIPT italic_p italic_r italic_o italic_m italic_p italic_t end_POSTSUPERSCRIPT ∈ caligraphic_R start_POSTSUPERSCRIPT italic_N × ( italic_L start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT * italic_L start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT, in which *** is value multiplication, Linsuperscript𝐿𝑖𝑛L^{in}italic_L start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT of the first weight will be the length of edownsuperscript𝑒𝑑𝑜𝑤𝑛e^{down}italic_e start_POSTSUPERSCRIPT italic_d italic_o italic_w italic_n end_POSTSUPERSCRIPT, and Loutsuperscript𝐿𝑜𝑢𝑡L^{out}italic_L start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT of the last weight will be the target length of the output. After that, we reshape epromptsuperscript𝑒𝑝𝑟𝑜𝑚𝑝𝑡e^{prompt}italic_e start_POSTSUPERSCRIPT italic_p italic_r italic_o italic_m italic_p italic_t end_POSTSUPERSCRIPT from 1D embedding to 2D weight wpromptN×Lin×Loutsuperscript𝑤𝑝𝑟𝑜𝑚𝑝𝑡superscript𝑁superscript𝐿𝑖𝑛superscript𝐿𝑜𝑢𝑡w^{prompt}\in\mathcal{R}^{N\times L^{in}\times L^{out}}italic_w start_POSTSUPERSCRIPT italic_p italic_r italic_o italic_m italic_p italic_t end_POSTSUPERSCRIPT ∈ caligraphic_R start_POSTSUPERSCRIPT italic_N × italic_L start_POSTSUPERSCRIPT italic_i italic_n end_POSTSUPERSCRIPT × italic_L start_POSTSUPERSCRIPT italic_o italic_u italic_t end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT, and apply it over edownsuperscript𝑒𝑑𝑜𝑤𝑛e^{down}italic_e start_POSTSUPERSCRIPT italic_d italic_o italic_w italic_n end_POSTSUPERSCRIPT, which can be represented as:

en+1down=ReLU(Norm(endownwprompt)),subscriptsuperscript𝑒𝑑𝑜𝑤𝑛𝑛1𝑅𝑒𝐿𝑈𝑁𝑜𝑟𝑚tensor-productsubscriptsuperscript𝑒𝑑𝑜𝑤𝑛𝑛superscript𝑤𝑝𝑟𝑜𝑚𝑝𝑡e^{down}_{n+1}=ReLU(Norm(e^{down}_{n}\otimes w^{prompt})),italic_e start_POSTSUPERSCRIPT italic_d italic_o italic_w italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT = italic_R italic_e italic_L italic_U ( italic_N italic_o italic_r italic_m ( italic_e start_POSTSUPERSCRIPT italic_d italic_o italic_w italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ⊗ italic_w start_POSTSUPERSCRIPT italic_p italic_r italic_o italic_m italic_p italic_t end_POSTSUPERSCRIPT ) ) , (2)

where tensor-product\otimes is the matrix product. We normalize the elements along the length dimension and apply ReLU activation after then. We set 3 layers for the hyper-prompting, each weight is projected by individual MLP layers. HyP-Adpt helps to turn the parameter conditioned on the prompt information and be more flexible to different modalities and downstream tasks.

Training Strategy

Table 2: The comparison of Med-SA with SAM and SOTA segmentation methods on different image modalities. The grey background denotes the methods are proposed for that/those particular tasks. Performance is omitted (-) if the algorithm fails over 70% of the samples.
Optic-Disc Optic-Cup Brain-Turmor Thyroid Nodule Melanoma
Param(M)
Turnable
Param(M)
Dice IoU Dice IoU Dice IoU HD95 Dice IoU Dice IoU
ResUNet 17 17 92.9 85.5 80.1 72.3 78.4 71.3 18.71 78.3 70.7 87.1 78.2
BEAL 25 25 93.7 86.1 83.5 74.1 78.8 71.7 18.53 78.6 71.6 86.6 78.0
TransBTS 39 39 94.1 87.2 85.4 75.7 87.6 78.44 12.44 83.8 75.5 88.1 80.6
EnsemDiff 32 32 94.3 87.8 84.2 74.4 88.7 80.9 10.85 83.9 75.3 88.2 80.7
MTSeg 27 27 90.3 83.6 82.3 73.1 82.2 74.5 15.74 82.3 75.2 87.5 79.7
UltraUNet 19 19 91.5 82.8 83.1 73.8 84.5 76.3 14.03 84.5 76.2 89.0 81.8
FAT-Net 75 75 91.8 84.8 80.9 71.5 79.2 72.8 17.35 80.8 73.4 90.7 83.9
BAT 88 88 92.3 85.8 82.0 73.2 79.6 73.5 15.49 81.7 74.2 91.2 84.3
SegDiff 32 32 92.6 85.2 82.5 71.9 85.7 77.0 14.31 81.9 74.8 87.3 79.4
nnUNet 16 16 94.7 87.3 84.9 75.1 88.5 80.6 11.20 84.2 76.2 90.8 83.6
TransUNet 96 96 95.0 87.7 85.6 75.9 86.6 79.0 13.74 83.5 75.1 89.4 82.2
UNetr 104 104 94.9 87.5 83.2 73.3 87.3 80.6 12.81 81.7 73.5 89.7 82.8
Swin-UNetr 138 138 95.3 87.9 84.3 74.5 88.4 81.8 11.36 83.5 74.8 90.2 83.1
SAM 1 points 636 0 - - - - 63.2 47.6 32.53 - - 81.6 70.4
SAM 3 points 636 0 - - - - 71.3 64.5 28.74 - - 85.8 77.5
SAM BBox 0.5 636 0 - - - - 51.2 44.6 38.56 - - 75.3 64.8
SAM BBox 0.75 636 0 - - - - 74.6 62.1 27.51 - - 85.7 74.4
MedSAM 1 point 636 636 92.9 85.5 82.1 73.8 81.5 74.3 15.68 81.3 74.7 86.8 77.5
MedSAM 3 points 636 636 93.8 86.2 82.8 74.2 82.3 74.8 15.19 81.6 75.1 87.5 78.6
MedSAM BBox 0.5 636 636 92.6 85.3 82.0 75.2 82.0 74.7 15.05 82.4 75.5 88.5 79.2
MedSAM BBox 0.75 636 636 94.6 86.7 82.8 75.9 83.6 75.6 14.90 82.8 75.7 88.9 79.8
Med-SA 1 point 636 13 97.4 89.5 86.8 78.8 89.1 81.8 10.38 86.3 78.7 92.6 84.1
Med-SA 3 points 636 13 97.9 89.8 87.1 79.0 89.8 82.3 10.11 86.7 79.4 93.4 84.7
Med-SA BBox 0.5 636 13 97.6 89.6 86.4 78.5 89.5 81.9 10.35 86.6 78.9 92.1 83.0
Med-SA BBox 0.75 636 13 98.3 90.1 87.5 79.9 90.5 83.0 9.50 88.4 80.4 93.0 84.2
Refer to caption
Figure 4: Visual comparison of Med-SA and SAM on medical image segmentation with four different modalities. Top-left: optic disc and cup segmentation from the fundus image. Top-right: brain tumor segmentation from the Brain MRI. Bottom-left: melanoma segmentation from the dermoscopic image. Bottom-right: thyroid nodule segmentation from the ultrasound image.

For interactive segmentation, we employ click prompts and bounding box (BBox) prompts during the model training process. To generate BBox prompts, we adopt the same approach as SAM. However, since the original SAM paper provides limited details on click prompt generation, we have devised our own method, which we present here.

The fundamental concept behind our click prompt generation process involves using positive clicks to indicate foreground regions and negative clicks to indicate background regions. We combine random and iterative click sampling strategies to train the model with these prompts. Initially, we utilize random sampling for prompt initialization, and subsequently, we incorporate a few clicks using an iterative sampling procedure. This iterative sampling strategy emulates the interaction with a real user, as each new click is placed in the erroneous region of a prediction generated by the network using the set of previous clicks. We refer to (Lin et al. 2020) for random sampling generation and (Mahadevan, Voigtlaender, and Leibe 2018) for simulating the iterative sampling process. The detailed implementation can be found in our released code.

Experiments

Dataset

We conducted experiments on five distinct medical image segmentation datasets, which can be categorized into two types. The first type focused on evaluating general segmentation performance. For this purpose, we selected abdominal multi-organ segmentation, as it represents one of the most significant challenges in medical image segmentation. We utilized the BTCV dataset (Fang and Yan 2020), a widely-used and publicly available benchmark with twelve anatomies as the benchmark.

The other four tasks were used to verify the model’s generalization to different modalities, including optic disc and optic cup segmentation over fundus images, brain tumor segmentation over brain MRI images, thyroid nodule segmentation over ultrasound images, and melanoma or nevus segmentation from dermoscopic images. For the fundus image segmentation, we conducted experiments on REFUGE2(Fang et al. 2022) dataset. For brain tumor segmentation, we conducted experiments on the BraTs 2021 dataset(Baid et al. 2021). For thyroid nodule segmentation, we used the TNMIX benchmark, a mixed dataset containing 4554 images from TNSCUI (Ma et al. 2017) and 637 images from DDTI (Pedraza et al. 2015). Finally, for melanoma or nevus segmentation, we conducted experiments on the ISIC 2019 dataset(Milton 2019). All datasets are publicly available.

Implementation Details

In this study, we implemented the Med-SA pipeline primarily following the official ViT-H SAM GitHub repository. For 2D medical image training, we adhered to the default training settings of SAM. For 3D medical image training, we used a smaller batch size of 16. For the REFUGE2, TNMIX, and ISIC datasets, we trained the model for 40 epochs. For the BTCV and BraTs datasets, we extended the training to 60 epochs. We chose smaller epoch numbers compared to fully fine-tuned training since we observed that the model converged faster in our setting. In the interactive model, we experimented with four different prompt settings. These included: (1) a random 1 positive point, denoted as "1-point", (2) three positive points, denoted as "3-points", (3) bounding boxes with 50% overlapping of the target, denoted as "BBox 0.5", and (4) bounding boxes with 75% overlapping of the target, denoted as "BBox 0.75". All the experiments are implemented with the PyTorch platform and trained/tested on 4 NVIDIA A100 GPUs. We utilized the default settings to reproduce the comparison methods.

Table 3: An ablation study on SD-Trans and HyP-Adpt.
2D-3D Prompt-Condition BTCV OpticDisc OpticCup BrainTumor ThyroidNodule Melanoma
SD-Trans Add Concat HyP-Adpt Ave-Dice (%) Dice (%) Dice (%) Dice (%) Dice (%) Dice (%)
79.3 90.1 80.1 77.5 76.5 89.2
84.7 - - 81.7 - -
86.1 94.6 83.4 83.9 83.7 93.8
86.4 95.7 84.0 85.1 84.8 94.5
88.3 97.4 86.8 87.6 86.3 96.3

Comparing with SOTA on Abdominal Multi-organ Segmentation

To verify the general performance of our proposed Med-SA model, we compare it with SOTA segmentation methods on the multi-organ segmentation datasets BTCV. The quantitative results are presented in 1. In the table, we compare Med-SA with well-recognized medical image segmentation methods, including nnUNet (Isensee et al. 2021), TransUNet (Chen et al. 2021a), UNetr (Hatamizadeh et al. 2022b), Swin-UNetr (Hatamizadeh et al. 2022a), EnsDiff (Wolleb et al. 2021), and SegDiff (Amit et al. 2021), as well as vanilla SAM and fully fine-turned SAM (MedSAM) (Ma and Wang 2023). We evaluate the segmentation performance using the Dice score.

In the table, we can see that Med-SA achieves a significant improvement over SAM when utilizing only 1-point prompt. Remarkably, on the BTCV dataset, the one-point Med-SA achieves SOTA performance for all 12 organs, surpassing other methods in overall performance. As we provide more fine-grained prompts, the results continue to improve, reaching a final Dice of 89.8% with BBox 0.75. This result outperforms the previous SOTA (Swin-UNetr) by a significant margin of 2.9%. Notably, Swin-UNetr consists of 138M turnable parameters, whereas we only update 13M parameters. Surprisingly, we even outperform the fully fine-tuned MedSAM model across all prompt variations. With the proposed SD-Trans and HyP-Adpt, we outperforms MedSAM by updating only 2% of its total turnable parameters (13M v.s. 636M), which highlights the effectiveness of the proposed techniques.

When comparing the performance of different prompts in interactive segmentation models (SAM, MedSAM, Med-SA), we notice that 3-points prompts slightly outperform 1-point prompts. BBox 0.75 often performs comparably or better than 3-point prompts. However, it is important to note that BBox 0.5 yields subpar performance, indicating the significance of accurate bounding box annotations for achieving performance improvements. All interactive models, including SAM, MedSAM, and Med-SA, exhibit similar behavior across different prompts, demonstrating consistency in their response to prompts.

Considering SAM’s performance in 1, we observe that SAM’s zero-shot performance is generally inferior to that of fully-trained models (e.g., MedSAM (Ma and Wang 2023)) in the target medical image segmentation tasks, regardless of the prompt used. While this comparison may seem unfair, as we are comparing SAM’s zero-shot performance with fully-trained medical image models, SAM has demonstrated superior zero-shot performance in nature image datasets. This indicates that SAM’s zero-shot transferability is less effective for medical images compared to nature image segmentation, which has also been observed in previous studies (Deng et al. 2023; Roy et al. 2023; He et al. 2023). This finding emphasizes the need for specific techniques to adapt SAM to medical image segmentation.

3 presents a qualitative comparison of the performance between Med-SA and SAM. From the figure, it can be observed that Med-SA segments accurately on parts that are difficult to recognize by the human eye. Conversely, SAM fails in many cases where the organ boundaries are visually clear. This further underscores the necessity of fine-tuning a general segmentation model on medical images to achieve optimal performance.

Comparing with SOTA on Multi-modality Images

We also compared Med-SA to specifically optimized segmentation methods across three medical image segmentation tasks with different image modalities. The results are presented in 2. In the table, ResUnet(Yu et al. 2019) and BEAL(Wang et al. 2019) are proposed for optic cup segmentation, TransBTS(Wang et al. 2021b) and EnsemDiff(Wolleb et al. 2021) are proposed for brain tumor segmentation, MTSeg(Gong et al. 2021) and UltraUNet(Chu, Zheng, and Zhou 2021) are proposed for thyroid nodule segmentation, and FAT-Net(Wu et al. 2022) and BAT(Wang et al. 2021a) are proposed for melanoma segmentation. SegDiff, nnUNet, TransUNet, UNetr, and Swin-UNetr are proposed for general medical image segmentation. The segmentation performance was evaluated using Dice score, IoU, and HD95 metrics.

From the table we can see that these specifically optimized methods often perform well within their respective domains but experience drops in performance when applied to other domains. For example, UltraUNet achieves the previous SOTA for thyroid nodule segmentation but performs the worst in optic disc segmentation compared to the other methods. On the other hand, general methods often achieve good results across most modalities but fail to outperform specialized methods in specific tasks such as brain tumor segmentation and thyroid nodule segmentation.

Turning our attention to the interactive models, SAM and MedSAM, we observe that zero-shot SAM struggles with organs/tissues that have ambiguous boundaries in medical images, such as optic disc/cup segmentation or thyroid nodule segmentation. In terms of fully fine-tuned MedSAM, it falls short in brain tumor segmentation due to its limitations in 3D image processing. However, our Med-SA achieves SOTA performance across all segmentation tasks, demonstrating its ability to generalize to various medical segmentation tasks and image modalities. On the widely-used BraTs benchmark, thanks to its adaptability to 3D images, Med-SA outperforms the previous SOTA Swin-UNetr by 2.1% in Dice score and 1.86 in HD95 metric while utilizing less than 10% of its turnable parameters.

Ablation Study

We conducted a comprehensive ablation study to validate the effectiveness of the proposed SD-Trans and HyP-Adpt. The results are presented in 3, where the baseline (first line) represents a simple combination of SAM and the original Adaption method. In the baseline setting, 3D images are treated as a sequence of 2D images and processed individually, without involving prompts in the Adaption process. As shown in the table, our 2D to 3D design significantly enhances the performance compared to the vanilla SAM plus Adaption setting on both 3D data benchmarks (BTCV and BrainTumor). This improvement highlights the effectiveness of our proposed 2D to 3D design. In the Prompt-conditional Adaption, we compared HyP-Adpt with two simpler alternatives: addition and concatenation, for combining the prompt embedding. While addition and concatenation also show some effectiveness, the improvements achieved are still marginal. On the other hand, using the proposed HyP-Adpt leads to a significant enhancement in performance, further validating the effectiveness of our proposed HyP-Adpt design.

Conclusion

In this paper, we have extended SAM, a powerful general segmentation model, to address medical image segmentation, introducing Med-SA. Leveraging parameter-efficient adaptation with simple yet effective SD-Trans and HyP-Adpt, we have achieved substantial improvements over the original SAM model. Our approach has resulted in SOTA performance across 17 medical image segmentation tasks spanning 5 different image modalities. We anticipate that this work will serve as a stepping stone towards advancing foundation medical image segmentation and inspire the development of novel fine-tuning techniques.

References

  • Amit et al. (2021) Amit, T.; Nachmani, E.; Shaharbany, T.; and Wolf, L. 2021. Segdiff: Image segmentation with diffusion probabilistic models. arXiv preprint arXiv:2112.00390.
  • Baid et al. (2021) Baid, U.; Ghodasara, S.; Mohan, S.; Bilello, M.; Calabrese, E.; Colak, E.; Farahani, K.; Kalpathy-Cramer, J.; Kitamura, F. C.; Pati, S.; et al. 2021. The rsna-asnr-miccai brats 2021 benchmark on brain tumor segmentation and radiogenomic classification. arXiv preprint arXiv:2107.02314.
  • Chen et al. (2021a) Chen, J.; Lu, Y.; Yu, Q.; Luo, X.; Adeli, E.; Wang, Y.; Lu, L.; Yuille, A. L.; and Zhou, Y. 2021a. Transunet: Transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306.
  • Chen et al. (2022) Chen, S.; Ge, C.; Tong, Z.; Wang, J.; Song, Y.; Wang, J.; and Luo, P. 2022. Adaptformer: Adapting vision transformers for scalable visual recognition. arXiv preprint arXiv:2205.13535.
  • Chen et al. (2021b) Chen, X.; Zhao, Z.; Yu, F.; Zhang, Y.; and Duan, M. 2021b. Conditional diffusion for interactive segmentation. In Proceedings of the IEEE/CVF International Conference on Computer Vision, 7345–7354.
  • Chu, Zheng, and Zhou (2021) Chu, C.; Zheng, J.; and Zhou, Y. 2021. Ultrasonic thyroid nodule detection method based on U-Net network. Computer Methods and Programs in Biomedicine, 199: 105906.
  • Deng et al. (2023) Deng, R.; Cui, C.; Liu, Q.; Yao, T.; Remedios, L. W.; Bao, S.; Landman, B. A.; Wheless, L. E.; Coburn, L. A.; Wilson, K. T.; et al. 2023. Segment anything model (sam) for digital pathology: Assess zero-shot segmentation on whole slide imaging. arXiv preprint arXiv:2304.04155.
  • Fang et al. (2022) Fang, H.; Li, F.; Fu, H.; Sun, X.; Cao, X.; Son, J.; Yu, S.; Zhang, M.; Yuan, C.; Bian, C.; et al. 2022. REFUGE2 Challenge: Treasure for Multi-Domain Learning in Glaucoma Assessment. arXiv preprint arXiv:2202.08994.
  • Fang and Yan (2020) Fang, X.; and Yan, P. 2020. Multi-organ segmentation over partially labeled datasets with multi-scale feature abstraction. IEEE Transactions on Medical Imaging, 39(11): 3619–3629.
  • Forte et al. (2020) Forte, M.; Price, B.; Cohen, S.; Xu, N.; and Pitié, F. 2020. Getting to 99% accuracy in interactive segmentation. arXiv preprint arXiv:2003.07932.
  • Gong et al. (2021) Gong, H.; Chen, G.; Wang, R.; Xie, X.; Mao, M.; Yu, Y.; Chen, F.; and Li, G. 2021. Multi-task learning for thyroid nodule segmentation with thyroid region prior. In 2021 IEEE 18th International Symposium on Biomedical Imaging (ISBI), 257–261. IEEE.
  • Grady (2006) Grady, L. 2006. Random walks for image segmentation. IEEE transactions on pattern analysis and machine intelligence, 28(11): 1768–1783.
  • Gulshan et al. (2010) Gulshan, V.; Rother, C.; Criminisi, A.; Blake, A.; and Zisserman, A. 2010. Geodesic star convexity for interactive image segmentation. In 2010 IEEE Computer Society Conference on Computer Vision and Pattern Recognition, 3129–3136. IEEE.
  • Ha, Dai, and Le (2016) Ha, D.; Dai, A.; and Le, Q. V. 2016. HyperNetworks. arXiv:1609.09106.
  • Hatamizadeh et al. (2022a) Hatamizadeh, A.; Nath, V.; Tang, Y.; Yang, D.; Roth, H. R.; and Xu, D. 2022a. Swin unetr: Swin transformers for semantic segmentation of brain tumors in mri images. In International MICCAI Brainlesion Workshop, 272–284. Springer.
  • Hatamizadeh et al. (2022b) Hatamizadeh, A.; Tang, Y.; Nath, V.; Yang, D.; Myronenko, A.; Landman, B.; Roth, H. R.; and Xu, D. 2022b. Unetr: Transformers for 3d medical image segmentation. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, 574–584.
  • He et al. (2023) He, S.; Bao, R.; Li, J.; Grant, P. E.; and Ou, Y. 2023. Accuracy of Segment-Anything Model (SAM) in medical image segmentation tasks. arXiv preprint arXiv:2304.09324.
  • He et al. (2022) He, X.; Li, C.; Zhang, P.; Yang, J.; and Wang, X. E. 2022. Parameter-efficient fine-tuning for vision transformers. arXiv preprint arXiv:2203.16329.
  • Hu et al. (2021) Hu, E. J.; Shen, Y.; Wallis, P.; Allen-Zhu, Z.; Li, Y.; Wang, S.; Wang, L.; and Chen, W. 2021. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685.
  • Isensee et al. (2021) Isensee, F.; Jaeger, P. F.; Kohl, S. A.; Petersen, J.; and Maier-Hein, K. H. 2021. nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature methods, 18(2): 203–211.
  • Kim, Lee, and Lee (2010) Kim, T. H.; Lee, K. M.; and Lee, S. U. 2010. Nonparametric higher-order learning for interactive segmentation. In 2010 IEEE computer society conference on computer vision and pattern recognition, 3201–3208. IEEE.
  • Kirillov et al. (2023) Kirillov, A.; Mintun, E.; Ravi, N.; Mao, H.; Rolland, C.; Gustafson, L.; Xiao, T.; Whitehead, S.; Berg, A. C.; Lo, W.-Y.; et al. 2023. Segment anything. arXiv preprint arXiv:2304.02643.
  • Li, Chen, and Koltun (2018) Li, Z.; Chen, Q.; and Koltun, V. 2018. Interactive image segmentation with latent diversity. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 577–585.
  • Liew et al. (2019) Liew, J. H.; Cohen, S.; Price, B.; Mai, L.; Ong, S.-H.; and Feng, J. 2019. Multiseg: Semantically meaningful, scale-diverse segmentations from minimal user input. In Proceedings of the IEEE/CVF International Conference on Computer Vision, 662–670.
  • Lin et al. (2020) Lin, Z.; Zhang, Z.; Chen, L.-Z.; Cheng, M.-M.; and Lu, S.-P. 2020. Interactive image segmentation with first click attention. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, 13339–13348.
  • Liu et al. (2019) Liu, Y.; Lu, Z.; Li, J.; Yang, T.; and Yao, C. 2019. Deep image-to-video adaptation and fusion networks for action recognition. IEEE Transactions on Image Processing, 29: 3168–3182.
  • Ma and Wang (2023) Ma, J.; and Wang, B. 2023. Segment Anything in Medical Images. arXiv preprint arXiv:2304.12306.
  • Ma et al. (2017) Ma, J.; Wu, F.; Jiang, T.; Zhao, Q.; and Kong, D. 2017. Ultrasound image-based thyroid nodule automatic segmentation using convolutional neural networks. International journal of computer assisted radiology and surgery, 12(11): 1895–1910.
  • Mahadevan, Voigtlaender, and Leibe (2018) Mahadevan, S.; Voigtlaender, P.; and Leibe, B. 2018. Iteratively trained interactive segmentation. arXiv preprint arXiv:1805.04398.
  • Milton (2019) Milton, M. A. A. 2019. Automated skin lesion classification using ensemble of deep neural networks in isic 2018: Skin lesion analysis towards melanoma detection challenge. arXiv preprint arXiv:1901.10802.
  • Pedraza et al. (2015) Pedraza, L.; Vargas, C.; Narváez, F.; Durán, O.; Muñoz, E.; and Romero, E. 2015. An open access thyroid ultrasound image database. In 10th International Symposium on Medical Information Processing and Analysis, volume 9287, 92870W. International Society for Optics and Photonics.
  • Raghu et al. (2019) Raghu, M.; Zhang, C.; Kleinberg, J.; and Bengio, S. 2019. Transfusion: Understanding transfer learning for medical imaging. Advances in neural information processing systems, 32.
  • Rother, Kolmogorov, and Blake (2004) Rother, C.; Kolmogorov, V.; and Blake, A. 2004. " GrabCut" interactive foreground extraction using iterated graph cuts. ACM transactions on graphics (TOG), 23(3): 309–314.
  • Roy et al. (2023) Roy, S.; Wald, T.; Koehler, G.; Rokuss, M. R.; Disch, N.; Holzschuh, J.; Zimmerer, D.; and Maier-Hein, K. H. 2023. SAM. MD: Zero-shot medical image segmentation capabilities of the Segment Anything Model. arXiv preprint arXiv:2304.05396.
  • Sofiiuk, Petrov, and Konushin (2022) Sofiiuk, K.; Petrov, I. A.; and Konushin, A. 2022. Reviving iterative training with mask guidance for interactive segmentation. In 2022 IEEE International Conference on Image Processing (ICIP), 3141–3145. IEEE.
  • Wang et al. (2021a) Wang, J.; Wei, L.; Wang, L.; Zhou, Q.; Zhu, L.; and Qin, J. 2021a. Boundary-aware transformers for skin lesion segmentation. In Medical Image Computing and Computer Assisted Intervention–MICCAI 2021: 24th International Conference, Strasbourg, France, September 27–October 1, 2021, Proceedings, Part I 24, 206–216. Springer.
  • Wang et al. (2019) Wang, S.; Yu, L.; Li, K.; Yang, X.; Fu, C.-W.; and Heng, P.-A. 2019. Boundary and entropy-driven adversarial learning for fundus image segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, 102–110. Springer.
  • Wang et al. (2021b) Wang, W.; Chen, C.; Ding, M.; Yu, H.; Zha, S.; and Li, J. 2021b. Transbts: Multimodal brain tumor segmentation using transformer. In International Conference on Medical Image Computing and Computer-Assisted Intervention, 109–119. Springer.
  • Wolleb et al. (2021) Wolleb, J.; Sandkühler, R.; Bieder, F.; Valmaggia, P.; and Cattin, P. C. 2021. Diffusion Models for Implicit Image Segmentation Ensembles. arXiv preprint arXiv:2112.03145.
  • Wu et al. (2022) Wu, H.; Chen, S.; Chen, G.; Wang, W.; Lei, B.; and Wen, Z. 2022. FAT-Net: Feature adaptive transformers for automated skin lesion segmentation. Medical image analysis, 76: 102327.
  • Xie and Richmond (2018) Xie, Y.; and Richmond, D. 2018. Pre-training on grayscale imagenet improves medical image classification. In Proceedings of the European conference on computer vision (ECCV) workshops, 0–0.
  • Xing et al. (2023) Xing, Z.; Wan, L.; Fu, H.; Yang, G.; and Zhu, L. 2023. Diff-UNet: A Diffusion Embedded Network for Volumetric Segmentation. arXiv preprint arXiv:2303.10326.
  • Xu et al. (2016) Xu, N.; Price, B.; Cohen, S.; Yang, J.; and Huang, T. S. 2016. Deep interactive object selection. In Proceedings of the IEEE conference on computer vision and pattern recognition, 373–381.
  • Yu et al. (2019) Yu, S.; Xiao, D.; Frost, S.; and Kanagasingam, Y. 2019. Robust optic disc and cup segmentation with deep learning for glaucoma detection. Computerized Medical Imaging and Graphics, 74: 61–71.
  • Zaken, Ravfogel, and Goldberg (2021) Zaken, E. B.; Ravfogel, S.; and Goldberg, Y. 2021. Bitfit: Simple parameter-efficient fine-tuning for transformer-based masked language-models. arXiv preprint arXiv:2106.10199.