Skip to content

Commit

Permalink
Merge pull request #28 from KabirSubbiah/main
Browse files Browse the repository at this point in the history
build_sam_with_ckpt
  • Loading branch information
luca-medeiros committed Aug 23, 2023
2 parents 2ebcd00 + 6c5482e commit fcdb92f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ text_prompt = "wheel"
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
```

Use with custom checkpoint:

First download a model checkpoint.

```python
from PIL import Image
from lang_sam import LangSAM

model = LangSAM("<model_type>", "<path/to/checkpoint>")
image_pil = Image.open("./assets/car.jpeg").convert("RGB")
text_prompt = "wheel"
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
```

## Examples

![car.png](/assets/outputs/car.png)
Expand Down
43 changes: 28 additions & 15 deletions lang_sam/lang_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,24 +49,37 @@ def transform_image(image) -> torch.Tensor:

class LangSAM():

def __init__(self, sam_type="vit_h"):
def __init__(self, sam_type="vit_h", ckpt_path=None):
self.sam_type = sam_type
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.build_groundingdino()
self.build_sam(sam_type)

def build_sam(self, sam_type):
checkpoint_url = SAM_MODELS[sam_type]
try:
sam = sam_model_registry[sam_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
except:
raise ValueError(f"Problem loading SAM please make sure you have the right model type: {sam_type} \
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
re-downloading it.")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
self.build_sam(ckpt_path)

def build_sam(self, ckpt_path):
if self.sam_type is None or ckpt_path is None:
if self.sam_type is None:
print("No sam type indicated. Using vit_h by default.")
self.sam_type = "vit_h"
checkpoint_url = SAM_MODELS[self.sam_type]
try:
sam = sam_model_registry[self.sam_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
except:
raise ValueError(f"Problem loading SAM please make sure you have the right model type: {self.sam_type} \
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
re-downloading it.")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
else:
try:
sam = sam_model_registry[self.sam_type](ckpt_path)
except:
raise ValueError(f"Problem loading SAM. Your model type: {self.sam_type} \
should match your checkpoint path: {ckpt_path}. Recommend calling LangSAM \
using matching model type AND checkpoint path")
sam.to(device=self.device)
self.sam = SamPredictor(sam)

def build_groundingdino(self):
ckpt_repo_id = "ShilongLiu/GroundingDINO"
Expand Down

0 comments on commit fcdb92f

Please sign in to comment.