special encoding for broadcast

This commit is contained in:
Yan Da
2022-06-18 21:16:45 +08:00
parent 53cf93ce6a
commit 9d1b5e3f79
6 changed files with 248 additions and 72 deletions

View File

@@ -4,8 +4,9 @@
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
// include "mlir/IR/TensorEncoding.td" // include "mlir/IR/TensorEncoding.td"
class TritonGPU_Attr<string name, list<Trait> traits = []> class TritonGPU_Attr<string name, list<Trait> traits = [],
: AttrDef<TritonGPU_Dialect, name, traits>; string baseCppClass = "::mlir::Attribute">
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> { def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
let mnemonic = "shared_layout"; let mnemonic = "shared_layout";
@@ -104,7 +105,8 @@ And the associated TritonGPU MLIR
ArrayRefParameter< ArrayRefParameter<
"unsigned", "unsigned",
"order of axes by the rate of changing" "order of axes by the rate of changing"
>:$order >:$order,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$threadOrdering, // "AffineMap":$threadOrdering,
// "AffineMap":warpOrdering, // "AffineMap":warpOrdering,
// "AffineMap":$blockOrdering, // "AffineMap":$blockOrdering,
@@ -114,6 +116,28 @@ And the associated TritonGPU MLIR
// let genVerifyDecl = 1; // let genVerifyDecl = 1;
} }
def TritonGPUBlockedMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
let mnemonic = "blocked_multicast_layout";
let description = [{
to be broadcasted to blocked_layout
}];
// This needs to be synced with BlockedEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$warpTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize,
ArrayRefParameter<"unsigned">:$order,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
);
// let genVerifyDecl = 1;
}
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> { def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
let mnemonic = "mma_layout"; let mnemonic = "mma_layout";
@@ -131,7 +155,8 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
ArrayRefParameter<"unsigned">:$shapePerTile, ArrayRefParameter<"unsigned">:$shapePerTile,
// TODO: should Distributed layout also // TODO: should Distributed layout also
ArrayRefParameter<"unsigned">:$repetitions, ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread ArrayRefParameter<"unsigned">:$contigPerThread,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$warpOrdering, // "AffineMap":$warpOrdering,
// "AffineMap":$blockOrdering // "AffineMap":$blockOrdering
); );
@@ -139,4 +164,28 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
// let genVerifyDecl = 1; // let genVerifyDecl = 1;
} }
def TritonGPUMmaMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
let mnemonic = "mma_multicast_layout";
let description = [{
To be broadcasted to mma.
}];
// This needs to be synced with MmaEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
ArrayRefParameter<"unsigned">:$shapePerWarp,
ArrayRefParameter<"unsigned">:$warpPerTile,
ArrayRefParameter<"unsigned">:$shapePerTile,
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
);
// let genVerifyDecl = 1;
}
#endif #endif

View File

@@ -23,6 +23,9 @@ class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter; TritonGPUTypeConverter &typeConverter;
public: public:
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter); explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
/// update layouts & insert ConvertLayoutOp if necessary
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
}; };
} // namespace mlir } // namespace mlir

View File

@@ -342,9 +342,13 @@ public:
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here? // mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns); populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target, if(failed(applyPartialConversion(mod, target, std::move(patterns))))
std::move(patterns))))
return signalPassFailure(); return signalPassFailure();
// update layouts
// broadcast src => multicast, dst => broadcasted
if(failed(target.refineLayouts(mod, numWarps)))
return signalPassFailure();
} }
}; };

View File

