Skip to content

Commit

Permalink
Merge pull request #92 from NatLibFi/fasttext-backend
Browse files Browse the repository at this point in the history
Add fastText backend. Fixes #74
  • Loading branch information
osma committed Apr 5, 2018
2 parents b8def22 + f6bff45 commit 4a82078
Show file tree
Hide file tree
Showing 12 changed files with 246 additions and 0 deletions.
6 changes: 6 additions & 0 deletions .scrutinizer.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
build:
environment:
python: 3.6.3
dependencies:
before:
- pip install cython
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ python:
- "3.6"
cache: pip
install:
- pip install cython
- pip install -r requirements.txt
- travis_wait 30 python -m nltk.downloader punkt
script:
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Create a virtual environment by running:

Install dependencies and download NLTK data:

pip install cython # needed by fasttext, has to be installed first
pip install -r requirements.txt
python -m nltk.downloader punkt

Expand Down
2 changes: 2 additions & 0 deletions annif/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import dummy
from . import http
from . import tfidf
from . import fasttext


_backend_types = {}
Expand All @@ -24,6 +25,7 @@ def get_backend_type(backend_type):
register_backend_type(dummy.DummyBackend)
register_backend_type(http.HTTPBackend)
register_backend_type(tfidf.TFIDFBackend)
register_backend_type(fasttext.FastTextBackend)


def _create_backends(backends_file, datadir):
Expand Down
2 changes: 2 additions & 0 deletions annif/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class AnnifBackend(metaclass=abc.ABCMeta):
non-implemented methods should be overridden in subclasses."""

name = None
needs_subject_index = False
needs_subject_vectorizer = False

def __init__(self, backend_id, params, datadir):
"""Initialize backend with specific parameters. The
Expand Down
132 changes: 132 additions & 0 deletions annif/backend/fasttext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Annif backend using the fastText classifier"""

import collections
import os.path
import annif.util
from annif.hit import AnalysisHit
import fasttext
from . import backend


class FastTextBackend(backend.AnnifBackend):
"""fastText backend for Annif"""

name = "fasttext"
needs_subject_index = True

FASTTEXT_PARAMS = (
'lr',
'lr_update_rate',
'dim',
'ws',
'epoch',
'min_count',
'neg',
'word_ngrams',
'loss',
'bucket',
'minn',
'maxn',
'thread',
't'
)

# defaults for uninitialized instances
_model = None

def initialize(self):
if self._model is None:
path = os.path.join(self._get_datadir(), 'model.bin')
self.debug('loading fastText model from {}'.format(path))
self._model = fasttext.load_model(path)
self.debug('loaded model {}'.format(str(self._model)))
self.debug('dim: {}'.format(self._model.dim))
self.debug('epoch: {}'.format(self._model.epoch))
self.debug('loss_name: {}'.format(self._model.loss_name))

@classmethod
def _id_to_label(cls, subject_id):
return "__label__{:d}".format(subject_id)

@classmethod
def _label_to_subject(cls, project, label):
subject_id = int(label.replace('__label__', ''))
return project.subjects[subject_id]

@classmethod
def _write_train_file(cls, doc_subjects, filename):
with open(filename, 'w') as trainfile:
for doc, subject_ids in doc_subjects.items():
labels = [cls._id_to_label(sid) for sid in subject_ids]
print(' '.join(labels), doc, file=trainfile)

@classmethod
def _normalize_text(cls, project, text):
return ' '.join(project.analyzer.tokenize_words(text))

def _create_train_file(self, subjects, project):
self.info('creating fastText training file')

doc_subjects = collections.defaultdict(set)
for subject_id, subj in enumerate(subjects):
for line in subj.text.splitlines():
doc_subjects[line].add(subject_id)

doc_subjects_normalized = {}
for doc, subjs in doc_subjects.items():
text = self._normalize_text(project, doc)
if text != '':
doc_subjects_normalized[text] = subjs

annif.util.atomic_save(doc_subjects_normalized,
self._get_datadir(),
'train.txt',
method=self._write_train_file)

def _create_model(self):
self.info('creating fastText model')
trainpath = os.path.join(self._get_datadir(), 'train.txt')
modelpath = os.path.join(self._get_datadir(), 'model')
params = {param: val for param, val in self.params.items()
if param in self.FASTTEXT_PARAMS}
self._model = fasttext.supervised(trainpath, modelpath, **params)

def load_subjects(self, subjects, project):
self._create_train_file(subjects, project)
self._create_model()

def _analyze_chunks(self, chunktexts, project):
limit = int(self.params['limit'])
ft_results = self._model.predict_proba(chunktexts, limit)
label_scores = collections.defaultdict(float)
for label, score in ft_results[0]:
label_scores[label] += score
best_labels = sorted([(score, label)
for label, score in label_scores.items()],
reverse=True)

results = []
for score, label in best_labels[:limit]:
subject = self._label_to_subject(project, label)
results.append(AnalysisHit(
uri=subject[0],
label=subject[1],
score=score / len(chunktexts)))
return results

def _analyze(self, text, project, params):
self.initialize()
self.debug('Analyzing text "{}..." (len={})'.format(
text[:20], len(text)))
sentences = project.analyzer.tokenize_sentences(text)
self.debug('Found {} sentences'.format(len(sentences)))
chunksize = int(params['chunksize'])
chunktexts = []
for i in range(0, len(sentences), chunksize):
chunktext = ' '.join(sentences[i:i + chunksize])
normalized = self._normalize_text(project, chunktext)
if normalized != '':
chunktexts.append(normalized)
self.debug('Split sentences into {} chunks'.format(len(chunktexts)))

return self._analyze_chunks(chunktexts, project)
3 changes: 3 additions & 0 deletions annif/backend/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@


class TFIDFBackend(backend.AnnifBackend):
"""TF-IDF vector space similarity based backend for Annif"""
name = "tfidf"
needs_subject_index = True
needs_subject_vectorizer = True

# defaults for uninitialized instances
_index = None
Expand Down
8 changes: 8 additions & 0 deletions annif/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,19 @@ def analyze(self, text, backend_params=None):
return merged_hits

def _create_subject_index(self, subjects):
if True not in [be[0].needs_subject_index for be in self.backends]:
logger.debug(
'not creating subject index: not needed by any backend')
return
logger.info('creating subject index')
self._subjects = annif.corpus.SubjectIndex(subjects)
annif.util.atomic_save(self._subjects, self._get_datadir(), 'subjects')

def _create_vectorizer(self, subjects):
if True not in [
be[0].needs_subject_vectorizer for be in self.backends]:
logger.debug('not creating vectorizer: not needed by any backend')
return
logger.info('creating vectorizer')
self._vectorizer = TfidfVectorizer(
tokenizer=self.analyzer.tokenize_words)
Expand Down
9 changes: 9 additions & 0 deletions backends.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ limit=100
[tfidf-en]
type=tfidf
limit=100

[fasttext-fi]
type=fasttext
dim=500
lr=0.25
epoch=30
loss=hs
limit=100
chunksize=10
5 changes: 5 additions & 0 deletions projects.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,8 @@ analyzer=snowball(swedish)
language=en
backends=tfidf-en
analyzer=snowball(english)

[fasttext-fi]
language=fi
backends=fasttext-fi
analyzer=snowball(finnish)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ nltk
requests
gensim
sklearn
fasttext
76 changes: 76 additions & 0 deletions tests/test_backend_fasttext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Unit tests for the TF-IDF backend in Annif"""

import annif
import annif.backend
import annif.corpus
import os.path
from sklearn.feature_extraction.text import TfidfVectorizer
import pytest
import unittest.mock


@pytest.fixture(scope='module')
def datadir(tmpdir_factory):
return tmpdir_factory.mktemp('data')


@pytest.fixture(scope='module')
def subject_corpus():
subjdir = os.path.join(
os.path.dirname(__file__),
'corpora',
'archaeology',
'subjects')
return annif.corpus.SubjectDirectory(subjdir)


@pytest.fixture(scope='module')
def project(subject_corpus):
proj = unittest.mock.Mock()
proj.analyzer = annif.analyzer.get_analyzer('snowball(finnish)')
proj.subjects = annif.corpus.SubjectIndex(subject_corpus)
return proj


def test_fasttext_load_subjects(datadir, subject_corpus, project):
fasttext_type = annif.backend.get_backend_type("fasttext")
fasttext = fasttext_type(
backend_id='fasttext',
params={
'limit': 50,
'dim': 100,
'lr': 0.25,
'epoch': 20,
'loss': 'hs'},
datadir=str(datadir))

fasttext.load_subjects(subject_corpus, project)
assert fasttext._model is not None
assert datadir.join('backends/fasttext/model.bin').exists()
assert datadir.join('backends/fasttext/model.bin').size() > 0


def test_fasttext_analyze(datadir, project):
fasttext_type = annif.backend.get_backend_type("fasttext")
fasttext = fasttext_type(
backend_id='fasttext',
params={
'limit': 50,
'chunksize': 1,
'dim': 100,
'lr': 0.25,
'epoch': 20,
'loss': 'hs'},
datadir=str(datadir))

results = fasttext.analyze("""Arkeologiaa sanotaan joskus myös
muinaistutkimukseksi tai muinaistieteeksi. Se on humanistinen tiede
tai oikeammin joukko tieteitä, jotka tutkivat ihmisen menneisyyttä.
Tutkimusta tehdään analysoimalla muinaisjäännöksiä eli niitä jälkiä,
joita ihmisten toiminta on jättänyt maaperään tai vesistöjen
pohjaan.""", project)

assert len(results) == 50
assert 'http:https://www.yso.fi/onto/yso/p1265' in [
result.uri for result in results]
assert 'arkeologia' in [result.label for result in results]

0 comments on commit 4a82078

Please sign in to comment.