diff --git a/promptbench/models/__init__.py b/promptbench/models/__init__.py index 1bb98b6..e55b033 100644 --- a/promptbench/models/__init__.py +++ b/promptbench/models/__init__.py @@ -68,7 +68,7 @@ class LLMModel(object): def model_list(): return SUPPORTED_MODELS - def __init__(self, model, max_new_tokens=20, temperature=0, device="cuda", dtype="auto", model_dir=None, system_prompt=None, api_key=None): + def __init__(self, model: str, max_new_tokens: int=20, temperature: float=0.0, device: str="cuda", dtype: str="auto", model_dir: str=None, system_prompt: str=None, api_key:str =None): self.model_name = model self.model = self._create_model(max_new_tokens, temperature, device, dtype, model_dir, system_prompt, api_key)