2015-04-29 15:50:57 -04:00
# include "isaac/array.h"
2015-07-11 09:36:01 -04:00
# include "isaac/backend/templates/gemm.h"
2015-04-29 15:50:57 -04:00
# include "isaac/backend/keywords.h"
2015-06-27 17:55:01 -07:00
# include "isaac/model/model.h"
# include "isaac/symbolic/preset.h"
2015-06-28 17:53:16 -07:00
# include "isaac/exception/operation_not_supported.h"
2015-04-29 15:50:57 -04:00
# include "isaac/tools/make_vector.hpp"
# include "isaac/tools/to_string.hpp"
# include "isaac/tools/miscellaneous.hpp"
2015-06-27 17:55:01 -07:00
2015-04-29 15:50:57 -04:00
namespace isaac
2015-01-12 13:20:53 -05:00
{
2015-07-11 09:36:01 -04:00
namespace templates
{
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
gemm_parameters : : gemm_parameters ( unsigned int simd_width
2015-04-29 15:50:57 -04:00
, int_t local_size_0 , int_t KL , int_t local_size_1 , int_t D
2015-01-12 13:20:53 -05:00
, int_t ms , int_t ks , int_t ns
, fetching_policy_type A_fetching_policy , fetching_policy_type B_fetching_policy
2015-01-17 10:48:02 -05:00
, int_t local_fetch_0 , int_t local_fetch_1 ) : base : : parameters_type ( simd_width , local_size_0 , local_size_1 , 1 ) ,
2015-04-29 15:50:57 -04:00
kL ( KL ) , depth ( D ) , mS ( ms ) , kS ( ks ) , nS ( ns ) , A_fetching_policy ( A_fetching_policy ) , B_fetching_policy ( B_fetching_policy ) ,
2015-01-12 13:20:53 -05:00
local_fetch_0 ( local_fetch_0 ) , local_fetch_1 ( local_fetch_1 ) ,
mL ( ms * local_size_0 ) , nL ( ns * local_size_1 ) { }
2015-07-11 09:36:01 -04:00
unsigned int gemm : : lmem_usage ( expressions_tuple const & expressions ) const
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
isaac : : array_expression const & array_expression = ( * expressions . data ( ) . front ( ) ) ;
2015-01-31 22:01:48 -05:00
numeric_type numeric_t = lhs_most ( array_expression . tree ( ) , array_expression . root ( ) ) . lhs . dtype ;
2015-01-12 13:20:53 -05:00
unsigned int N = 0 ;
2015-07-09 15:03:55 -04:00
N + = p_ . kL * p_ . mL ;
N + = p_ . nL * p_ . kL ;
2015-01-12 13:20:53 -05:00
return N * size_of ( numeric_t ) ;
}
2015-07-11 09:36:01 -04:00
unsigned int gemm : : registers_usage ( expressions_tuple const & expressions ) const
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
isaac : : array_expression const & array_expression = ( * expressions . data ( ) . front ( ) ) ;
2015-01-31 22:01:48 -05:00
numeric_type numeric_t = lhs_most ( array_expression . tree ( ) , array_expression . root ( ) ) . lhs . dtype ;
2015-01-12 13:20:53 -05:00
unsigned int N = p_ . mS * p_ . nS + p_ . mS * p_ . kS + p_ . kS * p_ . nS ;
return N * size_of ( numeric_t ) ;
}
2015-07-11 09:36:01 -04:00
int gemm : : is_invalid_impl ( driver : : Device const & , expressions_tuple const & expressions ) const
2015-01-12 13:20:53 -05:00
{
2015-06-28 17:53:16 -07:00
std : : vector < int_t > MNK = input_sizes ( expressions ) ;
int_t M = MNK [ 0 ] ; int_t N = MNK [ 1 ] ;
2015-07-09 15:03:55 -04:00
if ( p_ . A_fetching_policy ! = FETCH_FROM_LOCAL | | p_ . B_fetching_policy ! = FETCH_FROM_LOCAL )
throw operation_not_supported_exception ( " Only local memory is supported for GEMM " ) ;
2015-07-18 10:24:44 -07:00
if ( p_ . depth > 1 & & M * N * p_ . depth > 2e6 )
throw operation_not_supported_exception ( " This would necessitate a temporary larger than 1MB " ) ;
2015-01-12 13:20:53 -05:00
if ( ( p_ . mS % p_ . simd_width ) > 0 | | ( p_ . nS % p_ . simd_width ) > 0 )
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE ;
2015-04-29 15:50:57 -04:00
if ( p_ . mL > 256 | | p_ . nL > 256 )
return 1 ;
if ( p_ . kS % p_ . kL = = 0 )
2015-01-12 13:20:53 -05:00
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL ;
if ( p_ . A_fetching_policy = = FETCH_FROM_LOCAL | | p_ . B_fetching_policy = = FETCH_FROM_LOCAL )
{
if ( ( p_ . local_fetch_0 * p_ . local_fetch_1 ) ! = ( p_ . local_size_0 * p_ . local_size_1 ) )
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT ;
}
if ( p_ . A_fetching_policy = = FETCH_FROM_LOCAL )
{
unsigned int bound1 = ( A_trans_ = = ' N ' ) ? p_ . kL : p_ . mL ;
unsigned int bound0 = ( A_trans_ = = ' N ' ) ? p_ . mL : p_ . kL ;
if ( p_ . local_fetch_1 > 0 & & ( bound1 % p_ . local_fetch_1 ) > 0 )
return A_trans_ = = ' N ' ? TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE : TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE ;
if ( p_ . local_fetch_0 > 0 & & ( bound0 % ( p_ . local_fetch_0 * p_ . simd_width ) ) > 0 )
return A_trans_ = = ' N ' ? TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE : TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE ;
}
if ( p_ . B_fetching_policy = = FETCH_FROM_LOCAL )
{
unsigned int bound1 = ( B_trans_ = = ' T ' ) ? p_ . kL : p_ . nL ;
unsigned int bound0 = ( B_trans_ = = ' T ' ) ? p_ . nL : p_ . kL ;
if ( p_ . local_fetch_1 > 0 & & ( bound1 % p_ . local_fetch_1 ) > 0 )
return B_trans_ = = ' T ' ? TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE : TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE ;
if ( p_ . local_fetch_0 > 0 & & ( bound0 % ( p_ . local_fetch_0 * p_ . simd_width ) ) > 0 )
return B_trans_ = = ' T ' ? TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE : TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE ;
}
return TEMPLATE_VALID ;
}
2015-07-11 09:36:01 -04:00
std : : string gemm : : generate_impl ( const char * suffix , expressions_tuple const & expressions , driver : : Device const & device , std : : vector < mapping_type > const & ) const
2015-01-12 13:20:53 -05:00
{
using std : : string ;
using tools : : to_string ;
2015-04-29 15:50:57 -04:00
driver : : backend_type backend = device . backend ( ) ;
2015-07-14 20:40:29 -07:00
bool has_depth = p_ . depth > 1 ;
2015-04-29 15:50:57 -04:00
# define VLOAD(offset, ptr) vload(p_.simd_width, sdtype, offset, ptr, backend)
# define VSTORE(value, offset, ptr) vstore(p_.simd_width, sdtype, value, offset, ptr, backend)
2015-07-09 10:52:54 -04:00
# define ASTRIDE1 string(check_bounds_?"*Astride1":"")
# define BSTRIDE1 string(check_bounds_?"*Bstride1":"")
# define CSTRIDE1 string(check_bounds_?"*Cstride1":"")
2015-01-12 13:20:53 -05:00
//////////////////
/// INIT
/// //////////////
kernel_generation_stream stream ;
2015-02-01 22:28:49 -05:00
array_expression const & st = ( * expressions . data ( ) . front ( ) ) ;
2015-01-16 07:31:39 -05:00
numeric_type dtype = lhs_most ( st . tree ( ) , st . root ( ) ) . lhs . dtype ;
2015-04-29 15:50:57 -04:00
std : : string sdtype = numeric_type_to_string ( dtype ) ;
std : : string vdtype = append_width ( sdtype , p_ . simd_width ) ;
std : : string _size_t = size_type ( device ) ;
2015-01-12 13:20:53 -05:00
//////////////////
/// DECLARATIONS
/// //////////////
2015-04-29 15:50:57 -04:00
char gemm_name [ 32 ] = { " gemm " } ;
char reduce_name [ 32 ] = { " reduce " } ;
strcat ( gemm_name , suffix ) ;
strcat ( reduce_name , suffix ) ;
2015-05-13 02:20:44 -04:00
switch ( backend )
{
# ifdef ISAAC_WITH_CUDA
case driver : : CUDA : stream < < " #include \" helper_math.h \" " < < std : : endl ; break ;
# endif
case driver : : OPENCL : stream < < " __attribute__((reqd_work_group_size( " < < p_ . local_size_0 < < " , " < < p_ . local_size_1 < < " ,1))) " < < std : : endl ; break ;
}
2015-04-29 15:50:57 -04:00
stream < < KernelPrefix ( backend ) < < " void " < < gemm_name < < " ( " < < _size_t < < " M, " < < _size_t < < " N, " < < _size_t < < " K, "
2015-07-14 20:40:29 -07:00
< < Global ( backend ) < < " " < < sdtype < < " * C, " < < _size_t < < " ldc, " < < _size_t < < " offc, " < < _size_t < < " Cstride1, "
2015-04-29 15:50:57 -04:00
< < sdtype < < " alpha, "
2015-07-14 20:40:29 -07:00
< < Global ( backend ) < < " " < < sdtype < < " * A, " < < _size_t < < " lda, " < < _size_t < < " offa, " < < _size_t < < " Astride1, "
< < Global ( backend ) < < " " < < sdtype < < " * B, " < < _size_t < < " ldb, " < < _size_t < < " offb, " < < _size_t < < " Bstride1, "
2015-04-29 15:50:57 -04:00
< < sdtype < < " beta) "
2015-01-12 13:20:53 -05:00
< < std : : endl ;
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-07-14 20:41:34 -07:00
///Declare
2015-07-18 10:24:44 -07:00
stream < < " //Declarations " < < std : : endl ;
//Block
2015-07-16 10:40:38 -04:00
stream < < sdtype < < " rC[ " < < p_ . mS < < " ][ " < < p_ . nS < < " ] = {{0}}; " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < vdtype < < " rA[ " < < p_ . kS < < " ][ " < < p_ . mS / p_ . simd_width < < " ]; " < < std : : endl ;
stream < < vdtype < < " rB[ " < < p_ . kS < < " ][ " < < p_ . nS / p_ . simd_width < < " ]; " < < std : : endl ;
2015-07-18 10:24:44 -07:00
//Pointers
2015-07-14 20:41:34 -07:00
size_t llda = ( A_trans_ = = ' N ' ) ? p_ . mL : p_ . kL ;
size_t lldb = ( B_trans_ = = ' T ' ) ? p_ . nL : p_ . kL ;
2015-07-14 20:48:52 -07:00
stream < < Local ( backend ) < < " " < < sdtype < < " lA[ " < < p_ . kL * p_ . mL < < " ]; " < < std : : endl ;
2015-07-14 20:41:34 -07:00
stream < < Local ( backend ) < < " " < < sdtype < < " lB[ " < < p_ . kL * p_ . nL < < " ]; " < < std : : endl ;
2015-07-18 10:24:44 -07:00
stream < < LocalPtr ( backend ) < < " " < < sdtype < < " * readA, * readB, * storeA, * storeB; " < < std : : endl ;
2015-07-14 20:48:52 -07:00
unsigned int npA = p_ . mL / ( A_trans_ = = ' N ' ? p_ . local_fetch_0 * p_ . simd_width : p_ . local_fetch_1 ) ;
unsigned int npB = p_ . nL / ( B_trans_ = = ' T ' ? p_ . local_fetch_0 * p_ . simd_width : p_ . local_fetch_1 ) ;
2015-07-18 10:24:44 -07:00
stream < < Global ( backend ) < < " " < < sdtype < < " * Ai[ " < < npA < < " ]; " < < std : : endl ;
stream < < Global ( backend ) < < " " < < sdtype < < " * Bi[ " < < npB < < " ]; " < < std : : endl ;
//Helpers
2015-07-18 16:06:17 -04:00
stream < < " long4 ids; " < < std : : endl ;
stream < < " int2 idT; " < < std : : endl ;
stream < < _size_t < < " idt; " < < std : : endl ;
if ( has_depth )
stream < < _size_t < < " gidz, div, offz; " < < std : : endl ;
2015-07-18 17:23:53 -04:00
stream < < " int Ky, Kx; " < < std : : endl ;
2015-07-18 10:24:44 -07:00
stream < < std : : endl ;
stream < < " //Helpers " < < std : : endl ;
2015-07-14 20:40:29 -07:00
2015-07-18 16:06:17 -04:00
stream < < " ids.x = " < < GroupIdx0 ( backend ) < < " ; " < < std : : endl ;
stream < < " ids.y = " < < GroupIdx1 ( backend ) < < " ; " < < std : : endl ;
stream < < " ids.z = " < < LocalIdx0 ( backend ) < < " ; " < < std : : endl ;
stream < < " ids.w = " < < LocalIdx1 ( backend ) < < " ; " < < std : : endl ;
2015-07-14 20:40:29 -07:00
if ( has_depth )
{
2015-07-18 16:06:17 -04:00
stream < < " gidz = " < < GroupIdx2 ( backend ) < < " ; " < < std : : endl ;
stream < < " div = (K+ " < < p_ . depth - 1 < < " )/ " < < p_ . depth < < " ; " < < std : : endl ;
stream < < " offz = div*gidz; " < < std : : endl ;
2015-07-16 13:29:07 -04:00
stream < < " K = min(K - div*gidz, div); " < < std : : endl ;
2015-04-29 15:50:57 -04:00
}
2015-07-18 16:06:17 -04:00
stream < < " idt = " < < p_ . local_size_0 < < " *ids.w + ids.z; " < < std : : endl ;
stream < < " idT.y = idt/ " < < p_ . local_fetch_0 < < " ; " < < std : : endl ;
stream < < " idT.x = idt - " < < p_ . local_fetch_0 < < " *idT.y; " < < std : : endl ;
2015-07-18 13:09:38 -04:00
stream < < " ids.x *= " < < p_ . mL < < " ; " < < std : : endl ;
stream < < " ids.y *= " < < p_ . nL < < " ; " < < std : : endl ;
stream < < " idT.x *= " < < p_ . simd_width < < " ; " < < std : : endl ;
2015-07-18 16:06:17 -04:00
stream < < " M -= ids.x; " < < std : : endl ;
stream < < " N -= ids.y; " < < std : : endl ;
2015-07-18 10:24:44 -07:00
stream < < std : : endl ;
stream < < " // Offset A " < < std : : endl ;
stream < < " A += offa; " < < std : : endl ;
2015-07-09 11:40:26 -04:00
if ( A_trans_ = = ' N ' )
2015-07-18 13:09:38 -04:00
stream < < " A += (idT.x + ids.x) " < < ASTRIDE1 < < " + idT.y*lda " < < ( has_depth ? " + offz*lda " : " " ) < < " ; " < < std : : endl ;
2015-07-09 11:40:26 -04:00
else
2015-07-18 13:09:38 -04:00
stream < < " A += idT.x " < < ASTRIDE1 < < " + idT.y*lda + ids.x*lda " < < ( has_depth ? " + offz " : " " ) < < " ; " < < std : : endl ;
2015-07-18 10:24:44 -07:00
for ( int i = 0 ; i < npA ; + + i )
stream < < " Ai[ " < < i < < " ] = A; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
for ( unsigned int i = 0 ; i < npA ; i + + )
2015-07-07 23:37:53 -07:00
if ( A_trans_ = = ' N ' )
2015-07-18 16:06:17 -04:00
stream < < " if(idT.x + " < < i < < " * " < < p_ . local_fetch_0 * p_ . simd_width < < " < M) Ai[ " < < i < < " ] += " < < i * p_ . local_fetch_0 * p_ . simd_width < < ASTRIDE1 < < " ; " < < std : : endl ;
2015-07-07 23:37:53 -07:00
else
2015-07-18 16:06:17 -04:00
stream < < " if(idT.y + " < < i < < " * " < < p_ . local_fetch_1 < < " < M) Ai[ " < < i < < " ] += " < < i * p_ . local_fetch_1 < < " *lda; " < < std : : endl ;
2015-07-18 10:24:44 -07:00
stream < < " storeA = lA + idT.y* " < < llda < < " + idT.x; " < < std : : endl ;
2015-07-07 23:37:53 -07:00
2015-07-18 10:24:44 -07:00
stream < < std : : endl ;
stream < < " // Offset B " < < std : : endl ;
stream < < " B += offb; " < < std : : endl ;
if ( B_trans_ = = ' T ' )
stream < < " B += (idT.x + ids.y) " < < BSTRIDE1 < < " + idT.y*ldb " < < ( has_depth ? " + offz*ldb " : " " ) < < " ; " < < std : : endl ;
else
stream < < " B += idT.x " < < BSTRIDE1 < < " + idT.y*ldb + ids.y*ldb " < < ( has_depth ? " + offz " : " " ) < < " ; " < < std : : endl ;
for ( int i = 0 ; i < npB ; + + i )
stream < < " Bi[ " < < i < < " ] = B; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
for ( unsigned int i = 0 ; i < npB ; i + + )
if ( B_trans_ = = ' T ' )
2015-07-18 16:06:17 -04:00
stream < < " if(idT.x + " < < i < < " * " < < p_ . local_fetch_0 * p_ . simd_width < < " < N) Bi[ " < < i < < " ] += " < < i * p_ . local_fetch_0 * p_ . simd_width < < BSTRIDE1 < < " ; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
else
2015-07-18 16:06:17 -04:00
stream < < " if(idT.y + " < < i < < " * " < < p_ . local_fetch_1 < < " < N) Bi[ " < < i < < " ] += " < < i * p_ . local_fetch_1 < < " *ldb; " < < std : : endl ;
2015-07-18 17:23:53 -04:00
stream < < " storeB = lB + idT.y* " < < lldb < < " + idT.x; " < < std : : endl ;
2015-07-18 16:06:17 -04:00
2015-07-18 10:24:44 -07:00
stream < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " //Outer loop " < < std : : endl ;
2015-07-18 10:24:44 -07:00
stream < < " while(K > 0) " < < std : : endl ;
stream < < " { " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream . inc_tab ( ) ;
2015-07-09 15:03:55 -04:00
stream < < LocalBarrier ( backend ) < < " ; " < < std : : endl ;
2015-01-12 13:20:53 -05:00
2015-07-18 17:23:53 -04:00
if ( A_trans_ = = ' N ' | | B_trans_ = = ' T ' )
stream < < " Ky = K - idT.y; " < < std : : endl ;
if ( A_trans_ = = ' T ' | | B_trans_ = = ' N ' )
stream < < " Kx = K - idT.x; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
stream < < " //Fetch A to local memory " < < std : : endl ;
if ( A_trans_ = = ' N ' )
2015-07-14 20:40:29 -07:00
{
2015-07-09 15:03:55 -04:00
for ( int_t k = 0 ; k < p_ . kL ; k + = p_ . local_fetch_1 )
for ( int_t m = 0 ; m < p_ . mL ; m + = p_ . local_fetch_0 * p_ . simd_width )
{
std : : string mm = to_string ( m / ( p_ . simd_width * p_ . local_fetch_0 ) ) ;
std : : string kk = to_string ( k ) ;
2015-07-14 20:40:29 -07:00
string to_load = VLOAD ( " 0 " , " &Ai[ " + mm + " ][ " + kk + " *lda] " ) ;
2015-07-18 17:23:53 -04:00
to_load = " ( " + kk + " < Ky)? " + to_load + " :0 " ;
stream < < VSTORE ( to_load , " 0 " , " storeA + " + to_string ( k * llda + m ) ) < < " ; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
}
2015-07-14 20:40:29 -07:00
}
2015-07-09 15:03:55 -04:00
else
2015-07-14 20:40:29 -07:00
{
2015-07-10 21:15:36 -07:00
for ( int_t k = 0 ; k < p_ . kL ; k + = p_ . local_fetch_0 * p_ . simd_width )
for ( int_t m = 0 ; m < p_ . mL ; m + = p_ . local_fetch_1 )
2015-07-08 21:09:21 -07:00
{
2015-07-10 21:15:36 -07:00
std : : string mm = to_string ( m / p_ . local_fetch_1 ) ;
std : : string kk = to_string ( k ) ;
string to_load = VLOAD ( " 0 " , " &Ai[ " + mm + " ][ " + kk + ASTRIDE1 + " ] " ) ;
2015-07-18 17:23:53 -04:00
to_load = " ( " + kk + " < Kx)? " + to_load + " :0 " ;
stream < < VSTORE ( to_load , " 0 " , " storeA + " + to_string ( m * llda + k ) ) < < " ; " < < std : : endl ;
2015-07-08 21:09:21 -07:00
}
2015-07-14 20:40:29 -07:00
}
2015-07-08 21:09:21 -07:00
2015-07-09 15:03:55 -04:00
stream < < " //Fetch B to local memory " < < std : : endl ;
if ( B_trans_ = = ' T ' )
2015-07-14 20:40:29 -07:00
{
2015-07-09 15:03:55 -04:00
for ( int_t k = 0 ; k < p_ . kL ; k + = p_ . local_fetch_1 )
for ( int_t n = 0 ; n < p_ . nL ; n + = p_ . local_fetch_0 * p_ . simd_width )
{
std : : string nn = to_string ( n / ( p_ . simd_width * p_ . local_fetch_0 ) ) ;
std : : string kk = to_string ( k ) ;
2015-07-14 20:40:29 -07:00
string to_load = VLOAD ( " 0 " , " &Bi[ " + nn + " ][ " + kk + " *ldb] " ) ;
2015-07-18 17:23:53 -04:00
to_load = " ( " + kk + " < Ky)? " + to_load + " :0 " ;
stream < < VSTORE ( to_load , " 0 " , " storeB + " + to_string ( k * lldb + n ) ) < < " ; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
}
2015-07-14 20:40:29 -07:00
}
2015-07-09 15:03:55 -04:00
else
2015-07-14 20:40:29 -07:00
{
2015-07-10 21:15:36 -07:00
for ( int_t k = 0 ; k < p_ . kL ; k + = p_ . local_fetch_0 * p_ . simd_width )
for ( int_t n = 0 ; n < p_ . nL ; n + = p_ . local_fetch_1 )
2015-07-09 15:03:55 -04:00
{
2015-07-10 21:15:36 -07:00
std : : string nn = to_string ( n / p_ . local_fetch_1 ) ;
std : : string kk = to_string ( k ) ;
string to_load = VLOAD ( " 0 " , " &Bi[ " + nn + " ][ " + kk + BSTRIDE1 + " ] " ) ;
2015-07-18 17:23:53 -04:00
to_load = " ( " + kk + " < Kx)? " + to_load + " :0 " ;
stream < < VSTORE ( to_load , " 0 " , " storeB + " + to_string ( n * lldb + k ) ) < < " ; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
}
2015-07-14 20:40:29 -07:00
}
2015-07-09 15:03:55 -04:00
stream < < LocalBarrier ( backend ) < < " ; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
if ( A_trans_ = = ' N ' )
2015-07-18 17:23:53 -04:00
stream < < " readA = lA + ids.z* " < < p_ . simd_width < < " ; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
else
2015-07-18 17:23:53 -04:00
stream < < " readA = lA + ids.z* " < < llda * p_ . simd_width < < " ; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
if ( B_trans_ = = ' T ' )
2015-07-18 17:23:53 -04:00
stream < < " readB = lB + ids.w* " < < p_ . simd_width < < " ; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
else
2015-07-18 17:23:53 -04:00
stream < < " readB = lB + ids.w* " < < lldb * p_ . simd_width < < " ; " < < std : : endl ;
2015-04-29 15:50:57 -04:00
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
stream < < " //Inner loop " < < std : : endl ;
2015-07-18 10:24:44 -07:00
stream < < " #pragma unroll " < < std : : endl ;
2015-07-09 13:09:01 -04:00
stream < < " for(unsigned int k = 0; k < " < < p_ . kL < < " ; k+= " < < p_ . kS < < " ){ " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream . inc_tab ( ) ;
2015-04-29 15:50:57 -04:00
stream < < " //Fetch A to registers " < < std : : endl ;
stream < < " #pragma unroll " < < std : : endl ;
stream < < " for(unsigned int kk = 0; kk < " < < p_ . kS < < " ; kk++) " < < std : : endl ;
stream < < " #pragma unroll " < < p_ . mS / p_ . simd_width < < std : : endl ;
stream < < " for(unsigned int mm = 0; mm < " < < p_ . mS / p_ . simd_width < < " ; mm++) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-07-10 21:15:36 -07:00
if ( A_trans_ = = ' N ' )
2015-07-14 20:40:29 -07:00
stream < < " rA[kk][mm] = " < < VLOAD ( " 0 " , " readA + k* " + to_string ( llda ) + " + mm* " + to_string ( p_ . local_size_0 * p_ . simd_width ) + " + kk* " + to_string ( llda ) ) < < " ; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
else
{
if ( p_ . simd_width = = 1 )
2015-07-14 20:40:29 -07:00
stream < < " rA[kk][mm] = readA[k + mm* " < < p_ . local_size_0 * llda < < " + kk " < < " ]; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
else
for ( unsigned int s = 0 ; s < p_ . simd_width ; + + s )
2015-07-14 20:40:29 -07:00
stream < < access_vector_type ( " rA[kk][mm] " , s ) < < " = readA[k + (mm* " < < p_ . simd_width * p_ . local_size_0 < < " + " < < s < < " )* " < < llda < < " + kk]; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
}
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " //Fetch B to registers " < < std : : endl ;
stream < < " #pragma unroll " < < p_ . kS < < std : : endl ;
stream < < " for(unsigned int kk = 0; kk < " < < p_ . kS < < " ; kk++) " < < std : : endl ;
stream < < " #pragma unroll " < < p_ . nS / p_ . simd_width < < std : : endl ;
stream < < " for(unsigned int nn = 0; nn < " < < p_ . nS / p_ . simd_width < < " ; nn++) " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-07-10 21:15:36 -07:00
if ( B_trans_ = = ' T ' )
2015-07-14 20:40:29 -07:00
stream < < " rB[kk][nn] = " < < VLOAD ( " 0 " , " readB + k* " + to_string ( lldb ) + " + nn* " + to_string ( p_ . local_size_1 * p_ . simd_width ) + " + kk* " + to_string ( lldb ) ) < < " ; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
else
{
if ( p_ . simd_width = = 1 )
2015-07-14 20:40:29 -07:00
stream < < " rB[kk][nn] = readB[k " < < " + nn* " < < p_ . local_size_1 * lldb < < " + kk " < < " ]; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
else
for ( unsigned int s = 0 ; s < p_ . simd_width ; + + s )
2015-07-14 20:40:29 -07:00
stream < < access_vector_type ( " rB[kk][nn] " , s ) < < " = readB[k " < < " + (nn* " < < p_ . simd_width * p_ . local_size_1 < < " + " < < s < < " )* " < < lldb < < " + kk]; " < < std : : endl ;
2015-07-10 21:15:36 -07:00
}
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " //FMA computations " < < std : : endl ;
for ( int_t kk = 0 ; kk < p_ . kS ; + + kk )
for ( int_t nn = 0 ; nn < p_ . nS ; + + nn )
for ( int_t mm = 0 ; mm < p_ . mS ; + + mm )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
string res_str , lhs_str , rhs_str ;
res_str = " rC[ " + to_string ( mm ) + " ][ " + to_string ( nn ) + " ] " ;
if ( p_ . simd_width = = 1 )
lhs_str = " rA[ " + to_string ( kk ) + " ][ " + to_string ( mm ) + " ] " ;
2015-01-12 13:20:53 -05:00
else
2015-04-29 15:50:57 -04:00
lhs_str = access_vector_type ( " rA[ " + to_string ( kk ) + " ][ " + to_string ( mm / p_ . simd_width ) + " ] " , mm % p_ . simd_width ) ;
if ( p_ . simd_width = = 1 )
rhs_str = " rB[ " + to_string ( kk ) + " ][ " + to_string ( nn ) + " ] " ;
2015-01-12 13:20:53 -05:00
else
2015-04-29 15:50:57 -04:00
rhs_str = access_vector_type ( " rB[ " + to_string ( kk ) + " ][ " + to_string ( nn / p_ . simd_width ) + " ] " , nn % p_ . simd_width ) ;
stream < < res_str < < " = " < < " fma( " < < lhs_str < < " , " < < rhs_str < < " , " < < res_str < < " ); " < < std : : endl ;
2015-01-12 13:20:53 -05:00
}
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-07-09 15:03:55 -04:00
//Increment A pointers to global memory
if ( A_trans_ = = ' N ' )
for ( unsigned int i = 0 ; i < npA ; + + i )
2015-07-14 20:40:29 -07:00
stream < < " Ai[ " < < i < < " ] += " < < p_ . kL < < " *lda; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
else
for ( unsigned int i = 0 ; i < npA ; + + i )
2015-07-10 15:58:47 -04:00
stream < < " Ai[ " < < i < < " ] += " < < p_ . kL < < ASTRIDE1 < < " ; " < < std : : endl ;
2015-01-12 13:20:53 -05:00
2015-07-09 15:03:55 -04:00
//Increment B pointers to global memory
if ( B_trans_ = = ' T ' )
for ( unsigned int i = 0 ; i < npB ; + + i )
2015-07-14 20:40:29 -07:00
stream < < " Bi[ " < < i < < " ] += " < < p_ . kL < < " *ldb; " < < std : : endl ;
2015-07-09 15:03:55 -04:00
else
for ( unsigned int i = 0 ; i < npB ; + + i )
2015-07-10 15:58:47 -04:00
stream < < " Bi[ " < < i < < " ] += " < < p_ . kL < < BSTRIDE1 < < " ; " < < std : : endl ;
2015-01-12 13:20:53 -05:00
2015-07-18 16:06:17 -04:00
stream < < " K -= " < < p_ . kL < < " ; " < < std : : endl ;
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-07-18 10:24:44 -07:00
stream < < std : : endl ;
stream < < " // Offset C " < < std : : endl ;
stream < < " C += offc; " < < std : : endl ;
stream < < " C += ids.x " < < CSTRIDE1 < < " ; " < < std : : endl ;
stream < < " C += ids.z* " < < p_ . simd_width < < CSTRIDE1 < < " ; " < < std : : endl ;
stream < < " C += ids.y*ldc; " < < std : : endl ;
stream < < " C += ids.w*ldc* " < < p_ . simd_width < < " ; " < < std : : endl ;
if ( has_depth )
stream < < " C += gidz*ldc*N; " < < std : : endl ;
stream < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " //Write back C " < < std : : endl ;
2015-07-18 10:24:44 -07:00
stream < < " M -= ids.z* " < < p_ . simd_width < < " ; " < < std : : endl ;
stream < < " N -= ids.w* " < < p_ . simd_width < < " ; " < < std : : endl ;
2015-07-15 23:28:43 -07:00
stream < < " int ibm[ " < < p_ . mS < < " ]; " < < std : : endl ;
2015-04-29 15:50:57 -04:00
for ( int_t m = 0 ; m < p_ . mS ; + + m )
2015-07-15 23:28:43 -07:00
{
string Ci = to_string ( ( m / p_ . simd_width ) * ( p_ . local_size_0 * p_ . simd_width ) + m % p_ . simd_width ) ;
stream < < " ibm[ " < < m < < " ] = " < < Ci < < " < M; " < < std : : endl ;
}
2015-04-29 15:50:57 -04:00
for ( int_t n = 0 ; n < p_ . nS ; + + n )
2015-01-12 13:20:53 -05:00
{
2015-07-15 23:28:43 -07:00
string Cj = to_string ( ( n / p_ . simd_width ) * ( p_ . local_size_1 * p_ . simd_width ) + n % p_ . simd_width ) ;
for ( int_t m = 0 ; m < p_ . mS ; + + m )
stream < < " rC[ " < < m < < " ][ " < < n < < " ] *= alpha; " < < std : : endl ;
2015-07-18 10:24:44 -07:00
for ( int_t m = 0 ; m < p_ . mS ; + + m )
stream < < " ibm[ " < < m < < " ] = ibm[ " < < m < < " ] && ( " < < Cj < < " < N); " < < std : : endl ;
2015-07-15 23:28:43 -07:00
for ( int_t m = 0 ; m < p_ . mS ; + + m )
{
string Ci = to_string ( ( m / p_ . simd_width ) * ( p_ . local_size_0 * p_ . simd_width ) + m % p_ . simd_width ) ;
2015-07-18 10:24:44 -07:00
stream < < " if(ibm[ " < < m < < " ]) " ;
2015-07-16 00:29:58 -07:00
stream < < " C[ " < < Ci < < CSTRIDE1 < < " ] = rC[ " < < m < < " ][ " < < n < < " ] + select(( " < < sdtype < < " )0, C[ " < < Ci < < CSTRIDE1 < < " ], beta>0); " < < std : : endl ;
2015-07-15 23:28:43 -07:00
}
if ( ( n + 1 ) % p_ . simd_width = = 0 )
stream < < " C += ldc* " < < p_ . local_size_1 * p_ . simd_width - p_ . simd_width + 1 < < " ; " < < std : : endl ;
else
stream < < " C += ldc; " < < std : : endl ;
2015-04-29 15:50:57 -04:00
2015-01-12 13:20:53 -05:00
}
2015-07-15 23:28:43 -07:00
2015-01-12 13:20:53 -05:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
2015-04-29 15:50:57 -04:00
2015-07-14 20:40:29 -07:00
if ( has_depth )
2015-04-29 15:50:57 -04:00
{
stream < < KernelPrefix ( backend ) < < " void " < < reduce_name < < " ( " < < _size_t < < " M, " < < _size_t < < " N, " < < _size_t < < " D, "
< < Global ( backend ) < < " " < < sdtype < < " * Z, " < < _size_t < < " Zld, "
2015-07-14 20:40:29 -07:00
< < Global ( backend ) < < " " < < sdtype < < " * C, " < < _size_t < < " ldc, " < < _size_t < < " Cstart1, " < < _size_t < < " Cstart2, " < < _size_t < < " Cstride1, " < < _size_t < < " Cstride2, "
2015-04-29 15:50:57 -04:00
< < sdtype < < " beta) "
< < std : : endl ;
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
2015-07-14 20:40:29 -07:00
stream < < " C += Cstart1 + Cstart2*ldc; " < < std : : endl ;
stream < < " ldc *= Cstride2; " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream < < " for(unsigned int i = " < < GlobalIdx0 ( backend ) < < " ; i < M ; i += " < < GlobalSize0 ( backend ) < < " ) " < < std : : endl ;
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
stream < < " for(unsigned int j = " < < GlobalIdx1 ( backend ) < < " ; j < N ; j += " < < GlobalSize1 ( backend ) < < " ) " < < std : : endl ;
stream < < " { " < < std : : endl ;
stream . inc_tab ( ) ;
stream < < sdtype < < " acc = 0; " < < std : : endl ;
2015-07-10 23:16:21 -07:00
stream < < " for(unsigned int k = 0 ; k < D ; k++) " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream . inc_tab ( ) ;
stream < < " acc += Z[i + j*Zld + k*Zld*N]; " < < std : : endl ;
stream . dec_tab ( ) ;
2015-07-14 20:40:29 -07:00
stream < < " C[i*Cstride1 + j*ldc] = acc + beta*C[i + j*ldc]; " < < std : : endl ;
2015-04-29 15:50:57 -04:00
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
stream . dec_tab ( ) ;
stream < < " } " < < std : : endl ;
}
2015-07-16 00:29:58 -07:00
// if(p_.simd_width>1)
// std::cout << stream.str() << std::endl;
2015-01-12 13:20:53 -05:00
return stream . str ( ) ;
# undef VLOAD
# undef VST0RE
}
2015-07-11 09:36:01 -04:00
void gemm : : enqueue_block ( driver : : CommandQueue & /*queue*/ , int_t M , int_t N , int_t K ,
2015-04-29 15:50:57 -04:00
array const & A , array const & B , array const & C ,
2015-01-12 13:20:53 -05:00
value_scalar const & alpha , value_scalar const & beta ,
2015-04-29 15:50:57 -04:00
driver : : Program & program , const char * suffix , execution_options_type const & options )
2015-01-12 13:20:53 -05:00
{
2015-07-16 13:29:07 -04:00
using tools : : align ;
2015-04-29 15:50:57 -04:00
2015-07-16 13:29:07 -04:00
if ( M = = 0 | | N = = 0 | | K = = 0 )
return ;
2015-07-09 11:40:26 -04:00
2015-04-29 15:50:57 -04:00
char gemm_name [ 32 ] = { " gemm " } ;
char reduce_name [ 32 ] = { " reduce " } ;
strcat ( gemm_name , suffix ) ;
strcat ( reduce_name , suffix ) ;
bind_all_unique binder ;
array const * out = & C ;
std : : unique_ptr < array > tmp ;
if ( p_ . depth > 1 ) {
tmp . reset ( new array ( M , N , p_ . depth , C . dtype ( ) , C . context ( ) ) ) ;
out = tmp . get ( ) ;
}
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
driver : : Kernel gemm ( program , gemm_name ) ;
driver : : NDRange local ( p_ . local_size_0 , p_ . local_size_1 ) ;
2015-01-12 13:20:53 -05:00
2015-07-02 14:02:31 -04:00
driver : : NDRange global ( align ( align ( M , p_ . mS ) / p_ . mS , p_ . local_size_0 ) , align ( align ( N , p_ . nS ) / p_ . nS , p_ . local_size_1 ) , p_ . depth ) ;
2015-06-30 17:55:57 -04:00
2015-01-12 13:20:53 -05:00
unsigned int current_arg = 0 ;
2015-04-29 15:50:57 -04:00
set_arguments_functor helper ( binder , current_arg , gemm ) ;
gemm . setSizeArg ( current_arg + + , M ) ;
gemm . setSizeArg ( current_arg + + , N ) ;
gemm . setSizeArg ( current_arg + + , K ) ;
gemm . setArg ( current_arg + + , out - > data ( ) ) ;
gemm . setSizeArg ( current_arg + + , out - > ld ( ) * out - > stride ( ) [ 1 ] ) ;
gemm . setSizeArg ( current_arg + + , out - > start ( ) [ 0 ] + out - > start ( ) [ 1 ] * out - > ld ( ) ) ;
gemm . setSizeArg ( current_arg + + , out - > stride ( ) [ 0 ] ) ;
helper . set_arguments ( alpha . dtype ( ) , alpha . values ( ) ) ;
gemm . setArg ( current_arg + + , A . data ( ) ) ;
2015-07-08 21:09:21 -07:00
gemm . setSizeArg ( current_arg + + , A . ld ( ) * A . stride ( ) [ 1 ] ) ;
gemm . setSizeArg ( current_arg + + , ( A . start ( ) [ 0 ] + A . start ( ) [ 1 ] * A . ld ( ) ) ) ;
2015-04-29 15:50:57 -04:00
gemm . setSizeArg ( current_arg + + , A . stride ( ) [ 0 ] ) ;
gemm . setArg ( current_arg + + , B . data ( ) ) ;
2015-07-08 21:09:21 -07:00
gemm . setSizeArg ( current_arg + + , B . ld ( ) * B . stride ( ) [ 1 ] ) ;
gemm . setSizeArg ( current_arg + + , B . start ( ) [ 0 ] + B . start ( ) [ 1 ] * B . ld ( ) ) ;
2015-04-29 15:50:57 -04:00
gemm . setSizeArg ( current_arg + + , B . stride ( ) [ 0 ] ) ;
helper . set_arguments ( beta . dtype ( ) , beta . values ( ) ) ;
2015-06-23 09:38:34 -07:00
options . enqueue ( program . context ( ) , gemm , global , local ) ;
2015-04-29 15:50:57 -04:00
2015-06-30 17:55:57 -04:00
options . queue ( program . context ( ) ) . synchronize ( ) ;
2015-04-29 15:50:57 -04:00
if ( p_ . depth > 1 )
{
unsigned int current_arg = 0 ;
driver : : Kernel reduce ( program , reduce_name ) ;
driver : : NDRange local ( p_ . local_size_0 , p_ . local_size_1 ) ;
2015-07-02 14:02:31 -04:00
driver : : NDRange global ( align ( M , p_ . local_size_0 ) , align ( N , p_ . local_size_1 ) ) ;
2015-04-29 15:50:57 -04:00
set_arguments_functor helper ( binder , current_arg , reduce ) ;
reduce . setSizeArg ( current_arg + + , M ) ;
reduce . setSizeArg ( current_arg + + , N ) ;
reduce . setSizeArg ( current_arg + + , p_ . depth ) ;
reduce . setArg ( current_arg + + , out - > data ( ) ) ;
reduce . setSizeArg ( current_arg + + , out - > ld ( ) ) ;
reduce . setArg ( current_arg + + , C . data ( ) ) ;
reduce . setSizeArg ( current_arg + + , C . ld ( ) ) ;
reduce . setSizeArg ( current_arg + + , C . start ( ) [ 0 ] ) ;
reduce . setSizeArg ( current_arg + + , C . start ( ) [ 1 ] ) ;
reduce . setSizeArg ( current_arg + + , C . stride ( ) [ 0 ] ) ;
reduce . setSizeArg ( current_arg + + , C . stride ( ) [ 1 ] ) ;
helper . set_arguments ( beta . dtype ( ) , beta . values ( ) ) ;
2015-06-23 09:38:34 -07:00
options . enqueue ( program . context ( ) , reduce , global , local ) ;
2015-04-29 15:50:57 -04:00
}
2015-01-12 13:20:53 -05:00
}
2015-07-11 09:36:01 -04:00
array gemm : : create_slice ( array & M , int_t s0_0 , int_t s0_1 , int_t s1_0 , int_t s1_1 , bool swap )
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
slice s0 ( s0_0 , s0_1 ) ;
slice s1 ( s1_0 , s1_1 ) ;
2015-01-12 13:20:53 -05:00
if ( swap )
2015-04-29 15:50:57 -04:00
std : : swap ( s0 , s1 ) ;
return array ( M , s0 , s1 ) ;
2015-01-12 13:20:53 -05:00
}
2015-07-11 09:36:01 -04:00
std : : vector < int_t > gemm : : infos ( expressions_tuple const & expressions , symbolic : : preset : : gemm : : args & arguments ) const
2015-01-12 13:20:53 -05:00
{
2015-04-29 15:50:57 -04:00
isaac : : array_expression & array_expression = ( * expressions . data ( ) . front ( ) ) ;
array_expression : : container_type & array = array_expression . tree ( ) ;
2015-01-31 22:01:48 -05:00
std : : size_t root = array_expression . root ( ) ;
2015-06-27 17:55:01 -07:00
arguments = symbolic : : preset : : gemm : : check ( array , root ) ;
int_t M = arguments . C - > array - > shape ( ) [ 0 ] ;
int_t N = arguments . C - > array - > shape ( ) [ 1 ] ;
int_t K = ( A_trans_ = = ' T ' ) ? arguments . A - > array - > shape ( ) [ 0 ] : arguments . A - > array - > shape ( ) [ 1 ] ;
2015-04-29 15:50:57 -04:00
return { M , N , K } ;
2015-01-12 13:20:53 -05:00
}
2015-07-11 09:36:01 -04:00
gemm : : gemm ( gemm_parameters const & parameters , bool check_bounds , char A_trans , char B_trans ) : base_impl < gemm , gemm_parameters > ( parameters , BIND_ALL_UNIQUE ) , A_trans_ ( A_trans ) , B_trans_ ( B_trans ) , check_bounds_ ( check_bounds )
2015-04-29 15:50:57 -04:00
{
2015-07-11 09:36:01 -04:00
if ( A_trans_ = = ' N ' & & B_trans_ = = ' N ' ) type_ = GEMM_NN_TYPE ;
else if ( A_trans_ = = ' T ' & & B_trans_ = = ' N ' ) type_ = GEMM_TN_TYPE ;
else if ( A_trans_ = = ' N ' & & B_trans_ = = ' T ' ) type_ = GEMM_NT_TYPE ;
else if ( A_trans_ = = ' T ' & & B_trans_ = = ' T ' ) type_ = GEMM_TT_TYPE ;
2015-04-29 15:50:57 -04:00
else throw ;
}
2015-01-12 13:20:53 -05:00
2015-07-11 09:36:01 -04:00
std : : vector < int_t > gemm : : input_sizes ( expressions_tuple const & expressions ) const
2015-01-12 13:20:53 -05:00
{
2015-06-27 17:55:01 -07:00
symbolic : : preset : : gemm : : args dummy ;
return infos ( expressions , dummy ) ;
2015-01-12 13:20:53 -05:00
}
2015-07-11 09:36:01 -04:00
void gemm : : enqueue ( driver : : CommandQueue & queue , driver : : Program & program , const char * suffix , base & fallback_base , controller < expressions_tuple > const & ctr )
2015-01-12 13:20:53 -05:00
{
using namespace tools ;
2015-07-11 09:36:01 -04:00
gemm & fallback = ( gemm & ) fallback_base ;
2015-04-29 15:50:57 -04:00
expressions_tuple const & expressions = ctr . x ( ) ;
2015-02-05 04:42:57 -05:00
2015-01-12 13:20:53 -05:00
2015-06-27 17:55:01 -07:00
symbolic : : preset : : gemm : : args args ;
std : : vector < int_t > MNK = infos ( expressions , args ) ;
2015-02-05 04:42:57 -05:00
2015-01-12 13:20:53 -05:00
int_t M = MNK [ 0 ] ;
int_t N = MNK [ 1 ] ;
int_t K = MNK [ 2 ] ;
2015-04-29 15:50:57 -04:00
//Skip if empty
if ( M = = 0 | | N = = 0 | | K = = 0 )
return ;
//Extract
2015-06-27 17:55:01 -07:00
array * pA = args . A - > array ;
array * pB = args . B - > array ;
array * pC = args . C - > array ;
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
//Check if requires fallback
int_t ldstrideA = pA - > stride ( ) [ 0 ] ;
int_t ldstrideB = pB - > stride ( ) [ 0 ] ;
int_t ldstrideC = pC - > stride ( ) [ 0 ] ;
2015-01-12 13:20:53 -05:00
2015-06-27 17:55:01 -07:00
numeric_type dtype = args . C - > dtype ;
2015-01-12 13:20:53 -05:00
2015-04-29 15:50:57 -04:00
//Enqueue
2015-01-12 13:20:53 -05:00
bool swap_A = ( A_trans_ = = ' T ' ) ;
bool swap_B = ( B_trans_ = = ' T ' ) ;
2015-04-29 15:50:57 -04:00
value_scalar beta ( 0 , dtype ) ;
2015-06-29 21:52:50 -07:00
if ( args . beta ) beta = value_scalar ( args . beta - > vscalar , dtype ) ;
2015-04-29 15:50:57 -04:00
value_scalar alpha ( 1 , dtype ) ;
2015-06-29 21:52:50 -07:00
if ( args . alpha ) alpha = value_scalar ( args . alpha - > vscalar , dtype ) ;
2015-04-29 15:50:57 -04:00
execution_options_type const & options = ctr . execution_options ( ) ;
2015-01-12 13:20:53 -05:00
2015-07-16 13:29:07 -04:00
if ( ldstrideA > 1 | | ldstrideB > 1 | | ldstrideC > 1 )
2015-07-09 11:40:26 -04:00
{
fallback . enqueue_block ( queue , M , N , K , * pA , * pB , * pC , alpha , beta , program , " fallback " , options ) ;
}
else
{
2015-07-10 15:58:47 -04:00
// std::cout << p_.local_size_0 << " " << p_.kL << " " << p_.local_size_1 << " " << p_.depth << std::endl;
2015-07-16 13:29:07 -04:00
// value_scalar _1(1, dtype);
enqueue_block ( queue , M , N , K , * pA , * pB , * pC , alpha , beta , program , suffix , options ) ;
// fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);
2015-07-09 11:40:26 -04:00
}
2015-01-12 13:20:53 -05:00
}
//
2015-07-11 09:36:01 -04:00
gemm_nn : : gemm_nn ( unsigned int simd
2015-04-29 15:50:57 -04:00
, int_t ls0 , int_t KL , int_t ls1 , int_t D
, int_t ms , int_t ks , int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0 , int_t lfetch1 , bool check_bound ) :
2015-07-11 09:36:01 -04:00
gemm ( gemm_parameters ( simd , ls0 , KL , ls1 , D , ms , ks , ns , Afetch , Bfetch , lfetch0 , lfetch1 ) , check_bound , ' N ' , ' N ' )
2015-01-12 13:20:53 -05:00
{ }
//
2015-07-11 09:36:01 -04:00
gemm_tn : : gemm_tn ( unsigned int simd
2015-04-29 15:50:57 -04:00
, int_t ls0 , int_t KL , int_t ls1 , int_t D
, int_t ms , int_t ks , int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0 , int_t lfetch1 , bool check_bound ) :
2015-07-11 09:36:01 -04:00
gemm ( gemm_parameters ( simd , ls0 , KL , ls1 , D , ms , ks , ns , Afetch , Bfetch , lfetch0 , lfetch1 ) , check_bound , ' T ' , ' N ' )
2015-01-12 13:20:53 -05:00
{ }
//
2015-07-11 09:36:01 -04:00
gemm_nt : : gemm_nt ( unsigned int simd
2015-04-29 15:50:57 -04:00
, int_t ls0 , int_t KL , int_t ls1 , int_t D
, int_t ms , int_t ks , int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0 , int_t lfetch1 , bool check_bound ) :
2015-07-11 09:36:01 -04:00
gemm ( gemm_parameters ( simd , ls0 , KL , ls1 , D , ms , ks , ns , Afetch , Bfetch , lfetch0 , lfetch1 ) , check_bound , ' N ' , ' T ' )
2015-01-12 13:20:53 -05:00
{ }
//
2015-07-11 09:36:01 -04:00
gemm_tt : : gemm_tt ( unsigned int simd
2015-04-29 15:50:57 -04:00
, int_t ls0 , int_t KL , int_t ls1 , int_t D
, int_t ms , int_t ks , int_t ns
, fetching_policy_type Afetch , fetching_policy_type Bfetch
, int_t lfetch0 , int_t lfetch1 , bool check_bound ) :
2015-07-11 09:36:01 -04:00
gemm ( gemm_parameters ( simd , ls0 , KL , ls1 , D , ms , ks , ns , Afetch , Bfetch , lfetch0 , lfetch1 ) , check_bound , ' T ' , ' T ' )
2015-01-12 13:20:53 -05:00
{ }
2015-07-11 09:36:01 -04:00
}
}
2015-01-12 13:20:53 -05:00