More conversion patterns

This commit is contained in:
Yan Da
2022-05-04 12:50:02 +08:00
parent 5f08e2fdae
commit 3ad7bee35e
4 changed files with 69 additions and 20 deletions

View File

@@ -19,8 +19,23 @@ public:
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType()); Type retType = this->getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<Op>( Op res = rewriter.replaceOpWithNewOp<Op>(
op.getOperation(), retType, op.getLhs(), op.getRhs() op, retType, adaptor.getOperands()
);
return success();
}
};
template<class Op>
class ArithCmpPattern : public OpConversionPattern<Op> {
public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = this->getTypeConverter()->convertType(op.getType());
Op res = rewriter.replaceOpWithNewOp<Op>(
op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs()
); );
return success(); return success();
} }
@@ -51,13 +66,13 @@ void populateArithmeticPatternsAndLegality(
// non-null encoding // non-null encoding
// -------------- // --------------
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
// Legality rule // // Legality rule
target.addDynamicallyLegalDialect<arith::ArithmeticDialect>( // target.addDynamicallyLegalDialect<arith::ArithmeticDialect>(
// TODO: check above rule here // // TODO: check above rule here
[](Operation *op){ // [](Operation *op){
return true; // return true;
} // }
); // );
// Rewrite rule // Rewrite rule
// patterns.add<ConvertArithmeticOp>(typeConverter, context); // patterns.add<ConvertArithmeticOp>(typeConverter, context);
patterns.add<ArithBinaryPattern<arith::AddIOp>, patterns.add<ArithBinaryPattern<arith::AddIOp>,
@@ -89,7 +104,10 @@ void populateArithmeticPatternsAndLegality(
// Floating point // Floating point
ArithBinaryPattern<arith::MulFOp>, ArithBinaryPattern<arith::MulFOp>,
ArithBinaryPattern<arith::DivFOp>, ArithBinaryPattern<arith::DivFOp>,
ArithBinaryPattern<arith::RemFOp> ArithBinaryPattern<arith::RemFOp>,
// Cmp
ArithCmpPattern<arith::CmpIOp>,
ArithCmpPattern<arith::CmpFOp>
>(typeConverter, context); >(typeConverter, context);
} }
@@ -103,9 +121,8 @@ struct TritonMakeRangePattern : public OpConversionPattern<triton::MakeRangeOp>
LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType()); Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>( rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
op.getOperation(), retType, op.start(), op.end() op, retType, adaptor.start(), adaptor.end()
); );
return success(); return success();
} }
@@ -118,7 +135,7 @@ struct TritonBroadcastPattern : public OpConversionPattern<triton::BroadcastOp>
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType()); Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::BroadcastOp>( rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
op.getOperation(), retType, op.src() op, retType, adaptor.src()
); );
return success(); return success();
} }
@@ -131,7 +148,7 @@ struct TritonGEPPattern : public OpConversionPattern<triton::GEPOp> {
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType()); Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::GEPOp>( rewriter.replaceOpWithNewOp<triton::GEPOp>(
op.getOperation(), retType, op.ptr(), op.offset() op, retType, adaptor.getOperands()
); );
return success(); return success();
} }
@@ -144,8 +161,21 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType()); Type retType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<triton::LoadOp>( rewriter.replaceOpWithNewOp<triton::LoadOp>(
op.getOperation(), retType, op, retType,
op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile() adaptor.ptr(), adaptor.mask(), adaptor.other(),
adaptor.cache(), adaptor.evict(), adaptor.isVolatile()
);
return success();
}
};
struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
using OpConversionPattern<triton::StoreOp>::OpConversionPattern;
LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask()
); );
return success(); return success();
} }
@@ -158,7 +188,8 @@ void populateTritonPatterns(
patterns.add<TritonMakeRangePattern, patterns.add<TritonMakeRangePattern,
TritonBroadcastPattern, TritonBroadcastPattern,
TritonGEPPattern, TritonGEPPattern,
TritonLoadPattern TritonLoadPattern,
TritonStorePattern
>(typeConverter, context); >(typeConverter, context);
} }

View File

@@ -19,7 +19,11 @@ TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type
} }
void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const { void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const {
llvm_unreachable("Not implemented"); printer << "<"
<< "threadTileSize = " << getThreadTileSize()
<< ", blockTileSize = " << getBlockTileSize()
<< ", order = " << getOrder()
<< ">";
} }
mlir::Attribute mlir::Attribute

View File

@@ -26,7 +26,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
assert(numElements > numThreads); assert(numElements > numThreads);
assert(numElements % numThreads == 0); assert(numElements % numThreads == 0);
// assert no encoding? // or assert no encoding?
// Now we assume: // Now we assume:
// contiguous = 1, order = 0, 1, 2, ..., // contiguous = 1, order = 0, 1, 2, ...,
@@ -62,14 +62,26 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
scf::ReduceOp, scf::ReduceReturnOp>(); scf::ReduceOp, scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) { addDynamicallyLegalDialect<arith::ArithmeticDialect>([&](Operation *op) {
if (typeConverter.isLegal(op)) if (typeConverter.isLegal(op)) {
// llvm::errs() << *op << " is dyanamically legal\n";
return true; return true;
}
// if (typeConverter.isLegal(op->getOperandTypes()))
// llvm::errs() << *op << " is illegal with legal operands\n";
// if (typeConverter.isLegal(op->getResultTypes())) {
// llvm::errs() << *op << " is illegal with legal results\n";
// llvm::errs() << "operand0: " << op->getOperand(0) << "\n"
// << "operand1: " << op->getOperand(1) << "\n";
// }
return false; return false;
}); });
addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) { addDynamicallyLegalDialect<triton::TritonDialect>([&](Operation *op) {
if (typeConverter.isLegal(op)) if (typeConverter.isLegal(op))
return true; return true;
// llvm::errs() << *op << " is illegal\n"
// << "inside ...\n"
// << *op->getParentOp() << "\n";
return false; return false;
}); });

View File

@@ -41,5 +41,7 @@ z = torch.empty_like(x)
# print(add_kernel[(1,)].kernel.compile_to_ttir()) # print(add_kernel[(1,)].kernel.compile_to_ttir())
# print(add_kernel.annotations) # print(add_kernel.annotations)
mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,)) mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,))
assert mod.verify()
mod.dump()
mod = add_kernel.compile_ttir_to_llir(mod, ctx) mod = add_kernel.compile_ttir_to_llir(mod, ctx)
mod.dump() mod.dump()