[code generation] added ternary operator
This commit is contained in:
@@ -79,8 +79,8 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
|
||||
checkb1 = rkb < k;\
|
||||
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
a = checka ? *pa : 0;\
|
||||
b = checkb ? *pb : 0;\
|
||||
}\
|
||||
checkc0 = rxc < M;\
|
||||
checkc1 = ryc < N;\
|
||||
@@ -220,6 +220,7 @@ int main() {
|
||||
triton::codegen::vectorize vectorize(&tune);
|
||||
triton::codegen::selection selection(&allocation, &tune, &buffer_info);
|
||||
|
||||
triton::ir::print(module, std::cout);
|
||||
|
||||
// tuning parameters
|
||||
tune.run(module);
|
||||
@@ -280,7 +281,7 @@ int main() {
|
||||
manager.run(llvm_module);
|
||||
|
||||
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
|
||||
// std::cout << src << std::endl;
|
||||
std::cout << src << std::endl;
|
||||
|
||||
// compile machine code
|
||||
CUdevice cu_device;
|
||||
|
@@ -603,23 +603,47 @@ ir::value* cast_operator::codegen(ir::module *mod) const{
|
||||
}
|
||||
|
||||
/* Conditional expression */
|
||||
ir::value *conditional_expression::llvm_op(ir::builder &builder, ir::value *cond, ir::value *true_value, ir::value *false_value, const std::string &name) const{
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ir::value *conditional_expression::codegen(ir::module *mod) const{
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::value *cond = cond_->codegen(mod);
|
||||
ir::value *false_value = false_value_->codegen(mod);
|
||||
ir::value *pred = cond_->codegen(mod);
|
||||
ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
|
||||
ir::value *true_mask = mask->get_result(0);
|
||||
ir::value *false_mask = mask->get_result(1);
|
||||
ir::value *true_value = true_value_->codegen(mod);
|
||||
ir::value *false_value = false_value_->codegen(mod);
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
|
||||
itn->set_mask_pred(true_mask);
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(false_value))
|
||||
itn->set_mask_pred(false_mask);
|
||||
bool is_float, is_ptr, is_int, is_signed;
|
||||
ir::value *uncasted_true_value = true_value;
|
||||
ir::value *uncasted_false_value = false_value;
|
||||
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||
implicit_broadcast(mod, true_value, false_value);
|
||||
ir::instruction *itn = dynamic_cast<ir::instruction*>(true_value);
|
||||
assert(itn);
|
||||
itn->set_mask_pred(cond);
|
||||
itn->set_mask_else(false_value);
|
||||
return itn;
|
||||
{
|
||||
ir::value *current = true_value;
|
||||
while(current != uncasted_true_value) {
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
|
||||
itn->set_mask_pred(true_mask);
|
||||
current = itn->get_operand(0);
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
{
|
||||
ir::value *current = false_value;
|
||||
while(current != uncasted_false_value) {
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
|
||||
itn->set_mask_pred(false_mask);
|
||||
current = itn->get_operand(0);
|
||||
}
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Assignment expression */
|
||||
|
@@ -589,10 +589,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
BasicBlock *block_false = pmap_.at({mask_tile_false, idx});
|
||||
Value *value_false = value_tile_false->get_value(idx);
|
||||
BasicBlock *block_done = last_block_.at({mask_tile_true, idx});
|
||||
if(block_done->empty())
|
||||
builder.SetInsertPoint(block_done);
|
||||
else
|
||||
if(block_done->getTerminator())
|
||||
builder.SetInsertPoint(block_done->getTerminator());
|
||||
else
|
||||
builder.SetInsertPoint(block_done);
|
||||
PHINode *phi = builder.CreatePHI(value_true->getType(), 2);
|
||||
phi->addIncoming(value_true, block_true);
|
||||
phi->addIncoming(value_false,block_false);
|
||||
@@ -615,6 +615,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
// splat
|
||||
else if(dynamic_cast<ir::splat_inst*>(ins)) {
|
||||
result->for_each([&](indices_t idx) {
|
||||
set_mask_insert_pt(idx);
|
||||
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
|
||||
});
|
||||
}
|
||||
@@ -703,7 +704,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
}
|
||||
|
||||
void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
||||
if(src->has_tile_result_or_op()) {
|
||||
if(src->has_tile_result_or_op() || (src->get_mask_pred() && src->get_mask_pred()->get_type()->is_tile_ty())) {
|
||||
lower_tile_instruction(src, builder);
|
||||
}
|
||||
else {
|
||||
|
@@ -32,8 +32,8 @@ void instruction::erase_from_parent() {
|
||||
|
||||
bool instruction::has_tile_result_or_op() {
|
||||
bool result = get_type()->is_tile_ty();
|
||||
for(ir::value *v: ops())
|
||||
result |= v->get_type()->is_tile_ty();
|
||||
for(unsigned i = 0; i < get_num_operands(); i++)
|
||||
result |= get_operand(i)->get_type()->is_tile_ty();
|
||||
return result;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user