Merge branch 'mlir-rewrite' of https://github.com/daadaada/mlir-rewrite into mlir-rewrite

This commit is contained in:
Yan Da
2022-06-01 10:59:20 +08:00
4 changed files with 203 additions and 36 deletions

View File

@@ -8,12 +8,14 @@ class TritonGPU_Attr<string name, list<Trait> traits = []>
: AttrDef<TritonGPU_Dialect, name, traits>;
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
let mnemonic = "shared (memory) encoding";
let mnemonic = "shared_layout";
let description = [{
An encoding for tensors whose elements may be simultaneously accessed by different warps in the programs, via shared memory.
An encoding for tensors whose elements may be simultaneously accessed by
different warps in the programs, via shared memory.
In order to avoid shared memory bank conflicts, elements may be stored in a swizzled layout.
In order to avoid shared memory bank conflicts, elements may be stored in a
swizzled layout.
For example, a swizzled row-major layout stores would store data as follows:
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
@@ -29,10 +31,11 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
And the associated TritonGPU MLIR
```mlir
#SMEM = #triton_gpu.encoding<{
#SMEM = #triton_gpu.shared_layout<{
vec = 2,
perPhase = 2,
maxPhase = 4
maxPhase = 4,
order = [1, 0]
}>
```
}];
@@ -40,12 +43,13 @@ And the associated TritonGPU MLIR
let parameters = (
ins
// swizzle info
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
);
}
def TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
let mnemonic = "coalesced encoding";
def TritonGPUShardedEncodingAttr : TritonGPU_Attr<"TritonGPUShardedEncoding"> {
let mnemonic = "sharded_layout";
let description = [{
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
@@ -70,7 +74,7 @@ size } ....
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63]
And the associated TritonGPU MLIR
#SMEM = #triton_gpu.encoding<{
#LAYOUT = #triton_gpu.sharded_layout<{
threadTileSize = {2, 2}
blockTileSize = {32, 8}
}>
@@ -81,28 +85,55 @@ And the associated TritonGPU MLIR
let parameters = (
ins
ArrayRefParameter<"unsigned">:$threadTileSize,
ArrayRefParameter<"unsigned">:$blockTileSize,
// TODO: should we rename this as laneTileSize?
ArrayRefParameter<
"unsigned",
/*desc*/"size of a tile that is holded by a thread"
>:$threadTileSize,
ArrayRefParameter<
"unsigned",
"size of the a tile that is holded by a warp"
>:$warpTileSize,
ArrayRefParameter<
"unsigned",
"size of a tile that is holded by a thread block"
>:$blockTileSize,
// // TODO: It seems that we don't need this (because we can re-compute this)
// ArrayRefParameter<"unsigned">:$reptitions,
// fastest-changing axis first
ArrayRefParameter<"unsigned">:$order
ArrayRefParameter<
"unsigned",
"order of axes by the rate of changing"
>:$order
// "AffineMap":$threadOrdering,
// "AffineMap":warpOrdering,
// "AffineMap":$blockOrdering,
);
// let genVerifyDecl = 1;
}
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
let mnemonic = "mma encoding";
let mnemonic = "mma_layout";
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
let parameters = (
ins
// only used by Volta mma.884
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
// aka shapeOfInstr (e.g., {16,8,16})
ArrayRefParameter<"unsigned">:$shapePerWarp,
// TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout)
ArrayRefParameter<"unsigned">:$warpPerTile,
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
ArrayRefParameter<"unsigned">:$shapePerTile,
// TODO: should Distributed layout also
ArrayRefParameter<"unsigned">:$reptitions,
ArrayRefParameter<"unsigned">:$contigPerThread
// "AffineMap":$warpOrdering,
// "AffineMap":$blockOrdering
);
// let genVerifyDecl = 1;

View File

