diff --git a/include/triton/codegen/analysis/alignment.h b/include/triton/codegen/analysis/align.h similarity index 97% rename from include/triton/codegen/analysis/alignment.h rename to include/triton/codegen/analysis/align.h index 6ef3c0f55..6812314b7 100644 --- a/include/triton/codegen/analysis/alignment.h +++ b/include/triton/codegen/analysis/align.h @@ -13,7 +13,7 @@ namespace ir { namespace codegen{ namespace analysis{ -class alignment_info { +class align { struct cst_info { unsigned num_cst; unsigned value; diff --git a/include/triton/codegen/analysis/tune.h b/include/triton/codegen/analysis/grid.h similarity index 100% rename from include/triton/codegen/analysis/tune.h rename to include/triton/codegen/analysis/grid.h diff --git a/include/triton/codegen/analysis/shmem/liveness.h b/include/triton/codegen/analysis/liveness.h similarity index 93% rename from include/triton/codegen/analysis/shmem/liveness.h rename to include/triton/codegen/analysis/liveness.h index bec0303c0..4aa0c6dae 100644 --- a/include/triton/codegen/analysis/shmem/liveness.h +++ b/include/triton/codegen/analysis/liveness.h @@ -13,11 +13,10 @@ namespace ir{ namespace codegen{ namespace analysis{ -namespace shmem{ typedef unsigned slot_index; -class info; +class meminfo; struct segment { slot_index start; @@ -45,7 +44,7 @@ public: public: // constructor - liveness(info *info): info_(info){ } + liveness(meminfo *info): info_(info){ } // accessors const intervals_map_t& intervals() const { return intervals_; } @@ -55,7 +54,7 @@ public: void run(ir::module &mod); private: - info *info_; + meminfo *info_; has_storage_map_t has_dedicated_storage_; indices_map_t indices_; intervals_map_t intervals_; @@ -64,7 +63,6 @@ private: } } } -} #endif diff --git a/include/triton/codegen/analysis/shmem/allocation.h b/include/triton/codegen/analysis/memalloc.h similarity index 86% rename from include/triton/codegen/analysis/shmem/allocation.h rename to include/triton/codegen/analysis/memalloc.h index 243d78352..0e5b2adc9 100644 --- a/include/triton/codegen/analysis/shmem/allocation.h +++ b/include/triton/codegen/analysis/memalloc.h @@ -17,14 +17,12 @@ namespace analysis{ class grids; -namespace shmem{ - class liveness; -class info; +class meminfo; -class allocation { +class memalloc { public: - allocation(liveness *live, info *buffer_info, grids *params) + memalloc(liveness *live, meminfo *buffer_info, grids *params) : liveness_(live), buffer_info_(buffer_info), params_(params){ } // utilities @@ -44,13 +42,12 @@ private: size_t allocated_size_; // dependences liveness *liveness_; - info *buffer_info_; + meminfo *buffer_info_; grids *params_; }; } } } -} #endif diff --git a/include/triton/codegen/analysis/shmem/info.h b/include/triton/codegen/analysis/meminfo.h similarity index 95% rename from include/triton/codegen/analysis/shmem/info.h rename to include/triton/codegen/analysis/meminfo.h index 689516cb2..1b896056f 100644 --- a/include/triton/codegen/analysis/shmem/info.h +++ b/include/triton/codegen/analysis/meminfo.h @@ -15,9 +15,8 @@ namespace ir { namespace codegen{ namespace analysis{ -namespace shmem{ -class info { +class meminfo { public: void run(ir::module &mod); // queries @@ -38,6 +37,5 @@ private: } } } -} #endif diff --git a/include/triton/codegen/selection/selection.h b/include/triton/codegen/selection.h similarity index 94% rename from include/triton/codegen/selection/selection.h rename to include/triton/codegen/selection.h index 2610fefc3..0a5d84825 100644 --- a/include/triton/codegen/selection/selection.h +++ b/include/triton/codegen/selection.h @@ -5,7 +5,7 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/type.h" -#include "triton/codegen/analysis/shmem/info.h" +#include "triton/codegen/analysis/meminfo.h" namespace llvm{ @@ -45,14 +45,10 @@ namespace codegen{ namespace analysis{ class grids; -class alignment_info; +class align; +class memalloc; +class meminfo; -namespace shmem{ - -class allocation; -class info; - -} } class target; @@ -196,7 +192,7 @@ private: public: - selection(analysis::shmem::allocation *alloc, analysis::grids *params, analysis::shmem::info *buffer_info, analysis::alignment_info *alignment, target *tgt) + selection(analysis::memalloc *alloc, analysis::grids *params, analysis::meminfo *buffer_info, analysis::align *alignment, target *tgt) : alloc_(alloc), params_(params), buffer_info_(buffer_info), alignment_(alignment), tgt_(tgt){ } void run(ir::module &src, Module &dst); @@ -204,10 +200,10 @@ public: private: vmap_t vmap_; tmap_t tmap_; - analysis::shmem::allocation *alloc_; + analysis::memalloc *alloc_; analysis::grids *params_; - analysis::shmem::info *buffer_info_; - analysis::alignment_info *alignment_; + analysis::meminfo *buffer_info_; + analysis::align *alignment_; target *tgt_; std::map axes_; Value *sh_mem_ptr_; diff --git a/include/triton/codegen/selection/target.h b/include/triton/codegen/target.h similarity index 94% rename from include/triton/codegen/selection/target.h rename to include/triton/codegen/target.h index f5f8e9a7c..dc379bd0c 100644 --- a/include/triton/codegen/selection/target.h +++ b/include/triton/codegen/target.h @@ -46,6 +46,7 @@ public: virtual Value* get_local_id(Module *module, Builder& builder, unsigned ax) = 0; virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0; virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0; + virtual unsigned guaranteed_alignment() = 0; bool is_gpu() const; private: @@ -62,6 +63,7 @@ public: Value* get_local_id(Module *module, Builder& builder, unsigned ax); Value* get_block_id(Module *module, Builder& builder, unsigned ax); Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); + unsigned guaranteed_alignment() { return 16; } }; class nvidia_cu_target: public target { @@ -74,6 +76,7 @@ public: Value* get_local_id(Module *module, Builder& builder, unsigned ax); Value* get_block_id(Module *module, Builder& builder, unsigned ax); Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); + unsigned guaranteed_alignment() { return 16; } }; class cpu_target: public target { @@ -86,6 +89,7 @@ public: Value* get_local_id(Module *module, Builder& builder, unsigned ax); Value* get_block_id(Module *module, Builder& builder, unsigned ax); Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); + unsigned guaranteed_alignment() { return 1; } }; } diff --git a/include/triton/codegen/transform/shmem/barriers.h b/include/triton/codegen/transform/membar.h similarity index 79% rename from include/triton/codegen/transform/shmem/barriers.h rename to include/triton/codegen/transform/membar.h index 6352fd060..8991ac57d 100644 --- a/include/triton/codegen/transform/shmem/barriers.h +++ b/include/triton/codegen/transform/membar.h @@ -14,17 +14,15 @@ namespace ir { namespace codegen{ namespace analysis{ -namespace shmem{ -class allocation; -class info; +class memalloc; +class meminfo; -} } namespace transform{ -class shmem_barriers { +class membar { private: typedef std::pair interval_t; typedef std::vector interval_vec_t; @@ -40,12 +38,12 @@ private: std::pair transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set &insert_loc); public: - shmem_barriers(analysis::shmem::allocation *alloc, analysis::shmem::info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} + membar(analysis::memalloc *alloc, analysis::meminfo *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} void run(ir::module &mod); private: - analysis::shmem::allocation *alloc_; - analysis::shmem::info *buffer_info_; + analysis::memalloc *alloc_; + analysis::meminfo *buffer_info_; }; diff --git a/include/triton/codegen/transform/reassociate.h b/include/triton/codegen/transform/reassociate.h index 075446e6f..318884755 100644 --- a/include/triton/codegen/transform/reassociate.h +++ b/include/triton/codegen/transform/reassociate.h @@ -20,7 +20,7 @@ namespace codegen{ namespace analysis{ class grids; -class alignment_info; +class align; } namespace transform{ @@ -37,12 +37,12 @@ private: ir::value *reassociate_ptr(ir::getelementptr_inst* pz, ir::builder &builder, std::map &offsets); public: - reassociate(analysis::alignment_info* align, analysis::grids *params); + reassociate(analysis::align* align, analysis::grids *params); void run(ir::module& module); private: analysis::grids* params_; - analysis::alignment_info* align_; + analysis::align* align_; }; } diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 4a7c308eb..74af3abe2 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -61,6 +61,19 @@ public: return kind_ != multiple_of; } + std::string repr() const { + switch(kind_){ + case readonly: return ".readonly"; + case writeonly: return ".writeonly"; + case noalias: return ".noalias"; + case aligned: return ".aligned(" + std::to_string(value_) + ")"; + case multiple_of: return ".readonly"; + default: break; + } + assert(false); + return ""; + } + private: attribute_kind_t kind_; unsigned value_; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index f0a345c81..a4fbc3710 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -687,7 +687,7 @@ private: public: static nv_static_program_idx *get(constant_range* range); constant_range* get_range() const; - std::string repr() const { return get_name(); } + std::string repr() const { return "nv_static_program_idx"; } private: constant_range *range_; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index b0054c647..96ec35ef7 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -9,16 +9,16 @@ #include #include // codegen -#include "triton/codegen/selection/selection.h" -#include "triton/codegen/selection/target.h" -#include "triton/codegen/analysis/tune.h" -#include "triton/codegen/analysis/shmem/allocation.h" -#include "triton/codegen/analysis/shmem/liveness.h" -#include "triton/codegen/analysis/shmem/info.h" -#include "triton/codegen/analysis/alignment.h" +#include "triton/codegen/selection.h" +#include "triton/codegen/target.h" +#include "triton/codegen/analysis/grid.h" +#include "triton/codegen/analysis/memalloc.h" +#include "triton/codegen/analysis/liveness.h" +#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/peephole.h" -#include "triton/codegen/transform/shmem/barriers.h" +#include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/vectorize.h" #include "triton/lang/parser.h" diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 56016638b..554b3bcc3 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -41,8 +41,8 @@ inline double bench(std::function const & op, driver::stream * stream) while(total_time*1e-9 < 1e-3){ float norm = 1; // normalize clock if possible to reduce noise in auto-tuning -// if(auto cu_device = dynamic_cast(stream->context()->device())) -// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); + if(auto cu_device = dynamic_cast(stream->context()->device())) + norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); tmr.start(); op(); stream->synchronize(); diff --git a/lib/codegen/analysis/alignment.cc b/lib/codegen/analysis/align.cc similarity index 89% rename from lib/codegen/analysis/alignment.cc rename to lib/codegen/analysis/align.cc index 98d4a110f..85500aefb 100644 --- a/lib/codegen/analysis/alignment.cc +++ b/lib/codegen/analysis/align.cc @@ -1,4 +1,4 @@ -#include "triton/codegen/analysis/alignment.h" +#include "triton/codegen/analysis/align.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -29,14 +29,14 @@ inline T add_to_cache(ir::value *i, T value, std::map &map) { } -bool alignment_info::is_first_axis_unit(ir::value *x){ +bool align::is_first_axis_unit(ir::value *x){ if(x->get_type()->is_tile_ty()) return x->get_type()->get_tile_shapes()[0] == 1; else return true; } -alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) { +align::cst_info align::populate_is_constant(ir::value *v) { if(is_constant_.find(v) != is_constant_.end()) return is_constant_.at(v); // helper for the cache @@ -102,7 +102,7 @@ alignment_info::cst_info alignment_info::populate_is_constant(ir::value *v) { return cache({1, 0}); } -unsigned alignment_info::populate_max_contiguous(ir::value *v){ +unsigned align::populate_max_contiguous(ir::value *v){ if(max_contiguous_.find(v) != max_contiguous_.end()) return max_contiguous_.at(v); // helper for the cache @@ -181,7 +181,7 @@ unsigned alignment_info::populate_max_contiguous(ir::value *v){ return cache(1); } -unsigned alignment_info::populate_starting_multiple(ir::value *v){ +unsigned align::populate_starting_multiple(ir::value *v){ if(starting_multiple_.find(v) != starting_multiple_.end()) return starting_multiple_.at(v); auto cache = [this,v](unsigned value){ @@ -240,7 +240,19 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){ int rhs = populate_starting_multiple(x->get_operand(1)); return cache(gcd(lhs, rhs)); } - if(auto *x = dynamic_cast(v)){ + if(auto *x = dynamic_cast(v)){ + int op = populate_starting_multiple(x->get_operand(0)); + return cache(op); + } + if(auto *x = dynamic_cast(v)){ + int op = populate_starting_multiple(x->get_operand(0)); + auto shapes = x->get_type()->get_tile_shapes(); + if(shapes[0] == 1) + return cache(1); + else + return cache(op); + } + if(auto *x = dynamic_cast(v)){ int op = populate_starting_multiple(x->get_operand(0)); return cache(op); } @@ -271,22 +283,22 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){ return cache(result); } -unsigned alignment_info::get_starting_multiple(ir::value* v) const { +unsigned align::get_starting_multiple(ir::value* v) const { return starting_multiple_.at(v); } -unsigned alignment_info::get_max_contiguous(ir::value* v) const { +unsigned align::get_max_contiguous(ir::value* v) const { return max_contiguous_.at(v); } -void alignment_info::copy(ir::value *dst, ir::value *src) { +void align::copy(ir::value *dst, ir::value *src) { starting_multiple_[dst] = starting_multiple_[src]; max_contiguous_[dst] = max_contiguous_[src]; is_constant_[dst] = is_constant_[src]; } ///TODO: This doesn't seem to work in DOT-NN, DOT-TT, DOT-TN -void alignment_info::run(ir::module &mod) { +void align::run(ir::module &mod) { // populate constant for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) @@ -304,9 +316,13 @@ void alignment_info::run(ir::module &mod) { // populate maximum contiguous for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()){ + for(ir::instruction *i: block->get_inst_list()) populate_max_contiguous(i); - } + +// for(ir::function *fn: mod.get_function_list()) +// for(ir::basic_block *block: fn->blocks()) +// for(ir::instruction *i: block->get_inst_list()) +// std::cout << i->get_name() << " " << max_contiguous_.at(i) << " " << is_constant_.at(i).num_cst << " " << starting_multiple_.at(i) << std::endl; } diff --git a/lib/codegen/analysis/tune.cc b/lib/codegen/analysis/grid.cc similarity index 99% rename from lib/codegen/analysis/tune.cc rename to lib/codegen/analysis/grid.cc index 275011a7b..f90ab8822 100644 --- a/lib/codegen/analysis/tune.cc +++ b/lib/codegen/analysis/grid.cc @@ -1,6 +1,6 @@ #include #include -#include "triton/codegen/analysis/tune.h" +#include "triton/codegen/analysis/grid.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" #include "triton/ir/module.h" @@ -292,7 +292,7 @@ void grids::run(ir::module &mod) { else{ unsigned shape = shapes[0]; unsigned current = num_threads; - params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 8)); + params_.at(i).at("nts.d0")->set_value(clamp(size / num_threads, 1, 4)); params_.at(i).at("mts.d0")->set_value(clamp(current, 1, shape / params_.at(i).at("nts.d0")->get_value())); current = current / params_.at(i).at("mts.d0")->get_value(); for(size_t d = 1; d < shapes.size(); d++){ diff --git a/lib/codegen/analysis/shmem/liveness.cc b/lib/codegen/analysis/liveness.cc similarity index 89% rename from lib/codegen/analysis/shmem/liveness.cc rename to lib/codegen/analysis/liveness.cc index 617a764ed..8801235b5 100644 --- a/lib/codegen/analysis/shmem/liveness.cc +++ b/lib/codegen/analysis/liveness.cc @@ -1,5 +1,5 @@ -#include "triton/codegen/analysis/shmem/liveness.h" -#include "triton/codegen/analysis/shmem/info.h" +#include "triton/codegen/analysis/liveness.h" +#include "triton/codegen/analysis/meminfo.h" #include "triton/ir/basic_block.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -9,7 +9,6 @@ namespace triton{ namespace codegen{ namespace analysis{ -namespace shmem{ // Entry point void liveness::run(ir::module &mod) { @@ -41,4 +40,3 @@ void liveness::run(ir::module &mod) { } } } -} diff --git a/lib/codegen/analysis/shmem/allocation.cc b/lib/codegen/analysis/memalloc.cc similarity index 93% rename from lib/codegen/analysis/shmem/allocation.cc rename to lib/codegen/analysis/memalloc.cc index 1061c0425..5f8a4d70b 100644 --- a/lib/codegen/analysis/shmem/allocation.cc +++ b/lib/codegen/analysis/memalloc.cc @@ -1,8 +1,8 @@ #include -#include "triton/codegen/analysis/shmem/allocation.h" -#include "triton/codegen/analysis/shmem/liveness.h" -#include "triton/codegen/analysis/shmem/info.h" -#include "triton/codegen/analysis/tune.h" +#include "triton/codegen/analysis/memalloc.h" +#include "triton/codegen/analysis/liveness.h" +#include "triton/codegen/analysis/meminfo.h" +#include "triton/codegen/analysis/grid.h" #include "triton/ir/basic_block.h" #include "triton/ir/type.h" #include "triton/ir/value.h" @@ -12,9 +12,8 @@ namespace triton{ namespace codegen{ namespace analysis{ -namespace shmem{ -unsigned allocation::is_ld_padded(ir::value *x) { +unsigned memalloc::is_ld_padded(ir::value *x) { if(auto *trans = dynamic_cast(x)){ if(trans->get_perm()[0]->get_value() != 0) return 4; @@ -46,7 +45,7 @@ unsigned allocation::is_ld_padded(ir::value *x) { return 0; } -unsigned allocation::get_num_bytes(ir::value *x) { +unsigned memalloc::get_num_bytes(ir::value *x) { if(auto *red = dynamic_cast(x)){ unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; size_t axis = red->get_axis(); @@ -74,7 +73,7 @@ unsigned allocation::get_num_bytes(ir::value *x) { return num_bytes; } -void allocation::run(){ +void memalloc::run(){ using std::max; using std::min; typedef std::multimap triples_map_type; @@ -178,4 +177,3 @@ void allocation::run(){ } } } -} diff --git a/lib/codegen/analysis/shmem/info.cc b/lib/codegen/analysis/meminfo.cc similarity index 91% rename from lib/codegen/analysis/shmem/info.cc rename to lib/codegen/analysis/meminfo.cc index d16048d3b..d0b075603 100644 --- a/lib/codegen/analysis/shmem/info.cc +++ b/lib/codegen/analysis/meminfo.cc @@ -1,5 +1,5 @@ #include -#include "triton/codegen/analysis/shmem/info.h" +#include "triton/codegen/analysis/meminfo.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -10,10 +10,9 @@ namespace triton { namespace codegen{ namespace analysis{ -namespace shmem{ // run pass on module -bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ +bool meminfo::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ if(phi->get_parent() != terminator->get_parent()) return false; if(auto *br = dynamic_cast(terminator)) @@ -25,7 +24,7 @@ bool info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ throw std::runtime_error("unreachable"); } -void info::replace(ir::value* before, ir::value *after) { +void meminfo::replace(ir::value* before, ir::value *after) { shared_.erase(before); shared_.insert(after); if(refs_.find(before) != refs_.end()){ @@ -72,7 +71,7 @@ void add_copy(ir::value *x, ir::builder &builder) { } } -void info::run(ir::module &mod) { +void meminfo::run(ir::module &mod) { // Add shared copies for(ir::function *fn: mod.get_function_list()){ ir::builder builder(mod.get_context()); @@ -122,15 +121,15 @@ void info::run(ir::module &mod) { } // query double-buffered status -bool info::is_double(ir::value *x) +bool meminfo::is_double(ir::value *x) { return double_.find(x) != double_.end(); } // query shared status -bool info::is_shared(ir::value *x) +bool meminfo::is_shared(ir::value *x) { return shared_.find(x) != shared_.end(); } // get reference if any -ir::value *info::get_reference(ir::value *x) +ir::value *meminfo::get_reference(ir::value *x) { return refs_[x]; } @@ -138,4 +137,3 @@ ir::value *info::get_reference(ir::value *x) } } } -} diff --git a/lib/codegen/selection/selection.cc b/lib/codegen/selection.cc similarity index 99% rename from lib/codegen/selection/selection.cc rename to lib/codegen/selection.cc index a44a4c926..ff246f4f5 100644 --- a/lib/codegen/selection/selection.cc +++ b/lib/codegen/selection.cc @@ -1,8 +1,8 @@ -#include "triton/codegen/selection/selection.h" -#include "triton/codegen/analysis/tune.h" -#include "triton/codegen/analysis/shmem/allocation.h" -#include "triton/codegen/selection/target.h" -#include "triton/codegen/analysis/alignment.h" +#include "triton/codegen/selection.h" +#include "triton/codegen/target.h" +#include "triton/codegen/analysis/grid.h" +#include "triton/codegen/analysis/memalloc.h" +#include "triton/codegen/analysis/align.h" #include "triton/ir/context.h" #include "triton/ir/module.h" #include "triton/ir/function.h" @@ -1304,10 +1304,7 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun unsigned id = linear / vector_size; if(linear % vector_size == 0) { Value *ptr = pointers->get_value(idx); -// ConstantInt *cst = nullptr; -// if(GetElementPtrInst *gep = dyn_cast(ptr)) -// if(gep->getNumIndices() == 1) -// cst = dyn_cast(gep->idx_begin()); + ptr = builder.CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), ptr->getType()->getPointerAddressSpace())); @@ -1326,23 +1323,28 @@ void selection::lower_masked_load(ir::masked_load_inst *x, LLVMContext &ctx, Fun ((PHINode*)current_result)->addIncoming(result_then, mask_then_bb); Value *result_false = false_values->get_value(idx); if(result_then->getType()->isVectorTy()) - result_false = builder.CreateVectorSplat(vector_size, result_false); + result_false = builder.CreateVectorSplat(vector_size, llvm::UndefValue::get(result_false->getType())); ((PHINode*)current_result)->addIncoming(result_false, current_bb); } else current_result = result_then; +// ConstantInt *cst = nullptr; +// if(GetElementPtrInst *gep = dyn_cast(ptr)) +// if(gep->getNumIndices() == 1) +// cst = dyn_cast(gep->idx_begin()); +// llvm::Value* mask = masks->get_value(idx); // std::string offset = ""; // if(cst) // offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size); // Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2); // Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); // FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false); -// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];"; -// if(false_value) +// std::string asm_str = "@$0 ld.global.nc.b32 {$1, $2, $3, $4}, [$5" + offset + "];"; +// if(false_values) // asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};"; // InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true); -// Value *result = builder.CreateCall(iasm, {mask, ptr}); +// Value *current_result = builder.CreateCall(iasm, {mask, ptr}); packets[id] = current_result; } @@ -1499,9 +1501,11 @@ void selection::run(ir::module &src, Module &dst) { for(auto attr_pair: fn->attrs()){ unsigned id = attr_pair.first; for(ir::attribute attr: attr_pair.second) - if(attr.is_llvm_attr()) + if(attr.is_llvm_attr()){ dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr)); + } } + tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn); // set metadata Metadata *md_args[] = { diff --git a/lib/codegen/selection/target.cc b/lib/codegen/target.cc similarity index 99% rename from lib/codegen/selection/target.cc rename to lib/codegen/target.cc index 3a5e35aa1..4116bcca7 100644 --- a/lib/codegen/selection/target.cc +++ b/lib/codegen/target.cc @@ -1,4 +1,4 @@ -#include "triton/codegen/selection/target.h" +#include "triton/codegen/target.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Function.h" #include "llvm/IR/Intrinsics.h" diff --git a/lib/codegen/transform/shmem/barriers.cc b/lib/codegen/transform/membar.cc similarity index 81% rename from lib/codegen/transform/shmem/barriers.cc rename to lib/codegen/transform/membar.cc index 6b66ab148..007263543 100644 --- a/lib/codegen/transform/shmem/barriers.cc +++ b/lib/codegen/transform/membar.cc @@ -2,9 +2,9 @@ #include #include -#include "triton/codegen/transform/shmem/barriers.h" -#include "triton/codegen/analysis/shmem/allocation.h" -#include "triton/codegen/analysis/shmem/info.h" +#include "triton/codegen/transform/membar.h" +#include "triton/codegen/analysis/memalloc.h" +#include "triton/codegen/analysis/meminfo.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -16,7 +16,7 @@ namespace triton { namespace codegen{ namespace transform{ -bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) { +bool membar::intersect(const interval_vec_t &X, interval_t x) { return std::any_of(X.begin(), X.end(), [&](const interval_t &y){ bool left_intersect = y.first <= x.first && x.first < y.second; bool right_intersect = y.first <= x.second && x.second < y.second; @@ -24,13 +24,13 @@ bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) { }); } -bool shmem_barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) { +bool membar::intersect(const interval_vec_t &X, const interval_vec_t &Y) { return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){ return intersect(X, y); }); } -void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){ +void membar::add_reference(ir::value *v, interval_vec_t &res){ if(buffer_info_->is_shared(v) && !dynamic_cast(v)){ unsigned offset = alloc_->get_offset(v); unsigned num_bytes = alloc_->get_num_bytes(v); @@ -38,17 +38,17 @@ void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){ } } -void shmem_barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){ +void membar::get_read_intervals(ir::instruction *i, interval_vec_t &res){ for(ir::value *op: i->ops()) add_reference(op, res); } -void shmem_barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){ +void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){ if(!dynamic_cast(i)) add_reference(i, res); } -void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) { +void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) { if(auto *phi = dynamic_cast(instr)) { std::set incoming; for(unsigned n = 0; n < phi->get_num_incoming(); n++){ @@ -67,16 +67,16 @@ void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder } } -shmem_barriers::interval_vec_t shmem_barriers::join(const std::vector& intervals) { - shmem_barriers::interval_vec_t result; +membar::interval_vec_t membar::join(const std::vector& intervals) { + membar::interval_vec_t result; for(auto x: intervals) for(interval_t i: x) result.push_back(i); return result; } -std::pair shmem_barriers::transfer(ir::basic_block *block, +std::pair membar::transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set& insert_loc) { @@ -104,7 +104,7 @@ std::pair rpo = ir::cfg::reverse_post_order(fn); diff --git a/lib/codegen/transform/reassociate.cc b/lib/codegen/transform/reassociate.cc index 532c8e186..b0f4a2e73 100644 --- a/lib/codegen/transform/reassociate.cc +++ b/lib/codegen/transform/reassociate.cc @@ -1,7 +1,8 @@ #include +#include #include "triton/codegen/transform/reassociate.h" -#include "triton/codegen/analysis/alignment.h" -#include "triton/codegen/analysis/tune.h" +#include "triton/codegen/analysis/align.h" +#include "triton/codegen/analysis/grid.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -161,7 +162,7 @@ ir::value *reassociate::reassociate_idx(ir::value *old_value, return new_value; } -reassociate::reassociate(analysis::alignment_info *align, analysis::grids* params) +reassociate::reassociate(analysis::align *align, analysis::grids* params) : params_(params), align_(align) { } @@ -209,6 +210,29 @@ void reassociate::run(ir::module &mod) { for(ir::basic_block *block: rpo){ // iterate through instruction for(ir::instruction *i: block->get_inst_list()){ + // retiling + if(ir::retile_inst *rt = dynamic_cast(i)) { + ir::value* op = rt->get_operand(0); + if(infos.find(op) != infos.end()){ + builder.set_insert_point(rt); + ir::getelementptr_inst* sta = infos.at(op).sta_ptr; + ir::value* dyn = infos.at(op).dyn_ptr; + ir::value* cst = *sta->idx_begin(); + if(dynamic_cast(rt)) { + auto shapes = rt->get_type()->get_tile_shapes(); + ir::value* ndyn = builder.create_broadcast(dyn, shapes); + ir::value* broadcast = builder.create_broadcast(cst, shapes); + ir::getelementptr_inst* nsta = (ir::getelementptr_inst*)builder.create_gep(ndyn, {broadcast}); + params_->copy(ndyn, rt); + params_->copy(nsta, rt); + params_->copy(broadcast, rt); + align_->copy(ndyn, rt); + align_->copy(nsta, rt); + align_->copy(broadcast, rt); + infos[rt] = cst_info{ndyn, nsta}; + } + } + } // getelementptr instruction if(ir::getelementptr_inst *pz = dynamic_cast(i)){ if(replaced.find(pz) != replaced.end()) diff --git a/lib/codegen/transform/vectorize.cc b/lib/codegen/transform/vectorize.cc index dbf7ee7f1..16309ffc5 100644 --- a/lib/codegen/transform/vectorize.cc +++ b/lib/codegen/transform/vectorize.cc @@ -1,5 +1,5 @@ #include "triton/codegen/transform/vectorize.h" -#include "triton/codegen/analysis/tune.h" +#include "triton/codegen/analysis/grid.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" diff --git a/lib/driver/device.cc b/lib/driver/device.cc index fceb2754e..3f82e2f33 100755 --- a/lib/driver/device.cc +++ b/lib/driver/device.cc @@ -27,7 +27,7 @@ #include #include "triton/driver/device.h" #include "triton/driver/context.h" -#include "triton/codegen/selection/target.h" +#include "triton/codegen/target.h" namespace triton { diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 5a9bfc86f..f41fdc0e5 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -223,12 +223,12 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { static_cast*>(options["nvptx-short-ptr"])->setValue(true); // create llvm::SmallVector buffer; - module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_70", "", buffer, "", Assembly); + module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_60", "", buffer, "", Assembly); std::string result(buffer.begin(), buffer.end()); size_t start_replace = result.find(".version"); size_t end_replace = result.find('\n', start_replace); assert(start_replace != std::string::npos); - result.replace(start_replace, end_replace - start_replace, ".version 6.4"); + result.replace(start_replace, end_replace - start_replace, ".version 6.0"); return result; } @@ -245,10 +245,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo try{ dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval); }catch(exception::cuda::base const &){ -#ifdef TRITON_LOG_PTX_ERROR +//#ifdef TRITON_LOG_PTX_ERROR std::cerr << "Compilation Failed! Log: " << std::endl; std::cerr << errbuf << std::endl; -#endif +//#endif throw; } } diff --git a/lib/ir/module.cc b/lib/ir/module.cc index 3d995558e..98f171252 100644 --- a/lib/ir/module.cc +++ b/lib/ir/module.cc @@ -29,6 +29,7 @@ void module::set_value(const std::string& name, ir::basic_block *block, ir::valu if(it != metadatas_.end()){ x->set_metadata(it->second.first, it->second.second); } + value->set_name(name); } void module::set_value(const std::string& name, ir::value *value){ diff --git a/lib/ir/print.cc b/lib/ir/print.cc index 31cc15d9a..124091262 100644 --- a/lib/ir/print.cc +++ b/lib/ir/print.cc @@ -22,6 +22,18 @@ std::string get_name(ir::value *v, unsigned i) { void print(module &mod, std::ostream& os) { unsigned cnt = 0; for(ir::function *fn: mod.get_function_list()){ + os << "def " << fn->get_fn_type()->get_return_ty()->repr() << " " << fn->get_name() << "(" ; + for(ir::argument* arg: fn->args()) { + if(arg->get_arg_no() > 0) + os << ", "; + os << arg->get_type()->repr() << " " << arg->get_name(); + auto attrs = fn->get_attributes(arg); + if(attrs.size() > 0) + os << " "; + for(ir::attribute attr: attrs) + os << attr.repr() << " "; + } + os << ")" << std::endl; os << "{" << std::endl; for(ir::basic_block *block: fn->blocks()){ auto const &predecessors = block->get_predecessors(); diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index c2c691cb5..228bd69dd 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -373,8 +373,11 @@ void Generator::VisitFuncDef(FuncDef* funcDef) { for(Object* obj: type->Params()){ std::string name = obj->Name(); args[i]->set_name(name); - for(ASTNode::Attr attr: obj->GetAttrList()) - fn->add_attr(i, GenIRAttr(attr)); + if(obj->Type()->ToPointer()) + fn->add_attr(i + 1, ir::attribute(ir::aligned, 16)); + for(ASTNode::Attr attr: obj->GetAttrList()){ + fn->add_attr(i + 1, GenIRAttr(attr)); + } if(obj->IsRestrictQualified()) fn->add_attr(i, ir::attribute(ir::noalias)); mod_->set_value(name, nullptr, args[i]); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 9b2072974..703918ba5 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -3,7 +3,7 @@ #include #include #include -#include "triton/codegen/selection/selection.h" +#include "triton/codegen/selection.h" #include "triton/runtime/function.h" #include "triton/lang/cpp.h" #include "triton/lang/parser.h" @@ -167,8 +167,6 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr bin = make_bin(*ir, stream->context(), opt); }catch(const std::runtime_error& e) { return; - }catch(const driver::exception::cuda::invalid_ptx& e) { - return; } // benchmark ir::function *tmp = ir->get_function_list()[0]; @@ -191,23 +189,31 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c std::unique_ptr target = context->device()->make_target(); // create passes codegen::analysis::grids grids(opt.num_warps); - codegen::analysis::shmem::info shmem_info; - codegen::analysis::shmem::liveness shmem_liveness(&shmem_info); - codegen::analysis::shmem::allocation shmem_allocation(&shmem_liveness, &shmem_info, &grids); - codegen::analysis::alignment_info alignment_info; - codegen::transform::shmem_barriers shmem_barriers(&shmem_allocation, &shmem_info); + codegen::analysis::meminfo shmem_info; + codegen::analysis::liveness shmem_liveness(&shmem_info); + codegen::analysis::memalloc shmem_allocation(&shmem_liveness, &shmem_info, &grids); + codegen::analysis::align alignment_info; + codegen::transform::membar shmem_barriers(&shmem_allocation, &shmem_info); codegen::transform::vectorize vectorize(&grids); codegen::transform::dce dce; codegen::transform::peephole peephole; codegen::transform::reassociate reassociate(&alignment_info, &grids); codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get()); + + // run passes peephole.run(module); dce.run(module); - grids.run(module); alignment_info.run(module); + grids.run(module); +// ir::print(module, std::cout); + reassociate.run(module); + dce.run(module); +// ir::print(module, std::cout); + peephole.run(module); + if(target->is_gpu()){ shmem_info.run(module); shmem_liveness.run(module); @@ -217,7 +223,8 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c dce.run(module); vectorize.run(module); dce.run(module); -// ir::print(module, std::cout); + + // generate llvm code llvm::LLVMContext ctx; std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); diff --git a/tests/bench/CMakeLists.txt b/tests/bench/CMakeLists.txt index 1f3cc3341..598dadeea 100644 --- a/tests/bench/CMakeLists.txt +++ b/tests/bench/CMakeLists.txt @@ -1,4 +1,4 @@ -foreach(PROG dot) +foreach(PROG dot copy1d copy2d) set(TARGET bench_${PROG}) add_executable(${TARGET} ${PROG}.cc) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index cb678ff99..3fecb8e58 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -17,7 +17,7 @@ inline size_t ceil(size_t x, size_t y) { return (x + y - 1) / y; }; -inline rt::function::grid_fn_ty grid(size_t M, size_t N) { +inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { return [M, N](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")), ceil(N, x.D("TN"))}; @@ -42,11 +42,9 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i // create options rt::function::options_space_t opt; opt.defines.push_back({"TYPE", {ty}}); - if(AT) - opt.defines.push_back({"AT", {""}}); - if(BT) - opt.defines.push_back({"BT", {""}}); - opt.defines.push_back({"TM", {"64"}}); + opt.defines.push_back({"AT", {AT?"1":"0"}}); + opt.defines.push_back({"BT", {BT?"1":"0"}}); + opt.defines.push_back({"TM", {"128"}}); opt.defines.push_back({"TN", {"64"}}); opt.defines.push_back({"TK", {"8"}}); opt.num_warps = {4}; @@ -55,18 +53,18 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i // benchmark available libraries std::vector result; auto tflops = [&](double nanosec) { return 2.*M*N*K / nanosec * 1e-3; }; -// // cublas -// if(cublas::cublasinit()){ -// NumericT alpha(static_cast(1)); -// NumericT beta(static_cast(0)); -// cublasGemmAlgo_t fastest; -// cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); -// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, -// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream); -// result.push_back(tflops(cublas_ms)); -// } + // cublas + if(cublas::cublasinit()){ + NumericT alpha(static_cast(1)); + NumericT beta(static_cast(0)); + cublasGemmAlgo_t fastest = CUBLAS_GEMM_ALGO5; +// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); + double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, + &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream); + result.push_back(tflops(cublas_ms)); + } // triton - double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream);}, stream); + double triton_ms = triton::tools::bench([&]() { function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream);}, stream); result.push_back(tflops(triton_ms)); // done return result; @@ -79,11 +77,9 @@ int main() { // shapes to benchmark typedef std::tuple config_t; std::vector configs; - for(auto x: std::vector>{{false, false}, - {false, true}, - {true, false}}){ + for(auto x: std::vector>{{false, true}}){ std::vector tmp = { - config_t{x[0], x[1], 8192, 8192, 8192} + config_t{x[0], x[1], 2048, 2048, 2048} // config_t{x[0], x[1], 16, 2048, 2048}, // config_t{x[0], x[1], 32, 2048, 2048}, // config_t{x[0], x[1], 64, 2048, 2048}, diff --git a/tests/common/cuda/forward.h b/tests/common/cuda/forward.h index 1c12c4247..bd32adec6 100644 --- a/tests/common/cuda/forward.h +++ b/tests/common/cuda/forward.h @@ -24,11 +24,11 @@ typedef enum{ typedef enum { CUBLAS_GEMM_DFALT = -1, CUBLAS_GEMM_DEFAULT = -1, - CUBLAS_GEMM_ALGO0 = 0, - CUBLAS_GEMM_ALGO1 = 1, - CUBLAS_GEMM_ALGO2 = 2, - CUBLAS_GEMM_ALGO3 = 3, - CUBLAS_GEMM_ALGO4 = 4, + CUBLAS_GEMM_ALGO0 = 0, // maxwell_sgemm_32x128_nt + CUBLAS_GEMM_ALGO1 = 1, // maxwell_sgemm_64x64_nt + CUBLAS_GEMM_ALGO2 = 2, // maxwell_sgemm_128x32_nt + CUBLAS_GEMM_ALGO3 = 3, // maxwell_sgemm_128x64_nt + CUBLAS_GEMM_ALGO4 = 4, // maxwell_sgemm_128x128_nt CUBLAS_GEMM_ALGO5 = 5, CUBLAS_GEMM_ALGO6 = 6, CUBLAS_GEMM_ALGO7 = 7, @@ -102,4 +102,4 @@ typedef enum { CUBLAS_TENSOR_OP_MATH = 1 } cublasMath_t; -#endif \ No newline at end of file +#endif diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index 3e636e18a..c9b3454d7 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -2,33 +2,33 @@ namespace src { const char *dot = R"( -#ifdef AT +#if AT == 1 #define USEA ^a -#define STRIDE_AK lda -#define STRIDE_AM 1 +#define STRIDE_AK 1 +#define STRIDE_AM lda #define BROADCAST_AK :, newaxis #define BROADCAST_AM newaxis, : #define SHAPE_A TK, TM #else #define USEA a -#define STRIDE_AK 1 -#define STRIDE_AM lda +#define STRIDE_AK lda +#define STRIDE_AM 1 #define BROADCAST_AK newaxis, : #define BROADCAST_AM :, newaxis #define SHAPE_A TM, TK #endif -#ifdef BT +#if BT == 1 #define USEB ^b -#define STRIDE_BK 1 -#define STRIDE_BN ldb +#define STRIDE_BK ldb +#define STRIDE_BN 1 #define BROADCAST_BK newaxis, : #define BROADCAST_BN :, newaxis #define SHAPE_B TN, TK #else #define USEB b -#define STRIDE_BK ldb -#define STRIDE_BN 1 +#define STRIDE_BK 1 +#define STRIDE_BN ldb #define BROADCAST_BK :, newaxis #define BROADCAST_BN newaxis, : #define SHAPE_B TK, TN @@ -58,17 +58,15 @@ void dot(TYPE * A, TYPE * B, TYPE * C, c += USEA @ USEB; pa = pa + TK * STRIDE_AK; pb = pb + TK * STRIDE_BK; - a = *pa; - b = *pb; + a = ((bool[SHAPE_A])(k > TK)) ? *pa : 0; + b = ((bool[SHAPE_B])(k > TK)) ? *pb : 0; } // epilogue int rxc[TM] = ridx * TM + 0 ... TM; int ryc[TN] = ridy * TN + 0 ... TN; - TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc; - bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :]; - *?(checkc) pc = c; + TYPE* pc[TM, TN] = C + ryc[newaxis, :] * ldc + rxc[:, newaxis]; + *pc = c; } - )"; } diff --git a/tests/common/util.h b/tests/common/util.h index a60050af7..d8ffef090 100644 --- a/tests/common/util.h +++ b/tests/common/util.h @@ -3,21 +3,35 @@ #ifndef _TRITON_TESTS_UTIL_H #define _TRITON_TESTS_UTIL_H +#include #include "triton/runtime/function.h" +namespace drv = triton::driver; namespace rt = triton::runtime; inline size_t ceil(size_t x, size_t y) { return (x + y - 1) / y; }; -inline rt::function::grid_fn_ty grid(size_t M, size_t N) { +inline rt::function::grid_fn_ty grid1d(size_t N) { + return [N](const rt::function::options_t& x) { + return rt::grid_t{ceil(N, x.D("TN"))}; + }; +} + +inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { return [M, N](const rt::function::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")), ceil(N, x.D("TN"))}; }; } +enum order_t { + ROWMAJOR, + COLMAJOR +}; + + namespace aux{ template struct seq{}; @@ -51,11 +65,14 @@ namespace testing { if(hc.size() != rc.size()) return false; for(size_t i = 0; i < hc.size(); i++) - if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2) + if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ + std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; + return false; + } return true; } } -#endif \ No newline at end of file +#endif diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index f3cdae9a1..78fbc79d1 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -1,5 +1,5 @@ foreach(PROG dot) - set(TARGET test_${PROG}) + set(TARGET unit_${PROG}) add_executable(${TARGET} ${PROG}.cc) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) target_link_libraries(${TARGET} triton dl) diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index 3c9ec96d8..69b8cf2d7 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -51,8 +51,8 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K, bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){ - typedef half_float::half NumericT; - std::string ty = "half"; + typedef float NumericT; + std::string ty = "float"; size_t dt_nbytes = sizeof(NumericT); drv::context* context = stream->context(); std::vector hc(M*N); @@ -78,17 +78,15 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_ // run rt::function::options_space_t opt; opt.defines.push_back({"TYPE", {ty}}); - if(AT) - opt.defines.push_back({"AT", {""}}); - if(BT) - opt.defines.push_back({"BT", {""}}); + opt.defines.push_back({"AT", {AT?"1":"0"}}); + opt.defines.push_back({"BT", {BT?"1":"0"}}); opt.defines.push_back({"TM", {std::to_string(TM)}}); opt.defines.push_back({"TN", {std::to_string(TN)}}); opt.defines.push_back({"TK", {std::to_string(TK)}}); opt.num_warps = {nwarp}; rt::function function(src::dot, opt); try { - function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid(M, N), stream); + function({&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}, grid2d(M, N), stream); } catch (const std::runtime_error& e) { return true; }