Skip to content

Commit

Permalink
Upload the model
Browse files Browse the repository at this point in the history
  • Loading branch information
JongYun-Kim committed Jul 22, 2023
1 parent 3d5035d commit 18f5cf0
Show file tree
Hide file tree
Showing 13 changed files with 1,082 additions and 0 deletions.
445 changes: 445 additions & 0 deletions models/lazy_allocator.py

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions models/transformer_modules/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import copy
import torch.nn as nn


class Decoder(nn.Module):

def __init__(self, decoder_block, n_layer, norm):
super(Decoder, self).__init__()
self.n_layer = n_layer
self.layers = nn.ModuleList([copy.deepcopy(decoder_block) for _ in range(self.n_layer)])
self.norm = norm if norm is not None else nn.Identity() # a placeholder; may break backward compatibility
# If possible use nn.Identity() instead of None as this way is more readable

def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
out = tgt
for layer in self.layers:
out = layer(out, encoder_out, tgt_mask, src_tgt_mask)
out = self.norm(out)
return out # shape: (batch_size, tgt_seq_len, d_embed)


class DecoderPlaceholder(nn.Module):

def __init__(self, decoder_block=None, n_layer=None, norm=None, *args, **kwargs):
super(DecoderPlaceholder, self).__init__()
# We don't store or use the provided arguments,
# but we accept them to ensure interface compatibility.

def forward(self, tgt, encoder_out=None, tgt_mask=None, src_tgt_mask=None, *args, **kwargs):
return tgt # Just returning the input tgt as is

132 changes: 132 additions & 0 deletions models/transformer_modules/decoder_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import copy
import torch
import torch.nn as nn

from transformer_modules.residual_connection_layer import ResidualConnectionLayer, IdentityResidualLayer, \
NoResidualButSameForward


class DecoderBlock(nn.Module):

def __init__(self, self_attention, cross_attention, position_ff, norm, dr_rate=0):
super(DecoderBlock, self).__init__()
self.self_attention = self_attention
self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)
self.cross_attention = cross_attention
self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)
self.position_ff = position_ff
self.residual3 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)

def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
out = tgt
out = self.residual1(out, lambda out: self.self_attention(query=out, key=out, value=out, mask=tgt_mask))
out = self.residual2(out, lambda out: self.cross_attention(query=out, key=encoder_out, value=encoder_out, mask=src_tgt_mask))
out = self.residual3(out, self.position_ff)
return out


class CustomDecoderBlock(nn.Module):

def __init__(self, cross_attention, norm, self_attention=None, position_ff=None, dr_rate=0, efficient=False):
super(CustomDecoderBlock, self).__init__()
# Initialize ResidualConnectionLayers
if norm is None:
# If norm==None/nn.Identity() and efficient==True, do consider dropping out the dropout layers
norm = nn.Identity()

if efficient:
# self.residual1 = NoResidualButSameForward()
self.residual2 = NoResidualButSameForward()
self.residual3 = NoResidualButSameForward()
else:
# self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)
self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)
self.residual3 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)

# Initialize self-attention layer and cross-attention layer
self.self_attention = self_attention
if self.self_attention is not None: print("[DecoderBlock] self_attention is NOT None, but not used here!!!!!!!")
if cross_attention is None:
raise ValueError("cross_attention is None; you must use a cross attention in this implementation")
self.cross_attention = cross_attention
# Initialize position-wise feed-forward network
self.position_ff = position_ff

def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
# tgt: (batch_size, tgt_seq_len, d_embed_context)
# encoder_out: (batch_size, src_seq_len, d_embed_input)
# tgt_mask: (batch_size, 1, tgt_seq_len, tgt_seq_len)
# src_tgt_mask: (batch_size, 1, tgt_seq_len==1, src_seq_len)

# MHA layer with query as the output of the first MHA layer
# Shape: (batch_size, tgt_seq_len, d_model)
tgt = self.residual2(tgt, lambda tgt: self.cross_attention(query=tgt, key=encoder_out, value=encoder_out,
mask=src_tgt_mask))
# Position-wise feed-forward network, applied only if include_ffn is True
# Shape: (batch_size, tgt_seq_len, d_model)
if self.position_ff is not None:
tgt = self.residual3(tgt, self.position_ff)

# Return the output tensor
# Shape: (batch_size, tgt_seq_len==1, d_embed_context)
return tgt


