Skip to content

Commit

Permalink
Preparing to make this the new master
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 2, 2021
1 parent 72f94c2 commit e7d954c
Show file tree
Hide file tree
Showing 46 changed files with 536 additions and 1,882 deletions.
88 changes: 31 additions & 57 deletions hparam_presets.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,4 @@
# Groups of hyperparameters for reproducing papers
hparam_presets = {
'he2019': {
'data': dict(
batch_size=32,
chunking_strategy='sentence',
),
'model': dict(
adam_beta1=0.0, # 0.0
adam_beta2=0.0,
decoder_input_dropout=0.5,
decoder_output_dropout=0.5,
divide_loss_by_length=False,
grad_clip_threshold=5.0,
init_scale=0.01,
lr=1.0,
lr_plateau_patience=2,
tie_embedding_weights=False,
warmup_steps=0
),
'trainer': dict(
accumulate_grad_batches=1,
precision=32 # Diverges without this
)
},
# From https://github.com/timbmg/Sentence-VAE/
'sentence-vae': {
'data': dict(
batch_size=32,
chunking_strategy='sentence',
),
'model': dict(
d_model=256,
d_embedding=300,
divide_loss_by_length=False,
decoder_input_dropout=0.5,
init_scale=None, # Default PyTorch initialization
latent_depth=16,
lr=1e-3,
tie_embedding_weights=True,
warmup_steps=0
),
'trainer': dict(
accumulate_grad_batches=1,
precision=32
)
},
'lstm-benchmark': {
'model': dict(
bidirectional_encoder=True,
Expand All @@ -69,7 +23,6 @@
},
'lstm-wikipedia': {
'data': dict(
chunking_strategy='none',
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=50_000,
Expand Down Expand Up @@ -137,27 +90,48 @@
accumulate_grad_batches=2
)
},
'wikipedia': {
'nonvae-wikipedia': {
'data': dict(
chunking_strategy='none',
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=31_250,
# Stub articles (< 160 tokens) make up nearly 1/4 of the dataset and don't help
# the model learn long range dependencies. This way we force the model to get used
# to not having the whole document in its sliding window attention window
tokens_per_batch=50_000,
min_tokens_per_sample=160,
max_tokens_per_sample=12_500
max_tokens_per_sample=25_000
),
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_checkpointing=True,
grad_clip_threshold=5.0,
init_scale=0.02,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=128,
lr=3e-4,
num_layers=6,
sparse_self_attention=True,
tie_embedding_weights=True,
warmup_steps=1000
),
'trainer': dict(
accumulate_grad_batches=2,
val_check_interval=0.1
)
},
'wikipedia': {
'data': dict(
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=50_000,
min_tokens_per_sample=512,
max_tokens_per_sample=25_000
),
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_checkpointing=True,
grad_clip_threshold=5.0,
init_scale=0.02,
kl_weight_start=0.3,
kl_annealing_steps=4000,
latent_depth=64,
lr=3e-4,
# lr_decay_steps=1_000_000,
num_latent_vectors=1,
Expand Down
7 changes: 1 addition & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ numpy>=1.18.5
pytorch_lightning>=1.1.2
torch>=1.7.0
tokenizers>=0.9.4
tqdm>=4.49.0
omegaconf>=2.0.5
nltk>=3.5
deepspeed>=0.3.13
pynvml>=8.0.4
survae>=0.1
torchtext>=0.9.1
pyarrow>=3.0.0
pynvml>=8.0.4
89 changes: 0 additions & 89 deletions sample-vqvae.py

This file was deleted.

2 changes: 1 addition & 1 deletion sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from text_vae import *
from sparse_vae import *
import sys


Expand Down
25 changes: 25 additions & 0 deletions sparse_vae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Sort of in order of how many internal dependencies each file has
from .batch_generation import *
from .core import select_best_gpu
from .core.conditional_gaussian import *
from .core.generation import *
from .core.transformer import *
from .core.language_model import *
from .core.continuous_autoencoder import *
from .lstm_vae import *
from .lstm_language_model import *
from .text_data_module import *
from .transformer_vae import *

# Useful utility to have
from pathlib import Path

def get_checkpoint_path_for_name(experiment: str, ckpt_name: str) -> Path:
ckpt_path = Path.cwd() / 'sparse-vae-logs' / experiment / ckpt_name / "checkpoints"
try:
# Open the most recent checkpoint
ckpt = max(ckpt_path.glob('*.ckpt'), key=lambda file: file.lstat().st_mtime)
return ckpt
except ValueError:
print(f"Couldn't find checkpoint at path {ckpt_path}")
exit(1)
File renamed without changes.
3 changes: 1 addition & 2 deletions text_vae/core/__init__.py → sparse_vae/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from .activation_offload import ActivationOffloadFunction, offload
from .auto_select_gpu import select_best_gpu
from .attention import *
from .continuous_autoencoder import *
from .quantizer import *
from .conditional_gaussian import *
from .language_model import *
from .padded_tensor import PaddedTensor
from .perceiver import Perceiver
from .transformer import *
from .transformer_layer import *
from .utilities import *
Loading

0 comments on commit e7d954c

Please sign in to comment.