Skip to content

Hansen06/gpt_stochastic

Repository files navigation

Language modeling via stochastic processes

[Paper] [Open Review] [Long Video]

ICLR Oral 2022

Rose E Wang, Esin Durmus, Noah Goodman, Tatsunori Hashimoto

Introduction

Abstract: Modern language models can generate high-quality short texts. However, they often meander or are incoherent when generating longer texts. These issues arise from the next-token-only language modeling objective. To address these issues, we introduce Time Control (TC), a language model that implicitly plans via a latent stochastic process. TC does this by learning a representation which maps the dynamics of how text changes in a document to the dynamics of a stochastic process of interest. Using this representation, the language model can generate text by first implicitly generating a document plan via a stochastic process, and then generating text that is consistent with this latent plan. Compared to domain-specific methods and fine-tuning GPT2 across a variety of text domains, TC improves performance on text infilling and discourse coherence. On long text generation settings, TC preserves the text structure both in terms of ordering (up to +40% better) and text length consistency (up to +17% better). Human evaluators also prefer TC's output 28.6% more than the baselines.

Contents:

TLDR

Our key contribution is training the encoder to follow a stochastic process of interest via contrastive learning. In our setting, we recover a latent Brownian bridge process. This key contribution can be found in language_modeling_via_stochastic_processes/src/objectives/brownian_bridge.py

Installation

  1. Create a new virtual environment conda create -n lm_via_sp python=3.8
  2. In this repository, run:
pip install -e . # Installing the right libraries for the virtual environment
cd language_modeling_via_stochastic_processes/transformers
pip install -e . # Installing transformers locally; I modified their GPT2 module to take in our learned embeddings for decoding.
  1. Make sure you have a wandb account!
  2. Change the filepaths to correspond to your own in language_modeling_via_stochastic_processes/src/constants.py.

Datasets

This repo contains all but two datasets (Wikihow and Recipe NLG). Instructions are below.

The other four datasets are already in this repo.

TicketTalk

The TicketTalk dataset used in this paper is already included.
It can be found as the TicketTalk dataset (all the json files).

Encoder

python train_encoder.py --config-name=brownian_bridge wandb_settings.exp_dir=wikisection_tc32 data_params.name=wikisection model_params.latent_dim=32

多机多卡

master_ip=192.168.1.10
机器1:CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=0 LOCAL_RANK=0 python main.py
机器2:CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=1 LOCAL_RANK=0 python main.py
机器3:CUDA_VISIBLE_DEVICES="0,1,2,3" MASTER_ADDR=$master_ip MASTER_PORT=45321 WORLD_SIZE=3 NODE_RANK=2 LOCAL_RANK=0 python main.py

More information is under language_modeling_via_stochastic_processes/models/README.md.

Decoder

For training the decoder, you'll need to be in directory language_modeling_via_stochastic_processes/transformers/examples/pytorch/language-modeling/.

The script for training the decoder can be found at language_modeling_via_stochastic_processes/transformers/examples/pytorch/language-modeling/final_experiments.sh

An example command for training a decoder with the Brownian bridge encoder on Wikisection:

python run_time_clm.py --model_name_or_path gpt2 --dataset_name wikisection --do_train --do_eval --per_device_eval_batch_size=1 --per_device_train_batch_size=1 --save_total_limit=1 --load_best_model_at_end=True --overwrite_output_dir --num_train_epochs=10 --seed=1 --encoder_filepath=${path2repo}/language_modeling_via_stochastic_processes/models/wikisection/tc32/epoch=99-step=21999.ckpt --latent_dim=32 --output_dir LM_wikisection_32 --evaluation_strategy=steps --eval_steps=1000 --use_contrastive_embeddings

TLDR: The main thing I changed is in language_modeling_via_stochastic_processes/transformers/src/transformers/models/gpt2/modeling_time_gpt2.py. Specifically, when decoding, I'm using the learned embeddings for running a forward pass on the model:

    # in forward(self,....) on line 917-922
    cl_embeds = self._get_cl_embeddings(raw_text=raw_text,
                                        cl_feats=cl_feats,
                                        seq_cl_feats=seq_cl_feats,
                                        input_ids=input_ids,
                                        seq_len=inputs_embeds.shape[1])
    hidden_states = hidden_states + cl_embeds

Generation

For generating texts, you'll need to be in directory language_modeling_via_stochastic_processes/transformers/examples/pytorch/text-generation/.

The script for generating long texts can be found at language_modeling_via_stochastic_processes/transformers/examples/pytorch/text-generation/final_experiments.sh

An example command for generating long Wikisection texts with the Brownian bridge encoder:

python run_decoding_from_embeddings.py --model_type=gpt2 --model_name_or_path=${path2repo}/language_modeling_via_stochastic_processes/transformers/examples/pytorch/language-modeling/LM_wikisection_32/ --prompt="<|endoftext|>" --num_return_sequences=1 --num_intervals=1000 --method=sample --stop_token="<|endoftext|>" --dataset_name=wikisection --encoder_filepath=${path2repo}/language_modeling_via_stochastic_processes/models/wikisection/tc32/epoch=99-step=75299.ckpt --latent_dim=32 --project=LM_wikisection --no_eos --label=LM_wikisection_32 --seed=0