class ProbablyAlmostUniversalDecoderBlockLol(nn.Module): # Maybe act as a universal decoder block?
def __init__(self, cross_attention, norm, self_attention=None, position_ff=None, dr_rate=0):
super(ProbablyAlmostUniversalDecoderBlockLol, self).__init__()
self.cross_attention = cross_attention

if self_attention is not None:
self.self_attention = self_attention
self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)
else:
self.self_attention = lambda query, key, value, mask: query
self.residual1 = IdentityResidualLayer()

self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)

if position_ff is not None:
self.position_ff = position_ff
self.residual3 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)
else:
self.position_ff = lambda x: x
self.residual3 = IdentityResidualLayer()

def forward(self, tgt, encoder_out, src_tgt_mask, tgt_mask=None):
# tgt: (batch_size, tgt_seq_len, d_model) # d_model used interchangeably with d_embed; TODO: fix this l8er
# encoder_out: (batch_size, src_seq_len, d_model)
# tgt_mask: (batch_size, tgt_seq_len, tgt_seq_len); may vary; check your self-attention layer's input shape
# src_tgt_mask: (batch_size, tgt_seq_len, src_seq_len)

# 1. Please preprocess your input to the self-attention layer if necessary
# [Example implementation] Compute the mean of encoder_out along the sequence length dimension
# avg_enc_out shape: (batch_size, 1, d_model)
avg_enc_out = torch.mean(encoder_out, dim=1, keepdim=True)
# Expand avg_enc_out to the same size as tgt and concatenate along the last dimension
# tgt_concat shape: (batch_size, tgt_seq_len, 2*d_model)
out_concat = torch.cat((tgt, avg_enc_out.expand_as(tgt)), dim=-1)
# !!! Make sure that you have the tgt_mask that aligns with the dimension of the preprocessed input

# 2. Put the preprocessed input into the self-attention layer; Be careful about the dimensions!
# First MHA layer with query as the concatenation of out and avg_enc_out
# Shape: (batch_size, tgt_seq_len, d_model)
out = self.residual1(out_concat, lambda out: self.self_attention(query=out, key=encoder_out, value=encoder_out,
mask=tgt_mask))
# 3. Please process the output of the self-attention layer to the cross-attention layer
# Your implementation here:

# 4. Put the (preprocessed) input into the cross-attention layer; Be careful about the dimensions!
# Second MHA layer with query as the output of the first MHA layer
# Shape: (batch_size, tgt_seq_len, d_model)
out = self.residual2(out, lambda out: self.cross_attention(query=out, key=encoder_out, value=encoder_out,
mask=src_tgt_mask))

# 5. Put the output of the cross-attention layer into the position-wise feed-forward network
# Position-wise feed-forward network
# Shape: (batch_size, tgt_seq_len, d_model)
out = self.residual3(out, self.position_ff)

# Return the output tensor
# Shape: (batch_size, tgt_seq_len, d_model)
return out
20 changes: 20 additions & 0 deletions models/transformer_modules/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import copy
import torch.nn as nn


class Encoder(nn.Module):

def __init__(self, encoder_block, n_layer, norm):
super(Encoder, self).__init__()
self.n_layer = n_layer
self.layers = nn.ModuleList([copy.deepcopy(encoder_block) for _ in range(self.n_layer)])
self.norm = norm

def forward(self, src, src_mask):
# src: shape: (batch_size, src_seq_len, d_embed==d_embed_input)
# src_mask: shape: (batch_size, 1, src_seq_len, src_seq_len)
out = src
for layer in self.layers:
out = layer(out, src_mask)
out = self.norm(out)
return out # shape: (batch_size, src_seq_len, d_embed)
23 changes: 23 additions & 0 deletions models/transformer_modules/encoder_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import copy
import torch.nn as nn

from transformer_modules.residual_connection_layer import ResidualConnectionLayer


class EncoderBlock(nn.Module):

def __init__(self, self_attention, position_ff, norm, dr_rate=0):
super(EncoderBlock, self).__init__()
self.self_attention = self_attention
self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)
self.position_ff = position_ff
self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), dr_rate)

