[CODEGEN] Pipeline fixup (#336)
This commit is contained in:
		@@ -101,6 +101,15 @@ void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map<ir:
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct pipeline_info_t {
 | 
			
		||||
  ir::load_inst* load;
 | 
			
		||||
  ir::phi_node* ptr;
 | 
			
		||||
  ir::dot_inst* dot;
 | 
			
		||||
 | 
			
		||||
  pipeline_info_t(ir::load_inst* load, ir::phi_node* ptr, ir::dot_inst* dot)
 | 
			
		||||
    : load(load), ptr(ptr), dot(dot) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void pipeline::run(ir::module &mod) {
 | 
			
		||||
  // *Very* conservative heuristics for pre-fetching.
 | 
			
		||||
  // A load instruction can be pipelined if:
 | 
			
		||||
@@ -108,14 +117,15 @@ void pipeline::run(ir::module &mod) {
 | 
			
		||||
  //     in its basic block (i.e., pointer induction variable)
 | 
			
		||||
  //   - the load has only  a single use in a dot instruction
 | 
			
		||||
  // As more use cases become apparent, this pass will be improved
 | 
			
		||||
  std::vector<std::pair<ir::load_inst*, ir::phi_node*>> to_pipeline;
 | 
			
		||||
  std::vector<pipeline_info_t> to_pipeline;
 | 
			
		||||
  ir::for_each_instruction(mod, [&](ir::instruction *i){
 | 
			
		||||
    if(auto* load = dynamic_cast<ir::load_inst*>(i)){
 | 
			
		||||
      ir::phi_node* ptr = dynamic_cast<ir::phi_node*>(load->get_pointer_operand());
 | 
			
		||||
      auto users = load->get_users();
 | 
			
		||||
      auto dot = dynamic_cast<ir::dot_inst*>(*users.begin());
 | 
			
		||||
      if(ptr && ptr->get_incoming_block(1) == ptr->get_parent()
 | 
			
		||||
         && users.size() == 1 && dynamic_cast<ir::dot_inst*>(*users.begin()))
 | 
			
		||||
        to_pipeline.push_back({load, ptr});
 | 
			
		||||
         && users.size() == 1 && dot)
 | 
			
		||||
        to_pipeline.push_back({load, ptr, dot});
 | 
			
		||||
    }});
 | 
			
		||||
  // do the pipelining
 | 
			
		||||
  std::vector<ir::phi_node*> new_loads;
 | 
			
		||||
@@ -123,8 +133,8 @@ void pipeline::run(ir::module &mod) {
 | 
			
		||||
  const int num_stages = num_stages_;
 | 
			
		||||
  std::vector<std::pair<ir::phi_node*, std::vector<ir::value*>>> preheader_loads; // Used to reorder loads
 | 
			
		||||
  for(auto info: to_pipeline){
 | 
			
		||||
    ir::load_inst* load = info.first;
 | 
			
		||||
    ir::phi_node* ptr   = info.second;
 | 
			
		||||
    ir::load_inst* load = info.load;
 | 
			
		||||
    ir::phi_node* ptr   = info.ptr;
 | 
			
		||||
    ir::basic_block* block = load->get_parent();
 | 
			
		||||
    ir::basic_block* header = block->get_predecessors()[0];
 | 
			
		||||
    auto* block_br = dynamic_cast<ir::cond_branch_inst*>(block->get_inst_list().back());
 | 
			
		||||
@@ -288,22 +298,23 @@ void pipeline::run(ir::module &mod) {
 | 
			
		||||
    std::vector<ir::instruction*> insts;
 | 
			
		||||
    ir::load_inst* dst;
 | 
			
		||||
  };
 | 
			
		||||
  std::map<ir::basic_block*, move_config_t> to_move;
 | 
			
		||||
  std::vector<move_config_t> to_move(to_pipeline.size());
 | 
			
		||||
 | 
			
		||||
  if(has_copy_async_){
 | 
			
		||||
    for(ir::function* fn: mod.get_function_list())
 | 
			
		||||
    for(ir::basic_block* bb: fn->blocks())
 | 
			
		||||
    for(ir::instruction* inst: bb->get_inst_list()){
 | 
			
		||||
      if(auto* i = dynamic_cast<ir::dot_inst*>(inst))
 | 
			
		||||
        recursive_deps(i, bb, to_move[bb].insts);
 | 
			
		||||
      if(auto* i = dynamic_cast<ir::load_inst*>(inst))
 | 
			
		||||
        to_move[bb].dst = i;
 | 
			
		||||
    for (size_t idx = 0; idx < to_pipeline.size(); ++idx) {
 | 
			
		||||
      auto info = to_pipeline[idx];
 | 
			
		||||
      ir::load_inst* load = info.load;
 | 
			
		||||
      ir::phi_node* ptr = info.ptr;
 | 
			
		||||
      ir::dot_inst* dot = info.dot;
 | 
			
		||||
      ir::basic_block* bb = dot->get_parent();
 | 
			
		||||
      recursive_deps(dot, bb, to_move[idx].insts);
 | 
			
		||||
      to_move[idx].dst = load;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for(auto& x: to_move){
 | 
			
		||||
      builder.set_insert_point_after(x.second.dst);
 | 
			
		||||
      for(ir::instruction* i: x.second.insts){
 | 
			
		||||
        x.first->erase(i);
 | 
			
		||||
    for(auto& move_config: to_move){
 | 
			
		||||
      builder.set_insert_point_after(move_config.dst);
 | 
			
		||||
      for(ir::instruction* i: move_config.insts){
 | 
			
		||||
        i->get_parent()->erase(i);
 | 
			
		||||
        builder.insert(i);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
@@ -515,6 +515,20 @@ def test_dot(epilogue, device='cuda'):
 | 
			
		||||
    assert 'ld.global.v4' in ptx
 | 
			
		||||
    assert 'st.global.v4' in ptx
 | 
			
		||||
 | 
			
		||||
def test_dot_without_load():
 | 
			
		||||
    @triton.jit
 | 
			
		||||
    def kernel(out, **meta):
 | 
			
		||||
        pid = tl.program_id(axis=0)
 | 
			
		||||
        a = tl.zeros((32, 32), tl.float32)
 | 
			
		||||
        b = tl.zeros((32, 32), tl.float32)
 | 
			
		||||
        c = tl.zeros((32, 32), tl.float32)
 | 
			
		||||
        c = tl.dot(a, b)
 | 
			
		||||
        pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :]
 | 
			
		||||
        tl.store(pout, c)
 | 
			
		||||
        
 | 
			
		||||
    out = torch.ones((32,32), dtype=torch.float32, device="cuda")
 | 
			
		||||
    kernel[(1,)](out)
 | 
			
		||||
 | 
			
		||||
# ---------------
 | 
			
		||||
# test arange
 | 
			
		||||
# ---------------
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user