[FRONTEND][BACKEND] Fix bool and int8 load when the other operand is given (#968)
This commit is contained in:
@@ -973,8 +973,10 @@ struct LoadOpConversion
|
||||
|
||||
if (other) {
|
||||
for (size_t ii = 0; ii < nWords; ++ii) {
|
||||
// PTX doesn't support mov.u8, so we need to use mov.u16
|
||||
auto movWidth = width < 16 ? 16 : width;
|
||||
PTXInstr &mov =
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(width));
|
||||
ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth));
|
||||
|
||||
size_t size = width / valueElemNbits;
|
||||
|
||||
|
@@ -20,6 +20,7 @@ float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
# TODO: handle bfloat16
|
||||
dtypes_with_bfloat16 = dtypes # + ['bfloat16']
|
||||
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes # + ['bfloat16']
|
||||
|
||||
|
||||
def _bitwidth(dtype: str) -> int:
|
||||
@@ -1226,9 +1227,41 @@ def test_arange(start, device='cuda'):
|
||||
z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
|
||||
# # ---------------
|
||||
# # test load
|
||||
# # ---------------
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
|
||||
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
input_size = size - size_diff
|
||||
output_size = size
|
||||
if dtype_str == 'bool':
|
||||
input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device)
|
||||
elif dtype_str in int_dtypes or dtype_str in uint_dtypes:
|
||||
input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device)
|
||||
else:
|
||||
input = torch.rand(input_size, dtype=dtype, device=device)
|
||||
output = torch.zeros((output_size,), dtype=dtype, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
||||
in_offsets = tl.arange(0, out_size)
|
||||
# Load inputs.
|
||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1)
|
||||
# Store output
|
||||
output_offsets = tl.arange(0, out_size)
|
||||
tl.store(out_ptr + output_offsets, x)
|
||||
|
||||
_kernel[(1,)](input, output, input_size, output_size)
|
||||
|
||||
reference_out = input
|
||||
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
|
||||
# # 'bfloat16': torch.bfloat16,
|
||||
# # Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
@@ -738,16 +738,18 @@ def load(ptr: tl.tensor,
|
||||
if other:
|
||||
other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder)
|
||||
|
||||
if other:
|
||||
other = cast(other, ptr.type.scalar.element_ty, builder)
|
||||
ptr_ty = ptr.type.scalar
|
||||
elt_ty = ptr_ty.element_ty
|
||||
|
||||
# treat bool* as tl.int8*
|
||||
if elt_ty == tl.int1:
|
||||
elt_ty = tl.int8
|
||||
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
|
||||
ptr = cast(ptr, ptr_ty, builder)
|
||||
|
||||
if other:
|
||||
other = cast(other, elt_ty, builder)
|
||||
|
||||
# cache modifier
|
||||
cache = ir.CACHE_MODIFIER.NONE # default
|
||||
if cache_modifier:
|
||||
|
Reference in New Issue
Block a user