From 9d1b5e3f797323377c6b9047c00984a103f45497 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Sat, 18 Jun 2022 21:16:45 +0800 Subject: [PATCH] special encoding for broadcast --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 57 +++++++++- .../Transforms/TritonGPUConversion.h | 3 + .../TritonToTritonGPU/TritonToTritonGPU.cpp | 8 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 99 ++++++++++------- .../Transforms/TritonGPUConversion.cpp | 101 +++++++++++++++++- rewrite-test/jit/vecadd/vecadd.mlir | 52 +++++---- 6 files changed, 248 insertions(+), 72 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 38c33c5a4..4df2482cc 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -4,8 +4,9 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" // include "mlir/IR/TensorEncoding.td" -class TritonGPU_Attr traits = []> - : AttrDef; +class TritonGPU_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef; def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> { let mnemonic = "shared_layout"; @@ -104,7 +105,8 @@ And the associated TritonGPU MLIR ArrayRefParameter< "unsigned", "order of axes by the rate of changing" - >:$order + >:$order, + ArrayRefParameter<"unsigned">:$broadcastAxis // "AffineMap":$threadOrdering, // "AffineMap":warpOrdering, // "AffineMap":$blockOrdering, @@ -114,6 +116,28 @@ And the associated TritonGPU MLIR // let genVerifyDecl = 1; } +def TritonGPUBlockedMulticastEncodingAttr + : TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> { + let mnemonic = "blocked_multicast_layout"; + + let description = [{ + to be broadcasted to blocked_layout + }]; + + // This needs to be synced with BlockedEncoding + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$threadTileSize, + ArrayRefParameter<"unsigned">:$warpTileSize, + ArrayRefParameter<"unsigned">:$blockTileSize, + ArrayRefParameter<"unsigned">:$order, + // unique to broadcasted layout + ArrayRefParameter<"unsigned">:$broadcastAxis + ); + + // let genVerifyDecl = 1; +} + def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { let mnemonic = "mma_layout"; @@ -131,7 +155,8 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { ArrayRefParameter<"unsigned">:$shapePerTile, // TODO: should Distributed layout also ArrayRefParameter<"unsigned">:$repetitions, - ArrayRefParameter<"unsigned">:$contigPerThread + ArrayRefParameter<"unsigned">:$contigPerThread, + ArrayRefParameter<"unsigned">:$broadcastAxis // "AffineMap":$warpOrdering, // "AffineMap":$blockOrdering ); @@ -139,4 +164,28 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { // let genVerifyDecl = 1; } +def TritonGPUMmaMulticastEncodingAttr + : TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> { + let mnemonic = "mma_multicast_layout"; + + let description = [{ + To be broadcasted to mma. + }]; + + // This needs to be synced with MmaEncoding + let parameters = ( + ins + ArrayRefParameter<"unsigned">:$fragmentPerWarp, + ArrayRefParameter<"unsigned">:$shapePerWarp, + ArrayRefParameter<"unsigned">:$warpPerTile, + ArrayRefParameter<"unsigned">:$shapePerTile, + ArrayRefParameter<"unsigned">:$repetitions, + ArrayRefParameter<"unsigned">:$contigPerThread, + // unique to broadcasted layout + ArrayRefParameter<"unsigned">:$broadcastAxis + ); + + // let genVerifyDecl = 1; +} + #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h index 2f34d71f7..fd9048570 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h +++ b/include/triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h @@ -23,6 +23,9 @@ class TritonGPUConversionTarget : public ConversionTarget { TritonGPUTypeConverter &typeConverter; public: explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter); + + /// update layouts & insert ConvertLayoutOp if necessary + LogicalResult refineLayouts(ModuleOp mod, int numThreads); }; } // namespace mlir diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 591af75b5..94bc4c696 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -342,9 +342,13 @@ public: // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? populateSCFPatterns(typeConverter, patterns); - if(failed(applyPartialConversion(mod, target, - std::move(patterns)))) + if(failed(applyPartialConversion(mod, target, std::move(patterns)))) return signalPassFailure(); + + // update layouts + // broadcast src => multicast, dst => broadcasted + if(failed(target.refineLayouts(mod, numWarps))) + return signalPassFailure(); } }; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 4380524fe..9a99e8ae9 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -10,7 +10,7 @@ using namespace mlir::triton::gpu; // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, - SmallVector &res, + /*SmallVector*/auto &res, StringRef desc) { auto arrayAttr = attr.getValue().dyn_cast(); if (!arrayAttr) { @@ -36,8 +36,7 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser, #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" -Attribute -TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { +static Attribute parseBlocked(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; // Parse the data as a dictionary @@ -51,28 +50,7 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { SmallVector warpTileSize; SmallVector blockTileSize; SmallVector order; - - // parse an array of integers - // auto parseIntArrayAttr = [&parser](const NamedAttribute &attr, - // SmallVector &res, - // StringRef desc) -> LogicalResult { - // auto arrayAttr = attr.getValue().dyn_cast(); - // if (!arrayAttr) { - // parser.emitError(parser.getNameLoc(), "expected an array for ") - // << desc; - // return failure(); - // } - // for (Attribute i : arrayAttr) { - // auto intAttr = i.dyn_cast(); - // if (!intAttr) { - // parser.emitError(parser.getNameLoc(), "expected an integer value in ") - // << desc; - // return failure(); - // } - // res.push_back(intAttr.getUInt()); - // } - // return success(); - // }; + SmallVector broadcastAxis; for (const NamedAttribute &attr : dict) { if (attr.getName() == "threadTileSize") { @@ -98,20 +76,39 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { threadTileSize, warpTileSize, blockTileSize, - order); + order, + broadcastAxis); } -void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { +static void printBlocked(AsmPrinter &printer, auto *attr) { printer << "<{" - << "threadTileSize = [" << getThreadTileSize() << "]" - << ", warpTileSize = [" << getWarpTileSize() << "]" - << ", blockTileSize = [" << getBlockTileSize() << "]" - << ", order = [" << getOrder() << "]" + << "threadTileSize = [" << attr->getThreadTileSize() << "]" + << ", warpTileSize = [" << attr->getWarpTileSize() << "]" + << ", blockTileSize = [" << attr->getBlockTileSize() << "]" + << ", order = [" << attr->getOrder() << "]" + << ", broadcastAxis = [" << attr->getBroadcastAxis() << "]" << "}>"; } Attribute -TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { +TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) { + parseBlocked(parser, type); +} + +void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { + printBlocked(printer, this); +} + +Attribute +TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) { + parseBlocked(parser, type); +} + +void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const { + printBlocked(printer, this); +} + +static Attribute parseMma(AsmParser &parser, Type type) { if (parser.parseLess().failed()) return {}; DictionaryAttr dict; @@ -126,6 +123,7 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { SmallVector shapePerTile; SmallVector repetitions; SmallVector contigPerThread; + SmallVector broadcastAxis; for (const NamedAttribute &attr : dict) { if (attr.getName() == "fragmentPerWarp") { @@ -159,18 +157,37 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { warpPerTile, shapePerTile, repetitions, - contigPerThread); + contigPerThread, + broadcastAxis); +} + +static void printMma(AsmPrinter &printer, auto *attr) { + printer << "<{" + << "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]" + << ", shapePerWarp = [" << attr->getShapePerWarp() << "]" + << ", warpPerTile = [" << attr->getWarpPerTile() << "]" + << ", shapePerTile = [" << attr->getShapePerTile() << "]" + << ", repetitions = [" << attr->getRepetitions() << "]" + << ", contigPerThread = [" << attr->getContigPerThread() << "]" + << "}>"; +} + +Attribute +TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { + return parseMma(parser, type); } void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { - printer << "<{" - << "fragmentPerWarp = [" << getFragmentPerWarp() << "]" - << ", shapePerWarp = [" << getShapePerWarp() << "]" - << ", warpPerTile = [" << getWarpPerTile() << "]" - << ", shapePerTile = [" << getShapePerTile() << "]" - << ", repetitions = [" << getRepetitions() << "]" - << ", contigPerThread = [" << getContigPerThread() << "]" - << "}>"; + printMma(printer, this); +} + +Attribute +TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) { + return parseMma(parser, type); +} + +void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const { + printMma(printer, this); } Attribute diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 970ecbedc..bedf9f38a 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -1,9 +1,11 @@ #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "mlir/IR/BlockAndValueMapping.h" #include using namespace mlir; +using namespace mlir::triton::gpu; // // TypeConverter @@ -41,6 +43,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, llvm::SmallVector warpTileSize(rank, 1); llvm::SmallVector blockTileSize(rank); llvm::SmallVector order(rank); + llvm::SmallVector broadcastAxis; int remainingThreads = numThreads; int remainingLanes = /*warp size*/32; for (int64_t dim = 0; dim < rank; ++dim) { @@ -53,7 +56,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // TODO: will we need repetition? } Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( - context, threadTileSize, warpTileSize, blockTileSize, order); + context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis); return RankedTensorType::get(shape, elementType, encoding); }); @@ -81,7 +84,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, // NOTE: only for remapped values. addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) { - assert(inputs.size() == 1); llvm_unreachable("Not implemented"); return llvm::None; }); @@ -124,3 +126,98 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( }); } + +// %dst = tt.broadcast %src +// => +// %newSrc = convert_layout %src +// %bcst = tt.broadcast %newSrc +// %dst = convert_layout %bcst +LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod, + int numThreads) { + // collect broadcasts + SmallVector broadcasts; + mod.walk([&](triton::BroadcastOp op) { + broadcasts.push_back(op); + }); + + BlockAndValueMapping mapping; + for (auto broadcast : broadcasts) { + OpBuilder builder(broadcast); + Value src = mapping.lookupOrDefault(broadcast.src()); + Type originSrcType = src.getType(); + Type originDstType = broadcast.getType(); + auto originDstTensorType = originDstType.dyn_cast(); + unsigned dstRank = originDstTensorType.getRank(); + + // compute newSrcType & broadcastAxis + Type newSrcType; + SmallVector broadcastAxis; + bool isSrcScalar = false; + if (auto tensorType = originSrcType.dyn_cast()) { + assert(tensorType.getRank() == dstRank && + "src & dst should have same rank (verifier should catch this)"); + for (unsigned ax = 0; ax < dstRank; ++ax) + if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax]) + broadcastAxis.push_back(ax); + + Attribute originSrcEnc = tensorType.getEncoding(); + if (auto blockedEnc = originSrcEnc.dyn_cast()) { + auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get( + blockedEnc.getContext(), + blockedEnc.getThreadTileSize(), + blockedEnc.getWarpTileSize(), + blockedEnc.getBlockTileSize(), + blockedEnc.getOrder(), + broadcastAxis + ); + newSrcType = RankedTensorType::get( + tensorType.getShape(), + tensorType.getElementType(), + newSrcEnc + ); + } else + llvm_unreachable("src of broadcast should have blocked encoding"); + } else { + for (unsigned ax = 0; ax < dstRank; ++ax) + broadcastAxis.push_back(ax); + newSrcType = originSrcType; + isSrcScalar = true; + } + + // create new src + if (!isSrcScalar) // we don't need to convert layout for scalar values + src = builder.create( + src.getLoc(), newSrcType, src + ); + + // create new broadcast + // compute new type (encoding) + auto originDstEnc = originDstTensorType.getEncoding() + .dyn_cast(); + auto newEnc = TritonGPUBlockedEncodingAttr::get( + originDstEnc.getContext(), + originDstEnc.getThreadTileSize(), + originDstEnc.getWarpTileSize(), + originDstEnc.getBlockTileSize(), + originDstEnc.getOrder(), + broadcastAxis + ); + auto newType = RankedTensorType::get( + originDstTensorType.getShape(), + originDstTensorType.getElementType(), + newEnc + ); + Value newBroadcast = builder.create( + broadcast.getLoc(), newType, src + ); + // we don't want to change the encoding of the result + Value newDst = builder.create( + broadcast.getLoc(), originDstType, newBroadcast + ); + + broadcast.replaceAllUsesWith(newDst); + mapping.map(broadcast, newDst); + } + + return success(); +} diff --git a/rewrite-test/jit/vecadd/vecadd.mlir b/rewrite-test/jit/vecadd/vecadd.mlir index 07a216925..27db90323 100644 --- a/rewrite-test/jit/vecadd/vecadd.mlir +++ b/rewrite-test/jit/vecadd/vecadd.mlir @@ -45,32 +45,38 @@ module { %c32 = arith.constant 32 : index %c0 = arith.constant 0 : index %c256_i32 = arith.constant 256 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c256_i32 : i32 - %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %7 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %9 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %11 = arith.index_cast %arg4 : i32 to index - %12:3 = scf.for %arg6 = %c0 to %11 step %c32 iter_args(%arg7 = %cst, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) { - %15 = tt.load %arg8, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %16 = tt.load %arg9, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %17 = arith.addf %15, %16 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %18 = arith.addf %arg7, %17 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %19 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %20 = tt.getelementptr %arg8, %19 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %21 = tt.getelementptr %arg9, %19 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - scf.yield %18, %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> + %4 = triton_gpu.convert_layout %3 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %5 = arith.addi %4, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %6 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> + %7 = triton_gpu.convert_layout %6 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %8 = "triton_gpu.cmpi"(%5, %7) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %9 = tt.broadcast %arg0 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> + %10 = triton_gpu.convert_layout %9 : (tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %11 = tt.getelementptr %10, %5 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %12 = tt.broadcast %arg1 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> + %13 = triton_gpu.convert_layout %12 : (tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %14 = tt.getelementptr %13, %5 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %15 = arith.index_cast %arg4 : i32 to index + %16:3 = scf.for %arg6 = %c0 to %15 step %c32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %14) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) { + %20 = tt.load %arg8, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %21 = tt.load %arg9, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %22 = arith.addf %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %23 = arith.addf %arg7, %22 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> + %25 = triton_gpu.convert_layout %24 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %26 = tt.getelementptr %arg8, %25 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %27 = tt.getelementptr %arg9, %25 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + scf.yield %23, %26, %27 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> } - %13 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - %14 = tt.getelementptr %13, %4 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> - tt.store %14, %12#0, %6, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> + %17 = tt.broadcast %arg2 : (!tt.ptr) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>> + %18 = triton_gpu.convert_layout %17 : (tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + %19 = tt.getelementptr %18, %5 : tensor<256x!tt.ptr, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> + tt.store %19, %16#0, %8, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>> return } }