[dnn/shift] fixed in leading dimensions for shift-conv operation

This commit is contained in:
Philippe Tillet
2019-07-05 17:17:22 -07:00
parent c666f71fd6
commit 3e49dbe6ab
3 changed files with 14 additions and 18 deletions

View File

@@ -125,14 +125,16 @@ void shift::init(driver::stream *stream, driver::cu_module *module) {
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a, driver::buffer *b, driver::buffer *c,
size_t TM, size_t TN, size_t nthreads) {
int32_t lda = AT_ ? K_ : M_;
int32_t ldb = BT_ ? N_ : K_;
kernel->setArg(0, a);
kernel->setArg(1, b);
kernel->setArg(2, c);
kernel->setArg(3, M_);
kernel->setArg(4, N_);
kernel->setArg(5, K_);
kernel->setArg(6, M_);
kernel->setArg(7, N_);
kernel->setArg(6, lda);
kernel->setArg(7, ldb);
kernel->setArg(8, B_);
kernel->setArg(9, AH_);
kernel->setArg(10, AW_);