From 3ad7bee35ee74d3c38e8d00b900720e83ed8a691 Mon Sep 17 00:00:00 2001 From: Yan Da Date: Wed, 4 May 2022 12:50:02 +0800 Subject: [PATCH] More conversion patterns --- .../TritonToTritonGPU/TritonToTritonGPU.cpp | 65 ++++++++++++++----- lib/Dialect/TritonGPU/IR/Dialect.cpp | 6 +- .../Transforms/TritonGPUConversion.cpp | 16 ++++- rewrite-test/jit/vecadd.py | 2 + 4 files changed, 69 insertions(+), 20 deletions(-) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index 7be905fab..bf35f393d 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -19,8 +19,23 @@ public: LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op.getOperation(), retType, op.getLhs(), op.getRhs() + Op res = rewriter.replaceOpWithNewOp( + op, retType, adaptor.getOperands() + ); + return success(); + } +}; + +template +class ArithCmpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = this->getTypeConverter()->convertType(op.getType()); + Op res = rewriter.replaceOpWithNewOp( + op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs() ); return success(); } @@ -51,13 +66,13 @@ void populateArithmeticPatternsAndLegality( // non-null encoding // -------------- MLIRContext *context = patterns.getContext(); - // Legality rule - target.addDynamicallyLegalDialect( - // TODO: check above rule here - [](Operation *op){ - return true; - } - ); + // // Legality rule + // target.addDynamicallyLegalDialect( + // // TODO: check above rule here + // [](Operation *op){ + // return true; + // } + // ); // Rewrite rule // patterns.add(typeConverter, context); patterns.add, @@ -89,7 +104,10 @@ void populateArithmeticPatternsAndLegality( // Floating point ArithBinaryPattern, ArithBinaryPattern, - ArithBinaryPattern + ArithBinaryPattern, + // Cmp + ArithCmpPattern, + ArithCmpPattern >(typeConverter, context); } @@ -103,9 +121,8 @@ struct TritonMakeRangePattern : public OpConversionPattern LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op.getOperation(), retType, op.start(), op.end() + op, retType, adaptor.start(), adaptor.end() ); return success(); } @@ -118,7 +135,7 @@ struct TritonBroadcastPattern : public OpConversionPattern ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( - op.getOperation(), retType, op.src() + op, retType, adaptor.src() ); return success(); } @@ -131,7 +148,7 @@ struct TritonGEPPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( - op.getOperation(), retType, op.ptr(), op.offset() + op, retType, adaptor.getOperands() ); return success(); } @@ -144,8 +161,21 @@ struct TritonLoadPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( - op.getOperation(), retType, - op.ptr(), op.mask(), op.other(), op.cache(), op.evict(), op.isVolatile() + op, retType, + adaptor.ptr(), adaptor.mask(), adaptor.other(), + adaptor.cache(), adaptor.evict(), adaptor.isVolatile() + ); + return success(); + } +}; + +struct TritonStorePattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, adaptor.ptr(), adaptor.value(), adaptor.mask() ); return success(); } @@ -158,7 +188,8 @@ void populateTritonPatterns( patterns.add(typeConverter, context); } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index dd877d046..c1fb8a44a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -19,7 +19,11 @@ TritonGPUDistributedEncodingAttr::parse(mlir::AsmParser &parser, mlir::Type type } void TritonGPUDistributedEncodingAttr::print(mlir::AsmPrinter &printer) const { - llvm_unreachable("Not implemented"); + printer << "<" + << "threadTileSize = " << getThreadTileSize() + << ", blockTileSize = " << getBlockTileSize() + << ", order = " << getOrder() + << ">"; } mlir::Attribute diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 36162c1b2..b596f77b9 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -26,7 +26,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, assert(numElements > numThreads); assert(numElements % numThreads == 0); - // assert no encoding? + // or assert no encoding? // Now we assume: // contiguous = 1, order = 0, 1, 2, ..., @@ -62,14 +62,26 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( scf::ReduceOp, scf::ReduceReturnOp>(); addDynamicallyLegalDialect([&](Operation *op) { - if (typeConverter.isLegal(op)) + if (typeConverter.isLegal(op)) { + // llvm::errs() << *op << " is dyanamically legal\n"; return true; + } + // if (typeConverter.isLegal(op->getOperandTypes())) + // llvm::errs() << *op << " is illegal with legal operands\n"; + // if (typeConverter.isLegal(op->getResultTypes())) { + // llvm::errs() << *op << " is illegal with legal results\n"; + // llvm::errs() << "operand0: " << op->getOperand(0) << "\n" + // << "operand1: " << op->getOperand(1) << "\n"; + // } return false; }); addDynamicallyLegalDialect([&](Operation *op) { if (typeConverter.isLegal(op)) return true; + // llvm::errs() << *op << " is illegal\n" + // << "inside ...\n" + // << *op->getParentOp() << "\n"; return false; }); diff --git a/rewrite-test/jit/vecadd.py b/rewrite-test/jit/vecadd.py index 42a95424a..758b99572 100644 --- a/rewrite-test/jit/vecadd.py +++ b/rewrite-test/jit/vecadd.py @@ -41,5 +41,7 @@ z = torch.empty_like(x) # print(add_kernel[(1,)].kernel.compile_to_ttir()) # print(add_kernel.annotations) mod, ctx = add_kernel.compile_to_ttir(x, y, z, size, BLOCK_SIZE=256, grid=(1,)) +assert mod.verify() +mod.dump() mod = add_kernel.compile_ttir_to_llir(mod, ctx) mod.dump()