diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index eeabb6841..3c4fae3d8 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -788,7 +788,6 @@ void generator::visit_cat_inst(ir::cat_inst* x) { for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; } -// std::cout << "!" << std::endl; } @@ -1660,13 +1659,17 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va std::map ret = vals_[D]; std::map, Value*> has, hbs; + auto ord = layout_c->get_order(); for(unsigned k = 0; k < NK; k++){ int z = 0; - for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0)) - for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1)) - for(unsigned mm = 0; mm < layout_c->nts(0); mm++) - for(unsigned nn = 0; nn < layout_c->nts(1); nn++) - { + for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1])) + for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0])) + for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++) + for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){ + unsigned m = (ord[0] == 1) ? i : j; + unsigned n = (ord[0] == 1) ? j : i; + unsigned mm = (ord[0] == 1) ? ii : jj; + unsigned nn = (ord[0] == 1) ? jj : ii; if(has.find({m + mm, k}) == has.end()){ Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k)); Value* va = load(pa); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 98c8c34fa..6359857fe 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -455,8 +455,8 @@ def test_permute(dtype, shape, perm, device='cuda'): # test dot # --------------- -@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols']) -def test_dot(epilogue, device='cuda'): +@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']) +def test_dot(epilogue, dtype=torch.float32, device='cuda'): torch.manual_seed(0) # triton kernel @triton.jit @@ -483,11 +483,13 @@ def test_dot(epilogue, device='cuda'): tl.store(Zs, z) # input M, N, K = 64, 64, 32 - x = triton.testing.random((M, K), dtype=torch.float16, device=device) - y = triton.testing.random((K, N), dtype=torch.float16, device=device) + x = triton.testing.random((M, K), dtype=dtype, device=device) + y = triton.testing.random((K, N), dtype=dtype, device=device) # triton result - z = triton.testing.random((M, N), dtype=torch.float16, device=device) + z = triton.testing.random((M, N), dtype=dtype, device=device) z_tri = z.clone() + if epilogue == 'trans': + z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1]) pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), y, y.stride(0), y.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), @@ -505,10 +507,9 @@ def test_dot(epilogue, device='cuda'): z_ref += z[0,:][None, :] z_ref = z_ref.to(torch.float16) # compare - ptx = pgm.asm['ptx'] - # print(ptx) triton.testing.assert_almost_equal(z_tri, z_ref) # make sure ld/st are vectorized + ptx = pgm.asm['ptx'] assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx