Skip to content

Commit

Permalink
build_sam_with_ckpt
Browse files Browse the repository at this point in the history
  • Loading branch information
KabirSubbiah committed Aug 4, 2023
1 parent 2ebcd00 commit f3c4058
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ text_prompt = "wheel"
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
```

Use with custom checkpoint:

First download a model checkpoint.

```python
from PIL import Image
from lang_sam import LangSAM

model = LangSAM("<model_type>", "<path/to/checkpoint>")
image_pil = Image.open("./assets/car.jpeg").convert("RGB")
text_prompt = "wheel"
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
```

## Examples

![car.png](/assets/outputs/car.png)
Expand Down
17 changes: 15 additions & 2 deletions lang_sam/lang_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,24 @@ def transform_image(image) -> torch.Tensor:

class LangSAM():

def __init__(self, sam_type="vit_h"):
def __init__(self, sam_type="vit_h", ckpt_path=None):
self.sam_type = sam_type
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.build_groundingdino()
self.build_sam(sam_type)
if ckpt_path == None:
self.build_sam(sam_type)
else:
self.build_sam_with_ckpt(sam_type, ckpt_path)

def build_sam_with_ckpt(self, sam_type, ckpt_path):
try:
sam = sam_model_registry[sam_type](ckpt_path)
except:
raise ValueError(f"Problem loading SAM. Your model type: {sam_type} \
should match your checkpoint path: {ckpt_path}. Recommend calling LangSAM \
using matching model type AND checkpoint path")
sam.to(device=self.device)
self.sam = SamPredictor(sam)

def build_sam(self, sam_type):
checkpoint_url = SAM_MODELS[sam_type]
Expand Down

0 comments on commit f3c4058

Please sign in to comment.