Kernels: Fixed various corner cases for the kernel templates and BLAS

This commit is contained in:
Philippe Tillet
2015-11-25 18:42:25 -05:00
parent 6be5929b0d
commit 6fc94c0c0b
15 changed files with 107 additions and 38 deletions

View File

@@ -516,7 +516,10 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
{
string Ci = to_string((m/p_.simd_width)*(p_.local_size_0*p_.simd_width) + m%p_.simd_width);
stream << "if(" << Ci << "< M) ";
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + (beta?beta*" << "C[" << Ci << CSTRIDE1 << "]:0);" << std::endl;
if(has_depth)
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
else
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + (beta?(beta*" << "C[" << Ci << CSTRIDE1 << "]):0);" << std::endl;
}
if((n+1)%p_.simd_width==0){
stream << "C += ldc*" << p_.local_size_1*p_.simd_width - p_.simd_width + 1 << ";" << std::endl;
@@ -552,7 +555,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
stream.inc_tab();
stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl;
stream.dec_tab();
stream << "C[i*Cstride + j*ldc] = acc + beta*C[i + j*ldc];" << std::endl;
stream << "C[i*Cstride + j*ldc] = acc + beta*C[i*Cstride + j*ldc];" << std::endl;
stream.dec_tab();
stream << "}" << std::endl;
stream.dec_tab();