forked from JongYun-Kim/lazy_flocking_rl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3d5035d
commit 18f5cf0
Showing
13 changed files
with
1,082 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import copy | ||
import torch.nn as nn | ||
|
||
|
||
class Decoder(nn.Module): | ||
|
||
def __init__(self, decoder_block, n_layer, norm): | ||
super(Decoder, self).__init__() | ||
self.n_layer = n_layer | ||
self.layers = nn.ModuleList([copy.deepcopy(decoder_block) for _ in range(self.n_layer)]) | ||
self.norm = norm if norm is not None else nn.Identity() # a placeholder; may break backward compatibility | ||
# If possible use nn.Identity() instead of None as this way is more readable | ||
|
||
def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask): | ||
out = tgt | ||
for layer in self.layers: | ||
out = layer(out, encoder_out, tgt_mask, src_tgt_mask) | ||
out = self.norm(out) | ||
return out # shape: (batch_size, tgt_seq_len, d_embed) | ||
|
||
|
||
class DecoderPlaceholder(nn.Module): | ||
|
||
def __init__(self, decoder_block=None, n_layer=None, norm=None, *args, **kwargs): | ||
super(DecoderPlaceholder, self).__init__() | ||
# We don't store or use the provided arguments, | ||
# but we accept them to ensure interface compatibility. | ||
|
||
def forward(self, tgt, encoder_out=None, tgt_mask=None, src_tgt_mask=None, *args, **kwargs): | ||
return tgt # Just returning the input tgt as is | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import copy | ||
import torch | ||
import torch.nn as nn | ||
|
||
from transformer_modules.residual_connection_layer import ResidualConnectionLayer, IdentityResidualLayer, \ | ||
NoResidualButSameForward | ||
|
||
|
||
class DecoderBlock(nn.Module): | ||
|
||
def __init__(self, self_attention, cross_attention, position_ff, norm, dr_rate=0): | ||
super(DecoderBlock, self).__init__() | ||
self.self_attention = self_attention | ||
self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
self.cross_attention = cross_attention | ||
self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
self.position_ff = position_ff | ||
self.residual3 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
|
||
def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask): | ||
out = tgt | ||
out = self.residual1(out, lambda out: self.self_attention(query=out, key=out, value=out, mask=tgt_mask)) | ||
out = self.residual2(out, lambda out: self.cross_attention(query=out, key=encoder_out, value=encoder_out, mask=src_tgt_mask)) | ||
out = self.residual3(out, self.position_ff) | ||
return out | ||
|
||
|
||
class CustomDecoderBlock(nn.Module): | ||
|
||
def __init__(self, cross_attention, norm, self_attention=None, position_ff=None, dr_rate=0, efficient=False): | ||
super(CustomDecoderBlock, self).__init__() | ||
# Initialize ResidualConnectionLayers | ||
if norm is None: | ||
# If norm==None/nn.Identity() and efficient==True, do consider dropping out the dropout layers | ||
norm = nn.Identity() | ||
|
||
if efficient: | ||
# self.residual1 = NoResidualButSameForward() | ||
self.residual2 = NoResidualButSameForward() | ||
self.residual3 = NoResidualButSameForward() | ||
else: | ||
# self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
self.residual3 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
|
||
# Initialize self-attention layer and cross-attention layer | ||
self.self_attention = self_attention | ||
if self.self_attention is not None: print("[DecoderBlock] self_attention is NOT None, but not used here!!!!!!!") | ||
if cross_attention is None: | ||
raise ValueError("cross_attention is None; you must use a cross attention in this implementation") | ||
self.cross_attention = cross_attention | ||
# Initialize position-wise feed-forward network | ||
self.position_ff = position_ff | ||
|
||
def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask): | ||
# tgt: (batch_size, tgt_seq_len, d_embed_context) | ||
# encoder_out: (batch_size, src_seq_len, d_embed_input) | ||
# tgt_mask: (batch_size, 1, tgt_seq_len, tgt_seq_len) | ||
# src_tgt_mask: (batch_size, 1, tgt_seq_len==1, src_seq_len) | ||
|
||
# MHA layer with query as the output of the first MHA layer | ||
# Shape: (batch_size, tgt_seq_len, d_model) | ||
tgt = self.residual2(tgt, lambda tgt: self.cross_attention(query=tgt, key=encoder_out, value=encoder_out, | ||
mask=src_tgt_mask)) | ||
# Position-wise feed-forward network, applied only if include_ffn is True | ||
# Shape: (batch_size, tgt_seq_len, d_model) | ||
if self.position_ff is not None: | ||
tgt = self.residual3(tgt, self.position_ff) | ||
|
||
# Return the output tensor | ||
# Shape: (batch_size, tgt_seq_len==1, d_embed_context) | ||
return tgt | ||
|
||
|
||
class ProbablyAlmostUniversalDecoderBlockLol(nn.Module): # Maybe act as a universal decoder block? | ||
def __init__(self, cross_attention, norm, self_attention=None, position_ff=None, dr_rate=0): | ||
super(ProbablyAlmostUniversalDecoderBlockLol, self).__init__() | ||
self.cross_attention = cross_attention | ||
|
||
if self_attention is not None: | ||
self.self_attention = self_attention | ||
self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
else: | ||
self.self_attention = lambda query, key, value, mask: query | ||
self.residual1 = IdentityResidualLayer() | ||
|
||
self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
|
||
if position_ff is not None: | ||
self.position_ff = position_ff | ||
self.residual3 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
else: | ||
self.position_ff = lambda x: x | ||
self.residual3 = IdentityResidualLayer() | ||
|
||
def forward(self, tgt, encoder_out, src_tgt_mask, tgt_mask=None): | ||
# tgt: (batch_size, tgt_seq_len, d_model) # d_model used interchangeably with d_embed; TODO: fix this l8er | ||
# encoder_out: (batch_size, src_seq_len, d_model) | ||
# tgt_mask: (batch_size, tgt_seq_len, tgt_seq_len); may vary; check your self-attention layer's input shape | ||
# src_tgt_mask: (batch_size, tgt_seq_len, src_seq_len) | ||
|
||
# 1. Please preprocess your input to the self-attention layer if necessary | ||
# [Example implementation] Compute the mean of encoder_out along the sequence length dimension | ||
# avg_enc_out shape: (batch_size, 1, d_model) | ||
avg_enc_out = torch.mean(encoder_out, dim=1, keepdim=True) | ||
# Expand avg_enc_out to the same size as tgt and concatenate along the last dimension | ||
# tgt_concat shape: (batch_size, tgt_seq_len, 2*d_model) | ||
out_concat = torch.cat((tgt, avg_enc_out.expand_as(tgt)), dim=-1) | ||
# !!! Make sure that you have the tgt_mask that aligns with the dimension of the preprocessed input | ||
|
||
# 2. Put the preprocessed input into the self-attention layer; Be careful about the dimensions! | ||
# First MHA layer with query as the concatenation of out and avg_enc_out | ||
# Shape: (batch_size, tgt_seq_len, d_model) | ||
out = self.residual1(out_concat, lambda out: self.self_attention(query=out, key=encoder_out, value=encoder_out, | ||
mask=tgt_mask)) | ||
# 3. Please process the output of the self-attention layer to the cross-attention layer | ||
# Your implementation here: | ||
|
||
# 4. Put the (preprocessed) input into the cross-attention layer; Be careful about the dimensions! | ||
# Second MHA layer with query as the output of the first MHA layer | ||
# Shape: (batch_size, tgt_seq_len, d_model) | ||
out = self.residual2(out, lambda out: self.cross_attention(query=out, key=encoder_out, value=encoder_out, | ||
mask=src_tgt_mask)) | ||
|
||
# 5. Put the output of the cross-attention layer into the position-wise feed-forward network | ||
# Position-wise feed-forward network | ||
# Shape: (batch_size, tgt_seq_len, d_model) | ||
out = self.residual3(out, self.position_ff) | ||
|
||
# Return the output tensor | ||
# Shape: (batch_size, tgt_seq_len, d_model) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import copy | ||
import torch.nn as nn | ||
|
||
|
||
class Encoder(nn.Module): | ||
|
||
def __init__(self, encoder_block, n_layer, norm): | ||
super(Encoder, self).__init__() | ||
self.n_layer = n_layer | ||
self.layers = nn.ModuleList([copy.deepcopy(encoder_block) for _ in range(self.n_layer)]) | ||
self.norm = norm | ||
|
||
def forward(self, src, src_mask): | ||
# src: shape: (batch_size, src_seq_len, d_embed==d_embed_input) | ||
# src_mask: shape: (batch_size, 1, src_seq_len, src_seq_len) | ||
out = src | ||
for layer in self.layers: | ||
out = layer(out, src_mask) | ||
out = self.norm(out) | ||
return out # shape: (batch_size, src_seq_len, d_embed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import copy | ||
import torch.nn as nn | ||
|
||
from transformer_modules.residual_connection_layer import ResidualConnectionLayer | ||
|
||
|
||
class EncoderBlock(nn.Module): | ||
|
||
def __init__(self, self_attention, position_ff, norm, dr_rate=0): | ||
super(EncoderBlock, self).__init__() | ||
self.self_attention = self_attention | ||
self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
self.position_ff = position_ff | ||
self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate) | ||
|
||
def forward(self, src, src_mask): | ||
# src: shape: (batch_size, src_seq_len, d_embed==d_embed_input) | ||
# src_mask: shape: (batch_size, 1, src_seq_len, src_seq_len) | ||
out = src | ||
out = self.residual1(out, lambda out: self.self_attention(query=out, key=out, value=out, mask=src_mask)) | ||
out = self.residual2(out, self.position_ff) | ||
# out: shape: (batch_size, src_seq_len, d_embed) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
The modules were copied from the project 'TARL_2.0' on July 17, 2023. (i.e. from the repository DTARL in my github) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import copy | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class MultiHeadAttentionLayer(nn.Module): | ||
|
||
def __init__(self, d_model, h, q_fc, kv_fc, out_fc, dr_rate=0): | ||
super(MultiHeadAttentionLayer, self).__init__() | ||
self.d_model = d_model | ||
self.h = h | ||
|
||
# W^Q, W^K, W^V transform the input query, key, value to d_model dimension | ||
self.q_fc = copy.deepcopy(q_fc) # (d_embed_query, d_model) | ||
self.k_fc = copy.deepcopy(kv_fc) # (d_embed_key, d_model) | ||
self.v_fc = copy.deepcopy(kv_fc) # (d_embed_value, d_model) | ||
# TODO: Please remove the copy.deepcopy() if it is not necessary. | ||
# (maybe not necessary for q_fc but necessary for k_fc and v_fc) | ||
# In my understanding, uses of copy.deepcopy() are fine if it is used in __init__(). | ||
# But, if I don't need a copy of the (input) object, I shouldn't use copy.deepcopy() for better performance. | ||
# self.q_fc = nn.Linear(q_fc.in_features, q_fc.out_features, bias=q_fc.bias is not None) | ||
# self.k_fc = nn.Linear(kv_fc.in_features, kv_fc.out_features, bias=kv_fc.bias is not None) | ||
# self.v_fc = nn.Linear(kv_fc.in_features, kv_fc.out_features, bias=kv_fc.bias is not None) | ||
|
||
# W^O transforms the attention vectors to d_embed_MHA_out dimension (desired output dim, mostly idempotent) | ||
self.out_fc = out_fc # (d_model, d_embed_MHA_out) | ||
self.dropout = nn.Dropout(p=dr_rate) # if dr_rate == 0, identity mapping (no load on GPU/CPU) | ||
|
||
def calculate_attention(self, query, key, value, mask): | ||
# query: (n_batch, h, seq_len_query, d_k) | ||
# key, value: (n_batch, h, seq_len_key, d_k) | ||
# mask: (n_batch, 1, seq_len_query, seq_len_key) | ||
d_k = key.shape[-1] | ||
attention_score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T | ||
attention_score = attention_score / math.sqrt(d_k) # (n_batch, h, seq_len_query, seq_len_key) | ||
if mask is not None: | ||
attention_score = attention_score.masked_fill(mask == 0, -1e9) | ||
attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, h, seq_len_query, seq_len_key) | ||
attention_prob = self.dropout(attention_prob) | ||
out = torch.matmul(attention_prob, value) # TODO: check if this is correct, dimension-wise | ||
return out # (n_batch, h, seq_len_query, d_k) | ||
|
||
def forward(self, *args, query, key, value, mask=None): | ||
# query: (n_batch, seq_len_query, d_embed_query) | ||
# key, value: (n_batch, seq_len_key, d_embed_key) | ||
# mask: (n_batch, seq_len_query, seq_len_key) | ||
# return value: (n_batch, seq_len_query, d_embed_MHA_out); mostly idempotent: (query)==(return value) | ||
n_batch = query.size(0) | ||
|
||
def transform(x, x_fc): | ||
out = x_fc(x) # (n_batch, seq_len_x, d_embed_x) -> (n_batch, seq_len_x, d_model) | ||
out = out.view(n_batch, -1, self.h, self.d_model//self.h) # (n_batch, seq_len_x, h, d_k ) | ||
out = out.transpose(1, 2) | ||
return out # (n_batch, h, seq_len_x, d_k) | ||
|
||
query = transform(query, self.q_fc) # (n_batch, h, seq_len_query, d_k) | ||
key = transform(key, self.k_fc) # (n_batch, h, seq_len_key, d_k) | ||
value = transform(value, self.v_fc) # (n_batch, h, seq_len_key, d_k) | ||
|
||
out = self.calculate_attention(query, key, value, mask) # (n_batch, h, seq_len_query, d_k) | ||
out = out.transpose(1, 2) # (n_batch, seq_len_query, h, d_k) | ||
out = out.contiguous().view(n_batch, -1, self.d_model) # (n_batch, seq_len_query, d_model) | ||
out = self.out_fc(out) | ||
|
||
return out # (n_batch, seq_len_query, d_embed_MHA_out); d_embed_MHA_out == d_embed_query in most cases. |
Oops, something went wrong.