-
Notifications
You must be signed in to change notification settings - Fork 8
/
setup.py
32 lines (27 loc) · 1012 Bytes
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright (c) 2021-present, Zhuang AI Group.
# All rights reserved.
import torch
from setuptools import setup, Extension, find_packages
from torch.utils import cpp_extension
compile_args = {"cxx": [], "nvcc": [] }
if torch.__version__ < "1.8":
version = torch.__version__.split('.')
compile_args['cxx'] += ["-DTORCH_VERSION_MAJOR={}".format(version[0])]
compile_args['cxx'] += ["-DTORCH_VERSION_MINOR={}".format(version[1])]
setup(
name = 'Mesa',
version = '1.0',
packages=find_packages(),
ext_modules=[
cpp_extension.CppExtension(
'mesa.native',
['native.cpp'],
extra_compile_args=compile_args,
),
cpp_extension.CUDAExtension(
'mesa.cpp_extension.quantization',
['mesa/cpp_extension/quantization.cc',
'mesa/cpp_extension/quantization_cuda_kernel.cu']
),
],
cmdclass={'build_ext': cpp_extension.BuildExtension})