From 793012b4c45525a09e19eb52fee53c31a804ba16 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Fri, 9 Dec 2022 18:36:05 +0800 Subject: [PATCH] [Triton-MLIR][Backend] Fix mmav1 in case of numWarps > 1 (#972) --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 4 +++ python/tests/test_gemm.py | 30 +++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index b3d9e172f..c4421a72e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2886,11 +2886,15 @@ private: Value laneIdRem16Div4 = udiv(laneIdRem16, _4); Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2); Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2); + Value rowWarpOffset = mul(multiDimWarpId[0], _16); + Value colWarpOffset = mul(multiDimWarpId[1], _16); mmaRowIdx[0] = add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)), laneIdRem2); + mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset); mmaRowIdx[1] = add(mmaRowIdx[0], _2); mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2)); + mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset); mmaColIdx[1] = add(mmaColIdx[0], _1); mmaColIdx[2] = add(mmaColIdx[0], _8); mmaColIdx[3] = add(mmaColIdx[0], idx_val(9)); diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 8af333c70..cd3a7f805 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -302,10 +302,36 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): [32, 16, 32, 1, 32, 16, 32, False, False], [32, 32, 32, 1, 32, 32, 32, False, False], [128, 32, 32, 1, 128, 32, 32, False, False], - - # split-K + # # split-K [16, 16, 32, 1, 16, 16, 16, False, False], [64, 64, 128, 1, 64, 64, 32, False, False], + # numWarps > 1 + [32, 32, 64, 2, 32, 32, 32, False, False], + [64, 32, 64, 4, 64, 32, 64, False, False], + [128, 64, 128, 4, 128, 64, 128, False, False], + # [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue, hang on Volta + # K-Forloop + # [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads + [32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding + # [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k, hang on Volta + [64, 32, 128, 4, 64, 32, 64, False, False], + [128, 16, 128, 4, 128, 16, 32, False, False], + # [32, 16, 128, 4, 32, 16, 32, False, False], # hang on Volta + [32, 64, 128, 4, 32, 64, 32, False, False], + [32, 128, 256, 4, 32, 128, 64, False, False], + [64, 128, 64, 4, 64, 128, 32, False, False], + [64, 64, 128, 4, 64, 64, 32, False, False], + [128, 128, 64, 4, 128, 128, 32, False, False], + [128, 128, 128, 4, 128, 128, 32, False, False], + [128, 128, 256, 4, 128, 128, 64, False, False], + [128, 256, 128, 4, 128, 256, 32, False, False], + [256, 128, 64, 4, 256, 128, 16, False, False], + [128, 64, 128, 4, 128, 64, 32, False, False], + # [16, 16, 64, 4, 16, 16, 16, False, False], # hang on Volta + [32, 32, 64, 4, 32, 32, 32, False, False], + # trans + # [128, 64, 128, 4, 128, 64, 32, True, False], + # [128, 64, 128, 4, 128, 64, 32, False, True], ]) def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)