[Backend] Fix for mov.u8 (#766)
Init a potential fix for mov.u8 which is not supported by ptx for now. Use mov.u16 instead and cast it to u8.
This commit is contained in:
@@ -1216,8 +1216,10 @@ void generator::visit_load_inst(ir::load_inst* x){
|
||||
v = insert_elt(v, vals_[false_val][idxs[i + ii*size + s]], s);
|
||||
}
|
||||
v = bit_cast(v, IntegerType::get(*ctx_, width));
|
||||
// PTX doesn't support mov.u8, so we need to use mov.u16
|
||||
auto mov_width = width < 16 ? 16 : width;
|
||||
asm_oss << "\n ";
|
||||
asm_oss << "@!$" << n_words << " mov.u" << width;
|
||||
asm_oss << "@!$" << n_words << " mov.u" << mov_width;
|
||||
asm_oss << " $" << ii << ", ";
|
||||
std::ios_base::fmtflags flags(asm_oss.flags());
|
||||
if(ConstantInt* cst = dyn_cast<ConstantInt>(v))
|
||||
|
@@ -18,6 +18,7 @@ uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
float_dtypes = ['float16', 'float32', 'float64']
|
||||
dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
||||
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
|
||||
|
||||
|
||||
def _bitwidth(dtype: str) -> int:
|
||||
@@ -1188,10 +1189,42 @@ def test_arange(start, device='cuda'):
|
||||
# ---------------
|
||||
# test load
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [1, 2, 3, 4]])
|
||||
def test_masked_load(dtype_str, size, size_diff, device='cuda'):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
input_size = size - size_diff
|
||||
output_size = size
|
||||
if dtype_str == 'bool':
|
||||
input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device)
|
||||
elif dtype_str in int_dtypes or dtype_str in uint_dtypes:
|
||||
input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device)
|
||||
else:
|
||||
input = torch.rand(input_size, dtype=dtype, device=device)
|
||||
output = torch.zeros((output_size,), dtype=dtype, device=device)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr):
|
||||
in_offsets = tl.arange(0, out_size)
|
||||
# Load inputs.
|
||||
x = tl.load(in_ptr + in_offsets, mask=in_offsets < in_size, other=1.0)
|
||||
# Store output
|
||||
output_offsets = tl.arange(0, out_size)
|
||||
tl.store(out_ptr + output_offsets, x)
|
||||
|
||||
_kernel[(1,)](input, output, input_size, output_size)
|
||||
|
||||
reference_out = input
|
||||
reference_out = torch.cat((reference_out, torch.ones((size_diff,), dtype=dtype, device=device)))
|
||||
triton.testing.allclose(output, reference_out)
|
||||
|
||||
|
||||
# 'bfloat16': torch.bfloat16,
|
||||
# Testing masked loads with an intermate copy to shared memory run.
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
|
||||
def test_masked_load_shared_memory(dtype, device='cuda'):
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
Reference in New Issue
Block a user