[codegen][transform][trans] fixed incorrect replace_all_uses_with
This commit is contained in:
@@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
|
||||
stream->synchronize();
|
||||
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
|
||||
// benchmark triton
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream);
|
||||
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
|
||||
// benchmark cublas
|
||||
// NumericT alpha = 1;
|
||||
// NumericT beta = 0;
|
||||
|
@@ -65,7 +65,7 @@ public:
|
||||
vectorize(&tune),
|
||||
selection(&shmem_allocation, &tune, &shmem_info, &alignment_info, target),
|
||||
optimize_dot(&tune),
|
||||
optimize_dce(),
|
||||
dce(),
|
||||
optimize_trans(),
|
||||
alignment_info(),
|
||||
reassociate(&tune, &alignment_info),
|
||||
@@ -73,9 +73,9 @@ public:
|
||||
|
||||
void target_independent(ir::module &module) {
|
||||
optimize_dot.run(module);
|
||||
optimize_dce.run(module);
|
||||
optimize_trans.run(module);
|
||||
optimize_dce.run(module);
|
||||
dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
@@ -88,8 +88,7 @@ public:
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
optimize_dce.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
dce.run(module);
|
||||
}
|
||||
|
||||
codegen::selection selection;
|
||||
@@ -101,7 +100,7 @@ public:
|
||||
codegen::transform::shmem_barriers shmem_barriers;
|
||||
codegen::transform::vectorize vectorize;
|
||||
codegen::transform::optimize_dot optimize_dot;
|
||||
codegen::transform::optimize_dce optimize_dce;
|
||||
codegen::transform::optimize_dce dce;
|
||||
codegen::transform::optimize_trans optimize_trans;
|
||||
codegen::transform::reassociate reassociate;
|
||||
codegen::target* target_;
|
||||
|
@@ -17,10 +17,9 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
incs.push_back(replace_phi(phi->get_incoming_value(n), builder, perm));
|
||||
// create phi for transposed values
|
||||
builder.set_insert_point(phi);
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name());
|
||||
ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size());
|
||||
for(unsigned n = 0; n < phi->get_num_incoming(); n++)
|
||||
result->add_incoming(incs[n], phi->get_incoming_block(n));
|
||||
phi->replace_all_uses_with(result);
|
||||
return result;
|
||||
}
|
||||
else if(auto i = dynamic_cast<ir::instruction*>(value)){
|
||||
@@ -29,7 +28,6 @@ ir::value* optimize_trans::replace_phi(ir::value* value,
|
||||
it++;
|
||||
builder.set_insert_point(it);
|
||||
ir::instruction *trans = (ir::instruction*)builder.create_trans(i, perm);
|
||||
i->replace_all_uses_with(trans);
|
||||
trans->set_operand(0, i);
|
||||
return trans;
|
||||
}
|
||||
|
@@ -62,8 +62,7 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
|
||||
jit->add_module(name_.c_str(), src.c_str(), best.params);
|
||||
}
|
||||
else{
|
||||
// params_t params = heuristics();
|
||||
params_t params = {4, 2, 16, 4, 4, 16, 2, 2, 1, 1, 1, 8, 64, 8, 8, 1, 4, 2, 1};
|
||||
params_t params = heuristics();
|
||||
jit->add_module(name_.c_str(), src.c_str(), params);
|
||||
}
|
||||
triton::driver::kernel* kernel = jit->get_function(name_.c_str());
|
||||
|
Reference in New Issue
Block a user