[triton-c] predicate in assignment statement now propagates to rhs
computations
This commit is contained in:
@@ -69,7 +69,7 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
|
||||
*pcount = countp1;
|
||||
}
|
||||
else {
|
||||
*pc = c + (checkc ? *pc : 0);
|
||||
@checkc *pc = c + *pc;
|
||||
*pcount = countp1;
|
||||
}
|
||||
__atomic_cas(plock, 1, 0);
|
||||
|
@@ -410,13 +410,13 @@ class statement: public block_item{
|
||||
class expression_statement: public statement{
|
||||
public:
|
||||
expression_statement(node *expr, node *mask = nullptr)
|
||||
: expr_((expression*)expr), mask_((expression*)mask){ }
|
||||
: expr_((expression*)expr), pred_((expression*)mask){ }
|
||||
|
||||
ir::value* codegen(ir::module * mod) const;
|
||||
|
||||
private:
|
||||
expression *expr_;
|
||||
expression *mask_;
|
||||
expression *pred_;
|
||||
};
|
||||
|
||||
|
||||
|
@@ -335,15 +335,15 @@ public:
|
||||
};
|
||||
|
||||
// merge
|
||||
class merge_inst: public instruction {
|
||||
class psi_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "merge"; }
|
||||
merge_inst(ir::value *mask_true, ir::value *value_true,
|
||||
psi_inst(ir::value *mask_true, ir::value *value_true,
|
||||
ir::value *mask_false, ir::value *value_false,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
static merge_inst* create(ir::value *mask_true, ir::value *value_true,
|
||||
static psi_inst* create(ir::value *mask_true, ir::value *value_true,
|
||||
ir::value *mask_false, ir::value *value_false,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
ir::value *get_mask_true() { return get_operand(0); }
|
||||
|
@@ -70,6 +70,7 @@ public:
|
||||
shmem_barriers.run(module);
|
||||
}
|
||||
vectorize.run(module);
|
||||
ir::print(module, std::cout);
|
||||
}
|
||||
|
||||
codegen::tune tune;
|
||||
|
@@ -320,23 +320,33 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
/* expression statement */
|
||||
ir::value* expression_statement::codegen(ir::module *mod) const{
|
||||
ir::builder &builder = mod->get_builder();
|
||||
if(mask_) {
|
||||
ir::value *pred = mask_->codegen(mod);
|
||||
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
|
||||
ir::value *true_value = expr_->codegen(mod);
|
||||
ir::basic_block *block = builder.get_insert_block();
|
||||
if(pred_) {
|
||||
// check that it is an assignment
|
||||
assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_);
|
||||
assert(assignment);
|
||||
|
||||
ir::type *ty = true_value->get_type();
|
||||
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
|
||||
itn->set_mask_pred(mask->get_result(0));
|
||||
// generate mask
|
||||
ir::value *pred = pred_->codegen(mod);
|
||||
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
|
||||
// generate expression
|
||||
unsigned szbegin = block->get_inst_list().size();
|
||||
ir::value *expr = expr_->codegen(mod);
|
||||
ir::basic_block::iterator begin = block->begin();
|
||||
std::advance(begin, szbegin);
|
||||
// set mask
|
||||
ir::type *ty = expr->get_type();
|
||||
for(auto it = begin; it != builder.get_insert_point(); it++)
|
||||
(*it)->set_mask_pred(mask->get_result(0));
|
||||
// if(auto *itn = dynamic_cast<ir::instruction*>(expr))
|
||||
// itn->set_mask_pred(mask->get_result(0));
|
||||
if(ty->is_void_ty())
|
||||
return true_value;
|
||||
ir::merge_inst *merge = (ir::merge_inst*)builder.create_merge(mask->get_result(0), true_value,
|
||||
return expr;
|
||||
// merge with psi
|
||||
ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr,
|
||||
mask->get_result(1), ir::undef_value::get(ty));
|
||||
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
|
||||
mod->set_value(name, merge);
|
||||
return merge;
|
||||
mod->set_value(name, psi);
|
||||
return psi;
|
||||
}
|
||||
return expr_->codegen(mod);
|
||||
}
|
||||
|
@@ -690,7 +690,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
});
|
||||
}
|
||||
// merge
|
||||
else if(auto *merge = dynamic_cast<ir::merge_inst*>(ins)) {
|
||||
else if(auto *merge = dynamic_cast<ir::psi_inst*>(ins)) {
|
||||
distributed_tile* mask_tile_true = (distributed_tile*)tmap_.at(merge->get_mask_true());
|
||||
distributed_tile *value_tile_true = (distributed_tile*)tmap_.at(merge->get_value_true());
|
||||
distributed_tile* mask_tile_false = (distributed_tile*)tmap_.at(merge->get_mask_false());
|
||||
@@ -951,7 +951,7 @@ void selection::run(ir::module &src, Module &dst) {
|
||||
dst_builder.SetInsertPoint(parent);
|
||||
for(ir::instruction *i: block->get_inst_list()){
|
||||
BasicBlock *current = dst_builder.GetInsertBlock();
|
||||
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::merge_inst*>(i)) && !current->empty();
|
||||
bool phi_inserted = (dynamic_cast<ir::phi_node*>(i) || dynamic_cast<ir::psi_inst*>(i)) && !current->empty();
|
||||
if(phi_inserted && current->getFirstNonPHI())
|
||||
dst_builder.SetInsertPoint(&*current->getFirstNonPHI());
|
||||
lower_instruction(i, dst_builder);
|
||||
|
@@ -253,7 +253,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||
|
||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||
std::cout << source << std::endl;
|
||||
// std::cout << source << std::endl;
|
||||
cu_context::context_switcher ctx_switch(*context);
|
||||
// JIT compile source-code
|
||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||
|
@@ -92,7 +92,7 @@ value *builder::create_mask(value *pred, const std::string &name){
|
||||
}
|
||||
|
||||
value *builder::create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name) {
|
||||
return insert(merge_inst::create(mask_true, value_true, mask_false, value_false, name));
|
||||
return insert(psi_inst::create(mask_true, value_true, mask_false, value_false, name));
|
||||
}
|
||||
|
||||
|
||||
|
@@ -334,7 +334,7 @@ mask_inst* mask_inst::create(value *pred, const std::string &name, instruction *
|
||||
}
|
||||
|
||||
// merge_inst
|
||||
merge_inst::merge_inst(value *mask_true, value *value_true,
|
||||
psi_inst::psi_inst(value *mask_true, value *value_true,
|
||||
value *mask_false, value *value_false,
|
||||
const std::string &name, instruction *next)
|
||||
: instruction(value_true->get_type(), 4, 1, name, next) {
|
||||
@@ -344,10 +344,10 @@ merge_inst::merge_inst(value *mask_true, value *value_true,
|
||||
set_operand(3, value_false);
|
||||
}
|
||||
|
||||
merge_inst* merge_inst::create(value *mask_true, value *value_true,
|
||||
psi_inst* psi_inst::create(value *mask_true, value *value_true,
|
||||
value *mask_false, value *value_false,
|
||||
const std::string &name, instruction *next) {
|
||||
return new merge_inst(mask_true, value_true, mask_false, value_false, name, next);
|
||||
return new psi_inst(mask_true, value_true, mask_false, value_false, name, next);
|
||||
}
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user