special encoding for broadcast

This commit is contained in:
Yan Da
2022-06-18 21:16:45 +08:00
parent 53cf93ce6a
commit 9d1b5e3f79
6 changed files with 248 additions and 72 deletions

View File

@@ -4,8 +4,9 @@
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
// include "mlir/IR/TensorEncoding.td"
class TritonGPU_Attr<string name, list<Trait> traits = []>
: AttrDef<TritonGPU_Dialect, name, traits>;
class TritonGPU_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
let mnemonic = "shared_layout";
@@ -104,7 +105,8 @@ And the associated TritonGPU MLIR
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order
>:$order,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$threadOrdering,
// "AffineMap":warpOrdering,
// "AffineMap":$blockOrdering,
@@ -114,6 +116,28 @@ And the associated TritonGPU MLIR
// let genVerifyDecl = 1;
}
def TritonGPUBlockedMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
let mnemonic = "blocked_multicast_layout";
let description = [{
to be broadcasted to blocked_layout
}];
// This needs to be synced with BlockedEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$warpTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize,
ArrayRefParameter<"unsigned">:$order,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
);
// let genVerifyDecl = 1;
}
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
let mnemonic = "mma_layout";
@@ -131,7 +155,8 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
ArrayRefParameter<"unsigned">:$shapePerTile,
// TODO: should Distributed layout also
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread
ArrayRefParameter<"unsigned">:$contigPerThread,
ArrayRefParameter<"unsigned">:$broadcastAxis
// "AffineMap":$warpOrdering,
// "AffineMap":$blockOrdering
);
@@ -139,4 +164,28 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
// let genVerifyDecl = 1;
}
def TritonGPUMmaMulticastEncodingAttr
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
let mnemonic = "mma_multicast_layout";
let description = [{
To be broadcasted to mma.
}];
// This needs to be synced with MmaEncoding
let parameters = (
ins
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
ArrayRefParameter<"unsigned">:$shapePerWarp,
ArrayRefParameter<"unsigned">:$warpPerTile,
ArrayRefParameter<"unsigned">:$shapePerTile,
ArrayRefParameter<"unsigned">:$repetitions,
ArrayRefParameter<"unsigned">:$contigPerThread,
// unique to broadcasted layout
ArrayRefParameter<"unsigned">:$broadcastAxis
);
// let genVerifyDecl = 1;
}
#endif

View File

@@ -23,6 +23,9 @@ class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter;
public:
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
/// update layouts & insert ConvertLayoutOp if necessary
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
};
} // namespace mlir

View File

@@ -342,9 +342,13 @@ public:
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
populateSCFPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(mod, target,
std::move(patterns))))
if(failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
// update layouts
// broadcast src => multicast, dst => broadcasted
if(failed(target.refineLayouts(mod, numWarps)))
return signalPassFailure();
}
};

View File

@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
SmallVector<unsigned, 2> &res,
/*SmallVector<unsigned, 2>*/auto &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
@@ -36,8 +36,7 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
Attribute
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
static Attribute parseBlocked(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -51,28 +50,7 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
SmallVector<unsigned, 2> order;
// parse an array of integers
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
// SmallVector<unsigned, 2> &res,
// StringRef desc) -> LogicalResult {
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
// if (!arrayAttr) {
// parser.emitError(parser.getNameLoc(), "expected an array for ")
// << desc;
// return failure();
// }
// for (Attribute i : arrayAttr) {
// auto intAttr = i.dyn_cast<IntegerAttr>();
// if (!intAttr) {
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
// << desc;
// return failure();
// }
// res.push_back(intAttr.getUInt());
// }
// return success();
// };
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
@@ -98,20 +76,39 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
threadTileSize,
warpTileSize,
blockTileSize,
order);
order,
broadcastAxis);
}
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
static void printBlocked(AsmPrinter &printer, auto *attr) {
printer << "<{"
<< "threadTileSize = [" << getThreadTileSize() << "]"
<< ", warpTileSize = [" << getWarpTileSize() << "]"
<< ", blockTileSize = [" << getBlockTileSize() << "]"
<< ", order = [" << getOrder() << "]"
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
<< ", order = [" << attr->getOrder() << "]"
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
<< "}>";
}
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printBlocked(printer, this);
}
Attribute
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
parseBlocked(parser, type);
}
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
printBlocked(printer, this);
}
static Attribute parseMma(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
@@ -126,6 +123,7 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
SmallVector<unsigned, 2> shapePerTile;
SmallVector<unsigned, 2> repetitions;
SmallVector<unsigned, 2> contigPerThread;
SmallVector<unsigned, 2> broadcastAxis;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "fragmentPerWarp") {
@@ -159,18 +157,37 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
warpPerTile,
shapePerTile,
repetitions,
contigPerThread);
contigPerThread,
broadcastAxis);
}
static void printMma(AsmPrinter &printer, auto *attr) {
printer << "<{"
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
<< ", repetitions = [" << attr->getRepetitions() << "]"
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
<< "}>";
}
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
}
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]"
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
<< ", warpPerTile = [" << getWarpPerTile() << "]"
<< ", shapePerTile = [" << getShapePerTile() << "]"
<< ", repetitions = [" << getRepetitions() << "]"
<< ", contigPerThread = [" << getContigPerThread() << "]"
<< "}>";
printMma(printer, this);
}
Attribute
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
return parseMma(parser, type);
}
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
printMma(printer, this);
}
Attribute

