[BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering (#69)

* [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering

* Clean code

Co-authored-by: goostavz <gzhu@nvidia.com>
Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
goostavz
2022-08-23 12:47:09 +08:00
committed by GitHub
parent 92ef552a54
commit de2dd04c8a
5 changed files with 70 additions and 35 deletions

View File

@@ -1,7 +1,12 @@
import torch
from torch.testing import assert_allclose
import triton
import triton.language as tl
import triton.runtime as runtime
NUM_WARPS = 4
BLOCK_SIZE = 256
# triton kernel
@@ -22,6 +27,31 @@ def test_vecadd_no_scf():
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
ptx, shem_size, kernel_name = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
print(ret)
torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device()
binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32",
device=device,
constants={"BLOCK_SIZE_N": BLOCK_SIZE},
num_warps=NUM_WARPS,
num_stages=3)
grid = lambda META: (1, )
x = torch.randn((256,), device='cuda', dtype=torch.float32)
y = torch.randn((256,), device='cuda', dtype=torch.float32)
z = torch.empty((256,), device=x.device, dtype=x.dtype)
runtime.launch_kernel(fn=kernel,
binary=binary,
grid=grid,
num_warps=NUM_WARPS,
num_stages=3,
x_ptr=x,
stride_xn=x.stride(0),
y_ptr=y,
stride_yn=y.stride(0),
z_ptr=z,
stride_zn=z.stride(0),
BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE))
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)