Skip to content

Commit

Permalink
Progress reproducing "Lagging Inference Networks"
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 4, 2021
1 parent 3abbe89 commit 12508f0
Show file tree
Hide file tree
Showing 15 changed files with 230 additions and 207 deletions.
21 changes: 21 additions & 0 deletions hparam_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Groups of hyperparameters for reproducing papers
hparam_presets = {
'he2019': {
'data': dict(
batch_size=32,
batching_strategy='uniform_length',
chunking_strategy='sentence',
),
'model': dict(
adam_beta1=0.0, # 0.0
adam_beta2=0.0,
init_scale=0.01,
lr=1.0,
lr_plateau_patience=2,
warmup_steps=0
),
'trainer': dict(
accumulate_grad_batches=1
)
}
}
14 changes: 14 additions & 0 deletions text_vae/FeedforwardTransformerVAE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .core import ContinuousVAE, ContinuousVAEHparams
from omegaconf import DictConfig
from typing import *


class FeedforwardTransformerVAEHparams(ContinuousVAEHparams):
pass


class FeedforwardTransformerVAE(ContinuousVAE):
def __init__(self, hparams: DictConfig):
super().__init__(hparams)


2 changes: 1 addition & 1 deletion text_vae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .core.generation import *
from .core.transformer import *
from .core.language_model import *
from .core.autoencoder import *
from .core.continuous_autoencoder import *
from .funnel_transformer import *
from .lstm_autoencoder import *
from .lstm_language_model import *
Expand Down
2 changes: 1 addition & 1 deletion text_vae/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .auto_select_gpu import select_best_gpu
from .attention import *
from .autoencoder import *
from .continuous_autoencoder import *
from .quantizer import *
from .conditional_gaussian import *
from .language_model import *
Expand Down
9 changes: 5 additions & 4 deletions text_vae/core/conditional_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@


class ConditionalGaussian(nn.Module):
def __init__(self, in_features: int, out_features: int, zero_initialized: bool = False):
def __init__(self, in_features: int, out_features: int, zero_initialized: bool = False, bias: bool = True):
super(ConditionalGaussian, self).__init__()

linear = nn.Linear(in_features, out_features * 2)
linear = nn.Linear(in_features, out_features * 2, bias=bias)
if zero_initialized:
linear.bias.data.zero_()
linear.weight.data.zero_()
if bias:
linear.bias.data.zero_()

self.linear = nn.Sequential(nn.GELU(), linear)
self.linear = linear

