diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index eefccddb6..8dd25c94b 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -11,6 +11,10 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType def TT_BoolTensor : TensorOf<[I1]>; +def TT_BoolLike : AnyTypeOf<[I1, TT_BoolTensor]>; +def TT_IntegerLike : AnyTypeOf<[TT_Int, TT_IntegerTensor]>; +def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>; + class TTG_Op traits = []> : Op; @@ -38,10 +42,10 @@ def TTG_CmpIOp : TTG_Op<"cmpi"> { let description = [{}]; let arguments = (ins Arith_CmpIPredicateAttr:$predicate, - TT_IntegerTensor:$lhs, - TT_IntegerTensor:$rhs); + TT_IntegerLike:$lhs, + TT_IntegerLike:$rhs); - let results = (outs TT_BoolTensor:$result); + let results = (outs TT_BoolLike:$result); } def TTG_CmpFOp : TTG_Op<"cmpf"> { @@ -50,10 +54,10 @@ def TTG_CmpFOp : TTG_Op<"cmpf"> { let description = [{}]; let arguments = (ins Arith_CmpFPredicateAttr:$predicate, - TT_FloatTensor:$lhs, - TT_FloatTensor:$rhs); + TT_FloatLike:$lhs, + TT_FloatLike:$rhs); - let results = (outs TT_BoolTensor:$result); + let results = (outs TT_BoolLike:$result); } #endif diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp index dc372e27c..bb1af4b24 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPU.cpp @@ -41,6 +41,20 @@ public: } }; +class ArithConstantPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type retType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp( + op, retType, adaptor.getValue() + ); + return success(); + } +}; + class ConvertArithmeticOp: public ConversionPattern { public: ConvertArithmeticOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context) @@ -75,7 +89,8 @@ void populateArithmeticPatternsAndLegality( // ); // Rewrite rule // patterns.add(typeConverter, context); - patterns.add, + patterns.add, ArithBinaryPattern, ArithBinaryPattern, ArithBinaryPattern, @@ -106,10 +121,9 @@ void populateArithmeticPatternsAndLegality( ArithBinaryPattern, ArithBinaryPattern, // Cmp - // ArithCmpPattern, - // ArithCmpPattern ArithCmpPattern, ArithCmpPattern + // Cast Ops >(typeConverter, context); } @@ -205,7 +219,7 @@ public: ModuleOp mod = getOperation(); // int numThreads = mod.getAttr(); // type converter - TritonGPUTypeConverter typeConverter(context, /*numThreads*/128); + TritonGPUTypeConverter typeConverter(context, /*numThreads*/32); TritonGPUConversionTarget target(*context, typeConverter); // rewrite patterns RewritePatternSet patterns(context); diff --git a/rewrite-test/jit/matmul/matmul.py b/rewrite-test/jit/matmul/matmul.py index 05c7173d5..1e909dd50 100644 --- a/rewrite-test/jit/matmul/matmul.py +++ b/rewrite-test/jit/matmul/matmul.py @@ -90,19 +90,14 @@ mod, ctx = matmul_kernel.compile_to_ttir( a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), - 64, 64, 32, + 128, 128, 128, 8, grid=(2,) ) -# assert mod.verify() -# mod.dump() +assert mod.verify() +mod.dump() -pm = _triton.ir.pass_manager(ctx) -pm.add_inliner_pass() -pm.add_triton_combine_pass() -pm.add_canonicalizer_pass() -pm.add_convert_triton_to_tritongpu_pass() -pm.run(mod) +mod = matmul_kernel.compile_ttir_to_llir(mod, ctx) -# assert mod.verify() -# mod.dump() +assert mod.verify() +mod.dump()