[Triton-MLIR] tt.dot
operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com> Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -57,6 +57,8 @@ unsigned getElemsPerThread(Type type) {
|
||||
return mmaLayout.getElemsPerThread(shape);
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
return sharedLayout.getElemsPerThread(shape);
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return dotLayout.getElemsPerThread(shape);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
@@ -73,6 +75,27 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
assert(mmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
return SmallVector<unsigned>{2, 2};
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(parentMmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
auto parentShapePerCTA = getShapePerCTA(parentLayout);
|
||||
auto opIdx = dotLayout.getOpIdx();
|
||||
if (opIdx == 0) {
|
||||
return {2, 4};
|
||||
} else if (opIdx == 1) {
|
||||
return {4, 1};
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
assert(0 && "getSizePerThread not implemented");
|
||||
return {};
|
||||
@@ -124,6 +147,25 @@ SmallVector<unsigned> getShapePerCTA(const Attribute &layout) {
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
16 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
assert(0 && "Unexpected MMA layout version found");
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
if (auto parentMmaLayout = parentLayout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(parentMmaLayout.getVersion() == 2 &&
|
||||
"mmaLayout version = 1 is not implemented yet");
|
||||
auto parentShapePerCTA = getShapePerCTA(parentLayout);
|
||||
auto opIdx = dotLayout.getOpIdx();
|
||||
if (opIdx == 0) {
|
||||
return {parentShapePerCTA[0], 16};
|
||||
} else if (opIdx == 1) {
|
||||
return {16, parentShapePerCTA[1]};
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTA");
|
||||
}
|
||||
@@ -136,6 +178,8 @@ SmallVector<unsigned> getOrder(const Attribute &layout) {
|
||||
blockedLayout.getOrder().end());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return SmallVector<unsigned>{1, 0};
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return SmallVector<unsigned>{1, 0};
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
@@ -300,6 +344,12 @@ unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned
|
||||
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
assert(0 && "DotOPerandEncodingAttr::getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Blocked Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -471,6 +521,30 @@ void SharedEncodingAttr::print(AsmPrinter &printer) const {
|
||||
<< "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DotOperand Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSliceAsyncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -530,30 +604,6 @@ void printInsertSliceAsyncOp(OpAsmPrinter &printer,
|
||||
printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DotOperand Encoding
|
||||
//===----------------------------------------------------------------------===//
|
||||
Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseLess().failed())
|
||||
return {};
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs).failed())
|
||||
return {};
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent() << "}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ASM Interface (i.e.: alias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -594,21 +644,32 @@ struct TritonGPUInferLayoutInterface
|
||||
|
||||
LogicalResult
|
||||
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
|
||||
Attribute &resultEncoding) const override {
|
||||
Attribute &resultEncoding,
|
||||
Optional<Location> location) const override {
|
||||
auto sliceEncoding = operandEncoding.dyn_cast<SliceEncodingAttr>();
|
||||
if (!sliceEncoding) {
|
||||
llvm::report_fatal_error(
|
||||
"ExpandDimsOp operand encoding must be SliceEncodingAttr");
|
||||
return failure();
|
||||
}
|
||||
if (sliceEncoding.getDim() != axis) {
|
||||
llvm::report_fatal_error(
|
||||
"Incompatible slice dimension for ExpandDimsOp operand");
|
||||
return failure();
|
||||
}
|
||||
if (!sliceEncoding)
|
||||
return emitOptionalError(
|
||||
location, "ExpandDimsOp operand encoding must be SliceEncodingAttr");
|
||||
if (sliceEncoding.getDim() != axis)
|
||||
return emitOptionalError(
|
||||
location, "Incompatible slice dimension for ExpandDimsOp operand");
|
||||
resultEncoding = sliceEncoding.getParent();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
|
||||
Attribute retEncoding,
|
||||
Optional<Location> location) const override {
|
||||
if (auto dotOpEnc = operandEncoding.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
if (opIdx != dotOpEnc.getOpIdx())
|
||||
return emitOptionalError(location, "Wrong opIdx");
|
||||
if (retEncoding != dotOpEnc.getParent())
|
||||
return emitOptionalError(location, "Incompatible parent encoding");
|
||||
} else
|
||||
return emitOptionalError(
|
||||
location, "Dot's a/b's encoding should be of DotOperandEncodingAttr");
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
|
Reference in New Issue
Block a user