From 3919b57ccf55b88da05bd921ba20738e49bdb9a1 Mon Sep 17 00:00:00 2001 From: plkmo Date: Sun, 14 Jun 2020 17:43:40 +0800 Subject: [PATCH] init gec --- gec.py | 81 ++ nlptoolkit/gec/__init__.py | 0 nlptoolkit/gec/infer.py | 133 +++ nlptoolkit/gec/models/__init__.py | 0 nlptoolkit/gec/models/gector/__init__.py | 0 .../gec/models/gector/bert_token_embedder.py | 270 ++++++ nlptoolkit/gec/models/gector/datareader.py | 151 ++++ nlptoolkit/gec/models/gector/gec_model.py | 322 +++++++ .../gec/models/gector/seq2labels_model.py | 193 ++++ nlptoolkit/gec/models/gector/trainer.py | 845 ++++++++++++++++++ .../gec/models/gector/utils/__init__.py | 0 nlptoolkit/gec/models/gector/utils/helpers.py | 202 +++++ .../gector/utils/prepare_clc_fce_data.py | 123 +++ .../models/gector/utils/preprocess_data.py | 488 ++++++++++ .../gec/models/gector/wordpiece_indexer.py | 444 +++++++++ nlptoolkit/gec/trainer.py | 303 +++++++ 16 files changed, 3555 insertions(+) create mode 100644 gec.py create mode 100644 nlptoolkit/gec/__init__.py create mode 100644 nlptoolkit/gec/infer.py create mode 100644 nlptoolkit/gec/models/__init__.py create mode 100644 nlptoolkit/gec/models/gector/__init__.py create mode 100644 nlptoolkit/gec/models/gector/bert_token_embedder.py create mode 100644 nlptoolkit/gec/models/gector/datareader.py create mode 100644 nlptoolkit/gec/models/gector/gec_model.py create mode 100644 nlptoolkit/gec/models/gector/seq2labels_model.py create mode 100644 nlptoolkit/gec/models/gector/trainer.py create mode 100644 nlptoolkit/gec/models/gector/utils/__init__.py create mode 100644 nlptoolkit/gec/models/gector/utils/helpers.py create mode 100644 nlptoolkit/gec/models/gector/utils/prepare_clc_fce_data.py create mode 100644 nlptoolkit/gec/models/gector/utils/preprocess_data.py create mode 100644 nlptoolkit/gec/models/gector/wordpiece_indexer.py create mode 100644 nlptoolkit/gec/trainer.py diff --git a/gec.py b/gec.py new file mode 100644 index 0000000..f62ad21 --- /dev/null +++ b/gec.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +""" +Created on Sun Aug 18 13:09:24 2019 + +@author: WT +""" +from nlptoolkit.gec.infer import infer_from_trained +from nlptoolkit.utils.misc import save_as_pickle +from argparse import ArgumentParser +import logging + +logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \ + datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logger = logging.getLogger('__file__') + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--model_no", type=int, default=0, help="0: GECToR") + parser.add_argument('--model_path', type=str, default=['./data/gec/gector/roberta_1_gector.th'], + help='Path to the model file.', nargs='+') + parser.add_argument('--vocab_path', type=str, default='./data/gec/gector/output_vocabulary/', + help='Path to the model file.') + #parser.add_argument('--input_file', type=str, default='./data/gec/gector/input.txt', + # help='Path to the evalset file') + #parser.add_argument('--output_file', type=str, default='./data/gec/gector/output.txt', + # help='Path to the output file') + parser.add_argument('--max_len', + type=int, + help='The max sentence length' + '(all longer will be truncated)', + default=50) + parser.add_argument('--min_len', + type=int, + help='The minimum sentence length' + '(all longer will be returned w/o changes)', + default=3) + parser.add_argument('--batch_size', + type=int, + help='The size of hidden unit cell.', + default=128) + parser.add_argument('--lowercase_tokens', + type=int, + help='Whether to lowercase tokens.', + default=0) + parser.add_argument('--transformer_model', + choices=['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert'], + help='Name of the transformer model.', + default='roberta') + parser.add_argument('--iteration_count', + type=int, + help='The number of iterations of the model.', + default=5) + parser.add_argument('--additional_confidence', + type=float, + help='How many probability to add to $KEEP token.', + default=0) + parser.add_argument('--min_probability', + type=float, + default=0.0) + parser.add_argument('--min_error_probability', + type=float, + default=0.0) + parser.add_argument('--special_tokens_fix', + type=int, + help='Whether to fix problem with [CLS], [SEP] tokens tokenization. ' + 'For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa.', + default=1) + parser.add_argument('--is_ensemble', + type=int, + help='Whether to do ensembling.', + default=0) + parser.add_argument('--weights', + help='Used to calculate weighted average', nargs='+', + default=None) + args = parser.parse_args() + + save_as_pickle("args.pkl", args) + + inferer = infer_from_trained(args) + inferer.infer_from_file(input_file='./data/gec/gector/input.txt', \ + output_file='./data/gec/gector/output.txt', batch_size=32) \ No newline at end of file diff --git a/nlptoolkit/gec/__init__.py b/nlptoolkit/gec/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nlptoolkit/gec/infer.py b/nlptoolkit/gec/infer.py new file mode 100644 index 0000000..ab6d39a --- /dev/null +++ b/nlptoolkit/gec/infer.py @@ -0,0 +1,133 @@ +import argparse + +from .models.gector.utils.helpers import read_lines +from .models.gector.gec_model import GecBERTModel + +class infer_from_trained(object): + def __init__(self, args): + self.args = args + self.model = GecBERTModel(vocab_path=args.vocab_path, + model_paths=args.model_path, + max_len=args.max_len, min_len=args.min_len, + iterations=args.iteration_count, + min_error_probability=args.min_error_probability, + min_probability=args.min_error_probability, + lowercase_tokens=args.lowercase_tokens, + model_name=args.transformer_model, + special_tokens_fix=args.special_tokens_fix, + log=False, + confidence=args.additional_confidence, + is_ensemble=args.is_ensemble, + weigths=args.weights) + + def infer_from_file(self, input_file='./data/gec/gector/input.txt', \ + output_file='./data/gec/gector/output.txt', batch_size=32): + test_data = read_lines(input_file) + predictions = [] + cnt_corrections = 0 + batch = [] + for sent in test_data: + batch.append(sent.split()) + if len(batch) == batch_size: + preds, cnt = self.model.handle_batch(batch) + predictions.extend(preds) + cnt_corrections += cnt + batch = [] + if batch: + preds, cnt = self.model.handle_batch(batch) + predictions.extend(preds) + cnt_corrections += cnt + + with open(output_file, 'w') as f: + f.write("\n".join([" ".join(x) for x in predictions]) + '\n') + return cnt_corrections + + +def main(args): + # get all paths + model = GecBERTModel(vocab_path=args.vocab_path, + model_paths=args.model_path, + max_len=args.max_len, min_len=args.min_len, + iterations=args.iteration_count, + min_error_probability=args.min_error_probability, + min_probability=args.min_error_probability, + lowercase_tokens=args.lowercase_tokens, + model_name=args.transformer_model, + special_tokens_fix=args.special_tokens_fix, + log=False, + confidence=args.additional_confidence, + is_ensemble=args.is_ensemble, + weigths=args.weights) + + cnt_corrections = predict_for_file(args.input_file, args.output_file, model, + batch_size=args.batch_size) + # evaluate with m2 or ERRANT + print(f"Produced overall corrections: {cnt_corrections}") + + +if __name__ == '__main__': + # read parameters + parser = argparse.ArgumentParser() + parser.add_argument('--model_path', + help='Path to the model file.', nargs='+', + required=True) + parser.add_argument('--vocab_path', + help='Path to the model file.', + default='data/output_vocabulary' # to use pretrained models + ) + parser.add_argument('--input_file', + help='Path to the evalset file', + required=True) + parser.add_argument('--output_file', + help='Path to the output file', + required=True) + parser.add_argument('--max_len', + type=int, + help='The max sentence length' + '(all longer will be truncated)', + default=50) + parser.add_argument('--min_len', + type=int, + help='The minimum sentence length' + '(all longer will be returned w/o changes)', + default=3) + parser.add_argument('--batch_size', + type=int, + help='The size of hidden unit cell.', + default=128) + parser.add_argument('--lowercase_tokens', + type=int, + help='Whether to lowercase tokens.', + default=0) + parser.add_argument('--transformer_model', + choices=['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert'], + help='Name of the transformer model.', + default='roberta') + parser.add_argument('--iteration_count', + type=int, + help='The number of iterations of the model.', + default=5) + parser.add_argument('--additional_confidence', + type=float, + help='How many probability to add to $KEEP token.', + default=0) + parser.add_argument('--min_probability', + type=float, + default=0.0) + parser.add_argument('--min_error_probability', + type=float, + default=0.0) + parser.add_argument('--special_tokens_fix', + type=int, + help='Whether to fix problem with [CLS], [SEP] tokens tokenization. ' + 'For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa.', + default=1) + parser.add_argument('--is_ensemble', + type=int, + help='Whether to do ensembling.', + default=0) + parser.add_argument('--weights', + help='Used to calculate weighted average', nargs='+', + default=None) + args = parser.parse_args() + main(args) diff --git a/nlptoolkit/gec/models/__init__.py b/nlptoolkit/gec/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nlptoolkit/gec/models/gector/__init__.py b/nlptoolkit/gec/models/gector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nlptoolkit/gec/models/gector/bert_token_embedder.py b/nlptoolkit/gec/models/gector/bert_token_embedder.py new file mode 100644 index 0000000..ed7bea4 --- /dev/null +++ b/nlptoolkit/gec/models/gector/bert_token_embedder.py @@ -0,0 +1,270 @@ +"""Tweaked version of corresponding AllenNLP file""" +import logging +from copy import deepcopy +from typing import Dict + +import torch +import torch.nn.functional as F +from allennlp.modules.token_embedders.token_embedder import TokenEmbedder +from allennlp.nn import util +from transformers import AutoModel, PreTrainedModel + +logger = logging.getLogger(__name__) + + +class PretrainedBertModel: + """ + In some instances you may want to load the same BERT model twice + (e.g. to use as a token embedder and also as a pooling layer). + This factory provides a cache so that you don't actually have to load the model twice. + """ + + _cache: Dict[str, PreTrainedModel] = {} + + @classmethod + def load(cls, model_name: str, cache_model: bool = True) -> PreTrainedModel: + if model_name in cls._cache: + return PretrainedBertModel._cache[model_name] + + model = AutoModel.from_pretrained(model_name) + if cache_model: + cls._cache[model_name] = model + + return model + + +class BertEmbedder(TokenEmbedder): + """ + A ``TokenEmbedder`` that produces BERT embeddings for your tokens. + Should be paired with a ``BertIndexer``, which produces wordpiece ids. + Most likely you probably want to use ``PretrainedBertEmbedder`` + for one of the named pretrained models, not this base class. + Parameters + ---------- + bert_model: ``BertModel`` + The BERT model being wrapped. + top_layer_only: ``bool``, optional (default = ``False``) + If ``True``, then only return the top layer instead of apply the scalar mix. + max_pieces : int, optional (default: 512) + The BERT embedder uses positional embeddings and so has a corresponding + maximum length for its input ids. Assuming the inputs are windowed + and padded appropriately by this length, the embedder will split them into a + large batch, feed them into BERT, and recombine the output as if it was a + longer sequence. + num_start_tokens : int, optional (default: 1) + The number of starting special tokens input to BERT (usually 1, i.e., [CLS]) + num_end_tokens : int, optional (default: 1) + The number of ending tokens input to BERT (usually 1, i.e., [SEP]) + scalar_mix_parameters: ``List[float]``, optional, (default = None) + If not ``None``, use these scalar mix parameters to weight the representations + produced by different layers. These mixing weights are not updated during + training. + """ + + def __init__( + self, + bert_model: PreTrainedModel, + top_layer_only: bool = False, + max_pieces: int = 512, + num_start_tokens: int = 1, + num_end_tokens: int = 1 + ) -> None: + super().__init__() + # self.bert_model = bert_model + self.bert_model = deepcopy(bert_model) + self.output_dim = bert_model.config.hidden_size + self.max_pieces = max_pieces + self.num_start_tokens = num_start_tokens + self.num_end_tokens = num_end_tokens + self._scalar_mix = None + + def set_weights(self, freeze): + for param in self.bert_model.parameters(): + param.requires_grad = not freeze + return + + def get_output_dim(self) -> int: + return self.output_dim + + def forward( + self, + input_ids: torch.LongTensor, + offsets: torch.LongTensor = None + ) -> torch.Tensor: + """ + Parameters + ---------- + input_ids : ``torch.LongTensor`` + The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. + offsets : ``torch.LongTensor``, optional + The BERT embeddings are one per wordpiece. However it's possible/likely + you might want one per original token. In that case, ``offsets`` + represents the indices of the desired wordpiece for each original token. + Depending on how your token indexer is configured, this could be the + position of the last wordpiece for each token, or it could be the position + of the first wordpiece for each token. + For example, if you had the sentence "Definitely not", and if the corresponding + wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids + would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. + If offsets are provided, the returned tensor will contain only the wordpiece + embeddings at those positions, and (in particular) will contain one embedding + per token. If offsets are not provided, the entire tensor of wordpiece embeddings + will be returned. + """ + + batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) + initial_dims = list(input_ids.shape[:-1]) + + # The embedder may receive an input tensor that has a sequence length longer than can + # be fit. In that case, we should expect the wordpiece indexer to create padded windows + # of length `self.max_pieces` for us, and have them concatenated into one long sequence. + # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." + # We can then split the sequence into sub-sequences of that length, and concatenate them + # along the batch dimension so we effectively have one huge batch of partial sentences. + # This can then be fed into BERT without any sentence length issues. Keep in mind + # that the memory consumption can dramatically increase for large batches with extremely + # long sentences. + needs_split = full_seq_len > self.max_pieces + last_window_size = 0 + if needs_split: + # Split the flattened list by the window size, `max_pieces` + split_input_ids = list(input_ids.split(self.max_pieces, dim=-1)) + + # We want all sequences to be the same length, so pad the last sequence + last_window_size = split_input_ids[-1].size(-1) + padding_amount = self.max_pieces - last_window_size + split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0) + + # Now combine the sequences along the batch dimension + input_ids = torch.cat(split_input_ids, dim=0) + + input_mask = (input_ids != 0).long() + # input_ids may have extra dimensions, so we reshape down to 2-d + # before calling the BERT model and then reshape back at the end. + all_encoder_layers = self.bert_model( + input_ids=util.combine_initial_dims(input_ids), + attention_mask=util.combine_initial_dims(input_mask), + )[0] + if len(all_encoder_layers[0].shape) == 3: + all_encoder_layers = torch.stack(all_encoder_layers) + elif len(all_encoder_layers[0].shape) == 2: + all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0) + + if needs_split: + # First, unpack the output embeddings into one long sequence again + unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) + unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) + + # Next, select indices of the sequence such that it will result in embeddings representing the original + # sentence. To capture maximal context, the indices will be the middle part of each embedded window + # sub-sequence (plus any leftover start and final edge windows), e.g., + # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" + # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start + # and final windows with indices [0, 1] and [14, 15] respectively. + + # Find the stride as half the max pieces, ignoring the special start and end tokens + # Calculate an offset to extract the centermost embeddings of each window + stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2 + stride_offset = stride // 2 + self.num_start_tokens + + first_window = list(range(stride_offset)) + + max_context_windows = [ + i + for i in range(full_seq_len) + if stride_offset - 1 < i % self.max_pieces < stride_offset + stride + ] + + # Lookback what's left, unless it's the whole self.max_pieces window + if full_seq_len % self.max_pieces == 0: + lookback = self.max_pieces + else: + lookback = full_seq_len % self.max_pieces + + final_window_start = full_seq_len - lookback + stride_offset + stride + final_window = list(range(final_window_start, full_seq_len)) + + select_indices = first_window + max_context_windows + final_window + + initial_dims.append(len(select_indices)) + + recombined_embeddings = unpacked_embeddings[:, :, select_indices] + else: + recombined_embeddings = all_encoder_layers + + # Recombine the outputs of all layers + # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) + # recombined = torch.cat(combined, dim=2) + input_mask = (recombined_embeddings != 0).long() + + if self._scalar_mix is not None: + mix = self._scalar_mix(recombined_embeddings, input_mask) + else: + mix = recombined_embeddings[-1] + + # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) + + if offsets is None: + # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) + dims = initial_dims if needs_split else input_ids.size() + return util.uncombine_initial_dims(mix, dims) + else: + # offsets is (batch_size, d1, ..., dn, orig_sequence_length) + offsets2d = util.combine_initial_dims(offsets) + # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) + range_vector = util.get_range_vector( + offsets2d.size(0), device=util.get_device_of(mix) + ).unsqueeze(1) + # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) + selected_embeddings = mix[range_vector, offsets2d] + + return util.uncombine_initial_dims(selected_embeddings, offsets.size()) + + +# @TokenEmbedder.register("bert-pretrained") +class PretrainedBertEmbedder(BertEmbedder): + + """ + Parameters + ---------- + pretrained_model: ``str`` + Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), + or the path to the .tar.gz file with the model weights. + If the name is a key in the list of pretrained models at + https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41 + the corresponding path will be used; otherwise it will be interpreted as a path or URL. + requires_grad : ``bool``, optional (default = False) + If True, compute gradient of BERT parameters for fine tuning. + top_layer_only: ``bool``, optional (default = ``False``) + If ``True``, then only return the top layer instead of apply the scalar mix. + scalar_mix_parameters: ``List[float]``, optional, (default = None) + If not ``None``, use these scalar mix parameters to weight the representations + produced by different layers. These mixing weights are not updated during + training. + """ + + def __init__( + self, + pretrained_model: str, + requires_grad: bool = False, + top_layer_only: bool = False, + special_tokens_fix: int = 0, + ) -> None: + model = PretrainedBertModel.load(pretrained_model) + + for param in model.parameters(): + param.requires_grad = requires_grad + + super().__init__( + bert_model=model, + top_layer_only=top_layer_only + ) + + if special_tokens_fix: + try: + vocab_size = self.bert_model.embeddings.word_embeddings.num_embeddings + except AttributeError: + # reserve more space + vocab_size = self.bert_model.word_embedding.num_embeddings + 5 + self.bert_model.resize_token_embeddings(vocab_size + 1) diff --git a/nlptoolkit/gec/models/gector/datareader.py b/nlptoolkit/gec/models/gector/datareader.py new file mode 100644 index 0000000..d511afb --- /dev/null +++ b/nlptoolkit/gec/models/gector/datareader.py @@ -0,0 +1,151 @@ +"""Tweaked AllenNLP dataset reader.""" +import logging +import re +from random import random +from typing import Dict, List + +from allennlp.common.file_utils import cached_path +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field +from allennlp.data.instance import Instance +from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer +from allennlp.data.tokenizers import Token +from overrides import overrides + +from utils.helpers import SEQ_DELIMETERS, START_TOKEN + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +@DatasetReader.register("seq2labels_datareader") +class Seq2LabelsDatasetReader(DatasetReader): + """ + Reads instances from a pretokenised file where each line is in the following format: + + WORD###TAG [TAB] WORD###TAG [TAB] ..... \n + + and converts it into a ``Dataset`` suitable for sequence tagging. You can also specify + alternative delimiters in the constructor. + + Parameters + ---------- + delimiters: ``dict`` + The dcitionary with all delimeters. + token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) + We use this to define the input representation for the text. See :class:`TokenIndexer`. + Note that the `output` tags will always correspond to single token IDs based on how they + are pre-tokenised in the data file. + max_len: if set than will truncate long sentences + """ + # fix broken sentences mostly in Lang8 + BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]') + + def __init__(self, + token_indexers: Dict[str, TokenIndexer] = None, + delimeters: dict = SEQ_DELIMETERS, + skip_correct: bool = False, + skip_complex: int = 0, + lazy: bool = False, + max_len: int = None, + test_mode: bool = False, + tag_strategy: str = "keep_one", + tn_prob: float = 0, + tp_prob: float = 0, + broken_dot_strategy: str = "keep") -> None: + super().__init__(lazy) + self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} + self._delimeters = delimeters + self._max_len = max_len + self._skip_correct = skip_correct + self._skip_complex = skip_complex + self._tag_strategy = tag_strategy + self._broken_dot_strategy = broken_dot_strategy + self._test_mode = test_mode + self._tn_prob = tn_prob + self._tp_prob = tp_prob + + @overrides + def _read(self, file_path): + # if `file_path` is a URL, redirect to the cache + file_path = cached_path(file_path) + with open(file_path, "r") as data_file: + logger.info("Reading instances from lines in file at: %s", file_path) + for line in data_file: + line = line.strip("\n") + # skip blank and broken lines + if not line or (not self._test_mode and self._broken_dot_strategy == 'skip' + and self.BROKEN_SENTENCES_REGEXP.search(line) is not None): + continue + + tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1) + for pair in line.split(self._delimeters['tokens'])] + try: + tokens = [Token(token) for token, tag in tokens_and_tags] + tags = [tag for token, tag in tokens_and_tags] + except ValueError: + tokens = [Token(token[0]) for token in tokens_and_tags] + tags = None + + if tokens and tokens[0] != Token(START_TOKEN): + tokens = [Token(START_TOKEN)] + tokens + + words = [x.text for x in tokens] + if self._max_len is not None: + tokens = tokens[:self._max_len] + tags = None if tags is None else tags[:self._max_len] + instance = self.text_to_instance(tokens, tags, words) + if instance: + yield instance + + def extract_tags(self, tags: List[str]): + op_del = self._delimeters['operations'] + + labels = [x.split(op_del) for x in tags] + + comlex_flag_dict = {} + # get flags + for i in range(5): + idx = i + 1 + comlex_flag_dict[idx] = sum([len(x) > idx for x in labels]) + + if self._tag_strategy == "keep_one": + # get only first candidates for r_tags in right and the last for left + labels = [x[0] for x in labels] + elif self._tag_strategy == "merge_all": + # consider phrases as a words + pass + else: + raise Exception("Incorrect tag strategy") + + detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels] + return labels, detect_tags, comlex_flag_dict + + def text_to_instance(self, tokens: List[Token], tags: List[str] = None, + words: List[str] = None) -> Instance: # type: ignore + """ + We take `pre-tokenized` input here, because we don't have a tokenizer in this class. + """ + # pylint: disable=arguments-differ + fields: Dict[str, Field] = {} + sequence = TextField(tokens, self._token_indexers) + fields["tokens"] = sequence + fields["metadata"] = MetadataField({"words": words}) + if tags is not None: + labels, detect_tags, complex_flag_dict = self.extract_tags(tags) + if self._skip_complex and complex_flag_dict[self._skip_complex] > 0: + return None + rnd = random() + # skip TN + if self._skip_correct and all(x == "CORRECT" for x in detect_tags): + if rnd > self._tn_prob: + return None + # skip TP + else: + if rnd > self._tp_prob: + return None + + fields["labels"] = SequenceLabelField(labels, sequence, + label_namespace="labels") + fields["d_tags"] = SequenceLabelField(detect_tags, sequence, + label_namespace="d_tags") + return Instance(fields) diff --git a/nlptoolkit/gec/models/gector/gec_model.py b/nlptoolkit/gec/models/gector/gec_model.py new file mode 100644 index 0000000..67ff49f --- /dev/null +++ b/nlptoolkit/gec/models/gector/gec_model.py @@ -0,0 +1,322 @@ +"""Wrapper of AllenNLP model. Fixes errors based on model predictions""" +import logging +import os +import sys +from time import time + +import torch +from allennlp.data.dataset import Batch +from allennlp.data.fields import TextField +from allennlp.data.instance import Instance +from allennlp.data.tokenizers import Token +from allennlp.data.vocabulary import Vocabulary +from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder +from allennlp.nn import util + +from .bert_token_embedder import PretrainedBertEmbedder +from .seq2labels_model import Seq2Labels +from .wordpiece_indexer import PretrainedBertIndexer +from .utils.helpers import PAD, UNK, get_target_sent_by_edits + +logging.getLogger("werkzeug").setLevel(logging.ERROR) +logger = logging.getLogger(__file__) + + +def get_weights_name(transformer_name, lowercase): + if transformer_name == 'bert' and lowercase: + return 'bert-base-uncased' + if transformer_name == 'bert' and not lowercase: + return 'bert-base-cased' + if transformer_name == 'distilbert': + if not lowercase: + print('Warning! This model was trained only on uncased sentences.') + return 'distilbert-base-uncased' + if transformer_name == 'albert': + if not lowercase: + print('Warning! This model was trained only on uncased sentences.') + return 'albert-base-v1' + if lowercase: + print('Warning! This model was trained only on cased sentences.') + if transformer_name == 'roberta': + return 'roberta-base' + if transformer_name == 'gpt2': + return 'gpt2' + if transformer_name == 'transformerxl': + return 'transfo-xl-wt103' + if transformer_name == 'xlnet': + return 'xlnet-base-cased' + + +class GecBERTModel(object): + def __init__(self, vocab_path=None, model_paths=None, + weigths=None, + max_len=50, + min_len=3, + lowercase_tokens=False, + log=False, + iterations=3, + min_probability=0.0, + model_name='roberta', + special_tokens_fix=1, + is_ensemble=True, + min_error_probability=0.0, + confidence=0, + resolve_cycles=False, + ): + self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.max_len = max_len + self.min_len = min_len + self.lowercase_tokens = lowercase_tokens + self.min_probability = min_probability + self.min_error_probability = min_error_probability + self.vocab = Vocabulary.from_files(vocab_path) + self.log = log + self.iterations = iterations + self.confidence = confidence + self.resolve_cycles = resolve_cycles + # set training parameters and operations + + self.indexers = [] + self.models = [] + print("Model paths:", model_paths) + for model_path in model_paths: + print("Model path:", model_path) + if is_ensemble: + model_name, special_tokens_fix = self._get_model_data(model_path) + weights_name = get_weights_name(model_name, lowercase_tokens) + self.indexers.append(self._get_indexer(weights_name, special_tokens_fix)) + model = Seq2Labels(vocab=self.vocab, + text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix), + confidence=self.confidence + ).to(self.device) + if torch.cuda.is_available(): + model.load_state_dict(torch.load(model_path)) + else: + model.load_state_dict(torch.load(model_path, + map_location=torch.device('cpu'))) + model.eval() + self.models.append(model) + + @staticmethod + def _get_model_data(model_path): + model_name = model_path.split('/')[-1] + tr_model, stf = model_name.split('_')[:2] + return tr_model, int(stf) + + def _restore_model(self, input_path): + if os.path.isdir(input_path): + print("Model could not be restored from directory", file=sys.stderr) + filenames = [] + else: + filenames = [input_path] + for model_path in filenames: + try: + if torch.cuda.is_available(): + loaded_model = torch.load(model_path) + else: + loaded_model = torch.load(model_path, + map_location=lambda storage, + loc: storage) + except: + print(f"{model_path} is not valid model", file=sys.stderr) + own_state = self.model.state_dict() + for name, weights in loaded_model.items(): + if name not in own_state: + continue + try: + if len(filenames) == 1: + own_state[name].copy_(weights) + else: + own_state[name] += weights + except RuntimeError: + continue + print("Model is restored", file=sys.stderr) + + def predict(self, batches): + t11 = time() + predictions = [] + for batch, model in zip(batches, self.models): + batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1) + with torch.no_grad(): + prediction = model.forward(**batch) + predictions.append(prediction) + + preds, idx, error_probs = self._convert(predictions) + t55 = time() + if self.log: + print(f"Inference time {t55 - t11}") + return preds, idx, error_probs + + def get_token_action(self, token, index, prob, sugg_token): + """Get lost of suggested actions for token.""" + # cases when we don't need to do anything + if prob < self.min_probability or sugg_token in [UNK, PAD, '$KEEP']: + return None + + if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE': + start_pos = index + end_pos = index + 1 + elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): + start_pos = index + 1 + end_pos = index + 1 + + if sugg_token == "$DELETE": + sugg_token_clear = "" + elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"): + sugg_token_clear = sugg_token[:] + else: + sugg_token_clear = sugg_token[sugg_token.index('_') + 1:] + + return start_pos - 1, end_pos - 1, sugg_token_clear, prob + + def _get_embbeder(self, weigths_name, special_tokens_fix): + embedders = {'bert': PretrainedBertEmbedder( + pretrained_model=weigths_name, + requires_grad=False, + top_layer_only=True, + special_tokens_fix=special_tokens_fix) + } + text_field_embedder = BasicTextFieldEmbedder( + token_embedders=embedders, + embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]}, + allow_unmatched_keys=True) + return text_field_embedder + + def _get_indexer(self, weights_name, special_tokens_fix): + bert_token_indexer = PretrainedBertIndexer( + pretrained_model=weights_name, + do_lowercase=self.lowercase_tokens, + max_pieces_per_token=5, + use_starting_offsets=True, + truncate_long_sequences=True, + special_tokens_fix=special_tokens_fix, + is_test=True + ) + return {'bert': bert_token_indexer} + + def preprocess(self, token_batch): + seq_lens = [len(sequence) for sequence in token_batch if sequence] + if not seq_lens: + return [] + max_len = min(max(seq_lens), self.max_len) + batches = [] + for indexer in self.indexers: + batch = [] + for sequence in token_batch: + tokens = sequence[:max_len] + tokens = [Token(token) for token in ['$START'] + tokens] + batch.append(Instance({'tokens': TextField(tokens, indexer)})) + batch = Batch(batch) + batch.index_instances(self.vocab) + batches.append(batch) + + return batches + + def _convert(self, data): + all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels']) + error_probs = torch.zeros_like(data[0]['max_error_probability']) + for output, weight in zip(data, self.model_weights): + all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights) + error_probs += weight * output['max_error_probability'] / sum(self.model_weights) + + max_vals = torch.max(all_class_probs, dim=-1) + probs = max_vals[0].tolist() + idx = max_vals[1].tolist() + return probs, idx, error_probs.tolist() + + def update_final_batch(self, final_batch, pred_ids, pred_batch, + prev_preds_dict): + new_pred_ids = [] + total_updated = 0 + for i, orig_id in enumerate(pred_ids): + orig = final_batch[orig_id] + pred = pred_batch[i] + prev_preds = prev_preds_dict[orig_id] + if orig != pred and pred not in prev_preds: + final_batch[orig_id] = pred + new_pred_ids.append(orig_id) + prev_preds_dict[orig_id].append(pred) + total_updated += 1 + elif orig != pred and pred in prev_preds: + # update final batch, but stop iterations + final_batch[orig_id] = pred + total_updated += 1 + else: + continue + return final_batch, new_pred_ids, total_updated + + def postprocess_batch(self, batch, all_probabilities, all_idxs, + error_probs, + max_len=50): + all_results = [] + noop_index = self.vocab.get_token_index("$KEEP", "labels") + for tokens, probabilities, idxs, error_prob in zip(batch, + all_probabilities, + all_idxs, + error_probs): + length = min(len(tokens), max_len) + edits = [] + + # skip whole sentences if there no errors + if max(idxs) == 0: + all_results.append(tokens) + continue + + # skip whole sentence if probability of correctness is not high + if error_prob < self.min_error_probability: + all_results.append(tokens) + continue + + for i in range(length): + token = tokens[i - 1] # because of START token + # skip if there is no error + if idxs[i] == noop_index: + continue + + sugg_token = self.vocab.get_token_from_index(idxs[i], + namespace='labels') + action = self.get_token_action(token, i, probabilities[i], + sugg_token) + if not action: + continue + + edits.append(action) + all_results.append(get_target_sent_by_edits(tokens, edits)) + return all_results + + def handle_batch(self, full_batch): + """ + Handle batch of requests. + """ + final_batch = full_batch[:] + batch_size = len(full_batch) + prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))} + short_ids = [i for i in range(len(full_batch)) + if len(full_batch[i]) < self.min_len] + pred_ids = [i for i in range(len(full_batch)) if i not in short_ids] + total_updates = 0 + + for n_iter in range(self.iterations): + orig_batch = [final_batch[i] for i in pred_ids] + + sequences = self.preprocess(orig_batch) + + if not sequences: + break + probabilities, idxs, error_probs = self.predict(sequences) + + pred_batch = self.postprocess_batch(orig_batch, probabilities, + idxs, error_probs) + if self.log: + print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.") + + final_batch, pred_ids, cnt = \ + self.update_final_batch(final_batch, pred_ids, pred_batch, + prev_preds_dict) + total_updates += cnt + + if not pred_ids: + break + + return final_batch, total_updates diff --git a/nlptoolkit/gec/models/gector/seq2labels_model.py b/nlptoolkit/gec/models/gector/seq2labels_model.py new file mode 100644 index 0000000..13efc88 --- /dev/null +++ b/nlptoolkit/gec/models/gector/seq2labels_model.py @@ -0,0 +1,193 @@ +"""Basic model. Predicts tags for every token""" +from typing import Dict, Optional, List, Any + +import numpy +import torch +import torch.nn.functional as F +from allennlp.data import Vocabulary +from allennlp.models.model import Model +from allennlp.modules import TimeDistributed, TextFieldEmbedder +from allennlp.nn import InitializerApplicator, RegularizerApplicator +from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits +from allennlp.training.metrics import CategoricalAccuracy +from overrides import overrides +from torch.nn.modules.linear import Linear + + +@Model.register("seq2labels") +class Seq2Labels(Model): + """ + This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then + predicts a tag (or couple tags) for each token in the sequence. + + Parameters + ---------- + vocab : ``Vocabulary``, required + A Vocabulary, required in order to compute sizes for input/output projections. + text_field_embedder : ``TextFieldEmbedder``, required + Used to embed the ``tokens`` ``TextField`` we get as input to the model. + encoder : ``Seq2SeqEncoder`` + The encoder (with its own internal stacking) that we will use in between embedding tokens + and predicting output tags. + calculate_span_f1 : ``bool``, optional (default=``None``) + Calculate span-level F1 metrics during training. If this is ``True``, then + ``label_encoding`` is required. If ``None`` and + label_encoding is specified, this is set to ``True``. + If ``None`` and label_encoding is not specified, it defaults + to ``False``. + label_encoding : ``str``, optional (default=``None``) + Label encoding to use when calculating span f1. + Valid options are "BIO", "BIOUL", "IOB1", "BMES". + Required if ``calculate_span_f1`` is true. + label_namespace : ``str``, optional (default=``labels``) + This is needed to compute the SpanBasedF1Measure metric, if desired. + Unless you did something unusual, the default value should be what you want. + verbose_metrics : ``bool``, optional (default = False) + If true, metrics will be returned per label class in addition + to the overall statistics. + initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) + Used to initialize the model parameters. + regularizer : ``RegularizerApplicator``, optional (default=``None``) + If provided, will be used to calculate the regularization penalty during training. + """ + + def __init__(self, vocab: Vocabulary, + text_field_embedder: TextFieldEmbedder, + predictor_dropout=0.0, + labels_namespace: str = "labels", + detect_namespace: str = "d_tags", + verbose_metrics: bool = False, + label_smoothing: float = 0.0, + confidence: float = 0.0, + initializer: InitializerApplicator = InitializerApplicator(), + regularizer: Optional[RegularizerApplicator] = None) -> None: + super(Seq2Labels, self).__init__(vocab, regularizer) + + self.label_namespaces = [labels_namespace, + detect_namespace] + self.text_field_embedder = text_field_embedder + self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace) + self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace) + self.label_smoothing = label_smoothing + self.confidence = confidence + self.incorr_index = self.vocab.get_token_index("INCORRECT", + namespace=detect_namespace) + + self._verbose_metrics = verbose_metrics + self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout)) + + self.tag_labels_projection_layer = TimeDistributed( + Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes)) + + self.tag_detect_projection_layer = TimeDistributed( + Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes)) + + self.metrics = {"accuracy": CategoricalAccuracy()} + + initializer(self) + + @overrides + def forward(self, # type: ignore + tokens: Dict[str, torch.LongTensor], + labels: torch.LongTensor = None, + d_tags: torch.LongTensor = None, + metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: + # pylint: disable=arguments-differ + """ + Parameters + ---------- + tokens : Dict[str, torch.LongTensor], required + The output of ``TextField.as_array()``, which should typically be passed directly to a + ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` + tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": + Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used + for the ``TokenIndexers`` when you created the ``TextField`` representing your + sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, + which knows how to combine different word representations into a single vector per + token in your input. + lables : torch.LongTensor, optional (default = None) + A torch tensor representing the sequence of integer gold class labels of shape + ``(batch_size, num_tokens)``. + d_tags : torch.LongTensor, optional (default = None) + A torch tensor representing the sequence of integer gold class labels of shape + ``(batch_size, num_tokens)``. + metadata : ``List[Dict[str, Any]]``, optional, (default = None) + metadata containing the original words in the sentence to be tagged under a 'words' key. + + Returns + ------- + An output dictionary consisting of: + logits : torch.FloatTensor + A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing + unnormalised log probabilities of the tag classes. + class_probabilities : torch.FloatTensor + A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing + a distribution of the tag classes per word. + loss : torch.FloatTensor, optional + A scalar loss to be optimised. + + """ + encoded_text = self.text_field_embedder(tokens) + batch_size, sequence_length, _ = encoded_text.size() + mask = get_text_field_mask(tokens) + logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text)) + logits_d = self.tag_detect_projection_layer(encoded_text) + + class_probabilities_labels = F.softmax(logits_labels, dim=-1).view( + [batch_size, sequence_length, self.num_labels_classes]) + class_probabilities_d = F.softmax(logits_d, dim=-1).view( + [batch_size, sequence_length, self.num_detect_classes]) + error_probs = class_probabilities_d[:, :, self.incorr_index] * mask + incorr_prob = torch.max(error_probs, dim=-1)[0] + + if self.confidence > 0: + probability_change = [self.confidence] + [0] * (self.num_labels_classes - 1) + class_probabilities_labels += torch.cuda.FloatTensor(probability_change).repeat( + (batch_size, sequence_length, 1)) + + output_dict = {"logits_labels": logits_labels, + "logits_d_tags": logits_d, + "class_probabilities_labels": class_probabilities_labels, + "class_probabilities_d_tags": class_probabilities_d, + "max_error_probability": incorr_prob} + if labels is not None and d_tags is not None: + loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask, + label_smoothing=self.label_smoothing) + loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask) + for metric in self.metrics.values(): + metric(logits_labels, labels, mask.float()) + metric(logits_d, d_tags, mask.float()) + output_dict["loss"] = loss_labels + loss_d + + if metadata is not None: + output_dict["words"] = [x["words"] for x in metadata] + return output_dict + + @overrides + def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Does a simple position-wise argmax over each token, converts indices to string labels, and + adds a ``"tags"`` key to the dictionary with the result. + """ + for label_namespace in self.label_namespaces: + all_predictions = output_dict[f'class_probabilities_{label_namespace}'] + all_predictions = all_predictions.cpu().data.numpy() + if all_predictions.ndim == 3: + predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] + else: + predictions_list = [all_predictions] + all_tags = [] + + for predictions in predictions_list: + argmax_indices = numpy.argmax(predictions, axis=-1) + tags = [self.vocab.get_token_from_index(x, namespace=label_namespace) + for x in argmax_indices] + all_tags.append(tags) + output_dict[f'{label_namespace}'] = all_tags + return output_dict + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + metrics_to_return = {metric_name: metric.get_metric(reset) for + metric_name, metric in self.metrics.items()} + return metrics_to_return diff --git a/nlptoolkit/gec/models/gector/trainer.py b/nlptoolkit/gec/models/gector/trainer.py new file mode 100644 index 0000000..87f79c0 --- /dev/null +++ b/nlptoolkit/gec/models/gector/trainer.py @@ -0,0 +1,845 @@ +"""Tweaked version of corresponding AllenNLP file""" +import datetime +import logging +import math +import os +import time +import traceback +from typing import Dict, Optional, List, Tuple, Union, Iterable, Any + +import torch +import torch.optim.lr_scheduler +from allennlp.common import Params +from allennlp.common.checks import ConfigurationError, parse_cuda_device +from allennlp.common.tqdm import Tqdm +from allennlp.common.util import dump_metrics, gpu_memory_mb, peak_memory_mb, lazy_groups_of +from allennlp.data.instance import Instance +from allennlp.data.iterators.data_iterator import DataIterator, TensorDict +from allennlp.models.model import Model +from allennlp.nn import util as nn_util +from allennlp.training import util as training_util +from allennlp.training.checkpointer import Checkpointer +from allennlp.training.learning_rate_schedulers import LearningRateScheduler +from allennlp.training.metric_tracker import MetricTracker +from allennlp.training.momentum_schedulers import MomentumScheduler +from allennlp.training.moving_average import MovingAverage +from allennlp.training.optimizers import Optimizer +from allennlp.training.tensorboard_writer import TensorboardWriter +from allennlp.training.trainer_base import TrainerBase + +logger = logging.getLogger(__name__) + + +class Trainer(TrainerBase): + def __init__( + self, + model: Model, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + iterator: DataIterator, + train_dataset: Iterable[Instance], + validation_dataset: Optional[Iterable[Instance]] = None, + patience: Optional[int] = None, + validation_metric: str = "-loss", + validation_iterator: DataIterator = None, + shuffle: bool = True, + num_epochs: int = 20, + accumulated_batch_count: int = 1, + serialization_dir: Optional[str] = None, + num_serialized_models_to_keep: int = 20, + keep_serialized_model_every_num_seconds: int = None, + checkpointer: Checkpointer = None, + model_save_interval: float = None, + cuda_device: Union[int, List] = -1, + grad_norm: Optional[float] = None, + grad_clipping: Optional[float] = None, + learning_rate_scheduler: Optional[LearningRateScheduler] = None, + momentum_scheduler: Optional[MomentumScheduler] = None, + summary_interval: int = 100, + histogram_interval: int = None, + should_log_parameter_statistics: bool = True, + should_log_learning_rate: bool = False, + log_batch_size_period: Optional[int] = None, + moving_average: Optional[MovingAverage] = None, + cold_step_count: int = 0, + cold_lr: float = 1e-3, + cuda_verbose_step=None, + ) -> None: + """ + A trainer for doing supervised learning. It just takes a labeled dataset + and a ``DataIterator``, and uses the supplied ``Optimizer`` to learn the weights + for your model over some fixed number of epochs. You can also pass in a validation + dataset and enable early stopping. There are many other bells and whistles as well. + + Parameters + ---------- + model : ``Model``, required. + An AllenNLP model to be optimized. Pytorch Modules can also be optimized if + their ``forward`` method returns a dictionary with a "loss" key, containing a + scalar tensor representing the loss function to be optimized. + + If you are training your model using GPUs, your model should already be + on the correct device. (If you use `Trainer.from_params` this will be + handled for you.) + optimizer : ``torch.nn.Optimizer``, required. + An instance of a Pytorch Optimizer, instantiated with the parameters of the + model to be optimized. + iterator : ``DataIterator``, required. + A method for iterating over a ``Dataset``, yielding padded indexed batches. + train_dataset : ``Dataset``, required. + A ``Dataset`` to train on. The dataset should have already been indexed. + validation_dataset : ``Dataset``, optional, (default = None). + A ``Dataset`` to evaluate on. The dataset should have already been indexed. + patience : Optional[int] > 0, optional (default=None) + Number of epochs to be patient before early stopping: the training is stopped + after ``patience`` epochs with no improvement. If given, it must be ``> 0``. + If None, early stopping is disabled. + validation_metric : str, optional (default="loss") + Validation metric to measure for whether to stop training using patience + and whether to serialize an ``is_best`` model each epoch. The metric name + must be prepended with either "+" or "-", which specifies whether the metric + is an increasing or decreasing function. + validation_iterator : ``DataIterator``, optional (default=None) + An iterator to use for the validation set. If ``None``, then + use the training `iterator`. + shuffle: ``bool``, optional (default=True) + Whether to shuffle the instances in the iterator or not. + num_epochs : int, optional (default = 20) + Number of training epochs. + serialization_dir : str, optional (default=None) + Path to directory for saving and loading model files. Models will not be saved if + this parameter is not passed. + num_serialized_models_to_keep : ``int``, optional (default=20) + Number of previous model checkpoints to retain. Default is to keep 20 checkpoints. + A value of None or -1 means all checkpoints will be kept. + keep_serialized_model_every_num_seconds : ``int``, optional (default=None) + If num_serialized_models_to_keep is not None, then occasionally it's useful to + save models at a given interval in addition to the last num_serialized_models_to_keep. + To do so, specify keep_serialized_model_every_num_seconds as the number of seconds + between permanently saved checkpoints. Note that this option is only used if + num_serialized_models_to_keep is not None, otherwise all checkpoints are kept. + checkpointer : ``Checkpointer``, optional (default=None) + An instance of class Checkpointer to use instead of the default. If a checkpointer is specified, + the arguments num_serialized_models_to_keep and keep_serialized_model_every_num_seconds should + not be specified. The caller is responsible for initializing the checkpointer so that it is + consistent with serialization_dir. + model_save_interval : ``float``, optional (default=None) + If provided, then serialize models every ``model_save_interval`` + seconds within single epochs. In all cases, models are also saved + at the end of every epoch if ``serialization_dir`` is provided. + cuda_device : ``Union[int, List[int]]``, optional (default = -1) + An integer or list of integers specifying the CUDA device(s) to use. If -1, the CPU is used. + grad_norm : ``float``, optional, (default = None). + If provided, gradient norms will be rescaled to have a maximum of this value. + grad_clipping : ``float``, optional (default = ``None``). + If provided, gradients will be clipped `during the backward pass` to have an (absolute) + maximum of this value. If you are getting ``NaNs`` in your gradients during training + that are not solved by using ``grad_norm``, you may need this. + learning_rate_scheduler : ``LearningRateScheduler``, optional (default = None) + If specified, the learning rate will be decayed with respect to + this schedule at the end of each epoch (or batch, if the scheduler implements + the ``step_batch`` method). If you use :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`, + this will use the ``validation_metric`` provided to determine if learning has plateaued. + To support updating the learning rate on every batch, this can optionally implement + ``step_batch(batch_num_total)`` which updates the learning rate given the batch number. + momentum_scheduler : ``MomentumScheduler``, optional (default = None) + If specified, the momentum will be updated at the end of each batch or epoch + according to the schedule. + summary_interval: ``int``, optional, (default = 100) + Number of batches between logging scalars to tensorboard + histogram_interval : ``int``, optional, (default = ``None``) + If not None, then log histograms to tensorboard every ``histogram_interval`` batches. + When this parameter is specified, the following additional logging is enabled: + * Histograms of model parameters + * The ratio of parameter update norm to parameter norm + * Histogram of layer activations + We log histograms of the parameters returned by + ``model.get_parameters_for_histogram_tensorboard_logging``. + The layer activations are logged for any modules in the ``Model`` that have + the attribute ``should_log_activations`` set to ``True``. Logging + histograms requires a number of GPU-CPU copies during training and is typically + slow, so we recommend logging histograms relatively infrequently. + Note: only Modules that return tensors, tuples of tensors or dicts + with tensors as values currently support activation logging. + should_log_parameter_statistics : ``bool``, optional, (default = True) + Whether to send parameter statistics (mean and standard deviation + of parameters and gradients) to tensorboard. + should_log_learning_rate : ``bool``, optional, (default = False) + Whether to send parameter specific learning rate to tensorboard. + log_batch_size_period : ``int``, optional, (default = ``None``) + If defined, how often to log the average batch size. + moving_average: ``MovingAverage``, optional, (default = None) + If provided, we will maintain moving averages for all parameters. During training, we + employ a shadow variable for each parameter, which maintains the moving average. During + evaluation, we backup the original parameters and assign the moving averages to corresponding + parameters. Be careful that when saving the checkpoint, we will save the moving averages of + parameters. This is necessary because we want the saved model to perform as well as the validated + model if we load it later. But this may cause problems if you restart the training from checkpoint. + """ + super().__init__(serialization_dir, cuda_device) + + # I am not calling move_to_gpu here, because if the model is + # not already on the GPU then the optimizer is going to be wrong. + self.model = model + + self.iterator = iterator + self._validation_iterator = validation_iterator + self.shuffle = shuffle + self.optimizer = optimizer + self.scheduler = scheduler + self.train_data = train_dataset + self._validation_data = validation_dataset + self.accumulated_batch_count = accumulated_batch_count + self.cold_step_count = cold_step_count + self.cold_lr = cold_lr + self.cuda_verbose_step = cuda_verbose_step + + if patience is None: # no early stopping + if validation_dataset: + logger.warning( + "You provided a validation dataset but patience was set to None, " + "meaning that early stopping is disabled" + ) + elif (not isinstance(patience, int)) or patience <= 0: + raise ConfigurationError( + '{} is an invalid value for "patience": it must be a positive integer ' + "or None (if you want to disable early stopping)".format(patience) + ) + + # For tracking is_best_so_far and should_stop_early + self._metric_tracker = MetricTracker(patience, validation_metric) + # Get rid of + or - + self._validation_metric = validation_metric[1:] + + self._num_epochs = num_epochs + + if checkpointer is not None: + # We can't easily check if these parameters were passed in, so check against their default values. + # We don't check against serialization_dir since it is also used by the parent class. + if num_serialized_models_to_keep != 20 \ + or keep_serialized_model_every_num_seconds is not None: + raise ConfigurationError( + "When passing a custom Checkpointer, you may not also pass in separate checkpointer " + "args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'." + ) + self._checkpointer = checkpointer + else: + self._checkpointer = Checkpointer( + serialization_dir, + keep_serialized_model_every_num_seconds, + num_serialized_models_to_keep, + ) + + self._model_save_interval = model_save_interval + + self._grad_norm = grad_norm + self._grad_clipping = grad_clipping + + self._learning_rate_scheduler = learning_rate_scheduler + self._momentum_scheduler = momentum_scheduler + self._moving_average = moving_average + + # We keep the total batch number as an instance variable because it + # is used inside a closure for the hook which logs activations in + # ``_enable_activation_logging``. + self._batch_num_total = 0 + + self._tensorboard = TensorboardWriter( + get_batch_num_total=lambda: self._batch_num_total, + serialization_dir=serialization_dir, + summary_interval=summary_interval, + histogram_interval=histogram_interval, + should_log_parameter_statistics=should_log_parameter_statistics, + should_log_learning_rate=should_log_learning_rate, + ) + + self._log_batch_size_period = log_batch_size_period + + self._last_log = 0.0 # time of last logging + + # Enable activation logging. + if histogram_interval is not None: + self._tensorboard.enable_activation_logging(self.model) + + def rescale_gradients(self) -> Optional[float]: + return training_util.rescale_gradients(self.model, self._grad_norm) + + def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor: + """ + Does a forward pass on the given batches and returns the ``loss`` value in the result. + If ``for_training`` is `True` also applies regularization penalty. + """ + if self._multiple_gpu: + output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices) + else: + assert len(batch_group) == 1 + batch = batch_group[0] + batch = nn_util.move_to_device(batch, self._cuda_devices[0]) + output_dict = self.model(**batch) + + try: + loss = output_dict["loss"] + if for_training: + loss += self.model.get_regularization_penalty() + except KeyError: + if for_training: + raise RuntimeError( + "The model you are trying to optimize does not contain a" + " 'loss' key in the output of model.forward(inputs)." + ) + loss = None + + return loss + + def _train_epoch(self, epoch: int) -> Dict[str, float]: + """ + Trains one epoch and returns metrics. + """ + logger.info("Epoch %d/%d", epoch, self._num_epochs - 1) + peak_cpu_usage = peak_memory_mb() + logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}") + gpu_usage = [] + for gpu, memory in gpu_memory_mb().items(): + gpu_usage.append((gpu, memory)) + logger.info(f"GPU {gpu} memory usage MB: {memory}") + + train_loss = 0.0 + # Set the model to "train" mode. + self.model.train() + + num_gpus = len(self._cuda_devices) + + # Get tqdm for the training batches + raw_train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle) + train_generator = lazy_groups_of(raw_train_generator, num_gpus) + num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data) / num_gpus) + residue = num_training_batches % self.accumulated_batch_count + self._last_log = time.time() + last_save_time = time.time() + + batches_this_epoch = 0 + if self._batch_num_total is None: + self._batch_num_total = 0 + + histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging()) + + logger.info("Training") + train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches) + cumulative_batch_size = 0 + self.optimizer.zero_grad() + for batch_group in train_generator_tqdm: + batches_this_epoch += 1 + self._batch_num_total += 1 + batch_num_total = self._batch_num_total + + iter_len = self.accumulated_batch_count \ + if batches_this_epoch <= (num_training_batches - residue) else residue + + if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0: + print(f'Before forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}') + print(f'Before forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}') + try: + loss = self.batch_loss(batch_group, for_training=True) / iter_len + except RuntimeError as e: + print(e) + for x in batch_group: + all_words = [len(y['words']) for y in x['metadata']] + print(f"Total sents: {len(all_words)}. " + f"Min {min(all_words)}. Max {max(all_words)}") + for elem in ['labels', 'd_tags']: + tt = x[elem] + print( + f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}") + for elem in ["bert", "mask", "bert-offsets"]: + tt = x['tokens'][elem] + print( + f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}") + raise e + + if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0: + print(f'After forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}') + print(f'After forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}') + + if torch.isnan(loss): + raise ValueError("nan loss encountered") + + loss.backward() + + if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0: + print(f'After backprop - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}') + print(f'After backprop - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}') + + train_loss += loss.item() * iter_len + + del batch_group, loss + torch.cuda.empty_cache() + + if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0: + print(f'After collecting garbage - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}') + print(f'After collecting garbage - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}') + + batch_grad_norm = self.rescale_gradients() + + # This does nothing if batch_num_total is None or you are using a + # scheduler which doesn't update per batch. + if self._learning_rate_scheduler: + self._learning_rate_scheduler.step_batch(batch_num_total) + if self._momentum_scheduler: + self._momentum_scheduler.step_batch(batch_num_total) + + if self._tensorboard.should_log_histograms_this_batch(): + # get the magnitude of parameter updates for logging + # We need a copy of current parameters to compute magnitude of updates, + # and copy them to CPU so large models won't go OOM on the GPU. + param_updates = { + name: param.detach().cpu().clone() + for name, param in self.model.named_parameters() + } + if batches_this_epoch % self.accumulated_batch_count == 0 or \ + batches_this_epoch == num_training_batches: + self.optimizer.step() + self.optimizer.zero_grad() + for name, param in self.model.named_parameters(): + param_updates[name].sub_(param.detach().cpu()) + update_norm = torch.norm(param_updates[name].view(-1)) + param_norm = torch.norm(param.view(-1)).cpu() + self._tensorboard.add_train_scalar( + "gradient_update/" + name, update_norm / (param_norm + 1e-7) + ) + else: + if batches_this_epoch % self.accumulated_batch_count == 0 or \ + batches_this_epoch == num_training_batches: + self.optimizer.step() + self.optimizer.zero_grad() + + # Update moving averages + if self._moving_average is not None: + self._moving_average.apply(batch_num_total) + + # Update the description with the latest metrics + metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch) + description = training_util.description_from_metrics(metrics) + + train_generator_tqdm.set_description(description, refresh=False) + + # Log parameter values to Tensorboard + if self._tensorboard.should_log_this_batch(): + self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm) + self._tensorboard.log_learning_rates(self.model, self.optimizer) + + self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"]) + self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()}) + + if self._tensorboard.should_log_histograms_this_batch(): + self._tensorboard.log_histograms(self.model, histogram_parameters) + + if self._log_batch_size_period: + cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group]) + cumulative_batch_size += cur_batch + if (batches_this_epoch - 1) % self._log_batch_size_period == 0: + average = cumulative_batch_size / batches_this_epoch + logger.info(f"current batch size: {cur_batch} mean batch size: {average}") + self._tensorboard.add_train_scalar("current_batch_size", cur_batch) + self._tensorboard.add_train_scalar("mean_batch_size", average) + + # Save model if needed. + if self._model_save_interval is not None and ( + time.time() - last_save_time > self._model_save_interval + ): + last_save_time = time.time() + self._save_checkpoint( + "{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time))) + ) + + metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch, reset=True) + metrics["cpu_memory_MB"] = peak_cpu_usage + for (gpu_num, memory) in gpu_usage: + metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory + return metrics + + def _validation_loss(self) -> Tuple[float, int]: + """ + Computes the validation loss. Returns it and the number of batches. + """ + logger.info("Validating") + + self.model.eval() + + # Replace parameter values with the shadow values from the moving averages. + if self._moving_average is not None: + self._moving_average.assign_average_value() + + if self._validation_iterator is not None: + val_iterator = self._validation_iterator + else: + val_iterator = self.iterator + + num_gpus = len(self._cuda_devices) + + raw_val_generator = val_iterator(self._validation_data, num_epochs=1, shuffle=False) + val_generator = lazy_groups_of(raw_val_generator, num_gpus) + num_validation_batches = math.ceil( + val_iterator.get_num_batches(self._validation_data) / num_gpus + ) + val_generator_tqdm = Tqdm.tqdm(val_generator, total=num_validation_batches) + batches_this_epoch = 0 + val_loss = 0 + for batch_group in val_generator_tqdm: + + loss = self.batch_loss(batch_group, for_training=False) + if loss is not None: + # You shouldn't necessarily have to compute a loss for validation, so we allow for + # `loss` to be None. We need to be careful, though - `batches_this_epoch` is + # currently only used as the divisor for the loss function, so we can safely only + # count those batches for which we actually have a loss. If this variable ever + # gets used for something else, we might need to change things around a bit. + batches_this_epoch += 1 + val_loss += loss.detach().cpu().numpy() + + # Update the description with the latest metrics + val_metrics = training_util.get_metrics(self.model, val_loss, batches_this_epoch) + description = training_util.description_from_metrics(val_metrics) + val_generator_tqdm.set_description(description, refresh=False) + + # Now restore the original parameter values. + if self._moving_average is not None: + self._moving_average.restore() + + return val_loss, batches_this_epoch + + def train(self) -> Dict[str, Any]: + """ + Trains the supplied model with the supplied parameters. + """ + try: + epoch_counter = self._restore_checkpoint() + except RuntimeError: + traceback.print_exc() + raise ConfigurationError( + "Could not recover training from the checkpoint. Did you mean to output to " + "a different serialization directory or delete the existing serialization " + "directory?" + ) + + training_util.enable_gradient_clipping(self.model, self._grad_clipping) + + logger.info("Beginning training.") + + train_metrics: Dict[str, float] = {} + val_metrics: Dict[str, float] = {} + this_epoch_val_metric: float = None + metrics: Dict[str, Any] = {} + epochs_trained = 0 + training_start_time = time.time() + + if self.cold_step_count > 0: + base_lr = self.optimizer.param_groups[0]['lr'] + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.cold_lr + self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=True) + + metrics["best_epoch"] = self._metric_tracker.best_epoch + for key, value in self._metric_tracker.best_epoch_metrics.items(): + metrics["best_validation_" + key] = value + + for epoch in range(epoch_counter, self._num_epochs): + if epoch == self.cold_step_count and epoch != 0: + for param_group in self.optimizer.param_groups: + param_group['lr'] = base_lr + self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=False) + + epoch_start_time = time.time() + train_metrics = self._train_epoch(epoch) + + # get peak of memory usage + if "cpu_memory_MB" in train_metrics: + metrics["peak_cpu_memory_MB"] = max( + metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"] + ) + for key, value in train_metrics.items(): + if key.startswith("gpu_"): + metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) + + # clear cache before validation + torch.cuda.empty_cache() + if self._validation_data is not None: + with torch.no_grad(): + # We have a validation set, so compute all the metrics on it. + val_loss, num_batches = self._validation_loss() + val_metrics = training_util.get_metrics( + self.model, val_loss, num_batches, reset=True + ) + + # Check validation metric for early stopping + this_epoch_val_metric = val_metrics[self._validation_metric] + self._metric_tracker.add_metric(this_epoch_val_metric) + + if self._metric_tracker.should_stop_early(): + logger.info("Ran out of patience. Stopping training.") + break + + self._tensorboard.log_metrics( + train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1 + ) # +1 because tensorboard doesn't like 0 + + # Create overall metrics dict + training_elapsed_time = time.time() - training_start_time + metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) + metrics["training_start_epoch"] = epoch_counter + metrics["training_epochs"] = epochs_trained + metrics["epoch"] = epoch + + for key, value in train_metrics.items(): + metrics["training_" + key] = value + for key, value in val_metrics.items(): + metrics["validation_" + key] = value + + # if self.cold_step_count <= epoch: + self.scheduler.step(metrics['validation_loss']) + + if self._metric_tracker.is_best_so_far(): + # Update all the best_ metrics. + # (Otherwise they just stay the same as they were.) + metrics["best_epoch"] = epoch + for key, value in val_metrics.items(): + metrics["best_validation_" + key] = value + + self._metric_tracker.best_epoch_metrics = val_metrics + + if self._serialization_dir: + dump_metrics( + os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics + ) + + # The Scheduler API is agnostic to whether your schedule requires a validation metric - + # if it doesn't, the validation metric passed here is ignored. + if self._learning_rate_scheduler: + self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) + if self._momentum_scheduler: + self._momentum_scheduler.step(this_epoch_val_metric, epoch) + + self._save_checkpoint(epoch) + + epoch_elapsed_time = time.time() - epoch_start_time + logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) + + if epoch < self._num_epochs - 1: + training_elapsed_time = time.time() - training_start_time + estimated_time_remaining = training_elapsed_time * ( + (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1 + ) + formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) + logger.info("Estimated training time remaining: %s", formatted_time) + + epochs_trained += 1 + + # make sure pending events are flushed to disk and files are closed properly + # self._tensorboard.close() + + # Load the best model state before returning + best_model_state = self._checkpointer.best_model_state() + if best_model_state: + self.model.load_state_dict(best_model_state) + + return metrics + + def _save_checkpoint(self, epoch: Union[int, str]) -> None: + """ + Saves a checkpoint of the model to self._serialization_dir. + Is a no-op if self._serialization_dir is None. + + Parameters + ---------- + epoch : Union[int, str], required. + The epoch of training. If the checkpoint is saved in the middle + of an epoch, the parameter is a string with the epoch and timestamp. + """ + # If moving averages are used for parameters, we save + # the moving average values into checkpoint, instead of the current values. + if self._moving_average is not None: + self._moving_average.assign_average_value() + + # These are the training states we need to persist. + training_states = { + "metric_tracker": self._metric_tracker.state_dict(), + "optimizer": self.optimizer.state_dict(), + "batch_num_total": self._batch_num_total, + } + + # If we have a learning rate or momentum scheduler, we should persist them too. + if self._learning_rate_scheduler is not None: + training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict() + if self._momentum_scheduler is not None: + training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict() + + self._checkpointer.save_checkpoint( + model_state=self.model.state_dict(), + epoch=epoch, + training_states=training_states, + is_best_so_far=self._metric_tracker.is_best_so_far(), + ) + + # Restore the original values for parameters so that training will not be affected. + if self._moving_average is not None: + self._moving_average.restore() + + def _restore_checkpoint(self) -> int: + """ + Restores the model and training state from the last saved checkpoint. + This includes an epoch count and optimizer state, which is serialized separately + from model parameters. This function should only be used to continue training - + if you wish to load a model for inference/load parts of a model into a new + computation graph, you should use the native Pytorch functions: + `` model.load_state_dict(torch.load("/path/to/model/weights.th"))`` + + If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights, + this function will do nothing and return 0. + + Returns + ------- + epoch: int + The epoch at which to resume training, which should be one after the epoch + in the saved training state. + """ + model_state, training_state = self._checkpointer.restore_checkpoint() + + if not training_state: + # No checkpoint to restore, start at 0 + return 0 + + self.model.load_state_dict(model_state) + self.optimizer.load_state_dict(training_state["optimizer"]) + if self._learning_rate_scheduler is not None \ + and "learning_rate_scheduler" in training_state: + self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"]) + if self._momentum_scheduler is not None and "momentum_scheduler" in training_state: + self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"]) + training_util.move_optimizer_to_cuda(self.optimizer) + + # Currently the ``training_state`` contains a serialized ``MetricTracker``. + if "metric_tracker" in training_state: + self._metric_tracker.load_state_dict(training_state["metric_tracker"]) + # It used to be the case that we tracked ``val_metric_per_epoch``. + elif "val_metric_per_epoch" in training_state: + self._metric_tracker.clear() + self._metric_tracker.add_metrics(training_state["val_metric_per_epoch"]) + # And before that we didn't track anything. + else: + self._metric_tracker.clear() + + if isinstance(training_state["epoch"], int): + epoch_to_return = training_state["epoch"] + 1 + else: + epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1 + + # For older checkpoints with batch_num_total missing, default to old behavior where + # it is unchanged. + batch_num_total = training_state.get("batch_num_total") + if batch_num_total is not None: + self._batch_num_total = batch_num_total + + return epoch_to_return + + # Requires custom from_params. + @classmethod + def from_params( # type: ignore + cls, + model: Model, + serialization_dir: str, + iterator: DataIterator, + train_data: Iterable[Instance], + validation_data: Optional[Iterable[Instance]], + params: Params, + validation_iterator: DataIterator = None, + ) -> "Trainer": + + patience = params.pop_int("patience", None) + validation_metric = params.pop("validation_metric", "-loss") + shuffle = params.pop_bool("shuffle", True) + num_epochs = params.pop_int("num_epochs", 20) + cuda_device = parse_cuda_device(params.pop("cuda_device", -1)) + grad_norm = params.pop_float("grad_norm", None) + grad_clipping = params.pop_float("grad_clipping", None) + lr_scheduler_params = params.pop("learning_rate_scheduler", None) + momentum_scheduler_params = params.pop("momentum_scheduler", None) + + if isinstance(cuda_device, list): + model_device = cuda_device[0] + else: + model_device = cuda_device + if model_device >= 0: + # Moving model to GPU here so that the optimizer state gets constructed on + # the right device. + model = model.cuda(model_device) + + parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] + optimizer = Optimizer.from_params(parameters, params.pop("optimizer")) + if "moving_average" in params: + moving_average = MovingAverage.from_params( + params.pop("moving_average"), parameters=parameters + ) + else: + moving_average = None + + if lr_scheduler_params: + lr_scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params) + else: + lr_scheduler = None + if momentum_scheduler_params: + momentum_scheduler = MomentumScheduler.from_params(optimizer, momentum_scheduler_params) + else: + momentum_scheduler = None + + if "checkpointer" in params: + if "keep_serialized_model_every_num_seconds" in params \ + or "num_serialized_models_to_keep" in params: + raise ConfigurationError( + "Checkpointer may be initialized either from the 'checkpointer' key or from the " + "keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'" + " but the passed config uses both methods." + ) + checkpointer = Checkpointer.from_params(params.pop("checkpointer")) + else: + num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20) + keep_serialized_model_every_num_seconds = params.pop_int( + "keep_serialized_model_every_num_seconds", None + ) + checkpointer = Checkpointer( + serialization_dir=serialization_dir, + num_serialized_models_to_keep=num_serialized_models_to_keep, + keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds, + ) + model_save_interval = params.pop_float("model_save_interval", None) + summary_interval = params.pop_int("summary_interval", 100) + histogram_interval = params.pop_int("histogram_interval", None) + should_log_parameter_statistics = params.pop_bool("should_log_parameter_statistics", True) + should_log_learning_rate = params.pop_bool("should_log_learning_rate", False) + log_batch_size_period = params.pop_int("log_batch_size_period", None) + + params.assert_empty(cls.__name__) + return cls( + model, + optimizer, + iterator, + train_data, + validation_data, + patience=patience, + validation_metric=validation_metric, + validation_iterator=validation_iterator, + shuffle=shuffle, + num_epochs=num_epochs, + serialization_dir=serialization_dir, + cuda_device=cuda_device, + grad_norm=grad_norm, + grad_clipping=grad_clipping, + learning_rate_scheduler=lr_scheduler, + momentum_scheduler=momentum_scheduler, + checkpointer=checkpointer, + model_save_interval=model_save_interval, + summary_interval=summary_interval, + histogram_interval=histogram_interval, + should_log_parameter_statistics=should_log_parameter_statistics, + should_log_learning_rate=should_log_learning_rate, + log_batch_size_period=log_batch_size_period, + moving_average=moving_average, + ) diff --git a/nlptoolkit/gec/models/gector/utils/__init__.py b/nlptoolkit/gec/models/gector/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nlptoolkit/gec/models/gector/utils/helpers.py b/nlptoolkit/gec/models/gector/utils/helpers.py new file mode 100644 index 0000000..05690cf --- /dev/null +++ b/nlptoolkit/gec/models/gector/utils/helpers.py @@ -0,0 +1,202 @@ +import os +from pathlib import Path + + +VOCAB_DIR = './data/gec/gector'#Path(__file__).resolve().parent.parent / "data" +PAD = "@@PADDING@@" +UNK = "@@UNKNOWN@@" +START_TOKEN = "$START" +SEQ_DELIMETERS = {"tokens": " ", + "labels": "SEPL|||SEPR", + "operations": "SEPL__SEPR"} + + +def get_verb_form_dicts(): + path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt") + encode, decode = {}, {} + with open(path_to_dict, encoding="utf-8") as f: + for line in f: + words, tags = line.split(":") + word1, word2 = words.split("_") + tag1, tag2 = tags.split("_") + decode_key = f"{word1}_{tag1}_{tag2.strip()}" + if decode_key not in decode: + encode[words] = tags + decode[decode_key] = word2 + return encode, decode + + +ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts() + + +def get_target_sent_by_edits(source_tokens, edits): + target_tokens = source_tokens[:] + shift_idx = 0 + for edit in edits: + start, end, label, _ = edit + target_pos = start + shift_idx + source_token = target_tokens[target_pos] if target_pos >= 0 else '' + if label == "": + del target_tokens[target_pos] + shift_idx -= 1 + elif start == end: + word = label.replace("$APPEND_", "") + target_tokens[target_pos: target_pos] = [word] + shift_idx += 1 + elif label.startswith("$TRANSFORM_"): + word = apply_reverse_transformation(source_token, label) + if word is None: + word = source_token + target_tokens[target_pos] = word + elif start == end - 1: + word = label.replace("$REPLACE_", "") + target_tokens[target_pos] = word + elif label.startswith("$MERGE_"): + target_tokens[target_pos + 1: target_pos + 1] = [label] + shift_idx += 1 + + return replace_merge_transforms(target_tokens) + + +def replace_merge_transforms(tokens): + if all(not x.startswith("$MERGE_") for x in tokens): + return tokens + + target_line = " ".join(tokens) + target_line = target_line.replace(" $MERGE_HYPHEN ", "-") + target_line = target_line.replace(" $MERGE_SPACE ", "") + return target_line.split() + + +def convert_using_case(token, smart_action): + if not smart_action.startswith("$TRANSFORM_CASE_"): + return token + if smart_action.endswith("LOWER"): + return token.lower() + elif smart_action.endswith("UPPER"): + return token.upper() + elif smart_action.endswith("CAPITAL"): + return token.capitalize() + elif smart_action.endswith("CAPITAL_1"): + return token[0] + token[1:].capitalize() + elif smart_action.endswith("UPPER_-1"): + return token[:-1].upper() + token[-1] + else: + return token + + +def convert_using_verb(token, smart_action): + key_word = "$TRANSFORM_VERB_" + if not smart_action.startswith(key_word): + raise Exception(f"Unknown action type {smart_action}") + encoding_part = f"{token}_{smart_action[len(key_word):]}" + decoded_target_word = decode_verb_form(encoding_part) + return decoded_target_word + + +def convert_using_split(token, smart_action): + key_word = "$TRANSFORM_SPLIT" + if not smart_action.startswith(key_word): + raise Exception(f"Unknown action type {smart_action}") + target_words = token.split("-") + return " ".join(target_words) + + +def convert_using_plural(token, smart_action): + if smart_action.endswith("PLURAL"): + return token + "s" + elif smart_action.endswith("SINGULAR"): + return token[:-1] + else: + raise Exception(f"Unknown action type {smart_action}") + + +def apply_reverse_transformation(source_token, transform): + if transform.startswith("$TRANSFORM"): + # deal with equal + if transform == "$KEEP": + return source_token + # deal with case + if transform.startswith("$TRANSFORM_CASE"): + return convert_using_case(source_token, transform) + # deal with verb + if transform.startswith("$TRANSFORM_VERB"): + return convert_using_verb(source_token, transform) + # deal with split + if transform.startswith("$TRANSFORM_SPLIT"): + return convert_using_split(source_token, transform) + # deal with single/plural + if transform.startswith("$TRANSFORM_AGREEMENT"): + return convert_using_plural(source_token, transform) + # raise exception if not find correct type + raise Exception(f"Unknown action type {transform}") + else: + return source_token + + +def read_parallel_lines(fn1, fn2): + lines1 = read_lines(fn1, skip_strip=True) + lines2 = read_lines(fn2, skip_strip=True) + assert len(lines1) == len(lines2) + out_lines1, out_lines2 = [], [] + for line1, line2 in zip(lines1, lines2): + if not line1.strip() or not line2.strip(): + continue + else: + out_lines1.append(line1) + out_lines2.append(line2) + return out_lines1, out_lines2 + + +def read_lines(fn, skip_strip=False): + if not os.path.exists(fn): + return [] + with open(fn, 'r', encoding='utf-8') as f: + lines = f.readlines() + return [s.strip() for s in lines if s.strip() or skip_strip] + + +def write_lines(fn, lines, mode='w'): + if mode == 'w' and os.path.exists(fn): + os.remove(fn) + with open(fn, encoding='utf-8', mode=mode) as f: + f.writelines(['%s\n' % s for s in lines]) + + +def decode_verb_form(original): + return DECODE_VERB_DICT.get(original) + + +def encode_verb_form(original_word, corrected_word): + decoding_request = original_word + "_" + corrected_word + decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip() + if original_word and decoding_response: + answer = decoding_response + else: + answer = None + return answer + + +def get_weights_name(transformer_name, lowercase): + if transformer_name == 'bert' and lowercase: + return 'bert-base-uncased' + if transformer_name == 'bert' and not lowercase: + return 'bert-base-cased' + if transformer_name == 'distilbert': + if not lowercase: + print('Warning! This model was trained only on uncased sentences.') + return 'distilbert-base-uncased' + if transformer_name == 'albert': + if not lowercase: + print('Warning! This model was trained only on uncased sentences.') + return 'albert-base-v1' + if lowercase: + print('Warning! This model was trained only on cased sentences.') + if transformer_name == 'roberta': + return 'roberta-base' + if transformer_name == 'gpt2': + return 'gpt2' + if transformer_name == 'transformerxl': + return 'transfo-xl-wt103' + if transformer_name == 'xlnet': + return 'xlnet-base-cased' diff --git a/nlptoolkit/gec/models/gector/utils/prepare_clc_fce_data.py b/nlptoolkit/gec/models/gector/utils/prepare_clc_fce_data.py new file mode 100644 index 0000000..d7a6286 --- /dev/null +++ b/nlptoolkit/gec/models/gector/utils/prepare_clc_fce_data.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +""" +Convert CLC-FCE dataset (The Cambridge Learner Corpus) to the parallel sentences format. +""" + +import argparse +import glob +import os +import re +from xml.etree import cElementTree + +from nltk.tokenize import sent_tokenize, word_tokenize +from tqdm import tqdm + + +def annotate_fce_doc(xml): + """Takes a FCE xml document and yields sentences with annotated errors.""" + result = [] + doc = cElementTree.fromstring(xml) + paragraphs = doc.findall('head/text/*/coded_answer/p') + for p in paragraphs: + text = _get_formatted_text(p) + result.append(text) + + return '\n'.join(result) + + +def _get_formatted_text(elem, ignore_tags=None): + text = elem.text or '' + ignore_tags = [tag.upper() for tag in (ignore_tags or [])] + correct = None + mistake = None + + for child in elem.getchildren(): + tag = child.tag.upper() + if tag == 'NS': + text += _get_formatted_text(child) + + elif tag == 'UNKNOWN': + text += ' UNKNOWN ' + + elif tag == 'C': + assert correct is None + correct = _get_formatted_text(child) + + elif tag == 'I': + assert mistake is None + mistake = _get_formatted_text(child) + + elif tag in ignore_tags: + pass + + else: + raise ValueError(f"Unknown tag `{child.tag}`", text) + + if correct or mistake: + correct = correct or '' + mistake = mistake or '' + if '=>' not in mistake: + text += f'{{{mistake}=>{correct}}}' + else: + text += mistake + + text += elem.tail or '' + return text + + +def convert_fce(fce_dir): + """Processes the whole FCE directory. Yields annotated documents (strings).""" + + # Ensure we got the valid dataset path + if not os.path.isdir(fce_dir): + raise UserWarning( + f"{fce_dir} is not a valid path") + + dataset_dir = os.path.join(fce_dir, 'dataset') + if not os.path.exists(dataset_dir): + raise UserWarning( + f"{fce_dir} doesn't point to a dataset's root dir") + + # Convert XML docs to the corpora format + filenames = sorted(glob.glob(os.path.join(dataset_dir, '*/*.xml'))) + + docs = [] + for filename in filenames: + with open(filename, encoding='utf-8') as f: + doc = annotate_fce_doc(f.read()) + docs.append(doc) + return docs + + +def main(): + fce = convert_fce(args.fce_dataset_path) + with open(args.output + "/fce-original.txt", 'w', encoding='utf-8') as out_original, \ + open(args.output + "/fce-applied.txt", 'w', encoding='utf-8') as out_applied: + for doc in tqdm(fce, unit='doc'): + sents = re.split(r"\n +\n", doc) + for sent in sents: + tokenized_sents = sent_tokenize(sent) + for i in range(len(tokenized_sents)): + if re.search(r"[{>][.?!]$", tokenized_sents[i]): + tokenized_sents[i + 1] = tokenized_sents[i] + " " + tokenized_sents[i + 1] + tokenized_sents[i] = "" + regexp = r'{([^{}]*?)=>([^{}]*?)}' + original = re.sub(regexp, r"\1", tokenized_sents[i]) + applied = re.sub(regexp, r"\2", tokenized_sents[i]) + # filter out nested alerts + if original != "" and applied != "" and not re.search(r"[{}=]", original) \ + and not re.search(r"[{}=]", applied): + out_original.write(" ".join(word_tokenize(original)) + "\n") + out_applied.write(" ".join(word_tokenize(applied)) + "\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=( + "Convert CLC-FCE dataset to the parallel sentences format.")) + parser.add_argument('fce_dataset_path', + help='Path to the folder with the FCE dataset') + parser.add_argument('--output', + help='Path to the output folder') + args = parser.parse_args() + + main() diff --git a/nlptoolkit/gec/models/gector/utils/preprocess_data.py b/nlptoolkit/gec/models/gector/utils/preprocess_data.py new file mode 100644 index 0000000..775357c --- /dev/null +++ b/nlptoolkit/gec/models/gector/utils/preprocess_data.py @@ -0,0 +1,488 @@ +import argparse +import os +from difflib import SequenceMatcher + +import Levenshtein +import numpy as np +from tqdm import tqdm + +from helpers import write_lines, read_parallel_lines, encode_verb_form, \ + apply_reverse_transformation, SEQ_DELIMETERS, START_TOKEN + + +def perfect_align(t, T, insertions_allowed=0, + cost_function=Levenshtein.distance): + # dp[i, j, k] is a minimal cost of matching first `i` tokens of `t` with + # first `j` tokens of `T`, after making `k` insertions after last match of + # token from `t`. In other words t[:i] aligned with T[:j]. + + # Initialize with INFINITY (unknown) + shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1) + dp = np.ones(shape, dtype=int) * int(1e9) + come_from = np.ones(shape, dtype=int) * int(1e9) + come_from_ins = np.ones(shape, dtype=int) * int(1e9) + + dp[0, 0, 0] = 0 # The only known starting point. Nothing matched to nothing. + for i in range(len(t) + 1): # Go inclusive + for j in range(len(T) + 1): # Go inclusive + for q in range(insertions_allowed + 1): # Go inclusive + if i < len(t): + # Given matched sequence of t[:i] and T[:j], match token + # t[i] with following tokens T[j:k]. + for k in range(j, len(T) + 1): + transform = \ + apply_transformation(t[i], ' '.join(T[j:k])) + if transform: + cost = 0 + else: + cost = cost_function(t[i], ' '.join(T[j:k])) + current = dp[i, j, q] + cost + if dp[i + 1, k, 0] > current: + dp[i + 1, k, 0] = current + come_from[i + 1, k, 0] = j + come_from_ins[i + 1, k, 0] = q + if q < insertions_allowed: + # Given matched sequence of t[:i] and T[:j], create + # insertion with following tokens T[j:k]. + for k in range(j, len(T) + 1): + cost = len(' '.join(T[j:k])) + current = dp[i, j, q] + cost + if dp[i, k, q + 1] > current: + dp[i, k, q + 1] = current + come_from[i, k, q + 1] = j + come_from_ins[i, k, q + 1] = q + + # Solution is in the dp[len(t), len(T), *]. Backtracking from there. + alignment = [] + i = len(t) + j = len(T) + q = dp[i, j, :].argmin() + while i > 0 or q > 0: + is_insert = (come_from_ins[i, j, q] != q) and (q != 0) + j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q] + if not is_insert: + i -= 1 + + if is_insert: + alignment.append(['INSERT', T[j:k], (i, i)]) + else: + alignment.append([f'REPLACE_{t[i]}', T[j:k], (i, i + 1)]) + + assert j == 0 + + return dp[len(t), len(T)].min(), list(reversed(alignment)) + + +def _split(token): + if not token: + return [] + parts = token.split() + return parts or [token] + + +def apply_merge_transformation(source_tokens, target_words, shift_idx): + edits = [] + if len(source_tokens) > 1 and len(target_words) == 1: + # check merge + transform = check_merge(source_tokens, target_words) + if transform: + for i in range(len(source_tokens) - 1): + edits.append([(shift_idx + i, shift_idx + i + 1), transform]) + return edits + + if len(source_tokens) == len(target_words) == 2: + # check swap + transform = check_swap(source_tokens, target_words) + if transform: + edits.append([(shift_idx, shift_idx + 1), transform]) + return edits + + +def is_sent_ok(sent, delimeters=SEQ_DELIMETERS): + for del_val in delimeters.values(): + if del_val in sent and del_val != " ": + return False + return True + + +def check_casetype(source_token, target_token): + if source_token.lower() != target_token.lower(): + return None + if source_token.lower() == target_token: + return "$TRANSFORM_CASE_LOWER" + elif source_token.capitalize() == target_token: + return "$TRANSFORM_CASE_CAPITAL" + elif source_token.upper() == target_token: + return "$TRANSFORM_CASE_UPPER" + elif source_token[1:].capitalize() == target_token[1:] and source_token[0] == target_token[0]: + return "$TRANSFORM_CASE_CAPITAL_1" + elif source_token[:-1].upper() == target_token[:-1] and source_token[-1] == target_token[-1]: + return "$TRANSFORM_CASE_UPPER_-1" + else: + return None + + +def check_equal(source_token, target_token): + if source_token == target_token: + return "$KEEP" + else: + return None + + +def check_split(source_token, target_tokens): + if source_token.split("-") == target_tokens: + return "$TRANSFORM_SPLIT_HYPHEN" + else: + return None + + +def check_merge(source_tokens, target_tokens): + if "".join(source_tokens) == "".join(target_tokens): + return "$MERGE_SPACE" + elif "-".join(source_tokens) == "-".join(target_tokens): + return "$MERGE_HYPHEN" + else: + return None + + +def check_swap(source_tokens, target_tokens): + if source_tokens == [x for x in reversed(target_tokens)]: + return "$MERGE_SWAP" + else: + return None + + +def check_plural(source_token, target_token): + if source_token.endswith("s") and source_token[:-1] == target_token: + return "$TRANSFORM_AGREEMENT_SINGULAR" + elif target_token.endswith("s") and source_token == target_token[:-1]: + return "$TRANSFORM_AGREEMENT_PLURAL" + else: + return None + + +def check_verb(source_token, target_token): + encoding = encode_verb_form(source_token, target_token) + if encoding: + return f"$TRANSFORM_VERB_{encoding}" + else: + return None + + +def apply_transformation(source_token, target_token): + target_tokens = target_token.split() + if len(target_tokens) > 1: + # check split + transform = check_split(source_token, target_tokens) + if transform: + return transform + checks = [check_equal, check_casetype, check_verb, check_plural] + for check in checks: + transform = check(source_token, target_token) + if transform: + return transform + return None + + +def align_sequences(source_sent, target_sent): + # check if sent is OK + if not is_sent_ok(source_sent) or not is_sent_ok(target_sent): + return None + source_tokens = source_sent.split() + target_tokens = target_sent.split() + matcher = SequenceMatcher(None, source_tokens, target_tokens) + diffs = list(matcher.get_opcodes()) + all_edits = [] + for diff in diffs: + tag, i1, i2, j1, j2 = diff + source_part = _split(" ".join(source_tokens[i1:i2])) + target_part = _split(" ".join(target_tokens[j1:j2])) + if tag == 'equal': + continue + elif tag == 'delete': + # delete all words separatly + for j in range(i2 - i1): + edit = [(i1 + j, i1 + j + 1), '$DELETE'] + all_edits.append(edit) + elif tag == 'insert': + # append to the previous word + for target_token in target_part: + edit = ((i1 - 1, i1), f"$APPEND_{target_token}") + all_edits.append(edit) + else: + # check merge first of all + edits = apply_merge_transformation(source_part, target_part, + shift_idx=i1) + if edits: + all_edits.extend(edits) + continue + + # normalize alignments if need (make them singleton) + _, alignments = perfect_align(source_part, target_part, + insertions_allowed=0) + for alignment in alignments: + new_shift = alignment[2][0] + edits = convert_alignments_into_edits(alignment, + shift_idx=i1 + new_shift) + all_edits.extend(edits) + + # get labels + labels = convert_edits_into_labels(source_tokens, all_edits) + # match tags to source tokens + sent_with_tags = add_labels_to_the_tokens(source_tokens, labels) + return sent_with_tags + + +def convert_edits_into_labels(source_tokens, all_edits): + # make sure that edits are flat + flat_edits = [] + for edit in all_edits: + (start, end), edit_operations = edit + if isinstance(edit_operations, list): + for operation in edit_operations: + new_edit = [(start, end), operation] + flat_edits.append(new_edit) + elif isinstance(edit_operations, str): + flat_edits.append(edit) + else: + raise Exception("Unknown operation type") + all_edits = flat_edits[:] + labels = [] + total_labels = len(source_tokens) + 1 + if not all_edits: + labels = [["$KEEP"] for x in range(total_labels)] + else: + for i in range(total_labels): + edit_operations = [x[1] for x in all_edits if x[0][0] == i - 1 + and x[0][1] == i] + if not edit_operations: + labels.append(["$KEEP"]) + else: + labels.append(edit_operations) + return labels + + +def convert_alignments_into_edits(alignment, shift_idx): + edits = [] + action, target_tokens, new_idx = alignment + source_token = action.replace("REPLACE_", "") + + # check if delete + if not target_tokens: + edit = [(shift_idx, 1 + shift_idx), "$DELETE"] + return [edit] + + # check splits + for i in range(1, len(target_tokens)): + target_token = " ".join(target_tokens[:i + 1]) + transform = apply_transformation(source_token, target_token) + if transform: + edit = [(shift_idx, shift_idx + 1), transform] + edits.append(edit) + target_tokens = target_tokens[i + 1:] + for target in target_tokens: + edits.append([(shift_idx, shift_idx + 1), f"$APPEND_{target}"]) + return edits + + transform_costs = [] + transforms = [] + for target_token in target_tokens: + transform = apply_transformation(source_token, target_token) + if transform: + cost = 0 + transforms.append(transform) + else: + cost = Levenshtein.distance(source_token, target_token) + transforms.append(None) + transform_costs.append(cost) + min_cost_idx = transform_costs.index(min(transform_costs)) + # append to the previous word + for i in range(0, min_cost_idx): + target = target_tokens[i] + edit = [(shift_idx - 1, shift_idx), f"$APPEND_{target}"] + edits.append(edit) + # replace/transform target word + transform = transforms[min_cost_idx] + target = transform if transform is not None \ + else f"$REPLACE_{target_tokens[min_cost_idx]}" + edit = [(shift_idx, 1 + shift_idx), target] + edits.append(edit) + # append to this word + for i in range(min_cost_idx + 1, len(target_tokens)): + target = target_tokens[i] + edit = [(shift_idx, 1 + shift_idx), f"$APPEND_{target}"] + edits.append(edit) + return edits + + +def add_labels_to_the_tokens(source_tokens, labels, delimeters=SEQ_DELIMETERS): + tokens_with_all_tags = [] + source_tokens_with_start = [START_TOKEN] + source_tokens + for token, label_list in zip(source_tokens_with_start, labels): + all_tags = delimeters['operations'].join(label_list) + comb_record = token + delimeters['labels'] + all_tags + tokens_with_all_tags.append(comb_record) + return delimeters['tokens'].join(tokens_with_all_tags) + + +def convert_data_from_raw_files(source_file, target_file, output_file, chunk_size): + tagged = [] + source_data, target_data = read_parallel_lines(source_file, target_file) + print(f"The size of raw dataset is {len(source_data)}") + cnt_total, cnt_all, cnt_tp = 0, 0, 0 + for source_sent, target_sent in tqdm(zip(source_data, target_data)): + try: + aligned_sent = align_sequences(source_sent, target_sent) + except Exception: + aligned_sent = align_sequences(source_sent, target_sent) + if source_sent != target_sent: + cnt_tp += 1 + alignments = [aligned_sent] + cnt_all += len(alignments) + try: + check_sent = convert_tagged_line(aligned_sent) + except Exception: + # debug mode + aligned_sent = align_sequences(source_sent, target_sent) + check_sent = convert_tagged_line(aligned_sent) + + if "".join(check_sent.split()) != "".join( + target_sent.split()): + # do it again for debugging + aligned_sent = align_sequences(source_sent, target_sent) + check_sent = convert_tagged_line(aligned_sent) + print(f"Incorrect pair: \n{target_sent}\n{check_sent}") + continue + if alignments: + cnt_total += len(alignments) + tagged.extend(alignments) + if len(tagged) > chunk_size: + write_lines(output_file, tagged, mode='a') + tagged = [] + + print(f"Overall extracted {cnt_total}. " + f"Original TP {cnt_tp}." + f" Original TN {cnt_all - cnt_tp}") + if tagged: + write_lines(output_file, tagged, 'a') + + +def convert_labels_into_edits(labels): + all_edits = [] + for i, label_list in enumerate(labels): + if label_list == ["$KEEP"]: + continue + else: + edit = [(i - 1, i), label_list] + all_edits.append(edit) + return all_edits + + +def get_target_sent_by_levels(source_tokens, labels): + relevant_edits = convert_labels_into_edits(labels) + target_tokens = source_tokens[:] + leveled_target_tokens = {} + if not relevant_edits: + target_sentence = " ".join(target_tokens) + return leveled_target_tokens, target_sentence + max_level = max([len(x[1]) for x in relevant_edits]) + for level in range(max_level): + rest_edits = [] + shift_idx = 0 + for edits in relevant_edits: + (start, end), label_list = edits + label = label_list[0] + target_pos = start + shift_idx + source_token = target_tokens[target_pos] if target_pos >= 0 else START_TOKEN + if label == "$DELETE": + del target_tokens[target_pos] + shift_idx -= 1 + elif label.startswith("$APPEND_"): + word = label.replace("$APPEND_", "") + target_tokens[target_pos + 1: target_pos + 1] = [word] + shift_idx += 1 + elif label.startswith("$REPLACE_"): + word = label.replace("$REPLACE_", "") + target_tokens[target_pos] = word + elif label.startswith("$TRANSFORM"): + word = apply_reverse_transformation(source_token, label) + if word is None: + word = source_token + target_tokens[target_pos] = word + elif label.startswith("$MERGE_"): + # apply merge only on last stage + if level == (max_level - 1): + target_tokens[target_pos + 1: target_pos + 1] = [label] + shift_idx += 1 + else: + rest_edit = [(start + shift_idx, end + shift_idx), [label]] + rest_edits.append(rest_edit) + rest_labels = label_list[1:] + if rest_labels: + rest_edit = [(start + shift_idx, end + shift_idx), rest_labels] + rest_edits.append(rest_edit) + + leveled_tokens = target_tokens[:] + # update next step + relevant_edits = rest_edits[:] + if level == (max_level - 1): + leveled_tokens = replace_merge_transforms(leveled_tokens) + leveled_labels = convert_edits_into_labels(leveled_tokens, + relevant_edits) + leveled_target_tokens[level + 1] = {"tokens": leveled_tokens, + "labels": leveled_labels} + + target_sentence = " ".join(leveled_target_tokens[max_level]["tokens"]) + return leveled_target_tokens, target_sentence + + +def replace_merge_transforms(tokens): + if all(not x.startswith("$MERGE_") for x in tokens): + return tokens + target_tokens = tokens[:] + allowed_range = (1, len(tokens) - 1) + for i in range(len(tokens)): + target_token = tokens[i] + if target_token.startswith("$MERGE"): + if target_token.startswith("$MERGE_SWAP") and i in allowed_range: + target_tokens[i - 1] = tokens[i + 1] + target_tokens[i + 1] = tokens[i - 1] + target_tokens[i: i + 1] = [] + target_line = " ".join(target_tokens) + target_line = target_line.replace(" $MERGE_HYPHEN ", "-") + target_line = target_line.replace(" $MERGE_SPACE ", "") + return target_line.split() + + +def convert_tagged_line(line, delimeters=SEQ_DELIMETERS): + label_del = delimeters['labels'] + source_tokens = [x.split(label_del)[0] + for x in line.split(delimeters['tokens'])][1:] + labels = [x.split(label_del)[1].split(delimeters['operations']) + for x in line.split(delimeters['tokens'])] + assert len(source_tokens) + 1 == len(labels) + levels_dict, target_line = get_target_sent_by_levels(source_tokens, labels) + return target_line + + +def main(args): + convert_data_from_raw_files(args.source, args.target, args.output_file, args.chunk_size) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-s', '--source', + help='Path to the source file', + required=True) + parser.add_argument('-t', '--target', + help='Path to the target file', + required=True) + parser.add_argument('-o', '--output_file', + help='Path to the output file', + required=True) + parser.add_argument('--chunk_size', + type=int, + help='Dump each chunk size.', + default=1000000) + args = parser.parse_args() + main(args) diff --git a/nlptoolkit/gec/models/gector/wordpiece_indexer.py b/nlptoolkit/gec/models/gector/wordpiece_indexer.py new file mode 100644 index 0000000..6805574 --- /dev/null +++ b/nlptoolkit/gec/models/gector/wordpiece_indexer.py @@ -0,0 +1,444 @@ +"""Tweaked version of corresponding AllenNLP file""" +import logging +from collections import defaultdict +from typing import Dict, List, Callable + +from allennlp.common.util import pad_sequence_to_length +from allennlp.data.token_indexers.token_indexer import TokenIndexer +from allennlp.data.tokenizers.token import Token +from allennlp.data.vocabulary import Vocabulary +from overrides import overrides +from transformers import AutoTokenizer + +from .utils.helpers import START_TOKEN + +logger = logging.getLogger(__name__) + +# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer. + +# This is the default list of tokens that should not be lowercased. +_NEVER_LOWERCASE = ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] + + +class WordpieceIndexer(TokenIndexer[int]): + """ + A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings). + If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer`` + subclass rather than this base class. + + Parameters + ---------- + vocab : ``Dict[str, int]`` + The mapping {wordpiece -> id}. Note this is not an AllenNLP ``Vocabulary``. + wordpiece_tokenizer : ``Callable[[str], List[str]]`` + A function that does the actual tokenization. + namespace : str, optional (default: "wordpiece") + The namespace in the AllenNLP ``Vocabulary`` into which the wordpieces + will be loaded. + use_starting_offsets : bool, optional (default: False) + By default, the "offsets" created by the token indexer correspond to the + last wordpiece in each word. If ``use_starting_offsets`` is specified, + they will instead correspond to the first wordpiece in each word. + max_pieces : int, optional (default: 512) + The BERT embedder uses positional embeddings and so has a corresponding + maximum length for its input ids. Any inputs longer than this will + either be truncated (default), or be split apart and batched using a + sliding window. + do_lowercase : ``bool``, optional (default=``False``) + Should we lowercase the provided tokens before getting the indices? + You would need to do this if you are using an -uncased BERT model + but your DatasetReader is not lowercasing tokens (which might be the + case if you're also using other embeddings based on cased tokens). + never_lowercase: ``List[str]``, optional + Tokens that should never be lowercased. Default is + ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']. + start_tokens : ``List[str]``, optional (default=``None``) + These are prepended to the tokens provided to ``tokens_to_indices``. + end_tokens : ``List[str]``, optional (default=``None``) + These are appended to the tokens provided to ``tokens_to_indices``. + separator_token : ``str``, optional (default=``[SEP]``) + This token indicates the segments in the sequence. + truncate_long_sequences : ``bool``, optional (default=``True``) + By default, long sequences will be truncated to the maximum sequence + length. Otherwise, they will be split apart and batched using a + sliding window. + token_min_padding_length : ``int``, optional (default=``0``) + See :class:`TokenIndexer`. + """ + + def __init__(self, + vocab: Dict[str, int], + bpe_ranks: Dict, + byte_encoder: Dict, + wordpiece_tokenizer: Callable[[str], List[str]], + namespace: str = "wordpiece", + use_starting_offsets: bool = False, + max_pieces: int = 512, + max_pieces_per_token: int = 3, + is_test=False, + do_lowercase: bool = False, + never_lowercase: List[str] = None, + start_tokens: List[str] = None, + end_tokens: List[str] = None, + truncate_long_sequences: bool = True, + token_min_padding_length: int = 0) -> None: + super().__init__(token_min_padding_length) + self.vocab = vocab + + # The BERT code itself does a two-step tokenization: + # sentence -> [words], and then word -> [wordpieces] + # In AllenNLP, the first step is implemented as the ``BertBasicWordSplitter``, + # and this token indexer handles the second. + self.wordpiece_tokenizer = wordpiece_tokenizer + self.max_pieces_per_token = max_pieces_per_token + self._namespace = namespace + self._added_to_vocabulary = False + self.max_pieces = max_pieces + self.use_starting_offsets = use_starting_offsets + self._do_lowercase = do_lowercase + self._truncate_long_sequences = truncate_long_sequences + self.max_pieces_per_sentence = 80 + self.is_test = is_test + self.cache = {} + self.bpe_ranks = bpe_ranks + self.byte_encoder = byte_encoder + + if self.is_test: + self.max_pieces_per_token = None + + if never_lowercase is None: + # Use the defaults + self._never_lowercase = set(_NEVER_LOWERCASE) + else: + self._never_lowercase = set(never_lowercase) + + # Convert the start_tokens and end_tokens to wordpiece_ids + self._start_piece_ids = [vocab[wordpiece] + for token in (start_tokens or []) + for wordpiece in wordpiece_tokenizer(token)] + self._end_piece_ids = [vocab[wordpiece] + for token in (end_tokens or []) + for wordpiece in wordpiece_tokenizer(token)] + + @overrides + def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): + # If we only use pretrained models, we don't need to do anything here. + pass + + def _add_encoding_to_vocabulary(self, vocabulary: Vocabulary) -> None: + # pylint: disable=protected-access + for word, idx in self.vocab.items(): + vocabulary._token_to_index[self._namespace][word] = idx + vocabulary._index_to_token[self._namespace][idx] = word + + def get_pairs(self, word): + """Return set of symbol pairs in a word. + + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = self.get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, + float( + 'inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = self.get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def bpe_tokenize(self, text): + """ Tokenize a string.""" + bpe_tokens = [] + for token in text.split(): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + @overrides + def tokens_to_indices(self, + tokens: List[Token], + vocabulary: Vocabulary, + index_name: str) -> Dict[str, List[int]]: + if not self._added_to_vocabulary: + self._add_encoding_to_vocabulary(vocabulary) + self._added_to_vocabulary = True + + # This lowercases tokens if necessary + text = (token.text.lower() + if self._do_lowercase and token.text not in self._never_lowercase + else token.text + for token in tokens) + + # Obtain a nested sequence of wordpieces, each represented by a list of wordpiece ids + token_wordpiece_ids = [] + for token in text: + if self.bpe_ranks != {}: + wps = self.bpe_tokenize(token) + else: + wps = self.wordpiece_tokenizer(token) + limited_wps = [self.vocab[wordpiece] for wordpiece in wps][:self.max_pieces_per_token] + token_wordpiece_ids.append(limited_wps) + + # Flattened list of wordpieces. In the end, the output of the model (e.g., BERT) should + # have a sequence length equal to the length of this list. However, it will first be split into + # chunks of length `self.max_pieces` so that they can be fit through the model. After packing + # and passing through the model, it should be unpacked to represent the wordpieces in this list. + flat_wordpiece_ids = [wordpiece for token in token_wordpiece_ids for wordpiece in token] + + # reduce max_pieces_per_token if piece length of sentence is bigger than max_pieces_per_sentence + # helps to fix CUDA out of memory errors meanwhile increasing batch size + while not self.is_test and len(flat_wordpiece_ids) > \ + self.max_pieces_per_sentence - len(self._start_piece_ids) - len(self._end_piece_ids): + max_pieces = max([len(row) for row in token_wordpiece_ids]) + token_wordpiece_ids = [row[:max_pieces - 1] for row in token_wordpiece_ids] + flat_wordpiece_ids = [wordpiece for token in token_wordpiece_ids for wordpiece in token] + + # The code below will (possibly) pack the wordpiece sequence into multiple sub-sequences by using a sliding + # window `window_length` that overlaps with previous windows according to the `stride`. Suppose we have + # the following sentence: "I went to the store to buy some milk". Then a sliding window of length 4 and + # stride of length 2 will split them up into: + + # "[I went to the] [to the store to] [store to buy some] [buy some milk [PAD]]". + + # This is to ensure that the model has context of as much of the sentence as possible to get accurate + # embeddings. Finally, the sequences will be padded with any start/end piece ids, e.g., + + # "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ...". + + # The embedder should then be able to split this token sequence by the window length, + # pass them through the model, and recombine them. + + # Specify the stride to be half of `self.max_pieces`, minus any additional start/end wordpieces + window_length = self.max_pieces - len(self._start_piece_ids) - len(self._end_piece_ids) + stride = window_length // 2 + + # offsets[i] will give us the index into wordpiece_ids + # for the wordpiece "corresponding to" the i-th input token. + offsets = [] + + # If we're using initial offsets, we want to start at offset = len(text_tokens) + # so that the first offset is the index of the first wordpiece of tokens[0]. + # Otherwise, we want to start at len(text_tokens) - 1, so that the "previous" + # offset is the last wordpiece of "tokens[-1]". + offset = len(self._start_piece_ids) if self.use_starting_offsets else len(self._start_piece_ids) - 1 + + for token in token_wordpiece_ids: + # Truncate the sequence if specified, which depends on where the offsets are + next_offset = 1 if self.use_starting_offsets else 0 + if self._truncate_long_sequences and offset >= window_length + next_offset: + break + + # For initial offsets, the current value of ``offset`` is the start of + # the current wordpiece, so add it to ``offsets`` and then increment it. + if self.use_starting_offsets: + offsets.append(offset) + offset += len(token) + # For final offsets, the current value of ``offset`` is the end of + # the previous wordpiece, so increment it and then add it to ``offsets``. + else: + offset += len(token) + offsets.append(offset) + + if len(flat_wordpiece_ids) <= window_length: + # If all the wordpieces fit, then we don't need to do anything special + wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids)] + elif self._truncate_long_sequences: + logger.warning("Too many wordpieces, truncating sequence. If you would like a sliding window, set" + "`truncate_long_sequences` to False %s", str([token.text for token in tokens])) + wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids[:window_length])] + else: + # Create a sliding window of wordpieces of length `max_pieces` that advances by `stride` steps and + # add start/end wordpieces to each window + # TODO: this currently does not respect word boundaries, so words may be cut in half between windows + # However, this would increase complexity, as sequences would need to be padded/unpadded in the middle + wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids[i:i + window_length]) + for i in range(0, len(flat_wordpiece_ids), stride)] + + # Check for overlap in the last window. Throw it away if it is redundant. + last_window = wordpiece_windows[-1][1:] + penultimate_window = wordpiece_windows[-2] + if last_window == penultimate_window[-len(last_window):]: + wordpiece_windows = wordpiece_windows[:-1] + + # Flatten the wordpiece windows + wordpiece_ids = [wordpiece for sequence in wordpiece_windows for wordpiece in sequence] + + # Our mask should correspond to the original tokens, + # because calling util.get_text_field_mask on the + # "wordpiece_id" tokens will produce the wrong shape. + # However, because of the max_pieces constraint, we may + # have truncated the wordpieces; accordingly, we want the mask + # to correspond to the remaining tokens after truncation, which + # is captured by the offsets. + mask = [1 for _ in offsets] + + return {index_name: wordpiece_ids, + f"{index_name}-offsets": offsets, + "mask": mask} + + def _add_start_and_end(self, wordpiece_ids: List[int]) -> List[int]: + return self._start_piece_ids + wordpiece_ids + self._end_piece_ids + + def _extend(self, token_type_ids: List[int]) -> List[int]: + """ + Extend the token type ids by len(start_piece_ids) on the left + and len(end_piece_ids) on the right. + """ + first = token_type_ids[0] + last = token_type_ids[-1] + return ([first for _ in self._start_piece_ids] + + token_type_ids + + [last for _ in self._end_piece_ids]) + + @overrides + def get_padding_token(self) -> int: + return 0 + + @overrides + def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument + return {} + + @overrides + def pad_token_sequence(self, + tokens: Dict[str, List[int]], + desired_num_tokens: Dict[str, int], + padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument + return {key: pad_sequence_to_length(val, desired_num_tokens[key]) + for key, val in tokens.items()} + + @overrides + def get_keys(self, index_name: str) -> List[str]: + """ + We need to override this because the indexer generates multiple keys. + """ + # pylint: disable=no-self-use + return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"] + + +class PretrainedBertIndexer(WordpieceIndexer): + # pylint: disable=line-too-long + """ + A ``TokenIndexer`` corresponding to a pretrained BERT model. + + Parameters + ---------- + pretrained_model: ``str`` + Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), + or the path to the .txt file with its vocabulary. + + If the name is a key in the list of pretrained models at + https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33 + the corresponding path will be used; otherwise it will be interpreted as a path or URL. + use_starting_offsets: bool, optional (default: False) + By default, the "offsets" created by the token indexer correspond to the + last wordpiece in each word. If ``use_starting_offsets`` is specified, + they will instead correspond to the first wordpiece in each word. + do_lowercase: ``bool``, optional (default = True) + Whether to lowercase the tokens before converting to wordpiece ids. + never_lowercase: ``List[str]``, optional + Tokens that should never be lowercased. Default is + ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']. + max_pieces: int, optional (default: 512) + The BERT embedder uses positional embeddings and so has a corresponding + maximum length for its input ids. Any inputs longer than this will + either be truncated (default), or be split apart and batched using a + sliding window. + truncate_long_sequences : ``bool``, optional (default=``True``) + By default, long sequences will be truncated to the maximum sequence + length. Otherwise, they will be split apart and batched using a + sliding window. + """ + + def __init__(self, + pretrained_model: str, + use_starting_offsets: bool = False, + do_lowercase: bool = True, + never_lowercase: List[str] = None, + max_pieces: int = 512, + max_pieces_per_token=5, + is_test=False, + truncate_long_sequences: bool = True, + special_tokens_fix: int = 0) -> None: + if pretrained_model.endswith("-cased") and do_lowercase: + logger.warning("Your BERT model appears to be cased, " + "but your indexer is lowercasing tokens.") + elif pretrained_model.endswith("-uncased") and not do_lowercase: + logger.warning("Your BERT model appears to be uncased, " + "but your indexer is not lowercasing tokens.") + + bert_tokenizer = AutoTokenizer.from_pretrained( + pretrained_model, do_lower_case=do_lowercase, do_basic_tokenize=False) + + # to adjust all tokenizers + if hasattr(bert_tokenizer, 'encoder'): + bert_tokenizer.vocab = bert_tokenizer.encoder + if hasattr(bert_tokenizer, 'sp_model'): + bert_tokenizer.vocab = defaultdict(lambda: 1) + for i in range(bert_tokenizer.sp_model.get_piece_size()): + bert_tokenizer.vocab[bert_tokenizer.sp_model.id_to_piece(i)] = i + + if special_tokens_fix: + bert_tokenizer.add_tokens([START_TOKEN]) + bert_tokenizer.vocab[START_TOKEN] = len(bert_tokenizer) - 1 + + if "roberta" in pretrained_model: + bpe_ranks = bert_tokenizer.bpe_ranks + byte_encoder = bert_tokenizer.byte_encoder + else: + bpe_ranks = {} + byte_encoder = None + + super().__init__(vocab=bert_tokenizer.vocab, + bpe_ranks=bpe_ranks, + byte_encoder=byte_encoder, + wordpiece_tokenizer=bert_tokenizer.tokenize, + namespace="bert", + use_starting_offsets=use_starting_offsets, + max_pieces=max_pieces, + max_pieces_per_token=max_pieces_per_token, + is_test=is_test, + do_lowercase=do_lowercase, + never_lowercase=never_lowercase, + start_tokens=["[CLS]"] if not special_tokens_fix else [], + end_tokens=["[SEP]"] if not special_tokens_fix else [], + truncate_long_sequences=truncate_long_sequences) diff --git a/nlptoolkit/gec/trainer.py b/nlptoolkit/gec/trainer.py new file mode 100644 index 0000000..8a2d076 --- /dev/null +++ b/nlptoolkit/gec/trainer.py @@ -0,0 +1,303 @@ +import argparse +import os +from random import seed + +import torch +from allennlp.data.iterators import BucketIterator +from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN +from allennlp.data.vocabulary import Vocabulary +from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder + +from gector.bert_token_embedder import PretrainedBertEmbedder +from gector.datareader import Seq2LabelsDatasetReader +from gector.seq2labels_model import Seq2Labels +from gector.trainer import Trainer +from gector.wordpiece_indexer import PretrainedBertIndexer +from utils.helpers import get_weights_name + + +def fix_seed(): + torch.manual_seed(1) + torch.backends.cudnn.enabled = False + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + seed(43) + + +def get_token_indexers(model_name, max_pieces_per_token=5, lowercase_tokens=True, special_tokens_fix=0, is_test=False): + bert_token_indexer = PretrainedBertIndexer( + pretrained_model=model_name, + max_pieces_per_token=max_pieces_per_token, + do_lowercase=lowercase_tokens, + use_starting_offsets=True, + special_tokens_fix=special_tokens_fix, + is_test=is_test + ) + return {'bert': bert_token_indexer} + + +def get_token_embedders(model_name, tune_bert=False, special_tokens_fix=0): + take_grads = True if tune_bert > 0 else False + bert_token_emb = PretrainedBertEmbedder( + pretrained_model=model_name, + top_layer_only=True, requires_grad=take_grads, + special_tokens_fix=special_tokens_fix) + + token_embedders = {'bert': bert_token_emb} + embedder_to_indexer_map = {"bert": ["bert", "bert-offsets"]} + + text_filed_emd = BasicTextFieldEmbedder(token_embedders=token_embedders, + embedder_to_indexer_map=embedder_to_indexer_map, + allow_unmatched_keys=True) + return text_filed_emd + + +def get_data_reader(model_name, max_len, skip_correct=False, skip_complex=0, + test_mode=False, tag_strategy="keep_one", + broken_dot_strategy="keep", lowercase_tokens=True, + max_pieces_per_token=3, tn_prob=0, tp_prob=1, special_tokens_fix=0,): + token_indexers = get_token_indexers(model_name, + max_pieces_per_token=max_pieces_per_token, + lowercase_tokens=lowercase_tokens, + special_tokens_fix=special_tokens_fix, + is_test=test_mode) + reader = Seq2LabelsDatasetReader(token_indexers=token_indexers, + max_len=max_len, + skip_correct=skip_correct, + skip_complex=skip_complex, + test_mode=test_mode, + tag_strategy=tag_strategy, + broken_dot_strategy=broken_dot_strategy, + lazy=True, + tn_prob=tn_prob, + tp_prob=tp_prob) + return reader + + +def get_model(model_name, vocab, tune_bert=False, + predictor_dropout=0, + label_smoothing=0.0, + confidence=0, + special_tokens_fix=0): + token_embs = get_token_embedders(model_name, tune_bert=tune_bert, special_tokens_fix=special_tokens_fix) + model = Seq2Labels(vocab=vocab, + text_field_embedder=token_embs, + predictor_dropout=predictor_dropout, + label_smoothing=label_smoothing, + confidence=confidence) + return model + + +def main(args): + fix_seed() + if not os.path.exists(args.model_dir): + os.mkdir(args.model_dir) + + weights_name = get_weights_name(args.transformer_model, args.lowercase_tokens) + # read datasets + reader = get_data_reader(weights_name, args.max_len, skip_correct=bool(args.skip_correct), + skip_complex=args.skip_complex, + test_mode=False, + tag_strategy=args.tag_strategy, + lowercase_tokens=args.lowercase_tokens, + max_pieces_per_token=args.pieces_per_token, + tn_prob=args.tn_prob, + tp_prob=args.tp_prob, + special_tokens_fix=args.special_tokens_fix) + train_data = reader.read(args.train_set) + dev_data = reader.read(args.dev_set) + + default_tokens = [DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN] + namespaces = ['labels', 'd_tags'] + tokens_to_add = {x: default_tokens for x in namespaces} + # build vocab + if args.vocab_path: + vocab = Vocabulary.from_files(args.vocab_path) + else: + vocab = Vocabulary.from_instances(train_data, + max_vocab_size={'tokens': 30000, + 'labels': args.target_vocab_size, + 'd_tags': 2}, + tokens_to_add=tokens_to_add) + vocab.save_to_files(os.path.join(args.model_dir, 'vocabulary')) + + print("Data is loaded") + model = get_model(weights_name, vocab, + tune_bert=args.tune_bert, + predictor_dropout=args.predictor_dropout, + label_smoothing=args.label_smoothing, + special_tokens_fix=args.special_tokens_fix) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if torch.cuda.is_available(): + if torch.cuda.device_count() > 1: + cuda_device = list(range(torch.cuda.device_count())) + else: + cuda_device = 0 + else: + cuda_device = -1 + + if args.pretrain: + model.load_state_dict(torch.load(os.path.join(args.pretrain_folder, args.pretrain + '.th'))) + + model = model.to(device) + + print("Model is set") + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, factor=0.1, patience=10) + instances_per_epoch = None if not args.updates_per_epoch else \ + int(args.updates_per_epoch * args.batch_size * args.accumulation_size) + iterator = BucketIterator(batch_size=args.batch_size, + sorting_keys=[("tokens", "num_tokens")], + biggest_batch_first=True, + max_instances_in_memory=args.batch_size * 20000, + instances_per_epoch=instances_per_epoch, + ) + iterator.index_with(vocab) + trainer = Trainer(model=model, + optimizer=optimizer, + scheduler=scheduler, + iterator=iterator, + train_dataset=train_data, + validation_dataset=dev_data, + serialization_dir=args.model_dir, + patience=args.patience, + num_epochs=args.n_epoch, + cuda_device=cuda_device, + shuffle=False, + accumulated_batch_count=args.accumulation_size, + cold_step_count=args.cold_steps_count, + cold_lr=args.cold_lr, + cuda_verbose_step=int(args.cuda_verbose_steps) + if args.cuda_verbose_steps else None + ) + print("Start training") + trainer.train() + + # Here's how to save the model. + out_model = os.path.join(args.model_dir, 'model.th') + with open(out_model, 'wb') as f: + torch.save(model.state_dict(), f) + print("Model is dumped") + + +if __name__ == '__main__': + # read parameters + parser = argparse.ArgumentParser() + parser.add_argument('--train_set', + help='Path to the train data', required=True) + parser.add_argument('--dev_set', + help='Path to the dev data', required=True) + parser.add_argument('--model_dir', + help='Path to the model dir', required=True) + parser.add_argument('--vocab_path', + help='Path to the model vocabulary directory.' + 'If not set then build vocab from data', + default='') + parser.add_argument('--batch_size', + type=int, + help='The size of the batch.', + default=32) + parser.add_argument('--max_len', + type=int, + help='The max sentence length' + '(all longer will be truncated)', + default=50) + parser.add_argument('--target_vocab_size', + type=int, + help='The size of target vocabularies.', + default=1000) + parser.add_argument('--n_epoch', + type=int, + help='The number of epoch for training model.', + default=20) + parser.add_argument('--patience', + type=int, + help='The number of epoch with any improvements' + ' on validation set.', + default=3) + parser.add_argument('--skip_correct', + type=int, + help='If set than correct sentences will be skipped ' + 'by data reader.', + default=1) + parser.add_argument('--skip_complex', + type=int, + help='If set than complex corrections will be skipped ' + 'by data reader.', + choices=[0, 1, 2, 3, 4, 5], + default=0) + parser.add_argument('--tune_bert', + type=int, + help='If more then 0 then fine tune bert.', + default=1) + parser.add_argument('--tag_strategy', + choices=['keep_one', 'merge_all'], + help='The type of the data reader behaviour.', + default='keep_one') + parser.add_argument('--accumulation_size', + type=int, + help='How many batches do you want accumulate.', + default=4) + parser.add_argument('--lr', + type=float, + help='Set initial learning rate.', + default=1e-5) + parser.add_argument('--cold_steps_count', + type=int, + help='Whether to train only classifier layers first.', + default=4) + parser.add_argument('--cold_lr', + type=float, + help='Learning rate during cold_steps.', + default=1e-3) + parser.add_argument('--predictor_dropout', + type=float, + help='The value of dropout for predictor.', + default=0.0) + parser.add_argument('--lowercase_tokens', + type=int, + help='Whether to lowercase tokens.', + default=0) + parser.add_argument('--pieces_per_token', + type=int, + help='The max number for pieces per token.', + default=5) + parser.add_argument('--cuda_verbose_steps', + help='Number of steps after which CUDA memory information is printed. ' + 'Makes sense for local testing. Usually about 1000.', + default=None) + parser.add_argument('--label_smoothing', + type=float, + help='The value of parameter alpha for label smoothing.', + default=0.0) + parser.add_argument('--tn_prob', + type=float, + help='The probability to take TN from data.', + default=0) + parser.add_argument('--tp_prob', + type=float, + help='The probability to take TP from data.', + default=1) + parser.add_argument('--updates_per_epoch', + type=int, + help='If set then each epoch will contain the exact amount of updates.', + default=0) + parser.add_argument('--pretrain_folder', + help='The name of the pretrain folder.') + parser.add_argument('--pretrain', + help='The name of the pretrain weights in pretrain_folder param.', + default='') + parser.add_argument('--transformer_model', + choices=['bert', 'distilbert', 'gpt2', 'roberta', 'transformerxl', 'xlnet', 'albert'], + help='Name of the transformer model.', + default='roberta') + parser.add_argument('--special_tokens_fix', + type=int, + help='Whether to fix problem with [CLS], [SEP] tokens tokenization.', + default=1) + + args = parser.parse_args() + main(args)