Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/bf16' into zero-3
Browse files Browse the repository at this point in the history
  • Loading branch information
StellaAthena committed May 10, 2021
2 parents 8e383e6 + 1f388cf commit 39cc90a
Show file tree
Hide file tree
Showing 20 changed files with 1,648 additions and 403 deletions.
16 changes: 0 additions & 16 deletions megatron/fp16/__init__.py

This file was deleted.

56 changes: 0 additions & 56 deletions megatron/fp16/fp16.py

This file was deleted.

144 changes: 76 additions & 68 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,86 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import pathlib
import subprocess
import os

from torch.utils import cpp_extension

# Setting this param to a list has a problem of generating
# different compilation commands (with diferent order of architectures)
# and leading to recompilation of fused kernels.
# set it to empty string to avoid recompilation
# and assign arch flags explicity in extra_cuda_cflags below
# Setting this param to a list has a problem of generating different
# compilation commands (with diferent order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
# to avoid recompilation and assign arch flags explicity in
# extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = ""

def get_cuda_bare_metal_version(cuda_dir):

def load_fused_kernels(neox_args):
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')

# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'
_create_build_dir(buildpath)

# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=['-O3', ],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'--use_fast_math'] + extra_cuda_flags + cc_flag,
verbose=(neox_args.rank == 0)
)

# ==============
# Fused softmax.
# ==============

if neox_args.scaled_upper_triang_masked_softmax_fusion or neox_args.scaled_masked_softmax_fusion:

print('loading fused softmax kernels - this may take a minute or two...')

extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']

# Upper triangular softmax.
sources = [srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_upper_triang_masked_softmax_cuda",
sources, extra_cuda_flags)

# Masked softmax.
sources = [srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu']
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
"scaled_masked_softmax_cuda", sources, extra_cuda_flags)

# Neox isn't using this yet - but for when we do:
# # =================================
# # Mixed precision fused layer norm.
# # =================================
# print('Loading fused layer norm...')
# extra_cuda_flags = ['-maxrregcount=50']
# sources = [srcpath / 'layer_norm_cuda.cpp',
# srcpath / 'layer_norm_cuda_kernel.cu']
# fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
# "fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)


def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
output = raw_output.split()
Expand All @@ -36,69 +103,10 @@ def get_cuda_bare_metal_version(cuda_dir):

return raw_output, bare_metal_major, bare_metal_minor

def create_build_dir(buildpath):

def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")

def load_scaled_upper_triang_masked_softmax_fusion_kernel():

print(f'\nLoading scaled_upper_triang_masked_softmax fusion kernel (this may take a minute or two)...')

# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')

srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'

create_build_dir(buildpath)

scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
name='scaled_upper_triang_masked_softmax_cuda',
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'],
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)

def load_scaled_masked_softmax_fusion_kernel():

print(f'\nLoading scaled_masked_softmax fusion kernel (this may take a minute or two)...')

# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')

srcpath = pathlib.Path(__file__).parent.absolute()
buildpath = srcpath / 'build'

create_build_dir(buildpath)

scaled_upper_triang_masked_softmax_cuda = cpp_extension.load(
name='scaled_masked_softmax_cuda',
sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu'],
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + cc_flag)
31 changes: 31 additions & 0 deletions megatron/fused_kernels/compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */



#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif

#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
Loading

0 comments on commit 39cc90a

Please sign in to comment.