Skip to content

Commit

Permalink
add filters
Browse files Browse the repository at this point in the history
  • Loading branch information
alvin319 committed May 29, 2023
1 parent 4128486 commit 0aaa982
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
57 changes: 57 additions & 0 deletions filters/highly_duplicated_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from collections import Counter
from typing import Callable, List

import pandas as pd

def _concat_token_indices(token_indices: List[int], delimiter: str = '_') -> str:
"""
Concatenates a list of tokens into a single string.
Args:
token_indices (List[int]): List of token indices to concatenate.
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'.
Returns:
str: Concatenated string of tokens indices.
"""
return delimiter.join([str(t) for t in token_indices])

def generate_token_string_histogram(token_series: pd.Series, delimiter: str = '_') -> Counter:
"""
Generates a histogram from a Pandas Series of token indices. The histogram is based on the concatenated strings of token indices.
Args:
token_series (pd.Series): Series of token indices.
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'.
Returns:
Counter: Histogram of strings of token indices.
"""
return Counter(token_series.apply(lambda x: _concat_token_indices(x, delimiter=delimiter)))

def get_highly_duplicated_filter_func(histogram: Counter, frequency_threshold: int = 1, delimiter: str = '_') -> Callable[[List[int]], bool]:
"""
Generates a filter function that checks if a list of token indices is highly duplicated based on a threshold.
Args:
histogram (Counter): Histogram of strings of token indices.
frequency_threshold (int, optional): Frequency threshold to use for filtering. Defaults to 1.
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'.
Returns:
Callable[[List[int]], bool]: Filter function that checks if a list of token indices is highly duplicated based on a threshold.
"""
def _highly_duplicated_filter_func(token_indices: List[int]) -> bool:
"""
Checks if a list of token indices is highly duplicated.
Args:
token_indices (List[int]): List of token indices to check.
Returns:
bool: True if the list of token indices is highly duplicated, False otherwise.
"""
token_string = _concat_token_indices(token_indices, delimiter=delimiter)
return histogram[token_string] > frequency_threshold

return _highly_duplicated_filter_func
30 changes: 30 additions & 0 deletions filters/test_highly_duplicated_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pandas as pd

from .highly_duplicated_filter import get_highly_duplicated_filter_func, generate_token_string_histogram

def test_highly_duplicated_filter_on_seen_indices():
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]])
histogram = generate_token_string_histogram(data)
threshold = 1
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold)

sample = [4, 5, 6]
assert filter_func(sample) == True

def test_highly_duplicated_filter_on_unseen_indices():
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]])
histogram = generate_token_string_histogram(data)
threshold = 1
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold)

sample = [7, 8, 9]
assert filter_func(sample) == False

def test_highly_duplicated_filter_on_infrequent_indices():
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]])
histogram = generate_token_string_histogram(data)
threshold = 2
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold)

sample = [4, 5, 6]
assert filter_func(sample) == False

0 comments on commit 0aaa982

Please sign in to comment.