From bc98aead33892cd6e959268a77f10c17721e9eae Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 12 Oct 2022 14:32:27 -0700 Subject: [PATCH] [Backend] Fix for mov.u8 (#766) Init a potential fix for mov.u8 which is not supported by ptx for now. Use mov.u16 instead and cast it to u8. --- lib/codegen/selection/generator.cc | 4 ++- python/test/unit/language/test_core.py | 35 +++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 415adaab2..1cc461e5c 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1216,8 +1216,10 @@ void generator::visit_load_inst(ir::load_inst* x){ v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s); } v = bit_cast(v, IntegerType::get(*ctx_, width)); + // PTX doesn't support mov.u8, so we need to use mov.u16 + auto mov_width = width < 16 ? 16 : width; asm_oss << "\n "; - asm_oss << "@!$" << n_words << " mov.u" << width; + asm_oss << "@!$" << n_words << " mov.u" << mov_width; asm_oss << " $" << ii << ", "; std::ios_base::fmtflags flags(asm_oss.flags()); if(ConstantInt* cst = dyn_cast(v)) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 82a69cd42..d7d9130d5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -18,6 +18,7 @@ uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] float_dtypes = ['float16', 'float32', 'float64'] dtypes = int_dtypes + uint_dtypes + float_dtypes dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] def _bitwidth(dtype: str) -> int: @@ -1188,10 +1189,42 @@ def test_arange(start, device='cuda'): # --------------- # test load # --------------- + + +@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]]) +def test_masked_load(dtype_str, size, size_diff, device='cuda'): + dtype = getattr(torch, dtype_str) + check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested + + input_size = size - size_diff + output_size = size + if dtype_str == 'bool': + input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device) + elif dtype_str in int_dtypes or dtype_str in uint_dtypes: + input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device) + else: + input = torch.rand(input_size, dtype=dtype, device=device) + output = torch.zeros((output_size,), dtype=dtype, device=device) + + @triton.jit + def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): + in_offsets = tl.arange(0, out_size) + # Load inputs. + x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1.0) + # Store output + output_offsets = tl.arange(0, out_size) + tl.store(out_ptr + output_offsets, x) + + _kernel[(1,)](input, output, input_size, output_size) + + reference_out = input + reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device))) + triton.testing.allclose(output, reference_out) + + # 'bfloat16': torch.bfloat16, # Testing masked loads with an intermate copy to shared memory run. - @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device='cuda'): check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested