2015-12-19 21:35:35 -05:00
|
|
|
/*
|
|
|
|
* Copyright (c) 2015, PHILIPPE TILLET. All rights reserved.
|
|
|
|
*
|
|
|
|
* This file is part of ISAAC.
|
|
|
|
*
|
|
|
|
* ISAAC is free software; you can redistribute it and/or
|
|
|
|
* modify it under the terms of the GNU Lesser General Public
|
|
|
|
* License as published by the Free Software Foundation; either
|
|
|
|
* version 2.1 of the License, or (at your option) any later version.
|
|
|
|
*
|
|
|
|
* This library is distributed in the hope that it will be useful,
|
|
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
|
|
* Lesser General Public License for more details.
|
|
|
|
*
|
|
|
|
* You should have received a copy of the GNU Lesser General Public
|
|
|
|
* License along with this library; if not, write to the Free Software
|
|
|
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
|
|
|
|
* MA 02110-1301 USA
|
|
|
|
*/
|
2015-12-21 17:04:09 -05:00
|
|
|
|
2015-04-29 15:50:57 -04:00
|
|
|
#include "isaac/array.h"
|
2016-04-10 13:13:16 -04:00
|
|
|
#include "isaac/jit/syntax/expression/preset.h"
|
|
|
|
#include "isaac/jit/syntax/engine/process.h"
|
2016-09-30 23:04:50 -04:00
|
|
|
#include "isaac/jit/generation/gemm.h"
|
2016-04-10 13:13:16 -04:00
|
|
|
#include "isaac/jit/generation/engine/keywords.h"
|
2016-04-02 18:19:33 -04:00
|
|
|
#include "isaac/exception/api.h"
|
2015-08-06 16:14:33 -07:00
|
|
|
#include "tools/arguments.hpp"
|
|
|
|
#include "tools/vector_types.hpp"
|
|
|
|
|
2016-04-02 18:19:33 -04:00
|
|
|
|
2015-08-06 19:34:26 -07:00
|
|
|
#include <string>
|
2016-04-02 18:19:33 -04:00
|
|
|
#include "isaac/tools/cpp/align.hpp"
|
2015-06-27 17:55:01 -07:00
|
|
|
|
2015-04-29 15:50:57 -04:00
|
|
|
namespace isaac
|
2015-01-12 13:20:53 -05:00
|
|
|
{
|
2015-07-11 09:36:01 -04:00
|
|
|
namespace templates
|
|
|
|
{
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm_parameters::gemm_parameters(unsigned int vwidth
|
2016-07-02 12:06:05 -07:00
|
|
|
, unsigned int ls0, unsigned int KL, unsigned int ls1, unsigned int D
|
2015-12-17 00:51:04 -05:00
|
|
|
, unsigned int ms, unsigned int ks, unsigned int ns
|
2016-07-02 12:06:05 -07:00
|
|
|
, fetch_type Afetch, fetch_type Bfetch
|
|
|
|
, unsigned int lf0, unsigned int lf1): base::parameters_type(vwidth, ls0, ls1, 1),
|
|
|
|
kL(KL), depth(D), mS(ms), kS(ks), nS(ns), Afetch(Afetch), Bfetch(Bfetch),
|
|
|
|
lf0(lf0), lf1(lf1),
|
|
|
|
mL(ms*ls0), nL(ns*ls1)
|
2015-08-17 18:01:17 -07:00
|
|
|
{
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
unsigned int gemm::lmem_usage(expression_tree const & expression) const
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
|
|
|
unsigned int N = 0;
|
|
|
|
N += p_.kL * p_.mL;
|
|
|
|
N += p_.nL * p_.kL;
|
2016-04-02 18:19:33 -04:00
|
|
|
return N*size_of(expression.dtype());
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
unsigned int gemm::registers_usage(expression_tree const & expression) const
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
|
|
|
unsigned int N = p_.mS * p_.nS + p_.mS * p_.kS + p_.kS * p_.nS;
|
2016-04-02 18:19:33 -04:00
|
|
|
return N*size_of(expression.dtype());
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-06-28 17:53:16 -07:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
unsigned int gemm::temporary_workspace(expression_tree const & expressions) const
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
|
|
|
std::vector<int_t> MNK = input_sizes(expressions);
|
|
|
|
int_t M = MNK[0]; int_t N = MNK[1];
|
|
|
|
if(p_.depth > 1)
|
|
|
|
return M*N*p_.depth;
|
|
|
|
return 0;
|
|
|
|
}
|
2015-08-21 13:06:20 -04:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
int gemm::is_invalid_impl(driver::Device const &, expression_tree const &) const
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
if(p_.Afetch!=FETCH_FROM_LOCAL || p_.Bfetch!=FETCH_FROM_LOCAL)
|
2015-12-17 00:51:04 -05:00
|
|
|
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-07-02 12:06:05 -07:00
|
|
|
if ((p_.mS % p_.vwidth) > 0 || (p_.nS % p_.vwidth) > 0)
|
2015-12-17 00:51:04 -05:00
|
|
|
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
if(p_.mL > 256 || p_.nL > 256)
|
|
|
|
return TEMPLATE_BLOCK_SIZE_TOO_LARGE;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
if ( p_.kS % p_.kL == 0)
|
|
|
|
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.Afetch==FETCH_FROM_LOCAL || p_.Bfetch==FETCH_FROM_LOCAL){
|
|
|
|
if ((p_.lf0*p_.lf1) !=(p_.ls0*p_.ls1))
|
2015-12-17 00:51:04 -05:00
|
|
|
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.Afetch==FETCH_FROM_LOCAL)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
|
|
|
unsigned int bound1 = (A_trans_=='N')?p_.kL:p_.mL;
|
|
|
|
unsigned int bound0 = (A_trans_=='N')?p_.mL:p_.kL;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
2015-12-17 00:51:04 -05:00
|
|
|
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.lf0>0 && (bound0 % (p_.lf0*p_.vwidth)) > 0)
|
2015-12-17 00:51:04 -05:00
|
|
|
return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.Bfetch==FETCH_FROM_LOCAL)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
|
|
|
unsigned int bound1 = (B_trans_=='T')?p_.kL:p_.nL;
|
|
|
|
unsigned int bound0 = (B_trans_=='T')?p_.nL:p_.kL;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.lf1>0 && (bound1 % p_.lf1)> 0)
|
2015-12-17 00:51:04 -05:00
|
|
|
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.lf0>0 && (bound0 % (p_.lf0*p_.vwidth)) > 0)
|
2015-12-17 00:51:04 -05:00
|
|
|
return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
return TEMPLATE_VALID;
|
|
|
|
}
|
2015-11-29 16:13:14 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
std::string gemm::generate_impl(std::string const & suffix, expression_tree const & tree, driver::Device const & device, symbolic::symbols_table const &) const
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
|
|
|
using std::string;
|
|
|
|
using tools::to_string;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
driver::backend_type backend = device.backend();
|
|
|
|
bool has_depth = p_.depth > 1;
|
2016-07-02 12:06:05 -07:00
|
|
|
#define VLOAD(offset, ptr) vload(p_.vwidth, sdtype, offset, ptr, "1", backend, true)
|
|
|
|
#define VLOAD_MISALIGNED(offset, ptr) vload(p_.vwidth, sdtype, offset, ptr, "1", backend, false)
|
|
|
|
#define VSTORE(value, offset, ptr) vstore(p_.vwidth, sdtype, value, offset, ptr, "1", backend)
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
symbolic::preset::gemm::args args;
|
2016-04-10 16:31:29 -04:00
|
|
|
infos(tree, args);
|
|
|
|
std::string ASTRIDE1 = (args.A->ld[0] > 1)?"*Astride1":"";
|
|
|
|
std::string BSTRIDE1 = (args.B->ld[0] > 1)?"*Bstride1":"";
|
|
|
|
std::string CSTRIDE1 = (args.C->ld[0] > 1)?"*Cstride1":"";
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
//////////////////
|
|
|
|
/// INIT
|
|
|
|
/// //////////////
|
2016-04-02 18:19:33 -04:00
|
|
|
kernel_generation_stream stream(backend);
|
|
|
|
numeric_type dtype = tree.dtype();
|
2015-12-17 00:51:04 -05:00
|
|
|
std::string sdtype = to_string(dtype);
|
2016-07-02 12:06:05 -07:00
|
|
|
std::string vdtype = append_width(sdtype, p_.vwidth);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
//////////////////
|
|
|
|
/// DECLARATIONS
|
|
|
|
/// //////////////
|
2016-09-30 23:04:50 -04:00
|
|
|
std::string gemm_name = "gemm";
|
2015-12-17 00:51:04 -05:00
|
|
|
std::string reduce_name = "reduce";
|
2015-07-27 11:37:19 -07:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm_name += suffix;
|
2015-12-17 00:51:04 -05:00
|
|
|
reduce_name += suffix;
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
switch(backend)
|
|
|
|
{
|
|
|
|
case driver::OPENCL:
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << " __attribute__((reqd_work_group_size(" << p_.ls0 << "," << p_.ls1 << ",1)))" << std::endl;
|
2016-04-09 23:51:30 -04:00
|
|
|
break;
|
|
|
|
default:
|
|
|
|
break;
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-05-13 02:20:44 -04:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
stream << "$KERNEL void gemm" << suffix << "($SIZE_T M, $SIZE_T N, $SIZE_T K, "
|
2016-04-02 18:19:33 -04:00
|
|
|
<< "$GLOBAL " << sdtype << "* C, $SIZE_T ldc, $SIZE_T offc, $SIZE_T Cstride1, "
|
2015-12-17 00:51:04 -05:00
|
|
|
<< sdtype << " alpha,"
|
2016-04-02 18:19:33 -04:00
|
|
|
<< "$GLOBAL " << sdtype << "* A, $SIZE_T lda, $SIZE_T offa, $SIZE_T Astride1,"
|
|
|
|
<< "$GLOBAL " << sdtype << "* B, $SIZE_T ldb, $SIZE_T offb, $SIZE_T Bstride1,"
|
2015-12-17 00:51:04 -05:00
|
|
|
<< sdtype << " beta)"
|
|
|
|
<< std::endl;
|
|
|
|
stream << "{" << std::endl;
|
|
|
|
stream.inc_tab();
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
///Declare
|
|
|
|
stream << "//blocks" << std::endl;
|
|
|
|
stream << sdtype << " rC[" << p_.mS << "][" << p_.nS << "] = {{0}};" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << vdtype << " rA[" << p_.kS << "][" << p_.mS/p_.vwidth << "];" << std::endl;
|
|
|
|
stream << vdtype << " rB[" << p_.kS << "][" << p_.nS/p_.vwidth << "];" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << std::endl;
|
|
|
|
|
|
|
|
stream << "//pointers" << std::endl;
|
|
|
|
size_t llda = (A_trans_=='N')?p_.mL:p_.kL;
|
|
|
|
size_t lldb = (B_trans_=='T')?p_.nL:p_.kL;
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "$LOCAL " << sdtype << " lA[" << p_.kL*p_.mL << "];" << std::endl;
|
|
|
|
stream << "$LOCAL " << sdtype << " lB[" << p_.kL*p_.nL << "];" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
unsigned int npA = p_.mL/(A_trans_=='N'?p_.lf0*p_.vwidth:p_.lf1);
|
|
|
|
unsigned int npB = p_.nL/(B_trans_=='T'?p_.lf0*p_.vwidth:p_.lf1);
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "$GLOBAL " << sdtype << "* Ai[" << npA << "];" << std::endl;
|
|
|
|
stream << "$GLOBAL " << sdtype << "* Bi[" << npB << "];" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << std::endl;
|
|
|
|
|
|
|
|
stream << "//identifiers" << std::endl;
|
|
|
|
stream << "int2 idT;" << std::endl;
|
|
|
|
stream << "int idt;" << std::endl;
|
2015-07-18 16:06:17 -04:00
|
|
|
if(has_depth)
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "int gidz, div, offz;" << std::endl;
|
|
|
|
stream << "uint4 ids;" << std::endl;
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "ids.x = $GROUP_IDX_0;" << std::endl;
|
|
|
|
stream << "ids.y = $GROUP_IDX_1;" << std::endl;
|
|
|
|
stream << "ids.z = $LOCAL_IDX_0;" << std::endl;
|
|
|
|
stream << "ids.w = $LOCAL_IDX_1;" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << std::endl;
|
|
|
|
|
|
|
|
stream << "//offsets" << std::endl;
|
|
|
|
stream << "A += offa;" << std::endl;
|
|
|
|
stream << "B += offb;" << std::endl;
|
|
|
|
stream << "C += offc;" << std::endl;
|
2015-07-14 20:40:29 -07:00
|
|
|
|
2015-07-27 11:37:19 -07:00
|
|
|
if(has_depth)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "gidz = $GROUP_IDX_2;" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "div = (K+" << p_.depth-1 << ")/" << p_.depth << ";" << std::endl;
|
|
|
|
stream << "offz = div*gidz;" << std::endl;
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "K = min(K - div*gidz, ($SIZE_T)div);" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-07-27 11:37:19 -07:00
|
|
|
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "idt = " << p_.ls0 << "*ids.w + ids.z;" << std::endl;
|
|
|
|
stream << "idT.y = idt/" << p_.lf0 << ";" << std::endl;
|
|
|
|
stream << "idT.x = idt - " << p_.lf0 << "*idT.y;" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << std::endl;
|
2015-07-27 11:37:19 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "//Adjust pointers and bounds per work-item" << std::endl;
|
|
|
|
stream << "ids.x *= " << p_.mL << ";" << std::endl;
|
|
|
|
stream << "ids.y *= " << p_.nL << ";" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "idT.x *= " << p_.vwidth << ";" << std::endl;
|
2015-07-27 11:37:19 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "M -= ids.x;" << std::endl;
|
|
|
|
if(A_trans_=='N')
|
|
|
|
stream << "M -= idT.x;" << std::endl;
|
2015-07-27 11:37:19 -07:00
|
|
|
else
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "M -= idT.y;" << std::endl;
|
2015-07-27 11:37:19 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "N -= ids.y;" << std::endl;
|
|
|
|
if(B_trans_=='T')
|
|
|
|
stream << "N -= idT.x;" << std::endl;
|
|
|
|
else
|
|
|
|
stream << "N -= idT.y;" << std::endl;
|
2015-07-20 22:51:39 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
if (A_trans_=='N')
|
2015-07-20 22:51:39 -07:00
|
|
|
{
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "A += ids.x" << ASTRIDE1 << ";" << std::endl;
|
|
|
|
stream << "A += idT.y*lda;" << std::endl;
|
|
|
|
if(has_depth)
|
|
|
|
stream << "A += offz*lda;" << std::endl;
|
2015-07-21 14:44:10 -04:00
|
|
|
|
2015-11-29 16:13:14 -05:00
|
|
|
}
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2015-07-20 22:51:39 -07:00
|
|
|
{
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "A += ids.x*lda;" << std::endl;
|
|
|
|
stream << "A += idT.x" << ASTRIDE1 << ";" << std::endl;
|
|
|
|
if(has_depth)
|
|
|
|
stream << "A += offz;" << std::endl;
|
|
|
|
}
|
|
|
|
|
|
|
|
if(B_trans_=='T')
|
|
|
|
{
|
|
|
|
stream << "B += ids.y" << BSTRIDE1 << ";" << std::endl;
|
|
|
|
stream << "B += idT.y*ldb;" << std::endl;
|
|
|
|
if(has_depth)
|
|
|
|
stream << "B += offz*ldb;" << std::endl;
|
2015-07-20 22:51:39 -07:00
|
|
|
}
|
2015-07-21 14:44:10 -04:00
|
|
|
else
|
2015-07-20 22:51:39 -07:00
|
|
|
{
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "B += ids.y*ldb;" << std::endl;
|
|
|
|
stream << "B += idT.x" << BSTRIDE1 << ";" << std::endl;
|
|
|
|
if(has_depth)
|
|
|
|
stream << "B += offz;" << std::endl;
|
2015-11-29 16:13:14 -05:00
|
|
|
}
|
2015-07-21 14:35:22 -04:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "#pragma unroll" << std::endl;
|
|
|
|
stream << "for(int i = 0 ; i < " << npA << " ; ++i){" << std::endl;
|
|
|
|
stream.inc_tab();
|
|
|
|
stream << "Ai[i] = A;" << std::endl;
|
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
|
|
|
stream << std::endl;
|
|
|
|
|
|
|
|
stream << "#pragma unroll" << std::endl;
|
|
|
|
stream << "for(int i = 0 ; i < " << npB << " ; ++i){" << std::endl;
|
|
|
|
stream.inc_tab();
|
|
|
|
stream << "Bi[i] = B;" << std::endl;
|
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
|
|
|
stream << std::endl;
|
|
|
|
|
|
|
|
for(unsigned int i = 0 ; i < npA ; i++ )
|
|
|
|
if (A_trans_=='N')
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < M", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + ASTRIDE1 + ")", "0") << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "Ai[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < M", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*lda)", "0") << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
|
|
|
|
for(unsigned int i = 0 ; i < npB ; i++ )
|
|
|
|
if (B_trans_=='T')
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf0*p_.vwidth) + " < N", "(int)((idT.x + " + to_string(i*p_.lf0*p_.vwidth) + ")" + BSTRIDE1 + ")", "0") << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "Bi[" << i << "] += " << Select(backend, to_string(i*p_.lf1) + " < N", "(int)((idT.y + " + to_string(i*p_.lf1) + ")*ldb)", "0") << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
|
|
|
|
stream << std::endl;
|
|
|
|
stream << "//Outer loop" << std::endl;
|
|
|
|
stream << "while(K >=" << p_.kL << ")" << std::endl;
|
|
|
|
stream << "{" << std::endl;
|
|
|
|
stream.inc_tab();
|
|
|
|
|
|
|
|
|
|
|
|
auto fetch_to_lds = [&](bool last_iteration)
|
2015-07-21 14:35:22 -04:00
|
|
|
{
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "$LOCAL_BARRIER;" << std::endl;
|
|
|
|
stream << "$LOCAL_PTR " << sdtype << "* ldsA = lA + idT.y*" << llda << " + idT.x;" << std::endl;
|
|
|
|
stream << "$LOCAL_PTR " << sdtype << "* ldsB = lB + idT.y*" << lldb << " + idT.x;" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
|
|
|
|
stream << "//Fetch A to local memory" << std::endl;
|
|
|
|
if (A_trans_=='N')
|
2015-07-21 14:35:22 -04:00
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
|
|
|
for(unsigned int m = 0; m < p_.mL; m += p_.lf0*p_.vwidth)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
std::string mm = to_string(m/(p_.vwidth*p_.lf0));
|
2015-12-17 00:51:04 -05:00
|
|
|
std::string kk = to_string(k);
|
|
|
|
if(last_iteration)
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "ldsA[" << k*llda + m + s << "] = (condy" << k << " && " << s << "< M)? Ai[" << mm << "][" << k << "*lda + " << s << "] : 0;" << std::endl;
|
|
|
|
else
|
|
|
|
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Ai[" + mm +"][" + kk + "*lda]"), "0", "ldsA + " + to_string(k*llda+m)) << ";" << std::endl;
|
|
|
|
}
|
2015-07-22 17:46:50 -07:00
|
|
|
}
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2015-07-22 17:46:50 -07:00
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
|
|
|
for(unsigned int m = 0; m < p_.mL; m += p_.lf1)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
std::string mm = to_string(m/p_.lf1);
|
2015-12-17 00:51:04 -05:00
|
|
|
std::string kk = to_string(k);
|
|
|
|
if(last_iteration)
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "ldsA[" << m*llda + k + s << "] = condx" << k + s << "? Ai[" << mm << "][" << k + s << ASTRIDE1 << "] : 0;" << std::endl;
|
|
|
|
|
|
|
|
else
|
|
|
|
stream << VSTORE(VLOAD_MISALIGNED("0", "&Ai[" + mm + "][" + kk + ASTRIDE1 + "]"), "0", "ldsA + " + to_string(m*llda+k)) << ";" << std::endl;
|
|
|
|
}
|
2015-07-09 15:03:55 -04:00
|
|
|
}
|
2015-07-08 21:09:21 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "//Fetch B to local memory" << std::endl;
|
|
|
|
if (B_trans_=='T')
|
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
|
|
|
for(unsigned int n = 0; n < p_.nL; n += p_.lf0*p_.vwidth)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
std::string nn = to_string(n/(p_.vwidth*p_.lf0));
|
2015-12-17 00:51:04 -05:00
|
|
|
std::string kk = to_string(k);
|
|
|
|
if(last_iteration)
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "ldsB[" << k*lldb + n + s << "] = (condy" << k << " && " << s << "< N)? Bi[" << nn << "][" << kk << "*ldb +" << s << "] : 0;" << std::endl;
|
|
|
|
else
|
|
|
|
stream << VSTORE(VLOAD_MISALIGNED("0" ,"&Bi[" + nn +"][" + kk + "*ldb]"), "0", "ldsB + " + to_string(k*lldb+n)) << ";" << std::endl;
|
|
|
|
}
|
|
|
|
}
|
2015-07-22 17:46:50 -07:00
|
|
|
else
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int k = 0; k < p_.kL; k += p_.lf0*p_.vwidth)
|
|
|
|
for(unsigned int n = 0; n < p_.nL; n += p_.lf1)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
std::string nn = to_string(n/p_.lf1);
|
2015-12-17 00:51:04 -05:00
|
|
|
std::string kk = to_string(k);
|
|
|
|
if(last_iteration)
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "ldsB[" << n*lldb + k + s << "] = condx" << k + s << "? Bi[" << nn << "][" << k + s << BSTRIDE1 << "] : 0;" << std::endl;
|
|
|
|
|
|
|
|
else
|
|
|
|
stream << VSTORE(VLOAD_MISALIGNED("0", "&Bi[" + nn + "][" + kk + BSTRIDE1 + "]"), "0", "ldsB + " + to_string(n*lldb+k)) << ";" << std::endl;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if(A_trans_=='N')
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "ldsA = lA + ids.z*" << p_.vwidth << ";" << std::endl;
|
2015-07-22 17:46:50 -07:00
|
|
|
else
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "ldsA = lA + ids.z*" << llda*p_.vwidth << ";" << std::endl;
|
2015-07-22 17:46:50 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
if(B_trans_=='T')
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "ldsB = lB + ids.w*" << p_.vwidth << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "ldsB = lB + ids.w*" << lldb*p_.vwidth << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "$LOCAL_BARRIER;" << std::endl;
|
2016-09-27 23:44:22 -04:00
|
|
|
std::string bound = last_iteration?"K":tools::to_string(p_.kL);
|
|
|
|
size_t ks = last_iteration?1:p_.kS;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "//Inner loop" << std::endl;
|
2016-09-27 23:44:22 -04:00
|
|
|
stream << "for(unsigned int k = 0; k < " << bound << "; k+=" << ks << "){" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream.inc_tab();
|
|
|
|
|
|
|
|
stream << "//Fetch A to registers" << std::endl;
|
|
|
|
stream << "#pragma unroll" << std::endl;
|
2016-09-27 23:44:22 -04:00
|
|
|
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "#pragma unroll " << p_.mS/p_.vwidth << std::endl;
|
|
|
|
stream << "for(unsigned int mm = 0; mm < " << p_.mS/p_.vwidth << "; mm++)" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "{" << std::endl;
|
|
|
|
stream.inc_tab();
|
|
|
|
if(A_trans_=='N')
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "rA[kk][mm] = " << VLOAD("0", "ldsA + k*" + to_string(llda) + " + mm*" + to_string(p_.ls0*p_.vwidth) + "+ kk*" + to_string(llda)) << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
if(p_.vwidth==1)
|
|
|
|
stream << "rA[kk][mm] = ldsA[k + mm*" << p_.ls0*llda << "+ kk" << "];" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
|
|
|
stream << access_vector_type("rA[kk][mm]", s) << " = ldsA[k + (mm*" << p_.vwidth*p_.ls0 << " + " << s << ")*" << llda << "+ kk];" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-07-09 15:03:55 -04:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
|
|
|
|
|
|
|
stream << "//Fetch B to registers" << std::endl;
|
2016-09-27 23:44:22 -04:00
|
|
|
stream << "#pragma unroll " << ks << std::endl;
|
|
|
|
stream << "for(unsigned int kk = 0; kk < " << ks << "; kk++)" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "#pragma unroll " << p_.nS/p_.vwidth << std::endl;
|
|
|
|
stream << "for(unsigned int nn = 0; nn < " << p_.nS/p_.vwidth << "; nn++)" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "{" << std::endl;
|
|
|
|
stream.inc_tab();
|
|
|
|
if(B_trans_=='T')
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "rB[kk][nn] = " << VLOAD("0", "ldsB + k*" + to_string(lldb) + " + nn*" + to_string(p_.ls1*p_.vwidth) + "+ kk*" + to_string(lldb)) << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
if(p_.vwidth==1)
|
|
|
|
stream << "rB[kk][nn] = ldsB[k" << " + nn*" << p_.ls1*lldb << "+ kk" << "];" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
|
|
|
stream << access_vector_type("rB[kk][nn]", s) << " = ldsB[k" << " + (nn*" << p_.vwidth*p_.ls1 << " + " << s << ")*" << lldb << "+ kk];" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
2015-07-10 21:15:36 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "//FMA computations" << std::endl;
|
2016-09-27 23:44:22 -04:00
|
|
|
stream << "#pragma unroll" << std::endl;
|
|
|
|
stream << "for(unsigned int kk = 0 ; kk < " << ks << "; ++kk){" << std::endl;
|
|
|
|
stream.inc_tab();
|
2015-12-17 00:51:04 -05:00
|
|
|
for(unsigned int nn=0; nn < p_.nS; ++nn)
|
2016-04-02 18:19:33 -04:00
|
|
|
for(unsigned int mm=0; mm < p_.mS; ++mm){
|
2015-12-17 00:51:04 -05:00
|
|
|
string res_str, lhs_str, rhs_str;
|
|
|
|
res_str = "rC[" + to_string(mm) + "][" + to_string(nn) + "]";
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.vwidth==1)
|
2016-09-27 23:44:22 -04:00
|
|
|
lhs_str = "rA[kk][" + to_string(mm) + "]";
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2016-09-27 23:44:22 -04:00
|
|
|
lhs_str = access_vector_type("rA[kk][" + to_string(mm/p_.vwidth) + "]", mm%p_.vwidth);
|
2016-07-02 12:06:05 -07:00
|
|
|
if (p_.vwidth==1)
|
2016-09-27 23:44:22 -04:00
|
|
|
rhs_str = "rB[kk]["+to_string(nn)+"]";
|
2015-12-17 00:51:04 -05:00
|
|
|
else
|
2016-09-27 23:44:22 -04:00
|
|
|
rhs_str = access_vector_type("rB[kk]["+to_string(nn/p_.vwidth)+"]", nn%p_.vwidth);
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << res_str << "= $MAD(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2016-09-27 23:44:22 -04:00
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
|
|
|
stream << "K -= " << p_.kL << ";" << std::endl;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
//Increment A pointers to global memory
|
|
|
|
if (A_trans_=='N')
|
|
|
|
for(unsigned int i = 0 ; i < npA ; ++i)
|
|
|
|
stream << "Ai[" << i << "] += " << p_.kL << "*lda;" << std::endl;
|
|
|
|
else
|
|
|
|
for(unsigned int i = 0 ; i < npA ; ++i)
|
|
|
|
stream << "Ai[" << i << "] += " << p_.kL << ASTRIDE1 << ";" << std::endl;
|
2015-07-22 17:46:50 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
//Increment B pointers to global memory
|
|
|
|
if (B_trans_=='T')
|
|
|
|
for(unsigned int i = 0 ; i < npB ; ++i)
|
|
|
|
stream << "Bi[" << i << "] += " << p_.kL << "*ldb;" << std::endl;
|
|
|
|
else
|
|
|
|
for(unsigned int i = 0 ; i < npB ; ++i)
|
|
|
|
stream << "Bi[" << i << "] += " << p_.kL << BSTRIDE1 << ";" << std::endl;
|
|
|
|
};
|
|
|
|
fetch_to_lds(false);
|
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
2015-07-21 14:35:30 -04:00
|
|
|
|
2015-08-24 23:09:23 -04:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
if(A_trans_=='N' || B_trans_=='T')
|
2015-08-24 23:09:23 -04:00
|
|
|
{
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "int Ky = K - idT.y;" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int k = 0; k < p_.kL; k += p_.lf1)
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "int condy" << k << " = " << k << " < Ky;" << std::endl;
|
2015-11-29 16:13:14 -05:00
|
|
|
}
|
2015-12-17 00:51:04 -05:00
|
|
|
|
|
|
|
if(A_trans_=='T' || B_trans_=='N')
|
|
|
|
{
|
|
|
|
stream << "int Kx = K - idT.x;" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
for(unsigned int k = 0 ; k < p_.kL ; k += p_.lf0*p_.vwidth)
|
|
|
|
for(unsigned int s = 0 ; s < p_.vwidth ; ++s)
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "int condx" << k + s << " = " << k + s << " < Kx;" << std::endl;
|
2015-08-24 23:09:23 -04:00
|
|
|
}
|
2015-12-17 00:51:04 -05:00
|
|
|
fetch_to_lds(true);
|
2015-07-23 11:20:50 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "//Write back C" << std::endl;
|
|
|
|
stream << "M += ids.x;" << std::endl;
|
|
|
|
if(A_trans_=='N')
|
|
|
|
stream << "M += idT.x;" << std::endl;
|
|
|
|
else
|
|
|
|
stream << "M += idT.y;" << std::endl;
|
2015-07-23 11:20:50 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
if(B_trans_=='T')
|
|
|
|
stream << "N += idT.x;" << std::endl;
|
|
|
|
else
|
|
|
|
stream << "N += idT.y;" << std::endl;
|
|
|
|
stream << "N += ids.y;" << std::endl;
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "C += ids.x" << CSTRIDE1 << ";" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "C += ids.z*" << p_.vwidth << CSTRIDE1 << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "C += ids.y*ldc;" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "C += ids.w*" << p_.vwidth << "*ldc;" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
if(has_depth)
|
|
|
|
stream << "C += gidz*ldc*N;" << std::endl;
|
2015-07-15 23:28:43 -07:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "M -= ids.x;" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "M -= ids.z*" << p_.vwidth << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
|
|
|
|
stream << "N -= ids.y;" << std::endl;
|
2016-07-02 12:06:05 -07:00
|
|
|
stream << "N -= ids.w*" << p_.vwidth << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
|
|
|
|
for(unsigned int n=0; n < p_.nS; ++n)
|
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
string Cj = to_string((n/p_.vwidth)*(p_.ls1*p_.vwidth) + n%p_.vwidth);
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "if(" << Cj << " >= N) return;" << std::endl;
|
|
|
|
for(unsigned int m=0; m < p_.mS; ++m)
|
|
|
|
stream << "rC[" << m << "][" << n << "] *= alpha;" << std::endl;
|
|
|
|
for(unsigned int m=0; m < p_.mS; ++m)
|
|
|
|
{
|
2016-07-02 12:06:05 -07:00
|
|
|
string Ci = to_string((m/p_.vwidth)*(p_.ls0*p_.vwidth) + m%p_.vwidth);
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "if(" << Ci << "< M) ";
|
|
|
|
if(has_depth)
|
|
|
|
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "];" << std::endl;
|
|
|
|
else
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "C[" << Ci << CSTRIDE1 << "] = rC[" << m << "][" << n << "] + ((beta != (" << sdtype << ")0)?(beta*" << "C[" << Ci << CSTRIDE1 << "]):0);" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2016-07-02 12:06:05 -07:00
|
|
|
if((n+1)%p_.vwidth==0){
|
|
|
|
stream << "C += ldc*" << p_.ls1*p_.vwidth - p_.vwidth + 1 << ";" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
|
|
|
else{
|
|
|
|
stream << "C += ldc;" << std::endl;
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-11-29 16:13:14 -05:00
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
|
|
|
|
if(has_depth)
|
|
|
|
{
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "$KERNEL void reduce" << suffix << "($SIZE_T M, $SIZE_T N, $SIZE_T D, "
|
|
|
|
<< "$GLOBAL " << sdtype << "* Z, $SIZE_T Zld,"
|
|
|
|
<< "$GLOBAL " << sdtype << "* C, $SIZE_T ldc, $SIZE_T Cstart, $SIZE_T Cstride,"
|
2015-12-17 00:51:04 -05:00
|
|
|
<< sdtype << " beta)"
|
|
|
|
<< std::endl;
|
|
|
|
stream << "{" << std::endl;
|
|
|
|
stream.inc_tab();
|
|
|
|
|
|
|
|
stream << "C += Cstart;" << std::endl;
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "for(unsigned int i = $GLOBAL_IDX_0 ; i < M ; i += $GLOBAL_SIZE_0)" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "{" << std::endl;
|
|
|
|
stream.inc_tab();
|
2016-04-02 18:19:33 -04:00
|
|
|
stream << "for(unsigned int j = $GLOBAL_IDX_1 ; j < N ; j += $GLOBAL_SIZE_1)" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream << "{" << std::endl;
|
|
|
|
stream.inc_tab();
|
|
|
|
stream << sdtype << " acc = 0;" << std::endl;
|
|
|
|
stream << "for(unsigned int k = 0 ; k < D ; k++)" << std::endl;
|
|
|
|
stream.inc_tab();
|
|
|
|
stream << "acc += Z[i + j*Zld + k*Zld*N];" << std::endl;
|
|
|
|
stream.dec_tab();
|
2016-07-05 21:27:18 -07:00
|
|
|
stream << "C[i*Cstride + j*ldc] = acc + ((beta != (" << sdtype << ")0)?(beta*C[i*Cstride + j*ldc]):0);" << std::endl;
|
2015-12-17 00:51:04 -05:00
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
|
|
|
|
|
|
|
stream.dec_tab();
|
|
|
|
stream << "}" << std::endl;
|
|
|
|
}
|
|
|
|
|
|
|
|
return stream.str();
|
|
|
|
|
|
|
|
#undef VLOAD
|
2015-11-29 16:13:14 -05:00
|
|
|
#undef VST0RE
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
void gemm::enqueue_block(driver::CommandQueue & queue, int_t M, int_t N, int_t K,
|
2016-04-02 18:19:33 -04:00
|
|
|
expression_tree::node const & A, expression_tree::node const & B, expression_tree::node const & C,
|
2015-12-17 00:51:04 -05:00
|
|
|
value_scalar const & alpha, value_scalar const & beta,
|
2016-04-10 13:13:16 -04:00
|
|
|
driver::Program const & program, std::string const & suffix, runtime::execution_options_type const & options)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
|
|
|
using tools::align;
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
if(M==0 || N==0 || K==0)
|
|
|
|
return;
|
2015-07-09 11:40:26 -04:00
|
|
|
|
2016-04-02 18:19:33 -04:00
|
|
|
driver::backend_type backend = queue.context().backend();
|
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
std::string gemm_name = "gemm";
|
2015-12-17 00:51:04 -05:00
|
|
|
std::string reduce_name = "reduce";
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm_name += suffix;
|
2015-12-17 00:51:04 -05:00
|
|
|
reduce_name += suffix;
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
driver::Kernel gemm(program, gemm_name.c_str());
|
2016-07-02 12:06:05 -07:00
|
|
|
driver::NDRange local(p_.ls0, p_.ls1, 1);
|
|
|
|
driver::NDRange global(align(align(M,p_.mS)/p_.mS, p_.ls0), align(align(N,p_.nS)/p_.nS, p_.ls1), p_.depth);
|
2015-06-30 17:55:57 -04:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
unsigned int current_arg = 0;
|
2015-11-27 18:43:46 -05:00
|
|
|
|
2016-04-02 18:19:33 -04:00
|
|
|
driver::Buffer& workspace = driver::backend::workspaces::get(options.queue(queue.context()));
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setSizeArg(current_arg++, M);
|
|
|
|
gemm.setSizeArg(current_arg++, N);
|
|
|
|
gemm.setSizeArg(current_arg++, K);
|
2015-12-17 00:51:04 -05:00
|
|
|
if(p_.depth==1)
|
|
|
|
{
|
2016-04-02 18:19:33 -04:00
|
|
|
if(backend==driver::OPENCL)
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, C.array.handle.cl);
|
2016-04-02 18:19:33 -04:00
|
|
|
else
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, C.array.handle.cu);
|
|
|
|
gemm.setSizeArg(current_arg++, C.ld[1]);
|
|
|
|
gemm.setSizeArg(current_arg++, C.array.start);
|
|
|
|
gemm.setSizeArg(current_arg++, C.ld[0]);
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, workspace);
|
|
|
|
gemm.setSizeArg(current_arg++, M);
|
|
|
|
gemm.setSizeArg(current_arg++, 0);
|
|
|
|
gemm.setSizeArg(current_arg++, 1);
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-04-29 15:50:57 -04:00
|
|
|
|
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, alpha);
|
2016-04-02 18:19:33 -04:00
|
|
|
if(backend==driver::OPENCL)
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, A.array.handle.cl);
|
2016-04-02 18:19:33 -04:00
|
|
|
else
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, A.array.handle.cu);
|
|
|
|
gemm.setSizeArg(current_arg++, A.ld[1]);
|
|
|
|
gemm.setSizeArg(current_arg++, A.array.start);
|
|
|
|
gemm.setSizeArg(current_arg++, A.ld[0]);
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2016-04-02 18:19:33 -04:00
|
|
|
if(backend==driver::OPENCL)
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, B.array.handle.cl);
|
2016-04-02 18:19:33 -04:00
|
|
|
else
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, B.array.handle.cu);
|
|
|
|
gemm.setSizeArg(current_arg++, B.ld[1]);
|
|
|
|
gemm.setSizeArg(current_arg++, B.array.start);
|
|
|
|
gemm.setSizeArg(current_arg++, B.ld[0]);
|
2015-04-29 15:50:57 -04:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm.setArg(current_arg++, beta);
|
|
|
|
options.enqueue(program.context(), gemm, global, local);
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
if(p_.depth > 1)
|
|
|
|
{
|
|
|
|
unsigned int current_arg = 0;
|
|
|
|
driver::Kernel reduce(program, reduce_name.c_str());
|
2016-07-02 12:06:05 -07:00
|
|
|
driver::NDRange local(p_.ls0, p_.ls1);
|
|
|
|
driver::NDRange global(align(M, p_.ls0), align(N, p_.ls1));
|
2015-12-17 00:51:04 -05:00
|
|
|
reduce.setSizeArg(current_arg++, M);
|
|
|
|
reduce.setSizeArg(current_arg++, N);
|
|
|
|
reduce.setSizeArg(current_arg++, p_.depth);
|
|
|
|
reduce.setArg(current_arg++, workspace);
|
|
|
|
reduce.setSizeArg(current_arg++, M);
|
2016-04-02 18:19:33 -04:00
|
|
|
if(backend==driver::OPENCL)
|
|
|
|
reduce.setArg(current_arg++, C.array.handle.cl);
|
|
|
|
else
|
|
|
|
reduce.setArg(current_arg++, C.array.handle.cu);
|
|
|
|
reduce.setSizeArg(current_arg++, C.ld[1]);
|
|
|
|
reduce.setSizeArg(current_arg++, C.array.start);
|
|
|
|
reduce.setSizeArg(current_arg++, C.ld[0]);
|
|
|
|
reduce.setArg(current_arg++, beta);
|
2015-12-17 00:51:04 -05:00
|
|
|
options.enqueue(program.context(), reduce, global, local);
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
std::vector<int_t> gemm::infos(expression_tree const & tree, symbolic::preset::gemm::args& arguments) const
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-04-02 18:19:33 -04:00
|
|
|
expression_tree::data_type const & array = tree.data();
|
|
|
|
std::size_t root = tree.root();
|
2016-09-30 23:04:50 -04:00
|
|
|
arguments = symbolic::preset::gemm::check(array, root);
|
2016-04-02 18:19:33 -04:00
|
|
|
int_t M = arguments.C->shape[0];
|
|
|
|
int_t N = arguments.C->shape[1];
|
|
|
|
int_t K = (A_trans_=='T')?arguments.A->shape[0]:arguments.A->shape[1];
|
2015-12-17 00:51:04 -05:00
|
|
|
return {M, N, K};
|
|
|
|
}
|
2015-11-29 16:13:14 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm::gemm(gemm_parameters const & parameters, char A_trans, char B_trans) : base_impl<gemm, gemm_parameters>(parameters), A_trans_(A_trans), B_trans_(B_trans)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-09-30 23:04:50 -04:00
|
|
|
if(A_trans_=='N' && B_trans_=='N') type_ = GEMM_NN;
|
|
|
|
else if(A_trans_=='T' && B_trans_=='N') type_ = GEMM_TN;
|
|
|
|
else if(A_trans_=='N' && B_trans_=='T') type_ = GEMM_NT;
|
|
|
|
else if(A_trans_=='T' && B_trans_=='T') type_ = GEMM_TT;
|
2015-12-17 00:51:04 -05:00
|
|
|
else throw;
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
std::vector<int_t> gemm::input_sizes(expression_tree const & expressions) const
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2016-09-30 23:04:50 -04:00
|
|
|
symbolic::preset::gemm::args dummy;
|
2015-12-19 02:55:24 -05:00
|
|
|
return infos((expression_tree&)expressions, dummy);
|
2015-12-17 00:51:04 -05:00
|
|
|
}
|
2015-02-05 04:42:57 -05:00
|
|
|
|
2016-09-30 23:04:50 -04:00
|
|
|
void gemm::enqueue(driver::CommandQueue & queue, driver::Program const & program, std::string const & suffix, runtime::execution_handler const & control)
|
2015-12-17 00:51:04 -05:00
|
|
|
{
|
2015-12-19 02:55:24 -05:00
|
|
|
expression_tree const & expressions = control.x();
|
2016-09-30 23:04:50 -04:00
|
|
|
symbolic::preset::gemm::args args;
|
2015-12-17 00:51:04 -05:00
|
|
|
std::vector<int_t> MNK = infos(expressions, args);
|
|
|
|
int_t M = MNK[0];
|
|
|
|
int_t N = MNK[1];
|
|
|
|
int_t K = MNK[2];
|
|
|
|
//Skip if empty
|
|
|
|
if(M==0 || N == 0 || K ==0)
|
|
|
|
return;
|
|
|
|
//Enqueue
|
2016-04-10 13:13:16 -04:00
|
|
|
runtime::execution_options_type const & options = control.execution_options();
|
2016-04-10 16:31:29 -04:00
|
|
|
enqueue_block(queue, M, N, K, *args.A, *args.B, *args.C, args.alpha, args.beta, program, suffix, options);
|
2015-08-17 18:01:17 -07:00
|
|
|
}
|
2015-12-17 00:51:04 -05:00
|
|
|
|
|
|
|
//
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm_nn::gemm_nn(unsigned int simd
|
2015-12-17 00:51:04 -05:00
|
|
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
2016-07-02 12:06:05 -07:00
|
|
|
, fetch_type Afetch , fetch_type Bfetch
|
2016-09-30 23:04:50 -04:00
|
|
|
, int_t lf0, int_t lf1) :
|
|
|
|
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'N', 'N')
|
2015-11-29 16:13:14 -05:00
|
|
|
{
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-12-17 00:51:04 -05:00
|
|
|
//
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm_tn::gemm_tn(unsigned int simd
|
2015-12-17 00:51:04 -05:00
|
|
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
2016-07-02 12:06:05 -07:00
|
|
|
, fetch_type Afetch , fetch_type Bfetch
|
2016-09-30 23:04:50 -04:00
|
|
|
, int_t lf0, int_t lf1) :
|
|
|
|
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'T', 'N')
|
2015-12-17 00:51:04 -05:00
|
|
|
{ }
|
|
|
|
|
|
|
|
//
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm_nt::gemm_nt(unsigned int simd
|
2015-12-17 00:51:04 -05:00
|
|
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
2016-07-02 12:06:05 -07:00
|
|
|
, fetch_type Afetch , fetch_type Bfetch
|
2016-09-30 23:04:50 -04:00
|
|
|
, int_t lf0, int_t lf1) :
|
|
|
|
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'N', 'T')
|
2015-12-17 00:51:04 -05:00
|
|
|
{ }
|
|
|
|
|
|
|
|
//
|
2016-09-30 23:04:50 -04:00
|
|
|
gemm_tt::gemm_tt(unsigned int simd
|
2015-12-17 00:51:04 -05:00
|
|
|
, int_t ls0, int_t KL, int_t ls1, int_t D
|
|
|
|
, int_t ms, int_t ks, int_t ns
|
2016-07-02 12:06:05 -07:00
|
|
|
, fetch_type Afetch , fetch_type Bfetch
|
2016-09-30 23:04:50 -04:00
|
|
|
, int_t lf0, int_t lf1) :
|
|
|
|
gemm(gemm_parameters(simd, ls0, KL, ls1, D, ms, ks, ns, Afetch, Bfetch, lf0, lf1), 'T', 'T')
|
2015-12-17 00:51:04 -05:00
|
|
|
{ }
|
2015-01-12 13:20:53 -05:00
|
|
|
|
2015-07-11 09:36:01 -04:00
|
|
|
}
|
|
|
|
}
|
2015-01-12 13:20:53 -05:00
|
|
|
|