diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index ef9597318..f938697e9 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -408,8 +408,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> { // Make PrintfOp // def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>, - Arguments<(ins StrAttr:$prefix, - Variadic>:$args)> { + Arguments<(ins StrAttr:$prefix, Variadic>:$args)> { let summary = "Device-side printf, as in CUDA for debugging"; let description = [{ `tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. @@ -420,4 +419,14 @@ def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>, }]; } +// +// Make AssertOp +// +def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "Device-side assert, as in CUDA for debugging"; + let description = [{}]; + let arguments = (ins TT_Tensor:$condition, StrAttr:$message); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; +} + #endif // Triton_OPS diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h index 2a360f4b1..66ddc66c7 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 2261688f0..020cdf45d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -270,6 +270,46 @@ struct PrintfOpConversion } }; +struct AssertOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using ConvertTritonGPUOpToLLVMPattern< + triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto voidTy = void_ty(ctx); + auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter); + Value ret; + for (auto elem : elems) { + auto type = elem.getType(); + Value condition; + if (type.isIntOrFloat()) { + if (type.isSignedInteger() || type.isSignlessInteger()) { + condition = icmp_eq(elem, rewriter.create( + loc, type, rewriter.getZeroAttr(type))); + } else { + condition = fcmp_eq(elem, rewriter.create( + loc, type, rewriter.getZeroAttr(type))); + } + } else { + assert(false && "Unsupported type for assert"); + return failure(); + } + // MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be + // handled by ptxas + PTXBuilder builder; + auto &trapOp = *builder.create("trap"); + trapOp().predicate(condition); + ret = builder.launch(rewriter, loc, voidTy); + } + rewriter.replaceOp(op, ret); + return success(); + } +}; + struct MakeRangeOpConversion : public ConvertTritonGPUOpToLLVMPattern { @@ -524,4 +564,5 @@ void populateTritonGPUToLLVMPatterns( patterns.add(typeConverter, indexCacheInfo, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } \ No newline at end of file diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 4d7c558fc..d8a0bd281 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -45,6 +45,9 @@ #define fcmp_olt(lhs, rhs) \ rewriter.create(loc, rewriter.getI1Type(), \ LLVM::FCmpPredicate::olt, lhs, rhs) +#define fcmp_eq(lhs, rhs) \ + rewriter.create(loc, rewriter.getI1Type(), \ + LLVM::FCmpPredicate::oeq, lhs, rhs) #define icmp_eq(...) \ rewriter.create(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) #define icmp_ne(...) \ @@ -77,6 +80,7 @@ #define f16_ty rewriter.getF16Type() #define bf16_ty rewriter.getBF16Type() #define i8_ty rewriter.getIntegerType(8) +#define i1_ty rewriter.getI1Type() #define f32_ty rewriter.getF32Type() #define f64_ty rewriter.getF64Type() #define vec_ty(type, num) VectorType::get(num, type) diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index ce5698289..8d1dfe3a0 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -453,10 +453,11 @@ struct TritonReducePattern : public OpConversionPattern { }; struct TritonPrintfPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor, + matchAndRewrite(triton::PrintfOp op, + typename triton::PrintfOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.prefixAttr(), adaptor.getOperands()); @@ -464,6 +465,19 @@ struct TritonPrintfPattern : public OpConversionPattern { } }; +struct TritonAssertPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AssertOp op, + typename triton::AssertOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.condition(), + op.messageAttr()); + return success(); + } +}; + void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); @@ -478,7 +492,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, - TritonAtomicRMWPattern>(typeConverter, context); + TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context); } // diff --git a/python/src/triton.cc b/python/src/triton.cc index 31f754add..a3a47f8ad 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1261,6 +1261,14 @@ void init_triton_ir(py::module &&m) { llvm::StringRef(prefix)), values); }) + .def("create_assert", + [](mlir::OpBuilder &self, mlir::Value &condition, + const std::string &message) -> void { + auto loc = self.getUnknownLoc(); + auto messageAttr = mlir::StringAttr::get(self.getContext(), + llvm::StringRef(message)); + self.create(loc, condition, messageAttr); + }) // Undef .def("create_undef", [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { diff --git a/python/test/unit/language/printf_helper.py b/python/test/unit/language/printf_helper.py index 22e1350f1..9dfef4c23 100644 --- a/python/test/unit/language/printf_helper.py +++ b/python/test/unit/language/printf_helper.py @@ -52,5 +52,21 @@ def printf(data_type): assert_close(y, x) -printf("float16") -printf("int8") +def assert2(data_type): + @triton.jit + def kernel(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + tl.assert2(x == 0, "x > 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + shape = (128, ) + # limit the range of integers so that the sum does not overflow + x = get_tensor(shape, data_type) + y = torch.zeros(shape, dtype=x.dtype, device="cuda") + kernel[(1,)](x, y, BLOCK=shape[0]) + assert_close(y, x) + + +#printf("float16") +#printf("int8") +assert2("float16") \ No newline at end of file diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 038e26bbe..9b374a852 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -11,6 +11,7 @@ from .core import ( arange, argmin, argmax, + assert2, atomic_add, atomic_and, atomic_cas, @@ -98,6 +99,7 @@ __all__ = [ "arange", "argmin", "argmax", + "assert2", "atomic_add", "atomic_and", "atomic_cas", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 15dd8462a..4edad9c01 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1253,3 +1253,9 @@ def printf(prefix, *args, _builder=None): for arg in args: new_args.append(_to_tensor(arg, _builder)) return semantic.printf(new_prefix, new_args, _builder) + + +@builtin +def assert2(cond, msg="", _builder=None): + msg = _constexpr_to_value(msg) + return semantic.assert2(_to_tensor(cond, _builder), msg, _builder) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index d65931b48..fff3791e4 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1170,3 +1170,7 @@ def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor for arg in args: new_args.append(arg.handle) return tl.tensor(builder.create_printf(prefix, new_args), tl.void) + + +def assert2(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)