diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index ef73a7581..20b5bc72f 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->synchronize(); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8); // benchmark triton - double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); + double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); // benchmark cublas // NumericT alpha = 1; // NumericT beta = 0; diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 1cc8c929a..f43f94e8f 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -65,7 +65,7 @@ public: vectorize(&tune), selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target), optimize_dot(&tune), - optimize_dce(), + dce(), optimize_trans(), alignment_info(), reassociate(&tune, &alignment_info), @@ -73,9 +73,9 @@ public: void target_independent(ir::module &module) { optimize_dot.run(module); - optimize_dce.run(module); optimize_trans.run(module); - optimize_dce.run(module); + dce.run(module); +// ir::print(module, std::cout); } void target_dependent(ir::module &module) { @@ -88,8 +88,7 @@ public: shmem_barriers.run(module); } vectorize.run(module); - optimize_dce.run(module); -// ir::print(module, std::cout); + dce.run(module); } codegen::selection selection; @@ -101,7 +100,7 @@ public: codegen::transform::shmem_barriers shmem_barriers; codegen::transform::vectorize vectorize; codegen::transform::optimize_dot optimize_dot; - codegen::transform::optimize_dce optimize_dce; + codegen::transform::optimize_dce dce; codegen::transform::optimize_trans optimize_trans; codegen::transform::reassociate reassociate; codegen::target* target_; diff --git a/lib/codegen/transform/trans.cpp b/lib/codegen/transform/trans.cpp index 43cba99b7..4edfa6a59 100644 --- a/lib/codegen/transform/trans.cpp +++ b/lib/codegen/transform/trans.cpp @@ -17,10 +17,9 @@ ir::value* optimize_trans::replace_phi(ir::value* value, incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm)); // create phi for transposed values builder.set_insert_point(phi); - ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name()); + ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size()); for(unsigned n = 0; n < phi->get_num_incoming(); n++) result->add_incoming(incs[n], phi->get_incoming_block(n)); - phi->replace_all_uses_with(result); return result; } else if(auto i = dynamic_cast(value)){ @@ -29,7 +28,6 @@ ir::value* optimize_trans::replace_phi(ir::value* value, it++; builder.set_insert_point(it); ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm); - i->replace_all_uses_with(trans); trans->set_operand(0, i); return trans; } diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index dae023eef..86d031564 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -62,8 +62,7 @@ std::pair base::get_profile_impl(driver::stream *stream, std::v jit->add_module(name_.c_str(), src.c_str(), best.params); } else{ -// params_t params = heuristics(); - params_t params = {4, 2, 16, 4, 4, 16, 2, 2, 1, 1, 1, 8, 64, 8, 8, 1, 4, 2, 1}; + params_t params = heuristics(); jit->add_module(name_.c_str(), src.c_str(), params); } triton::driver::kernel* kernel = jit->get_function(name_.c_str());