Skip to content

Forward and backward Attention DNN operators implementationed by LibTorch, cuDNN, and Eigen.

Notifications You must be signed in to change notification settings


Repository files navigation

Which part will we implement in the transformer model.

eigenMHA (eigenDNN vs cuDNN) -- Multi-head Attention Inference and Training implemented by Eigen.

To clone this repo,

git clone --recursive
cd eigenMHA
git clone  # clone eigen if necessary


In this repo, we use Eigen3 to implement the forward and backward of Multi-head Attention in Transformer models. Basically, this repo has two branches -- torch and cudnn.

The MHAs in this repo

  1. a pytorch MHA in that illustrates the MHA module we implement
  2. an eigen MHA in in both branches (with sources in ./src/eigenDNN.cpp and headers in ./inlcude/eigenDNN.h)
  3. a libtorch MHA in the torch branch as a comparison to the eigenMHA
  4. a cudnn MHA in the cudnn branch as a comparison to the eigenMHA

branch torch

git checkout torch

In this branch, the eigenDNN is compared with the CPU libtorch. To make and run the project, first install LibTorch for necessary verification, see [nnTest mainly focuses on providing a testing framework to train and inference Deep Neural Networks using YOUR OWN LIBRARY]. And then,

mkdir build && cd build
cmake ..
make -j4

branch cudnn

git checkout cudnn

In this branch, the eigenDNN is compared with the Multi-head Attention APIs provided by cuDNN V8 (cudnn_samples_v8/multiHeadAttention).

To install cuDNN, see and . After copying the corresponding libraries and headers to the correct location,

mkdir build && cd build
cmake ..
make -j4

To be more specific, this eigenDNN does what the cuDNN does in the following APIs for MHA operations.

For more details of the Attention APIs in cuDNN v8, see this 中文CSDN链接.

What are the variables of MHA in a Training Library?

Forward Pass of MHA

  1. Q, K, V input embeddings

$$ \mathbf{Q}{in} \quad \mathbf{K}{in} \quad \mathbf{V}_{in} $$

  1. Weights and bias for the linear layer of Q K V and O.

$$ \mathbf{W}{Q} \quad \mathbf{b}{Q} $$

$$ \mathbf{W}{K} \quad \mathbf{b}{K} $$

$$ \mathbf{W}{V} \quad \mathbf{b}{V} $$

$$ \mathbf{W}{O} \quad \mathbf{b}{O} $$

  1. Intermediate variables
  2. Output and target

$$ \mathbf{O}{out}\quad\mathbf{O}{target} $$

The equations of MHA forward pass are as follows,

$$ \mathbf{Q} = \mathbf{Q}{in}*\mathbf{W}{Q}+\mathbf{b}_{Q} $$

$$ \mathbf{K} = \mathbf{K}{in}*\mathbf{W}{K}+\mathbf{b}_{K} $$

$$ \mathbf{V} = \mathbf{V}{in}*\mathbf{W}{V}+\mathbf{b}_{V} $$

$$ \mathbf{S} = \mathbf{Q}*\mathbf{K}^T $$

$$ \mathbf{P} = SoftmaxFWD(Mask(\mathbf{S}*\frac{1}{\sqrt{d}})) $$

$$ \mathbf{P} = DropoutFWD(\mathbf{P}) $$

$$ \mathbf{O}=\mathbf{P}*\mathbf{V} $$

$$ \mathbf{O}{out} = \mathbf{O}*\mathbf{W}{O}+\mathbf{b}_{O} $$

MSE Loss

$$ loss = MSELoss(\mathbf{O}{out},\mathbf{O}{target}) $$

MSELoss will also gives

$$ \mathbf{grad\O}{out} $$

, the gradient of

$$ \mathbf{O}_{out} $$

Backward Pass of MHA

  1. Gradients for output (from LayerNorm)

$$ \mathbf{grad\O}{out} $$

  1. Gradients for the intermediate variables
  2. Gradients for the forward input

$$ \mathbf{grad\Q}{in} \quad \mathbf{grad\K}{in} \quad \mathbf{grad\V}{in} $$

  1. Gradients of the weights and biases

$$ \mathbf{grad\W}{Q} \quad \mathbf{grad\b}{Q} $$

$$ \mathbf{grad\W}{K} \quad \mathbf{grad\b}{K} $$

$$ \mathbf{grad\W}{V} \quad \mathbf{grad\b}{V} $$

$$ \mathbf{grad\W}{O} \quad \mathbf{grad\b}{O} $$

The equations of MHA backward pass are as follows,

$$ \mathbf{grad\O} = \mathbf{grad\O}{out}*\mathbf{W}{O} $$

$$ \mathbf{grad\W}{O} = \mathbf{grad\O}{out}^T*\mathbf{O} $$

$$ \mathbf{grad\b}{O} = colsum(\mathbf{grad\O}{out}) $$

$$ \mathbf{grad\_P} = \mathbf{grad\_O}*\mathbf{V}^T $$

$$ \mathbf{grad\_V} = \mathbf{P}^T*\mathbf{grad\_O} $$

$$ \mathbf{grad\_P} = DropoutBWD(\mathbf{grad\_P}) $$

$$ \mathbf{grad\_S} = SoftmaxBWD(\mathbf{P},\mathbf{grad\_P})*\frac{1}{\sqrt{d}} $$

