Skip to content

Commit

Permalink
Merge pull request #2 from artyaltanzaya/main
Browse files Browse the repository at this point in the history
Cmdclass install doesn't run on pip install.
  • Loading branch information
capjamesg committed Jun 22, 2023
2 parents 095bba7 + 2a03c5f commit 316e10f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 28 deletions.
2 changes: 1 addition & 1 deletion autodistill_detic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from autodistill_detic.detic_model import DETIC

__version__ = "0.1.1"
__version__ = "0.1.3"
42 changes: 36 additions & 6 deletions autodistill_detic/detic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
import supervision as sv
import torch
from autodistill.detection import CaptionOntology, DetectionBaseModel

import subprocess
import argparse
import multiprocessing as mp
import os
import sys

from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger

VOCAB = "custom"
CONFIDENCE_THRESHOLD = 0.3
Expand Down Expand Up @@ -66,6 +63,40 @@ def load_detic_model(ontology):
HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def check_dependencies():
# Create the ~/.cache/autodistill directory if it doesn't exist
autodistill_dir = os.path.expanduser("~/.cache/autodistill")
os.makedirs(autodistill_dir, exist_ok=True)

os.chdir(autodistill_dir)

try:
import detectron2
except ImportError:
subprocess.run(["pip", "install", "git+https://github.com/facebookresearch/detectron2.git"])

# Check if Detic is installed
detic_path = os.path.join(autodistill_dir, "Detic")
if not os.path.isdir(detic_path):
subprocess.run(["git", "clone", "https://github.com/facebookresearch/Detic.git", "--recurse-submodules"])

os.chdir(detic_path)

subprocess.run(["pip", "install", "-r", "requirements.txt"])

models_dir = os.path.join(detic_path, "models")
os.makedirs(models_dir, exist_ok=True)

model_url = "https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth"
model_path = os.path.join(models_dir, "Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth")
subprocess.run(["wget", model_url, "-O", model_path])

check_dependencies()

from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger

@dataclass
class DETIC(DetectionBaseModel):
ontology: CaptionOntology
Expand All @@ -75,7 +106,6 @@ def __init__(self, ontology: CaptionOntology):
original_dir = os.getcwd()

sys.path.insert(0, HOME + "/.cache/autodistill/Detic/third_party/CenterNet2/")

sys.path.insert(0, HOME + "/.cache/autodistill/Detic/")
os.chdir(HOME + "/.cache/autodistill/Detic/")

Expand Down Expand Up @@ -113,4 +143,4 @@ def predict(self, input: str) -> sv.Detections:
xyxy=np.array(final_pred_boxes),
class_id=np.array(final_pred_classes),
confidence=np.array(final_pred_scores),
)
)
22 changes: 1 addition & 21 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,6 @@
with open("README.md", "r") as fh:
long_description = fh.read()

class AutodistillDetic(install):
def run(self):
install.run(self)
installation_commands = """
mkdir -p ~/.cache/autodistill/ &&
cd ~/.cache/autodistill/ &&
pip install 'git+https://github.com/facebookresearch/detectron2.git' &&
git clone https://github.com/facebookresearch/Detic.git --recurse-submodules &&
cd Detic &&
pip install -r requirements.txt &&
mkdir models &&
wget https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth -O models/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth
"""

subprocess.run(installation_commands, shell=True, executable="/bin/bash")


setuptools.setup(
name="autodistill_detic",
version=version,
Expand All @@ -43,9 +26,6 @@ def run(self):
"numpy",
"autodistill",
],
cmdclass={
'install': AutodistillDetic,
},
packages=find_packages(exclude=("tests",)),
extras_require={
"dev": ["flake8", "black==22.3.0", "isort", "twine", "pytest", "wheel"],
Expand All @@ -56,4 +36,4 @@ def run(self):
"Operating System :: OS Independent",
],
python_requires=">=3.7",
)
)

0 comments on commit 316e10f

Please sign in to comment.