2022-06-27 11:49:19 -07:00
|
|
|
#include "triton/codegen/analysis/layout.h"
|
2021-07-27 12:38:38 -07:00
|
|
|
#include "triton/codegen/transform/cts.h"
|
|
|
|
#include "triton/ir/module.h"
|
|
|
|
#include "triton/ir/function.h"
|
|
|
|
#include "triton/ir/basic_block.h"
|
|
|
|
#include "triton/ir/instructions.h"
|
2022-06-27 11:49:19 -07:00
|
|
|
#include "triton/ir/utils.h"
|
2021-07-27 12:38:38 -07:00
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
namespace triton {
|
|
|
|
namespace codegen{
|
|
|
|
namespace transform{
|
|
|
|
|
|
|
|
|
2022-06-27 11:49:19 -07:00
|
|
|
bool cts::is_shmem_op(ir::instruction* i, int op) {
|
2021-07-27 12:38:38 -07:00
|
|
|
if(i->get_id() == ir::INST_DOT)
|
2022-06-27 11:49:19 -07:00
|
|
|
return op == 0 || op == 1;
|
2021-07-27 12:38:38 -07:00
|
|
|
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
|
|
|
return op==0;
|
|
|
|
if(i->get_id() == ir::INST_TRANS)
|
|
|
|
return op==0;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2022-06-27 11:49:19 -07:00
|
|
|
bool cts::is_shmem_res(ir::value* v){
|
2021-07-27 12:38:38 -07:00
|
|
|
ir::instruction* i = dynamic_cast<ir::instruction*>(v);
|
|
|
|
if(!i)
|
|
|
|
return false;
|
|
|
|
if(i->get_id() == ir::INST_TRANS)
|
|
|
|
return true;
|
|
|
|
if(i->get_id() == ir::INST_COPY_TO_SHARED)
|
|
|
|
return true;
|
2021-01-11 19:20:34 -05:00
|
|
|
if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
|
|
|
return true;
|
2021-07-27 12:38:38 -07:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// run pass on module
|
2022-06-27 11:49:19 -07:00
|
|
|
void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared, std::map<ir::value*, ir::value*>& copies) {
|
2021-07-27 12:38:38 -07:00
|
|
|
auto *i = dynamic_cast<ir::instruction*>(x);
|
|
|
|
// not an instruction
|
|
|
|
if(!i) {
|
|
|
|
builder.set_insert_point(parent);
|
|
|
|
ir::value *copy;
|
|
|
|
if(to_shared)
|
|
|
|
copy = builder.create_copy_to_shared(x);
|
|
|
|
else
|
|
|
|
copy = builder.create_copy_from_shared(x);
|
|
|
|
parent->replace_uses_of_with(x, copy);
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
// phi node
|
|
|
|
if(auto* phi = dynamic_cast<ir::phi_node*>(x)) {
|
|
|
|
for(unsigned i = 0; i < phi->get_num_incoming(); ++i)
|
2022-06-27 11:49:19 -07:00
|
|
|
add_copy(phi, phi->get_incoming_value(i), builder, to_shared, copies);
|
2021-07-27 12:38:38 -07:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
// already in shared memory
|
|
|
|
if(to_shared && is_shmem_res(i))
|
|
|
|
return;
|
|
|
|
// copy
|
|
|
|
builder.set_insert_point_after(i);
|
|
|
|
ir::value *copy;
|
2021-01-11 19:20:34 -05:00
|
|
|
if(to_shared){
|
2021-07-27 12:38:38 -07:00
|
|
|
copy = builder.create_copy_to_shared(x);
|
2021-01-11 19:20:34 -05:00
|
|
|
}
|
2021-07-27 12:38:38 -07:00
|
|
|
else
|
|
|
|
copy = builder.create_copy_from_shared(x);
|
2022-06-27 11:49:19 -07:00
|
|
|
copies.insert({x, copy});
|
|
|
|
parent->replace_uses_of_with(x, copies.at(x));
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
void cts::run(ir::module &mod) {
|
2022-06-27 11:49:19 -07:00
|
|
|
// Precompute where copies should be added
|
|
|
|
std::set<ir::value*> shmem_ops;
|
|
|
|
std::set<ir::value*> shmem_res;
|
|
|
|
ir::for_each_instruction(mod, [&](ir::instruction* i) {
|
|
|
|
if(i->get_id() == ir::INST_DOT){
|
|
|
|
ir::dot_inst* dot = dynamic_cast<ir::dot_inst*>(i);
|
|
|
|
ir::value* lhs = i->get_operand(0);
|
|
|
|
ir::type* ty = lhs->get_type()->get_scalar_ty();
|
|
|
|
analysis::mma_layout* mma_lhs = layouts_->get(lhs)->to_mma();
|
|
|
|
// TODO: V100
|
|
|
|
bool is_lhs_shmem = !(mma_lhs && has_sm80_ && ty->get_primitive_size_in_bits() == 16 && !dot->is_trans_a());
|
|
|
|
if(is_lhs_shmem)
|
|
|
|
shmem_ops.insert(lhs);
|
|
|
|
shmem_ops.insert(i->get_operand(1));
|
|
|
|
}
|
|
|
|
if(i->get_id() == ir::INST_COPY_FROM_SHARED)
|
|
|
|
shmem_ops.insert(i->get_operand(0));
|
|
|
|
if(i->get_id() == ir::INST_TRANS)
|
|
|
|
shmem_ops.insert(i->get_operand(0));
|
|
|
|
if(i->get_id() == ir::INST_TRANS ||
|
|
|
|
i->get_id() == ir::INST_COPY_TO_SHARED ||
|
|
|
|
i->get_id() == ir::INST_MASKED_LOAD_ASYNC)
|
|
|
|
shmem_res.insert(i);
|
|
|
|
});
|
|
|
|
|
2021-07-27 12:38:38 -07:00
|
|
|
// Add shared copies
|
2022-06-27 11:49:19 -07:00
|
|
|
std::map<ir::value*, ir::value*> copies;
|
2021-07-27 12:38:38 -07:00
|
|
|
ir::builder &builder = mod.get_builder();
|
2022-06-27 11:49:19 -07:00
|
|
|
ir::for_each_instruction(mod, [&](ir::instruction* i) {
|
|
|
|
size_t num_op = i->get_num_operands();
|
|
|
|
for(size_t k = 0; k < num_op; k++){
|
|
|
|
ir::value* op = i->get_operand(k);
|
2021-07-27 12:38:38 -07:00
|
|
|
// copy to shared operands
|
2022-06-27 11:49:19 -07:00
|
|
|
bool is_shmem_op = shmem_ops.find(op) != shmem_ops.end();
|
|
|
|
if(is_shmem_op)
|
|
|
|
add_copy(i, op, builder, true, copies);
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
2022-06-27 11:49:19 -07:00
|
|
|
});
|
2021-07-27 12:38:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
2021-12-22 01:56:10 +08:00
|
|
|
}
|