diff --git a/examples/matrix.cpp b/examples/matrix.cpp index 624872f9c..16f5c4434 100644 --- a/examples/matrix.cpp +++ b/examples/matrix.cpp @@ -21,11 +21,10 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, fp32* pb[TN, TK] = b + rkb[newaxis, :]*K + ryb[:, newaxis]; fp32 a[TM, TK] = *pa; fp32 b[TN, TK] = *pb; - for(int32 k = K; k > 0;){ + for(int32 k = K; k > 0; k = k - TK){ C = dot(a, b, C); pa = pa + TK*M; pb = pb + TK*K; - k = k - TK; a = *pa; b = *pb; } @@ -164,7 +163,7 @@ int main() { }; // params = {8, 2, 64, 16, 2, 64, 4, 16, 2, 2, 8, 8, 4}; - jit.autotune(src, benchmark); +// jit.autotune(src, benchmark); jit.add_module(src, params); triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); diff --git a/include/triton/codegen/barriers.h b/include/triton/codegen/barriers.h index 546b36893..336ec255a 100644 --- a/include/triton/codegen/barriers.h +++ b/include/triton/codegen/barriers.h @@ -26,13 +26,14 @@ private: typedef std::vector interval_vec_t; private: + interval_vec_t join(const std::vector& intervals); void insert_barrier(ir::instruction *instr, ir::builder &builder); bool intersect(const interval_vec_t &X, interval_t x); bool intersect(const interval_vec_t &X, const interval_vec_t &Y); void add_reference(ir::value *v, interval_vec_t &res); void get_read_intervals(ir::instruction *i, interval_vec_t &res); void get_written_intervals(ir::instruction *i, interval_vec_t &res); - void add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder); + std::pair transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set &insert_loc); public: barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} diff --git a/include/triton/codegen/buffer_info.h b/include/triton/codegen/buffer_info.h index c9b954a58..58f140d61 100644 --- a/include/triton/codegen/buffer_info.h +++ b/include/triton/codegen/buffer_info.h @@ -19,9 +19,11 @@ public: void run(ir::module &mod); // queries bool is_double(ir::value *x); + void add_shared(ir::value *v); bool is_shared(ir::value *x); bool is_loop_latch(ir::phi_node *phi, ir::value *terminator); ir::value *get_reference(ir::value *x); + void replace(ir::value* before, ir::value *after); private: diff --git a/include/triton/ir/basic_block.h b/include/triton/ir/basic_block.h index 63de2a18b..09eb3ad64 100644 --- a/include/triton/ir/basic_block.h +++ b/include/triton/ir/basic_block.h @@ -58,6 +58,7 @@ public: // predecessors const std::vector& get_predecessors() const { return preds_; } + const std::vector& get_successors() const { return succs_; } void add_predecessor(basic_block* pred); // factory functions @@ -68,6 +69,7 @@ private: std::string name_; function *parent_; std::vector preds_; + std::vector succs_; inst_list_t inst_list_; }; diff --git a/include/triton/jit.h b/include/triton/jit.h index 93f08f280..b53884e36 100644 --- a/include/triton/jit.h +++ b/include/triton/jit.h @@ -5,6 +5,7 @@ #include #include "llvm/IR/LLVMContext.h" #include "triton/ir/context.h" +#include "triton/ir/print.h" #include "triton/driver/module.h" #include "triton/driver/kernel.h" #include "triton/codegen/selection.h" @@ -54,10 +55,12 @@ public: // generate ptx buffer_info.run(module); shared.run(module); + triton::ir::print(module, std::cout); liveness.run(module); allocation.run(); barriers.run(module); vectorize.run(module); + triton::ir::print(module, std::cout); } codegen::tune tune; diff --git a/lib/codegen/allocation.cpp b/lib/codegen/allocation.cpp index 9a3d5e39d..c8ce9f60c 100644 --- a/lib/codegen/allocation.cpp +++ b/lib/codegen/allocation.cpp @@ -29,7 +29,7 @@ void allocation::run(){ std::vector J = I; triples_map_type H; - H.insert({0, segment{0, 100}}); + H.insert({0, segment{0, 1024}}); std::vector V; std::map starts; @@ -116,6 +116,9 @@ void allocation::run(){ for(auto &x: offsets_){ allocated_size_ = std::max(allocated_size_, x.second + get_num_bytes(x.first)); } + std::cout << "Allocated: " << allocated_size_ << std::endl; + for(auto &x: offsets_) + std::cout << x.first->get_name() << " " << x.second << std::endl; } } diff --git a/lib/codegen/barriers.cpp b/lib/codegen/barriers.cpp index b84a945d8..d7b126ee0 100644 --- a/lib/codegen/barriers.cpp +++ b/lib/codegen/barriers.cpp @@ -6,6 +6,7 @@ #include "triton/ir/function.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" +#include "triton/ir/cfg.h" namespace triton { @@ -62,27 +63,76 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) { } } -void barriers::add(ir::basic_block *block, interval_vec_t ¬_synced, ir::builder &builder) { +barriers::interval_vec_t barriers::join(const std::vector& intervals) { + barriers::interval_vec_t result; + for(auto x: intervals) + for(interval_t i: x) + result.push_back(i); + return result; +} + +std::pair barriers::transfer(ir::basic_block *block, + const interval_vec_t &written_to, + const interval_vec_t &read_from, + std::set& insert_loc) { ir::basic_block::inst_list_t instructions = block->get_inst_list(); + interval_vec_t new_written_to = written_to; + interval_vec_t new_read_from = read_from; for(ir::instruction *i: instructions){ interval_vec_t read, written; get_read_intervals(i, read); get_written_intervals(i, written); - if(intersect(not_synced, read)) { - not_synced.clear(); - insert_barrier(i, builder); + bool read_while_written = intersect(new_written_to, read); + bool written_while_read = intersect(new_read_from, written); + // double buffering: write and phi-node read won't intersect + if(dynamic_cast(i) && + buffer_info_->is_double(buffer_info_->get_reference(i))) + written_while_read = false; + if(read_while_written || written_while_read) { + insert_loc.insert(i); + new_written_to.clear(); + new_read_from.clear(); } - std::copy(written.begin(), written.end(), std::back_inserter(not_synced)); + std::copy(written.begin(), written.end(), std::back_inserter(new_written_to)); + std::copy(read.begin(), read.end(), std::back_inserter(new_read_from)); } + return std::make_pair(new_written_to, new_read_from); } void barriers::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); for(ir::function *fn: mod.get_function_list()){ - // find barrier location - interval_vec_t not_synced; - for(ir::basic_block *block: fn->blocks()) - add(block, not_synced, builder); + std::vector rpo = ir::cfg::reverse_post_order(fn); + std::map written_to; + std::map read_from; + std::set insert_locs; + size_t n_inserted_im1 = 0; + bool done = false; + do{ + // find barrier location + for(ir::basic_block *block: rpo){ + // written to + std::vector pred_written_to; + for(ir::basic_block* pred: block->get_predecessors()) + pred_written_to.push_back(written_to[pred]); + // read from + std::vector pred_read_from; + for(ir::basic_block* pred: block->get_predecessors()) + pred_read_from.push_back(read_from[pred]); + // apply transfer function + auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs); + written_to[block] = result.first; + read_from[block] = result.second; + } + size_t n_inserted_i = insert_locs.size(); + done = (n_inserted_im1 == n_inserted_i); + n_inserted_im1 = n_inserted_i; + }while(!done); + for(ir::instruction* i: insert_locs){ + std::cout << i->get_name() << std::endl; + insert_barrier(i, builder); + } } } diff --git a/lib/codegen/buffer_info.cpp b/lib/codegen/buffer_info.cpp index 4d2a3c676..dff371a64 100644 --- a/lib/codegen/buffer_info.cpp +++ b/lib/codegen/buffer_info.cpp @@ -21,6 +21,16 @@ bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){ throw std::runtime_error("unreachable"); } +void buffer_info_pass::replace(ir::value* before, ir::value *after) { + shared_.erase(before); + shared_.insert(after); + if(refs_.find(before) != refs_.end()){ + ir::value* v = refs_.at(before); + refs_.erase(before); + refs_.insert({after, v}); + } +} + void buffer_info_pass::run(ir::module &mod) { // Find which buffers are shared for(ir::function *fn: mod.get_function_list()) diff --git a/lib/codegen/liveness.cpp b/lib/codegen/liveness.cpp index 5e1987b9e..c7c067052 100644 --- a/lib/codegen/liveness.cpp +++ b/lib/codegen/liveness.cpp @@ -11,30 +11,43 @@ namespace codegen{ // Entry point -void liveness::run(ir::module &mod) { -for(ir::function *fn: mod.get_function_list()){ - // Assigns index to each instruction - slot_index index = 0; - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *instr: block->get_inst_list()){ - index += 1; - indices_.insert({instr, index}); - } - // Liveness analysis - // Creates live intervals - for(auto i: indices_){ - ir::value *v = i.first; - if(!info_->is_shared(v) || info_->get_reference(v)) - continue; - unsigned start = i.second; - unsigned end = start; - for(ir::value *u: v->get_users()){ - start = std::min(start, indices_.at(u)); - end = std::max(end, indices_.at(u)); - } - intervals_[v] = segment{start, end}; +inline bool is_shared(ir::value* v) { + if(auto x = dynamic_cast(v)) + return true; + if(auto x = dynamic_cast(v)){ + bool res = true; + for(unsigned inc = 0; inc < x->get_num_incoming(); inc++) + res = res && is_shared(x->get_incoming_value(inc)); + return res; } + return false; } + +void liveness::run(ir::module &mod) { + for(ir::function *fn: mod.get_function_list()){ + // Assigns index to each instruction + slot_index index = 0; + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *instr: block->get_inst_list()){ + index += 1; + indices_.insert({instr, index}); + } + // Liveness analysis + // Creates live intervals + for(auto i: indices_){ + ir::value *v = i.first; + if(!info_->is_shared(v) || info_->get_reference(v)) + continue; + unsigned start = i.second; + unsigned end = start; + for(ir::value *u: v->get_users()){ + start = std::min(start, indices_.at(u)); + end = std::max(end, indices_.at(u)); + } + intervals_[v] = segment{start, end}; + } + std::cout << "Number of intervals: " << intervals_.size() << std::endl; + } } } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index aff4dfbff..7810e6540 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -748,8 +748,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & indices_t b_idx = {idx[1], builder.getInt32(K)}; Value *a = TA->get_value(a_idx); Value *b = TB->get_value(b_idx); -// a = ConstantFP::get(builder.getFloatTy(), 1); -// b = ConstantFP::get(builder.getFloatTy(), 1); res = builder.CreateCall(f_mul_add, {a, b, res}); } result->set_value(idx, res); @@ -846,6 +844,7 @@ void selection::run(ir::module &src, Module &dst) { // create grids init_grids(fn, dst_builder, sh_mem_ptr); + // iterate through block std::map last_block; for(ir::basic_block *block: fn->blocks()) { @@ -854,10 +853,10 @@ void selection::run(ir::module &src, Module &dst) { for(ir::instruction *i: block->get_inst_list()){ BasicBlock *current = dst_builder.GetInsertBlock(); bool phi_inserted = (dynamic_cast(i) || dynamic_cast(i)) && !current->empty(); - if(phi_inserted) - dst_builder.SetInsertPoint(&*current->getFirstInsertionPt()); + if(phi_inserted && current->getFirstNonPHI()) + dst_builder.SetInsertPoint(&*current->getFirstNonPHI()); lower_instruction(i, dst_builder); - if(phi_inserted) + if(phi_inserted && current->getFirstNonPHI()) dst_builder.SetInsertPoint(current); last_block[block] = dst_builder.GetInsertBlock(); } diff --git a/lib/codegen/shared_copy.cpp b/lib/codegen/shared_copy.cpp index ce6f53fbe..6c05b7807 100644 --- a/lib/codegen/shared_copy.cpp +++ b/lib/codegen/shared_copy.cpp @@ -28,6 +28,12 @@ void place_shared_copy::run(ir::module &mod) { for(ir::instruction *i: block->get_inst_list()) if(info_->is_shared(i) && !info_->is_double(i)) add_copy(i, builder); + + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i: block->get_inst_list()) + if(auto* cts = dynamic_cast(i)) + info_->replace(cts->get_operand(0), cts); } } diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 6e3533983..8792495a2 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -109,6 +109,10 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple llvm::SmallVectorImpl &buffer, std::vector paths) { init_llvm(); +// llvm::legacy::PassManager passes; +// passes.add(llvm::createPrintModulePass(llvm::outs())); +// passes.add(llvm::createVerifierPass()); +// passes.run(*module); // create machine module->setTargetTriple(triple); std::string error; diff --git a/lib/ir/basic_block.cpp b/lib/ir/basic_block.cpp index 456f0f820..0654156a3 100644 --- a/lib/ir/basic_block.cpp +++ b/lib/ir/basic_block.cpp @@ -21,6 +21,8 @@ basic_block* basic_block::create(context &ctx, const std::string &name, function void basic_block::add_predecessor(basic_block *pred) { preds_.push_back(pred); + if(pred) + pred->succs_.push_back(this); }