From 192be76b3c90ab9dd6ae631c4869a11fdb02a92b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 18 Aug 2022 12:49:37 -0700 Subject: [PATCH] [OPTIMIZER] Rewrite patterns for layout conversions (#64) --- include/triton/Dialect/Triton/IR/TritonOps.td | 2 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 44 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 4 +- .../Dialect/TritonGPU/Transforms/Passes.h | 2 + .../Dialect/TritonGPU/Transforms/Passes.td | 13 + lib/Analysis/Allocation.cpp | 3 +- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 67 ++- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 23 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 90 +++- .../TritonGPU/Transforms/CMakeLists.txt | 1 + .../Transforms/CanonicalizeLoops.cpp | 55 +++ lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 20 +- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 416 +++++++++++++++++- lib/Dialect/TritonGPU/Transforms/Combine.td | 18 - lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 18 +- .../Transforms/TritonGPUConversion.cpp | 7 +- lib/Dialect/TritonGPU/Transforms/Verifier.cpp | 6 +- python/triton/compiler.py | 14 +- test/TritonGPU/combine.mlir | 175 ++++++++ 19 files changed, 851 insertions(+), 127 deletions(-) create mode 100644 lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp create mode 100644 test/TritonGPU/combine.mlir diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 99e593855..e0656492a 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -225,7 +225,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect, // let hasCanonicalizer = 1; } -def TT_ReduceOp : TT_Op<"reduce"> { +def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect]> { let summary = "reduce"; let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index f5f5d5904..2277b509e 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -37,7 +37,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed. // Shared Layout Encoding //===----------------------------------------------------------------------===// -def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> { +def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding"> { let mnemonic = "shared"; let description = [{ @@ -70,9 +70,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / // Distributed Layout Encoding //===----------------------------------------------------------------------===// -class TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> { - let mnemonic = "distributed"; - +class DistributedEncoding : TritonGPU_Attr { let description = [{ Distributed encodings have a layout function that is entirely characterized by a d-dimensional tensor L. Note that L doesn't need to have the same shape @@ -97,12 +95,11 @@ L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, }]; } - //===----------------------------------------------------------------------===// // Blocked Layout Encoding //===----------------------------------------------------------------------===// -def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> { +def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding"> { let mnemonic = "blocked"; let description = [{ @@ -174,6 +171,10 @@ for }]> ]; + let extraClassDeclaration = [{ + SliceEncodingAttr squeeze(int axis); + }]; + let parameters = ( ins @@ -197,7 +198,7 @@ for //===----------------------------------------------------------------------===// // TODO: MMAv1 and MMAv2 should be two instances of the same class -def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { +def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> { let mnemonic = "mma"; let description = [{ @@ -283,5 +284,34 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: ); } +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { + let mnemonic = "slice"; + + let description = [{ + TODO: improve docs + + A = [x x x x x x x x] + [x x x x x x x x] + L_parent = [0 1 2 3 ] + [4 5 6 7 ] + [8 9 10 11] + [12 13 14 15] + dim = 0 + + Then the data of A would be distributed as follow between the 16 CUDA threads: + L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15} ] + + This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes. + + }]; + + let parameters = ( + ins + "unsigned":$dim, + // TODO: constraint here to only take distributed encodings + "Attribute":$parent + ); +} + #endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 90e61dcb5..57187cb2f 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -20,7 +20,7 @@ class TTG_Op traits = []> : Op; def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", - [NoSideEffect, SameOperandsAndResultType]> { + [NoSideEffect]> { let summary = "convert layout"; let arguments = (ins TT_Tensor:$src); @@ -65,7 +65,7 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async", // This is needed because Arith's Cmp ops don't // handle encodings // https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111 -def TTG_CmpIOp : TTG_Op<"cmpi"> { +def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> { let summary = "integer comparison operation"; let description = [{}]; diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.h b/include/triton/Dialect/TritonGPU/Transforms/Passes.h index 9ed457431..0f3f055c2 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.h @@ -6,6 +6,8 @@ namespace mlir { std::unique_ptr createTritonGPUPipelinePass(int numStages = 2); +std::unique_ptr createTritonGPUCanonicalizeLoopsPass(); + std::unique_ptr createTritonGPUCoalescePass(); std::unique_ptr createTritonGPUCombineOpsPass(); diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index d28c32a98..07cda2ee5 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -60,6 +60,19 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> { "mlir::triton::TritonDialect"]; } +def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> { + let summary = "canonicalize scf.ForOp ops"; + + let description = [{ + This implements some optimizations that are missing in the standard scf.ForOp + canonicalizer. + }]; + + let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"]; +} + def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> { let summary = "verify TritonGPU IR"; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 5a5b09832..efe13f054 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -43,8 +43,7 @@ private: auto type = result.getType(); if (auto tensorType = type.dyn_cast()) { auto encoding = tensorType.getEncoding(); - if (encoding && - encoding.isa()) { + if (encoding && encoding.isa()) { // Bytes could be a different value once we support padding or other // allocation policies. auto bytes = tensorType.getNumElements() * diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 3c8ce991e..89bff6851 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -23,9 +23,9 @@ using namespace mlir; using namespace mlir::triton; -using ::mlir::triton::gpu::TritonGPUBlockedEncodingAttr; -using ::mlir::triton::gpu::TritonGPUMmaEncodingAttr; -using ::mlir::triton::gpu::TritonGPUSharedEncodingAttr; +using ::mlir::triton::gpu::BlockedEncodingAttr; +using ::mlir::triton::gpu::MmaEncodingAttr; +using ::mlir::triton::gpu::SharedEncodingAttr; namespace mlir { namespace LLVM { @@ -226,7 +226,7 @@ static int64_t getLinearIndex(std::vector multidim_index, return linear_index; } -static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout, +static unsigned getElemsPerThread(BlockedEncodingAttr layout, ArrayRef shape) { size_t rank = shape.size(); SmallVector elemsPerThreadPerDim(rank); @@ -368,10 +368,10 @@ public: // Emit indices calculation within each ConversionPattern // TODO: [goostavz] Double confirm the redundant indices calculations will // be eliminated in the consequent MLIR/LLVM optimization - SmallVector> emitIndicesForBlockedLayout( - Location loc, ConversionPatternRewriter &b, - const TritonGPUBlockedEncodingAttr &blocked_layout, - ArrayRef shape) const { + SmallVector> + emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b, + const BlockedEncodingAttr &blocked_layout, + ArrayRef shape) const { auto llvmIndexTy = this->getTypeConverter()->getIndexType(); auto cast = b.create( loc, TypeRange{llvmIndexTy}, @@ -483,7 +483,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal, ConversionPatternRewriter &rewriter, Location loc) { auto tensorTy = resType.cast(); - auto layout = tensorTy.getEncoding().cast(); + auto layout = tensorTy.getEncoding().cast(); auto srcType = typeConverter->convertType(elemType); auto llSrc = rewriter.create(loc, srcType, constVal); @@ -594,9 +594,9 @@ struct StoreOpConversion MLIRContext *ctx = rewriter.getContext(); auto loc = op->getLoc(); - auto getLLVMElems = [&](Value value, Value llValue, - const TritonGPUBlockedEncodingAttr &layout) - -> SmallVector { + auto getLLVMElems = + [&](Value value, Value llValue, + const BlockedEncodingAttr &layout) -> SmallVector { auto ty = value.getType().cast(); auto shape = ty.getShape(); // Here, we assume that all inputs should have a blockedLayout @@ -613,11 +613,11 @@ struct StoreOpConversion }; auto getLayout = - [&](Value val) -> std::tuple { + [&](Value val) -> std::tuple { auto ty = val.getType().cast(); auto shape = ty.getShape(); // Here, we assume that all inputs should have a blockedLayout - auto layout = ty.getEncoding().dyn_cast(); + auto layout = ty.getEncoding().dyn_cast(); unsigned valueElems = getElemsPerThread(layout, shape); @@ -633,9 +633,8 @@ struct StoreOpConversion auto maskElems = getLLVMElems(mask, llMask, maskLayout); assert(valueElems.size() == maskElems.size()); - auto getAlign = - [this](Value val, - const TritonGPUBlockedEncodingAttr &layout) -> unsigned { + auto getAlign = [this](Value val, + const BlockedEncodingAttr &layout) -> unsigned { auto axisInfo = getAxisInfo(val); assert(axisInfo.hasValue()); @@ -648,9 +647,9 @@ struct StoreOpConversion }; // get align - auto getVec = [this, &getAlign]( - Value val, - const TritonGPUBlockedEncodingAttr &layout) -> unsigned { + auto getVec = [this, + &getAlign](Value val, + const BlockedEncodingAttr &layout) -> unsigned { auto axisInfo = getAxisInfo(val); auto contig = axisInfo->getContiguity(); // Here order should be ordered by contiguous first, so the first element @@ -820,10 +819,8 @@ struct BroadcastOpConversion Value result = op.result(); auto srcTy = op.src().getType().cast(); auto resultTy = result.getType().cast(); - auto srcLayout = - srcTy.getEncoding().dyn_cast(); - auto resultLayout = - resultTy.getEncoding().dyn_cast(); + auto srcLayout = srcTy.getEncoding().dyn_cast(); + auto resultLayout = resultTy.getEncoding().dyn_cast(); assert(srcLayout && (srcLayout == resultLayout) && "Unexpected layout of BroadcastOp"); auto srcShape = srcTy.getShape(); @@ -894,8 +891,7 @@ struct ViewOpConversion // due to MLIR's restrictions Location loc = op->getLoc(); auto resultTy = op.getType().cast(); - auto resultLayout = - resultTy.getEncoding().dyn_cast(); + auto resultLayout = resultTy.getEncoding().dyn_cast(); auto resultShape = resultTy.getShape(); unsigned elems = getElemsPerThread(resultLayout, resultShape); Type elemTy = @@ -921,7 +917,7 @@ struct MakeRangeOpConversion auto rankedTy = op.result().getType().dyn_cast(); auto shape = rankedTy.getShape(); auto blocked_layout = - rankedTy.getEncoding().dyn_cast(); + rankedTy.getEncoding().dyn_cast(); auto elemTy = rankedTy.getElementType(); assert(elemTy.isInteger(32)); Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start()); @@ -955,8 +951,7 @@ struct LoadOpConversion Value mask = adaptor.mask(); Value other = adaptor.other(); auto resultTy = op.result().getType().cast(); - auto blockedLayout = - resultTy.getEncoding().dyn_cast(); + auto blockedLayout = resultTy.getEncoding().dyn_cast(); auto shape = resultTy.getShape(); // TODO: Handle AxisInfo @@ -1166,8 +1161,7 @@ struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto resultTy = op.getType().dyn_cast(); - auto resultLayout = - resultTy.getEncoding().dyn_cast(); + auto resultLayout = resultTy.getEncoding().dyn_cast(); auto resultShape = resultTy.getShape(); unsigned elems = getElemsPerThread(resultLayout, resultShape); Type elemTy = @@ -1206,8 +1200,8 @@ public: return failure(); Location loc = op->getLoc(); - auto resultLayout = resultTy.getEncoding() - .template dyn_cast(); + auto resultLayout = + resultTy.getEncoding().template dyn_cast(); auto resultShape = resultTy.getShape(); unsigned elems = getElemsPerThread(resultLayout, resultShape); Type elemTy = @@ -1250,17 +1244,16 @@ public: llvm::Optional convertTritonTensorType(RankedTensorType type) { Attribute layout = type.getEncoding(); - if (auto blocked_layout = layout.dyn_cast()) { + if (auto blocked_layout = layout.dyn_cast()) { unsigned numElementsPerThread = getElemsPerThread(blocked_layout, type.getShape()); SmallVector types(numElementsPerThread, convertType(type.getElementType())); return LLVM::LLVMStructType::getLiteral(&getContext(), types); - } else if (auto mma_layout = layout.dyn_cast()) { + } else if (auto mma_layout = layout.dyn_cast()) { // TODO: Not implemented return llvm::None; - } else if (auto shared_layout = - layout.dyn_cast()) { + } else if (auto shared_layout = layout.dyn_cast()) { // TODO: Not implemented return llvm::None; } diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 63b37198c..128d9521a 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -156,8 +156,7 @@ struct TritonExpandDimsPattern Attribute _argEncoding = argType.getEncoding(); if (!_argEncoding) return failure(); - auto argEncoding = - _argEncoding.cast(); + auto argEncoding = _argEncoding.cast(); // return shape auto retShape = argType.getShape().vec(); retShape.insert(retShape.begin() + op.axis(), 1); @@ -170,10 +169,10 @@ struct TritonExpandDimsPattern retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1); SmallVector retOrder(retShape.size()); std::iota(retOrder.begin(), retOrder.end(), 0); - triton::gpu::TritonGPUBlockedEncodingAttr retEncoding = - triton::gpu::TritonGPUBlockedEncodingAttr::get( - getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA, - retOrder); + triton::gpu::BlockedEncodingAttr retEncoding = + triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, + retThreadsPerWarp, retWarpsPerCTA, + retOrder); // return type RankedTensorType retType = RankedTensorType::get(retShape, argType.getElementType(), retEncoding); @@ -201,16 +200,16 @@ struct TritonDotPattern : public OpConversionPattern { Value a = adaptor.a(); Value b = adaptor.b(); SmallVector order{1, 0}; - if (!aEncoding.isa()) { - Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get( - getContext(), 1, 1, 1, order); + if (!aEncoding.isa()) { + Attribute encoding = + triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order); auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding); a = rewriter.create(a.getLoc(), dstType, a); } - if (!bEncoding.isa()) { - Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get( - getContext(), 1, 1, 1, order); + if (!bEncoding.isa()) { + Attribute encoding = + triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order); auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding); b = rewriter.create(b.getLoc(), dstType, b); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 16b74587c..d6275dc89 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1,6 +1,8 @@ -#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include + #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" #include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" @@ -70,11 +72,15 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" +SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) { + return SliceEncodingAttr::get(getContext(), axis, *this); +} + //===----------------------------------------------------------------------===// // Blocked Encoding //===----------------------------------------------------------------------===// -Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { +Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary @@ -115,11 +121,11 @@ Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { } } - return parser.getChecked( + return parser.getChecked( parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order); } -void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { +void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "<{" << "sizePerThread = [" << getSizePerThread() << "]" << ", threadsPerWarp = [" << getThreadsPerWarp() << "]" @@ -132,7 +138,7 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { // MMA encoding //===----------------------------------------------------------------------===// -Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { +Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; DictionaryAttr dict; @@ -155,22 +161,59 @@ Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { } } - return parser.getChecked(parser.getContext(), - version, warpsPerCTA); + return parser.getChecked(parser.getContext(), version, + warpsPerCTA); } -void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { +void MmaEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" << "version = " << getVersion() << ", " << "warpsPerCTA = [" << getWarpsPerCTA() << "]" << "}>"; } +//===----------------------------------------------------------------------===// +// Sliced Encoding +//===----------------------------------------------------------------------===// + +Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + // Parse the data as a dictionary + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + if (parser.parseGreater().failed()) + return {}; + + unsigned dim = 0; + Attribute parent; + + for (const NamedAttribute &attr : dict) { + if (attr.getName() == "dim") { + if (parseUInt(parser, attr, dim, "dim").failed()) + return {}; + } + if (attr.getName() == "parent") { + if (parser.parseAttribute(parent).failed()) + return {}; + } + } + + return parser.getChecked(parser.getContext(), dim, parent); +} + +void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { + printer << "<{" + << "dim = " << getDim() << ", " + << "parent = " << getParent() << "}>"; +} + //===----------------------------------------------------------------------===// // Shared encoding //===----------------------------------------------------------------------===// -Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { +Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary @@ -205,11 +248,11 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) { } } - return parser.getChecked( - parser.getContext(), vec, perPhase, maxPhase, order); + return parser.getChecked(parser.getContext(), vec, + perPhase, maxPhase, order); } -void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const { +void SharedEncodingAttr::print(AsmPrinter &printer) const { printer << "<{" << "vec = " << getVec() << ", perPhase = " << getPerPhase() << ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder() @@ -226,18 +269,21 @@ public: using OpAsmDialectInterface::OpAsmDialectInterface; AliasResult getAlias(Attribute attr, raw_ostream &os) const override { - if (auto mmaAttr = attr.dyn_cast()) { + if (auto mmaAttr = attr.dyn_cast()) { os << "mma"; return AliasResult::FinalAlias; - } else if (auto sharedAttr = attr.dyn_cast()) { + } else if (auto sharedAttr = attr.dyn_cast()) { os << "shared"; return AliasResult::FinalAlias; - } else if (auto blockedAttr = - attr.dyn_cast()) { + } else if (auto blockedAttr = attr.dyn_cast()) { os << "blocked"; return AliasResult::FinalAlias; - } - return OpAsmDialectInterface::getAlias(attr, os); + } /* else if (auto sliceAttr = attr.dyn_cast()) { + os << "slice"; + return AliasResult::FinalAlias; + } */ + OpAsmDialectInterface::getAlias(attr, os); + return AliasResult::FinalAlias; } }; @@ -283,11 +329,15 @@ static Type getPointeeType(Type type) { } // namespace triton } // namespace mlir +//===----------------------------------------------------------------------===// +// Verification +//===----------------------------------------------------------------------===// + static LogicalResult verify(CopyAsyncOp op) { Type resType = op.getResult().getType(); if (auto tensorType = resType.dyn_cast()) { Attribute encoding = tensorType.getEncoding(); - if (!encoding.isa()) + if (!encoding.isa()) return op.emitOpError("copy_async should return a shared memory tensor"); } else return op.emitOpError("copy_async should return a tensor"); @@ -302,4 +352,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { // TODO: fill this. return success(); -} +} \ No newline at end of file diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 9fa32b806..a5c0898d3 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_public_tablegen_target(TritonGPUCombineIncGen) add_mlir_dialect_library(TritonGPUTransforms Coalesce.cpp + CanonicalizeLoops.cpp Combine.cpp Pipeline.cpp Verifier.cpp diff --git a/lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp b/lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp new file mode 100644 index 000000000..6126f6ce6 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp @@ -0,0 +1,55 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::triton; + +#define GEN_PASS_CLASSES +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +namespace { + +struct CanonicalizePass + : public TritonGPUCanonicalizeLoopsBase { + CanonicalizePass() = default; + + void runOnOperation() override { + + // Canonicalize pass may have created dead code that + // standard scf.for canonicalization cannot handle + // as of LLVM 14. For example, the iteration arguments + // for the pointer of the synchronous loads that are + // discarded. + // The following piece of code is a workaround to + // very crudely remove dead code, by making an iteration + // argument yield itself if it is not used to create + // side-effects anywhere. + getOperation()->walk([&](scf::ForOp forOp) -> void { + for (size_t i = 0; i < forOp.getNumResults(); ++i) { + // condition 1: no other iter arguments depend on it + SetVector fwdSlice; + mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice); + Operation *yieldOp = forOp.getBody()->getTerminator(); + bool noOtherDependency = std::all_of( + yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) { + return arg == yieldOp->getOperand(i) || + !fwdSlice.contains(arg.getDefiningOp()); + }); + // condition 2: final value is not used after the loop + auto retVal = forOp.getResult(i); + bool noUserAfterLoop = retVal.getUsers().empty(); + // yielding the region iter arg will cause loop canonicalization + // to clean up the dead code + if (noOtherDependency && noUserAfterLoop) { + yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]); + } + } + }); + } +}; +} // anonymous namespace + +std::unique_ptr mlir::createTritonGPUCanonicalizeLoopsPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index f24f6b2b3..e2dc9a09d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -32,8 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase { unsigned alignment = std::min(maxMultiple, maxContig); unsigned perThread = std::min(alignment, 128 / numBits); sizePerThread[order[0]] = perThread; + SmallVector dims(rank); + std::iota(dims.begin(), dims.end(), 0); // create encoding - Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( + Attribute encoding = triton::gpu::BlockedEncodingAttr::get( &getContext(), origType.getShape(), sizePerThread, order, this->numWarps); return encoding; @@ -64,15 +66,20 @@ struct CoalescePass : public TritonGPUCoalesceBase { op->getLoc(), convertType(v.getType()), v)); // convert output types SmallVector newTypes; - for (auto t : op->getResultTypes()) - newTypes.push_back(convertType(t)); + for (auto t : op->getResultTypes()) { + bool is_async = std::is_same::value; + newTypes.push_back(is_async ? t : convertType(t)); + } // construct new op with the new encoding Operation *newOp = builder.create(op->getLoc(), newTypes, newArgs, op->getAttrs()); // cast the results back to the original layout for (size_t i = 0; i < op->getNumResults(); i++) { - auto newResult = builder.create( - op->getLoc(), op->getResult(i).getType(), newOp->getResult(i)); + Value newResult = newOp->getResult(i); + if (newTypes[i] != op->getResultTypes()[i]) { + newResult = builder.create( + op->getLoc(), op->getResult(i).getType(), newResult); + } op->getResult(i).replaceAllUsesWith(newResult); } op->erase(); @@ -97,6 +104,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { builder.setInsertionPoint(curr); if (auto load = dyn_cast(curr)) coalesceOp(axisInfo, curr, load.ptr(), builder); + if (auto load = dyn_cast(curr)) + coalesceOp(axisInfo, curr, load.ptr(), + builder); if (auto store = dyn_cast(curr)) coalesceOp(axisInfo, curr, store.ptr(), builder); }); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 92b9127a3..13061e65e 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -1,10 +1,13 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -15,15 +18,414 @@ using namespace mlir; static bool isSharedLayout(Value v) { if (auto tensorType = v.getType().dyn_cast()) { Attribute encoding = tensorType.getEncoding(); - return encoding.isa(); + return encoding.isa(); } return false; } namespace { #include "TritonGPUCombine.inc" + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// Layout conversions can't deduce their return type automatically. +// IIUC they are therefore not handled by DRR right now +class SimplifyConversion : public mlir::RewritePattern { +public: + SimplifyConversion(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 2, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + if (!llvm::isa(op)) + return mlir::failure(); + // convert to the same layout -- we can delete + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return mlir::success(); + } + Operation *arg = op->getOperand(0).getDefiningOp(); + // block argument + if (!arg) + return mlir::failure(); + // cvt(type2, cvt(type1, x)) -> cvt(type2, x) + if (llvm::isa(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), arg->getOperand(0)); + return mlir::success(); + } + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = llvm::dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.src()); + return mlir::success(); + } + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = llvm::dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.start(), range.end()); + return mlir::success(); + } + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = cst.getValue().dyn_cast()) { + auto newRet = SplatElementsAttr::get(op->getResultTypes().front(), + ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return mlir::success(); + } + return mlir::failure(); + } +}; + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// Layout conversions are expensive. They require going through +// shared memory, which is orders of magnitude slower than +// other non-i/o operations in the dialect. +// It therefore makes sense to remove them whenever possible, +// even if it means rematerializing all values whose definitions +// are reachable from it without passing through any memory operation. +class PullConversionToSource : public mlir::RewritePattern { +public: + PullConversionToSource(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 3, context) {} + + void getReachableNotThroughMemOp( + ArrayRef operands, + SmallVectorImpl &postOrderRet) const { + struct State { + Value value; + unsigned operandIndex; + }; + SmallVector worklist; + for (auto operand : operands) + worklist.push_back({operand, 0}); + + while (!worklist.empty()) { + State &state = worklist.back(); + auto *opInst = state.value.getDefiningOp(); + // Note: getDefiningOp will return nullptr if the operand is not an + // Operation (i.e., block arguments) which is a terminator for the search. + if (opInst == nullptr) { + worklist.pop_back(); + continue; + } + // if we encounter a memory operation, then + // we can assume it's not worth doing any + // rematerialization: layout conversion + // will be cheaper + if (isa( + opInst)) + return; + // we don't want to rematerialize conversions + if (isa(opInst)) + return; + // visit operands + if (state.operandIndex < opInst->getNumOperands()) { + auto nextOperand = opInst->getOperand(state.operandIndex); + ++state.operandIndex; + worklist.push_back({nextOperand, 0}); + } else { + // Post-visit: done visiting operand, pop off stack. + // and add to post-order result + worklist.pop_back(); + postOrderRet.push_back(opInst); + } + } + } + + Attribute invertEncoding(Type targetType, Operation *op) const { + RankedTensorType targetTensorType = targetType.cast(); + if (auto expand_dims = dyn_cast(op)) { + return targetTensorType.getEncoding() + .cast() + .squeeze(expand_dims.axis()); + } + return targetTensorType.getEncoding(); + } + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *cvt, + mlir::PatternRewriter &rewriter) const override { + if (!llvm::isa(cvt)) + return mlir::failure(); + // constants/splat are handled separately + Operation *op = cvt->getOperand(0).getDefiningOp(); + if (!op) + return mlir::failure(); + if (isa(op)) + return mlir::failure(); + // DFS through all operands + // auto filter = [](Operation *op) { + // return !isa(op); + // }; + + SmallVector postOrderOps; + getReachableNotThroughMemOp({cvt->getOperand(0)}, postOrderOps); + if (postOrderOps.empty()) + return mlir::failure(); + + // We convert cvt(op(arg_0, arg_1, ..., arg_n)) + // into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n)) + BlockAndValueMapping mapping; + for (Value argI : op->getOperands()) { + // Compute new argument types + auto oldArgType = argI.getType().dyn_cast(); + if (!oldArgType) + continue; + auto newEncoding = invertEncoding(cvt->getResultTypes()[0], op); + auto newArgType = RankedTensorType::get( + oldArgType.getShape(), oldArgType.getElementType(), newEncoding); + // Create new argument + auto cvtI = rewriter.create( + op->getLoc(), newArgType, argI); + cvtI->moveBefore(op); + mapping.map(argI, cvtI); + } + Operation *newOp = rewriter.clone(*op, mapping); + newOp->getResult(0).setType(cvt->getResult(0).getType()); + rewriter.replaceOp(cvt, newOp->getResults()); + + return mlir::success(); + } +}; + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +// This modifies the loop in-place +bool tryLegalizeOp(Operation *op, DenseSet toPreserve, + mlir::PatternRewriter &rewriter) { + auto targetType = toPreserve.begin()->getType().cast(); + auto newType = [&](RankedTensorType origType) { + return RankedTensorType::get(origType.getShape(), origType.getElementType(), + targetType.getEncoding()); + }; + bool hasSameTypes = op->getDialect()->getNamespace() == "arith" || + isa(op); + if (hasSameTypes) { + // replace argument types + for (auto arg : llvm::enumerate(op->getOperands())) { + auto argType = arg.value().getType().dyn_cast(); + if (toPreserve.count(arg.value()) || !argType) + continue; + auto newArg = rewriter.create( + rewriter.getUnknownLoc(), newType(argType), arg.value()); + newArg->moveBefore(op); + op->setOperand(arg.index(), newArg); + } + // replace result types + if (!isa(op)) + op->getResult(0).setType(op->getOperand(0).getType()); + return true; + } + + // i + return false; } +std::pair, scf::ForOp> +tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i, + Type newType) { + auto newEncoding = newType.cast().getEncoding(); + auto ctx = forOp.getContext(); + auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; + // Rewrite init argument + Type origType = forOp.getInitArgs()[i].getType(); + SmallVector newInitArgs = forOp.getInitArgs(); + newInitArgs[i] = rewriter.create( + newInitArgs[i].getLoc(), newType, newInitArgs[i]); + // Clone for loop + scf::ForOp newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInitArgs); + newForOp->moveBefore(forOp); + rewriter.setInsertionPointToStart(newForOp.getBody()); + BlockAndValueMapping mapping; + for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + // traverse all ops in the loop + for (Operation &op : forOp.getBody()->without_terminator()) { + // we clone the op + Operation *newOp = rewriter.clone(op, mapping); + // if any argument of this op has changed type, then the + // new operation is not legal and we should try to + // legalize it. + DenseSet modifiedTypes; + for (Value arg : op.getOperands()) { + if (mapping.contains(arg) && + mapping.lookup(arg).getType() != arg.getType()) + modifiedTypes.insert(mapping.lookup(arg)); + } + + bool shouldTryLegalize = !modifiedTypes.empty(); + if (shouldTryLegalize) + tryLegalizeOp(newOp, modifiedTypes, rewriter); + } + // create yield, inserting conversions if necessary + auto yieldOp = forOp.getBody()->getTerminator(); + SmallVector newYieldArgs; + for (Value arg : yieldOp->getOperands()) + newYieldArgs.push_back(mapping.lookup(arg)); + newYieldArgs[i] = rewriter.create( + yieldOp->getLoc(), newType, newYieldArgs[i]); + rewriter.create(forOp.getLoc(), newYieldArgs); + + // replace + SmallVector newResults = newForOp->getResults(); + newResults[i] = rewriter.create( + rewriter.getUnknownLoc(), origType, newForOp->getResult(i)); + newResults[i].getDefiningOp()->moveAfter(newForOp); + return {newResults, newForOp}; +} + +class MoveArgConvertOutOfLoop : public mlir::RewritePattern { +public: + MoveArgConvertOutOfLoop(mlir::MLIRContext *context) + : mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {} + + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const { + + auto forOp = cast(op); + auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; + auto iterArgs = forOp.getRegionIterArgs(); + for (auto iterArg : llvm::enumerate(iterArgs)) { + for (auto op : iterArg.value().getUsers()) { + auto currOps = mlir::getSlice(op, isInLoop); + auto pred = [&](Operation *op) { + return isa(op); + }; + auto isCvt = [&](Operation *op) { + return isa(op); + }; + auto isYield = [&](Operation *op) { return isa(op); }; + auto opIt = std::find(currOps.begin(), currOps.end(), op); + auto yieldIt = std::find_if(currOps.begin(), currOps.end(), isYield); + auto fwdEndIt = std::find_if(opIt, currOps.end(), pred); + auto bwdBeginIt = std::find_if(currOps.begin(), opIt, pred); + auto fwdCvtIt = std::find_if(opIt, fwdEndIt, isCvt); + auto bwdCvtIt = std::find_if(bwdBeginIt, opIt, isCvt); + + if (fwdCvtIt != fwdEndIt) { + auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(), + (*fwdCvtIt)->getResult(0).getType()); + rewriter.replaceOp(forOp, newFor.first); + return success(); + } + } + } + return failure(); + } +}; + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +class PushConversionToSink : public mlir::RewritePattern { +public: + PushConversionToSink(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 2, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *_cvtOp, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(_cvtOp); + auto forOp = dyn_cast(cvt->getParentOp()); + if (!forOp) + return mlir::failure(); + auto yieldOp = cast(forOp.getBody()->getTerminator()); + auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; }; + + SetVector cvtSlices; + auto filter = [&](Operation *op) { + return isInLoop(op) && !isa(op) && + !isa(op) && !isa(op) && + !isa(op); + }; + mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); + if (cvtSlices.empty()) + return failure(); + // if other operands are in the loop + // then we don't touch anything + Operation *op = cvtSlices.front(); + for (Value _arg : op->getOperands()) { + Operation *arg = _arg.getDefiningOp(); + if (arg && isInLoop(arg) && (arg != cvt)) + return failure(); + } + // otherwise, we push the conversion forward + // since we'll be able to move it out of + // the loop once it reaches the yield op + // op(cvt(arg_0), arg_1, ..., arg_n) + // -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n))) + BlockAndValueMapping mapping; + for (Value arg : op->getOperands()) { + if (arg.getDefiningOp() == cvt) + mapping.map(arg, cvt.getOperand()); + else { + auto cvtI = rewriter.create( + arg.getLoc(), cvt.getOperand().getType(), arg); + mapping.map(arg, cvtI); + } + } + Operation *newOp = rewriter.clone(*op, mapping); + newOp->getResult(0).setType(cvt.getOperand().getType()); + auto newCvt = rewriter.create( + newOp->getLoc(), cvt.getResult().getType(), newOp->getResult(0)); + rewriter.replaceOp(op, newCvt->getResults()); + return success(); + } +}; + +// ----------------------------------------------------------------------------- +// +// ----------------------------------------------------------------------------- + +class BlockedToMMA : public mlir::RewritePattern { +public: + BlockedToMMA(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto dotOp = cast(op); + // TODO: Check data-types and SM compatibility + auto oldRetType = dotOp.getResult().getType().cast(); + if (oldRetType.getEncoding().isa()) + return failure(); + // TODO: compute warpsPerCTA + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), + triton::gpu::MmaEncodingAttr::get(oldRetType.getContext(), 2, {2, 2})); + auto oldAcc = dotOp.getOperand(2); + auto newAcc = rewriter.create( + oldAcc.getLoc(), newRetType, oldAcc); + auto newDot = rewriter.create( + dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1), + newAcc, dotOp.allowTF32()); + + rewriter.replaceOpWithNewOp( + op, oldRetType, newDot.getResult()); + return success(); + } +}; + +} // namespace + #define GEN_PASS_CLASSES #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" @@ -36,9 +438,11 @@ public: mlir::RewritePatternSet patterns(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); @@ -47,4 +451,4 @@ public: std::unique_ptr mlir::createTritonGPUCombineOpsPass() { return std::make_unique(); -} +} \ No newline at end of file diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.td b/lib/Dialect/TritonGPU/Transforms/Combine.td index 66349ba76..6bf1b1486 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.td +++ b/lib/Dialect/TritonGPU/Transforms/Combine.td @@ -4,22 +4,4 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td" include "triton/Dialect/Triton/IR/TritonOps.td" -// convert_layout(load(...), #L) => copy_async(...); barrier -// if #L is smem_layout -def CopyAsyncOptPattern : Pat< - (TTG_ConvertLayoutOp:$res (TT_LoadOp $ptr, $mask, $other, $cache, $evict, $isVolatile, $isOtherUnspecified)), - (TTG_CopyAsyncOp $ptr, $mask, $other, $cache, $evict, $isVolatile, $isOtherUnspecified), - [(Constraint> $res)]>; - -// ConvertLayout(ConvertLayout(x, #L0), #L1) => ConvertLayout(x, #L1) -def ConvertLayoutOptPattern : Pat< - (TTG_ConvertLayoutOp (TTG_ConvertLayoutOp $x)), - (TTG_ConvertLayoutOp $x)>; - -// TODO: can we replace this with ConvertLayoutOp's folder? -// ConvertLayout(x, #L) => x if x.layout() == #L -def RedundantConvertLayoutOptPattern : Pat< - (TTG_ConvertLayoutOp:$res $x), (replaceWithValue $x), - [(Constraint> $res, $x)]>; - #endif diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 102299757..d44a25c71 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1,8 +1,7 @@ +#include "mlir/IR/BlockAndValueMapping.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" -#include "mlir/IR/BlockAndValueMapping.h" - //===----------------------------------------------------------------------===// // // This file implements loop software pipelining @@ -168,8 +167,7 @@ LogicalResult LoopPipeliner::initialize() { if (auto tensorType = convertLayout.getResult() .getType() .dyn_cast()) { - if (tensorType.getEncoding() - .isa()) { + if (tensorType.getEncoding().isa()) { isCandiate = true; loadsMapping[loadOp] = convertLayout; } @@ -263,7 +261,7 @@ void LoopPipeliner::emitPrologue() { // assert(I1 or TensorOf<[I1]>); OpBuilder::InsertionGuard g(builder); builder.setInsertionPoint(newOp); - Value splatCond = builder.create( + Value splatCond = builder.create( mask.getLoc(), mask.getType(), loopCond); Value newMask = builder.create(mask.getLoc(), mask, splatCond); @@ -356,6 +354,9 @@ scf::ForOp LoopPipeliner::createNewForOp() { Value loadUse = load.getUsers().begin()->getResult(0); mapping.lookup(loadUse).replaceAllUsesWith( newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]); + // delete old load and layout conversion + mapping.lookup(loadUse).getDefiningOp()->erase(); + mapping.lookup(load).getDefiningOp()->erase(); } // 4. prefetch the next iteration @@ -389,7 +390,7 @@ scf::ForOp LoopPipeliner::createNewForOp() { if (loads.contains(op->getResult(0))) { auto loadOp = llvm::cast(op); Value mask = loadOp.mask(); - Value splatCond = builder.create( + Value splatCond = builder.create( mask.getLoc(), mask.getType(), nextLoopCond); Value newMask = builder.create( mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask)); @@ -442,9 +443,10 @@ scf::ForOp LoopPipeliner::createNewForOp() { yieldValues.push_back( depArgsMapping.lookup(newForOp.getRegionIterArgs()[i])); yieldValues.push_back(nextIV); + builder.setInsertionPointToEnd(newForOp.getBody()); - builder.create(forOp.getBody()->getTerminator()->getLoc(), - yieldValues); + auto test = builder.create( + forOp.getBody()->getTerminator()->getLoc(), yieldValues); return newForOp; } diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 521fa8964..4adb11143 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -29,7 +29,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, llvm::SmallVector order(rank); std::iota(order.begin(), order.end(), 0); llvm::SmallVector sizePerThread(rank, 1); - Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( + Attribute encoding = triton::gpu::BlockedEncodingAttr::get( this->context, shape, sizePerThread, order, this->numWarps); return RankedTensorType::get(shape, tensorType.getElementType(), encoding); }); @@ -95,9 +95,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( dotOp.a().getType().cast().getEncoding(); Attribute bEncoding = dotOp.b().getType().cast().getEncoding(); - if (aEncoding && - aEncoding.isa() && - bEncoding && bEncoding.isa()) + if (aEncoding && aEncoding.isa() && + bEncoding && bEncoding.isa()) return true; return false; }); diff --git a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp index e88799927..7bf4143f6 100644 --- a/lib/Dialect/TritonGPU/Transforms/Verifier.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Verifier.cpp @@ -33,7 +33,7 @@ private: Attribute encoding = tensorType.getEncoding(); if (!encoding) return dotOp.emitError() << name << " should have encoding"; - if (!encoding.isa()) + if (!encoding.isa()) return dotOp.emitError() << name << " should be of shared layout"; } else return dotOp.emitError() @@ -49,8 +49,8 @@ private: Attribute encoding = tensorType.getEncoding(); if (!encoding) return dotOp.emitError() << name << " should have encoding"; - if (!encoding.isa() && - !encoding.isa()) + if (!encoding.isa() && + !encoding.isa()) return dotOp.emitError() << name << " should be of distributed layout"; if (name == 'c') diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 321954a8f..b70348d7f 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -749,8 +749,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()): context = _triton.ir.context() context.load_triton() # create kernel prototype - arg_types = signature.replace(' ', '').split(',') constants = {fn.arg_names.index(name): value for name, value in constants.items()} + attributes = {fn.arg_names.index(name): value for name, value in attributes.items()} + arg_types = signature.replace(' ', '').split(',') arg_types = [str_to_ty(x) for x in arg_types] prototype = triton.language.function_type([], arg_types) # visit kernel AST @@ -769,6 +770,14 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()): return ret +def optimize_triton_ir(mod): + pm = _triton.ir.pass_manager(mod.context) + pm.add_inliner_pass() + pm.add_canonicalizer_pass() + pm.run(mod) + return mod + + def make_tritongpu_ir(mod, num_warps): pm = _triton.ir.pass_manager(mod.context) pm.add_inliner_pass() @@ -785,7 +794,7 @@ def optimize_tritongpu_ir(mod, num_stages): pm.add_tritongpu_pipeline_pass(num_stages) pm.add_canonicalizer_pass() pm.add_cse_pass() - pm.add_triton_gpu_combine_pass() + # pm.add_triton_gpu_combine_pass() pm.add_triton_gpu_verifier_pass() pm.run(mod) return mod @@ -815,6 +824,7 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output) # triton-ir module = make_triton_ir(fn, signature, constants, attributes) + module = optimize_triton_ir(module) if output == "ttir": return module.str() # tritongpu-ir diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir new file mode 100644 index 000000000..c34babdff --- /dev/null +++ b/test/TritonGPU/combine.mlir @@ -0,0 +1,175 @@ +// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s + +#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> + +// CHECK: [[target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> + +func @cst() -> tensor<1024xi32, #layout1> { + %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> + %1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: return %cst : tensor<1024xi32, [[target_layout]]> + return %1: tensor<1024xi32, #layout1> +} + +func @range() -> tensor<1024xi32, #layout1> { + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: return %0 : tensor<1024xi32, [[target_layout]]> + return %1: tensor<1024xi32, #layout1> +} + +func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { + %0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0> + %1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: return %0 : tensor<1024xi32, [[target_layout]]> + return %1: tensor<1024xi32, #layout1> +} + +func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { + %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> + %2 = arith.muli %0, %1 : tensor<1024xi32, #layout0> + %3 = triton_gpu.convert_layout %2 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> + %4 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0> + %5 = triton_gpu.convert_layout %2 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> + %6 = arith.addi %3, %5 : tensor<1024xi32, #layout1> + return %6: tensor<1024xi32, #layout1> + // CHECK: %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]> + // CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]> + // CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]> + // CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]> + // CHECK: %4 = arith.muli %2, %3 : tensor<1024xi32, [[target_layout]]> + // CHECK: %5 = arith.muli %0, %1 : tensor<1024xi32, [[target_layout]]> + // CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[target_layout]]> + // CHECK: return %6 : tensor<1024xi32, [[target_layout]]> +} + +#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> + +// CHECK-LABEL: transpose +func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { + // CHECK: %cst = arith.constant dense : tensor<64x64xi1, [[row_layout]]> + // CHECK: %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, [[row_layout]]> + // CHECK: %cst_1 = arith.constant dense : tensor<64x64xi1, [[col_layout]]> + // CHECK: %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>> + // CHECK: %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>> + // CHECK: %2 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>) -> tensor<64x1xi32, [[row_layout]]> + // CHECK: %3 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, [[row_layout]]> + // CHECK: %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, [[row_layout]]> + // CHECK: %5 = arith.muli %2, %3 : tensor<64x1xi32, [[row_layout]]> + // CHECK: %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>> + // CHECK: %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>> + // CHECK: %8 = tt.getelementptr %4, %5 : tensor<64x1x!tt.ptr, [[row_layout]]> + // CHECK: %9 = tt.expand_dims %7 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>) -> tensor<1x64xi32, [[row_layout]]> + // CHECK: %10 = tt.broadcast %8 : (tensor<64x1x!tt.ptr, [[row_layout]]>) -> tensor<64x64x!tt.ptr, [[row_layout]]> + // CHECK: %11 = tt.broadcast %9 : (tensor<1x64xi32, [[row_layout]]>) -> tensor<64x64xi32, [[row_layout]]> + // CHECK: %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, [[col_layout]]> + // CHECK: %13 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>) -> tensor<64x1xi32, [[col_layout]]> + // CHECK: %14 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>) -> tensor<1x64xi32, [[col_layout]]> + // CHECK: %15 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, [[col_layout]]> + // CHECK: %16 = tt.getelementptr %12, %13 : tensor<64x1x!tt.ptr, [[col_layout]]> + // CHECK: %17 = arith.muli %14, %15 : tensor<1x64xi32, [[col_layout]]> + // CHECK: %18 = tt.broadcast %16 : (tensor<64x1x!tt.ptr, [[col_layout]]>) -> tensor<64x64x!tt.ptr, [[col_layout]]> + // CHECK: %19 = tt.broadcast %17 : (tensor<1x64xi32, [[col_layout]]>) -> tensor<64x64xi32, [[col_layout]]> + // CHECK: %20 = tt.getelementptr %10, %11 : tensor<64x64x!tt.ptr, [[row_layout]]> + // CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> + // CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr, [[col_layout]]> + // CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> + // CHECK: tt.store %22, %23, %cst_1, : tensor<64x64xf32, [[col_layout]]> + // CHECK: return + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> + %cst_0 = arith.constant dense : tensor<64x64xi1, #blocked1> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1> + %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> + %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> + %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> + %10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %11 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr, #blocked1> + %13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> + %14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2> + %15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> + %18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr, #blocked1> + %19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> + %20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> + %21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3> + %22 = tt.load %19, %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, isOtherUnspecified = false} : tensor<64x64xf32, #blocked3> + %23 = triton_gpu.convert_layout %22 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1> + %24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked4> + %25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4> + %26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4> + tt.store %24, %25, %26, : tensor<64x64xf32, #blocked4> + return +} + +// CHECK-LABEL: loop +func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { + // CHECK-NOT: triton_gpu.convert_layout + // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr, [[row_layout]]>) + // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]> + // CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[row_layout]]> + // CHECK-NEXT: {{.*}} = tt.getelementptr {{.*}} : tensor<64x64x!tt.ptr, [[row_layout]]> + // CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr, [[row_layout]]> + // CHECK-NEXT: } + // CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout_novec]]> + // CHECK-NOT: triton_gpu.convert_layout + %cst = arith.constant dense : tensor<64x64xi1, #blocked1> + %cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1> + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0> + %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1> + %2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1> + %3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1> + %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr, #blocked1> + %6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2> + %7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> + %10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr, #blocked1> + %11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1>) { + %23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked3> + %24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3> + %25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3> + %26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked3> + %27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1> + %28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1> + %29 = tt.getelementptr %arg7, %cst_0 : tensor<64x64x!tt.ptr, #blocked1> + scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr, #blocked1> + } + %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x1x!tt.ptr, #blocked1> + %13 = tt.getelementptr %12, %1 : tensor<64x1x!tt.ptr, #blocked1> + %14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2> + %15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2> + %16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2> + %18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1> + %19 = tt.getelementptr %16, %18 : tensor<64x64x!tt.ptr, #blocked1> + %20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr, #blocked1>) -> tensor<64x64x!tt.ptr, #blocked1> + %21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1> + %22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1> + tt.store %20, %21, %22, : tensor<64x64xf32, #blocked1> + return +} \ No newline at end of file