[CODEGEN] Pipeline fixup (#336)

This commit is contained in:
daadaada
2021-10-10 16:47:11 +08:00
committed by GitHub
parent d5f20dbce0
commit 9e9d781912
2 changed files with 42 additions and 17 deletions

View File

@@ -515,6 +515,20 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
def test_dot_without_load():
@triton.jit
def kernel(out, **meta):
pid = tl.program_id(axis=0)
a = tl.zeros((32, 32), tl.float32)
b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)
# ---------------
# test arange
# ---------------