Skip to content

Commit fd07ee3

Browse files
committed
Fix openmp flags
ghstack-source-id: 22bd33d Pull Request resolved: #54
1 parent ff9c39f commit fd07ee3

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

setup.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def write_version_file():
5353
]
5454

5555

56+
def append_flags(flags, flags_to_append):
57+
for flag in flags_to_append:
58+
if not flag in flags:
59+
flags.append(flag)
60+
return flags
61+
62+
5663
def get_extensions():
5764
build_cuda = torch.cuda.is_available() or os.getenv('FORCE_CUDA',
5865
'0') == '1'
@@ -71,9 +78,12 @@ def get_extensions():
7178
nvcc_flags = []
7279
else:
7380
nvcc_flags = nvcc_flags.split(' ')
74-
for flag in ['--expt-extended-lambda', '-Xcompiler', '-fopenmp']:
75-
if not flag in nvcc_flags:
76-
nvcc_flags.append(flag)
81+
nvcc_flags = append_flags(nvcc_flags, ['--expt-extended-lambda', '-Xcompiler'])
82+
if openmp:
83+
if sys.platform == 'linux':
84+
nvcc_flags = append_flags(nvcc_flags, ['-fopenmp'])
85+
elif sys.platform == 'win32':
86+
nvcc_flags = append_flags(nvcc_flags, ['/openmp'])
7787
extra_compile_args = {
7888
'cxx': [],
7989
'nvcc': nvcc_flags,
@@ -86,10 +96,11 @@ def get_extensions():
8696
cxx_flags = []
8797
else:
8898
cxx_flags = cxx_flags.split(' ')
89-
if sys.platform == 'linux' and openmp:
90-
for flag in ['-fopenmp']:
91-
if not flag in cxx_flags:
92-
cxx_flags.append(flag)
99+
if openmp:
100+
if sys.platform == 'linux':
101+
cxx_flags = append_flags(cxx_flags, ['-fopenmp'])
102+
elif sys.platform == 'win32':
103+
cxx_flags = append_flags(cxx_flags, ['/openmp'])
93104
extra_compile_args = {
94105
'cxx': cxx_flags
95106
}

0 commit comments

Comments
 (0)