Simple assert
This commit is contained in:
@@ -1261,6 +1261,14 @@ void init_triton_ir(py::module &&m) {
|
||||
llvm::StringRef(prefix)),
|
||||
values);
|
||||
})
|
||||
.def("create_assert",
|
||||
[](mlir::OpBuilder &self, mlir::Value &condition,
|
||||
const std::string &message) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto messageAttr = mlir::StringAttr::get(self.getContext(),
|
||||
llvm::StringRef(message));
|
||||
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr);
|
||||
})
|
||||
// Undef
|
||||
.def("create_undef",
|
||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||
|
@@ -52,5 +52,21 @@ def printf(data_type):
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
||||
def assert2(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.assert2(x == 0, "x > 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
#printf("float16")
|
||||
#printf("int8")
|
||||
assert2("float16")
|
@@ -11,6 +11,7 @@ from .core import (
|
||||
arange,
|
||||
argmin,
|
||||
argmax,
|
||||
assert2,
|
||||
atomic_add,
|
||||
atomic_and,
|
||||
atomic_cas,
|
||||
@@ -98,6 +99,7 @@ __all__ = [
|
||||
"arange",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"assert2",
|
||||
"atomic_add",
|
||||
"atomic_and",
|
||||
"atomic_cas",
|
||||
|
@@ -1253,3 +1253,9 @@ def printf(prefix, *args, _builder=None):
|
||||
for arg in args:
|
||||
new_args.append(_to_tensor(arg, _builder))
|
||||
return semantic.printf(new_prefix, new_args, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def assert2(cond, msg="", _builder=None):
|
||||
msg = _constexpr_to_value(msg)
|
||||
return semantic.assert2(_to_tensor(cond, _builder), msg, _builder)
|
||||
|
@@ -1170,3 +1170,7 @@ def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor
|
||||
for arg in args:
|
||||
new_args.append(arg.handle)
|
||||
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|
||||
|
||||
|
||||
def assert2(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
|
||||
|
Reference in New Issue
Block a user