Runtime: More progress towards cuBLAS integration
This commit is contained in:
@@ -50,16 +50,11 @@ std::vector<int_t> infos(expression_tree const & tree, symbolic::preset::gemm::a
|
||||
}
|
||||
|
||||
/* ------------------ CUBLAS ------------------ */
|
||||
bool cublas_gemm::init()
|
||||
{
|
||||
return driver::dispatch::cublasinit();
|
||||
}
|
||||
|
||||
cublas_gemm::cublas_gemm(char A_trans, char B_trans): A_trans_(A_trans), B_trans_(B_trans), init_(driver::dispatch::cublasinit())
|
||||
{ }
|
||||
|
||||
int cublas_gemm::is_invalid(expression_tree const &, driver::Device const & device) const
|
||||
{ return init_ && device.backend()==driver::CUDA; }
|
||||
{ return (init_ && device.backend()==driver::CUDA)?0:-1; }
|
||||
|
||||
std::vector<int_t> cublas_gemm::input_sizes(expression_tree const & expressions) const
|
||||
{
|
||||
@@ -67,9 +62,21 @@ std::vector<int_t> cublas_gemm::input_sizes(expression_tree const & expressions)
|
||||
return infos((expression_tree&)expressions, dummy, A_trans_);
|
||||
}
|
||||
|
||||
expression_type cublas_gemm::type() const
|
||||
{
|
||||
if(A_trans_=='N' && B_trans_=='N')
|
||||
return GEMM_NN;
|
||||
else if(A_trans_=='T' && B_trans_=='N')
|
||||
return GEMM_TN;
|
||||
else if(A_trans_=='N' && B_trans_=='T')
|
||||
return GEMM_NT;
|
||||
else
|
||||
return GEMM_TT;
|
||||
}
|
||||
|
||||
void cublas_gemm::enqueue(driver::CommandQueue & queue, driver::Program const &, std::string const &, runtime::execution_handler const & control)
|
||||
{
|
||||
namespace drv = driver;;
|
||||
namespace drv = driver;
|
||||
//Get GEMM info
|
||||
symbolic::preset::gemm::args args;
|
||||
std::vector<int_t> MNK = infos(control.x(), args, A_trans_);
|
||||
@@ -115,6 +122,19 @@ unsigned int gemm::lmem_usage(expression_tree const & expression) const
|
||||
return N*size_of(expression.dtype());
|
||||
}
|
||||
|
||||
expression_type gemm::type() const
|
||||
{
|
||||
if(A_trans_=='N' && B_trans_=='N')
|
||||
return GEMM_NN;
|
||||
else if(A_trans_=='T' && B_trans_=='N')
|
||||
return GEMM_TN;
|
||||
else if(A_trans_=='N' && B_trans_=='T')
|
||||
return GEMM_NT;
|
||||
else
|
||||
return GEMM_TT;
|
||||
}
|
||||
|
||||
|
||||
unsigned int gemm::registers_usage(expression_tree const & expression) const
|
||||
{
|
||||
unsigned int N = mS_ * nS_ + mS_ * kS_ + kS_ * nS_;
|
||||
|
Reference in New Issue
Block a user