Kernels: more generic temporary workspace checks

This commit is contained in:
Philippe Tillet
2015-08-10 10:19:50 -07:00
parent f56eac5adb
commit f60b82af25
7 changed files with 25 additions and 9 deletions

View File

@@ -46,16 +46,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
return N*size_of(numeric_t);
}
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
unsigned int gemm::temporary_workspace(expressions_tuple const & expressions) const
{
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
if(p_.depth > 1)
return M*N*p_.depth;
return 0;
}
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
if(p_.depth > 1 && M*N*p_.depth > 2e6)
return TEMPLATE_TEMPORARY_TOO_LARGE;
if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
@@ -66,8 +70,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if ( p_.kS % p_.kL == 0)
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL)
{
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL){
if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1))
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
}