Skip to content

Commit

Permalink
Fused Rotary Embeddings (fixed) (#1108)
Browse files Browse the repository at this point in the history
* Create fused_rotary_positional_embedding.cpp

* Create fused_rotary_positional_embedding.h

* Create fused_rotary_positional_embedding_cuda.cu

* Update fused_rotary_positional_embedding.h

Ports the fix from NVIDIA/apex#1750 into this branch.

* Update neox_args.py

* Update setup.py

* Update initialize.py

* Update setup.py

* Update __init__.py

* Update test_fused_kernels.py

* Update setup.py

* Create fused_rope.py

* Update fused_rotary_positional_embedding.h

* Update fused_rotary_positional_embedding.cpp

* Update fused_rotary_positional_embedding.cpp

* Update transformer.py

* Update transformer.py

Just checked and this should work for bf16. Or, at least, the reason I originally thought it wouldn't doesn't apply.

* Update transformer.py

* Create 125M_fused_rope.yml

* Update 125M_fused_rope.yml

* Update transformer.py

Add `self.rope_fusion = neox_args.rope_fusion` so that `ParallelSelfAttention` knows if we're using rope fusion.

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

* Fix fused rope

Just needed to bring in the latest headers/sources,
and call into it the right way from transformers.py.

* Add rope_fusion arg to all ymls

---------

Co-authored-by: Stella Biderman <[email protected]>
Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
Co-authored-by: Yang Zhang <[email protected]>
  • Loading branch information
5 people committed Jan 5, 2024
1 parent 98716eb commit 77605ca
Show file tree
Hide file tree
Showing 30 changed files with 1,374 additions and 30 deletions.
1 change: 1 addition & 0 deletions configs/1-3B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/125M-json.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

"init_method": "small_init",
"output_layer_init_method": "wang_init",
Expand Down
1 change: 1 addition & 0 deletions configs/125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/13B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/175B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/19M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/2-7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/20B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"output_layer_parallelism": "column",
"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": true,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/350M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/49M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/6-7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/760M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/800M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
Expand Down
1 change: 1 addition & 0 deletions configs/bf16_125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,


# optimizer settings
Expand Down
1 change: 1 addition & 0 deletions configs/bnb_125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,


# optimizer settings
Expand Down
8 changes: 8 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,14 @@ Model Arguments



- **rope_fusion**: bool

Default = False

Enable rotary embedding fusion.



- **fp16_lm_cross_entropy**: bool

Default = False
Expand Down
1 change: 1 addition & 0 deletions configs/slurm_125M.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"no_weight_tying": true,
"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": true,
"rope_fusion": false,
"optimizer": {
"type": "Adam",
"params": {
Expand Down
1 change: 1 addition & 0 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def load_fused_kernels():
try:
import scaled_upper_triang_masked_softmax_cuda
import scaled_masked_softmax_cuda
import fused_rotary_positional_embedding
except (ImportError, ModuleNotFoundError) as e:
print("\n")
print(e)
Expand Down
139 changes: 139 additions & 0 deletions megatron/fused_kernels/fused_rotary_positional_embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* coding=utf-8
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <torch/extension.h>

namespace fused_rope {

torch::Tensor fwd_cuda(const torch::Tensor& input,
const torch::Tensor& freqs,
const bool transpose_output);

torch::Tensor bwd_cuda(const torch::Tensor& output_grads,
const torch::Tensor& freqs,
const bool transpose_output);

torch::Tensor fwd_cached_cuda(const torch::Tensor& input,
const torch::Tensor& cos,
const torch::Tensor& sin,
const bool transpose_output);

torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads,
const torch::Tensor& cos,
const torch::Tensor& sin,
const bool transpose_output);

torch::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output)
{
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) == freqs.size(0),
"expected input and freqs tensor have the same sequence length");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(3) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");

return fwd_cuda(input, freqs, transpose_output);
}

torch::Tensor bwd(const torch::Tensor& output_grads,
const at::Tensor& freqs,
const bool transpose_output)
{
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(output_grads.size(0) == freqs.size(0),
"expected output_grads and freqs tensor have the same sequence length");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");

return bwd_cuda(output_grads, freqs, transpose_output);
}

torch::Tensor fwd_cached(const at::Tensor& input,
const at::Tensor& cos,
const at::Tensor& sin,
const bool transpose_output)
{
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) == cos.size(0),
"expected input and cos tensor have the same sequence length");
TORCH_CHECK(input.size(0) == sin.size(0),
"expected input and sin tensor have the same sequence length");
TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
"expected the second and third dims of the cos tensor equal 1");
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
"expected the second and third dims of the sin tensor equal 1");
TORCH_CHECK(cos.size(3) == sin.size(3), "expected cos and sin tensor have the same last dim");
TORCH_CHECK(input.size(3) >= cos.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the cos tensor");
TORCH_CHECK(cos.scalar_type() == sin.scalar_type(),
"expected cos and sin tensor have the same dtype");

return fwd_cached_cuda(input, cos, sin, transpose_output);
}

torch::Tensor bwd_cached(const torch::Tensor& output_grads,
const at::Tensor& cos,
const at::Tensor& sin,
const bool transpose_output)
{
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
TORCH_CHECK(output_grads.size(0) == cos.size(0),
"expected output_grads and cos tensor have the same sequence length");
TORCH_CHECK(output_grads.size(0) == sin.size(0),
"expected output_grads and sin tensor have the same sequence length");
TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
"expected the second and third dims of the cos tensor equal 1");
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
"expected the second and third dims of the sin tensor equal 1");
TORCH_CHECK(cos.size(3) == sin.size(3), "expected cos and sin tensor have the same last dim");
TORCH_CHECK(output_grads.size(3) >= cos.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the cos tensor");
TORCH_CHECK(cos.scalar_type() == sin.scalar_type(),
"expected cos and sin tensor have the same dtype");

return bwd_cached_cuda(output_grads, cos, sin, transpose_output);
}

} // end namespace fused_rope

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward", &fused_rope::fwd, "Fused Rotary Positional Embedding -- Forward.");
m.def("backward", &fused_rope::bwd, "Fused Rotary Positional Embedding -- Backward.");
m.def("forward_cached",
&fused_rope::fwd_cached,
"Fused Rotary Positional Embedding Cached -- Forward.");
m.def("backward_cached",
&fused_rope::bwd_cached,
"Fused Rotary Positional Embedding Cached -- Backward.");
}
Loading

0 comments on commit 77605ca

Please sign in to comment.