[BACKEND] Added support for scalars in LoadOp / StoreOp / ElementwiseOp (#814)
Also fixed various errors that showed up in `test_core.py`, and added more TODOs for open (hopefully relatively minor) issues
This commit is contained in:
@@ -556,45 +556,45 @@ def make_ptr_str(name, shape):
|
||||
# # ---------------
|
||||
|
||||
|
||||
# @triton.jit
|
||||
# def fn(a, b):
|
||||
# return a + b, \
|
||||
# a - b, \
|
||||
# a * b
|
||||
@triton.jit
|
||||
def fn(a, b):
|
||||
return a + b, \
|
||||
a - b, \
|
||||
a * b
|
||||
|
||||
|
||||
# def test_tuples():
|
||||
# device = 'cuda'
|
||||
def test_tuples():
|
||||
device = 'cuda'
|
||||
|
||||
# @triton.jit
|
||||
# def with_fn(X, Y, A, B, C):
|
||||
# x = tl.load(X)
|
||||
# y = tl.load(Y)
|
||||
# a, b, c = fn(x, y)
|
||||
# tl.store(A, a)
|
||||
# tl.store(B, b)
|
||||
# tl.store(C, c)
|
||||
@triton.jit
|
||||
def with_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = fn(x, y)
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
|
||||
# @triton.jit
|
||||
# def without_fn(X, Y, A, B, C):
|
||||
# x = tl.load(X)
|
||||
# y = tl.load(Y)
|
||||
# a, b, c = x + y, x - y, x * y
|
||||
# tl.store(A, a)
|
||||
# tl.store(B, b)
|
||||
# tl.store(C, c)
|
||||
@triton.jit
|
||||
def without_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = x + y, x - y, x * y
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
|
||||
# x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
||||
# y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
||||
# a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
# b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
# c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
# for kernel in [with_fn, without_fn]:
|
||||
# kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
||||
# a_ref, b_ref, c_ref = x + y, x - y, x * y
|
||||
# assert a_tri == a_ref
|
||||
# assert b_tri == b_ref
|
||||
# assert c_tri == c_ref
|
||||
x = torch.tensor([1.3], device=device, dtype=torch.float32)
|
||||
y = torch.tensor([1.9], device=device, dtype=torch.float32)
|
||||
a_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
b_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
c_tri = torch.tensor([0], device=device, dtype=torch.float32)
|
||||
for kernel in [with_fn, without_fn]:
|
||||
kernel[(1, )](x, y, a_tri, b_tri, c_tri, num_warps=1)
|
||||
a_ref, b_ref, c_ref = x + y, x - y, x * y
|
||||
assert a_tri == a_ref
|
||||
assert b_tri == b_ref
|
||||
assert c_tri == c_ref
|
||||
|
||||
|
||||
# # ---------------
|
||||
@@ -709,75 +709,77 @@ def make_ptr_str(name, shape):
|
||||
# # ---------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
# (dtype_x, dtype_z, False)
|
||||
# for dtype_x in dtypes
|
||||
# for dtype_z in dtypes
|
||||
# ] + [
|
||||
# ('float32', 'bfloat16', False),
|
||||
# ('bfloat16', 'float32', False),
|
||||
# ('float32', 'int32', True),
|
||||
# ('float32', 'int1', False),
|
||||
# ] + [
|
||||
# (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||
# ] + [
|
||||
# (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
# ])
|
||||
# def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
# x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
# if dtype_x in float_dtypes and dtype_z == 'int1':
|
||||
# x0 = 0.5
|
||||
# if dtype_x.startswith('bfloat'):
|
||||
# x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
||||
# else:
|
||||
# x = np.array([x0], dtype=getattr(np, dtype_x))
|
||||
# x_tri = to_triton(x)
|
||||
@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [
|
||||
(dtype_x, dtype_z, False)
|
||||
for dtype_x in dtypes
|
||||
for dtype_z in dtypes
|
||||
] + [
|
||||
# TODO:
|
||||
# ('float32', 'bfloat16', False),
|
||||
# ('bfloat16', 'float32', False),
|
||||
('float32', 'int32', True),
|
||||
# TODO:
|
||||
# ('float32', 'int1', False),
|
||||
] + [
|
||||
(f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64]
|
||||
] + [
|
||||
(f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64]
|
||||
])
|
||||
def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
x0 = 43 if dtype_x in int_dtypes else 43.5
|
||||
if dtype_x in float_dtypes and dtype_z == 'int1':
|
||||
x0 = 0.5
|
||||
if dtype_x.startswith('bfloat'):
|
||||
x_tri = torch.tensor([x0], dtype=getattr(torch, dtype_x), device=device)
|
||||
else:
|
||||
x = np.array([x0], dtype=getattr(np, dtype_x))
|
||||
x_tri = to_triton(x)
|
||||
|
||||
# # triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(X, Z, BITCAST: tl.constexpr):
|
||||
# x = tl.load(X)
|
||||
# z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||
# tl.store(Z, z)
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, Z, BITCAST: tl.constexpr):
|
||||
x = tl.load(X)
|
||||
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
|
||||
tl.store(Z, z)
|
||||
|
||||
# dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||
# # triton result
|
||||
# if dtype_z.startswith('bfloat'):
|
||||
# z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
||||
# else:
|
||||
# z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||
# kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
||||
# # torch result
|
||||
# if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
||||
# assert bitcast is False
|
||||
# z_ref = x_tri.to(z_tri.dtype)
|
||||
# assert z_tri == z_ref
|
||||
# else:
|
||||
# if bitcast:
|
||||
# z_ref = x.view(getattr(np, dtype_z_np))
|
||||
# else:
|
||||
# z_ref = x.astype(getattr(np, dtype_z_np))
|
||||
# assert to_numpy(z_tri) == z_ref
|
||||
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
|
||||
# triton result
|
||||
if dtype_z.startswith('bfloat'):
|
||||
z_tri = torch.empty((1,), dtype=getattr(torch, dtype_z), device=device)
|
||||
else:
|
||||
z_tri = to_triton(np.empty((1, ), dtype=getattr(np, dtype_z_np)), device=device)
|
||||
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast)
|
||||
# torch result
|
||||
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'):
|
||||
assert bitcast is False
|
||||
z_ref = x_tri.to(z_tri.dtype)
|
||||
assert z_tri == z_ref
|
||||
else:
|
||||
if bitcast:
|
||||
z_ref = x.view(getattr(np, dtype_z_np))
|
||||
else:
|
||||
z_ref = x.astype(getattr(np, dtype_z_np))
|
||||
assert to_numpy(z_tri) == z_ref
|
||||
|
||||
|
||||
# def test_store_bool():
|
||||
# """Tests that boolean True is stored as 1"""
|
||||
# @triton.jit
|
||||
# def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
# offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
# mask = offsets < n_elements
|
||||
# input = tl.load(input_ptr + offsets, mask=mask)
|
||||
# output = input
|
||||
# tl.store(output_ptr + offsets, output, mask=mask)
|
||||
def test_store_bool():
|
||||
"""Tests that boolean True is stored as 1"""
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
input = tl.load(input_ptr + offsets, mask=mask)
|
||||
output = input
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
# src = torch.tensor([True, False], dtype=torch.bool, device='cuda')
|
||||
# n_elements = src.numel()
|
||||
# dst = torch.empty_like(src)
|
||||
# grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
# copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024)
|
||||
src = torch.tensor([True, False], dtype=torch.bool, device='cuda')
|
||||
n_elements = src.numel()
|
||||
dst = torch.empty_like(src)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](src, dst, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
# assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
||||
assert (to_numpy(src).view('uint8') == to_numpy(dst).view('uint8')).all()
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@@ -990,48 +992,49 @@ def make_ptr_str(name, shape):
|
||||
# # ---------------
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
# [(dtype, shape, perm)
|
||||
# for dtype in ['bfloat16', 'float16', 'float32']
|
||||
# for shape in [(64, 64), (128, 128)]
|
||||
# for perm in [(1, 0)]])
|
||||
# def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||
[(dtype, shape, perm)
|
||||
# TODO: bfloat16
|
||||
for dtype in ['float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
# # triton kernel
|
||||
# @triton.jit
|
||||
# def kernel(X, stride_xm, stride_xn,
|
||||
# Z, stride_zm, stride_zn,
|
||||
# BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
# off_m = tl.arange(0, BLOCK_M)
|
||||
# off_n = tl.arange(0, BLOCK_N)
|
||||
# Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
# Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
# tl.store(Zs, tl.load(Xs))
|
||||
# # input
|
||||
# x = numpy_random(shape, dtype_str=dtype_str)
|
||||
# # triton result
|
||||
# z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
# z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
# x_tri = to_triton(x, device=device, dst_type=dtype_str)
|
||||
# pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
# z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||
# BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
||||
# z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||
# BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# # numpy result
|
||||
# z_ref = x.transpose(*perm)
|
||||
# # compare
|
||||
# triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||
# # parse ptx to make sure ld/st are vectorized
|
||||
# ptx = pgm.asm['ptx']
|
||||
# assert 'ld.global.v4' in ptx
|
||||
# assert 'st.global.v4' in ptx
|
||||
# ptx = pgm_contiguous.asm['ptx']
|
||||
# assert 'ld.global.v4' in ptx
|
||||
# assert 'st.global.v4' in ptx
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xn,
|
||||
Z, stride_zm, stride_zn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
off_m = tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn
|
||||
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
|
||||
tl.store(Zs, tl.load(Xs))
|
||||
# input
|
||||
x = numpy_random(shape, dtype_str=dtype_str)
|
||||
# triton result
|
||||
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str)
|
||||
x_tri = to_triton(x, device=device, dst_type=dtype_str)
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
||||
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||
# numpy result
|
||||
z_ref = x.transpose(*perm)
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||
# parse ptx to make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
ptx = pgm_contiguous.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
# # ---------------
|
||||
# # test dot
|
||||
|
Reference in New Issue
Block a user