Merge branch 'mlir-rewrite' of https://github.com/daadaada/mlir-rewrite into mlir-rewrite
This commit is contained in:
@@ -8,12 +8,14 @@ class TritonGPU_Attr<string name, list<Trait> traits = []>
|
|||||||
: AttrDef<TritonGPU_Dialect, name, traits>;
|
: AttrDef<TritonGPU_Dialect, name, traits>;
|
||||||
|
|
||||||
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
def TritonGPUSharedEncodingAttr : TritonGPU_Attr<"TritonGPUSharedEncoding"> {
|
||||||
let mnemonic = "shared (memory) encoding";
|
let mnemonic = "shared_layout";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
An encoding for tensors whose elements may be simultaneously accessed by different warps in the programs, via shared memory.
|
An encoding for tensors whose elements may be simultaneously accessed by
|
||||||
|
different warps in the programs, via shared memory.
|
||||||
|
|
||||||
In order to avoid shared memory bank conflicts, elements may be stored in a swizzled layout.
|
In order to avoid shared memory bank conflicts, elements may be stored in a
|
||||||
|
swizzled layout.
|
||||||
For example, a swizzled row-major layout stores would store data as follows:
|
For example, a swizzled row-major layout stores would store data as follows:
|
||||||
|
|
||||||
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
A_{0, 0} A_{0, 1} A_{0, 2} A_{0, 3} ... [phase 0] \ per_phase = 2
|
||||||
@@ -29,10 +31,11 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] /
|
|||||||
And the associated TritonGPU MLIR
|
And the associated TritonGPU MLIR
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
#SMEM = #triton_gpu.encoding<{
|
#SMEM = #triton_gpu.shared_layout<{
|
||||||
vec = 2,
|
vec = 2,
|
||||||
perPhase = 2,
|
perPhase = 2,
|
||||||
maxPhase = 4
|
maxPhase = 4,
|
||||||
|
order = [1, 0]
|
||||||
}>
|
}>
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
@@ -40,12 +43,13 @@ And the associated TritonGPU MLIR
|
|||||||
let parameters = (
|
let parameters = (
|
||||||
ins
|
ins
|
||||||
// swizzle info
|
// swizzle info
|
||||||
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase
|
"unsigned":$vec, "unsigned":$perPhase, "unsigned":$maxPhase,
|
||||||
|
ArrayRefParameter<"unsigned", "order of axes by the rate of changing">:$order
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
def TritonGPUDistributedEncodingAttr : TritonGPU_Attr<"TritonGPUDistributedEncoding"> {
|
def TritonGPUShardedEncodingAttr : TritonGPU_Attr<"TritonGPUShardedEncoding"> {
|
||||||
let mnemonic = "coalesced encoding";
|
let mnemonic = "sharded_layout";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
|
An encoding where each warp owns a contiguous portion of the target tensor. This is typically the kind of data layout
|
||||||
@@ -70,7 +74,7 @@ size } ....
|
|||||||
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63]
|
A_{63, 0}[T60] A_{63, 1}[T60] ... A_{63, 6}[T63] A_{63, 7}[T63] A_{63, 8}[T60] A_{63, 9}[T60] ... A_{63, 14}[T63] A_{63, 15}[T63]
|
||||||
|
|
||||||
And the associated TritonGPU MLIR
|
And the associated TritonGPU MLIR
|
||||||
#SMEM = #triton_gpu.encoding<{
|
#LAYOUT = #triton_gpu.sharded_layout<{
|
||||||
threadTileSize = {2, 2}
|
threadTileSize = {2, 2}
|
||||||
blockTileSize = {32, 8}
|
blockTileSize = {32, 8}
|
||||||
}>
|
}>
|
||||||
@@ -81,28 +85,55 @@ And the associated TritonGPU MLIR
|
|||||||
|
|
||||||
let parameters = (
|
let parameters = (
|
||||||
ins
|
ins
|
||||||
ArrayRefParameter<"unsigned">:$threadTileSize,
|
// TODO: should we rename this as laneTileSize?
|
||||||
ArrayRefParameter<"unsigned">:$blockTileSize,
|
ArrayRefParameter<
|
||||||
|
"unsigned",
|
||||||
|
/*desc*/"size of a tile that is holded by a thread"
|
||||||
|
>:$threadTileSize,
|
||||||
|
ArrayRefParameter<
|
||||||
|
"unsigned",
|
||||||
|
"size of the a tile that is holded by a warp"
|
||||||
|
>:$warpTileSize,
|
||||||
|
ArrayRefParameter<
|
||||||
|
"unsigned",
|
||||||
|
"size of a tile that is holded by a thread block"
|
||||||
|
>:$blockTileSize,
|
||||||
|
// // TODO: It seems that we don't need this (because we can re-compute this)
|
||||||
|
// ArrayRefParameter<"unsigned">:$reptitions,
|
||||||
// fastest-changing axis first
|
// fastest-changing axis first
|
||||||
ArrayRefParameter<"unsigned">:$order
|
ArrayRefParameter<
|
||||||
|
"unsigned",
|
||||||
|
"order of axes by the rate of changing"
|
||||||
|
>:$order
|
||||||
|
// "AffineMap":$threadOrdering,
|
||||||
|
// "AffineMap":warpOrdering,
|
||||||
|
// "AffineMap":$blockOrdering,
|
||||||
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// let genVerifyDecl = 1;
|
// let genVerifyDecl = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
def TritonGPUMmaEncodingAttr : TritonGPU_Attr<"TritonGPUMmaEncoding"> {
|
||||||
let mnemonic = "mma encoding";
|
let mnemonic = "mma_layout";
|
||||||
|
|
||||||
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
|
let description = [{TODO: I think we may be able to implement it as a special-case of Distributed encoding with maybe one more warpTileSize attribute!}];
|
||||||
|
|
||||||
let parameters = (
|
let parameters = (
|
||||||
ins
|
ins
|
||||||
|
// only used by Volta mma.884
|
||||||
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
ArrayRefParameter<"unsigned">:$fragmentPerWarp,
|
||||||
|
// aka shapeOfInstr (e.g., {16,8,16})
|
||||||
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
ArrayRefParameter<"unsigned">:$shapePerWarp,
|
||||||
|
// TODO: should we rename this as warpTileSize? (consistent naming with Distributed layout)
|
||||||
ArrayRefParameter<"unsigned">:$warpPerTile,
|
ArrayRefParameter<"unsigned">:$warpPerTile,
|
||||||
|
// TODO: should we rename this as blockTileSize? (consistent naming with Distributed layout)
|
||||||
ArrayRefParameter<"unsigned">:$shapePerTile,
|
ArrayRefParameter<"unsigned">:$shapePerTile,
|
||||||
|
// TODO: should Distributed layout also
|
||||||
ArrayRefParameter<"unsigned">:$reptitions,
|
ArrayRefParameter<"unsigned">:$reptitions,
|
||||||
ArrayRefParameter<"unsigned">:$contigPerThread
|
ArrayRefParameter<"unsigned">:$contigPerThread
|
||||||
|
// "AffineMap":$warpOrdering,
|
||||||
|
// "AffineMap":$blockOrdering
|
||||||
);
|
);
|
||||||
|
|
||||||
// let genVerifyDecl = 1;
|
// let genVerifyDecl = 1;
|
||||||
|
@@ -4,46 +4,181 @@
|
|||||||
|
|
||||||
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
using namespace mlir::triton::gpu;
|
using namespace mlir::triton::gpu;
|
||||||
|
|
||||||
|
// parse an array of integers
|
||||||
|
static LogicalResult parseIntArrayAttr(AsmParser &parser,
|
||||||
|
const NamedAttribute &attr,
|
||||||
|
SmallVector<unsigned, 2> &res,
|
||||||
|
StringRef desc) {
|
||||||
|
auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||||
|
if (!arrayAttr) {
|
||||||
|
parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||||
|
<< desc;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
for (Attribute i : arrayAttr) {
|
||||||
|
auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||||
|
if (!intAttr) {
|
||||||
|
parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||||
|
<< desc;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
res.push_back(intAttr.getUInt());
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Attribute methods
|
// Attribute methods
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
#define GET_ATTRDEF_CLASSES
|
#define GET_ATTRDEF_CLASSES
|
||||||
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||||
|
|
||||||
mlir::Attribute
|
Attribute
|
||||||
TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
|
TritonGPUShardedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
llvm_unreachable("Not implemented");
|
if (parser.parseLess().failed())
|
||||||
|
return {};
|
||||||
|
// Parse the data as a dictionary
|
||||||
|
DictionaryAttr dict;
|
||||||
|
if (parser.parseAttribute(dict).failed())
|
||||||
|
return {};
|
||||||
|
if (parser.parseGreater().failed())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
SmallVector<unsigned, 2> threadTileSize;
|
||||||
|
SmallVector<unsigned, 2> warpTileSize;
|
||||||
|
SmallVector<unsigned, 2> blockTileSize;
|
||||||
|
SmallVector<unsigned, 2> order;
|
||||||
|
|
||||||
|
// parse an array of integers
|
||||||
|
// auto parseIntArrayAttr = [&parser](const NamedAttribute &attr,
|
||||||
|
// SmallVector<unsigned, 2> &res,
|
||||||
|
// StringRef desc) -> LogicalResult {
|
||||||
|
// auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
|
||||||
|
// if (!arrayAttr) {
|
||||||
|
// parser.emitError(parser.getNameLoc(), "expected an array for ")
|
||||||
|
// << desc;
|
||||||
|
// return failure();
|
||||||
|
// }
|
||||||
|
// for (Attribute i : arrayAttr) {
|
||||||
|
// auto intAttr = i.dyn_cast<IntegerAttr>();
|
||||||
|
// if (!intAttr) {
|
||||||
|
// parser.emitError(parser.getNameLoc(), "expected an integer value in ")
|
||||||
|
// << desc;
|
||||||
|
// return failure();
|
||||||
|
// }
|
||||||
|
// res.push_back(intAttr.getUInt());
|
||||||
|
// }
|
||||||
|
// return success();
|
||||||
|
// };
|
||||||
|
|
||||||
|
for (const NamedAttribute &attr : dict) {
|
||||||
|
if (attr.getName() == "threadTileSize") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, threadTileSize, "thread tile size").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "warpTileSize") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, warpTileSize, "warp tile size").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "blockTileSize") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, blockTileSize, "block tile size").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "order") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||||
|
return {};
|
||||||
|
} else {
|
||||||
|
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||||
|
<< attr.getName().strref();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parser.getChecked<TritonGPUShardedEncodingAttr>(parser.getContext(),
|
||||||
|
threadTileSize,
|
||||||
|
warpTileSize,
|
||||||
|
blockTileSize,
|
||||||
|
order);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
void TritonGPUShardedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
printer << "<"
|
printer << "<"
|
||||||
<< "threadTileSize = " << getThreadTileSize()
|
<< "threadTileSize = [" << getThreadTileSize() << "]"
|
||||||
<< ", blockTileSize = " << getBlockTileSize()
|
<< ", warpTileSize = [" << getWarpTileSize() << "]"
|
||||||
<< ", order = " << getOrder()
|
<< ", blockTileSize = [" << getBlockTileSize() << "]"
|
||||||
|
<< ", order = [" << getOrder() << "]"
|
||||||
<< ">";
|
<< ">";
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Attribute
|
Attribute
|
||||||
TritonGPUMmaEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
TritonGPUMmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
llvm_unreachable("Not implemented");
|
llvm_unreachable("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUMmaEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
void TritonGPUMmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||||
llvm_unreachable("Not implemented");
|
llvm_unreachable("Not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Attribute
|
Attribute
|
||||||
TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
TritonGPUSharedEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||||
llvm_unreachable("Not implemented");
|
if (parser.parseLess().failed())
|
||||||
|
return {};
|
||||||
|
// Parse the data as a dictionary
|
||||||
|
DictionaryAttr dict;
|
||||||
|
if (parser.parseAttribute(dict).failed())
|
||||||
|
return {};
|
||||||
|
if (parser.parseGreater().failed())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
unsigned vec = 0;
|
||||||
|
unsigned perPhase = 0;
|
||||||
|
unsigned maxPhase = 0;
|
||||||
|
SmallVector<unsigned, 2> order;
|
||||||
|
|
||||||
|
auto parseUInt = [&parser](const NamedAttribute &attr,
|
||||||
|
unsigned &value,
|
||||||
|
StringRef desc) -> LogicalResult {
|
||||||
|
auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
|
||||||
|
if (!intAttr) {
|
||||||
|
parser.emitError(parser.getNameLoc(), "expected an integer ") << desc;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
value = intAttr.getUInt();
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const NamedAttribute &attr : dict) {
|
||||||
|
if (attr.getName() == "vec") {
|
||||||
|
if (parseUInt(attr, vec, "vec").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "perPhase") {
|
||||||
|
if (parseUInt(attr, perPhase, "perPhase").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "maxPhase") {
|
||||||
|
if (parseUInt(attr, maxPhase, "maxPhase").failed())
|
||||||
|
return {};
|
||||||
|
} else if (attr.getName() == "order") {
|
||||||
|
if (parseIntArrayAttr(parser, attr, order, "order").failed())
|
||||||
|
return {};
|
||||||
|
} else {
|
||||||
|
parser.emitError(parser.getNameLoc(), "unexpected key: ")
|
||||||
|
<< attr.getName().strref();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parser.getChecked<TritonGPUSharedEncodingAttr>(parser.getContext(),
|
||||||
|
vec,
|
||||||
|
perPhase,
|
||||||
|
maxPhase,
|
||||||
|
order);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
void TritonGPUSharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||||
printer << "<"
|
printer << "<"
|
||||||
// << "threadTileSize = " << getThreadTileSize()
|
<< "vec = " << getVec()
|
||||||
// << ", blockTileSize = " << getBlockTileSize()
|
<< ", perPhase = " << getPerPhase()
|
||||||
// << ", order = " << getOrder()
|
<< ", order = [" << getOrder() << "]"
|
||||||
<< ">";
|
<< ">";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,9 +227,9 @@ static Type getPointeeType(Type type) {
|
|||||||
|
|
||||||
|
|
||||||
// verify TritonGPU ops
|
// verify TritonGPU ops
|
||||||
mlir::LogicalResult
|
LogicalResult
|
||||||
TritonGPUDialect::verifyOperationAttribute(mlir::Operation *op,
|
TritonGPUDialect::verifyOperationAttribute(Operation *op,
|
||||||
mlir::NamedAttribute attr) {
|
NamedAttribute attr) {
|
||||||
// TODO: fill this.
|
// TODO: fill this.
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@@ -35,6 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
// Now we assume:
|
// Now we assume:
|
||||||
// contiguous = 1, order = 0, 1, 2, ...,
|
// contiguous = 1, order = 0, 1, 2, ...,
|
||||||
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
llvm::SmallVector<unsigned> threadTileSize(rank, 1); // naive layout
|
||||||
|
llvm::SmallVector<unsigned> warpTileSize(rank, 1);
|
||||||
llvm::SmallVector<unsigned> blockTileSize(rank);
|
llvm::SmallVector<unsigned> blockTileSize(rank);
|
||||||
llvm::SmallVector<unsigned> order(rank);
|
llvm::SmallVector<unsigned> order(rank);
|
||||||
int remainingThreads = numThreads;
|
int remainingThreads = numThreads;
|
||||||
@@ -45,8 +46,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
remainingThreads /= blockTileSize[dim];
|
remainingThreads /= blockTileSize[dim];
|
||||||
// TODO: will we need repetition?
|
// TODO: will we need repetition?
|
||||||
}
|
}
|
||||||
Attribute encoding = triton::gpu::TritonGPUDistributedEncodingAttr::get(
|
Attribute encoding = triton::gpu::TritonGPUShardedEncodingAttr::get(
|
||||||
context, threadTileSize, blockTileSize, order);
|
context, threadTileSize, warpTileSize, blockTileSize, order);
|
||||||
return RankedTensorType::get(shape, elementType, encoding);
|
return RankedTensorType::get(shape, elementType, encoding);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@@ -50,7 +50,7 @@ private:
|
|||||||
if (!encoding)
|
if (!encoding)
|
||||||
return dotOp.emitError() << name << " should have encoding";
|
return dotOp.emitError() << name << " should have encoding";
|
||||||
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
if (!encoding.isa<triton::gpu::TritonGPUMmaEncodingAttr>() &&
|
||||||
!encoding.isa<triton::gpu::TritonGPUDistributedEncodingAttr>())
|
!encoding.isa<triton::gpu::TritonGPUShardedEncodingAttr>())
|
||||||
return dotOp.emitError() << name << " should be of distributed layout";
|
return dotOp.emitError() << name << " should be of distributed layout";
|
||||||
if (name == 'c')
|
if (name == 'c')
|
||||||
cLayout = encoding;
|
cLayout = encoding;
|
||||||
|
Reference in New Issue
Block a user