Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💫 Fix loading of multiple vector models #2158

Merged
merged 10 commits into from
Mar 28, 2018
Prev Previous commit
Next Next commit
Set pretrained_vectors in begin_training
  • Loading branch information
honnibal committed Mar 28, 2018
commit 9bf6e93b3e61211468347fcfe6f3241c9d8d2e3b
6 changes: 5 additions & 1 deletion spacy/pipeline.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ class Tagger(Pipe):
vocab.morphology = Morphology(vocab.strings, new_tag_map,
vocab.morphology.lemmatizer,
exc=vocab.morphology.exc)
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
if self.model is True:
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
link_vectors_to_models(self.vocab)
Expand Down Expand Up @@ -910,12 +911,15 @@ class TextCategorizer(Pipe):
self.labels.append(label)
return 1

def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None):
def begin_training(self, gold_tuples=tuple(), pipeline=None, sgd=None,
**kwargs):
if pipeline and getattr(pipeline[0], 'name', None) == 'tensorizer':
token_vector_width = pipeline[0].model.nO
else:
token_vector_width = 64

if self.model is True:
self.cfg['pretrained_vectors'] = kwargs.get('pretrained_vectors')
self.model = self.Model(len(self.labels), token_vector_width,
**self.cfg)
link_vectors_to_models(self.vocab)
Expand Down
2 changes: 0 additions & 2 deletions spacy/syntax/nn_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,6 @@ cdef class Parser:
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
print("Create parser model", self.cfg)
path = util.ensure_path(path)
if self.model is True:
self.model, cfg = self.Model(**self.cfg)
Expand Down Expand Up @@ -944,7 +943,6 @@ cdef class Parser:
# TODO: Remove this once we don't have to handle previous models
if 'pretrained_dims' in self.cfg and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors.name
print("Create parser model", self.cfg)
if self.model is True:
self.model, cfg = self.Model(**self.cfg)
else:
Expand Down