Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -994,8 +994,8 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
mHiddenSize, mInterSize, mNumExperts, mK,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mTotalTokens,
mHiddenSize, mHiddenSize, mInterSize, mNumExperts, mK,
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
Expand All @@ -1007,8 +1007,8 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture
mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr,
mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex,
ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize,
mHiddenSize, mInterSize, mNumExperts, mK,
mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mTotalTokens,
mHiddenSize, mHiddenSize, mInterSize, mNumExperts, mK,
mWorkspace + mWorkspaceSize * (mBufferIndex % mNumWorkspaceBuffers),
mFinalOutput + mFinalOutputSize * (mBufferIndex % mNumInputBuffers),
mSourceToExpandedMap + mSourceToExpandedMapSize * mBufferIndex, parallelism_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::gemm(__nv_fp8

template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a,
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k,
cudaStream_t stream, float const* scales_a, float const* scales_b)
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t expected_m, size_t shape_n,
size_t shape_k, cudaStream_t stream, float const* scales_a, float const* scales_b)
{
constexpr bool internal_quantize_a = !std::is_same_v<ElementA, __nv_fp8_e4m3>;
constexpr bool internal_quantize_b = !std::is_same_v<ElementB, __nv_fp8_e4m3>;
Expand Down Expand Up @@ -138,21 +138,21 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
{
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m, max_shape_m_4_align_,
max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
}
else if constexpr (std::is_same_v<ElementA, __nv_bfloat16> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
{
fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, per_token_per_128c_scales,
nullptr, fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets,
num_problems, expected_m_, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
num_problems, expected_m, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, shape_k, stream,
internal_quantize_a, internal_quantize_b);
}
else if constexpr (std::is_same_v<ElementA, __nv_fp8_e4m3> && std::is_same_v<ElementB, __nv_fp8_e4m3>)
{
fp8_grouped_gemm_run(nullptr, fp8_mat_a, per_token_per_128c_scales,
reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_,
reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, expected_m, max_shape_m_4_align_,
max_shape_m_32_align_padded_, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b);
}
else
Expand All @@ -164,6 +164,15 @@ void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void*
#endif
}

template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::moeGemm(void* mat_d, void const* mat_a,
void const* mat_b, int64_t const* problem_m_offsets, size_t num_problems, size_t shape_n, size_t shape_k,
cudaStream_t stream, float const* scales_a, float const* scales_b)
{
moeGemm(mat_d, mat_a, mat_b, problem_m_offsets, num_problems, expected_m_, shape_n, shape_k, stream, scales_a,
scales_b);
}

template <typename ElementA, typename ElementB, typename ElementD>
void CutlassFp8BlockScaleGemmRunner<ElementA, ElementB, ElementD>::strideBatchGemm(__nv_bfloat16* mat_d, int ld_d,
int stride_d, __nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b, int stride_b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class CutlassFp8BlockScaleGemmRunnerInterface
cudaStream_t stream)
= 0;

virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t expected_m, size_t shape_n, size_t shape_k, cudaStream_t stream,
float const* scales_a = nullptr, float const* scales_b = nullptr)
= 0;

virtual void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr,
float const* scales_b = nullptr)
Expand Down Expand Up @@ -95,6 +100,10 @@ class CutlassFp8BlockScaleGemmRunner : public CutlassFp8BlockScaleGemmRunnerInte
int ld_d, int shape_m, int shape_n, int shape_k, float const* scales_a, float const* scales_b,
cudaStream_t stream) override;

void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t expected_m, size_t shape_n, size_t shape_k, cudaStream_t stream,
float const* scales_a = nullptr, float const* scales_b = nullptr) override;

void moeGemm(void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets,
size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a = nullptr,
float const* scales_b = nullptr) override;
Expand Down
Loading