From 79f833ba65b9a3349ceaa0ac045bf4942aa57ffc Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 21 Jul 2015 14:35:30 -0400 Subject: [PATCH] GEMM: Better update of Kx, Ky --- lib/backend/templates/gemm.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/lib/backend/templates/gemm.cpp b/lib/backend/templates/gemm.cpp index 4dcfe8b80..6134418fc 100644 --- a/lib/backend/templates/gemm.cpp +++ b/lib/backend/templates/gemm.cpp @@ -121,6 +121,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width std::string sdtype = numeric_type_to_string(dtype); std::string vdtype = append_width(sdtype, p_.simd_width); std::string _size_t = size_type(device); + std::string vint = append_width("int", p_.simd_width); ////////////////// /// DECLARATIONS @@ -259,17 +260,16 @@ gemm_parameters::gemm_parameters(unsigned int simd_width stream << "storeA = lA + idT.y*" << llda << " + idT.x;" << std::endl; stream << "storeB = lB + idT.y*" << lldb << " + idT.x;" << std::endl; - stream << "//Outer loop" << std::endl; - stream << "while(K > 0){" << std::endl; - stream.inc_tab(); - stream << LocalBarrier(backend) << ";" << std::endl; - if(A_trans_=='N' || B_trans_=='T') stream << "Ky = K - idT.y;" << std::endl; if(A_trans_=='T' || B_trans_=='N') stream << "Kx = K - idT.x;" << std::endl; - std::string vint = append_width("int", p_.simd_width); + stream << "//Outer loop" << std::endl; + stream << "while(K > 0){" << std::endl; + stream.inc_tab(); + stream << LocalBarrier(backend) << ";" << std::endl; + if(A_trans_=='N' || B_trans_=='T') { @@ -419,6 +419,13 @@ gemm_parameters::gemm_parameters(unsigned int simd_width stream.dec_tab(); stream << "}" << std::endl; + + stream << "K -= " << p_.kL << ";" << std::endl; + if(A_trans_=='N' || B_trans_=='T') + stream << "Ky -= " << p_.kL << ";" << std::endl; + if(A_trans_=='T' || B_trans_=='N') + stream << "Kx -= " << p_.kL << ";" << std::endl; + //Increment A pointers to global memory if (A_trans_=='N') for(unsigned int i = 0 ; i < npA ; ++i) @@ -435,7 +442,6 @@ gemm_parameters::gemm_parameters(unsigned int simd_width for(unsigned int i = 0 ; i < npB ; ++i) stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl; - stream << "K -= " << p_.kL << ";" << std::endl; stream.dec_tab(); stream << "}" << std::endl;