-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
87 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from collections import Counter | ||
from typing import Callable, List | ||
|
||
import pandas as pd | ||
|
||
def _concat_token_indices(token_indices: List[int], delimiter: str = '_') -> str: | ||
""" | ||
Concatenates a list of tokens into a single string. | ||
Args: | ||
token_indices (List[int]): List of token indices to concatenate. | ||
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'. | ||
Returns: | ||
str: Concatenated string of tokens indices. | ||
""" | ||
return delimiter.join([str(t) for t in token_indices]) | ||
|
||
def generate_token_string_histogram(token_series: pd.Series, delimiter: str = '_') -> Counter: | ||
""" | ||
Generates a histogram from a Pandas Series of token indices. The histogram is based on the concatenated strings of token indices. | ||
Args: | ||
token_series (pd.Series): Series of token indices. | ||
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'. | ||
Returns: | ||
Counter: Histogram of strings of token indices. | ||
""" | ||
return Counter(token_series.apply(lambda x: _concat_token_indices(x, delimiter=delimiter))) | ||
|
||
def get_highly_duplicated_filter_func(histogram: Counter, frequency_threshold: int = 1, delimiter: str = '_') -> Callable[[List[int]], bool]: | ||
""" | ||
Generates a filter function that checks if a list of token indices is highly duplicated based on a threshold. | ||
Args: | ||
histogram (Counter): Histogram of strings of token indices. | ||
frequency_threshold (int, optional): Frequency threshold to use for filtering. Defaults to 1. | ||
delimiter (str, optional): Delimiter to use for concatenation. Defaults to '_'. | ||
Returns: | ||
Callable[[List[int]], bool]: Filter function that checks if a list of token indices is highly duplicated based on a threshold. | ||
""" | ||
def _highly_duplicated_filter_func(token_indices: List[int]) -> bool: | ||
""" | ||
Checks if a list of token indices is highly duplicated. | ||
Args: | ||
token_indices (List[int]): List of token indices to check. | ||
Returns: | ||
bool: True if the list of token indices is highly duplicated, False otherwise. | ||
""" | ||
token_string = _concat_token_indices(token_indices, delimiter=delimiter) | ||
return histogram[token_string] > frequency_threshold | ||
|
||
return _highly_duplicated_filter_func |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import pandas as pd | ||
|
||
from .highly_duplicated_filter import get_highly_duplicated_filter_func, generate_token_string_histogram | ||
|
||
def test_highly_duplicated_filter_on_seen_indices(): | ||
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]]) | ||
histogram = generate_token_string_histogram(data) | ||
threshold = 1 | ||
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold) | ||
|
||
sample = [4, 5, 6] | ||
assert filter_func(sample) == True | ||
|
||
def test_highly_duplicated_filter_on_unseen_indices(): | ||
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]]) | ||
histogram = generate_token_string_histogram(data) | ||
threshold = 1 | ||
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold) | ||
|
||
sample = [7, 8, 9] | ||
assert filter_func(sample) == False | ||
|
||
def test_highly_duplicated_filter_on_infrequent_indices(): | ||
data = pd.Series([[1, 2, 3], [4, 5, 6], [4, 5, 6]]) | ||
histogram = generate_token_string_histogram(data) | ||
threshold = 2 | ||
filter_func = get_highly_duplicated_filter_func(histogram, frequency_threshold=threshold) | ||
|
||
sample = [4, 5, 6] | ||
assert filter_func(sample) == False |