From a78dfdf77031c860e5a379f35fb3cef63ef9b2b4 Mon Sep 17 00:00:00 2001 From: Wee Tee Soh Date: Sat, 14 Sep 2019 22:29:24 +0800 Subject: [PATCH] updated ner with conll2003 dataset --- ner.py | 27 +++ ner/preprocessing_funcs.py | 89 ++------- ner/utils/__init__.py | 0 ner/utils/bpe_vocab.py | 291 +++++++++++++++++++++++++++++ ner/utils/misc_utils.py | 95 ++++++++++ ner/utils/word_char_level_vocab.py | 73 ++++++++ translate.py | 2 +- 7 files changed, 505 insertions(+), 72 deletions(-) create mode 100644 ner.py create mode 100644 ner/utils/__init__.py create mode 100644 ner/utils/bpe_vocab.py create mode 100644 ner/utils/misc_utils.py create mode 100644 ner/utils/word_char_level_vocab.py diff --git a/ner.py b/ner.py new file mode 100644 index 0000000..be35005 --- /dev/null +++ b/ner.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Sep 14 18:04:45 2019 + +@author: WT +""" +from ner.preprocessing_funcs import get_NER_data +from utils.misc import save_as_pickle +from argparse import ArgumentParser +import logging +logging.basicConfig(format='%(asctime)s [%(levelname)s]: %(message)s', \ + datefmt='%m/%d/%Y %I:%M:%S %p', level=logging.INFO) +logger = logging.getLogger('__file__') + +if __name__=="__main__": + parser = ArgumentParser() + parser.add_argument("--train_path", type=str, default="./data/ner/conll2003/eng.train.txt", help="Path to training data txt file") + parser.add_argument("--test_path", type=str, default="./data/ner/conll2003/eng.testa.txt", help="Path to test data txt file (if any)") + parser.add_argument("--gradient_acc_steps", type=int, default=1, help="No. of steps of gradient accumulation") + parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm") + parser.add_argument("--num_epochs", type=int, default=1700, help="No of epochs") + parser.add_argument("--lr", type=float, default=0.0031, help="learning rate") + parser.add_argument("--model_no", type=int, default=2, help="Model ID: (0: Graph Convolution Network (GCN), 1: BERT, 2: XLNet)") + + args = parser.parse_args() + text = get_NER_data(args, load_extracted=False) + #save_as_pickle("args.pkl", args) \ No newline at end of file diff --git a/ner/preprocessing_funcs.py b/ner/preprocessing_funcs.py index 933d737..bfd7c58 100644 --- a/ner/preprocessing_funcs.py +++ b/ner/preprocessing_funcs.py @@ -40,85 +40,32 @@ def clean_and_tokenize_text(text, table, tokenizer, clean_only=False): text = [w for w in text if not any(char.isdigit() for char in w)] return text -def get_CNN_data(args, load_extracted=True): +def get_NER_data(args, load_extracted=True): """ - Extracts CNN dataset, saves then - returns dataframe containing body (main text) and highlights (summarized text) + Extracts NER dataset, saves then + returns dataframe containing body (main text) and NER tags columns table: table containing symbols to remove from text tokenizer: tokenizer to tokenize text into word tokens """ - path = args.data_path - tokenizer_en = tokener() + train_path = args.train_path + if args.test_path is not None: + test_path = args.test_path + else: + test_path = None + table = str.maketrans("", "", '"#$%&\'()*+-/:;<=>@[\\]^_`{|}~') if load_extracted: - df = load_pickle("df_unencoded_CNN.pkl") + df_train = load_pickle("df_train.pkl") + if os.path.isfile("./data/df_test.pkl") is not None: + df_test = load_pickle("df_test.pkl") + else: - logger.info("Extracting CNN stories...") - df = pd.DataFrame(index=[i for i in range(len(os.listdir(path)))], columns=["body", "highlights"]) - for idx, file in tqdm(enumerate(os.listdir(path)), total=len(os.listdir(path))): - with open(os.path.join(path, file), encoding="utf8") as csv_file: - csv_reader = csv.reader(csv_file) - text = "" - for row in csv_reader: - text += "".join(t for t in row) - highlights = re.search("@highlight(.*)", text).group(1) - highlights = highlights.replace("@highlight", ". ") - body = text[:re.search("@highlight", text).span(0)[0]] - df.iloc[idx]["body"] = body - df.iloc[idx]["highlights"] = highlights - save_as_pickle("df_unencoded_CNN.pkl", df) - - if (args.level == "word") or (args.level == "char"): - logger.info("Tokenizing and cleaning extracted text...") - df.loc[:, "body"] = df.apply(lambda x: clean_and_tokenize_text(x["body"], table, tokenizer_en), axis=1) - df.loc[:, "highlights"] = df.apply(lambda x: clean_and_tokenize_text(x["highlights"], table, tokenizer_en), \ - axis=1) - df.loc[:, "body_length"] = df.apply(lambda x: len(x['body']), axis=1) - df.loc[:, "highlights_length"] = df.apply(lambda x: len(x['highlights']), axis=1) - df = df[(df["body_length"] > 0) & (df["highlights_length"] > 0)] - - logger.info("Limiting to max features length, building vocab and converting to id tokens...") - df = df[df["body_length"] <= args.max_features_length] - v = vocab(level=args.level) - v.build_vocab(df["body"]) - v.build_vocab(df["highlights"]) - df.loc[:, "body"] = df.apply(lambda x: v.convert_w2idx(x["body"]), axis=1) - df.loc[:, "highlights"] = df.apply(lambda x: v.convert_w2idx(x["highlights"]), axis=1) - df.loc[:, "highlights"] = df.apply(lambda x: pad_sos_eos(x["highlights"], 0, 2), axis=1) - save_as_pickle("df_encoded_CNN.pkl", df) - save_as_pickle("vocab.pkl", v) - - elif args.level == "bpe": - encoder = Encoder(vocab_size=args.bpe_vocab_size, pct_bpe=args.bpe_word_ratio, word_tokenizer=tokenizer_en.tokenize) - df.loc[:, "body"] = df.apply(lambda x: clean_and_tokenize_text(x["body"], table, tokenizer_en, clean_only=True), axis=1) - df.loc[:, "highlights"] = df.apply(lambda x: clean_and_tokenize_text(x["highlights"], table, tokenizer_en, clean_only=True), \ - axis=1) - logger.info("Training bpe, this might take a while...") - text_list = list(df["body"]) - text_list.extend(list(df["highlights"])) - encoder.fit(text_list); del text_list + logger.info("Extracting data stories...") + with open(train_path, "r", encoding="utf8") as f: + text = f.readlines() - logger.info("Tokenizing to ids and limiting to max features length...") - df.loc[:, "body"] = df.apply(lambda x: next(encoder.transform([x["body"]])), axis=1) - df.loc[:, "highlights"] = df.apply(lambda x: next(encoder.transform([x["highlights"]])), axis=1) - df.loc[:, "body_length"] = df.apply(lambda x: len(x['body']), axis=1) - df.loc[:, "highlights_length"] = df.apply(lambda x: len(x['highlights']), axis=1) - df = df[(df["body_length"] > 0) & (df["highlights_length"] > 0)] - df = df[df["body_length"] <= args.max_features_length] - - ''' - logger.info("Converting tokens to ids...") - df.loc[:, "body"] = df.apply(lambda x: next(encoder.transform(list(" ".join(t for t in x["body"])))),\ - axis=1) - df.loc[:, "highlights"] = df.apply(lambda x: next(encoder.transform(list(" ".join(t for t in x["highlights"])))),\ - axis=1) - ''' - df.loc[:, "highlights"] = df.apply(lambda x: pad_sos_eos(x["highlights"], encoder.word_vocab["__sos"], encoder.word_vocab["__eos"]),\ - axis=1) - - save_as_pickle("df_encoded_CNN.pkl", df) - encoder.save("./data/vocab.pkl") - return df + + return text class Pad_Sequence(): """ diff --git a/ner/utils/__init__.py b/ner/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ner/utils/bpe_vocab.py b/ner/utils/bpe_vocab.py new file mode 100644 index 0000000..029b691 --- /dev/null +++ b/ner/utils/bpe_vocab.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- +""" +Adapted from https://github.com/soaxelbrooke/python-bpe/blob/master/bpe/encoder.py +""" +""" An encoder which learns byte pair encodings for white-space separated text. Can tokenize, encode, and decode. """ +from collections import Counter + +try: + from typing import Dict, Iterable, Callable, List, Any, Iterator +except ImportError: + pass + +from nltk.tokenize import wordpunct_tokenize +from tqdm import tqdm +import toolz +import json + +DEFAULT_EOW = '__eow' +DEFAULT_SOW = '__sow' +DEFAULT_UNK = '__unk' +DEFAULT_PAD = '__pad' +DEFAULT_SOS = '__sos' +DEFAULT_EOS = '__eos' + +class Encoder: + """ Encodes white-space separated text using byte-pair encoding. See https://arxiv.org/abs/1508.07909 for details. + """ + + def __init__(self, vocab_size=8192, pct_bpe=0.2, word_tokenizer=None, + silent=True, ngram_min=2, ngram_max=2, required_tokens=None, strict=False, + EOW=DEFAULT_EOW, SOW=DEFAULT_SOW, UNK=DEFAULT_UNK, PAD=DEFAULT_PAD, SOS=DEFAULT_SOS,\ + EOS=DEFAULT_EOS): + if vocab_size < 1: + raise ValueError('vocab size must be greater than 0.') + + self.EOW = EOW + self.SOW = SOW + self.eow_len = len(EOW) + self.sow_len = len(SOW) + self.UNK = UNK + self.PAD = PAD + self.SOS = SOS + self.EOS = EOS + self.required_tokens = [self.UNK, self.PAD, self.SOS, self.EOS] + self.vocab_size = vocab_size + self.pct_bpe = pct_bpe + self.word_vocab_size = max([int(vocab_size * (1 - pct_bpe)), len(self.required_tokens or [])]) + self.bpe_vocab_size = vocab_size - self.word_vocab_size + self.word_tokenizer = word_tokenizer if word_tokenizer is not None else wordpunct_tokenize + self.custom_tokenizer = word_tokenizer is not None + self.word_vocab = {} # type: Dict[str, int] + self.bpe_vocab = {} # type: Dict[str, int] + self.inverse_word_vocab = {} # type: Dict[int, str] + self.inverse_bpe_vocab = {} # type: Dict[int, str] + self._progress_bar = iter if silent else tqdm + self.ngram_min = ngram_min + self.ngram_max = ngram_max + self.strict = strict + + def mute(self): + """ Turn on silent mode """ + self._progress_bar = iter + + def unmute(self): + """ Turn off silent mode """ + self._progress_bar = tqdm + + def byte_pair_counts(self, words): + # type: (Encoder, Iterable[str]) -> Iterable[Counter] + """ Counts space separated token character pairs: + [('T h i s ', 4}] -> {'Th': 4, 'hi': 4, 'is': 4} + """ + for token, count in self._progress_bar(self.count_tokens(words).items()): + bp_counts = Counter() # type: Counter + for ngram in token.split(' '): + bp_counts[ngram] += count + for ngram_size in range(self.ngram_min, min([self.ngram_max, len(token)]) + 1): + ngrams = [''.join(ngram) for ngram in toolz.sliding_window(ngram_size, token.split(' '))] + + for ngram in ngrams: + bp_counts[''.join(ngram)] += count + + yield bp_counts + + def count_tokens(self, words): + # type: (Encoder, Iterable[str]) -> Dict[str, int] + """ Count tokens into a BPE vocab """ + token_counts = Counter(self._progress_bar(words)) + return {' '.join(token): count for token, count in token_counts.items()} + + def learn_word_vocab(self, sentences): + # type: (Encoder, Iterable[str]) -> Dict[str, int] + """ Build vocab from self.word_vocab_size most common tokens in provided sentences """ + word_counts = Counter(word for word in toolz.concat(map(self.word_tokenizer, sentences))) + for token in self.required_tokens: + word_counts[token] = int(2**63) + sorted_word_counts = sorted(word_counts.items(), key=lambda p: -p[1]) + return {word: idx for idx, (word, count) in enumerate(sorted_word_counts[:self.word_vocab_size])} + + def learn_bpe_vocab(self, words): + # type: (Encoder, Iterable[str]) -> Dict[str, int] + """ Learns a vocab of byte pair encodings """ + vocab = Counter() # type: Counter + for token in {self.SOW, self.EOW}: + vocab[token] = int(2**63) + for idx, byte_pair_count in enumerate(self.byte_pair_counts(words)): + for byte_pair, count in byte_pair_count.items(): + vocab[byte_pair] += count + + if (idx + 1) % 10000 == 0: + self.trim_vocab(10 * self.bpe_vocab_size, vocab) + + sorted_bpe_counts = sorted(vocab.items(), key=lambda p: -p[1])[:self.bpe_vocab_size] + return {bp: idx + self.word_vocab_size for idx, (bp, count) in enumerate(sorted_bpe_counts)} + + def fit(self, text): + # type: (Encoder, Iterable[str]) -> None + """ Learn vocab from text. """ + _text = [l.lower().strip() for l in text] + + # First, learn word vocab + self.word_vocab = self.learn_word_vocab(_text) + + remaining_words = [word for word in toolz.concat(map(self.word_tokenizer, _text)) + if word not in self.word_vocab] + self.bpe_vocab = self.learn_bpe_vocab(remaining_words) + + self.inverse_word_vocab = {idx: token for token, idx in self.word_vocab.items()} + self.inverse_bpe_vocab = {idx: token for token, idx in self.bpe_vocab.items()} + + @staticmethod + def trim_vocab(n, vocab): + # type: (int, Dict[str, int]) -> None + """ Deletes all pairs below 10 * vocab size to prevent memory problems """ + pair_counts = sorted(vocab.items(), key=lambda p: -p[1]) + pairs_to_trim = [pair for pair, count in pair_counts[n:]] + for pair in pairs_to_trim: + del vocab[pair] + + def subword_tokenize(self, word): + # type: (Encoder, str) -> List[str] + """ Tokenizes inside an unknown token using BPE """ + end_idx = min([len(word), self.ngram_max]) + sw_tokens = [self.SOW] + start_idx = 0 + + while start_idx < len(word): + subword = word[start_idx:end_idx] + if subword in self.bpe_vocab: + sw_tokens.append(subword) + start_idx = end_idx + end_idx = min([len(word), start_idx + self.ngram_max]) + elif len(subword) == 1: + sw_tokens.append(self.UNK) + start_idx = end_idx + end_idx = min([len(word), start_idx + self.ngram_max]) + else: + end_idx -= 1 + + sw_tokens.append(self.EOW) + return sw_tokens + + def tokenize(self, sentence): + # type: (Encoder, str) -> List[str] + """ Split a sentence into word and subword tokens """ + word_tokens = self.word_tokenizer(sentence.lower().strip()) + + tokens = [] + for word_token in word_tokens: + if word_token in self.word_vocab: + tokens.append(word_token) + else: + tokens.extend(self.subword_tokenize(word_token)) + + return tokens + + def transform(self, sentences, reverse=False, fixed_length=None): + # type: (Encoder, Iterable[str], bool, int) -> Iterable[List[int]] + """ Turns space separated tokens into vocab idxs """ + direction = -1 if reverse else 1 + for sentence in self._progress_bar(sentences): + encoded = [] + tokens = list(self.tokenize(sentence.lower().strip())) + for token in tokens: + if token in self.word_vocab: + encoded.append(self.word_vocab[token]) + elif token in self.bpe_vocab: + encoded.append(self.bpe_vocab[token]) + else: + encoded.append(self.word_vocab[self.UNK]) + + if fixed_length is not None: + encoded = encoded[:fixed_length] + while len(encoded) < fixed_length: + encoded.append(self.word_vocab[self.PAD]) + + yield encoded[::direction] + + def inverse_transform(self, rows): + # type: (Encoder, Iterable[List[int]]) -> Iterator[str] + """ Turns token indexes back into space-joined text. """ + for row in rows: + words = [] + + rebuilding_word = False + current_word = '' + for idx in row: + if self.inverse_bpe_vocab.get(idx) == self.SOW: + if rebuilding_word and self.strict: + raise ValueError('Encountered second SOW token before EOW.') + rebuilding_word = True + + elif self.inverse_bpe_vocab.get(idx) == self.EOW: + if not rebuilding_word and self.strict: + raise ValueError('Encountered EOW without matching SOW.') + rebuilding_word = False + words.append(current_word) + current_word = '' + + elif rebuilding_word and (idx in self.inverse_bpe_vocab): + current_word += self.inverse_bpe_vocab[idx] + + elif rebuilding_word and (idx in self.inverse_word_vocab): + current_word += self.inverse_word_vocab[idx] + + elif idx in self.inverse_word_vocab: + words.append(self.inverse_word_vocab[idx]) + + elif idx in self.inverse_bpe_vocab: + if self.strict: + raise ValueError("Found BPE index {} when not rebuilding word!".format(idx)) + else: + words.append(self.inverse_bpe_vocab[idx]) + + else: + raise ValueError("Got index {} that was not in word or BPE vocabs!".format(idx)) + + yield ' '.join(w for w in words if w != '') + + def vocabs_to_dict(self, dont_warn=False): + # type: (Encoder, bool) -> Dict[str, Dict[str, int]] + """ Turns vocab into dict that is json-serializeable """ + if self.custom_tokenizer and not dont_warn: + print("WARNING! You've specified a non-default tokenizer. You'll need to reassign it when you load the " + "model!") + return { + 'byte_pairs': self.bpe_vocab, + 'words': self.word_vocab, + 'kwargs': { + 'vocab_size': self.vocab_size, + 'pct_bpe': self.pct_bpe, + 'silent': self._progress_bar is iter, + 'ngram_min': self.ngram_min, + 'ngram_max': self.ngram_max, + 'required_tokens': self.required_tokens, + 'strict': self.strict, + 'EOW': self.EOW, + 'SOW': self.SOW, + 'UNK': self.UNK, + 'PAD': self.PAD, + 'SOS': self.SOS, + 'EOS': self.EOS + } + } + + def save(self, outpath, dont_warn=False): + # type: (Encoder, str, bool) -> None + """ Serializes and saves encoder to provided path """ + with open(outpath, 'w') as outfile: + json.dump(self.vocabs_to_dict(dont_warn), outfile) + + @classmethod + def from_dict(cls, vocabs): + # type: (Any, Dict[str, Dict[str, int]]) -> Encoder + """ Load encoder from dict produced with vocabs_to_dict """ + encoder = Encoder(**vocabs['kwargs']) + encoder.word_vocab = vocabs['words'] + encoder.bpe_vocab = vocabs['byte_pairs'] + + encoder.inverse_bpe_vocab = {v: k for k, v in encoder.bpe_vocab.items()} + encoder.inverse_word_vocab = {v: k for k, v in encoder.word_vocab.items()} + + return encoder + + @classmethod + def load(cls, in_path): + # type: (Any, str) -> Encoder + """ Loads an encoder from path saved with save """ + with open(in_path) as infile: + obj = json.load(infile) + return cls.from_dict(obj) \ No newline at end of file diff --git a/ner/utils/misc_utils.py b/ner/utils/misc_utils.py new file mode 100644 index 0000000..22f5332 --- /dev/null +++ b/ner/utils/misc_utils.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Jul 17 18:06:56 2019 + +@author: WT +""" + +import os +import pickle +import torch +import numpy as np +import math + +def load_pickle(filename): + completeName = os.path.join("./data/",\ + filename) + with open(completeName, 'rb') as pkl_file: + data = pickle.load(pkl_file) + return data + +def save_as_pickle(filename, data): + completeName = os.path.join("./data/",\ + filename) + with open(completeName, 'wb') as output: + pickle.dump(data, output) + + +class CosineWithRestarts(torch.optim.lr_scheduler._LRScheduler): + """ + Cosine annealing with restarts. + Parameters + ---------- + optimizer : torch.optim.Optimizer + T_max : int + The maximum number of iterations within the first cycle. + eta_min : float, optional (default: 0) + The minimum learning rate. + last_epoch : int, optional (default: -1) + The index of the last epoch. + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + T_max: int, + eta_min: float = 0., + last_epoch: int = -1, + factor: float = 1.) -> None: + # pylint: disable=invalid-name + self.T_max = T_max + self.eta_min = eta_min + self.factor = factor + self._last_restart: int = 0 + self._cycle_counter: int = 0 + self._cycle_factor: float = 1. + self._updated_cycle_len: int = T_max + self._initialized: bool = False + super(CosineWithRestarts, self).__init__(optimizer, last_epoch) + + def get_lr(self): + """Get updated learning rate.""" + # HACK: We need to check if this is the first time get_lr() was called, since + # we want to start with step = 0, but _LRScheduler calls get_lr with + # last_epoch + 1 when initialized. + if not self._initialized: + self._initialized = True + return self.base_lrs + + step = self.last_epoch + 1 + self._cycle_counter = step - self._last_restart + + lrs = [ + ( + self.eta_min + ((lr - self.eta_min) / 2) * + ( + np.cos( + np.pi * + ((self._cycle_counter) % self._updated_cycle_len) / + self._updated_cycle_len + ) + 1 + ) + ) for lr in self.base_lrs + ] + + if self._cycle_counter % self._updated_cycle_len == 0: + # Adjust the cycle length. + self._cycle_factor *= self.factor + self._cycle_counter = 0 + self._updated_cycle_len = int(self._cycle_factor * self.T_max) + self._last_restart = step + + return lrs + +def lrate(n, d_model, k=10, warmup_n=25000): + lr = (k/math.sqrt(d_model))*min(1/math.sqrt(n), n*warmup_n**(-1.5)) + return lr \ No newline at end of file diff --git a/ner/utils/word_char_level_vocab.py b/ner/utils/word_char_level_vocab.py new file mode 100644 index 0000000..d1770ef --- /dev/null +++ b/ner/utils/word_char_level_vocab.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Jul 30 16:35:16 2019 + +@author: WT +""" +import spacy +import re +from string import ascii_lowercase +from tqdm import tqdm + +class tokener(object): + def __init__(self, lang="en"): + d = {"en":"en_core_web_sm", "fr":"fr_core_news_sm"} + self.ob = spacy.load(d[lang]) + + def tokenize(self, sent): + sent = re.sub(r"[\*\"“”\n\\…\+\-\/\=\(\)‘•:\[\]\|’\!;]", " ", str(sent)) + sent = re.sub(r"\!+", "!", sent) + sent = re.sub(r"\,+", ",", sent) + sent = re.sub(r"\?+", "?", sent) + sent = re.sub(r"[ ]+", " ", sent) + sent = sent.lower() + sent = [token.text for token in self.ob.tokenizer(sent) if token.text != " "] + return sent + +class vocab(object): + def __init__(self, level="word", model="transformer"): + self.model = model + if model == "transformer": + self.w2idx = {"":0, "":2, "":1} + self.idx2w = {0:"", 2:"", 1:""} + self.idx = 3 + self.level = level + + elif model == "h_encoder_decoder": + self.w2idx = {"":0, "":2, "":1, "":3, "":4} + self.idx2w = {0:"", 2:"", 1:"", 3:"", 4:""} + self.idx = 5 + self.level = level + + def build_vocab(self, df_text): + if self.level == "word": + word_soup = set([word for text in df_text for word in text]) + print("Building word vocab...") + for word in tqdm(word_soup): + if word not in self.w2idx.keys(): + self.w2idx[word] = self.idx + self.idx += 1 + + elif self.level == "char": + self.w2idx.update({k:v for k,v in zip(ascii_lowercase, [i for i in range(self.idx, len(ascii_lowercase) + self.idx)])}) + self.idx += len(ascii_lowercase) + self.w2idx[" "] = self.idx; self.idx += 1 + self.w2idx["'"] = self.idx; self.idx += 1 + + self.idx2w.update({v:k for k,v in self.w2idx.items() if v not in self.idx2w.keys()}) + + def convert_w2idx(self, word_list): + if self.level == "word": + w = [] + for word in word_list: + w.extend([self.w2idx[word]]) + return w + + elif self.level == "char": + return [self.w2idx[c] for c in " ".join(word_list)] + + def convert_idx2w(self, idx_list): + if self.level == "word": + return [self.idx2w[idx] for idx in idx_list] + elif self.level == "char": + return [self.idx2w[idx] for idx in idx_list] \ No newline at end of file diff --git a/translate.py b/translate.py index be9c58c..19ff570 100644 --- a/translate.py +++ b/translate.py @@ -26,7 +26,7 @@ parser.add_argument("--max_encoder_len", type=int, default=80, help="Max src length") parser.add_argument("--max_decoder_len", type=int, default=80, help="Max trg length") parser.add_argument("--num_epochs", type=int, default=500, help="No of epochs") - parser.add_argument("--lr", type=float, default=0.00002, help="learning rate") + parser.add_argument("--lr", type=float, default=0.00001, help="learning rate") parser.add_argument("--gradient_acc_steps", type=int, default=1, help="Number of steps of gradient accumulation") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipped gradient norm") parser.add_argument("--model_no", type=int, default=0, help="Model ID (0: Transformer)")