More progress on Triton=>TritonGPU conversion (works for matmul)

This commit is contained in:
Yan Da
2022-05-09 21:19:53 +08:00
parent 0c5319eed9
commit 96876a46d1
3 changed files with 64 additions and 32 deletions

View File

@@ -4,8 +4,8 @@
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
// #include "mlir/IR/BlockAndValueMapping.h"
#include "../PassDetail.h"
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
using namespace mlir;
using namespace mlir::triton;
@@ -155,9 +155,31 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, adaptor.a(), adaptor.b(), adaptor.c(), adaptor.allowTF32()
// a & b must be of smem layout
auto aType = adaptor.a().getType().cast<RankedTensorType>();
auto bType = adaptor.b().getType().cast<RankedTensorType>();
Attribute aEncoding = aType.getEncoding();
Attribute bEncoding = bType.getEncoding();
if (!aEncoding || !bEncoding)
return failure();
Value a = adaptor.a();
Value b = adaptor.b();
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
);
// auto newDot = rewriter.create<triton::DotOp>(op.getLoc(), retType,
// a, b, adaptor.c(), adaptor.allowTF32());
// rewriter.replaceOp(op, {newDot});
return success();
}
};
@@ -182,7 +204,7 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::StoreOp>(
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
);
return success();
@@ -220,26 +242,24 @@ void populateTritonPatterns(
//
// SCF patterns
//
// This is borrowed from ConvertForOpTypes in
// SCF/Transforms/StructuralTypeConversions.cpp
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
// Ref: ConvertForOpTypes
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> newResultTypes;
for (Type type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
newOp.getLoopBody().end());
// Now, update all the types.
// Convert the type of the entry block of the ForOp's body.
// Convert the types of block arguments within the given region. This
// replaces each block with a new block containing the updated signature. The
// entry block may have a special conversion if `entryConversion` is
// provided. On success, the new entry block to the region is returned for
// convenience. Otherwise, failure is returned.
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
*getTypeConverter()))) {
return rewriter.notifyMatchFailure(op, "could not convert body types");
@@ -248,11 +268,17 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
// a BlockAndValueMapping, but this seems a bit more direct.
newOp->setOperands(adaptor.getOperands());
// Update the result types to the new converted types.
SmallVector<Type> newResultTypes;
for (Type type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
rewriter.replaceOp(op, newOp.getResults());
return success();
return success();
}
@@ -277,8 +303,7 @@ void populateSCFPatterns(
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext();
patterns.add<SCFForPattern,
SCFYieldPattern
patterns.add<SCFYieldPattern, SCFForPattern
>(typeConverter, context);
}