[CODEGEN] Pipeline fixup (#336)

This commit is contained in:
daadaada
2021-10-10 16:47:11 +08:00
committed by GitHub
parent d5f20dbce0
commit 9e9d781912
2 changed files with 42 additions and 17 deletions

View File

@@ -101,6 +101,15 @@ void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir:
} }
} }
struct pipeline_info_t {
ir::load_inst* load;
ir::phi_node* ptr;
ir::dot_inst* dot;
pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
: load(load), ptr(ptr), dot(dot) {}
};
void pipeline::run(ir::module &mod) { void pipeline::run(ir::module &mod) {
// *Very* conservative heuristics for pre-fetching. // *Very* conservative heuristics for pre-fetching.
// A load instruction can be pipelined if: // A load instruction can be pipelined if:
@@ -108,14 +117,15 @@ void pipeline::run(ir::module &mod) {
// in its basic block (i.e., pointer induction variable) // in its basic block (i.e., pointer induction variable)
// - the load has only a single use in a dot instruction // - the load has only a single use in a dot instruction
// As more use cases become apparent, this pass will be improved // As more use cases become apparent, this pass will be improved
std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline; std::vector<pipeline_info_t> to_pipeline;
ir::for_each_instruction(mod, [&](ir::instruction *i){ ir::for_each_instruction(mod, [&](ir::instruction *i){
if(auto* load = dynamic_cast<ir::load_inst*>(i)){ if(auto* load = dynamic_cast<ir::load_inst*>(i)){
ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand()); ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
auto users = load->get_users(); auto users = load->get_users();
auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
if(ptr && ptr->get_incoming_block(1) == ptr->get_parent() if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
&& users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin())) && users.size() == 1 && dot)
to_pipeline.push_back({load, ptr}); to_pipeline.push_back({load, ptr, dot});
}}); }});
// do the pipelining // do the pipelining
std::vector<ir::phi_node*> new_loads; std::vector<ir::phi_node*> new_loads;
@@ -123,8 +133,8 @@ void pipeline::run(ir::module &mod) {
const int num_stages = num_stages_; const int num_stages = num_stages_;
std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
for(auto info: to_pipeline){ for(auto info: to_pipeline){
ir::load_inst* load = info.first; ir::load_inst* load = info.load;
ir::phi_node* ptr = info.second; ir::phi_node* ptr = info.ptr;
ir::basic_block* block = load->get_parent(); ir::basic_block* block = load->get_parent();
ir::basic_block* header = block->get_predecessors()[0]; ir::basic_block* header = block->get_predecessors()[0];
auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back()); auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
@@ -288,22 +298,23 @@ void pipeline::run(ir::module &mod) {
std::vector<ir::instruction*> insts; std::vector<ir::instruction*> insts;
ir::load_inst* dst; ir::load_inst* dst;
}; };
std::map<ir::basic_block*, move_config_t> to_move; std::vector<move_config_t> to_move(to_pipeline.size());
if(has_copy_async_){ if(has_copy_async_){
for(ir::function* fn: mod.get_function_list()) for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
for(ir::basic_block* bb: fn->blocks()) auto info = to_pipeline[idx];
for(ir::instruction* inst: bb->get_inst_list()){ ir::load_inst* load = info.load;
if(auto* i = dynamic_cast<ir::dot_inst*>(inst)) ir::phi_node* ptr = info.ptr;
recursive_deps(i, bb, to_move[bb].insts); ir::dot_inst* dot = info.dot;
if(auto* i = dynamic_cast<ir::load_inst*>(inst)) ir::basic_block* bb = dot->get_parent();
to_move[bb].dst = i; recursive_deps(dot, bb, to_move[idx].insts);
to_move[idx].dst = load;
} }
for(auto& x: to_move){ for(auto& move_config: to_move){
builder.set_insert_point_after(x.second.dst); builder.set_insert_point_after(move_config.dst);
for(ir::instruction* i: x.second.insts){ for(ir::instruction* i: move_config.insts){
x.first->erase(i); i->get_parent()->erase(i);
builder.insert(i); builder.insert(i);
} }
} }

View File

@@ -515,6 +515,20 @@ def test_dot(epilogue, device='cuda'):
assert 'ld.global.v4' in ptx assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx assert 'st.global.v4' in ptx
def test_dot_without_load():
@triton.jit
def kernel(out, **meta):
pid = tl.program_id(axis=0)
a = tl.zeros((32, 32), tl.float32)
b = tl.zeros((32, 32), tl.float32)
c = tl.zeros((32, 32), tl.float32)
c = tl.dot(a, b)
pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
tl.store(pout, c)
out = torch.ones((32,32), dtype=torch.float32, device="cuda")
kernel[(1,)](out)
# --------------- # ---------------
# test arange # test arange
# --------------- # ---------------