From 9e9d7819124a267e538dd973104b74c20efc6acd Mon Sep 17 00:00:00 2001 From: daadaada Date: Sun, 10 Oct 2021 16:47:11 +0800 Subject: [PATCH] [CODEGEN] Pipeline fixup (#336) --- lib/codegen/transform/pipeline.cc | 45 ++++++++++++++++---------- python/test/unit/language/test_core.py | 14 ++++++++ 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/lib/codegen/transform/pipeline.cc b/lib/codegen/transform/pipeline.cc index 096b7d7b6..c09cafe83 100644 --- a/lib/codegen/transform/pipeline.cc +++ b/lib/codegen/transform/pipeline.cc @@ -101,6 +101,15 @@ void finalize_iv_vals(ir::builder& builder, ir::basic_block* block, std::map> to_pipeline; + std::vector to_pipeline; ir::for_each_instruction(mod, [&](ir::instruction *i){ if(auto* load = dynamic_cast(i)){ ir::phi_node* ptr = dynamic_cast(load->get_pointer_operand()); auto users = load->get_users(); + auto dot = dynamic_cast(*users.begin()); if(ptr && ptr->get_incoming_block(1) == ptr->get_parent() - && users.size() == 1 && dynamic_cast(*users.begin())) - to_pipeline.push_back({load, ptr}); + && users.size() == 1 && dot) + to_pipeline.push_back({load, ptr, dot}); }}); // do the pipelining std::vector new_loads; @@ -123,8 +133,8 @@ void pipeline::run(ir::module &mod) { const int num_stages = num_stages_; std::vector>> preheader_loads; // Used to reorder loads for(auto info: to_pipeline){ - ir::load_inst* load = info.first; - ir::phi_node* ptr = info.second; + ir::load_inst* load = info.load; + ir::phi_node* ptr = info.ptr; ir::basic_block* block = load->get_parent(); ir::basic_block* header = block->get_predecessors()[0]; auto* block_br = dynamic_cast(block->get_inst_list().back()); @@ -288,22 +298,23 @@ void pipeline::run(ir::module &mod) { std::vector insts; ir::load_inst* dst; }; - std::map to_move; + std::vector to_move(to_pipeline.size()); if(has_copy_async_){ - for(ir::function* fn: mod.get_function_list()) - for(ir::basic_block* bb: fn->blocks()) - for(ir::instruction* inst: bb->get_inst_list()){ - if(auto* i = dynamic_cast(inst)) - recursive_deps(i, bb, to_move[bb].insts); - if(auto* i = dynamic_cast(inst)) - to_move[bb].dst = i; + for (size_t idx = 0; idx < to_pipeline.size(); ++idx) { + auto info = to_pipeline[idx]; + ir::load_inst* load = info.load; + ir::phi_node* ptr = info.ptr; + ir::dot_inst* dot = info.dot; + ir::basic_block* bb = dot->get_parent(); + recursive_deps(dot, bb, to_move[idx].insts); + to_move[idx].dst = load; } - for(auto& x: to_move){ - builder.set_insert_point_after(x.second.dst); - for(ir::instruction* i: x.second.insts){ - x.first->erase(i); + for(auto& move_config: to_move){ + builder.set_insert_point_after(move_config.dst); + for(ir::instruction* i: move_config.insts){ + i->get_parent()->erase(i); builder.insert(i); } } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3744e4d03..5c6072b3e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -515,6 +515,20 @@ def test_dot(epilogue, device='cuda'): assert 'ld.global.v4' in ptx assert 'st.global.v4' in ptx +def test_dot_without_load(): + @triton.jit + def kernel(out, **meta): + pid = tl.program_id(axis=0) + a = tl.zeros((32, 32), tl.float32) + b = tl.zeros((32, 32), tl.float32) + c = tl.zeros((32, 32), tl.float32) + c = tl.dot(a, b) + pout = out + tl.arange(0, 32)[:, None]*32 + tl.arange(0, 32)[None, :] + tl.store(pout, c) + + out = torch.ones((32,32), dtype=torch.float32, device="cuda") + kernel[(1,)](out) + # --------------- # test arange # ---------------