From d52abc93799be0873018e32dbd6fe9c018a2734a Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Tue, 25 Jun 2019 15:06:15 -0700 Subject: [PATCH] [codegen] bugfix in alignment inference --- examples/python/tensorflow/dot.cpp | 6 +- include/triton/codegen/axis_info.h | 12 +- include/triton/codegen/selection.h | 6 +- include/triton/runtime/jit.h | 6 +- lib/codegen/axis_info.cpp | 183 +++++++++++++++++++++++------ lib/codegen/selection.cpp | 7 +- lib/driver/module.cpp | 2 +- 7 files changed, 170 insertions(+), 52 deletions(-) diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 0a8e1e948..8dea2337a 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -26,9 +26,9 @@ const tunable int32 TN = {64, 128}; const tunable int32 TK = {16}; const tunable int32 GZ = {1}; -void matmul(restrict read_only align(4) fp16 *A, - restrict read_only align(4) fp16 *B, - align(4) fp32 *C, +void matmul(restrict read_only align(16) fp16 *A, + restrict read_only align(16) fp16 *B, + align(16) fp32 *C, int32 M, int32 N, int32 K, multiple_of(4) int32 lda, multiple_of(4) int32 ldb, multiple_of(4) int32 ldc, int32 *locks, int32 grid0, int32 grid1) { diff --git a/include/triton/codegen/axis_info.h b/include/triton/codegen/axis_info.h index bfc4ef322..9b44b01c7 100644 --- a/include/triton/codegen/axis_info.h +++ b/include/triton/codegen/axis_info.h @@ -16,20 +16,22 @@ namespace codegen{ class axis_info { private: // helpers - bool is_first_axis_unit(ir::value *x); + bool is_first_axis_unit(ir::value *v); // populate maps - bool populate_is_constant(ir::value *i); - unsigned populate_max_contiguous(ir::value *i); - unsigned populate_multiple_of(ir::value *i); + bool populate_is_constant(ir::value *v); + unsigned populate_max_contiguous(ir::value *v); + unsigned populate_starting_multiple(ir::value *v); public: void run(ir::module &mod); + unsigned get_starting_multiple(ir::value* v) const; + unsigned get_max_contiguous(ir::value* v) const; private: std::map is_constant_; std::map max_contiguous_; - std::map multiple_of_; + std::map starting_multiple_; }; diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 9a8149a01..7ad586058 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -25,6 +25,7 @@ class shmem_allocation; class tune; class shmem_info; class target; +class axis_info; typedef std::vector indices_t; @@ -143,8 +144,8 @@ private: void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder); public: - selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, target *tgt) - : alloc_(alloc), params_(params), buffer_info_(buffer_info), tgt_(tgt){ } + selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, axis_info *ax_info, target *tgt) + : alloc_(alloc), params_(params), buffer_info_(buffer_info), axis_info_(ax_info), tgt_(tgt){ } void run(ir::module &src, llvm::Module &dst); @@ -157,6 +158,7 @@ private: tune *params_; target *tgt_; shmem_info *buffer_info_; + axis_info *axis_info_; std::map axes_; llvm::Value *sh_mem_ptr_; llvm::Value *offset_a_i_, *offset_a_k_; diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index c12bb6b23..3b8aa606c 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -57,7 +57,7 @@ public: shmem_allocation(&shmem_liveness, &shmem_info, &tune), shmem_barriers(&shmem_allocation, &shmem_info), vectorize(&tune), - selection(&shmem_allocation, &tune, &shmem_info, target), + selection(&shmem_allocation, &tune, &shmem_info, &axis_info, target), optimize_dot(&tune), optimize_cse(), optimize_trans(), @@ -67,11 +67,11 @@ public: void target_independent(ir::module &module) { optimize_dot.run(module); optimize_trans.run(module); - axis_info.run(module); -// ir::print(module, std::cout); + ir::print(module, std::cout); } void target_dependent(ir::module &module) { + axis_info.run(module); if(target_->is_gpu()){ shmem_info.run(module); shmem_liveness.run(module); diff --git a/lib/codegen/axis_info.cpp b/lib/codegen/axis_info.cpp index 38e2fcd9b..be2a16c91 100644 --- a/lib/codegen/axis_info.cpp +++ b/lib/codegen/axis_info.cpp @@ -11,7 +11,7 @@ namespace codegen{ template inline T add_to_cache(ir::value *i, T value, std::map &map) { - return map.insert(std::make_pair(i, value)).first->second; + return map[i] = value; } @@ -23,63 +23,132 @@ bool axis_info::is_first_axis_unit(ir::value *x){ } bool axis_info::populate_is_constant(ir::value *v) { + if(is_constant_.find(v) != is_constant_.end()) + return is_constant_.at(v); // helper for the cache auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); }; // populate - if(v->get_type()->is_tile_ty()){ - if(auto *x = dynamic_cast(v)){ - bool value = populate_is_constant(x->get_operand(0)); - // check if broadcast (i.e., constant) along contiguous dimension - if(is_first_axis_unit(x->get_operand(0)) - && !is_first_axis_unit(x)) - return cache(value); - } - // otherwise the tile is not constant in the contiguous dimension + if(auto *x = dynamic_cast(v)){ + ir::value *op = x->get_operand(0); + populate_is_constant(op); + if(is_first_axis_unit(op)) + return cache(true); + } + if(auto *x = dynamic_cast(v)){ + bool lhs = populate_is_constant(x->get_operand(0)); + bool rhs = populate_is_constant(x->get_operand(1)); + return cache(lhs && rhs); + } + if(v->get_type()->is_tile_ty()) return cache(false); + if(auto *x = dynamic_cast(v)){ + // put a conservative initial value in phi node to avoid infinite recursion + bool result = true; + for(unsigned n = 0; n < x->get_num_incoming(); n++){ + ir::value* inc = x->get_incoming_value(n); + if(is_constant_.find(inc) != is_constant_.end()) + result = is_constant_.at(inc); + } + cache(result); + // recurse + for(unsigned n = 0; n < x->get_num_incoming(); n++){ + ir::value* inc = x->get_incoming_value(n); + result = result && populate_is_constant(inc); + } + return cache(result); } // scalars are always constant in the contiguous dimension return cache(true); } unsigned axis_info::populate_max_contiguous(ir::value *v){ + if(max_contiguous_.find(v) != max_contiguous_.end()) + return max_contiguous_.at(v); // helper for the cache auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); }; // populate - if(v->get_type()->is_tile_ty()){ - auto shapes = v->get_type()->get_tile_shapes(); - if(dynamic_cast(v)) - return cache(shapes[0]->get_value()); - if(auto *x = dynamic_cast(v)){ - ir::value* lhs = x->get_operand(0); - ir::value* rhs = x->get_operand(1); - unsigned lhs_max_contiguous = populate_max_contiguous(lhs); - bool lhs_has_cst = populate_is_constant(lhs); - unsigned rhs_max_contiguous = populate_max_contiguous(rhs); - bool rhs_has_cst = populate_is_constant(rhs); - if(x->is_int_add_sub()){ - if(lhs_has_cst) - return cache(rhs_max_contiguous); - if(rhs_has_cst) - return cache(lhs_max_contiguous); - } + if(!v->get_type()->is_tile_ty()) + return cache(1); + auto shapes = v->get_type()->get_tile_shapes(); + if(dynamic_cast(v)) + return cache(shapes[0]->get_value()); + if(dynamic_cast(v)) + return cache(shapes[0]->get_value()); + if(auto *x = dynamic_cast(v)){ + ir::value *op = x->get_operand(0); + if(op->get_type()->is_tile_ty()){ + auto op_shapes = op->get_type()->get_tile_shapes(); + if(op_shapes[0] == shapes[0]) + return cache(populate_max_contiguous(op)); } + return cache(1); + } + if(auto *x = dynamic_cast(v)){ + ir::value* lhs = x->get_operand(0); + ir::value* rhs = x->get_operand(1); + unsigned lhs_max_contiguous = populate_max_contiguous(lhs); + unsigned rhs_max_contiguous = populate_max_contiguous(rhs); + bool lhs_has_cst = populate_is_constant(lhs); + bool rhs_has_cst = populate_is_constant(rhs); + if(x->is_int_add_sub()){ + if(lhs_has_cst) + return cache(rhs_max_contiguous); + if(rhs_has_cst) + return cache(lhs_max_contiguous); + } + } + if(auto *x = dynamic_cast(v)){ + ir::value* lhs = x->get_operand(0); + ir::value* rhs = x->get_operand(1); + unsigned lhs_max_contiguous = populate_max_contiguous(lhs); + unsigned rhs_max_contiguous = populate_max_contiguous(rhs); + bool lhs_has_cst = populate_is_constant(lhs); + bool rhs_has_cst = populate_is_constant(rhs); + if(lhs_has_cst) + return cache(rhs_max_contiguous); + if(rhs_has_cst) + return cache(lhs_max_contiguous); + } + if(auto *x = dynamic_cast(v)){ + // put a conservative initial value in phi node to avoid infinite recursion + unsigned result = 1; + for(unsigned n = 0; n < x->get_num_incoming(); n++){ + ir::value* inc = x->get_incoming_value(n); + if(max_contiguous_.find(inc) != max_contiguous_.end()) + result = max_contiguous_.at(inc); + } + cache(result); + // recurse + for(unsigned n = 0; n < x->get_num_incoming(); n++){ + ir::value* inc = x->get_incoming_value(n); + result = std::min(result, populate_max_contiguous(inc)); + } + return cache(result); } return cache(1); } -unsigned axis_info::populate_multiple_of(ir::value *v){ - auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); }; - +unsigned axis_info::populate_starting_multiple(ir::value *v){ + if(starting_multiple_.find(v) != starting_multiple_.end()) + return starting_multiple_.at(v); + auto cache = [this,v](unsigned value){ return add_to_cache(v, value, starting_multiple_); }; + // arguments if(auto *x = dynamic_cast(v)){ std::set attributes = x->get_parent()->get_attributes(x); for(auto attr: attributes){ if(attr.get_kind() == ir::multiple_of) return cache(attr.get_value()); + if(attr.get_kind() == ir::aligned){ + ir::type* ty = x->get_type()->get_pointer_element_ty(); + int nbits = ty->get_primitive_size_in_bits(); + int nbytes = nbits / 8; + return cache(attr.get_value() / nbytes); + } } } if(auto *x = dynamic_cast(v)){ - int lhs = populate_multiple_of(x->get_operand(0)); - int rhs = populate_multiple_of(x->get_operand(1)); + int lhs = populate_starting_multiple(x->get_operand(0)); + int rhs = populate_starting_multiple(x->get_operand(1)); if(x->is_int_mult()) return cache(lhs * rhs); if(x->is_int_add_sub()) @@ -93,12 +162,52 @@ unsigned axis_info::populate_multiple_of(ir::value *v){ if(x->is_shr()) return cache(std::max(lhs >> rhs, 1)); } - if(auto *x = dynamic_cast(v)){ - return cache(populate_multiple_of(x->get_operand(0))); + if(auto *x = dynamic_cast(v)){ + int lhs = populate_starting_multiple(x->get_operand(0)); + int rhs = populate_starting_multiple(x->get_operand(1)); + return cache(std::min(lhs, rhs)); } - return cache(1); + if(auto *x = dynamic_cast(v)){ + int op = populate_starting_multiple(x->get_operand(0)); + return cache(op); + } + if(auto *x = dynamic_cast(v)){ + return cache(v->get_type()->get_tile_shapes()[0]->get_value()); + } + if(auto *x = dynamic_cast(v)){ + // put a conservative initial value in phi node to avoid infinite recursion + unsigned result = 1; + for(unsigned n = 0; n < x->get_num_incoming(); n++){ + ir::value* inc = x->get_incoming_value(n); + if(starting_multiple_.find(inc) != starting_multiple_.end()) + result = starting_multiple_.at(inc); + } + cache(result); + // recurse + for(unsigned n = 0; n < x->get_num_incoming(); n++){ + ir::value* inc = x->get_incoming_value(n); + result = std::min(result, populate_starting_multiple(inc)); + } + return cache(result); + } + // scalars + if(!v->get_type()->is_tile_ty()) + return cache(1); + // tiles + auto shapes = v->get_type()->get_tile_shapes(); + unsigned result = 1; + for(unsigned i = 0; i < shapes.size() - 1; i++) + result *= shapes[i]->get_value(); + return cache(result); } +unsigned axis_info::get_starting_multiple(ir::value* v) const { + return starting_multiple_.at(v); +} + +unsigned axis_info::get_max_contiguous(ir::value* v) const { + return max_contiguous_.at(v); +} void axis_info::run(ir::module &mod) { @@ -109,11 +218,11 @@ void axis_info::run(ir::module &mod) { populate_is_constant(i); } - // populate multiple_of + // populate starting multiple for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()){ - populate_multiple_of(i); + populate_starting_multiple(i); } // populate maximum contiguous diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index ae2a3f1c1..e98f0bcb0 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -2,6 +2,7 @@ #include "triton/codegen/tune.h" #include "triton/codegen/shmem_allocation.h" #include "triton/codegen/target.h" +#include "triton/codegen/axis_info.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" @@ -1027,7 +1028,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & } } else if(auto *ld = dynamic_cast(ins)){ - unsigned vector_size = result->axis(0).contiguous; + ir::value *ptr = ld->get_pointer_operand(); + unsigned starting_multiple = axis_info_->get_starting_multiple(ptr); + unsigned max_contiguous = axis_info_->get_max_contiguous(ptr); + unsigned alignment = std::min(starting_multiple, max_contiguous); + unsigned vector_size = std::min(result->axis(0).contiguous, alignment); std::map packets; distributed_tile *TP = (distributed_tile*)tmap_.at(ld->get_pointer_operand()); result->for_each([&](indices_t idx){ diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 4ff863666..c3139ece6 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ -// std::cout << source << std::endl; + std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};