1. Handle induction variable when step is negative 2. Restore async_wait that accidentally deleted 3. Add missing induction variable in prefetch 4. Add device property functions Co-authored-by: Philippe Tillet <Phil.Tillet@gmail.com>
216 lines
6.8 KiB
Python
216 lines
6.8 KiB
Python
import math
|
|
import random
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.testing import assert_close
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@pytest.mark.parametrize('num_warps, block_size, iter_size', [
|
|
[4, 256, 1],
|
|
[4, 1024, 256],
|
|
])
|
|
def test_vecadd_scf_no_mask(num_warps, block_size, iter_size):
|
|
|
|
@triton.jit
|
|
def kernel(x_ptr,
|
|
y_ptr,
|
|
z_ptr,
|
|
block_size,
|
|
iter_size: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
for i in range(0, block_size, iter_size):
|
|
offset = pid * block_size + tl.arange(0, iter_size)
|
|
x_ptrs = x_ptr + offset
|
|
y_ptrs = y_ptr + offset
|
|
|
|
x = tl.load(x_ptrs)
|
|
y = tl.load(y_ptrs)
|
|
z = x + y
|
|
z_ptrs = z_ptr + offset
|
|
tl.store(z_ptrs, z)
|
|
|
|
x_ptr += iter_size
|
|
y_ptr += iter_size
|
|
z_ptr += iter_size
|
|
|
|
x = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
|
y = torch.randn((block_size,), device='cuda', dtype=torch.float32)
|
|
z = torch.empty((block_size,), device=x.device, dtype=x.dtype)
|
|
|
|
grid = lambda EA: (x.shape.numel() // (block_size),)
|
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps)
|
|
|
|
golden_z = x + y
|
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
|
|
|
|
|
@pytest.mark.parametrize('shape, num_warps, block_size, iter_size', [
|
|
[(127, 3), 2, 128, 1],
|
|
[(127, 3), 2, 128, 32],
|
|
])
|
|
def test_vecadd_scf_mask(shape, num_warps, block_size, iter_size):
|
|
@triton.jit
|
|
def kernel(x_ptr,
|
|
y_ptr,
|
|
z_ptr,
|
|
num_elements,
|
|
block_size: tl.constexpr,
|
|
iter_size: tl.constexpr
|
|
):
|
|
'''
|
|
@block_size: size of a block
|
|
@iter_size: size of the iteration, a block has multiple iterations
|
|
@num_elements: number of elements
|
|
'''
|
|
pid = tl.program_id(axis=0)
|
|
for i in range(tl.cdiv(block_size, iter_size)):
|
|
# TODO: a bug here, if put the offset outside the forloop, there will be a GPU mis-aligned error.
|
|
offset = pid * block_size + tl.arange(0, iter_size)
|
|
x_ptrs = x_ptr + offset
|
|
y_ptrs = y_ptr + offset
|
|
|
|
x = tl.load(x_ptrs, mask=offset < num_elements)
|
|
y = tl.load(y_ptrs, mask=offset < num_elements)
|
|
z = x + y
|
|
z_ptrs = z_ptr + offset
|
|
tl.store(z_ptrs, z, mask=offset < num_elements)
|
|
|
|
x_ptr += iter_size
|
|
y_ptr += iter_size
|
|
z_ptr += iter_size
|
|
|
|
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
|
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
|
z = torch.empty(shape, device=x.device, dtype=x.dtype)
|
|
|
|
grid = lambda EA: (math.ceil(x.numel() / block_size),)
|
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z,
|
|
block_size=x.shape[0], iter_size=iter_size, num_warps=num_warps,
|
|
num_elements=x.numel())
|
|
|
|
golden_z = x + y
|
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
|
|
|
|
|
def vecadd_no_scf_tester(num_warps, block_size, shape):
|
|
@triton.jit
|
|
def kernel(x_ptr,
|
|
y_ptr,
|
|
z_ptr,
|
|
n_elements,
|
|
block_size_N: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
|
|
offset = pid * block_size_N + tl.arange(0, block_size_N)
|
|
x_ptrs = x_ptr + offset
|
|
y_ptrs = y_ptr + offset
|
|
|
|
mask = offset < n_elements
|
|
|
|
x = tl.load(x_ptrs, mask=mask)
|
|
y = tl.load(y_ptrs, mask=mask)
|
|
z = x + y
|
|
z_ptrs = z_ptr + offset
|
|
tl.store(z_ptrs, z, mask=mask)
|
|
|
|
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
|
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
|
z = torch.empty(shape, device=x.device, dtype=x.dtype)
|
|
|
|
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
|
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
|
|
|
|
golden_z = x + y
|
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
|
|
|
|
|
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
|
|
'''
|
|
vecadd tester with float comparison as load/store mask.
|
|
'''
|
|
@triton.jit
|
|
def kernel(x_ptr,
|
|
y_ptr,
|
|
z_ptr,
|
|
n_elements,
|
|
block_size_N: tl.constexpr):
|
|
pid = tl.program_id(axis=0)
|
|
|
|
offset = pid * block_size_N + tl.arange(0, block_size_N)
|
|
x_ptrs = x_ptr + offset
|
|
y_ptrs = y_ptr + offset
|
|
|
|
io_mask = offset < n_elements
|
|
x = tl.load(x_ptrs, mask=io_mask)
|
|
y = tl.load(y_ptrs, mask=io_mask)
|
|
|
|
z = x + y
|
|
val_mask = offset < n_elements and (z < 0. or z > 1.)
|
|
|
|
z_ptrs = z_ptr + offset
|
|
tl.store(z_ptrs, z, mask=val_mask)
|
|
|
|
x = torch.randn(shape, device='cuda', dtype=torch.float32)
|
|
y = torch.randn(shape, device='cuda', dtype=torch.float32)
|
|
z = torch.zeros(shape, device=x.device, dtype=x.dtype)
|
|
|
|
grid = lambda EA: (math.ceil(x.shape.numel() / block_size),)
|
|
kernel[grid](x_ptr=x, y_ptr=y, z_ptr=z, n_elements=x.shape.numel(), block_size_N=block_size, num_warps=num_warps)
|
|
|
|
golden_z: torch.Tensor = x + y
|
|
gz_data = torch.flatten(golden_z)
|
|
for i in range(golden_z.numel()):
|
|
gz_data[i] = gz_data[i] if gz_data[i] < 0. or gz_data[i] > 1. else 0.
|
|
|
|
assert_close(z, golden_z, rtol=1e-7, atol=1e-7)
|
|
|
|
|
|
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
|
[4, 256, (256,)],
|
|
[2, 256, (256,)],
|
|
[1, 256, (256,)],
|
|
[4, 16, (256,)],
|
|
[2, 64, (256,)],
|
|
[1, 128, (256,)],
|
|
])
|
|
def test_vecadd_no_scf(num_warps, block_size, shape):
|
|
vecadd_no_scf_tester(num_warps, block_size, shape)
|
|
|
|
|
|
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
|
[1, 128, (256 + 1,)],
|
|
[1, 256, (256 + 1,)],
|
|
[2, 256, (3, 256 + 7)],
|
|
[4, 256, (3, 256 + 7)],
|
|
])
|
|
def test_vecadd_no_scf_masked(num_warps, block_size, shape):
|
|
vecadd_no_scf_tester(num_warps, block_size, shape)
|
|
|
|
|
|
def test_vecadd_no_scf_masked_randomly():
|
|
random.seed(0) # fix seed to make random test reproducible
|
|
for i in range(10):
|
|
num_elements = random.randint(128, 2048)
|
|
shape = (num_elements,)
|
|
max_warps = num_elements // 32 # floor div
|
|
for num_warps in range(1, max_warps):
|
|
is_power2 = num_warps & (num_warps - 1) == 0 and num_warps != 0
|
|
if not is_power2: continue
|
|
block_size = min(32, num_warps * 32)
|
|
vecadd_no_scf_tester(num_warps, block_size, shape)
|
|
|
|
|
|
@pytest.mark.parametrize('num_warps, block_size, shape', [
|
|
[1, 128, (256 + 1,)],
|
|
[1, 256, (256 + 1,)],
|
|
[2, 256, (3, 256 + 7)],
|
|
[4, 256, (3, 256 + 7)],
|
|
])
|
|
def test_vecadd_fcmp_no_scf_masked(num_warps, block_size, shape):
|
|
vecadd_fcmp_no_scf_tester(num_warps, block_size, shape)
|