-
Notifications
You must be signed in to change notification settings - Fork 417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shape error when using torchtune.modules.RotaryPositionalEmbeddings #1157
Comments
This might be better filed in the torchtune repo? |
Hi @Leo-Lifeblood thanks for creating the issue. I believe you are seeing this error because you're using the |
@ebsmothers why do we apply rope on the expanded key tensor? it seems wasteful as rope is applied on head_dim anyway. |
rope = torchtune.modules.RotaryPositionalEmbeddings(32)
|
the rope implementation somehow ends up with 1/8th the required batch dimension: |
@Leo-Lifeblood in your first comment:
you are setting RoPE's dim to 32, while in your input tensor you have batch_size=32, seq_len=10, num_heads=4, head_dim=8. But as you can see here RoPE's dim should be the I don't follow your second example. But I suspect that it's due to the same reason: if your RoPE dim is off by a factor of 8 from what's in your input data it makes sense that the inferred dimension from a view of the RoPE cache based on your input data would be off by a factor of 8 as well. |
Ok I have tried what you have suggested It has not worked though I have the code below and i'll try to explain whats wrong with it from my perspective:
From this code I get:
In my understanding the batch size should not change here. |
@Leo-Lifeblood in your most recent example I think this line is not correct:
For RoPE your input tensor should have shape |
🐛 Describe the bug
When using the position encoding layer strange shape errors occur I dont have the time or insight to resolve
import torch
import torchtune
#max_value = max(tokenizer_causal.vocab.values()) + 1
max_value = 50
class causallm(torch.nn.Module):
def init(self, d_model, num_heads, d_ff, num_layers):
super().init()
model = causallm(d_model=512, num_heads=8, d_ff=2048, num_layers=2)
input_ids = torch.randint(0, max_value, (32, 128)) # input tensor with batch size 32 and sequence length 128
attention_mask = torch.ones((32, 128), dtype=torch.bool) # attention mask
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
outputs = model(input_ids, attention_mask=attention_mask)
RuntimeError Traceback (most recent call last)
in <cell line: 11>()
9
10 # Forward pass
---> 11 outputs = model(input_ids, attention_mask=attention_mask)
5 frames
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
in forward(self, x, attention_mask)
23 #batch_size = x.shape[0]
24 #x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
---> 25 x = self.pos_embeddings(x)
26
27
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/torchtune/modules/position_embeddings.py in forward(self, x, input_pos)
109 # reshape the cache for broadcasting
110 # tensor has shape [1, s, 1, n_d // 2, 2]
--> 111 rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
112
113 # tensor has shape [b, s, n_h, n_d // 2, 2]
RuntimeError: shape '[1, 128, 1, 2, 2]' is invalid for input of size 65536
Versions
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.27.9
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.85+-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 2
On-line CPU(s) list: 0,1
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family: 6
Model: 79
Thread(s) per core: 2
Core(s) per socket: 1
Socket(s): 1
Stepping: 0
BogoMIPS: 4399.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32 KiB (1 instance)
L1i cache: 32 KiB (1 instance)
L2 cache: 256 KiB (1 instance)
L3 cache: 55 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0,1
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable; SMT Host state unknown
Vulnerability Meltdown: Vulnerable
Vulnerability Mmio stale data: Vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Not affected; BHI: Vulnerable (Syscall hardening enabled)
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable
Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.3.0+cu121
[pip3] torchao==0.1
[pip3] torchaudio==2.3.0+cu121
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.18.0
[pip3] torchtune==0.1.1
[pip3] torchvision==0.18.0+cu121
[pip3] triton==2.3.0
[conda] Could not collect
The text was updated successfully, but these errors were encountered: