diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 9e3431ffc..eefccddb6 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -3,11 +3,14 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" // NoSideEffect include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +def TT_BoolTensor : TensorOf<[I1]>; + class TTG_Op traits = []> : Op; @@ -28,4 +31,29 @@ def TTG_AsyncWaitOp : TTG_Op<"async_wait"> { // def TTG_CopyAsyncOp : TTG_Op<"copy_async"> {} +// Port Arith_CmpIOp & Arith_CmpFOp to TritonGPU. +def TTG_CmpIOp : TTG_Op<"cmpi"> { + let summary = "integer comparison operation"; + + let description = [{}]; + + let arguments = (ins Arith_CmpIPredicateAttr:$predicate, + TT_IntegerTensor:$lhs, + TT_IntegerTensor:$rhs); + + let results = (outs TT_BoolTensor:$result); +} + +def TTG_CmpFOp : TTG_Op<"cmpf"> { + let summary = "floating-point comparison operation"; + + let description = [{}]; + + let arguments = (ins Arith_CmpFPredicateAttr:$predicate, + TT_FloatTensor:$lhs, + TT_FloatTensor:$rhs); + + let results = (outs TT_BoolTensor:$result); +} + #endif diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index bf35f393d..dc372e27c 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -26,15 +26,15 @@ public: } }; -template -class ArithCmpPattern : public OpConversionPattern { +template +class ArithCmpPattern : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, + LogicalResult matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); - Op res = rewriter.replaceOpWithNewOp( + DstOp res = rewriter.replaceOpWithNewOp( op, retType, adaptor.getPredicate(), adaptor.getLhs(), adaptor.getRhs() ); return success(); @@ -106,8 +106,10 @@ void populateArithmeticPatternsAndLegality( ArithBinaryPattern, ArithBinaryPattern, // Cmp - ArithCmpPattern, - ArithCmpPattern + // ArithCmpPattern, + // ArithCmpPattern + ArithCmpPattern, + ArithCmpPattern >(typeConverter, context); } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index a8f9e4880..e5416bf19 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -13,14 +13,14 @@ namespace triton { static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); if (auto tensorType = type.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), i1Type); + return RankedTensorType::get(tensorType.getShape(), i1Type, tensorType.getEncoding()); return Type(); } static Type getI32SameShape(Type type) { auto i32Type = IntegerType::get(type.getContext(), 32); if (auto tensorType = type.dyn_cast()) - return RankedTensorType::get(tensorType.getShape(), i32Type); + return RankedTensorType::get(tensorType.getShape(), i32Type, tensorType.getEncoding()); return Type(); } @@ -29,7 +29,7 @@ static Type getPointerTypeFromTensor(Type type) { Type elementType = tensorType.getElementType(); auto shape = tensorType.getShape(); PointerType ptrType = PointerType::get(elementType, 1); - return RankedTensorType::get(shape, ptrType); + return RankedTensorType::get(shape, ptrType, tensorType.getEncoding()); } return Type(); } diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index b596f77b9..c4fc36861 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -53,35 +53,27 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, TritonGPUConversionTarget::TritonGPUConversionTarget( MLIRContext &context, TritonGPUTypeConverter &typeConverter) : ConversionTarget(context), typeConverter(typeConverter) { - addLegalDialect(); + addLegalDialect(); // Some ops from SCF are illegal addIllegalOp(); addDynamicallyLegalDialect([&](Operation *op) { - if (typeConverter.isLegal(op)) { - // llvm::errs() << *op << " is dyanamically legal\n"; + if (typeConverter.isLegal(op)) return true; - } - // if (typeConverter.isLegal(op->getOperandTypes())) - // llvm::errs() << *op << " is illegal with legal operands\n"; - // if (typeConverter.isLegal(op->getResultTypes())) { - // llvm::errs() << *op << " is illegal with legal results\n"; - // llvm::errs() << "operand0: " << op->getOperand(0) << "\n" - // << "operand1: " << op->getOperand(1) << "\n"; - // } return false; }); addDynamicallyLegalDialect([&](Operation *op) { if (typeConverter.isLegal(op)) return true; - // llvm::errs() << *op << " is illegal\n" - // << "inside ...\n" - // << *op->getParentOp() << "\n"; + return false; + }); + + addDynamicallyLegalDialect([&](Operation *op) { + if (typeConverter.isLegal(op)) + return true; return false; });