Kernels: more generic temporary workspace checks

This commit is contained in:
Philippe Tillet
2015-08-10 10:19:50 -07:00
parent f56eac5adb
commit f60b82af25
7 changed files with 25 additions and 9 deletions

View File

@@ -265,7 +265,6 @@ void bench(isc::numeric_type dtype, std::string operation)
#endif
std::cout << std::endl;
}
std::cout << "\n\n" << std::flush;
}
if(operation.substr(0, 4)=="gemv")
@@ -334,6 +333,8 @@ void bench(isc::numeric_type dtype, std::string operation)
MNKs.push_back(std::make_tuple('N','N',169,384,2304));
MNKs.push_back(std::make_tuple('N','N',169,192,1728));
MNKs.push_back(std::make_tuple('N','N',169,128,1728));
MNKs.push_back(std::make_tuple('N','T',256,4096,9216));
//AlexNet (Backward)
MNKs.push_back(std::make_tuple('T','N',1728,128,169));
MNKs.push_back(std::make_tuple('T','N',1728,192,169));

View File

@@ -67,6 +67,7 @@ private:
virtual std::string generate_impl(std::string const & suffix, expressions_tuple const & expressions, driver::Device const & device, std::vector<mapping_type> const & mapping) const = 0;
public:
base(binding_policy_t binding_policy);
virtual unsigned int temporary_workspace(expressions_tuple const &) const;
virtual unsigned int lmem_usage(expressions_tuple const &) const;
virtual unsigned int registers_usage(expressions_tuple const &) const;
virtual std::vector<int_t> input_sizes(expressions_tuple const & expressions) const = 0;

View File

@@ -42,6 +42,7 @@ struct gemm_parameters : public base::parameters_type
class gemm : public base_impl<gemm, gemm_parameters>
{
private:
unsigned int temporary_workspace(expressions_tuple const & expressions) const;
unsigned int lmem_usage(expressions_tuple const & expressions) const;
unsigned int registers_usage(expressions_tuple const & expressions) const;
int is_invalid_impl(driver::Device const &, expressions_tuple const &) const;

View File

@@ -73,6 +73,9 @@ unsigned int base::lmem_usage(expressions_tuple const &) const
unsigned int base::registers_usage(expressions_tuple const &) const
{ return 0; }
unsigned int base::temporary_workspace(expressions_tuple const &) const
{ return 0; }
base::~base()
{ }
@@ -81,7 +84,8 @@ std::string base::generate(std::string const & suffix, expressions_tuple const &
expressions_tuple::data_type::const_iterator sit;
std::vector<mapping_type>::iterator mit;
if(int err = is_invalid(expressions, device))
int err = is_invalid(expressions, device);
if(err != 0 && err != TEMPLATE_TEMPORARY_TOO_LARGE)
throw operation_not_supported_exception("The supplied parameters for this template are invalid : err " + tools::to_string(err));
//Create mapping
@@ -142,6 +146,10 @@ int base_impl<TType, PType>::is_invalid(expressions_tuple const & expressions, d
if (p_.simd_width!=1 && p_.simd_width!=2 && p_.simd_width!=3 && p_.simd_width!=4)
return TEMPLATE_INVALID_SIMD_WIDTH;
//Temporary workspace
if(temporary_workspace(expressions) > 2e6)
return TEMPLATE_TEMPORARY_TOO_LARGE;
return is_invalid_impl(device, expressions);
}

View File

@@ -46,16 +46,20 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
return N*size_of(numeric_t);
}
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const & expressions) const
unsigned int gemm::temporary_workspace(expressions_tuple const & expressions) const
{
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
std::vector<int_t> MNK = input_sizes(expressions);
int_t M = MNK[0]; int_t N = MNK[1];
if(p_.depth > 1)
return M*N*p_.depth;
return 0;
}
int gemm::is_invalid_impl(driver::Device const &, expressions_tuple const &) const
{
if(p_.A_fetching_policy!=FETCH_FROM_LOCAL || p_.B_fetching_policy!=FETCH_FROM_LOCAL)
return TEMPLATE_INVALID_FETCHING_POLICY_TYPE;
if(p_.depth > 1 && M*N*p_.depth > 2e6)
return TEMPLATE_TEMPORARY_TOO_LARGE;
if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
@@ -66,8 +70,7 @@ gemm_parameters::gemm_parameters(unsigned int simd_width
if ( p_.kS % p_.kL == 0)
return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL)
{
if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL){
if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1))
return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
}

View File

@@ -124,6 +124,7 @@ void model::execute(controller<expressions_tuple> const & expr)
}
//Execution
// std::cout << std::endl << "Label: " << label << std::endl;
return templates_[label]->enqueue(queue_, program, tools::to_string(label), *fallback_, expr);
}

View File

@@ -176,6 +176,7 @@ extern "C"
cl_uint numCommandQueues, cl_command_queue *commandQueues,\
cl_uint numEventsInWaitList, const cl_event *eventWaitList, cl_event *events)\
{\
/*std::cout << transA << " " <<transB << " " << M << " " << N << " " << K << std::endl;*/\
cl_mem mA = cmA;\
cl_mem mB = cmB;\
if(order==clblasRowMajor){\