More progress on Triton=>TritonGPU conversion (works for matmul)
This commit is contained in:
@@ -4,8 +4,8 @@
|
|||||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
// #include "mlir/IR/BlockAndValueMapping.h"
|
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
|
#include <llvm-6.0/llvm/Support/ErrorHandling.h>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::triton;
|
using namespace mlir::triton;
|
||||||
@@ -155,9 +155,31 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
|||||||
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Type retType = getTypeConverter()->convertType(op.getType());
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<triton::DotOp>(
|
// a & b must be of smem layout
|
||||||
op, retType, adaptor.a(), adaptor.b(), adaptor.c(), adaptor.allowTF32()
|
auto aType = adaptor.a().getType().cast<RankedTensorType>();
|
||||||
|
auto bType = adaptor.b().getType().cast<RankedTensorType>();
|
||||||
|
Attribute aEncoding = aType.getEncoding();
|
||||||
|
Attribute bEncoding = bType.getEncoding();
|
||||||
|
if (!aEncoding || !bEncoding)
|
||||||
|
return failure();
|
||||||
|
Value a = adaptor.a();
|
||||||
|
Value b = adaptor.b();
|
||||||
|
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||||
|
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
|
||||||
|
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
|
||||||
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||||
|
}
|
||||||
|
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||||
|
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1);
|
||||||
|
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
|
||||||
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||||
|
}
|
||||||
|
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||||
|
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
|
||||||
);
|
);
|
||||||
|
// auto newDot = rewriter.create<triton::DotOp>(op.getLoc(), retType,
|
||||||
|
// a, b, adaptor.c(), adaptor.allowTF32());
|
||||||
|
// rewriter.replaceOp(op, {newDot});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -182,7 +204,7 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
|
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
|
||||||
);
|
);
|
||||||
return success();
|
return success();
|
||||||
@@ -220,26 +242,24 @@ void populateTritonPatterns(
|
|||||||
//
|
//
|
||||||
// SCF patterns
|
// SCF patterns
|
||||||
//
|
//
|
||||||
|
// This is borrowed from ConvertForOpTypes in
|
||||||
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
||||||
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||||
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||||
// Ref: ConvertForOpTypes
|
// Ref: ConvertForOpTypes
|
||||||
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
SmallVector<Type> newResultTypes;
|
|
||||||
for (Type type : op.getResultTypes()) {
|
|
||||||
Type newType = typeConverter->convertType(type);
|
|
||||||
if (!newType)
|
|
||||||
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
|
||||||
newResultTypes.push_back(newType);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||||
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
||||||
newOp.getLoopBody().end());
|
newOp.getLoopBody().end());
|
||||||
|
|
||||||
// Now, update all the types.
|
// Now, update all the types.
|
||||||
|
|
||||||
// Convert the type of the entry block of the ForOp's body.
|
// Convert the types of block arguments within the given region. This
|
||||||
|
// replaces each block with a new block containing the updated signature. The
|
||||||
|
// entry block may have a special conversion if `entryConversion` is
|
||||||
|
// provided. On success, the new entry block to the region is returned for
|
||||||
|
// convenience. Otherwise, failure is returned.
|
||||||
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
||||||
*getTypeConverter()))) {
|
*getTypeConverter()))) {
|
||||||
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
||||||
@@ -248,11 +268,17 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
|||||||
// a BlockAndValueMapping, but this seems a bit more direct.
|
// a BlockAndValueMapping, but this seems a bit more direct.
|
||||||
newOp->setOperands(adaptor.getOperands());
|
newOp->setOperands(adaptor.getOperands());
|
||||||
// Update the result types to the new converted types.
|
// Update the result types to the new converted types.
|
||||||
|
SmallVector<Type> newResultTypes;
|
||||||
|
for (Type type : op.getResultTypes()) {
|
||||||
|
Type newType = typeConverter->convertType(type);
|
||||||
|
if (!newType)
|
||||||
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
||||||
|
newResultTypes.push_back(newType);
|
||||||
|
}
|
||||||
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
||||||
std::get<0>(t).setType(std::get<1>(t));
|
std::get<0>(t).setType(std::get<1>(t));
|
||||||
|
|
||||||
rewriter.replaceOp(op, newOp.getResults());
|
rewriter.replaceOp(op, newOp.getResults());
|
||||||
return success();
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -277,8 +303,7 @@ void populateSCFPatterns(
|
|||||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
||||||
) {
|
) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
patterns.add<SCFForPattern,
|
patterns.add<SCFYieldPattern, SCFForPattern
|
||||||
SCFYieldPattern
|
|
||||||
>(typeConverter, context);
|
>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -41,7 +41,11 @@ TritonGPUSharedEncodingAttr::parse(mlir::AsmParser &parser, ::mlir::Type type) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||||
llvm_unreachable("Not implemented");
|
printer << "<"
|
||||||
|
// << "threadTileSize = " << getThreadTileSize()
|
||||||
|
// << ", blockTileSize = " << getBlockTileSize()
|
||||||
|
// << ", order = " << getOrder()
|
||||||
|
<< ">";
|
||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUDialect::initialize() {
|
void TritonGPUDialect::initialize() {
|
||||||
|
@@ -50,9 +50,6 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
// materailizations
|
// materailizations
|
||||||
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
ValueRange inputs, Location loc) {
|
ValueRange inputs, Location loc) {
|
||||||
llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n"
|
|
||||||
<< "in: \n";
|
|
||||||
inputs[0].dyn_cast<BlockArgument>().getOwner()->getParentOp()->getParentOp()->print(llvm::errs());
|
|
||||||
llvm_unreachable("Not implemented");
|
llvm_unreachable("Not implemented");
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
});
|
});
|
||||||
@@ -63,8 +60,8 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
});
|
});
|
||||||
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
|
||||||
ValueRange inputs, Location loc) {
|
ValueRange inputs, Location loc) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
llvm_unreachable("Not implemented");
|
llvm_unreachable("Not implemented");
|
||||||
// llvm::errs() << "Trying to materialize target... " << inputs[0] << "\n";
|
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -75,13 +72,15 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||||
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||||
: ConversionTarget(context), typeConverter(typeConverter) {
|
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||||
|
// TODO: we should also verify ops of TritonGPUDialect
|
||||||
|
addLegalDialect<triton::gpu::TritonGPUDialect>();
|
||||||
|
|
||||||
// Some ops from SCF are illegal
|
// Some ops from SCF are illegal
|
||||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||||
|
|
||||||
addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
addDynamicallyLegalDialect<arith::ArithmeticDialect,
|
||||||
triton::TritonDialect,
|
triton::TritonDialect,
|
||||||
triton::gpu::TritonGPUDialect,
|
|
||||||
StandardOpsDialect,
|
StandardOpsDialect,
|
||||||
scf::SCFDialect>([&](Operation *op) {
|
scf::SCFDialect>([&](Operation *op) {
|
||||||
if (typeConverter.isLegal(op))
|
if (typeConverter.isLegal(op))
|
||||||
@@ -89,14 +88,18 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
|||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
// // We have requirements for the data layouts
|
|
||||||
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
// We have requirements for the data layouts
|
||||||
// Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
addDynamicallyLegalOp<triton::DotOp>([this](triton::DotOp dotOp) -> bool {
|
||||||
// Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
Attribute aEncoding = dotOp.a().getType().cast<RankedTensorType>().getEncoding();
|
||||||
// if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
Attribute bEncoding = dotOp.b().getType().cast<RankedTensorType>().getEncoding();
|
||||||
// bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
if (aEncoding && aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>() &&
|
||||||
// return true;
|
bEncoding && bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>())
|
||||||
// return false;
|
return true;
|
||||||
// });
|
// TODO: we should delete this
|
||||||
|
if (this->typeConverter.isLegal(dotOp))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user