Skip to content

Commit

Permalink
Init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jpWang committed Mar 23, 2022
1 parent e94e120 commit 738014d
Show file tree
Hide file tree
Showing 33 changed files with 3,293 additions and 2 deletions.
45 changes: 45 additions & 0 deletions LiLTfinetune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections import OrderedDict

from transformers import CONFIG_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_NAMES_MAPPING, TOKENIZER_MAPPING
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, BertConverter, RobertaConverter, XLMRobertaConverter
from transformers.models.auto.modeling_auto import auto_class_factory

from .models.LiLTRobertaLike import (
LiLTRobertaLikeConfig,
LiLTRobertaLikeForRelationExtraction,
LiLTRobertaLikeForTokenClassification,
LiLTRobertaLikeTokenizer,
LiLTRobertaLikeTokenizerFast,
)

CONFIG_MAPPING.update([("liltrobertalike", LiLTRobertaLikeConfig),])
MODEL_NAMES_MAPPING.update([("liltrobertalike", "LiLTRobertaLike"),])
TOKENIZER_MAPPING.update(
[
(LiLTRobertaLikeConfig, (LiLTRobertaLikeTokenizer, LiLTRobertaLikeTokenizerFast)),
]
)

with open('tag.txt', 'r') as tagf:
TAG = tagf.read().lower()
assert TAG == 'monolingual' or TAG == 'multilingual', 'TAG is wrong. It should be monolingual or multilingual.'
if TAG == 'monolingual':
SLOW_TO_FAST_CONVERTERS.update({"LiLTRobertaLikeTokenizer": RobertaConverter,})
elif TAG == 'multilingual':
SLOW_TO_FAST_CONVERTERS.update({"LiLTRobertaLikeTokenizer": XLMRobertaConverter,})

MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.update(
[(LiLTRobertaLikeConfig, LiLTRobertaLikeForTokenClassification),]
)

MODEL_FOR_RELATION_EXTRACTION_MAPPING = OrderedDict(
[(LiLTRobertaLikeConfig, LiLTRobertaLikeForRelationExtraction),]
)

AutoModelForTokenClassification = auto_class_factory(
"AutoModelForTokenClassification", MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
)

AutoModelForRelationExtraction = auto_class_factory(
"AutoModelForRelationExtraction", MODEL_FOR_RELATION_EXTRACTION_MAPPING, head_doc="relation extraction"
)
3 changes: 3 additions & 0 deletions LiLTfinetune/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa
from .data_collator import DataCollatorForKeyValueExtraction
from .datasets import *
81 changes: 81 additions & 0 deletions LiLTfinetune/data/data_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""

task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
pad_to_max_length: bool = field(
default=True,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
label_all_tokens: bool = field(
default=False,
metadata={
"help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
"one (in which case the other tokens will have a padding index)."
},
)
return_entity_level_metrics: bool = field(
default=False,
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
)


@dataclass
class XFUNDataTrainingArguments(DataTrainingArguments):
lang: Optional[str] = field(default="en")
additional_langs: Optional[str] = field(default=None)
82 changes: 82 additions & 0 deletions LiLTfinetune/data/data_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from dataclasses import dataclass
from typing import Optional, Union

import torch

from detectron2.structures import ImageList
from transformers import PreTrainedTokenizerBase
from transformers.file_utils import PaddingStrategy


@dataclass
class DataCollatorForKeyValueExtraction:
"""
Data collator that will dynamically pad the inputs received, as well as the labels.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
"""

tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100

def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

has_image_input = "image" in features[0]
has_bbox_input = "bbox" in features[0]
if has_image_input:
image = ImageList.from_tensors([torch.tensor(feature["image"]) for feature in features], 32)
for feature in features:
del feature["image"]
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
)

if labels is None:
return batch

sequence_length = torch.tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
if has_bbox_input:
batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
if has_bbox_input:
batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]

batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
if has_image_input:
batch["image"] = image
return batch
Empty file.
116 changes: 116 additions & 0 deletions LiLTfinetune/data/datasets/funsd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# coding=utf-8

import json
import os

import datasets

from LiLTfinetune.data.utils import load_image, normalize_bbox


logger = datasets.logging.get_logger(__name__)


_CITATION = """\
@article{Jaume2019FUNSDAD,
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
year={2019},
volume={2},
pages={1-6}
}
"""

_DESCRIPTION = """\
https://guillaumejaume.github.io/FUNSD/
"""


class FunsdConfig(datasets.BuilderConfig):
"""BuilderConfig for FUNSD"""

def __init__(self, **kwargs):
"""BuilderConfig for FUNSD.
Args:
**kwargs: keyword arguments forwarded to super.
"""
super(FunsdConfig, self).__init__(**kwargs)


class Funsd(datasets.GeneratorBasedBuilder):
"""Conll2003 dataset."""

BUILDER_CONFIGS = [
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
]

def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"tokens": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
}
),
supervised_keys=None,
homepage="https://guillaumejaume.github.io/FUNSD/",
citation=_CITATION,
)

def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
),
]

def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "annotations")
img_dir = os.path.join(filepath, "images")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
tokens = []
bboxes = []
ner_tags = []

file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["form"]:
words, label = item["words"], item["label"]
words = [w for w in words if w["text"].strip() != ""]
if len(words) == 0:
continue
if label == "other":
for w in words:
tokens.append(w["text"])
ner_tags.append("O")
bboxes.append(normalize_bbox(item["box"], size))
else:
tokens.append(words[0]["text"])
ner_tags.append("B-" + label.upper())
bboxes.append(normalize_bbox(item["box"], size))
for w in words[1:]:
tokens.append(w["text"])
ner_tags.append("I-" + label.upper())
bboxes.append(normalize_bbox(item["box"], size))

yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}
Loading

0 comments on commit 738014d

Please sign in to comment.