special encoding for broadcast
This commit is contained in:
@@ -4,8 +4,9 @@
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
|
||||
// include "mlir/IR/TensorEncoding.td"
|
||||
|
||||
class TritonGPU_Attr<string name, list<Trait> traits = []>
|
||||
: AttrDef<TritonGPU_Dialect, name, traits>;
|
||||
class TritonGPU_Attr<string name, list<Trait> traits = [],
|
||||
string baseCppClass = "::mlir::Attribute">
|
||||
: AttrDef<TritonGPU_Dialect, name, traits, baseCppClass>;
|
||||
|
||||
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
||||
let mnemonic = "shared_layout";
|
||||
@@ -104,7 +105,8 @@ And the associated TritonGPU MLIR
|
||||
ArrayRefParameter<
|
||||
"unsigned",
|
||||
"order of axes by the rate of changing"
|
||||
>:$order
|
||||
>:$order,
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
// "AffineMap":$threadOrdering,
|
||||
// "AffineMap":warpOrdering,
|
||||
// "AffineMap":$blockOrdering,
|
||||
@@ -114,6 +116,28 @@ And the associated TritonGPU MLIR
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUBlockedMulticastEncodingAttr
|
||||
: TritonGPU_Attr<"TritonGPUBlockedMulticastEncoding"> {
|
||||
let mnemonic = "blocked_multicast_layout";
|
||||
|
||||
let description = [{
|
||||
to be broadcasted to blocked_layout
|
||||
}];
|
||||
|
||||
// This needs to be synced with BlockedEncoding
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$threadTileSize,
|
||||
ArrayRefParameter<"unsigned">:$warpTileSize,
|
||||
ArrayRefParameter<"unsigned">:$blockTileSize,
|
||||
ArrayRefParameter<"unsigned">:$order,
|
||||
// unique to broadcasted layout
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
let mnemonic = "mma_layout";
|
||||
|
||||
@@ -131,7 +155,8 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||
// TODO: should Distributed layout also
|
||||
ArrayRefParameter<"unsigned">:$repetitions,
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread,
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
// "AffineMap":$warpOrdering,
|
||||
// "AffineMap":$blockOrdering
|
||||
);
|
||||
@@ -139,4 +164,28 @@ def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
def TritonGPUMmaMulticastEncodingAttr
|
||||
: TritonGPU_Attr<"TritonGPUMmaMulticastEncoding"> {
|
||||
let mnemonic = "mma_multicast_layout";
|
||||
|
||||
let description = [{
|
||||
To be broadcasted to mma.
|
||||
}];
|
||||
|
||||
// This needs to be synced with MmaEncoding
|
||||
let parameters = (
|
||||
ins
|
||||
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
||||
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
||||
ArrayRefParameter<"unsigned">:$warpPerTile,
|
||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||
ArrayRefParameter<"unsigned">:$repetitions,
|
||||
ArrayRefParameter<"unsigned">:$contigPerThread,
|
||||
// unique to broadcasted layout
|
||||
ArrayRefParameter<"unsigned">:$broadcastAxis
|
||||
);
|
||||
|
||||
// let genVerifyDecl = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@@ -23,6 +23,9 @@ class TritonGPUConversionTarget : public ConversionTarget {
|
||||
TritonGPUTypeConverter &typeConverter;
|
||||
public:
|
||||
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
|
||||
|
||||
/// update layouts & insert ConvertLayoutOp if necessary
|
||||
LogicalResult refineLayouts(ModuleOp mod, int numThreads);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
@@ -342,9 +342,13 @@ public:
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
|
||||
if(failed(applyPartialConversion(mod, target,
|
||||
std::move(patterns))))
|
||||
if(failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
if(failed(target.refineLayouts(mod, numWarps)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -10,7 +10,7 @@ using namespace mlir::triton::gpu;
|
||||
// parse an array of integers
|
||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
const NamedAttribute &attr,
|
||||
SmallVector<unsigned, 2> &res,
|
||||
/*SmallVector<unsigned, 2>*/auto &res,
|
||||
StringRef desc) {
|
||||
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr) {
|
||||
@@ -36,8 +36,7 @@ static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
static Attribute parseBlocked(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -51,28 +50,7 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
SmallVector<unsigned, 2> warpTileSize;
|
||||
SmallVector<unsigned, 2> blockTileSize;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
// parse an array of integers
|
||||
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
|
||||
// SmallVector<unsigned, 2> &res,
|
||||
// StringRef desc) -> LogicalResult {
|
||||
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
// if (!arrayAttr) {
|
||||
// parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||
// << desc;
|
||||
// return failure();
|
||||
// }
|
||||
// for (Attribute i : arrayAttr) {
|
||||
// auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||
// if (!intAttr) {
|
||||
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||
// << desc;
|
||||
// return failure();
|
||||
// }
|
||||
// res.push_back(intAttr.getUInt());
|
||||
// }
|
||||
// return success();
|
||||
// };
|
||||
SmallVector<unsigned, 2> broadcastAxis;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "threadTileSize") {
|
||||
@@ -98,20 +76,39 @@ TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
threadTileSize,
|
||||
warpTileSize,
|
||||
blockTileSize,
|
||||
order);
|
||||
order,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
static void printBlocked(AsmPrinter &printer, auto *attr) {
|
||||
printer << "<{"
|
||||
<< "threadTileSize = [" << getThreadTileSize() << "]"
|
||||
<< ", warpTileSize = [" << getWarpTileSize() << "]"
|
||||
<< ", blockTileSize = [" << getBlockTileSize() << "]"
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< "threadTileSize = [" << attr->getThreadTileSize() << "]"
|
||||
<< ", warpTileSize = [" << attr->getWarpTileSize() << "]"
|
||||
<< ", blockTileSize = [" << attr->getBlockTileSize() << "]"
|
||||
<< ", order = [" << attr->getOrder() << "]"
|
||||
<< ", broadcastAxis = [" << attr->getBroadcastAxis() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUBlockedMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
parseBlocked(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printBlocked(printer, this);
|
||||
}
|
||||
|
||||
static Attribute parseMma(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
DictionaryAttr dict;
|
||||
@@ -126,6 +123,7 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
SmallVector<unsigned, 2> shapePerTile;
|
||||
SmallVector<unsigned, 2> repetitions;
|
||||
SmallVector<unsigned, 2> contigPerThread;
|
||||
SmallVector<unsigned, 2> broadcastAxis;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "fragmentPerWarp") {
|
||||
@@ -159,18 +157,37 @@ TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
warpPerTile,
|
||||
shapePerTile,
|
||||
repetitions,
|
||||
contigPerThread);
|
||||
contigPerThread,
|
||||
broadcastAxis);
|
||||
}
|
||||
|
||||
static void printMma(AsmPrinter &printer, auto *attr) {
|
||||
printer << "<{"
|
||||
<< "fragmentPerWarp = [" << attr->getFragmentPerWarp() << "]"
|
||||
<< ", shapePerWarp = [" << attr->getShapePerWarp() << "]"
|
||||
<< ", warpPerTile = [" << attr->getWarpPerTile() << "]"
|
||||
<< ", shapePerTile = [" << attr->getShapePerTile() << "]"
|
||||
<< ", repetitions = [" << attr->getRepetitions() << "]"
|
||||
<< ", contigPerThread = [" << attr->getContigPerThread() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "fragmentPerWarp = [" << getFragmentPerWarp() << "]"
|
||||
<< ", shapePerWarp = [" << getShapePerWarp() << "]"
|
||||
<< ", warpPerTile = [" << getWarpPerTile() << "]"
|
||||
<< ", shapePerTile = [" << getShapePerTile() << "]"
|
||||
<< ", repetitions = [" << getRepetitions() << "]"
|
||||
<< ", contigPerThread = [" << getContigPerThread() << "]"
|
||||
<< "}>";
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
TritonGPUMmaMulticastEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return parseMma(parser, type);
|
||||
}
|
||||
|
||||
void TritonGPUMmaMulticastEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printMma(printer, this);
|
||||
}
|
||||
|
||||
Attribute
|
||||
|
@@ -1,9 +1,11 @@
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include <algorithm>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton::gpu;
|
||||
|
||||
//
|
||||
// TypeConverter
|
||||
@@ -41,6 +43,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
llvm::SmallVector<unsigned> broadcastAxis;
|
||||
int remainingThreads = numThreads;
|
||||
int remainingLanes = /*warp size*/32;
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
@@ -53,7 +56,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// TODO: will we need repetition?
|
||||
}
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order);
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order, broadcastAxis);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
});
|
||||
|
||||
@@ -81,7 +84,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// NOTE: only for remapped values.
|
||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
llvm_unreachable("Not implemented");
|
||||
return llvm::None;
|
||||
});
|
||||
@@ -124,3 +126,98 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
// %dst = tt.broadcast %src
|
||||
// =>
|
||||
// %newSrc = convert_layout %src
|
||||
// %bcst = tt.broadcast %newSrc
|
||||
// %dst = convert_layout %bcst
|
||||
LogicalResult TritonGPUConversionTarget::refineLayouts(ModuleOp mod,
|
||||
int numThreads) {
|
||||
// collect broadcasts
|
||||
SmallVector<triton::BroadcastOp> broadcasts;
|
||||
mod.walk([&](triton::BroadcastOp op) {
|
||||
broadcasts.push_back(op);
|
||||
});
|
||||
|
||||
BlockAndValueMapping mapping;
|
||||
for (auto broadcast : broadcasts) {
|
||||
OpBuilder builder(broadcast);
|
||||
Value src = mapping.lookupOrDefault(broadcast.src());
|
||||
Type originSrcType = src.getType();
|
||||
Type originDstType = broadcast.getType();
|
||||
auto originDstTensorType = originDstType.dyn_cast<RankedTensorType>();
|
||||
unsigned dstRank = originDstTensorType.getRank();
|
||||
|
||||
// compute newSrcType & broadcastAxis
|
||||
Type newSrcType;
|
||||
SmallVector<unsigned> broadcastAxis;
|
||||
bool isSrcScalar = false;
|
||||
if (auto tensorType = originSrcType.dyn_cast<RankedTensorType>()) {
|
||||
assert(tensorType.getRank() == dstRank &&
|
||||
"src & dst should have same rank (verifier should catch this)");
|
||||
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||
if (tensorType.getShape()[ax] < originDstTensorType.getShape()[ax])
|
||||
broadcastAxis.push_back(ax);
|
||||
|
||||
Attribute originSrcEnc = tensorType.getEncoding();
|
||||
if (auto blockedEnc = originSrcEnc.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
auto newSrcEnc = TritonGPUBlockedMulticastEncodingAttr::get(
|
||||
blockedEnc.getContext(),
|
||||
blockedEnc.getThreadTileSize(),
|
||||
blockedEnc.getWarpTileSize(),
|
||||
blockedEnc.getBlockTileSize(),
|
||||
blockedEnc.getOrder(),
|
||||
broadcastAxis
|
||||
);
|
||||
newSrcType = RankedTensorType::get(
|
||||
tensorType.getShape(),
|
||||
tensorType.getElementType(),
|
||||
newSrcEnc
|
||||
);
|
||||
} else
|
||||
llvm_unreachable("src of broadcast should have blocked encoding");
|
||||
} else {
|
||||
for (unsigned ax = 0; ax < dstRank; ++ax)
|
||||
broadcastAxis.push_back(ax);
|
||||
newSrcType = originSrcType;
|
||||
isSrcScalar = true;
|
||||
}
|
||||
|
||||
// create new src
|
||||
if (!isSrcScalar) // we don't need to convert layout for scalar values
|
||||
src = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
src.getLoc(), newSrcType, src
|
||||
);
|
||||
|
||||
// create new broadcast
|
||||
// compute new type (encoding)
|
||||
auto originDstEnc = originDstTensorType.getEncoding()
|
||||
.dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto newEnc = TritonGPUBlockedEncodingAttr::get(
|
||||
originDstEnc.getContext(),
|
||||
originDstEnc.getThreadTileSize(),
|
||||
originDstEnc.getWarpTileSize(),
|
||||
originDstEnc.getBlockTileSize(),
|
||||
originDstEnc.getOrder(),
|
||||
broadcastAxis
|
||||
);
|
||||
auto newType = RankedTensorType::get(
|
||||
originDstTensorType.getShape(),
|
||||
originDstTensorType.getElementType(),
|
||||
newEnc
|
||||
);
|
||||
Value newBroadcast = builder.create<triton::BroadcastOp>(
|
||||
broadcast.getLoc(), newType, src
|
||||
);
|
||||
// we don't want to change the encoding of the result
|
||||
Value newDst = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
broadcast.getLoc(), originDstType, newBroadcast
|
||||
);
|
||||
|
||||
broadcast.replaceAllUsesWith(newDst);
|
||||
mapping.map(broadcast, newDst);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@@ -45,32 +45,38 @@ module {
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%c256_i32 = arith.constant 256 : i32
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = arith.muli %0, %c256_i32 : i32
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%4 = arith.addi %3, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%5 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%6 = "triton_gpu.cmpi"(%4, %5) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%7 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%8 = tt.getelementptr %7, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%9 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%10 = tt.getelementptr %9, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%11 = arith.index_cast %arg4 : i32 to index
|
||||
%12:3 = scf.for %arg6 = %c0 to %11 step %c32 iter_args(%arg7 = %cst, %arg8 = %8, %arg9 = %10) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>) {
|
||||
%15 = tt.load %arg8, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%16 = tt.load %arg9, %6, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%17 = arith.addf %15, %16 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%18 = arith.addf %arg7, %17 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%19 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%20 = tt.getelementptr %arg8, %19 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%21 = tt.getelementptr %arg9, %19 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
scf.yield %18, %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%3 = tt.broadcast %1 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%5 = arith.addi %4, %2 : tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%6 = tt.broadcast %arg3 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%7 = triton_gpu.convert_layout %6 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%8 = "triton_gpu.cmpi"(%5, %7) {predicate = 2 : i64} : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) -> tensor<256xi1, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%9 = tt.broadcast %arg0 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%10 = triton_gpu.convert_layout %9 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%11 = tt.getelementptr %10, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%12 = tt.broadcast %arg1 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%13 = triton_gpu.convert_layout %12 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%14 = tt.getelementptr %13, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%15 = arith.index_cast %arg4 : i32 to index
|
||||
%16:3 = scf.for %arg6 = %c0 to %15 step %c32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %14) -> (tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>) {
|
||||
%20 = tt.load %arg8, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%21 = tt.load %arg9, %8, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%22 = arith.addf %20, %21 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%23 = arith.addf %arg7, %22 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%24 = tt.broadcast %arg5 : (i32) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%25 = triton_gpu.convert_layout %24 : (tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256xi32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%26 = tt.getelementptr %arg8, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%27 = tt.getelementptr %arg9, %25 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
scf.yield %23, %26, %27 : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>, tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
}
|
||||
%13 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%14 = tt.getelementptr %13, %4 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
tt.store %14, %12#0, %6, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0]}>>
|
||||
%17 = tt.broadcast %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = [0]}>>) -> tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
%19 = tt.getelementptr %18, %5 : tensor<256x!tt.ptr<f32>, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
tt.store %19, %16#0, %8, : tensor<256xf32, #triton_gpu.blocked_layout<{threadTileSize = [1], warpTileSize = [32], blockTileSize = [128], order = [0], broadcastAxis = []}>>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user