[code generation] more optimizations

This commit is contained in:
Philippe Tillet
2019-03-02 16:03:26 -05:00
parent 2467c5e504
commit 1f30e111ec
5 changed files with 68 additions and 51 deletions

View File

@@ -47,21 +47,21 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
int32 ryb[TN] = get_global_range[TN](1);\
int32 rka[TK] = 0 ... TK;\
int32 rkb[TK] = 0 ... TK;\
int32 rxc[TM] = get_global_range[TM](0);\
int32 ryc[TN] = get_global_range[TN](1);\
int32 rxc[TM];\
int32 ryc[TN];\
fp32 C[TM, TN] = 0;\
int32 k;\
fp32* pa[TM, TK] = a + rxa[:, newaxis] + rka[newaxis, :]*M;\
fp32* pb[TN, TK] = b + ryb[:, newaxis] + rkb[newaxis, :]*K;\
fp32* pc[TM, TN] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\
fp32* pc[TM, TN];\
fp32 a[TM, TK] = *pa;\
fp32 b[TN, TK] = *pb;\
int1 checkc0[TM];\
int1 checkc1[TN];\
int1 checkc[TM, TN];\
for(k = K; k > 0; k = k - TK){\
int1 checka[TM, TK] = (k > bound);\
int1 checkb[TN, TK] = (k > bound);\
int1 checka[TM, TK];\
int1 checkb[TN, TK];\
int1 checka0[TM];\
int1 checka1[TK];\
int1 checkb0[TN];\
@@ -69,6 +69,8 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
C = dot(a, b, C);\
pa = pa + TK*M;\
pb = pb + TK*K;\
checka = k > bound;\
checkb = k > bound;\
@checka a = *pa;\
@checkb b = *pb;\
if(k > bound)\
@@ -82,6 +84,9 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
a = checka ? *pa : 0;\
b = checkb ? *pb : 0;\
}\
rxc = get_global_range[TM](0);\
ryc = get_global_range[TN](1);\
pc = c + ryc[newaxis, :]*M + rxc[:, newaxis];\
checkc0 = rxc < M;\
checkc1 = ryc < N;\
checkc = checkc0[:, newaxis] && checkc1[newaxis, :];\
@@ -231,16 +236,15 @@ int main() {
2, 8, 1,
// b0
4, 4, 1,
// c0
2, 8, 1,
// c1
4, 4, 1,
// c
2, 4, 8, 4, 1, 1,
// a1
2, 4, 1,
// b1
1, 8, 1
};
// meta-parameters
unsigned i = 0;
context.p_impl->mp_constants_[0]->set_value(params[0]);
@@ -257,21 +261,20 @@ int main() {
std::cout << "errors: " << errors.size() << std::endl;
for(auto &x: errors){
for(auto &e: x.second)
std::cout << e << std::endl;
std::cout << x.first->get_name() << " " << e << std::endl;
}
if(errors.size())
exit(EXIT_FAILURE);
// run passes
triton::ir::print(module, std::cout);
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
vectorize.run(module);
triton::ir::print(module, std::cout);
selection.run(module, llvm_module);
// llvm source

View File

@@ -30,6 +30,10 @@ class constant;
class global_value;
/* Module */
struct scope {
std::map<std::string, ir::type*> types;
};
class module {
typedef std::pair<std::string, basic_block*> val_key_t;
friend class function;
@@ -56,15 +60,11 @@ public:
// Setters
void set_value(const std::string& name, basic_block* block, value *x);
void set_value(const std::string& name, value* x);
void set_type(const std::string& name, basic_block* block, type* x);
void set_type(const std::string& name, type* x);
void set_const(const std::string& name);
void set_continue_fn(std::function<ir::value*()> fn);
// Getters
value *get_value(const std::string& name, basic_block* block);
value *get_value(const std::string& name);
type *get_type(const std::string& name, basic_block* block);
type *get_type(const std::string& name);
std::function<ir::value*()> get_continue_fn();
// Seal block -- no more predecessors will be added
void seal_block(basic_block *block);
@@ -73,9 +73,9 @@ public:
functions_list_t &get_function_list() { return functions_; }
function *get_or_insert_function(const std::string &name, function_type *ty);
// Scope
void push_scope(const ast::compound_statement* scope) { scopes_.push(scope); }
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
void pop_scope() { scopes_.pop(); }
const ast::compound_statement* get_scope() { return scopes_.top(); }
scope& get_scope() { return scopes_.top(); }
private:
@@ -91,7 +91,7 @@ private:
symbols_map_t symbols_;
std::function<ir::value*()> continue_fn_;
std::map<value*, value**> current_phi_;
std::stack<const ast::compound_statement*> scopes_;
std::stack<scope> scopes_;
};
}

View File

