More conversion patterns
This commit is contained in:
@@ -19,8 +19,23 @@ public:
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<Op>(
|
||||
op.getOperation(), retType, op.getLhs(), op.getRhs()
|
||||
Op res = rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template<class Op>
|
||||
class ArithCmpPattern : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
Op res = rewriter.replaceOpWithNewOp<Op>(
|
||||
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
@@ -51,13 +66,13 @@ void populateArithmeticPatternsAndLegality(
|
||||
// non-null encoding
|
||||
// --------------
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Legality rule
|
||||
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
// TODO: check above rule here
|
||||
[](Operation *op){
|
||||
return true;
|
||||
}
|
||||
);
|
||||
// // Legality rule
|
||||
// target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
|
||||
// // TODO: check above rule here
|
||||
// [](Operation *op){
|
||||
// return true;
|
||||
// }
|
||||
// );
|
||||
// Rewrite rule
|
||||
// patterns.add<ConvertArithmeticOp>(typeConverter, context);
|
||||
patterns.add<ArithBinaryPattern<arith::AddIOp>,
|
||||
@@ -89,7 +104,10 @@ void populateArithmeticPatternsAndLegality(
|
||||
// Floating point
|
||||
ArithBinaryPattern<arith::MulFOp>,
|
||||
ArithBinaryPattern<arith::DivFOp>,
|
||||
ArithBinaryPattern<arith::RemFOp>
|
||||
ArithBinaryPattern<arith::RemFOp>,
|
||||
// Cmp
|
||||
ArithCmpPattern<arith::CmpIOp>,
|
||||
ArithCmpPattern<arith::CmpFOp>
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
@@ -103,9 +121,8 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
|
||||
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op.getOperation(), retType, op.start(), op.end()
|
||||
op, retType, adaptor.start(), adaptor.end()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
@@ -118,7 +135,7 @@ struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp>
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
||||
op.getOperation(), retType, op.src()
|
||||
op, retType, adaptor.src()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
@@ -131,7 +148,7 @@ struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::GEPOp>(
|
||||
op.getOperation(), retType, op.ptr(), op.offset()
|
||||
op, retType, adaptor.getOperands()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
@@ -144,8 +161,21 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op.getOperation(), retType,
|
||||
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile()
|
||||
op, retType,
|
||||
adaptor.ptr(), adaptor.mask(), adaptor.other(),
|
||||
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
|
||||
);
|
||||
return success();
|
||||
}
|
||||
@@ -158,7 +188,8 @@ void populateTritonPatterns(
|
||||
patterns.add<TritonMakeRangePattern,
|
||||
TritonBroadcastPattern,
|
||||
TritonGEPPattern,
|
||||
TritonLoadPattern
|
||||
TritonLoadPattern,
|
||||
TritonStorePattern
|
||||
>(typeConverter, context);
|
||||
}
|
||||
|
||||
|
@@ -19,7 +19,11 @@ TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type
|
||||
}
|
||||
|
||||
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
llvm_unreachable("Not implemented");
|
||||
printer << "<"
|
||||
<< "threadTileSize = " << getThreadTileSize()
|
||||
<< ", blockTileSize = " << getBlockTileSize()
|
||||
<< ", order = " << getOrder()
|
||||
<< ">";
|
||||
}
|
||||
|
||||
mlir::Attribute
|
||||
|
@@ -26,7 +26,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
||||
assert(numElements > numThreads);
|
||||
assert(numElements % numThreads == 0);
|
||||
|
||||
// assert no encoding?
|
||||
// or assert no encoding?
|
||||
|
||||
// Now we assume:
|
||||
// contiguous = 1, order = 0, 1, 2, ...,
|
||||
@@ -62,14 +62,26 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
scf::ReduceOp, scf::ReduceReturnOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
if (typeConverter.isLegal(op)) {
|
||||
// llvm::errs() << *op << " is dyanamically legal\n";
|
||||
return true;
|
||||
}
|
||||
// if (typeConverter.isLegal(op->getOperandTypes()))
|
||||
// llvm::errs() << *op << " is illegal with legal operands\n";
|
||||
// if (typeConverter.isLegal(op->getResultTypes())) {
|
||||
// llvm::errs() << *op << " is illegal with legal results\n";
|
||||
// llvm::errs() << "operand0: " << op->getOperand(0) << "\n"
|
||||
// << "operand1: " << op->getOperand(1) << "\n";
|
||||
// }
|
||||
return false;
|
||||
});
|
||||
|
||||
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
|
||||
if (typeConverter.isLegal(op))
|
||||
return true;
|
||||
// llvm::errs() << *op << " is illegal\n"
|
||||
// << "inside ...\n"
|
||||
// << *op->getParentOp() << "\n";
|
||||
return false;
|
||||
});
|
||||
|
||||
|
Reference in New Issue
Block a user