-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #92 from NatLibFi/fasttext-backend
Add fastText backend. Fixes #74
- Loading branch information
Showing
12 changed files
with
246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
build: | ||
environment: | ||
python: 3.6.3 | ||
dependencies: | ||
before: | ||
- pip install cython |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
"""Annif backend using the fastText classifier""" | ||
|
||
import collections | ||
import os.path | ||
import annif.util | ||
from annif.hit import AnalysisHit | ||
import fasttext | ||
from . import backend | ||
|
||
|
||
class FastTextBackend(backend.AnnifBackend): | ||
"""fastText backend for Annif""" | ||
|
||
name = "fasttext" | ||
needs_subject_index = True | ||
|
||
FASTTEXT_PARAMS = ( | ||
'lr', | ||
'lr_update_rate', | ||
'dim', | ||
'ws', | ||
'epoch', | ||
'min_count', | ||
'neg', | ||
'word_ngrams', | ||
'loss', | ||
'bucket', | ||
'minn', | ||
'maxn', | ||
'thread', | ||
't' | ||
) | ||
|
||
# defaults for uninitialized instances | ||
_model = None | ||
|
||
def initialize(self): | ||
if self._model is None: | ||
path = os.path.join(self._get_datadir(), 'model.bin') | ||
self.debug('loading fastText model from {}'.format(path)) | ||
self._model = fasttext.load_model(path) | ||
self.debug('loaded model {}'.format(str(self._model))) | ||
self.debug('dim: {}'.format(self._model.dim)) | ||
self.debug('epoch: {}'.format(self._model.epoch)) | ||
self.debug('loss_name: {}'.format(self._model.loss_name)) | ||
|
||
@classmethod | ||
def _id_to_label(cls, subject_id): | ||
return "__label__{:d}".format(subject_id) | ||
|
||
@classmethod | ||
def _label_to_subject(cls, project, label): | ||
subject_id = int(label.replace('__label__', '')) | ||
return project.subjects[subject_id] | ||
|
||
@classmethod | ||
def _write_train_file(cls, doc_subjects, filename): | ||
with open(filename, 'w') as trainfile: | ||
for doc, subject_ids in doc_subjects.items(): | ||
labels = [cls._id_to_label(sid) for sid in subject_ids] | ||
print(' '.join(labels), doc, file=trainfile) | ||
|
||
@classmethod | ||
def _normalize_text(cls, project, text): | ||
return ' '.join(project.analyzer.tokenize_words(text)) | ||
|
||
def _create_train_file(self, subjects, project): | ||
self.info('creating fastText training file') | ||
|
||
doc_subjects = collections.defaultdict(set) | ||
for subject_id, subj in enumerate(subjects): | ||
for line in subj.text.splitlines(): | ||
doc_subjects[line].add(subject_id) | ||
|
||
doc_subjects_normalized = {} | ||
for doc, subjs in doc_subjects.items(): | ||
text = self._normalize_text(project, doc) | ||
if text != '': | ||
doc_subjects_normalized[text] = subjs | ||
|
||
annif.util.atomic_save(doc_subjects_normalized, | ||
self._get_datadir(), | ||
'train.txt', | ||
method=self._write_train_file) | ||
|
||
def _create_model(self): | ||
self.info('creating fastText model') | ||
trainpath = os.path.join(self._get_datadir(), 'train.txt') | ||
modelpath = os.path.join(self._get_datadir(), 'model') | ||
params = {param: val for param, val in self.params.items() | ||
if param in self.FASTTEXT_PARAMS} | ||
self._model = fasttext.supervised(trainpath, modelpath, **params) | ||
|
||
def load_subjects(self, subjects, project): | ||
self._create_train_file(subjects, project) | ||
self._create_model() | ||
|
||
def _analyze_chunks(self, chunktexts, project): | ||
limit = int(self.params['limit']) | ||
ft_results = self._model.predict_proba(chunktexts, limit) | ||
label_scores = collections.defaultdict(float) | ||
for label, score in ft_results[0]: | ||
label_scores[label] += score | ||
best_labels = sorted([(score, label) | ||
for label, score in label_scores.items()], | ||
reverse=True) | ||
|
||
results = [] | ||
for score, label in best_labels[:limit]: | ||
subject = self._label_to_subject(project, label) | ||
results.append(AnalysisHit( | ||
uri=subject[0], | ||
label=subject[1], | ||
score=score / len(chunktexts))) | ||
return results | ||
|
||
def _analyze(self, text, project, params): | ||
self.initialize() | ||
self.debug('Analyzing text "{}..." (len={})'.format( | ||
text[:20], len(text))) | ||
sentences = project.analyzer.tokenize_sentences(text) | ||
self.debug('Found {} sentences'.format(len(sentences))) | ||
chunksize = int(params['chunksize']) | ||
chunktexts = [] | ||
for i in range(0, len(sentences), chunksize): | ||
chunktext = ' '.join(sentences[i:i + chunksize]) | ||
normalized = self._normalize_text(project, chunktext) | ||
if normalized != '': | ||
chunktexts.append(normalized) | ||
self.debug('Split sentences into {} chunks'.format(len(chunktexts))) | ||
|
||
return self._analyze_chunks(chunktexts, project) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ nltk | |
requests | ||
gensim | ||
sklearn | ||
fasttext |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Unit tests for the TF-IDF backend in Annif""" | ||
|
||
import annif | ||
import annif.backend | ||
import annif.corpus | ||
import os.path | ||
from sklearn.feature_extraction.text import TfidfVectorizer | ||
import pytest | ||
import unittest.mock | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def datadir(tmpdir_factory): | ||
return tmpdir_factory.mktemp('data') | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def subject_corpus(): | ||
subjdir = os.path.join( | ||
os.path.dirname(__file__), | ||
'corpora', | ||
'archaeology', | ||
'subjects') | ||
return annif.corpus.SubjectDirectory(subjdir) | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def project(subject_corpus): | ||
proj = unittest.mock.Mock() | ||
proj.analyzer = annif.analyzer.get_analyzer('snowball(finnish)') | ||
proj.subjects = annif.corpus.SubjectIndex(subject_corpus) | ||
return proj | ||
|
||
|
||
def test_fasttext_load_subjects(datadir, subject_corpus, project): | ||
fasttext_type = annif.backend.get_backend_type("fasttext") | ||
fasttext = fasttext_type( | ||
backend_id='fasttext', | ||
params={ | ||
'limit': 50, | ||
'dim': 100, | ||
'lr': 0.25, | ||
'epoch': 20, | ||
'loss': 'hs'}, | ||
datadir=str(datadir)) | ||
|
||
fasttext.load_subjects(subject_corpus, project) | ||
assert fasttext._model is not None | ||
assert datadir.join('backends/fasttext/model.bin').exists() | ||
assert datadir.join('backends/fasttext/model.bin').size() > 0 | ||
|
||
|
||
def test_fasttext_analyze(datadir, project): | ||
fasttext_type = annif.backend.get_backend_type("fasttext") | ||
fasttext = fasttext_type( | ||
backend_id='fasttext', | ||
params={ | ||
'limit': 50, | ||
'chunksize': 1, | ||
'dim': 100, | ||
'lr': 0.25, | ||
'epoch': 20, | ||
'loss': 'hs'}, | ||
datadir=str(datadir)) | ||
|
||
results = fasttext.analyze("""Arkeologiaa sanotaan joskus myös | ||
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede | ||
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä. | ||
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä, | ||
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen | ||
pohjaan.""", project) | ||
|
||
assert len(results) == 50 | ||
assert 'http:https://www.yso.fi/onto/yso/p1265' in [ | ||
result.uri for result in results] | ||
assert 'arkeologia' in [result.label for result in results] |