Skip to content

Commit

Permalink
Committing to save TSNE code I'm about to delete
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 29, 2021
1 parent e7d954c commit 70ebf9a
Show file tree
Hide file tree
Showing 16 changed files with 722 additions and 207 deletions.
27 changes: 13 additions & 14 deletions hparam_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
kl_annealing_steps=8000,
latent_depth=64,
lr=3e-4,
num_latent_vectors=1,
tie_embedding_weights=True,
tie_logit_weights=True,
transformer_encoder=False,
Expand All @@ -26,20 +25,20 @@
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=50_000,
max_tokens_per_sample=12_500
min_tokens_per_sample=512,
max_tokens_per_sample=25_000
),
'model': dict(
bidirectional_encoder=True,
divide_loss_by_length=True,
d_model=1024,
d_model=2048,
d_embedding=512,
grad_clip_threshold=5.0,
init_scale=None,
kl_weight_start=0.2,
kl_annealing_steps=8000,
kl_weight_start=1.0,
kl_annealing_steps=0,
latent_depth=64,
lr=3e-4,
num_latent_vectors=1,
tie_embedding_weights=True,
tie_logit_weights=True,
transformer_encoder=False,
Expand All @@ -60,7 +59,6 @@
kl_annealing_steps=8000,
latent_depth=64,
lr=3e-4,
num_latent_vectors=1,
num_layers=3,
sparse_self_attention=False,
tie_embedding_weights=True,
Expand All @@ -78,9 +76,9 @@
init_scale=0.02,
kl_weight_start=0.2,
kl_annealing_steps=8000,
latent_depth=64,
latent_depth=512,
lr=3e-4,
num_latent_vectors=1,
num_samples=1,
num_layers=3,
sparse_self_attention=True,
tie_embedding_weights=True,
Expand All @@ -95,7 +93,7 @@
dataset_name='wikipedia',
dataset_config='20200501.en',
tokens_per_batch=50_000,
min_tokens_per_sample=160,
min_tokens_per_sample=512,
max_tokens_per_sample=25_000
),
'model': dict(
Expand Down Expand Up @@ -127,14 +125,15 @@
divide_loss_by_length=True,
d_model=512,
grad_checkpointing=True,
grad_clip_threshold=5.0,
grad_clip_threshold=150.0,
init_scale=0.02,
kl_weight_start=0.3,
kl_annealing_steps=4000,
kl_weight_start=1.0,
kl_weight_end=1.0,
kl_annealing_steps=0,
latent_depth=64,
# mmd=True,
lr=3e-4,
# lr_decay_steps=1_000_000,
num_latent_vectors=1,
num_layers=6,
sparse_self_attention=True,
tie_embedding_weights=True,
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ torch>=1.7.0
tokenizers>=0.9.4
omegaconf>=2.0.5
deepspeed>=0.3.13
pynvml>=8.0.4
pynvml>=8.0.4
torchtext>=0.9.1
tqdm>=4.49.0
1 change: 1 addition & 0 deletions sparse_vae/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .continuous_autoencoder import *
from .conditional_gaussian import *
from .language_model import *
from .math_utils import *
from .padded_tensor import PaddedTensor
from .perceiver import Perceiver
from .transformer import *
Expand Down
108 changes: 46 additions & 62 deletions sparse_vae/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ def __init__(
num_heads: int,
causal = False,
sparse: Union[bool, SlidingWindowSparsityConfig] = False,
learned_queries: int = None
learned_queries: int = None,
# num_special_tokens: int = 1 # Number of leading special tokens exempted from causal masking
):
super().__init__()

self.causal = causal
self.d_model = d_model
self.num_heads = num_heads
# self.num_special_tokens = num_special_tokens
assert d_model % num_heads == 0, "num_heads must divide d_model evenly"

# We can use learned queries to extract a single vector or a fixed size sequence of vectors out of a sequence
Expand All @@ -36,7 +38,6 @@ def __init__(
self.output_linear = nn.Linear(d_model, d_model)

self.cache_index = 0
self.precomputed_kv = False
self.key_cache = None
self.value_cache = None

Expand All @@ -45,7 +46,6 @@ def __init__(
if isinstance(sparse, SlidingWindowSparsityConfig):
config = sparse
else:
# config = FixedSparsityConfig(num_heads=num_heads, attention='unidirectional')
config = SlidingWindowSparsityConfig(num_heads=num_heads, window_size=4)

self.sparse_attention = SparseSelfAttention(
Expand All @@ -65,34 +65,31 @@ def forward(self, q: Optional[Tensor], k: PaddedTensor, v: Tensor):
q = q + positional_encodings_like(q, self.cache_index)
q = self.q_linear(q)

# This will be True only when we're using cached keys and values with cross attention
if self.precomputed_kv:
k, v = self.key_cache, self.value_cache

# Normal case
else:
k = k + positional_encodings_like(k, self.cache_index)

k, v = self.k_linear(k), self.v_linear(v)
if self.kv_cache_length:
k, v = self._update_kv_cache(k, v)
k = k + positional_encodings_like(k, self.cache_index)
k, v = self.k_linear(k), self.v_linear(v)
if self.kv_cache_length:
k, v = self._update_kv_cache(k, v)

mask = getattr(k, 'padding', None)
q, k, v = (rearrange(x, '... l (h d) -> ... h l d', h=self.num_heads) for x in (q, k, v))
if self.sparse_attention and self.key_cache is None:

if self.causal and self.key_cache is None:
q_len = q.shape[-2]
causal_mask = torch.ones(q_len, q_len, device=q.device, dtype=torch.bool).triu(1)
else:
causal_mask = None

if self.sparse_attention and self.key_cache is None:
mask = mask * -1e7 if mask is not None else None # DeepSpeed *adds* this mask to the attn scores
attn_mask = torch.ones(q_len, q_len, device=q.device, dtype=torch.bool).triu(1) * -1e7 if self.causal else None
output = self.sparse_attention(q, k, v, attn_mask=attn_mask, key_padding_mask=mask)
causal_mask = causal_mask * -1e7 if causal_mask is not None else None
output = self.sparse_attention(q, k, v, attn_mask=causal_mask, key_padding_mask=mask)
else:
scores = q @ k.transpose(-1, -2) * k.shape[-1] ** -0.5

# Note that we only apply the upper triangular causal mask during training;
# during autoregressive decoding there's no "right context" to mask out
mask = mask[..., None, None, :] if mask is not None and mask.ndim >= 2 else mask
if self.causal and self.key_cache is None:
causal_mask = torch.ones(*scores.shape[-2:], device=scores.device, dtype=torch.bool).triu(1)
if causal_mask is not None:
mask = mask | causal_mask if mask is not None else causal_mask

if mask is not None:
Expand All @@ -109,50 +106,38 @@ def _update_kv_cache(self, k, v) -> Tuple[Tensor, Tensor]:
# Register for automatic cache cleanup when we exit the cached kv context
self.live_attention_modules.add(self)

# We're being fed new keys and values one token at a time- self-attention case
if k.shape[-2] == 1:
# When we're using sparse attention, we only need to cache the keys and values that are
# actually going to be attended to
if self.sparse_attention:
config = self.sparse_attention.sparsity_config
num_blocks = config.window_size + int(config.include_cls)
block_size = config.block
cache_length, cache_offset = num_blocks * block_size, int(config.include_cls) * block_size
else:
cache_length, cache_offset = self.kv_cache_length, 0
block_size = 0

if self.key_cache is None:
if k.shape[0] == 1:
breakpoint()
self.key_cache = k.new_zeros([k.shape[0], cache_length, k.shape[-1]])
self.value_cache = v.new_zeros([v.shape[0], cache_length, v.shape[-1]])

# We've overshot the kv cache size
if self.sparse_attention and self.cache_index >= cache_length:
local_index = self.cache_index % block_size
kv_index = cache_length - block_size + local_index

# Shift the kv cache leftward by one block, discarding the leftmost one
if local_index == 0:
self.key_cache[:, cache_offset:kv_index] = self.key_cache[:, cache_offset + block_size:].clone()
self.value_cache[:, cache_offset:kv_index] = self.value_cache[:, cache_offset + block_size:].clone()
else:
kv_index = self.cache_index

self.key_cache[:, kv_index] = k.squeeze(-2)
self.value_cache[:, kv_index] = v.squeeze(-2)
self.cache_index += 1
# When we're using sparse attention, we only need to cache the keys and values that are
# actually going to be attended to
if self.sparse_attention:
config = self.sparse_attention.sparsity_config
num_blocks = config.window_size + int(config.include_cls)
block_size = config.block
cache_length, cache_offset = num_blocks * block_size, int(config.include_cls) * block_size
else:
cache_length, cache_offset = self.kv_cache_length, 0
block_size = 0

if self.key_cache is None:
self.key_cache = k.new_zeros([k.shape[0], cache_length, k.shape[-1]])
self.value_cache = v.new_zeros([v.shape[0], cache_length, v.shape[-1]])

# We've overshot the kv cache size
if self.sparse_attention and self.cache_index >= cache_length:
local_index = self.cache_index % block_size
kv_index = cache_length - block_size + local_index

# Shift the kv cache leftward by one block, discarding the leftmost one
if local_index == 0:
self.key_cache[:, cache_offset:kv_index] = self.key_cache[:, cache_offset + block_size:].clone()
self.value_cache[:, cache_offset:kv_index] = self.value_cache[:, cache_offset + block_size:].clone()
else:
kv_index = self.cache_index

return self.key_cache[:, :kv_index + 1], self.value_cache[:, :kv_index + 1]
self.key_cache[:, kv_index] = k.squeeze(-2)
self.value_cache[:, kv_index] = v.squeeze(-2)
self.cache_index += 1

# We're being fed a bunch of keys and values all at once- cross-attention case.
else:
breakpoint()
self.key_cache = k
self.value_cache = v
self.precomputed_kv = True
return k, v
return self.key_cache[:, :kv_index + 1], self.value_cache[:, :kv_index + 1]

# Class variables
kv_cache_length: int = None
Expand All @@ -169,7 +154,6 @@ def kv_cache(cls, max_seq_length: int):
cls.kv_cache_length = None
for module in cls.live_attention_modules:
module.cache_index = 0
module.precomputed_kv = False
module.key_cache = None
module.value_cache = None

Expand Down
2 changes: 1 addition & 1 deletion sparse_vae/core/auto_select_gpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Uses pynvml to select the index of the least-used GPU with sufficient free memory.
# The `min_free_memory` argument is interpreted in gigabytes
def select_best_gpu(min_free_memory: float = 40.0) -> int:
def select_best_gpu(min_free_memory: float = 35.0) -> int:
from pynvml import (
nvmlInit,
nvmlDeviceGetCount,
Expand Down
11 changes: 5 additions & 6 deletions sparse_vae/core/conditional_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@ def __init__(self, in_features: int, out_features: int, zero_initialized: bool =

self.linear = linear

def forward(self, x: Tensor, temperature: float = 1.0, get_kl: bool = False) -> Union[Normal, Tuple[Normal, Tensor]]:
mu, logsigma = self.linear(x).chunk(2, dim=-1)
sigma = logsigma.exp()
def forward(self, x: Tensor, get_kl: bool = False) -> Union[Normal, Tuple[Normal, Tensor]]:
mu, logvar = self.linear(x).chunk(2, dim=-1)
var = logvar.exp()

# We do NOT validate the parameters here because this raises an error if any of the sigma values
# are exactly zero. This should yield an infinite KL divergence and therefore an infinite loss,
# but the AMP grad scaler will take care of that.
gaussian = Normal(loc=mu, scale=sigma * temperature, validate_args=False)
gaussian = Normal(loc=mu, scale=var.sqrt(), validate_args=False)
if get_kl:
# Analytical formula for the KL divergence p -> q, where p is a standard unit variance Gaussian.
kl = -0.5 + logsigma + 0.5 * (1.0 + mu ** 2) / (sigma ** 2)
kl = 0.5 * (mu ** 2 + var - logvar - 1.0)
return gaussian, kl
else:
return gaussian
26 changes: 17 additions & 9 deletions sparse_vae/core/continuous_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from abc import abstractmethod
from contextlib import contextmanager

from torch.distributions import Normal
from .language_model import *
from .math_utils import pairwise_gaussian_kl
import math


Expand Down Expand Up @@ -57,8 +59,8 @@ def sample_z(self, encoder_out: Tensor, token_counts: Tensor, stage: str = 'trai

log_prefix = stage + '_'
self.log_dict({
log_prefix + 'kl': raw_kl.mean(),
log_prefix + 'mutual_info': self.estimate_mutual_info(q_of_z, z)
log_prefix + 'kl': raw_kl.mean()
# log_prefix + 'mutual_info': self.estimate_mutual_info(q_of_z)
})
if stage == 'val':
self.log('val_active_units', q_of_z.mean, reduce_fx=self.compute_num_active_units)
Expand Down Expand Up @@ -153,13 +155,19 @@ def estimate_log_prob_iw(self, q_of_z: Normal, x: Tensor, labels: Tensor, num_sa
return torch.cat(log_ws).logsumexp(dim=0) - math.log(num_samples)

@staticmethod
def estimate_mutual_info(conditional_q: Normal, z: Tensor):
unsqueezed = Normal(loc=conditional_q.loc[:, None], scale=conditional_q.scale[:, None], validate_args=False)
cross_densities = unsqueezed.log_prob(z[None]).sum(dim=-1) # [batch, batch]

# Approximate q(z) by averaging over the densities assigned to each z by the other q(z|x)s in the minibatch
marginal_q = cross_densities.logsumexp(dim=0) - math.log(cross_densities.shape[0])
return -conditional_q.entropy().sum(dim=-1).mean() - marginal_q.mean()
def estimate_mutual_info(posteriors: Normal):
# Mutual info = E_p(x)[KL(q(z|x)|q(z))]. Since q(z) = E_p(x)[q(z|x)], we can view each minibatch of
# posteriors as a mixture of Gaussians with equal component weights that approximates q(z). While there
# is no closed form solution for the KL P -> Q where Q is a Gaussian mixture, there is an analytic formula
# for an upper bound on this quantity, and we use it here.
neg_pairwise_kl = -pairwise_gaussian_kl(posteriors)
batch = neg_pairwise_kl.shape[1]
lower = -neg_pairwise_kl.logsumexp(dim=-1).mean() - math.log(batch)

tiny = torch.finfo(neg_pairwise_kl.dtype).min
loo_neg_pairwise_kl = neg_pairwise_kl.clone().fill_diagonal_(tiny)
upper = -loo_neg_pairwise_kl.logsumexp(dim=-1).mean() - math.log(batch - 1)
return (lower + upper) / 2.0

# Returns log p(x|z) summed (not averaged!) over the sequence length, for each element in the batch
def p_of_x_given_z(self, x, z, labels) -> Tensor:
Expand Down
Loading

0 comments on commit 70ebf9a

Please sign in to comment.