Skip to content

Commit

Permalink
Merge pull request #14 from luca-medeiros/f/state_caching
Browse files Browse the repository at this point in the history
  • Loading branch information
luca-medeiros committed May 16, 2023
2 parents 623d338 + 6b87006 commit 667660d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: check-executables-have-shebangs
- id: check-case-conflict
- id: check-added-large-files
args: ['--maxkb=350', '--enforce-all']
args: ['--maxkb=3500', '--enforce-all']
- id: detect-private-key

- repo: https://github.com/commitizen-tools/commitizen
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Language Segment-Anything is an open-source project that combines the power of i
pip install torch torchvision
pip install -U git+https://github.com/luca-medeiros/lang-segment-anything.git
```

Or
Clone the repository and nstall the required packages:

Expand All @@ -43,12 +44,12 @@ To run the Lightning AI APP:
Use as a library:

```python
from PIL import Image
from PIL import Image
from lang_sam import LangSAM

model = LangSAM()
image_pil = Image.open('./assets/car.jpeg').convert("RGB")
text_prompt = 'wheel'
image_pil = Image.open("./assets/car.jpeg").convert("RGB")
text_prompt = "wheel"
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
```

Expand Down
14 changes: 5 additions & 9 deletions lang_sam/lang_sam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from urllib import request

import groundingdino.datasets.transforms as T
import numpy as np
Expand Down Expand Up @@ -57,17 +56,14 @@ def __init__(self, sam_type="vit_h"):
self.build_sam(sam_type)

def build_sam(self, sam_type):
url = SAM_MODELS[sam_type]
sam_checkpoint = os.path.join(CACHE_PATH, os.path.basename(url))
if not os.path.exists(sam_checkpoint):
if not os.path.exists(CACHE_PATH):
os.makedirs(CACHE_PATH)
request.urlretrieve(url, sam_checkpoint)
checkpoint_url = SAM_MODELS[sam_type]
try:
sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
sam = sam_model_registry[sam_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
except:
raise ValueError(f"Problem loading SAM please make sure you have the right model type: {sam_type} \
and a working checkpoint: {sam_checkpoint}. Recommend deleting the checkpoint and \
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
re-downloading it.")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
Expand Down

0 comments on commit 667660d

Please sign in to comment.