[BACKEND] Fix layout convert for non-contiguous input (#564)
This commit is contained in:
@@ -2638,8 +2638,6 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
|||||||
// Orders
|
// Orders
|
||||||
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
|
||||||
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
|
||||||
auto in_ord = in_layout->get_order();
|
|
||||||
auto out_ord = out_layout->get_order();
|
|
||||||
Value *base;
|
Value *base;
|
||||||
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
|
base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(out)))));
|
||||||
base = bit_cast(base, ptr_ty(ty, 3));
|
base = bit_cast(base, ptr_ty(ty, 3));
|
||||||
@@ -2656,9 +2654,16 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
|
|||||||
in_ax.push_back(axes_.at(a_axes_->get(in, d)).values);
|
in_ax.push_back(axes_.at(a_axes_->get(in, d)).values);
|
||||||
out_ax.push_back(axes_.at(a_axes_->get(out, d)).values);
|
out_ax.push_back(axes_.at(a_axes_->get(out, d)).values);
|
||||||
}
|
}
|
||||||
in_ord = in_layout->to_mma() ? out_ord : in_ord;
|
auto in_ord =
|
||||||
out_ord = out_layout->to_mma() ? in_ord : out_ord;
|
in_layout->to_mma() ? out_layout->get_order() : in_layout->get_order();
|
||||||
int in_vec = out_ord[0] == 0 ? 1 : in_layout->contig_per_thread(in_ord[0]);
|
auto out_ord =
|
||||||
|
out_layout->to_mma() ? in_layout->get_order() : out_layout->get_order();
|
||||||
|
// out_ord[0] == 0 or in_order[0] == 0 means the first dimension is
|
||||||
|
// non-contiguous. in_vec can be greater than 0 only if both out_ord[0] and
|
||||||
|
// and in_ord[0] are contiguous.
|
||||||
|
int in_vec = out_ord[0] == 0 ? 1
|
||||||
|
: in_ord[0] == 0 ? 1
|
||||||
|
: in_layout->contig_per_thread(in_ord[0]);
|
||||||
int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
|
int out_vec = out_ord[0] == 0 ? 1 : out_layout->contig_per_thread(out_ord[0]);
|
||||||
int pad = std::max(in_vec, out_vec);
|
int pad = std::max(in_vec, out_vec);
|
||||||
Value *in_ld = i32(shape[in_ord[0]] + pad);
|
Value *in_ld = i32(shape[in_ord[0]] + pad);
|
||||||
|
@@ -793,8 +793,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("dtype_str, shape, perm",
|
@pytest.mark.parametrize("dtype_str, shape, perm",
|
||||||
[(dtype, shape, perm)
|
[(dtype, shape, perm)
|
||||||
for dtype in ['float32']
|
for dtype in ['float16', 'float32']
|
||||||
for shape in [(128, 128)]
|
for shape in [(64, 64), (128, 128)]
|
||||||
for perm in [(1, 0)]])
|
for perm in [(1, 0)]])
|
||||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||||
|
|
||||||
@@ -812,18 +812,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
x = numpy_random(shape, dtype_str=dtype_str)
|
x = numpy_random(shape, dtype_str=dtype_str)
|
||||||
# triton result
|
# triton result
|
||||||
z_tri = to_triton(np.empty_like(x), device=device)
|
z_tri = to_triton(np.empty_like(x), device=device)
|
||||||
|
z_tri_contiguous = to_triton(np.empty_like(x), device=device)
|
||||||
x_tri = to_triton(x, device=device)
|
x_tri = to_triton(x, device=device)
|
||||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||||
z_tri, z_tri.stride(1), z_tri.stride(0),
|
z_tri, z_tri.stride(1), z_tri.stride(0),
|
||||||
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||||
|
pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0),
|
||||||
|
z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1),
|
||||||
|
BLOCK_M=shape[0], BLOCK_N=shape[1])
|
||||||
# torch result
|
# torch result
|
||||||
z_ref = x.transpose(*perm)
|
z_ref = x.transpose(*perm)
|
||||||
# compare
|
# compare
|
||||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||||
|
triton.testing.assert_almost_equal(z_tri_contiguous, z_ref)
|
||||||
# parse ptx to make sure ld/st are vectorized
|
# parse ptx to make sure ld/st are vectorized
|
||||||
ptx = pgm.asm['ptx']
|
ptx = pgm.asm['ptx']
|
||||||
assert 'ld.global.v4' in ptx
|
assert 'ld.global.v4' in ptx
|
||||||
assert 'st.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
|
ptx = pgm_contiguous.asm['ptx']
|
||||||
|
assert 'ld.global.v4' in ptx
|
||||||
|
assert 'st.global.v4' in ptx
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# test dot
|
# test dot
|
||||||
|
Reference in New Issue
Block a user