Skip to content

Commit

Permalink
improve python setup script
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Mar 19, 2023
1 parent 436fa10 commit cbedd7f
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 50 deletions.
2 changes: 2 additions & 0 deletions approxcdf/cpp_wrapper.pyx
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#cython: language_level=3

### BLAS and LAPACK
from scipy.linalg.cython_blas cimport (
ddot, daxpy, dgemv, dgemm, dtrmm, dtrsm
Expand Down
142 changes: 94 additions & 48 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@
import sys, os, subprocess
from os import environ

found_omp = True
def set_omp_false():
global found_omp
found_omp = False

## Modify this to make the output of the compilation tests more verbose
silent_tests = not (("verbose" in sys.argv)
or ("-verbose" in sys.argv)
or ("--verbose" in sys.argv))

## Workaround for python<=3.9 on windows
try:
EXIT_SUCCESS = os.EX_OK
except AttributeError:
EXIT_SUCCESS = 0

## https://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used
class build_ext_subclass( build_ext ):
def build_extensions(self):
Expand All @@ -13,9 +29,9 @@ def build_extensions(self):

if is_msvc:
for e in self.extensions:
e.extra_compile_args = ['/openmp', '/O2', '/std:c++14']
e.extra_compile_args = ['/openmp', '/O2', '/GL', '/std:c++14']
else:
if not self.check_for_variable_dont_set_march() and not self.is_arch_in_cflags():
if not self.check_for_variable_dont_set_march() and not self.check_cflags_contain_arch():
self.add_march_native()
self.set_cxxstd()
self.add_openmp_linkage()
Expand All @@ -28,25 +44,37 @@ def build_extensions(self):

build_ext.build_extensions(self)

def check_cflags_contain_arch(self):
if "CFLAGS" in os.environ:
arch_list = [
"-march", "-mcpu", "-mtune", "-msse", "-msse2", "-msse3",
"-mssse3", "-msse4", "-msse4a", "-msse4.1", "-msse4.2",
"-mavx", "-mavx2", "-mavx512"
]
for flag in arch_list:
if flag in os.environ["CFLAGS"]:
return True
return False

def check_for_variable_dont_set_march(self):
return "DONT_SET_MARCH" in os.environ

def add_march_native(self):
arg_march_native = "-march=native"
arg_mcpu_native = "-mcpu=native"
if self.test_supports_compile_arg(arg_march_native):
for e in self.extensions:
e.extra_compile_args.append(arg_march_native)
elif self.test_supports_compile_arg(arg_mcpu_native):
for e in self.extensions:
e.extra_compile_args.append(arg_mcpu_native)
args_march_native = ["-march=native", "-mcpu=native"]
for arg_march_native in args_march_native:
if self.test_supports_compile_arg(arg_march_native):
for e in self.extensions:
e.extra_compile_args.append(arg_march_native)
break

def add_link_time_optimization(self):
arg_lto = "-flto"
if self.test_supports_compile_arg(arg_lto):
for e in self.extensions:
e.extra_compile_args.append(arg_lto)
e.extra_link_args.append(arg_lto)
args_lto = ["-flto=auto", "-flto"]
for arg_lto in args_lto:
if self.test_supports_compile_arg(arg_lto):
for e in self.extensions:
e.extra_compile_args.append(arg_lto)
e.extra_link_args.append(arg_lto)
break

def add_no_math_errno(self):
arg_fnme = "-fno-math-errno"
Expand All @@ -60,6 +88,12 @@ def add_no_trapping_math(self):
for e in self.extensions:
e.extra_compile_args.append(arg_fntm)

def add_O3(self):
O3 = "-O3"
if self.test_supports_compile_arg(O3):
for e in self.extensions:
e.extra_compile_args.append(O3)

def set_cxxstd(self):
arg_std17 = "-std=c++17"
arg_std14 = "-std=gnu++14"
Expand All @@ -75,45 +109,59 @@ def set_cxxstd(self):

def add_openmp_linkage(self):
arg_omp1 = "-fopenmp"
arg_omp2 = "-qopenmp"
arg_omp3 = "-xopenmp"
arg_omp4 = "-fiopenmp"
arg_omp2 = "-fopenmp=libomp"
args_omp3 = ["-fopenmp=libomp", "-lomp"]
arg_omp4 = "-qopenmp"
arg_omp5 = "-xopenmp"
is_apple = sys.platform[:3].lower() == "dar"
args_apple_omp = ["-Xclang", "-fopenmp", "-lomp"]
args_apple_omp2 = ["-Xclang", "-fopenmp", "-L/usr/local/lib", "-lomp", "-I/usr/local/include"]
has_brew_omp = False
if is_apple:
res_brew_pref = subprocess.run(["brew", "--prefix", "libomp"], capture_output=silent_tests)
if res_brew_pref.returncode == EXIT_SUCCESS:
has_brew_omp = True
brew_omp_prefix = res_brew_pref.stdout.decode().strip()
args_apple_omp3 = ["-Xclang", "-fopenmp", f"-L{brew_omp_prefix}/lib", "-lomp", f"-I{brew_omp_prefix}/include"]


if self.test_supports_compile_arg(arg_omp1, with_omp=True):
for e in self.extensions:
e.extra_compile_args.append(arg_omp1)
e.extra_link_args.append(arg_omp1)
elif (sys.platform[:3].lower() == "dar") and self.test_supports_compile_arg(args_apple_omp, with_omp=True):
elif is_apple and self.test_supports_compile_arg(args_apple_omp, with_omp=True):
for e in self.extensions:
e.extra_compile_args += ["-Xclang", "-fopenmp"]
e.extra_link_args += ["-lomp"]
elif (sys.platform[:3].lower() == "dar") and self.test_supports_compile_arg(args_apple_omp2, with_omp=True):
elif is_apple and self.test_supports_compile_arg(args_apple_omp2, with_omp=True):
for e in self.extensions:
e.extra_compile_args += ["-Xclang", "-fopenmp"]
e.extra_link_args += ["-L/usr/local/lib", "-lomp"]
e.include_dirs += ["/usr/local/include"]
elif is_apple and has_brew_omp and self.test_supports_compile_arg(args_apple_omp3, with_omp=True):
for e in self.extensions:
e.extra_compile_args += ["-Xclang", "-fopenmp"]
e.extra_link_args += [f"-L{brew_omp_prefix}/lib", "-lomp"]
e.include_dirs += [f"{brew_omp_prefix}/include"]
elif self.test_supports_compile_arg(arg_omp2, with_omp=True):
for e in self.extensions:
e.extra_compile_args.append(arg_omp2)
e.extra_link_args.append(arg_omp2)
elif self.test_supports_compile_arg(arg_omp3, with_omp=True):
e.extra_compile_args += ["-fopenmp=libomp"]
e.extra_link_args += ["-fopenmp"]
elif self.test_supports_compile_arg(args_omp3, with_omp=True):
for e in self.extensions:
e.extra_compile_args.append(arg_omp3)
e.extra_link_args.append(arg_omp3)
e.extra_compile_args += ["-fopenmp=libomp"]
e.extra_link_args += ["-fopenmp", "-lomp"]
elif self.test_supports_compile_arg(arg_omp4, with_omp=True):
for e in self.extensions:
e.extra_compile_args.append(arg_omp4)
e.extra_link_args.append(arg_omp4)
elif self.test_supports_compile_arg(arg_omp5, with_omp=True):
for e in self.extensions:
e.extra_compile_args.append(arg_omp5)
e.extra_link_args.append(arg_omp5)
else:
set_omp_false()

def add_O3(self):
O3 = "-O3"
if self.test_supports_compile_arg(O3):
for e in self.extensions:
e.extra_compile_args.append(O3)

def test_supports_compile_arg(self, comm, with_omp=False):
is_supported = False
try:
Expand All @@ -132,13 +180,12 @@ def test_supports_compile_arg(self, comm, with_omp=False):
cmd = self.compiler.compiler_cxx
except Exception:
cmd = self.compiler.compiler_cxx
val_good = subprocess.call(cmd + [fname])
if with_omp:
with open(fname, "w") as ftest:
ftest.write(u"#include <omp.h>\nint main(int argc, char**argv) {return 0;}\n")
try:
val = subprocess.call(cmd + comm + [fname])
is_supported = (val == val_good)
val = subprocess.run(cmd + comm + [fname], capture_output=silent_tests).returncode
is_supported = (val == EXIT_SUCCESS)
except Exception:
is_supported = False
except Exception:
Expand All @@ -149,16 +196,6 @@ def test_supports_compile_arg(self, comm, with_omp=False):
pass
return is_supported

def is_arch_in_cflags(self):
arch_flags = '-march -mtune -msse -msse2 -msse3 -mssse3 -msse4 -msse4a -msse4.1 -msse4.2 -mavx -mavx2 -mcpu'.split()
for env_var in ("CFLAGS", "CXXFLAGS"):
if env_var in os.environ:
for flag in arch_flags:
if flag in os.environ[env_var]:
return True

return False

def add_restrict_qualifier(self):
supports_restrict = False
try:
Expand All @@ -175,12 +212,11 @@ def add_restrict_qualifier(self):
cmd = self.compiler.compiler_cxx
except Exception:
cmd = self.compiler.compiler_cxx
val_good = subprocess.call(cmd + [fname])
try:
with open(fname, "w") as ftest:
ftest.write(u"int main(int argc, char**argv) {double *__restrict x = 0; return 0;}\n")
val = subprocess.call(cmd + [fname])
supports_restrict = (val == val_good)
val = subprocess.run(cmd + comm + [fname], capture_output=silent_tests).returncode
is_supported = (val == EXIT_SUCCESS)
except Exception:
return None
except Exception:
Expand All @@ -200,7 +236,6 @@ def add_restrict_qualifier(self):
version = '0.0.1',
description = 'Approximations for fast CDF calculation of MVN distributions',
author = 'David Cortes',
author_email = '[email protected]',
url = 'https://github.com/david-cortes/approxcdf',
keywords = ['cdf', 'tvbs', 'multivariate-normal', 'mvn'],
cmdclass = {'build_ext': build_ext_subclass},
Expand All @@ -224,3 +259,14 @@ def add_restrict_qualifier(self):
define_macros = [("FOR_PYTHON", None)]
)]
)

if not found_omp:
omp_msg = "\n\n\nCould not detect OpenMP. Package will be built without multi-threading capabilities. "
omp_msg += " To enable multi-threading, first install OpenMP"
if (sys.platform[:3] == "dar"):
omp_msg += " - for macOS: 'brew install libomp'\n"
else:
omp_msg += " modules for your compiler. "

omp_msg += "Then reinstall this package from scratch: 'pip install --upgrade --no-deps --force-reinstall git+https://www.github.com/david-cortes/approxcdf.git'.\n"
warnings.warn(omp_msg)
2 changes: 0 additions & 2 deletions src/plackett.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
case it ends up giving reasonable results with an error of around 1e-8,
but still very slow. */

const int four = 4;

/* https://stackoverflow.com/questions/2937702/i-want-to-find-determinant-of-4x4-matrix-in-c-sharp */
double determinant4by4tri(const double x_tri[6])
{
Expand Down

0 comments on commit cbedd7f

Please sign in to comment.