Skip to content

Commit

Permalink
Further cleanup and doc updates
Browse files Browse the repository at this point in the history
  • Loading branch information
rajcscw committed Oct 4, 2022
1 parent 6f22de5 commit f5e7b37
Show file tree
Hide file tree
Showing 6 changed files with 562 additions and 17 deletions.
156 changes: 139 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
# RL4LMs - A modular RL library to train language models for natural language generation tasks.

We provide building blocks for natural language policy optimization containing on-policy algorithms, reward functions, metrics, datasets and LM actor-critic policies

# RL4LMs - A modular RL library to fine-tune language models for natural language generation tasks.

We provide easily customizable building blocks for training language models including implementations of on-policy algorithms, reward functions, metrics, datasets and LM actor-critic policies

Thoroughly tested and benchmarked on a comprehensive set of:
- 6 different Natural Language Generation (NLG) Tasks:
- Summarization
- Generative CommonSense Reasoning
- IMDB Text Continuation
- Table-to-text generation
- Abstractive Question Answering
- Machine Translation
- Different types of NLG metrics which can be used reward functions:
- Lexical Metrics (ROUGE, BLEU, SacreBLEU, METEOR)
- Semantic Metrics (BERTSCORE, BLEURT)
- Task specific metrics ()
- On-policy algorithms of PPO, A2C, TRPO and novel **NLPO (Natural Language Policy Optimization)**
- Actor-Critic Policies supporting causal LMs (eg. GPT-2/3) and seq2seq LMs (eg. T5, BART)

---
# Install

## Local Installation
Expand All @@ -19,26 +35,132 @@ We provide also a Dockerfile for development using docker containers containing

Optionally, coreNLP libraries are required for certain metric computations (eg. SPICE) which can be downloaded using the bash script `rl4lms/envs/text_generation/caption_metrics/spice`

---
# Quick Start - Train PPO/NLPO using pre-defined YAML configs
We provide a simple training interface `scripts/training/train_text_generation.py` that allows to train PPO, NLPO or supervised by using a config file (YAML).
For instance to train T5-base on CNN/DM summarization on PPO using Rouge-1 as reward function, one can run:
We provide a simple training API that can be invoked via `scripts/training/train_text_generation.py` that allows to train PPO, NLPO or a supervised model by using a config file (YAML).

For example, to train T5-base on CNN/DM summarization on PPO using Rouge-1 as reward function, you can run:

```bash
python scripts/train_text_generation.py --config_path scripts/text_gen_configs/seq2seq/final_configs/cnn_summarization_ppo.yml
python scripts/train_text_generation.py --config_path scripts/task_gen_configs/summarization/t5_ppo.yml
```

Configs for training other tasks and algorithms can be found in: `scripts/text_gen_configs`


Additionally, we support WANDB logging and warm-starting of training by storing checkpoints and other training artifacts in a user-specified path
```bash
WANDB_API_KEY=<YOUR-WANDB-API-KEY-HERE> python scripts/training/train_text_generation.py --config_path <PATH-TO-CONFIG-FILE> --experiment_name <EXPERIMENT-NAME> --base_path_to_store_results <PATH-TO-STORE-RESULTS> --log_to_wandb
```
## YAML file schema - Configuring building blocks

Config file contains details about hyper-parameter settings for building blocks and they are described below:

- **Dataset/Task**: Dataset containing samples with input prompts and reference sentences. Available datasets are found in the class `DataPoolRegistry` in `rl4lms/envs/text_generation/registry.py`. (See how to create your own dataset below)

```yaml
# dataset
datapool:
id: cnn_daily_mail
args:
prompt_prefix: "Summarize: "
```
- **Tokenizer** - A pre-trained tokenizer that is used to (de)tokenize input and output sequences with settings for padding and truncation
```yaml
# tokenizer
tokenizer:
model_name: t5-base
padding_side: left
truncation_side: left
pad_token_as_eos_token: False

```
- **Reward Function**: Reward function which computes token-level scores at each time step of MDP, that is configured using ID and arguments. Available reward functions can be found in the class `RewardFunctionRegistry` in `rl4lms/envs/text_generation/registry.py`. (See how to create your own reward function below)

```yaml
# reward function that is optimized
reward_fn:
id: rouge
args:
rouge_type: "rouge1"
```

- **Environment**: Configures a gym-style environment `rl4lms/envs/text_generation/env.py` which simulates and generates MDP episodes. We use vectorized environment from stable-baselines that processes `n_envs` episodes in parallel using multi-processing to compute step-wise rewards. Further parameters that can be configured are: `max_episode_length` - max length of the episode, `max_prompt_length` - maximum length of the input text to consider, `terminate_on_eos` - whether to terminate the episode as soon as EOS action is performed, `prompt_truncation_side` - truncation side for the prompt text, `context_start_token` - token id for context token (corresponds to initial token given to decoder in encoder-decoder models)

```yaml
env:
n_envs: 10
args:
max_prompt_length: 512
max_episode_length: 100
terminate_on_eos: True
prompt_truncation_side: "right"
context_start_token: 0
```

- **On-policy alg**: TBD
```yaml
alg:
id: ppo
args:
n_steps: 512
batch_size: 64
verbose: 1
learning_rate: 0.000002
n_epochs: 5
ent_coef: 0.0
kl_div:
coeff: 0.001
target_kl: 0.2
policy:
id: seq2seq_lm_actor_critic_policy
args:
model_name: t5-base
apply_model_parallel: True
prompt_truncation_side: "right"
generation_kwargs:
do_sample: True
top_k: 50
min_length: 50
max_new_tokens: 100
```

