From 3b92ddf7e60a0ea55d9dcfc5e3abaeca653aa152 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 31 Jul 2019 18:41:56 -0700 Subject: [PATCH] [codegen/reassociation] now recursively takes pointer arguments into account as well --- examples/python/tensorflow/blocksparse.cpp | 2 +- include/triton/runtime/jit.h | 1 + lib/codegen/reassociate.cpp | 15 +++++++++++++++ lib/codegen/tune.cpp | 8 ++++---- lib/dnn/blocksparse/dot.cpp | 6 +++--- 5 files changed, 24 insertions(+), 8 deletions(-) diff --git a/examples/python/tensorflow/blocksparse.cpp b/examples/python/tensorflow/blocksparse.cpp index 38b335689..0d37d382d 100644 --- a/examples/python/tensorflow/blocksparse.cpp +++ b/examples/python/tensorflow/blocksparse.cpp @@ -130,7 +130,7 @@ public: // create profile triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP); // blocksparse matmul - triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING); + triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING); triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks(); Tensor *tmp = nullptr; TensorShape tmp_shapes; diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index de84d1788..19fde0e84 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -77,6 +77,7 @@ public: void target_dependent(ir::module &module) { alignment_info.run(module); +// ir::print(module, std::cout); reassociate.run(module); if(target_->is_gpu()){ shmem_info.run(module); diff --git a/lib/codegen/reassociate.cpp b/lib/codegen/reassociate.cpp index bf36b2033..d0a54ec31 100644 --- a/lib/codegen/reassociate.cpp +++ b/lib/codegen/reassociate.cpp @@ -215,6 +215,21 @@ void reassociate::run(ir::module &mod) { infos[sta_ptr].dyn_ptr = (ir::getelementptr_inst*)dyn_ptr; infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr; } + // reassociate pointer argument + if(ir::getelementptr_inst* gepy = dynamic_cast(py)) + if(infos.find(gepy) != infos.end()){ + builder.set_insert_point(pz); + ir::getelementptr_inst *sta = infos[gepy].sta_ptr; + ir::getelementptr_inst *dyn = infos[gepy].dyn_ptr; + ir::value *cst = *sta->idx_begin(); + ir::value *off = *pz->idx_begin(); + ir::value *new_dyn = builder.create_gep(dyn, {off}); + ir::value *new_pz = builder.create_gep(new_dyn, {cst}, pz->get_name()); + params_->copy(new_dyn, pz); + params_->copy(new_pz, pz); + align_->copy(new_pz, pz); + pz->replace_all_uses_with(new_pz); + } // reassociate phi-node pointer if(ir::phi_node* phi = dynamic_cast(py)){ // only optimize the case where py = phi pa, pz for now diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 8c351be1c..820db29b3 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -228,7 +228,7 @@ void tune::run(ir::module &mod) { nts->set_value(1); } else { - ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4); + ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); } @@ -247,14 +247,14 @@ void tune::run(ir::module &mod) { size_t addr_space = ptr_ty->get_pointer_address_space(); if(addr_space < 4){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 8)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 8)); *params_.at(i).at("nts.d0") = *tmp; } } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 2, 8)); - std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 2, 8)); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 4, 8)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 4, 8)); *params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d1") = *tmp2; } diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp index 3ea79bc78..c7e3a9a85 100644 --- a/lib/dnn/blocksparse/dot.cpp +++ b/lib/dnn/blocksparse/dot.cpp @@ -88,8 +88,8 @@ void dot::triton_c_src(std::ostream &os) const { std::string bca1 = "newaxis, :"; std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :"; std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis"; - std::string ldb0 = (op_ == FPROP) ? "1" : "TK"; - std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ; + std::string ldb0 = (op_ == FPROP) ? "" : "*TK"; + std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ; std::string result = R"( const tunable int32 TM = {16, 32, 64, 128}; @@ -110,7 +110,7 @@ void dot::triton_c_src(std::ostream &os) const { int32 rkb[TK] = 0 ... TK; int1 checka[TM, TK] = (rxa < N)[:, newaxis]; int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda; - int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]*" + ldb0 + " + rkb[" + bcb1 + "]*" + ldb1 + R"(; + int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; int32 *header = lut + ridy * 4; int32 offset = *(header + 0); int32 K = *(header + 1);