[Triton-MLIR] tt.dot
operands now must have DotOperand layout; also added prefetch pass prototype (#712)
Co-authored-by: Jokeren <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com> Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
@@ -191,6 +191,20 @@ mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
|
||||
// type is the same as the accumulator
|
||||
auto accTy = operands[2].getType().cast<RankedTensorType>();
|
||||
inferredReturnTypes.push_back(accTy);
|
||||
|
||||
// verify encodings
|
||||
auto aEnc = operands[0].getType().cast<RankedTensorType>().getEncoding();
|
||||
auto bEnc = operands[1].getType().cast<RankedTensorType>().getEncoding();
|
||||
auto retEnc = accTy.getEncoding();
|
||||
if (aEnc) {
|
||||
assert(bEnc);
|
||||
Dialect &dialect = aEnc.getDialect();
|
||||
auto interface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed())
|
||||
return mlir::failure();
|
||||
if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed())
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -244,7 +258,7 @@ OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
||||
//-- ExpandDimsOp --
|
||||
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
MLIRContext *context, Optional<Location> loc, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
@@ -260,11 +274,9 @@ mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
|
||||
Dialect &dialect = argEncoding.getDialect();
|
||||
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
|
||||
if (inferLayoutInterface
|
||||
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding)
|
||||
.failed()) {
|
||||
llvm::report_fatal_error("failed to infer layout for ExpandDimsOp");
|
||||
return mlir::failure();
|
||||
}
|
||||
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc)
|
||||
.failed())
|
||||
return emitOptionalError(loc, "failed to infer layout for ExpandDimsOp");
|
||||
}
|
||||
// create type
|
||||
auto argEltTy = argTy.getElementType();
|
||||
|
@@ -48,7 +48,8 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
<< " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
<< *op << " doesn't follow the rule (" << numElements << ")"
|
||||
<< " elements";
|
||||
}
|
||||
}
|
||||
for (auto opType : op->getResultTypes()) {
|
||||
@@ -62,7 +63,8 @@ mlir::LogicalResult mlir::OpTrait::impl::verifyTensorSize(Operation *op) {
|
||||
<< " has more than that";
|
||||
if ((numElements & (numElements - 1)) != 0)
|
||||
return op->emitError("Number of elements must be power-of-two, but ")
|
||||
<< *op << " doesn't follow the rule";
|
||||
<< *op << " doesn't follow the rule (" << numElements << ")"
|
||||
<< " elements";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
|
Reference in New Issue
Block a user