[CODEGEN] Fix an inliner to call a function with a phi-node (#727)
This commit is contained in:
@@ -53,36 +53,56 @@ void inliner::do_inline(ir::function* fn, ir::call_inst* callsite, ir::builder&
|
||||
for(size_t k = 0; k < fn->args().size(); k++)
|
||||
arg_map[fn->args()[k]] = callsite->ops()[k];
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
// clone instructions
|
||||
for(size_t i = 0; i < new_blocks.size(); i++){
|
||||
ir::basic_block* old_block = fn->blocks()[i];
|
||||
ir::basic_block* new_block = new_blocks[i];
|
||||
builder.set_insert_point(new_block);
|
||||
for(ir::instruction* old_inst: old_block->get_inst_list()){
|
||||
// clone instruction
|
||||
ir::instruction* new_inst = old_inst->clone();
|
||||
// replace basic block
|
||||
for(size_t k = 0; k < new_blocks.size(); k++)
|
||||
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
|
||||
// replace values
|
||||
for(size_t k = 0; k < new_inst->get_num_operands(); k++){
|
||||
ir::value* op = new_inst->get_operand(k);
|
||||
if(auto arg_op = dynamic_cast<ir::argument*>(op))
|
||||
new_inst->set_operand(k, arg_map.at(arg_op));
|
||||
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
|
||||
if(inst_map.find(inst_op) != inst_map.end())
|
||||
new_inst->set_operand(k, inst_map.at(inst_op));
|
||||
}
|
||||
// `ret` instruction is a special case:
|
||||
// instead of returning we need to branch to after the function call
|
||||
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)){
|
||||
if(ir::value* ret_val = ret->get_return_value())
|
||||
exit_val->add_incoming(ret_val, new_block);
|
||||
new_inst = ir::branch_inst::create(exit);
|
||||
}
|
||||
inst_map[old_inst] = new_inst;
|
||||
builder.insert(new_inst);
|
||||
}
|
||||
}
|
||||
// update basic blocks
|
||||
for(size_t i = 0; i < new_blocks.size(); i++) {
|
||||
for (ir::instruction* new_inst: new_blocks[i]->get_inst_list()) {
|
||||
// replace basic use cases
|
||||
for(size_t k = 0; k < new_blocks.size(); k++)
|
||||
new_inst->replace_uses_of_with(fn->blocks()[k], new_blocks[k]);
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(new_inst)) {
|
||||
// additionally replace basic blocks of phi-nodes since
|
||||
// replace_uses_of_with() does not replace them.
|
||||
for(unsigned in = 0; in < phi->get_num_incoming(); in++)
|
||||
for(size_t k = 0; k < new_blocks.size(); k++)
|
||||
if (phi->get_incoming_block(in) == fn->blocks()[k])
|
||||
phi->set_incoming_block(in, new_blocks[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// replace operands of instructions after constructing inst_map
|
||||
for (auto& it: inst_map) {
|
||||
ir::instruction* new_inst = it.second;
|
||||
for(size_t k = 0; k < new_inst->get_num_operands(); k++) {
|
||||
ir::value* op = new_inst->get_operand(k);
|
||||
if(auto arg_op = dynamic_cast<ir::argument*>(op))
|
||||
new_inst->set_operand(k, arg_map.at(arg_op));
|
||||
if(auto inst_op = dynamic_cast<ir::instruction*>(op))
|
||||
if(inst_map.find(inst_op) != inst_map.end())
|
||||
new_inst->set_operand(k, inst_map.at(inst_op));
|
||||
}
|
||||
// handles a ret instruciton.
|
||||
// instead of returning we need to branch to after the function call
|
||||
if(ir::return_inst* ret = dynamic_cast<ir::return_inst*>(new_inst)) {
|
||||
if(ir::value* ret_val = ret->get_return_value())
|
||||
exit_val->add_incoming(ret_val, new_inst->get_parent());
|
||||
// replace ret with branch
|
||||
ir::instruction* new_br_inst = ir::branch_inst::create(exit);
|
||||
builder.set_insert_point(new_inst->get_parent());
|
||||
builder.insert(new_br_inst);
|
||||
new_inst->erase_from_parent();
|
||||
}
|
||||
}
|
||||
if(exit_val->get_num_incoming() == 1)
|
||||
exit_val->replace_all_uses_with(exit_val->get_incoming_value(0));
|
||||
// done -- make sure insert point is properly set to exit block
|
||||
|
@@ -1376,6 +1376,41 @@ def test_constexpr_scalar_shape():
|
||||
kernel[(1,)](x_tri, 32)
|
||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)
|
||||
|
||||
# -------------
|
||||
# test call
|
||||
# -------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def val_multiplier(val, i):
|
||||
return val * i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def vecmul_kernel(ptr, n_elements, rep):
|
||||
pid = tl.program_id(axis=0)
|
||||
offsets = pid * 128 + tl.arange(0, 128)
|
||||
mask = offsets < n_elements
|
||||
vec = tl.load(ptr + offsets, mask=mask)
|
||||
for i in range(1, rep):
|
||||
vec = val_multiplier(vec, i)
|
||||
tl.store(ptr + offsets, vec, mask=mask)
|
||||
|
||||
|
||||
def test_call():
|
||||
|
||||
@triton.jit
|
||||
def kernel(ptr, n_elements, num1, num2):
|
||||
vecmul_kernel(ptr, n_elements, num1)
|
||||
vecmul_kernel(ptr, n_elements, num2)
|
||||
|
||||
size = 1024
|
||||
rand_val = numpy_random((size,), dtype_str="float32")
|
||||
rand_val_tri = to_triton(rand_val, device='cuda')
|
||||
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
|
||||
|
||||
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||
|
||||
# -------------
|
||||
# test if
|
||||
|
Reference in New Issue
Block a user