more progress on the definition of layouts
This commit is contained in:
@@ -4,46 +4,181 @@
|
||||
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton::gpu;
|
||||
|
||||
// parse an array of integers
|
||||
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||
const NamedAttribute &attr,
|
||||
SmallVector<unsigned, 2> &res,
|
||||
StringRef desc) {
|
||||
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||
<< desc;
|
||||
return failure();
|
||||
}
|
||||
for (Attribute i : arrayAttr) {
|
||||
auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||
<< desc;
|
||||
return failure();
|
||||
}
|
||||
res.push_back(intAttr.getUInt());
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Attribute methods
|
||||
//===----------------------------------------------------------------------===//
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||
|
||||
mlir::Attribute
|
||||
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
|
||||
llvm_unreachable("Not implemented");
|
||||
Attribute
|
||||
TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
DictionaryAttr dict;
|
||||
if (parser.parseAttribute(dict).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
SmallVector<unsigned, 2> threadTileSize;
|
||||
SmallVector<unsigned, 2> warpTileSize;
|
||||
SmallVector<unsigned, 2> blockTileSize;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
// parse an array of integers
|
||||
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
|
||||
// SmallVector<unsigned, 2> &res,
|
||||
// StringRef desc) -> LogicalResult {
|
||||
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||
// if (!arrayAttr) {
|
||||
// parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||
// << desc;
|
||||
// return failure();
|
||||
// }
|
||||
// for (Attribute i : arrayAttr) {
|
||||
// auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||
// if (!intAttr) {
|
||||
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||
// << desc;
|
||||
// return failure();
|
||||
// }
|
||||
// res.push_back(intAttr.getUInt());
|
||||
// }
|
||||
// return success();
|
||||
// };
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "threadTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "warpTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "blockTileSize") {
|
||||
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUShardedEncodingAttr>(parser.getContext(),
|
||||
threadTileSize,
|
||||
warpTileSize,
|
||||
blockTileSize,
|
||||
order);
|
||||
}
|
||||
|
||||
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<"
|
||||
<< "threadTileSize = " << getThreadTileSize()
|
||||
<< ", blockTileSize = " << getBlockTileSize()
|
||||
<< ", order = " << getOrder()
|
||||
<< "threadTileSize = [" << getThreadTileSize() << "]"
|
||||
<< ", warpTileSize = [" << getWarpTileSize() << "]"
|
||||
<< ", blockTileSize = [" << getBlockTileSize() << "]"
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< ">";
|
||||
}
|
||||
|
||||
mlir::Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
||||
Attribute
|
||||
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
llvm_unreachable("Not implemented");
|
||||
}
|
||||
|
||||
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
llvm_unreachable("Not implemented");
|
||||
}
|
||||
|
||||
mlir::Attribute
|
||||
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
||||
llvm_unreachable("Not implemented");
|
||||
Attribute
|
||||
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
// Parse the data as a dictionary
|
||||
DictionaryAttr dict;
|
||||
if (parser.parseAttribute(dict).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
unsigned vec = 0;
|
||||
unsigned perPhase = 0;
|
||||
unsigned maxPhase = 0;
|
||||
SmallVector<unsigned, 2> order;
|
||||
|
||||
auto parseUInt = [&parser](const NamedAttribute &attr,
|
||||
unsigned &value,
|
||||
StringRef desc) -> LogicalResult {
|
||||
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||
if (!intAttr) {
|
||||
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
|
||||
return failure();
|
||||
}
|
||||
value = intAttr.getUInt();
|
||||
return success();
|
||||
};
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "vec") {
|
||||
if (parseUInt(attr, vec, "vec").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "perPhase") {
|
||||
if (parseUInt(attr, perPhase, "perPhase").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "maxPhase") {
|
||||
if (parseUInt(attr, maxPhase, "maxPhase").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "order") {
|
||||
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||
return {};
|
||||
} else {
|
||||
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||
<< attr.getName().strref();
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
|
||||
vec,
|
||||
perPhase,
|
||||
maxPhase,
|
||||
order);
|
||||
}
|
||||
|
||||
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<"
|
||||
// << "threadTileSize = " << getThreadTileSize()
|
||||
// << ", blockTileSize = " << getBlockTileSize()
|
||||
// << ", order = " << getOrder()
|
||||
<< "vec = " << getVec()
|
||||
<< ", perPhase = " << getPerPhase()
|
||||
<< ", order = [" << getOrder() << "]"
|
||||
<< ">";
|
||||
}
|
||||
|
||||
@@ -92,9 +227,9 @@ static Type getPointeeType(Type type) {
|
||||
|
||||
|
||||
// verify TritonGPU ops
|
||||
mlir::LogicalResult
|
||||
TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op,
|
||||
mlir::NamedAttribute attr) {
|
||||
LogicalResult
|
||||
TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute attr) {
|
||||
// TODO: fill this.
|
||||
return success();
|
||||
}
|
||||
|
@@ -35,6 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||
llvm::SmallVector<unsigned> order(rank);
|
||||
int remainingThreads = numThreads;
|
||||
@@ -45,8 +46,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
remainingThreads /= blockTileSize[dim];
|
||||
// TODO: will we need repetition?
|
||||
}
|
||||
Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get(
|
||||
context, threadTileSize, blockTileSize, order);
|
||||
Attribute encoding = triton::gpu::TritonGPUShardedEncodingAttr::get(
|
||||
context, threadTileSize, warpTileSize, blockTileSize, order);
|
||||
return RankedTensorType::get(shape, elementType, encoding);
|
||||
});
|
||||
|
||||
|
@@ -50,7 +50,7 @@ private:
|
||||
if (!encoding)
|
||||
return dotOp.emitError() << name << " should have encoding";
|
||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
||||
!encoding.isa<triton::gpu::TritonGPUDistributedEncodingAttr>())
|
||||
!encoding.isa<triton::gpu::TritonGPUShardedEncodingAttr>())
|
||||
return dotOp.emitError() << name << " should be of distributed layout";
|
||||
if (name == 'c')
|
||||
cLayout = encoding;
|
||||
|
Reference in New Issue
Block a user