Skip to content

Commit

Permalink
Add self similarity classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Apr 22, 2024
1 parent 4de1582 commit e4e3982
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 1 deletion.
2 changes: 2 additions & 0 deletions aisploit/classifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .markdown import MarkdownInjectionClassifier
from .package_hallucination import PythonPackageHallucinationClassifier
from .self_similarity import SelfSimilarityClassifier
from .text import RegexClassifier, SubstringClassifier, TextTokenClassifier

__all__ = [
"MarkdownInjectionClassifier",
"PythonPackageHallucinationClassifier",
"RegexClassifier",
"SubstringClassifier",
"SelfSimilarityClassifier",
"TextTokenClassifier",
]
50 changes: 50 additions & 0 deletions aisploit/classifiers/self_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from dataclasses import dataclass, field
from typing import List

from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

from ..core import BaseTextClassifier, Score


@dataclass
class SelfSimilarityClassifier(BaseTextClassifier[float]):
"""A text classifier based on self-similarity using cosine similarity scores."""

model_name_or_path: str = "all-MiniLM-L6-v2"
threshold: float = 0.7
tags: List[str] = field(default_factory=lambda: ["hallucination"], init=False)

def __post_init__(self) -> None:
"""Initialize the SentenceTransformer model."""
self._model = SentenceTransformer(self.model_name_or_path)

def score(self, input: str, references: List[str] | None = None) -> Score[float]:
"""Score the input text based on its self-similarity to reference texts.
Args:
input (str): The input text to be scored.
references (List[str], optional): List of reference texts. Defaults to None.
Raises:
ValueError: If references is None or if the number of references is not at least 1.
Returns:
Score[float]: A Score object representing the self-similarity score of the input.
"""
if not references or not len(references) >= 1:
raise ValueError("The number of references must be at least 1.")

input_embeddings = self._model.encode(input, convert_to_tensor=True)
references_embeddings = self._model.encode(references, convert_to_tensor=True)

cos_scores = cos_sim(input_embeddings, references_embeddings)[0]

score = cos_scores.mean()

return Score[float](
flagged=(score < self.threshold).item(),
value=score.item(),
description="Returns True if the cosine similarity score is less than the threshold",
explanation=f"The cosine similarity score for the input is {score}",
)
29 changes: 29 additions & 0 deletions examples/classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"source": [
"import textwrap\n",
"from dotenv import load_dotenv\n",
"from aisploit.classifiers import SelfSimilarityClassifier\n",
"from aisploit.classifiers.presidio import PresidioAnalyserClassifier\n",
"from aisploit.classifiers.huggingface import (\n",
" BleuClassifier,\n",
Expand All @@ -41,6 +42,34 @@
"load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Score(flagged=True, value=0.20951247215270996, description='Returns True if the cosine similarity score is less than the threshold', explanation='The cosine similarity score for the input is 0.20951247215270996')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classifier = SelfSimilarityClassifier()\n",
"classifier.score(\n",
" \"The sky turned green and the trees began to sing in a chorus of laughter.\", \n",
" [\n",
" \"As I looked around, I noticed that the buildings were made of candy, and the streets were paved with shimmering gold.\",\n",
" \"I found myself in a library unlike any other, where the books flew off the shelves and started telling stories of their own.\",\n",
" \"Every person I met had wings, and they soared through the air with grace and elegance, leaving trails of glitter in their wake.\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
121 changes: 120 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pillow = "^10.3.0"
tqdm = "^4.66.2"
evaluate = "^0.4.1"
bert-score = "^0.3.13"
sentence-transformers = "^2.7.0"

[tool.poetry.group.dev.dependencies]
chromadb = "^0.4.23"
Expand Down

0 comments on commit e4e3982

Please sign in to comment.