From 13644e7ac42063a17fe6ee231e6c0464a81839dc Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Fri, 9 Dec 2022 12:52:43 +0800 Subject: [PATCH] adapt isMMAv1Row in backend (#969) --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 4 +-- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 28 ++++++++++++------- python/tests/test_gemm.py | 8 ++++++ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index f00c387ce..cd58e5858 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -434,11 +434,11 @@ section 9.7.13.4.1 for more details. "Attribute":$parent), [{ if(parent.isa() && parent.cast().getVersion() == 1){ - llvm::errs() << "DotOperand for MMAv1 must have isMMAv1Row field\n"; + llvm::report_fatal_error("DotOperand for MMAv1 must have isMMAv1Row field"); return {}; } Attribute none; - return $_get(context, opIdx, parent, none); + return $_get(context, opIdx, parent, none); }]> ]; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index fc8e2f508..89917a494 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -3428,12 +3428,16 @@ Value ConvertLayoutOpConversion::lowerSharedToDotOperandMMA( } else if (!isOuter && mmaLayout.getVersion() == 1 && isHMMA) { // tensor core v1 DotOpMmaV1ConversionHelper helper(mmaLayout); - bool isMMAv1Row = dotOperandLayout.getIsMMAv1Row().cast().getValue(); - auto srcSharedLayout = src.getType().cast().getEncoding().cast(); + bool isMMAv1Row = + dotOperandLayout.getIsMMAv1Row().cast().getValue(); + auto srcSharedLayout = src.getType() + .cast() + .getEncoding() + .cast(); // Can only convert [1, 0] to row or [0, 1] to col for now - if((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) || - (srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)){ + if ((srcSharedLayout.getOrder()[0] == 1 && !isMMAv1Row) || + (srcSharedLayout.getOrder()[0] == 0 && isMMAv1Row)) { llvm::errs() << "Unsupported Shared -> DotOperand[MMAv1] conversion\n"; return Value(); } @@ -3550,6 +3554,14 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, .cast() .getEncoding() .cast(); + auto ALayout = A.getType() + .cast() + .getEncoding() + .cast(); + auto BLayout = B.getType() + .cast() + .getEncoding() + .cast(); auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); @@ -3561,12 +3573,8 @@ DotOpConversion::convertMMA884(triton::DotOp op, DotOpAdaptor adaptor, auto DShape = DTensorTy.getShape(); auto wpt = mmaLayout.getWarpsPerCTA(); - // TODO[Superjomn]: order cannot accessed in DotOp. - SmallVector AOrder({1, 0}); - SmallVector BOrder({1, 0}); - - bool isARow = AOrder[0] != 0; - bool isBRow = BOrder[0] != 0; + bool isARow = ALayout.getIsMMAv1Row().cast().getValue(); + bool isBRow = BLayout.getIsMMAv1Row().cast().getValue(); bool isAVec4 = !isARow && AShape[isARow] <= 16; // fp16*4 = 16bytes bool isBVec4 = isBRow && BShape[isBRow] <= 16; // TODO[Superjomn]: ld.v4 is not supported. diff --git a/python/tests/test_gemm.py b/python/tests/test_gemm.py index 8af333c70..012e0771d 100644 --- a/python/tests/test_gemm.py +++ b/python/tests/test_gemm.py @@ -303,9 +303,17 @@ def test_gemm_fp32(M, N, K, num_warps, block_M, block_N, block_K, allow_tf32): [32, 32, 32, 1, 32, 32, 32, False, False], [128, 32, 32, 1, 128, 32, 32, False, False], + [128, 32, 32, 1, 128, 32, 32, True, False], + [128, 32, 32, 1, 128, 32, 32, True, True], + # split-K [16, 16, 32, 1, 16, 16, 16, False, False], [64, 64, 128, 1, 64, 64, 32, False, False], + + [16, 16, 32, 1, 16, 16, 16, True, False], + [16, 16, 32, 1, 16, 16, 16, True, True], + [64, 64, 128, 1, 64, 64, 32, True, True], ]) def test_gemm_for_mmav1(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B): test_gemm(SIZE_M, SIZE_N, SIZE_K, NUM_WARPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, TRANS_A, TRANS_B) +