From 87413bc92522f14da4860adb506a8bc96c5e3a89 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sat, 25 Jun 2022 23:12:03 -0700 Subject: [PATCH] [BACKEND] Fix layout convert for non-contiguous input (#564) --- lib/codegen/selection/generator.cc | 15 ++++++++++----- python/test/unit/language/test_core.py | 12 ++++++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index ebd21732b..8d95a2790 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -2638,8 +2638,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ // Orders analysis::distributed_layout* in_layout = dynamic_cast(layouts_->get(in)); analysis::distributed_layout* out_layout = dynamic_cast(layouts_->get(out)); - auto in_ord = in_layout->get_order(); - auto out_ord = out_layout->get_order(); Value *base; base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out))))); base = bit_cast(base, ptr_ty(ty, 3)); @@ -2656,9 +2654,16 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){ in_ax.push_back(axes_.at(a_axes_->get(in, d)).values); out_ax.push_back(axes_.at(a_axes_->get(out, d)).values); } - in_ord = in_layout->to_mma() ? out_ord : in_ord; - out_ord = out_layout->to_mma() ? in_ord : out_ord; - int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]); + auto in_ord = + in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order(); + auto out_ord = + out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order(); + // out_ord[0] == 0 or in_order[0] == 0 means the first dimension is + // non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and + // and in_ord[0] are contiguous. + int in_vec = out_ord[0] == 0 ? 1 + : in_ord[0] == 0 ? 1 + : in_layout->contig_per_thread(in_ord[0]); int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]); int pad = std::max(in_vec, out_vec); Value *in_ld = i32(shape[in_ord[0]] + pad); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f1b4f899f..a5fb0acba 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -793,8 +793,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'): @pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) - for dtype in ['float32'] - for shape in [(128, 128)] + for dtype in ['float16', 'float32'] + for shape in [(64, 64), (128, 128)] for perm in [(1, 0)]]) def test_permute(dtype_str, shape, perm, device='cuda'): @@ -812,18 +812,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'): x = numpy_random(shape, dtype_str=dtype_str) # triton result z_tri = to_triton(np.empty_like(x), device=device) + z_tri_contiguous = to_triton(np.empty_like(x), device=device) x_tri = to_triton(x, device=device) pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), BLOCK_M=shape[0], BLOCK_N=shape[1]) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), + z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), + BLOCK_M=shape[0], BLOCK_N=shape[1]) # torch result z_ref = x.transpose(*perm) # compare triton.testing.assert_almost_equal(z_tri, z_ref) + triton.testing.assert_almost_equal(z_tri_contiguous, z_ref) # parse ptx to make sure ld/st are vectorized ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx + ptx = pgm_contiguous.asm['ptx'] + assert 'ld.global.v4' in ptx + assert 'st.global.v4' in ptx # --------------- # test dot