more progress on the definition of layouts

This commit is contained in:
Da Yan
2022-05-31 11:43:21 +00:00
parent 41d338d848
commit e36a54eb86
4 changed files with 203 additions and 36 deletions

View File

@@ -4,46 +4,181 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
using namespace mlir;
using namespace mlir::triton::gpu;
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
SmallVector<unsigned, 2> &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(), "expected an array for ")
<< desc;
return failure();
}
for (Attribute i : arrayAttr) {
auto intAttr = i.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
<< desc;
return failure();
}
res.push_back(intAttr.getUInt());
}
return success();
};
//===----------------------------------------------------------------------===//
// Attribute methods
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
mlir::Attribute
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
llvm_unreachable("Not implemented");
Attribute
TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> threadTileSize;
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
SmallVector<unsigned, 2> order;
// parse an array of integers
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
// SmallVector<unsigned, 2> &res,
// StringRef desc) -> LogicalResult {
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
// if (!arrayAttr) {
// parser.emitError(parser.getNameLoc(), "expected an array for ")
// << desc;
// return failure();
// }
// for (Attribute i : arrayAttr) {
// auto intAttr = i.dyn_cast<IntegerAttr>();
// if (!intAttr) {
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
// << desc;
// return failure();
// }
// res.push_back(intAttr.getUInt());
// }
// return success();
// };
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
return {};
} else if (attr.getName() == "warpTileSize") {
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
return {};
} else if (attr.getName() == "blockTileSize") {
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUShardedEncodingAttr>(parser.getContext(),
threadTileSize,
warpTileSize,
blockTileSize,
order);
}
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<"
<< "threadTileSize = " << getThreadTileSize()
<< ", blockTileSize = " << getBlockTileSize()
<< ", order = " << getOrder()
<< "threadTileSize = [" << getThreadTileSize() << "]"
<< ", warpTileSize = [" << getWarpTileSize() << "]"
<< ", blockTileSize = [" << getBlockTileSize() << "]"
<< ", order = [" << getOrder() << "]"
<< ">";
}
mlir::Attribute
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
llvm_unreachable("Not implemented");
}
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
}
mlir::Attribute
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
llvm_unreachable("Not implemented");
Attribute
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned vec = 0;
unsigned perPhase = 0;
unsigned maxPhase = 0;
SmallVector<unsigned, 2> order;
auto parseUInt = [&parser](const NamedAttribute &attr,
unsigned &value,
StringRef desc) -> LogicalResult {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
};
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "vec") {
if (parseUInt(attr, vec, "vec").failed())
return {};
} else if (attr.getName() == "perPhase") {
if (parseUInt(attr, perPhase, "perPhase").failed())
return {};
} else if (attr.getName() == "maxPhase") {
if (parseUInt(attr, maxPhase, "maxPhase").failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
vec,
perPhase,
maxPhase,
order);
}
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<"
// << "threadTileSize = " << getThreadTileSize()
// << ", blockTileSize = " << getBlockTileSize()
// << ", order = " << getOrder()
<< "vec = " << getVec()
<< ", perPhase = " << getPerPhase()
<< ", order = [" << getOrder() << "]"
<< ">";
}
@@ -92,9 +227,9 @@ static Type getPointeeType(Type type) {
// verify TritonGPU ops
mlir::LogicalResult
TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op,
mlir::NamedAttribute attr) {
LogicalResult
TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}

View File

@@ -35,6 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
int remainingThreads = numThreads;
@@ -45,8 +46,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
remainingThreads /= blockTileSize[dim];
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get(
context, threadTileSize, blockTileSize, order);
Attribute encoding = triton::gpu::TritonGPUShardedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order);
return RankedTensorType::get(shape, elementType, encoding);
});

View File

@@ -50,7 +50,7 @@ private:
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
!encoding.isa<triton::gpu::TritonGPUDistributedEncodingAttr>())
!encoding.isa<triton::gpu::TritonGPUShardedEncodingAttr>())
return dotOp.emitError() << name << " should be of distributed layout";
if (name == 'c')
cLayout = encoding;