[OPTIMIZER] Rewrite patterns for layout conversions (#64)

This commit is contained in:
Philippe Tillet
2022-08-18 12:49:37 -07:00
committed by GitHub
parent e0bedeb44c
commit 192be76b3c
19 changed files with 851 additions and 127 deletions

View File

@@ -1,6 +1,8 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <numeric>
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/TypeSwitch.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
@@ -70,11 +72,15 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
return SliceEncodingAttr::get(getContext(), axis, *this);
}
//===----------------------------------------------------------------------===//
// Blocked Encoding
//===----------------------------------------------------------------------===//
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -115,11 +121,11 @@ Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
}
}
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
return parser.getChecked<BlockedEncodingAttr>(
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
}
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "sizePerThread = [" << getSizePerThread() << "]"
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
@@ -132,7 +138,7 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
// MMA encoding
//===----------------------------------------------------------------------===//
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
@@ -155,22 +161,59 @@ Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
}
}
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
version, warpsPerCTA);
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), version,
warpsPerCTA);
}
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
void MmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "version = " << getVersion() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< "}>";
}
//===----------------------------------------------------------------------===//
// Sliced Encoding
//===----------------------------------------------------------------------===//
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned dim = 0;
Attribute parent;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "dim") {
if (parseUInt(parser, attr, dim, "dim").failed())
return {};
}
if (attr.getName() == "parent") {
if (parser.parseAttribute(parent).failed())
return {};
}
}
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
}
void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "dim = " << getDim() << ", "
<< "parent = " << getParent() << "}>";
}
//===----------------------------------------------------------------------===//
// Shared encoding
//===----------------------------------------------------------------------===//
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
@@ -205,11 +248,11 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
}
}
return parser.getChecked<TritonGPUSharedEncodingAttr>(
parser.getContext(), vec, perPhase, maxPhase, order);
return parser.getChecked<SharedEncodingAttr>(parser.getContext(), vec,
perPhase, maxPhase, order);
}
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
void SharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
@@ -226,18 +269,21 @@ public:
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
if (auto mmaAttr = attr.dyn_cast<MmaEncodingAttr>()) {
os << "mma";
return AliasResult::FinalAlias;
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
} else if (auto sharedAttr = attr.dyn_cast<SharedEncodingAttr>()) {
os << "shared";
return AliasResult::FinalAlias;
} else if (auto blockedAttr =
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
} else if (auto blockedAttr = attr.dyn_cast<BlockedEncodingAttr>()) {
os << "blocked";
return AliasResult::FinalAlias;
}
return OpAsmDialectInterface::getAlias(attr, os);
} /* else if (auto sliceAttr = attr.dyn_cast<SliceEncodingAttr>()) {
os << "slice";
return AliasResult::FinalAlias;
} */
OpAsmDialectInterface::getAlias(attr, os);
return AliasResult::FinalAlias;
}
};
@@ -283,11 +329,15 @@ static Type getPointeeType(Type type) {
} // namespace triton
} // namespace mlir
//===----------------------------------------------------------------------===//
// Verification
//===----------------------------------------------------------------------===//
static LogicalResult verify(CopyAsyncOp op) {
Type resType = op.getResult().getType();
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
if (!encoding.isa<TritonGPUSharedEncodingAttr>())
if (!encoding.isa<SharedEncodingAttr>())
return op.emitOpError("copy_async should return a shared memory tensor");
} else
return op.emitOpError("copy_async should return a tensor");
@@ -302,4 +352,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}
}