Skip to content

Commit

Permalink
Ungodly number of changes
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Sep 26, 2021
1 parent 70ebf9a commit 894de25
Show file tree
Hide file tree
Showing 33 changed files with 1,924 additions and 932 deletions.
44 changes: 44 additions & 0 deletions gather_latents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from sparse_vae import *
from datasets import Dataset
from itertools import chain
import sys


def main(args):
model_str, model_name = args[1:]
model = load_checkpoint_for_name(model_str, model_name)
model.freeze()

device = 'cuda:' + str(select_best_gpu())
model = model.to(device)

hparams = TextDataModuleHparams()
data = TextDataModule(hparams)
data.prepare_data()
data.setup('predict')

dls = data.predict_dataloader()
total = sum(len(dl) for dl in dls)
pbar = tqdm(
chain.from_iterable(dls), desc="Gathering latents", unit='batch', total=total
)
latents, scales, titles = [], [], []
for i, batch in enumerate(pbar):
q_of_z = model.predict({k: v.to(device) for k, v in batch.items() if isinstance(v, Tensor)}, i)
mean, scale = q_of_z.mean, q_of_z.scale
latents.extend(mean.cpu().numpy().squeeze())
scales.extend(scale.cpu().numpy().squeeze())
titles.extend([''.join(x) for x in batch['title']])
if i >= total:
pbar.close()
break

print("Saving to disk...")
save_path = Path.cwd() / 'sparse-vae-datasets' / 'latents' / model_str / model_name
dataset = Dataset.from_dict({'title': titles, 'latent': latents, 'scale': scales})
dataset.save_to_disk(str(save_path))
print("Done.")


