Skip to content

Commit

Permalink
[AMD] Supporting fused kernels build using JIT (#1188)
Browse files Browse the repository at this point in the history
* initial JIT load functions

* passing neox_arge to load() as optional for easy testing

* modified headers for correct copyright statements
  • Loading branch information
R0n12 committed Apr 1, 2024
1 parent 977448e commit 51a7de9
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 4 deletions.
144 changes: 140 additions & 4 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -11,14 +14,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified from its original version
#

import os
import pathlib
import subprocess

from pathlib import Path

srcpath = Path(__file__).parent.absolute()
import torch
from torch.utils import cpp_extension

# Setting this param to a list has a problem of generating different
# compilation commands (with different order of architectures) and
Expand All @@ -28,6 +32,138 @@
os.environ["TORCH_CUDA_ARCH_LIST"] = ""


def load(neox_args=None):

# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
if torch.version.hip is None:
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME
)
if int(bare_metal_major) >= 11:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if int(bare_metal_minor) >= 1:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_86,code=sm_86")
if int(bare_metal_minor) >= 4:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_87,code=sm_87")
if int(bare_metal_minor) >= 8:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_89,code=sm_89")
if int(bare_metal_major) >= 12:
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / "build"
_create_build_dir(buildpath)

# Determine verbosity
verbose = True if neox_args is None else (neox_args.rank == 0)

# Helper function to build the kernels.
def _cpp_extention_load_helper(
name, sources, extra_cuda_flags, extra_include_paths
):
if torch.version.hip is not None:
extra_cuda_cflags = ["-O3"] + extra_cuda_flags + cc_flag
else:
extra_cuda_cflags = (
["-O3", "-gencode", "arch=compute_70,code=sm_70", "--use_fast_math"]
+ extra_cuda_flags
+ cc_flag
)

return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=[
"-O3",
],
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
verbose=verbose,
)

# ==============
# Fused softmax.
# ==============

if torch.version.hip is not None:
extra_include_paths = [os.path.abspath(srcpath)]
else:
extra_include_paths = []

if torch.version.hip is not None:
extra_cuda_flags = [
"-D__HIP_NO_HALF_OPERATORS__=1",
"-D__HIP_NO_HALF_CONVERSIONS__=1",
]
else:
extra_cuda_flags = [
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
]

# Upper triangular softmax.
sources = [
srcpath / "scaled_upper_triang_masked_softmax.cpp",
srcpath / "scaled_upper_triang_masked_softmax_cuda.cu",
]
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda",
sources,
extra_cuda_flags,
extra_include_paths,
)
# Masked softmax.
sources = [
srcpath / "scaled_masked_softmax.cpp",
srcpath / "scaled_masked_softmax_cuda.cu",
]
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths
)
# fused rope
sources = [
srcpath / "fused_rotary_positional_embedding.cpp",
srcpath / "fused_rotary_positional_embedding_cuda.cu",
]
fused_rotary_positional_embedding_cuda = _cpp_extention_load_helper(
"fused_rotary_positional_embedding_cuda",
sources,
extra_cuda_flags,
extra_include_paths,
)


def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]

return raw_output, bare_metal_major, bare_metal_minor


def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")


def load_fused_kernels():
try:
import scaled_upper_triang_masked_softmax_cuda
Expand Down
2 changes: 2 additions & 0 deletions megatron/fused_kernels/scaled_masked_softmax_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_fp16.h>
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cuda_runtime.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
Expand Down
1 change: 1 addition & 0 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def finish_mpu_init():
or neox_args.scaled_masked_softmax_fusion
or neox_args.rope_fusion
):
fused_kernels.load(neox_args)
fused_kernels.load_fused_kernels()

if neox_args.lazy_mpu_init:
Expand Down
4 changes: 4 additions & 0 deletions tests/model/test_fused_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from transformers import BertTokenizer, GPT2Tokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from megatron.fused_kernels import load
import transformers

transformers.logging.set_verbosity(
Expand All @@ -33,6 +34,7 @@
reason="ModuleNotFoundError: No module named 'scaled_masked_softmax_cuda'"
)
def test_load_fused_kernels():
load()
try:
import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda
Expand All @@ -47,6 +49,7 @@ def test_load_fused_kernels():

@pytest.mark.xfail(reason="SystemExit: None")
def test_fused_softmax():
load()
from megatron.model.fused_softmax import FusedScaleMaskSoftmax, SoftmaxFusionTypes
from megatron.model.gpt2_model import (
gpt2_attention_mask_func as attention_mask_func,
Expand Down Expand Up @@ -149,6 +152,7 @@ def test_fused_softmax():

@pytest.mark.xfail(reason="SystemExit: None")
def test_fused_upper_triangle_mask_softmax():
load()
from megatron.model.gpt2_model import (
gpt2_attention_mask_func as attention_mask_func,
)
Expand Down

0 comments on commit 51a7de9

Please sign in to comment.