$$ \mathbf{grad\_Q} = \mathbf{grad\_S}*\mathbf{K} $$

$$ \mathbf{grad\_K} = \mathbf{grad\_S}^T*\mathbf{Q} $$

$$ \mathbf{grad\Q}{in} = \mathbf{grad\Q}*\mathbf{W}{Q}^T $$

$$ \mathbf{grad\W}{Q} = \mathbf{Q}_{in}^T*\mathbf{grad\_Q} $$

$$ \mathbf{grad\b}{Q} = colsum(\mathbf{grad\_Q}) $$

$$ \mathbf{grad\K}{in} = \mathbf{grad\K}*\mathbf{W}{K}^T $$

$$ \mathbf{grad\W}{K} = \mathbf{K}_{in}^T*\mathbf{grad\_K} $$

$$ \mathbf{grad\b}{K} = colsum(\mathbf{grad\_K}) $$

$$ \mathbf{grad\V}{in} = \mathbf{grad\V}*\mathbf{W}{V}^T $$

$$ \mathbf{grad\W}{V} = \mathbf{V}_{in}^T*\mathbf{grad\_V} $$

$$ \mathbf{grad\b}{V} = colsum(\mathbf{grad\_V}) $$

The components of the MHA Training Library

MSE Loss Function

Loss function, as the origin of DL system, is a basic component inside a DL system.

MSE Loss.
eidnnStatus_t eidnnMSELoss(
    eidnnHandle_t handle,
    const Tensor<float, 3> &output, 
    const Tensor<float, 3> &target,
    Tensor<float, 0> &loss,
    Tensor<float, 3> &d_loss);


cuDNN has no specific APIs for linear layer.

In eigenDNN, we have

eidnnStatus_t eidnnLinearForward(eidnnHandle_t handle,
                    const Tensor<float, 3>& x, // data
                    const Tensor<float, 2>& w, // weight
                    const Tensor<float, 1>& bias, // bias
                    Tensor<float, 3>& y);
eidnnStatus_t eidnnLinearBackward(eidnnHandle_t handle,
                     const Tensor<float, 3>& dy,
                     const Tensor<float, 3>& x,
                     const Tensor<float, 2>& w,
                     Tensor<float, 3>& dx, // gradient of input data
                     Tensor<float, 2>& dw, // accumulated gradient of weight
                     Tensor<float, 1>& dbias // accumulated gradient of bias


$$ C = \beta * C + \alpha*Op_c(MatMul(Op_a(A),Op_b(B))) $$

, where $Op_m(M)$ is whether to transpose matrix $M$ or not in the forward pass.

cuDNN has no specific APIs for matrix-multiply operation.

In eigenDNN, we have

eidnnStatus_t eidnnStridedBatchedGemmForward(
    eidnnHandle_t handle,
    float alpha,
    float beta,
    bool trans_A, // Op_a
    bool trans_B, // Op_b
    bool trans_C, // Op_c
    const Tensor<float, 4> &A, 
    const Tensor<float, 4> &B, 
    Tensor<float, 4> &C);
eidnnStatus_t eidnnStridedBatchedGemmBackward(
    eidnnHandle_t handle,
    float alpha,
    float beta,
    bool trans_A, // Op_a
    bool trans_B, // Op_b
    bool trans_C, // Op_c
    const Tensor<float, 4> &A, // A
    const Tensor<float, 4> &B, // B
    const Tensor<float, 4> &d_C, // gradient of C
    Tensor<float, 4> &d_A, // gradient of A
    Tensor<float, 4> &d_B // gradient of B


cuDNN has the following APIs for softmax operation.

In eigenDNN, we have

eidnnStatus_t eidnnSoftmaxForward(eidnnHandle_t handle,
                    eidnnSoftmaxAlgorithm_t algo,
                    eidnnSoftmaxMode_t mode,
                    const Tensor<float, 4>& x,
                    Tensor<float, 4>& y);
eidnnStatus_t eidnnSoftmaxBackward(eidnnHandle_t handle,
                     eidnnSoftmaxAlgorithm_t algo,
                     eidnnSoftmaxMode_t mode,
                     const Tensor<float, 4>& y,
                     const Tensor<float, 4>& dy,
                     Tensor<float, 4>& dx);


cuDNN has the following APIs for dropout operation.

In eigenDNN, we have

// dropout rate, 
// pointer to memory space of states (allocated by forward pass), 
// size of memory space in bytes (calculated by forward pass), 
// random seed
using eidnnDropoutDescriptor_t = std::tuple<float, void*, size_t, unsigned long long>; 
eidnnStatus_t eidnnDropoutForward(
    eidnnHandle_t                       handle,
    eidnnDropoutDescriptor_t      &dropoutDesc,
    const Tensor<float, 4>         &x, // input data
    Tensor<float, 4>               &y // input data after dropout
eidnnStatus_t eidnnDropoutBackward(
    eidnnHandle_t                   handle,
    const eidnnDropoutDescriptor_t  dropoutDesc,
    const Tensor<float, 4>       &dy, // gradient of dropout output data
    Tensor<float, 4>             &dx // gradient of dropout input data


Forward and backward Attention DNN operators implementationed by LibTorch, cuDNN, and Eigen.







No releases published


No packages published