From 3ed36dcb4dab7cda6c0765b1b9c65c3903120a64 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 8 Dec 2022 12:11:51 -0800 Subject: [PATCH] [BACKEND] MMA->DotOperand conversion for chain dot of float32 tensors (#962) Co-authored-by: Philippe Tillet --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 146 ++++++++++++------ python/tests/test_core.py | 25 +-- 2 files changed, 113 insertions(+), 58 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 965f13467..b3d9e172f 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -2682,63 +2682,24 @@ public: dstLayout.isa())) { return lowerDistributedToDistributed(op, adaptor, rewriter); } - // dot_op = #mma - // when #mma = MmaEncoding if (srcLayout.isa() && dstLayout.isa()) { - auto srcMmaLayout = srcLayout.cast(); - auto dstDotLayout = dstLayout.cast(); - if (srcMmaLayout.getWarpsPerCTA()[1] == 1 && - dstDotLayout.getOpIdx() == 0 && - dstDotLayout.getParent() == srcMmaLayout) { - // get source values - Location loc = op->getLoc(); - auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); - unsigned elems = getElemsPerThread(srcTy); - Type elemTy = - this->getTypeConverter()->convertType(srcTy.getElementType()); - // for the destination type, we need to pack values together - // so they can be consumed by tensor core operations - unsigned vecSize = - std::max(32 / elemTy.getIntOrFloatBitWidth(), 1); - Type vecTy = vec_ty(elemTy, vecSize); - SmallVector types(elems / vecSize, vecTy); - SmallVector vecVals; - for (unsigned i = 0; i < elems; i += vecSize) { - Value packed = rewriter.create(loc, vecTy); - for (unsigned j = 0; j < vecSize; j++) - packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); - vecVals.push_back(packed); - } - - // This needs to be ordered the same way that - // ldmatrix.x4 would order it - // TODO: this needs to be refactor so we don't - // implicitly depends on how emitOffsetsForMMAV2 - // is implemented - SmallVector reorderedVals; - for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(vecVals[i]); - reorderedVals.push_back(vecVals[i + 2]); - reorderedVals.push_back(vecVals[i + 1]); - reorderedVals.push_back(vecVals[i + 3]); - } - - // return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); - - Type structTy = - LLVM::LLVMStructType::getLiteral(this->getContext(), types); - Value view = - getStructFromElements(loc, reorderedVals, rewriter, structTy); - rewriter.replaceOp(op, view); - return success(); - } + return lowerMmaToDotOperand(op, adaptor, rewriter); } // TODO: to be implemented llvm_unreachable("unsupported layout conversion"); return failure(); } + static bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, + DotOperandEncodingAttr &dotOperandLayout) { + // dot_op = #mma + // when #mma = MmaEncoding + return mmaLayout.getWarpsPerCTA()[1] == 1 && + dotOperandLayout.getOpIdx() == 0 && + dotOperandLayout.getParent() == mmaLayout; + } + static void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, ArrayRef srcIndices, Value dst, @@ -3003,6 +2964,11 @@ private: lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const; + // mma -> dot_operand + LogicalResult lowerMmaToDotOperand(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const; + // shared -> dot_operand if the result layout is mma Value lowerSharedToDotOperandMMA( triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, @@ -3209,6 +3175,58 @@ LogicalResult ConvertLayoutOpConversion::lowerBlockedToShared( return success(); } +LogicalResult ConvertLayoutOpConversion::lowerMmaToDotOperand( + triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.src().getType().cast(); + auto dstTy = op.result().getType().cast(); + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto srcMmaLayout = srcLayout.cast(); + auto dstDotLayout = dstLayout.cast(); + if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) { + // get source values + auto vals = getElementsFromStruct(loc, adaptor.src(), rewriter); + unsigned elems = getElemsPerThread(srcTy); + Type elemTy = this->getTypeConverter()->convertType(srcTy.getElementType()); + // for the destination type, we need to pack values together + // so they can be consumed by tensor core operations + unsigned vecSize = + std::max(32 / elemTy.getIntOrFloatBitWidth(), 1); + Type vecTy = vec_ty(elemTy, vecSize); + SmallVector types(elems / vecSize, vecTy); + SmallVector vecVals; + for (unsigned i = 0; i < elems; i += vecSize) { + Value packed = rewriter.create(loc, vecTy); + for (unsigned j = 0; j < vecSize; j++) + packed = insert_element(vecTy, packed, vals[i + j], i32_val(j)); + vecVals.push_back(packed); + } + + // This needs to be ordered the same way that + // ldmatrix.x4 would order it + // TODO: this needs to be refactor so we don't + // implicitly depends on how emitOffsetsForMMAV2 + // is implemented + SmallVector reorderedVals; + for (unsigned i = 0; i < vecVals.size(); i += 4) { + reorderedVals.push_back(vecVals[i]); + reorderedVals.push_back(vecVals[i + 2]); + reorderedVals.push_back(vecVals[i + 1]); + reorderedVals.push_back(vecVals[i + 3]); + } + + // return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK); + + Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types); + Value view = getStructFromElements(loc, reorderedVals, rewriter, structTy); + rewriter.replaceOp(op, view); + return success(); + } + return failure(); +} + struct InsertSliceOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< @@ -4625,6 +4643,34 @@ class ConvertTritonGPUToLLVM : public ConvertTritonGPUToLLVMBase { private: + void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) { + // replace `mma -> dot_op` with `mma -> blocked -> dot_op` + // unless certain conditions are met + mod.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void { + OpBuilder builder(cvtOp); + auto srcType = cvtOp.getOperand().getType().cast(); + auto dstType = cvtOp.getType().cast(); + auto srcMma = + srcType.getEncoding().dyn_cast(); + auto dstDotOp = + dstType.getEncoding().dyn_cast(); + if (srcMma && dstDotOp && + !ConvertLayoutOpConversion::isMmaToDotShortcut(srcMma, dstDotOp)) { + auto tmpType = RankedTensorType::get( + dstType.getShape(), dstType.getElementType(), + triton::gpu::BlockedEncodingAttr::get( + mod.getContext(), srcType.getShape(), getSizePerThread(srcMma), + getOrder(srcMma), numWarps)); + auto tmp = builder.create( + cvtOp.getLoc(), tmpType, cvtOp.getOperand()); + auto newConvert = builder.create( + cvtOp.getLoc(), dstType, tmp); + cvtOp.replaceAllUsesWith(newConvert.getResult()); + cvtOp.erase(); + } + }); + } + void decomposeBlockedToDotOperand(ModuleOp mod) { // replace `blocked -> dot_op` with `blocked -> shared -> dot_op` // because the codegen doesn't handle `blocked -> dot_op` directly @@ -4771,6 +4817,8 @@ public: // separation between 1/4 is that, step 3 is out of the scope of Dialect // Conversion, thus we need to make sure the smem is not revised during the // conversion of step 4. + decomposeMmaToDotOperand(mod, numWarps); + decomposeBlockedToDotOperand(mod); decomposeInsertSliceAsyncOp(mod); diff --git a/python/tests/test_core.py b/python/tests/test_core.py index 95ffde257..3685eece0 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -1071,21 +1071,23 @@ def test_permute(dtype_str, shape, perm, device='cuda'): # # --------------- -@pytest.mark.parametrize("epilogue, allow_tf32, dtype", - [(epilogue, allow_tf32, dtype) +@pytest.mark.parametrize("M, N, K, epilogue, allow_tf32, dtype", + [(*shape, epilogue, allow_tf32, dtype) + for shape in [(64, 64, 64)] for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] for allow_tf32 in [True, False] - for dtype in ['float16'] + for dtype in ['float16', 'float32'] if not (allow_tf32 and (dtype in ['float16']))]) -def test_dot(epilogue, allow_tf32, dtype, device='cuda'): +def test_dot(M, N, K, epilogue, allow_tf32, dtype, device='cuda'): capability = torch.cuda.get_device_capability() - if capability[0] < 80: + if capability[0] < 8: if dtype == 'int8': pytest.skip("Only test int8 on devices with sm >= 80") elif dtype == 'float32' and allow_tf32: pytest.skip("Only test tf32 on devices with sm >= 80") - M, N, K = 64, 64, 64 + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + num_warps = 4 trans_a, trans_b = False, False @@ -1130,7 +1132,8 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): if CHAIN_DOT: # tl.store(Zs, z) # tl.debug_barrier() - z = tl.dot(z.to(tl.float16), tl.load(Ws)) + w = tl.load(Ws) + z = tl.dot(z.to(w.dtype), w) tl.store(Zs, z) # input rs = RandomState(17) @@ -1180,14 +1183,18 @@ def test_dot(epilogue, allow_tf32, dtype, device='cuda'): z_ref = np.matmul(z_ref, w) # compare # print(z_ref[:,0], z_tri[:,0]) - np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) + if dtype == 'float32': + # XXX: Somehow there's a larger difference when we use float32 + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01) # make sure ld/st are vectorized ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx if allow_tf32: assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx - elif dtype == 'float32': + elif dtype == 'float32' and allow_tf32: assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx elif dtype == 'int8': assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx