diff --git a/include/triton/dnn/heuristics.h b/include/triton/dnn/heuristics.h index d9bd01e75..56c23642b 100644 --- a/include/triton/dnn/heuristics.h +++ b/include/triton/dnn/heuristics.h @@ -127,16 +127,42 @@ inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) { /* Block-sparse matrix multiplication */ static const std::map, std::map> bsdot_params = { - /* 32x32 */ + /* FPROP */ {{true, 32}, std::map{ {32, {2, 2, 32, 32, 2, 2, 4, 8, 32, 32, 8, 4, 16}}, {64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 2, 4}}, {128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 8, 4, 16}} }}, + + {{true, 16}, std::map{ + {32, {4, 1, 32, 16, 1, 1, 8, 4, 4, 16, 4, 4, 8}}, + {64, {4, 1, 64, 16, 2, 2, 8, 8, 16, 16, 8, 2, 16}}, + {128, {4, 1, 128, 16, 4, 1, 16, 8, 8, 16, 8, 2, 16}} + }}, + + {{true, 8}, std::map{ + {32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}}, + {64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 4, 2, 8}}, + {128, {4, 1, 128, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}} + }}, + + /* BPROP */ {{false, 32}, std::map{ {32, {2, 2, 32, 32, 1, 1, 8, 4, 4, 32, 8, 4, 8}}, {64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 4, 8}}, {128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 32, 4, 8}} + }}, + + {{false, 16}, std::map{ + {32, {4, 1, 32, 16, 1, 2, 4, 8, 16, 16, 16, 4, 4}}, + {64, {4, 1, 64, 16, 2, 1, 8, 8, 8, 16, 16, 4, 4}}, + {128, {4, 1, 128, 16, 2, 2, 32, 4, 4, 16, 16, 8, 2}} + }}, + + {{false, 8}, std::map{ + {32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 8, 4, 2}}, + {64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}}, + {128, {4, 1, 128, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}} }} }; diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index a28fd827e..8c351be1c 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -228,7 +228,7 @@ void tune::run(ir::module &mod) { nts->set_value(1); } else { - ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); + ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4); ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); } diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp index f38030366..3ea79bc78 100644 --- a/lib/dnn/blocksparse/dot.cpp +++ b/lib/dnn/blocksparse/dot.cpp @@ -92,7 +92,7 @@ void dot::triton_c_src(std::ostream &os) const { std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ; std::string result = R"( - const tunable int32 TM = {32, 64, 128}; + const tunable int32 TM = {16, 32, 64, 128}; const tunable int32 TN = {)" + std::to_string(BS_) + R"(}; const tunable int32 TK = {)" + std::to_string(BS_) + R"(}; @@ -143,11 +143,11 @@ void dot::triton_c_src(std::ostream &os) const { int32 count = *pcount; if(count == 0){ @checkc *pc = c; - __atomic_exch(pcount, 1); } else{ @checkc *pc = c + *pc; } + __atomic_exch(pcount, 1); __atomic_exch(plock, 0); } })"; diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 12eebdd3c..928ec0812 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -230,6 +230,8 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben parallel_for_each(targets, update_best, nthreads_); } + if(best.params.empty()) + throw std::runtime_error("auto-tuning didn't find valid parameters"); // std::cout << "Autotuning done - Best performance: " << best.perf << std::endl; return best; }