Skip to content

Implementation of Marge, Pre-training via Paraphrasing, in Pytorch

License

Notifications You must be signed in to change notification settings

lucidrains/marge-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

61 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Marge - Pre-training via Paraphrasing

Implementation of Marge, Pre-training via Paraphrasing, in Pytorch. It is an alternative to masked language modeling pretraining, where an encoder / decoder attention network learns to reconstruct a target document from a collection of evidence documents.

Update: Three researchers have independently reported that the repository works for them

Install

$ pip install marge-pytorch

Usage

import torch
import numpy as np
from torch.utils.data import DataLoader

from marge_pytorch import Marge, TrainingWrapper

# your documents must be tokenized and stored as memmap in the shape (num documents, seq length)

# constants
NUM_DOCS = 10000
SEQ_LEN = 1024
SHAPE = (NUM_DOCS, SEQ_LEN)

# generate mock training data
f = np.memmap('./train.dat', dtype=np.int32, mode='w+', shape=SHAPE)
f[:] = np.random.randint(0, 20000, size=SHAPE)
del f

# generate mock masking data
f = np.memmap('./train.mask.dat', dtype=np.bool, mode='w+', shape=SHAPE)
f[:] = np.full(SHAPE, True)
del f

# instantiate model

model = Marge(
    dim = 512,
    num_tokens = 20000,
    max_seq_len = SEQ_LEN,
    enc_depth = 12,
    enc_retrieval_depth = 4,                # defaults to 4 as in paper (take the CLS token after the 4th layer of the encoder)
    enc_heads = 8,
    enc_ff_mult = 4,
    dec_depth = 12,
    dec_heads = 8,
    dec_ff_mult = 16,                       # paper noted that decoder needs to have much bigger feed forward sizes
    distill_attn = False,                   # (experimental) will add, on top of the decoder loss, an auxiliary distillation loss as defined in https://arxiv.org/abs/2012.04584
    distill_loss_coef = 1.                  # weight of distillation auxilliary loss         
 )

# wrap your model and your documents

trainer = TrainingWrapper(
    model,
    num_documents = NUM_DOCS,
    doc_seq_len = SEQ_LEN,
    num_evidence = 4,                         # number of evidence documents to fetch per target document to construct
    reindex_batch_size = 32,                  # batch size to use when reindexing
    documents_memmap_path = './train.dat',    # path to the mem-mapped documents
    masks_memmap_path = './train.mask.dat',   # if None is supplied, will assume all tokens are visible
    use_faiss_ann = True                      # set this to false if you have a low number of documents, and approximate nearest neighbor is not needed
)

# instantiate dataloader

dl = DataLoader(trainer.dataset, batch_size=16)

# now you can train, and use the reindex method on the training wrapper at appropriate intervals

for ind, data in enumerate(dl):
    loss = trainer(data)
    loss.backward()
    # optimizer step and all that

    # reindex and precompute knn every 10000 steps, as in paper
    if ind > 0 and ind % 10000 == 0:
        trainer.reindex()

Save your model after much training

torch.save(model, f'./trained-model.pt')

Advanced

If you would like the target and evidence documents to be from different sets, you just have to pass in up to four additional keyword arguments, as shown below.

trainer = TrainingWrapper(
    model,
    num_documents = NUM_DOCS,
    doc_seq_len = SEQ_LEN,
    num_evidence = 4,
    reindex_batch_size = 32,
    documents_memmap_path = './evidence.dat',
    masks_memmap_path = './evidence.mask.dat',
    num_targets = NUM_TARGETS,                       # 1. number of target documents, with sequence length the same as the document (evidence)
    target_seq_len = SEQ_LEN,                        # 2. sequence length of target documents
    target_memmap_path = './target.dat',             # 3. path to target memmap, same as documents (evidence)
    target_masks_memmap_path = './target.mask.dat',  # 4. path to target mask memmap, same as document masks (evidence)
    use_faiss_ann = True
)

Sampling

You can sample from the decoder with the following instructions

# some random evidence from the dataset
# or provide your own in the dimensions (b x num_evidences x seq_len)
*_, evidence, mask = trainer.dataset[0:1]

# assume 1 is start token
prime = torch.tensor([[1.]]).long().cuda()

# supply your own document similarities array (b x num_evidences)
# if not supplied, will default to 1. for all evidence
doc_similarities = torch.ones(evidence.shape[:2]).float().cuda()

# generate sample of length 1024
samples = model.generate(prime, 1024, evidence, mask = mask, similarities = doc_similarities)

Citations

@misc{lewis2020pretraining,
    title={Pre-training via Paraphrasing},
    author={Mike Lewis and Marjan Ghazvininejad and Gargi Ghosh and Armen Aghajanyan and Sida Wang and Luke Zettlemoyer},
    year={2020},
    eprint={2006.15020},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
@misc{komatsuzaki2020current,
    title={Current Limitations of Language Models: What You Need is Retrieval},
    author={Aran Komatsuzaki},
    year={2020},
    eprint={2009.06857},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}
@misc{izacard2020distilling,
    title={Distilling Knowledge from Reader to Retriever for Question Answering},
    author={Gautier Izacard and Edouard Grave},
    year={2020},
    eprint={2012.04584},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}