better masking

This commit is contained in:
Philippe Tillet
2019-02-28 23:46:11 -05:00
parent 017702590b
commit 36acf22fd3
9 changed files with 203 additions and 86 deletions

View File

@@ -209,6 +209,7 @@ int main() {
llvm::Module llvm_module("matmul", llvm_context);
triton::ir::print(module, std::cout);
// create passes
triton::codegen::buffer_info_pass buffer_info;
@@ -220,6 +221,7 @@ int main() {
triton::codegen::vectorize vectorize(&tune);
triton::codegen::selection selection(&allocation, &tune, &buffer_info);
// tuning parameters
tune.run(module);
std::vector<unsigned> params = {
@@ -246,6 +248,9 @@ int main() {
context.p_impl->mp_constants_[2]->set_value(params[2]);
for(unsigned *x: tune.get_params(module))
*x = params[3 + i++];
// constraints
std::map<triton::ir::value*, std::vector<std::string>> errors;
tune.check_constraints(module, errors);
@@ -265,12 +270,11 @@ int main() {
allocation.run();
barriers.run(module);
vectorize.run(module);
triton::ir::print(module, std::cout);
selection.run(module, llvm_module);
// llvm source
llvm::legacy::PassManager manager;
// manager.add(llvm::createPrintModulePass(llvm::outs()));
manager.add(llvm::createPrintModulePass(llvm::outs()));
manager.add(llvm::createVerifierPass(true));
manager.run(llvm_module);

View File

@@ -51,7 +51,8 @@ public:
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
// Tile-level control flow
value *create_ternary(value *cond, value *true_value, value *false_value, const std::string &name = "");
value *create_mask(value *pred, const std::string &name = "");
value *create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name = "");
// Cast instructions
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");

View File

@@ -16,6 +16,7 @@ class context;
// instruction classes
//===----------------------------------------------------------------------===//
class result_reference;
class instruction: public user{
public:
struct mask_info_t {
@@ -27,7 +28,7 @@ public:
protected:
// constructors
instruction(type *ty, unsigned num_ops, const std::string &name = "", instruction *next = nullptr);
instruction(type *ty, unsigned num_ops, unsigned num_results = 1, const std::string &name = "", instruction *next = nullptr);
public:
// parent
@@ -38,15 +39,33 @@ public:
// mask
void set_mask_pred(value *pred) { resize_hidden(1); set_operand(get_num_operands(), pred); }
value* get_mask_pred() const { if(get_num_hidden() == 0) return nullptr; return get_operand(get_num_operands()); }
void set_mask_else(value *x) { resize_hidden(2); set_operand(get_num_operands() + 1, x); }
value* get_mask_else() const { if(get_num_hidden() < 2) return nullptr; return get_operand(get_num_operands() + 1); }
// helpers
bool has_tile_result_or_op();
// repr
std::string repr() const { return repr_impl(); }
// results
unsigned get_num_results() const { return results_.size(); }
value* get_result(unsigned i) { return results_.at(i); }
private:
basic_block *parent_;
value *pred_;
value *mask_pred_;
std::vector<value*> results_;
};
// result reference
class result_reference: public value {
public:
result_reference(instruction *ref, unsigned arg_id, const std::string &name = "");
instruction *get_ref();
unsigned get_arg_id();
private:
instruction *ref_;
unsigned arg_id_;
};
//===----------------------------------------------------------------------===//
@@ -303,6 +322,30 @@ public:
value *get_false_value() { return get_operand(2); }
static ternary_inst* create(value *cond, value *true_value, value *false_value,
const std::string &name = "", instruction *next = nullptr);
};
// mask
class mask_inst: public instruction {
private:
std::string repr_impl() const { return "mask"; }
mask_inst(ir::value *pred, const std::string &name, instruction *next);
public:
static mask_inst* create(ir::value *pred, const std::string &name = "", instruction *next = nullptr);
};
// merge
class merge_inst: public instruction {
private:
std::string repr_impl() const { return "merge"; }
merge_inst(ir::value *mask_true, ir::value *value_true,
ir::value *mask_false, ir::value *value_false,
const std::string &name, instruction *next);
public:
static merge_inst* create(ir::value *mask_true, ir::value *value_true,
ir::value *mask_false, ir::value *value_false,
const std::string &name = "", instruction *next = nullptr);
};

View File

@@ -301,12 +301,20 @@ ir::value* compound_statement::codegen(ir::module* mod) const{
/* expression statement */
ir::value* expression_statement::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
ir::value *expr = expr_->codegen(mod);
if(mask_) {
ir::instruction *itn = dynamic_cast<ir::instruction*>(expr);
assert(itn);
ir::value *mask = mask_->codegen(mod);
itn->set_mask_pred(mask);
ir::value *pred = mask_->codegen(mod);
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
ir::value *true_value = expr_->codegen(mod);
ir::type *ty = true_value->get_type();
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
itn->set_mask_pred(mask->get_result(0));
if(expr->get_type()->is_void_ty())
return expr;
ir::merge_inst *merge = (ir::merge_inst*)builder.create_merge(mask->get_result(0), true_value,
mask->get_result(1), ir::undef_value::get(ty));
return merge;
}
return expr;
}
@@ -596,10 +604,18 @@ ir::value *conditional_expression::llvm_op(ir::builder &builder, ir::value *cond
}
ir::value *conditional_expression::codegen(ir::module *mod) const{
ir::builder &builder = mod->get_builder();
ir::value *cond = cond_->codegen(mod);
ir::value *true_value = true_value_->codegen(mod);
ir::value *false_value = false_value_->codegen(mod);
return llvm_op(mod->get_builder(), cond, true_value, false_value, "");
ir::value *true_value = true_value_->codegen(mod);
bool is_float, is_ptr, is_int, is_signed;
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
implicit_broadcast(mod, true_value, false_value);
ir::instruction *itn = dynamic_cast<ir::instruction*>(true_value);
assert(itn);
itn->set_mask_pred(cond);
itn->set_mask_else(false_value);
return itn;
}
/* Assignment expression */

View File

@@ -472,7 +472,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
distributed_tile *T = new distributed_tile(ty, shapes2, axes, builder, vectorize);
tmap_.insert({v, T});
// constant range
if(dynamic_cast<ir::constant*>(v)){
if(dynamic_cast<ir::constant*>(v) && !dynamic_cast<ir::undef_value*>(v)){
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
T->set_value(idx, idx[0]);
@@ -494,15 +494,21 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
std::vector<ir::value*> grids;
std::map<unsigned*, ir::value*> references;
create_grids(grids, references, fn);
for(ir::value* i: grids)
init_axes(i, builder, u_thread_warp_id, u_warp_id);
for(ir::value* i: grids){
if(auto *instr = dynamic_cast<ir::instruction*>(i))
for(unsigned r = 0; r < instr->get_num_results(); r++)
init_axes(instr->get_result(r), builder, u_thread_warp_id, u_warp_id);
else
init_axes(i, builder, u_thread_warp_id, u_warp_id);
}
// create tile
std::set<ir::value*> seen;
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list()){
if(!i->get_type()->is_tile_ty())
continue;
create_tile(i, builder, references, seen, sh_mem_ptr);
for(unsigned r = 0; r < i->get_num_results(); r++)
create_tile(i->get_result(r), builder, references, seen, sh_mem_ptr);
}
}
@@ -510,46 +516,43 @@ void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem
void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &builder) {
BasicBlock *block = builder.GetInsertBlock();
Module *module = block->getModule();
Function *function = block->getParent();
ir::value* mask_pred = ins->get_mask_pred();
LLVMContext &ctx = builder.getContext();
// helper to handle masks
auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
BasicBlock *block = builder.GetInsertBlock();
Value *result;
if(mask_pred){
// if(mask.else_value)
// std::cout << mask.else_value << std::endl;
Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
builder.CreateCondBr(llvm_mask, then_bb, done_bb);
builder.SetInsertPoint(then_bb);
result = insert_value();
builder.CreateBr(done_bb);
builder.SetInsertPoint(done_bb);
if(!ins->get_type()->is_void_ty()){
Type *ty = result->getType();
PHINode *phi = builder.CreatePHI(ty, 2);
// if(mask.else_value)
// phi->addIncoming(tmap_.at(mask.else_value)->get_value(idx), block);
// // helper to handle masks
// auto insert_masked = [&](indices_t idx, std::function<Value*()> insert_value) {
// BasicBlock *block = builder.GetInsertBlock();
// Value *result;
// if(mask_pred){
// Value *llvm_mask = tmap_.at(mask_pred)->get_value(idx);
// BasicBlock *then_bb = BasicBlock::Create(ctx, "", function);
// BasicBlock *done_bb = BasicBlock::Create(ctx, "", function);
// builder.CreateCondBr(llvm_mask, then_bb, done_bb);
// builder.SetInsertPoint(then_bb);
// result = insert_value();
// builder.CreateBr(done_bb);
// builder.SetInsertPoint(done_bb);
// if(!ins->get_type()->is_void_ty()){
// Type *ty = result->getType();
// PHINode *phi = builder.CreatePHI(ty, 2);
// if(mask_else)
// phi->addIncoming(tmap_.at(mask_else)->get_value(idx), block);
// else
phi->addIncoming(llvm::UndefValue::get(ty), block);
phi->addIncoming(result, then_bb);
return (Value*)phi;
}
}
else
result = insert_value();
return result;
};
// phi->addIncoming(llvm::UndefValue::get(ty), block);
// phi->addIncoming(result, then_bb);
// return (Value*)phi;
// }
// }
// else
// result = insert_value();
// return result;
// };
std::cout << ins->get_name() << " " << typeid(*ins).name() << std::endl;
// store
if(auto *x = dynamic_cast<ir::store_inst*>(ins)) {
distributed_tile* ptr = (distributed_tile*)tmap_.at(x->get_pointer_operand());
tile *value = tmap_.at(x->get_value_operand());
ptr->for_each([&](indices_t idx){
insert_masked(idx, [&]{ return builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); });
builder.CreateStore(value->get_value(idx), ptr->get_value(idx));
});
}
else {
@@ -570,9 +573,30 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Value *offset = builder.CreateMul(builder.getInt32(shapes[0]->get_value()), group_id);
result->for_each([&](indices_t idx){
BinaryOperator *bin = static_cast<BinaryOperator*>(idx[0]);
result->set_value(idx, insert_masked(idx, [&]{ return builder.CreateAdd(bin, offset); }));
result->set_value(idx, builder.CreateAdd(bin, offset));
});
}
// mask
else if(dynamic_cast<ir::mask_inst*>(ins)) {
// distributed_tile* pred = (distributed_tile*)ins->get_operand(0);
// BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done");
// pred->for_each([&](indices_t idx){
// BasicBlock *mask_if_bb = BasicBlock::Create(ctx, "mask_if");
// BasicBlock* mask_else_bb = BasicBlock::Create(ctx, "mask_else");
// builder.CreateCondBr(pred->get_value(idx), mask_if_bb, mask_else_bb);
// builder.SetInsertPoint(mask_if_bb);
// builder.CreateBr(mask_done_bb);
// builder.SetInsertPoint(mask_else_bb);
// builder.CreateBr(mask_done_bb);
// });
// builder.SetInsertPoint(mask_done_bb);
}
// merge
else if(dynamic_cast<ir::merge_inst*>(ins)) {
// result->for_each([&](indices_t idx){
// std::cout << "merge" << std::endl;
// });
}
// reshape
else if(dynamic_cast<ir::reshape_inst*>(ins)) {
ir::value* in = ins->get_operand(0);
@@ -589,7 +613,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// splat
else if(dynamic_cast<ir::splat_inst*>(ins)) {
result->for_each([&](indices_t idx) {
result->set_value(idx, insert_masked(idx, [&]{ return llvm_value(ins->get_operand(0), builder); }));
result->set_value(idx, llvm_value(ins->get_operand(0), builder));
});
}
// broadcast
@@ -667,7 +691,7 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
else
return llvm_value(x, builder);
};
result->set_value(idx, insert_masked(idx, [&]() { return llvm_inst(ins, value, builder); }));
result->set_value(idx, llvm_inst(ins, value, builder));
});
}
}

