Kernels: more generic temporary workspace checks
This commit is contained in:
@@ -265,7 +265,6 @@ void bench(isc::numeric_type dtype, std::string operation)
|
||||
#endif
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << "\n\n" << std::flush;
|
||||
}
|
||||
|
||||
if(operation.substr(0, 4)=="gemv")
|
||||
@@ -334,6 +333,8 @@ void bench(isc::numeric_type dtype, std::string operation)
|
||||
MNKs.push_back(std::make_tuple('N','N',169,384,2304));
|
||||
MNKs.push_back(std::make_tuple('N','N',169,192,1728));
|
||||
MNKs.push_back(std::make_tuple('N','N',169,128,1728));
|
||||
MNKs.push_back(std::make_tuple('N','T',256,4096,9216));
|
||||
|
||||
//AlexNet (Backward)
|
||||
MNKs.push_back(std::make_tuple('T','N',1728,128,169));
|
||||
MNKs.push_back(std::make_tuple('T','N',1728,192,169));
|
||||
|
@@ -67,6 +67,7 @@ private:
|
||||
virtual std::string generate_impl(std::string const & suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mapping) const = 0;
|
||||
public:
|
||||
base(binding_policy_t binding_policy);
|
||||
virtual unsigned int temporary_workspace(expressions_tuple const &) const;
|
||||
virtual unsigned int lmem_usage(expressions_tuple const &) const;
|
||||
virtual unsigned int registers_usage(expressions_tuple const &) const;
|
||||
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const = 0;
|
||||
|
@@ -42,6 +42,7 @@ struct gemm_parameters : public base::parameters_type
|
||||
class gemm : public base_impl<gemm, gemm_parameters>
|
||||
{
|
||||
private:
|
||||
unsigned int temporary_workspace(expressions_tuple const & expressions) const;
|
||||
unsigned int lmem_usage(expressions_tuple const & expressions) const;
|
||||
unsigned int registers_usage(expressions_tuple const & expressions) const;
|
||||
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;
|
||||
|
@@ -73,6 +73,9 @@ unsigned int base::lmem_usage(expressions_tuple const &) const
|
||||
unsigned int base::registers_usage(expressions_tuple const &) const
|
||||
{ return 0; }
|
||||
|
||||
unsigned int base::temporary_workspace(expressions_tuple const &) const
|
||||
{ return 0; }
|
||||
|
||||
base::~base()
|
||||
{ }
|
||||
|
||||
@@ -81,7 +84,8 @@ std::string base::generate(std::string const & suffix, expressions_tuple const &
|
||||
expressions_tuple::data_type::const_iterator sit;
|
||||
std::vector<mapping_type>::iterator mit;
|
||||
|
||||
if(int err = is_invalid(expressions, device))
|
||||
int err = is_invalid(expressions, device);
|
||||
if(err != 0 && err != TEMPLATE_TEMPORARY_TOO_LARGE)
|
||||
throw operation_not_supported_exception("The supplied parameters for this template are invalid : err " + tools::to_string(err));
|
||||
|
||||
//Create mapping
|
||||
@@ -142,6 +146,10 @@ int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, d
|
||||
if (p_.simd_width!=1 && p_.simd_width!=2 && p_.simd_width!=3 && p_.simd_width!=4)
|
||||
return TEMPLATE_INVALID_SIMD_WIDTH;
|
||||
|
||||
//Temporary workspace
|
||||
if(temporary_workspace(expressions) > 2e6)
|
||||
return TEMPLATE_TEMPORARY_TOO_LARGE;
|
||||
|
||||
return is_invalid_impl(device, expressions);
|
||||
}
|
||||
|
||||
|
@@ -46,16 +46,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
return N*size_of(numeric_t);
|
||||
}
|
||||
|
||||
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
|
||||
unsigned int gemm::temporary_workspace(expressions_tuple const & expressions) const
|
||||
{
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
std::vector<int_t> MNK = input_sizes(expressions);
|
||||
int_t M = MNK[0]; int_t N = MNK[1];
|
||||
if(p_.depth > 1)
|
||||
return M*N*p_.depth;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
|
||||
{
|
||||
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
|
||||
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
|
||||
|
||||
if(p_.depth > 1 && M*N*p_.depth > 2e6)
|
||||
return TEMPLATE_TEMPORARY_TOO_LARGE;
|
||||
|
||||
if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
|
||||
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
|
||||
@@ -66,8 +70,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
|
||||
if ( p_.kS % p_.kL == 0)
|
||||
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
|
||||
|
||||
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL)
|
||||
{
|
||||
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL){
|
||||
if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1))
|
||||
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
|
||||
}
|
||||
|
@@ -124,6 +124,7 @@ void model::execute(controller<expressions_tuple> const & expr)
|
||||
}
|
||||
|
||||
//Execution
|
||||
// std::cout << std::endl << "Label: " << label << std::endl;
|
||||
return templates_[label]->enqueue(queue_, program, tools::to_string(label), *fallback_, expr);
|
||||
}
|
||||
|
||||
|
@@ -176,6 +176,7 @@ extern "C"
|
||||
cl_uint numCommandQueues, cl_command_queue *commandQueues,\
|
||||
cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *events)\
|
||||
{\
|
||||
/*std::cout << transA << " " <<transB << " " << M << " " << N << " " << K << std::endl;*/\
|
||||
cl_mem mA = cmA;\
|
||||
cl_mem mB = cmB;\
|
||||
if(order==clblasRowMajor){\
|
||||
|
Reference in New Issue
Block a user