Runtime: More progress towards cuBLAS integration

This commit is contained in:
Philippe Tillet
2016-10-04 01:02:43 -04:00
parent fb9669a34d
commit ffb9548b6a
18 changed files with 170 additions and 210 deletions

View File

@@ -50,16 +50,11 @@ std::vector<int_t> infos(expression_tree const & tree, symbolic::preset::gemm::a
}
/* ------------------ CUBLAS ------------------ */
bool cublas_gemm::init()
{
return driver::dispatch::cublasinit();
}
cublas_gemm::cublas_gemm(char A_trans, char B_trans): A_trans_(A_trans), B_trans_(B_trans), init_(driver::dispatch::cublasinit())
{ }
int cublas_gemm::is_invalid(expression_tree const &, driver::Device const & device) const
{ return init_ && device.backend()==driver::CUDA; }
{ return (init_ && device.backend()==driver::CUDA)?0:-1; }
std::vector<int_t> cublas_gemm::input_sizes(expression_tree const & expressions) const
{
@@ -67,9 +62,21 @@ std::vector<int_t> cublas_gemm::input_sizes(expression_tree const & expressions)
return infos((expression_tree&)expressions, dummy, A_trans_);
}
expression_type cublas_gemm::type() const
{
if(A_trans_=='N' && B_trans_=='N')
return GEMM_NN;
else if(A_trans_=='T' && B_trans_=='N')
return GEMM_TN;
else if(A_trans_=='N' && B_trans_=='T')
return GEMM_NT;
else
return GEMM_TT;
}
void cublas_gemm::enqueue(driver::CommandQueue & queue, driver::Program const &, std::string const &, runtime::execution_handler const & control)
{
namespace drv = driver;;
namespace drv = driver;
//Get GEMM info
symbolic::preset::gemm::args args;
std::vector<int_t> MNK = infos(control.x(), args, A_trans_);
@@ -115,6 +122,19 @@ unsigned int gemm::lmem_usage(expression_tree const & expression) const
return N*size_of(expression.dtype());
}
expression_type gemm::type() const
{
if(A_trans_=='N' && B_trans_=='N')
return GEMM_NN;
else if(A_trans_=='T' && B_trans_=='N')
return GEMM_TN;
else if(A_trans_=='N' && B_trans_=='T')
return GEMM_NT;
else
return GEMM_TT;
}
unsigned int gemm::registers_usage(expression_tree const & expression) const
{
unsigned int N = mS_ * nS_ + mS_ * kS_ + kS_ * nS_;