Skip to content

Commit

Permalink
Fix openai eval bug (#78)
Browse files Browse the repository at this point in the history
* fix bug

* make blue happy

---------

Co-authored-by: wangyuxin <[email protected]>
  • Loading branch information
wangyuxinwhy and wangyuxin committed Aug 4, 2023
1 parent 72fc675 commit 57b0dfa
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mteb-zh/mteb_zh/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,16 @@ def __init__(
self,
api_key: Optional[str] = None,
model_name: str = 'text-embedding-ada-002',
max_length: int = 4000,
) -> None:
if api_key is not None:
openai.api_key = api_key
self._client = openai.Embedding.create
self.model_name = model_name
self.max_length = max_length

def encode(self, sentences: list[str], batch_size: int = 32, **kwargs) -> list[np.ndarray]:
sentences = [sentence[: self.max_length] for sentence in sentences]
all_embeddings = []
for batch in tqdm(
generate_batch(sentences, batch_size),
Expand All @@ -128,15 +131,17 @@ def encode(self, sentences: list[str], batch_size: int = 32, **kwargs) -> list[n


class AzureModel:
def __init__(self, model_name: str = 'text-embedding-ada-002') -> None:
def __init__(self, model_name: str = 'text-embedding-ada-002', max_length: int = 4000) -> None:
openai.api_type = 'azure'
openai.api_key = os.environ['AZURE_API_KEY']
openai.api_base = os.environ['AZURE_API_BASE']
openai.api_version = '2023-03-15-preview'
self._client = openai.Embedding.create
self.model_name = model_name
self.max_length = max_length

def encode(self, sentences: list[str], batch_size: int = 32, **kwargs) -> list[np.ndarray]:
sentences = [sentence[: self.max_length] for sentence in sentences]
all_embeddings = []
for text in tqdm(sentences):
output = self._client(input=text, engine=self.model_name)
Expand Down

0 comments on commit 57b0dfa

Please sign in to comment.