- **Trainer Config**: TBD

```yaml
# train and evaluation
train_evaluation:
eval_batch_size: 100
n_iters: 100
eval_every: 10
save_every: 1
metrics:
- id: meteor
args: {}
- id: rouge
- id: bleu
args: {}
- id: bert_score
args:
language: en
- id: diversity
args: {}
generation_kwargs:
do_sample: True
top_k: 0
temperature: 0.7
min_length: 50
max_new_tokens: 100
```

## Configs
Configs for training other tasks and algorithms can be found in: `scripts/training/task_configs`

---
# Custom Components (TBD)
RL4LMs provide full customizability - with respect to adding new tasks/datasets, reward functions, evaluation metrics and actor-critic policies.

# Train PPO/NLPO using own (TBD)

# Custom trainer (training loop) (TBD)

# Custom Components (TBD)
RL4LMs provide full customizability - with respect to adding new tasks/datasets, reward functions, evaluation metrics and actor-critic policies.
---

# Logging

Additionally, we support WANDB logging and warm-starting of training by storing checkpoints and other training artifacts in a user-specified path
```bash
WANDB_API_KEY=<YOUR-WANDB-API-KEY-HERE> python scripts/training/train_text_generation.py --config_path <PATH-TO-CONFIG-FILE> --experiment_name <EXPERIMENT-NAME> --base_path_to_store_results <PATH-TO-STORE-RESULTS> --log_to_wandb
84 changes: 84 additions & 0 deletions scripts/training/task_configs/common_gen/t5_nlpo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
tokenizer:
model_name: t5-base
padding_side: left
truncation_side: left
pad_token_as_eos_token: False

reward_fn:
id: meteor
args:
shaping_fn: "common_gen_repeat_penalty"


datapool:
id: commongen
args:
concept_end_token: '.'
concept_separator_token: ' '
prefix: "generate a sentence with: "

env:
n_envs: 10
args:
max_prompt_length: 15
max_episode_length: 20
terminate_on_eos: True
context_start_token: 0

alg:
id: nlpo
args:
n_steps: 128
batch_size: 64
verbose: 1
learning_rate: 0.000002
n_epochs: 5
kl_div:
coeff: 0.001
target_kl: 2.0
policy:
id: maskable_seq2seq_lm_actor_critic_policy
args:
model_name: t5-base
apply_model_parallel: True
mask_type: "learned_top_p"
top_mask: 0.9
target_update_iterations: 20
generation_kwargs:
do_sample: True
top_k: 50
min_length: 10
max_new_tokens: 20

train_evaluation:
eval_batch_size: 100
n_iters: 100
eval_every: 10
save_every: 20
metrics:
- id: meteor
args: {}
- id: rouge
- id: bleu
args: {}
- id: bert_score
args:
language: en
# - id: bleurt
# args:
# config_name: bleurt-large-512
- id: diversity
args: {}
# - id: summaCZS
# args:
# granularity: sentence
# use_ent: True
# use_con: False
# - id: summaCConv
# args:
# granularity: sentence
generation_kwargs:
do_sample: True
top_k: 50
min_length: 10
max_new_tokens: 20
96 changes: 96 additions & 0 deletions scripts/training/task_configs/common_gen/t5_nlpo_on_supervised.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
tokenizer:
model_name: t5-base
padding_side: left
truncation_side: left
pad_token_as_eos_token: False

reward_fn:
id: meteor
# values:
# #- id: rouge_combined
# - id: meteor
# # - id: rouge
# # args:
# # rouge_type: "rouge1"
# # - id: spider
# # args:
# # spice_coeff: 0.0
# # cider_coeff: 1.0
# # - id: spider
# # args:
# # spice_coeff: 1.0
# # cider_coeff: 0.0
# # # - id: spider
# # # args:
# # # spice_coeff: 0.5
# # # cider_coeff: 0.5


datapool:
id: commongen
args:
concept_end_token: '.'
concept_separator_token: ' '
prefix: "generate a sentence with: "


env:
n_envs: 10
args:
max_prompt_length: 20
max_episode_length: 20
terminate_on_eos: True
context_start_token: 0
prompt_truncation_side: "right"


alg:
id: nlpo
args:
n_steps: 128
batch_size: 64
verbose: 1
learning_rate: 0.000002
n_epochs: 5
ent_coef: 0.01
kl_div:
coeff: 0.01
target_kl: 1.0
policy:
id: maskable_seq2seq_lm_actor_critic_policy
args:
model_name: rajkumarrrk/t5-common-gen
apply_model_parallel: True
prompt_truncation_side: "right"
mask_type: "learned_top_p"
top_mask: 0.9
target_update_iterations: 20
generation_kwargs:
do_sample: True
top_k: 50
min_length: 5
max_new_tokens: 20

train_evaluation:
eval_batch_size: 50
n_iters: 100
eval_every: 5
save_every: 20
metrics:
- id: meteor
args: {}
- id: rouge
- id: bleu
args: {}
- id: bert_score
args:
language: en
- id: cider
- id: spice
- id: diversity
args: {}
generation_kwargs:
num_beams: 5
min_length: 5
max_new_tokens: 20

Loading

0 comments on commit f5e7b37

Please sign in to comment.