diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index ba92843a4..961aea725 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -119,7 +119,7 @@ private: Type *make_vector_ty(Type *ty, size_t vector_size); public: - distributed_tile(Type *ty, const shapes_t& shapes, const axes_t &axes, Builder &builder, bool vectorize); + distributed_tile(Type *ty, const shapes_t& shapes, const std::vector& order, const axes_t &axes, Builder &builder, bool vectorize); void set_value(indices_t idx, Value *v); Value* get_value(indices_t idx); unsigned get_linear_index(indices_t idx); @@ -129,6 +129,7 @@ public: private: axes_t axes_; + std::vector order_; indices_map_t indices_; values_map_t values_; ordered_indices_vec_t ordered_indices_; diff --git a/include/triton/codegen/transform/coalesce.h b/include/triton/codegen/transform/coalesce.h index 9f0576af0..3d418fdb5 100644 --- a/include/triton/codegen/transform/coalesce.h +++ b/include/triton/codegen/transform/coalesce.h @@ -11,6 +11,8 @@ namespace ir { class module; class value; class io_inst; + class instruction; + class builder; } namespace codegen{ @@ -27,6 +29,7 @@ class coalesce { private: void extract_io_use(ir::value *v, std::set& result); void extract_ld(ir::io_inst *i, std::map > &result); + ir::value* rematerialize(ir::value *v, ir::builder& builder, std::map& seen); public: coalesce(analysis::align* align, triton::codegen::analysis::layout *layouts, analysis::meminfo* mem); diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 99fc59234..3949a03db 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -158,6 +158,7 @@ void axes::run(ir::module &mod) { unsigned group_id = 0; while(!nodes_.empty()) connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++); + std::cout << "Number of axes: " << group_id << std::endl; } } diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc index d1b26a6f9..3ee256550 100644 --- a/lib/codegen/analysis/tiles.cc +++ b/lib/codegen/analysis/tiles.cc @@ -190,6 +190,8 @@ void tiles::run(ir::module &) { ); } order_[i] = order; + std::cout << "order: " << order[0] << " " << order[1] << std::endl; + } // tiling parameters for(auto x: largest_){ diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index ca93bc917..79c8214c7 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -1,4 +1,5 @@ -#include "triton/codegen/selection.h" +#include +#include "triton/codegen/selection.h" #include "triton/codegen/target.h" #include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/axes.h" @@ -28,6 +29,14 @@ using namespace llvm; /* Distributed Tile */ void distributed_tile::init_indices() { std::vector id(axes_.size(), 0); + // create iteration order + std::vector order(id.size()); + std::iota(order.begin(), order.end(), 0); + auto cmp = [&](int x, int y) { + return axes_[x].contiguous > axes_[y].contiguous; + }; + std::sort(order.begin(), order.end(), cmp); + // build size_t k = 0; while(true) { indices_t current; @@ -37,12 +46,12 @@ void distributed_tile::init_indices() { indices_[current] = sz; values_[current] = nullptr; ordered_indices_.push_back(current); - id[0]++; - while(id[k] == axes_[k].values.size()){ + id[order[0]]++; + while(id[order[k]] == axes_[order[k]].values.size()){ if(k == id.size() - 1) return; - id[k++] = 0; - id[k]++; + id[order[k++]] = 0; + id[order[k]]++; } k = 0; } @@ -54,8 +63,8 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) return VectorType::get(ty, vector_size); } -distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize) - : tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), builder_(builder) { +distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize) + : tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) { vector_size_ = vectorize?ty_->getVectorNumElements():1; init_indices(); } @@ -767,7 +776,7 @@ void selection::create_shared_tile(ir::value *v, IRBuilder<> &builder, Value *sh for(ir::user *usr: v->get_users()) if(dynamic_cast(usr)) has_phi_user = true; - if(has_phi_user){ + if(!has_phi_user){ size_t offset = alloc_->offset(v); Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); ptr = builder.CreateBitCast(ptr, ptr_ty); @@ -791,7 +800,7 @@ void selection::create_distributed_tile(ir::value *v, IRBuilder<> &builder) { } } bool vectorize = dynamic_cast(v); - distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize); + distributed_tile *T = new distributed_tile(ty, shapes, tiles_->order(v), axes, builder, vectorize); bool is_inserted = tmap_.insert({v, T}).second; // constant range if(is_inserted && dynamic_cast(v)){ @@ -1260,8 +1269,9 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun // find vector size distributed_tile* result = (distributed_tile*)tmap_.at(x); ir::value *ptr = x->get_pointer_operand(); - unsigned alignment = alignment_->get(ptr, 0); - unsigned vector_size = std::min(result->axis(0).contiguous, alignment); + size_t ld = tiles_->order(ptr)[0]; + unsigned alignment = alignment_->get(ptr, ld); + unsigned vector_size = std::min(result->axis(ld).contiguous, alignment); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand()); distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand()); @@ -1331,8 +1341,9 @@ void selection::lower_load(ir::load_inst *x, LLVMContext &ctx, Function *fn, IRB distributed_tile* result = (distributed_tile*)tmap_.at(x); // find vector size ir::value *ptr = x->get_pointer_operand(); - unsigned alignment = alignment_->get(ptr, 0); - unsigned vector_size = std::min(result->axis(0).contiguous, alignment); + size_t ld = tiles_->order(ptr)[0]; + unsigned alignment = alignment_->get(ptr, ld); + unsigned vector_size = std::min(result->axis(ld).contiguous, alignment); distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); // vector loads std::map packets; diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 825b6adf6..0e435e663 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -35,6 +35,31 @@ void coalesce::extract_ld(ir::io_inst* i, std::map& seen) { + if(seen.find(x) != seen.end()) + return seen.at(x); + auto i = dynamic_cast(x); + // not an instruction -- forward value + if(!i) + return x; + // already in shared memory -- forward value + if(dynamic_cast(x)){ + return x; + } + // set insert point + auto& inst_list = i->get_parent()->get_inst_list(); + auto pos = ++std::find(inst_list.begin(), inst_list.end(), i); + builder.set_insert_point(pos); + // default -- recursive clone + ir::instruction *cloned = builder.insert(i->clone()); + seen[i] = cloned; + // rematerialize operands + for(ir::value *op: cloned->ops()) + cloned->replace_uses_of_with(op, rematerialize(op, builder, seen)); + return cloned; +} + void coalesce::run(ir::module &mod) { // find values to rematerialize size_t num_groups = layout_->get_num_groups(); @@ -56,54 +81,21 @@ void coalesce::run(ir::module &mod) { remat.insert(remat.begin(), it->second.begin(), it->second.end()); } - // rematerialize values - ir::builder &builder = mod.get_builder(); for(ir::io_inst *r: remat) { - std::list> work_list; - std::map replaced; - work_list.push_back({r, nullptr}); - // rematerialize recursively - while(!work_list.empty()) { - auto pair = work_list.back(); - ir::instruction* cloned = pair.first; - ir::instruction* original = pair.second; - work_list.pop_back(); - for(ir::value *op: cloned->ops()) { - ir::instruction* i_op = dynamic_cast(op); - if(replaced.find(i_op) != replaced.end()){ - cloned->replace_uses_of_with(i_op, replaced.at(i_op)); - continue; - } - if(!i_op) - continue; - ir::type *ty = i_op->get_type(); - if(!ty->is_tile_ty()) - continue; - auto& inst_list = i_op->get_parent()->get_inst_list(); - auto it = std::find(inst_list.begin(), inst_list.end(), i_op); - it++; - builder.set_insert_point(it); - // found a load; write to shared memory and stop recursion - ir::instruction *n_op = nullptr; - if(mem_->is_shared(i_op)){ - i_op->add_use(cloned); - continue; - } - if(auto* ld = dynamic_cast(i_op)) - n_op = ir::copy_to_shared_inst::create(ld); - // not a load; rematerialize and add to worklist - else { - n_op = i_op->clone(); - work_list.push_back({n_op, i_op}); - } - n_op = builder.insert(n_op); - replaced.insert({i_op, n_op}); - mem_->copy(n_op, i_op); - if(original) - n_op->erase_use(original); - cloned->replace_uses_of_with(i_op, n_op); - } + ir::builder& builder = mod.get_builder(); + // rematerialize operands + std::map seen; + for(ir::value *op: r->ops()) + rematerialize(op, mod.get_builder(), seen); + // copy to shared if load + auto& inst_list = r->get_parent()->get_inst_list(); + auto pos = ++std::find(inst_list.begin(), inst_list.end(), r); + builder.set_insert_point(pos); + if(dynamic_cast(r)){ + ir::instruction *cts = builder.insert(ir::copy_to_shared_inst::create(r)); + r->replace_all_uses_with(cts); + cts->replace_uses_of_with(cts, r); } } } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 1dcf4d738..d541b4d6c 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -92,10 +92,10 @@ void module::compile_llvm_module(std::unique_ptr module, const std file_type_t ft) { init_llvm(); // debug -// llvm::legacy::PassManager pm; -// pm.add(llvm::createPrintModulePass(llvm::outs())); + llvm::legacy::PassManager pm; + pm.add(llvm::createPrintModulePass(llvm::outs())); // pm.add(llvm::createVerifierPass()); -// pm.run(*module); + pm.run(*module); // create machine module->setTargetTriple(triple); std::string error; @@ -241,7 +241,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ -// std::cout << source_ << std::endl; + std::cout << source_ << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/ir/print.cc b/lib/ir/print.cc index af2c68a2e..f88ba1f6f 100644 --- a/lib/ir/print.cc +++ b/lib/ir/print.cc @@ -48,8 +48,10 @@ void print(module &mod, std::ostream& os) { os << std::endl; for(ir::instruction *inst: block->get_inst_list()){ os << " "; - os << get_name(inst, cnt++); - os << " = "; + if(!inst->get_type()->is_void_ty()){ + os << get_name(inst, cnt++); + os << " = "; + } ir::type* type = inst->get_type(); os << inst->repr() << " " << type->repr(); ir::instruction::ops_t ops = inst->ops(); @@ -65,7 +67,6 @@ void print(module &mod, std::ostream& os) { } os << ";" << std::endl; } - os << std::endl; } os << "}" << std::endl; } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 5e40b4419..7908f8ec7 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -221,6 +221,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c axes.run(module); layouts.run(module); coalesce.run(module); +// ir::print(module, std::cout); align.run(module); dce.run(module); tiles.run(module); diff --git a/tests/bench/copy2d.cc b/tests/bench/copy2d.cc index 69e877767..c3433b2e2 100644 --- a/tests/bench/copy2d.cc +++ b/tests/bench/copy2d.cc @@ -48,7 +48,7 @@ int main() { std::vector configs; for(auto x: std::vector{COLMAJOR}){ std::vector tmp = { - config_t{2048, 2048, x} + config_t{4096, 4096, x} }; configs.insert(configs.end(), tmp.begin(), tmp.end()); } diff --git a/tests/common/src/copy.h b/tests/common/src/copy.h index 2a7dc0627..58651a84f 100644 --- a/tests/common/src/copy.h +++ b/tests/common/src/copy.h @@ -38,7 +38,7 @@ void copy2d(TYPE * X __noalias __readonly __aligned(16), int rm[TM] = ridm * TM + 0 ... TM; int rn[TN] = ridn * TN + 0 ... TN; TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; - TYPE* py[TM, TN] = Y + rm[:, newaxis] * ldy + rn[newaxis, :]; + TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldy; *py = *px; } )";