Skip to content

Commit

Permalink
Last commit before ckpt BC-breaking changes
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Mar 31, 2021
1 parent 6211be8 commit f3f110f
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 10 deletions.
1 change: 0 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def main(args):
tokenizer_path = Path.cwd() / 'text-vae-pretrained' / 'tokenizers' / 'yelp_polarity.json'
tokenizer = Tokenizer.from_file(str(tokenizer_path))
outputs = tokenizer.decode_batch([x.tolist() for x in outputs])
breakpoint()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def main(args):
setattr(options, key, value)

elif command == 's':
breakpoint()
output = sampler.sample(options)
samples = tokenizer.decode_batch(output.tolist())
for sample in samples:
Expand Down
2 changes: 1 addition & 1 deletion text_vae/funnel_transformers/funnel_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def forward(self, q: Tensor, k: PaddedTensor, v: Tensor, pos_encodings = None) -
k = self.k_head(k)
v = self.v_head(v)

mask = k.padding; mask = mask[:, None, None, :] if mask is not None else None
q, k, v = (rearrange(x, '... l h d -> ... h l d') for x in (q, k, v))

if self.r_w_bias is not None:
Expand All @@ -135,7 +136,6 @@ def forward(self, q: Tensor, k: PaddedTensor, v: Tensor, pos_encodings = None) -
attn_score = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5

# Perform masking
mask = k.padding
if self.causal:
causal_mask = torch.ones(*attn_score.shape[-2:], device=attn_score.device, dtype=torch.bool).triu_(1)
mask = causal_mask if mask is None else mask | causal_mask
Expand Down
5 changes: 1 addition & 4 deletions text_vae/funnel_transformers/funnel_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def __init__(self, hparams: Union[FunnelTransformerHparams, DictConfig]):
if isinstance(hparams, FunnelTransformerHparams):
hparams = OmegaConf.structured(hparams)

if not hparams.d_embedding:
hparams.d_embedding = hparams.d_model

self.hparams = hparams

if hparams.use_convolutions:
Expand Down Expand Up @@ -73,7 +70,7 @@ def __init__(self, hparams: Union[FunnelTransformerHparams, DictConfig]):
def strides(self) -> List[int]:
scaling_factors = [1] + list(self.hparams.scaling_factors)
encoder_strides = cumprod(scaling_factors).tolist()
return encoder_strides if not self.hparams.upsampling else self.hparams.upsampling[::-1]
return encoder_strides if not self.hparams.upsampling else encoder_strides[::-1]

# Activates cross attention for the layers specified in the list of (block index, layer index) tuples
def configure_cross_attention(self, layers: List[Tuple[int, int]]):
Expand Down
2 changes: 1 addition & 1 deletion text_vae/quantized_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def forward(self, batch: Dict[str, Tensor], quantize: bool = True) -> QuantizedV

return vae_state

def decoder_block_end(self, vae_state: QuantizedVAEState, dec_state: Tensor, block_idx: int):
def decoder_block_end(self, vae_state: QuantizedVAEState, block_idx: int, dec_state: Tensor):
cross_attn_kv = vae_state.encoder_states[block_idx - 1] if block_idx > 0 else None
if block_idx >= len(vae_state.encoder_states):
return dec_state, cross_attn_kv
Expand Down
15 changes: 12 additions & 3 deletions text_vae/quantized_vae_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ def for_vae(name: str):
if not hparams_exist:
old_hparams_path.link_to(hparams_path)

vae = QuantizedVAE.load_from_checkpoint(ckpt_path, hparams_file=str(hparams_path), strict=False)
vae = QuantizedVAE.load_from_checkpoint(ckpt_path, hparams_file=str(hparams_path))
return QuantizedVAESampler(vae_name=name, vae=vae) # noqa

# We've already trained the priors, so just load them
vae = QuantizedVAE.load_from_checkpoint(ckpt_path, hparams_file=str(hparams_path), strict=False)
vae = QuantizedVAE.load_from_checkpoint(ckpt_path, hparams_file=str(hparams_path))

num_priors = len(vae.quantizers)
priors = {}
Expand Down Expand Up @@ -199,7 +199,16 @@ def sample(self, options: QuantizedVAESamplingOptions):
else:
start = None; end = None

with torch.autograd.profiler.profile(use_cuda=vae.on_gpu) if options.profile else nullcontext():
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA
],
with_stack=True, record_shapes=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(os.path.join(os.getcwd(), 'text-vae-profiling'))
) if options.profile else nullcontext()

with profiler:
logits = self._raw_sample(vae, priors, options)

if end:
Expand Down

0 comments on commit f3f110f

Please sign in to comment.