[CODEGEN] Fix an inliner to call a function with a phi-node (#727)

This commit is contained in:
Shintaro Iwasaki
2022-09-29 21:36:40 -07:00
committed by GitHub
parent f45e31ba7c
commit ae59f51c2d
2 changed files with 75 additions and 20 deletions

View File

@@ -53,36 +53,56 @@ void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder&
for(size_t k = 0; k < fn->args().size(); k++)
arg_map[fn->args()[k]] = callsite->ops()[k];
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
// clone instructions
for(size_t i = 0; i < new_blocks.size(); i++){
ir::basic_block* old_block = fn->blocks()[i];
ir::basic_block* new_block = new_blocks[i];
builder.set_insert_point(new_block);
for(ir::instruction* old_inst: old_block->get_inst_list()){
// clone instruction
ir::instruction* new_inst = old_inst->clone();
// replace basic block
for(size_t k = 0; k < new_blocks.size(); k++)
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
// replace values
for(size_t k = 0; k < new_inst->get_num_operands(); k++){
ir::value* op = new_inst->get_operand(k);
if(auto arg_op = dynamic_cast<ir::argument*>(op))
new_inst->set_operand(k, arg_map.at(arg_op));
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
if(inst_map.find(inst_op) != inst_map.end())
new_inst->set_operand(k, inst_map.at(inst_op));
}
// `ret` instruction is a special case:
// instead of returning we need to branch to after the function call
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)){
if(ir::value* ret_val = ret->get_return_value())
exit_val->add_incoming(ret_val, new_block);
new_inst = ir::branch_inst::create(exit);
}
inst_map[old_inst] = new_inst;
builder.insert(new_inst);
}
}
// update basic blocks
for(size_t i = 0; i < new_blocks.size(); i++) {
for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) {
// replace basic use cases
for(size_t k = 0; k < new_blocks.size(); k++)
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(new_inst)) {
// additionally replace basic blocks of phi-nodes since
// replace_uses_of_with() does not replace them.
for(unsigned in = 0; in < phi->get_num_incoming(); in++)
for(size_t k = 0; k < new_blocks.size(); k++)
if (phi->get_incoming_block(in) == fn->blocks()[k])
phi->set_incoming_block(in, new_blocks[k]);
}
}
}
// replace operands of instructions after constructing inst_map
for (auto& it: inst_map) {
ir::instruction* new_inst = it.second;
for(size_t k = 0; k < new_inst->get_num_operands(); k++) {
ir::value* op = new_inst->get_operand(k);
if(auto arg_op = dynamic_cast<ir::argument*>(op))
new_inst->set_operand(k, arg_map.at(arg_op));
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
if(inst_map.find(inst_op) != inst_map.end())
new_inst->set_operand(k, inst_map.at(inst_op));
}
// handles a ret instruciton.
// instead of returning we need to branch to after the function call
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
if(ir::value* ret_val = ret->get_return_value())
exit_val->add_incoming(ret_val, new_inst->get_parent());
// replace ret with branch
ir::instruction* new_br_inst = ir::branch_inst::create(exit);
builder.set_insert_point(new_inst->get_parent());
builder.insert(new_br_inst);
new_inst->erase_from_parent();
}
}
if(exit_val->get_num_incoming() == 1)
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
// done -- make sure insert point is properly set to exit block

View File

@@ -1376,6 +1376,41 @@ def test_constexpr_scalar_shape():
kernel[(1,)](x_tri, 32)
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
# -------------
# test call
# -------------
@triton.jit
def val_multiplier(val, i):
return val * i
@triton.jit
def vecmul_kernel(ptr, n_elements, rep):
pid = tl.program_id(axis=0)
offsets = pid * 128 + tl.arange(0, 128)
mask = offsets < n_elements
vec = tl.load(ptr + offsets, mask=mask)
for i in range(1, rep):
vec = val_multiplier(vec, i)
tl.store(ptr + offsets, vec, mask=mask)
def test_call():
@triton.jit
def kernel(ptr, n_elements, num1, num2):
vecmul_kernel(ptr, n_elements, num1)
vecmul_kernel(ptr, n_elements, num2)
size = 1024
rand_val = numpy_random((size,), dtype_str="float32")
rand_val_tri = to_triton(rand_val, device='cuda')
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
# -------------
# test if