[Triton-MLIR][Backend] Fix mmav1 in case of numWarps > 1 (#972)
This commit is contained in:
@@ -2886,11 +2886,15 @@ private:
|
|||||||
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
|
Value laneIdRem16Div4 = udiv(laneIdRem16, _4);
|
||||||
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
|
Value laneIdRem16Div4Rem2 = urem(laneIdRem16Div4, _2);
|
||||||
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
|
Value laneIdRem4Div2 = udiv(urem(laneId, _4), _2);
|
||||||
|
Value rowWarpOffset = mul(multiDimWarpId[0], _16);
|
||||||
|
Value colWarpOffset = mul(multiDimWarpId[1], _16);
|
||||||
mmaRowIdx[0] =
|
mmaRowIdx[0] =
|
||||||
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
|
add(add(mul(laneIdDiv16, _8), mul(laneIdRem16Div4Rem2, _4)),
|
||||||
laneIdRem2);
|
laneIdRem2);
|
||||||
|
mmaRowIdx[0] = add(mmaRowIdx[0], rowWarpOffset);
|
||||||
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
mmaRowIdx[1] = add(mmaRowIdx[0], _2);
|
||||||
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
|
mmaColIdx[0] = add(mul(laneIdRem16Div8, _4), mul(laneIdRem4Div2, _2));
|
||||||
|
mmaColIdx[0] = add(mmaColIdx[0], colWarpOffset);
|
||||||
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
mmaColIdx[1] = add(mmaColIdx[0], _1);
|
||||||
mmaColIdx[2] = add(mmaColIdx[0], _8);
|
mmaColIdx[2] = add(mmaColIdx[0], _8);
|
||||||
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
|
mmaColIdx[3] = add(mmaColIdx[0], idx_val(9));
|
||||||
|
@@ -302,10 +302,36 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32):
|
|||||||
[32, 16, 32, 1, 32, 16, 32, False, False],
|
[32, 16, 32, 1, 32, 16, 32, False, False],
|
||||||
[32, 32, 32, 1, 32, 32, 32, False, False],
|
[32, 32, 32, 1, 32, 32, 32, False, False],
|
||||||
[128, 32, 32, 1, 128, 32, 32, False, False],
|
[128, 32, 32, 1, 128, 32, 32, False, False],
|
||||||
|
# # split-K
|
||||||
# split-K
|
|
||||||
[16, 16, 32, 1, 16, 16, 16, False, False],
|
[16, 16, 32, 1, 16, 16, 16, False, False],
|
||||||
[64, 64, 128, 1, 64, 64, 32, False, False],
|
[64, 64, 128, 1, 64, 64, 32, False, False],
|
||||||
|
# numWarps > 1
|
||||||
|
[32, 32, 64, 2, 32, 32, 32, False, False],
|
||||||
|
[64, 32, 64, 4, 64, 32, 64, False, False],
|
||||||
|
[128, 64, 128, 4, 128, 64, 128, False, False],
|
||||||
|
# [16, 16, 16, 16, 16, 16, 16, False, False], # wpt overflow issue, hang on Volta
|
||||||
|
# K-Forloop
|
||||||
|
# [16, 16, 64, 4, 8, 8, 8, False, False], # Wrap threads
|
||||||
|
[32, 32, 64, 4, 32, 32, 32, False, False], # Single shared encoding
|
||||||
|
# [16, 16, 128, 4, 16, 16, 16, False, False], # Single shared encoding and small k, hang on Volta
|
||||||
|
[64, 32, 128, 4, 64, 32, 64, False, False],
|
||||||
|
[128, 16, 128, 4, 128, 16, 32, False, False],
|
||||||
|
# [32, 16, 128, 4, 32, 16, 32, False, False], # hang on Volta
|
||||||
|
[32, 64, 128, 4, 32, 64, 32, False, False],
|
||||||
|
[32, 128, 256, 4, 32, 128, 64, False, False],
|
||||||
|
[64, 128, 64, 4, 64, 128, 32, False, False],
|
||||||
|
[64, 64, 128, 4, 64, 64, 32, False, False],
|
||||||
|
[128, 128, 64, 4, 128, 128, 32, False, False],
|
||||||
|
[128, 128, 128, 4, 128, 128, 32, False, False],
|
||||||
|
[128, 128, 256, 4, 128, 128, 64, False, False],
|
||||||
|
[128, 256, 128, 4, 128, 256, 32, False, False],
|
||||||
|
[256, 128, 64, 4, 256, 128, 16, False, False],
|
||||||
|
[128, 64, 128, 4, 128, 64, 32, False, False],
|
||||||
|
# [16, 16, 64, 4, 16, 16, 16, False, False], # hang on Volta
|
||||||
|
[32, 32, 64, 4, 32, 32, 32, False, False],
|
||||||
|
# trans
|
||||||
|
# [128, 64, 128, 4, 128, 64, 32, True, False],
|
||||||
|
# [128, 64, 128, 4, 128, 64, 32, False, True],
|
||||||
])
|
])
|
||||||
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B):
|
||||||
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B)
|
||||||
|
Reference in New Issue
Block a user