[dnn/blocksparse] added heuristics for block-sparse dot
This commit is contained in:
@@ -127,16 +127,42 @@ inline params_t dot_heuristics(bool AT, bool BT, size_t M, size_t N, size_t K) {
|
||||
/* Block-sparse matrix multiplication */
|
||||
|
||||
static const std::map<std::pair<bool, size_t>, std::map<size_t, params_t>> bsdot_params = {
|
||||
/* 32x32 */
|
||||
/* FPROP */
|
||||
{{true, 32}, std::map<size_t, params_t>{
|
||||
{32, {2, 2, 32, 32, 2, 2, 4, 8, 32, 32, 8, 4, 16}},
|
||||
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 2, 4}},
|
||||
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 8, 4, 16}}
|
||||
}},
|
||||
|
||||
{{true, 16}, std::map<size_t, params_t>{
|
||||
{32, {4, 1, 32, 16, 1, 1, 8, 4, 4, 16, 4, 4, 8}},
|
||||
{64, {4, 1, 64, 16, 2, 2, 8, 8, 16, 16, 8, 2, 16}},
|
||||
{128, {4, 1, 128, 16, 4, 1, 16, 8, 8, 16, 8, 2, 16}}
|
||||
}},
|
||||
|
||||
{{true, 8}, std::map<size_t, params_t>{
|
||||
{32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}},
|
||||
{64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 4, 2, 8}},
|
||||
{128, {4, 1, 128, 8, 1, 1, 4, 8, 8, 8, 4, 2, 8}}
|
||||
}},
|
||||
|
||||
/* BPROP */
|
||||
{{false, 32}, std::map<size_t, params_t>{
|
||||
{32, {2, 2, 32, 32, 1, 1, 8, 4, 4, 32, 8, 4, 8}},
|
||||
{64, {2, 2, 64, 32, 2, 1, 16, 4, 4, 32, 16, 4, 8}},
|
||||
{128, {2, 2, 128, 32, 4, 1, 32, 4, 4, 32, 32, 4, 8}}
|
||||
}},
|
||||
|
||||
{{false, 16}, std::map<size_t, params_t>{
|
||||
{32, {4, 1, 32, 16, 1, 2, 4, 8, 16, 16, 16, 4, 4}},
|
||||
{64, {4, 1, 64, 16, 2, 1, 8, 8, 8, 16, 16, 4, 4}},
|
||||
{128, {4, 1, 128, 16, 2, 2, 32, 4, 4, 16, 16, 8, 2}}
|
||||
}},
|
||||
|
||||
{{false, 8}, std::map<size_t, params_t>{
|
||||
{32, {4, 1, 32, 8, 1, 1, 4, 8, 8, 8, 8, 4, 2}},
|
||||
{64, {4, 1, 64, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}},
|
||||
{128, {4, 1, 128, 8, 1, 1, 8, 8, 4, 8, 8, 4, 2}}
|
||||
}}
|
||||
};
|
||||
|
||||
|
@@ -228,7 +228,7 @@ void tune::run(ir::module &mod) {
|
||||
nts->set_value(1);
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
|
@@ -92,7 +92,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int32 TM = {32, 64, 128};
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
const tunable int32 TN = {)" + std::to_string(BS_) + R"(};
|
||||
const tunable int32 TK = {)" + std::to_string(BS_) + R"(};
|
||||
|
||||
@@ -143,11 +143,11 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
int32 count = *pcount;
|
||||
if(count == 0){
|
||||
@checkc *pc = c;
|
||||
__atomic_exch(pcount, 1);
|
||||
}
|
||||
else{
|
||||
@checkc *pc = c + *pc;
|
||||
}
|
||||
__atomic_exch(pcount, 1);
|
||||
__atomic_exch(plock, 0);
|
||||
}
|
||||
})";
|
||||
|
@@ -230,6 +230,8 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
|
||||
parallel_for_each(targets, update_best, nthreads_);
|
||||
}
|
||||
|
||||
if(best.params.empty())
|
||||
throw std::runtime_error("auto-tuning didn't find valid parameters");
|
||||
// std::cout << "Autotuning done - Best performance: " << best.perf << std::endl;
|
||||
return best;
|
||||
}
|
||||
|
Reference in New Issue
Block a user