[code generation] implements hidden operands in user (e.g., mask)
This commit is contained in:
@@ -57,14 +57,25 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
|
|||||||
for(k = K; k > 0; k = k - 8){\
|
for(k = K; k > 0; k = k - 8){\
|
||||||
int1 checka[16, 8] = (k > 8);\
|
int1 checka[16, 8] = (k > 8);\
|
||||||
int1 checkb[16, 8] = (k > 8);\
|
int1 checkb[16, 8] = (k > 8);\
|
||||||
|
int1 checka0[16];\
|
||||||
|
int1 checka1[8];\
|
||||||
|
int1 checkb0[16];\
|
||||||
|
int1 checkb1[8];\
|
||||||
C = dot(a, b, C);\
|
C = dot(a, b, C);\
|
||||||
pa = pa + 8*M;\
|
pa = pa + 8*M;\
|
||||||
pb = pb + 8*K;\
|
pb = pb + 8*K;\
|
||||||
@checka a = *pa;\
|
@checka a = *pa;\
|
||||||
@checkb b = *pb;\
|
@checkb b = *pb;\
|
||||||
if(k > 8){\
|
if(k > 8)\
|
||||||
continue;\
|
continue;\
|
||||||
}\
|
checka0 = rxa < M;\
|
||||||
|
checka1 = rka < k;\
|
||||||
|
checkb0 = ryb < N;\
|
||||||
|
checkb1 = rkb < k;\
|
||||||
|
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||||
|
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||||
|
@checka a = *pa;\
|
||||||
|
@checkb b = *pb;\
|
||||||
}\
|
}\
|
||||||
@checkc *pc = C;\
|
@checkc *pc = C;\
|
||||||
}\
|
}\
|
||||||
@@ -211,7 +222,6 @@ int main() {
|
|||||||
if(errors.size())
|
if(errors.size())
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
|
|
||||||
// print
|
|
||||||
|
|
||||||
// run passes
|
// run passes
|
||||||
tdl::ir::print(module, std::cout);
|
tdl::ir::print(module, std::cout);
|
||||||
|
@@ -35,8 +35,8 @@ public:
|
|||||||
basic_block *get_parent() { return parent_; }
|
basic_block *get_parent() { return parent_; }
|
||||||
void erase_from_parent();
|
void erase_from_parent();
|
||||||
// mask
|
// mask
|
||||||
value* set_mask(value *pred, value *else_value = nullptr) { mask_ = {pred, else_value}; }
|
void set_mask_pred(value *pred) { resize_hidden(1); set_operand(get_num_operands(), pred); }
|
||||||
const mask_info_t get_mask() const { return mask_; }
|
value* get_mask_pred() const { if(get_num_hidden() == 0) return nullptr; return get_operand(get_num_operands()); }
|
||||||
// helpers
|
// helpers
|
||||||
bool has_tile_result_or_op();
|
bool has_tile_result_or_op();
|
||||||
// repr
|
// repr
|
||||||
@@ -45,7 +45,7 @@ public:
|
|||||||
private:
|
private:
|
||||||
basic_block *parent_;
|
basic_block *parent_;
|
||||||
value *pred_;
|
value *pred_;
|
||||||
mask_info_t mask_;
|
value *mask_pred_;
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -51,20 +51,23 @@ public:
|
|||||||
typedef ops_t::const_iterator const_op_iterator;
|
typedef ops_t::const_iterator const_op_iterator;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void resize_ops(unsigned n) { ops_.resize(n); }
|
void resize_ops(unsigned num_ops) { ops_.resize(num_ops + num_hidden_); num_ops_ = num_ops; }
|
||||||
|
void resize_hidden(unsigned num_hidden) { ops_.resize(num_ops_ + num_hidden); num_hidden_ = num_hidden; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Constructor
|
// Constructor
|
||||||
user(type *ty, unsigned num_ops, const std::string &name = "")
|
user(type *ty, unsigned num_ops, const std::string &name = "")
|
||||||
: value(ty, name), ops_(num_ops){ }
|
: value(ty, name), ops_(num_ops), num_ops_(num_ops), num_hidden_(0){
|
||||||
|
}
|
||||||
|
|
||||||
// Operands
|
// Operands
|
||||||
const ops_t& ops() { return ops_; }
|
const ops_t& ops() { return ops_; }
|
||||||
op_iterator op_begin() { return ops_.begin(); }
|
op_iterator op_begin() { return ops_.begin(); }
|
||||||
op_iterator op_end() { return ops_.end(); }
|
op_iterator op_end() { return ops_.end(); }
|
||||||
void set_operand(unsigned i, value *x);
|
void set_operand(unsigned i, value *x);
|
||||||
value *get_operand(unsigned i);
|
value *get_operand(unsigned i) const;
|
||||||
unsigned get_num_operands() const ;
|
unsigned get_num_operands() const ;
|
||||||
|
unsigned get_num_hidden() const;
|
||||||
|
|
||||||
// Utils
|
// Utils
|
||||||
void replace_all_uses_with(value *target);
|
void replace_all_uses_with(value *target);
|
||||||
@@ -72,6 +75,8 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
ops_t ops_;
|
ops_t ops_;
|
||||||
|
unsigned num_ops_;
|
||||||
|
unsigned num_hidden_;
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@@ -264,7 +264,7 @@ ir::value* expression_statement::codegen(ir::module *mod) const{
|
|||||||
ir::instruction *itn = dynamic_cast<ir::instruction*>(expr);
|
ir::instruction *itn = dynamic_cast<ir::instruction*>(expr);
|
||||||
assert(itn);
|
assert(itn);
|
||||||
ir::value *mask = mask_->codegen(mod);
|
ir::value *mask = mask_->codegen(mod);
|
||||||
itn->set_mask(mask);
|
itn->set_mask_pred(mask);
|
||||||
}
|
}
|
||||||
return expr;
|
return expr;
|
||||||
}
|
}
|
||||||
|
@@ -484,16 +484,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
BasicBlock *block = builder.GetInsertBlock();
|
BasicBlock *block = builder.GetInsertBlock();
|
||||||
Module *module = block->getModule();
|
Module *module = block->getModule();
|
||||||
Function *function = block->getParent();
|
Function *function = block->getParent();
|
||||||
ir::instruction::mask_info_t mask = ins->get_mask();
|
ir::value* mask_pred = ins->get_mask_pred();
|
||||||
LLVMContext &ctx = builder.getContext();
|
LLVMContext &ctx = builder.getContext();
|
||||||
// helper to handle masks
|
// helper to handle masks
|
||||||
auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
|
auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
|
||||||
BasicBlock *block = builder.GetInsertBlock();
|
BasicBlock *block = builder.GetInsertBlock();
|
||||||
Value *result;
|
Value *result;
|
||||||
if(mask.pred){
|
if(mask_pred){
|
||||||
// if(mask.else_value)
|
// if(mask.else_value)
|
||||||
// std::cout << mask.else_value << std::endl;
|
// std::cout << mask.else_value << std::endl;
|
||||||
Value *llvm_mask = tmap_.at(mask.pred)->get_value(idx);
|
Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
|
||||||
BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
|
BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
|
||||||
BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
|
BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
|
||||||
builder.CreateCondBr(llvm_mask, then_bb, done_bb);
|
builder.CreateCondBr(llvm_mask, then_bb, done_bb);
|
||||||
|
@@ -73,7 +73,7 @@ void tune::init_c_graph(ir::instruction *v) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Add mask constraints */
|
/* Add mask constraints */
|
||||||
if(ir::value *pred = v->get_mask().pred){
|
if(ir::value *pred = v->get_mask_pred()){
|
||||||
for(unsigned i = 0; i < shapes.size(); i++)
|
for(unsigned i = 0; i < shapes.size(); i++)
|
||||||
add_constraint({v->ops()[0], i}, {pred, i});
|
add_constraint({v->ops()[0], i}, {pred, i});
|
||||||
}
|
}
|
||||||
|
@@ -43,13 +43,17 @@ void user::set_operand(unsigned i, value *x) {
|
|||||||
x->add_use(this);
|
x->add_use(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
value* user::get_operand(unsigned i) {
|
value* user::get_operand(unsigned i) const {
|
||||||
assert(i < ops_.size() && "get_operand() out of range!");
|
assert(i < ops_.size() && "get_operand() out of range!");
|
||||||
return ops_[i];
|
return ops_[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned user::get_num_operands() const {
|
unsigned user::get_num_operands() const {
|
||||||
return ops_.size();
|
return num_ops_;
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned user::get_num_hidden() const {
|
||||||
|
return num_hidden_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void user::replace_all_uses_with(value *target) {
|
void user::replace_all_uses_with(value *target) {
|
||||||
|
Reference in New Issue
Block a user