@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
// parse an array of integers // parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser, static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr, const NamedAttribute &attr,
SmallVector<unsigned, 2> &res, /*SmallVector<unsigned, 2>*/auto &res,
StringRef desc) { StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) { if (!arrayAttr) {
@@ -36,8 +36,7 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
#define GET_ATTRDEF_CLASSES #define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc" #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
Attribute static Attribute parseBlocked(AsmParser &parser, Type type) {
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed()) if (parser.parseLess().failed())
return {}; return {};
// Parse the data as a dictionary // Parse the data as a dictionary
@@ -51,28 +50,7 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
SmallVector<unsigned, 2> warpTileSize; SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize; SmallVector<unsigned, 2> blockTileSize;
SmallVector<unsigned, 2> order; SmallVector<unsigned, 2> order;
SmallVector<unsigned, 2> broadcastAxis;
// parse an array of integers
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
// SmallVector<unsigned, 2> &res,
// StringRef desc) -> LogicalResult {
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
// if (!arrayAttr) {
// parser.emitError(parser.getNameLoc(), "expected an array for ")
// << desc;
// return failure();
// }
// for (Attribute i : arrayAttr) {
// auto intAttr = i.dyn_cast<IntegerAttr>();
// if (!intAttr) {
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
// << desc;
// return failure();
// }
// res.push_back(intAttr.getUInt());
// }
// return success();
// };
for (const NamedAttribute &attr : dict) { for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") { if (attr.getName() == "threadTileSize") {
@@ -98,20 +76,39 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
threadTileSize, threadTileSize,
warpTileSize, warpTileSize,
blockTileSize, blockTileSize,
order); order,
broadcastAxis);
} }
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { static void printBlocked(AsmPrinter &printer, auto *attr) {
printer << "<{" printer << "<{"
<< "threadTileSize = [" << getThreadTileSize() << "]" << "threadTileSize = [" << attr->getThreadTileSize() << "]"
<< ", warpTileSize = [" << getWarpTileSize() << "]" << ", warpTileSize = [" << attr->getWarpTileSize() << "]"
<< ", blockTileSize = [" << getBlockTileSize() << "]" << ", blockTileSize = [" << attr->getBlockTileSize() << "]"
<< ", order = [" << getOrder() << "]" << ", order = [" << attr->getOrder() << "]"
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
<< "}>"; << "}>";
} }
Attribute Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) { TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this);
}
Attribute
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
printBlocked(printer, this);
}
static Attribute parseMma(AsmParser &parser, Type type) {
if (parser.parseLess().failed()) if (parser.parseLess().failed())
return {}; return {};
DictionaryAttr dict; DictionaryAttr dict;
@@ -126,6 +123,7 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
SmallVector<unsigned, 2> shapePerTile; SmallVector<unsigned, 2> shapePerTile;
SmallVector<unsigned, 2> repetitions; SmallVector<unsigned, 2> repetitions;
SmallVector<unsigned, 2> contigPerThread; SmallVector<unsigned, 2> contigPerThread;
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) { for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") { if (attr.getName() == "fragmentPerWarp") {
@@ -159,18 +157,37 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
warpPerTile, warpPerTile,
shapePerTile, shapePerTile,
repetitions, repetitions,
contigPerThread); contigPerThread,
broadcastAxis);
}
static void printMma(AsmPrinter &printer, auto *attr) {
printer << "<{"
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
<< ", repetitions = [" << attr->getRepetitions() << "]"
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
<< "}>";
}
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
} }
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const { void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{" printMma(printer, this);
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]" }
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
<< ", warpPerTile = [" << getWarpPerTile() << "]" Attribute
<< ", shapePerTile = [" << getShapePerTile() << "]" TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
<< ", repetitions = [" << getRepetitions() << "]" return parseMma(parser, type);
<< ", contigPerThread = [" << getContigPerThread() << "]" }
<< "}>";
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
} }
Attribute Attribute

View File