if __name__ == "__main__":
main(sys.argv)
104 changes: 68 additions & 36 deletions hparam_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
'lstm-benchmark': {
'model': dict(
bidirectional_encoder=True,
divide_loss_by_length=True,
d_model=1024,
d_embedding=512,
grad_clip_threshold=5.0,
grad_clip_threshold=150.0,
init_scale=None,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=64,
lr=3e-4,
tie_embedding_weights=True,
tie_logit_weights=True,
transformer_encoder=False,
warmup_steps=500
transformer_encoder=False
),
'trainer': dict(
accumulate_grad_batches=2
Expand All @@ -30,59 +28,69 @@
),
'model': dict(
bidirectional_encoder=True,
divide_loss_by_length=True,
d_model=2048,
d_embedding=512,
grad_clip_threshold=5.0,
grad_clip_threshold=150.0,
init_scale=None,
kl_weight_start=1.0,
kl_annealing_steps=0,
latent_depth=64,
lr=3e-4,
tie_embedding_weights=True,
tie_logit_weights=True,
transformer_encoder=False,
warmup_steps=500
transformer_encoder=False
),
'trainer': dict(
accumulate_grad_batches=2,
val_check_interval=0.25
)
},
'dense-benchmark': {
'data': dict(
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=50_000,
min_tokens_per_sample=512,
max_tokens_per_sample=3_125
),
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_clip_threshold=5.0,
grad_checkpointing=True,
grad_clip_threshold=150.0,
init_scale=0.02,
kl_weight_start=0.2,
kl_weight_start=0.3,
kl_weight_end=1.0,
kl_annealing_steps=8000,
latent_depth=64,
lr=3e-4,
num_layers=3,
num_layers=6,
sparse_self_attention=False,
tie_embedding_weights=True,
warmup_steps=500
tie_embedding_weights=True
),
'trainer': dict(
accumulate_grad_batches=2
)
},
'sparse-benchmark': {
'data': dict(
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=50_000,
min_tokens_per_sample=512,
max_tokens_per_sample=3_125
),
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_clip_threshold=5.0,
grad_checkpointing=True,
grad_clip_threshold=150.0,
init_scale=0.02,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=512,
kl_weight_start=1.0,
kl_annealing_steps=0,
latent_depth=64,
lr=3e-4,
num_samples=1,
num_layers=3,
num_layers=6,
sparse_self_attention=True,
tie_embedding_weights=True,
warmup_steps=500
tie_embedding_weights=True
),
'trainer': dict(
accumulate_grad_batches=2
Expand All @@ -94,19 +102,17 @@
dataset_config='20200501.en',
tokens_per_batch=50_000,
min_tokens_per_sample=512,
max_tokens_per_sample=25_000
max_tokens_per_sample=3_125
),
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_checkpointing=True,
grad_clip_threshold=5.0,
grad_clip_threshold=150.0,
init_scale=0.02,
lr=3e-4,
num_layers=6,
sparse_self_attention=True,
tie_embedding_weights=True,
warmup_steps=1000
sparse_self_attention=False,
tie_embedding_weights=True
),
'trainer': dict(
accumulate_grad_batches=2,
Expand All @@ -122,26 +128,52 @@
max_tokens_per_sample=25_000
),
'model': dict(
divide_loss_by_length=True,
d_model=512,
grad_checkpointing=True,
grad_clip_threshold=150.0,
init_scale=0.02,
kl_weight_start=1.0,
attn_window_size=12,
kl_weight_start=0.2,
kl_weight_end=1.0,
kl_annealing_steps=0,
kl_annealing_steps=8000,
latent_depth=64,
# mmd=True,
lr=3e-4,
# lr_decay_steps=1_000_000,
num_layers=6,
sparse_self_attention=True,
tie_embedding_weights=True,
warmup_steps=1000
tie_embedding_weights=True
),
'trainer': dict(
accumulate_grad_batches=2,
val_check_interval=0.1
)
},
'pg19': {
'data': dict(
dataset_name='pg19',
dataset_config=None,
tokens_per_batch=55_296,
min_tokens_per_sample=512,
max_tokens_per_sample=55_296
),
'model': dict(
# adam_beta1=0.95,
d_model=512,
grad_checkpointing=True,
grad_clip_threshold=150.0,
init_scale=0.02,
attn_window_size=16,
kl_weight_start=0.3,
kl_weight_end=1.0,
kl_annealing_steps=8000,
latent_depth=64,
lr=1e-3,
num_layers=6,
sparse_self_attention=True,
tie_embedding_weights=True
),
'trainer': dict(
accumulate_grad_batches=4,
val_check_interval=0.5
)
},
}
62 changes: 62 additions & 0 deletions knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from sparse_vae import *
from datasets import Dataset
from torch.distributions import Normal, kl_divergence
import torch.nn.functional as F
import sys


def main(args):
model_str, model_name = args[1:]
save_path = Path.cwd() / 'sparse-vae-datasets' / 'latents' / model_str / model_name
dataset = Dataset.load_from_disk(str(save_path))
titles = {title: idx for idx, title in enumerate(dataset['title'])}

gpu_idx = select_best_gpu(min_free_memory=4.0)
dataset.set_format('torch', device=gpu_idx)
posteriors = Normal(loc=dataset['latent'], scale=dataset['scale'])
dataset.reset_format()

print("Type the title of an article to get the nearest neighbors. Type q to quit.")
while (query := input("Article: ")) != 'q':
article_idx = titles.get(query)

if article_idx is None:
print("No article found with that title. Try again.")
else:
posterior = Normal(loc=posteriors.loc[article_idx], scale=posteriors.scale[article_idx])

print("\nL2 distance of means:")
dists = torch.sum((posterior.mean - posteriors.mean) ** 2, dim=-1)
dists, hit_indices = dists.topk(10, largest=False)

# HF docs guarantee this will return a dictionary when passed a NumPy array like this
hits = cast(Dict[str, List[str]], dataset[hit_indices.cpu().numpy()])

max_len = max(len(x) for x in hits['title'])
for dist, title in zip(dists, hits['title']):
print(title + " " * (max_len - len(title)) + f" - {dist}")

