[codegen][transform][trans] fixed incorrect replace_all_uses_with

This commit is contained in:
Philippe Tillet
2019-08-07 21:50:16 -07:00
parent 7578c27d3d
commit f93099bda1
4 changed files with 8 additions and 12 deletions

View File

@@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize(); stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
// benchmark triton // benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
// benchmark cublas // benchmark cublas
// NumericT alpha = 1; // NumericT alpha = 1;
// NumericT beta = 0; // NumericT beta = 0;

View File

@@ -65,7 +65,7 @@ public:
vectorize(&tune), vectorize(&tune),
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target), selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
optimize_dot(&tune), optimize_dot(&tune),
optimize_dce(), dce(),
optimize_trans(), optimize_trans(),
alignment_info(), alignment_info(),
reassociate(&tune, &alignment_info), reassociate(&tune, &alignment_info),
@@ -73,9 +73,9 @@ public:
void target_independent(ir::module &module) { void target_independent(ir::module &module) {
optimize_dot.run(module); optimize_dot.run(module);
optimize_dce.run(module);
optimize_trans.run(module); optimize_trans.run(module);
optimize_dce.run(module); dce.run(module);
// ir::print(module, std::cout);
} }
void target_dependent(ir::module &module) { void target_dependent(ir::module &module) {
@@ -88,8 +88,7 @@ public:
shmem_barriers.run(module); shmem_barriers.run(module);
} }
vectorize.run(module); vectorize.run(module);
optimize_dce.run(module); dce.run(module);
// ir::print(module, std::cout);
} }
codegen::selection selection; codegen::selection selection;
@@ -101,7 +100,7 @@ public:
codegen::transform::shmem_barriers shmem_barriers; codegen::transform::shmem_barriers shmem_barriers;
codegen::transform::vectorize vectorize; codegen::transform::vectorize vectorize;
codegen::transform::optimize_dot optimize_dot; codegen::transform::optimize_dot optimize_dot;
codegen::transform::optimize_dce optimize_dce; codegen::transform::optimize_dce dce;
codegen::transform::optimize_trans optimize_trans; codegen::transform::optimize_trans optimize_trans;
codegen::transform::reassociate reassociate; codegen::transform::reassociate reassociate;
codegen::target* target_; codegen::target* target_;

View File

@@ -17,10 +17,9 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm)); incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm));
// create phi for transposed values // create phi for transposed values
builder.set_insert_point(phi); builder.set_insert_point(phi);
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name()); ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
for(unsigned n = 0; n < phi->get_num_incoming(); n++) for(unsigned n = 0; n < phi->get_num_incoming(); n++)
result->add_incoming(incs[n], phi->get_incoming_block(n)); result->add_incoming(incs[n], phi->get_incoming_block(n));
phi->replace_all_uses_with(result);
return result; return result;
} }
else if(auto i = dynamic_cast<ir::instruction*>(value)){ else if(auto i = dynamic_cast<ir::instruction*>(value)){
@@ -29,7 +28,6 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
it++; it++;
builder.set_insert_point(it); builder.set_insert_point(it);
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm); ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
i->replace_all_uses_with(trans);
trans->set_operand(0, i); trans->set_operand(0, i);
return trans; return trans;
} }

View File

@@ -62,8 +62,7 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
jit->add_module(name_.c_str(), src.c_str(), best.params); jit->add_module(name_.c_str(), src.c_str(), best.params);
} }
else{ else{
// params_t params = heuristics(); params_t params = heuristics();
params_t params = {4, 2, 16, 4, 4, 16, 2, 2, 1, 1, 1, 8, 64, 8, 8, 1, 4, 2, 1};
jit->add_module(name_.c_str(), src.c_str(), params); jit->add_module(name_.c_str(), src.c_str(), params);
} }
triton::driver::kernel* kernel = jit->get_function(name_.c_str()); triton::driver::kernel* kernel = jit->get_function(name_.c_str());