[dnn/blocksparse] added heuristics for block-sparse dot

This commit is contained in:
Philippe Tillet
2019-07-31 17:12:36 -07:00
parent bb32ac56c9
commit f7bd976fc7
4 changed files with 32 additions and 4 deletions

View File

@@ -127,16 +127,42 @@ inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
/* Block-sparse matrix multiplication */
static const std::map<std::pair<bool, size_t>, std::map<size_t, params_t>> bsdot_params = {
/* 32x32 */
/* FPROP */
{{true, 32}, std::map<size_t, params_t>{
{32, {2, 2, 32, 32, 2, 2, 4, 8, 32, 32, 8, 4, 16}},
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 2, 4}},
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 8, 4, 16}}
}},
{{true, 16}, std::map<size_t, params_t>{
{32, {4, 1, 32, 16, 1, 1, 8, 4, 4, 16, 4, 4, 8}},
{64, {4, 1, 64, 16, 2, 2, 8, 8, 16, 16, 8, 2, 16}},
{128, {4, 1, 128, 16, 4, 1, 16, 8, 8, 16, 8, 2, 16}}
}},
{{true, 8}, std::map<size_t, params_t>{
{32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}},
{64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 4, 2, 8}},
{128, {4, 1, 128, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}}
}},
/* BPROP */
{{false, 32}, std::map<size_t, params_t>{
{32, {2, 2, 32, 32, 1, 1, 8, 4, 4, 32, 8, 4, 8}},
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 4, 8}},
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 32, 4, 8}}
}},
{{false, 16}, std::map<size_t, params_t>{
{32, {4, 1, 32, 16, 1, 2, 4, 8, 16, 16, 16, 4, 4}},
{64, {4, 1, 64, 16, 2, 1, 8, 8, 8, 16, 16, 4, 4}},
{128, {4, 1, 128, 16, 2, 2, 32, 4, 4, 16, 16, 8, 2}}
}},
{{false, 8}, std::map<size_t, params_t>{
{32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 8, 4, 2}},
{64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}},
{128, {4, 1, 128, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}}
}}
};

View File

@@ -228,7 +228,7 @@ void tune::run(ir::module &mod) {
nts->set_value(1);
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}

View File

@@ -92,7 +92,7 @@ void dot::triton_c_src(std::ostream &os) const {
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
std::string result =
R"(
const tunable int32 TM = {32, 64, 128};
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
@@ -143,11 +143,11 @@ void dot::triton_c_src(std::ostream &os) const {
int32 count = *pcount;
if(count == 0){
@checkc *pc = c;
__atomic_exch(pcount, 1);
}
else{
@checkc *pc = c + *pc;
}
__atomic_exch(pcount, 1);
__atomic_exch(plock, 0);
}
})";

View File

@@ -230,6 +230,8 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
parallel_for_each(targets, update_best, nthreads_);
}
if(best.params.empty())
throw std::runtime_error("auto-tuning didn't find valid parameters");
// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
return best;
}