[BACKEND/CODE_GEN] Fixed float32 matmul problem (#380)

This commit is contained in:
Philippe Tillet
2021-11-30 22:00:56 -08:00
committed by GitHub
parent c86ad9c9ab
commit 8ec9f037bb
2 changed files with 17 additions and 13 deletions

View File

@@ -788,7 +788,6 @@ void generator::visit_cat_inst(ir::cat_inst* x) {
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
} }
// std::cout << "!" << std::endl;
} }
@@ -1660,13 +1659,17 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va
std::map<indices_t, Value*> ret = vals_[D]; std::map<indices_t, Value*> ret = vals_[D];
std::map<std::pair<int, int>, Value*> has, hbs; std::map<std::pair<int, int>, Value*> has, hbs;
auto ord = layout_c->get_order();
for(unsigned k = 0; k < NK; k++){ for(unsigned k = 0; k < NK; k++){
int z = 0; int z = 0;
for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0)) for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1]))
for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1)) for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0]))
for(unsigned mm = 0; mm < layout_c->nts(0); mm++) for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++)
for(unsigned nn = 0; nn < layout_c->nts(1); nn++) for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){
{ unsigned m = (ord[0] == 1) ? i : j;
unsigned n = (ord[0] == 1) ? j : i;
unsigned mm = (ord[0] == 1) ? ii : jj;
unsigned nn = (ord[0] == 1) ? jj : ii;
if(has.find({m + mm, k}) == has.end()){ if(has.find({m + mm, k}) == has.end()){
Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k)); Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k));
Value* va = load(pa); Value* va = load(pa);

View File

@@ -455,8 +455,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
# test dot # test dot
# --------------- # ---------------
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols']) @pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'): def test_dot(epilogue, dtype=torch.float32, device='cuda'):
torch.manual_seed(0) torch.manual_seed(0)
# triton kernel # triton kernel
@triton.jit @triton.jit
@@ -483,11 +483,13 @@ def test_dot(epilogue, device='cuda'):
tl.store(Zs, z) tl.store(Zs, z)
# input # input
M, N, K = 64, 64, 32 M, N, K = 64, 64, 32
x = triton.testing.random((M, K), dtype=torch.float16, device=device) x = triton.testing.random((M, K), dtype=dtype, device=device)
y = triton.testing.random((K, N), dtype=torch.float16, device=device) y = triton.testing.random((K, N), dtype=dtype, device=device)
# triton result # triton result
z = triton.testing.random((M, N), dtype=torch.float16, device=device) z = triton.testing.random((M, N), dtype=dtype, device=device)
z_tri = z.clone() z_tri = z.clone()
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1), pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
y, y.stride(0), y.stride(1), y, y.stride(0), y.stride(1),
z_tri, z_tri.stride(0), z_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1),
@@ -505,10 +507,9 @@ def test_dot(epilogue, device='cuda'):
z_ref += z[0,:][None, :] z_ref += z[0,:][None, :]
z_ref = z_ref.to(torch.float16) z_ref = z_ref.to(torch.float16)
# compare # compare
ptx = pgm.asm['ptx']
# print(ptx)
triton.testing.assert_almost_equal(z_tri, z_ref) triton.testing.assert_almost_equal(z_tri, z_ref)
# make sure ld/st are vectorized # make sure ld/st are vectorized
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx assert 'st.global.v4' in ptx