More on TritonGPU conversion

This commit is contained in:
Yan Da
2022-05-02 21:51:00 +08:00
parent 1428185c9c
commit 75d32e2442
7 changed files with 114 additions and 18 deletions

View File

@@ -13,7 +13,9 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
let dependentDialects = ["mlir::arith::ArithmeticDialect", let dependentDialects = ["mlir::arith::ArithmeticDialect",
"mlir::StandardOpsDialect", "mlir::StandardOpsDialect",
// TODO: Does this pass depend on SCF? // TODO: Does this pass depend on SCF?
"mlir::scf::SCFDialect"]; "mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect"];
} }
#endif #endif

View File

@@ -20,8 +20,9 @@ private:
}; };
class TritonGPUConversionTarget : public ConversionTarget { class TritonGPUConversionTarget : public ConversionTarget {
TritonGPUTypeConverter &typeConverter;
public: public:
explicit TritonGPUConversionTarget(MLIRContext &ctx); explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
}; };
} // namespace mlir } // namespace mlir

View File

@@ -1,4 +1,5 @@
#include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@@ -12,7 +13,7 @@ namespace {
class ConvertArithmeticOp: public ConversionPattern { class ConvertArithmeticOp: public ConversionPattern {
public: public:
ConvertArithmeticOp(TypeConverter &typeConverter, MLIRContext *context) ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
context) {} context) {}
@@ -21,14 +22,13 @@ public:
Dialect* dialect = op->getDialect(); Dialect* dialect = op->getDialect();
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>()) if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
return failure(); return failure();
// Arithmetic op to legalize here. Create layout conversion if necessary
return success(); return success();
} }
}; };
void populateArithmeticPatternsAndLegality( void populateArithmeticPatternsAndLegality(
TypeConverter& typeConverter, RewritePatternSet &patterns, TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
ConversionTarget &target){ TritonGPUConversionTarget &target){
// -------------- // --------------
// Add legality and rewrite pattern rules for operations // Add legality and rewrite pattern rules for operations
// from the Arithmetic dialect. The basic premise is that // from the Arithmetic dialect. The basic premise is that
@@ -47,6 +47,75 @@ void populateArithmeticPatternsAndLegality(
patterns.add<ConvertArithmeticOp>(typeConverter, context); patterns.add<ConvertArithmeticOp>(typeConverter, context);
} }
//
// Triton patterns
//
// TODO: Do we need to put them in anonymous namespace?
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op.getOperation(), retType, op.start(), op.end()
);
return success();
}
};
struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp> {
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
op.getOperation(), retType, op.src()
);
return success();
}
};
struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
using OpConversionPattern<triton::GEPOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::GEPOp>(
op.getOperation(), retType, op.ptr(), op.offset()
);
return success();
}
};
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op.getOperation(), retType,
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile()
);
return success();
}
};
void populateTritonPatterns(
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
) {
MLIRContext *context = patterns.getContext();
patterns.add<TritonMakeRangePattern,
TritonBroadcastPattern,
TritonGEPPattern,
TritonLoadPattern
>(typeConverter, context);
}
class ConvertTritonToTritonGPU : class ConvertTritonToTritonGPU :
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> { public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
@@ -54,18 +123,19 @@ class ConvertTritonToTritonGPU :
public: public:
void runOnOperation() override { void runOnOperation() override {
MLIRContext *context = &getContext(); MLIRContext *context = &getContext();
TritonGPUConversionTarget target(*context);
ModuleOp mod = getOperation(); ModuleOp mod = getOperation();
// int numThreads = mod.getAttr(); // int numThreads = mod.getAttr();
// type converter // type converter
TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32); TritonGPUTypeConverter typeConverter(context, /*numThreads*/128);
TritonGPUConversionTarget target(*context, typeConverter);
// rewrite patterns // rewrite patterns
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
// add rules // add rules
populateArithmeticPatternsAndLegality(typeConverter, patterns, target); populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns);
if(failed(applyPartialConversion(getOperation(), target, if(failed(applyPartialConversion(mod, target,
std::move(patterns)))) std::move(patterns))))
return signalPassFailure(); return signalPassFailure();
} }

View File

@@ -41,6 +41,10 @@ void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
} }
void TritonGPUDialect::initialize() { void TritonGPUDialect::initialize() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
>();
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" #include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"

View File

@@ -11,7 +11,12 @@ using namespace mlir;
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numThreads) int numThreads)
: context(context), numThreads(numThreads) { : context(context), numThreads(numThreads) {
addConversion([&](RankedTensorType tensorType) -> RankedTensorType { // TODO: how does MLIR pick the right conversion?
addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
MLIRContext *context = this->context;
int numThreads = this->numThreads;
llvm::ArrayRef<int64_t> shape = tensorType.getShape(); llvm::ArrayRef<int64_t> shape = tensorType.getShape();
Type elementType = tensorType.getElementType(); Type elementType = tensorType.getElementType();
int64_t rank = tensorType.getRank(); int64_t rank = tensorType.getRank();
@@ -45,15 +50,28 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// //
// TritonGPUConversion // TritonGPUConversion
// //
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context) TritonGPUConversionTarget::TritonGPUConversionTarget(
: ConversionTarget(context) { MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context), typeConverter(typeConverter) {
addLegalDialect<triton::TritonDialect, addLegalDialect<triton::TritonDialect,
arith::ArithmeticDialect, StandardOpsDialect,
scf::SCFDialect>(); scf::SCFDialect>();
// Some ops from SCF are illegal // Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
scf::ReduceOp, scf::ReduceReturnOp>(); scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
return false;
});
// // We have requirements for the data layouts // // We have requirements for the data layouts
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool { // addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {

View File

@@ -94,8 +94,8 @@ mod, ctx = matmul_kernel.compile_to_ttir(
8, grid=(2,) 8, grid=(2,)
) )
assert mod.verify() # assert mod.verify()
mod.dump() # mod.dump()
pm = _triton.ir.pass_manager(ctx) pm = _triton.ir.pass_manager(ctx)
pm.add_inliner_pass() pm.add_inliner_pass()
@@ -104,5 +104,5 @@ pm.add_canonicalizer_pass()
pm.add_convert_triton_to_tritongpu_pass() pm.add_convert_triton_to_tritongpu_pass()
pm.run(mod) pm.run(mod)
assert mod.verify() # assert mod.verify()
mod.dump() # mod.dump()

View File

@@ -1,7 +1,8 @@
from tarfile import BLOCKSIZE
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
import triton._C.libtriton.triton as _triton
@triton.jit @triton.jit