@@ -1,9 +1,11 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <algorithm> #include <algorithm>
using namespace mlir; using namespace mlir;
using namespace mlir::triton::gpu;
// //
// TypeConverter // TypeConverter
@@ -41,6 +43,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
llvm::SmallVector<unsigned> warpTileSize(rank, 1); llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank); llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank); llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis;
int remainingThreads = numThreads; int remainingThreads = numThreads;
int remainingLanes = /*warp size*/32; int remainingLanes = /*warp size*/32;
for (int64_t dim = 0; dim < rank; ++dim) { for (int64_t dim = 0; dim < rank; ++dim) {
@@ -53,7 +56,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TODO: will we need repetition? // TODO: will we need repetition?
} }
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get( Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order); context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
return RankedTensorType::get(shape, elementType, encoding); return RankedTensorType::get(shape, elementType, encoding);
}); });
@@ -81,7 +84,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// NOTE: only for remapped values. // NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) { ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
llvm_unreachable("Not implemented"); llvm_unreachable("Not implemented");
return llvm::None; return llvm::None;
}); });
@@ -124,3 +126,98 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
}); });
} }
// %dst = tt.broadcast %src
// =>
// %newSrc = convert_layout %src
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) {
broadcasts.push_back(op);
});
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
OpBuilder builder(broadcast);
Value src = mapping.lookupOrDefault(broadcast.src());
Type originSrcType = src.getType();
Type originDstType = broadcast.getType();
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
unsigned dstRank = originDstTensorType.getRank();
// compute newSrcType & broadcastAxis
Type newSrcType;
SmallVector<unsigned> broadcastAxis;
bool isSrcScalar = false;
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
assert(tensorType.getRank() == dstRank &&
"src & dst should have same rank (verifier should catch this)");
for (unsigned ax = 0; ax < dstRank; ++ax)
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(),
blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(),
blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(),
broadcastAxis
);
newSrcType = RankedTensorType::get(
tensorType.getShape(),
tensorType.getElementType(),
newSrcEnc
);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
for (unsigned ax = 0; ax < dstRank; ++ax)
broadcastAxis.push_back(ax);
newSrcType = originSrcType;
isSrcScalar = true;
}
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(
src.getLoc(), newSrcType, src
);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(),
originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(),
originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(),
broadcastAxis
);
auto newType = RankedTensorType::get(
originDstTensorType.getShape(),
originDstTensorType.getElementType(),
newEnc
);
Value newBroadcast = builder.create<triton::BroadcastOp>(
broadcast.getLoc(), newType, src
);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast
);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);
}
return success();
}

View File

@@ -45,32 +45,38 @@ module {
%c32 = arith.constant 32 : index %c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index %c0 = arith.constant 0 : index
%c256_i32 = arith.constant 256 : i32 %c256_i32 = arith.constant 256 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%0 = tt.get_program_id {axis = 0 : i32} : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32 %1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %4 = triton_gpu.convert_layout %3 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %5 = arith.addi %4, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %6 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %7 = triton_gpu.convert_layout %6 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %8 = "triton_gpu.cmpi"(%5, %7) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %9 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %10 = triton_gpu.convert_layout %9 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%11 = arith.index_cast %arg4 : i32 to index %11 = tt.getelementptr %10, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%12:3 = scf.for %arg6 = %c0 to %11 step %c32 iter_args(%arg7 = %cst, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) { %12 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%15 = tt.load %arg8, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %13 = triton_gpu.convert_layout %12 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%16 = tt.load %arg9, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %14 = tt.getelementptr %13, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%17 = arith.addf %15, %16 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %15 = arith.index_cast %arg4 : i32 to index
%18 = arith.addf %arg7, %17 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %16:3 = scf.for %arg6 = %c0 to %15 step %c32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %14) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) {
%19 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %20 = tt.load %arg8, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%20 = tt.getelementptr %arg8, %19 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %21 = tt.load %arg9, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%21 = tt.getelementptr %arg9, %19 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %22 = arith.addf %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
scf.yield %18, %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %23 = arith.addf %arg7, %22 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%25 = triton_gpu.convert_layout %24 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%26 = tt.getelementptr %arg8, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%27 = tt.getelementptr %arg9, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
scf.yield %23, %26, %27 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
} }
%13 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %17 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%14 = tt.getelementptr %13, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %18 = triton_gpu.convert_layout %17 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
tt.store %14, %12#0, %6, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>> %19 = tt.getelementptr %18, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
tt.store %19, %16#0, %8, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
return return
} }
} }