Simple assert

This commit is contained in:
Jokeren
2023-01-05 15:04:08 -05:00
parent bc73bbb12c
commit 2920f6f50f
10 changed files with 112 additions and 7 deletions

View File

@@ -52,5 +52,21 @@ def printf(data_type):
assert_close(y, x)
printf("float16")
printf("int8")
def assert2(data_type):
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
tl.assert2(x == 0, "x > 0")
tl.store(Y + tl.arange(0, BLOCK), x)
shape = (128, )
# limit the range of integers so that the sum does not overflow
x = get_tensor(shape, data_type)
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
kernel[(1,)](x, y, BLOCK=shape[0])
assert_close(y, x)
#printf("float16")
#printf("int8")
assert2("float16")