View File

@@ -1,9 +1,11 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include <algorithm>
using namespace mlir;
using namespace mlir::triton::gpu;
//
// TypeConverter
@@ -41,6 +43,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
llvm::SmallVector<unsigned> broadcastAxis;
int remainingThreads = numThreads;
int remainingLanes = /*warp size*/32;
for (int64_t dim = 0; dim < rank; ++dim) {
@@ -53,7 +56,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order);
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
return RankedTensorType::get(shape, elementType, encoding);
});
@@ -81,7 +84,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// NOTE: only for remapped values.
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
llvm_unreachable("Not implemented");
return llvm::None;
});
@@ -124,3 +126,98 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
});
}
// %dst = tt.broadcast %src
// =>
// %newSrc = convert_layout %src
// %bcst = tt.broadcast %newSrc
// %dst = convert_layout %bcst
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
int numThreads) {
// collect broadcasts
SmallVector<triton::BroadcastOp> broadcasts;
mod.walk([&](triton::BroadcastOp op) {
broadcasts.push_back(op);
});
BlockAndValueMapping mapping;
for (auto broadcast : broadcasts) {
OpBuilder builder(broadcast);
Value src = mapping.lookupOrDefault(broadcast.src());
Type originSrcType = src.getType();
Type originDstType = broadcast.getType();
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
unsigned dstRank = originDstTensorType.getRank();
// compute newSrcType & broadcastAxis
Type newSrcType;
SmallVector<unsigned> broadcastAxis;
bool isSrcScalar = false;
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
assert(tensorType.getRank() == dstRank &&
"src & dst should have same rank (verifier should catch this)");
for (unsigned ax = 0; ax < dstRank; ++ax)
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
broadcastAxis.push_back(ax);
Attribute originSrcEnc = tensorType.getEncoding();
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
blockedEnc.getContext(),
blockedEnc.getThreadTileSize(),
blockedEnc.getWarpTileSize(),
blockedEnc.getBlockTileSize(),
blockedEnc.getOrder(),
broadcastAxis
);
newSrcType = RankedTensorType::get(
tensorType.getShape(),
tensorType.getElementType(),
newSrcEnc
);
} else
llvm_unreachable("src of broadcast should have blocked encoding");
} else {
for (unsigned ax = 0; ax < dstRank; ++ax)
broadcastAxis.push_back(ax);
newSrcType = originSrcType;
isSrcScalar = true;
}
// create new src
if (!isSrcScalar) // we don't need to convert layout for scalar values
src = builder.create<triton::gpu::ConvertLayoutOp>(
src.getLoc(), newSrcType, src
);
// create new broadcast
// compute new type (encoding)
auto originDstEnc = originDstTensorType.getEncoding()
.dyn_cast<TritonGPUBlockedEncodingAttr>();
auto newEnc = TritonGPUBlockedEncodingAttr::get(
originDstEnc.getContext(),
originDstEnc.getThreadTileSize(),
originDstEnc.getWarpTileSize(),
originDstEnc.getBlockTileSize(),
originDstEnc.getOrder(),
broadcastAxis
);
auto newType = RankedTensorType::get(
originDstTensorType.getShape(),
originDstTensorType.getElementType(),
newEnc
);
Value newBroadcast = builder.create<triton::BroadcastOp>(
broadcast.getLoc(), newType, src
);
// we don't want to change the encoding of the result
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
broadcast.getLoc(), originDstType, newBroadcast
);
broadcast.replaceAllUsesWith(newDst);
mapping.map(broadcast, newDst);
}
return success();
}

View File

@@ -45,32 +45,38 @@ module {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c256_i32 = arith.constant 256 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%11 = arith.index_cast %arg4 : i32 to index
%12:3 = scf.for %arg6 = %c0 to %11 step %c32 iter_args(%arg7 = %cst, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) {
%15 = tt.load %arg8, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%16 = tt.load %arg9, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%17 = arith.addf %15, %16 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%18 = arith.addf %arg7, %17 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%19 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%20 = tt.getelementptr %arg8, %19 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%21 = tt.getelementptr %arg9, %19 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
scf.yield %18, %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%4 = triton_gpu.convert_layout %3 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%5 = arith.addi %4, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%6 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%7 = triton_gpu.convert_layout %6 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%8 = "triton_gpu.cmpi"(%5, %7) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%9 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%10 = triton_gpu.convert_layout %9 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%11 = tt.getelementptr %10, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%12 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%13 = triton_gpu.convert_layout %12 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%14 = tt.getelementptr %13, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%15 = arith.index_cast %arg4 : i32 to index
%16:3 = scf.for %arg6 = %c0 to %15 step %c32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %14) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) {
%20 = tt.load %arg8, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%21 = tt.load %arg9, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%22 = arith.addf %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%23 = arith.addf %arg7, %22 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%25 = triton_gpu.convert_layout %24 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%26 = tt.getelementptr %arg8, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%27 = tt.getelementptr %arg9, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
scf.yield %23, %26, %27 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
}
%13 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%14 = tt.getelementptr %13, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
tt.store %14, %12#0, %6, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
%17 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
%18 = triton_gpu.convert_layout %17 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
%19 = tt.getelementptr %18, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
tt.store %19, %16#0, %8, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
return
}
}