@@ -4,46 +4,181 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
using namespace mlir;
using namespace mlir::triton::gpu;
// parse an array of integers
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
SmallVector<unsigned, 2> &res,
StringRef desc) {
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(), "expected an array for ")
<< desc;
return failure();
}
for (Attribute i : arrayAttr) {
auto intAttr = i.dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
<< desc;
return failure();
}
res.push_back(intAttr.getUInt());
}
return success();
};
//===----------------------------------------------------------------------===//
// Attribute methods
//===----------------------------------------------------------------------===//
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
mlir::Attribute
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
llvm_unreachable("Not implemented");
Attribute
TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned, 2> threadTileSize;
SmallVector<unsigned, 2> warpTileSize;
SmallVector<unsigned, 2> blockTileSize;
SmallVector<unsigned, 2> order;
// parse an array of integers
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
// SmallVector<unsigned, 2> &res,
// StringRef desc) -> LogicalResult {
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
// if (!arrayAttr) {
// parser.emitError(parser.getNameLoc(), "expected an array for ")
// << desc;
// return failure();
// }
// for (Attribute i : arrayAttr) {
// auto intAttr = i.dyn_cast<IntegerAttr>();
// if (!intAttr) {
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
// << desc;
// return failure();
// }
// res.push_back(intAttr.getUInt());
// }
// return success();
// };
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "threadTileSize") {
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
return {};
} else if (attr.getName() == "warpTileSize") {
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
return {};
} else if (attr.getName() == "blockTileSize") {
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUShardedEncodingAttr>(parser.getContext(),
threadTileSize,
warpTileSize,
blockTileSize,
order);
}
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<"
<< "threadTileSize = " << getThreadTileSize()
<< ", blockTileSize = " << getBlockTileSize()
<< ", order = " << getOrder()
<< "threadTileSize = [" << getThreadTileSize() << "]"
<< ", warpTileSize = [" << getWarpTileSize() << "]"
<< ", blockTileSize = [" << getBlockTileSize() << "]"
<< ", order = [" << getOrder() << "]"
<< ">";
}
mlir::Attribute
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
Attribute
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
llvm_unreachable("Not implemented");
}
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
}
mlir::Attribute
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
llvm_unreachable("Not implemented");
Attribute
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
// Parse the data as a dictionary
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned vec = 0;
unsigned perPhase = 0;
unsigned maxPhase = 0;
SmallVector<unsigned, 2> order;
auto parseUInt = [&parser](const NamedAttribute &attr,
unsigned &value,
StringRef desc) -> LogicalResult {
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
return failure();
}
value = intAttr.getUInt();
return success();
};
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "vec") {
if (parseUInt(attr, vec, "vec").failed())
return {};
} else if (attr.getName() == "perPhase") {
if (parseUInt(attr, perPhase, "perPhase").failed())
return {};
} else if (attr.getName() == "maxPhase") {
if (parseUInt(attr, maxPhase, "maxPhase").failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
vec,
perPhase,
maxPhase,
order);
}
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<"
// << "threadTileSize = " << getThreadTileSize()
// << ", blockTileSize = " << getBlockTileSize()
// << ", order = " << getOrder()
<< "vec = " << getVec()
<< ", perPhase = " << getPerPhase()
<< ", order = [" << getOrder() << "]"
<< ">";
}
@@ -92,9 +227,9 @@ static Type getPointeeType(Type type) {
// verify TritonGPU ops
mlir::LogicalResult
TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op,
mlir::NamedAttribute attr) {
LogicalResult
TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
// TODO: fill this.
return success();
}

View File

@@ -35,6 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
llvm::SmallVector<unsigned> blockTileSize(rank);
llvm::SmallVector<unsigned> order(rank);
int remainingThreads = numThreads;
@@ -45,8 +46,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
remainingThreads /= blockTileSize[dim];
// TODO: will we need repetition?
}
Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get(
context, threadTileSize, blockTileSize, order);
Attribute encoding = triton::gpu::TritonGPUShardedEncodingAttr::get(
context, threadTileSize, warpTileSize, blockTileSize, order);
return RankedTensorType::get(shape, elementType, encoding);
});

View File

@@ -50,7 +50,7 @@ private:
if (!encoding)
return dotOp.emitError() << name << " should have encoding";
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
!encoding.isa<triton::gpu::TritonGPUDistributedEncodingAttr>())
!encoding.isa<triton::gpu::TritonGPUShardedEncodingAttr>())
return dotOp.emitError() << name << " should be of distributed layout";
if (name == 'c')
cLayout = encoding;