-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
70ebf9a
commit 894de25
Showing
33 changed files
with
1,924 additions
and
932 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from sparse_vae import * | ||
from datasets import Dataset | ||
from itertools import chain | ||
import sys | ||
|
||
|
||
def main(args): | ||
model_str, model_name = args[1:] | ||
model = load_checkpoint_for_name(model_str, model_name) | ||
model.freeze() | ||
|
||
device = 'cuda:' + str(select_best_gpu()) | ||
model = model.to(device) | ||
|
||
hparams = TextDataModuleHparams() | ||
data = TextDataModule(hparams) | ||
data.prepare_data() | ||
data.setup('predict') | ||
|
||
dls = data.predict_dataloader() | ||
total = sum(len(dl) for dl in dls) | ||
pbar = tqdm( | ||
chain.from_iterable(dls), desc="Gathering latents", unit='batch', total=total | ||
) | ||
latents, scales, titles = [], [], [] | ||
for i, batch in enumerate(pbar): | ||
q_of_z = model.predict({k: v.to(device) for k, v in batch.items() if isinstance(v, Tensor)}, i) | ||
mean, scale = q_of_z.mean, q_of_z.scale | ||
latents.extend(mean.cpu().numpy().squeeze()) | ||
scales.extend(scale.cpu().numpy().squeeze()) | ||
titles.extend([''.join(x) for x in batch['title']]) | ||
if i >= total: | ||
pbar.close() | ||
break | ||
|
||
print("Saving to disk...") | ||
save_path = Path.cwd() / 'sparse-vae-datasets' / 'latents' / model_str / model_name | ||
dataset = Dataset.from_dict({'title': titles, 'latent': latents, 'scale': scales}) | ||
dataset.save_to_disk(str(save_path)) | ||
print("Done.") | ||
|
||
|
||
if __name__ == "__main__": | ||
main(sys.argv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from sparse_vae import * | ||
from datasets import Dataset | ||
from torch.distributions import Normal, kl_divergence | ||
import torch.nn.functional as F | ||
import sys | ||
|
||
|
||
def main(args): | ||
model_str, model_name = args[1:] | ||
save_path = Path.cwd() / 'sparse-vae-datasets' / 'latents' / model_str / model_name | ||
dataset = Dataset.load_from_disk(str(save_path)) | ||
titles = {title: idx for idx, title in enumerate(dataset['title'])} | ||
|
||
gpu_idx = select_best_gpu(min_free_memory=4.0) | ||
dataset.set_format('torch', device=gpu_idx) | ||
posteriors = Normal(loc=dataset['latent'], scale=dataset['scale']) | ||
dataset.reset_format() | ||
|
||
print("Type the title of an article to get the nearest neighbors. Type q to quit.") | ||
while (query := input("Article: ")) != 'q': | ||
article_idx = titles.get(query) | ||
|
||
if article_idx is None: | ||
print("No article found with that title. Try again.") | ||
else: | ||
posterior = Normal(loc=posteriors.loc[article_idx], scale=posteriors.scale[article_idx]) | ||
|
||
print("\nL2 distance of means:") | ||
dists = torch.sum((posterior.mean - posteriors.mean) ** 2, dim=-1) | ||
dists, hit_indices = dists.topk(10, largest=False) | ||
|
||
# HF docs guarantee this will return a dictionary when passed a NumPy array like this | ||
hits = cast(Dict[str, List[str]], dataset[hit_indices.cpu().numpy()]) | ||
|
||
max_len = max(len(x) for x in hits['title']) | ||
for dist, title in zip(dists, hits['title']): | ||
print(title + " " * (max_len - len(title)) + f" - {dist}") | ||
|
||
print("\nCosine similarity:") | ||
affinities = F.cosine_similarity(posterior.mean[None], posteriors.mean).squeeze() | ||
dists, hit_indices = affinities.topk(10, largest=True) | ||
hits = cast(Dict[str, List[str]], dataset[hit_indices.cpu().numpy()]) | ||
|
||
max_len = max(len(x) for x in hits['title']) | ||
for dist, title in zip(dists, hits['title']): | ||
print(title + " " * (max_len - len(title)) + f" - {dist}") | ||
|
||
print("\nKL divergence:") | ||
kls = kl_divergence(posterior, posteriors).sum(dim=-1) | ||
dists, hit_indices = kls.topk(10, largest=False) | ||
dists = dists.cpu().numpy() | ||
hits = cast(Dict[str, List[str]], dataset[hit_indices.cpu().numpy()]) | ||
|
||
max_len = max(len(x) for x in hits['title']) | ||
for dist, title in zip(dists, hits['title']): | ||
print(title + " " * (max_len - len(title)) + f" - {dist}") | ||
|
||
print('\n') | ||
|
||
|
||
if __name__ == "__main__": | ||
main(sys.argv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from sparse_vae import * | ||
from datasets import concatenate_datasets | ||
import sys | ||
|
||
|
||
def main(args): | ||
model_str, model_name = args[1:] | ||
model = load_checkpoint_for_name(model_str, model_name) | ||
model.freeze() | ||
model.eval() | ||
|
||
dm = TextDataModule(TextDataModuleHparams()) | ||
dm.prepare_data() | ||
dataset, tokenizer = dm.dataset, dm.tokenizer | ||
dataset = concatenate_datasets([dataset['train'], dataset['test']]) | ||
titles = {title: idx for idx, title in enumerate(dataset['title'])} | ||
gpu_idx = select_best_gpu(min_free_memory=4.0) | ||
model = model.to(gpu_idx) | ||
|
||
print("Type the title of an article to get a reconstruction. Type q to quit.\nType i to switch to interpolation mode.") | ||
while True: | ||
query = input("Article: ") | ||
if query == 'q': | ||
return | ||
|
||
article_idx = titles.get(query) | ||
if article_idx is None: | ||
print("No article found with that title. Try again.") | ||
else: | ||
text = dataset[article_idx]['text'] | ||
latent = model.predict({'token_ids': torch.tensor([text], device=gpu_idx)}, 0).loc | ||
reconstruction = model.sample(1024, 1, z=latent, temperature=0.7) | ||
reconstruction = tokenizer.decode(reconstruction.squeeze().tolist()) | ||
print("Reconstruction:\n\n" + reconstruction) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(sys.argv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
datasets>=1.1.3 | ||
einops>=0.3.0 | ||
numpy>=1.18.5 | ||
pytorch_lightning>=1.1.2 | ||
pytorch_lightning>=1.3 | ||
torch>=1.7.0 | ||
tokenizers>=0.9.4 | ||
transformers>=4.0 | ||
omegaconf>=2.0.5 | ||
deepspeed>=0.3.13 | ||
pynvml>=8.0.4 | ||
torchtext>=0.9.1 | ||
tqdm>=4.49.0 | ||
tqdm>=4.49.0 | ||
matplotlib>=3.4.2 | ||
triton>=1.0.0.dev20210625 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.