Simple assert

This commit is contained in:
Jokeren
2023-01-05 15:04:08 -05:00
parent bc73bbb12c
commit 2920f6f50f
10 changed files with 112 additions and 7 deletions

View File

@@ -1261,6 +1261,14 @@ void init_triton_ir(py::module &&m) {
llvm::StringRef(prefix)),
values);
})
.def("create_assert",
[](mlir::OpBuilder &self, mlir::Value &condition,
const std::string &message) -> void {
auto loc = self.getUnknownLoc();
auto messageAttr = mlir::StringAttr::get(self.getContext(),
llvm::StringRef(message));
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr);
})
// Undef
.def("create_undef",
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {

View File

@@ -52,5 +52,21 @@ def printf(data_type):
assert_close(y, x)
printf("float16")
printf("int8")
def assert2(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.assert2(x == 0, "x > 0")
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
#printf("float16")
#printf("int8")
assert2("float16")

View File

@@ -11,6 +11,7 @@ from .core import (
arange,
argmin,
argmax,
assert2,
atomic_add,
atomic_and,
atomic_cas,
@@ -98,6 +99,7 @@ __all__ = [
"arange",
"argmin",
"argmax",
"assert2",
"atomic_add",
"atomic_and",
"atomic_cas",

View File

@@ -1253,3 +1253,9 @@ def printf(prefix, *args, _builder=None):
for arg in args:
new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder)
@builtin
def assert2(cond, msg="", _builder=None):
msg = _constexpr_to_value(msg)
return semantic.assert2(_to_tensor(cond, _builder), msg, _builder)

View File

@@ -1170,3 +1170,7 @@ def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor
for arg in args:
new_args.append(arg.handle)
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
def assert2(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)