diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h deleted file mode 100644 index da6399573..000000000 --- a/include/triton/codegen/selection.h +++ /dev/null @@ -1,346 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_SELECTION_H -#define TDL_INCLUDE_CODEGEN_SELECTION_H - -#include "triton/ir/context.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/type.h" -#include "triton/ir/visitor.h" -#include "triton/codegen/analysis/layout.h" -#include "triton/codegen/transform/cts.h" - - -namespace llvm{ - class Type; - class Value; - class Instruction; - class Constant; - class LLVMContext; - class Module; - class ConstantFolder; - class IRBuilderDefaultInserter; - template - class IRBuilder; - class ArrayType; - class Function; -} - - -// typedefs -namespace triton{ -namespace codegen{ - typedef llvm::IRBuilder Builder; - typedef llvm::LLVMContext LLVMContext; - typedef llvm::Type Type; - typedef llvm::Value Value; - typedef llvm::Module Module; - typedef llvm::Instruction Instruction; - typedef llvm::Constant Constant; - typedef llvm::ArrayType ArrayType; - typedef llvm::Function Function; -} -} - -namespace triton{ -namespace codegen{ - -namespace analysis{ -class liveness; -class tiles; -class align; -class allocation; -class cts; -class axes; -class layout; -} - -namespace transform{ -class coalesce; -} - -class target; - -typedef std::vector indices_t; - -struct distributed_axis { - int contiguous; - std::vector values; - Value* thread_id; -}; - -class tile { -protected: - typedef std::vector shapes_t; - -public: - tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ } - virtual void set_value(indices_t idx, Value *v) = 0; - virtual Value* get_value(indices_t idx) = 0; - Type *get_ty() const { return ty_; } - shapes_t get_shapes() const { return shapes_; } - -protected: - Type *ty_; - shapes_t shapes_; -}; - -class shared_tile: public tile { -private: - void extract_constant(Value *arg, Value *&non_cst, Value *&cst); - void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx); - - -public: - shared_tile(Type* ty, const shapes_t &shapes, const std::vector &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector& perm = {}); - void set_vector_size(unsigned vector_size); - void set_return_mode(bool return_vector); - void set_value(indices_t, Value *); - Value* get_ptr_to(indices_t idx); - Value* get_value(indices_t idx); - Value* get_pointer() { return ptr_; } - Value* get_offset() { return offset_; } - const std::vector& get_perm() { return perm_; } - const std::vector& get_order() { return order_; } - static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector& perm, const std::vector& order, indices_t idx); - -private: - Value *ptr_; - bool return_vector_; - Builder &builder_; - Value *offset_; - std::map ptr_cache_; - unsigned vector_size_; - std::vector order_; - std::vector perm_; -}; - -// Distribtued tile -class distributed_tile: public tile{ - typedef std::vector axes_t; - typedef std::vector ordered_indices_vec_t; - typedef std::map indices_map_t; - typedef std::map values_map_t; - -private: - void init_indices(); - Type *make_vector_ty(Type *ty, size_t vector_size); - -public: - distributed_tile(Type *ty, const shapes_t& shapes, const std::vector& order, const axes_t &axes, Builder &builder, bool vectorize); - void set_value(indices_t idx, Value *v); - Value* get_value(indices_t idx); - const std::vector& get_order() { return order_; } - unsigned get_linear_index(indices_t idx); - indices_t get_ordered_indices(unsigned id); - void for_each(std::function fn); - const distributed_axis &axis(unsigned dim) { return axes_.at(dim); } - -private: - axes_t axes_; - std::vector order_; - indices_map_t indices_; - values_map_t values_; - ordered_indices_vec_t ordered_indices_; - size_t vector_size_; - Builder &builder_; -}; - -class machine_layout_t { -public: - virtual tile* create(ir::value *v) = 0; -}; - -class machine_layout_shared_t: public machine_layout_t { -public: - machine_layout_shared_t(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, analysis::layout_t* layout, - std::map& vmap, - std::map& tmap); - - tile* create(ir::value *v); - - Module *mod_; - Builder *builder_; - target *tgt_; - analysis::allocation* alloc_; - Value *&sh_mem_ptr_; - analysis::layout_t* layout_; - std::map& vmap_; - std::map& tmap_; - - Value *offset_; - Value *ptr_; - Value *pre_ptr_; - Value *next_ptr_; - -}; - -class machine_layout_distributed_t: public machine_layout_t { -public: - machine_layout_distributed_t(Module *mod, Builder *builder, target *tgt, Type *ty, - analysis::axes *a_axes, std::map& axes, - analysis::layout_t* layout); - - tile* create(ir::value *v); - Module *mod_; - Builder *builder_; - target *tgt_; - Type *ty_; - analysis::axes *a_axes_; - std::map& axes_; - analysis::layout_t* layout_; -}; - - -class machine_layout_hmma_884_t: public machine_layout_distributed_t { -public: - machine_layout_hmma_884_t(Module *mod, Builder *builder, - target *tgt, Type *ty, - analysis::axes *a_axes, std::map& axes, - analysis::layout_hmma_884_t* layout); - Value *offset_a_i_, *offset_a_k_; - Value *offset_b_j_, *offset_b_k_; - unsigned pack_size_0_; - unsigned pack_size_1_; - unsigned num_packs_0_; - unsigned num_packs_1_; -}; - -class machine_layout_scanline_t: public machine_layout_distributed_t { -public: - machine_layout_scanline_t(Module *mod, Builder *builder, - target *tgt, Type *ty, - analysis::axes *a_axes, std::map& axes, - analysis::layout_scanline_t* layout); -}; - -class generator: public ir::visitor, public analysis::layout_visitor { -private: - void for_each(ir::value *x, const std::function& fn); - Value* get_value(ir::value *x, const indices_t& idx); - void set_value(ir::value *x, const indices_t& idx, Value* v); - - void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK); - void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); - void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, - Type *c_ty, Function *f_mul_add); - - void finalize_shared_layout(analysis::layout_shared_t*); - void finalize_function(ir::function*); - void finalize_phi_node(ir::phi_node*); - -public: - generator(Module *dst, - analysis::axes *a_axes, - target *tgt, - analysis::layout *layouts, - analysis::align *alignment, - analysis::allocation *alloc, - unsigned num_warps); - - void visit_value(ir::value* v); - - void visit_phi_node(ir::phi_node*); - void visit_binary_operator(ir::binary_operator*); - void visit_getelementptr_inst(ir::getelementptr_inst*); - - void visit_icmp_inst(ir::icmp_inst*); - void visit_fcmp_inst(ir::fcmp_inst*); - void visit_cast_inst(ir::cast_inst*); - - void visit_return_inst(ir::return_inst*); - void visit_cond_branch_inst(ir::cond_branch_inst*); - void visit_uncond_branch_inst(ir::uncond_branch_inst*); - - - void visit_unmasked_load_inst(ir::unmasked_load_inst*); - void visit_masked_load_inst(ir::masked_load_inst*); - void visit_unmasked_store_inst(ir::unmasked_store_inst*); - void visit_masked_store_inst(ir::masked_store_inst*); - - void visit_reshape_inst(ir::reshape_inst*); - void visit_splat_inst(ir::splat_inst*); - void visit_broadcast_inst(ir::broadcast_inst*); - void visit_downcast_inst(ir::downcast_inst*); - - void visit_get_program_id_inst(ir::get_program_id_inst*); - void visit_get_num_program_inst(ir::get_num_program_inst*); - void visit_atomic_cas_inst(ir::atomic_cas_inst*); - void visit_atomic_exch_inst(ir::atomic_exch_inst*); - void visit_atomic_add_inst(ir::atomic_add_inst*); - void visit_dot_inst(ir::dot_inst*); - void visit_trans_inst(ir::trans_inst*); - void visit_sqrt_inst(ir::sqrt_inst*); - void visit_reduce_inst(ir::reduce_inst*); - void visit_select_inst(ir::select_inst*); - - void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); - void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); - void visit_barrier_inst(ir::barrier_inst*); - void visit_make_range_dyn(ir::make_range_dyn*); - void visit_make_range(ir::make_range*); - - void visit_make_range_sta(ir::make_range_sta*); - void visit_undef_value(ir::undef_value*); - void visit_constant_int(ir::constant_int*); - void visit_constant_fp(ir::constant_fp*); - void visit_alloc_const(ir::alloc_const*); - - void visit_function(ir::function*); - void visit_basic_block(ir::basic_block*); - void visit_argument(ir::argument*); - - void visit_layout_hmma_884(analysis::layout_hmma_884_t*); - void visit_layout_scanline(analysis::layout_scanline_t*); - void visit_layout_shared(analysis::layout_shared_t*); - -private: - LLVMContext *ctx_; - std::unique_ptr builder_; - Module *mod_; - - std::map machine_layouts_; - analysis::axes *a_axes_; - std::map axes_; - std::map vmap_; - std::map tmap_; - target *tgt_; - analysis::layout *layouts_; - analysis::align *alignment_; - analysis::allocation *alloc_; - Value *sh_mem_ptr_; - unsigned num_warps_; - - std::set seen_; -}; - - -// Selection pass -class selection{ - typedef std::map vmap_t; - typedef std::map tmap_t; - -public: - selection(analysis::liveness* liveness, analysis::allocation *alloc, - analysis::align *alignment, analysis::axes *axes, - analysis::layout *layouts, target *tgt, unsigned num_warps) - : liveness_(liveness), alloc_(alloc), - alignment_(alignment), a_axes_(axes), layouts_(layouts), - tgt_(tgt), num_warps_(num_warps){ } - - void run(ir::module &src, Module &dst); - -private: - analysis::liveness *liveness_; - analysis::allocation *alloc_; - analysis::axes *a_axes_; - analysis::layout *layouts_; - analysis::align *alignment_; - target *tgt_; - unsigned num_warps_; -}; - -} -} - -#endif diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 76ec88b90..3e6c0bacb 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -79,12 +79,11 @@ private: void finalize_phi_node(ir::phi_node*); public: - generator(Module *dst, - analysis::axes *a_axes, - target *tgt, + generator(analysis::axes *a_axes, analysis::layout *layouts, analysis::align *alignment, analysis::allocation *alloc, + target *tgt, unsigned num_warps); void visit_value(ir::value* v); @@ -143,6 +142,8 @@ public: void visit_layout_scanline(analysis::layout_scanline_t*); void visit_layout_shared(analysis::layout_shared_t*); + void visit(ir::module &, llvm::Module &); + private: LLVMContext *ctx_; Builder* builder_; diff --git a/include/triton/codegen/selection/selection.h b/include/triton/codegen/selection/selection.h deleted file mode 100644 index a2b88247f..000000000 --- a/include/triton/codegen/selection/selection.h +++ /dev/null @@ -1,70 +0,0 @@ -#pragma once - -#ifndef _TRITON_SELECTION_SELECTION_H_ -#define _TRITON_SELECTION_SELECTION_H_ - -#include - -namespace llvm{ - class Module; - class Value; -} - - -namespace triton{ - -namespace ir{ -class value; -class module; -} - -namespace codegen{ -// typedef -typedef llvm::Module Module; -typedef llvm::Value Value; -// forward -namespace analysis{ -class liveness; -class align; -class allocation; -class axes; -class layout; -} -class target; -class tile; - -} -} - -namespace triton{ -namespace codegen{ - -// Selection pass -class selection{ - typedef std::map vmap_t; - typedef std::map tmap_t; - -public: - selection(analysis::liveness* liveness, analysis::allocation *alloc, - analysis::align *alignment, analysis::axes *axes, - analysis::layout *layouts, target *tgt, unsigned num_warps) - : liveness_(liveness), alloc_(alloc), - alignment_(alignment), a_axes_(axes), layouts_(layouts), - tgt_(tgt), num_warps_(num_warps){ } - - void run(ir::module &src, Module &dst); - -private: - analysis::liveness *liveness_; - analysis::allocation *alloc_; - analysis::axes *a_axes_; - analysis::layout *layouts_; - analysis::align *alignment_; - target *tgt_; - unsigned num_warps_; -}; - -} -} - -#endif diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index e312cfded..fa06544f8 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -9,7 +9,7 @@ #include #include // codegen -#include "triton/codegen/selection.h" +#include "triton/ir/context.h" #include "triton/codegen/target.h" #include "triton/lang/parser.h" #include "triton/runtime/arg.h" diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 831719b73..022c51ed7 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -109,18 +109,18 @@ llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) { } -inline Type *type(ir::type *ty, LLVMContext &ctx) { +inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) { // function if(auto* tt = dynamic_cast(ty)){ - Type *return_ty = type(tt->get_return_ty(), ctx); + Type *return_ty = llvm_type(tt->get_return_ty(), ctx); std::vector param_tys; std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), - [&ctx](ir::type* t){ return type(t, ctx);}); + [&ctx](ir::type* t){ return llvm_type(t, ctx);}); return FunctionType::get(return_ty, param_tys, false); } // pointer if(ty->is_pointer_ty()){ - Type *elt_ty = type(ty->get_pointer_element_ty(), ctx); + Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx); unsigned addr_space = ty->get_pointer_address_space(); return PointerType::get(elt_ty, addr_space); } @@ -173,29 +173,14 @@ inline bool is_trans(ir::value *v) { -generator::generator(Module *dst, - analysis::axes *a_axes, - target *tgt, - analysis::layout *layouts, - analysis::align *alignment, - analysis::allocation *alloc, - unsigned num_warps) - : ctx_(&dst->getContext()), mod_(dst), - builder_(new Builder(dst->getContext())), - a_axes_(a_axes), tgt_(tgt), - layouts_(layouts), alignment_(alignment), alloc_(alloc), - num_warps_(num_warps) { - - if(tgt_->is_gpu()) - if(unsigned alloc_size = alloc_->allocated_size()){ - Type *int_8_ty = Type::getInt8Ty(*ctx_); - ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size); - Type *ptr_ty = PointerType::get(int_8_ty, 3); - GlobalVariable *sh_mem_array = - new GlobalVariable(*dst, array_ty, false, GlobalVariable::ExternalLinkage, - nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); - sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty); - } +generator::generator(analysis::axes *a_axes, + analysis::layout *layouts, + analysis::align *alignment, + analysis::allocation *alloc, + target *tgt, + unsigned num_warps) + : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), + tgt_(tgt), num_warps_(num_warps) { } @@ -226,7 +211,7 @@ void generator::visit_value(ir::value* v) { } void generator::visit_phi_node(ir::phi_node* phi) { - Type *ty = type(phi->get_type()->get_scalar_ty(), *ctx_); + Type *ty = llvm_type(phi->get_type()->get_scalar_ty(), *ctx_); unsigned num_ops = phi->get_num_operands(); for_each(phi, [&](indices_t idx){ set_value(phi, idx, builder_->Insert(PHINode::Create(ty, num_ops))); @@ -248,7 +233,7 @@ void generator::visit_getelementptr_inst(ir::getelementptr_inst* gep) { std::vector idx_vals; std::transform(gep->idx_begin(), gep->idx_end(), std::back_inserter(idx_vals), [&](ir::value* x){ return get_value(x, idx);}); - Type *source_ty = type(gep->get_source_elt_ty()->get_scalar_ty(), *ctx_); + Type *source_ty = llvm_type(gep->get_source_elt_ty()->get_scalar_ty(), *ctx_); Value *ret = builder_->Insert(GetElementPtrInst::CreateInBounds(source_ty, ptr, idx_vals)); set_value(gep, idx, ret); }); @@ -277,7 +262,7 @@ void generator::visit_fcmp_inst(ir::fcmp_inst* fcmp) { void generator::visit_cast_inst(ir::cast_inst* cast) { for_each(cast, [&](indices_t idx){ Value *arg = get_value(cast->get_operand(0), idx); - Type *dst_ty = type(cast->get_type()->get_scalar_ty(), *ctx_); + Type *dst_ty = llvm_type(cast->get_type()->get_scalar_ty(), *ctx_); Value *ret = builder_->Insert(CastInst::Create(llvm_op(cast->get_op()), arg, dst_ty)); set_value(cast, idx, ret); }); @@ -726,7 +711,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) { ir::value *D = dot->get_operand(2); distributed_tile *TD = (distributed_tile*)tmap_.at(D); - Type *c_ty = type(D->get_type()->get_scalar_ty(), *ctx_); + Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), *ctx_); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {c_ty}); auto A_shapes = A->get_type()->get_tile_shapes(); size_t red_axis = 1; @@ -851,22 +836,22 @@ void generator::visit_make_range(ir::make_range* x) { void generator::visit_undef_value(ir::undef_value *ud) { - vmap_[ud] = llvm::UndefValue::get(type(ud->get_type(), *ctx_)); + vmap_[ud] = llvm::UndefValue::get(llvm_type(ud->get_type(), *ctx_)); } void generator::visit_constant_int(ir::constant_int *cst){ - Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_); + Type *ty = llvm_type(cst->get_type()->get_scalar_ty(), *ctx_); vmap_[cst] = ConstantInt::get(ty, cst->get_value()); } void generator::visit_constant_fp(ir::constant_fp *cst){ - Type *ty = type(cst->get_type()->get_scalar_ty(), *ctx_); + Type *ty = llvm_type(cst->get_type()->get_scalar_ty(), *ctx_); vmap_[cst] = ConstantFP::get(ty, cst->get_value()); } void generator::visit_alloc_const(ir::alloc_const *alloc) { unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value(); - Type *element_ty = type(alloc->get_type()->get_pointer_element_ty(), *ctx_); + Type *element_ty = llvm_type(alloc->get_type()->get_pointer_element_ty(), *ctx_); Type *array_ty = llvm::ArrayType::get(element_ty, size); Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage, nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); @@ -876,7 +861,7 @@ void generator::visit_alloc_const(ir::alloc_const *alloc) { void generator::visit_function(ir::function* fn) { LLVMContext &ctx = builder_->getContext(); - FunctionType *fn_ty = (FunctionType*)type(fn->get_fn_type(), *ctx_); + FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), *ctx_); if(!tgt_->is_gpu()){ Type *fn_ret_ty = fn_ty->getReturnType(); std::vector fn_args_ty; @@ -925,11 +910,11 @@ void generator::visit_function(ir::function* fn) { void generator::visit_layout_hmma_884(analysis::layout_hmma_884_t* layout) { - machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); + machine_layouts_[layout] = new machine_layout_hmma_884_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); } void generator::visit_layout_scanline(analysis::layout_scanline_t* layout) { - machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); + machine_layouts_[layout] = new machine_layout_scanline_t(mod_, &*builder_, tgt_, llvm_type(layout->ty->get_scalar_ty(), *ctx_), a_axes_, axes_, layout); } void generator::visit_layout_shared(analysis::layout_shared_t* layout) { @@ -1026,5 +1011,29 @@ void generator::finalize_phi_node(ir::phi_node *phi) { } } +void generator::visit(ir::module &src, llvm::Module &dst) { + mod_ = &dst; + ctx_ = &dst.getContext(); + builder_ = new Builder(*ctx_); + // allocate shared memory + if(tgt_->is_gpu()) + if(unsigned alloc_size = alloc_->allocated_size()){ + Type *int_8_ty = Type::getInt8Ty(*ctx_); + ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size); + Type *ptr_ty = PointerType::get(int_8_ty, 3); + GlobalVariable *sh_mem_array = + new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage, + nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); + sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty); + } + // allocate constant memory + for(ir::alloc_const *x: src.allocs()) + visit_alloc_const(x); + // visit functions + for(ir::function *fn: src.get_function_list()) + visit_function(fn); +} + + } } diff --git a/lib/codegen/selection/machine_layout.cc b/lib/codegen/selection/machine_layout.cc index 1e6f0d5da..ac242c815 100644 --- a/lib/codegen/selection/machine_layout.cc +++ b/lib/codegen/selection/machine_layout.cc @@ -1,6 +1,7 @@ #include #include "triton/codegen/selection/machine_layout.h" #include "triton/codegen/selection/machine_value.h" +#include "triton/codegen/selection/generator.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/axes.h" #include "triton/codegen/target.h" @@ -13,18 +14,18 @@ namespace codegen{ using namespace llvm; -inline Type *type(ir::type *ty, LLVMContext &ctx) { +inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) { // function if(auto* tt = dynamic_cast(ty)){ - Type *return_ty = type(tt->get_return_ty(), ctx); + Type *return_ty = llvm_type(tt->get_return_ty(), ctx); std::vector param_tys; std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), - [&ctx](ir::type* t){ return type(t, ctx);}); + [&ctx](ir::type* t){ return llvm_type(t, ctx);}); return FunctionType::get(return_ty, param_tys, false); } // pointer if(ty->is_pointer_ty()){ - Type *elt_ty = type(ty->get_pointer_element_ty(), ctx); + Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx); unsigned addr_space = ty->get_pointer_address_space(); return PointerType::get(elt_ty, addr_space); } @@ -80,7 +81,7 @@ machine_layout_shared_t::machine_layout_shared_t(Module *mod, Builder *builder, auto shapes = layout_->shapes; shapes[order[0]] += layout_->pad; - Type* ty = type(layout_->ty, builder_->getContext()); + Type* ty = llvm_type(layout_->ty, builder_->getContext()); PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace()); // double-buffered @@ -113,7 +114,7 @@ tile* machine_layout_shared_t::create(ir::value *v) { auto order = layout_->order; auto shapes = layout_->shapes; shapes[order[0]] += layout_->pad; - Type* ty = type(layout_->ty, builder_->getContext()); + Type* ty = llvm_type(layout_->ty, builder_->getContext()); // double-buffered if(layout_->double_buffer) { if(v == layout_->double_buffer->phi) @@ -135,7 +136,7 @@ machine_layout_distributed_t::machine_layout_distributed_t(Module *mod, Builder } tile *machine_layout_distributed_t::create(ir::value *v) { - Type *ty = type(v->get_type()->get_scalar_ty(), builder_->getContext()); + Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext()); const auto &shapes = v->get_type()->get_tile_shapes(); std::vector axes(shapes.size()); for(size_t d = 0; d < shapes.size(); d++){ diff --git a/lib/codegen/selection/selection.cc b/lib/codegen/selection/selection.cc deleted file mode 100644 index 49fa1b714..000000000 --- a/lib/codegen/selection/selection.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include -#include "triton/codegen/selection/selection.h" -#include "triton/codegen/selection/generator.h" -#include "triton/ir/module.h" - -namespace triton{ -namespace codegen{ - -using namespace llvm; - -void selection::run(ir::module &src, Module &dst) { - generator gen(&dst, a_axes_, tgt_, layouts_, alignment_, alloc_, num_warps_ ); - for(ir::alloc_const *x: src.allocs()) - gen.visit_alloc_const(x); - for(ir::function *fn: src.get_function_list()) - gen.visit_function(fn); -} - -} -} diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index b83ea8442..115b739d9 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -13,7 +13,7 @@ #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/reassociate.h" #include "triton/codegen/transform/cts.h" -#include "triton/codegen/selection.h" +#include "triton/codegen/selection/generator.h" #include "triton/runtime/function.h" #include "triton/lang/cpp.h" #include "triton/lang/parser.h" @@ -217,7 +217,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::transform::reassociate reassociate(&align); codegen::transform::coalesce coalesce(&align, &layouts); codegen::transform::cts cts; - codegen::selection selection(&liveness, &allocation, &align, &axes, &layouts, target.get(), opt.num_warps); + codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps); // run passes // ir::print(module, std::cout); peephole.run(module); @@ -243,7 +243,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c return std::unique_ptr(); barriers.run(module); // ir::print(module, std::cout); - selection.run(module, *llvm); + isel.visit(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); // done