From 83f3b9165bd54a5a2af1b979c3a955ce0bebb780 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Thu, 8 Dec 2022 11:52:18 -0800 Subject: [PATCH] [FRONTEND][BACKEND] Fix bool and int8 load when the other operand is given (#968) --- .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 4 +- python/tests/test_core.py | 39 +++++++++++++++++-- python/triton/language/semantic.py | 6 ++- 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index fd4dfcf28..965f13467 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -973,8 +973,10 @@ struct LoadOpConversion if (other) { for (size_t ii = 0; ii < nWords; ++ii) { + // PTX doesn't support mov.u8, so we need to use mov.u16 + auto movWidth = width < 16 ? 16 : width; PTXInstr &mov = - ptxBuilder.create<>("mov")->o("u" + std::to_string(width)); + ptxBuilder.create<>("mov")->o("u" + std::to_string(movWidth)); size_t size = width / valueElemNbits; diff --git a/python/tests/test_core.py b/python/tests/test_core.py index abd951055..95ffde257 100644 --- a/python/tests/test_core.py +++ b/python/tests/test_core.py @@ -20,6 +20,7 @@ float_dtypes = ['float16', 'float32', 'float64'] dtypes = int_dtypes + uint_dtypes + float_dtypes # TODO: handle bfloat16 dtypes_with_bfloat16 = dtypes # + ['bfloat16'] +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes # + ['bfloat16'] def _bitwidth(dtype: str) -> int: @@ -1226,9 +1227,41 @@ def test_arange(start, device='cuda'): z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) triton.testing.assert_almost_equal(z_tri, z_ref) -# # --------------- -# # test load -# # --------------- +# --------------- +# test load +# --------------- + + +@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]]) +def test_masked_load(dtype_str, size, size_diff, device='cuda'): + dtype = getattr(torch, dtype_str) + check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size,), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1) + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + _kernel[(1,)](input, output, input_size, output_size) + + reference_out = input + reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device))) + triton.testing.allclose(output, reference_out) + # # 'bfloat16': torch.bfloat16, # # Testing masked loads with an intermate copy to shared memory run. diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index ecacf47ec..9d8e65a50 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -738,16 +738,18 @@ def load(ptr: tl.tensor, if other: other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) - if other: - other = cast(other, ptr.type.scalar.element_ty, builder) ptr_ty = ptr.type.scalar elt_ty = ptr_ty.element_ty + # treat bool* as tl.int8* if elt_ty == tl.int1: elt_ty = tl.int8 ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) ptr = cast(ptr, ptr_ty, builder) + if other: + other = cast(other, elt_ty, builder) + # cache modifier cache = ir.CACHE_MODIFIER.NONE # default if cache_modifier: