@@ -53,6 +53,13 @@ def write_version_file():
5353]
5454
5555
56+ def append_flags (flags , flags_to_append ):
57+ for flag in flags_to_append :
58+ if not flag in flags :
59+ flags .append (flag )
60+ return flags
61+
62+
5663def get_extensions ():
5764 build_cuda = torch .cuda .is_available () or os .getenv ('FORCE_CUDA' ,
5865 '0' ) == '1'
@@ -71,9 +78,12 @@ def get_extensions():
7178 nvcc_flags = []
7279 else :
7380 nvcc_flags = nvcc_flags .split (' ' )
74- for flag in ['--expt-extended-lambda' , '-Xcompiler' , '-fopenmp' ]:
75- if not flag in nvcc_flags :
76- nvcc_flags .append (flag )
81+ nvcc_flags = append_flags (nvcc_flags , ['--expt-extended-lambda' , '-Xcompiler' ])
82+ if openmp :
83+ if sys .platform == 'linux' :
84+ nvcc_flags = append_flags (nvcc_flags , ['-fopenmp' ])
85+ elif sys .platform == 'win32' :
86+ nvcc_flags = append_flags (nvcc_flags , ['/openmp' ])
7787 extra_compile_args = {
7888 'cxx' : [],
7989 'nvcc' : nvcc_flags ,
@@ -86,10 +96,11 @@ def get_extensions():
8696 cxx_flags = []
8797 else :
8898 cxx_flags = cxx_flags .split (' ' )
89- if sys .platform == 'linux' and openmp :
90- for flag in ['-fopenmp' ]:
91- if not flag in cxx_flags :
92- cxx_flags .append (flag )
99+ if openmp :
100+ if sys .platform == 'linux' :
101+ cxx_flags = append_flags (cxx_flags , ['-fopenmp' ])
102+ elif sys .platform == 'win32' :
103+ cxx_flags = append_flags (cxx_flags , ['/openmp' ])
93104 extra_compile_args = {
94105 'cxx' : cxx_flags
95106 }
0 commit comments