[OPTIMIZER] Rewrite patterns for layout conversions (#64)
This commit is contained in:
@@ -225,7 +225,7 @@ def TT_DotOp : TT_Op<"dot", [NoSideEffect,
|
||||
// let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TT_ReduceOp : TT_Op<"reduce"> {
|
||||
def TT_ReduceOp : TT_Op<"reduce", [NoSideEffect]> {
|
||||
let summary = "reduce";
|
||||
|
||||
let arguments = (ins TT_RedOpAttr:$redOp, TT_Tensor:$operand, I32Attr:$axis);
|
||||
|
@@ -37,7 +37,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
// Shared Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
||||
def SharedEncodingAttr : TritonGPU_Attr<"SharedEncoding"> {
|
||||
let mnemonic = "shared";
|
||||
|
||||
let description = [{
|
||||
@@ -70,9 +70,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
||||
// Distributed Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
|
||||
let mnemonic = "distributed";
|
||||
|
||||
class DistributedEncoding<string name> : TritonGPU_Attr<name> {
|
||||
let description = [{
|
||||
Distributed encodings have a layout function that is entirely characterized
|
||||
by a d-dimensional tensor L. Note that L doesn't need to have the same shape
|
||||
@@ -97,12 +95,11 @@ L(A) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Layout Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TritonGPUBlockedEncodingAttr : TritonGPU_Attr<"TritonGPUBlockedEncoding"> {
|
||||
def BlockedEncodingAttr : DistributedEncoding<"BlockedEncoding"> {
|
||||
let mnemonic = "blocked";
|
||||
|
||||
let description = [{
|
||||
@@ -174,6 +171,10 @@ for
|
||||
}]>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
SliceEncodingAttr squeeze(int axis);
|
||||
}];
|
||||
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
@@ -197,7 +198,7 @@ for
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TODO: MMAv1 and MMAv2 should be two instances of the same class
|
||||
|
||||
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||
def MmaEncodingAttr : DistributedEncoding<"MmaEncoding"> {
|
||||
let mnemonic = "mma";
|
||||
|
||||
let description = [{
|
||||
@@ -283,5 +284,34 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
);
|
||||
}
|
||||
|
||||
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> {
|
||||
let mnemonic = "slice";
|
||||
|
||||
let description = [{
|
||||
TODO: improve docs
|
||||
|
||||
A = [x x x x x x x x]
|
||||
[x x x x x x x x]
|
||||
L_parent = [0 1 2 3 ]
|
||||
[4 5 6 7 ]
|
||||
[8 9 10 11]
|
||||
[12 13 14 15]
|
||||
dim = 0
|
||||
|
||||
Then the data of A would be distributed as follow between the 16 CUDA threads:
|
||||
L(A) = [ {0,4,8,12} , {1,5,9,13} , ... {3,7,11,15} ]
|
||||
|
||||
This is useful for constructing the inverse layout of an expand_dims operation during some optimization passes.
|
||||
|
||||
}];
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$dim,
|
||||
// TODO: constraint here to only take distributed encodings
|
||||
"Attribute":$parent
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
@@ -20,7 +20,7 @@ class TTG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<TritonGPU_Dialect, mnemonic, traits>;
|
||||
|
||||
def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
|
||||
[NoSideEffect, SameOperandsAndResultType]> {
|
||||
[NoSideEffect]> {
|
||||
let summary = "convert layout";
|
||||
|
||||
let arguments = (ins TT_Tensor:$src);
|
||||
@@ -65,7 +65,7 @@ def TTG_CopyAsyncOp : TTG_Op<"copy_async",
|
||||
// This is needed because Arith's Cmp ops don't
|
||||
// handle encodings
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td#L111
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi"> {
|
||||
def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect]> {
|
||||
let summary = "integer comparison operation";
|
||||
|
||||
let description = [{}];
|
||||
|
@@ -6,6 +6,8 @@
|
||||
namespace mlir {
|
||||
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCanonicalizeLoopsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCoalescePass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUCombineOpsPass();
|
||||
|
@@ -60,6 +60,19 @@ def TritonGPUCombineOps : Pass<"tritongpu-combine", "mlir::ModuleOp"> {
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUCanonicalizeLoops: Pass<"tritongpu-canonicalize-loops", "mlir::ModuleOp"> {
|
||||
let summary = "canonicalize scf.ForOp ops";
|
||||
|
||||
let description = [{
|
||||
This implements some optimizations that are missing in the standard scf.ForOp
|
||||
canonicalizer.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUCanonicalizeLoopsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
def TritonGPUVerifier : Pass<"tritongpu-verifier", "mlir::ModuleOp"> {
|
||||
let summary = "verify TritonGPU IR";
|
||||
|
||||
|
@@ -43,8 +43,7 @@ private:
|
||||
auto type = result.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorType.getEncoding();
|
||||
if (encoding &&
|
||||
encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
if (encoding && encoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
// Bytes could be a different value once we support padding or other
|
||||
// allocation policies.
|
||||
auto bytes = tensorType.getNumElements() *
|
||||
|
@@ -23,9 +23,9 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::TritonGPUBlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::TritonGPUMmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::TritonGPUSharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
@@ -226,7 +226,7 @@ static int64_t getLinearIndex(std::vector<int64_t> multidim_index,
|
||||
return linear_index;
|
||||
}
|
||||
|
||||
static unsigned getElemsPerThread(TritonGPUBlockedEncodingAttr layout,
|
||||
static unsigned getElemsPerThread(BlockedEncodingAttr layout,
|
||||
ArrayRef<int64_t> shape) {
|
||||
size_t rank = shape.size();
|
||||
SmallVector<unsigned> elemsPerThreadPerDim(rank);
|
||||
@@ -368,10 +368,10 @@ public:
|
||||
// Emit indices calculation within each ConversionPattern
|
||||
// TODO: [goostavz] Double confirm the redundant indices calculations will
|
||||
// be eliminated in the consequent MLIR/LLVM optimization
|
||||
SmallVector<SmallVector<Value>> emitIndicesForBlockedLayout(
|
||||
Location loc, ConversionPatternRewriter &b,
|
||||
const TritonGPUBlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForBlockedLayout(Location loc, ConversionPatternRewriter &b,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto cast = b.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{llvmIndexTy},
|
||||
@@ -483,7 +483,7 @@ Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
auto layout = tensorTy.getEncoding().cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto layout = tensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto srcType = typeConverter->convertType(elemType);
|
||||
auto llSrc = rewriter.create<LLVM::BitcastOp>(loc, srcType, constVal);
|
||||
|
||||
@@ -594,9 +594,9 @@ struct StoreOpConversion
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
auto getLLVMElems = [&](Value value, Value llValue,
|
||||
const TritonGPUBlockedEncodingAttr &layout)
|
||||
-> SmallVector<Value, 4> {
|
||||
auto getLLVMElems =
|
||||
[&](Value value, Value llValue,
|
||||
const BlockedEncodingAttr &layout) -> SmallVector<Value, 4> {
|
||||
auto ty = value.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
@@ -613,11 +613,11 @@ struct StoreOpConversion
|
||||
};
|
||||
|
||||
auto getLayout =
|
||||
[&](Value val) -> std::tuple<TritonGPUBlockedEncodingAttr, unsigned> {
|
||||
[&](Value val) -> std::tuple<BlockedEncodingAttr, unsigned> {
|
||||
auto ty = val.getType().cast<RankedTensorType>();
|
||||
auto shape = ty.getShape();
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
auto layout = ty.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto layout = ty.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
|
||||
unsigned valueElems = getElemsPerThread(layout, shape);
|
||||
|
||||
@@ -633,9 +633,8 @@ struct StoreOpConversion
|
||||
auto maskElems = getLLVMElems(mask, llMask, maskLayout);
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
|
||||
auto getAlign =
|
||||
[this](Value val,
|
||||
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
|
||||
auto getAlign = [this](Value val,
|
||||
const BlockedEncodingAttr &layout) -> unsigned {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
assert(axisInfo.hasValue());
|
||||
|
||||
@@ -648,9 +647,9 @@ struct StoreOpConversion
|
||||
};
|
||||
|
||||
// get align
|
||||
auto getVec = [this, &getAlign](
|
||||
Value val,
|
||||
const TritonGPUBlockedEncodingAttr &layout) -> unsigned {
|
||||
auto getVec = [this,
|
||||
&getAlign](Value val,
|
||||
const BlockedEncodingAttr &layout) -> unsigned {
|
||||
auto axisInfo = getAxisInfo(val);
|
||||
auto contig = axisInfo->getContiguity();
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
@@ -820,10 +819,8 @@ struct BroadcastOpConversion
|
||||
Value result = op.result();
|
||||
auto srcTy = op.src().getType().cast<RankedTensorType>();
|
||||
auto resultTy = result.getType().cast<RankedTensorType>();
|
||||
auto srcLayout =
|
||||
srcTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
assert(srcLayout && (srcLayout == resultLayout) &&
|
||||
"Unexpected layout of BroadcastOp");
|
||||
auto srcShape = srcTy.getShape();
|
||||
@@ -894,8 +891,7 @@ struct ViewOpConversion
|
||||
// due to MLIR's restrictions
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().cast<RankedTensorType>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
Type elemTy =
|
||||
@@ -921,7 +917,7 @@ struct MakeRangeOpConversion
|
||||
auto rankedTy = op.result().getType().dyn_cast<RankedTensorType>();
|
||||
auto shape = rankedTy.getShape();
|
||||
auto blocked_layout =
|
||||
rankedTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
rankedTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto elemTy = rankedTy.getElementType();
|
||||
assert(elemTy.isInteger(32));
|
||||
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.start());
|
||||
@@ -955,8 +951,7 @@ struct LoadOpConversion
|
||||
Value mask = adaptor.mask();
|
||||
Value other = adaptor.other();
|
||||
auto resultTy = op.result().getType().cast<RankedTensorType>();
|
||||
auto blockedLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto blockedLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto shape = resultTy.getShape();
|
||||
|
||||
// TODO: Handle AxisInfo
|
||||
@@ -1166,8 +1161,7 @@ struct GEPOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::GEPOp> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding().dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
Type elemTy =
|
||||
@@ -1206,8 +1200,8 @@ public:
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto resultLayout = resultTy.getEncoding()
|
||||
.template dyn_cast<TritonGPUBlockedEncodingAttr>();
|
||||
auto resultLayout =
|
||||
resultTy.getEncoding().template dyn_cast<BlockedEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned elems = getElemsPerThread(resultLayout, resultShape);
|
||||
Type elemTy =
|
||||
@@ -1250,17 +1244,16 @@ public:
|
||||
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
Attribute layout = type.getEncoding();
|
||||
if (auto blocked_layout = layout.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
if (auto blocked_layout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
unsigned numElementsPerThread =
|
||||
getElemsPerThread(blocked_layout, type.getShape());
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(&getContext(), types);
|
||||
} else if (auto mma_layout = layout.dyn_cast<TritonGPUMmaEncodingAttr>()) {
|
||||
} else if (auto mma_layout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
// TODO: Not implemented
|
||||
return llvm::None;
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast<TritonGPUSharedEncodingAttr>()) {
|
||||
} else if (auto shared_layout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
// TODO: Not implemented
|
||||
return llvm::None;
|
||||
}
|
||||
|
@@ -156,8 +156,7 @@ struct TritonExpandDimsPattern
|
||||
Attribute _argEncoding = argType.getEncoding();
|
||||
if (!_argEncoding)
|
||||
return failure();
|
||||
auto argEncoding =
|
||||
_argEncoding.cast<triton::gpu::TritonGPUBlockedEncodingAttr>();
|
||||
auto argEncoding = _argEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
||||
// return shape
|
||||
auto retShape = argType.getShape().vec();
|
||||
retShape.insert(retShape.begin() + op.axis(), 1);
|
||||
@@ -170,10 +169,10 @@ struct TritonExpandDimsPattern
|
||||
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.axis(), 1);
|
||||
SmallVector<unsigned, 4> retOrder(retShape.size());
|
||||
std::iota(retOrder.begin(), retOrder.end(), 0);
|
||||
triton::gpu::TritonGPUBlockedEncodingAttr retEncoding =
|
||||
triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
triton::gpu::BlockedEncodingAttr retEncoding =
|
||||
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
||||
retThreadsPerWarp, retWarpsPerCTA,
|
||||
retOrder);
|
||||
// return type
|
||||
RankedTensorType retType =
|
||||
RankedTensorType::get(retShape, argType.getElementType(), retEncoding);
|
||||
@@ -201,16 +200,16 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
Value a = adaptor.a();
|
||||
Value b = adaptor.b();
|
||||
SmallVector<unsigned, 2> order{1, 0};
|
||||
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
if (!aEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
if (!bEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
|
@@ -1,6 +1,8 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
@@ -70,11 +72,15 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
|
||||
SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
||||
return SliceEncodingAttr::get(getContext(), axis, *this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -115,11 +121,11 @@ Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
|
||||
return parser.getChecked<BlockedEncodingAttr>(
|
||||
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "sizePerThread = [" << getSizePerThread() << "]"
|
||||
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
|
||||
@@ -132,7 +138,7 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
// MMA encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
DictionaryAttr dict;
|
||||
@@ -155,22 +161,59 @@ Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||
version, warpsPerCTA);
|
||||
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), version,
|
||||
warpsPerCTA);
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
void MmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "version = " << getVersion() << ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sliced Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
DictionaryAttr dict;
|
||||
if (parser.parseAttribute(dict).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
unsigned dim = 0;
|
||||
Attribute parent;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "dim") {
|
||||
if (parseUInt(parser, attr, dim, "dim").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "parent") {
|
||||
if (parser.parseAttribute(parent).failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
||||
}
|
||||
|
||||
void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "dim = " << getDim() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -205,11 +248,11 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(
|
||||
parser.getContext(), vec, perPhase, maxPhase, order);
|
||||
return parser.getChecked<SharedEncodingAttr>(parser.getContext(), vec,
|
||||
perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
|
||||
@@ -226,18 +269,21 @@ public:
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
|
||||
if (auto mmaAttr = attr.dyn_cast<MmaEncodingAttr>()) {
|
||||
os << "mma";
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
|
||||
} else if (auto sharedAttr = attr.dyn_cast<SharedEncodingAttr>()) {
|
||||
os << "shared";
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto blockedAttr =
|
||||
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
} else if (auto blockedAttr = attr.dyn_cast<BlockedEncodingAttr>()) {
|
||||
os << "blocked";
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
return OpAsmDialectInterface::getAlias(attr, os);
|
||||
} /* else if (auto sliceAttr = attr.dyn_cast<SliceEncodingAttr>()) {
|
||||
os << "slice";
|
||||
return AliasResult::FinalAlias;
|
||||
} */
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -283,11 +329,15 @@ static Type getPointeeType(Type type) {
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verification
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyAsyncOp op) {
|
||||
Type resType = op.getResult().getType();
|
||||
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding.isa<TritonGPUSharedEncodingAttr>())
|
||||
if (!encoding.isa<SharedEncodingAttr>())
|
||||
return op.emitOpError("copy_async should return a shared memory tensor");
|
||||
} else
|
||||
return op.emitOpError("copy_async should return a tensor");
|
||||
@@ -302,4 +352,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
}
|
@@ -4,6 +4,7 @@ add_public_tablegen_target(TritonGPUCombineIncGen)
|
||||
|
||||
add_mlir_dialect_library(TritonGPUTransforms
|
||||
Coalesce.cpp
|
||||
CanonicalizeLoops.cpp
|
||||
Combine.cpp
|
||||
Pipeline.cpp
|
||||
Verifier.cpp
|
||||
|
55
lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp
Normal file
55
lib/Dialect/TritonGPU/Transforms/CanonicalizeLoops.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
struct CanonicalizePass
|
||||
: public TritonGPUCanonicalizeLoopsBase<CanonicalizePass> {
|
||||
CanonicalizePass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
|
||||
// Canonicalize pass may have created dead code that
|
||||
// standard scf.for canonicalization cannot handle
|
||||
// as of LLVM 14. For example, the iteration arguments
|
||||
// for the pointer of the synchronous loads that are
|
||||
// discarded.
|
||||
// The following piece of code is a workaround to
|
||||
// very crudely remove dead code, by making an iteration
|
||||
// argument yield itself if it is not used to create
|
||||
// side-effects anywhere.
|
||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
|
||||
// condition 1: no other iter arguments depend on it
|
||||
SetVector<Operation *> fwdSlice;
|
||||
mlir::getForwardSlice(forOp.getRegionIterArgs()[i], &fwdSlice);
|
||||
Operation *yieldOp = forOp.getBody()->getTerminator();
|
||||
bool noOtherDependency = std::all_of(
|
||||
yieldOp->operand_begin(), yieldOp->operand_end(), [&](Value arg) {
|
||||
return arg == yieldOp->getOperand(i) ||
|
||||
!fwdSlice.contains(arg.getDefiningOp());
|
||||
});
|
||||
// condition 2: final value is not used after the loop
|
||||
auto retVal = forOp.getResult(i);
|
||||
bool noUserAfterLoop = retVal.getUsers().empty();
|
||||
// yielding the region iter arg will cause loop canonicalization
|
||||
// to clean up the dead code
|
||||
if (noOtherDependency && noUserAfterLoop) {
|
||||
yieldOp->setOperand(i, forOp.getRegionIterArgs()[i]);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCanonicalizeLoopsPass() {
|
||||
return std::make_unique<CanonicalizePass>();
|
||||
}
|
@@ -32,8 +32,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
unsigned perThread = std::min(alignment, 128 / numBits);
|
||||
sizePerThread[order[0]] = perThread;
|
||||
SmallVector<unsigned> dims(rank);
|
||||
std::iota(dims.begin(), dims.end(), 0);
|
||||
// create encoding
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
&getContext(), origType.getShape(), sizePerThread, order,
|
||||
this->numWarps);
|
||||
return encoding;
|
||||
@@ -64,15 +66,20 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
op->getLoc(), convertType(v.getType()), v));
|
||||
// convert output types
|
||||
SmallVector<Type, 4> newTypes;
|
||||
for (auto t : op->getResultTypes())
|
||||
newTypes.push_back(convertType(t));
|
||||
for (auto t : op->getResultTypes()) {
|
||||
bool is_async = std::is_same<T, triton::gpu::CopyAsyncOp>::value;
|
||||
newTypes.push_back(is_async ? t : convertType(t));
|
||||
}
|
||||
// construct new op with the new encoding
|
||||
Operation *newOp =
|
||||
builder.create<T>(op->getLoc(), newTypes, newArgs, op->getAttrs());
|
||||
// cast the results back to the original layout
|
||||
for (size_t i = 0; i < op->getNumResults(); i++) {
|
||||
auto newResult = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), op->getResult(i).getType(), newOp->getResult(i));
|
||||
Value newResult = newOp->getResult(i);
|
||||
if (newTypes[i] != op->getResultTypes()[i]) {
|
||||
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), op->getResult(i).getType(), newResult);
|
||||
}
|
||||
op->getResult(i).replaceAllUsesWith(newResult);
|
||||
}
|
||||
op->erase();
|
||||
@@ -97,6 +104,9 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
builder.setInsertionPoint(curr);
|
||||
if (auto load = dyn_cast<triton::LoadOp>(curr))
|
||||
coalesceOp<triton::LoadOp>(axisInfo, curr, load.ptr(), builder);
|
||||
if (auto load = dyn_cast<triton::gpu::CopyAsyncOp>(curr))
|
||||
coalesceOp<triton::gpu::CopyAsyncOp>(axisInfo, curr, load.ptr(),
|
||||
builder);
|
||||
if (auto store = dyn_cast<triton::StoreOp>(curr))
|
||||
coalesceOp<triton::StoreOp>(axisInfo, curr, store.ptr(), builder);
|
||||
});
|
||||
|
@@ -1,10 +1,13 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
@@ -15,15 +18,414 @@ using namespace mlir;
|
||||
static bool isSharedLayout(Value v) {
|
||||
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
return encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>();
|
||||
return encoding.isa<triton::gpu::SharedEncodingAttr>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
#include "TritonGPUCombine.inc"
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Layout conversions can't deduce their return type automatically.
|
||||
// IIUC they are therefore not handled by DRR right now
|
||||
class SimplifyConversion : public mlir::RewritePattern {
|
||||
public:
|
||||
SimplifyConversion(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
// convert to the same layout -- we can delete
|
||||
if (op->getResultTypes() == op->getOperandTypes()) {
|
||||
rewriter.replaceOp(op, op->getOperands());
|
||||
return mlir::success();
|
||||
}
|
||||
Operation *arg = op->getOperand(0).getDefiningOp();
|
||||
// block argument
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
// cvt(type2, cvt(type1, x)) -> cvt(type2, x)
|
||||
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, op->getResultTypes().front(), arg->getOperand(0));
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, splat(type2, x)) -> splat(type1, x)
|
||||
if (auto splat = llvm::dyn_cast<triton::SplatOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
|
||||
splat.src());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type1, make_range(type2, x)) -> make_range(type1, x)
|
||||
if (auto range = llvm::dyn_cast<triton::MakeRangeOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, op->getResultTypes(), range.start(), range.end());
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(type, constant) -> constant
|
||||
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
|
||||
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
|
||||
ret.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Layout conversions are expensive. They require going through
|
||||
// shared memory, which is orders of magnitude slower than
|
||||
// other non-i/o operations in the dialect.
|
||||
// It therefore makes sense to remove them whenever possible,
|
||||
// even if it means rematerializing all values whose definitions
|
||||
// are reachable from it without passing through any memory operation.
|
||||
class PullConversionToSource : public mlir::RewritePattern {
|
||||
public:
|
||||
PullConversionToSource(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
3, context) {}
|
||||
|
||||
void getReachableNotThroughMemOp(
|
||||
ArrayRef<Value> operands,
|
||||
SmallVectorImpl<Operation *> &postOrderRet) const {
|
||||
struct State {
|
||||
Value value;
|
||||
unsigned operandIndex;
|
||||
};
|
||||
SmallVector<State, 4> worklist;
|
||||
for (auto operand : operands)
|
||||
worklist.push_back({operand, 0});
|
||||
|
||||
while (!worklist.empty()) {
|
||||
State &state = worklist.back();
|
||||
auto *opInst = state.value.getDefiningOp();
|
||||
// Note: getDefiningOp will return nullptr if the operand is not an
|
||||
// Operation (i.e., block arguments) which is a terminator for the search.
|
||||
if (opInst == nullptr) {
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
// if we encounter a memory operation, then
|
||||
// we can assume it's not worth doing any
|
||||
// rematerialization: layout conversion
|
||||
// will be cheaper
|
||||
if (isa<triton::gpu::CopyAsyncOp, triton::LoadOp, triton::StoreOp>(
|
||||
opInst))
|
||||
return;
|
||||
// we don't want to rematerialize conversions
|
||||
if (isa<triton::gpu::ConvertLayoutOp, scf::YieldOp, scf::ForOp>(opInst))
|
||||
return;
|
||||
// visit operands
|
||||
if (state.operandIndex < opInst->getNumOperands()) {
|
||||
auto nextOperand = opInst->getOperand(state.operandIndex);
|
||||
++state.operandIndex;
|
||||
worklist.push_back({nextOperand, 0});
|
||||
} else {
|
||||
// Post-visit: done visiting operand, pop off stack.
|
||||
// and add to post-order result
|
||||
worklist.pop_back();
|
||||
postOrderRet.push_back(opInst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Attribute invertEncoding(Type targetType, Operation *op) const {
|
||||
RankedTensorType targetTensorType = targetType.cast<RankedTensorType>();
|
||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
return targetTensorType.getEncoding()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.squeeze(expand_dims.axis());
|
||||
}
|
||||
return targetTensorType.getEncoding();
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *cvt,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(cvt))
|
||||
return mlir::failure();
|
||||
// constants/splat are handled separately
|
||||
Operation *op = cvt->getOperand(0).getDefiningOp();
|
||||
if (!op)
|
||||
return mlir::failure();
|
||||
if (isa<arith::ConstantOp, triton::MakeRangeOp, triton::SplatOp>(op))
|
||||
return mlir::failure();
|
||||
// DFS through all operands
|
||||
// auto filter = [](Operation *op) {
|
||||
// return !isa<triton::LoadOp, triton::StoreOp,
|
||||
// triton::gpu::ConvertLayoutOp>(op);
|
||||
// };
|
||||
|
||||
SmallVector<Operation *, 4> postOrderOps;
|
||||
getReachableNotThroughMemOp({cvt->getOperand(0)}, postOrderOps);
|
||||
if (postOrderOps.empty())
|
||||
return mlir::failure();
|
||||
|
||||
// We convert cvt(op(arg_0, arg_1, ..., arg_n))
|
||||
// into op(cvt_0(arg_0), cvt_1(arg_1), ..., cvt_n(arg_n))
|
||||
BlockAndValueMapping mapping;
|
||||
for (Value argI : op->getOperands()) {
|
||||
// Compute new argument types
|
||||
auto oldArgType = argI.getType().dyn_cast<RankedTensorType>();
|
||||
if (!oldArgType)
|
||||
continue;
|
||||
auto newEncoding = invertEncoding(cvt->getResultTypes()[0], op);
|
||||
auto newArgType = RankedTensorType::get(
|
||||
oldArgType.getShape(), oldArgType.getElementType(), newEncoding);
|
||||
// Create new argument
|
||||
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newArgType, argI);
|
||||
cvtI->moveBefore(op);
|
||||
mapping.map(argI, cvtI);
|
||||
}
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
newOp->getResult(0).setType(cvt->getResult(0).getType());
|
||||
rewriter.replaceOp(cvt, newOp->getResults());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// This modifies the loop in-place
|
||||
bool tryLegalizeOp(Operation *op, DenseSet<Value> toPreserve,
|
||||
mlir::PatternRewriter &rewriter) {
|
||||
auto targetType = toPreserve.begin()->getType().cast<RankedTensorType>();
|
||||
auto newType = [&](RankedTensorType origType) {
|
||||
return RankedTensorType::get(origType.getShape(), origType.getElementType(),
|
||||
targetType.getEncoding());
|
||||
};
|
||||
bool hasSameTypes = op->getDialect()->getNamespace() == "arith" ||
|
||||
isa<triton::SplatOp, triton::GEPOp>(op);
|
||||
if (hasSameTypes) {
|
||||
// replace argument types
|
||||
for (auto arg : llvm::enumerate(op->getOperands())) {
|
||||
auto argType = arg.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (toPreserve.count(arg.value()) || !argType)
|
||||
continue;
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
rewriter.getUnknownLoc(), newType(argType), arg.value());
|
||||
newArg->moveBefore(op);
|
||||
op->setOperand(arg.index(), newArg);
|
||||
}
|
||||
// replace result types
|
||||
if (!isa<triton::SplatOp>(op))
|
||||
op->getResult(0).setType(op->getOperand(0).getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
// i
|
||||
return false;
|
||||
}
|
||||
|
||||
std::pair<SmallVector<Value, 4>, scf::ForOp>
|
||||
tryConvertIterArg(scf::ForOp &forOp, mlir::PatternRewriter &rewriter, size_t i,
|
||||
Type newType) {
|
||||
auto newEncoding = newType.cast<RankedTensorType>().getEncoding();
|
||||
auto ctx = forOp.getContext();
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
// Rewrite init argument
|
||||
Type origType = forOp.getInitArgs()[i].getType();
|
||||
SmallVector<Value, 4> newInitArgs = forOp.getInitArgs();
|
||||
newInitArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newInitArgs[i].getLoc(), newType, newInitArgs[i]);
|
||||
// Clone for loop
|
||||
scf::ForOp newForOp = rewriter.create<scf::ForOp>(
|
||||
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
|
||||
forOp.getStep(), newInitArgs);
|
||||
newForOp->moveBefore(forOp);
|
||||
rewriter.setInsertionPointToStart(newForOp.getBody());
|
||||
BlockAndValueMapping mapping;
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
// traverse all ops in the loop
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
// we clone the op
|
||||
Operation *newOp = rewriter.clone(op, mapping);
|
||||
// if any argument of this op has changed type, then the
|
||||
// new operation is not legal and we should try to
|
||||
// legalize it.
|
||||
DenseSet<Value> modifiedTypes;
|
||||
for (Value arg : op.getOperands()) {
|
||||
if (mapping.contains(arg) &&
|
||||
mapping.lookup(arg).getType() != arg.getType())
|
||||
modifiedTypes.insert(mapping.lookup(arg));
|
||||
}
|
||||
|
||||
bool shouldTryLegalize = !modifiedTypes.empty();
|
||||
if (shouldTryLegalize)
|
||||
tryLegalizeOp(newOp, modifiedTypes, rewriter);
|
||||
}
|
||||
// create yield, inserting conversions if necessary
|
||||
auto yieldOp = forOp.getBody()->getTerminator();
|
||||
SmallVector<Value, 4> newYieldArgs;
|
||||
for (Value arg : yieldOp->getOperands())
|
||||
newYieldArgs.push_back(mapping.lookup(arg));
|
||||
newYieldArgs[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
yieldOp->getLoc(), newType, newYieldArgs[i]);
|
||||
rewriter.create<scf::YieldOp>(forOp.getLoc(), newYieldArgs);
|
||||
|
||||
// replace
|
||||
SmallVector<Value, 4> newResults = newForOp->getResults();
|
||||
newResults[i] = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
rewriter.getUnknownLoc(), origType, newForOp->getResult(i));
|
||||
newResults[i].getDefiningOp()->moveAfter(newForOp);
|
||||
return {newResults, newForOp};
|
||||
}
|
||||
|
||||
class MoveArgConvertOutOfLoop : public mlir::RewritePattern {
|
||||
public:
|
||||
MoveArgConvertOutOfLoop(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(scf::ForOp::getOperationName(), 1, context) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const {
|
||||
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
for (auto iterArg : llvm::enumerate(iterArgs)) {
|
||||
for (auto op : iterArg.value().getUsers()) {
|
||||
auto currOps = mlir::getSlice(op, isInLoop);
|
||||
auto pred = [&](Operation *op) {
|
||||
return isa<triton::LoadOp, triton::StoreOp>(op);
|
||||
};
|
||||
auto isCvt = [&](Operation *op) {
|
||||
return isa<triton::gpu::ConvertLayoutOp>(op);
|
||||
};
|
||||
auto isYield = [&](Operation *op) { return isa<scf::YieldOp>(op); };
|
||||
auto opIt = std::find(currOps.begin(), currOps.end(), op);
|
||||
auto yieldIt = std::find_if(currOps.begin(), currOps.end(), isYield);
|
||||
auto fwdEndIt = std::find_if(opIt, currOps.end(), pred);
|
||||
auto bwdBeginIt = std::find_if(currOps.begin(), opIt, pred);
|
||||
auto fwdCvtIt = std::find_if(opIt, fwdEndIt, isCvt);
|
||||
auto bwdCvtIt = std::find_if(bwdBeginIt, opIt, isCvt);
|
||||
|
||||
if (fwdCvtIt != fwdEndIt) {
|
||||
auto newFor = tryConvertIterArg(forOp, rewriter, iterArg.index(),
|
||||
(*fwdCvtIt)->getResult(0).getType());
|
||||
rewriter.replaceOp(forOp, newFor.first);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
class PushConversionToSink : public mlir::RewritePattern {
|
||||
public:
|
||||
PushConversionToSink(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
2, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *_cvtOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(_cvtOp);
|
||||
auto forOp = dyn_cast<scf::ForOp>(cvt->getParentOp());
|
||||
if (!forOp)
|
||||
return mlir::failure();
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
auto isInLoop = [&](Operation *op) { return op->getParentOp() == forOp; };
|
||||
|
||||
SetVector<Operation *> cvtSlices;
|
||||
auto filter = [&](Operation *op) {
|
||||
return isInLoop(op) && !isa<triton::LoadOp>(op) &&
|
||||
!isa<triton::DotOp>(op) && !isa<scf::YieldOp>(op) &&
|
||||
!isa<triton::gpu::ConvertLayoutOp>(op);
|
||||
};
|
||||
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
|
||||
if (cvtSlices.empty())
|
||||
return failure();
|
||||
// if other operands are in the loop
|
||||
// then we don't touch anything
|
||||
Operation *op = cvtSlices.front();
|
||||
for (Value _arg : op->getOperands()) {
|
||||
Operation *arg = _arg.getDefiningOp();
|
||||
if (arg && isInLoop(arg) && (arg != cvt))
|
||||
return failure();
|
||||
}
|
||||
// otherwise, we push the conversion forward
|
||||
// since we'll be able to move it out of
|
||||
// the loop once it reaches the yield op
|
||||
// op(cvt(arg_0), arg_1, ..., arg_n)
|
||||
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
|
||||
BlockAndValueMapping mapping;
|
||||
for (Value arg : op->getOperands()) {
|
||||
if (arg.getDefiningOp() == cvt)
|
||||
mapping.map(arg, cvt.getOperand());
|
||||
else {
|
||||
auto cvtI = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
arg.getLoc(), cvt.getOperand().getType(), arg);
|
||||
mapping.map(arg, cvtI);
|
||||
}
|
||||
}
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
newOp->getResult(0).setType(cvt.getOperand().getType());
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
newOp->getLoc(), cvt.getResult().getType(), newOp->getResult(0));
|
||||
rewriter.replaceOp(op, newCvt->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
class BlockedToMMA : public mlir::RewritePattern {
|
||||
public:
|
||||
BlockedToMMA(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto dotOp = cast<triton::DotOp>(op);
|
||||
// TODO: Check data-types and SM compatibility
|
||||
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
|
||||
if (oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
|
||||
return failure();
|
||||
// TODO: compute warpsPerCTA
|
||||
auto newRetType = RankedTensorType::get(
|
||||
oldRetType.getShape(), oldRetType.getElementType(),
|
||||
triton::gpu::MmaEncodingAttr::get(oldRetType.getContext(), 2, {2, 2}));
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
auto newDot = rewriter.create<triton::DotOp>(
|
||||
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
|
||||
newAcc, dotOp.allowTF32());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
|
||||
op, oldRetType, newDot.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
@@ -36,9 +438,11 @@ public:
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<ConvertLayoutOptPattern>(context);
|
||||
patterns.add<CopyAsyncOptPattern>(context);
|
||||
patterns.add<RedundantConvertLayoutOptPattern>(context);
|
||||
patterns.add<SimplifyConversion>(context);
|
||||
patterns.add<PullConversionToSource>(context);
|
||||
patterns.add<PushConversionToSink>(context);
|
||||
patterns.add<MoveArgConvertOutOfLoop>(context);
|
||||
patterns.add<BlockedToMMA>(context);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
@@ -47,4 +451,4 @@ public:
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUCombineOpsPass() {
|
||||
return std::make_unique<TritonGPUCombineOpsPass>();
|
||||
}
|
||||
}
|
@@ -4,22 +4,4 @@
|
||||
include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td"
|
||||
include "triton/Dialect/Triton/IR/TritonOps.td"
|
||||
|
||||
// convert_layout(load(...), #L) => copy_async(...); barrier
|
||||
// if #L is smem_layout
|
||||
def CopyAsyncOptPattern : Pat<
|
||||
(TTG_ConvertLayoutOp:$res (TT_LoadOp $ptr, $mask, $other, $cache, $evict, $isVolatile, $isOtherUnspecified)),
|
||||
(TTG_CopyAsyncOp $ptr, $mask, $other, $cache, $evict, $isVolatile, $isOtherUnspecified),
|
||||
[(Constraint<CPred<"isSharedLayout($0)">> $res)]>;
|
||||
|
||||
// ConvertLayout(ConvertLayout(x, #L0), #L1) => ConvertLayout(x, #L1)
|
||||
def ConvertLayoutOptPattern : Pat<
|
||||
(TTG_ConvertLayoutOp (TTG_ConvertLayoutOp $x)),
|
||||
(TTG_ConvertLayoutOp $x)>;
|
||||
|
||||
// TODO: can we replace this with ConvertLayoutOp's folder?
|
||||
// ConvertLayout(x, #L) => x if x.layout() == #L
|
||||
def RedundantConvertLayoutOptPattern : Pat<
|
||||
(TTG_ConvertLayoutOp:$res $x), (replaceWithValue $x),
|
||||
[(Constraint<CPred<"$0.getType() == $1.getType()">> $res, $x)]>;
|
||||
|
||||
#endif
|
||||
|
@@ -1,8 +1,7 @@
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements loop software pipelining
|
||||
@@ -168,8 +167,7 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
if (auto tensorType = convertLayout.getResult()
|
||||
.getType()
|
||||
.dyn_cast<RankedTensorType>()) {
|
||||
if (tensorType.getEncoding()
|
||||
.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
if (tensorType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
|
||||
isCandiate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
@@ -263,7 +261,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
// assert(I1 or TensorOf<[I1]>);
|
||||
OpBuilder::InsertionGuard g(builder);
|
||||
builder.setInsertionPoint(newOp);
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), loopCond);
|
||||
Value newMask =
|
||||
builder.create<arith::AndIOp>(mask.getLoc(), mask, splatCond);
|
||||
@@ -356,6 +354,9 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||
mapping.lookup(loadUse).replaceAllUsesWith(
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx * (numStages - 1)]);
|
||||
// delete old load and layout conversion
|
||||
mapping.lookup(loadUse).getDefiningOp()->erase();
|
||||
mapping.lookup(load).getDefiningOp()->erase();
|
||||
}
|
||||
|
||||
// 4. prefetch the next iteration
|
||||
@@ -389,7 +390,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
if (loads.contains(op->getResult(0))) {
|
||||
auto loadOp = llvm::cast<triton::LoadOp>(op);
|
||||
Value mask = loadOp.mask();
|
||||
Value splatCond = builder.create<triton::BroadcastOp>(
|
||||
Value splatCond = builder.create<triton::SplatOp>(
|
||||
mask.getLoc(), mask.getType(), nextLoopCond);
|
||||
Value newMask = builder.create<arith::AndIOp>(
|
||||
mask.getLoc(), splatCond, nextMapping.lookupOrDefault(mask));
|
||||
@@ -442,9 +443,10 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
yieldValues.push_back(
|
||||
depArgsMapping.lookup(newForOp.getRegionIterArgs()[i]));
|
||||
yieldValues.push_back(nextIV);
|
||||
|
||||
builder.setInsertionPointToEnd(newForOp.getBody());
|
||||
builder.create<scf::YieldOp>(forOp.getBody()->getTerminator()->getLoc(),
|
||||
yieldValues);
|
||||
auto test = builder.create<scf::YieldOp>(
|
||||
forOp.getBody()->getTerminator()->getLoc(), yieldValues);
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
|
@@ -29,7 +29,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
|
||||
Attribute encoding = triton::gpu::TritonGPUBlockedEncodingAttr::get(
|
||||
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
|
||||
this->context, shape, sizePerThread, order, this->numWarps);
|
||||
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
|
||||
});
|
||||
@@ -95,9 +95,8 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||
Attribute bEncoding =
|
||||
dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||
if (aEncoding &&
|
||||
aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
if (aEncoding && aEncoding.isa<triton::gpu::SharedEncodingAttr>() &&
|
||||
bEncoding && bEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return true;
|
||||
return false;
|
||||
});
|
||||
|
@@ -33,7 +33,7 @@ private:
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||
if (!encoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of shared layout";
|
||||
} else
|
||||
return dotOp.emitError()
|
||||
@@ -49,8 +49,8 @@ private:
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
||||
!encoding.isa<triton::gpu::TritonGPUBlockedEncodingAttr>())
|
||||
if (!encoding.isa<triton::gpu::MmaEncodingAttr>() &&
|
||||
!encoding.isa<triton::gpu::BlockedEncodingAttr>())
|
||||
return dotOp.emitError()
|
||||
<< name << " should be of distributed layout";
|
||||
if (name == 'c')
|
||||
|
@@ -749,8 +749,9 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
context = _triton.ir.context()
|
||||
context.load_triton()
|
||||
# create kernel prototype
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
constants = {fn.arg_names.index(name): value for name, value in constants.items()}
|
||||
attributes = {fn.arg_names.index(name): value for name, value in attributes.items()}
|
||||
arg_types = signature.replace(' ', '').split(',')
|
||||
arg_types = [str_to_ty(x) for x in arg_types]
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
# visit kernel AST
|
||||
@@ -769,6 +770,14 @@ def make_triton_ir(fn, signature, constants=dict(), attributes=dict()):
|
||||
return ret
|
||||
|
||||
|
||||
def optimize_triton_ir(mod):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_inliner_pass()
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def make_tritongpu_ir(mod, num_warps):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.add_inliner_pass()
|
||||
@@ -785,7 +794,7 @@ def optimize_tritongpu_ir(mod, num_stages):
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_canonicalizer_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_triton_gpu_combine_pass()
|
||||
# pm.add_triton_gpu_combine_pass()
|
||||
pm.add_triton_gpu_verifier_pass()
|
||||
pm.run(mod)
|
||||
return mod
|
||||
@@ -815,6 +824,7 @@ def compile(fn, signature, device=-1, constants=dict(), attributes=dict(), num_w
|
||||
assert output in valid_outputs, "output should be one of [%s], but get \"%s\"" % (','.join(valid_outputs), output)
|
||||
# triton-ir
|
||||
module = make_triton_ir(fn, signature, constants, attributes)
|
||||
module = optimize_triton_ir(module)
|
||||
if output == "ttir":
|
||||
return module.str()
|
||||
# tritongpu-ir
|
||||
|
175
test/TritonGPU/combine.mlir
Normal file
175
test/TritonGPU/combine.mlir
Normal file
@@ -0,0 +1,175 @@
|
||||
// RUN: triton-opt %s -tritongpu-combine 2>&1 | FileCheck %s
|
||||
|
||||
#layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
|
||||
// CHECK: [[target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
// CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
func @cst() -> tensor<1024xi32, #layout1> {
|
||||
%cst = arith.constant dense<0> : tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: return %cst : tensor<1024xi32, [[target_layout]]>
|
||||
return %1: tensor<1024xi32, #layout1>
|
||||
}
|
||||
|
||||
func @range() -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: return %0 : tensor<1024xi32, [[target_layout]]>
|
||||
return %1: tensor<1024xi32, #layout1>
|
||||
}
|
||||
|
||||
func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: return %0 : tensor<1024xi32, [[target_layout]]>
|
||||
return %1: tensor<1024xi32, #layout1>
|
||||
}
|
||||
|
||||
func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
%0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0>
|
||||
%2 = arith.muli %0, %1 : tensor<1024xi32, #layout0>
|
||||
%3 = triton_gpu.convert_layout %2 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
%4 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1>
|
||||
%6 = arith.addi %3, %5 : tensor<1024xi32, #layout1>
|
||||
return %6: tensor<1024xi32, #layout1>
|
||||
// CHECK: %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %3 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %4 = arith.muli %2, %3 : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %5 = arith.muli %0, %1 : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: %6 = arith.addi %4, %5 : tensor<1024xi32, [[target_layout]]>
|
||||
// CHECK: return %6 : tensor<1024xi32, [[target_layout]]>
|
||||
}
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>
|
||||
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}>
|
||||
|
||||
// CHECK-LABEL: transpose
|
||||
func @transpose(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) {
|
||||
// CHECK: %cst = arith.constant dense<true> : tensor<64x64xi1, [[row_layout]]>
|
||||
// CHECK: %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: %cst_1 = arith.constant dense<true> : tensor<64x64xi1, [[col_layout]]>
|
||||
// CHECK: %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>
|
||||
// CHECK: %1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>
|
||||
// CHECK: %2 = tt.expand_dims %1 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[row_layout]]}>>) -> tensor<64x1xi32, [[row_layout]]>
|
||||
// CHECK: %3 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, [[row_layout]]>
|
||||
// CHECK: %4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %5 = arith.muli %2, %3 : tensor<64x1xi32, [[row_layout]]>
|
||||
// CHECK: %6 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>
|
||||
// CHECK: %7 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>
|
||||
// CHECK: %8 = tt.getelementptr %4, %5 : tensor<64x1x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %9 = tt.expand_dims %7 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[row_layout]]}>>) -> tensor<1x64xi32, [[row_layout]]>
|
||||
// CHECK: %10 = tt.broadcast %8 : (tensor<64x1x!tt.ptr<f32>, [[row_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %11 = tt.broadcast %9 : (tensor<1x64xi32, [[row_layout]]>) -> tensor<64x64xi32, [[row_layout]]>
|
||||
// CHECK: %12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %13 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = [[col_layout]]}>>) -> tensor<64x1xi32, [[col_layout]]>
|
||||
// CHECK: %14 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = [[col_layout]]}>>) -> tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %15 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %16 = tt.getelementptr %12, %13 : tensor<64x1x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %17 = arith.muli %14, %15 : tensor<1x64xi32, [[col_layout]]>
|
||||
// CHECK: %18 = tt.broadcast %16 : (tensor<64x1x!tt.ptr<f32>, [[col_layout]]>) -> tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %19 = tt.broadcast %17 : (tensor<1x64xi32, [[col_layout]]>) -> tensor<64x64xi32, [[col_layout]]>
|
||||
// CHECK: %20 = tt.getelementptr %10, %11 : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK: %21 = tt.load %20, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK: %22 = tt.getelementptr %18, %19 : tensor<64x64x!tt.ptr<f32>, [[col_layout]]>
|
||||
// CHECK: %23 = triton_gpu.convert_layout %21 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: tt.store %22, %23, %cst_1, : tensor<64x64xf32, [[col_layout]]>
|
||||
// CHECK: return
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%cst_0 = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%12 = tt.getelementptr %11, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%14 = arith.muli %6, %13 : tensor<1x64xi32, #blocked2>
|
||||
%15 = tt.broadcast %12 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%16 = tt.broadcast %14 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%17 = triton_gpu.convert_layout %16 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%18 = tt.getelementptr %15, %17 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%19 = triton_gpu.convert_layout %10 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%20 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
%21 = triton_gpu.convert_layout %cst : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||
%22 = tt.load %19, %20, %21 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, isOtherUnspecified = false} : tensor<64x64xf32, #blocked3>
|
||||
%23 = triton_gpu.convert_layout %22 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||
%24 = triton_gpu.convert_layout %18 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked4>
|
||||
%25 = triton_gpu.convert_layout %23 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked4>
|
||||
%26 = triton_gpu.convert_layout %cst_0 : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked4>
|
||||
tt.store %24, %25, %26, : tensor<64x64xf32, #blocked4>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: loop
|
||||
func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
// CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>)
|
||||
// CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = arith.addf {{.*}} : tensor<64x64xf32, [[row_layout]]>
|
||||
// CHECK-NEXT: {{.*}} = tt.getelementptr {{.*}} : tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK-NEXT: scf.yield {{.*}} : tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: {{.*}} = triton_gpu.convert_layout [[loop_ret]]#0 : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout_novec]]>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%cst = arith.constant dense<true> : tensor<64x64xi1, #blocked1>
|
||||
%cst_0 = arith.constant dense<64> : tensor<64x64xi32, #blocked1>
|
||||
%c1 = arith.constant 1 : index
|
||||
%c32 = arith.constant 32 : index
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked1>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked0>
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<64xi32, #blocked0>) -> tensor<64x1xi32, #blocked1>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<64x1xi32, #blocked1>
|
||||
%3 = arith.muli %1, %2 : tensor<64x1xi32, #blocked1>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%5 = tt.getelementptr %4, %3 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%6 = tt.expand_dims %0 {axis = 0 : i32} : (tensor<64xi32, #blocked0>) -> tensor<1x64xi32, #blocked2>
|
||||
%7 = tt.broadcast %5 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%8 = tt.broadcast %6 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%9 = triton_gpu.convert_layout %8 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%10 = tt.getelementptr %7, %9 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%11:2 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_1, %arg7 = %10) -> (tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>) {
|
||||
%23 = triton_gpu.convert_layout %arg7 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked3>
|
||||
%24 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked3>
|
||||
%25 = triton_gpu.convert_layout %cst_1 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked3>
|
||||
%26 = tt.load %23, %24, %25 {cache = 1 : i32, evict = 1 : i32, isOtherUnspecified = false, isVolatile = false} : tensor<64x64xf32, #blocked3>
|
||||
%27 = triton_gpu.convert_layout %26 : (tensor<64x64xf32, #blocked3>) -> tensor<64x64xf32, #blocked1>
|
||||
%28 = arith.addf %arg6, %27 : tensor<64x64xf32, #blocked1>
|
||||
%29 = tt.getelementptr %arg7, %cst_0 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
scf.yield %28, %29 : tensor<64x64xf32, #blocked1>, tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
}
|
||||
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%13 = tt.getelementptr %12, %1 : tensor<64x1x!tt.ptr<f32>, #blocked1>
|
||||
%14 = tt.splat %arg3 : (i32) -> tensor<1x64xi32, #blocked2>
|
||||
%15 = arith.muli %6, %14 : tensor<1x64xi32, #blocked2>
|
||||
%16 = tt.broadcast %13 : (tensor<64x1x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%17 = tt.broadcast %15 : (tensor<1x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked2>
|
||||
%18 = triton_gpu.convert_layout %17 : (tensor<64x64xi32, #blocked2>) -> tensor<64x64xi32, #blocked1>
|
||||
%19 = tt.getelementptr %16, %18 : tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<64x64x!tt.ptr<f32>, #blocked1>) -> tensor<64x64x!tt.ptr<f32>, #blocked1>
|
||||
%21 = triton_gpu.convert_layout %11#0 : (tensor<64x64xf32, #blocked1>) -> tensor<64x64xf32, #blocked1>
|
||||
%22 = triton_gpu.convert_layout %cst : (tensor<64x64xi1, #blocked1>) -> tensor<64x64xi1, #blocked1>
|
||||
tt.store %20, %21, %22, : tensor<64x64xf32, #blocked1>
|
||||
return
|
||||
}
|
Reference in New Issue
Block a user