More on TritonGPU conversion
This commit is contained in:
@@ -13,7 +13,9 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO
|
|||||||
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
let dependentDialects = ["mlir::arith::ArithmeticDialect",
|
||||||
"mlir::StandardOpsDialect",
|
"mlir::StandardOpsDialect",
|
||||||
// TODO: Does this pass depend on SCF?
|
// TODO: Does this pass depend on SCF?
|
||||||
"mlir::scf::SCFDialect"];
|
"mlir::scf::SCFDialect",
|
||||||
|
"mlir::triton::TritonDialect",
|
||||||
|
"mlir::triton::gpu::TritonGPUDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@@ -20,8 +20,9 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
class TritonGPUConversionTarget : public ConversionTarget {
|
class TritonGPUConversionTarget : public ConversionTarget {
|
||||||
|
TritonGPUTypeConverter &typeConverter;
|
||||||
public:
|
public:
|
||||||
explicit TritonGPUConversionTarget(MLIRContext &ctx);
|
explicit TritonGPUConversionTarget(MLIRContext &ctx, TritonGPUTypeConverter &typeConverter);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
@@ -12,7 +13,7 @@ namespace {
|
|||||||
|
|
||||||
class ConvertArithmeticOp: public ConversionPattern {
|
class ConvertArithmeticOp: public ConversionPattern {
|
||||||
public:
|
public:
|
||||||
ConvertArithmeticOp(TypeConverter &typeConverter, MLIRContext *context)
|
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||||
context) {}
|
context) {}
|
||||||
|
|
||||||
@@ -21,14 +22,13 @@ public:
|
|||||||
Dialect* dialect = op->getDialect();
|
Dialect* dialect = op->getDialect();
|
||||||
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||||
return failure();
|
return failure();
|
||||||
// Arithmetic op to legalize here. Create layout conversion if necessary
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void populateArithmeticPatternsAndLegality(
|
void populateArithmeticPatternsAndLegality(
|
||||||
TypeConverter& typeConverter, RewritePatternSet &patterns,
|
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target){
|
TritonGPUConversionTarget &target){
|
||||||
// --------------
|
// --------------
|
||||||
// Add legality and rewrite pattern rules for operations
|
// Add legality and rewrite pattern rules for operations
|
||||||
// from the Arithmetic dialect. The basic premise is that
|
// from the Arithmetic dialect. The basic premise is that
|
||||||
@@ -47,6 +47,75 @@ void populateArithmeticPatternsAndLegality(
|
|||||||
patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Triton patterns
|
||||||
|
//
|
||||||
|
// TODO: Do we need to put them in anonymous namespace?
|
||||||
|
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
|
||||||
|
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||||
|
op.getOperation(), retType, op.start(), op.end()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp> {
|
||||||
|
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
||||||
|
op.getOperation(), retType, op.src()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
|
||||||
|
using OpConversionPattern<triton::GEPOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(triton::GEPOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<triton::GEPOp>(
|
||||||
|
op.getOperation(), retType, op.ptr(), op.offset()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||||
|
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Type retType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||||
|
op.getOperation(), retType,
|
||||||
|
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile()
|
||||||
|
);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateTritonPatterns(
|
||||||
|
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||||
|
) {
|
||||||
|
MLIRContext *context = patterns.getContext();
|
||||||
|
patterns.add<TritonMakeRangePattern,
|
||||||
|
TritonBroadcastPattern,
|
||||||
|
TritonGEPPattern,
|
||||||
|
TritonLoadPattern
|
||||||
|
>(typeConverter, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ConvertTritonToTritonGPU :
|
class ConvertTritonToTritonGPU :
|
||||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||||
@@ -54,18 +123,19 @@ class ConvertTritonToTritonGPU :
|
|||||||
public:
|
public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
TritonGPUConversionTarget target(*context);
|
|
||||||
ModuleOp mod = getOperation();
|
ModuleOp mod = getOperation();
|
||||||
// int numThreads = mod.getAttr();
|
// int numThreads = mod.getAttr();
|
||||||
// type converter
|
// type converter
|
||||||
TritonGPUTypeConverter typeConverter(context, /*numThreads*/4*32);
|
TritonGPUTypeConverter typeConverter(context, /*numThreads*/128);
|
||||||
|
TritonGPUConversionTarget target(*context, typeConverter);
|
||||||
// rewrite patterns
|
// rewrite patterns
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
// add rules
|
// add rules
|
||||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||||
|
populateTritonPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
|
|
||||||
if(failed(applyPartialConversion(getOperation(), target,
|
if(failed(applyPartialConversion(mod, target,
|
||||||
std::move(patterns))))
|
std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
@@ -41,6 +41,10 @@ void TritonGPUSharedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TritonGPUDialect::initialize() {
|
void TritonGPUDialect::initialize() {
|
||||||
|
addAttributes<
|
||||||
|
#define GET_ATTRDEF_LIST
|
||||||
|
#include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.cpp.inc"
|
||||||
|
>();
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
|
||||||
|
@@ -11,7 +11,12 @@ using namespace mlir;
|
|||||||
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||||
int numThreads)
|
int numThreads)
|
||||||
: context(context), numThreads(numThreads) {
|
: context(context), numThreads(numThreads) {
|
||||||
addConversion([&](RankedTensorType tensorType) -> RankedTensorType {
|
// TODO: how does MLIR pick the right conversion?
|
||||||
|
addConversion([](Type type) { return type; });
|
||||||
|
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
|
||||||
|
MLIRContext *context = this->context;
|
||||||
|
int numThreads = this->numThreads;
|
||||||
|
|
||||||
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
|
llvm::ArrayRef<int64_t> shape = tensorType.getShape();
|
||||||
Type elementType = tensorType.getElementType();
|
Type elementType = tensorType.getElementType();
|
||||||
int64_t rank = tensorType.getRank();
|
int64_t rank = tensorType.getRank();
|
||||||
@@ -45,15 +50,28 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
//
|
//
|
||||||
// TritonGPUConversion
|
// TritonGPUConversion
|
||||||
//
|
//
|
||||||
TritonGPUConversionTarget::TritonGPUConversionTarget(MLIRContext &context)
|
TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||||
: ConversionTarget(context) {
|
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
|
||||||
|
: ConversionTarget(context), typeConverter(typeConverter) {
|
||||||
addLegalDialect<triton::TritonDialect,
|
addLegalDialect<triton::TritonDialect,
|
||||||
arith::ArithmeticDialect,
|
StandardOpsDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
|
|
||||||
// Some ops from SCF are illegal
|
// Some ops from SCF are illegal
|
||||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp,
|
||||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||||
|
|
||||||
|
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
||||||
|
if (typeConverter.isLegal(op))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
|
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
||||||
|
if (typeConverter.isLegal(op))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
// // We have requirements for the data layouts
|
// // We have requirements for the data layouts
|
||||||
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
// addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
|
||||||
|
@@ -94,8 +94,8 @@ mod, ctx = matmul_kernel.compile_to_ttir(
|
|||||||
8, grid=(2,)
|
8, grid=(2,)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert mod.verify()
|
# assert mod.verify()
|
||||||
mod.dump()
|
# mod.dump()
|
||||||
|
|
||||||
pm = _triton.ir.pass_manager(ctx)
|
pm = _triton.ir.pass_manager(ctx)
|
||||||
pm.add_inliner_pass()
|
pm.add_inliner_pass()
|
||||||
@@ -104,5 +104,5 @@ pm.add_canonicalizer_pass()
|
|||||||
pm.add_convert_triton_to_tritongpu_pass()
|
pm.add_convert_triton_to_tritongpu_pass()
|
||||||
pm.run(mod)
|
pm.run(mod)
|
||||||
|
|
||||||
assert mod.verify()
|
# assert mod.verify()
|
||||||
mod.dump()
|
# mod.dump()
|
||||||
|
@@ -1,7 +1,8 @@
|
|||||||
from tarfile import BLOCKSIZE
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
import triton._C.libtriton.triton as _triton
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
Reference in New Issue
Block a user