[OPTIMIZER] Rewrite patterns for layout conversions (#64)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
@@ -70,11 +72,15 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
|
||||
SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
||||
return SliceEncodingAttr::get(getContext(), axis, *this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -115,11 +121,11 @@ Attribute TritonGPUBlockedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUBlockedEncodingAttr>(
|
||||
return parser.getChecked<BlockedEncodingAttr>(
|
||||
parser.getContext(), sizePerThread, threadsPerWarp, warpsPerCTA, order);
|
||||
}
|
||||
|
||||
void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "sizePerThread = [" << getSizePerThread() << "]"
|
||||
<< ", threadsPerWarp = [" << getThreadsPerWarp() << "]"
|
||||
@@ -132,7 +138,7 @@ void TritonGPUBlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
// MMA encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute MmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
DictionaryAttr dict;
|
||||
@@ -155,22 +161,59 @@ Attribute TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUMmaEncodingAttr>(parser.getContext(),
|
||||
version, warpsPerCTA);
|
||||
return parser.getChecked<MmaEncodingAttr>(parser.getContext(), version,
|
||||
warpsPerCTA);
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
void MmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "version = " << getVersion() << ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sliced Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
DictionaryAttr dict;
|
||||
if (parser.parseAttribute(dict).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
unsigned dim = 0;
|
||||
Attribute parent;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "dim") {
|
||||
if (parseUInt(parser, attr, dim, "dim").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "parent") {
|
||||
if (parser.parseAttribute(parent).failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
|
||||
}
|
||||
|
||||
void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "dim = " << getDim() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
Attribute SharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
@@ -205,11 +248,11 @@ Attribute TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(
|
||||
parser.getContext(), vec, perPhase, maxPhase, order);
|
||||
return parser.getChecked<SharedEncodingAttr>(parser.getContext(), vec,
|
||||
perPhase, maxPhase, order);
|
||||
}
|
||||
|
||||
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "vec = " << getVec() << ", perPhase = " << getPerPhase()
|
||||
<< ", maxPhase = " << getMaxPhase() << ", order = [" << getOrder()
|
||||
@@ -226,18 +269,21 @@ public:
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
||||
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
|
||||
if (auto mmaAttr = attr.dyn_cast<TritonGPUMmaEncodingAttr>()) {
|
||||
if (auto mmaAttr = attr.dyn_cast<MmaEncodingAttr>()) {
|
||||
os << "mma";
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto sharedAttr = attr.dyn_cast<TritonGPUSharedEncodingAttr>()) {
|
||||
} else if (auto sharedAttr = attr.dyn_cast<SharedEncodingAttr>()) {
|
||||
os << "shared";
|
||||
return AliasResult::FinalAlias;
|
||||
} else if (auto blockedAttr =
|
||||
attr.dyn_cast<TritonGPUBlockedEncodingAttr>()) {
|
||||
} else if (auto blockedAttr = attr.dyn_cast<BlockedEncodingAttr>()) {
|
||||
os << "blocked";
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
return OpAsmDialectInterface::getAlias(attr, os);
|
||||
} /* else if (auto sliceAttr = attr.dyn_cast<SliceEncodingAttr>()) {
|
||||
os << "slice";
|
||||
return AliasResult::FinalAlias;
|
||||
} */
|
||||
OpAsmDialectInterface::getAlias(attr, os);
|
||||
return AliasResult::FinalAlias;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -283,11 +329,15 @@ static Type getPointeeType(Type type) {
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verification
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CopyAsyncOp op) {
|
||||
Type resType = op.getResult().getType();
|
||||
if (auto tensorType = resType.dyn_cast<RankedTensorType>()) {
|
||||
Attribute encoding = tensorType.getEncoding();
|
||||
if (!encoding.isa<TritonGPUSharedEncodingAttr>())
|
||||
if (!encoding.isa<SharedEncodingAttr>())
|
||||
return op.emitOpError("copy_async should return a shared memory tensor");
|
||||
} else
|
||||
return op.emitOpError("copy_async should return a tensor");
|
||||
@@ -302,4 +352,4 @@ LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user