more improvements and regressions
This commit is contained in:
@@ -8,7 +8,17 @@ namespace triton {
|
||||
namespace codegen{
|
||||
|
||||
inline bool is_trans(ir::value *v){
|
||||
return dynamic_cast<ir::trans_inst*>(v) != nullptr;
|
||||
auto *x = dynamic_cast<ir::trans_inst*>(v);
|
||||
if(!x)
|
||||
return false;
|
||||
std::vector<ir::constant_int*> perm = x->get_perm();
|
||||
std::vector<ir::constant_int*> ref;
|
||||
ir::type *int32_ty = ir::type::get_int32_ty(v->get_type()->get_context());
|
||||
for(size_t i = 0; i < perm.size(); i++)
|
||||
ref.push_back(ir::constant_int::get(int32_ty, i));
|
||||
std::swap(ref[0], ref[1]);
|
||||
// true is perm == ref
|
||||
return std::equal(perm.begin(), perm.end(), ref.begin());
|
||||
}
|
||||
|
||||
inline bool is_hmma(ir::value *v){
|
||||
@@ -28,7 +38,6 @@ inline bool is_hmma(ir::value *v){
|
||||
|
||||
void optimize_dot::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
std::vector<ir::instruction*> to_delete;
|
||||
// iterate
|
||||
for(ir::function *fn: mod.get_function_list())
|
||||
for(ir::basic_block *block: fn->blocks())
|
||||
@@ -47,15 +56,12 @@ void optimize_dot::run(ir::module &mod) {
|
||||
ir::value *BB = B;
|
||||
if(trans_a){
|
||||
AA = ((ir::trans_inst*)A)->get_operand(0);
|
||||
to_delete.push_back((ir::instruction*)A);
|
||||
}
|
||||
if(trans_b){
|
||||
BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
to_delete.push_back((ir::instruction*)B);
|
||||
}
|
||||
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
|
||||
dot->replace_all_uses_with(dot_atbt);
|
||||
to_delete.push_back(dot);
|
||||
}
|
||||
else{
|
||||
// dot(op(a), trans(b))
|
||||
@@ -63,28 +69,24 @@ void optimize_dot::run(ir::module &mod) {
|
||||
ir::value* BB = ((ir::trans_inst*)B)->get_operand(0);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
to_delete.push_back((ir::instruction*)B);
|
||||
to_delete.push_back(dot);
|
||||
}
|
||||
// dot(op(a), b)
|
||||
if(!trans_b){
|
||||
// create permutations
|
||||
size_t size = B->get_type()->get_tile_shapes().size();
|
||||
std::vector<ir::constant_int*> perm(size);
|
||||
ir::type *int32_ty = ir::type::get_int32_ty(B->get_type()->get_context());
|
||||
for(size_t i = 0; i < size; i++)
|
||||
perm[i] = ir::constant_int::get(int32_ty, i);
|
||||
std::swap(perm[0], perm[1]);
|
||||
// replace NN -> NT (trans)
|
||||
ir::value* BB = builder.create_trans(B, perm);
|
||||
ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BB, D));
|
||||
dot->replace_all_uses_with(NT);
|
||||
to_delete.push_back(dot);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(ir::instruction* i: to_delete)
|
||||
i->erase_from_parent();
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user