Simple assert

This commit is contained in:
Jokeren
2023-01-05 15:04:08 -05:00
parent bc73bbb12c
commit 2920f6f50f
10 changed files with 112 additions and 7 deletions

View File

@@ -408,8 +408,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
// Make PrintfOp // Make PrintfOp
// //
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>, def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
Arguments<(ins StrAttr:$prefix, Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
let summary = "Device-side printf, as in CUDA for debugging"; let summary = "Device-side printf, as in CUDA for debugging";
let description = [{ let description = [{
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. `tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
@@ -420,4 +419,14 @@ def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
}]; }];
} }
//
// Make AssertOp
//
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
let summary = "Device-side assert, as in CUDA for debugging";
let description = [{}];
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
}
#endif // Triton_OPS #endif // Triton_OPS

View File

@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"

View File

@@ -270,6 +270,46 @@ struct PrintfOpConversion
} }
}; };
struct AssertOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto ctx = rewriter.getContext();
auto voidTy = void_ty(ctx);
auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter);
Value ret;
for (auto elem : elems) {
auto type = elem.getType();
Value condition;
if (type.isIntOrFloat()) {
if (type.isSignedInteger() || type.isSignlessInteger()) {
condition = icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
} else {
condition = fcmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
loc, type, rewriter.getZeroAttr(type)));
}
} else {
assert(false && "Unsupported type for assert");
return failure();
}
// MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be
// handled by ptxas
PTXBuilder builder;
auto &trapOp = *builder.create<PTXInstr>("trap");
trapOp().predicate(condition);
ret = builder.launch(rewriter, loc, voidTy);
}
rewriter.replaceOp(op, ret);
return success();
}
};
struct MakeRangeOpConversion struct MakeRangeOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> { : public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
@@ -524,4 +564,5 @@ void populateTritonGPUToLLVMPatterns(
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit); patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
patterns.add<ReturnOpConversion>(typeConverter, benefit); patterns.add<ReturnOpConversion>(typeConverter, benefit);
patterns.add<PrintfOpConversion>(typeConverter, benefit); patterns.add<PrintfOpConversion>(typeConverter, benefit);
patterns.add<AssertOpConversion>(typeConverter, benefit);
} }

View File

@@ -45,6 +45,9 @@
#define fcmp_olt(lhs, rhs) \ #define fcmp_olt(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \ rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::olt, lhs, rhs) LLVM::FCmpPredicate::olt, lhs, rhs)
#define fcmp_eq(lhs, rhs) \
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
LLVM::FCmpPredicate::oeq, lhs, rhs)
#define icmp_eq(...) \ #define icmp_eq(...) \
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__) rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
#define icmp_ne(...) \ #define icmp_ne(...) \
@@ -77,6 +80,7 @@
#define f16_ty rewriter.getF16Type() #define f16_ty rewriter.getF16Type()
#define bf16_ty rewriter.getBF16Type() #define bf16_ty rewriter.getBF16Type()
#define i8_ty rewriter.getIntegerType(8) #define i8_ty rewriter.getIntegerType(8)
#define i1_ty rewriter.getI1Type()
#define f32_ty rewriter.getF32Type() #define f32_ty rewriter.getF32Type()
#define f64_ty rewriter.getF64Type() #define f64_ty rewriter.getF64Type()
#define vec_ty(type, num) VectorType::get(num, type) #define vec_ty(type, num) VectorType::get(num, type)

View File

@@ -453,10 +453,11 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
}; };
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> { struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
using OpConversionPattern<PrintfOp>::OpConversionPattern; using OpConversionPattern<triton::PrintfOp>::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor, matchAndRewrite(triton::PrintfOp op,
typename triton::PrintfOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(), rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
adaptor.getOperands()); adaptor.getOperands());
@@ -464,6 +465,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
} }
}; };
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::AssertOp op,
typename triton::AssertOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::AssertOp>(op, adaptor.condition(),
op.messageAttr());
return success();
}
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext(); MLIRContext *context = patterns.getContext();
@@ -478,7 +492,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
TritonAtomicRMWPattern>(typeConverter, context); TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
} }
// //

View File

@@ -1261,6 +1261,14 @@ void init_triton_ir(py::module &&m) {
llvm::StringRef(prefix)), llvm::StringRef(prefix)),
values); values);
}) })
.def("create_assert",
[](mlir::OpBuilder &self, mlir::Value &condition,
const std::string &message) -> void {
auto loc = self.getUnknownLoc();
auto messageAttr = mlir::StringAttr::get(self.getContext(),
llvm::StringRef(message));
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr);
})
// Undef // Undef
.def("create_undef", .def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value { [](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {

View File

@@ -52,5 +52,21 @@ def printf(data_type):
assert_close(y, x) assert_close(y, x)
printf("float16") def assert2(data_type):
printf("int8") @triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.assert2(x == 0, "x > 0")
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
#printf("float16")
#printf("int8")
assert2("float16")

View File

@@ -11,6 +11,7 @@ from .core import (
arange, arange,
argmin, argmin,
argmax, argmax,
assert2,
atomic_add, atomic_add,
atomic_and, atomic_and,
atomic_cas, atomic_cas,
@@ -98,6 +99,7 @@ __all__ = [
"arange", "arange",
"argmin", "argmin",
"argmax", "argmax",
"assert2",
"atomic_add", "atomic_add",
"atomic_and", "atomic_and",
"atomic_cas", "atomic_cas",

View File

@@ -1253,3 +1253,9 @@ def printf(prefix, *args, _builder=None):
for arg in args: for arg in args:
new_args.append(_to_tensor(arg, _builder)) new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder) return semantic.printf(new_prefix, new_args, _builder)
@builtin
def assert2(cond, msg="", _builder=None):
msg = _constexpr_to_value(msg)
return semantic.assert2(_to_tensor(cond, _builder), msg, _builder)

View File

@@ -1170,3 +1170,7 @@ def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor
for arg in args: for arg in args:
new_args.append(arg.handle) new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void) return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
def assert2(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)