Skip to content

Commit

Permalink
Sparse Transformer VAEs working on Wikipedia
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 21, 2021
1 parent d4447fd commit a591a35
Show file tree
Hide file tree
Showing 27 changed files with 1,091 additions and 31,176 deletions.
27 changes: 0 additions & 27 deletions eval.py

This file was deleted.

115 changes: 94 additions & 21 deletions hparam_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
'he2019': {
'data': dict(
batch_size=32,
batching_strategy='uniform_length',
chunking_strategy='sentence',
),
'model': dict(
Expand All @@ -28,7 +27,6 @@
'sentence-vae': {
'data': dict(
batch_size=32,
batching_strategy='uniform_length',
chunking_strategy='sentence',
),
'model': dict(
Expand All @@ -47,55 +45,130 @@
precision=32
)
},
'belrose-lstm': {
'lstm-benchmark': {
'model': dict(
bidirectional_encoder=True,
divide_loss_by_length=True,
d_model=1024,
d_embedding=512,
grad_clip_threshold=5.0,
init_scale=None,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=64,
lr=3e-4,
num_latent_vectors=1,
tie_embedding_weights=True,
tie_logit_weights=True,
transformer_encoder=False,
warmup_steps=500
),
'trainer': dict(
accumulate_grad_batches=2
)
},
'lstm-wikipedia': {
'data': dict(
batch_size=64,
batching_strategy='uniform_length',
chunking_strategy='none',
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=50_000,
max_tokens_per_sample=12_500
),
'model': dict(
bidirectional_encoder=True,
divide_loss_by_length=True,
d_model=1024,
d_embedding=512,
grad_clip_threshold=5.0,
init_scale=0.02,
kl_weight_start=0.3,
kl_annealing_steps=8_000,
latent_depth=512,
lr=5e-4,
init_scale=None,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=64,
lr=3e-4,
num_latent_vectors=1,
tie_embedding_weights=True,
tie_logit_weights=True,
transformer_encoder=True,
transformer_encoder=False,
warmup_steps=500
),
'trainer': dict(
accumulate_grad_batches=1
accumulate_grad_batches=2,
val_check_interval=0.25
)
},
'belrose-transformer': {
'data': dict(
batch_size=64,
batching_strategy='uniform_length',
chunking_strategy='none',
'dense-benchmark': {
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_clip_threshold=5.0,
init_scale=0.02,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=64,
lr=3e-4,
num_latent_vectors=1,
num_layers=3,
sparse_self_attention=False,
tie_embedding_weights=True,
warmup_steps=500
),
'trainer': dict(
accumulate_grad_batches=2
)
},
'sparse-benchmark': {
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_clip_threshold=5.0,
init_scale=0.02,
kl_weight_start=0.0,
kl_annealing_steps=50_000,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=64,
lr=1e-4,
lr=3e-4,
num_latent_vectors=1,
num_layers=3,
sparse_self_attention=True,
tie_embedding_weights=True,
warmup_steps=500
),
'trainer': dict(
accumulate_grad_batches=1
accumulate_grad_batches=2
)
},
'wikipedia': {
'data': dict(
chunking_strategy='none',
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=62_500,
# Stub articles (< 160 tokens) make up nearly 1/4 of the dataset and don't help
# the model learn long range dependencies. This way we force the model to get used
# to not having the whole document in its sliding window attention window
min_tokens_per_sample=160,
max_tokens_per_sample=12_500
),
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_checkpointing=False,
grad_clip_threshold=5.0,
init_scale=0.02,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=128,
lr=3e-4,
# lr_decay_steps=1_000_000,
num_latent_vectors=1,
num_layers=6,
sparse_self_attention=True,
tie_embedding_weights=True,
warmup_steps=1000
),
'trainer': dict(
accumulate_grad_batches=2,
val_check_interval=0.1
)
},
}
7 changes: 2 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@ pytorch_lightning>=1.1.2
torch>=1.7.0
tokenizers>=0.9.4
tqdm>=4.49.0
requests>=2.24.0
omegaconf>=2.0.5
nltk>=3.5
ray>=1.1.0
hpbandster
ConfigSpace
deepspeed>=0.3.13
pynvml>=8.0.4
survae>=0.1
torchtext>=0.9.1
torchtext>=0.9.1
pyarrow>=3.0.0
89 changes: 89 additions & 0 deletions sample-vqvae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import sys
from text_vae import *


def try_type_conversion(value, target_type):
try:
return target_type(value)
except ValueError:
print(f"Invalid input- expected value of type {target_type.__name__}.")
return None


def main(args):
version_name = args[1]
sampler = QuantizedVAESampler.for_vae(version_name)

vocab_path = Path.cwd() / 'text-vae-pretrained' / 'tokenizers' / 'yelp_polarity.json'
assert vocab_path.exists(), f"Couldn't find pretrained tokenizer for yelp_polarity"

tokenizer = Tokenizer.from_file(str(vocab_path))
options = QuantizedVAESamplingOptions()

print("Type 's' to generate a sample, or 'q' to quit. Type 'help' for a list of other commands.")
while True:
command = input()

if command.startswith('set '):
rest = command[4:]
parts = rest.split('=')

if len(parts) != 2:
print("Expected a command of the form 'set max_length=500'")
continue

parts = [part.strip() for part in parts]
key, value = parts

# For moving the model between devices
if key == 'gpu':
if value.lower() == 'none':
sampler = sampler.to('cpu')
print("Model moved to the CPU.")
else:
idx = try_type_conversion(value, int)
if idx is not None:
print(f"Moving model to GPU {idx}...")
sampler = sampler.to('cuda:' + str(idx))
print("Done.")

# For loading different VAE versions
elif key == 'version':
try:
new_sampler = QuantizedVAESampler.for_vae(value)
except AssertionError as ex:
print(f"Creating a sampler for VAE '{value}' failed with error: {ex}")
else:
sampler = new_sampler

# For changing the sampling options
else:
if not hasattr(options, key):
print(f"'{key}' is not a valid configuration option.")
continue

key_type = type(getattr(options, key))
value = try_type_conversion(value, key_type)
if value is not None:
setattr(options, key, value)

elif command == 's':
breakpoint()
output = sampler.sample(options)
samples = tokenizer.decode_batch(output.tolist())
for sample in samples:
print(sample)

elif command == 'q':
return

elif command == 'config':
print("Current sampling options:")
print(asdict(options))

else:
print("Not a recognized command. ")


if __name__ == "__main__":
main(sys.argv)
Loading

0 comments on commit a591a35

Please sign in to comment.