From 28c250216cae840d3c59f55cb61cdece833e3e90 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 19 Jul 2019 21:32:55 -0700 Subject: [PATCH] [dnn/gemm] added some bounds checking --- examples/cpp/dot.cpp | 6 ++-- include/triton/dnn/gemm.h | 3 -- lib/codegen/tune.cpp | 2 +- lib/dnn/gemm.cpp | 60 ++++++++++++++++++++------------------- lib/runtime/jit.cpp | 2 +- 5 files changed, 36 insertions(+), 37 deletions(-) diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index b5af64615..720c872f2 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -8,15 +8,15 @@ int main() { - bool AT = true; - bool BT = false; + bool AT = false; + bool BT = true; typedef float T; std::string ty = "fp16"; size_t dt_nbytes = sizeof(T); // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); // matrix multiplication parameters - int32_t M = 65536, N = 2048, K = 2048; + int32_t M = 4096, N = 4096, K = 4096; std::vector hc(M*N); std::vector rc(M*N); std::vector ha(M*K); diff --git a/include/triton/dnn/gemm.h b/include/triton/dnn/gemm.h index 8348edf3e..f43370606 100644 --- a/include/triton/dnn/gemm.h +++ b/include/triton/dnn/gemm.h @@ -31,9 +31,6 @@ public: // clone base* clone() const; - // default params - std::vector default_params(); - // CPU reference implementation template static void cpu_ref(std::vector &c, const std::vector &a, const std::vector &b, diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 1d24b9548..2d104d8d6 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -235,7 +235,7 @@ void tune::run(ir::module &mod) { continue; if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 8, 8)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 8)); *params_.at(i).at("nts.d0") = *tmp; } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ diff --git a/lib/dnn/gemm.cpp b/lib/dnn/gemm.cpp index 6a9bace7d..82fdb431b 100644 --- a/lib/dnn/gemm.cpp +++ b/lib/dnn/gemm.cpp @@ -51,6 +51,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel, driver::buffer *a = args[0], *b = args[1], *c = args[2]; unsigned TM = info.globals.at("TM"); unsigned TN = info.globals.at("TN"); + unsigned TK = info.globals.at("TK"); unsigned grid_0 = (M_ + TM - 1)/TM; unsigned grid_1 = (N_ + TN - 1)/TN; unsigned grid_2 = 1; @@ -67,23 +68,13 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel, kernel->setArg(6, lda); kernel->setArg(7, ldb); kernel->setArg(8, ldc); - kernel->setArg(9, locks_); - kernel->setArg(10, grid_0); - kernel->setArg(11, grid_1); + kernel->setArg(9, TK); + kernel->setArg(10, locks_); + kernel->setArg(11, grid_0); + kernel->setArg(12, grid_1); stream->enqueue(kernel, grid, {info.num_threads, 1, 1}); } -std::vector gemm::default_params() { - if(AT_ && BT_) - return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1}; - else if(AT_ && !BT_) - return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1}; - else if(!AT_ && BT_) - return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1}; - else - return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1}; -} - void gemm::triton_c_src(std::ostream &os) const { std::string AS0 = "TM", AS1 = "TK"; std::string BS0 = "TK", BS1 = "TN"; @@ -103,12 +94,14 @@ void gemm::triton_c_src(std::ostream &os) const { std::swap(bcb0, bcb1); std::swap(ldb0, ldb1); } + std::string AS = AS0 + ", " + AS1; + std::string BS = BS0 + ", " + BS1; std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = R"( -const tunable int32 TM = {16, 32, 64, 128}; -const tunable int32 TN = {16, 32, 64, 128}; +const tunable int32 TM = {32, 64, 128, 256}; +const tunable int32 TN = {32, 64, 128, 256}; const tunable int32 TK = {32}; const tunable int32 GZ = {1}; @@ -117,27 +110,36 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, fp32 *C, int32 M, int32 N, int32 K, )" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc, - int32 *locks, int32 grid0, int32 grid1) { - int32 rxa[TM] = get_global_range[TM](0); - int32 ryb[TN] = get_global_range[TN](1); + int32 bound, int32 *locks, int32 grid0, int32 grid1) { + int32 ridx = get_range_id(0); + int32 ridy = get_range_id(1); + int32 rxa[TM] = ridx*TM + (0 ... TM); + int32 ryb[TN] = ridy*TN + (0 ... TN); int32 rka[TK] = 0 ... TK; int32 rkb[TK] = 0 ... TK; fp32 c[TM, TN] = 0; - )" + a_ty_ + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; - )" + b_ty_ + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; - )" + a_ty_ + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa; - )" + b_ty_ + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb; - for(int32 k = K; k > TK; k = k - TK){ + )" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; + )" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; + int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(; + int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(; + )" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0; + )" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0; + for(int32 k = K; k > 0; k = k - TK){ c = dot()" + usea + ", " + useb + R"(, c); pa = pa + TK)" + lda0 + R"(; pb = pb + TK)" + ldb0 + R"(; - a = *pa; - b = *pb; + int1 checka[)" + AS + R"(] = k > bound; + int1 checkb[)" + BS + R"(] = k > bound; + @checka a = *pa; + @checkb b = *pb; } - int32 rxc[TM] = get_global_range[TM](0); - int32 ryc[TN] = get_global_range[TN](1); + int32 rxc[TM] = ridx*TM + (0 ... TM); + int32 ryc[TN] = ridy*TN + (0 ... TN); + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = ryc < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - *pc = c; + @checkc *pc = c; } )"; os << res; diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 6e0f72334..c925a690c 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -49,7 +49,7 @@ void loop_nest(std::vector const & ranges, values[i--] = 0; } i = D - 1; - // Small sleep so that the thread pool doesn't grow too big + // Short sleep so that the thread pool doesn't grow too big std::this_thread::sleep_for(std::chrono::microseconds(1)); } }