More conversion patterns

This commit is contained in:
Yan Da
2022-05-04 12:50:02 +08:00
parent 5f08e2fdae
commit 3ad7bee35e
4 changed files with 69 additions and 20 deletions

View File

@@ -19,8 +19,23 @@ public:
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>(
op.getOperation(), retType, op.getLhs(), op.getRhs()
Op res = rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getOperands()
);
return success();
}
};
template<class Op>
class ArithCmpPattern : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
Op res = rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
);
return success();
}
@@ -51,13 +66,13 @@ void populateArithmeticPatternsAndLegality(
// non-null encoding
// --------------
MLIRContext *context = patterns.getContext();
// Legality rule
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// TODO: check above rule here
[](Operation *op){
return true;
}
);
// // Legality rule
// target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// // TODO: check above rule here
// [](Operation *op){
// return true;
// }
// );
// Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithBinaryPattern<arith::AddIOp>,
@@ -89,7 +104,10 @@ void populateArithmeticPatternsAndLegality(
// Floating point
ArithBinaryPattern<arith::MulFOp>,
ArithBinaryPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp>
ArithBinaryPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp>,
ArithCmpPattern<arith::CmpFOp>
>(typeConverter, context);
}
@@ -103,9 +121,8 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op.getOperation(), retType, op.start(), op.end()
op, retType, adaptor.start(), adaptor.end()
);
return success();
}
@@ -118,7 +135,7 @@ struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp>
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
op.getOperation(), retType, op.src()
op, retType, adaptor.src()
);
return success();
}
@@ -131,7 +148,7 @@ struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::GEPOp>(
op.getOperation(), retType, op.ptr(), op.offset()
op, retType, adaptor.getOperands()
);
return success();
}
@@ -144,8 +161,21 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::LoadOp>(
op.getOperation(), retType,
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile()
op, retType,
adaptor.ptr(), adaptor.mask(), adaptor.other(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
);
return success();
}
};
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
);
return success();
}
@@ -158,7 +188,8 @@ void populateTritonPatterns(
patterns.add<TritonMakeRangePattern,
TritonBroadcastPattern,
TritonGEPPattern,
TritonLoadPattern
TritonLoadPattern,
TritonStorePattern
>(typeConverter, context);
}

View File

@@ -19,7 +19,11 @@ TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type
}
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
llvm_unreachable("Not implemented");
printer << "<"
<< "threadTileSize = " << getThreadTileSize()
<< ", blockTileSize = " << getBlockTileSize()
<< ", order = " << getOrder()
<< ">";
}
mlir::Attribute

View File

@@ -26,7 +26,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
assert(numElements > numThreads);
assert(numElements % numThreads == 0);
// assert no encoding?
// or assert no encoding?
// Now we assume:
// contiguous = 1, order = 0, 1, 2, ...,
@@ -62,14 +62,26 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
if (typeConverter.isLegal(op)) {
// llvm::errs() << *op << " is dyanamically legal\n";
return true;
}
// if (typeConverter.isLegal(op->getOperandTypes()))
// llvm::errs() << *op << " is illegal with legal operands\n";
// if (typeConverter.isLegal(op->getResultTypes())) {
// llvm::errs() << *op << " is illegal with legal results\n";
// llvm::errs() << "operand0: " << op->getOperand(0) << "\n"
// << "operand1: " << op->getOperand(1) << "\n";
// }
return false;
});
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
if (typeConverter.isLegal(op))
return true;
// llvm::errs() << *op << " is illegal\n"
// << "inside ...\n"
// << *op->getParentOp() << "\n";
return false;
});