def forward(self, x: Tensor, temperature: float = 1.0, get_kl: bool = False) -> Union[Normal, Tuple[Normal, Tensor]]:
mu, logsigma = self.linear(x).chunk(2, dim=-1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from abc import abstractmethod
from contextlib import contextmanager
from torch.distributions import Normal
from .language_model import *
Expand All @@ -21,14 +20,14 @@ class ContinuousVAEHparams(LanguageModelHparams, ABC):
early_stopping_metric: str = 'val_loss' # For continuous VAEs we should monitor the whole loss, not just the NLL

class ContinuousVAE(LanguageModel, ABC):
def __init__(self, hparams: DictConfig):
super(LanguageModel, self).__init__(hparams)
self.decoder_frozen = False

def configure_callbacks(self):
callbacks = super().configure_callbacks()
return callbacks + [ReconstructionSampler()]

def setup(self, stage: str):
super().setup(stage)
self.decoder_frozen = False

# Performs a backward pass for both the encoder and decoder networks, using the DReG gradient estimator for the
# encoder parameters. Returns the IWAE importance-weighted estimate of log p(x).
# See Tucker et al. 2018 (https://arxiv.org/pdf/1810.04152.pdf) for derivation.
Expand Down Expand Up @@ -62,9 +61,9 @@ def perform_backward(log_w_value, normalized_w_values, loss_weight = 1.0, retain
self.manual_backward(decoder_loss * loss_weight, retain_graph=retain_graph)
return iwae

log_cond = self.log_prob(x, z, labels) # log p(x|z_i) for all z_i
log_joint = log_p_of_z + log_cond # log p(x, z_i) for all z_i
log_w = log_joint - log_q_of_z # log w_i = log [p(x|z_i)/q(z_i|x)]
log_cond = self.p_of_x_given_z(x, z, labels) # log p(x|z_i) for all z_i
log_joint = log_p_of_z + log_cond # log p(x, z_i) for all z_i
log_w = log_joint - log_q_of_z # log w_i = log [p(x|z_i)/q(z_i|x)]
normalized_w = log_w.softmax(dim=0)

s_hat_weight = 1.0 # The weight placed on the standard DReG loss terms
Expand Down Expand Up @@ -98,7 +97,7 @@ def estimate_log_prob_iw(self, q_of_z: Normal, x: Tensor, labels: Tensor, num_sa
log_ws = []

for z, log_p, log_q in zip(latents, log_p_of_z, log_q_of_z):
log_joint = log_p + self.log_prob(x, z, labels)
log_joint = log_p + self.p_of_x_given_z(x, z, labels)
log_ws += [log_joint - log_q]

log_ws = torch.cat(log_ws) # [chunks * samples, batch]
Expand All @@ -114,8 +113,8 @@ def estimate_mutual_info(conditional_q: Normal, z: Tensor):
return -conditional_q.entropy().sum(dim=-1).mean() - marginal_q.mean()

# Should return p(x|z)
@abstractmethod
def log_prob(self, x, z, labels) -> Tensor:
# @abstractmethod
def p_of_x_given_z(self, x, z, labels) -> Tensor:
raise NotImplementedError

# Called by AggressiveEncoderTraining callback
Expand All @@ -124,7 +123,7 @@ def decoder_requires_grad_(self, requires_grad: bool):
for param in self.decoder_params():
param.requires_grad = requires_grad

@abstractmethod
# @abstractmethod
def decoder_params(self) -> Iterable[nn.Parameter]:
raise NotImplementedError

Expand Down
133 changes: 78 additions & 55 deletions text_vae/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from omegaconf import DictConfig
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch import nn, Tensor
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from typing import *
from .padded_tensor import PaddedTensor
from ..train_callbacks import UnconditionalSampler
Expand All @@ -17,19 +17,26 @@
@dataclass
class LanguageModelHparams(ABC):
batch_size: int = 0 # Just for compatibility with pl.Trainer's auto_scale_batch_size feature
grad_clip_threshold: float = 150.0
grad_clip_threshold: float = 5.0
init_scale: float = 0.02 # Std. deviation of Gaussian used to initialize the weights
lr: float = 2.5e-4
lr_decay_steps: Optional[int] = 150_000
lr_plateau_patience: Optional[int] = None # If non-None, ReduceLROnPlateau scheduler is used w/ specified patience
warmup_steps: int = 1000

# If beta2 ie set to 0.0, we just use SGD (with momentum set to beta1)- useful for replicating papers that use it
adam_beta1: float = 0.9
adam_beta2: float = 0.999
adam_eps: float = 1e-6 # We need to use 1e-6 since 1e-8 underflows to 0 on fp16
weight_decay: float = 0.01

vocab_size: int = 30522
start_token: Optional[int] = None # If None, it's read off the datamodule's Tokenizer object
end_token: Optional[int] = None

# Whether to divide the loss by the number of tokens in each sequence. This is pretty much always done
# for vanilla language models, but often not done for text VAEs.
divide_loss_by_length: bool = True
early_stopping_metric: str = 'val_nll'
log_samples: bool = True

Expand Down Expand Up @@ -66,25 +73,45 @@ def setup(self, stage: str):
self.end_token = vocab['[SEP]']

def configure_optimizers(self, lr: float = None, params = None):
try:
from deepspeed.ops.adam import FusedAdam as Adam
except ImportError:
print("Couldn't import fused Adam kernel from DeepSpeed, falling back on PyTorch version.")
from torch.optim import Adam

adam = Adam(
params or self.parameters(),
lr=lr or self.hparams.lr,
betas=(self.hparams.adam_beta1, 0.999),
weight_decay=self.hparams.weight_decay,
eps=self.hparams.adam_eps
)
lr_lambda = get_cosine_decay_with_warmup_schedule(self.hparams.lr_decay_steps, self.hparams.warmup_steps)

return [adam], [{
'scheduler': LambdaLR(adam, lr_lambda),
'interval': 'step'
}]
beta1, beta2 = self.hparams.adam_beta1, self.hparams.adam_beta2
if beta2 == 0.0:
opt = torch.optim.SGD(
params or self.parameters(),
lr=lr or self.hparams.lr,
momentum=beta1
)
else:
try:
from deepspeed.ops.adam import FusedAdam as Adam
except ImportError:
print("Couldn't import fused Adam kernel from DeepSpeed, falling back on PyTorch version.")
from torch.optim import Adam

opt = Adam(
params or self.parameters(),
lr=lr or self.hparams.lr,
betas=(beta1, beta2),
weight_decay=self.hparams.weight_decay,
eps=self.hparams.adam_eps
)

# This is mainly just here to reproduce the Lagging Inference Networks paper (He et al. 2019) which uses a
# ReduceLROnPlateau-type learning rate schedule
on_plateau_patience = self.hparams.lr_plateau_patience
if on_plateau_patience is not None:
lr_dict = {
'scheduler': ReduceLROnPlateau(opt, factor=0.5, patience=on_plateau_patience),
'monitor': self.hparams.early_stopping_metric,
'interval': 'epoch'
}
else:
lr_lambda = get_cosine_decay_with_warmup_schedule(self.hparams.lr_decay_steps, self.hparams.warmup_steps)
lr_dict = {
'scheduler': LambdaLR(opt, lr_lambda),
'interval': 'step'
}

return [opt], [lr_dict]

def initialize_weights(self):
# Default BERT weight initialization
Expand All @@ -94,49 +121,45 @@ def initialize_weights(self):

if hasattr(module, 'weight'):
module.weight.data.normal_(0.0, self.hparams.init_scale)
if getattr(module, 'bias', None) is not None:
module.bias.data.zero_()

@staticmethod
def ppl_from_nll(nll: Tensor, batch: Dict[str, PaddedTensor]):
# Normalize the NLL using the number of words, not number of tokens, so that it's comparable across
# different subword vocabularies
word_counts = batch.get('num_words')
if word_counts is not None:
token_counts = batch['num_tokens'] if 'num_tokens' in batch else (~batch['token_ids'].padding).sum(dim=-1)
ppl_loss = nll * token_counts / word_counts
else:
ppl_loss = nll

# For some reason it's convention to use base 2 for perplexity scores
return 2 ** (ppl_loss / math.log(2))

def stats_from_logits(self, logits: Tensor, batch: Dict[str, PaddedTensor], autoregressive: bool = False):
ground_truth = batch.get('labels') or batch['token_ids']
if autoregressive:
logits = logits[:, :-1] # Remove final [SEP] token
ground_truth = ground_truth[:, 1:] # Remove initial [CLS] token

log_probs = logits.log_softmax(dim=-1)
loss = F.nll_loss(input=log_probs.flatten(end_dim=-2), target=ground_truth.flatten(), ignore_index=0)
loss = loss.view(*logits[:-2]) # Add extra batch / MC sample dimension(s) if needed

# For some reason it's convention to use base 2 for perplexity scores
ppl = self.ppl_from_nll(loss, batch)
entropy = -(log_probs * log_probs.exp()).sum(dim=-1).mean()
return loss, ppl, entropy
bias = getattr(module, 'bias', None)
bias = getattr(bias, 'data', None)
if bias is not None:
bias.zero_()

# If reduce_batch == False, then this method will not reduce across any dimensions other than sequence length
def stats_from_logits(self, logits: Tensor, labels: Tensor, word_counts: Tensor = None, reduce_batch: bool = True):
nll = F.cross_entropy(input=logits.flatten(end_dim=-2), target=labels.flatten(), ignore_index=0, reduction='none')
nll_sum = nll.view(*logits.shape[:-1]).sum(dim=-1) # Add batch dim(s) back; sum across sequence length

# Divide by the number of non-padding tokens
nll = nll_sum / labels.ne(0).sum(dim=-1) if self.hparams.divide_loss_by_length else nll_sum

if reduce_batch:
# Perplexity is normalized using the number of words, not the number of tokens, so that it's comparable
# across different subword vocabularies
if word_counts is not None:
per_word_nll = (nll_sum / word_counts).mean()
ppl = 2 ** (per_word_nll / math.log(2))
return nll.mean(), ppl.mean()
else:
return nll.mean()
else:
return nll

# These implementations are used by LSTMLanguageModel and TransformerLanguageModel, but are overriden by others
def training_step(self, batch: Dict[str, Tensor], batch_index: int, val: bool = False) -> Tensor:
logits = self.forward(batch)
loss, ppl, entropy = self.stats_from_logits(logits, batch, autoregressive=True)
logits = self.forward(batch) # Remove initial [CLS] token from the labels; autoregressive by default
loss, ppl = self.stats_from_logits(logits, batch['token_ids'][..., 1:], word_counts = batch['num_words'])

# Log the entropy of the model's probability distribution over words to see how confident it is
self.log('pred_entropy', entropy)
self.log('train_nll' if not val else 'val_nll', loss, on_step=not val, on_epoch=val)
self.log('train_ppl' if not val else 'val_ppl', ppl, on_step=not val, on_epoch=val)
return loss

def on_after_backward(self):
grad_norm = nn.utils.clip_grad_norm_(self.parameters(), self.hparams.grad_clip_threshold)
self.log('grad_norm', grad_norm, on_step=True)

def validation_step(self, batch: Dict[str, Tensor], batch_index: int) -> Tensor:
return self.training_step(batch, batch_index, val=True)

Expand Down
20 changes: 3 additions & 17 deletions text_vae/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,17 @@
from typing import *
from dataclasses import dataclass
from itertools import chain
from math import ceil
from tokenizers import Tokenizer
from torch import Tensor
import random
import torch


# Flattens large lists faster than list(itertools.chain.from_iterable(x))
def fast_flatten(original: List[List]) -> List:
# Pre-allocate memory
total_size = sum(len(x) for x in original)
output = [None] * total_size

cur_idx = 0
for x in original:
next_idx = cur_idx + len(x)
output[cur_idx:next_idx] = x
cur_idx = next_idx

return output


# Convert text into WordPiece tokens, while also saving some important stats. This is set up as a freestanding
# function to avoid an annoying crash that happens when dill, a HuggingFace dependency, tries to pickle the function
def tokenize(batch: Dict[str, list], tokenizer: Tokenizer, should_chunk: bool, min_tokens: int,
max_tokens: int) -> Dict[str, list]:
max_tokens: int) -> Dict[str, list]:
if should_chunk:
# Tokenizer has had .enable_truncation(max_tokens) called on it
encodings = tokenizer.encode_batch(batch['text'])
Expand All @@ -34,7 +20,7 @@ def tokenize(batch: Dict[str, list], tokenizer: Tokenizer, should_chunk: bool, m
if len(sample[-1].ids) < min_tokens: # Only the last sequence might possibly be too short
encodings.pop()

encodings = fast_flatten(encodings)
encodings = list(chain.from_iterable(encodings))
else:
encodings = tokenizer.encode_batch(batch['text'])
encodings = [x for x in encodings if min_tokens <= len(x.ids) <= max_tokens]
Expand Down
Loading

0 comments on commit 12508f0

Please sign in to comment.