def forward(self, src, src_mask):
# src: shape: (batch_size, src_seq_len, d_embed==d_embed_input)
# src_mask: shape: (batch_size, 1, src_seq_len, src_seq_len)
out = src
out = self.residual1(out, lambda out: self.self_attention(query=out, key=out, value=out, mask=src_mask))
out = self.residual2(out, self.position_ff)
# out: shape: (batch_size, src_seq_len, d_embed)
return out
1 change: 1 addition & 0 deletions models/transformer_modules/from_where.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The modules were copied from the project 'TARL_2.0' on July 17, 2023. (i.e. from the repository DTARL in my github)
67 changes: 67 additions & 0 deletions models/transformer_modules/multi_head_attention_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadAttentionLayer(nn.Module):

def __init__(self, d_model, h, q_fc, kv_fc, out_fc, dr_rate=0):
super(MultiHeadAttentionLayer, self).__init__()
self.d_model = d_model
self.h = h

# W^Q, W^K, W^V transform the input query, key, value to d_model dimension
self.q_fc = copy.deepcopy(q_fc) # (d_embed_query, d_model)
self.k_fc = copy.deepcopy(kv_fc) # (d_embed_key, d_model)
self.v_fc = copy.deepcopy(kv_fc) # (d_embed_value, d_model)
# TODO: Please remove the copy.deepcopy() if it is not necessary.
# (maybe not necessary for q_fc but necessary for k_fc and v_fc)
# In my understanding, uses of copy.deepcopy() are fine if it is used in __init__().
# But, if I don't need a copy of the (input) object, I shouldn't use copy.deepcopy() for better performance.
# self.q_fc = nn.Linear(q_fc.in_features, q_fc.out_features, bias=q_fc.bias is not None)
# self.k_fc = nn.Linear(kv_fc.in_features, kv_fc.out_features, bias=kv_fc.bias is not None)
# self.v_fc = nn.Linear(kv_fc.in_features, kv_fc.out_features, bias=kv_fc.bias is not None)

# W^O transforms the attention vectors to d_embed_MHA_out dimension (desired output dim, mostly idempotent)
self.out_fc = out_fc # (d_model, d_embed_MHA_out)
self.dropout = nn.Dropout(p=dr_rate) # if dr_rate == 0, identity mapping (no load on GPU/CPU)

def calculate_attention(self, query, key, value, mask):
# query: (n_batch, h, seq_len_query, d_k)
# key, value: (n_batch, h, seq_len_key, d_k)
# mask: (n_batch, 1, seq_len_query, seq_len_key)
d_k = key.shape[-1]
attention_score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T
attention_score = attention_score / math.sqrt(d_k) # (n_batch, h, seq_len_query, seq_len_key)
if mask is not None:
attention_score = attention_score.masked_fill(mask == 0, -1e9)
attention_prob = F.softmax(attention_score, dim=-1) # (n_batch, h, seq_len_query, seq_len_key)
attention_prob = self.dropout(attention_prob)
out = torch.matmul(attention_prob, value) # TODO: check if this is correct, dimension-wise
return out # (n_batch, h, seq_len_query, d_k)

def forward(self, *args, query, key, value, mask=None):
# query: (n_batch, seq_len_query, d_embed_query)
# key, value: (n_batch, seq_len_key, d_embed_key)
# mask: (n_batch, seq_len_query, seq_len_key)
# return value: (n_batch, seq_len_query, d_embed_MHA_out); mostly idempotent: (query)==(return value)
n_batch = query.size(0)

def transform(x, x_fc):
out = x_fc(x) # (n_batch, seq_len_x, d_embed_x) -> (n_batch, seq_len_x, d_model)
out = out.view(n_batch, -1, self.h, self.d_model//self.h) # (n_batch, seq_len_x, h, d_k )
out = out.transpose(1, 2)
return out # (n_batch, h, seq_len_x, d_k)

query = transform(query, self.q_fc) # (n_batch, h, seq_len_query, d_k)
key = transform(key, self.k_fc) # (n_batch, h, seq_len_key, d_k)
value = transform(value, self.v_fc) # (n_batch, h, seq_len_key, d_k)

out = self.calculate_attention(query, key, value, mask) # (n_batch, h, seq_len_query, d_k)
out = out.transpose(1, 2) # (n_batch, seq_len_query, h, d_k)
out = out.contiguous().view(n_batch, -1, self.d_model) # (n_batch, seq_len_query, d_model)
out = self.out_fc(out)

return out # (n_batch, seq_len_query, d_embed_MHA_out); d_embed_MHA_out == d_embed_query in most cases.
Loading

0 comments on commit 18f5cf0

Please sign in to comment.