forked from 1x-technologies/1xgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
factorization_utils.py
85 lines (66 loc) · 3.44 KB
/
factorization_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
from einops import rearrange
class FactorizedEmbedding(nn.Module):
""" Each token's embedding is the sum of the embeddings in each factorized vocabulary. """
def __init__(self, config):
"""
Args:
config: Should specify `factored_vocab_size`, `d_model`, `num_factored_vocabs`, `image_vocab_size`.
E.g. genie.config.GenieConfig
"""
super().__init__()
self.factored_embeds = nn.ParameterList([nn.Embedding(config.factored_vocab_size, config.d_model)
for _ in range(config.num_factored_vocabs)])
self.mask_token_embed = nn.Parameter(torch.zeros(1, config.d_model))
self.config = config
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
"""
Args:
input_ids: Shape (B, T, H*W)
Returns:
input embeddings: Shape (B, T, H*W, d_model)
"""
# x should be (b, t, h*w)
# initialize all embeddings to the mask token embedding, and then fill in actual token embeddings
embeds = self.mask_token_embed.repeat(input_ids.size() + (1,))
is_not_mask = input_ids != self.config.image_vocab_size # `image_vocab_size = factored_vocab_size**num_factored_vocabs`
embeds[is_not_mask] = 0
input_ids = input_ids.clone()
for factored_embed in self.factored_embeds: # TODO: no for loop
embeds[is_not_mask] += factored_embed(input_ids[is_not_mask] % self.config.factored_vocab_size)
input_ids //= self.config.factored_vocab_size
return embeds
def factorize_token_ids(token_ids: torch.LongTensor, num_factored_vocabs: int = 2, factored_vocab_size: int = 512) -> torch.LongTensor:
"""
`token_ids`: any size tensor with token id values in [0, image_vocab_size = 2**18).
Returns:
Size token_ids.size() + (num_factored_vocabs,), where the last dimension has token ids in
each individual vocabulary values in [0, factored_vocab_size = 512)
"""
powers = factored_vocab_size ** torch.arange(num_factored_vocabs, device=token_ids.device)
return (token_ids.unsqueeze(-1) // powers) % factored_vocab_size
def unfactorize_token_ids(
factored_token_ids: torch.LongTensor,
num_factored_vocabs: int = 2,
factored_vocab_size: int = 512
) -> torch.LongTensor:
"""
Inverse of `factorize_token_ids`.
It is assumed that the last dimension of `factored_token_ids` is the vocbulary dimension.
Returns:
Size token_ids.size()[:-1, where the last dimension has token ids in
each individual vocabulary values in [0, factored_vocab_size = 512)
"""
powers = factored_vocab_size ** torch.arange(num_factored_vocabs, device=factored_token_ids.device)
return (factored_token_ids * powers).sum(dim=-1)
def factorize_labels(labels_THW: torch.LongTensor, num_factored_vocabs: int = 2, factored_vocab_size: int = 512) \
-> torch.LongTensor:
"""
Simply `factorize_token_ids` followed by permuting dimensions.
labels_THW: shape (B, T, H, W), values in [0, image_vocab_size=2**18)
Returns:
factored_labels: shape (B, num_factored_vocabs=2, T, H, W), values in [0, factored_vocab_size=512)
"""
factored_labels = factorize_token_ids(labels_THW, num_factored_vocabs, factored_vocab_size)
return rearrange(factored_labels, "b t h w num_factored_vocabs -> b num_factored_vocabs t h w")