Kernels: more generic temporary workspace checks
This commit is contained in:
@@ -46,16 +46,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
return N*size_of(numeric_t);
|
||||
}
|
||||
|
||||
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
|
||||
unsigned int gemm::temporary_workspace(expressions_tuple const & expressions) const
|
||||
{
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
if(p_.depth > 1)
|
||||
return M*N*p_.depth;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
{
|
||||
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
|
||||
if(p_.depth > 1 && M*N*p_.depth > 2e6)
|
||||
return TEMPLATE_TEMPORARY_TOO_LARGE;
|
||||
|
||||
if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
|
||||
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
|
||||
@@ -66,8 +70,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
if ( p_.kS % p_.kL == 0)
|
||||
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
|
||||
|
||||
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL)
|
||||
{
|
||||
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL){
|
||||
if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1))
|
||||
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
|
||||
}
|
||||
|
Reference in New Issue
Block a user