[codegen][transform][trans] fixed incorrect replace_all_uses_with

This commit is contained in:
Philippe Tillet
2019-08-07 21:50:16 -07:00
parent 7578c27d3d
commit f93099bda1
4 changed files with 8 additions and 12 deletions

View File

@@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
// benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
// benchmark cublas
// NumericT alpha = 1;
// NumericT beta = 0;

View File

@@ -65,7 +65,7 @@ public:
vectorize(&tune),
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
optimize_dot(&tune),
optimize_dce(),
dce(),
optimize_trans(),
alignment_info(),
reassociate(&tune, &alignment_info),
@@ -73,9 +73,9 @@ public:
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_dce.run(module);
optimize_trans.run(module);
optimize_dce.run(module);
dce.run(module);
// ir::print(module, std::cout);
}
void target_dependent(ir::module &module) {
@@ -88,8 +88,7 @@ public:
shmem_barriers.run(module);
}
vectorize.run(module);
optimize_dce.run(module);
// ir::print(module, std::cout);
dce.run(module);
}
codegen::selection selection;
@@ -101,7 +100,7 @@ public:
codegen::transform::shmem_barriers shmem_barriers;
codegen::transform::vectorize vectorize;
codegen::transform::optimize_dot optimize_dot;
codegen::transform::optimize_dce optimize_dce;
codegen::transform::optimize_dce dce;
codegen::transform::optimize_trans optimize_trans;
codegen::transform::reassociate reassociate;
codegen::target* target_;

View File

@@ -17,10 +17,9 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm));
// create phi for transposed values
builder.set_insert_point(phi);
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name());
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
result->add_incoming(incs[n], phi->get_incoming_block(n));
phi->replace_all_uses_with(result);
return result;
}
else if(auto i = dynamic_cast<ir::instruction*>(value)){
@@ -29,7 +28,6 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
it++;
builder.set_insert_point(it);
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
i->replace_all_uses_with(trans);
trans->set_operand(0, i);
return trans;
}

View File

@@ -62,8 +62,7 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else{
// params_t params = heuristics();
params_t params = {4, 2, 16, 4, 4, 16, 2, 2, 1, 1, 1, 8, 64, 8, 8, 1, 4, 2, 1};
params_t params = heuristics();
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());