@@ -149,7 +149,7 @@ inline bool is_terminator(ir::value* x) {
/* Translation unit */
ir::value* translation_unit::codegen(ir::module *mod) const{
mod->push_scope(nullptr);
mod->add_new_scope();
decls_.codegen(mod);
return nullptr;
}
@@ -242,7 +242,7 @@ void function::bind_parameters(ir::module *mod, ir::function *fn) const{
if(id_i){
args[i]->set_name(id_i->name());
mod->set_value(id_i->name(), nullptr, args[i]);
mod->set_type(id_i->name(), nullptr, args[i]->get_type());
mod->get_scope().types[id_i->name()] = args[i]->get_type();
}
}
}
@@ -285,7 +285,7 @@ ir::value* function_definition::codegen(ir::module *mod) const{
/* Statements */
ir::value* compound_statement::codegen(ir::module* mod) const{
mod->push_scope(this);
mod->add_new_scope();
if(decls_)
decls_->codegen(mod);
if(statements_){
@@ -422,7 +422,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
}
value->set_name(name);
mod->set_value(name, value);
mod->set_type(name, ty);
mod->get_scope().types[name] = ty;
if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end())
mod->set_const(name);
return value;
@@ -649,8 +649,12 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{
/* Assignment expression */
ir::value *assignment_expression::codegen(ir::module *mod) const{
ir::value *rvalue = rvalue_->codegen(mod);
if(auto *x = dynamic_cast<const named_expression*>(lvalue_))
if(auto *x = dynamic_cast<const named_expression*>(lvalue_)){
ir::type *ty = mod->get_scope().types.at(x->id()->name());
rvalue = explicit_cast(mod->get_builder(), rvalue, ty);
implicit_broadcast(mod, rvalue, ty);
mod->set_value(x->id()->name(), rvalue);
}
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
assert(x->get_op()==DEREF);
assert(x->lvalue());

View File

@@ -214,6 +214,38 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
throw std::runtime_error("unknown conversion from ir::constant to Constant");
}
inline Value *Reassociate(Value *V, IRBuilder<> &Builder){
BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V);
if(BinOp)
if(BinOp->getOpcode()==BinaryOperator::BinaryOps::Add){
Value *LHS = Reassociate(BinOp->getOperand(0), Builder);
Value *RHS = Reassociate(BinOp->getOperand(1), Builder);
if(BinaryOperator *BinLHS = dyn_cast<BinaryOperator>(LHS))
if(BinLHS->getOpcode()==BinaryOperator::BinaryOps::Add){
Value *LLHS = BinLHS->getOperand(0);
Value *RLHS = BinLHS->getOperand(1);
// (cst + x) + y -> cst + (x + y)
if(isa<Constant>(LLHS))
return Builder.CreateAdd(LLHS, Builder.CreateAdd(RLHS, RHS));
// (x + cst) + y -> cst + (x + y)
if(isa<Constant>(RLHS))
return Builder.CreateAdd(RLHS, Builder.CreateAdd(LLHS, RHS));
}
if(BinaryOperator *BinRHS = dyn_cast<BinaryOperator>(RHS))
if(BinRHS->getOpcode()==BinaryOperator::BinaryOps::Add){
Value *LRHS = BinRHS->getOperand(0);
Value *RRHS = BinRHS->getOperand(1);
// x + (cst + y) -> cst + (x + y)
if(isa<Constant>(LRHS))
return Builder.CreateAdd(LRHS, Builder.CreateAdd(RRHS, LHS));
// x + (cst + y) -> cst + (x + y)
if(isa<Constant>(LRHS))
return Builder.CreateAdd(RRHS, Builder.CreateAdd(LRHS, LHS));
}
return BinOp;
}
return V;
}
/* convert ir::instruction to llvm::Instruction */
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, IRBuilder<> &builder) {
@@ -271,8 +303,9 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
std::transform(ii->idx_begin(), ii->idx_end(), std::back_inserter(idx_vals),
[&value](ir::value* x){ return value(x);});
Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty());
idx_vals[0] = Reassociate(idx_vals[0], builder);
Value *arg = value(ii->get_operand(0));
return builder.Insert(GetElementPtrInst::Create(source_ty, arg, idx_vals));
return builder.Insert(GetElementPtrInst::CreateInBounds(source_ty, arg, idx_vals));
}
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
Value *ptr = value(ii->get_pointer_operand());

View File

@@ -29,14 +29,6 @@ void module::set_value(const std::string& name, ir::value *value){
return set_value(name, builder_.get_insert_block(), value);
}
void module::set_type(const std::string& name, ir::basic_block *block, ir::type *type){
types_[val_key_t{name, block}] = type;
}
void module::set_type(const std::string& name, ir::type *type){
return set_type(name, builder_.get_insert_block(), type);
}
void module::set_const(const std::string& name){
const_.insert(name);
}
@@ -97,7 +89,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
ir::value *result;
bool is_const = const_.find(name) != const_.end();
auto &preds = block->get_predecessors();
ir::type *ty = get_type(name, block);
ir::type *ty = get_scope().types.at(name);
if(block)
if(!is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
incomplete_phis_[block][name] = make_phi(ty, 1, block);
@@ -136,21 +128,6 @@ ir::value *module::get_value(const std::string& name) {
return get_value(name, builder_.get_insert_block());
}
ir::type *module::get_type(const std::string &name, basic_block *block) {
val_key_t key(name, block);
if(types_.find(key) != types_.end())
return types_.at(key);
assert(block);
const auto& predecessors = block->get_predecessors();
if(predecessors.empty())
return get_type(name, nullptr);
return get_type(name, predecessors[0]);
}
ir::type *module::get_type(const std::string &name) {
return types_.at({name, builder_.get_insert_block()});
}
void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){