diff --git a/lib/backend/templates/mproduct.cpp b/lib/backend/templates/mproduct.cpp index da9a32f08..522ba6843 100644 --- a/lib/backend/templates/mproduct.cpp +++ b/lib/backend/templates/mproduct.cpp @@ -49,7 +49,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL) throw operation_not_supported_exception("Only local memory is supported for GEMM"); - if(p_.depth > 1 && M*N*p_.depth > 1e6) + if(p_.depth > 1 && M*N*p_.depth > 2e6) throw operation_not_supported_exception("This would necessitate a temporary larger than 1MB"); if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0) @@ -192,12 +192,12 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width if (A_trans_=='N') stream << "A += (idxT*" << p_.simd_width << " + gidx*" << p_.mL<< ")" << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl; else - stream << "A += idxT" << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*Ald + offz;" << std::endl; + stream << "A += (idxT)*" << p_.simd_width << ASTRIDE1 << " + (idyT + gidx*" << p_.mL/p_.simd_width << ")*Ald + offz;" << std::endl; if(B_trans_=='T') stream << "B += (idxT*" << p_.simd_width << " + gidy*" << p_.nL << ")" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl; else - stream << "B += idxT" << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*Bld + offz;" << std::endl; + stream << "B += (idxT)*" << p_.simd_width << BSTRIDE1 << " + (idyT + gidy*" << p_.nL << ")*Bld + offz;" << std::endl; stream << "__global " << sdtype << "* Ai[" << npA << "];" << std::endl; @@ -249,7 +249,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width std::string kk = to_string(k); string to_load = VLOAD("0" ,"&Ai[" + mm +"][" + kk + "*Ald]"); if(check_bounds_) - to_load = "(block_k + idyT + " + kk + "< K)?" + to_load + ":0"; + to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0"; stream << VSTORE(to_load, "0", "lAstore + " + to_string(k*lAld+m)) << ";" << std::endl; } else @@ -258,16 +258,16 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width { std::string mm = to_string(k/p_.local_fetch_1); std::string kk = to_string(m/p_.simd_width); - string to_load = "Ai[" + mm + "][" + kk + ASTRIDE1 + "]"; + string to_load = VLOAD("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]"); if(check_bounds_) - to_load = "(block_k + idxT + " + kk + "< K)?" + to_load + ":0"; + to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0"; if(p_.simd_width==1) stream << VSTORE(to_load, "0", "lAstore + " + to_string(m*lAld+k)) << ";" << std::endl; else { - stream << vdtype << " tmpA" << k << m << " = " << to_load << ";" << std::endl; + stream << vdtype << " tmpA" << k << "_" << m << " = " << to_load << ";" << std::endl; for(unsigned int s = 0 ; s < p_.simd_width ; ++s) - stream << "lAstore[" << k + (m + s)*lAld << "]= tmpA" << k << m << ".s" << s << ";" << std::endl; + stream << "lAstore[" << k + (m + s)*lAld << "]= tmpA" << k << "_" << m << ".s" << s << ";" << std::endl; } } @@ -281,7 +281,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width std::string kk = to_string(k); string to_load = VLOAD("0", "&Bi[" + nn + "][" + kk + "*Bld]"); if(check_bounds_) - to_load = "(block_k + idyT + " + kk + "< K)?" + to_load + ":0"; + to_load = "(block_k + idyT + " + kk + "< chunk_size)?" + to_load + ":0"; stream << VSTORE(to_load, "0", "lBstore + " + to_string(k*lBld+n)) << ";" << std::endl; } else @@ -289,17 +289,17 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width for(int_t n = 0; n < p_.kL; n += p_.local_fetch_0*p_.simd_width) { std::string nn = to_string(k/p_.local_fetch_1); - std::string kk = to_string(n/p_.simd_width); - string to_load = "Bi[" + nn + "][" + kk + BSTRIDE1 + "]"; + std::string kk = to_string(n); + string to_load = VLOAD("0","&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"); if(check_bounds_) - to_load = "(block_k + idxT + " + kk + "< K)?" + to_load + ":0"; + to_load = "(block_k + idxT + " + kk + "< chunk_size)?" + to_load + ":0"; if(p_.simd_width==1) stream << VSTORE(to_load, "0", "lBstore + " + to_string(n*lBld+k)) << ";" << std::endl; else { - stream << vdtype << " tmpB" << k << n << " = " << to_load << ";" << std::endl; + stream << vdtype << " tmpB" << k << "_" << n << " = " << to_load << ";" << std::endl; for(unsigned int s = 0 ; s < p_.simd_width ; ++s) - stream << "lBstore[" << k + (n + s)*lBld << "]= tmpB" << k << n << ".s" << s << ";" << std::endl; + stream << "lBstore[" << k + (n + s)*lBld << "]= tmpB" << k << "_" << n << ".s" << s << ";" << std::endl; } } stream << LocalBarrier(backend) << ";" << std::endl; @@ -361,7 +361,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width stream << "Ai[" << i << "] += " << p_.kL << "*Ald;" << std::endl; else for(unsigned int i = 0 ; i < npA ; ++i) - stream << "Ai[" << i << "] += " << p_.kL/p_.simd_width << ASTRIDE1 << ";" << std::endl; + stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl; //Increment B pointers to global memory if (B_trans_=='T') @@ -369,7 +369,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width stream << "Bi[" << i << "] += " << p_.kL << "*Bld;" << std::endl; else for(unsigned int i = 0 ; i < npB ; ++i) - stream << "Bi[" << i << "] += " << p_.kL/p_.simd_width << BSTRIDE1 << ";" << std::endl; + stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl; stream.dec_tab(); stream << "}" << std::endl; @@ -429,8 +429,8 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width stream << "}" << std::endl; } - if(p_.simd_width>1) - std::cout << stream.str() << std::endl; +// if(p_.simd_width>1) +// std::cout << stream.str() << std::endl; return stream.str(); #undef VLOAD @@ -605,6 +605,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width } else { +// std::cout << p_.local_size_0 << " " << p_.kL << " " << p_.local_size_1 << " " << p_.depth << std::endl; value_scalar _1(1, dtype); enqueue_block(queue, M, N, lK, create_slice(*pA, 0, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, suffix, options); fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);