[BACKEND] Fix layout convert for non-contiguous input (#564)

This commit is contained in:
Keren Zhou
2022-06-25 23:12:03 -07:00
committed by GitHub
parent d345ddf837
commit 87413bc925
2 changed files with 20 additions and 7 deletions

View File

@@ -2638,8 +2638,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
// Orders // Orders
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in)); analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out)); analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
auto in_ord = in_layout->get_order();
auto out_ord = out_layout->get_order();
Value *base; Value *base;
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out))))); base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
base = bit_cast(base, ptr_ty(ty, 3)); base = bit_cast(base, ptr_ty(ty, 3));
@@ -2656,9 +2654,16 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
in_ax.push_back(axes_.at(a_axes_->get(in, d)).values); in_ax.push_back(axes_.at(a_axes_->get(in, d)).values);
out_ax.push_back(axes_.at(a_axes_->get(out, d)).values); out_ax.push_back(axes_.at(a_axes_->get(out, d)).values);
} }
in_ord = in_layout->to_mma() ? out_ord : in_ord; auto in_ord =
out_ord = out_layout->to_mma() ? in_ord : out_ord; in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order();
int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]); auto out_ord =
out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order();
// out_ord[0] == 0 or in_order[0] == 0 means the first dimension is
// non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and
// and in_ord[0] are contiguous.
int in_vec = out_ord[0] == 0 ? 1
: in_ord[0] == 0 ? 1
: in_layout->contig_per_thread(in_ord[0]);
int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]); int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
int pad = std::max(in_vec, out_vec); int pad = std::max(in_vec, out_vec);
Value *in_ld = i32(shape[in_ord[0]] + pad); Value *in_ld = i32(shape[in_ord[0]] + pad);

View File

@@ -793,8 +793,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
@pytest.mark.parametrize("dtype_str, shape, perm", @pytest.mark.parametrize("dtype_str, shape, perm",
[(dtype, shape, perm) [(dtype, shape, perm)
for dtype in ['float32'] for dtype in ['float16', 'float32']
for shape in [(128, 128)] for shape in [(64, 64), (128, 128)]
for perm in [(1, 0)]]) for perm in [(1, 0)]])
def test_permute(dtype_str, shape, perm, device='cuda'): def test_permute(dtype_str, shape, perm, device='cuda'):
@@ -812,18 +812,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
x = numpy_random(shape, dtype_str=dtype_str) x = numpy_random(shape, dtype_str=dtype_str)
# triton result # triton result
z_tri = to_triton(np.empty_like(x), device=device) z_tri = to_triton(np.empty_like(x), device=device)
z_tri_contiguous = to_triton(np.empty_like(x), device=device)
x_tri = to_triton(x, device=device) x_tri = to_triton(x, device=device)
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
z_tri, z_tri.stride(1), z_tri.stride(0), z_tri, z_tri.stride(1), z_tri.stride(0),
BLOCK_M=shape[0], BLOCK_N=shape[1]) BLOCK_M=shape[0], BLOCK_N=shape[1])
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
BLOCK_M=shape[0], BLOCK_N=shape[1])
# torch result # torch result
z_ref = x.transpose(*perm) z_ref = x.transpose(*perm)
# compare # compare
triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri, z_ref)
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
# parse ptx to make sure ld/st are vectorized # parse ptx to make sure ld/st are vectorized
ptx = pgm.asm['ptx'] ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx assert 'st.global.v4' in ptx
ptx = pgm_contiguous.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
# --------------- # ---------------
# test dot # test dot