print("\nCosine similarity:")
affinities = F.cosine_similarity(posterior.mean[None], posteriors.mean).squeeze()
dists, hit_indices = affinities.topk(10, largest=True)
hits = cast(Dict[str, List[str]], dataset[hit_indices.cpu().numpy()])

max_len = max(len(x) for x in hits['title'])
for dist, title in zip(dists, hits['title']):
print(title + " " * (max_len - len(title)) + f" - {dist}")

print("\nKL divergence:")
kls = kl_divergence(posterior, posteriors).sum(dim=-1)
dists, hit_indices = kls.topk(10, largest=False)
dists = dists.cpu().numpy()
hits = cast(Dict[str, List[str]], dataset[hit_indices.cpu().numpy()])

max_len = max(len(x) for x in hits['title'])
for dist, title in zip(dists, hits['title']):
print(title + " " * (max_len - len(title)) + f" - {dist}")

print('\n')


if __name__ == "__main__":
main(sys.argv)
38 changes: 38 additions & 0 deletions reconstruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from sparse_vae import *
from datasets import concatenate_datasets
import sys


def main(args):
model_str, model_name = args[1:]
model = load_checkpoint_for_name(model_str, model_name)
model.freeze()
model.eval()

dm = TextDataModule(TextDataModuleHparams())
dm.prepare_data()
dataset, tokenizer = dm.dataset, dm.tokenizer
dataset = concatenate_datasets([dataset['train'], dataset['test']])
titles = {title: idx for idx, title in enumerate(dataset['title'])}
gpu_idx = select_best_gpu(min_free_memory=4.0)
model = model.to(gpu_idx)

print("Type the title of an article to get a reconstruction. Type q to quit.\nType i to switch to interpolation mode.")
while True:
query = input("Article: ")
if query == 'q':
return

article_idx = titles.get(query)
if article_idx is None:
print("No article found with that title. Try again.")
else:
text = dataset[article_idx]['text']
latent = model.predict({'token_ids': torch.tensor([text], device=gpu_idx)}, 0).loc
reconstruction = model.sample(1024, 1, z=latent, temperature=0.7)
reconstruction = tokenizer.decode(reconstruction.squeeze().tolist())
print("Reconstruction:\n\n" + reconstruction)


if __name__ == "__main__":
main(sys.argv)
8 changes: 5 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
datasets>=1.1.3
einops>=0.3.0
numpy>=1.18.5
pytorch_lightning>=1.1.2
pytorch_lightning>=1.3
torch>=1.7.0
tokenizers>=0.9.4
transformers>=4.0
omegaconf>=2.0.5
deepspeed>=0.3.13
pynvml>=8.0.4
torchtext>=0.9.1
tqdm>=4.49.0
tqdm>=4.49.0
matplotlib>=3.4.2
triton>=1.0.0.dev20210625
21 changes: 20 additions & 1 deletion sparse_vae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .core import select_best_gpu
from .core.conditional_gaussian import *
from .core.generation import *
from .core.transformer import *
from .core.transformer_language_model import *
from .core.language_model import *
from .core.continuous_autoencoder import *
from .lstm_vae import *
Expand All @@ -23,3 +23,22 @@ def get_checkpoint_path_for_name(experiment: str, ckpt_name: str) -> Path:
except ValueError:
print(f"Couldn't find checkpoint at path {ckpt_path}")
exit(1)

def load_checkpoint_for_name(experiment: str, ckpt_name: str):
if experiment == 'lstm-vae':
model_class = LSTMVAE
elif experiment == 'lstm-lm':
model_class = LSTMLanguageModel
elif experiment == 'transformer-lm':
model_class = TransformerLanguageModel
elif experiment == 'transformer-vae':
model_class = TransformerVAE
else:
print(f"Unrecognized model type '{experiment}'.")
return

path = get_checkpoint_path_for_name(experiment, ckpt_name)
model = model_class.load_from_checkpoint(path)
model.start_token = 2
model.end_token = 3
return model
Loading

0 comments on commit 894de25

Please sign in to comment.