Skip to content

Commit

Permalink
Added KLAnnealing callback and LSTMLanguageModel benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Jan 28, 2021
1 parent 74e4891 commit 689fb5f
Show file tree
Hide file tree
Showing 13 changed files with 141 additions and 69 deletions.
3 changes: 3 additions & 0 deletions __main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def main(args):
gpu_available = torch.cuda.is_available()
config = OmegaConf.create({
'aggressive_encoder_training': False,
'kl_annealing': False,
# Override Trainer defaults but still allow them to be overridden by the command line
'trainer': {
'auto_select_gpus': gpu_available,
Expand Down Expand Up @@ -65,6 +66,8 @@ def main(args):
callbacks = [EarlyStopping(monitor='val_loss'), UnconditionalSampler()]
if config.aggressive_encoder_training:
callbacks.append(AggressiveEncoderTraining())
if config.kl_annealing:
callbacks.append(KLAnnealing())

trainer = Trainer(**config.trainer, callbacks=callbacks)

Expand Down
11 changes: 2 additions & 9 deletions benchmarks/LSTMAutoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,14 @@
from typing import *
from .LSTMDecoder import LSTMDecoder
from .LSTMEncoder import LSTMEncoder
from .LSTMLanguageModel import LSTMLanguageModelHparams


@dataclass
class LSTMAutoencoderHparams(AutoencoderHparams):
class LSTMAutoencoderHparams(AutoencoderHparams, LSTMLanguageModelHparams):
enc_nh: int = 1024 # Dimensionality of the encoder's LSTM hidden state
dec_nh: int = 1024 # Dimensionality of the decoder's LSTM hidden state
dec_dropout_in: float = 0.5
dec_dropout_out: float = 0.5
ni: int = 512 # Dimensionality of the input embedding vectors
latent_depth: int = 32 # Dimensionality of the latent variable vector

vocab_size: int = 30522
cls_id: int = 101
sep_id: int = 102


class LSTMAutoencoder(Autoencoder):
"""VAE with normal prior"""
Expand Down
13 changes: 4 additions & 9 deletions benchmarks/LSTMEncoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
import torch.nn as nn
from torch.distributions import Normal


class LSTMEncoder(nn.Module):
Expand All @@ -11,12 +11,7 @@ def __init__(self, hparams):
self.nz = hparams.latent_depth

self.embed = nn.Embedding(hparams.vocab_size, hparams.ni)

self.lstm = nn.LSTM(input_size=hparams.ni,
hidden_size=hparams.enc_nh,
num_layers=1,
batch_first=True,
dropout=0)
self.lstm = nn.LSTM(input_size=hparams.ni, hidden_size=hparams.enc_nh, batch_first=True)

# dimension transformation to z (mean and logvar)
self.linear = nn.Linear(hparams.enc_nh, 2 * hparams.latent_depth, bias=False)
Expand All @@ -28,7 +23,7 @@ def reset_parameters(self):

nn.init.uniform_(self.embed.weight, -0.1, 0.1)

def sample(self, input, nsamples):
def sample(self, inputs, nsamples):
"""sampling from the encoder
Returns: Tensor1, Tuple
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
Expand All @@ -37,7 +32,7 @@ def sample(self, input, nsamples):
"""

# (batch_size, nz)
distribution = self.forward(input)
distribution = self.forward(inputs)

# (batch, nsamples, nz)
z = distribution.rsample([nsamples])
Expand Down
44 changes: 44 additions & 0 deletions benchmarks/LSTMLanguageModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from dataclasses import dataclass
from omegaconf import OmegaConf
from torch import nn, Tensor
from typing import *
import pytorch_lightning as pl
import torch.nn.functional as F


@dataclass
class LSTMLanguageModelHparams:
dec_nh: int = 1024 # Dimensionality of the LSTM hidden state
dec_dropout_in: float = 0.5
dec_dropout_out: float = 0.5
ni: int = 512 # Dimensionality of the input embedding vectors

vocab_size: int = 30522
cls_id: int = 101
sep_id: int = 102


class LSTMLanguageModel(pl.LightningModule):
def __init__(self, hparams: OmegaConf):
super(LSTMLanguageModel, self).__init__()
self.save_hyperparameters(hparams)

self.embed = nn.Embedding(hparams.vocab_size, hparams.ni)
self.decoder = nn.LSTM(input_size=hparams.ni, hidden_size=hparams.dec_nh, batch_first=True)
self.logit_linear = nn.Linear(in_features=hparams.dec_nh, out_features=hparams.vocab_size)

# Returns [batch, seq_len, vocab_size] tensor of logits
def forward(self, batch: Dict[str, Tensor]) -> Tensor:
x = batch['token_ids']
x = self.embed(x)
x, _ = self.decoder(x)
return self.logit_linear(x)

def training_step(self, batch: Dict[str, Tensor], batch_index: int, val: bool = False) -> Tensor:
logits = self.forward(batch)
loss = F.cross_entropy(input=logits, target=batch['token_ids'])
self.log('train_loss' if not val else 'val_loss', loss)
return loss

def validation_step(self, batch: Dict[str, Tensor], batch_index: int) -> Tensor:
return self.training_step(batch, batch_index, val=True)
2 changes: 1 addition & 1 deletion tests/test_FunnelTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_backward_compatibility(self):
for positional_encoding_type in ("factorized", "rel_shift"):
new_model = FunnelTransformer(FunnelTransformerHparams(
positional_encoding_type=positional_encoding_type,
block_outputs_to_return=list(range(12))
return_block_outputs=True
))
new_model.load_pretrained_weights()
new_model.eval()
Expand Down
28 changes: 18 additions & 10 deletions text_vae/AggressiveEncoderTraining.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .Autoencoder import Autoencoder
from collections import deque
from dataclasses import asdict, dataclass, field
from pytorch_lightning.callbacks.base import Callback
from torch import Tensor
Expand All @@ -17,7 +16,6 @@ class AggressiveEncoderTraining(Callback):
_last_decoder_update: int = field(default=0, init=False)
_last_loss: float = field(default=0.0, init=False)
_last_mutual_info: float = field(default=0.0, init=False)
# _loss_history: deque = field(default_factory=lambda: deque(maxlen=10), init=False)

def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.__dict__.update(checkpointed_state)
Expand All @@ -35,11 +33,13 @@ def on_sanity_check_end(self, trainer, autoencoder: Autoencoder):
def on_train_start(self, trainer, autoencoder: Autoencoder):
autoencoder.decoder_requires_grad_(False)

def on_train_batch_end(self, trainer, autoencoder: Autoencoder, outputs: List, batch, batch_index, dataloader_idx):
def on_after_backward(self, trainer, autoencoder: Autoencoder):
if self._aggressive_stage_complete:
return

inner_loop_step = batch_index - self._last_decoder_update
cur_step = autoencoder.global_step

inner_loop_step = cur_step - self._last_decoder_update
new_loss = _to_scalar(getattr(autoencoder, 'last_loss'))

update_decoder = (
Expand All @@ -50,22 +50,30 @@ def on_train_batch_end(self, trainer, autoencoder: Autoencoder, outputs: List, b
)
autoencoder.decoder_requires_grad_(update_decoder)
if update_decoder:
self._last_decoder_update = batch_index
self._last_decoder_update = cur_step
self._last_loss = new_loss

def on_validation_end(self, trainer, autoencoder: Autoencoder):
if self._aggressive_stage_complete:
return

new_mutual_info = _to_scalar(trainer.callback_metrics['mutual_info'])
if new_mutual_info < self._last_mutual_info:
autoencoder.print("Aggressive encoder training complete.")
raw_mutual_info = trainer.callback_metrics.get('mutual_info')
if raw_mutual_info is None and trainer.current_epoch >= 4:
self.end_aggressive_training(autoencoder)
return

self._aggressive_stage_complete = True
autoencoder.decoder_requires_grad_(True)
new_mutual_info = _to_scalar(raw_mutual_info)
if new_mutual_info < self._last_mutual_info:
self.end_aggressive_training(autoencoder)
else:
self._last_mutual_info = new_mutual_info

def end_aggressive_training(self, autoencoder: Autoencoder):
autoencoder.print("Aggressive encoder training complete.")

self._aggressive_stage_complete = True
autoencoder.decoder_requires_grad_(True)

# Convenience method
def _to_scalar(x):
return x.item() if isinstance(x, Tensor) else x
1 change: 1 addition & 0 deletions text_vae/Autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class AutoencoderHparams(ABC):

batch_size: int = 0 # This is here just for compatibility with pl.Trainer's auto_scale_batch_size feature
grad_clip_threshold: float = 150.0
kl_weight: float = 1.0
lr: float = 1e-4
lr_decay_steps: Optional[int] = 150_000
warmup_steps: int = 1000
Expand Down
10 changes: 5 additions & 5 deletions text_vae/AutoencoderDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from multiprocessing import cpu_count
from omegaconf import OmegaConf
from pathlib import Path
from tokenizers import BertWordPieceTokenizer
from tokenizers import BertWordPieceTokenizer # noqa
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader # noqa
from .Utilities import *
import os
import numpy as np
Expand Down Expand Up @@ -117,13 +117,13 @@ def sentence_split(batch: Dict[str, list]) -> Dict[str, list]:
sentences = sent_tokenizer.tokenize_sents(batch['text'])
return {'text': fast_flatten(sentences)} # Chain lists of sentences together from different samples

batch = self.get_reasonable_preprocessing_batch_size()
b_sz = self.get_reasonable_preprocessing_batch_size()

print(f"Finding sentence boundaries for '{self.hparams.dataset_name}'...")
self.dataset = self.dataset.map(sentence_split, batched=True, batch_size=batch, remove_columns=nontext_cols)
self.dataset = self.dataset.map(sentence_split, batched=True, batch_size=b_sz, remove_columns=nontext_cols)

print(f"Tokenizing '{self.hparams.dataset_name}'...")
self.dataset = self.dataset.map(tokenize, batched=True, batch_size=batch * max(10, cpu_count()))
self.dataset = self.dataset.map(tokenize, batched=True, batch_size=b_sz * max(10, cpu_count()))

# This is for when we're generating batches of multiple full sentences
if max_sents > 1:
Expand Down
25 changes: 9 additions & 16 deletions text_vae/FunnelTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ class FunnelTransformerHparams:

positional_encoding_type: str = 'rel_shift' # 'absolute', 'absolute_decoupled', 'rel_shift' or 'factorized'

# Whether to return the pre-pooling output of each block on forward(). If a Sequence, then only the output of
# selected blocks will be returned.
block_outputs_to_return: Sequence[int] = field(default_factory=list)
use_convolutions: bool = False
# Whether to return the pre-pooling output of each block
return_block_outputs: bool = False
use_performer_attention: bool = False
use_initialization_scaling: bool = False
upsampling: bool = False # True for the "reverse" funnel transformer; e.g. a VAE decoder
Expand Down Expand Up @@ -109,7 +107,7 @@ def hidden_state_coroutine(self, x: Tensor, padding_mask: Tensor, *, states_to_y
if not attn_state.shared:
attn_state.configure_for_input(x.shape[1], x.dtype, x.device, padding_mask)

hidden_states = {}
hidden_states = []
layer_iter = iter(enumerate(self.layers))
q = kv = x

Expand Down Expand Up @@ -137,10 +135,12 @@ def hidden_state_coroutine(self, x: Tensor, padding_mask: Tensor, *, states_to_y
kv = q

# Cache block outputs if indicated
if block_idx in hparams.block_outputs_to_return:
hidden_states[block_idx] = kv
if hparams.return_block_outputs:
hidden_states.append(kv)

output = {}
if hidden_states:
output['hidden_states'] = hidden_states
if hparams.upsampling:
# Non-autoregressively generate a softmax distribution over words
output['logits'] = self.output_layer(q)
Expand Down Expand Up @@ -252,10 +252,7 @@ def __init__(self, hparams):
super().__init__()

d_model = hparams.d_model
if hparams.use_convolutions:
self.attention = nn.Conv1d(d_model, d_model, 3)

elif hparams.positional_encoding_type == 'absolute':
if hparams.positional_encoding_type == 'absolute':
# Softmax attention with absolute, sinusoidal positional encodings
if not hparams.use_performer_attention:
raw_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=hparams.num_heads)
Expand Down Expand Up @@ -284,10 +281,6 @@ def absolute_pos_attn_func(q: Tensor, k: Tensor, v: Tensor, attn_state: Attentio
# Q is different from K and V right after pooling; K and V are always the same
def forward(self, q: Tensor, kv: Tensor, attn_state: AttentionState) -> Tensor:
# These custom attention and feedforward layers have built-in residual connections
if isinstance(self.attention, nn.Conv1d):
kv = self.attention(kv.transpose(-2, -1)).transpose(-2, -1) # Channels are dim 1
else:
kv = self.attention(q, kv, kv, attn_state)

kv = self.attention(q, kv, kv, attn_state)
kv = self.feedforward(kv)
return kv
3 changes: 1 addition & 2 deletions text_vae/FunnelWithDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ def __init__(self, hparams: Union[FunnelTransformerHparams, OmegaConf], num_deco

# Make sure the first block's output is returned from the encoder so that we can
# use it in the residual connection for the decoder
if not hparams.block_outputs_to_return:
hparams.block_outputs_to_return = [0]
hparams.return_block_outputs = True

self.encoder = FunnelTransformer(hparams)
self.decoder = nn.Sequential(**[ # noqa
Expand Down
Loading

0 comments on commit 689fb5f

Please sign in to comment.