[Triton-MLIR] tt.dot operands now must have DotOperand layout; also added prefetch pass prototype (#712)

Co-authored-by: Jokeren <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
Co-authored-by: Superjomn <yanchunwei@outlook.com>
This commit is contained in:
Da Yan
2022-11-10 13:57:27 +08:00
committed by GitHub
parent 8832e32683
commit 4946167241
29 changed files with 1227 additions and 507 deletions

View File

@@ -12,21 +12,13 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
static bool isSharedLayout(Value v) {
if (auto tensorType = v.getType().dyn_cast<RankedTensorType>()) {
Attribute encoding = tensorType.getEncoding();
return encoding.isa<triton::gpu::SharedEncodingAttr>();
}
return false;
}
namespace {
#include "TritonGPUCombine.inc"
@@ -37,7 +29,7 @@ namespace {
// convert(blocked, dot_operand) ->
// convert(blocked, mma) + convert(mma, dot_operand)
// if this value is itself the result of a dot operation
// this is a hueiristics to accomodate some pattern seen in fused attention
// this is a heuristic to accomodate some pattern seen in fused attention
// kernels.
// TODO: replace this by something more generic, i.e. layout-aware CSE
class DecomposeDotOperand : public mlir::RewritePattern {
@@ -59,9 +51,8 @@ public:
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto tmpType =
RankedTensorType::get(dstType.getShape(), dstType.getElementType(),
dstType.getEncoding()
.cast<triton::gpu::DotOperandEncodingAttr>()
.getParent());
triton::gpu::SharedEncodingAttr::get(
op->getContext(), 1, 1, 1, {1, 0}));
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
auto newConvert = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -87,11 +78,12 @@ public:
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accomodate fused attention
if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return mlir::failure();
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
// return mlir::failure();
// convert to the same layout -- we can delete
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
@@ -122,8 +114,8 @@ public:
rewriter.replaceOpWithNewOp<triton::gpu::InsertSliceAsyncOp>(
op, newType, insert_slice.src(), newArg.getResult(),
insert_slice.index(), insert_slice.mask(), insert_slice.other(),
insert_slice.cache(), insert_slice.evict(),
insert_slice.isVolatile(), insert_slice.axis());
insert_slice.cache(), insert_slice.evict(), insert_slice.isVolatile(),
insert_slice.axis());
return mlir::success();
}
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
@@ -133,7 +125,10 @@ public:
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
op->getResult(0).getType().cast<RankedTensorType>().getEncoding());
auto resType = op->getResult(0).getType().cast<RankedTensorType>();
auto origResType = op->getResult(0).getType().cast<RankedTensorType>();
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
@@ -148,8 +143,21 @@ public:
extract_slice.static_strides());
return mlir::success();
}
// cvt(type2, x)
if (llvm::isa<triton::gpu::ConvertLayoutOp>(arg)) {
auto argType = arg->getOperand(0).getType().cast<RankedTensorType>();
if (arg->getOperand(0).getDefiningOp() &&
!argType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
srcType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
!dstType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
return mlir::failure();
}
auto srcShared =
srcType.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
if (srcShared && srcShared.getVec() > 1)
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), arg->getOperand(0));
return mlir::success();
@@ -253,8 +261,8 @@ public:
if (!op)
return mlir::failure();
// we don't want to rematerialize any conversion to/from shared
if (isSharedLayout(cvt->getResults()[0]) ||
isSharedLayout(cvt->getOperand(0)))
if (isSharedEncoding(cvt->getResults()[0]) ||
isSharedEncoding(cvt->getOperand(0)))
return mlir::failure();
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accomodate fused attention
@@ -325,7 +333,6 @@ public:
for (Operation *op : tmp)
sortedValues.push_back(op->getResult(0));
// llvm::outs() << "----\n";
BlockAndValueMapping mapping;
for (Value currOperand : sortedValues) {
// unpack information
@@ -346,7 +353,6 @@ public:
newOperand->moveAfter(currOperation);
mapping.map(currOperand, newOperand);
}
// llvm::outs() << cvt->getParentOfType<mlir::FuncOp>() << "\n";
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
return mlir::success();
}
@@ -356,8 +362,6 @@ public:
//
// -----------------------------------------------------------------------------
// int test = 0;
class MoveConvertOutOfLoop : public mlir::RewritePattern {
public:
MoveConvertOutOfLoop(mlir::MLIRContext *context)
@@ -435,9 +439,25 @@ public:
auto users = iterArg.value().getUsers();
// check first condition
SetVector<Type> cvtTargetTypes;
for (auto user : users)
if (isa<triton::gpu::ConvertLayoutOp>(user))
cvtTargetTypes.insert(user->getResults()[0].getType());
for (auto user : users) {
if (isa<triton::gpu::ConvertLayoutOp>(user)) {
auto newType =
user->getResults()[0].getType().cast<RankedTensorType>();
auto oldType = user->getOperand(0).getType().cast<RankedTensorType>();
if (oldType.getEncoding().isa<triton::gpu::SharedEncodingAttr>() &&
newType.getEncoding()
.isa<triton::gpu::DotOperandEncodingAttr>()) {
continue;
}
if (newType.getEncoding().isa<triton::gpu::SharedEncodingAttr>()) {
if (newType.getEncoding()
.cast<triton::gpu::SharedEncodingAttr>()
.getVec() == 1)
continue;
}
cvtTargetTypes.insert(newType);
}
}
if (cvtTargetTypes.size() != 1)
continue;
// TODO: check second condition
@@ -446,6 +466,7 @@ public:
continue;
}
// check
// llvm::outs() << "replacing " << iterArg.index() << "\n";
for (auto op : iterArg.value().getUsers()) {
auto cvt = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (!cvt)
@@ -597,10 +618,23 @@ public:
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
// convert output
Value a = dotOp.a();
Value b = dotOp.b();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, dotOp.getOperand(0), dotOp.getOperand(1),
newAcc, dotOp.allowTF32(), dotOp.transA(), dotOp.transB());
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.allowTF32(),
dotOp.transA(), dotOp.transB());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
@@ -623,7 +657,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<SimplifyConversion>(context);
patterns.add<DecomposeDotOperand>(context);
// patterns.add<DecomposeDotOperand>(context);
patterns.add<RematerializeBackward>(context);
patterns.add<RematerializeForward>(context);
patterns.add<MoveConvertOutOfLoop>(context);