[code generation] more optimizations
This commit is contained in:
@@ -47,21 +47,21 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
|
||||
int32 ryb[TN] = get_global_range[TN](1);\
|
||||
int32 rka[TK] = 0 ... TK;\
|
||||
int32 rkb[TK] = 0 ... TK;\
|
||||
int32 rxc[TM] = get_global_range[TM](0);\
|
||||
int32 ryc[TN] = get_global_range[TN](1);\
|
||||
int32 rxc[TM];\
|
||||
int32 ryc[TN];\
|
||||
fp32 C[TM, TN] = 0;\
|
||||
int32 k;\
|
||||
fp32* pa[TM, TK] = a + rxa[:, newaxis] + rka[newaxis, :]*M;\
|
||||
fp32* pb[TN, TK] = b + ryb[:, newaxis] + rkb[newaxis, :]*K;\
|
||||
fp32* pc[TM, TN] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
|
||||
fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\
|
||||
fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\
|
||||
fp32* pc[TM, TN];\
|
||||
fp32 a[TM, TK] = *pa;\
|
||||
fp32 b[TN, TK] = *pb;\
|
||||
int1 checkc0[TM];\
|
||||
int1 checkc1[TN];\
|
||||
int1 checkc[TM, TN];\
|
||||
for(k = K; k > 0; k = k - TK){\
|
||||
int1 checka[TM, TK] = (k > bound);\
|
||||
int1 checkb[TN, TK] = (k > bound);\
|
||||
int1 checka[TM, TK];\
|
||||
int1 checkb[TN, TK];\
|
||||
int1 checka0[TM];\
|
||||
int1 checka1[TK];\
|
||||
int1 checkb0[TN];\
|
||||
@@ -69,6 +69,8 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + TK*M;\
|
||||
pb = pb + TK*K;\
|
||||
checka = k > bound;\
|
||||
checkb = k > bound;\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
if(k > bound)\
|
||||
@@ -82,6 +84,9 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
|
||||
a = checka ? *pa : 0;\
|
||||
b = checkb ? *pb : 0;\
|
||||
}\
|
||||
rxc = get_global_range[TM](0);\
|
||||
ryc = get_global_range[TN](1);\
|
||||
pc = c + ryc[newaxis, :]*M + rxc[:, newaxis];\
|
||||
checkc0 = rxc < M;\
|
||||
checkc1 = ryc < N;\
|
||||
checkc = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||
@@ -231,16 +236,15 @@ int main() {
|
||||
2, 8, 1,
|
||||
// b0
|
||||
4, 4, 1,
|
||||
// c0
|
||||
2, 8, 1,
|
||||
// c1
|
||||
4, 4, 1,
|
||||
// c
|
||||
2, 4, 8, 4, 1, 1,
|
||||
// a1
|
||||
2, 4, 1,
|
||||
// b1
|
||||
1, 8, 1
|
||||
};
|
||||
|
||||
|
||||
// meta-parameters
|
||||
unsigned i = 0;
|
||||
context.p_impl->mp_constants_[0]->set_value(params[0]);
|
||||
@@ -257,21 +261,20 @@ int main() {
|
||||
std::cout << "errors: " << errors.size() << std::endl;
|
||||
for(auto &x: errors){
|
||||
for(auto &e: x.second)
|
||||
std::cout << e << std::endl;
|
||||
std::cout << x.first->get_name() << " " << e << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
|
||||
|
||||
// run passes
|
||||
triton::ir::print(module, std::cout);
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
triton::ir::print(module, std::cout);
|
||||
selection.run(module, llvm_module);
|
||||
|
||||
// llvm source
|
||||
|
@@ -30,6 +30,10 @@ class constant;
|
||||
class global_value;
|
||||
|
||||
/* Module */
|
||||
struct scope {
|
||||
std::map<std::string, ir::type*> types;
|
||||
};
|
||||
|
||||
class module {
|
||||
typedef std::pair<std::string, basic_block*> val_key_t;
|
||||
friend class function;
|
||||
@@ -56,15 +60,11 @@ public:
|
||||
// Setters
|
||||
void set_value(const std::string& name, basic_block* block, value *x);
|
||||
void set_value(const std::string& name, value* x);
|
||||
void set_type(const std::string& name, basic_block* block, type* x);
|
||||
void set_type(const std::string& name, type* x);
|
||||
void set_const(const std::string& name);
|
||||
void set_continue_fn(std::function<ir::value*()> fn);
|
||||
// Getters
|
||||
value *get_value(const std::string& name, basic_block* block);
|
||||
value *get_value(const std::string& name);
|
||||
type *get_type(const std::string& name, basic_block* block);
|
||||
type *get_type(const std::string& name);
|
||||
std::function<ir::value*()> get_continue_fn();
|
||||
// Seal block -- no more predecessors will be added
|
||||
void seal_block(basic_block *block);
|
||||
@@ -73,9 +73,9 @@ public:
|
||||
functions_list_t &get_function_list() { return functions_; }
|
||||
function *get_or_insert_function(const std::string &name, function_type *ty);
|
||||
// Scope
|
||||
void push_scope(const ast::compound_statement* scope) { scopes_.push(scope); }
|
||||
void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); }
|
||||
void pop_scope() { scopes_.pop(); }
|
||||
const ast::compound_statement* get_scope() { return scopes_.top(); }
|
||||
scope& get_scope() { return scopes_.top(); }
|
||||
|
||||
|
||||
private:
|
||||
@@ -91,7 +91,7 @@ private:
|
||||
symbols_map_t symbols_;
|
||||
std::function<ir::value*()> continue_fn_;
|
||||
std::map<value*, value**> current_phi_;
|
||||
std::stack<const ast::compound_statement*> scopes_;
|
||||
std::stack<scope> scopes_;
|
||||
};
|
||||
|
||||
}
|
||||
|
@@ -149,7 +149,7 @@ inline bool is_terminator(ir::value* x) {
|
||||
|
||||
/* Translation unit */
|
||||
ir::value* translation_unit::codegen(ir::module *mod) const{
|
||||
mod->push_scope(nullptr);
|
||||
mod->add_new_scope();
|
||||
decls_.codegen(mod);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -242,7 +242,7 @@ void function::bind_parameters(ir::module *mod, ir::function *fn) const{
|
||||
if(id_i){
|
||||
args[i]->set_name(id_i->name());
|
||||
mod->set_value(id_i->name(), nullptr, args[i]);
|
||||
mod->set_type(id_i->name(), nullptr, args[i]->get_type());
|
||||
mod->get_scope().types[id_i->name()] = args[i]->get_type();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -285,7 +285,7 @@ ir::value* function_definition::codegen(ir::module *mod) const{
|
||||
|
||||
/* Statements */
|
||||
ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||
mod->push_scope(this);
|
||||
mod->add_new_scope();
|
||||
if(decls_)
|
||||
decls_->codegen(mod);
|
||||
if(statements_){
|
||||
@@ -422,7 +422,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
}
|
||||
value->set_name(name);
|
||||
mod->set_value(name, value);
|
||||
mod->set_type(name, ty);
|
||||
mod->get_scope().types[name] = ty;
|
||||
if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end())
|
||||
mod->set_const(name);
|
||||
return value;
|
||||
@@ -649,8 +649,12 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{
|
||||
/* Assignment expression */
|
||||
ir::value *assignment_expression::codegen(ir::module *mod) const{
|
||||
ir::value *rvalue = rvalue_->codegen(mod);
|
||||
if(auto *x = dynamic_cast<const named_expression*>(lvalue_))
|
||||
if(auto *x = dynamic_cast<const named_expression*>(lvalue_)){
|
||||
ir::type *ty = mod->get_scope().types.at(x->id()->name());
|
||||
rvalue = explicit_cast(mod->get_builder(), rvalue, ty);
|
||||
implicit_broadcast(mod, rvalue, ty);
|
||||
mod->set_value(x->id()->name(), rvalue);
|
||||
}
|
||||
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
|
||||
assert(x->get_op()==DEREF);
|
||||
assert(x->lvalue());
|
||||
|
@@ -214,6 +214,38 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) {
|
||||
throw std::runtime_error("unknown conversion from ir::constant to Constant");
|
||||
}
|
||||
|
||||
inline Value *Reassociate(Value *V, IRBuilder<> &Builder){
|
||||
BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V);
|
||||
if(BinOp)
|
||||
if(BinOp->getOpcode()==BinaryOperator::BinaryOps::Add){
|
||||
Value *LHS = Reassociate(BinOp->getOperand(0), Builder);
|
||||
Value *RHS = Reassociate(BinOp->getOperand(1), Builder);
|
||||
if(BinaryOperator *BinLHS = dyn_cast<BinaryOperator>(LHS))
|
||||
if(BinLHS->getOpcode()==BinaryOperator::BinaryOps::Add){
|
||||
Value *LLHS = BinLHS->getOperand(0);
|
||||
Value *RLHS = BinLHS->getOperand(1);
|
||||
// (cst + x) + y -> cst + (x + y)
|
||||
if(isa<Constant>(LLHS))
|
||||
return Builder.CreateAdd(LLHS, Builder.CreateAdd(RLHS, RHS));
|
||||
// (x + cst) + y -> cst + (x + y)
|
||||
if(isa<Constant>(RLHS))
|
||||
return Builder.CreateAdd(RLHS, Builder.CreateAdd(LLHS, RHS));
|
||||
}
|
||||
if(BinaryOperator *BinRHS = dyn_cast<BinaryOperator>(RHS))
|
||||
if(BinRHS->getOpcode()==BinaryOperator::BinaryOps::Add){
|
||||
Value *LRHS = BinRHS->getOperand(0);
|
||||
Value *RRHS = BinRHS->getOperand(1);
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(isa<Constant>(LRHS))
|
||||
return Builder.CreateAdd(LRHS, Builder.CreateAdd(RRHS, LHS));
|
||||
// x + (cst + y) -> cst + (x + y)
|
||||
if(isa<Constant>(LRHS))
|
||||
return Builder.CreateAdd(RRHS, Builder.CreateAdd(LRHS, LHS));
|
||||
}
|
||||
return BinOp;
|
||||
}
|
||||
return V;
|
||||
}
|
||||
|
||||
/* convert ir::instruction to llvm::Instruction */
|
||||
Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir::value*)> value, IRBuilder<> &builder) {
|
||||
@@ -271,8 +303,9 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
std::transform(ii->idx_begin(), ii->idx_end(), std::back_inserter(idx_vals),
|
||||
[&value](ir::value* x){ return value(x);});
|
||||
Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty());
|
||||
idx_vals[0] = Reassociate(idx_vals[0], builder);
|
||||
Value *arg = value(ii->get_operand(0));
|
||||
return builder.Insert(GetElementPtrInst::Create(source_ty, arg, idx_vals));
|
||||
return builder.Insert(GetElementPtrInst::CreateInBounds(source_ty, arg, idx_vals));
|
||||
}
|
||||
if(ir::load_inst* ii = dynamic_cast<ir::load_inst*>(inst)){
|
||||
Value *ptr = value(ii->get_pointer_operand());
|
||||
|
@@ -29,14 +29,6 @@ void module::set_value(const std::string& name, ir::value *value){
|
||||
return set_value(name, builder_.get_insert_block(), value);
|
||||
}
|
||||
|
||||
void module::set_type(const std::string& name, ir::basic_block *block, ir::type *type){
|
||||
types_[val_key_t{name, block}] = type;
|
||||
}
|
||||
|
||||
void module::set_type(const std::string& name, ir::type *type){
|
||||
return set_type(name, builder_.get_insert_block(), type);
|
||||
}
|
||||
|
||||
void module::set_const(const std::string& name){
|
||||
const_.insert(name);
|
||||
}
|
||||
@@ -97,7 +89,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block
|
||||
ir::value *result;
|
||||
bool is_const = const_.find(name) != const_.end();
|
||||
auto &preds = block->get_predecessors();
|
||||
ir::type *ty = get_type(name, block);
|
||||
ir::type *ty = get_scope().types.at(name);
|
||||
if(block)
|
||||
if(!is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){
|
||||
incomplete_phis_[block][name] = make_phi(ty, 1, block);
|
||||
@@ -136,21 +128,6 @@ ir::value *module::get_value(const std::string& name) {
|
||||
return get_value(name, builder_.get_insert_block());
|
||||
}
|
||||
|
||||
ir::type *module::get_type(const std::string &name, basic_block *block) {
|
||||
val_key_t key(name, block);
|
||||
if(types_.find(key) != types_.end())
|
||||
return types_.at(key);
|
||||
assert(block);
|
||||
const auto& predecessors = block->get_predecessors();
|
||||
if(predecessors.empty())
|
||||
return get_type(name, nullptr);
|
||||
return get_type(name, predecessors[0]);
|
||||
}
|
||||
|
||||
ir::type *module::get_type(const std::string &name) {
|
||||
return types_.at({name, builder_.get_insert_block()});
|
||||
}
|
||||
|
||||
|
||||
void module::seal_block(ir::basic_block *block){
|
||||
for(auto &x: incomplete_phis_[block]){
|
||||
|
Reference in New Issue
Block a user