[BACKEND] Support splat constant on the DotOperandLayout (#1008)
This commit is contained in:
@@ -1227,20 +1227,20 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
# FIXME: Unsupported layout found in ConvertSplatLikeOp
|
||||
# def test_dot_without_load():
|
||||
# @triton.jit
|
||||
# def kernel(out):
|
||||
# pid = tl.program_id(axis=0)
|
||||
# a = tl.zeros((32, 32), tl.float32)
|
||||
# b = tl.zeros((32, 32), tl.float32)
|
||||
# c = tl.zeros((32, 32), tl.float32)
|
||||
# c = tl.dot(a, b)
|
||||
# pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
# tl.store(pout, c)
|
||||
#
|
||||
# out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
# kernel[(1,)](out)
|
||||
|
||||
def test_dot_without_load():
|
||||
@triton.jit
|
||||
def kernel(out):
|
||||
pid = tl.program_id(axis=0)
|
||||
a = tl.zeros((32, 32), tl.float32)
|
||||
b = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.zeros((32, 32), tl.float32)
|
||||
c = tl.dot(a, b)
|
||||
pout = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
tl.store(pout, c)
|
||||
|
||||
out = torch.ones((32, 32), dtype=torch.float32, device="cuda")
|
||||
kernel[(1,)](out)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
|
Reference in New Issue
Block a user