Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def write_version_file():
]


def append_flags(flags, flags_to_append):
for flag in flags_to_append:
if not flag in flags:
flags.append(flag)
return flags


def get_extensions():
build_cuda = torch.cuda.is_available() or os.getenv('FORCE_CUDA',
'0') == '1'
Expand All @@ -71,9 +78,12 @@ def get_extensions():
nvcc_flags = []
else:
nvcc_flags = nvcc_flags.split(' ')
for flag in ['--expt-extended-lambda', '-Xcompiler', '-fopenmp']:
if not flag in nvcc_flags:
nvcc_flags.append(flag)
nvcc_flags = append_flags(nvcc_flags, ['--expt-extended-lambda', '-Xcompiler'])
if openmp:
if sys.platform == 'linux':
nvcc_flags = append_flags(nvcc_flags, ['-fopenmp'])
elif sys.platform == 'win32':
nvcc_flags = append_flags(nvcc_flags, ['/openmp'])
extra_compile_args = {
'cxx': [],
'nvcc': nvcc_flags,
Expand All @@ -86,10 +96,11 @@ def get_extensions():
cxx_flags = []
else:
cxx_flags = cxx_flags.split(' ')
if sys.platform == 'linux' and openmp:
for flag in ['-fopenmp']:
if not flag in cxx_flags:
cxx_flags.append(flag)
if openmp:
if sys.platform == 'linux':
cxx_flags = append_flags(cxx_flags, ['-fopenmp'])
elif sys.platform == 'win32':
cxx_flags = append_flags(cxx_flags, ['/openmp'])
extra_compile_args = {
'cxx': cxx_flags
}
Expand Down