View File

@@ -70,15 +70,10 @@ void tune::init_c_graph(ir::instruction *v) {
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)){
for(unsigned i = 0; i < shapes.size(); i ++)
for(ir::value* op: v->ops())
add_constraint({v, i}, {op, i});
}
/* Add mask constraints */
if(ir::value *pred = v->get_mask_pred()){
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v->ops()[0], i}, {pred, i});
for(unsigned k = 0; k < v->get_num_results(); k++)
for(unsigned i = 0; i < shapes.size(); i ++)
for(ir::value* op: v->ops())
add_constraint({v->get_result(k), i}, {op, i});
}
}

View File

@@ -72,8 +72,12 @@ value *builder::create_ret_void() {
// tile-level control-flow instructions
//===----------------------------------------------------------------------===//
value *builder::create_ternary(value *cond, value *true_value, value *false_value, const std::string &name){
return insert(ternary_inst::create(cond, true_value, false_value, name));
value *builder::create_mask(value *pred, const std::string &name){
return insert(mask_inst::create(pred, name));
}
value *builder::create_merge(value *mask_true, value *value_true, value *mask_false, value *value_false, const std::string &name) {
return insert(merge_inst::create(mask_true, value_true, mask_false, value_false, name));
}

View File

@@ -11,7 +11,7 @@ namespace ir{
// instruction classes
//===----------------------------------------------------------------------===//
instruction::instruction(type *ty, unsigned num_ops, const std::string &name, instruction *next)
instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const std::string &name, instruction *next)
: user(ty, num_ops, name) {
if(next){
basic_block *block = next->get_parent();
@@ -19,6 +19,11 @@ instruction::instruction(type *ty, unsigned num_ops, const std::string &name, in
auto it = std::find(block->begin(), block->end(), next);
block->get_inst_list().insert(it, next);
}
if(num_results == 1)
results_.push_back(this);
else
for(unsigned i = 0; i < num_results; i++)
results_.push_back(new result_reference(this, i));
}
void instruction::erase_from_parent() {
@@ -32,12 +37,17 @@ bool instruction::has_tile_result_or_op() {
return result;
}
// result reference
result_reference::result_reference(instruction *ref, unsigned arg_id, const std::string &name)
: value(ref->get_type(), name), arg_id_(arg_id){ }
//===----------------------------------------------------------------------===//
// phi_node classes
//===----------------------------------------------------------------------===//
phi_node::phi_node(type *ty, unsigned num_reserved, std::string const &name, instruction *next)
: instruction(ty, 0, name, next) {
: instruction(ty, 0, 1, name, next) {
blocks_.reserve(num_reserved);
}
@@ -98,7 +108,7 @@ std::string binary_operator::repr_impl() const {
}
binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next)
: instruction(ty, 2, name, next), op_(op){
: instruction(ty, 2, 1, name, next), op_(op){
set_operand(0, lhs);
set_operand(1, rhs);
}
@@ -165,7 +175,7 @@ std::string cmp_inst::repr_impl() const {
}
cmp_inst::cmp_inst(type *ty, cmp_inst::pred_t pred, value *lhs, value *rhs, const std::string &name, instruction *next)
: instruction(ty, 2, name, next), pred_(pred) {
: instruction(ty, 2, 1, name, next), pred_(pred) {
set_operand(0, lhs);
set_operand(1, rhs);
}
@@ -205,7 +215,7 @@ fcmp_inst* fcmp_inst::create(pred_t pred, value *lhs, value *rhs, const std::str
//===----------------------------------------------------------------------===//
unary_inst::unary_inst(type *ty, value *v, const std::string &name, instruction *next)
: instruction(ty, 1, name, next) {
: instruction(ty, 1, 1, name, next) {
set_operand(0, v);
}
@@ -275,7 +285,7 @@ cast_inst *cast_inst::create_integer_cast(value *arg, type *ty, bool is_signed,
// return_inst
return_inst::return_inst(context &ctx, value *ret_val, instruction *next)
: terminator_inst(type::get_void_ty(ctx), ret_val!=nullptr, "", next){
: terminator_inst(type::get_void_ty(ctx), ret_val!=nullptr, 0, "", next){
if(ret_val)
set_operand(0, ret_val);
}
@@ -298,40 +308,54 @@ branch_inst* branch_inst::create(value *cond, basic_block *if_dst, basic_block *
// uncond_branch_inst
uncond_branch_inst::uncond_branch_inst(basic_block *dst, instruction *next)
: branch_inst(type::get_void_ty(dst->get_context()), 1, "", next){
: branch_inst(type::get_void_ty(dst->get_context()), 1, 0, "", next){
set_operand(0, dst);
}
// cond_branch_inst
cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next)
: branch_inst(type::get_void_ty(if_dst->get_context()), 3, "", next){
: branch_inst(type::get_void_ty(if_dst->get_context()), 3, 0, "", next){
assert(cond->get_type()->is_integer_ty(1) && "May only branch on boolean predicates!");
set_operand(0, if_dst);
set_operand(1, else_dst);
set_operand(2, cond);
}
// ternary_inst
ternary_inst::ternary_inst(value *cond, value *true_value, value *false_value, const std::string &name, instruction *next)
: instruction(true_value->get_type(), 3) {
assert(true_value->get_type() == false_value->get_type());
set_operand(0, cond);
set_operand(1, true_value);
set_operand(2, false_value);
// mask_inst
mask_inst::mask_inst(value *pred, const std::string &name, instruction *next)
: instruction(pred->get_type(), 1, 2, name, next) {
set_operand(0, pred);
}
ternary_inst *ternary_inst::create(value *cond, value *true_value, value *false_value,
const std::string &name, instruction *next) {
return new ternary_inst(cond, true_value, false_value, name, next);
mask_inst* mask_inst::create(value *pred, const std::string &name, instruction *next) {
return new mask_inst(pred, name, next);
}
// merge_inst
merge_inst::merge_inst(value *mask_true, value *value_true,
value *mask_false, value *value_false,
const std::string &name, instruction *next)
: instruction(value_true->get_type(), 4, 1, name, next) {
set_operand(0, mask_true);
set_operand(1, value_true);
set_operand(2, mask_false);
set_operand(3, value_false);
}
merge_inst* merge_inst::create(value *mask_true, value *value_true,
value *mask_false, value *value_false,
const std::string &name, instruction *next) {
return new merge_inst(mask_true, value_true, mask_false, value_false, name, next);
}
//===----------------------------------------------------------------------===//
// getelementptr_inst classes
//===----------------------------------------------------------------------===//
getelementptr_inst::getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value *> &idx, const std::string &name, instruction *next)
: instruction(get_return_type(pointee_ty, ptr, idx), 1 + idx.size(), name, next),
: instruction(get_return_type(pointee_ty, ptr, idx), 1 + idx.size(), 1, name, next),
source_elt_ty(pointee_ty),
res_elt_ty(get_indexed_type(pointee_ty, idx)){
type *expected_ty = ((pointer_type*)(get_type()->get_scalar_ty()))->get_element_ty();
@@ -407,7 +431,7 @@ load_inst* load_inst::create(value *ptr, const std::string &name, instruction *n
// store
store_inst::store_inst(value *ptr, value *v, const std::string &name, instruction *next)
: instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, name, next) {
: instruction(type::get_void_ty(ptr->get_type()->get_context()), 2, 1, name, next) {
set_operand(0, ptr);
set_operand(1, v);
}
@@ -465,7 +489,7 @@ instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shape
matmul_inst::matmul_inst(value *A, value *B, value *C,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), 3, name, next) {
: builtin_inst(C->get_type(), 3, 0, name, next) {
set_operand(0, A);
set_operand(1, B);
set_operand(2, C);
@@ -481,7 +505,7 @@ instruction *matmul_inst::create(value *A, value *B, value *C,
//===----------------------------------------------------------------------===//
get_global_range_inst::get_global_range_inst(type *ty, unsigned axis,
const std::string &name, instruction *next)
: builtin_inst(ty, 0, name, next), axis_(axis) {
: builtin_inst(ty, 0, 1, name, next), axis_(axis) {
}
@@ -506,7 +530,7 @@ vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, inst
barrier_inst::barrier_inst(context &ctx, const std::string &name,
instruction *next)
: instruction(type::get_void_ty(ctx), 0, name, next){ }
: instruction(type::get_void_ty(ctx), 0, 1, name, next){ }
barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) {
return new barrier_inst(ctx, name, next);

View File

@@ -37,13 +37,19 @@ void print(module &mod, std::ostream& os) {
os << " ";
if(ir::value *pred = inst->get_mask_pred())
os << "@" << get_name(pred, cnt++) << " ";
if(!inst->get_type()->is_void_ty())
os << get_name(inst, cnt++) << " = ";
unsigned num_results = inst->get_num_results();
for(unsigned i = 0; i < num_results; i++){
os << get_name(inst->get_result(i), cnt++);
if(i < num_results - 1)
os << ", ";
else
os << " = ";
}
os << inst->repr();
ir::instruction::ops_t ops = inst->ops();
size_t num_ops = inst->get_num_operands();
if(num_ops > 0)
os << " ";
os << " ";;
for(unsigned i = 0; i < num_ops; i++)
os << get_name(ops[i], cnt++) << (i < num_ops - 1?", ":"");
os << ";" << std::endl;