[CI] run clang-format (#24)
This commit is contained in:
@@ -2,14 +2,16 @@
|
||||
#define TRITON_CONVERSION_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir{
|
||||
namespace triton{
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
|
@@ -1,42 +1,42 @@
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "../PassDetail.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPU.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||
#include "../PassDetail.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
template<class Op>
|
||||
class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||
template <class Op> class ArithGenericPattern : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
Op res = rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
Op res =
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template<class SrcOp, class DstOp>
|
||||
template <class SrcOp, class DstOp>
|
||||
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
||||
public:
|
||||
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
DstOp res = rewriter.replaceOpWithNewOp<DstOp>(
|
||||
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
|
||||
);
|
||||
DstOp res =
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -45,36 +45,40 @@ class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
||||
public:
|
||||
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||
assert(value);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, retType, value.reshape(retType) // This is a hack. We just want to add encoding
|
||||
op, retType,
|
||||
value.reshape(retType) // This is a hack. We just want to add encoding
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertArithmeticOp: public ConversionPattern {
|
||||
class ConvertArithmeticOp : public ConversionPattern {
|
||||
public:
|
||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter,
|
||||
MLIRContext *context)
|
||||
: ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1,
|
||||
context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
Dialect* dialect = op->getDialect();
|
||||
if(dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Dialect *dialect = op->getDialect();
|
||||
if (dialect->getTypeID() != mlir::TypeID::get<arith::ArithmeticDialect>())
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateArithmeticPatternsAndLegality(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target){
|
||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUConversionTarget &target) {
|
||||
// --------------
|
||||
// Add legality and rewrite pattern rules for operations
|
||||
// from the Arithmetic dialect. The basic premise is that
|
||||
@@ -91,59 +95,49 @@ void populateArithmeticPatternsAndLegality(
|
||||
// );
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
patterns.add<ArithConstantPattern,
|
||||
ArithGenericPattern<arith::AddIOp>,
|
||||
ArithGenericPattern<arith::SubIOp>,
|
||||
ArithGenericPattern<arith::MulIOp>,
|
||||
ArithGenericPattern<arith::DivUIOp>,
|
||||
ArithGenericPattern<arith::DivSIOp>,
|
||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||
ArithGenericPattern<arith::RemUIOp>,
|
||||
ArithGenericPattern<arith::RemSIOp>,
|
||||
ArithGenericPattern<arith::AndIOp>,
|
||||
ArithGenericPattern<arith::OrIOp>,
|
||||
ArithGenericPattern<arith::XOrIOp>,
|
||||
ArithGenericPattern<arith::ShLIOp>,
|
||||
ArithGenericPattern<arith::ShRUIOp>,
|
||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::AddFOp>,
|
||||
ArithGenericPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
ArithGenericPattern<arith::MaxFOp>,
|
||||
ArithGenericPattern<arith::MaxSIOp>,
|
||||
ArithGenericPattern<arith::MaxUIOp>,
|
||||
ArithGenericPattern<arith::MinFOp>,
|
||||
ArithGenericPattern<arith::MinSIOp>,
|
||||
ArithGenericPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::MulFOp>,
|
||||
ArithGenericPattern<arith::DivFOp>,
|
||||
ArithGenericPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>
|
||||
>(typeConverter, context);
|
||||
patterns.add<
|
||||
ArithConstantPattern, ArithGenericPattern<arith::AddIOp>,
|
||||
ArithGenericPattern<arith::SubIOp>, ArithGenericPattern<arith::MulIOp>,
|
||||
ArithGenericPattern<arith::DivUIOp>, ArithGenericPattern<arith::DivSIOp>,
|
||||
ArithGenericPattern<arith::CeilDivUIOp>,
|
||||
ArithGenericPattern<arith::CeilDivSIOp>,
|
||||
ArithGenericPattern<arith::FloorDivSIOp>,
|
||||
ArithGenericPattern<arith::RemUIOp>, ArithGenericPattern<arith::RemSIOp>,
|
||||
ArithGenericPattern<arith::AndIOp>, ArithGenericPattern<arith::OrIOp>,
|
||||
ArithGenericPattern<arith::XOrIOp>, ArithGenericPattern<arith::ShLIOp>,
|
||||
ArithGenericPattern<arith::ShRUIOp>,
|
||||
ArithGenericPattern<arith::ShRSIOp>, // NegFOp
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::AddFOp>, ArithGenericPattern<arith::SubFOp>,
|
||||
// MaxMin
|
||||
ArithGenericPattern<arith::MaxFOp>, ArithGenericPattern<arith::MaxSIOp>,
|
||||
ArithGenericPattern<arith::MaxUIOp>, ArithGenericPattern<arith::MinFOp>,
|
||||
ArithGenericPattern<arith::MinSIOp>, ArithGenericPattern<arith::MinUIOp>,
|
||||
// Floating point
|
||||
ArithGenericPattern<arith::MulFOp>, ArithGenericPattern<arith::DivFOp>,
|
||||
ArithGenericPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
||||
// Cast Ops
|
||||
ArithGenericPattern<arith::TruncIOp>,
|
||||
ArithGenericPattern<arith::TruncFOp>>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
// Triton patterns
|
||||
//
|
||||
// TODO: Do we need to put them in anonymous namespace?
|
||||
struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp> {
|
||||
struct TritonMakeRangePattern
|
||||
: public OpConversionPattern<triton::MakeRangeOp> {
|
||||
using OpConversionPattern<triton::MakeRangeOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, retType, adaptor.start(), adaptor.end()
|
||||
);
|
||||
op, retType, adaptor.start(), adaptor.end());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -151,8 +145,9 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
|
||||
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
// a & b must be of smem layout
|
||||
auto aType = adaptor.a().getType().cast<RankedTensorType>();
|
||||
@@ -165,18 +160,21 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
Value b = adaptor.b();
|
||||
SmallVector<unsigned, 2> order{1, 0};
|
||||
if (!aEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(), aType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::TritonGPUSharedEncodingAttr>()) {
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(), bType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::TritonGPUSharedEncodingAttr::get(
|
||||
getContext(), 1, 1, 1, order);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
auto newDot = rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32()
|
||||
);
|
||||
op, retType, a, b, adaptor.c(), adaptor.allowTF32());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -184,14 +182,13 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
using OpConversionPattern<triton::LoadOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, retType,
|
||||
adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
|
||||
);
|
||||
op, retType, adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -199,11 +196,11 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
|
||||
);
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -212,12 +209,11 @@ template <class Op>
|
||||
struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -225,30 +221,25 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis()
|
||||
);
|
||||
op, retType, adaptor.redOp(), adaptor.operand(), adaptor.axis());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(
|
||||
TritonGPUTypeConverter& typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<TritonGenericPattern<triton::ReshapeOp>,
|
||||
TritonGenericPattern<triton::SplatOp>,
|
||||
TritonGenericPattern<triton::BroadcastOp>,
|
||||
TritonGenericPattern<triton::GEPOp>,
|
||||
TritonReducePattern,
|
||||
TritonMakeRangePattern,
|
||||
TritonDotPattern,
|
||||
TritonLoadPattern,
|
||||
TritonStorePattern
|
||||
>(typeConverter, context);
|
||||
TritonGenericPattern<triton::GEPOp>, TritonReducePattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -259,17 +250,19 @@ void populateTritonPatterns(
|
||||
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
||||
// Ref: ConvertForOpTypes
|
||||
LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp =
|
||||
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
||||
rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
|
||||
newOp.getLoopBody().end());
|
||||
|
||||
// Now, update all the types.
|
||||
|
||||
// Convert the types of block arguments within the given region. This
|
||||
// replaces each block with a new block containing the updated signature. The
|
||||
// entry block may have a special conversion if `entryConversion` is
|
||||
// replaces each block with a new block containing the updated signature.
|
||||
// The entry block may have a special conversion if `entryConversion` is
|
||||
// provided. On success, the new entry block to the region is returned for
|
||||
// convenience. Otherwise, failure is returned.
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(),
|
||||
@@ -299,33 +292,27 @@ struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
||||
struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
using OpConversionPattern<scf::YieldOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
||||
// op.erase();
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(
|
||||
op, adaptor.getOperands()
|
||||
);
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateSCFPatterns(
|
||||
TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns
|
||||
) {
|
||||
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<SCFYieldPattern, SCFForPattern
|
||||
>(typeConverter, context);
|
||||
patterns.add<SCFYieldPattern, SCFForPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
||||
class ConvertTritonToTritonGPU :
|
||||
public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
class ConvertTritonToTritonGPU
|
||||
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
||||
public:
|
||||
ConvertTritonToTritonGPU(int numWarps) {
|
||||
this->numWarps = numWarps;
|
||||
}
|
||||
ConvertTritonToTritonGPU(int numWarps) { this->numWarps = numWarps; }
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
@@ -339,21 +326,21 @@ public:
|
||||
// add rules
|
||||
populateArithmeticPatternsAndLegality(typeConverter, patterns, target);
|
||||
populateTritonPatterns(typeConverter, patterns);
|
||||
// TODO: can we use
|
||||
// TODO: can we use
|
||||
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
||||
populateSCFPatterns(typeConverter, patterns);
|
||||
|
||||
if(failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// update layouts
|
||||
// broadcast src => multicast, dst => broadcasted
|
||||
if(failed(target.refineLayouts(mod, numWarps)))
|
||||
if (failed(target.refineLayouts(mod, numWarps)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps) {
|
||||
|
Reference in New Issue
Block a user