From b5aafb0dab6d8d5404581b9b6e1d53acbcccd426 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 21 Dec 2022 12:52:32 -0800 Subject: [PATCH 1/7] [FRONTEND] Fix 3d indexing (#1006) --- python/test/unit/language/test_core.py | 7 +++---- python/triton/language/core.py | 4 +--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 203995bc8..ac381d50a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -491,10 +491,9 @@ def make_ptr_str(name, shape): # TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` @pytest.mark.parametrize("expr, dtype_str", [ (f'x[{s}]', d) - for s in ['None, :', ':, None'] - # FIXME: 3d indexing doesn't work - #'None, :, :', - # ':, :, None'] + for s in ['None, :', ':, None', + 'None, :, :', + ':, :, None'] for d in ['int32', 'uint32', 'uint16'] ]) def test_index1d(expr, dtype_str, device='cuda'): diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 2abc82b0c..15dd8462a 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -596,11 +596,9 @@ class tensor: if isinstance(slices, slice): slices = [slices] ret = self - n_inserted = 0 for dim, sl in enumerate(slices): if isinstance(sl, constexpr) and sl.value is None: - ret = semantic.expand_dims(ret, dim + n_inserted, _builder) - n_inserted += 1 + ret = semantic.expand_dims(ret, dim, _builder) elif sl == slice(None, None, None): pass else: From 925d3d7f98bac449a775cf183f173bb5168418d9 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 21 Dec 2022 17:57:33 -0800 Subject: [PATCH 2/7] [FRONTEND] Export `broadcast` and `broadcast_to` in `triton.language` (#1007) --- python/triton/language/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 4618794a0..0e7875e42 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -20,6 +20,8 @@ from .core import ( atomic_xor, bfloat16, block_type, + broadcast, + broadcast_to, cat, cdiv, constexpr, @@ -105,6 +107,8 @@ __all__ = [ "atomic_xor", "bfloat16", "block_type", + "broadcast", + "broadcast_to", "builtin", "cat", "cdiv", From fd2da4aff608ffada0840d17a71baf33bb7e6c88 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 22 Dec 2022 00:48:46 -0800 Subject: [PATCH 3/7] [BACKEND] Support splat constant on the DotOperandLayout (#1008) --- .../TritonGPUToLLVM/ViewOpToLLVM.cpp | 39 +++++++++++++++++++ python/test/unit/language/test_core.py | 28 ++++++------- 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 21dccb963..7c33f795c 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -4,11 +4,13 @@ using namespace mlir; using namespace mlir::triton; +using ::mlir::LLVM::DotOpFMAConversionHelper; using ::mlir::LLVM::DotOpMmaV1ConversionHelper; using ::mlir::LLVM::DotOpMmaV2ConversionHelper; using ::mlir::LLVM::getElementsFromStruct; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; using ::mlir::LLVM::getStructFromElements; +using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::triton::gpu::getElemsPerThread; struct SplatOpConversion @@ -38,6 +40,11 @@ struct SplatOpConversion LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); return getStructFromElements(loc, elems, rewriter, structTy); + } else if (auto dotLayout = + tensorTy.getEncoding() + .dyn_cast()) { + return convertSplatLikeOpWithDotOperandLayout( + dotLayout, resType, elemType, constVal, typeConverter, rewriter, loc); } else if (auto mmaLayout = tensorTy.getEncoding().dyn_cast()) { return convertSplatLikeOpWithMmaLayout( @@ -48,6 +55,38 @@ struct SplatOpConversion return {}; } + static Value convertSplatLikeOpWithDotOperandLayout( + const triton::gpu::DotOperandEncodingAttr &layout, Type resType, + Type elemType, Value constVal, TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, Location loc) { + auto tensorTy = resType.cast(); + auto shape = tensorTy.getShape(); + auto parent = layout.getParent(); + int numElems{}; + if (auto mmaLayout = parent.dyn_cast()) { + if (mmaLayout.isAmpere()) { + numElems = layout.getOpIdx() == 0 + ? MMA16816ConversionHelper::getANumElemsPerThread( + tensorTy, mmaLayout.getWarpsPerCTA()[0]) + : MMA16816ConversionHelper::getBNumElemsPerThread( + tensorTy, mmaLayout.getWarpsPerCTA()[1]); + } else if (mmaLayout.isVolta()) { + DotOpMmaV1ConversionHelper helper(mmaLayout); + numElems = layout.getOpIdx() == 0 + ? helper.numElemsPerThreadA(shape, {0, 1}) + : helper.numElemsPerThreadB(shape, {0, 1}); + } + } else if (auto blockedLayout = parent.dyn_cast()) { + numElems = DotOpFMAConversionHelper::getNumElemsPerThread(shape, layout); + } else { + assert(false && "Unsupported layout found"); + } + auto structTy = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), SmallVector(numElems, elemType)); + return getStructFromElements(loc, SmallVector(numElems, constVal), + rewriter, structTy); + } + static Value convertSplatLikeOpWithMmaLayout( const MmaEncodingAttr &layout, Type resType, Type elemType, Value constVal, TypeConverter *typeConverter, diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ac381d50a..862687f6f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1227,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi elif dtype == 'int8': assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx -# FIXME: Unsupported layout found in ConvertSplatLikeOp -# def test_dot_without_load(): -# @triton.jit -# def kernel(out): -# pid = tl.program_id(axis=0) -# a = tl.zeros((32, 32), tl.float32) -# b = tl.zeros((32, 32), tl.float32) -# c = tl.zeros((32, 32), tl.float32) -# c = tl.dot(a, b) -# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] -# tl.store(pout, c) -# -# out = torch.ones((32, 32), dtype=torch.float32, device="cuda") -# kernel[(1,)](out) + +def test_dot_without_load(): + @triton.jit + def kernel(out): + pid = tl.program_id(axis=0) + a = tl.zeros((32, 32), tl.float32) + b = tl.zeros((32, 32), tl.float32) + c = tl.zeros((32, 32), tl.float32) + c = tl.dot(a, b) + pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] + tl.store(pout, c) + + out = torch.ones((32, 32), dtype=torch.float32, device="cuda") + kernel[(1,)](out) # --------------- # test arange From 2ba74d27291efe014be47cc75814df972e0073a7 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 28 Dec 2022 12:24:01 +0800 Subject: [PATCH 4/7] [OPTIMIZER] Update the versionMinor in MMA layout for volta (#1014) Continue the work https://github.com/openai/triton/pull/990 # Background The `versionMinor` in MmaEncodingAttr holds some states of DotOp's operands in Volta, while such operands will be modified by some patterns, making the states out-of-date. This PR helps to correct the states. # Implementation It adds three new patterns: 1. `CollectMmaToUpdateForVolta` helps to collect and build a map holding the MmaEncodingAttr instances with wrong states and create new correct ones for them, 2. `UpdateMMAVersionMinorForVolta` helps to replace the Ops generating the wrong MmaEncodingAttr instances with new correct ones, currently it supports the following Ops a. `convert_layout[X -> mma]` b. `arith.constant SplatAttr : !tensor` c. `dot ... : !tensor` # Limitation This PR chooses the mapping way to bypass the IR walk complexity from the circular dependency between dot_operand[parent] and mma. We use the MmaEncodingAttr instance as the mapping key, but there might be multiple DotOp holding different DotOprand(IsMMAv1Row) that have the same wrong MmaEncodingAttr instance. To make each DotOp's (wrong) MmaEncodingAttr unique, we might need an ID field to MmaEncodingAttr. --- lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h | 44 ++-- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 227 +++++++++++++++++- test/TritonGPU/combine.mlir | 36 ++- 3 files changed, 281 insertions(+), 26 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h index fa94ba3f3..2a360f4b1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h @@ -204,7 +204,12 @@ struct DotOpMmaV1ConversionHelper { offA[i] = add(mul(offA0I, strideA0), mul(offA1, strideA1)); } - Type f16x2Ty = vec_ty(f16_ty, 2); + Type elemX2Ty = vec_ty(f16_ty, 2); + Type elemPtrTy = ptr_ty(f16_ty); + if (tensorTy.getElementType().isBF16()) { + elemX2Ty = vec_ty(i16_ty, 2); + elemPtrTy = ptr_ty(i16_ty); + } // prepare arguments SmallVector ptrA(numPtrA); @@ -213,30 +218,28 @@ struct DotOpMmaV1ConversionHelper { for (int i = 0; i < numPtrA; i++) ptrA[i] = gep(ptr_ty(f16_ty), smemBase, offA[i]); - Type f16PtrTy = ptr_ty(f16_ty); - auto ld = [&](decltype(has) &vals, int m, int k, Value val0, Value val1) { vals[{m, k}] = {val0, val1}; }; auto loadA = [&](int m, int k) { int offidx = (isARow ? k / 4 : m) % numPtrA; - Value thePtrA = gep(f16PtrTy, smemBase, offA[offidx]); + Value thePtrA = gep(elemPtrTy, smemBase, offA[offidx]); int stepAM = isARow ? m : m / numPtrA * numPtrA; int stepAK = isARow ? k / (numPtrA * vecA) * (numPtrA * vecA) : k; Value offset = add(mul(i32_val(stepAM * strideRepM), strideAM), mul(i32_val(stepAK), strideAK)); - Value pa = gep(f16PtrTy, thePtrA, offset); + Value pa = gep(elemPtrTy, thePtrA, offset); Type aPtrTy = ptr_ty(vec_ty(i32_ty, std::max(vecA / 2, 1)), 3); Value ha = load(bitcast(pa, aPtrTy)); // record lds that needs to be moved - Value ha00 = bitcast(extract_element(ha, i32_val(0)), f16x2Ty); - Value ha01 = bitcast(extract_element(ha, i32_val(1)), f16x2Ty); + Value ha00 = bitcast(extract_element(ha, i32_val(0)), elemX2Ty); + Value ha01 = bitcast(extract_element(ha, i32_val(1)), elemX2Ty); ld(has, m, k, ha00, ha01); if (vecA > 4) { - Value ha10 = bitcast(extract_element(ha, i32_val(2)), f16x2Ty); - Value ha11 = bitcast(extract_element(ha, i32_val(3)), f16x2Ty); + Value ha10 = bitcast(extract_element(ha, i32_val(2)), elemX2Ty); + Value ha11 = bitcast(extract_element(ha, i32_val(3)), elemX2Ty); if (isARow) ld(has, m, k + 4, ha10, ha11); else @@ -256,7 +259,7 @@ struct DotOpMmaV1ConversionHelper { elems.push_back(item.second.second); } - Type resTy = struct_ty(SmallVector(elems.size(), f16x2Ty)); + Type resTy = struct_ty(SmallVector(elems.size(), elemX2Ty)); Value res = getStructFromElements(loc, elems, rewriter, resTy); return res; } @@ -319,8 +322,12 @@ struct DotOpMmaV1ConversionHelper { offB[i] = add(mul(offB0I, strideB0), mul(offB1, strideB1)); } - Type f16PtrTy = ptr_ty(f16_ty); - Type f16x2Ty = vec_ty(f16_ty, 2); + Type elemPtrTy = ptr_ty(f16_ty); + Type elemX2Ty = vec_ty(f16_ty, 2); + if (tensorTy.getElementType().isBF16()) { + elemPtrTy = ptr_ty(i16_ty); + elemX2Ty = vec_ty(i16_ty, 2); + } SmallVector ptrB(numPtrB); ValueTable hbs; @@ -339,17 +346,17 @@ struct DotOpMmaV1ConversionHelper { int stepBK = isBRow ? K : K / (numPtrB * vecB) * (numPtrB * vecB); Value offset = add(mul(i32_val(stepBN * strideRepN), strideBN), mul(i32_val(stepBK), strideBK)); - Value pb = gep(f16PtrTy, thePtrB, offset); + Value pb = gep(elemPtrTy, thePtrB, offset); Value hb = load(bitcast(pb, ptr_ty(vec_ty(i32_ty, std::max(vecB / 2, 1)), 3))); // record lds that needs to be moved - Value hb00 = bitcast(extract_element(hb, i32_val(0)), f16x2Ty); - Value hb01 = bitcast(extract_element(hb, i32_val(1)), f16x2Ty); + Value hb00 = bitcast(extract_element(hb, i32_val(0)), elemX2Ty); + Value hb01 = bitcast(extract_element(hb, i32_val(1)), elemX2Ty); ld(hbs, n, K, hb00, hb01); if (vecB > 4) { - Value hb10 = bitcast(extract_element(hb, i32_val(2)), f16x2Ty); - Value hb11 = bitcast(extract_element(hb, i32_val(3)), f16x2Ty); + Value hb10 = bitcast(extract_element(hb, i32_val(2)), elemX2Ty); + Value hb11 = bitcast(extract_element(hb, i32_val(3)), elemX2Ty); if (isBRow) ld(hbs, n + 1, K, hb10, hb11); else @@ -369,8 +376,7 @@ struct DotOpMmaV1ConversionHelper { elems.push_back(item.second.first); elems.push_back(item.second.second); } - Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); - Type resTy = struct_ty(SmallVector(elems.size(), fp16x2Ty)); + Type resTy = struct_ty(SmallVector(elems.size(), elemX2Ty)); Value res = getStructFromElements(loc, elems, rewriter, resTy); return res; } diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 7dcdc0162..5e391e3c1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -22,6 +22,10 @@ using namespace mlir; namespace { #include "TritonGPUCombine.inc" +using triton::DotOp; +using triton::gpu::ConvertLayoutOp; +using triton::gpu::DotOperandEncodingAttr; +using triton::gpu::MmaEncodingAttr; // ----------------------------------------------------------------------------- // @@ -1019,6 +1023,7 @@ public: dstDotOperandLayout.getIsMMAv1Row().cast().getValue(); if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row)) return failure(); + auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row); auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get( op->getContext(), dstDotOperandLayout.getOpIdx(), @@ -1060,7 +1065,8 @@ public: auto dotOp = cast(op); // TODO: Check data-types and SM compatibility auto oldRetType = dotOp.getResult().getType().cast(); - if (oldRetType.getEncoding().isa()) + if (!oldRetType.getEncoding() || + oldRetType.getEncoding().isa()) return failure(); auto AType = dotOp.getOperand(0).getType().cast(); @@ -1170,7 +1176,8 @@ public: for (size_t i = 0; i < newInitArgs.size(); i++) { auto initArg = newInitArgs[i]; auto regionArg = forOp.getRegionIterArgs()[i]; - if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType()) { + if (newInitArgs[i].getType() != forOp.getRegionIterArgs()[i].getType() || + newInitArgs[i].getType() != forOp.getResultTypes()[i]) { shouldRematerialize = true; break; } @@ -1186,15 +1193,207 @@ public: BlockAndValueMapping mapping; for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); for (Operation &op : forOp.getBody()->getOperations()) { - Operation *newOp = rewriter.clone(op, mapping); + rewriter.clone(op, mapping); } rewriter.replaceOp(forOp, newForOp.getResults()); return success(); } }; +// This pattern collects the wrong Mma those need to update and create the right +// ones for each. +class CollectMmaToUpdateForVolta : public mlir::RewritePattern { + DenseMap &mmaToUpdate; + +public: + CollectMmaToUpdateForVolta( + mlir::MLIRContext *ctx, + DenseMap &mmaToUpdate) + : mlir::RewritePattern(triton::DotOp::getOperationName(), 1, ctx), + mmaToUpdate(mmaToUpdate) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + + auto dotOp = cast(op); + auto *ctx = dotOp->getContext(); + auto AT = dotOp.a().getType().cast(); + auto BT = dotOp.b().getType().cast(); + auto DT = dotOp.d().getType().cast(); + if (!DT.getEncoding()) + return failure(); + auto mmaLayout = DT.getEncoding().dyn_cast(); + if (!(mmaLayout && mmaLayout.isVolta())) + return failure(); + + // Has processed. + if (mmaToUpdate.count(mmaLayout)) + return failure(); + + auto dotOperandA = AT.getEncoding().cast(); + auto dotOperandB = BT.getEncoding().cast(); + bool isARow = dotOperandA.getIsMMAv1Row().cast().getValue(); + bool isBRow = dotOperandB.getIsMMAv1Row().cast().getValue(); + auto [isARow_, isBRow_, isAVec4, isBVec4] = + mmaLayout.decodeVoltaLayoutStates(); + if (isARow_ == isARow && isBRow_ == isBRow) { + return failure(); // No need to update + } + + auto newMmaLayout = MmaEncodingAttr::get( + ctx, mmaLayout.getVersionMajor(), mmaLayout.getWarpsPerCTA(), + AT.getShape(), BT.getShape(), isARow, isBRow); + + // Collect the wrong MMA Layouts, and mark need to update. + mmaToUpdate.try_emplace(mmaLayout, newMmaLayout); + + return failure(); + } +}; + +// Correct the versionMinor field in MmaEncodingAttr for Volta. +class UpdateMMAVersionMinorForVolta : public mlir::RewritePattern { + const DenseMap &mmaToUpdate; + enum class Kind { + kUnk, + kCvtToMma, + kCvtToDotOp, + kDot, + kConstant, + }; + mutable Kind rewriteKind{Kind::kUnk}; + +public: + UpdateMMAVersionMinorForVolta( + mlir::MLIRContext *ctx, llvm::StringRef opName, + const DenseMap &mmaToUpdate) + : RewritePattern(opName, 1 /*benefit*/, ctx), mmaToUpdate(mmaToUpdate) {} + + LogicalResult match(Operation *op) const override { + MmaEncodingAttr mma; + if (mmaToUpdate.empty()) + return failure(); + if (op->getNumResults() != 1) + return failure(); + auto tensorTy = op->getResult(0).getType().dyn_cast(); + if (!tensorTy) + return failure(); + + // ConvertLayoutOp + if (auto cvt = llvm::dyn_cast(op)) { + // cvt X -> dot_operand + if (auto dotOperand = + tensorTy.getEncoding().dyn_cast()) { + mma = dotOperand.getParent().dyn_cast(); + rewriteKind = Kind::kCvtToDotOp; + if (mma && mmaToUpdate.count(mma)) + return success(); + } + if ((mma = tensorTy.getEncoding().dyn_cast())) { + // cvt X -> mma + rewriteKind = Kind::kCvtToMma; + if (mma && mmaToUpdate.count(mma)) + return success(); + } + } else if (auto dot = llvm::dyn_cast(op)) { + // DotOp + mma = dot.d() + .getType() + .cast() + .getEncoding() + .dyn_cast(); + rewriteKind = Kind::kDot; + } else if (auto constant = llvm::dyn_cast(op)) { + // ConstantOp + mma = tensorTy.getEncoding().dyn_cast(); + rewriteKind = Kind::kConstant; + } + + return success(mma && mmaToUpdate.count(mma)); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + switch (rewriteKind) { + case Kind::kDot: + rewriteDot(op, rewriter); + break; + case Kind::kConstant: + rewriteConstant(op, rewriter); + break; + case Kind::kCvtToDotOp: + rewriteCvtDotOp(op, rewriter); + break; + case Kind::kCvtToMma: + rewriteCvtToMma(op, rewriter); + break; + default: + llvm::report_fatal_error("Not supported rewrite kind"); + } + } + +private: + void rewriteCvtDotOp(Operation *op, PatternRewriter &rewriter) const { + auto *ctx = op->getContext(); + auto cvt = llvm::cast(op); + auto tensorTy = cvt.result().getType().cast(); + auto dotOperand = tensorTy.getEncoding().cast(); + MmaEncodingAttr newMma = + mmaToUpdate.lookup(dotOperand.getParent().cast()); + auto newDotOperand = DotOperandEncodingAttr::get( + ctx, dotOperand.getOpIdx(), newMma, dotOperand.getIsMMAv1Row()); + auto newTensorTy = RankedTensorType::get( + tensorTy.getShape(), tensorTy.getElementType(), newDotOperand); + rewriter.replaceOpWithNewOp(op, newTensorTy, + cvt.getOperand()); + } + + void rewriteDot(Operation *op, PatternRewriter &rewriter) const { + auto *ctx = op->getContext(); + auto dot = llvm::cast(op); + auto tensorTy = dot.d().getType().cast(); + auto mma = tensorTy.getEncoding().cast(); + auto newMma = mmaToUpdate.lookup(mma); + auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newMma); + rewriter.replaceOpWithNewOp(op, newTensorTy, dot.a(), dot.b(), + dot.c(), dot.allowTF32()); + } + + void rewriteCvtToMma(Operation *op, PatternRewriter &rewriter) const { + auto *ctx = op->getContext(); + auto cvt = llvm::cast(op); + auto tensorTy = cvt.result().getType().cast(); + auto mma = tensorTy.getEncoding().cast(); + auto newMma = mmaToUpdate.lookup(mma); + auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newMma); + rewriter.replaceOpWithNewOp(op, newTensorTy, + cvt.getOperand()); + } + + void rewriteConstant(Operation *op, PatternRewriter &rewriter) const { + auto *ctx = op->getContext(); + auto constant = llvm::cast(op); + auto tensorTy = constant.getResult().getType().dyn_cast(); + auto mma = tensorTy.getEncoding().cast(); + auto newMma = mmaToUpdate.lookup(mma); + auto newTensorTy = RankedTensorType::get(tensorTy.getShape(), + tensorTy.getElementType(), newMma); + if (auto attr = constant.getValue().dyn_cast()) { + auto newRet = + SplatElementsAttr::get(newTensorTy, attr.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newTensorTy, newRet); + return; + } + + assert(false && "Not supported ConstantOp value type"); + } +}; + } // namespace #define GEN_PASS_CLASSES @@ -1229,6 +1428,28 @@ public: signalPassFailure(); } + llvm::DenseMap mmaToUpdate; + { + mlir::RewritePatternSet patterns(context); + patterns.add(context, mmaToUpdate); + if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) + signalPassFailure(); + } + { + mlir::RewritePatternSet patterns(context); + patterns.add( + context, DotOp::getOperationName(), mmaToUpdate); + patterns.add( + context, ConvertLayoutOp::getOperationName(), mmaToUpdate); + patterns.add( + context, arith::ConstantOp::getOperationName(), mmaToUpdate); + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + + if (applyPatternsAndFoldGreedily(m, std::move(patterns), config).failed()) + signalPassFailure(); + } + mlir::RewritePatternSet loopFixup(context); loopFixup.add(context); if (applyPatternsAndFoldGreedily(m, std::move(loopFixup)).failed()) { diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index b4d2da376..a040873d8 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-combine 2>&1 | FileCheck %s #layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -7,7 +7,6 @@ // CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> - func @cst() -> tensor<1024xi32, #layout1> { %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> %1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> @@ -62,9 +61,9 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { // CHECK-LABEL: transpose func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout - // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> + // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> - // CHECK: tt.store {{.*}}, [[cvt_val]], %cst_1 : tensor<64x64xf32, [[col_layout]]> + // CHECK: tt.store {{.*}}, [[cvt_val]], {{%cst.*}} : tensor<64x64xf32, [[col_layout]]> // CHECK: return %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> %cst_0 = arith.constant dense : tensor<64x64xi1, #blocked1> @@ -184,3 +183,32 @@ func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr return } + + +// ----- + +// check the UpdateMMAVersionMinorForVolta pattern +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}> +#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[1,1]}> +// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor, +// and the pattern should update the versionMinor. +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}> +// It creates a new MMA layout to fit with $a and $b's dot_operand +// CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 11, warpsPerCTA = [1, 1]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: dot_mmav1 + func @dot_mmav1(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) -> tensor<16x16xf32, #blocked0> { + %C = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #blocked0> + %AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_a> + %BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #dot_operand_b> + %CC = triton_gpu.convert_layout %C : (tensor<16x16xf32, #blocked0>) -> tensor<16x16xf32, #mma0> + + // CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[new_mma]], isMMAv1Row = true}>> * tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[new_mma]], isMMAv1Row = true}>> -> tensor<16x16xf32, [[new_mma]]> + %D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> + %res = triton_gpu.convert_layout %D : (tensor<16x16xf32, #mma0>) -> tensor<16x16xf32, #blocked0> + + return %res : tensor<16x16xf32, #blocked0> + } +} From 1d3029faf8e4bff85b1d203028cf72203ad8f7e5 Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Fri, 30 Dec 2022 03:19:59 +0800 Subject: [PATCH 5/7] [Backend] Add value cache in emitting indices calculation and some refinement (#1018) 1, add explicit value cache in emitting indices calculation; 2, move the indices calculation emitting logics into ConvertTritonGPUOpToLLVMPatternBase to avoid the redundant build cost by templates. Refer to the discussion in this thread by @LyricZhao : https://triton-lang.slack.com/archives/C042VBSQWNS/p1671336755922969 --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 15 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.h | 4 +- .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 31 +- .../TritonGPUToLLVM/LoadStoreOpToLLVM.h | 12 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 16 +- .../TritonGPUToLLVM/ReduceOpToLLVM.h | 11 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 24 +- .../TritonGPUToLLVM/TritonGPUToLLVM.h | 11 +- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 323 ++++++++++++------ .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 19 +- test/Conversion/tritongpu_to_llvm.mlir | 69 +++- 11 files changed, 355 insertions(+), 180 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1a85ca5da..cd18ed751 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -53,9 +53,6 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, auto wordTy = vec_ty(elemTy, minVec); auto elemPtrTy = ptr_ty(elemTy); - // TODO: [goostavz] We should make a cache for the calculation of - // emitBaseIndexForBlockedLayout in case backend compiler not being able to - // optimize that SmallVector srcShapePerCTA = getShapePerCTA(srcBlockedLayout); SmallVector reps{ceil(srcShape[0], srcShapePerCTA[0]), ceil(srcShape[1], srcShapePerCTA[1])}; @@ -182,7 +179,7 @@ private: unsigned rank = shape.size(); if (auto blockedLayout = layout.dyn_cast()) { auto multiDimOffsetFirstElem = - emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); + emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape); SmallVector multiDimOffset(rank); SmallVector multiDimElemId = getMultiDimIndex( elemId, getSizePerThread(layout), getOrder(layout)); @@ -501,8 +498,8 @@ private: auto srcStrides = getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter); - auto srcIndices = emitBaseIndexForBlockedLayout(loc, rewriter, - srcBlockedLayout, srcShape); + auto srcIndices = + emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape); storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst, smemBase, elemTy, loc, rewriter); @@ -680,7 +677,9 @@ private: void populateConvertLayoutOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, PatternBenefit benefit) { + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { patterns.add(typeConverter, allocation, smem, - benefit); + indexCacheInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h index ebf943b6f..ec435b2ab 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h @@ -19,6 +19,8 @@ void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, void populateConvertLayoutOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, PatternBenefit benefit); + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 97ce9457a..92b11a94c 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -14,7 +14,7 @@ using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; // Contains some helper functions for both Load and Store conversions. -struct LoadStoreConversionBase : public ConvertTritonGPUOpToLLVMPatternBase { +struct LoadStoreConversionBase { explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) : axisAnalysisPass(axisAnalysisPass) {} @@ -640,7 +640,7 @@ struct InsertSliceOpConversion auto llSrc = adaptor.source(); auto srcIndices = - emitBaseIndexForBlockedLayout(loc, rewriter, srcLayout, srcShape); + emitBaseIndexForLayout(loc, rewriter, srcLayout, srcShape); storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, elemTy, loc, rewriter); // Barrier is not necessary. @@ -657,12 +657,12 @@ struct InsertSliceAsyncOpConversion using ConvertTritonGPUOpToLLVMPattern< triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; - InsertSliceAsyncOpConversion(LLVMTypeConverter &converter, - const Allocation *allocation, Value smem, - AxisInfoAnalysis &axisAnalysisPass, - PatternBenefit benefit) + InsertSliceAsyncOpConversion( + LLVMTypeConverter &converter, const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( - converter, allocation, smem, benefit), + converter, allocation, smem, indexCacheInfo, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult @@ -865,12 +865,12 @@ struct InsertSliceAsyncOpConversion } }; -void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit) { +void populateLoadStoreOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, @@ -878,7 +878,8 @@ void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, patterns.add(typeConverter, allocation, smem, axisInfoAnalysis, benefit); patterns.add(typeConverter, allocation, smem, - benefit); + indexCacheInfo, benefit); patterns.add(typeConverter, allocation, smem, - axisInfoAnalysis, benefit); + indexCacheInfo, axisInfoAnalysis, + benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h index 96c2f1afd..b5042019e 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h @@ -6,11 +6,11 @@ using namespace mlir; using namespace mlir::triton; -void populateLoadStoreOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit); +void populateLoadStoreOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 5f055fa6f..69abd889b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -164,7 +164,7 @@ private: auto srcValues = getElementsFromStruct(loc, adaptor.operand(), rewriter); SmallVector> offset = - emitOffsetForBlockedLayout(srcLayout, srcShape); + emitOffsetForLayout(srcLayout, srcShape); std::map, Value> accs; std::map, Value> accIndices; @@ -479,10 +479,12 @@ private: } }; -void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit) { - patterns.add(typeConverter, allocation, smem, benefit); +void populateReduceOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { + patterns.add(typeConverter, allocation, smem, + indexCacheInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h index fc4c5145c..f2c0af463 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h @@ -6,10 +6,11 @@ using namespace mlir; using namespace mlir::triton; -void populateReduceOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit); +void populateReduceOpToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); #endif \ No newline at end of file diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 3fcb83d95..2261688f0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -63,6 +63,7 @@ struct BroadcastOpConversion auto srcShape = srcTy.getShape(); auto resultShape = resultTy.getShape(); unsigned rank = srcTy.getRank(); + assert(rank == resultTy.getRank()); auto order = triton::gpu::getOrder(srcLayout); auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape); @@ -272,9 +273,13 @@ struct PrintfOpConversion struct MakeRangeOpConversion : public ConvertTritonGPUOpToLLVMPattern { - MakeRangeOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertTritonGPUOpToLLVMPattern(converter, - benefit) {} + MakeRangeOpConversion( + LLVMTypeConverter &converter, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern( + converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo, + benefit) {} LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, @@ -500,11 +505,12 @@ struct AsyncWaitOpConversion } }; -void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit) { +void populateTritonGPUToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit) { patterns.add(typeConverter, benefit); patterns.add(typeConverter, allocation, smem, benefit); @@ -515,7 +521,7 @@ void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); + patterns.add(typeConverter, indexCacheInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); } \ No newline at end of file diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index e96330176..2a6e22bf0 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -6,10 +6,11 @@ using namespace mlir; using namespace mlir::triton; -void populateTritonGPUToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, - PatternBenefit benefit); +void populateTritonGPUToLLVMPatterns( + mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + int numWarps, AxisInfoAnalysis &axisInfoAnalysis, + const Allocation *allocation, Value smem, + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, + PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 6020c9617..7f11a762e 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -18,7 +18,6 @@ using ::mlir::LLVM::SharedMemoryObject; using ::mlir::triton::gpu::BlockedEncodingAttr; using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; - // FuncOpConversion/FuncOpConversionBase is borrowed from // https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L276 // since it is not exposed on header files in mlir v14 @@ -128,7 +127,60 @@ protected: } }; -struct ConvertTritonGPUOpToLLVMPatternBase { +using IndexCacheKeyT = std::pair>; + +struct CacheKeyDenseMapInfo { + static IndexCacheKeyT getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return std::make_pair( + mlir::Attribute(static_cast(pointer)), + SmallVector{}); + } + static IndexCacheKeyT getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return std::make_pair( + mlir::Attribute(static_cast(pointer)), + SmallVector{std::numeric_limits::max()}); + } + static unsigned getHashValue(IndexCacheKeyT key) { + return llvm::hash_combine( + mlir::hash_value(key.first), + llvm::hash_combine_range(key.second.begin(), key.second.end())); + } + static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) { + return LHS == RHS; + } +}; + +class ConvertTritonGPUOpToLLVMPatternBase { +public: + // Two levels of value cache in emitting indices calculation: + // Key: pair + struct IndexCacheInfo { + DenseMap, CacheKeyDenseMapInfo> + *baseIndexCache; + DenseMap>, + CacheKeyDenseMapInfo> *indexCache; + OpBuilder::InsertPoint *indexInsertPoint; + }; + + explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter) + : converter(&typeConverter) {} + + explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem) + : converter(&typeConverter), allocation(allocation), smem(smem) {} + + explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem, + IndexCacheInfo indexCacheInfo) + : converter(&typeConverter), indexCacheInfo(indexCacheInfo), + allocation(allocation), smem(smem) {} + + LLVMTypeConverter *getTypeConverter() const { return converter; } + static Value getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, @@ -139,25 +191,6 @@ struct ConvertTritonGPUOpToLLVMPatternBase { LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types); return getStructFromElements(loc, elems, rewriter, structTy); } -}; - -template -class ConvertTritonGPUOpToLLVMPattern - : public ConvertOpToLLVMPattern, - public ConvertTritonGPUOpToLLVMPatternBase { -public: - using OpAdaptor = typename SourceOp::Adaptor; - - explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit) {} - - explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, - const Allocation *allocation, - Value smem, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit), - allocation(allocation), smem(smem) {} Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); @@ -169,6 +202,23 @@ public: return threadId; } + // ----------------------------------------------------------------------- + // Shared memory utilities + // ----------------------------------------------------------------------- + template + Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, + T value) const { + + auto ptrTy = LLVM::LLVMPointerType::get( + this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); + auto bufferId = allocation->getBufferId(value); + assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); + size_t offset = allocation->getOffset(bufferId); + Value offVal = idx_val(offset); + Value base = gep(ptrTy, smem, offVal); + return base; + } + // ----------------------------------------------------------------------- // Utilities // ----------------------------------------------------------------------- @@ -242,6 +292,116 @@ public: return ret; } + struct SmallVectorKeyInfo { + static unsigned getHashValue(const SmallVector &key) { + return llvm::hash_combine_range(key.begin(), key.end()); + } + static bool isEqual(const SmallVector &lhs, + const SmallVector &rhs) { + return lhs == rhs; + } + static SmallVector getEmptyKey() { + return SmallVector(); + } + static SmallVector getTombstoneKey() { + return {std::numeric_limits::max()}; + } + }; + + // ----------------------------------------------------------------------- + // Get offsets / indices for any layout + // ----------------------------------------------------------------------- + + SmallVector emitBaseIndexForLayout(Location loc, + ConversionPatternRewriter &rewriter, + const Attribute &layout, + ArrayRef shape) const { + IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape)); + auto cache = indexCacheInfo.baseIndexCache; + assert(cache && "baseIndexCache is nullptr"); + auto insertPt = indexCacheInfo.indexInsertPoint; + if (cache->count(key) > 0) { + return cache->lookup(key); + } else { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + restoreInsertionPointIfSet(insertPt, rewriter); + SmallVector result; + if (auto blockedLayout = layout.dyn_cast()) { + result = + emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); + } else if (auto mmaLayout = layout.dyn_cast()) { + if (mmaLayout.isVolta()) + result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape); + if (mmaLayout.isAmpere()) + result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape); + } else { + llvm_unreachable("unsupported emitBaseIndexForLayout"); + } + cache->insert(std::make_pair(key, result)); + *insertPt = rewriter.saveInsertionPoint(); + return result; + } + } + + SmallVector> + emitOffsetForLayout(const Attribute &layout, ArrayRef shape) const { + if (auto blockedLayout = layout.dyn_cast()) + return emitOffsetForBlockedLayout(blockedLayout, shape); + if (auto mmaLayout = layout.dyn_cast()) { + if (mmaLayout.isVolta()) + return emitOffsetForMmaLayoutV1(mmaLayout, shape); + if (mmaLayout.isAmpere()) + return emitOffsetForMmaLayoutV2(mmaLayout, shape); + } + llvm_unreachable("unsupported emitOffsetForLayout"); + } + + // ----------------------------------------------------------------------- + // Emit indices + // ----------------------------------------------------------------------- + SmallVector> emitIndices(Location loc, + ConversionPatternRewriter &b, + const Attribute &layout, + ArrayRef shape) const { + IndexCacheKeyT key(layout, llvm::to_vector(shape)); + auto cache = indexCacheInfo.indexCache; + assert(cache && "indexCache is nullptr"); + auto insertPt = indexCacheInfo.indexInsertPoint; + if (cache->count(key) > 0) { + return cache->lookup(key); + } else { + ConversionPatternRewriter::InsertionGuard guard(b); + restoreInsertionPointIfSet(insertPt, b); + SmallVector> result; + if (auto blocked = layout.dyn_cast()) { + result = emitIndicesForDistributedLayout(loc, b, blocked, shape); + } else if (auto mma = layout.dyn_cast()) { + result = emitIndicesForDistributedLayout(loc, b, mma, shape); + } else if (auto slice = layout.dyn_cast()) { + result = emitIndicesForSliceLayout(loc, b, slice, shape); + } else { + llvm_unreachable( + "emitIndices for layouts other than blocked & slice not " + "implemented yet"); + } + cache->insert(std::make_pair(key, result)); + *insertPt = b.saveInsertionPoint(); + return result; + } + } + +private: + void restoreInsertionPointIfSet(OpBuilder::InsertPoint *insertPt, + ConversionPatternRewriter &rewriter) const { + if (insertPt->isSet()) { + rewriter.restoreInsertionPoint(*insertPt); + } else { + auto func = + rewriter.getInsertionPoint()->getParentOfType(); + rewriter.setInsertionPointToStart(&func.getBody().front()); + } + } + // ----------------------------------------------------------------------- // Blocked layout indices // ----------------------------------------------------------------------- @@ -411,38 +571,6 @@ public: return ret; } - // ----------------------------------------------------------------------- - // Get offsets / indices for any layout - // ----------------------------------------------------------------------- - - SmallVector emitBaseIndexForLayout(Location loc, - ConversionPatternRewriter &rewriter, - const Attribute &layout, - ArrayRef shape) const { - if (auto blockedLayout = layout.dyn_cast()) - return emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape); - if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.isVolta()) - return emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape); - if (mmaLayout.isAmpere()) - return emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape); - } - llvm_unreachable("unsupported emitBaseIndexForLayout"); - } - - SmallVector> - emitOffsetForLayout(const Attribute &layout, ArrayRef shape) const { - if (auto blockedLayout = layout.dyn_cast()) - return emitOffsetForBlockedLayout(blockedLayout, shape); - if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.isVolta()) - return emitOffsetForMmaLayoutV1(mmaLayout, shape); - if (mmaLayout.isAmpere()) - return emitOffsetForMmaLayoutV2(mmaLayout, shape); - } - llvm_unreachable("unsupported emitOffsetForLayout"); - } - // Emit indices calculation within each ConversionPattern, and returns a // [elemsPerThread X rank] index matrix. @@ -470,22 +598,6 @@ public: return multiDimIdx; } - struct SmallVectorKeyInfo { - static unsigned getHashValue(const SmallVector &key) { - return llvm::hash_combine_range(key.begin(), key.end()); - } - static bool isEqual(const SmallVector &lhs, - const SmallVector &rhs) { - return lhs == rhs; - } - static SmallVector getEmptyKey() { - return SmallVector(); - } - static SmallVector getTombstoneKey() { - return {std::numeric_limits::max()}; - } - }; - SmallVector> emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, const SliceEncodingAttr &sliceLayout, @@ -505,46 +617,45 @@ public: return resultIndices; } - // ----------------------------------------------------------------------- - // Emit indices - // ----------------------------------------------------------------------- - SmallVector> emitIndices(Location loc, - ConversionPatternRewriter &b, - const Attribute &layout, - ArrayRef shape) const { - if (auto blocked = layout.dyn_cast()) { - return emitIndicesForDistributedLayout(loc, b, blocked, shape); - } else if (auto mma = layout.dyn_cast()) { - return emitIndicesForDistributedLayout(loc, b, mma, shape); - } else if (auto slice = layout.dyn_cast()) { - return emitIndicesForSliceLayout(loc, b, slice, shape); - } else { - assert(0 && "emitIndices for layouts other than blocked & slice not " - "implemented yet"); - return {}; - } - } - - // ----------------------------------------------------------------------- - // Shared memory utilities - // ----------------------------------------------------------------------- - template - Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter, - T value) const { - - auto ptrTy = LLVM::LLVMPointerType::get( - this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); - auto bufferId = allocation->getBufferId(value); - assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); - size_t offset = allocation->getOffset(bufferId); - Value offVal = idx_val(offset); - Value base = gep(ptrTy, smem, offVal); - return base; - } - protected: + LLVMTypeConverter *converter; const Allocation *allocation; Value smem; + IndexCacheInfo indexCacheInfo; +}; + +template +class ConvertTritonGPUOpToLLVMPattern + : public ConvertOpToLLVMPattern, + public ConvertTritonGPUOpToLLVMPatternBase { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {} + + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {} + + explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const Allocation *allocation, + Value smem, + IndexCacheInfo indexCacheInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem, + indexCacheInfo) {} + +protected: + LLVMTypeConverter *getTypeConverter() const { + return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter(); + } }; #endif diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 797cd6f6d..897ab913d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -170,16 +170,20 @@ public: // We set a higher benefit here to ensure triton's patterns runs before // arith patterns for some encoding not supported by the community // patterns. + OpBuilder::InsertPoint indexInsertPoint; + ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{ + &baseIndexCache, &indexCache, &indexInsertPoint}; + RewritePatternSet patterns(context); // Normal conversions populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, - /*benefit=*/10); + indexCacheInfo, /*benefit=*/10); // ConvertLayoutOp populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, - /*benefit=*/10); + indexCacheInfo, /*benefit=*/10); // DotOp populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, @@ -191,11 +195,11 @@ public: // LoadStoreOp populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, - /*benefit=*/10); + indexCacheInfo, /*benefit=*/10); // ReduceOp populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, - /*benefit=*/10); + indexCacheInfo, /*benefit=*/10); // ViewOp populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps, axisInfoAnalysis, &allocation, smem, @@ -215,6 +219,13 @@ public: private: Value smem; + using IndexCacheKeyT = std::pair>; + DenseMap, CacheKeyDenseMapInfo> + baseIndexCache; + DenseMap>, + CacheKeyDenseMapInfo> + indexCache; + int computeCapability{}; void initSharedMemory(size_t size, diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index abc5e9a31..e49a231bf 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -997,20 +997,61 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - -func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { - // CHECK: nvvm.read.ptx.sreg.nctaid.x - // CHECK: nvvm.read.ptx.sreg.nctaid.y - // CHECK: nvvm.read.ptx.sreg.nctaid.z - %blockdimx = tt.get_num_programs {axis=0:i32} : i32 - %blockdimy = tt.get_num_programs {axis=1:i32} : i32 - %blockdimz = tt.get_num_programs {axis=2:i32} : i32 - %v0 = arith.addi %blockdimx, %blockdimy : i32 - %v1 = arith.addi %v0, %blockdimz : i32 - %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> - tt.store %a, %0 : tensor<32xi32, #blocked0> - - return + func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { + // CHECK: nvvm.read.ptx.sreg.nctaid.x + // CHECK: nvvm.read.ptx.sreg.nctaid.y + // CHECK: nvvm.read.ptx.sreg.nctaid.z + %blockdimx = tt.get_num_programs {axis=0:i32} : i32 + %blockdimy = tt.get_num_programs {axis=1:i32} : i32 + %blockdimz = tt.get_num_programs {axis=2:i32} : i32 + %v0 = arith.addi %blockdimx, %blockdimy : i32 + %v1 = arith.addi %v0, %blockdimz : i32 + %0 = tt.splat %v1 : (i32) -> tensor<32xi32, #blocked0> + tt.store %a, %0 : tensor<32xi32, #blocked0> + + return + } } +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: test_index_cache + func @test_index_cache() { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x + %1 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> + return + } } + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: test_base_index_cache + func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x + %1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + return + } +} + +// ----- +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}> +#shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +module attributes {"triton_gpu.num-warps" = 1 : i32} { + // CHECK-LABEL: test_index_cache_different_block + func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { + // CHECK: nvvm.read.ptx.sreg.tid.x + %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + scf.if %arg1 { + // CHECK-NOT: nvvm.read.ptx.sreg.tid.x + %1 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> + } + return + } +} \ No newline at end of file From 194ba103b19898fe7b80732d3df6f19a30c5a9ec Mon Sep 17 00:00:00 2001 From: fdrocha <99990201+fdrocha@users.noreply.github.com> Date: Thu, 29 Dec 2022 23:10:34 +0000 Subject: [PATCH 6/7] [BUILD] Fixed error when compiling in systems with multiple versions of python installed (#1019) --- python/setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/setup.py b/python/setup.py index 3adcc6db8..4fda02659 100644 --- a/python/setup.py +++ b/python/setup.py @@ -141,10 +141,10 @@ class CMakeBuild(build_ext): "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, "-DTRITON_BUILD_TUTORIALS=OFF", "-DTRITON_BUILD_PYTHON_MODULE=ON", - # '-DPYTHON_EXECUTABLE=' + sys.executable, - '-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON', + "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, + "-DCMAKE_VERBOSE_MAKEFILE:BOOL=ON", "-DPYTHON_INCLUDE_DIRS=" + python_include_dir, - "-DLLVM_EXTERNAL_LIT=" + lit_dir + "-DLLVM_EXTERNAL_LIT=" + lit_dir, ] + thirdparty_cmake_args # configuration From 0e8590f1c9d3c3301187d2b450714f8ca8b4eeda Mon Sep 17 00:00:00 2001 From: goostavz <109190422+goostavz@users.noreply.github.com> Date: Sat, 31 Dec 2022 03:29:58 +0800 Subject: [PATCH 7/7] [BACKEND] Add generic support of convert_layout from distributed to shared (#1025) --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 2 + .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 154 ++++++------------ .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.h | 10 +- .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 7 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 + 5 files changed, 68 insertions(+), 110 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 8c24a5777..3ccd2da32 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -39,6 +39,8 @@ SmallVector getShapePerCTA(const Attribute &layout); SmallVector getOrder(const Attribute &layout); +bool isaDistributedLayout(const Attribute &layout); + } // namespace gpu } // namespace triton } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index cd18ed751..5f4c803e8 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -9,10 +9,12 @@ using ::mlir::LLVM::getStridesFromShapeAndOrder; using ::mlir::LLVM::getStructFromElements; using ::mlir::LLVM::MMA16816ConversionHelper; using ::mlir::triton::gpu::DotOperandEncodingAttr; +using ::mlir::triton::gpu::getContigPerThread; using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getSizePerThread; +using ::mlir::triton::gpu::isaDistributedLayout; using ::mlir::triton::gpu::SharedEncodingAttr; bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, @@ -24,108 +26,63 @@ bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, dotOperandLayout.getParent() == mmaLayout; } -void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, - ArrayRef srcIndices, Value dst, Value smemBase, - Type elemTy, Location loc, - ConversionPatternRewriter &rewriter) { +void storeDistributedToShared(Value src, Value llSrc, + ArrayRef dstStrides, + ArrayRef> srcIndices, + Value dst, Value smemBase, Type elemTy, + Location loc, + ConversionPatternRewriter &rewriter) { auto srcTy = src.getType().cast(); auto srcShape = srcTy.getShape(); - assert(srcShape.size() == 2 && "Unexpected rank of insertSlice"); - + assert(srcShape.size() == 2 && "Unexpected rank of storeDistributedToShared"); auto dstTy = dst.getType().cast(); - auto srcBlockedLayout = srcTy.getEncoding().cast(); + auto srcDistributedLayout = srcTy.getEncoding(); + if (auto mmaLayout = srcDistributedLayout.dyn_cast()) { + assert((!mmaLayout.isVolta()) && + "ConvertLayout MMAv1->Shared is not suppported yet"); + } auto dstSharedLayout = dstTy.getEncoding().cast(); - auto inOrd = srcBlockedLayout.getOrder(); + auto inOrd = getOrder(srcDistributedLayout); auto outOrd = dstSharedLayout.getOrder(); - if (inOrd != outOrd) - llvm_unreachable( - "blocked -> shared with different order not yet implemented"); unsigned inVec = - inOrd == outOrd ? srcBlockedLayout.getSizePerThread()[inOrd[0]] : 1; + inOrd == outOrd ? getContigPerThread(srcDistributedLayout)[inOrd[0]] : 1; unsigned outVec = dstSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned perPhase = dstSharedLayout.getPerPhase(); unsigned maxPhase = dstSharedLayout.getMaxPhase(); unsigned numElems = getElemsPerThread(srcTy); + assert(numElems == srcIndices.size()); auto inVals = getElementsFromStruct(loc, llSrc, rewriter); - auto srcAccumSizeInThreads = - product(srcBlockedLayout.getSizePerThread()); auto wordTy = vec_ty(elemTy, minVec); auto elemPtrTy = ptr_ty(elemTy); - - SmallVector srcShapePerCTA = getShapePerCTA(srcBlockedLayout); - SmallVector reps{ceil(srcShape[0], srcShapePerCTA[0]), - ceil(srcShape[1], srcShapePerCTA[1])}; - - // Visit each input value in the order they are placed in inVals - // - // Please note that the order was not awaring of blockLayout.getOrder(), - // thus the adjacent elems may not belong to a same word. This could be - // improved if we update the elements order by emitIndicesForBlockedLayout() - SmallVector wordsInEachRep(2); - wordsInEachRep[0] = inOrd[0] == 0 - ? srcBlockedLayout.getSizePerThread()[0] / minVec - : srcBlockedLayout.getSizePerThread()[0]; - wordsInEachRep[1] = inOrd[0] == 0 - ? srcBlockedLayout.getSizePerThread()[1] - : srcBlockedLayout.getSizePerThread()[1] / minVec; Value outVecVal = i32_val(outVec); Value minVecVal = i32_val(minVec); - auto numWordsEachRep = product(wordsInEachRep); - SmallVector wordVecs(numWordsEachRep); + Value word; for (unsigned i = 0; i < numElems; ++i) { - if (i % srcAccumSizeInThreads == 0) { - // start of a replication - for (unsigned w = 0; w < numWordsEachRep; ++w) { - wordVecs[w] = undef(wordTy); - } - } - unsigned linearIdxInNanoTile = i % srcAccumSizeInThreads; - auto multiDimIdxInNanoTile = getMultiDimIndex( - linearIdxInNanoTile, srcBlockedLayout.getSizePerThread(), inOrd); - unsigned pos = multiDimIdxInNanoTile[inOrd[0]] % minVec; - multiDimIdxInNanoTile[inOrd[0]] /= minVec; - auto wordVecIdx = - getLinearIndex(multiDimIdxInNanoTile, wordsInEachRep, inOrd); - wordVecs[wordVecIdx] = - insert_element(wordTy, wordVecs[wordVecIdx], inVals[i], i32_val(pos)); + if (i % minVec == 0) + word = undef(wordTy); + word = insert_element(wordTy, word, inVals[i], i32_val(i % minVec)); + if (i % minVec == minVec - 1) { + // step 1: recover the multidim_index from the index of + SmallVector multiDimIdx = srcIndices[i]; + SmallVector dbgVal = srcIndices[i]; - if (i % srcAccumSizeInThreads == srcAccumSizeInThreads - 1) { - // end of replication, store the vectors into shared memory - unsigned linearRepIdx = i / srcAccumSizeInThreads; - auto multiDimRepIdx = - getMultiDimIndex(linearRepIdx, reps, inOrd); - for (unsigned linearWordIdx = 0; linearWordIdx < numWordsEachRep; - ++linearWordIdx) { - // step 1: recover the multidim_index from the index of - // input_elements - auto multiDimWordIdx = - getMultiDimIndex(linearWordIdx, wordsInEachRep, inOrd); - SmallVector multiDimIdx(2); - auto wordOffset0 = multiDimRepIdx[0] * srcShapePerCTA[0] + - multiDimWordIdx[0] * (inOrd[0] == 0 ? minVec : 1); - auto wordOffset1 = multiDimRepIdx[1] * srcShapePerCTA[1] + - multiDimWordIdx[1] * (inOrd[0] == 1 ? minVec : 1); - multiDimIdx[0] = add(srcIndices[0], i32_val(wordOffset0)); - multiDimIdx[1] = add(srcIndices[1], i32_val(wordOffset1)); + // step 2: do swizzling + Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); + multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); + Value off_1 = mul(multiDimIdx[outOrd[1]], dstStrides[outOrd[1]]); + Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase)); + phaseId = urem(phaseId, i32_val(maxPhase)); + Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); + off_0 = mul(off_0, outVecVal); + remained = udiv(remained, minVecVal); + off_0 = add(off_0, mul(remained, minVecVal)); + Value offset = add(off_1, mul(off_0, dstStrides[outOrd[0]])); - // step 2: do swizzling - Value remained = urem(multiDimIdx[outOrd[0]], outVecVal); - multiDimIdx[outOrd[0]] = udiv(multiDimIdx[outOrd[0]], outVecVal); - Value off_1 = mul(multiDimIdx[outOrd[1]], srcStrides[outOrd[1]]); - Value phaseId = udiv(multiDimIdx[outOrd[1]], i32_val(perPhase)); - phaseId = urem(phaseId, i32_val(maxPhase)); - Value off_0 = xor_(multiDimIdx[outOrd[0]], phaseId); - off_0 = mul(off_0, outVecVal); - remained = udiv(remained, minVecVal); - off_0 = add(off_0, mul(remained, minVecVal)); - Value offset = add(off_1, off_0); - - // step 3: store - Value smemAddr = gep(elemPtrTy, smemBase, offset); - smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); - store(wordVecs[linearWordIdx], smemAddr); - } + // step 3: store + Value smemAddr = gep(elemPtrTy, smemBase, offset); + smemAddr = bitcast(smemAddr, ptr_ty(wordTy, 3)); + store(word, smemAddr); } } } @@ -145,20 +102,15 @@ public: auto dstTy = dst.getType().cast(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - if (srcLayout.isa() && + if (isaDistributedLayout(srcLayout) && dstLayout.isa()) { - return lowerBlockedToShared(op, adaptor, rewriter); + return lowerDistributedToShared(op, adaptor, rewriter); } if (srcLayout.isa() && dstLayout.isa()) { return lowerSharedToDotOperand(op, adaptor, rewriter); } - if ((srcLayout.isa() || - srcLayout.isa() || - srcLayout.isa()) && - (dstLayout.isa() || - dstLayout.isa() || - dstLayout.isa())) { + if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) { return lowerDistributedToDistributed(op, adaptor, rewriter); } if (srcLayout.isa() && @@ -476,8 +428,8 @@ private: // Swizzling in shared memory to avoid bank conflict. Normally used for // A/B operands of dots. LogicalResult - lowerBlockedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + lowerDistributedToShared(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { auto loc = op.getLoc(); Value src = op.src(); Value dst = op.result(); @@ -487,22 +439,20 @@ private: auto dstShape = dstTy.getShape(); assert(srcShape.size() == 2 && "Unexpected rank of ConvertLayout(blocked->shared)"); - auto srcBlockedLayout = srcTy.getEncoding().cast(); + auto srcLayout = srcTy.getEncoding(); auto dstSharedLayout = dstTy.getEncoding().cast(); - auto inOrd = srcBlockedLayout.getOrder(); + auto inOrd = getOrder(srcLayout); auto outOrd = dstSharedLayout.getOrder(); Value smemBase = getSharedMemoryBase(loc, rewriter, dst); auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = ptr_ty(getTypeConverter()->convertType(elemTy), 3); smemBase = bitcast(smemBase, elemPtrTy); - auto srcStrides = - getStridesFromShapeAndOrder(srcShape, inOrd, loc, rewriter); - auto srcIndices = - emitBaseIndexForLayout(loc, rewriter, srcBlockedLayout, srcShape); - storeBlockedToShared(src, adaptor.src(), srcStrides, srcIndices, dst, - smemBase, elemTy, loc, rewriter); - + auto dstStrides = + getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); + storeDistributedToShared(src, adaptor.src(), dstStrides, srcIndices, dst, + smemBase, elemTy, loc, rewriter); auto smemObj = SharedMemoryObject(smemBase, dstShape, outOrd, loc, rewriter); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h index ec435b2ab..d5b866a44 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h @@ -11,10 +11,12 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout, DotOperandEncodingAttr &dotOperandLayout); -void storeBlockedToShared(Value src, Value llSrc, ArrayRef srcStrides, - ArrayRef srcIndices, Value dst, Value smemBase, - Type elemPtrTy, Location loc, - ConversionPatternRewriter &rewriter); +void storeDistributedToShared(Value src, Value llSrc, + ArrayRef srcStrides, + ArrayRef> srcIndices, + Value dst, Value smemBase, Type elemPtrTy, + Location loc, + ConversionPatternRewriter &rewriter); void populateConvertLayoutOpToLLVMPatterns( mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 92b11a94c..a06910277 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -639,10 +639,9 @@ struct InsertSliceOpConversion auto smemBase = gep(elemPtrTy, smemObj.base, offset); auto llSrc = adaptor.source(); - auto srcIndices = - emitBaseIndexForLayout(loc, rewriter, srcLayout, srcShape); - storeBlockedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, - elemTy, loc, rewriter); + auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape); + storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase, + elemTy, loc, rewriter); // Barrier is not necessary. // The membar pass knows that it writes to shared memory and will handle it // properly. diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index d671f377d..a4573ca82 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -254,6 +254,11 @@ SmallVector getOrder(const Attribute &layout) { } }; +bool isaDistributedLayout(const Attribute &layout) { + return layout.isa() || layout.isa() || + layout.isa(); +} + } // namespace gpu } // namespace triton } // namespace mlir