Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add regen wikification #44

Merged
merged 7 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions zshot/linker/linker_regen/linker_regen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,37 @@

class LinkerRegen(Linker):
""" REGEN linker """
def __init__(self, max_input_len=384, max_output_len=15, num_beams=10):
def __init__(self, max_input_len=384, max_output_len=15, num_beams=10, trie=None):
"""
:param max_input_len: Max length of input
:param max_output_len: Max length of output
:param num_beams: Number of beans to use
:param trie: If the trie is given the linker will use it to restrict the search space.
Custom entities won't be used if the trie is given.
"""
super().__init__()
self.model = None
self.tokenizer = None
self.trie = None
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.num_beams = num_beams
self.skip_set_kg = False if trie is None else True
self.trie = trie

def set_kg(self, entities: Iterator[Entity]):
""" Set new entities

:param entities: New entities to use
"""
super().set_kg(entities)
self.load_tokenizer()
self.trie = Trie(
[
self.tokenizer(e.name, return_tensors="pt")['input_ids'][0].tolist()
for e in entities
]
)
if not self.skip_set_kg:
self.load_tokenizer()
self.trie = Trie(
[
self.tokenizer(e.name, return_tensors="pt")['input_ids'][0].tolist()
for e in entities
]
)

def load_models(self):
""" Load Model """
Expand Down
8 changes: 4 additions & 4 deletions zshot/linker/linker_regen/trie.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import List
from typing import Collection


class Trie(object):
def __init__(self, sequences: List[List[int]] = []):
def __init__(self, sequences: Collection[Collection[int]] = []):
self.trie_dict = {}
for sequence in sequences:
self.add(sequence)

def add(self, sequence: List[int]):
def add(self, sequence: Collection[int]):
trie = self.trie_dict
for idx in sequence:
if idx not in trie:
trie[idx] = {}
trie = trie[idx]

def postfix(self, prefix_sequence: List[int]):
def postfix(self, prefix_sequence: Collection[int]):
if len(prefix_sequence) == 1:
return list(self.trie_dict.keys())
trie = self.trie_dict
Expand Down
57 changes: 57 additions & 0 deletions zshot/linker/linker_regen/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
import json
import pickle
from typing import Dict, List

import pytest
from huggingface_hub import hf_hub_download

from zshot.linker.linker_regen.trie import Trie
from zshot.utils.data_models import Span

REPO_ID = "ibm/regen-disambiguation"
TRIE_FILE_NAME = "wikipedia_trie.pkl"
WIKIPEDIA_MAP = "wikipedia_map_id.json"


def create_input(sentence, max_length, start_delimiter, end_delimiter):
sent_list = sentence.split(" ")
if len(sent_list) < max_length:
Expand All @@ -12,3 +27,45 @@ def create_input(sentence, max_length, start_delimiter, end_delimiter):
left_index = left_index - max(0, (half_context - (right_index - end_delimiter_index)))
print(len(sent_list[left_index:right_index]))
return " ".join(sent_list[left_index:right_index])


def load_wikipedia_trie() -> Trie:
"""
Load the wikipedia trie from the HB hub
:return: The Wikipedia trie
"""
wikipedia_trie_file = hf_hub_download(repo_id=REPO_ID,
repo_type='model',
filename=TRIE_FILE_NAME)
with open(wikipedia_trie_file, "rb") as f:
wikipedia_trie = pickle.load(f)
return wikipedia_trie


@pytest.mark.skip(reason="Too expensive to run on every commit")
def load_wikipedia_mapping() -> Dict[str, str]:
"""
Load the wikipedia trie from the HB hub
:return: The Wikipedia trie
"""
wikipedia_map = hf_hub_download(repo_id=REPO_ID,
repo_type='model',
filename=WIKIPEDIA_MAP)
with open(wikipedia_map, "r") as f:
wikipedia_map = json.load(f)
return wikipedia_map


def spans_to_wikipedia(spans: List[Span]) -> List[str]:
"""
Generate wikipedia link for spans
:return: The list of generated links
"""
links = []
wikipedia_map = load_wikipedia_mapping()
for s in spans:
if s.label in wikipedia_map:
links.append(f"https://en.wikipedia.org/wiki?curid={wikipedia_map[s.label]}")
else:
links.append(None)
return links
41 changes: 39 additions & 2 deletions zshot/tests/linker/test_regen_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@

from zshot import PipelineConfig
from zshot.linker.linker_regen.linker_regen import LinkerRegen
from zshot.linker.linker_regen.trie import Trie
from zshot.linker.linker_regen.utils import load_wikipedia_trie, spans_to_wikipedia
from zshot.mentions_extractor import MentionsExtractorSpacy
from zshot.tests.config import EX_DOCS, EX_ENTITIES
from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor
from zshot.utils.data_models import Span

logger = logging.getLogger(__name__)

Expand All @@ -25,9 +29,9 @@ def teardown():


def test_regen_linker():
nlp = spacy.load("en_core_web_sm")
nlp = spacy.blank("en")
config = PipelineConfig(
mentions_extractor=MentionsExtractorSpacy(),
mentions_extractor=DummyMentionsExtractor(),
linker=LinkerRegen(),
entities=EX_ENTITIES
)
Expand Down Expand Up @@ -60,3 +64,36 @@ def test_regen_linker_pipeline():
nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
nlp.remove_pipe('zshot')
del docs, nlp, config


def test_regen_linker_wikification():
nlp = spacy.blank("en")
trie = Trie()
trie.add([794, 536, 1])
trie.add([794, 357, 1])
config = PipelineConfig(
mentions_extractor=DummyMentionsExtractor(),
linker=LinkerRegen(trie=trie),
)
nlp.add_pipe("zshot", config=config, last=True)
assert "zshot" in nlp.pipe_names

doc = nlp(EX_DOCS[1])
assert len(doc.ents) > 0
del nlp.get_pipe('zshot').mentions_extractor, nlp.get_pipe('zshot').entities, nlp.get_pipe('zshot').nlp
del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.trie, \
nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
nlp.remove_pipe('zshot')
del doc, nlp, config


def test_load_wikipedia_trie():
trie = load_wikipedia_trie()
assert len(list(trie.trie_dict.keys())) == 6952


def test_span_to_wiki():
s = Span(label="Surfing", start=0, end=10)
wiki_links = spans_to_wikipedia([s])
assert len(wiki_links) > 0
assert wiki_links[0].startswith("https://en.wikipedia.org/wiki?curid=")