diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 3bc1f2f6a..7bc14b08f 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -24,6 +24,8 @@ class layout { typedef std::map > graph_t; private: + // create edge + void connect(ir::value *x, ir::value *y); // connected components void connected_components(node_t x, std::set &nodes, graph_t &graph, unsigned id); // list the axes of the given value diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 3949a03db..99fc59234 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -158,7 +158,6 @@ void axes::run(ir::module &mod) { unsigned group_id = 0; while(!nodes_.empty()) connected_components(*nodes_.begin(), nodes_, dependencies_, group_id++); - std::cout << "Number of axes: " << group_id << std::endl; } } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index a6eade0b2..0f376b4fc 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -53,6 +53,27 @@ const std::vector& layout::values(unsigned id) const size_t layout::get_num_groups() const { return values_.size(); } +void layout::connect(ir::value *x, ir::value *y) { + if(x == y) + return; + if(!x->get_type()->is_tile_ty()) + return; + if(!y->get_type()->is_tile_ty()) + return; + std::set x_axes = axes_of(x); + std::set y_axes = axes_of(y); + std::set common; + std::set_intersection(x_axes.begin(), x_axes.end(), + y_axes.begin(), y_axes.end(), + std::inserter(common, common.begin())); + if(!common.empty()){ + nodes_.insert(x); + nodes_.insert(y); + dependencies_[x].insert(y); + dependencies_[y].insert(x); + } +} + // run void layout::run(ir::module &mod) { nodes_.clear(); @@ -63,26 +84,12 @@ void layout::run(ir::module &mod) { for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i : block->get_inst_list()) { - // skip scalars - if(!i->get_type()->is_tile_ty()) - continue; - // add an edge between i and the operands that share an axis - std::set i_axes = axes_of(i); - nodes_.insert(i); - for(ir::value* op: i->ops()){ - if(!op->get_type()->is_tile_ty()) - continue; - nodes_.insert(op); - std::set op_axes = axes_of(op); - std::set common; - std::set_intersection(i_axes.begin(), i_axes.end(), - op_axes.begin(), op_axes.end(), - std::inserter(common, common.begin())); - if(!common.empty() || !op->get_type()->is_tile_ty()){ - dependencies_[i].insert(op); - dependencies_[op].insert(i); + for(ir::value* opx: i->ops()) + for(ir::value* opy: i->ops()){ + connect(i, opx); + connect(opx, opy); } - } + } // Grids unsigned group_id = 0; diff --git a/lib/codegen/analysis/tiles.cc b/lib/codegen/analysis/tiles.cc index 3ee256550..d1b26a6f9 100644 --- a/lib/codegen/analysis/tiles.cc +++ b/lib/codegen/analysis/tiles.cc @@ -190,8 +190,6 @@ void tiles::run(ir::module &) { ); } order_[i] = order; - std::cout << "order: " << order[0] << " " << order[1] << std::endl; - } // tiling parameters for(auto x: largest_){ diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 762fd90db..c6592a59c 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -1035,11 +1035,17 @@ void selection::lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Functio } void selection::lower_copy_to_shared(ir::copy_to_shared_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { - shared_tile* result = (shared_tile*)tmap_.at(x); + unsigned vector_size = 1; + auto x_order = tiles_->order(x); ir::value *arg = x->get_operand(0); + auto arg_order = tiles_->order(arg); + // tiles + shared_tile* result = (shared_tile*)tmap_.at(x); distributed_tile* in = (distributed_tile*)tmap_.at(arg); - size_t ld = tiles_->order(arg)[0]; - unsigned vector_size = in->axis(ld).contiguous; + if(x_order == arg_order){ + size_t ld = arg_order[0]; + vector_size = std::min(tiles_->nts(x, ld),tiles_->nts(arg, ld)); + } std::map packets; in->for_each([&](indices_t idx){ diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index b0d1a3521..873f7a9f5 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -51,6 +51,11 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder, auto& inst_list = i->get_parent()->get_inst_list(); auto pos = ++std::find(inst_list.begin(), inst_list.end(), i); builder.set_insert_point(pos); + if(dynamic_cast(x)){ + ir::value *ret = builder.insert(ir::copy_to_shared_inst::create(x)); +// x->replace_all_uses_with(ret); + return ret; + } // default -- recursive clone ir::instruction *cloned = builder.insert(i->clone()); seen[i] = cloned; @@ -97,6 +102,9 @@ void coalesce::run(ir::module &mod) { r->replace_all_uses_with(cts); cts->replace_uses_of_with(cts, r); } + else{ + + } } } diff --git a/lib/driver/module.cc b/lib/driver/module.cc index d541b4d6c..0bf85c84f 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -92,10 +92,10 @@ void module::compile_llvm_module(std::unique_ptr module, const std file_type_t ft) { init_llvm(); // debug - llvm::legacy::PassManager pm; - pm.add(llvm::createPrintModulePass(llvm::outs())); +// llvm::legacy::PassManager pm; +// pm.add(llvm::createPrintModulePass(llvm::outs())); // pm.add(llvm::createVerifierPass()); - pm.run(*module); +// pm.run(*module); // create machine module->setTargetTriple(triple); std::string error; @@ -241,7 +241,6 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ - std::cout << source_ << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 7db1e1af1..04977966d 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -220,7 +220,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c axes.run(module); layouts.run(module); coalesce.run(module); -// ir::print(module, std::cout); + dce.run(module); align.run(module); dce.run(module); tiles.run(module); diff --git a/tests/bench/copy2d.cc b/tests/bench/copy2d.cc index c3433b2e2..6ee7f5496 100644 --- a/tests/bench/copy2d.cc +++ b/tests/bench/copy2d.cc @@ -11,19 +11,21 @@ #include "cuda/cublas.h" -std::vector do_bench(drv::stream* stream, int32_t M, int32_t N, order_t order){ +std::vector do_bench(drv::stream* stream, int32_t M, int32_t N, order_t order_x, order_t order_y){ typedef float NumericT; std::string ty = "float"; size_t dt_nbytes = sizeof(NumericT); drv::context* context = stream->context(); - int32_t ld = order == ROWMAJOR ? N : M; // create inputs auto dx = std::unique_ptr(drv::buffer::create(context, M*N*dt_nbytes)); auto dy = std::unique_ptr(drv::buffer::create(context, M*N*dt_nbytes)); // create options rt::function::options_space_t opt; opt.defines.push_back({"TYPE", {ty}}); - opt.defines.push_back({"ORDER", {order==ROWMAJOR?"ROWMAJOR":"COLMAJOR"}}); + opt.defines.push_back({"STRIDE_XM", {(order_x == ROWMAJOR)?"M":"1"}}); + opt.defines.push_back({"STRIDE_XN", {(order_x == ROWMAJOR)?"1":"N"}}); + opt.defines.push_back({"STRIDE_YM", {(order_y == ROWMAJOR)?"M":"1"}}); + opt.defines.push_back({"STRIDE_YN", {(order_y == ROWMAJOR)?"1":"N"}}); opt.defines.push_back({"TM", {"32"}}); opt.defines.push_back({"TN", {"32"}}); opt.num_warps = {4}; @@ -33,7 +35,7 @@ std::vector do_bench(drv::stream* stream, int32_t M, int32_t N, order_t std::vector result; auto gbps = [&](double ns) { return 2*M*N*dt_nbytes / (ns * 1e-9) * 1e-9; }; // triton - double triton_ns = triton::tools::bench([&]() { function({&*dx, &*dy, M, N, ld, ld}, grid2d(M, N), stream);}, stream); + double triton_ns = triton::tools::bench([&]() { function({&*dx, &*dy, M, N}, grid2d(M, N), stream);}, stream); result.push_back(gbps(triton_ns)); // done return result; @@ -44,21 +46,20 @@ int main() { auto context = triton::driver::backend::contexts::get_default(); triton::driver::stream* stream = triton::driver::stream::create(context); // shapes to benchmark - typedef std::tuple config_t; - std::vector configs; - for(auto x: std::vector{COLMAJOR}){ - std::vector tmp = { - config_t{4096, 4096, x} - }; - configs.insert(configs.end(), tmp.begin(), tmp.end()); - } + typedef std::tuple config_t; + std::vector configs = { + {4096, 4096, ROWMAJOR, ROWMAJOR}, + {4096, 4096, COLMAJOR, ROWMAJOR}, + {4096, 4096, ROWMAJOR, COLMAJOR}, + {4096, 4096, COLMAJOR, COLMAJOR}, + }; // does the work int32_t M, N; - order_t ord; + order_t ord_x, ord_y; for(const auto& c: configs){ - std::tie(M, N, ord) = c; - std::cout << "// " << M << ", " << N << ", " << ord << std::flush; - for(auto perf: do_bench(stream, M, N, ord)) + std::tie(M, N, ord_x, ord_y) = c; + std::cout << "// " << M << ", " << N << ", " << ord_x << ", " << ord_y << std::flush; + for(auto perf: do_bench(stream, M, N, ord_x, ord_y)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/src/copy.h b/tests/common/src/copy.h index b1d571b51..8b0f5d9dc 100644 --- a/tests/common/src/copy.h +++ b/tests/common/src/copy.h @@ -16,29 +16,16 @@ void copy1d(TYPE * X __noalias __readonly __aligned(16), const char *copy2d = R"( -#if ORDER == ROWMAJOR -#define STRIDE_XM ldx -#define STRIDE_XN 1 -#define STRIDE_YM ldy -#define STRIDE_YN 1 -#else -#define STRIDE_XM 1 -#define STRIDE_XN ldx -#define STRIDE_YM 1 -#define STRIDE_YN ldy -#endif - void copy2d(TYPE * X __noalias __readonly __aligned(16), TYPE * Y __noalias __writeonly __aligned(16), - int M, int N, - int ldx __multipleof(8), - int ldy __multipleof(8)) { + int M __multipleof(8), + int N __multipleof(8)) { int ridm = get_program_id(0); int ridn = get_program_id(1); int rm[TM] = ridm * TM + 0 ... TM; int rn[TN] = ridn * TN + 0 ... TN; - TYPE* px[TM, TN] = X + rm[:, newaxis] * ldx + rn[newaxis, :] ; - TYPE* py[TM, TN] = Y + rm[:, newaxis] + rn[newaxis, :] * ldy; + TYPE* px[TM, TN] = X + rm[:, newaxis] * STRIDE_XM + rn[newaxis, :] * STRIDE_XN; + TYPE* py[TM, TN] = Y + rm[:, newaxis] * STRIDE_YM + rn[newaxis, :] * STRIDE_YN; *py = *px; } )";