From ae59f51c2dcc60397dda0227acbc823d59c18e35 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Thu, 29 Sep 2022 21:36:40 -0700 Subject: [PATCH] [CODEGEN] Fix an inliner to call a function with a phi-node (#727) --- lib/codegen/transform/inline.cc | 60 +++++++++++++++++--------- python/test/unit/language/test_core.py | 35 +++++++++++++++ 2 files changed, 75 insertions(+), 20 deletions(-) diff --git a/lib/codegen/transform/inline.cc b/lib/codegen/transform/inline.cc index fa22e5354..c870a7758 100644 --- a/lib/codegen/transform/inline.cc +++ b/lib/codegen/transform/inline.cc @@ -53,36 +53,56 @@ void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder& for(size_t k = 0; k < fn->args().size(); k++) arg_map[fn->args()[k]] = callsite->ops()[k]; std::vector rpo = ir::cfg::reverse_post_order(fn); + // clone instructions for(size_t i = 0; i < new_blocks.size(); i++){ ir::basic_block* old_block = fn->blocks()[i]; ir::basic_block* new_block = new_blocks[i]; builder.set_insert_point(new_block); for(ir::instruction* old_inst: old_block->get_inst_list()){ - // clone instruction ir::instruction* new_inst = old_inst->clone(); - // replace basic block - for(size_t k = 0; k < new_blocks.size(); k++) - new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]); - // replace values - for(size_t k = 0; k < new_inst->get_num_operands(); k++){ - ir::value* op = new_inst->get_operand(k); - if(auto arg_op = dynamic_cast(op)) - new_inst->set_operand(k, arg_map.at(arg_op)); - if(auto inst_op = dynamic_cast(op)) - if(inst_map.find(inst_op) != inst_map.end()) - new_inst->set_operand(k, inst_map.at(inst_op)); - } - // `ret` instruction is a special case: - // instead of returning we need to branch to after the function call - if(ir::return_inst* ret = dynamic_cast(new_inst)){ - if(ir::value* ret_val = ret->get_return_value()) - exit_val->add_incoming(ret_val, new_block); - new_inst = ir::branch_inst::create(exit); - } inst_map[old_inst] = new_inst; builder.insert(new_inst); } } + // update basic blocks + for(size_t i = 0; i < new_blocks.size(); i++) { + for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) { + // replace basic use cases + for(size_t k = 0; k < new_blocks.size(); k++) + new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]); + if(ir::phi_node* phi = dynamic_cast(new_inst)) { + // additionally replace basic blocks of phi-nodes since + // replace_uses_of_with() does not replace them. + for(unsigned in = 0; in < phi->get_num_incoming(); in++) + for(size_t k = 0; k < new_blocks.size(); k++) + if (phi->get_incoming_block(in) == fn->blocks()[k]) + phi->set_incoming_block(in, new_blocks[k]); + } + } + } + // replace operands of instructions after constructing inst_map + for (auto& it: inst_map) { + ir::instruction* new_inst = it.second; + for(size_t k = 0; k < new_inst->get_num_operands(); k++) { + ir::value* op = new_inst->get_operand(k); + if(auto arg_op = dynamic_cast(op)) + new_inst->set_operand(k, arg_map.at(arg_op)); + if(auto inst_op = dynamic_cast(op)) + if(inst_map.find(inst_op) != inst_map.end()) + new_inst->set_operand(k, inst_map.at(inst_op)); + } + // handles a ret instruciton. + // instead of returning we need to branch to after the function call + if(ir::return_inst* ret = dynamic_cast(new_inst)) { + if(ir::value* ret_val = ret->get_return_value()) + exit_val->add_incoming(ret_val, new_inst->get_parent()); + // replace ret with branch + ir::instruction* new_br_inst = ir::branch_inst::create(exit); + builder.set_insert_point(new_inst->get_parent()); + builder.insert(new_br_inst); + new_inst->erase_from_parent(); + } + } if(exit_val->get_num_incoming() == 1) exit_val->replace_all_uses_with(exit_val->get_incoming_value(0)); // done -- make sure insert point is properly set to exit block diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d00de5be5..e73e661ab 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1376,6 +1376,41 @@ def test_constexpr_scalar_shape(): kernel[(1,)](x_tri, 32) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) +# ------------- +# test call +# ------------- + + +@triton.jit +def val_multiplier(val, i): + return val * i + + +@triton.jit +def vecmul_kernel(ptr, n_elements, rep): + pid = tl.program_id(axis=0) + offsets = pid * 128 + tl.arange(0, 128) + mask = offsets < n_elements + vec = tl.load(ptr + offsets, mask=mask) + for i in range(1, rep): + vec = val_multiplier(vec, i) + tl.store(ptr + offsets, vec, mask=mask) + + +def test_call(): + + @triton.jit + def kernel(ptr, n_elements, num1, num2): + vecmul_kernel(ptr, n_elements, num1) + vecmul_kernel(ptr, n_elements, num2) + + size = 1024 + rand_val = numpy_random((size,), dtype_str="float32") + rand_val_tri = to_triton(rand_val, device='cuda') + kernel[(size // 128,)](rand_val_tri, size, 3, 5) + + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) # ------------- # test if