[Backend] Fix for mov.u8 (#766)

Init a potential fix for mov.u8 which is not supported by ptx for now.
Use mov.u16 instead and cast it to u8.
This commit is contained in:
Keren Zhou
2022-10-12 14:32:27 -07:00
committed by GitHub
parent 71b46acc42
commit bc98aead33
2 changed files with 37 additions and 2 deletions

View File

@@ -1216,8 +1216,10 @@ void generator::visit_load_inst(ir::load_inst* x){
v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s);
}
v = bit_cast(v, IntegerType::get(*ctx_, width));
// PTX doesn't support mov.u8, so we need to use mov.u16
auto mov_width = width < 16 ? 16 : width;
asm_oss << "\n ";
asm_oss << "@!$" << n_words << " mov.u" << width;
asm_oss << "@!$" << n_words << " mov.u" << mov_width;
asm_oss << " $" << ii << ", ";
std::ios_base::fmtflags flags(asm_oss.flags());
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))

View File

@@ -18,6 +18,7 @@ uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
float_dtypes = ['float16', 'float32', 'float64']
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ['bfloat16']
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
def _bitwidth(dtype: str) -> int:
@@ -1188,10 +1189,42 @@ def test_arange(start, device='cuda'):
# ---------------
# test load
# ---------------
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
dtype = getattr(torch, dtype_str)
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
input_size = size - size_diff
output_size = size
if dtype_str == 'bool':
input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device)
elif dtype_str in int_dtypes or dtype_str in uint_dtypes:
input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device)
else:
input = torch.rand(input_size, dtype=dtype, device=device)
output = torch.zeros((output_size,), dtype=dtype, device=device)
@triton.jit
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
in_offsets = tl.arange(0, out_size)
# Load inputs.
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1.0)
# Store output
output_offsets = tl.arange(0, out_size)
tl.store(out_ptr + output_offsets, x)
_kernel[(1,)](input, output, input_size, output_size)
reference_out = input
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
triton.testing.allclose(output, reference_out)
# 'bfloat16': torch.bfloat16,
# Testing masked loads with an intermate copy to shared memory run.
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_masked_load_shared_memory(dtype, device='cuda'):
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested