diff --git a/lib/backend/templates/gemm.cpp b/lib/backend/templates/gemm.cpp index 4c0788364..c64e0971f 100644 --- a/lib/backend/templates/gemm.cpp +++ b/lib/backend/templates/gemm.cpp @@ -168,10 +168,12 @@ gemm_parameters::gemm_parameters(unsigned int simd_width stream << "B += offb;" << std::endl; stream << "C += offc;" << std::endl; - stream << "size_t gidx = " << GroupIdx0(backend) << ";" << std::endl; - stream << "size_t gidy = " << GroupIdx1(backend) << ";" << std::endl; - stream << "size_t idx = " << LocalIdx0(backend) << ";" << std::endl; - stream << "size_t idy = " << LocalIdx1(backend) << ";" << std::endl; + stream << "int4 ids = (int4)(" << GroupIdx0(backend) << "," << GroupIdx1(backend) << "," << LocalIdx0(backend) << "," << LocalIdx1(backend) << ");" << std::endl; + + stream << "size_t idt = " << p_.local_size_0 << "*ids.w + ids.z;" << std::endl; + stream << "int2 idT;" << std::endl; + stream << "idT.y = idt/" << p_.local_fetch_0 << ";" << std::endl; + stream << "idT.x = idt - " << p_.local_fetch_0 << "*idT.y;" << std::endl; if(has_depth) { @@ -181,39 +183,38 @@ gemm_parameters::gemm_parameters(unsigned int simd_width stream << "K = min(K - div*gidz, div);" << std::endl; } - stream << std::endl; - stream << "size_t idt = " << p_.local_size_0 << "*idy + idx;" << std::endl; - stream << "size_t idxT = idt % " << p_.local_fetch_0 << ";" << std::endl; - stream << "size_t idyT = idt / " << p_.local_fetch_0 << ";" << std::endl; - stream << std::endl; + stream << "ids.x *= " << p_.mL << ";" << std::endl; + stream << "ids.y *= " << p_.nL << ";" << std::endl; + + stream << "idT.x *= " << p_.simd_width << ";" << std::endl; if (A_trans_=='N') - stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*lda" << (has_depth?"+ offz*lda":"") << ";" << std::endl; + stream << "A += (idT.x + ids.x)" << ASTRIDE1 << " + idT.y*lda" << (has_depth?"+ offz*lda":"") << ";" << std::endl; else - stream << "A += idxT*" << p_.simd_width << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*lda" << (has_depth?"+ offz":"") << ";" << std::endl; + stream << "A += idT.x" << ASTRIDE1 << " + idT.y*lda + ids.x*lda" << (has_depth?"+ offz":"") << ";" << std::endl; if(B_trans_=='T') - stream << "B += (idxT*" << p_.simd_width << " + gidy*" << p_.nL << ")" << BSTRIDE1 << " + idyT*ldb" << (has_depth?"+ offz*ldb":"") << ";" << std::endl; + stream << "B += (idT.x + ids.y)" << BSTRIDE1 << " + idT.y*ldb" << (has_depth?"+ offz*ldb":"") << ";" << std::endl; else - stream << "B += idxT*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*ldb" << (has_depth?"+ offz":"") << ";" << std::endl; + stream << "B += idT.x" << BSTRIDE1 << " + idT.y*ldb + ids.y*ldb" << (has_depth?"+ offz":"") << ";" << std::endl; stream << "for(unsigned int i = 0 ; i < " << npA << " ; ++i) Ai[i] = A;" << std::endl; stream << "for(unsigned int i = 0 ; i < " << npB << " ; ++i) Bi[i] = B;" << std::endl; for(unsigned int i = 0 ; i < npA ; i++ ) if (A_trans_=='N') - stream << "if(gidx*" << p_.mL << " + idxT*" << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << ASTRIDE1 << ";" << std::endl; + stream << "if(ids.x + idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < M) Ai[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << ASTRIDE1 << ";" << std::endl; else - stream << "if(gidx*" << p_.mL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*lda;" << std::endl; + stream << "if(ids.x + idT.y + " << i << "*" << p_.local_fetch_1 << " < M) Ai[" << i << "] += " << i*p_.local_fetch_1 << "*lda;" << std::endl; for(unsigned int i = 0 ; i < npB ; i++ ) if (B_trans_=='T') - stream << "if(gidy*" << p_.nL << " + idxT* " << p_.simd_width << " + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl; + stream << "if(ids.y + idT.x + " << i << "*" << p_.local_fetch_0*p_.simd_width << " < N) Bi[" << i << "] += " << i*p_.local_fetch_0*p_.simd_width << BSTRIDE1 << ";" << std::endl; else - stream << "if(gidy*" << p_.nL << " + idyT + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl; + stream << "if(ids.y + idT.y + " << i << "*" << p_.local_fetch_1 << " < N) Bi[" << i << "] += " << i*p_.local_fetch_1 << "*ldb;" << std::endl; - stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idyT*" << llda << " + idxT*" << p_.simd_width << ";" << std::endl; - stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idyT*" << lldb << " + idxT*" << p_.simd_width << ";" << std::endl; + stream << LocalPtr(backend) << " " << sdtype << "* lAstore = lA + idT.y*" << llda << " + idT.x;" << std::endl; + stream << LocalPtr(backend) << " " << sdtype << "* lBstore = lB + idT.y*" << lldb << " + idT.x;" << std::endl; stream << "//Outer loop" << std::endl; stream << "for(long block_k=K; block_k > 0 ; block_k-=" << p_.kL << "){" << std::endl; @@ -229,7 +230,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width std::string mm = to_string(m/(p_.simd_width*p_.local_fetch_0)); std::string kk = to_string(k); string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*lda]"); - to_load = "(idyT + " + kk + "< block_k)?" + to_load + ":0"; + to_load = "(idT.y + " + kk + "< block_k)?" + to_load + ":0"; stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*llda+m)) << ";" << std::endl; } } @@ -241,7 +242,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width std::string mm = to_string(m/p_.local_fetch_1); std::string kk = to_string(k); string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]"); - to_load = "(idxT + " + kk + "< block_k)?" + to_load + ":0"; + to_load = "(idT.x + " + kk + "< block_k)?" + to_load + ":0"; stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*llda+k)) << ";" << std::endl; } } @@ -255,7 +256,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width std::string nn = to_string(n/(p_.simd_width*p_.local_fetch_0)); std::string kk = to_string(k); string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*ldb]"); - to_load = "(idyT + " + kk + "< block_k)?" + to_load + ":0"; + to_load = "(idT.y + " + kk + "< block_k)?" + to_load + ":0"; stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lldb+n)) << ";" << std::endl; } } @@ -267,21 +268,21 @@ gemm_parameters::gemm_parameters(unsigned int simd_width std::string nn = to_string(n/p_.local_fetch_1); std::string kk = to_string(k); string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"); - to_load = "(idxT + " + kk + "< block_k)?" + to_load + ":0"; + to_load = "(idT.x + " + kk + "< block_k)?" + to_load + ":0"; stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lldb+k)) << ";" << std::endl; } } stream << LocalBarrier(backend) << ";" << std::endl; if(A_trans_=='N') - stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << p_.simd_width << ";" << std::endl; + stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << p_.simd_width << ";" << std::endl; else - stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + idx*" << llda*p_.simd_width << ";" << std::endl; + stream << LocalPtr(backend) << " " << sdtype << "* readA = lA + ids.z*" << llda*p_.simd_width << ";" << std::endl; if(B_trans_=='T') - stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << p_.simd_width << ";" << std::endl; + stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << p_.simd_width << ";" << std::endl; else - stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + idy*" << lldb*p_.simd_width << ";" << std::endl; + stream << LocalPtr(backend) << " " << sdtype << "* readB = lB + ids.w*" << lldb*p_.simd_width << ";" << std::endl; stream << "//Inner loop" << std::endl; @@ -371,8 +372,8 @@ gemm_parameters::gemm_parameters(unsigned int simd_width stream << "//Write back C" << std::endl; - stream << "size_t offx = (gidx*" << p_.mL << " + idx*" << p_.simd_width << ")" << ";" << std::endl; - stream << "size_t offy = (gidy*" << p_.nL << " + idy*" << p_.simd_width << ");" << std::endl; + stream << "size_t offx = (ids.x + ids.z*" << p_.simd_width << ")" << ";" << std::endl; + stream << "size_t offy = (ids.y + ids.w*" << p_.simd_width << ");" << std::endl; stream << "C += " << "offx" << CSTRIDE1 << " + offy*ldc" << (has_depth?" + gidz*ldc*N;":"") << ";" << std::endl; stream << "N -= offy;" << std::endl; stream << "M -= offx;" << std::endl; diff --git a/tests/linalg/gemm.cpp b/tests/linalg/gemm.cpp index c4f4b839c..e01770ee8 100644 --- a/tests/linalg/gemm.cpp +++ b/tests/linalg/gemm.cpp @@ -101,9 +101,9 @@ void test_impl(T epsilon, ad::driver::Context const & ctx) int_t N = 256; int_t K = 293; - int_t SUBM = 64; - int_t SUBN = 64; - int_t SUBK = 64; + int_t SUBM = 7; + int_t SUBN = 13; + int_t SUBK = 41; { INIT_MATRIX(M, SUBM, 5, 1, N, SUBN, 7, 1, cC, C, ctx);