Skip to content

Commit

Permalink
fix comparison of transformers version
Browse files Browse the repository at this point in the history
  • Loading branch information
João Nadkarni committed Jan 9, 2022
1 parent c2b1a47 commit e48e73f
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/ecco/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from operator import attrgetter
import re
from ecco.util import is_partial_token, strip_tokenizer_prefix
from packaging import version


class LM(object):
Expand Down Expand Up @@ -199,7 +200,7 @@ def generate(self, input_str: str,
# Get decoder input ids
if self.model_type == 'enc-dec': # FIXME: only done because causal LMs like GPT-2 have the _prepare_decoder_input_ids_for_generation method but do not use it
assert len(input_ids.size()) == 2 # will break otherwise
if transformers.__version__ >= '4.13': # ALSO FIXME: awful hack. But seems to work?
if version.parse(transformers.__version__) >= version.parse('4.13'): # ALSO FIXME: awful hack. But seems to work?
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids.shape[0], None, None)
else:
decoder_input_ids = self.model._prepare_decoder_input_ids_for_generation(input_ids, None, None)
Expand Down

0 comments on commit e48e73f

Please sign in to comment.