[BACKEND/CODE_GEN] Fixed float32 matmul problem (#380)
This commit is contained in:
@@ -788,7 +788,6 @@ void generator::visit_cat_inst(ir::cat_inst* x) {
|
||||
for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){
|
||||
vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]];
|
||||
}
|
||||
// std::cout << "!" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@@ -1660,13 +1659,17 @@ void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::va
|
||||
|
||||
std::map<indices_t, Value*> ret = vals_[D];
|
||||
std::map<std::pair<int, int>, Value*> has, hbs;
|
||||
auto ord = layout_c->get_order();
|
||||
for(unsigned k = 0; k < NK; k++){
|
||||
int z = 0;
|
||||
for(unsigned m = 0; m < shape_c[0]; m += layout_c->shape_per_cta(0))
|
||||
for(unsigned n = 0; n < shape_c[1]; n += layout_c->shape_per_cta(1))
|
||||
for(unsigned mm = 0; mm < layout_c->nts(0); mm++)
|
||||
for(unsigned nn = 0; nn < layout_c->nts(1); nn++)
|
||||
{
|
||||
for(unsigned i = 0; i < shape_c[ord[1]]; i += layout_c->shape_per_cta(ord[1]))
|
||||
for(unsigned j = 0; j < shape_c[ord[0]]; j += layout_c->shape_per_cta(ord[0]))
|
||||
for(unsigned ii = 0; ii < layout_c->nts(ord[1]); ii++)
|
||||
for(unsigned jj = 0; jj < layout_c->nts(ord[0]); jj++){
|
||||
unsigned m = (ord[0] == 1) ? i : j;
|
||||
unsigned n = (ord[0] == 1) ? j : i;
|
||||
unsigned mm = (ord[0] == 1) ? ii : jj;
|
||||
unsigned nn = (ord[0] == 1) ? jj : ii;
|
||||
if(has.find({m + mm, k}) == has.end()){
|
||||
Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k));
|
||||
Value* va = load(pa);
|
||||
|
@@ -455,8 +455,8 @@ def test_permute(dtype, shape, perm, device='cuda'):
|
||||
# test dot
|
||||
# ---------------
|
||||
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, device='cuda'):
|
||||
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
|
||||
def test_dot(epilogue, dtype=torch.float32, device='cuda'):
|
||||
torch.manual_seed(0)
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
@@ -483,11 +483,13 @@ def test_dot(epilogue, device='cuda'):
|
||||
tl.store(Zs, z)
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
x = triton.testing.random((M, K), dtype=torch.float16, device=device)
|
||||
y = triton.testing.random((K, N), dtype=torch.float16, device=device)
|
||||
x = triton.testing.random((M, K), dtype=dtype, device=device)
|
||||
y = triton.testing.random((K, N), dtype=dtype, device=device)
|
||||
# triton result
|
||||
z = triton.testing.random((M, N), dtype=torch.float16, device=device)
|
||||
z = triton.testing.random((M, N), dtype=dtype, device=device)
|
||||
z_tri = z.clone()
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
pgm = kernel[(1, 1)](x, x.stride(0), x.stride(1),
|
||||
y, y.stride(0), y.stride(1),
|
||||
z_tri, z_tri.stride(0), z_tri.stride(1),
|
||||
@@ -505,10 +507,9 @@ def test_dot(epilogue, device='cuda'):
|
||||
z_ref += z[0,:][None, :]
|
||||
z_ref = z_ref.to(torch.float16)
|
||||
# compare
|
||||
ptx = pgm.asm['ptx']
|
||||
# print(ptx)
|
||||
triton.testing.assert_almost_equal(z_tri, z_ref)
|
||||
# make sure ld/st are vectorized
|
||||
ptx = pgm.asm['ptx']
|
||||
assert 'ld.global.v4' in ptx
|
||||
assert 'st.global.v4' in ptx
|
||||
|
||||
|
Reference in New Issue
Block a user