diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 9b52020f6..0a8e1e948 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -30,7 +30,7 @@ void matmul(restrict read_only align(4) fp16 *A, restrict read_only align(4) fp16 *B, align(4) fp32 *C, int32 M, int32 N, int32 K, - int32 lda, int32 ldb, int32 ldc, + multiple_of(4) int32 lda, multiple_of(4) int32 ldb, multiple_of(4) int32 ldc, int32 *locks, int32 grid0, int32 grid1) { int32 rxa[TM] = get_global_range[TM](0); int32 ryb[TN] = get_global_range[TN](1); diff --git a/include/triton/codegen/axis_info.h b/include/triton/codegen/axis_info.h new file mode 100644 index 000000000..bfc4ef322 --- /dev/null +++ b/include/triton/codegen/axis_info.h @@ -0,0 +1,39 @@ +#ifndef TDL_INCLUDE_CODEGEN_AXIS_INFO_PASS_H +#define TDL_INCLUDE_CODEGEN_AXIS_INFO_PASS_H + +#include +#include + +namespace triton { + +namespace ir { + class value; + class module; +} + +namespace codegen{ + +class axis_info { +private: + // helpers + bool is_first_axis_unit(ir::value *x); + + // populate maps + bool populate_is_constant(ir::value *i); + unsigned populate_max_contiguous(ir::value *i); + unsigned populate_multiple_of(ir::value *i); + +public: + void run(ir::module &mod); + +private: + std::map is_constant_; + std::map max_contiguous_; + std::map multiple_of_; +}; + + +} +} + +#endif diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index cb1ab1f6d..c5f5f0605 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -5,6 +5,7 @@ #include #include "value.h" #include "constant.h" +#include namespace triton{ namespace ir{ @@ -21,6 +22,8 @@ class argument: public value{ public: static argument* create(type *ty, const std::string &name, function *parent = nullptr, unsigned arg_no = 0); + function* get_parent() const; + unsigned get_arg_no() const; private: function *parent_; @@ -53,6 +56,10 @@ public: return value_; } + bool is_llvm_attr() const { + return kind_ != multiple_of; + } + private: attribute_kind_t kind_; unsigned value_; @@ -89,6 +96,7 @@ public: // attributes void add_attr(unsigned arg_id, attribute attr) { attrs_[arg_id].insert(attr); } const attr_map_t &attrs() { return attrs_; } + std::set get_attributes(argument* arg) { return attrs_[arg->get_arg_no() + 1]; } private: module *parent_; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 8e1da57ae..397be9d9d 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -122,6 +122,12 @@ public: bool is_int_div_rem() const; bool is_shift() const; bool is_cast() const; + bool is_int_mult() const; + bool is_int_add_sub() const; + bool is_int_div() const; + bool is_int_rem() const; + bool is_shl() const; + bool is_shr() const; // Wraps void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; } diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index a9bea664b..c12bb6b23 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -17,6 +17,7 @@ #include "triton/codegen/shmem_liveness.h" #include "triton/codegen/shmem_info.h" #include "triton/codegen/shmem_barriers.h" +#include "triton/codegen/axis_info.h" #include "triton/codegen/target.h" #include "triton/codegen/vectorize.h" #include @@ -60,11 +61,13 @@ public: optimize_dot(&tune), optimize_cse(), optimize_trans(), + axis_info(), target_(target) { } void target_independent(ir::module &module) { optimize_dot.run(module); optimize_trans.run(module); + axis_info.run(module); // ir::print(module, std::cout); } @@ -88,6 +91,7 @@ public: codegen::optimize_dot optimize_dot; codegen::optimize_cse optimize_cse; codegen::optimize_trans optimize_trans; + codegen::axis_info axis_info; codegen::target* target_; }; diff --git a/lib/codegen/axis_info.cpp b/lib/codegen/axis_info.cpp new file mode 100644 index 000000000..38e2fcd9b --- /dev/null +++ b/lib/codegen/axis_info.cpp @@ -0,0 +1,129 @@ +#include "triton/codegen/axis_info.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/instructions.h" +#include "triton/ir/type.h" + +namespace triton { +namespace codegen{ + + +template +inline T add_to_cache(ir::value *i, T value, std::map &map) { + return map.insert(std::make_pair(i, value)).first->second; +} + + +bool axis_info::is_first_axis_unit(ir::value *x){ + if(x->get_type()->is_tile_ty()) + return x->get_type()->get_tile_shapes()[0]->get_value() == 1; + else + return true; +} + +bool axis_info::populate_is_constant(ir::value *v) { + // helper for the cache + auto cache = [this,v](bool value){ return add_to_cache(v, value, is_constant_); }; + // populate + if(v->get_type()->is_tile_ty()){ + if(auto *x = dynamic_cast(v)){ + bool value = populate_is_constant(x->get_operand(0)); + // check if broadcast (i.e., constant) along contiguous dimension + if(is_first_axis_unit(x->get_operand(0)) + && !is_first_axis_unit(x)) + return cache(value); + } + // otherwise the tile is not constant in the contiguous dimension + return cache(false); + } + // scalars are always constant in the contiguous dimension + return cache(true); +} + +unsigned axis_info::populate_max_contiguous(ir::value *v){ + // helper for the cache + auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); }; + // populate + if(v->get_type()->is_tile_ty()){ + auto shapes = v->get_type()->get_tile_shapes(); + if(dynamic_cast(v)) + return cache(shapes[0]->get_value()); + if(auto *x = dynamic_cast(v)){ + ir::value* lhs = x->get_operand(0); + ir::value* rhs = x->get_operand(1); + unsigned lhs_max_contiguous = populate_max_contiguous(lhs); + bool lhs_has_cst = populate_is_constant(lhs); + unsigned rhs_max_contiguous = populate_max_contiguous(rhs); + bool rhs_has_cst = populate_is_constant(rhs); + if(x->is_int_add_sub()){ + if(lhs_has_cst) + return cache(rhs_max_contiguous); + if(rhs_has_cst) + return cache(lhs_max_contiguous); + } + } + } + return cache(1); +} + +unsigned axis_info::populate_multiple_of(ir::value *v){ + auto cache = [this,v](unsigned value){ return add_to_cache(v, value, max_contiguous_); }; + + if(auto *x = dynamic_cast(v)){ + std::set attributes = x->get_parent()->get_attributes(x); + for(auto attr: attributes){ + if(attr.get_kind() == ir::multiple_of) + return cache(attr.get_value()); + } + } + if(auto *x = dynamic_cast(v)){ + int lhs = populate_multiple_of(x->get_operand(0)); + int rhs = populate_multiple_of(x->get_operand(1)); + if(x->is_int_mult()) + return cache(lhs * rhs); + if(x->is_int_add_sub()) + return cache(std::min(lhs, rhs)); + if(x->is_int_div()) + return cache(std::max(lhs / rhs, 1)); + if(x->is_int_rem()) + return cache(std::max(lhs % rhs, 1)); + if(x->is_shl()) + return cache(lhs << rhs); + if(x->is_shr()) + return cache(std::max(lhs >> rhs, 1)); + } + if(auto *x = dynamic_cast(v)){ + return cache(populate_multiple_of(x->get_operand(0))); + } + return cache(1); +} + + + +void axis_info::run(ir::module &mod) { + // populate constant + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i: block->get_inst_list()){ + populate_is_constant(i); + } + + // populate multiple_of + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i: block->get_inst_list()){ + populate_multiple_of(i); + } + + // populate maximum contiguous + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i: block->get_inst_list()){ + populate_max_contiguous(i); + } +} + + +} +} diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index fed49407f..ae2a3f1c1 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -1125,6 +1125,7 @@ void selection::run(ir::module &src, Module &dst) { for(auto attr_pair: fn->attrs()){ unsigned id = attr_pair.first; for(ir::attribute attr: attr_pair.second) + if(attr.is_llvm_attr()) dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr)); } tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn); diff --git a/lib/ir/function.cpp b/lib/ir/function.cpp index 758fd8bc3..5c7ca1e2a 100644 --- a/lib/ir/function.cpp +++ b/lib/ir/function.cpp @@ -16,6 +16,15 @@ argument *argument::create(type *ty, const std::string &name, return new argument(ty, name, parent, arg_no); } +function* argument::get_parent() const { + return parent_; +} + +unsigned argument::get_arg_no() const { + return arg_no_; +} + + /* function */ function::function(function_type *ty, linkage_types_t linkage, const std::string &name, module *parent) diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index 79c951f6d..9b5d37094 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -109,6 +109,31 @@ std::string binary_operator::repr_impl() const { } } +bool binary_operator::is_int_div() const { + return op_ == llop::UDiv || op_ == llop::SDiv; +} + +bool binary_operator::is_int_rem() const { + return op_ == llop::URem || op_ == llop::SRem; +} + +bool binary_operator::is_shl() const { + return op_ == llop::Shl; +} + +bool binary_operator::is_shr() const { + return op_ == llop::LShr || op_ == llop::AShr; +} + +bool binary_operator::is_int_mult() const { + return op_ == llop::Mul; +} + +bool binary_operator::is_int_add_sub() const { + return op_ == llop::Add || llop::Sub; +} + + binary_operator::binary_operator(op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next) : instruction(ty, 2, 1, name, next), op_(op){ set_operand(0, lhs);