From 71594da66ff6c41d7c6c27605b4627a5957efd9c Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 18 Jul 2019 16:35:48 -0700 Subject: [PATCH] [dnn/gemm]: fixed leading dimension in transposed variants --- examples/cpp/dot.cpp | 22 +++++++++++----------- lib/codegen/selection.cpp | 1 + lib/dnn/gemm.cpp | 9 ++++++--- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 2e790c2f5..5068dfbde 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -8,12 +8,12 @@ int main() { - bool AT = false; + bool AT = true; bool BT = false; // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); // matrix multiplication parameters - int32_t M = 1024, N = 1024, K = 1024; + int32_t M = 64, N = 128, K = 128; std::vector hc(M*N); std::vector rc(M*N); std::vector ha(M*K); @@ -33,14 +33,14 @@ int main() { stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - triton::dnn::gemm gemm(M, N, K, AT, BT, "fp16", "fp16", 4, 4); - gemm.enqueue(stream, {da, db, dc}, true); -// stream->read(dc, true, 0, hc); -// gemm.cpu_ref(rc, ha, hb); -// for(size_t i = 0; i < M*N; i++) -// if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ -// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; -// exit(EXIT_FAILURE); -// } + triton::dnn::gemm gemm(M, N, K, AT, BT, "fp32", "fp32", 4, 4); + gemm.enqueue(stream, {da, db, dc}, false); + stream->read(dc, true, 0, hc); + gemm.cpu_ref(rc, ha, hb); + for(size_t i = 0; i < M*N; i++) + if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; + exit(EXIT_FAILURE); + } std::cout << "Pass!" << std::endl; } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 04a413b32..3238efdfa 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -1149,6 +1149,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & unsigned alignment = std::min(starting_multiple, max_contiguous); unsigned vector_size = std::min(result->axis(0).contiguous, alignment); // vector_size = result->axis(0).contiguous; +// vector_size = 1; std::map packets; distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand()); result->for_each([&](indices_t idx){ diff --git a/lib/dnn/gemm.cpp b/lib/dnn/gemm.cpp index 05a47e41f..222173c61 100644 --- a/lib/dnn/gemm.cpp +++ b/lib/dnn/gemm.cpp @@ -54,6 +54,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel, unsigned grid_0 = (M_ + TM - 1)/TM; unsigned grid_1 = (N_ + TN - 1)/TN; unsigned grid_2 = 1; + int32_t lda = AT_ ? K_ : M_; + int32_t ldb = BT_ ? N_ : K_; + int32_t ldc = M_; std::array grid = {grid_0, grid_1, grid_2}; kernel->setArg(0, a); kernel->setArg(1, b); @@ -61,9 +64,9 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(3, M_); kernel->setArg(4, N_); kernel->setArg(5, K_); - kernel->setArg(6, M_); - kernel->setArg(7, N_); - kernel->setArg(8, M_); + kernel->setArg(6, lda); + kernel->setArg(7, ldb); + kernel->setArg(8, ldc); kernel->setArg(9, locks_); kernel->setArg(10, grid_0); kernel->setArg(11, grid_1);