Skip to content

Commit 1c7af00

Browse files
authored
gemm int8 quantization (#5706)
* quantize gemm * write gemm quantize scales * update doc * less openmp args * x86 riscv fallback * skip gemm vulkan int8 * fix noint8 test, fix arm bf16 test * enable vfpv4 on neon build only * fix gemm vulkan without C * fp16 pack8 output * enable elempack=8 only for asimdhp+ * tiled gemm int8 test * opt arm64 tiles, fix asimdhp dispatch
1 parent 9b5f6a3 commit 1c7af00

24 files changed

+36265
-173
lines changed

CMakeLists.txt

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,25 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm")
162162
endif()
163163

164164
if(CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32)
165-
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
166-
set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4")
167-
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
165+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s, _a, _b; _s = vmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM_NEON)
168166

169-
unset(CMAKE_REQUIRED_FLAGS)
170-
else()
171-
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4")
172-
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
167+
if(NCNN_COMPILER_SUPPORT_ARM_NEON)
168+
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
169+
set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4")
170+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
173171

174-
if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4)
175-
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee")
176-
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
177-
endif()
172+
unset(CMAKE_REQUIRED_FLAGS)
173+
else()
174+
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4")
175+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
178176

179-
unset(CMAKE_REQUIRED_FLAGS)
177+
if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4)
178+
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee")
179+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
180+
endif()
181+
182+
unset(CMAKE_REQUIRED_FLAGS)
183+
endif()
180184
endif()
181185

182186
if(NCNN_COMPILER_SUPPORT_ARM_VFPV4 OR NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)

