[code generation] added ternary operator

This commit is contained in:
Philippe Tillet
2019-03-01 21:53:35 -05:00
parent 08fcfbca47
commit 2467c5e504
4 changed files with 46 additions and 20 deletions

View File

@@ -79,8 +79,8 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
checkb1 = rkb < k;\
checka = checka0[:, newaxis] && checka1[newaxis, :];\
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
@checka a = *pa;\
@checkb b = *pb;\
a = checka ? *pa : 0;\
b = checkb ? *pb : 0;\
}\
checkc0 = rxc < M;\
checkc1 = ryc < N;\
@@ -220,6 +220,7 @@ int main() {
triton::codegen::vectorize vectorize(&tune);
triton::codegen::selection selection(&allocation, &tune, &buffer_info);
triton::ir::print(module, std::cout);
// tuning parameters
tune.run(module);
@@ -280,7 +281,7 @@ int main() {
manager.run(llvm_module);
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
// std::cout << src << std::endl;
std::cout << src << std::endl;
// compile machine code
CUdevice cu_device;

View File

@@ -603,23 +603,47 @@ ir::value* cast_operator::codegen(ir::module *mod) const{
}
/* Conditional expression */
ir::value *conditional_expression::llvm_op(ir::builder &builder, ir::value *cond, ir::value *true_value, ir::value *false_value, const std::string &name) const{
return nullptr;
}
ir::value *conditional_expression::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
ir::value *cond = cond_->codegen(mod);
ir::value *false_value = false_value_->codegen(mod);
ir::value *pred = cond_->codegen(mod);
ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
ir::value *true_mask = mask->get_result(0);
ir::value *false_mask = mask->get_result(1);
ir::value *true_value = true_value_->codegen(mod);
ir::value *false_value = false_value_->codegen(mod);
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
itn->set_mask_pred(true_mask);
if(auto *itn = dynamic_cast<ir::instruction*>(false_value))
itn->set_mask_pred(false_mask);
bool is_float, is_ptr, is_int, is_signed;
ir::value *uncasted_true_value = true_value;
ir::value *uncasted_false_value = false_value;
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, true_value, false_value);
ir::instruction *itn = dynamic_cast<ir::instruction*>(true_value);
assert(itn);
itn->set_mask_pred(cond);
itn->set_mask_else(false_value);
return itn;
{
ir::value *current = true_value;
while(current != uncasted_true_value) {
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
itn->set_mask_pred(true_mask);
current = itn->get_operand(0);
}
else
break;
}
}
{
ir::value *current = false_value;
while(current != uncasted_false_value) {
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
itn->set_mask_pred(false_mask);
current = itn->get_operand(0);
}
else
break;
}
}
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
return result;
}
/* Assignment expression */

View File

@@ -589,10 +589,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
Value *value_false = value_tile_false->get_value(idx);
BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
if(block_done->empty())
builder.SetInsertPoint(block_done);
else
if(block_done->getTerminator())
builder.SetInsertPoint(block_done->getTerminator());
else
builder.SetInsertPoint(block_done);
PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
phi->addIncoming(value_true, block_true);
phi->addIncoming(value_false,block_false);
@@ -615,6 +615,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {
result->for_each([&](indices_t idx) {
set_mask_insert_pt(idx);
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
});
}
@@ -703,7 +704,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
}
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
if(src->has_tile_result_or_op()) {
if(src->has_tile_result_or_op() || (src->get_mask_pred() && src->get_mask_pred()->get_type()->is_tile_ty())) {
lower_tile_instruction(src, builder);
}
else {

View File

@@ -32,8 +32,8 @@ void instruction::erase_from_parent() {
bool instruction::has_tile_result_or_op() {
bool result = get_type()->is_tile_ty();
for(ir::value *v: ops())
result |= v->get_type()->is_tile_ty();
for(unsigned i = 0; i < get_num_operands(); i++)
result |= get_operand(i)->get_type()->is_tile_ty();
return result;
}