[CODEGEN] Fixed bug that caused missing recoalescing for some transpose
operations
This commit is contained in:
committed by
Philippe Tillet
parent
0c5bd7563a
commit
0516ea96d0
@@ -19,7 +19,7 @@ namespace transform{
|
|||||||
|
|
||||||
class peephole {
|
class peephole {
|
||||||
private:
|
private:
|
||||||
bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
// bool rewrite_cts_cfs(ir::instruction *value, ir::builder &builder);
|
||||||
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
bool rewrite_trans_phi(ir::instruction* value, ir::builder &builder);
|
||||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
|
@@ -109,9 +109,10 @@ void coalesce::run(ir::module &mod) {
|
|||||||
// extract leading axes
|
// extract leading axes
|
||||||
std::map<int, std::vector<ir::io_inst*>> axes;
|
std::map<int, std::vector<ir::io_inst*>> axes;
|
||||||
for(ir::io_inst *i: io){
|
for(ir::io_inst *i: io){
|
||||||
if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->get_rank())
|
if(i->get_pointer_operand()->get_type()->get_tile_rank() == layout_->get(id)->get_rank()){
|
||||||
extract_ld(i, axes);
|
extract_ld(i, axes);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// update list of values to rematerialize
|
// update list of values to rematerialize
|
||||||
if(axes.empty())
|
if(axes.empty())
|
||||||
continue;
|
continue;
|
||||||
|
@@ -83,18 +83,18 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
//bool peephole::rewrite_cts_cfs(ir::instruction *value, ir::builder &builder){
|
||||||
auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
// auto cfs = dynamic_cast<ir::copy_from_shared_inst*>(value);
|
||||||
if(cfs) {
|
// if(cfs) {
|
||||||
ir::value *arg = cfs->get_operand(0);
|
// ir::value *arg = cfs->get_operand(0);
|
||||||
ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
// ir::copy_to_shared_inst* cts = dynamic_cast<ir::copy_to_shared_inst*>(arg);
|
||||||
if(!cts)
|
// if(!cts)
|
||||||
return false;
|
// return false;
|
||||||
cfs->replace_all_uses_with(cts->get_operand(0));
|
// cfs->replace_all_uses_with(cts->get_operand(0));
|
||||||
return true;
|
// return true;
|
||||||
}
|
// }
|
||||||
|
|
||||||
}
|
//}
|
||||||
|
|
||||||
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
||||||
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
auto x = dynamic_cast<ir::reduce_inst*>(value);
|
||||||
@@ -196,7 +196,7 @@ void peephole::run(ir::module &mod) {
|
|||||||
continue;
|
continue;
|
||||||
bool was_modified = false;
|
bool was_modified = false;
|
||||||
was_modified = was_modified || rewrite_mult(i, builder);
|
was_modified = was_modified || rewrite_mult(i, builder);
|
||||||
was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
// was_modified = was_modified || rewrite_cts_cfs(i, builder);
|
||||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||||
|
@@ -218,7 +218,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
|||||||
codegen::transform::cts cts;
|
codegen::transform::cts cts;
|
||||||
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
|
||||||
// run passes
|
// run passes
|
||||||
// ir::print(module, std::cout);
|
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
disassociate.run(module);
|
disassociate.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
|
Reference in New Issue
Block a user