[dnn/shift] fixed in leading dimensions for shift-conv operation
This commit is contained in:
@@ -125,14 +125,16 @@ void shift::init(driver::stream *stream, driver::cu_module *module) {
|
||||
void shift::enqueue(driver::stream *stream, driver::kernel *kernel,
|
||||
driver::buffer *a, driver::buffer *b, driver::buffer *c,
|
||||
size_t TM, size_t TN, size_t nthreads) {
|
||||
int32_t lda = AT_ ? K_ : M_;
|
||||
int32_t ldb = BT_ ? N_ : K_;
|
||||
kernel->setArg(0, a);
|
||||
kernel->setArg(1, b);
|
||||
kernel->setArg(2, c);
|
||||
kernel->setArg(3, M_);
|
||||
kernel->setArg(4, N_);
|
||||
kernel->setArg(5, K_);
|
||||
kernel->setArg(6, M_);
|
||||
kernel->setArg(7, N_);
|
||||
kernel->setArg(6, lda);
|
||||
kernel->setArg(7, ldb);
|
||||
kernel->setArg(8, B_);
|
||||
kernel->setArg(9, AH_);
|
||||
kernel->setArg(10, AW_);
|
||||
|
Reference in New Issue
Block a user