Skip to content

Commit

Permalink
make sure last token can still predict eos
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 7, 2021
1 parent 7582af4 commit c32f724
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
self.text_seq_len = text_seq_len
self.image_seq_len = image_seq_len

seq_len = text_seq_len + image_seq_len
total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS
self.total_tokens = total_tokens

Expand All @@ -271,7 +272,7 @@ def __init__(
nn.Linear(dim, self.total_tokens),
)

seq_range = torch.arange(text_seq_len + image_seq_len)
seq_range = torch.arange(seq_len)
logits_range = torch.arange(total_tokens)

seq_range = rearrange(seq_range, 'n -> () n ()')
Expand All @@ -280,7 +281,7 @@ def __init__(
logits_mask = (
((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
(logits_range >= (total_tokens - 1))
((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))
)

self.register_buffer('logits_mask', logits_mask)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.15',
version = '0.0.16',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c32f724

Please sign in to comment.