[dnn/blocksparse] added heuristics for block-sparse dot
This commit is contained in:
@@ -127,16 +127,42 @@ inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
|
|||||||
/* Block-sparse matrix multiplication */
|
/* Block-sparse matrix multiplication */
|
||||||
|
|
||||||
static const std::map<std::pair<bool, size_t>, std::map<size_t, params_t>> bsdot_params = {
|
static const std::map<std::pair<bool, size_t>, std::map<size_t, params_t>> bsdot_params = {
|
||||||
/* 32x32 */
|
/* FPROP */
|
||||||
{{true, 32}, std::map<size_t, params_t>{
|
{{true, 32}, std::map<size_t, params_t>{
|
||||||
{32, {2, 2, 32, 32, 2, 2, 4, 8, 32, 32, 8, 4, 16}},
|
{32, {2, 2, 32, 32, 2, 2, 4, 8, 32, 32, 8, 4, 16}},
|
||||||
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 2, 4}},
|
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 2, 4}},
|
||||||
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 8, 4, 16}}
|
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 8, 4, 16}}
|
||||||
}},
|
}},
|
||||||
|
|
||||||
|
{{true, 16}, std::map<size_t, params_t>{
|
||||||
|
{32, {4, 1, 32, 16, 1, 1, 8, 4, 4, 16, 4, 4, 8}},
|
||||||
|
{64, {4, 1, 64, 16, 2, 2, 8, 8, 16, 16, 8, 2, 16}},
|
||||||
|
{128, {4, 1, 128, 16, 4, 1, 16, 8, 8, 16, 8, 2, 16}}
|
||||||
|
}},
|
||||||
|
|
||||||
|
{{true, 8}, std::map<size_t, params_t>{
|
||||||
|
{32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}},
|
||||||
|
{64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 4, 2, 8}},
|
||||||
|
{128, {4, 1, 128, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}}
|
||||||
|
}},
|
||||||
|
|
||||||
|
/* BPROP */
|
||||||
{{false, 32}, std::map<size_t, params_t>{
|
{{false, 32}, std::map<size_t, params_t>{
|
||||||
{32, {2, 2, 32, 32, 1, 1, 8, 4, 4, 32, 8, 4, 8}},
|
{32, {2, 2, 32, 32, 1, 1, 8, 4, 4, 32, 8, 4, 8}},
|
||||||
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 4, 8}},
|
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 4, 8}},
|
||||||
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 32, 4, 8}}
|
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 32, 4, 8}}
|
||||||
|
}},
|
||||||
|
|
||||||
|
{{false, 16}, std::map<size_t, params_t>{
|
||||||
|
{32, {4, 1, 32, 16, 1, 2, 4, 8, 16, 16, 16, 4, 4}},
|
||||||
|
{64, {4, 1, 64, 16, 2, 1, 8, 8, 8, 16, 16, 4, 4}},
|
||||||
|
{128, {4, 1, 128, 16, 2, 2, 32, 4, 4, 16, 16, 8, 2}}
|
||||||
|
}},
|
||||||
|
|
||||||
|
{{false, 8}, std::map<size_t, params_t>{
|
||||||
|
{32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 8, 4, 2}},
|
||||||
|
{64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}},
|
||||||
|
{128, {4, 1, 128, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}}
|
||||||
}}
|
}}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -228,7 +228,7 @@ void tune::run(ir::module &mod) {
|
|||||||
nts->set_value(1);
|
nts->set_value(1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||||
}
|
}
|
||||||
|
@@ -92,7 +92,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
|||||||
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
|
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
|
||||||
std::string result =
|
std::string result =
|
||||||
R"(
|
R"(
|
||||||
const tunable int32 TM = {32, 64, 128};
|
const tunable int32 TM = {16, 32, 64, 128};
|
||||||
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
|
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
|
||||||
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
|
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
|
||||||
|
|
||||||
@@ -143,11 +143,11 @@ void dot::triton_c_src(std::ostream &os) const {
|
|||||||
int32 count = *pcount;
|
int32 count = *pcount;
|
||||||
if(count == 0){
|
if(count == 0){
|
||||||
@checkc *pc = c;
|
@checkc *pc = c;
|
||||||
__atomic_exch(pcount, 1);
|
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
@checkc *pc = c + *pc;
|
@checkc *pc = c + *pc;
|
||||||
}
|
}
|
||||||
|
__atomic_exch(pcount, 1);
|
||||||
__atomic_exch(plock, 0);
|
__atomic_exch(plock, 0);
|
||||||
}
|
}
|
||||||
})";
|
})";
|
||||||
|
@@ -230,6 +230,8 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
|||||||
parallel_for_each(targets, update_best, nthreads_);
|
parallel_for_each(targets, update_best, nthreads_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(best.params.empty())
|
||||||
|
throw std::runtime_error("auto-tuning didn't find valid parameters");
|
||||||
// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
|
// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
|
||||||
return best;
|
return best;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user