diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index f93b53886..da6399573 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -178,7 +178,7 @@ public: class machine_layout_distributed_t: public machine_layout_t { public: machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, - std::map& axes, + analysis::axes *a_axes, std::map& axes, analysis::layout_t* layout); tile* create(ir::value *v); @@ -186,6 +186,7 @@ public: Builder *builder_; target *tgt_; Type *ty_; + analysis::axes *a_axes_; std::map& axes_; analysis::layout_t* layout_; }; @@ -195,7 +196,7 @@ class machine_layout_hmma_884_t: public machine_layout_distributed_t { public: machine_layout_hmma_884_t(Module *mod, Builder *builder, target *tgt, Type *ty, - std::map& axes, + analysis::axes *a_axes, std::map& axes, analysis::layout_hmma_884_t* layout); Value *offset_a_i_, *offset_a_k_; Value *offset_b_j_, *offset_b_k_; @@ -209,7 +210,7 @@ class machine_layout_scanline_t: public machine_layout_distributed_t { public: machine_layout_scanline_t(Module *mod, Builder *builder, target *tgt, Type *ty, - std::map& axes, + analysis::axes *a_axes, std::map& axes, analysis::layout_scanline_t* layout); }; @@ -230,6 +231,7 @@ private: public: generator(Module *dst, + analysis::axes *a_axes, target *tgt, analysis::layout *layouts, analysis::align *alignment, @@ -298,6 +300,7 @@ private: Module *mod_; std::map machine_layouts_; + analysis::axes *a_axes_; std::map axes_; std::map vmap_; std::map tmap_; @@ -319,10 +322,10 @@ class selection{ public: selection(analysis::liveness* liveness, analysis::allocation *alloc, - analysis::align *alignment, + analysis::align *alignment, analysis::axes *axes, analysis::layout *layouts, target *tgt, unsigned num_warps) : liveness_(liveness), alloc_(alloc), - alignment_(alignment), layouts_(layouts), + alignment_(alignment), a_axes_(axes), layouts_(layouts), tgt_(tgt), num_warps_(num_warps){ } void run(ir::module &src, Module &dst); @@ -330,6 +333,7 @@ public: private: analysis::liveness *liveness_; analysis::allocation *alloc_; + analysis::axes *a_axes_; analysis::layout *layouts_; analysis::align *alignment_; target *tgt_; diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h new file mode 100644 index 000000000..76ec88b90 --- /dev/null +++ b/include/triton/codegen/selection/generator.h @@ -0,0 +1,169 @@ +#pragma once + +#ifndef _TRITON_SELECTION_GENERATOR_H_ +#define _TRITON_SELECTION_GENERATOR_H_ + +#include "triton/ir/visitor.h" +#include "triton/codegen/analysis/layout.h" +#include "triton/codegen/selection/machine_value.h" +#include + +// forward +namespace llvm{ + class Type; + class Value; + class Instruction; + class Constant; + class LLVMContext; + class Module; + class ConstantFolder; + class IRBuilderDefaultInserter; + template + class IRBuilder; + class ArrayType; + class Function; +} + +namespace triton{ +namespace codegen{ + +// forward +namespace analysis{ +class liveness; +class tiles; +class align; +class allocation; +class cts; +class axes; +class layout; +} +// typedef +typedef llvm::IRBuilder Builder; +typedef llvm::LLVMContext LLVMContext; +typedef llvm::Type Type; +typedef llvm::Value Value; +typedef llvm::Module Module; +typedef llvm::Instruction Instruction; +typedef llvm::Constant Constant; +typedef llvm::ArrayType ArrayType; +typedef llvm::Function Function; +typedef std::vector indices_t; +// forward +class machine_layout_t; +class tile; +class shared_tile; +class distributed_tile; +class target; + +} +} + +namespace triton{ +namespace codegen{ + + +class generator: public ir::visitor, public analysis::layout_visitor { +private: + void for_each(ir::value *x, const std::function& fn); + Value* get_value(ir::value *x, const indices_t& idx); + void set_value(ir::value *x, const indices_t& idx, Value* v); + + void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK); + void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); + void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, + Type *c_ty, Function *f_mul_add); + + void finalize_shared_layout(analysis::layout_shared_t*); + void finalize_function(ir::function*); + void finalize_phi_node(ir::phi_node*); + +public: + generator(Module *dst, + analysis::axes *a_axes, + target *tgt, + analysis::layout *layouts, + analysis::align *alignment, + analysis::allocation *alloc, + unsigned num_warps); + + void visit_value(ir::value* v); + + void visit_phi_node(ir::phi_node*); + void visit_binary_operator(ir::binary_operator*); + void visit_getelementptr_inst(ir::getelementptr_inst*); + + void visit_icmp_inst(ir::icmp_inst*); + void visit_fcmp_inst(ir::fcmp_inst*); + void visit_cast_inst(ir::cast_inst*); + + void visit_return_inst(ir::return_inst*); + void visit_cond_branch_inst(ir::cond_branch_inst*); + void visit_uncond_branch_inst(ir::uncond_branch_inst*); + + + void visit_unmasked_load_inst(ir::unmasked_load_inst*); + void visit_masked_load_inst(ir::masked_load_inst*); + void visit_unmasked_store_inst(ir::unmasked_store_inst*); + void visit_masked_store_inst(ir::masked_store_inst*); + + void visit_reshape_inst(ir::reshape_inst*); + void visit_splat_inst(ir::splat_inst*); + void visit_broadcast_inst(ir::broadcast_inst*); + void visit_downcast_inst(ir::downcast_inst*); + + void visit_get_program_id_inst(ir::get_program_id_inst*); + void visit_get_num_program_inst(ir::get_num_program_inst*); + void visit_atomic_cas_inst(ir::atomic_cas_inst*); + void visit_atomic_exch_inst(ir::atomic_exch_inst*); + void visit_atomic_add_inst(ir::atomic_add_inst*); + void visit_dot_inst(ir::dot_inst*); + void visit_trans_inst(ir::trans_inst*); + void visit_sqrt_inst(ir::sqrt_inst*); + void visit_reduce_inst(ir::reduce_inst*); + void visit_select_inst(ir::select_inst*); + + void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); + void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); + void visit_barrier_inst(ir::barrier_inst*); + void visit_make_range_dyn(ir::make_range_dyn*); + void visit_make_range(ir::make_range*); + + void visit_make_range_sta(ir::make_range_sta*); + void visit_undef_value(ir::undef_value*); + void visit_constant_int(ir::constant_int*); + void visit_constant_fp(ir::constant_fp*); + void visit_alloc_const(ir::alloc_const*); + + void visit_function(ir::function*); + void visit_basic_block(ir::basic_block*); + void visit_argument(ir::argument*); + + void visit_layout_hmma_884(analysis::layout_hmma_884_t*); + void visit_layout_scanline(analysis::layout_scanline_t*); + void visit_layout_shared(analysis::layout_shared_t*); + +private: + LLVMContext *ctx_; + Builder* builder_; + Module *mod_; + + std::map machine_layouts_; + analysis::axes *a_axes_; + std::map axes_; + std::map vmap_; + std::map tmap_; + target *tgt_; + analysis::layout *layouts_; + analysis::align *alignment_; + analysis::allocation *alloc_; + Value *sh_mem_ptr_; + unsigned num_warps_; + + std::set seen_; +}; + +} +} + +#endif diff --git a/include/triton/codegen/selection/machine_layout.h b/include/triton/codegen/selection/machine_layout.h new file mode 100644 index 000000000..a3b453995 --- /dev/null +++ b/include/triton/codegen/selection/machine_layout.h @@ -0,0 +1,138 @@ +#pragma once + +#ifndef _TRITON_SELECTION_MACHINE_LAYOUT_H_ +#define _TRITON_SELECTION_MACHINE_LAYOUT_H_ + +#include +#include "triton/codegen/analysis/layout.h" + +namespace llvm{ + class Type; + class Value; + class Instruction; + class Constant; + class LLVMContext; + class Module; + class ConstantFolder; + class IRBuilderDefaultInserter; + template + class IRBuilder; + class ArrayType; + class Function; +} + +namespace triton{ + +namespace ir{ +class value; +} + +namespace codegen{ + +namespace analysis{ +class liveness; +class tiles; +class align; +class allocation; +class cts; +class axes; +class layout; +} + +typedef llvm::IRBuilder Builder; +typedef llvm::LLVMContext LLVMContext; +typedef llvm::Type Type; +typedef llvm::Value Value; +typedef llvm::Module Module; +typedef llvm::Instruction Instruction; +typedef llvm::Constant Constant; +typedef llvm::ArrayType ArrayType; +typedef llvm::Function Function; + +class distributed_axis; +class machine_layout_t; +class tile; +class shared_tile; +class distributed_tile; +class target; + +} +} + +namespace triton{ +namespace codegen{ + + +class machine_layout_t { +public: + virtual tile* create(ir::value *v) = 0; +}; + +class machine_layout_shared_t: public machine_layout_t { +public: + machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, analysis::layout_t* layout, + std::map& vmap, + std::map& tmap); + + tile* create(ir::value *v); + + Module *mod_; + Builder *builder_; + target *tgt_; + analysis::allocation* alloc_; + Value *&sh_mem_ptr_; + analysis::layout_t* layout_; + std::map& vmap_; + std::map& tmap_; + + Value *offset_; + Value *ptr_; + Value *pre_ptr_; + Value *next_ptr_; + +}; + +class machine_layout_distributed_t: public machine_layout_t { +public: + machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, + analysis::axes *a_axes, std::map& axes, + analysis::layout_t* layout); + + tile* create(ir::value *v); + Module *mod_; + Builder *builder_; + target *tgt_; + Type *ty_; + analysis::axes *a_axes_; + std::map& axes_; + analysis::layout_t* layout_; +}; + + +class machine_layout_hmma_884_t: public machine_layout_distributed_t { +public: + machine_layout_hmma_884_t(Module *mod, Builder *builder, + target *tgt, Type *ty, + analysis::axes *a_axes, std::map& axes, + analysis::layout_hmma_884_t* layout); + Value *offset_a_i_, *offset_a_k_; + Value *offset_b_j_, *offset_b_k_; + unsigned pack_size_0_; + unsigned pack_size_1_; + unsigned num_packs_0_; + unsigned num_packs_1_; +}; + +class machine_layout_scanline_t: public machine_layout_distributed_t { +public: + machine_layout_scanline_t(Module *mod, Builder *builder, + target *tgt, Type *ty, + analysis::axes *a_axes, std::map& axes, + analysis::layout_scanline_t* layout); +}; + +} +} + +#endif diff --git a/include/triton/codegen/selection/machine_value.h b/include/triton/codegen/selection/machine_value.h new file mode 100644 index 000000000..508881fd3 --- /dev/null +++ b/include/triton/codegen/selection/machine_value.h @@ -0,0 +1,153 @@ +#pragma once + +#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_ +#define _TRITON_SELECTION_MACHINE_VALUE_H_ + +#include +#include +#include + +namespace llvm{ + class Type; + class Value; + class Instruction; + class Constant; + class LLVMContext; + class Module; + class ConstantFolder; + class IRBuilderDefaultInserter; + template + class IRBuilder; + class ArrayType; + class Function; +} + +namespace triton{ +namespace codegen{ + typedef llvm::IRBuilder Builder; + typedef llvm::LLVMContext LLVMContext; + typedef llvm::Type Type; + typedef llvm::Value Value; + typedef llvm::Module Module; + typedef llvm::Instruction Instruction; + typedef llvm::Constant Constant; + typedef llvm::ArrayType ArrayType; + typedef llvm::Function Function; +} +} + +namespace triton{ +namespace codegen{ + +namespace analysis{ +class liveness; +class tiles; +class align; +class allocation; +class cts; +class axes; +class layout; +} + +class distributed_axis; +class machine_layout_t; +class tile; +class shared_tile; +class distributed_tile; +class target; +typedef std::vector indices_t; + +} +} + +namespace triton{ +namespace codegen{ + +struct distributed_axis { + int contiguous; + std::vector values; + Value* thread_id; +}; + +class tile { +protected: + typedef std::vector shapes_t; + +public: + tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ } + virtual void set_value(indices_t idx, Value *v) = 0; + virtual Value* get_value(indices_t idx) = 0; + Type *get_ty() const { return ty_; } + shapes_t get_shapes() const { return shapes_; } + +protected: + Type *ty_; + shapes_t shapes_; +}; + +class shared_tile: public tile { +private: + void extract_constant(Value *arg, Value *&non_cst, Value *&cst); + void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx); + + +public: + shared_tile(Type* ty, const shapes_t &shapes, const std::vector &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector& perm = {}); + void set_vector_size(unsigned vector_size); + void set_return_mode(bool return_vector); + void set_value(indices_t, Value *); + Value* get_ptr_to(indices_t idx); + Value* get_value(indices_t idx); + Value* get_pointer() { return ptr_; } + Value* get_offset() { return offset_; } + const std::vector& get_perm() { return perm_; } + const std::vector& get_order() { return order_; } + static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector& perm, const std::vector& order, indices_t idx); + +private: + Value *ptr_; + bool return_vector_; + Builder &builder_; + Value *offset_; + std::map ptr_cache_; + unsigned vector_size_; + std::vector order_; + std::vector perm_; +}; + +// Distribtued tile +class distributed_tile: public tile{ + typedef std::vector axes_t; + typedef std::vector ordered_indices_vec_t; + typedef std::map indices_map_t; + typedef std::map values_map_t; + +private: + void init_indices(); + Type *make_vector_ty(Type *ty, size_t vector_size); + +public: + distributed_tile(Type *ty, const shapes_t& shapes, const std::vector& order, const axes_t &axes, Builder &builder, bool vectorize); + void set_value(indices_t idx, Value *v); + Value* get_value(indices_t idx); + const std::vector& get_order() { return order_; } + unsigned get_linear_index(indices_t idx); + indices_t get_ordered_indices(unsigned id); + void for_each(std::function fn); + const distributed_axis &axis(unsigned dim) { return axes_.at(dim); } + +private: + axes_t axes_; + std::vector order_; + indices_map_t indices_; + values_map_t values_; + ordered_indices_vec_t ordered_indices_; + size_t vector_size_; + Builder &builder_; +}; + +} +} + +#endif diff --git a/include/triton/codegen/selection/selection.h b/include/triton/codegen/selection/selection.h new file mode 100644 index 000000000..a2b88247f --- /dev/null +++ b/include/triton/codegen/selection/selection.h @@ -0,0 +1,70 @@ +#pragma once + +#ifndef _TRITON_SELECTION_SELECTION_H_ +#define _TRITON_SELECTION_SELECTION_H_ + +#include + +namespace llvm{ + class Module; + class Value; +} + + +namespace triton{ + +namespace ir{ +class value; +class module; +} + +namespace codegen{ +// typedef +typedef llvm::Module Module; +typedef llvm::Value Value; +// forward +namespace analysis{ +class liveness; +class align; +class allocation; +class axes; +class layout; +} +class target; +class tile; + +} +} + +namespace triton{ +namespace codegen{ + +// Selection pass +class selection{ + typedef std::map vmap_t; + typedef std::map tmap_t; + +public: + selection(analysis::liveness* liveness, analysis::allocation *alloc, + analysis::align *alignment, analysis::axes *axes, + analysis::layout *layouts, target *tgt, unsigned num_warps) + : liveness_(liveness), alloc_(alloc), + alignment_(alignment), a_axes_(axes), layouts_(layouts), + tgt_(tgt), num_warps_(num_warps){ } + + void run(ir::module &src, Module &dst); + +private: + analysis::liveness *liveness_; + analysis::allocation *alloc_; + analysis::axes *a_axes_; + analysis::layout *layouts_; + analysis::align *alignment_; + target *tgt_; + unsigned num_warps_; +}; + +} +} + +#endif diff --git a/lib/codegen/selection.cc b/lib/codegen/selection/generator.cc similarity index 69% rename from lib/codegen/selection.cc rename to lib/codegen/selection/generator.cc index a4daa2d50..831719b73 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection/generator.cc @@ -1,9 +1,8 @@ -#include -#include "triton/codegen/selection.h" +#include +#include "triton/codegen/selection/generator.h" +#include "triton/codegen/selection/machine_layout.h" +#include "triton/codegen/selection/machine_value.h" #include "triton/codegen/target.h" -#include "triton/codegen/analysis/liveness.h" -#include "triton/codegen/analysis/layout.h" -#include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/coalesce.h" @@ -12,12 +11,8 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/type.h" -#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/Transforms/Scalar/EarlyCSE.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/InlineAsm.h" @@ -27,198 +22,6 @@ namespace codegen{ using namespace llvm; -/* Distributed Tile */ -void distributed_tile::init_indices() { - std::vector id(axes_.size(), 0); - // create iteration order - std::vector order(id.size()); - std::iota(order.begin(), order.end(), 0); - auto cmp = [&](int x, int y) { - return axes_[x].contiguous > axes_[y].contiguous; - }; - std::sort(order.begin(), order.end(), cmp); - // build - size_t k = 0; - while(true) { - indices_t current; - for(size_t d = 0; d < id.size(); d++) - current.push_back(axes_[d].values[id[d]]); - size_t sz = indices_.size(); - indices_[current] = sz; - values_[current] = nullptr; - ordered_indices_.push_back(current); - id[order[0]]++; - while(id[order[k]] == axes_[order[k]].values.size()){ - if(k == id.size() - 1) - return; - id[order[k++]] = 0; - id[order[k]]++; - } - k = 0; - } -} - -llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) { - if(vector_size == 1) - return ty; - return VectorType::get(ty, vector_size); -} - -distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize) - : tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) { - vector_size_ = vectorize?ty_->getVectorNumElements():1; - init_indices(); -} - -void distributed_tile::set_value(indices_t idx, Value *x) { - assert(x->getType() == ty_ && "cannot set a value of different type"); - Value *&result = values_[idx]; - assert(!result && "value cannot be set twice"); - result = x; -} - -Value* distributed_tile::get_value(indices_t idx) { - Value *result = values_.at(idx); - assert(result && "value has not been set"); - return result; -} - -unsigned distributed_tile::get_linear_index(indices_t idx) { - return indices_[idx]; -} - -indices_t distributed_tile::get_ordered_indices(unsigned id) { - return ordered_indices_.at(id); -} - - -void distributed_tile::for_each(std::function fn) { - for(unsigned i = 0; i < ordered_indices_.size(); i++){ - if(i % vector_size_ == 0) - fn(ordered_indices_[i]); - } -} - -/* Shared Tile */ -void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) { - BinaryOperator *bin_op = dyn_cast(arg); - Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0); - if(dyn_cast(arg)){ - cst = arg; - non_cst = _0; - return; - } - if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){ - non_cst = arg; - cst = _0; - return; - } - Constant *cst_lhs = dyn_cast(bin_op->getOperand(0)); - Constant *cst_rhs = dyn_cast(bin_op->getOperand(1)); - if(cst_lhs && cst_rhs){ - cst = arg; - non_cst = _0; - } - else if(cst_lhs){ - cst = cst_lhs; - non_cst = bin_op->getOperand(1); - } - else if(cst_rhs){ - cst = cst_rhs; - non_cst = bin_op->getOperand(0); - } - else{ - non_cst = arg; - cst = _0; - } -} - -void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) { - non_cst_idx.clear(); - cst_idx.clear(); - for(Value *idx: arg_idx){ - Value *non_cst, *cst; - extract_constant(idx, non_cst, cst); - non_cst_idx.push_back(non_cst); - cst_idx.push_back(cst); - } -} - - -Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector& perm, const std::vector& order, indices_t idx) { - // strides - std::vector strides(order.size()); - strides[order[0]] = builder.getInt32(1); - for(size_t i = 1; i < idx.size(); i++) - strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]])); - // result - Value *result = builder.getInt32(0); - for(size_t i = 0; i < strides.size(); i++) - result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i])); - return result; -} - -shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector& perm): - tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){ - return_vector_ = false; - if(perm_.empty()){ - perm_.resize(shapes.size()); - std::iota(perm_.begin(), perm_.end(), 0); - } -} - -void shared_tile::set_value(indices_t idx, Value *value) { - Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx)); - unsigned addr_space = ptr->getType()->getPointerAddressSpace(); - ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space)); - builder_.CreateStore(value, ptr); -} - -void shared_tile::set_vector_size(unsigned vector_size) { - vector_size_ = vector_size; -} - -void shared_tile::set_return_mode(bool return_vector){ - return_vector_ = return_vector; -} - - -Value* shared_tile::get_value(indices_t idx) { - indices_t non_cst_idx, cst_idx; - extract_constant(idx, non_cst_idx, cst_idx); - Value *&base_ptr = ptr_cache_[non_cst_idx]; - unsigned vector_size = vector_size_; - Type *ty = ty_; - if(ty->isHalfTy() && (vector_size % 2 == 0)){ - ty = IntegerType::get(ty->getContext(), 32); - vector_size = vector_size / 2; - } - if(base_ptr == nullptr){ -// BasicBlock* store = builder_.GetInsertBlock(); -// if(!non_cst_idx.empty()) -// if(isa(non_cst_idx.front())){ -// builder_.SetInsertPoint((Instruction*)non_cst_idx.front()); -// } - base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx)); - if(vector_size_ > 1){ - Type *vec_ty = VectorType::get(ty, vector_size); - Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace()); - base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty); - } -// builder_.SetInsertPoint(store); - } - Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx); - Value *div = offset; - if(vector_size_ > 1) - div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_)); - Value *ptr = builder_.CreateGEP(base_ptr, div); - Value *result = builder_.CreateLoad(ptr); - if(return_vector_ == false && vector_size_ > 1) { - Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_)); - result = builder_.CreateExtractElement(result, rem); - } - return result; -} llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) { using llop = llvm::Instruction::BinaryOps; @@ -306,7 +109,7 @@ llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) { } -Type *type(ir::type *ty, LLVMContext &ctx) { +inline Type *type(ir::type *ty, LLVMContext &ctx) { // function if(auto* tt = dynamic_cast(ty)){ Type *return_ty = type(tt->get_return_ty(), ctx); @@ -344,34 +147,17 @@ Type *type(ir::type *ty, LLVMContext &ctx) { } -/* ------------------- - * ---- Init Axes ---- - * ------------------- */ - -// Grid construction -std::vector delinearize(Value *trailing, const std::vector& order, std::vector &shapes, IRBuilder<> &builder){ - size_t dim = shapes.size(); - std::vector result(dim); - for(unsigned k = 0; k < dim - 1; k++){ - Constant *dim_k = builder.getInt32(shapes[order[k]]); - Value *rem = builder.CreateURem(trailing, dim_k); - trailing = builder.CreateUDiv(trailing, dim_k); - result[order[k]] = rem; +inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) { + switch(attr.get_kind()){ + case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias); + case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly); + case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly); + case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value()); + default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute"); } - result[order[dim - 1]] = trailing; - return result; } -inline int32_t ceil(int32_t num, int32_t div){ - return (num + div - 1)/div; -} - -/* ------------------- - * ---- Init Tiles ---- - * ------------------- */ - - -bool is_trans(ir::value *v) { +inline bool is_trans(ir::value *v) { if(dynamic_cast(v)) { return true; } @@ -386,34 +172,9 @@ bool is_trans(ir::value *v) { -/* ---------------------------- - * ---- Generate LLVM code ---- - * ---------------------------- */ - -inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) { - switch(attr.get_kind()){ - case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias); - case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly); - case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly); - case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value()); - default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute"); - } -} - - -void selection::run(ir::module &src, Module &dst) { - // create tile - generator gen(&dst, tgt_, layouts_, alignment_, alloc_, num_warps_ ); - - for(ir::alloc_const *x: src.allocs()) - gen.visit_value(x); - for(ir::function *fn: src.get_function_list()) - gen.visit_value(fn); -} - - generator::generator(Module *dst, + analysis::axes *a_axes, target *tgt, analysis::layout *layouts, analysis::align *alignment, @@ -421,7 +182,7 @@ generator::generator(Module *dst, unsigned num_warps) : ctx_(&dst->getContext()), mod_(dst), builder_(new Builder(dst->getContext())), - tgt_(tgt), + a_axes_(a_axes), tgt_(tgt), layouts_(layouts), alignment_(alignment), alloc_(alloc), num_warps_(num_warps) { @@ -1163,14 +924,12 @@ void generator::visit_function(ir::function* fn) { - - void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) { - machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout); + machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); } void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { - machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), axes_, layout); + machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); } void generator::visit_layout_shared(analysis::layout_shared_t* layout) { @@ -1215,240 +974,6 @@ void generator::set_value(ir::value *x, const indices_t& idx, Value* v) { } - -machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, - Value *&sh_mem_ptr, analysis::layout_t *layout, - std::map& vmap, - std::map& tmap) - : mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) { - - auto order = layout_->order; - auto shapes = layout_->shapes; - shapes[order[0]] += layout_->pad; - - Type* ty = type(layout_->ty, builder_->getContext()); - - PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace()); - // double-buffered - if(layout_->double_buffer) { - BasicBlock *current = builder_->GetInsertBlock(); - auto info = *layout_->double_buffer; - ir::phi_node *phi = info.phi; - BasicBlock *parent = (BasicBlock*)vmap_.at(phi->get_parent()); - if(parent->empty()) - builder_->SetInsertPoint(parent); - else - builder_->SetInsertPoint(&*parent->getFirstNonPHI()); - // create pointers - ptr_ = builder_->CreatePHI(ptr_ty, 2); - pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_))); - pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType()); - offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2); - next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr"); - builder_->SetInsertPoint(current); - } - else{ - size_t offset = alloc_->offset(layout_); - ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset)); - ptr_ = builder_->CreateBitCast(ptr_, ptr_ty); - } -} - - -tile* machine_layout_shared_t::create(ir::value *v) { - auto order = layout_->order; - auto shapes = layout_->shapes; - shapes[order[0]] += layout_->pad; - Type* ty = type(layout_->ty, builder_->getContext()); - // double-buffered - if(layout_->double_buffer) { - if(v == layout_->double_buffer->phi) - return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_); - if(v == layout_->double_buffer->latch) - return new shared_tile(ty, shapes, order, next_ptr_, *builder_); - return new shared_tile(ty, shapes, order, pre_ptr_, *builder_); - } - else { - return new shared_tile(ty, shapes, order, ptr_, *builder_); - } -} - -machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, - std::map& axes, - analysis::layout_t *layout) - : mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), axes_(axes), layout_(layout) { - -} - -tile *machine_layout_distributed_t::create(ir::value *v) { - Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext()); - const auto &shapes = v->get_type()->get_tile_shapes(); - std::vector axes(shapes.size()); - for(size_t d = 0; d < shapes.size(); d++){ - if(shapes[d] > 1){ - unsigned x = layout_->axes[d]; - axes[d] = axes_.at(x); - } - else{ - axes[d].contiguous = 1; - axes[d].values = {builder_->getInt32(0)}; - } - } - return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false); -} - -machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder, - target *tgt, Type *ty, - std::map& axes, - analysis::layout_hmma_884_t* layout) - : machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) { - - Value *warp_size = builder_->getInt32(32); - Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); - Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); - Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); - - const auto& shapes = layout->shapes; - if(shapes.size() > 3) - throw std::runtime_error("unsupported"); - - bool is_batched = shapes.size() >= 3; - - Value *_1 = builder_->getInt32(1); - Value *_2 = builder_->getInt32(2); - Value *_3 = builder_->getInt32(3); - Value *_4 = builder_->getInt32(4); - Value *_16 = builder_->getInt32(16); - - // fragments per warp - unsigned fpw_0 = layout->fpw.at(0); - unsigned fpw_1 = layout->fpw.at(1); - unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1; - // warps per tile - unsigned wpt_0 = layout->wpt.at(0); - unsigned wpt_1 = layout->wpt.at(1); - unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1; - // hmma warp tile size - unsigned hmma_wts_0 = fpw_0 * 8; - unsigned hmma_wts_1 = fpw_1 * 8; - unsigned hmma_wts_2 = is_batched ? fpw_2 : 1; - // hmma block tile size - unsigned hmma_bts_0 = hmma_wts_0 * wpt_0; - unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; - unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; - // number of repetition - unsigned num_rep_0 = shapes[0] / hmma_bts_0; - unsigned num_rep_1 = shapes[1] / hmma_bts_1; - unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1; - // size of each pack (interleaving) - pack_size_0_ = std::min(num_rep_0, 1); - pack_size_1_ = std::min(num_rep_1, 1); - // number of packs (interleaving) - num_packs_0_ = num_rep_0 / pack_size_0_; - num_packs_1_ = num_rep_1 / pack_size_1_; - - /* intra warp offset */ - // offset of quad in pair - Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)), - builder_->getInt32(fpw_0 * pack_size_0_)); - Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)), - builder_->getInt32(fpw_1 * pack_size_1_)); - - // Quad pair id - Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4); - Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4); - pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0)); - pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0)); - pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1)); - // Quad pair offset - Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_)); - Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_)); - - /* inter warp offset */ - Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0)); - Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0)); - Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1)); - Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1)); - Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_)); - Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_)); - - /* offsets */ - // a offset - offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a)); - offset_a_k_ = builder_->CreateAnd(u_thread_id, _3); - // b offsets - offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b)); - offset_b_k_ = builder_->CreateAnd(u_thread_id, _3); - - // c offsets - Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_); - Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2), - builder_->CreateAdd(warp_offset_j, pair_b_off)); - - /* indices */ - // i indices - std::vector idx_i; - for(unsigned pack = 0; pack < num_packs_0_; pack++) - for(unsigned ii = 0; ii < pack_size_0_; ii++) - for(unsigned i = 0; i < 2; i++){ - idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2))); - } - // j indices - std::vector idx_j; - for(unsigned pack = 0; pack < num_packs_1_; pack++) - for(unsigned jj = 0; jj < pack_size_1_; jj++) - for(unsigned j = 0; j < 2; j++){ - idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_))); - idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1))); - } - // z indices - std::vector idx_z; - for(unsigned pack = 0; pack < num_rep_2; pack++) - idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2))); - - - /* axes */ - axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0}; - axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1}; - if(is_batched) - axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2}; -} - - -machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder, - target *tgt, Type *ty, - std::map &axes, - analysis::layout_scanline_t* layout) - : machine_layout_distributed_t(mod, builder, tgt, ty, axes, layout) { - - Value *warp_size = builder_->getInt32(32); - Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); - Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); - Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); - - auto order = layout->order; - const auto& shapes = layout->shapes; - size_t dim = shapes.size(); - std::vector nts = layout->nts; - std::vector mts = layout->mts; - Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id); - std::vector thread_id = delinearize(full_thread_id, order, mts, *builder_); - // Create axes - for(unsigned k = 0; k < dim; k++) { - std::string str_k = std::to_string(k); - Value *contiguous_k = builder_->getInt32(nts[k]); - Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k); - unsigned per_block = nts[k] * mts[k]; - unsigned per_thread = nts[k] * shapes[k] / per_block; - std::vector idx_list(per_thread); - for(unsigned n = 0 ; n < per_thread; n++){ - unsigned offset = n / nts[k] * per_block + n % nts[k]; - idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); - } - axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]}; - } -} - void generator::finalize_shared_layout(analysis::layout_shared_t *shared) { if(shared->double_buffer) { auto info = *shared->double_buffer; diff --git a/lib/codegen/selection/machine_layout.cc b/lib/codegen/selection/machine_layout.cc new file mode 100644 index 000000000..1e6f0d5da --- /dev/null +++ b/lib/codegen/selection/machine_layout.cc @@ -0,0 +1,308 @@ +#include +#include "triton/codegen/selection/machine_layout.h" +#include "triton/codegen/selection/machine_value.h" +#include "triton/codegen/analysis/allocation.h" +#include "triton/codegen/analysis/axes.h" +#include "triton/codegen/target.h" +#include "triton/ir/instructions.h" +#include "triton/ir/type.h" +#include "llvm/IR/IRBuilder.h" + +namespace triton{ +namespace codegen{ + +using namespace llvm; + +inline Type *type(ir::type *ty, LLVMContext &ctx) { + // function + if(auto* tt = dynamic_cast(ty)){ + Type *return_ty = type(tt->get_return_ty(), ctx); + std::vector param_tys; + std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), + [&ctx](ir::type* t){ return type(t, ctx);}); + return FunctionType::get(return_ty, param_tys, false); + } + // pointer + if(ty->is_pointer_ty()){ + Type *elt_ty = type(ty->get_pointer_element_ty(), ctx); + unsigned addr_space = ty->get_pointer_address_space(); + return PointerType::get(elt_ty, addr_space); + } + // integer + if(ty->is_integer_ty()){ + unsigned bitwidth = ty->get_integer_bitwidth(); + return IntegerType::get(ctx, bitwidth); + } + // primitive types + switch(ty->get_type_id()){ + case ir::type::VoidTyID: return Type::getVoidTy(ctx); + case ir::type::HalfTyID: return Type::getHalfTy(ctx); + case ir::type::FloatTyID: return Type::getFloatTy(ctx); + case ir::type::DoubleTyID: return Type::getDoubleTy(ctx); + case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx); + case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx); + case ir::type::LabelTyID: return Type::getLabelTy(ctx); + case ir::type::MetadataTyID: return Type::getMetadataTy(ctx); + case ir::type::TokenTyID: return Type::getTokenTy(ctx); + default: break; + } + // unknown type + throw std::runtime_error("unknown conversion from ir::type to Type"); +} + +// Grid construction +inline std::vector delinearize(Value *trailing, const std::vector& order, std::vector &shapes, IRBuilder<> &builder){ + size_t dim = shapes.size(); + std::vector result(dim); + for(unsigned k = 0; k < dim - 1; k++){ + Constant *dim_k = builder.getInt32(shapes[order[k]]); + Value *rem = builder.CreateURem(trailing, dim_k); + trailing = builder.CreateUDiv(trailing, dim_k); + result[order[k]] = rem; + } + result[order[dim - 1]] = trailing; + return result; +} + +inline int32_t ceil(int32_t num, int32_t div){ + return (num + div - 1)/div; +} + + + +machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, + Value *&sh_mem_ptr, analysis::layout_t *layout, + std::map& vmap, + std::map& tmap) + : mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) { + + auto order = layout_->order; + auto shapes = layout_->shapes; + shapes[order[0]] += layout_->pad; + + Type* ty = type(layout_->ty, builder_->getContext()); + + PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace()); + // double-buffered + if(layout_->double_buffer) { + BasicBlock *current = builder_->GetInsertBlock(); + auto info = *layout_->double_buffer; + ir::phi_node *phi = info.phi; + BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent())); + if(parent->empty()) + builder_->SetInsertPoint(parent); + else + builder_->SetInsertPoint(&*parent->getFirstNonPHI()); + // create pointers + ptr_ = builder_->CreatePHI(ptr_ty, 2); + pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_))); + pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType()); + offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2); + next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr"); + builder_->SetInsertPoint(current); + } + else{ + size_t offset = alloc_->offset(layout_); + ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset)); + ptr_ = builder_->CreateBitCast(ptr_, ptr_ty); + } +} + + +tile* machine_layout_shared_t::create(ir::value *v) { + auto order = layout_->order; + auto shapes = layout_->shapes; + shapes[order[0]] += layout_->pad; + Type* ty = type(layout_->ty, builder_->getContext()); + // double-buffered + if(layout_->double_buffer) { + if(v == layout_->double_buffer->phi) + return new shared_tile(ty, shapes, order, ptr_, *builder_, offset_); + if(v == layout_->double_buffer->latch) + return new shared_tile(ty, shapes, order, next_ptr_, *builder_); + return new shared_tile(ty, shapes, order, pre_ptr_, *builder_); + } + else { + return new shared_tile(ty, shapes, order, ptr_, *builder_); + } +} + +machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, + analysis::axes *a_axes, std::map& axes, + analysis::layout_t *layout) + : mod_(mod), builder_(builder), tgt_(tgt), ty_(ty), a_axes_(a_axes), axes_(axes), layout_(layout) { + +} + +tile *machine_layout_distributed_t::create(ir::value *v) { + Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext()); + const auto &shapes = v->get_type()->get_tile_shapes(); + std::vector axes(shapes.size()); + for(size_t d = 0; d < shapes.size(); d++){ + if(shapes[d] > 1){ + unsigned x = a_axes_->get(v, d); + axes[d] = axes_.at(x); + } + else{ + axes[d].contiguous = 1; + axes[d].values = {builder_->getInt32(0)}; + } + } + return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false); +} + +machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder, + target *tgt, Type *ty, analysis::axes *a_axes, + std::map& axes, + analysis::layout_hmma_884_t* layout) + : machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) { + + Value *warp_size = builder_->getInt32(32); + Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); + Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); + Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); + + const auto& shapes = layout->shapes; + if(shapes.size() > 3) + throw std::runtime_error("unsupported"); + + bool is_batched = shapes.size() >= 3; + + Value *_1 = builder_->getInt32(1); + Value *_2 = builder_->getInt32(2); + Value *_3 = builder_->getInt32(3); + Value *_4 = builder_->getInt32(4); + Value *_16 = builder_->getInt32(16); + + // fragments per warp + unsigned fpw_0 = layout->fpw.at(0); + unsigned fpw_1 = layout->fpw.at(1); + unsigned fpw_2 = is_batched ? layout->fpw.at(2) : 1; + // warps per tile + unsigned wpt_0 = layout->wpt.at(0); + unsigned wpt_1 = layout->wpt.at(1); + unsigned wpt_2 = is_batched ? layout->wpt.at(2) : 1; + // hmma warp tile size + unsigned hmma_wts_0 = fpw_0 * 8; + unsigned hmma_wts_1 = fpw_1 * 8; + unsigned hmma_wts_2 = is_batched ? fpw_2 : 1; + // hmma block tile size + unsigned hmma_bts_0 = hmma_wts_0 * wpt_0; + unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; + unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; + // number of repetition + unsigned num_rep_0 = shapes[0] / hmma_bts_0; + unsigned num_rep_1 = shapes[1] / hmma_bts_1; + unsigned num_rep_2 = is_batched ? shapes[2] / hmma_bts_2 : 1; + // size of each pack (interleaving) + pack_size_0_ = std::min(num_rep_0, 1); + pack_size_1_ = std::min(num_rep_1, 1); + // number of packs (interleaving) + num_packs_0_ = num_rep_0 / pack_size_0_; + num_packs_1_ = num_rep_1 / pack_size_1_; + + /* intra warp offset */ + // offset of quad in pair + Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)), + builder_->getInt32(fpw_0 * pack_size_0_)); + Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)), + builder_->getInt32(fpw_1 * pack_size_1_)); + + // Quad pair id + Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4); + Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4); + pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0)); + pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0)); + pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1)); + // Quad pair offset + Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_)); + Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_)); + + /* inter warp offset */ + Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0)); + Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0)); + Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1)); + Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1)); + Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_)); + Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_)); + + /* offsets */ + // a offset + offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a)); + offset_a_k_ = builder_->CreateAnd(u_thread_id, _3); + // b offsets + offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b)); + offset_b_k_ = builder_->CreateAnd(u_thread_id, _3); + + // c offsets + Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_); + Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2), + builder_->CreateAdd(warp_offset_j, pair_b_off)); + + /* indices */ + // i indices + std::vector idx_i; + for(unsigned pack = 0; pack < num_packs_0_; pack++) + for(unsigned ii = 0; ii < pack_size_0_; ii++) + for(unsigned i = 0; i < 2; i++){ + idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2))); + } + // j indices + std::vector idx_j; + for(unsigned pack = 0; pack < num_packs_1_; pack++) + for(unsigned jj = 0; jj < pack_size_1_; jj++) + for(unsigned j = 0; j < 2; j++){ + idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_))); + idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1))); + } + // z indices + std::vector idx_z; + for(unsigned pack = 0; pack < num_rep_2; pack++) + idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2))); + + + /* axes */ + axes_[layout->axes[0]] = distributed_axis{1, idx_i, warp_id_0}; + axes_[layout->axes[1]] = distributed_axis{1, idx_j, warp_id_1}; + if(is_batched) + axes_[layout->axes[2]] = distributed_axis{1, idx_z, warp_id_2}; +} + + +machine_layout_scanline_t::machine_layout_scanline_t(Module *mod, Builder *builder, + target *tgt, Type *ty, + analysis::axes *a_axes, std::map &axes, + analysis::layout_scanline_t* layout) + : machine_layout_distributed_t(mod, builder, tgt, ty, a_axes, axes, layout) { + + Value *warp_size = builder_->getInt32(32); + Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); + Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); + Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); + + auto order = layout->order; + const auto& shapes = layout->shapes; + size_t dim = shapes.size(); + std::vector nts = layout->nts; + std::vector mts = layout->mts; + Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id); + std::vector thread_id = delinearize(full_thread_id, order, mts, *builder_); + // Create axes + for(unsigned k = 0; k < dim; k++) { + std::string str_k = std::to_string(k); + Value *contiguous_k = builder_->getInt32(nts[k]); + Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k); + unsigned per_block = nts[k] * mts[k]; + unsigned per_thread = nts[k] * shapes[k] / per_block; + std::vector idx_list(per_thread); + for(unsigned n = 0 ; n < per_thread; n++){ + unsigned offset = n / nts[k] * per_block + n % nts[k]; + idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); + } + axes_[layout->axes[k]] = distributed_axis{nts[k], idx_list, thread_id[k]}; + } +} + + +} +} diff --git a/lib/codegen/selection/machine_value.cc b/lib/codegen/selection/machine_value.cc new file mode 100644 index 000000000..bd4237043 --- /dev/null +++ b/lib/codegen/selection/machine_value.cc @@ -0,0 +1,206 @@ +#include +#include "llvm/IR/IRBuilder.h" +#include "triton/codegen/selection/machine_value.h" + +namespace triton{ +namespace codegen{ + +using namespace llvm; + +/* Distributed Tile */ +void distributed_tile::init_indices() { + std::vector id(axes_.size(), 0); + // create iteration order + std::vector order(id.size()); + std::iota(order.begin(), order.end(), 0); + auto cmp = [&](int x, int y) { + return axes_[x].contiguous > axes_[y].contiguous; + }; + std::sort(order.begin(), order.end(), cmp); + // build + size_t k = 0; + while(true) { + indices_t current; + for(size_t d = 0; d < id.size(); d++) + current.push_back(axes_[d].values[id[d]]); + size_t sz = indices_.size(); + indices_[current] = sz; + values_[current] = nullptr; + ordered_indices_.push_back(current); + id[order[0]]++; + while(id[order[k]] == axes_[order[k]].values.size()){ + if(k == id.size() - 1) + return; + id[order[k++]] = 0; + id[order[k]]++; + } + k = 0; + } +} + +llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) { + if(vector_size == 1) + return ty; + return VectorType::get(ty, vector_size); +} + +distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize) + : tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) { + vector_size_ = vectorize?ty_->getVectorNumElements():1; + init_indices(); +} + +void distributed_tile::set_value(indices_t idx, Value *x) { + assert(x->getType() == ty_ && "cannot set a value of different type"); + Value *&result = values_[idx]; + assert(!result && "value cannot be set twice"); + result = x; +} + +Value* distributed_tile::get_value(indices_t idx) { + Value *result = values_.at(idx); + assert(result && "value has not been set"); + return result; +} + +unsigned distributed_tile::get_linear_index(indices_t idx) { + return indices_[idx]; +} + +indices_t distributed_tile::get_ordered_indices(unsigned id) { + return ordered_indices_.at(id); +} + + +void distributed_tile::for_each(std::function fn) { + for(unsigned i = 0; i < ordered_indices_.size(); i++){ + if(i % vector_size_ == 0) + fn(ordered_indices_[i]); + } +} + +/* Shared Tile */ +void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) { + BinaryOperator *bin_op = dyn_cast(arg); + Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0); + if(dyn_cast(arg)){ + cst = arg; + non_cst = _0; + return; + } + if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){ + non_cst = arg; + cst = _0; + return; + } + Constant *cst_lhs = dyn_cast(bin_op->getOperand(0)); + Constant *cst_rhs = dyn_cast(bin_op->getOperand(1)); + if(cst_lhs && cst_rhs){ + cst = arg; + non_cst = _0; + } + else if(cst_lhs){ + cst = cst_lhs; + non_cst = bin_op->getOperand(1); + } + else if(cst_rhs){ + cst = cst_rhs; + non_cst = bin_op->getOperand(0); + } + else{ + non_cst = arg; + cst = _0; + } +} + +void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) { + non_cst_idx.clear(); + cst_idx.clear(); + for(Value *idx: arg_idx){ + Value *non_cst, *cst; + extract_constant(idx, non_cst, cst); + non_cst_idx.push_back(non_cst); + cst_idx.push_back(cst); + } +} + + +Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector& perm, const std::vector& order, indices_t idx) { + // strides + std::vector strides(order.size()); + strides[order[0]] = builder.getInt32(1); + for(size_t i = 1; i < idx.size(); i++) + strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]])); + // result + Value *result = builder.getInt32(0); + for(size_t i = 0; i < strides.size(); i++) + result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i])); + return result; +} + +shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector& perm): + tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){ + return_vector_ = false; + if(perm_.empty()){ + perm_.resize(shapes.size()); + std::iota(perm_.begin(), perm_.end(), 0); + } +} + +void shared_tile::set_value(indices_t idx, Value *value) { + Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx)); + unsigned addr_space = ptr->getType()->getPointerAddressSpace(); + ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space)); + builder_.CreateStore(value, ptr); +} + +void shared_tile::set_vector_size(unsigned vector_size) { + vector_size_ = vector_size; +} + +void shared_tile::set_return_mode(bool return_vector){ + return_vector_ = return_vector; +} + + +Value* shared_tile::get_value(indices_t idx) { + indices_t non_cst_idx, cst_idx; + extract_constant(idx, non_cst_idx, cst_idx); + Value *&base_ptr = ptr_cache_[non_cst_idx]; + unsigned vector_size = vector_size_; + Type *ty = ty_; + if(ty->isHalfTy() && (vector_size % 2 == 0)){ + ty = IntegerType::get(ty->getContext(), 32); + vector_size = vector_size / 2; + } + if(base_ptr == nullptr){ +// BasicBlock* store = builder_.GetInsertBlock(); +// if(!non_cst_idx.empty()) +// if(isa(non_cst_idx.front())){ +// builder_.SetInsertPoint((Instruction*)non_cst_idx.front()); +// } + base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx)); + if(vector_size_ > 1){ + Type *vec_ty = VectorType::get(ty, vector_size); + Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace()); + base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty); + } +// builder_.SetInsertPoint(store); + } + Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx); + Value *div = offset; + if(vector_size_ > 1) + div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_)); + Value *ptr = builder_.CreateGEP(base_ptr, div); + Value *result = builder_.CreateLoad(ptr); + if(return_vector_ == false && vector_size_ > 1) { + Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_)); + result = builder_.CreateExtractElement(result, rem); + } + return result; +} + + + +} +} diff --git a/lib/codegen/selection/selection.cc b/lib/codegen/selection/selection.cc new file mode 100644 index 000000000..49fa1b714 --- /dev/null +++ b/lib/codegen/selection/selection.cc @@ -0,0 +1,20 @@ +#include +#include "triton/codegen/selection/selection.h" +#include "triton/codegen/selection/generator.h" +#include "triton/ir/module.h" + +namespace triton{ +namespace codegen{ + +using namespace llvm; + +void selection::run(ir::module &src, Module &dst) { + generator gen(&dst, a_axes_, tgt_, layouts_, alignment_, alloc_, num_warps_ ); + for(ir::alloc_const *x: src.allocs()) + gen.visit_alloc_const(x); + for(ir::function *fn: src.get_function_list()) + gen.visit_function(fn); +} + +} +} diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 68fe7fe01..b83ea8442 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -217,7 +217,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::transform::reassociate reassociate(&align); codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::cts cts; - codegen::selection selection(&liveness, &allocation, &align, &layouts, target.get(), opt.num_warps); + codegen::selection selection(&liveness, &allocation, &align, &axes, &layouts, target.get(), opt.num_warps); // run passes // ir::print(module, std::cout); peephole.run(module);