Skip to content

Commit

Permalink
day2
Browse files Browse the repository at this point in the history
好像懂了,明天再看吧
  • Loading branch information
lgX1123 committed Nov 7, 2023
1 parent 9490f71 commit 6e364d4
Showing 1 changed file with 71 additions and 5 deletions.
76 changes: 71 additions & 5 deletions notes_code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.nn.init import xavier_uniform_\n",
"import torch.nn.functional as F"
"import torch.nn.functional as F\n",
"from torch.nn.modules.normalization import LayerNorm\n",
"import copy"
]
},
{
Expand Down Expand Up @@ -48,6 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Todo: My layernorm\n",
"\n",
"class MultiheadAttention(nn.Module):\n",
" def __init__(self, embed_dim, num_heads, dropout):\n",
Expand All @@ -64,35 +67,98 @@
" self.in_proj_bias = torch.zeros(3 * embed_dim).to(device=device, dtype=dtype)\n",
"\n",
" def forward(self, query, key, value):\n",
" \n",
" pass\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"class TransformerEncoder(nn.Module):\n",
" def __init__(self, encoder_layer, num_encoder_layers, encoder_norm):\n",
" super().__init__()\n",
" self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_encoder_layers)])\n",
" self.num_layers = num_encoder_layers\n",
" self.norm = encoder_norm\n",
"\n",
" def forward(self, src, src_mask):\n",
" # Todo: mask\n",
" \n",
" output = src\n",
"\n",
" for layer in self.layers:\n",
" output = layer(output, src_mask)\n",
" \n",
" output = self.norm(output)\n",
" \n",
" return output\n",
"\n",
"class TransformerEncoderLayer(nn.Module):\n",
" def __init__(self, d_model, nhead, dim_feedforward, dropout):\n",
" super().__init__()\n",
" self.self_attn = MultiheadAttention(d_model, nhead, dropout)\n",
"\n",
" # feedforward\n",
" self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
"\n",
" self.norm1 = LayerNorm(d_model)\n",
" self.norm2 = LayerNorm(d_model)\n",
" self.dropout1 = nn.Dropout(dropout)\n",
" self.dropout2 = nn.Dropout(dropout)\n",
" self.activation = F.relu\n",
"\n",
" def forward(self, src, src_mask):\n",
" pass\n",
"\n",
" # Todo: mask\n",
"\n",
" x = src\n",
"\n",
" x = self.norm1(x + self._sa_block(self.norm1(x), src_mask))\n",
" x = self.norm2(x + self._ff_block(x))\n",
"\n",
" return x\n",
"\n",
" def _sa_block(self, x, attn_mask):\n",
" x = self.self_attn(x, x, x, attn_mask)\n",
" return self.dropout1(x)\n",
" \n",
" def _ff_block(self, x):\n",
" x = self.linear2(self.dropout(self.activation(self.linear1(x))))\n",
" return self.dropout2(x)\n",
" \n",
"\n",
"\n",
"class Transformer(nn.Module):\n",
" def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers, \n",
" dim_feedforward, dropout):\n",
" super().__init__()\n",
"\n",
" # encoder\n",
" encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)\n",
" encoder_norm = LayerNorm(d_model)\n",
" self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)\n",
"\n",
" #decoder\n",
" decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout)\n",
" decoder_norm = LayerNorm(d_model)\n",
" self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)\n",
"\n",
" self.d_model = d_model\n",
" self.nhead = nhead\n",
"\n",
"\n",
" \n",
" def forward(self, src, tgt, src_mask, tgt_mask):\n",
" pass\n"
" \"\"\"\n",
" src: (seq_s, batch_size, embedding)\n",
" \"\"\"\n",
" memory = self.encoder(src, src_mask)\n",
" output = self.decoder(tgt, memory, tgt_mask)\n",
"\n",
" return output\n"
]
}
],
Expand Down

0 comments on commit 6e364d4

Please sign in to comment.