diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 369d0ac3e..20a911387 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -47,21 +47,21 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32 int32 ryb[TN] = get_global_range[TN](1);\ int32 rka[TK] = 0 ... TK;\ int32 rkb[TK] = 0 ... TK;\ - int32 rxc[TM] = get_global_range[TM](0);\ - int32 ryc[TN] = get_global_range[TN](1);\ + int32 rxc[TM];\ + int32 ryc[TN];\ fp32 C[TM, TN] = 0;\ int32 k;\ - fp32* pa[TM, TK] = a + rxa[:, newaxis] + rka[newaxis, :]*M;\ - fp32* pb[TN, TK] = b + ryb[:, newaxis] + rkb[newaxis, :]*K;\ - fp32* pc[TM, TN] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\ + fp32* pa[TM, TK] = a + rka[newaxis, :]*M + rxa[:, newaxis];\ + fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis];\ + fp32* pc[TM, TN];\ fp32 a[TM, TK] = *pa;\ fp32 b[TN, TK] = *pb;\ int1 checkc0[TM];\ int1 checkc1[TN];\ int1 checkc[TM, TN];\ for(k = K; k > 0; k = k - TK){\ - int1 checka[TM, TK] = (k > bound);\ - int1 checkb[TN, TK] = (k > bound);\ + int1 checka[TM, TK];\ + int1 checkb[TN, TK];\ int1 checka0[TM];\ int1 checka1[TK];\ int1 checkb0[TN];\ @@ -69,6 +69,8 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32 C = dot(a, b, C);\ pa = pa + TK*M;\ pb = pb + TK*K;\ + checka = k > bound;\ + checkb = k > bound;\ @checka a = *pa;\ @checkb b = *pb;\ if(k > bound)\ @@ -82,6 +84,9 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32 a = checka ? *pa : 0;\ b = checkb ? *pb : 0;\ }\ + rxc = get_global_range[TM](0);\ + ryc = get_global_range[TN](1);\ + pc = c + ryc[newaxis, :]*M + rxc[:, newaxis];\ checkc0 = rxc < M;\ checkc1 = ryc < N;\ checkc = checkc0[:, newaxis] && checkc1[newaxis, :];\ @@ -231,16 +236,15 @@ int main() { 2, 8, 1, // b0 4, 4, 1, - // c0 - 2, 8, 1, - // c1 - 4, 4, 1, + // c + 2, 4, 8, 4, 1, 1, // a1 2, 4, 1, // b1 1, 8, 1 }; + // meta-parameters unsigned i = 0; context.p_impl->mp_constants_[0]->set_value(params[0]); @@ -257,21 +261,20 @@ int main() { std::cout << "errors: " << errors.size() << std::endl; for(auto &x: errors){ for(auto &e: x.second) - std::cout << e << std::endl; + std::cout << x.first->get_name() << " " << e << std::endl; } if(errors.size()) exit(EXIT_FAILURE); + // run passes - triton::ir::print(module, std::cout); buffer_info.run(module); shared.run(module); liveness.run(module); allocation.run(); barriers.run(module); vectorize.run(module); - triton::ir::print(module, std::cout); selection.run(module, llvm_module); // llvm source diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 633e356f2..e4026f8b6 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -30,6 +30,10 @@ class constant; class global_value; /* Module */ +struct scope { + std::map types; +}; + class module { typedef std::pair val_key_t; friend class function; @@ -56,15 +60,11 @@ public: // Setters void set_value(const std::string& name, basic_block* block, value *x); void set_value(const std::string& name, value* x); - void set_type(const std::string& name, basic_block* block, type* x); - void set_type(const std::string& name, type* x); void set_const(const std::string& name); void set_continue_fn(std::function fn); // Getters value *get_value(const std::string& name, basic_block* block); value *get_value(const std::string& name); - type *get_type(const std::string& name, basic_block* block); - type *get_type(const std::string& name); std::function get_continue_fn(); // Seal block -- no more predecessors will be added void seal_block(basic_block *block); @@ -73,9 +73,9 @@ public: functions_list_t &get_function_list() { return functions_; } function *get_or_insert_function(const std::string &name, function_type *ty); // Scope - void push_scope(const ast::compound_statement* scope) { scopes_.push(scope); } + void add_new_scope() { if(scopes_.empty()) scopes_.push(scope()); else scopes_.push(scope(get_scope())); } void pop_scope() { scopes_.pop(); } - const ast::compound_statement* get_scope() { return scopes_.top(); } + scope& get_scope() { return scopes_.top(); } private: @@ -91,7 +91,7 @@ private: symbols_map_t symbols_; std::function continue_fn_; std::map current_phi_; - std::stack scopes_; + std::stack scopes_; }; } diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 9745ff0e2..d70955d98 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -149,7 +149,7 @@ inline bool is_terminator(ir::value* x) { /* Translation unit */ ir::value* translation_unit::codegen(ir::module *mod) const{ - mod->push_scope(nullptr); + mod->add_new_scope(); decls_.codegen(mod); return nullptr; } @@ -242,7 +242,7 @@ void function::bind_parameters(ir::module *mod, ir::function *fn) const{ if(id_i){ args[i]->set_name(id_i->name()); mod->set_value(id_i->name(), nullptr, args[i]); - mod->set_type(id_i->name(), nullptr, args[i]->get_type()); + mod->get_scope().types[id_i->name()] = args[i]->get_type(); } } } @@ -285,7 +285,7 @@ ir::value* function_definition::codegen(ir::module *mod) const{ /* Statements */ ir::value* compound_statement::codegen(ir::module* mod) const{ - mod->push_scope(this); + mod->add_new_scope(); if(decls_) decls_->codegen(mod); if(statements_){ @@ -422,7 +422,7 @@ ir::value* initializer::codegen(ir::module * mod) const{ } value->set_name(name); mod->set_value(name, value); - mod->set_type(name, ty); + mod->get_scope().types[name] = ty; if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end()) mod->set_const(name); return value; @@ -649,8 +649,12 @@ ir::value *conditional_expression::codegen(ir::module *mod) const{ /* Assignment expression */ ir::value *assignment_expression::codegen(ir::module *mod) const{ ir::value *rvalue = rvalue_->codegen(mod); - if(auto *x = dynamic_cast(lvalue_)) + if(auto *x = dynamic_cast(lvalue_)){ + ir::type *ty = mod->get_scope().types.at(x->id()->name()); + rvalue = explicit_cast(mod->get_builder(), rvalue, ty); + implicit_broadcast(mod, rvalue, ty); mod->set_value(x->id()->name(), rvalue); + } else if(auto* x = dynamic_cast(lvalue_)){ assert(x->get_op()==DEREF); assert(x->lvalue()); diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index d0f0a1310..32a713428 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -214,6 +214,38 @@ Constant *selection::llvm_constant(ir::constant *cst, LLVMContext &ctx) { throw std::runtime_error("unknown conversion from ir::constant to Constant"); } +inline Value *Reassociate(Value *V, IRBuilder<> &Builder){ + BinaryOperator *BinOp = dyn_cast(V); + if(BinOp) + if(BinOp->getOpcode()==BinaryOperator::BinaryOps::Add){ + Value *LHS = Reassociate(BinOp->getOperand(0), Builder); + Value *RHS = Reassociate(BinOp->getOperand(1), Builder); + if(BinaryOperator *BinLHS = dyn_cast(LHS)) + if(BinLHS->getOpcode()==BinaryOperator::BinaryOps::Add){ + Value *LLHS = BinLHS->getOperand(0); + Value *RLHS = BinLHS->getOperand(1); + // (cst + x) + y -> cst + (x + y) + if(isa(LLHS)) + return Builder.CreateAdd(LLHS, Builder.CreateAdd(RLHS, RHS)); + // (x + cst) + y -> cst + (x + y) + if(isa(RLHS)) + return Builder.CreateAdd(RLHS, Builder.CreateAdd(LLHS, RHS)); + } + if(BinaryOperator *BinRHS = dyn_cast(RHS)) + if(BinRHS->getOpcode()==BinaryOperator::BinaryOps::Add){ + Value *LRHS = BinRHS->getOperand(0); + Value *RRHS = BinRHS->getOperand(1); + // x + (cst + y) -> cst + (x + y) + if(isa(LRHS)) + return Builder.CreateAdd(LRHS, Builder.CreateAdd(RRHS, LHS)); + // x + (cst + y) -> cst + (x + y) + if(isa(LRHS)) + return Builder.CreateAdd(RRHS, Builder.CreateAdd(LRHS, LHS)); + } + return BinOp; + } + return V; +} /* convert ir::instruction to llvm::Instruction */ Instruction *selection::llvm_inst(ir::instruction *inst, std::function value, IRBuilder<> &builder) { @@ -271,8 +303,9 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::functionidx_begin(), ii->idx_end(), std::back_inserter(idx_vals), [&value](ir::value* x){ return value(x);}); Type *source_ty = type(ii->get_source_elt_ty()->get_scalar_ty()); + idx_vals[0] = Reassociate(idx_vals[0], builder); Value *arg = value(ii->get_operand(0)); - return builder.Insert(GetElementPtrInst::Create(source_ty, arg, idx_vals)); + return builder.Insert(GetElementPtrInst::CreateInBounds(source_ty, arg, idx_vals)); } if(ir::load_inst* ii = dynamic_cast(inst)){ Value *ptr = value(ii->get_pointer_operand()); diff --git a/lib/ir/module.cpp b/lib/ir/module.cpp index e5764d010..14f1337e1 100644 --- a/lib/ir/module.cpp +++ b/lib/ir/module.cpp @@ -29,14 +29,6 @@ void module::set_value(const std::string& name, ir::value *value){ return set_value(name, builder_.get_insert_block(), value); } -void module::set_type(const std::string& name, ir::basic_block *block, ir::type *type){ - types_[val_key_t{name, block}] = type; -} - -void module::set_type(const std::string& name, ir::type *type){ - return set_type(name, builder_.get_insert_block(), type); -} - void module::set_const(const std::string& name){ const_.insert(name); } @@ -97,7 +89,7 @@ ir::value *module::get_value_recursive(const std::string& name, ir::basic_block ir::value *result; bool is_const = const_.find(name) != const_.end(); auto &preds = block->get_predecessors(); - ir::type *ty = get_type(name, block); + ir::type *ty = get_scope().types.at(name); if(block) if(!is_const && sealed_blocks_.find(block) == sealed_blocks_.end()){ incomplete_phis_[block][name] = make_phi(ty, 1, block); @@ -136,21 +128,6 @@ ir::value *module::get_value(const std::string& name) { return get_value(name, builder_.get_insert_block()); } -ir::type *module::get_type(const std::string &name, basic_block *block) { - val_key_t key(name, block); - if(types_.find(key) != types_.end()) - return types_.at(key); - assert(block); - const auto& predecessors = block->get_predecessors(); - if(predecessors.empty()) - return get_type(name, nullptr); - return get_type(name, predecessors[0]); -} - -ir::type *module::get_type(const std::string &name) { - return types_.at({name, builder_.get_insert_block()}); -} - void module::seal_block(ir::basic_block *block){ for(auto &x: incomplete_phis_[block]){