Simple assert
This commit is contained in:
@@ -408,8 +408,7 @@ def TT_MakeRangeOp : TT_Op<"make_range", [NoSideEffect]> {
|
|||||||
// Make PrintfOp
|
// Make PrintfOp
|
||||||
//
|
//
|
||||||
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
||||||
Arguments<(ins StrAttr:$prefix,
|
Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||||
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
|
||||||
let summary = "Device-side printf, as in CUDA for debugging";
|
let summary = "Device-side printf, as in CUDA for debugging";
|
||||||
let description = [{
|
let description = [{
|
||||||
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
||||||
@@ -420,4 +419,14 @@ def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Make AssertOp
|
||||||
|
//
|
||||||
|
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
|
||||||
|
let summary = "Device-side assert, as in CUDA for debugging";
|
||||||
|
let description = [{}];
|
||||||
|
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
|
||||||
|
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
|
||||||
|
}
|
||||||
|
|
||||||
#endif // Triton_OPS
|
#endif // Triton_OPS
|
||||||
|
@@ -12,6 +12,7 @@
|
|||||||
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
@@ -270,6 +270,46 @@ struct PrintfOpConversion
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct AssertOpConversion
|
||||||
|
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
|
||||||
|
using ConvertTritonGPUOpToLLVMPattern<
|
||||||
|
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto ctx = rewriter.getContext();
|
||||||
|
auto voidTy = void_ty(ctx);
|
||||||
|
auto elems = getElementsFromStruct(loc, adaptor.condition(), rewriter);
|
||||||
|
Value ret;
|
||||||
|
for (auto elem : elems) {
|
||||||
|
auto type = elem.getType();
|
||||||
|
Value condition;
|
||||||
|
if (type.isIntOrFloat()) {
|
||||||
|
if (type.isSignedInteger() || type.isSignlessInteger()) {
|
||||||
|
condition = icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, type, rewriter.getZeroAttr(type)));
|
||||||
|
} else {
|
||||||
|
condition = fcmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, type, rewriter.getZeroAttr(type)));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert(false && "Unsupported type for assert");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// MLIR::AssertOp is lowered to a call to llvm.abort, which cannot be
|
||||||
|
// handled by ptxas
|
||||||
|
PTXBuilder builder;
|
||||||
|
auto &trapOp = *builder.create<PTXInstr>("trap");
|
||||||
|
trapOp().predicate(condition);
|
||||||
|
ret = builder.launch(rewriter, loc, voidTy);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, ret);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct MakeRangeOpConversion
|
struct MakeRangeOpConversion
|
||||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||||
|
|
||||||
@@ -524,4 +564,5 @@ void populateTritonGPUToLLVMPatterns(
|
|||||||
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||||
|
patterns.add<AssertOpConversion>(typeConverter, benefit);
|
||||||
}
|
}
|
@@ -45,6 +45,9 @@
|
|||||||
#define fcmp_olt(lhs, rhs) \
|
#define fcmp_olt(lhs, rhs) \
|
||||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||||
LLVM::FCmpPredicate::olt, lhs, rhs)
|
LLVM::FCmpPredicate::olt, lhs, rhs)
|
||||||
|
#define fcmp_eq(lhs, rhs) \
|
||||||
|
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||||
|
LLVM::FCmpPredicate::oeq, lhs, rhs)
|
||||||
#define icmp_eq(...) \
|
#define icmp_eq(...) \
|
||||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||||
#define icmp_ne(...) \
|
#define icmp_ne(...) \
|
||||||
@@ -77,6 +80,7 @@
|
|||||||
#define f16_ty rewriter.getF16Type()
|
#define f16_ty rewriter.getF16Type()
|
||||||
#define bf16_ty rewriter.getBF16Type()
|
#define bf16_ty rewriter.getBF16Type()
|
||||||
#define i8_ty rewriter.getIntegerType(8)
|
#define i8_ty rewriter.getIntegerType(8)
|
||||||
|
#define i1_ty rewriter.getI1Type()
|
||||||
#define f32_ty rewriter.getF32Type()
|
#define f32_ty rewriter.getF32Type()
|
||||||
#define f64_ty rewriter.getF64Type()
|
#define f64_ty rewriter.getF64Type()
|
||||||
#define vec_ty(type, num) VectorType::get(num, type)
|
#define vec_ty(type, num) VectorType::get(num, type)
|
||||||
|
@@ -453,10 +453,11 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||||
using OpConversionPattern<PrintfOp>::OpConversionPattern;
|
using OpConversionPattern<triton::PrintfOp>::OpConversionPattern;
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
|
matchAndRewrite(triton::PrintfOp op,
|
||||||
|
typename triton::PrintfOp::Adaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
|
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.prefixAttr(),
|
||||||
adaptor.getOperands());
|
adaptor.getOperands());
|
||||||
@@ -464,6 +465,19 @@ struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
|
||||||
|
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(triton::AssertOp op,
|
||||||
|
typename triton::AssertOp::Adaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<triton::AssertOp>(op, adaptor.condition(),
|
||||||
|
op.messageAttr());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
MLIRContext *context = patterns.getContext();
|
MLIRContext *context = patterns.getContext();
|
||||||
@@ -478,7 +492,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|||||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||||
TritonAtomicRMWPattern>(typeConverter, context);
|
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@@ -1261,6 +1261,14 @@ void init_triton_ir(py::module &&m) {
|
|||||||
llvm::StringRef(prefix)),
|
llvm::StringRef(prefix)),
|
||||||
values);
|
values);
|
||||||
})
|
})
|
||||||
|
.def("create_assert",
|
||||||
|
[](mlir::OpBuilder &self, mlir::Value &condition,
|
||||||
|
const std::string &message) -> void {
|
||||||
|
auto loc = self.getUnknownLoc();
|
||||||
|
auto messageAttr = mlir::StringAttr::get(self.getContext(),
|
||||||
|
llvm::StringRef(message));
|
||||||
|
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr);
|
||||||
|
})
|
||||||
// Undef
|
// Undef
|
||||||
.def("create_undef",
|
.def("create_undef",
|
||||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||||
|
@@ -52,5 +52,21 @@ def printf(data_type):
|
|||||||
assert_close(y, x)
|
assert_close(y, x)
|
||||||
|
|
||||||
|
|
||||||
printf("float16")
|
def assert2(data_type):
|
||||||
printf("int8")
|
@triton.jit
|
||||||
|
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||||
|
x = tl.load(X + tl.arange(0, BLOCK))
|
||||||
|
tl.assert2(x == 0, "x > 0")
|
||||||
|
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||||
|
|
||||||
|
shape = (128, )
|
||||||
|
# limit the range of integers so that the sum does not overflow
|
||||||
|
x = get_tensor(shape, data_type)
|
||||||
|
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||||
|
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||||
|
assert_close(y, x)
|
||||||
|
|
||||||
|
|
||||||
|
#printf("float16")
|
||||||
|
#printf("int8")
|
||||||
|
assert2("float16")
|
@@ -11,6 +11,7 @@ from .core import (
|
|||||||
arange,
|
arange,
|
||||||
argmin,
|
argmin,
|
||||||
argmax,
|
argmax,
|
||||||
|
assert2,
|
||||||
atomic_add,
|
atomic_add,
|
||||||
atomic_and,
|
atomic_and,
|
||||||
atomic_cas,
|
atomic_cas,
|
||||||
@@ -98,6 +99,7 @@ __all__ = [
|
|||||||
"arange",
|
"arange",
|
||||||
"argmin",
|
"argmin",
|
||||||
"argmax",
|
"argmax",
|
||||||
|
"assert2",
|
||||||
"atomic_add",
|
"atomic_add",
|
||||||
"atomic_and",
|
"atomic_and",
|
||||||
"atomic_cas",
|
"atomic_cas",
|
||||||
|
@@ -1253,3 +1253,9 @@ def printf(prefix, *args, _builder=None):
|
|||||||
for arg in args:
|
for arg in args:
|
||||||
new_args.append(_to_tensor(arg, _builder))
|
new_args.append(_to_tensor(arg, _builder))
|
||||||
return semantic.printf(new_prefix, new_args, _builder)
|
return semantic.printf(new_prefix, new_args, _builder)
|
||||||
|
|
||||||
|
|
||||||
|
@builtin
|
||||||
|
def assert2(cond, msg="", _builder=None):
|
||||||
|
msg = _constexpr_to_value(msg)
|
||||||
|
return semantic.assert2(_to_tensor(cond, _builder), msg, _builder)
|
||||||
|
@@ -1170,3 +1170,7 @@ def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor
|
|||||||
for arg in args:
|
for arg in args:
|
||||||
new_args.append(arg.handle)
|
new_args.append(arg.handle)
|
||||||
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|
||||||
|
|
||||||
|
|
||||||
|
def assert2(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
|
||||||
|
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
|
||||||
|
Reference in New Issue
Block a user