[Backend] Vectorize Load/Store Ops (#86)

This PR does the following things:

- Code refactoring on Load and Store op codegen, rewrite with same logic
and share much code
- Support the vectorized load/store
This commit is contained in:
Yan Chunwei
2022-09-07 03:28:09 +08:00
committed by GitHub
parent 35e346bcff
commit a9464f4993
10 changed files with 433 additions and 295 deletions

View File

@@ -29,7 +29,6 @@ def test_empty_kernel_cubin_compile():
def test_empty_kernel_launch():
device = torch.cuda.current_device()
binary = runtime.build_kernel(empty_kernel, "*fp32,i32,i32",
device=device,
constants={"BLOCK": 256},
num_warps=4,
num_stages=3)
@@ -38,11 +37,9 @@ def test_empty_kernel_launch():
)
A = torch.zeros([1024], device="cuda")
runtime.launch_kernel(fn=empty_kernel,
binary=binary,
runtime.launch_kernel(kernel=binary,
grid=grid,
num_warps=4,
num_stages=3,
device=device,
X=A,
stride_xm=256,
BLOCK=tl.constexpr(256))

View File

@@ -5,17 +5,12 @@ import triton
import triton.language as tl
import triton.runtime as runtime
NUM_WARPS = 4
BLOCK_SIZE = 256
# triton kernel
def test_vecadd_no_scf():
def vecadd_no_scf_tester(num_warps, block_size):
@triton.jit
def kernel(x_ptr, stride_xn,
y_ptr, stride_yn,
z_ptr, stride_zn,
def kernel(x_ptr,
y_ptr,
z_ptr,
BLOCK_SIZE_N: tl.constexpr):
pid = tl.program_id(axis=0)
offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -27,37 +22,35 @@ def test_vecadd_no_scf():
z_ptrs = z_ptr + offset
tl.store(z_ptrs, z)
# TODO: add this to CI, to make sure the the compilation flow is at lease OK
# before we have GPU machines for CI.
# ptx, shem_size, kernel_name = triton.compile(kernel,
# "*fp32,i32,*fp32,i32,*fp32,i32",
# constants={"BLOCK_SIZE_N": 256},
# num_warps=NUM_WARPS,
# device=0, output="ptx")
torch.zeros([10], device=torch.device('cuda'))
device = torch.cuda.current_device()
binary = runtime.build_kernel(kernel, "*fp32,i32,*fp32,i32,*fp32,i32",
device=device,
constants={"BLOCK_SIZE_N": BLOCK_SIZE},
num_warps=NUM_WARPS,
binary = runtime.build_kernel(kernel, "*fp32,*fp32,*fp32,i32",
constants={"BLOCK_SIZE_N": block_size},
num_warps=num_warps,
num_stages=3)
grid = lambda META: (1, )
x = torch.randn((256,), device='cuda', dtype=torch.float32)
y = torch.randn((256,), device='cuda', dtype=torch.float32)
z = torch.empty((256,), device=x.device, dtype=x.dtype)
runtime.launch_kernel(fn=kernel,
binary=binary,
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
assert x.shape.numel() % block_size == 0, "Only test load without mask here"
grid = lambda EA: (x.shape.numel() // block_size,)
runtime.launch_kernel(kernel=binary,
grid=grid,
num_warps=NUM_WARPS,
num_stages=3,
device=device,
x_ptr=x,
stride_xn=x.stride(0),
y_ptr=y,
stride_yn=y.stride(0),
z_ptr=z,
stride_zn=z.stride(0),
BLOCK_SIZE_N=tl.constexpr(BLOCK_SIZE))
BLOCK_SIZE_N=tl.constexpr(block_size))
golden_z = x + y
assert_allclose(z, golden_z, rtol=1e-7, atol=1e-7)
def test_vecadd_no_scf():
vecadd_no_scf_tester(num_warps=2, block_size=256)
vecadd_no_scf_tester(num_warps=1, block_size=256)
if __name__ == '__main__':
test_vecadd_no_scf()