[dnn/gemm]: fixed leading dimension in transposed variants
This commit is contained in:
@@ -54,6 +54,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
unsigned grid_0 = (M_ + TM - 1)/TM;
|
||||
unsigned grid_1 = (N_ + TN - 1)/TN;
|
||||
unsigned grid_2 = 1;
|
||||
int32_t lda = AT_ ? K_ : M_;
|
||||
int32_t ldb = BT_ ? N_ : K_;
|
||||
int32_t ldc = M_;
|
||||
std::array<size_t, 3> grid = {grid_0, grid_1, grid_2};
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
@@ -61,9 +64,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, M_);
|
||||
kernel->setArg(7, N_);
|
||||
kernel->setArg(8, M_);
|
||||
kernel->setArg(6, lda);
|
||||
kernel->setArg(7, ldb);
|
||||
kernel->setArg(8, ldc);
|
||||
kernel->setArg(9, locks_);
|
||||
kernel->setArg(10, grid_0);
|
||||
kernel->setArg(11, grid_1);
|
||||
|
Reference in New Issue
Block a user