[BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering (#69)
* [BACKEND] two minor bugfix on StoreOpLowering and kernel launch & support optional other in LoadOpLowering * Clean code Co-authored-by: goostavz <gzhu@nvidia.com> Co-authored-by: Yan Chunwei <yanchunwei@outlook.com>
This commit is contained in:
@@ -1,7 +1,12 @@
|
||||
import torch
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.runtime as runtime
|
||||
|
||||
NUM_WARPS = 4
|
||||
BLOCK_SIZE = 256
|
||||
|
||||
# triton kernel
|
||||
|
||||
@@ -22,6 +27,31 @@ def test_vecadd_no_scf():
|
||||
z_ptrs = z_ptr + offset
|
||||
tl.store(z_ptrs, z)
|
||||
|
||||
ret = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
||||
ptx, shem_size, kernel_name = triton.compile(kernel, "*fp32,i32,*fp32,i32,*fp32,i32", constants={"BLOCK_SIZE_N": 256}, num_warps=NUM_WARPS, device=0, output="ptx")
|
||||
|
||||
print(ret)
|
||||
torch.zeros([10], device=torch.device('cuda'))
|
||||
device = torch.cuda.current_device()
|
||||
binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32",
|
||||
device=device,
|
||||
constants={"BLOCK_SIZE_N": BLOCK_SIZE},
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=3)
|
||||
grid = lambda META: (1, )
|
||||
|
||||
x = torch.randn((256,), device='cuda', dtype=torch.float32)
|
||||
y = torch.randn((256,), device='cuda', dtype=torch.float32)
|
||||
z = torch.empty((256,), device=x.device, dtype=x.dtype)
|
||||
runtime.launch_kernel(fn=kernel,
|
||||
binary=binary,
|
||||
grid=grid,
|
||||
num_warps=NUM_WARPS,
|
||||
num_stages=3,
|
||||
x_ptr=x,
|
||||
stride_xn=x.stride(0),
|
||||
y_ptr=y,
|
||||
stride_yn=y.stride(0),
|
||||
z_ptr=z,
|
||||
stride_zn=z.stride(0),
|
||||
BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE))
|
||||
golden_z = x + y
|
||||
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
|
||||
|
Reference in New Issue
Block a user