[dnn/gemm]: fixed leading dimension in transposed variants

This commit is contained in:
Philippe Tillet
2019-07-18 16:35:48 -07:00
parent f0d8306437
commit 71594da66f
3 changed files with 18 additions and 14 deletions

View File

@@ -54,6 +54,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned grid_2 = 1;
int32_t lda = AT_ ? K_ : M_;
int32_t ldb = BT_ ? N_ : K_;
int32_t ldc = M_;
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
kernel->setArg(0, a);
kernel->setArg(1, b);
@@ -61,9 +64,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(3, M_);
kernel->setArg(4, N_);
kernel->setArg(5, K_);
kernel->setArg(6, M_);
kernel->setArg(7, N_);
kernel->setArg(8, M_);
kernel->setArg(6, lda);
kernel->setArg(7, ldb);
kernel->setArg(8, ldc);
kernel->setArg(9, locks_);
kernel->setArg(10, grid_0);
kernel->setArg(11, grid_1);