[dnn/shift] now using constant divisions

This commit is contained in:
Philippe Tillet
2019-07-16 21:05:21 -07:00
parent 07c964919c
commit a55b098e88
4 changed files with 17 additions and 13 deletions

View File

@@ -14,7 +14,7 @@ int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
auto op = triton::dnn::shift::BPROP;
auto op = triton::dnn::shift::FPROP;
// initialization
int32_t R = 3, S = 3;

View File

@@ -223,7 +223,7 @@ void tune::run(ir::module &mod) {
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
}
@@ -237,7 +237,7 @@ void tune::run(ir::module &mod) {
continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 2));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
*params_.at(i).at("nts.d0") = *tmp;
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){

View File

@@ -291,19 +291,23 @@ void shift::triton_c_src(std::ostream &os) const {
bool is_chwn = layout_ == CHWN;
auto compute_bhw = [&](std::string rx, std::string sz, std::string rkx){
std::string B = std::to_string(B_);
std::string CW = std::to_string(CW_);
std::string CH = std::to_string(CH_);
if(is_chwn) {
return R"(
int32 )" + rx + "wh[" + sz + "] = " + rkx + R"( / NB;
int32 )" + rx + "b[" + sz + "] = " + rkx + R"( % NB;
int32 )" + rx + "w[" + sz + "] = (" + rx + R"(wh % CW) + pad_w;
int32 )" + rx + "h[" + sz + "] = (" + rx + R"(wh / CW) + pad_h;)";
int32 )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(;
int32 )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"();
int32 )" + rx + "w[" + sz + "] = (" + rx + "(wh % " + CW + R"() + pad_w;
int32 )" + rx + "h[" + sz + "] = (" + rx + "(wh / " + CW + R"() + pad_h;)";
}
else {
return R"(
int32 )" + rx + "bh[" + sz + "] = " + rkx + R"( / CW;
int32 )" + rx + "w[" + sz + "] = (" + rkx + R"( % CW) + pad_w;
int32 )" + rx + "h[" + sz + "] = (" + rx + R"(bh % CH) + pad_h;
int32 )" + rx + "b[" + sz + "] = " + rx + R"(bh / CH;)";
int32 )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(;
int32 )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w;
int32 )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h;
int32 )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";";
}
};

View File

@@ -43,8 +43,8 @@ void loop_nest(std::vector<size_t> const & ranges,
// size_t current = 0;
while(true){
//Execute function
// pool.enqueue([values, &f](){ f(values); });
f(values);
pool.enqueue([values, &f](){ f(values); });
// f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;