Skip to content

Commit

Permalink
Added polygon filter method to reduce number of polygon given gven fr…
Browse files Browse the repository at this point in the history
…om SAM predictions. (#142)
  • Loading branch information
samz5320 authored Aug 4, 2023
1 parent 5846726 commit 27baddc
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
24 changes: 24 additions & 0 deletions label_anything/sam/filter_poly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
import math

class NearNeighborRemover:
def __init__(self,distance_threshold):
self.distance_threshold = distance_threshold

def calculate_distance(self, point1, point2):
x1, y1 = point1
x2, y2 = point2
return math.sqrt((x2 - x1)**2 + (y2 - y1)**2)

def remove_near_neighbors(self, points):
filtered_points = [points[0]] # Add the first point to the filtered list
for i in range(1, len(points)):
# Calculate the distance between the current point and the last added point
distance = self.calculate_distance(points[i], filtered_points[-1])
# If the distance is above the threshold, add the current point to the filtered list
if distance >= self.distance_threshold:
filtered_points.append(points[i])
return filtered_points



5 changes: 3 additions & 2 deletions label_anything/sam/mmdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from label_studio_ml.utils import (DATA_UNDEFINED_NAME, get_image_size,
get_single_tag_keys)
from label_studio_tools.core.utils.io import get_data_dir
from filter_poly import NearNeighborRemover
# from mmdet.apis import inference_detector, init_detector

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -268,7 +269,7 @@ def predict(self, tasks, **kwargs):
points.append([float(x)/original_width*100,
float(y)/original_height * 100])
points_list.extend(points)

filterd_points=NearNeighborRemover(distance_threshold=0.4).remove_near_neighbors(points_list) # remove near neighbors (increase distance_threshold to reduce more points)
# interval = points_list.__len__()//128

# points_list = points_list[::points_list.__len__()//40]
Expand All @@ -279,7 +280,7 @@ def predict(self, tasks, **kwargs):
"original_height": original_height,
# "image_rotation": 0,
"value": {
"points": points_list,
"points": filterd_points,
"polygonlabels": [output_label],
},
"type": "polygonlabels",
Expand Down

0 comments on commit 27baddc

Please sign in to comment.