cmake/ncnn_add_layer.cmake

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -144,25 +144,25 @@ macro(ncnn_add_layer class)
144144
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
145145
ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__")
146146
endif()
147-
if(NCNN_AVX512VNNI)
147+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
148148
ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__")
149149
endif()
150-
if(NCNN_AVX512BF16)
150+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
151151
ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__")
152152
endif()
153-
if(NCNN_AVX512FP16)
153+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
154154
ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__")
155155
endif()
156-
if(NCNN_AVXVNNI)
156+
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
157157
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
158158
endif()
159-
if(NCNN_AVX2)
159+
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
160160
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
161161
endif()
162-
if(NCNN_XOP)
162+
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
163163
ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__XOP__")
164164
endif()
165-
if(NCNN_F16C)
165+
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
166166
ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__F16C__")
167167
endif()
168168
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")
@@ -175,25 +175,25 @@ macro(ncnn_add_layer class)
175175
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
176176
ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__")
177177
endif()
178-
if(NCNN_AVX512VNNI)
178+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
179179
ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__")
180180
endif()
181-
if(NCNN_AVX512BF16)
181+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
182182
ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__")
183183
endif()
184-
if(NCNN_AVX512FP16)
184+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
185185
ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__")
186186
endif()
187-
if(NCNN_AVXVNNI)
187+
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
188188
ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 -mfma -mf16c -mavxvnni /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__")
189189
endif()
190-
if(NCNN_AVX2)
190+
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
191191
ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 -mfma -mf16c /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__")
192192
endif()
193-
if(NCNN_XOP)
193+
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
194194
ncnn_add_arch_opt_source(${class} xop "/arch:AVX -mxop /D__SSSE3__ /D__SSE4_1__ /D__XOP__")
195195
endif()
196-
if(NCNN_F16C)
196+
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
197197
ncnn_add_arch_opt_source(${class} f16c "/arch:AVX -mf16c /D__SSSE3__ /D__SSE4_1__ /D__F16C__")
198198
endif()
199199
else()
@@ -206,25 +206,25 @@ macro(ncnn_add_layer class)
206206
if(NCNN_RUNTIME_CPU AND NCNN_AVX)
207207
ncnn_add_arch_opt_layer(${class} avx "-mavx")
208208
endif()
209-
if(NCNN_AVX512VNNI)
209+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512VNNI)
210210
ncnn_add_arch_opt_source(${class} avx512vnni "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512vnni")
211211
endif()
212-
if(NCNN_AVX512BF16)
212+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512BF16)
213213
ncnn_add_arch_opt_source(${class} avx512bf16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512bf16")
214214
endif()
215-
if(NCNN_AVX512FP16)
215+
if(NCNN_RUNTIME_CPU AND NCNN_AVX512FP16)
216216
ncnn_add_arch_opt_source(${class} avx512fp16 "-mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mfma -mf16c -mavx512fp16")
217217
endif()
218-
if(NCNN_AVXVNNI)
218+
if(NCNN_RUNTIME_CPU AND NCNN_AVXVNNI)
219219
ncnn_add_arch_opt_source(${class} avxvnni "-mavx2 -mfma -mf16c -mavxvnni")
220220
endif()
221-
if(NCNN_AVX2)
221+
if(NCNN_RUNTIME_CPU AND NCNN_AVX2)
222222
ncnn_add_arch_opt_source(${class} avx2 "-mavx2 -mfma -mf16c")
223223
endif()
224-
if(NCNN_XOP)
224+
if(NCNN_RUNTIME_CPU AND NCNN_XOP)
225225
ncnn_add_arch_opt_source(${class} xop "-mavx -mxop")
226226
endif()
227-
if(NCNN_F16C)
227+
if(NCNN_RUNTIME_CPU AND NCNN_F16C)
228228
ncnn_add_arch_opt_source(${class} f16c "-mavx -mf16c")
229229
endif()
230230
endif()
@@ -254,28 +254,28 @@ macro(ncnn_add_layer class)
254254
if(NCNN_ARM82)
255255
ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC")
256256
endif()
257-
if(NCNN_ARM82DOT)
257+
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
258258
ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD")
259259
endif()
260-
if(NCNN_ARM82FP16FML)
260+
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
261261
ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML")
262262
endif()
263-
if(NCNN_ARM84BF16)
263+
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
264264
ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC")
265265
endif()
266-
if(NCNN_ARM84I8MM)
266+
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
267267
ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8")
268268
endif()
269269
# TODO add support for sve family
270-
if(NCNN_ARM86SVE)
270+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
271271
endif()
272-
if(NCNN_ARM86SVE2)
272+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
273273
endif()
274-
if(NCNN_ARM86SVEBF16)
274+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
275275
endif()
276-
if(NCNN_ARM86SVEI8MM)
276+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
277277
endif()
278-
if(NCNN_ARM86SVEF32MM)
278+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
279279
endif()
280280
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")
281281
if(NCNN_VFPV4)
@@ -284,28 +284,28 @@ macro(ncnn_add_layer class)
284284
if(NCNN_ARM82)
285285
ncnn_add_arch_opt_source(${class} asimdhp "/arch:armv8.2 -march=armv8.2-a+fp16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC")
286286
endif()
287-
if(NCNN_ARM82DOT)
287+
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
288288
ncnn_add_arch_opt_source(${class} asimddp "/arch:armv8.2 -march=armv8.2-a+fp16+dotprod /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD")
289289
endif()
290-
if(NCNN_ARM82FP16FML)
290+
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
291291
ncnn_add_arch_opt_source(${class} asimdfhm "/arch:armv8.2 -march=armv8.2-a+fp16+fp16fml /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_FP16_FML")
292292
endif()
293-
if(NCNN_ARM84BF16)
293+
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
294294
ncnn_add_arch_opt_source(${class} bf16 "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+bf16 /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_BF16_VECTOR_ARITHMETIC")
295295
endif()
296-
if(NCNN_ARM84I8MM)
296+
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
297297
ncnn_add_arch_opt_source(${class} i8mm "/arch:armv8.4 -march=armv8.4-a+fp16+dotprod+i8mm /D__ARM_FEATURE_FP16_VECTOR_ARITHMETIC /D__ARM_FEATURE_DOTPROD /D__ARM_FEATURE_FP16_FML /D__ARM_FEATURE_MATMUL_INT8")
298298
endif()
299299
# TODO add support for sve family
300-
if(NCNN_ARM86SVE)
300+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
301301
endif()
302-
if(NCNN_ARM86SVE2)
302+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
303303
endif()
304-
if(NCNN_ARM86SVEBF16)
304+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
305305
endif()
306-
if(NCNN_ARM86SVEI8MM)
306+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
307307
endif()
308-
if(NCNN_ARM86SVEF32MM)
308+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
309309
endif()
310310
else()
311311
if(NCNN_VFPV4)
@@ -314,31 +314,31 @@ macro(ncnn_add_layer class)
314314
if(NCNN_ARM82)
315315
ncnn_add_arch_opt_source(${class} asimdhp "-march=armv8.2-a+fp16")
316316
endif()
317-
if(NCNN_ARM82DOT)
317+
if(NCNN_RUNTIME_CPU AND NCNN_ARM82DOT)
318318
ncnn_add_arch_opt_source(${class} asimddp "-march=armv8.2-a+fp16+dotprod")
319319
endif()
320-
if(NCNN_ARM82FP16FML)
320+
if(NCNN_RUNTIME_CPU AND NCNN_ARM82FP16FML)
321321
ncnn_add_arch_opt_source(${class} asimdfhm "-march=armv8.2-a+fp16+fp16fml")
322322
endif()
323-
if(NCNN_ARM84BF16)
323+
if(NCNN_RUNTIME_CPU AND NCNN_ARM84BF16)
324324
ncnn_add_arch_opt_source(${class} bf16 "-march=armv8.4-a+fp16+dotprod+bf16")
325325
endif()
326-
if(NCNN_ARM84I8MM)
326+
if(NCNN_RUNTIME_CPU AND NCNN_ARM84I8MM)
327327
ncnn_add_arch_opt_source(${class} i8mm "-march=armv8.4-a+fp16+dotprod+i8mm")
328328
endif()
329-
if(NCNN_ARM86SVE)
329+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE)
330330
ncnn_add_arch_opt_source(${class} sve "-march=armv8.6-a+fp16+dotprod+sve")
331331
endif()
332-
if(NCNN_ARM86SVE2)
332+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVE2)
333333
ncnn_add_arch_opt_source(${class} sve2 "-march=armv8.6-a+fp16+dotprod+sve2")
334334
endif()
335-
if(NCNN_ARM86SVEBF16)
335+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEBF16)
336336
ncnn_add_arch_opt_source(${class} svebf16 "-march=armv8.6-a+fp16+dotprod+sve+bf16")
337337
endif()
338-
if(NCNN_ARM86SVEI8MM)
338+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEI8MM)
339339
ncnn_add_arch_opt_source(${class} svei8mm "-march=armv8.6-a+fp16+dotprod+sve+i8mm")
340340
endif()
341-
if(NCNN_ARM86SVEF32MM)
341+
if(NCNN_RUNTIME_CPU AND NCNN_ARM86SVEF32MM)
342342
ncnn_add_arch_opt_source(${class} svef32mm "-march=armv8.6-a+fp16+dotprod+sve+f32mm")
343343
endif()
344344
endif()

docs/developer-guide/operators.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -942,15 +942,18 @@ y = (gemm(a, b) + c * beta) * alpha
942942
| 12 | output_elempack | int | 0 | |
943943
| 13 | output_elemtype | int | 0 | |
944944
| 14 | output_transpose | int| 0 | |
945+
| 18 | int8_scale_term | int | 0 | |
945946
| 20 | constant_TILE_M | int | 0 | |
946947
| 21 | constant_TILE_N | int | 0 | |
947948
| 22 | constant_TILE_K | int | 0 | |
948949

949950
| weight | type | shape |
950951
| ------------- | ----- | --------------------- |
951-
| A_data | float | [M, K] or [K, M] |
952-
| B_data | float | [N, K] or [K, N] |
952+
| A_data | float/fp16/int8 | [M, K] or [K, M] |
953+
| B_data | float/fp16/int8 | [N, K] or [K, N] |
953954
| C_data | float | [1], [M] or [N] or [1, M] or [N,1] or [N, M] |
955+
| A_data_int8_scales| float | [M] |
956+
| B_data_int8_scales| float | [1] |
954957

955958
# GridSample
956959
```

0 commit comments

Comments
 (0)