Skip to content

Commit

Permalink
Parameters for hedging language are now editable (arthur-ai#100)
Browse files Browse the repository at this point in the history
* Parameters for hedging language now editable

* Black reformatting

* Added type hints and hedging lang in to_dict
  • Loading branch information
Mymoza committed Nov 28, 2023
1 parent e983a3d commit c5e453e
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions arthur_bench/scoring/hedging_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

DEFAULT_MODEL = "microsoft/deberta-v3-base"

# [TODO] need to make these editable by user
DEFAULT_HEDGE = (
"As an AI language model, I don't have personal opinions, emotions, or beliefs."
)
Expand All @@ -20,6 +19,20 @@ class HedgingLanguage(Scorer):
model output.
"""

def __init__(
self, model_type: str = DEFAULT_MODEL, hedging_language: str = DEFAULT_HEDGE
):
"""
Hedging Language score implementation.
:param model_type: the underlying language model to extract embeddings from
:param hedging_language: reference hedging language used by an llm
"""
self.hedging_language = hedging_language

with suppress_warnings("transformers"):
self.scorer = BERTScorer(lang="en", model_type=model_type)

@staticmethod
def name() -> str:
return "hedging_language"
Expand All @@ -28,12 +41,11 @@ def name() -> str:
def requires_reference() -> bool:
return False

def __init__(self):
with suppress_warnings("transformers"):
self.scorer = BERTScorer(lang="en", model_type=DEFAULT_MODEL)

def to_dict(self, warn=False):
return {"model_type": self.scorer.model_type}
return {
"model_type": self.scorer.model_type,
"hedging_language": self.hedging_language,
}

def run_batch(
self,
Expand All @@ -43,7 +55,7 @@ def run_batch(
context_batch: Optional[List[str]] = None,
) -> List[float]:
# convert reference hedge to list
reference_batch = [DEFAULT_HEDGE] * len(candidate_batch)
reference_batch = [self.hedging_language] * len(candidate_batch)

# get precision, recall, and F1 score from bert_score package
p, r, f = self.scorer.score(candidate_batch, reference_batch, verbose=False)
Expand Down

0 comments on commit c5e453e

Please sign in to comment.