Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward hooks not called when fast path is used in TransformerEncoderLayer #128413

Closed
iibrahimli opened this issue Jun 11, 2024 · 1 comment
Closed
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention

Comments

@iibrahimli
Copy link
Contributor

iibrahimli commented Jun 11, 2024

🐛 Describe the bug

When TransformerEncoderLayer is run in evaluation mode and a few conditions are met, the fast path is used (which is a fused optimized implementation) instead of calling the modules like MultiheadAttention (e.g. self.self_attn). This means any forward hooks or pre-hooks registered for the submodules are not called.

import torch
from torch import nn


cache = []


# forward hook to save output
def hook(module, inputs, output):
    cache.append(output[0].detach())


enc_layer = nn.TransformerEncoderLayer(d_model=32, nhead=8, batch_first=True)
enc_layer.eval()

# register hook to get the output of the self-attention layer
handle = enc_layer.self_attn.register_forward_hook(hook)

# input tensor of shape (batch_size, seq_len, d_model)
x = torch.randn(4, 6, 32)

# forward pass
with torch.inference_mode():
    output = enc_layer(x)

# output of the self-attention layer
assert len(cache) == 1, f"Expected 1 output, got {len(cache)}"
print(cache[0].shape)

# remove hook
handle.remove()

In this example, we would expect the cache to contain the output, but it does not. However, if we modify any condition for fast path selection, e.g. use an odd number of attention heads nhead=3, fast path is not used and the hook is called as expected.

Versions

PyTorch version: 2.2.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.23.1
Libc version: N/A

Python version: 3.12.0 (main, Oct 5 2023, 15:44:07) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1

Versions of relevant libraries:
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.2.0.post0
[pip3] torch==2.2.0
[pip3] torchmetrics==1.3.1
[pip3] torchviz==0.0.2
[conda] numpy 1.22.3 py39h64940a9_2 conda-forge
[conda] pytorch 1.11.0 cpu_py39h03f923b_1 conda-forge
[conda] torchdata 0.3.0 pyhd8ed1ab_0 conda-forge

cc @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki

@iibrahimli iibrahimli changed the title Forward hooks not called when fast path is used in TransformerEncoderLayer and TransformerDecoderLayer Forward hooks not called when fast path is used in TransformerEncoderLayer Jun 11, 2024
@iibrahimli
Copy link
Contributor Author

My proposed solution would be to fall back from using fast path if there are pre-/forward hooks on any submodules of the layer. I have started working on it: #128415

@colesbury colesbury added the oncall: transformer/mha Issues related to Transformers and MultiheadAttention label Jun 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: transformer/mha Issues related to Transformers and MultiheadAttention
Projects
None yet
2 participants