diff --git a/cmake/FindLLVM.cmake b/cmake/FindLLVM.cmake index dcc0488b5..3de161c64 100644 --- a/cmake/FindLLVM.cmake +++ b/cmake/FindLLVM.cmake @@ -28,7 +28,8 @@ # We also want an user-specified LLVM_ROOT_DIR to take precedence over the # system default locations such as /usr/local/bin. Executing find_program() # multiples times is the approach recommended in the docs. -set(llvm_config_names llvm-config-10 llvm-config-10.0 llvm-config100 +set(llvm_config_names llvm-config-11 llvm-config-11.0 + llvm-config-10 llvm-config-10.0 llvm-config100 llvm-config-9 llvm-config-9.0 llvm-config90 llvm-config-8 llvm-config-8.0 llvm-config80 llvm-config) diff --git a/include/triton/codegen/analysis/axes.h b/include/triton/codegen/analysis/axes.h index dc50223fb..1806ff725 100644 --- a/include/triton/codegen/analysis/axes.h +++ b/include/triton/codegen/analysis/axes.h @@ -27,7 +27,7 @@ private: void update_graph_trans(ir::instruction *i); void update_graph_broadcast(ir::instruction *i); void update_graph_dot(ir::instruction *i); - void update_graph_elementwise(ir::instruction *i); + void update_graph_elementwise(ir::instruction *i, bool connect_ret=true); void update_graph_no_edge(ir::instruction *i); void update_graph(ir::instruction *i); diff --git a/include/triton/codegen/analysis/layout.h b/include/triton/codegen/analysis/layout.h index 42bf34e9a..d11372a2e 100644 --- a/include/triton/codegen/analysis/layout.h +++ b/include/triton/codegen/analysis/layout.h @@ -25,7 +25,7 @@ class axes; class align; class layout_visitor; class data_layout; -class mma884_layout; +class mma_layout; class scanline_layout; class shared_layout; @@ -33,7 +33,7 @@ class shared_layout; class layout_visitor { public: virtual void visit_layout(data_layout *); - virtual void visit_layout_hmma_884(mma884_layout*) = 0; + virtual void visit_layout_mma(mma_layout*) = 0; virtual void visit_layout_scanline(scanline_layout*) = 0; virtual void visit_layout_shared(shared_layout*) = 0; }; @@ -41,7 +41,7 @@ public: class data_layout { protected: enum id_t { - HMMA_884, + MMA, SCANLINE, SHARED }; @@ -68,7 +68,7 @@ public: // visitor virtual void accept(layout_visitor* vst) = 0; // downcast - mma884_layout* to_mma884() { return downcast(HMMA_884); } + mma_layout* to_mma() { return downcast(MMA); } scanline_layout* to_scanline() { return downcast(SCANLINE); } shared_layout* to_shared() { return downcast(SHARED); } // accessors @@ -77,9 +77,10 @@ public: const order_t& get_order() const { return order_; } const values_t& get_values() const { return values_;} int get_axis(size_t k) const { return axes_.at(k); } + std::vector get_axes() const { return axes_; } const int get_order(size_t k) const { return order_.at(k); } // find the position of given axis - size_t find_axis(int to_find) const; + int find_axis(int to_find) const; private: @@ -92,21 +93,29 @@ protected: shape_t shape_; }; -class mma884_layout: public data_layout { +class mma_layout: public data_layout { public: - mma884_layout(size_t num_warps, + mma_layout(size_t num_warps, const std::vector& axes, const std::vector& shapes, const std::vector &values, - analysis::align* align); - void accept(layout_visitor* vst) { vst->visit_layout_hmma_884(this); } + analysis::align* align, target *tgt, + shared_layout* layout_a, + shared_layout* layout_b); + void accept(layout_visitor* vst) { vst->visit_layout_mma(this); } // accessor int fpw(size_t k) { return fpw_.at(k); } int wpt(size_t k) { return wpt_.at(k); } + int spw(size_t k) { return spw_.at(k); } + int spt(size_t k) { return spt_.at(k); } + int rep(size_t k) { return rep_.at(k); } private: std::vector fpw_; + std::vector spw_; std::vector wpt_; + std::vector spt_; + std::vector rep_; }; struct scanline_layout: public data_layout { @@ -138,7 +147,7 @@ private: static void extract_double_bufferable(ir::value *v, std::shared_ptr& res); public: - shared_layout(const data_layout *arg, + shared_layout(data_layout *arg, const std::vector& axes, const std::vector& shapes, const std::vector &values_, @@ -149,11 +158,22 @@ public: size_t get_size() { return size_; } ir::type* get_type() { return ty_; } double_buffer_info_t* get_double_buffer() { return double_buffer_.get(); } + size_t get_num_per_phase() { return num_per_phase_; } + ir::value* hmma_dot_a() { return hmma_dot_a_; } + ir::value* hmma_dot_b() { return hmma_dot_b_; } + void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; } + int get_mma_vec() { return mma_vec_;} + data_layout* get_arg_layout() { return arg_layout_; } private: size_t size_; ir::type *ty_; std::shared_ptr double_buffer_; + size_t num_per_phase_; + ir::value* hmma_dot_a_; + ir::value* hmma_dot_b_; + data_layout* arg_layout_; + int mma_vec_; }; diff --git a/include/triton/codegen/analysis/swizzle.h b/include/triton/codegen/analysis/swizzle.h new file mode 100644 index 000000000..6f2833a68 --- /dev/null +++ b/include/triton/codegen/analysis/swizzle.h @@ -0,0 +1,43 @@ +#ifndef TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H +#define TRITON_INCLUDE_IR_CODEGEN_SWIZZLE_H + +#include + +namespace triton{ + +namespace ir{ + class module; +} + +namespace codegen{ +class target; + +namespace analysis{ + +class layouts; +class data_layout; + +class swizzle { +public: + // constructor + swizzle(layouts *l, target* tgt): layouts_(l), tgt_(tgt){ } + // accessors + int get_per_phase(data_layout* layout) { return per_phase_.at(layout); } + int get_max_phase(data_layout* layout) { return max_phase_.at(layout); } + int get_vec (data_layout* layout) { return vec_.at(layout); } + // run + void run(ir::module &mod); +private: + layouts* layouts_; + target* tgt_; + std::map per_phase_; + std::map max_phase_; + std::map vec_; +}; + +} +} +} + + +#endif diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index f13a3305b..1524f53e4 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -5,13 +5,14 @@ #include "triton/ir/visitor.h" #include "triton/codegen/analysis/layout.h" -#include "triton/codegen/selection/machine_value.h" #include // forward namespace llvm{ class Type; class Value; + class BasicBlock; + class Attribute; class Instruction; class Constant; class LLVMContext; @@ -25,6 +26,13 @@ namespace llvm{ } namespace triton{ + +namespace ir{ +class attribute; +class load_inst; +class store_inst; +} + namespace codegen{ // forward @@ -36,6 +44,7 @@ class allocation; class cts; class axes; class layouts; +class swizzle; } // typedef typedef llvm::IRBuilder indices_t; -// forward -class machine_data_layout; -class tile; -class shared_tile; -class distributed_tile; class target; } @@ -62,110 +68,129 @@ class target; namespace triton{ namespace codegen{ +struct distributed_axis { + int contiguous; + std::vector values; + Value* thread_id; +}; class generator: public ir::visitor, public analysis::layout_visitor { private: - void for_each(ir::value *x, const std::function& fn); - Value* get_value(ir::value *x, const indices_t& idx); - void set_value(ir::value *x, const indices_t& idx, Value* v); - - void visit_hmma_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK); - void visit_scanline_dot(ir::dot_inst*, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, Type *c_ty, Function *f_mul_add); - void visit_outer_dot(ir::dot_inst*, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, - Type *c_ty, Function *f_mul_add); - + void init_idx(ir::value *x); + Instruction* add_barrier(); + Value* shared_off(const std::vector& shapes, const std::vector& order, indices_t idx); void finalize_shared_layout(analysis::shared_layout*); void finalize_function(ir::function*); void finalize_phi_node(ir::phi_node*); +private: + Type *cvt(ir::type *ty); + llvm::Attribute cvt(ir::attribute attr); + public: generator(analysis::axes *a_axes, analysis::layouts *layouts, analysis::align *alignment, analysis::allocation *alloc, + analysis::swizzle *swizzle, target *tgt, unsigned num_warps); void visit_value(ir::value* v); - void visit_phi_node(ir::phi_node*); void visit_binary_operator(ir::binary_operator*); void visit_getelementptr_inst(ir::getelementptr_inst*); - void visit_icmp_inst(ir::icmp_inst*); void visit_fcmp_inst(ir::fcmp_inst*); void visit_cast_inst(ir::cast_inst*); - void visit_return_inst(ir::return_inst*); void visit_cond_branch_inst(ir::cond_branch_inst*); void visit_uncond_branch_inst(ir::uncond_branch_inst*); - - + void visit_load_inst(ir::load_inst*); void visit_unmasked_load_inst(ir::unmasked_load_inst*); void visit_masked_load_inst(ir::masked_load_inst*); + void visit_store_inst(ir::store_inst*); void visit_unmasked_store_inst(ir::unmasked_store_inst*); void visit_masked_store_inst(ir::masked_store_inst*); - void visit_reshape_inst(ir::reshape_inst*); void visit_splat_inst(ir::splat_inst*); void visit_broadcast_inst(ir::broadcast_inst*); void visit_downcast_inst(ir::downcast_inst*); - void visit_exp_inst(ir::exp_inst*); void visit_log_inst(ir::log_inst*); - void visit_get_program_id_inst(ir::get_program_id_inst*); void visit_get_num_program_inst(ir::get_num_program_inst*); void visit_atomic_cas_inst(ir::atomic_cas_inst*); void visit_atomic_exch_inst(ir::atomic_exch_inst*); void visit_atomic_add_inst(ir::atomic_add_inst*); + void visit_mma884(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); + void visit_mma16816(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK); + void visit_fmadot(ir::dot_inst*, ir::value *A, ir::value *B, ir::value *D, unsigned NK, Type *c_ty, Function *f_mul_add); void visit_dot_inst(ir::dot_inst*); void visit_trans_inst(ir::trans_inst*); void visit_sqrt_inst(ir::sqrt_inst*); + void visit_reduce1d_inst(ir::reduce_inst*, std::function, Value*); + void visit_reducend_inst(ir::reduce_inst*, std::function, Value*); void visit_reduce_inst(ir::reduce_inst*); void visit_select_inst(ir::select_inst*); - void visit_recoalesce_inst(ir::recoalesce_inst*); + void visit_masked_load_async_inst(ir::masked_load_async_inst*); void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); void visit_barrier_inst(ir::barrier_inst*); + void visit_async_wait_inst(ir::async_wait_inst*); void visit_make_range_dyn(ir::make_range_dyn*); void visit_make_range(ir::make_range*); - void visit_make_range_sta(ir::make_range_sta*); void visit_undef_value(ir::undef_value*); void visit_constant_int(ir::constant_int*); void visit_constant_fp(ir::constant_fp*); void visit_alloc_const(ir::alloc_const*); - void visit_function(ir::function*); void visit_basic_block(ir::basic_block*); void visit_argument(ir::argument*); + void visit(ir::module &, llvm::Module &); - void visit_layout_hmma_884(analysis::mma884_layout*); + // layouts + void visit_layout_mma(analysis::mma_layout*); void visit_layout_scanline(analysis::scanline_layout*); void visit_layout_shared(analysis::shared_layout*); - void visit(ir::module &, llvm::Module &); private: LLVMContext *ctx_; Builder* builder_; Module *mod_; - std::map machine_layouts_; analysis::axes *a_axes_; + analysis::swizzle *swizzle_; std::map axes_; - std::map vmap_; - std::map tmap_; target *tgt_; analysis::layouts *layouts_; analysis::align *alignment_; analysis::allocation *alloc_; - Value *sh_mem_ptr_; + Value *shmem_; unsigned num_warps_; - std::set seen_; + + std::map offset_a_m_; + std::map offset_a_k_; + std::map offset_b_k_; + std::map offset_b_n_; + + std::map shared_ptr_; + std::map shared_pre_ptr_; + std::map shared_next_ptr_; + std::map shared_off_; + + + std::map shmems_; + std::map shoffs_; + std::map> idxs_; + std::map> vals_; + std::map bbs_; + std::map> ords_; + }; } diff --git a/include/triton/codegen/selection/machine_layout.h b/include/triton/codegen/selection/machine_layout.h deleted file mode 100644 index 5458f15d3..000000000 --- a/include/triton/codegen/selection/machine_layout.h +++ /dev/null @@ -1,138 +0,0 @@ -#pragma once - -#ifndef _TRITON_SELECTION_MACHINE_LAYOUT_H_ -#define _TRITON_SELECTION_MACHINE_LAYOUT_H_ - -#include -#include "triton/codegen/analysis/layout.h" - -namespace llvm{ - class Type; - class Value; - class Instruction; - class Constant; - class LLVMContext; - class Module; - class ConstantFolder; - class IRBuilderDefaultInserter; - template - class IRBuilder; - class ArrayType; - class Function; -} - -namespace triton{ - -namespace ir{ -class value; -} - -namespace codegen{ - -namespace analysis{ -class liveness; -class tiles; -class align; -class allocation; -class cts; -class axes; -class layouts; -} - -typedef llvm::IRBuilder Builder; -typedef llvm::LLVMContext LLVMContext; -typedef llvm::Type Type; -typedef llvm::Value Value; -typedef llvm::Module Module; -typedef llvm::Instruction Instruction; -typedef llvm::Constant Constant; -typedef llvm::ArrayType ArrayType; -typedef llvm::Function Function; - -class distributed_axis; -class machine_data_layout; -class tile; -class shared_tile; -class distributed_tile; -class target; - -} -} - -namespace triton{ -namespace codegen{ - - -class machine_data_layout { -public: - virtual tile* create(ir::value *v) = 0; -}; - -class machine_shared_layout: public machine_data_layout { -public: - machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, Value *&sh_mem_ptr, - analysis::shared_layout* layout, - std::map& vmap, - std::map& tmap); - - tile* create(ir::value *v); - - Module *mod_; - Builder *builder_; - target *tgt_; - analysis::allocation* alloc_; - Value *&sh_mem_ptr_; - analysis::shared_layout* layout_; - std::map& vmap_; - std::map& tmap_; - - Value *offset_; - Value *ptr_; - Value *pre_ptr_; - Value *next_ptr_; - -}; - -class machine_distributed_layout: public machine_data_layout { -public: - machine_distributed_layout(Module *mod, Builder *builder, target *tgt, - analysis::axes *a_axes, std::map& axes, - analysis::data_layout* layout); - - tile* create(ir::value *v); - Module *mod_; - Builder *builder_; - target *tgt_; - analysis::axes *a_axes_; - std::map& axes_; - analysis::data_layout* layout_; -}; - - -class machine_mma884_layout: public machine_distributed_layout { -public: - machine_mma884_layout(Module *mod, Builder *builder, - target *tgt, - analysis::axes *a_axes, std::map& axes, - analysis::mma884_layout* layout); - Value *offset_a_i_, *offset_a_k_; - Value *offset_b_j_, *offset_b_k_; - unsigned pack_size_0_; - unsigned pack_size_1_; - unsigned num_packs_0_; - unsigned num_packs_1_; -}; - -class machine_scanline_layout: public machine_distributed_layout { -public: - machine_scanline_layout(Module *mod, Builder *builder, - target *tgt, - analysis::axes *a_axes, std::map& axes, - analysis::scanline_layout* layout); -}; - -} -} - -#endif diff --git a/include/triton/codegen/selection/machine_value.h b/include/triton/codegen/selection/machine_value.h deleted file mode 100644 index 67c2ed394..000000000 --- a/include/triton/codegen/selection/machine_value.h +++ /dev/null @@ -1,152 +0,0 @@ -#pragma once - -#ifndef _TRITON_SELECTION_MACHINE_VALUE_H_ -#define _TRITON_SELECTION_MACHINE_VALUE_H_ - -#include -#include -#include - -namespace llvm{ - class Type; - class Value; - class Instruction; - class Constant; - class LLVMContext; - class Module; - class ConstantFolder; - class IRBuilderDefaultInserter; - template - class IRBuilder; - class ArrayType; - class Function; -} - -namespace triton{ -namespace codegen{ - typedef llvm::IRBuilder Builder; - typedef llvm::LLVMContext LLVMContext; - typedef llvm::Type Type; - typedef llvm::Value Value; - typedef llvm::Module Module; - typedef llvm::Instruction Instruction; - typedef llvm::Constant Constant; - typedef llvm::ArrayType ArrayType; - typedef llvm::Function Function; -} -} - -namespace triton{ -namespace codegen{ - -namespace analysis{ -class liveness; -class tiles; -class align; -class allocation; -class cts; -class axes; -class layouts; -} - -class distributed_axis; -class machine_data_layout; -class tile; -class shared_tile; -class distributed_tile; -class target; -typedef std::vector indices_t; - -} -} - -namespace triton{ -namespace codegen{ - -struct distributed_axis { - int contiguous; - std::vector values; - Value* thread_id; -}; - -class tile { -protected: - typedef std::vector shapes_t; - -public: - tile(Type *ty, const shapes_t &shapes): ty_(ty), shapes_(shapes){ } - virtual void set_value(indices_t idx, Value *v) = 0; - virtual Value* get_value(indices_t idx) = 0; - Type *get_ty() const { return ty_; } - shapes_t get_shapes() const { return shapes_; } - -protected: - Type *ty_; - shapes_t shapes_; -}; - -class shared_tile: public tile { -private: - void extract_constant(Value *arg, Value *&non_cst, Value *&cst); - void extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx); - - -public: - shared_tile(Type* ty, const shapes_t &shapes, const std::vector &order, Value* ptr, Builder &builder, Value* offset = nullptr, const std::vector& perm = {}); - void set_vector_size(unsigned vector_size); - void set_return_mode(bool return_vector); - void set_value(indices_t, Value *); - Value* get_ptr_to(indices_t idx); - Value* get_value(indices_t idx); - Value* get_pointer() { return ptr_; } - Value* get_offset() { return offset_; } - const std::vector& get_perm() { return perm_; } - const std::vector& get_order() { return order_; } - static Value* shared_offset(Builder& builder, const shapes_t& shapes, const std::vector& perm, const std::vector& order, indices_t idx); - -private: - Value *ptr_; - bool return_vector_; - Builder &builder_; - Value *offset_; - std::map ptr_cache_; - unsigned vector_size_; - std::vector order_; - std::vector perm_; -}; - -// Distribtued tile -class distributed_tile: public tile{ - typedef std::vector axes_t; - typedef std::vector ordered_indices_vec_t; - typedef std::map indices_map_t; - typedef std::map values_map_t; - -private: - void init_indices(); - -public: - distributed_tile(Type *ty, const shapes_t& shapes, const std::vector& order, const axes_t &axes, Builder &builder); - void set_value(indices_t idx, Value *v); - Value* get_value(indices_t idx); - const std::vector& get_order() { return order_; } - unsigned get_linear_index(indices_t idx); - indices_t get_ordered_indices(unsigned id); - void for_each(std::function fn, int start = 0, int end = -1); - void for_each(std::function fn, std::vector start, std::vector size); - - const distributed_axis &axis(unsigned dim) { return axes_.at(dim); } -private: - axes_t axes_; - std::vector order_; - indices_map_t indices_; - values_map_t values_; - ordered_indices_vec_t ordered_indices_; - Builder &builder_; -}; - -} -} - -#endif diff --git a/include/triton/codegen/target.h b/include/triton/codegen/target.h index dc379bd0c..96e4d5c31 100644 --- a/include/triton/codegen/target.h +++ b/include/triton/codegen/target.h @@ -35,6 +35,8 @@ namespace codegen{ namespace triton{ namespace codegen{ +class nvidia_cu_target; + class target { public: target(bool is_gpu): is_gpu_(is_gpu){} @@ -47,6 +49,7 @@ public: virtual Value* get_block_id(Module *module, Builder& builder, unsigned ax) = 0; virtual Value* get_num_blocks(Module *module, Builder& builder, unsigned ax) = 0; virtual unsigned guaranteed_alignment() = 0; + nvidia_cu_target* as_nvidia(); bool is_gpu() const; private: @@ -68,7 +71,7 @@ public: class nvidia_cu_target: public target { public: - nvidia_cu_target(): target(true){} + nvidia_cu_target(int sm): target(true), sm_(sm){} void set_kernel(Builder& builder, LLVMContext &ctx, Module *module, Function* fn); Instruction* add_barrier(Module *module, Builder& builder); Instruction* add_memfence(Module *module, Builder& builder); @@ -76,7 +79,11 @@ public: Value* get_local_id(Module *module, Builder& builder, unsigned ax); Value* get_block_id(Module *module, Builder& builder, unsigned ax); Value* get_num_blocks(Module *module, Builder& builder, unsigned ax); + int sm() { return sm_; } unsigned guaranteed_alignment() { return 16; } + +private: + int sm_; }; class cpu_target: public target { diff --git a/include/triton/codegen/transform/cts.h b/include/triton/codegen/transform/cts.h index b4289305b..dcc5f36c2 100644 --- a/include/triton/codegen/transform/cts.h +++ b/include/triton/codegen/transform/cts.h @@ -11,14 +11,22 @@ namespace ir { class value; class phi_node; class instruction; + class builder; } namespace codegen{ namespace transform{ class cts { +private: + void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared); + public: + cts(bool use_async = false): use_async_(use_async) {} void run(ir::module &mod); + +private: + bool use_async_; }; } diff --git a/include/triton/codegen/transform/membar.h b/include/triton/codegen/transform/membar.h index 015f44f3d..1c2878d64 100644 --- a/include/triton/codegen/transform/membar.h +++ b/include/triton/codegen/transform/membar.h @@ -1,6 +1,8 @@ #ifndef TDL_INCLUDE_CODEGEN_BARRIERS_H #define TDL_INCLUDE_CODEGEN_BARRIERS_H +#include + namespace triton { namespace ir { @@ -31,14 +33,14 @@ private: private: interval_vec_t join(const std::vector& intervals); - void insert_barrier(ir::instruction *instr, ir::builder &builder); + void insert_barrier(ir::instruction *instr, std::pair type, ir::builder &builder); bool intersect(const interval_vec_t &X, interval_t x); bool intersect(const interval_vec_t &X, const interval_vec_t &Y); void add_reference(ir::value *v, interval_vec_t &res); void get_read_intervals(ir::instruction *i, interval_vec_t &res); void get_written_intervals(ir::instruction *i, interval_vec_t &res); std::pair transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, - std::set &insert_loc, std::set &safe_war); + std::map > &insert_loc, std::set &safe_war, std::vector &to_sync); public: membar(analysis::liveness *liveness, analysis::layouts *layouts, analysis::allocation *alloc): diff --git a/include/triton/codegen/transform/peephole.h b/include/triton/codegen/transform/peephole.h index 2ea27937c..19dba70a9 100644 --- a/include/triton/codegen/transform/peephole.h +++ b/include/triton/codegen/transform/peephole.h @@ -1,6 +1,7 @@ #ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H #define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H +#include "triton/codegen/target.h" namespace triton { @@ -27,12 +28,16 @@ private: bool rewrite_mult(ir::instruction *value, ir::builder& builder); bool rewrite_unit_red(ir::instruction *value, ir::builder& builder); bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder); + bool rewrite_load_to_shared(ir::instruction *value, ir::builder& builder); private: public: - peephole() {} + peephole(target* tgt): tgt_(tgt) {} void run(ir::module &mod); + +private: + target* tgt_; }; diff --git a/include/triton/codegen/transform/reorder.h b/include/triton/codegen/transform/reorder.h new file mode 100644 index 000000000..3b48a330f --- /dev/null +++ b/include/triton/codegen/transform/reorder.h @@ -0,0 +1,26 @@ +#ifndef TRITON_INCLUDE_IR_CODEGEN_REORDER_H +#define TRITON_INCLUDE_IR_CODEGEN_REORDER_H + +namespace triton { + +// forward declaration +namespace ir { +class module; +} + +namespace codegen{ + +namespace transform{ + +class reorder { +public: + void run(ir::module& module); +}; + +} + +} + +} + +#endif diff --git a/include/triton/driver/device.h b/include/triton/driver/device.h index 8110c0bc7..c39c768ca 100755 --- a/include/triton/driver/device.h +++ b/include/triton/driver/device.h @@ -39,43 +39,23 @@ public: // CUDA device class cu_device: public device { -public: - //Supported architectures - enum class Architecture{ - //NVidia - SM_2_0, - SM_2_1, - SM_3_0, - SM_3_5, - SM_3_7, - SM_5_0, - SM_5_2, - SM_6_0, - SM_6_1, - SM_7_0, - UNKNOWN - }; - private: //Metaprogramming elper to get cuda info from attribute template int cuGetInfo() const; - inline Architecture nv_arch(std::pair sm) const; inline nvmlDevice_t nvml_device() const; public: cu_device(CUdevice cu = CUdevice(), bool take_ownership = true): device(cu, take_ownership){} - // Accessors - Architecture architecture() const; // Informations std::string infos() const; size_t address_bits() const; std::vector max_block_dim() const; size_t warp_size() const; // Compute Capability - void interpret_as(std::pair cc); - std::pair compute_capability() const; + void interpret_as(int cc); + int compute_capability() const; // Identifier std::string name() const; std::string pci_bus_id() const; @@ -91,7 +71,7 @@ public: std::unique_ptr make_target() const; private: - std::shared_ptr> interpreted_as_; + std::shared_ptr interpreted_as_; }; } diff --git a/include/triton/driver/error.h b/include/triton/driver/error.h index affbae94a..0e634abe6 100755 --- a/include/triton/driver/error.h +++ b/include/triton/driver/error.h @@ -19,18 +19,18 @@ namespace triton namespace nvrtc { -#define ISAAC_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } } +#define TRITON_CREATE_NVRTC_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "NVRTC: Error- " msg; } } - ISAAC_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory"); - ISAAC_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure"); - ISAAC_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input"); - ISAAC_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program"); - ISAAC_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option"); - ISAAC_CREATE_NVRTC_EXCEPTION(compilation ,"compilation"); - ISAAC_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure"); - ISAAC_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error"); + TRITON_CREATE_NVRTC_EXCEPTION(out_of_memory ,"out of memory"); + TRITON_CREATE_NVRTC_EXCEPTION(program_creation_failure ,"program creation failure"); + TRITON_CREATE_NVRTC_EXCEPTION(invalid_input ,"invalid input"); + TRITON_CREATE_NVRTC_EXCEPTION(invalid_program ,"invalid program"); + TRITON_CREATE_NVRTC_EXCEPTION(invalid_option ,"invalid option"); + TRITON_CREATE_NVRTC_EXCEPTION(compilation ,"compilation"); + TRITON_CREATE_NVRTC_EXCEPTION(builtin_operation_failure ,"builtin operation failure"); + TRITON_CREATE_NVRTC_EXCEPTION(unknown_error ,"unknown error"); -#undef ISAAC_CREATE_NVRTC_EXCEPTION +#undef TRITON_CREATE_NVRTC_EXCEPTION } @@ -38,107 +38,107 @@ namespace triton { class base: public std::exception{}; -#define ISAAC_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } } +#define TRITON_CREATE_CUDA_EXCEPTION(name, msg) class name: public base { public:const char * what() const throw(){ return "CUDA: Error- " msg; } } - ISAAC_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value"); - ISAAC_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory"); - ISAAC_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized"); - ISAAC_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized"); - ISAAC_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled"); - ISAAC_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized"); - ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started"); - ISAAC_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped"); - ISAAC_CREATE_CUDA_EXCEPTION(no_device ,"no device"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context"); - ISAAC_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current"); - ISAAC_CREATE_CUDA_EXCEPTION(map_failed ,"map failed"); - ISAAC_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed"); - ISAAC_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped"); - ISAAC_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped"); - ISAAC_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu"); - ISAAC_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired"); - ISAAC_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped"); - ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array"); - ISAAC_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer"); - ISAAC_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable"); - ISAAC_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit"); - ISAAC_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use"); - ISAAC_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source"); - ISAAC_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found"); - ISAAC_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found"); - ISAAC_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed"); - ISAAC_CREATE_CUDA_EXCEPTION(operating_system ,"operating system"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle"); - ISAAC_CREATE_CUDA_EXCEPTION(not_found ,"not found"); - ISAAC_CREATE_CUDA_EXCEPTION(not_ready ,"not ready"); - ISAAC_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address"); - ISAAC_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources"); - ISAAC_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout"); - ISAAC_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing"); - ISAAC_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled"); - ISAAC_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled"); - ISAAC_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active"); - ISAAC_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed"); - ISAAC_CREATE_CUDA_EXCEPTION(assert_error ,"assert"); - ISAAC_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers"); - ISAAC_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered"); - ISAAC_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered"); - ISAAC_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error"); - ISAAC_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction"); - ISAAC_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space"); - ISAAC_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc"); - ISAAC_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed"); - ISAAC_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted"); - ISAAC_CREATE_CUDA_EXCEPTION(not_supported ,"not supported"); - ISAAC_CREATE_CUDA_EXCEPTION(unknown ,"unknown"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_value ,"invalid value"); + TRITON_CREATE_CUDA_EXCEPTION(out_of_memory ,"out of memory"); + TRITON_CREATE_CUDA_EXCEPTION(not_initialized ,"not initialized"); + TRITON_CREATE_CUDA_EXCEPTION(deinitialized ,"deinitialized"); + TRITON_CREATE_CUDA_EXCEPTION(profiler_disabled ,"profiler disabled"); + TRITON_CREATE_CUDA_EXCEPTION(profiler_not_initialized ,"profiler not initialized"); + TRITON_CREATE_CUDA_EXCEPTION(profiler_already_started ,"profiler already started"); + TRITON_CREATE_CUDA_EXCEPTION(profiler_already_stopped ,"profiler already stopped"); + TRITON_CREATE_CUDA_EXCEPTION(no_device ,"no device"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_device ,"invalid device"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_image ,"invalid image"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_context ,"invalid context"); + TRITON_CREATE_CUDA_EXCEPTION(context_already_current ,"context already current"); + TRITON_CREATE_CUDA_EXCEPTION(map_failed ,"map failed"); + TRITON_CREATE_CUDA_EXCEPTION(unmap_failed ,"unmap failed"); + TRITON_CREATE_CUDA_EXCEPTION(array_is_mapped ,"array is mapped"); + TRITON_CREATE_CUDA_EXCEPTION(already_mapped ,"already mapped"); + TRITON_CREATE_CUDA_EXCEPTION(no_binary_for_gpu ,"no binary for gpu"); + TRITON_CREATE_CUDA_EXCEPTION(already_acquired ,"already acquired"); + TRITON_CREATE_CUDA_EXCEPTION(not_mapped ,"not mapped"); + TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_array ,"not mapped as array"); + TRITON_CREATE_CUDA_EXCEPTION(not_mapped_as_pointer ,"not mapped as pointer"); + TRITON_CREATE_CUDA_EXCEPTION(ecc_uncorrectable ,"ecc uncorrectable"); + TRITON_CREATE_CUDA_EXCEPTION(unsupported_limit ,"unsupported limit"); + TRITON_CREATE_CUDA_EXCEPTION(context_already_in_use ,"context already in use"); + TRITON_CREATE_CUDA_EXCEPTION(peer_access_unsupported ,"peer access unsupported"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_ptx ,"invalid ptx"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_graphics_context ,"invalid graphics context"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_source ,"invalid source"); + TRITON_CREATE_CUDA_EXCEPTION(file_not_found ,"file not found"); + TRITON_CREATE_CUDA_EXCEPTION(shared_object_symbol_not_found ,"shared object symbol not found"); + TRITON_CREATE_CUDA_EXCEPTION(shared_object_init_failed ,"shared object init failed"); + TRITON_CREATE_CUDA_EXCEPTION(operating_system ,"operating system"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_handle ,"invalid handle"); + TRITON_CREATE_CUDA_EXCEPTION(not_found ,"not found"); + TRITON_CREATE_CUDA_EXCEPTION(not_ready ,"not ready"); + TRITON_CREATE_CUDA_EXCEPTION(illegal_address ,"illegal address"); + TRITON_CREATE_CUDA_EXCEPTION(launch_out_of_resources ,"launch out of resources"); + TRITON_CREATE_CUDA_EXCEPTION(launch_timeout ,"launch timeout"); + TRITON_CREATE_CUDA_EXCEPTION(launch_incompatible_texturing ,"launch incompatible texturing"); + TRITON_CREATE_CUDA_EXCEPTION(peer_access_already_enabled ,"peer access already enabled"); + TRITON_CREATE_CUDA_EXCEPTION(peer_access_not_enabled ,"peer access not enabled"); + TRITON_CREATE_CUDA_EXCEPTION(primary_context_active ,"primary context active"); + TRITON_CREATE_CUDA_EXCEPTION(context_is_destroyed ,"context is destroyed"); + TRITON_CREATE_CUDA_EXCEPTION(assert_error ,"assert"); + TRITON_CREATE_CUDA_EXCEPTION(too_many_peers ,"too many peers"); + TRITON_CREATE_CUDA_EXCEPTION(host_memory_already_registered ,"host memory already registered"); + TRITON_CREATE_CUDA_EXCEPTION(host_memory_not_registered ,"hot memory not registered"); + TRITON_CREATE_CUDA_EXCEPTION(hardware_stack_error ,"hardware stack error"); + TRITON_CREATE_CUDA_EXCEPTION(illegal_instruction ,"illegal instruction"); + TRITON_CREATE_CUDA_EXCEPTION(misaligned_address ,"misaligned address"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_address_space ,"invalid address space"); + TRITON_CREATE_CUDA_EXCEPTION(invalid_pc ,"invalid pc"); + TRITON_CREATE_CUDA_EXCEPTION(launch_failed ,"launch failed"); + TRITON_CREATE_CUDA_EXCEPTION(not_permitted ,"not permitted"); + TRITON_CREATE_CUDA_EXCEPTION(not_supported ,"not supported"); + TRITON_CREATE_CUDA_EXCEPTION(unknown ,"unknown"); -#undef ISAAC_CREATE_CUDA_EXCEPTION +#undef TRITON_CREATE_CUDA_EXCEPTION } namespace cublas { class base: public std::exception{}; -#define ISAAC_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } } +#define TRITON_CREATE_CUBLAS_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "CUBLAS: Error- " msg; } } - ISAAC_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized"); - ISAAC_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed"); - ISAAC_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value"); - ISAAC_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch"); - ISAAC_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error"); - ISAAC_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed"); - ISAAC_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error"); - ISAAC_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported"); - ISAAC_CREATE_CUBLAS_EXCEPTION(license_error ,"license error"); - ISAAC_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown"); + TRITON_CREATE_CUBLAS_EXCEPTION(not_initialized ,"not initialized"); + TRITON_CREATE_CUBLAS_EXCEPTION(alloc_failed ,"alloc failed"); + TRITON_CREATE_CUBLAS_EXCEPTION(invalid_value ,"invalid value"); + TRITON_CREATE_CUBLAS_EXCEPTION(arch_mismatch ,"arch mismatch"); + TRITON_CREATE_CUBLAS_EXCEPTION(mapping_error ,"mapping error"); + TRITON_CREATE_CUBLAS_EXCEPTION(execution_failed ,"execution failed"); + TRITON_CREATE_CUBLAS_EXCEPTION(internal_error ,"internal error"); + TRITON_CREATE_CUBLAS_EXCEPTION(not_supported ,"not supported"); + TRITON_CREATE_CUBLAS_EXCEPTION(license_error ,"license error"); + TRITON_CREATE_CUBLAS_EXCEPTION(unknown ,"unknown"); -#undef ISAAC_CREATE_CUBLAS_EXCEPTION +#undef TRITON_CREATE_CUBLAS_EXCEPTION } namespace cudnn { -#define ISAAC_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } } +#define TRITON_CREATE_CUDNN_EXCEPTION(name, msg) class name: public std::exception { public: const char * what() const throw(){ return "CUDNN: Error- " msg; } } - ISAAC_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized"); - ISAAC_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed"); - ISAAC_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param"); - ISAAC_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error"); - ISAAC_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value"); - ISAAC_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch"); - ISAAC_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error"); - ISAAC_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed"); - ISAAC_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported"); - ISAAC_CREATE_CUDNN_EXCEPTION(license_error ,"license error"); - ISAAC_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing"); - ISAAC_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress"); - ISAAC_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow"); + TRITON_CREATE_CUDNN_EXCEPTION(not_initialized ,"not initialized"); + TRITON_CREATE_CUDNN_EXCEPTION(alloc_failed ,"allocation failed"); + TRITON_CREATE_CUDNN_EXCEPTION(bad_param ,"bad param"); + TRITON_CREATE_CUDNN_EXCEPTION(internal_error ,"internal error"); + TRITON_CREATE_CUDNN_EXCEPTION(invalid_value ,"invalid value"); + TRITON_CREATE_CUDNN_EXCEPTION(arch_mismatch ,"arch mismatch"); + TRITON_CREATE_CUDNN_EXCEPTION(mapping_error ,"mapping error"); + TRITON_CREATE_CUDNN_EXCEPTION(execution_failed ,"execution failed"); + TRITON_CREATE_CUDNN_EXCEPTION(not_supported ,"not supported"); + TRITON_CREATE_CUDNN_EXCEPTION(license_error ,"license error"); + TRITON_CREATE_CUDNN_EXCEPTION(runtime_prerequisite_missing ,"prerequisite missing"); + TRITON_CREATE_CUDNN_EXCEPTION(runtime_in_progress ,"runtime in progress"); + TRITON_CREATE_CUDNN_EXCEPTION(runtime_fp_overflow ,"runtime fp overflow"); } } diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index 760c54575..0cdfbb84c 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -44,6 +44,13 @@ public: const std::string &features, file_type_t file_type); virtual std::unique_ptr symbol(const char * name) const = 0; + std::string llir() const { return llir_; } + int spilled() const { return spilled_; } + +private: + std::string llir_; +protected: + int spilled_; }; // CPU @@ -59,12 +66,12 @@ class cu_module: public module { public: cu_module(driver::device* device, std::unique_ptr module); - cu_module(const std::string& source); + cu_module(driver::device* device, const std::string& source); std::unique_ptr symbol(const char * name) const; - const std::string& source() const { return source_; } + const std::string& ptx() const { return ptx_; } private: - std::string source_; + std::string ptx_; }; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 7a7ef80bc..0a5c2c8d6 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -146,8 +146,10 @@ public: value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); // Intrinsics value *create_copy_to_shared(value *arg, const std::string &name = ""); + value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = ""); value *create_copy_from_shared(value *arg, const std::string &name = ""); value *create_barrier(const std::string &name = ""); + value *create_async_wait(); private: context &ctx_; diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 86e4b925f..e34a3a1ae 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -7,7 +7,7 @@ namespace triton{ namespace ir{ -enum binary_op_t { +enum binary_op_t: unsigned int{ Add, FAdd, Sub, @@ -28,7 +28,7 @@ enum binary_op_t { Xor }; -enum cast_op_t { +enum cast_op_t: unsigned int { Trunc, ZExt, SExt, @@ -44,7 +44,7 @@ enum cast_op_t { AddrSpaceCast }; -enum cmp_pred_t { +enum cmp_pred_t: unsigned int { FIRST_FCMP_PREDICATE, FCMP_FALSE, FCMP_OEQ, @@ -113,6 +113,7 @@ enum value_id_t: unsigned { // io INST_UNMASKED_LOAD, INST_MASKED_LOAD, + INST_MASKED_LOAD_ASYNC, INST_UNMASKED_STORE, INST_MASKED_STORE, // retile @@ -139,6 +140,7 @@ enum value_id_t: unsigned { INST_COPY_FROM_SHARED, INST_RECOALESCE, INST_BARRIER, + INST_ASYNC_WAIT, INST_MAKE_RANGE_DYN, INST_MAKE_RANGE_STA, INST_MAKE_RANGE diff --git a/include/triton/ir/function.h b/include/triton/ir/function.h index 189fdd88b..2df493d01 100644 --- a/include/triton/ir/function.h +++ b/include/triton/ir/function.h @@ -72,6 +72,7 @@ public: case noalias: return ".noalias"; case aligned: return ".aligned(" + std::to_string(value_) + ")"; case multiple_of: return ".readonly"; + case retune: return ".retunr"; default: break; } assert(false); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 23219b6c9..7eebe8755 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -64,9 +64,10 @@ public: // cloning ir::instruction* clone() { ir::instruction* res = clone_impl(); - for(auto it = op_begin(); it != op_end(); it++) - (*it)->add_use(res); +// for(auto it = op_begin(); it != op_end(); it++) +// (*it)->add_use(res); res->parent_ = nullptr; + res->users_.clear(); return res; } // instruction id @@ -431,6 +432,25 @@ public: _TRITON_DEFINE_ACCEPT(masked_load_inst) }; +// masked load async +class masked_load_async_inst: public load_inst { +private: + std::string repr_impl() const { return "masked_load_async_async"; } + masked_load_async_inst(value *ptr, value *mask, value *false_value, + const std::string &name, instruction *next); + +public: + // accessors + value *get_mask_operand() { return get_operand(1); } + value *get_false_value_operand() { return get_operand(2); } + // factory method + static masked_load_async_inst* create(value *ptr, value *mask, value *false_value, + const std::string &name = "", + instruction *next = nullptr); + _TRITON_DEFINE_CLONE(masked_load_async_inst) + _TRITON_DEFINE_ACCEPT(masked_load_async_inst) +}; + class atomic_add_inst: public io_inst { private: atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr); @@ -757,6 +777,7 @@ public: _TRITON_DEFINE_ACCEPT(copy_from_shared_inst) }; + class recoalesce_inst: public unary_inst{ private: using unary_inst::unary_inst; @@ -780,6 +801,18 @@ public: instruction *next = nullptr); }; +class async_wait_inst: public instruction{ +private: + async_wait_inst(context &ctx, const std::string &name, instruction *next); + std::string repr_impl() const { return "async_wait"; } + _TRITON_DEFINE_CLONE(async_wait_inst) + _TRITON_DEFINE_ACCEPT(async_wait_inst) + +public: + static async_wait_inst* create(context &ctx, const std::string &name = "", + instruction *next = nullptr); +}; + // On NVIDIA, implementation is such that // constant_range = nv_dynamic_program_idx + nv_static_program_idx // so as to enable re-association on nv_static_program_idx which is constant diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index e612ee889..a54e5edfc 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -65,7 +65,9 @@ class select_inst; class recoalesce_inst; class copy_to_shared_inst; class copy_from_shared_inst; +class masked_load_async_inst; class barrier_inst; +class async_wait_inst; class make_range_dyn; class make_range; @@ -139,7 +141,9 @@ public: virtual void visit_recoalesce_inst(recoalesce_inst*) = 0; virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0; virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; + virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0; virtual void visit_barrier_inst(barrier_inst*) = 0; + virtual void visit_async_wait_inst(async_wait_inst*) = 0; virtual void visit_make_range_dyn(make_range_dyn*) = 0; virtual void visit_make_range(make_range*) = 0; diff --git a/include/triton/runtime/error.h b/include/triton/runtime/error.h new file mode 100644 index 000000000..d03c96e35 --- /dev/null +++ b/include/triton/runtime/error.h @@ -0,0 +1,34 @@ +#pragma once + +#ifndef _TRITON_RUNTIME_ERROR_H_ +#define _TRITON_RUNTIME_ERROR_H_ + +#include +#include + +namespace triton { +namespace runtime{ +namespace exception { + +class base: public std::exception {}; +#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } }; + +TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory") +TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers") + +class no_valid_configuration: public exception::base { +public: + no_valid_configuration(const std::string& err): err_(err) { } + const char * what() const throw(){ return err_.c_str(); } +private: + std::string err_; +}; + + +#undef TRITON_CREATE_RUNTIME_EXCEPTION + +} +} +} + +#endif diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 778b47d7f..b5b96f1bb 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ #include "triton/ir/context.h" #include "triton/codegen/target.h" #include "triton/runtime/arg.h" +#include "triton/runtime/error.h" namespace llvm { class Module; @@ -56,33 +58,43 @@ template inline T convert(const std::string& name); template<> inline long convert(const std::string& name) { return std::stol(name); } template<> inline int convert(const std::string& name) { return std::stoi(name); } +template +void add_arg(std::stringstream& ss, T arg) { + ss.write((char*)&arg, sizeof(T)); +} + +enum asm_mode_t { + ASM_LLIR, + ASM_NV_PTX, + ASM_NV_SASS +}; + +struct options_space_t { + typedef std::pair> define_t; + std::vector defines; + std::vector num_warps; + std::vector recompile_key; +}; + +struct options_t { + template + T D(const std::string& name) const { + return convert(defines.at(name)); + } + bool operator<(const options_t& other) const { + return std::make_pair(defines, num_warps) < + std::make_pair(other.defines, other.num_warps); + } + std::string to_str() const; + + std::map defines; + size_t num_warps; +}; + class function { public: - struct options_space_t { - typedef std::pair> define_t; - std::vector defines; - std::vector num_warps; - std::vector recompile_key; - }; - - struct options_t { - template - T D(const std::string& name) const { - return convert(defines.at(name)); - } - bool operator<(const options_t& other) const { - return std::make_pair(defines, num_warps) < - std::make_pair(other.defines, other.num_warps); - } - std::string to_str() const; - - std::map defines; - size_t num_warps; - }; - typedef std::function grid_fn_ty; - private: class caller { public: @@ -135,7 +147,7 @@ public: void operator()(void** args, size_t args_size, const grid_t& grid, driver::stream* stream, driver::device* device); void operator()(void** args, size_t args_size, const grid_fn_ty& grid, driver::stream *stream, driver::device* device); void set_cst(const char* name, void* data, size_t n_bytes); - std::string ptx(driver::device *device, const options_t& opt); + std::string get_asm(asm_mode_t mode, driver::device *device, const options_t& opt); private: std::map> cst_; diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index fdf17eaa2..58f8acecd 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -33,25 +33,20 @@ private: inline double bench(std::function const & op, driver::stream * stream, bool normalize = false) { // const driver::device * device = stream->context()->device(); + size_t warmup = 10; + size_t repeat = 50; timer tmr; std::vector times; double total_time = 0; - op(); + for(size_t i = 0; i < warmup; i++) + op(); stream->synchronize(); tmr.start(); - for(size_t i = 0; i < 10; i++){ -// while(total_time*1e-9 < 1e-2){ -// float norm = 1; - // normalize clock if possible to reduce noise in auto-tuning -// if(normalize) -// if(auto cu_device = dynamic_cast(stream->context()->device())) -// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); + for(size_t i = 0; i < repeat; i++){ op(); -// times.push_back(norm*tmr.get().count()); -// total_time+=times.back(); } stream->synchronize(); - return (float)tmr.get().count() / 10; + return (float)tmr.get().count() / repeat; // return *std::min_element(times.begin(), times.end()); } diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index a83471bdb..1ec198787 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -79,7 +79,7 @@ void axes::update_graph_dot(ir::instruction *i) { graph_.add_edge({dot, d}, {D, d}); } -void axes::update_graph_elementwise(ir::instruction *i) { +void axes::update_graph_elementwise(ir::instruction *i, bool connect_ret) { if(i->get_num_operands() == 0) return; ir::value *op = i->get_operand(0); @@ -89,7 +89,7 @@ void axes::update_graph_elementwise(ir::instruction *i) { for(unsigned d = 0; d < rank; d++) for(ir::value* opx: i->ops()) for(ir::value* opy: i->ops()){ - if(!i->get_type()->is_void_ty()) + if(connect_ret && !i->get_type()->is_void_ty()) graph_.add_edge({i, d}, {opx, d}); graph_.add_edge({opx, d}, {opy, d}); } @@ -111,7 +111,8 @@ void axes::update_graph(ir::instruction *i) { case ir::INST_TRANS: return update_graph_trans(i); case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_DOT: return update_graph_dot(i); - case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);; + case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i); + case ir::INST_MASKED_LOAD_ASYNC:return update_graph_elementwise(i, false); case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); case ir::INST_RECOALESCE: return update_graph_no_edge(i); default: return update_graph_elementwise(i); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 9c9929d01..989963206 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -55,7 +55,7 @@ inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) { for(ir::user* u: v->get_users()){ auto i = dynamic_cast(u); if(i && is_hmma_c(i) && i->get_operand(n) == v) - result = v; + result = i; } } @@ -115,8 +115,10 @@ data_layout::data_layout(id_t id, } } -size_t data_layout::find_axis(int to_find) const { +int data_layout::find_axis(int to_find) const { auto it = std::find(axes_.begin(), axes_.end(), to_find); + if(it == axes_.end()) + return -1; return std::distance(axes_.begin(), it); } @@ -125,23 +127,41 @@ size_t data_layout::find_axis(int to_find) const { * MMA Layout * * -------------------------------- */ -mma884_layout::mma884_layout(size_t num_warps, - const std::vector& axes, - const std::vector& shape, - const std::vector &values, - analysis::align* align): data_layout(HMMA_884, axes, shape, values, align) { +mma_layout::mma_layout(size_t num_warps, + const std::vector& axes, + const std::vector& shape, + const std::vector &values, + analysis::align* align, target* tgt, + shared_layout *layout_a, shared_layout *layout_b): data_layout(MMA, axes, shape, values, align) { /* fragments per warp */ // try to make things as square as possible to maximize data re-use - fpw_ = {1, 1, 1}; - std::vector fpw_nm1; - unsigned num_fragments = std::min((shape_[0]/8)*(shape_[1]/8), 4); - do { - fpw_nm1 = fpw_; - if(fpw_[0]*fpw_[1] < num_fragments) - fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8); - if(fpw_[0]*fpw_[1] < num_fragments) - fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8); - }while(fpw_nm1 != fpw_); + if(tgt->as_nvidia()->sm() < 80){ + fpw_ = {1, 1, 1}; + std::vector fpw_nm1; + unsigned num_fragments = std::min((shape_[0]/8)*(shape_[1]/8), 4); + do { + fpw_nm1 = fpw_; + if(fpw_[0]*fpw_[1] < num_fragments) + fpw_[0] = clamp(fpw_[0]*2, 1, shape_[0] / 8); + if(fpw_[0]*fpw_[1] < num_fragments) + fpw_[1] = clamp(fpw_[1]*2, 1, shape_[1] / 8); + }while(fpw_nm1 != fpw_); + auto ord_a = layout_a->get_order(); + auto ord_b = layout_b->get_order(); + bool is_a_row = ord_a[0] != 0; + bool is_b_row = ord_b[0] != 0; + bool is_a_vec4 = !is_a_row && (layout_a->get_shape()[ord_a[0]] <= 16); + bool is_b_vec4 = is_b_row && (layout_b->get_shape()[ord_b[0]] <= 16); + int pack_size_0 = (is_a_row || is_a_vec4) ? 1 : 2; + int pack_size_1 = (is_b_row && !is_b_vec4) ? 2 : 1; + rep_ = {2*pack_size_0, 2*pack_size_1, 1}; + spw_ = {fpw_[0]*8*pack_size_0, fpw_[1]*8*pack_size_1, 1}; + } + else{ + fpw_ = {1, 1, 1}; + spw_ = {16, 8, 1}; + rep_ = {2, 2, 1}; + } /* warps per tile */ // try to make things as square as possible to maximize data re-use @@ -150,17 +170,13 @@ mma884_layout::mma884_layout(size_t num_warps, do{ wpt_nm1 = wpt_; if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) - wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / (fpw_[0]*8)); + wpt_[0] = clamp(wpt_[0]*2, 1, shape_[0] / spw_[0]); if(wpt_[0] * wpt_[1] * wpt_[2] < num_warps) - wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / (fpw_[1]*8)); + wpt_[1] = clamp(wpt_[1]*2, 1, shape_[1] / spw_[1]); }while(wpt_nm1 != wpt_); - /* sanity check */ - unsigned effective_num_warps = 1; - for(size_t d = 0; d < shape.size(); d++) - effective_num_warps *= wpt_[d]; -// if(num_warps != effective_num_warps) -// throw std::runtime_error("cannot create a kernel with this amount of warps"); + /* shape per block */ + spt_ = {spw_[0]*wpt_[0], spw_[1]*wpt_[1], 1}; } @@ -183,13 +199,15 @@ scanline_layout::scanline_layout(size_t num_warps, ir::value *ptr = nullptr; for(ir::value *v: values) for(ir::user *usr: v->get_users()) - if(auto *st = dynamic_cast(usr)) + if(auto *st = dynamic_cast(usr)) ptr = st->get_pointer_operand(); unsigned i = order_[0]; - int contiguous = 4; - if(ptr) - contiguous = std::min(align->contiguous(ptr)[i], 4); + int contiguous = 1; + if(ptr){ + int nbits = ptr->get_type()->get_pointer_element_ty()->get_scalar_ty()->get_primitive_size_in_bits(); + contiguous = std::min(align->contiguous(ptr)[i], 128 / nbits); + } nts_[i] = clamp(size / num_threads, 1, std::min(contiguous, shape_[i])); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); @@ -204,14 +222,6 @@ scanline_layout::scanline_layout(size_t num_warps, mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); num_threads = num_threads / mts_[i]; } - /* sanity check */ - unsigned effective_num_threads = 1; - for(size_t d = 0; d < shape_.size(); d++) - effective_num_threads *= mts_[d]; - -// std::cout <get_incoming_value(1); ir::instruction *i_0 = dynamic_cast(value_0); ir::instruction *i_1 = dynamic_cast(value_1); - if(!i_0 || !i_1 || - !dynamic_cast(i_0) || - !dynamic_cast(i_1) ) + if(!(i_0 && !i_1) && + !(dynamic_cast(i_0) && dynamic_cast(i_1)) && + !(dynamic_cast(i_0) && dynamic_cast(i_1))) return; if(is_latch_1) res.reset(new double_buffer_info_t{value_0, value_1, phi}); @@ -257,7 +267,7 @@ void shared_layout::extract_double_bufferable(ir::value *v, std::shared_ptr& axes, const std::vector& shape, const std::vector &values, @@ -265,6 +275,7 @@ shared_layout::shared_layout(const data_layout *arg, analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) { size_ = 0; + arg_layout_ = arg; // double-buffering for(ir::value *v: values) @@ -284,36 +295,8 @@ shared_layout::shared_layout(const data_layout *arg, extract_hmma_dot_use(v, hmma_dot_a, 0); extract_hmma_dot_use(v, hmma_dot_b, 1); } - - - // non-mma ordering - std::vector col = {0, 1}; - std::vector row = {1, 0}; - for(size_t s = 2; s < get_rank(); s++){ - col.push_back(s); - row.push_back(s); - } - bool is_nonhmma_dot_a = dot_a && !hmma_dot_a; - bool is_nonhmma_dot_b = dot_b && !hmma_dot_b; - if(is_nonhmma_dot_a) - order_ = is_trans(dot_a) ? row : col; - else if(is_nonhmma_dot_b) - order_ = is_trans(dot_b) ? col : row; - - // padding - size_t pad = 0; - if(hmma_dot_a){ - bool row = is_trans(hmma_dot_a) ^ order_[0] != 0; - pad = 24 - shape_[row ? 0 : 1] % 32; - } - else if(hmma_dot_b){ - bool row = is_trans(hmma_dot_b) ^ order_[0] != 0; - pad = 24 - shape_[row ? 1 : 0] % 32; - } - else if(order_ != arg_order) { - pad = 4; - } - shape_[order_[0]] += pad; + hmma_dot_a_ = hmma_dot_a; + hmma_dot_b_ = hmma_dot_b; // size size_ = ty_->get_primitive_size_in_bits() / 8; @@ -362,6 +345,8 @@ void layouts::make_graph(ir::instruction *i) { } void layouts::create(size_t id, const std::vector& values) { +// if(layouts_.find(id) != layouts_.end()) +// return; auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c); auto cmp = [](ir::value* x, ir::value *y) { std::pair xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()}; @@ -374,19 +359,27 @@ void layouts::create(size_t id, const std::vector& values) { const auto& axes = axes_->get(largest); const auto& shapes = largest->get_type()->get_tile_shapes(); auto it_cts = std::find_if(values.begin(), values.end(), [](ir::value* v) { - return dynamic_cast(v); + return dynamic_cast(v) || + dynamic_cast(v); }); // type - if(it_hmma_c != values.end()) - layouts_[id] = new mma884_layout(num_warps_, axes, shapes, values, align_); + if(it_hmma_c != values.end()){ + ir::instruction *dot = (ir::instruction*)*it_hmma_c; + ir::value *a = dot->get_operand(0); + ir::value *b = dot->get_operand(1); + create(groups_.at(a), values_.at(groups_.at(a))); + create(groups_.at(b), values_.at(groups_.at(b))); + layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b))); + } else if(it_cts != values.end()){ - ir::copy_to_shared_inst *cts = (ir::copy_to_shared_inst*)*it_cts; + ir::instruction *cts = (ir::instruction*)*it_cts; ir::value *arg = cts->get_operand(0); create(groups_.at(arg), values_.at(groups_.at(arg))); layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_); } - else + else{ layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_); + } } void layouts::run(ir::module &mod) { @@ -420,7 +413,7 @@ void layouts::run(ir::module &mod) { } if(auto *recoalasce = dynamic_cast(i)){ ir::value *val = recoalasce->get_operand(0); - mma884_layout* in_layout = get(val)->to_mma884(); + mma_layout* in_layout = get(val)->to_mma(); scanline_layout* out_layout = get(i)->to_scanline(); if(!in_layout || !out_layout) return; @@ -431,7 +424,7 @@ void layouts::run(ir::module &mod) { shape[ld] = in_shape[ld]; for(size_t k = 0; k < in_shape.size(); k++) if(k != ld) - shape[k] = 4*in_layout->to_mma884()->fpw(k)*in_layout->to_mma884()->wpt(k); + shape[k] = in_layout->to_mma()->spt(k); // create layout layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), align_); tmp_[recoalasce] = id; diff --git a/lib/codegen/analysis/swizzle.cc b/lib/codegen/analysis/swizzle.cc new file mode 100644 index 000000000..304aac95a --- /dev/null +++ b/lib/codegen/analysis/swizzle.cc @@ -0,0 +1,54 @@ +#include "triton/codegen/analysis/swizzle.h" +#include "triton/codegen/analysis/layout.h" +#include "triton/codegen/target.h" +#include "triton/ir/type.h" +#include + +namespace triton{ +namespace codegen{ +namespace analysis{ + + +void swizzle::run(ir::module &) { + per_phase_.clear(); + max_phase_.clear(); + + for(auto &x: layouts_->get_all()){ + shared_layout* layout = dynamic_cast(x.second); + if(!layout) + continue; + ir::value* mma_dot_a = layout->hmma_dot_a(); + ir::value* mma_dot_b = layout->hmma_dot_b(); + if(!mma_dot_a && !mma_dot_b){ + per_phase_[layout] = 1; + max_phase_[layout] = 1; + vec_[layout] = 1; + continue; + } + auto ord = layout->get_order(); + scanline_layout* in_layout = dynamic_cast(layout->get_arg_layout()); + if(!in_layout) + continue; + int dtsize = layout->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + if(tgt_->as_nvidia()->sm() < 80){ + int inner = mma_dot_a ? 0 : 1; + per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[layout] = (ord[inner] == 1 ? 8 : 4) / per_phase_[layout]; + if(mma_dot_a) + vec_[layout] = 2*layouts_->get(mma_dot_a)->to_mma()->rep(0); + else + vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1); + } + else{ + per_phase_[layout] = std::max(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1); + max_phase_[layout] = 8 / per_phase_[layout]; + vec_[layout] = 8; + } + } +} + +} +} +} + + diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 577b6c9c7..cf764c857 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -1,11 +1,10 @@ -#include +#include #include "triton/codegen/selection/generator.h" -#include "triton/codegen/selection/machine_layout.h" -#include "triton/codegen/selection/machine_value.h" #include "triton/codegen/target.h" #include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/align.h" +#include "triton/codegen/analysis/swizzle.h" #include "triton/codegen/transform/coalesce.h" #include "triton/ir/context.h" #include "triton/ir/module.h" @@ -13,209 +12,160 @@ #include "triton/ir/type.h" #include "llvm/IR/Module.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" namespace triton{ namespace codegen{ using namespace llvm; -// Function for extended Euclidean Algorithm -inline int gcd_impl(int a, int b, int *x, int *y) -{ - // Base Case - if (a == 0) - { - *x = 0; - *y = 1; - return b; - } - int x1, y1; // To store results of recursive call - int gcd = gcd_impl(b%a, a, &x1, &y1); - // Update x and y using results of - // recursive call - *x = y1 - (b/a) * x1; - *y = x1; - return gcd; -} - -inline int gcd(int a, int b) { - int x, y; - return gcd_impl(a, b, &x, &y); -} +// types +#define void_ty builder_->getVoidTy() +#define f16_ty builder_->getHalfTy() +#define f32_ty builder_->getFloatTy() +#define i32_ty builder_->getInt32Ty() +#define vec_ty(...) VectorType::get(__VA_ARGS__) +#define ptr_ty(...) PointerType::get(__VA_ARGS__) +// constants +#define i32(...) builder_->getInt32(__VA_ARGS__) +// ops +#define add(...) builder_->CreateAdd(__VA_ARGS__) +#define and_(...) builder_->CreateAnd(__VA_ARGS__) +#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__) +#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__) +#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__) +#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__) +#define br(...) builder_->CreateBr(__VA_ARGS__) +#define call(...) builder_->CreateCall(__VA_ARGS__) +#define cast(...) builder_->CreateCast(__VA_ARGS__) +#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__) +#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__) +#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__) +#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__) +#define fadd(...) builder_->CreateFAdd(__VA_ARGS__) +#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__) +#define fmul(...) builder_->CreateFMul(__VA_ARGS__) +#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__) +#define fsub(...) builder_->CreateFSub(__VA_ARGS__) +#define gep(...) builder_->CreateGEP(__VA_ARGS__) +#define icmp(...) builder_->CreateICmp(__VA_ARGS__) +#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__) +#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__) +#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__) +#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__) +#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__) +#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__) +#define load(...) builder_->CreateLoad(__VA_ARGS__) +#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__) +#define min_num(...) builder_->CreateMinNum(__VA_ARGS__) +#define mul(...) builder_->CreateMul(__VA_ARGS__) +#define neg(...) builder_->CreateNeg(__VA_ARGS__) +#define phi(...) builder_->CreatePHI(__VA_ARGS__) +#define ret(...) builder_->CreateRet(__VA_ARGS__) +#define select(...) builder_->CreateSelect(__VA_ARGS__) +#define store(...) builder_->CreateStore(__VA_ARGS__) +#define sub(...) builder_->CreateSub(__VA_ARGS__) +#define udiv(...) builder_->CreateUDiv(__VA_ARGS__) +#define urem(...) builder_->CreateURem(__VA_ARGS__) +#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__) +#define xor_(...) builder_->CreateXor(__VA_ARGS__) -llvm::Instruction::BinaryOps llvm_op(ir::binary_op_t op) { - using llop = llvm::Instruction::BinaryOps; - using ttop = ir::binary_op_t; - switch(op) { - case ttop::Add: return llop::Add; - case ttop::FAdd: return llop::FAdd; - case ttop::Sub: return llop::Sub; - case ttop::FSub: return llop::FSub; - case ttop::Mul: return llop::Mul; - case ttop::FMul: return llop::FMul; - case ttop::UDiv: return llop::UDiv; - case ttop::SDiv: return llop::SDiv; - case ttop::FDiv: return llop::FDiv; - case ttop::URem: return llop::URem; - case ttop::SRem: return llop::SRem; - case ttop::FRem: return llop::FRem; - case ttop::Shl: return llop::Shl; - case ttop::LShr: return llop::LShr; - case ttop::AShr: return llop::AShr; - case ttop::And: return llop::And; - case ttop::Or: return llop::Or; - case ttop::Xor: return llop::Xor; - } - throw std::runtime_error("unknown operator"); -} - -llvm::Instruction::CastOps llvm_op(ir::cast_op_t op) { - using llop = llvm::Instruction::CastOps; - using ttop = ir::cast_op_t; - switch(op){ - case ttop::Trunc: return llop::Trunc; - case ttop::ZExt: return llop::ZExt; - case ttop::SExt: return llop::SExt; - case ttop::FPTrunc: return llop::FPTrunc; - case ttop::FPExt: return llop::FPExt; - case ttop::UIToFP: return llop::UIToFP; - case ttop::SIToFP: return llop::SIToFP; - case ttop::FPToUI: return llop::FPToUI; - case ttop::FPToSI: return llop::FPToSI; - case ttop::PtrToInt: return llop::PtrToInt; - case ttop::IntToPtr: return llop::IntToPtr; - case ttop::BitCast: return llop::BitCast; - case ttop::AddrSpaceCast: return llop::AddrSpaceCast; - } - throw std::runtime_error("unknown operator"); -} - -llvm::CmpInst::Predicate llvm_pred(ir::cmp_pred_t pred) { - using llop = llvm::CmpInst::Predicate; - using ttop = ir::cmp_pred_t; - switch(pred){ - case ttop::FIRST_FCMP_PREDICATE: return llop::FIRST_FCMP_PREDICATE; - case ttop::FCMP_FALSE: return llop::FCMP_FALSE; - case ttop::FCMP_OEQ: return llop::FCMP_OEQ; - case ttop::FCMP_OGT: return llop::FCMP_OGT; - case ttop::FCMP_OGE: return llop::FCMP_OGE; - case ttop::FCMP_OLT: return llop::FCMP_OLT; - case ttop::FCMP_OLE: return llop::FCMP_OLE; - case ttop::FCMP_ONE: return llop::FCMP_ONE; - case ttop::FCMP_ORD: return llop::FCMP_ORD; - case ttop::FCMP_UNO: return llop::FCMP_UNO; - case ttop::FCMP_UEQ: return llop::FCMP_UEQ; - case ttop::FCMP_UGT: return llop::FCMP_UGT; - case ttop::FCMP_UGE: return llop::FCMP_UGE; - case ttop::FCMP_ULT: return llop::FCMP_ULT; - case ttop::FCMP_ULE: return llop::FCMP_ULE; - case ttop::FCMP_UNE: return llop::FCMP_UNE; - case ttop::FCMP_TRUE: return llop::FCMP_TRUE; - case ttop::LAST_FCMP_PREDICATE: return llop::LAST_FCMP_PREDICATE; - case ttop::FIRST_ICMP_PREDICATE: return llop::FIRST_ICMP_PREDICATE; - case ttop::ICMP_EQ: return llop::ICMP_EQ; - case ttop::ICMP_NE: return llop::ICMP_NE; - case ttop::ICMP_UGT: return llop::ICMP_UGT; - case ttop::ICMP_UGE: return llop::ICMP_UGE; - case ttop::ICMP_ULT: return llop::ICMP_ULT; - case ttop::ICMP_ULE: return llop::ICMP_ULE; - case ttop::ICMP_SGT: return llop::ICMP_SGT; - case ttop::ICMP_SGE: return llop::ICMP_SGE; - case ttop::ICMP_SLT: return llop::ICMP_SLT; - case ttop::ICMP_SLE: return llop::ICMP_SLE; - case ttop::LAST_ICMP_PREDICATE: return llop::LAST_ICMP_PREDICATE; - } - throw std::runtime_error("unknown operator"); -} - - -inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) { +/** + * \brief Convert Triton-IR Type to LLVM-IR Type + */ +Type *generator::cvt(ir::type *ty) { // function if(auto* tt = dynamic_cast(ty)){ - Type *return_ty = llvm_type(tt->get_return_ty(), ctx); - std::vector param_tys; - std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), - [&ctx](ir::type* t){ return llvm_type(t, ctx);}); - return FunctionType::get(return_ty, param_tys, false); + Type *ret_ty = cvt(tt->get_return_ty()); + std::vector arg_tys(tt->get_num_params()); + for(size_t i = 0; i < arg_tys.size(); i++) + arg_tys[i] = cvt(tt->get_param_ty(i)); + return FunctionType::get(ret_ty, arg_tys, false); } // pointer if(ty->is_pointer_ty()){ - Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx); + Type *elt_ty = cvt(ty->get_pointer_element_ty()); unsigned addr_space = ty->get_pointer_address_space(); - return PointerType::get(elt_ty, addr_space); + return ptr_ty(elt_ty, addr_space); } // integer if(ty->is_integer_ty()){ unsigned bitwidth = ty->get_integer_bitwidth(); - return IntegerType::get(ctx, bitwidth); + return IntegerType::get(*ctx_, bitwidth); } // primitive types switch(ty->get_type_id()){ - case ir::type::VoidTyID: return Type::getVoidTy(ctx); - case ir::type::HalfTyID: return Type::getHalfTy(ctx); - case ir::type::FloatTyID: return Type::getFloatTy(ctx); - case ir::type::DoubleTyID: return Type::getDoubleTy(ctx); - case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx); - case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx); - case ir::type::LabelTyID: return Type::getLabelTy(ctx); - case ir::type::MetadataTyID: return Type::getMetadataTy(ctx); - case ir::type::TokenTyID: return Type::getTokenTy(ctx); + case ir::type::VoidTyID: return Type::getVoidTy(*ctx_); + case ir::type::HalfTyID: return Type::getHalfTy(*ctx_); + case ir::type::FloatTyID: return Type::getFloatTy(*ctx_); + case ir::type::DoubleTyID: return Type::getDoubleTy(*ctx_); + case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(*ctx_); + case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(*ctx_); + case ir::type::LabelTyID: return Type::getLabelTy(*ctx_); + case ir::type::MetadataTyID: return Type::getMetadataTy(*ctx_); + case ir::type::TokenTyID: return Type::getTokenTy(*ctx_); default: break; } // unknown type throw std::runtime_error("unknown conversion from ir::type to Type"); } - -inline llvm::Attribute llvm_attr(llvm::LLVMContext& ctx, ir::attribute attr) { +/** + * \brief Convert Triton-IR Attribute to LLVM-IR Attribute + */ +llvm::Attribute generator::cvt(ir::attribute attr) { switch(attr.get_kind()){ - case ir::noalias: return llvm::Attribute::get(ctx, llvm::Attribute::NoAlias); - case ir::readonly: return llvm::Attribute::get(ctx, llvm::Attribute::ReadOnly); - case ir::writeonly: return llvm::Attribute::get(ctx, llvm::Attribute::WriteOnly); - case ir::aligned: return llvm::Attribute::get(ctx, llvm::Attribute::Alignment, attr.get_value()); - case ir::retune: return llvm::Attribute::get(ctx, llvm::Attribute::None); + case ir::noalias: return llvm::Attribute::get(*ctx_, llvm::Attribute::NoAlias); + case ir::readonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::ReadOnly); + case ir::writeonly: return llvm::Attribute::get(*ctx_, llvm::Attribute::WriteOnly); + case ir::aligned: return llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, attr.get_value()); + case ir::retune: return llvm::Attribute::get(*ctx_, llvm::Attribute::None); default: throw std::runtime_error("cannot convert ir::attribute_t to llvm::Attribute"); } } -inline bool is_trans(ir::value *v) { - if(dynamic_cast(v)) { - return true; - } - if(auto *phi = dynamic_cast(v)) { - bool result = true; - for(ir::value *op: phi->ops()) - result = result && is_trans(op); - return result; - } - return false; -} - - - - +/** + * \brief Constructor of LLVM code generator + */ generator::generator(analysis::axes *a_axes, analysis::layouts *layouts, analysis::align *alignment, analysis::allocation *alloc, - target *tgt, + analysis::swizzle *swizzle, + target *tgt, unsigned num_warps) - : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), + : a_axes_(a_axes), layouts_(layouts), alignment_(alignment), alloc_(alloc), swizzle_(swizzle), tgt_(tgt), num_warps_(num_warps) { } - +/** + * \brief Code Generation for `value` + */ void generator::visit_value(ir::value* v) { if(!seen_.insert(v).second) return; - // create machine tile if(v->get_type()->is_tile_ty()){ - tmap_[v] = machine_layouts_.at(layouts_->get(v))->create(v); + if(analysis::shared_layout* layout = layouts_->get(v)->to_shared()){ + auto double_buffer = layout->get_double_buffer(); + // offset + Value *offset = nullptr; + if(double_buffer && v == double_buffer->phi) + offset = shared_off_[layout]; + // base pointer + Value *ptr = shared_ptr_[layout]; + if(double_buffer && v == double_buffer->latch) + ptr = shared_next_ptr_[layout]; + else if(double_buffer && v == double_buffer->first) + ptr = shared_pre_ptr_[layout]; + shmems_[v] = ptr; + shoffs_[v] = offset; + } } // visit operands BasicBlock *current = builder_->GetInsertBlock(); @@ -225,520 +175,473 @@ void generator::visit_value(ir::value* v) { if(dynamic_cast(op) || !dynamic_cast(v)) visit_value(op); } + init_idx(v); // change insert point for phi node builder_->SetInsertPoint(current); auto *phi = dynamic_cast(v); if(phi && !current->empty() && current->getFirstNonPHI()) builder_->SetInsertPoint(&*current->getFirstNonPHI()); // visit user - if(auto *usr = dynamic_cast(v)) + if(auto *usr = dynamic_cast(v)){ usr->accept(this); + } // revert insert point if(phi && !current->empty() && current->getFirstNonPHI()) builder_->SetInsertPoint(current); } -void generator::visit_phi_node(ir::phi_node* phi) { - Type *ty = llvm_type(phi->get_type()->get_scalar_ty(), *ctx_); - unsigned num_ops = phi->get_num_operands(); - for_each(phi, [&](indices_t idx){ - set_value(phi, idx, builder_->CreatePHI(ty, num_ops)); - }); +/** + * \brief Code Generation for `phi` + */ +void generator::visit_phi_node(ir::phi_node* x) { + Type *ty = cvt(x->get_type()->get_scalar_ty()); + for(indices_t idx: idxs_.at(x)) + vals_[x][idx] = phi(ty, x->get_num_operands()); } -void generator::visit_binary_operator(ir::binary_operator*binop) { - for_each(binop, [&](indices_t idx){ - Value *lhs = get_value(binop->get_operand(0), idx); - Value *rhs = get_value(binop->get_operand(1), idx); - Value *ret = builder_->CreateBinOp(llvm_op(binop->get_op()), lhs, rhs); - set_value(binop, idx, ret); - }); +/** + * \brief Code Generation for `binary_operator` + */ +void generator::visit_binary_operator(ir::binary_operator*x) { + auto cvt = [](ir::binary_op_t op){ + using ll = llvm::Instruction::BinaryOps; + using tt = ir::binary_op_t; + switch(op) { + case tt::Add: return ll::Add; + case tt::FAdd: return ll::FAdd; + case tt::Sub: return ll::Sub; + case tt::FSub: return ll::FSub; + case tt::Mul: return ll::Mul; + case tt::FMul: return ll::FMul; + case tt::UDiv: return ll::UDiv; + case tt::SDiv: return ll::SDiv; + case tt::FDiv: return ll::FDiv; + case tt::URem: return ll::URem; + case tt::SRem: return ll::SRem; + case tt::FRem: return ll::FRem; + case tt::Shl: return ll::Shl; + case tt::LShr: return ll::LShr; + case tt::AShr: return ll::AShr; + case tt::And: return ll::And; + case tt::Or: return ll::Or; + case tt::Xor: return ll::Xor; + default: throw std::runtime_error("unreachable switch"); + } + }; + for(indices_t idx: idxs_.at(x)){ + Value *lhs = vals_[x->get_operand(0)][idx]; + Value *rhs = vals_[x->get_operand(1)][idx]; + vals_[x][idx] = bin_op(cvt(x->get_op()), lhs, rhs); + } } -void generator::visit_getelementptr_inst(ir::getelementptr_inst* gep) { - for_each(gep, [&](indices_t idx){ - Value *ptr = get_value(gep->get_operand(0), idx); - std::vector idx_vals; - std::transform(gep->idx_begin(), gep->idx_end(), std::back_inserter(idx_vals), - [&](ir::value* x){ return get_value(x, idx);}); - Type *source_ty = llvm_type(gep->get_source_elt_ty()->get_scalar_ty(), *ctx_); - Value *ret = builder_->CreateGEP(source_ty, ptr, idx_vals); - set_value(gep, idx, ret); - }); +/** + * \brief Code Generation for `getelementptr` + */ +void generator::visit_getelementptr_inst(ir::getelementptr_inst* x) { + for(indices_t idx: idxs_.at(x)){ + Value *ptr = vals_[x->get_pointer_operand()][idx]; + std::vector vals; + for(auto it= x->idx_begin(); it != x->idx_end(); it++) + vals.push_back(vals_[*it][idx]); + Type *ty = cvt(x->get_source_elt_ty()->get_scalar_ty()); + vals_[x][idx] = gep(ty, ptr, vals); + } } -void generator::visit_icmp_inst(ir::icmp_inst* icmp) { - for_each(icmp, [&](indices_t idx){ - ir::cmp_pred_t pred = icmp->get_pred(); - Value *lhs = get_value(icmp->get_operand(0), idx); - Value *rhs = get_value(icmp->get_operand(1), idx); - Value *ret = builder_->CreateICmp(llvm_pred(pred), lhs, rhs); - set_value(icmp, idx, ret); - }); +/** + * \brief Code Generation for `icmp` + */ +void generator::visit_icmp_inst(ir::icmp_inst* x) { + auto cvt = [](ir::cmp_pred_t pred) { + using ll = llvm::CmpInst::Predicate; + using tt = ir::cmp_pred_t; + switch(pred){ + case tt::FIRST_ICMP_PREDICATE: return ll::FIRST_ICMP_PREDICATE; + case tt::ICMP_EQ: return ll::ICMP_EQ; + case tt::ICMP_NE: return ll::ICMP_NE; + case tt::ICMP_UGT: return ll::ICMP_UGT; + case tt::ICMP_UGE: return ll::ICMP_UGE; + case tt::ICMP_ULT: return ll::ICMP_ULT; + case tt::ICMP_ULE: return ll::ICMP_ULE; + case tt::ICMP_SGT: return ll::ICMP_SGT; + case tt::ICMP_SGE: return ll::ICMP_SGE; + case tt::ICMP_SLT: return ll::ICMP_SLT; + case tt::ICMP_SLE: return ll::ICMP_SLE; + case tt::LAST_ICMP_PREDICATE: return ll::LAST_ICMP_PREDICATE; + default: throw std::runtime_error("unreachable switch"); + } + }; + + for(indices_t idx: idxs_.at(x)){ + Value *lhs = vals_[x->get_operand(0)][idx]; + Value *rhs = vals_[x->get_operand(1)][idx]; + vals_[x][idx] = icmp(cvt(x->get_pred()), lhs, rhs); + } } -void generator::visit_fcmp_inst(ir::fcmp_inst* fcmp) { - for_each(fcmp, [&](indices_t idx){ - ir::cmp_pred_t pred = fcmp->get_pred(); - Value *lhs = get_value(fcmp->get_operand(0), idx); - Value *rhs = get_value(fcmp->get_operand(1), idx); - Value *ret = builder_->CreateFCmp(llvm_pred(pred), lhs, rhs); - set_value(fcmp, idx, ret); - }); +/** + * \brief Code Generation for `fcmp` + */ +void generator::visit_fcmp_inst(ir::fcmp_inst* x) { + auto cvt = [](ir::cmp_pred_t pred) { + using ll = llvm::CmpInst::Predicate; + using tt = ir::cmp_pred_t; + switch(pred){ + case tt::FIRST_FCMP_PREDICATE: return ll::FIRST_FCMP_PREDICATE; + case tt::FCMP_FALSE: return ll::FCMP_FALSE; + case tt::FCMP_OEQ: return ll::FCMP_OEQ; + case tt::FCMP_OGT: return ll::FCMP_OGT; + case tt::FCMP_OGE: return ll::FCMP_OGE; + case tt::FCMP_OLT: return ll::FCMP_OLT; + case tt::FCMP_OLE: return ll::FCMP_OLE; + case tt::FCMP_ONE: return ll::FCMP_ONE; + case tt::FCMP_ORD: return ll::FCMP_ORD; + case tt::FCMP_UNO: return ll::FCMP_UNO; + case tt::FCMP_UEQ: return ll::FCMP_UEQ; + case tt::FCMP_UGT: return ll::FCMP_UGT; + case tt::FCMP_UGE: return ll::FCMP_UGE; + case tt::FCMP_ULT: return ll::FCMP_ULT; + case tt::FCMP_ULE: return ll::FCMP_ULE; + case tt::FCMP_UNE: return ll::FCMP_UNE; + case tt::FCMP_TRUE: return ll::FCMP_TRUE; + case tt::LAST_FCMP_PREDICATE: return ll::LAST_FCMP_PREDICATE; + default: throw std::runtime_error("unreachable switch"); + } + }; + for(indices_t idx: idxs_.at(x)){ + Value *lhs = vals_[x->get_operand(0)][idx]; + Value *rhs = vals_[x->get_operand(1)][idx]; + vals_[x][idx] = fcmp(cvt(x->get_pred()), lhs, rhs); + } } -void generator::visit_cast_inst(ir::cast_inst* cast) { - for_each(cast, [&](indices_t idx){ - Value *arg = get_value(cast->get_operand(0), idx); - Type *dst_ty = llvm_type(cast->get_type()->get_scalar_ty(), *ctx_); - Value *ret = builder_->CreateCast(llvm_op(cast->get_op()), arg, dst_ty); - set_value(cast, idx, ret); - }); +/** + * \brief Code Generation for `cast` + */ +void generator::visit_cast_inst(ir::cast_inst* x) { + Type *ty = cvt(x->get_type()->get_scalar_ty()); + auto cvt = [](ir::cast_op_t op){ + using ll = llvm::Instruction::CastOps; + using tt = ir::cast_op_t; + switch(op){ + case tt::Trunc: return ll::Trunc; + case tt::ZExt: return ll::ZExt; + case tt::SExt: return ll::SExt; + case tt::FPTrunc: return ll::FPTrunc; + case tt::FPExt: return ll::FPExt; + case tt::UIToFP: return ll::UIToFP; + case tt::SIToFP: return ll::SIToFP; + case tt::FPToUI: return ll::FPToUI; + case tt::FPToSI: return ll::FPToSI; + case tt::PtrToInt: return ll::PtrToInt; + case tt::IntToPtr: return ll::IntToPtr; + case tt::BitCast: return ll::BitCast; + case tt::AddrSpaceCast: return ll::AddrSpaceCast; + default: throw std::runtime_error("unreachable switch"); + } + }; + for(indices_t idx: idxs_.at(x)){ + Value *arg = vals_[x->get_operand(0)][idx]; + vals_[x][idx] = cast(cvt(x->get_op()), arg, ty); + } } +/** + * \brief Code Generation for `return` + */ void generator::visit_return_inst(ir::return_inst* rr) { ir::value *ret_val = rr->get_return_value(); - builder_->CreateRet(ret_val ? vmap_.at(ret_val) : nullptr); + ret(ret_val ? vals_[ret_val][{}] : nullptr); } +/** + * \brief Code Generation for `cond_branch` + */ void generator::visit_cond_branch_inst(ir::cond_branch_inst* br) { - BasicBlock *true_dest = (BasicBlock*)vmap_.at(br->get_true_dest()); - BasicBlock *false_dest = (BasicBlock*)vmap_.at(br->get_false_dest()); - Value *cond = vmap_.at(br->get_cond()); - builder_->CreateCondBr(cond, true_dest, false_dest); + BasicBlock *true_dest = bbs_.at(br->get_true_dest()); + BasicBlock *false_dest = bbs_.at(br->get_false_dest()); + Value *cond = vals_[br->get_cond()][{}]; + cond_br(cond, true_dest, false_dest); } +/** + * \brief Code Generation for `uncond_branch` + */ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) { - BasicBlock *dest = (BasicBlock*)vmap_.at(br->get_dest()); - builder_->CreateBr(dest); + BasicBlock *dest = bbs_.at(br->get_dest()); + br(dest); } +/** + * \brief Code Generation for a (synchronous) `load` + */ +void generator::visit_load_inst(ir::load_inst* x){ + ir::value *op = x->get_pointer_operand(); + ir::masked_load_inst *mx = dynamic_cast(x); + Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty()); + int space = op->get_type()->get_scalar_ty()->get_pointer_address_space(); -void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { - if(!x->get_type()->is_tile_ty()){ - Value *ptr = get_value(x->get_pointer_operand(), {}); - set_value(x, {}, builder_->CreateLoad(ptr)); - return; + // compute vector width + size_t vec = 1; + if(op->get_type()->is_tile_ty()){ + auto ord = ords_.at(op); + size_t aln = alignment_->get(op, ord[0]); + size_t nts = layouts_->get(x)->to_scanline()->nts(ord[0]); + vec = std::min(nts, aln); } - // find vector size - ir::value *ptr = x->get_pointer_operand(); - size_t ld = layouts_->get(ptr)->get_order(0); - unsigned alignment = std::max(alignment_->get(ptr, ld), 1); - - // vector loads - std::map packets; - for_each(x, [&](indices_t idx){ - distributed_tile* result = (distributed_tile*)tmap_.at(x); - // vector size - unsigned contiguous = 1; - if(ld < x->get_type()->get_tile_rank()) - contiguous = result->axis(ld).contiguous; - unsigned vector_size = gcd(contiguous, alignment); - - unsigned linear = result->get_linear_index(idx); - unsigned id = linear / vector_size; - if(linear % vector_size == 0) { - distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); - Value *ptr = pointers->get_value(idx); - ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), - ptr->getType()->getPointerAddressSpace())); - packets[id] = builder_->CreateLoad(ptr); - } - }); - - // extract result element - for_each(x, [&](indices_t idx){ - distributed_tile* result = (distributed_tile*)tmap_.at(x); - // vector size - unsigned contiguous = 1; - if(ld < x->get_type()->get_tile_rank()) - contiguous = result->axis(ld).contiguous; - unsigned vector_size = gcd(contiguous, alignment); - unsigned linear = result->get_linear_index(idx); - unsigned id = linear / vector_size; - set_value(x, idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size)); - }); -} - -void generator::visit_masked_load_inst(ir::masked_load_inst* x) { - if(!x->get_type()->is_tile_ty()){ - Value *ptr = vmap_.at(x->get_pointer_operand()); - Value *mask = vmap_.at(x->get_mask_operand()); - BasicBlock *current_bb = builder_->GetInsertBlock(); - Function *parent = builder_->GetInsertBlock()->getParent(); - BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent); - BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent); - builder_->CreateCondBr(mask, mask_then_bb, mask_done_bb); - builder_->SetInsertPoint(mask_then_bb); - Value *result_then = builder_->CreateLoad(ptr); - builder_->CreateBr(mask_done_bb); - builder_->SetInsertPoint(mask_done_bb); - Value *result = nullptr; - if(x->get_false_value_operand()){ - Value *result_false = vmap_.at(x->get_false_value_operand()); - result = builder_->CreatePHI(result_then->getType(), 2); - ((PHINode*)result)->addIncoming(result_then, mask_then_bb); - ((PHINode*)result)->addIncoming(result_false, current_bb); + // code generation + auto idxs = idxs_.at(x); + for(size_t i = 0; i < idxs.size(); i += vec){ + indices_t idx = idxs[i]; + // pointer value + Value *ptr = bit_cast(vals_[op][idx], ptr_ty(vec_ty(ty, vec), space)); + // masked load + Value *ret = nullptr; + if(mx){ + // if mask: + // ret = load(ptr) + // else: + // ret = false_value + PHINode *_ret = phi(ptr->getType()->getPointerElementType(), 2); + Instruction *then_term; + Instruction *else_term; + llvm::SplitBlockAndInsertIfThenElse(vals_[mx->get_mask_operand()][idx], _ret, &then_term, &else_term); + builder_->SetInsertPoint(then_term); + Value* then_ret = load(ptr); + builder_->SetInsertPoint(else_term); + Value* else_ret = splat(vec, vals_[mx->get_false_value_operand()][idx]); + builder_->SetInsertPoint(_ret->getParent()); + _ret->addIncoming(then_ret, then_term->getParent()); + _ret->addIncoming(else_ret, else_term->getParent()); + ret = (Value*)_ret; } else - result = result_then; - vmap_[x] = result; - return; + ret = load(ptr); + // write back + for(size_t ii = 0; ii < vec; ii++) + vals_[x][idxs[i+ii]] = extract_elt(ret, ii); } - // find vector size - ir::value *ptr = x->get_pointer_operand(); - auto order = layouts_->get(ptr)->get_order(); - size_t ld; - for(size_t i = 0; i < order.size(); i++){ - ld = order[i]; - if(ld < x->get_type()->get_tile_rank()) - break; - } - //size_t ld = layouts_->get(ptr)->get_order(0); - unsigned alignment = alignment_->get(ptr, ld); - distributed_tile *pointers = (distributed_tile*)tmap_.at(ptr); - distributed_tile *masks = (distributed_tile*)tmap_.at(x->get_mask_operand()); - distributed_tile *false_values = (distributed_tile*)tmap_.at(x->get_false_value_operand()); - std::map packets; - for_each(x, [&](indices_t idx){ - distributed_tile* result = (distributed_tile*)tmap_.at(x); - unsigned vector_size = gcd(result->axis(ld).contiguous, alignment); - unsigned linear = result->get_linear_index(idx); - unsigned id = linear / vector_size; - if(linear % vector_size == 0) { - Value *ptr = pointers->get_value(idx); - ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), - ptr->getType()->getPointerAddressSpace())); - - Value *mask = masks->get_value(idx); - BasicBlock *current_bb = builder_->GetInsertBlock(); - Function *parent = builder_->GetInsertBlock()->getParent(); - BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent); - BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent); - builder_->CreateCondBr(mask, mask_then_bb, mask_done_bb); - builder_->SetInsertPoint(mask_then_bb); - Value *result_then = builder_->CreateLoad(ptr); - builder_->CreateBr(mask_done_bb); - builder_->SetInsertPoint(mask_done_bb); - Value *current_result = nullptr; - if(false_values){ - current_result = builder_->CreatePHI(result_then->getType(), 2); - ((PHINode*)current_result)->addIncoming(result_then, mask_then_bb); - Value *result_false = false_values->get_value(idx); - if(result_then->getType()->isVectorTy()) - result_false = builder_->CreateVectorSplat(vector_size, result_false); - ((PHINode*)current_result)->addIncoming(result_false, current_bb); - } - else - current_result = result_then; - -// ConstantInt *cst = nullptr; -// if(GetElementPtrInst *gep = dyn_cast(ptr)) -// if(gep->getNumIndices() == 1) -// cst = dyn_cast(gep->idx_begin()); -// llvm::Value* mask = masks->get_value(idx); -// std::string offset = ""; -// if(cst) -// offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size); -// Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2); -// Type *fp16x2_pack4_ty = StructType::get(*ctx_, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); -// FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false); -// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];"; -// if(false_values) -// asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};"; -// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true); -// Value *current_result = builder_->CreateCall(iasm, {mask, ptr}); - - packets[id] = current_result; - } - }); - // extract result element - for_each(x, [&](indices_t idx){ - distributed_tile* result = (distributed_tile*)tmap_.at(x); - unsigned vector_size = gcd(result->axis(ld).contiguous, alignment); - unsigned linear = result->get_linear_index(idx); - unsigned id = linear / vector_size; -// Value *tmp = builder_->CreateExtractValue(packets.at(id), {(linear % vector_size) / 2}); -// Value *res = builder_->CreateExtractElement(tmp, (linear % vector_size) % 2); -// result->set_value(idx, res); - result->set_value(idx, builder_->CreateExtractElement(packets.at(id), linear % vector_size)); - }); +} +void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) { + visit_load_inst(x); +} +void generator::visit_masked_load_inst(ir::masked_load_inst* x) { + visit_load_inst(x); } -void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* st) { - for_each(st->get_pointer_operand(), [&](indices_t idx){ - Value *ptr = get_value(st->get_pointer_operand(), idx); - Value *val = get_value(st->get_value_operand(), idx); - builder_->CreateStore(val, ptr); - }); -} - - - -void generator::visit_masked_store_inst(ir::masked_store_inst* st) { - distributed_tile* ptrs = (distributed_tile*)tmap_.at(st->get_pointer_operand()); - distributed_tile* masks = (distributed_tile*)tmap_.at(st->get_mask_operand()); +/** + * \brief Code Generation for a (synchronous) `store` + */ +void generator::visit_store_inst(ir::store_inst * x){ + ir::masked_store_inst *mx = dynamic_cast(x); + // operands + ir::value *ptr_op = x->get_pointer_operand(); + ir::value *val_op = x->get_value_operand(); // vector size - int vector_size = 1; - int ld = ptrs->get_order()[0]; - unsigned alignment = alignment_->get(st->get_pointer_operand(), ld); - vector_size = gcd(ptrs->axis(ld).contiguous, alignment); - // create packets - std::map packets; - ir::value *arg = st->get_value_operand(); - for_each(arg, [&](indices_t idx){ - distributed_tile* in = (distributed_tile*)tmap_.at(arg); - unsigned linear = in->get_linear_index(idx); - unsigned id = linear / vector_size; - Value *in_value = in->get_value(idx); - if(linear % vector_size == 0) - packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size)); - packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size); - }); - // write-back packets - for_each(arg, [&](indices_t idx){ - distributed_tile* in = (distributed_tile*)tmap_.at(arg); - unsigned linear = in->get_linear_index(idx); - unsigned id = linear / vector_size; - if(linear % vector_size == 0){ - // fetch tile elements - Value *elt = packets[id]; - Value *ptr = ptrs->get_value(idx); - Value *pred = masks->get_value(idx); - // type information - Type *ty = elt->getType(); - unsigned nbits = ty->getScalarSizeInBits(); - unsigned nbytes = nbits / 8; - // extract pointer offset - std::string offset = ""; - if(GetElementPtrInst *gep = dyn_cast(ptr)) - if(gep->getNumIndices() == 1) - if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ - offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbytes); - ptr = gep->getPointerOperand(); - } - ptr = builder_->CreateBitCast(ptr, ty->getPointerTo(1)); - if(tgt_->is_gpu()){ - // asm argument type - std::vector arg_ty = {pred->getType(), ptr->getType()}; - for(int v = 0; v < vector_size; v++) - arg_ty.push_back(ty->getScalarType()); - // asm function type - FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false); - // asm string - std::string asm_str; - asm_str += "@$0 st.global"; - if(vector_size > 1) - asm_str += ".v" + std::to_string(vector_size); - asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],"; - if(vector_size > 1) - asm_str += "{"; - for(int v = 0; v < vector_size; v++){ - if(v > 0) - asm_str += ", "; - asm_str += "$" + std::to_string(2 + v); - } - if(vector_size > 1) - asm_str += "}"; - asm_str += ";"; - // asm constraint - std::string constraint = "b,l"; - for(int v = 0; v < vector_size; v++){ - constraint += ","; - constraint += (nbits == 32 ? "r" : "h"); - } - // create inline asm - InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); - // call asm - std::vector args = {pred, ptr}; - for(int v = 0; v < vector_size; v++) - args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v))); - builder_->CreateCall(iasm, args); - } - else{ - builder_->CreateMaskedStore(elt, ptr, alignment, builder_->CreateVectorSplat(vector_size, pred)); - } - + size_t vec = 1; + if(val_op->get_type()->is_tile_ty()){ + auto ord = layouts_->get(x->get_pointer_operand())->get_order(); + size_t aln = alignment_->get(ptr_op, ord[0]); + size_t nts = axes_.at(a_axes_->get(x->get_pointer_operand(), ord[0])).contiguous; + vec = std::min(nts, aln); + } + auto idxs = idxs_.at(val_op); + Type *ty = cvt(val_op->get_type()->get_scalar_ty()); + for(size_t i = 0; i < idxs.size(); i += vec){ + auto idx = idxs[i]; + // pointer + Value *ptr = vals_[ptr_op][idx]; + ptr = bit_cast(ptr, vec_ty(ty, vec)->getPointerTo(1)); + // value + Value* val = UndefValue::get(vec_ty(ty, vec)); + for(size_t ii = 0; ii < vec; ii++) + val = insert_elt(val, vals_.at(val_op)[idxs[i + ii]], ii); + if(mx){ + Value *msk = vals_[mx->get_mask_operand()][idx]; + Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {}); + Instruction *term = llvm::SplitBlockAndInsertIfThen(msk, no_op, false); + builder_->SetInsertPoint(term); + store(val, ptr); + builder_->SetInsertPoint(no_op); } - }); + else + store(val, ptr); + } +} +void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* x) { + visit_store_inst(x); +} +void generator::visit_masked_store_inst(ir::masked_store_inst* x) { + visit_store_inst(x); } - -void generator::visit_reshape_inst(ir::reshape_inst* reshape) { - for_each(reshape, [&](indices_t out_idx){ - distributed_tile* result = (distributed_tile*)tmap_.at(reshape); - unsigned pos = result->get_linear_index(out_idx); - ir::value* in = reshape->get_operand(0); - distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); - indices_t in_idx = in_tile->get_ordered_indices(pos); - set_value(reshape, out_idx, get_value(in, in_idx)); - }); +/** + * \brief Code Generation for `reshape` + */ +void generator::visit_reshape_inst(ir::reshape_inst* x) { + auto idxs = idxs_.at(x); + for(size_t i = 0; i < idxs_.at(x).size(); i ++){ + ir::value* op = x->get_operand(0); + vals_[x][idxs_[x][i]] = vals_[op][idxs_[op][i]]; + }; } -void generator::visit_splat_inst(ir::splat_inst* splat) { - Value *in = get_value(splat->get_operand(0), {}); - for_each(splat, [&](indices_t idx){ - set_value(splat, idx, in); - }); +/** + * \brief Code Generation for `splat` + */ +void generator::visit_splat_inst(ir::splat_inst* x) { + for(auto idx: idxs_.at(x)) + vals_[x][idx] = vals_[x->get_operand(0)][{}]; } -void generator::visit_broadcast_inst(ir::broadcast_inst* bcast) { - ir::value* in = bcast->get_operand(0); - const auto& in_shapes = in->get_type()->get_tile_shapes(); - distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); - for_each(bcast, [&](indices_t out_idx){ +/** + * \brief Code Generation for `broadcast` + */ +void generator::visit_broadcast_inst(ir::broadcast_inst* x) { + ir::value* op = x->get_operand(0); + const auto& shape = op->get_type()->get_tile_shapes(); + for(auto out_idx: idxs_.at(x)){ indices_t in_idx = out_idx; - for(size_t k = 0; k < in_idx.size(); k++){ - if(in_shapes[k] == 1) - in_idx[k] = builder_->getInt32(0); - } - set_value(bcast, out_idx, in_tile->get_value(in_idx)); - }); + for(size_t k = 0; k < in_idx.size(); k++) + in_idx[k] = shape[k] == 1 ? i32(0) : in_idx[k]; + vals_[x][out_idx] = vals_[op][in_idx]; + } } +/** + * \brief Code Generation for `downcast` + */ void generator::visit_downcast_inst(ir::downcast_inst* x) { - vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder_->getInt32(0)}); + vals_[x][{}] = vals_[x->get_operand(0)][{i32(0)}]; } +/** + * \brief Code Generation for `get_program_id` + */ void generator::visit_get_program_id_inst(ir::get_program_id_inst* pid) { Module *module = builder_->GetInsertBlock()->getModule(); Value *ret = tgt_->get_block_id(module, *builder_, pid->get_axis()); - vmap_[pid] = ret; + vals_[pid][{}] = ret; } +/** + * \brief Code Generation for `get_num_program` + */ void generator::visit_get_num_program_inst(ir::get_num_program_inst* np) { Module *module = builder_->GetInsertBlock()->getModule(); Value *ret = tgt_->get_num_blocks(module, *builder_, np->get_axis()); - vmap_[np] = ret; + vals_[np][{}] = ret; } +/** + * \brief Code Generation for `exp` + */ void generator::visit_exp_inst(ir::exp_inst* x){ - distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0)); -// Function *fn = builder_->GetInsertBlock()->getParent(); -// Module *module = fn->getParent(); -// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_); -// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty}); - Constant *log2e = ConstantFP::get(builder_->getFloatTy(), 1.4426950408889634); - std::vector tys = {builder_->getFloatTy()}; - FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); + Constant *log2e = ConstantFP::get(f32_ty, 1.4426950408889634); + std::vector tys = {f32_ty}; + FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); InlineAsm *ex2 = InlineAsm::get(fn_ty, "ex2.approx.f32 $0, $1;", "=f,f", false); - - - for_each(x, [&](indices_t idx){ - Value *ex2arg = builder_->CreateFMul(arg->get_value(idx), log2e); - set_value(x, idx, builder_->CreateCall(ex2, std::vector{ex2arg})); - }); + for(auto idx: idxs_.at(x)){ + Value *ex2arg = fmul(vals_[x->get_operand(0)][idx], log2e); + vals_[x][idx] = call(ex2, std::vector{ex2arg}); + } } +/** + * \brief Code Generation for `log` + */ void generator::visit_log_inst(ir::log_inst* x){ - distributed_tile *arg = (distributed_tile*)tmap_.at(x->get_operand(0)); -// Function *fn = builder_->GetInsertBlock()->getParent(); -// Module *module = fn->getParent(); -// Type *ty = llvm_type(x->get_type()->get_scalar_ty(), *ctx_); -// Function *ex2 = Intrinsic::getDeclaration(module, Intrinsic::nvvm_ex2_approx_ftz_f, {ty}); - Constant *rcplog2e = ConstantFP::get(builder_->getFloatTy(), 0.6931471805599453); - std::vector tys = {builder_->getFloatTy()}; - FunctionType *fn_ty = FunctionType::get(builder_->getFloatTy(), tys, false); + Constant *rcplog2e = ConstantFP::get(f32_ty, 0.6931471805599453); + std::vector tys = {f32_ty}; + FunctionType *fn_ty = FunctionType::get(f32_ty, tys, false); InlineAsm *lg2 = InlineAsm::get(fn_ty, "lg2.approx.f32 $0, $1;", "=f,f", false); - - - for_each(x, [&](indices_t idx){ - Value *lg2arg = builder_->CreateCall(lg2, std::vector{arg->get_value(idx)}); - set_value(x, idx, builder_->CreateFMul(lg2arg, rcplog2e)); - }); + for(auto idx: idxs_.at(x)){ + Value *lg2arg = call(lg2, std::vector{vals_[x->get_operand(0)][idx]}); + vals_[x][idx] = fmul(lg2arg, rcplog2e); + } } +/** + * \brief Code Generation for `atomic_cas` + */ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); Value *tid = tgt_->get_local_id(module, *builder_, 0); - Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0)); + Value *pred = icmp_eq(tid, i32(0)); BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); - tgt_->add_barrier(module, *builder_); + add_barrier(); tgt_->add_memfence(module, *builder_); - builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb); + cond_br(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); - Value *cas_ptr = vmap_.at(cas->get_operand(0)); - Value *cas_cmp = vmap_.at(cas->get_operand(1)); - Value *cas_val = vmap_.at(cas->get_operand(2)); - Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, - AtomicOrdering::Monotonic, - AtomicOrdering::Monotonic); - old = builder_->CreateExtractValue(old, std::vector{0}); + Value *cas_ptr = vals_[cas->get_operand(0)][{}]; + Value *cas_cmp = vals_[cas->get_operand(1)][{}]; + Value *cas_val = vals_[cas->get_operand(2)][{}]; + Value *old = atomic_cmp_xchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); + old = extract_val(old, std::vector{0}); Value *atom_ptr; - atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas))))); - atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3)); - - builder_->CreateStore(old, atom_ptr); - builder_->CreateBr(tid_0_done_bb); + atom_ptr = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(cas)))), ""); + atom_ptr = bit_cast(atom_ptr, ptr_ty(old->getType(), 3)); + store(old, atom_ptr); + br(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); - tgt_->add_barrier(module, *builder_); - vmap_[cas] = builder_->CreateLoad(atom_ptr); + add_barrier(); + vals_[cas][{}] = load(atom_ptr); } +/** + * \brief Code Generation for `atomic_exch` + */ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { BasicBlock *current = builder_->GetInsertBlock(); Module *module = current->getModule(); - Value *rmw_ptr = vmap_.at(xchg->get_operand(0)); - Value *rmw_val = vmap_.at(xchg->get_operand(1)); + Value *rmw_ptr = vals_[xchg->get_operand(0)][{}]; + Value *rmw_val = vals_[xchg->get_operand(1)][{}]; Value *tid = tgt_->get_local_id(module, *builder_, 0); - Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0)); + Value *pred = icmp_eq(tid, i32(0)); BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); tgt_->add_memfence(module, *builder_); - tgt_->add_barrier(module, *builder_); - builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb); + add_barrier(); + cond_br(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); - builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, - AtomicOrdering::Monotonic, - SyncScope::System); - builder_->CreateBr(tid_0_done_bb); + atomic_rmw(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System); + br(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); } +/** + * \brief Code Generation for `atomic_add` + */ +//TODO: clean-up void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { - if(add->get_type()->is_tile_ty()){ ir::value* ptr = add->get_operand(0); ir::value* val = add->get_operand(1); ir::value* msk = add->get_operand(2); - distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr); - distributed_tile* vals = (distributed_tile*)tmap_.at(val); - distributed_tile* msks = (distributed_tile*)tmap_.at(msk); - + // vector size - int vector_size = 1; - int ld = ptrs->get_order()[0]; + int vec = 1; + int ld = layouts_->get(ptr)->get_order()[0]; unsigned alignment = alignment_->get(ptr, ld); - vector_size = gcd(ptrs->axis(ld).contiguous, alignment); - vector_size = std::min(vector_size, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1); + vec = std::min(layouts_->get(ptr)->to_scanline()->nts(ld), alignment); + vec = std::min(vec, val->get_type()->get_tile_element_ty()->is_half_ty() ? 2 : 1); - std::map packets; - for_each(val, [&](indices_t idx){ - unsigned linear = vals->get_linear_index(idx); - unsigned id = linear / vector_size; - Value *in_value = vals->get_value(idx); - if(linear % vector_size == 0) - packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size)); - packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size); - }); - - for_each(ptr, [&](indices_t idx){ - unsigned linear = vals->get_linear_index(idx); - unsigned id = linear / vector_size; - if(linear % vector_size != 0) - return; - // num bytes - Value *rmw_ptr = ptrs->get_value(idx); - Value *rmw_msk = msks->get_value(idx); - Value *rmw_val = packets[id]; - if(vector_size == 1) - rmw_val = builder_->CreateExtractElement(rmw_val, builder_->getInt32(0)); + for(int i = 0; i < idxs_.at(val).size(); i += vec){ + auto idx = idxs_[val][i]; + Value *rmw_val = UndefValue::get(vec_ty(vals_[val][idx]->getType(), vec)); + for(int ii = 0; ii < vec; ii++) + rmw_val = insert_elt(rmw_val, vals_[val][idxs_[val][i+ii]], ii); + Value *rmw_ptr = vals_[ptr][idx]; + Value *rmw_msk = vals_[msk][idx]; + if(vec == 1) + rmw_val = extract_elt(rmw_val, i32(0)); Type* ty = rmw_val->getType(); size_t nbits = ty->getScalarSizeInBits(); // extract pointer offset @@ -749,27 +652,27 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8); rmw_ptr = gep->getPointerOperand(); } - rmw_ptr = builder_->CreateBitCast(rmw_ptr, ty->getPointerTo(1)); + rmw_ptr = bit_cast(rmw_ptr, ty->getPointerTo(1)); // asm argument type std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; // asm function type FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false); // asm string - std::string suffix = vector_size == 2 ? "x2" : ""; + std::string suffix = vec == 2 ? "x2" : ""; std::string mod = nbits == 32 ? "" : ".noftz"; std::string asm_str = "@$0 atom.global.gpu.add" + mod + ".f" + std::to_string(nbits) + suffix + " $1, [$2" + offset + "], $3;"; - std::string ty_id = nbits == 32 ? "f" : (vector_size == 1 ? "h" : "r"); + std::string ty_id = nbits == 32 ? "f" : (vec == 1 ? "h" : "r"); std::string constraint = "b,=" + ty_id + ",l," + ty_id; // create inline asm InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); // call asm - builder_->CreateCall(iasm, {rmw_msk, rmw_ptr, rmw_val}); - }); + call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); + } } else{ - Value *rmw_ptr = vmap_.at(add->get_operand(0)); - Value *rmw_val = vmap_.at(add->get_operand(1)); - Value *rmw_msk = vmap_.at(add->get_operand(2)); + Value *rmw_ptr = vals_[add->get_operand(0)][{}]; + Value *rmw_val = vals_[add->get_operand(1)][{}]; + Value *rmw_msk = vals_[add->get_operand(2)][{}]; Type* ty = rmw_val->getType(); size_t nbits = ty->getScalarSizeInBits(); std::vector arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()}; @@ -783,538 +686,793 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) { Module *module = current->getModule(); Value *tid = tgt_->get_local_id(module, *builder_, 0); - Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0)); + Value *pred = icmp_eq(tid, i32(0)); BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); tgt_->add_memfence(module, *builder_); - tgt_->add_barrier(module, *builder_); - builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb); + add_barrier(); + cond_br(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); - builder_->CreateCall(iasm, {rmw_msk, rmw_ptr, rmw_val}); - builder_->CreateBr(tid_0_done_bb); + call(iasm, (ArrayRef{rmw_msk, rmw_ptr, rmw_val})); + br(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); } } -void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK) { - const auto& shapes = dot->get_type()->get_tile_shapes(); - machine_mma884_layout* hmma = (machine_mma884_layout*)machine_layouts_.at(layouts_->get(dot)); - TA->set_vector_size(4*hmma->pack_size_0_); - TB->set_vector_size(4*hmma->pack_size_1_); - TA->set_return_mode(true); - TB->set_return_mode(true); +/** + * \brief Code Generation for `mma.884` (V100) + */ +//TODO: clean-up +void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { + // shapes + auto shape_c = C->get_type()->get_tile_shapes(); + auto shape_a = A->get_type()->get_tile_shapes(); + auto shape_b = B->get_type()->get_tile_shapes(); + // order + auto ord_a = layouts_->get(A)->get_order(); + auto ord_b = layouts_->get(B)->get_order(); + // layouts + analysis::mma_layout* layout_c = layouts_->get(C)->to_mma(); + analysis::shared_layout* layout_a = layouts_->get(A)->to_shared(); + analysis::shared_layout* layout_b = layouts_->get(B)->to_shared(); + // vectorization + int vec_a = swizzle_->get_vec(layout_a); + int vec_b = swizzle_->get_vec(layout_b); + // strides + bool is_a_row = ord_a[0] != 0; + bool is_b_row = ord_b[0] != 0; + int stride_am = is_a_row ? shape_a[1] : 1; + int stride_ak = is_a_row ? 1 : shape_a[0]; + int stride_a0 = is_a_row ? stride_ak : stride_am; + int stride_a1 = is_a_row ? stride_am : stride_ak; + int stride_bn = is_b_row ? 1 : shape_b[0]; + int stride_bk = is_b_row ? shape_b[1] : 1; + int stride_b0 = is_b_row ? stride_bn : stride_bk; + int stride_b1 = is_b_row ? stride_bk : stride_bn; + int stride_rep_m = layout_c->wpt(0) * layout_c->fpw(0) * 8; + int stride_rep_n = layout_c->wpt(1) * layout_c->fpw(1) * 8; + int stride_rep_k = 1; + // swizzling + int per_phase_a = swizzle_->get_per_phase(layout_a); + int max_phase_a = swizzle_->get_max_phase(layout_a); + int step_a0 = is_a_row ? stride_rep_k : stride_rep_m; + int num_ptr_a = std::max(2 * per_phase_a * max_phase_a / step_a0, 1); + int per_phase_b = swizzle_->get_per_phase(layout_b); + int max_phase_b = swizzle_->get_max_phase(layout_b); + int step_b0 = is_b_row ? stride_rep_n : stride_rep_k; + int num_ptr_b = std::max(2 * per_phase_b * max_phase_b / step_b0, 1); - std::map, std::vector> fcs; - - for_each(dot, [&](indices_t idx){ - std::vector key(idx.size() - 2); - std::copy(idx.begin() + 2, idx.end(), key.begin()); - fcs[key].push_back(TD->get_value(idx)); - }); - - Type *fp32_ty = builder_->getFloatTy(); - Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2); - Type *fp32_pack8_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}); - FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); - - - Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0); - - auto ord_a = layouts_->get(dot->get_operand(0))->get_order(); - auto ord_b = layouts_->get(dot->get_operand(1))->get_order(); - - bool is_a_trans = is_trans(dot->get_operand(0)); - bool is_b_trans = is_trans(dot->get_operand(1)); - bool is_a_row = is_a_trans ^ (ord_a[0] != 0); - bool is_b_row = is_b_trans ^ (ord_b[0] != 0); - - Value *offset_a_i = hmma->offset_a_i_; - Value *offset_a_k = hmma->offset_a_k_; - if(is_a_row){ - offset_a_i = builder_->CreateAdd(offset_a_i, builder_->CreateURem(u_thread_id, builder_->getInt32(4))); - offset_a_k = builder_->getInt32(0); + /* --------------------------------- */ + /* --- pre-compute pointer lanes --- */ + /* --------------------------------- */ + BasicBlock* curr_bb = builder_->GetInsertBlock(); + BasicBlock* entry = &curr_bb->getParent()->getEntryBlock(); + builder_->SetInsertPoint(entry->getTerminator()); + Value* off_a0 = is_a_row ? offset_a_k_[layout_c] : offset_a_m_[layout_c]; + Value* off_a1 = is_a_row ? offset_a_m_[layout_c] : offset_a_k_[layout_c]; + Value* phase_a = urem(udiv(off_a1, i32(per_phase_a)), i32(max_phase_a)); + std::vector off_a(num_ptr_a); + for(int i = 0; i < num_ptr_a; i++){ + Value* off_a0i = add(off_a0, i32(i*(is_a_row?4:stride_rep_m))); + off_a0i = exact_udiv(off_a0i, i32(vec_a)); + off_a0i = xor_(off_a0i, phase_a); + off_a0i = mul(off_a0i, i32(vec_a)); + off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1))); } - - Value *offset_b_j = hmma->offset_b_j_; - Value *offset_b_k = hmma->offset_b_k_; - if(!is_b_row){ - offset_b_j = builder_->CreateAdd(offset_b_j, builder_->CreateURem(u_thread_id, builder_->getInt32(4))); - offset_b_k = builder_->getInt32(0); + Value* off_b0 = is_b_row ? offset_b_n_[layout_c] : offset_b_k_[layout_c]; + Value* off_b1 = is_b_row ? offset_b_k_[layout_c] : offset_b_n_[layout_c]; + Value* phase_b = urem(udiv(off_b1, i32(per_phase_b)), i32(max_phase_b)); + std::vector off_b(num_ptr_b); + for(int i = 0; i < num_ptr_b; i++){ + Value* off_b0i = add(off_b0, i32(i*(is_b_row?stride_rep_n:4))); + off_b0i = udiv(off_b0i, i32(vec_b)); + off_b0i = xor_(off_b0i, phase_b); + off_b0i = mul(off_b0i, i32(vec_b)); + off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1))); } + builder_->SetInsertPoint(curr_bb); - std::string op_a = is_a_row ? "row" : "col"; - std::string op_b = is_b_row ? "row" : "col"; - - InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4." + op_a + "." + op_b + ".f32.f16.f16.f32 " + /* --------------------------------- */ + /* --- MMA intrinsic --- */ + /* --------------------------------- */ + Type *f16x2_ty = vec_ty(f16_ty, 2); + Type *ret_ty = StructType::get(*ctx_, {f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty}); + std::vector arg_ty = {f16x2_ty, f16x2_ty, f16x2_ty, f16x2_ty, + f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty, f32_ty}; + InlineAsm *mma = InlineAsm::get(FunctionType::get(ret_ty, arg_ty, false), + " mma.sync.aligned.m8n8k4." + + std::string(is_a_row ? "row" : "col") + + "." + + std::string(is_b_row ? "row" : "col") + + ".f32.f16.f16.f32 " "{$0, $1, $2, $3, $4, $5, $6, $7}, " "{$8, $9}, " "{$10, $11}, " "{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false); - analysis::mma884_layout* layout = layouts_->get(dot)->to_mma884(); - - unsigned fpw_0 = layout->fpw(0); - unsigned fpw_1 = layout->fpw(1); - unsigned wts_0 = fpw_0 * 8; - unsigned wts_1 = fpw_1 * 8; - unsigned wpt_0 = layout->wpt(0); - unsigned wpt_1 = layout->wpt(1); - unsigned stride_rep_i = wpt_0 * wts_0; - unsigned stride_rep_j = wpt_1 * wts_1; - unsigned num_rep_i = shapes[0] / stride_rep_i; - unsigned ld_fc = num_rep_i * 2; - for(auto& x: fcs){ - std::vector& fc = x.second; - for(unsigned pack_i = 0; pack_i < hmma->num_packs_0_; pack_i++) - for(unsigned pack_j = 0; pack_j < hmma->num_packs_1_; pack_j++){ - for(unsigned K = 0; K < NK; K += 4){ - Value *_K = builder_->getInt32(K); - Value *current_offset_a_i = builder_->CreateAdd(offset_a_i, builder_->getInt32(pack_i*stride_rep_i*hmma->pack_size_0_)); - Value *current_offset_b_i = builder_->CreateAdd(offset_b_j, builder_->getInt32(pack_j*stride_rep_j*hmma->pack_size_1_)); - indices_t idx_a = {current_offset_a_i, builder_->CreateAdd(offset_a_k, _K)}; - indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i}; - idx_a.insert(idx_a.end(), x.first.begin(), x.first.end()); - idx_b.insert(idx_b.end(), x.first.begin(), x.first.end()); - - Value *ha = TA->get_value(idx_a); - Value *hb = TB->get_value(idx_b); - for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++) - for(unsigned jj = 0; jj < hmma->pack_size_1_; jj++){ - Value *ha0 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*hmma->pack_size_0_ + 0)), fp16x2_ty); - Value *ha1 = builder_->CreateBitCast(builder_->CreateExtractElement(ha, builder_->getInt32(ii*hmma->pack_size_0_ + 1)), fp16x2_ty); - Value *hb0 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*hmma->pack_size_0_ + 0)), fp16x2_ty); - Value *hb1 = builder_->CreateBitCast(builder_->CreateExtractElement(hb, builder_->getInt32(jj*hmma->pack_size_0_ + 1)), fp16x2_ty); - std::vector idx = { - (pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 0)*ld_fc, - (pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 1)*ld_fc, - (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 0)*ld_fc, - (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 1)*ld_fc, - (pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc, - (pack_i*2*hmma->pack_size_0_ + ii*2 + 0) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc, - (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 2)*ld_fc, - (pack_i*2*hmma->pack_size_0_ + ii*2 + 1) + (pack_j*4*hmma->pack_size_1_ + jj*4 + 3)*ld_fc - }; - Value *nc = builder_->CreateCall(mma_fn, std::vector{ha0, ha1, hb0, hb1, fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]], fc[idx[4]], fc[idx[5]], fc[idx[6]], fc[idx[7]]}); - fc[idx[0]] = builder_->CreateExtractValue(nc, std::vector{0}); - fc[idx[1]] = builder_->CreateExtractValue(nc, std::vector{1}); - fc[idx[2]] = builder_->CreateExtractValue(nc, std::vector{2}); - fc[idx[3]] = builder_->CreateExtractValue(nc, std::vector{3}); - fc[idx[4]] = builder_->CreateExtractValue(nc, std::vector{4}); - fc[idx[5]] = builder_->CreateExtractValue(nc, std::vector{5}); - fc[idx[6]] = builder_->CreateExtractValue(nc, std::vector{6}); - fc[idx[7]] = builder_->CreateExtractValue(nc, std::vector{7}); + std::vector ptr_a(num_ptr_a); + std::vector ptr_b(num_ptr_b); + std::map, std::pair> has, hbs; + for(int i = 0; i < num_ptr_a; i++) + ptr_a[i] = gep(shmems_[A], off_a[i]); + for(int i = 0; i < num_ptr_b; i++) + ptr_b[i] = gep(shmems_[B], off_b[i]); + + + // initialize accumulators + std::vector acc; + for(indices_t idx: idxs_.at(C)) + acc.push_back(vals_[D][idx]); + + // update accumulators + unsigned num_m = layout_c->rep(0) * shape_c[0] / layout_c->spt(0); + unsigned num_n = layout_c->rep(1) * shape_c[1] / layout_c->spt(1); + for(unsigned m = 0; m < num_m/2; m++) + for(unsigned n = 0; n < num_n/2; n++) + for(unsigned K = 0; K < NK; K += 4){ + if(has.find({m, K}) == has.end()){ + Value* ptra = ptr_a[(is_a_row ? K/4 : m) % num_ptr_a]; + int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); + int step_ak = is_a_row ? K / (num_ptr_a*vec_a)*(num_ptr_a*vec_a) : K; + Value* pa = gep(ptra, i32(step_am*stride_rep_m*stride_am + step_ak*stride_ak)); + Value* ha = load(bit_cast(pa, ptr_ty(vec_ty(i32_ty, vec_a/2), 3))); + Value *ha00 = bit_cast(extract_elt(ha, i32(0)), f16x2_ty); + Value *ha01 = bit_cast(extract_elt(ha, i32(1)), f16x2_ty); + has[{m, K}] = {ha00, ha01}; + if(vec_a > 4){ + Value *ha10 = bit_cast(extract_elt(ha, i32(2)), f16x2_ty); + Value *ha11 = bit_cast(extract_elt(ha, i32(3)), f16x2_ty); + if(is_a_row) + has[{m, K+4}] = {ha10, ha11}; + else + has[{m+1, K}] = {ha10, ha11}; } } + if(hbs.find({n, K}) == hbs.end()){ + Value* ptrb = ptr_b[(is_b_row? n : K/4) % num_ptr_b]; + int stepbn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; + int stepbk = is_b_row ? K : K / (num_ptr_b*vec_b)*(num_ptr_b*vec_b); + Value* pb = gep(ptrb, i32(stepbn*stride_rep_n*stride_bn + stepbk*stride_bk)); + Value* hb = load(bit_cast(pb, ptr_ty(vec_ty(i32_ty, vec_b/2), 3))); + Value *hb00 = bit_cast(extract_elt(hb, i32(0)), f16x2_ty); + Value *hb01 = bit_cast(extract_elt(hb, i32(1)), f16x2_ty); + hbs[{n, K}] = {hb00, hb01}; + if(vec_b > 4){ + Value *hb10 = bit_cast(extract_elt(hb, i32(2)), f16x2_ty); + Value *hb11 = bit_cast(extract_elt(hb, i32(3)), f16x2_ty); + if(is_b_row) + hbs[{n+1, K}] = {hb10, hb11}; + else + hbs[{n, K+4}] = {hb10, hb11}; + } } + auto ha = has[{m, K}]; + auto hb = hbs[{n, K}]; + // arguments + std::vector idx = { + (m*2 + 0) + (n*4 + 0)*num_m, (m*2 + 0) + (n*4 + 1)*num_m, + (m*2 + 1) + (n*4 + 0)*num_m, (m*2 + 1) + (n*4 + 1)*num_m, + (m*2 + 0) + (n*4 + 2)*num_m, (m*2 + 0) + (n*4 + 3)*num_m, + (m*2 + 1) + (n*4 + 2)*num_m, (m*2 + 1) + (n*4 + 3)*num_m + }; + std::vector args = {ha.first, ha.second, hb.first, hb.second}; + for(unsigned i = 0; i < 8; i++) + args.push_back(acc[idx[i]]); + // execute mma + Value *nc = call(mma, args); + // unpack + for(unsigned i = 0; i < 8; i++) + acc[idx[i]] = extract_val(nc, {i}); + } + + // write back accumulators + for(size_t i = 0; i < idxs_.at(C).size(); i++) + vals_[C][idxs_[C][i]] = acc[i]; +} + +/** + * \brief Code Generation for `mma.16816` (A100) + */ +//TODO: clean-up +void generator::visit_mma16816(ir::dot_inst* dot, ir::value *A, ir::value *B, ir::value *D, unsigned NK) { + const auto& shapes = dot->get_type()->get_tile_shapes(); + + std::map, std::vector> fcs; + + for(indices_t idx: idxs_.at(dot)){ + std::vector key(idx.size() - 2); + std::copy(idx.begin() + 2, idx.end(), key.begin()); + fcs[key].push_back(vals_[D][idx]); + }; + + auto shape_a = A->get_type()->get_tile_shapes(); + auto shape_b = B->get_type()->get_tile_shapes(); + auto ord_a = layouts_->get(A)->get_order(); + auto ord_b = layouts_->get(B)->get_order(); + analysis::mma_layout* layout = layouts_->get(dot)->to_mma(); + analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(dot->get_operand(0)); + analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(dot->get_operand(1)); + bool is_a_row = ord_a[0] == 1; + bool is_b_row = ord_b[0] == 1; + std::string a_trans = is_a_row ? "" : ".trans"; + std::string b_trans = is_b_row ? ".trans" : ""; + int stride_a_m = is_a_row ? shape_a[1] : 1; + int stride_a_k = is_a_row ? 1 : shape_a[0]; + int stride_b_n = is_b_row ? 1 : shape_b[0]; + int stride_b_k = is_b_row ? shape_b[1] : 1; + int stride_a0 = is_a_row ? stride_a_k : stride_a_m; + int stride_a1 = is_a_row ? stride_a_m : stride_a_k; + int stride_b0 = is_b_row ? stride_b_n : stride_b_k; + int stride_b1 = is_b_row ? stride_b_k : stride_b_n; + int lda = is_a_row ? stride_a_m : stride_a_k; + int ldb = is_b_row ? stride_b_k : stride_b_n; + int per_phase_a = swizzle_->get_per_phase(layout_a); + int max_phase_a = swizzle_->get_max_phase(layout_a); + int per_phase_b = swizzle_->get_per_phase(layout_b); + int max_phase_b = swizzle_->get_max_phase(layout_b); + int num_ptr_a = 8; + int num_ptr_b = 8; + int vec_a = 8; + int vec_b = 8; + + + + Type *fp32_ty = f32_ty; + Type *fp16x2_ty = vec_ty(f16_ty, 2); + Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); + Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector{fp32_ty, fp32_ty, fp32_ty, fp32_ty}); + FunctionType *ld_x4_ty = FunctionType::get(fp16x2_pack4_ty, std::vector{ptr_ty(f16_ty, 3)}, false); + + // left-hand-side values + std::map, std::pair> ha; + std::map, Value*> hb; + + + BasicBlock* CurrBB = builder_->GetInsertBlock(); + BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); + builder_->SetInsertPoint(FirstBB->getTerminator()); + + Value* thread = tgt_->get_local_id(mod_, *builder_, 0); + Value *lane = urem(thread, i32(32)); + Value *warp = udiv(thread, i32(32)); + Value *warp12 = udiv(warp, i32(layout->wpt(0))); + Value *warp0 = urem(warp, i32(layout->wpt(0))); + Value *warp1 = urem(warp12, i32(layout->wpt(1))); + std::vector& fc = fcs.begin()->second; + + Value *tidr8 = urem(lane, i32(8)); + Value *phase_a = urem(udiv(tidr8, i32(per_phase_a)), i32(max_phase_a)); + Value* off_a0 = mul(tidr8, i32(lda)); + Value *off_am = mul(add(urem(udiv(lane, i32(8)), i32(2)), mul(warp0, i32(2))), i32(8)); + Value *off_ak = mul(udiv(lane, i32(16)), i32(8)); + off_am = urem(off_am, i32(shape_a[0])); + off_ak = urem(off_ak, i32(shape_a[1])); + off_a0 = add(off_a0, is_a_row ? off_ak : off_am); + Value* off_a1 = is_a_row ? off_am : off_ak; + std::vector off_a(num_ptr_a); + for(int i = 0; i < num_ptr_a; i++){ + Value* off_a0i = add(off_a0, i32(i*16*(is_a_row?1:layout->wpt(0)))); + off_a0i = exact_udiv(off_a0i, i32(vec_a)); + off_a0i = xor_(off_a0i, phase_a); + off_a0i = mul(off_a0i, i32(vec_a)); + off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1))); + } + + Value *phase_b = urem(udiv(tidr8, i32(per_phase_b)), i32(max_phase_b)); + Value* off_b0 = mul(tidr8, i32(ldb)); + Value *off_bn = mul(add(mul(udiv(lane, i32(16)), i32(layout->wpt(1))), mul(warp1, i32(1))), i32(8)); + Value *off_bk = mul(urem(udiv(lane, i32(8)), i32(2)), i32(8)); + off_bn = urem(off_bn, i32(shape_b[1])); + off_bk = urem(off_bk, i32(shape_b[0])); + off_b0 = add(off_b0, is_b_row ? off_bn : off_bk); + Value* off_b1 = is_b_row ? off_bk : off_bn; + std::vector off_b(num_ptr_b); + for(int i = 0; i < num_ptr_b; i++){ + Value* off_b0i = add(off_b0, i32(i*(is_b_row?8*layout->wpt(1):16))); + off_b0i = exact_udiv(off_b0i, i32(vec_b)); + off_b0i = xor_(off_b0i, phase_b); + off_b0i = mul(off_b0i, i32(vec_b)); + off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1))); + } + + builder_->SetInsertPoint(CurrBB); + // A pointer + std::vector ptrs_a(num_ptr_a); + for(int i = 0; i < num_ptr_a; i++) + ptrs_a[i] = gep(shmems_[A], {off_a[i]}); + // B pointer + std::vector ptrs_b(num_ptr_b); + for(int i = 0; i < num_ptr_b; i++) + ptrs_b[i] = gep(shmems_[B], {off_b[i]}); + + FunctionType *mma_ty = FunctionType::get(fp32_pack4_ty, std::vector{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false); + InlineAsm *mma_fn = InlineAsm::get(mma_ty, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{$0, $1, $2, $3}, " + "{$4, $5, $6, $7}, " + "{$8, $9}, " + "{$10, $11, $12, $13};", "=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", false); + unsigned num_rep_0 = shapes[0] / layout->spt(0); + unsigned num_rep_1 = shapes[1] / layout->spt(1); + for(unsigned K = 0; K < NK; K += 16) + for(unsigned m = 0; m < num_rep_0; m++) + for(unsigned n = 0; n < num_rep_1; n++){ + if(ha.find({m, K}) == ha.end()){ + Value* ptra = ptrs_a[(is_a_row ? K/16 : m) % num_ptr_a]; + int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a); + int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K; + InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 " + "{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];", "=r,=r,=r,=r,r", false); + Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra}); + Value *ha0 = extract_val(haa, std::vector{0}); + Value *ha1 = extract_val(haa, std::vector{1}); + Value *ha2 = extract_val(haa, std::vector{2}); + Value *ha3 = extract_val(haa, std::vector{3}); + ha[{m, K}] = std::make_pair(ha0, ha1); + ha[{m, K+8}] = std::make_pair(ha2, ha3); + } + if(hb.find({n, K})==hb.end()){ + Value* ptrb = ptrs_b[(is_b_row ? n : K/16) % num_ptr_b]; + int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n; + int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8); + InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 " + "{$0, $1, $2, $3}, [$4 + " + std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];", "=r,=r,=r,=r,r", false); + Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb}); + Value *hb0 = extract_val(hbb, std::vector{0}); + Value *hb1 = extract_val(hbb, std::vector{1}); + Value *hb2 = extract_val(hbb, std::vector{2}); + Value *hb3 = extract_val(hbb, std::vector{3}); + hb[{n, K}] = hb0; + hb[{n+1, K}] = hb2; + hb[{n, K+8}] = hb1; + hb[{n+1, K+8}] = hb3; + } + unsigned cols_per_thread = num_rep_0 * 2; + std::vector idx = { + (m*2 + 0) + (n*2 + 0)*cols_per_thread, + (m*2 + 0) + (n*2 + 1)*cols_per_thread, + (m*2 + 1) + (n*2 + 0)*cols_per_thread, + (m*2 + 1) + (n*2 + 1)*cols_per_thread + }; + Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second, + hb[{n, K}], hb[{n, K+8}], + fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]}); + fc[idx[0]] = extract_val(nc, std::vector{0}); + fc[idx[1]] = extract_val(nc, std::vector{1}); + fc[idx[2]] = extract_val(nc, std::vector{2}); + fc[idx[3]] = extract_val(nc, std::vector{3}); } // write back unsigned i = 0; - for_each(dot, [&](indices_t idx){ + for(indices_t idx: idxs_.at(dot)){ std::vector key(idx.size() - 2); std::copy(idx.begin() + 2, idx.end(), key.begin()); if(i >= fcs.at(key).size()) i = 0; - set_value(dot, idx, fcs.at(key)[i++]); - }); - - TA->set_return_mode(false); - TB->set_return_mode(false); - + vals_[dot][idx] = fcs.at(key)[i++]; + }; } -void generator::visit_scanline_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *TB, distributed_tile *TD, unsigned NK, - Type *c_ty, Function *f_mul_add) { - TA->set_vector_size(TD->axis(0).contiguous); - TB->set_vector_size(TD->axis(1).contiguous); - for_each(dot, [&](indices_t idx){ - Value *res = TD->get_value(idx); - for(unsigned K = 0; K < NK; ++K){ - // input indices - indices_t a_idx = {idx[0], builder_->getInt32(K)}; - indices_t b_idx = {builder_->getInt32(K), idx[1]}; - // add batching dimension - for(size_t i = 2; i < idx.size(); i++){ - a_idx.insert(a_idx.end(), idx[i]); - b_idx.insert(b_idx.end(), idx[i]); + +/** + * \brief Code Generation for FMA-based `dot` (FP32, FP64, Default) + */ +void generator::visit_fmadot(ir::dot_inst* C, ir::value* A, ir::value* B, ir::value* D, unsigned NK, Type *c_ty, Function *f_mul_add) { + auto shape_c = C->get_type()->get_tile_shapes(); + auto shape_a = A->get_type()->get_tile_shapes(); + auto shape_b = B->get_type()->get_tile_shapes(); + auto ord_a = layouts_->get(A)->get_order(); + auto ord_b = layouts_->get(B)->get_order(); + analysis::scanline_layout* layout_c = layouts_->get(C)->to_scanline(); + analysis::shared_layout* layout_a = (analysis::shared_layout*)layouts_->get(C->get_operand(0)); + analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1)); + bool is_a_row = ord_a[0] == 1; + bool is_b_row = ord_b[0] == 1; + std::string a_trans = is_a_row ? "" : ".trans"; + std::string b_trans = is_b_row ? ".trans" : ""; + int stride_a_m = is_a_row ? shape_a[1] : 1; + int stride_a_k = is_a_row ? 1 : shape_a[0]; + int stride_b_n = is_b_row ? 1 : shape_b[0]; + int stride_b_k = is_b_row ? shape_b[1] : 1; + int stride_a0 = is_a_row ? stride_a_k : stride_a_m; + int stride_a1 = is_a_row ? stride_a_m : stride_a_k; + int stride_b0 = is_b_row ? stride_b_n : stride_b_k; + int stride_b1 = is_b_row ? stride_b_k : stride_b_n; + int lda = is_a_row ? stride_a_m : stride_a_k; + int ldb = is_b_row ? stride_b_k : stride_b_n; + int per_phase_a = swizzle_->get_per_phase(layout_a); + int max_phase_a = swizzle_->get_max_phase(layout_a); + int per_phase_b = swizzle_->get_per_phase(layout_b); + int max_phase_b = swizzle_->get_max_phase(layout_b); + int num_ptr_a = 8; + int num_ptr_b = 8; + int vec_a = 2; + int vec_b = 4; + distributed_axis ax_m = axes_.at(a_axes_->get(C, 0)); + distributed_axis ax_n = axes_.at(a_axes_->get(C, 1)); +// Value* thread = tgt_->get_local_id(mod_, *builder_, 0); + + Value* off_a0 = is_a_row ? i32(0) : mul(ax_m.thread_id, i32(ax_m.contiguous)); + Value* off_a1 = is_a_row ? mul(ax_m.thread_id, i32(ax_m.contiguous)): i32(0); + std::vector off_a(num_ptr_a); + for(int i = 0; i < num_ptr_a; i++){ +// Value* off_a0i = add(off_a0, i32(is_a_row ? vec_a : layout_c->mts(0)*vec_a)); +// off_a0i = exact_udiv(off_a0i, i32(vec_a)); +// off_a0i = xor_(off_a0i, phase_a); +// off_a0i = mul(off_a0i, i32(vec_a)); + off_a[i] = add(mul(off_a0, i32(stride_a0)), mul(off_a1, i32(stride_a1))); + } + Value* off_b0 = is_b_row ? mul(ax_n.thread_id, i32(ax_n.contiguous)): i32(0); + Value* off_b1 = is_b_row ? i32(0) : mul(ax_n.thread_id, i32(ax_n.contiguous)); + std::vector off_b(num_ptr_b); + for(int i = 0; i < num_ptr_b; i++){ +// Value* off_b0i = add(off_b0, i32(is_b_row ? layout_c->mts(1)*vec_b : vec_b)); +// off_b0i = exact_udiv(off_b0i, i32(vec_b)); +// off_b0i = xor_(off_b0i, phase_b); +// off_b0i = mul(off_b0i, i32(vec_b)); + off_b[i] = add(mul(off_b0, i32(stride_b0)), mul(off_b1, i32(stride_b1))); + } + std::vector ptrs_a(num_ptr_a); + for(int i = 0; i < num_ptr_a; i++) + ptrs_a[i] = gep(shmems_[A], off_a[i]); + std::vector ptrs_b(num_ptr_b); + for(int i = 0; i < num_ptr_b; i++) + ptrs_b[i] = gep(shmems_[B], off_b[i]); + + std::map ret = vals_[D]; + std::map, Value*> has, hbs; + for(unsigned k = 0; k < NK; k++){ + int z = 0; + for(unsigned m = 0; m < shape_c[0]; m+=layout_c->mts(0)*layout_c->nts(0)) + for(unsigned n = 0; n < shape_c[1]; n+=layout_c->mts(1)*layout_c->nts(1)) + for(unsigned mm = 0; mm < layout_c->nts(0); mm++) + for(unsigned nn = 0; nn < layout_c->nts(1); nn++) + { + if(has.find({m + mm, k}) == has.end()){ + Value* pa = gep(ptrs_a[0], i32((m + mm)*stride_a_m + k*stride_a_k)); + Value* va = load(pa); + has[{m + mm, k}] = va; } - // load value - Value *a = TA->get_value(a_idx); - Value *b = TB->get_value(b_idx); - if(a->getType() != c_ty) - a = builder_->CreateFPCast(a, c_ty); - if(b->getType() != c_ty) - b = builder_->CreateFPCast(b, c_ty); - res = builder_->CreateCall(f_mul_add, std::vector{a, b, res}); + if(hbs.find({n + nn, k}) == hbs.end()){ + Value* pb = gep(ptrs_b[0], i32((n + nn)*stride_b_n + k*stride_b_k)); + Value* vb = load(pb); + hbs[{n + nn, k}] = vb; + } + ret[idxs_[C].at(z)] = call(f_mul_add, {has[{m+mm,k}], hbs[{n+nn, k}], ret[idxs_[C].at(z)]}); + z++; } - set_value(dot, idx, res); - }); -} - -void generator::visit_outer_dot(ir::dot_inst* dot, distributed_tile *TA, distributed_tile *TB, distributed_tile *TD, unsigned NK, - Type *c_ty, Function *f_mul_add) { - for_each(dot, [&](indices_t idx){ - Value *res = TD->get_value(idx); - indices_t a_idx = {idx[0], builder_->getInt32(0)}; - indices_t b_idx = {builder_->getInt32(0), idx[1]}; - std::swap(a_idx[0], a_idx[1]); - std::swap(b_idx[0], b_idx[1]); - Value *a = TA->get_value(a_idx); - Value *b = TB->get_value(b_idx); - if(a->getType() != c_ty) - a = builder_->CreateFPCast(a, c_ty); - if(b->getType() != c_ty) - b = builder_->CreateFPCast(b, c_ty); - res = builder_->CreateCall(f_mul_add, std::vector{a, b, res}); - set_value(dot, idx, res); - }); + } + + for(indices_t idx: idxs_.at(C)){ + vals_[C][idx] = ret[idx]; + } } +/** + * \brief Code Generation for `dot` + * Dispatches to appropriate specialized function + */ void generator::visit_dot_inst(ir::dot_inst* dot) { Function *fn = builder_->GetInsertBlock()->getParent(); - Module *module = fn->getParent(); ir::value *A = dot->get_operand(0); ir::value *B = dot->get_operand(1); ir::value *D = dot->get_operand(2); - - distributed_tile *TD = (distributed_tile*)tmap_.at(D); - Type *c_ty = llvm_type(D->get_type()->get_scalar_ty(), *ctx_); + Type *c_ty = cvt(D->get_type()->get_scalar_ty()); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, std::vector{c_ty}); auto A_shapes = A->get_type()->get_tile_shapes(); size_t red_axis = 1; unsigned NK = A_shapes[red_axis]; - - if(NK != 1) { - shared_tile *TA = (shared_tile*)tmap_.at(A); - shared_tile *TB = (shared_tile*)tmap_.at(B); - if(layouts_->get(dot)->to_mma884()) - visit_hmma_dot(dot, TA, TB, TD, NK); - else - visit_scanline_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add); - } - else { - distributed_tile *TA = (distributed_tile*)tmap_.at(A); - distributed_tile *TB = (distributed_tile*)tmap_.at(B); - visit_outer_dot(dot, TA, TB, TD, NK, c_ty, f_mul_add); - } + bool is_outer = NK == 1; + bool is_mma = layouts_->get(dot)->to_mma(); + if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80) + return visit_mma884(dot, A, B, D, NK); + if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80) + return visit_mma16816(dot, A, B, D, NK); + return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add); } void generator::visit_trans_inst(ir::trans_inst* trans) { - shared_tile* in = (shared_tile*)tmap_.at(trans->get_operand(0)); - shared_tile* out = new shared_tile(in->get_ty(), in->get_shapes(), in->get_order(), in->get_pointer(), *builder_, in->get_offset(), trans->get_perm()); - tmap_[trans] = out; + throw std::runtime_error("not supported"); } -void generator::visit_sqrt_inst(ir::sqrt_inst* sqt) { - for_each(sqt, [&](indices_t idx){ - Value *val = get_value(sqt->get_operand(0), idx); - Module* module = builder_->GetInsertBlock()->getModule(); - Value *sqrt = Intrinsic::getDeclaration(module, Intrinsic::sqrt, std::vector{val->getType()}); - Value *ret = builder_->CreateCall(sqrt, std::vector{val}); - set_value(sqt, idx, ret); - }); +/** + * \brief Code Generation for `sqrt` + */ +void generator::visit_sqrt_inst(ir::sqrt_inst* x) { + for(indices_t idx: idxs_.at(x)){ + Value *val = vals_[x->get_operand(0)][idx]; + Value *ret = intrinsic(Intrinsic::sqrt, {val->getType()}, {val}); + vals_[x][idx] = ret; + } } -void generator::visit_reduce_inst(ir::reduce_inst* x) { +Value* generator::shared_off(const std::vector& shapes, const std::vector& order, indices_t idx){ + // strides + std::vector strides(shapes.size(), builder_->getInt32(0)); + strides[order[0]] = builder_->getInt32(1); + for(size_t i = 1; i < idx.size(); i++) + strides[order[i]] = builder_->CreateMul(strides[order[i-1]], builder_->getInt32(shapes[order[i-1]])); + // result + Value *result = builder_->getInt32(0); + for(size_t i = 0; i < idx.size(); i++) + result = builder_->CreateAdd(result, builder_->CreateMul(idx[i], strides[i])); + return result; +} + +/** + * \brief Code Generation for `reduce` (1D case) + */ +void generator::visit_reduce1d_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { std::map partial; ir::value *arg = x->get_operand(0); - distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg); - ir::reduce_inst::op_t op = x->get_op(); - unsigned axis = x->get_axis(); - - Type *fp32_ty = builder_->getFloatTy(); - FunctionType *fmaxmin_ty = FunctionType::get(fp32_ty, std::vector{fp32_ty, fp32_ty}, false); - InlineAsm *fmin = InlineAsm::get(fmaxmin_ty, "min.ftz.f32 $0, $1, $2;", "=f,f,f", false); - InlineAsm *fmax = InlineAsm::get(fmaxmin_ty, "max.ftz.f32 $0, $1, $2;", "=f,f,f", false); - - auto accumulate = [&](Value* x, Value *y) -> Value* { - switch(op) { - case ir::reduce_inst::ADD: return builder_->CreateAdd(x, y); - case ir::reduce_inst::SUB: return builder_->CreateSub(x, y); - case ir::reduce_inst::MAX:{ - if(x->getType()->isIntegerTy()) - return builder_->CreateSelect(builder_->CreateICmpSGE(x, y), x, y); - else - return builder_->CreateMaxNum(x, y); - } - case ir::reduce_inst::MIN:{ - if(x->getType()->isIntegerTy()) - return builder_->CreateSelect(builder_->CreateICmpSLE(x, y), x, y); - else - return builder_->CreateMinNum(x, y); - } - case ir::reduce_inst::FADD: return builder_->CreateFAdd(x, y); - case ir::reduce_inst::FSUB: return builder_->CreateFSub(x, y); - case ir::reduce_inst::FMAX: return builder_->CreateCall(fmax, std::vector{x, y}); - case ir::reduce_inst::FMIN: return builder_->CreateCall(fmin, std::vector{x, y}); - default: assert(false); return nullptr; - } - }; - - Value *neutral; - switch(op) { - case ir::reduce_inst::ADD: neutral = builder_->getInt32(0); break; - case ir::reduce_inst::SUB: neutral = builder_->getInt32(0); break; - case ir::reduce_inst::MAX: neutral = builder_->getInt32(INT32_MIN); break; - case ir::reduce_inst::MIN: neutral = builder_->getInt32(INT32_MAX); break; - case ir::reduce_inst::FADD: neutral = ConstantFP::get(arg_tile->get_ty(), 0); break; - case ir::reduce_inst::FSUB: neutral = ConstantFP::get(arg_tile->get_ty(), 0); break; - case ir::reduce_inst::FMAX: neutral = ConstantFP::get(arg_tile->get_ty(), -INFINITY); break; - case ir::reduce_inst::FMIN: neutral = ConstantFP::get(arg_tile->get_ty(), INFINITY); break; - default: assert(false); break; - } - - - - analysis::data_layout* arg_layout = layouts_->get(arg); - if(auto* L = dynamic_cast(arg_layout)){ - bool can_optimize = L->get_rank() == 1; - /* - for(size_t r = 0; r < L->get_rank(); r++){ - if(r != axis) - can_optimize = can_optimize && (L->mts(r) == L->get_shape()[r]); - } - */ - if(can_optimize){ - Value *thread_acc = nullptr; - // reduce within thread - arg_tile->for_each([&](indices_t idx) { - Value *current = arg_tile->get_value(idx); - if(thread_acc == nullptr) - thread_acc = current; - else - thread_acc = accumulate(thread_acc, current); - }); - // reduce within wrap - FunctionType *fn_ty = FunctionType::get(thread_acc->getType(), std::vector{thread_acc->getType(), builder_->getInt32Ty()}, false); - InlineAsm *shfl_xor = InlineAsm::get(fn_ty, "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false); - Value *warp_acc = thread_acc; - for(int i = 16; i > 0; i >>= 1) - warp_acc = accumulate(warp_acc, builder_->CreateCall(shfl_xor, std::vector{warp_acc, builder_->getInt32(i)})); - // shared memory pointer - unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); - Type *res_ty = arg_tile->get_ty(); - Value *sh_mem_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); - Value* u_thread_id = tgt_->get_local_id(builder_->GetInsertBlock()->getModule(), *builder_, 0); - Value* warp_id = builder_->CreateUDiv(u_thread_id, builder_->getInt32(32)); - Value *write_ptr = builder_->CreateGEP(sh_mem_ptr, warp_id); - // store warp result in shared memory - tgt_->add_barrier(mod_, *builder_); - builder_->CreateStore(warp_acc, write_ptr); - tgt_->add_barrier(mod_, *builder_); - // accumulate all warps - Value *load_ptr = builder_->CreateGEP(sh_mem_ptr, u_thread_id); - Value* is_first_warp = builder_->CreateICmpEQ(warp_id, builder_->getInt32(0)); - BasicBlock* bb_final_acc = BasicBlock::Create(*ctx_, "bb_final_acc", builder_->GetInsertBlock()->getParent()); - BasicBlock* bb_final_acc_done = BasicBlock::Create(*ctx_, "bb_final_acc_done", builder_->GetInsertBlock()->getParent()); - builder_->CreateCondBr(is_first_warp, bb_final_acc, bb_final_acc_done); - builder_->SetInsertPoint(bb_final_acc); - Value* final_val = builder_->CreateLoad(load_ptr); - for(int i = (num_warps_+1)/2; i > 0; i >>= 1) - final_val = accumulate(final_val, builder_->CreateCall(shfl_xor, std::vector{final_val, builder_->getInt32(i)})); - builder_->CreateStore(final_val, load_ptr); - builder_->CreateBr(bb_final_acc_done); -// // store first warp done - builder_->SetInsertPoint(bb_final_acc_done); - // write back - tgt_->add_barrier(mod_, *builder_); - final_val = builder_->CreateLoad(sh_mem_ptr); - for_each(x, [&](indices_t idx) { - set_value(x, idx, final_val); - }); - return; - } - } + Type *ty = cvt(x->get_type()->get_scalar_ty()); + Value *acc = nullptr; // reduce within thread - arg_tile->for_each([&](indices_t idx) { + for(indices_t idx: idxs_.at(arg)){ + Value *val = vals_[arg][idx]; + acc = !acc ? val : do_acc(acc, val); + } + // reduce within wrap + InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), + "shfl.sync.bfly.b32 $0, $1, $2, 0x1f, 0xffffffff;", "=f,f,r", false); + for(int i = 16; i > 0; i >>= 1) + acc = do_acc(acc, call(shfl, {acc, i32(i)})); + // pointers + unsigned addr_space = shmem_->getType()->getPointerAddressSpace(); + Value *base = bit_cast(shmem_, ptr_ty(ty, addr_space)); + Value* thread = tgt_->get_local_id(mod_, *builder_, 0); + Value* warp = udiv(thread, i32(32)); + Value* lane = urem(thread, i32(32)); + // store warp result in shared memory + add_barrier(); + store(neutral, gep(base, lane)); + add_barrier(); + store(acc, gep(base, warp)); + add_barrier(); + + // reduce across warps + Value *cond = icmp_eq(warp, i32(0)); + Instruction *barrier = add_barrier(); + Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false); + builder_->SetInsertPoint(term); + Value* ret = load(gep(base, thread)); + for(int i = (num_warps_+1)/2; i > 0; i >>= 1){ + Value *current = call(shfl, {ret, i32(i)}); + ret = do_acc(ret, current); + } + store(ret, gep(base, thread)); + + // store first warp done + builder_->SetInsertPoint(barrier->getParent()); + ret = load(base); + for(indices_t idx: idxs_.at(x)) + vals_[x][idx] = ret; +} + +/** + * \brief Code Generation for `reduce` (ND case) + */ +void generator::visit_reducend_inst(ir::reduce_inst* x, std::function do_acc, Value *neutral) { + ir::value *arg = x->get_operand(0); + Type *ty = cvt(x->get_type()->get_scalar_ty()); + unsigned axis = x->get_axis(); + + // reduce within thread + std::map accs; + for(indices_t idx: idxs_.at(arg)){ indices_t pidx = idx; - pidx[axis] = builder_->getInt32(0); - Value *current = arg_tile->get_value(idx); - // current partial result is not initialized -- create - if(partial.find(pidx) == partial.end()) - partial[pidx] = current; - // current partial result is initialized -- accumulate - else - partial[pidx] = accumulate(partial[pidx], current); - }); + pidx[axis] = i32(0); + Value *current = vals_[arg][idx]; + bool is_first = accs.find(pidx) == accs.end(); + accs[pidx] = is_first ? current : do_acc(accs[pidx], current); + }; // reduce within blocks - machine_data_layout *slayout = machine_layouts_.at(layouts_->get(layouts_->tmp(x))); - shared_tile *stile = (shared_tile*)slayout->create(x); - unsigned depth = stile->get_shapes()[axis]; - - unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); - Type *res_ty = arg_tile->get_ty(); - Value *base_ptr = builder_->CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); - for(auto& x: partial) { + analysis::data_layout* layout = layouts_->get(layouts_->tmp(x)); + Value *base = shared_ptr_.at(layout); + auto shape = layout->get_shape(); + auto order = layout->get_order(); + int space = base->getType()->getPointerAddressSpace(); + Value *ptr = bit_cast(base, ptr_ty(ty, space)); + Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; + for(auto& x: accs) { // current element being computed - Value *lane = axes_.at(a_axes_->get(arg, axis)).thread_id; - Value *&result = x.second; + Value *&acc = x.second; indices_t write_idx = x.first; write_idx[axis] = lane; // shared memory write pointer - Value *write_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), write_idx); - Value *write_ptr = builder_->CreateGEP(base_ptr, write_offset); + Value *write_off = shared_off(shape, order, write_idx); + Value *write_ptr = gep(ptr, write_off); // initialize shared memory - tgt_->add_barrier(mod_, *builder_); - builder_->CreateStore(result, write_ptr); + add_barrier(); + store(acc, write_ptr); // build result - for(unsigned i = depth/2; i > 0; i >>= 1){ - // current indices - indices_t current(write_idx.size(), builder_->getInt32(0)); - current[axis] = builder_->getInt32(i); - // shared memory offset - Value *read_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), current); - Value *is_active = builder_->CreateICmpULT(lane, builder_->getInt32(i)); - read_offset = builder_->CreateSelect(is_active, read_offset, builder_->getInt32(0)); - // shared memory read pointer - Value *read_ptr = builder_->CreateGEP(write_ptr, read_offset); - tgt_->add_barrier(mod_, *builder_); - Value *next = builder_->CreateLoad(read_ptr); - // accumulate - result = accumulate(result, next); - // write back - tgt_->add_barrier(mod_, *builder_); - builder_->CreateStore(result, write_ptr); + indices_t idx(write_idx.size(), i32(0)); + for(size_t i = shape[axis]/2; i > 0; i >>= 1){ + idx[axis] = i32(i); + // read pointer + Value *read_msk = icmp_ult(lane, i32(i)); + Value *read_off = select(read_msk, shared_off(shape, order, idx), i32(0)); + Value *read_ptr = gep(write_ptr, read_off); + add_barrier(); + // update accumulator + acc = do_acc(acc, load(read_ptr)); + store(acc, write_ptr); } } - tgt_->add_barrier(mod_, *builder_); + add_barrier(); // write back - for_each(x, [&](indices_t idx) { - indices_t red_idx = idx; - red_idx.insert(red_idx.begin() + axis, builder_->getInt32(0)); - Value *read_offset = shared_tile::shared_offset(*builder_, stile->get_shapes(), stile->get_perm(), stile->get_order(), red_idx); - Value *read_ptr = builder_->CreateGEP(base_ptr, read_offset); - set_value(x, idx, builder_->CreateLoad(read_ptr)); - }); + for(indices_t idx: idxs_.at(x)){ + indices_t read_idx = idx; + read_idx.insert(read_idx.begin() + axis, i32(0)); + Value *read_off = shared_off(shape, order, read_idx); + Value *read_ptr = gep(ptr, read_off); + vals_[x][idx] = load(read_ptr); + }; } -void generator::visit_select_inst(ir::select_inst* select) { - for_each(select, [&](indices_t idx){ - Value *pred = get_value(select->get_operand(0), idx); - Value *if_value = get_value(select->get_operand(1), idx); - Value *else_value = get_value(select->get_operand(2), idx); - Value *ret = builder_->CreateSelect(pred, if_value, else_value); - set_value(select, idx, ret); - }); - +/** + * \brief Code Generation for `reduce` (generic case) + */ +void generator::visit_reduce_inst(ir::reduce_inst* x) { + Type *ty = cvt(x->get_type()->get_scalar_ty()); + // accumulation function + ir::reduce_inst::op_t op = x->get_op(); + auto do_acc = [&](Value *x, Value *y) -> Value* { + switch(op){ + case ir::reduce_inst::ADD: return add(x, y); + case ir::reduce_inst::SUB: return sub(x, y); + case ir::reduce_inst::MAX: return select(icmp_sge(x, y), x, y); + case ir::reduce_inst::MIN: return select(icmp_sle(x, y), x, y); + case ir::reduce_inst::FADD: return fadd(x, y); + case ir::reduce_inst::FSUB: return fsub(x, y); + case ir::reduce_inst::FMAX: return max_num(x, y); + case ir::reduce_inst::FMIN: return min_num(x, y); + default: throw std::runtime_error("unreachable"); + } + }; + // neutral element + Value *neutral; + switch(op) { + case ir::reduce_inst::ADD: neutral = i32(0); break; + case ir::reduce_inst::SUB: neutral = i32(0); break; + case ir::reduce_inst::MAX: neutral = i32(INT32_MIN); break; + case ir::reduce_inst::MIN: neutral = i32(INT32_MAX); break; + case ir::reduce_inst::FADD: neutral = ConstantFP::get(ty, 0); break; + case ir::reduce_inst::FSUB: neutral = ConstantFP::get(ty, 0); break; + case ir::reduce_inst::FMAX: neutral = ConstantFP::get(ty, -INFINITY); break; + case ir::reduce_inst::FMIN: neutral = ConstantFP::get(ty, INFINITY); break; + default: throw std::runtime_error("unreachable"); + } + ir::value *arg = x->get_operand(0); + if(arg->get_type()->get_tile_rank() == 1) + visit_reduce1d_inst(x, do_acc, neutral); + else + visit_reducend_inst(x, do_acc, neutral); } +/** + * \brief Code Generation for `select` + */ +void generator::visit_select_inst(ir::select_inst* x) { + for(indices_t idx: idxs_.at(x)) + vals_[x][idx] = select(vals_[x->get_operand(0)][idx], + vals_[x->get_operand(1)][idx], + vals_[x->get_operand(2)][idx]); +} + +/** + * \brief Code Generation for `recoalesce` + */ void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { ir::value *op = rc->get_operand(0); ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes(); - size_t rank = shape.size(); - // temporary layout - shared_tile *tmp = (shared_tile*)machine_layouts_.at(layouts_->get(layouts_->tmp(rc))) - ->create(rc); // pointer to temporary shared memory - Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_); - // layouts - analysis::mma884_layout* in_layout = layouts_->get(op)->to_mma884(); + Type *ty = cvt(rc->get_type()->get_scalar_ty()); + // layout + analysis::mma_layout* in_layout = layouts_->get(op)->to_mma(); analysis::scanline_layout* out_layout = layouts_->get(rc)->to_scanline(); - // machine tiles - distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op)); - distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc)); - // WMMA configuration - long wmma_pt[3] = { 2, 4, 1}; - long wmma[3] = { 8*in_layout->wpt(0)*in_layout->fpw(0), - 8*in_layout->wpt(1)*in_layout->fpw(1), - 1}; - // Work per thread for input layout - long in_pt[3] = { shape[0] / wmma[0], - shape[1] / wmma[1], - 1 }; - // Work per thread for output layout - long out_pt[3] = { shape[0] / out_layout->mts(0), - shape[1] / out_layout->mts(1), - 1 }; - if(rank > 2){ - wmma[2] = in_layout->wpt(2)*in_layout->fpw(2); - in_pt[2] = shape[2] / wmma[2]; - out_pt[2] = shape[2] / out_layout->mts(2); - } // Orders - auto ord = out_layout->get_order(); - if(ord.size() < 3) - ord.push_back(2); - // pointer lanes - std::vector> ptrs; - for(int in_zz = 0; in_zz < wmma_pt[ord[2]]; in_zz++) { - std::vector current; - for(int in_cc = 0; in_cc < wmma_pt[ord[1]]; in_cc++) { - Value *base; - base = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(rc))))); - base = builder_->CreateBitCast(base, PointerType::get(ty, 3)); + auto ord = layouts_->get(rc)->to_scanline()->get_order(); + Value *base; + base = gep(shmem_, i32(alloc_->offset(layouts_->get(layouts_->tmp(rc))))); + base = bit_cast(base, ptr_ty(ty, 3)); + Value *ld = i32(shape[ord[0]]); + auto in_ord0 = axes_.at(a_axes_->get(op, ord[0])).values; + auto in_ord1 = axes_.at(a_axes_->get(op, ord[1])).values; + auto out_ord0 = axes_.at(a_axes_->get(rc, ord[0])).values; + auto out_ord1 = axes_.at(a_axes_->get(rc, ord[1])).values; + int in_outer = in_layout->spt(ord[1]); + int in_rep = in_layout->rep(ord[1]); + int out_outer = out_layout->mts(ord[1]) * out_layout->nts(ord[1]); + int max_outer = std::max(in_outer, out_outer); + int out_ratio = std::max(out_outer/in_outer, 1); + int in_ratio = std::max(in_outer/out_outer, 1); + indices_t idx(2); + for(size_t j = 0; j < shape[ord[1]]/max_outer; j++){ + add_barrier(); + for(size_t k = 0; k < in_rep*out_ratio; k++) + for(size_t i = 0; i < in_ord0.size(); i++){ + idx[ord[0]] = in_ord0[i]; + idx[ord[1]] = in_ord1[j*in_rep*out_ratio + k]; + Value *off = add(idx[ord[0]], mul(in_ord1[k], ld)); + Value *ptr = gep(base, off); + store(vals_[op][idx], ptr); + } + add_barrier(); + for(size_t k = 0; k < in_ratio; k++) + for(size_t i = 0; i < out_ord0.size(); i++){ + idx[ord[0]] = out_ord0[i]; + idx[ord[1]] = out_ord1[j*in_ratio + k]; + Value *off = add(out_ord0[i], mul(out_ord1[k], ld)); + Value *ptr = gep(base, off); + vals_[rc][idx] = load(ptr); + } + } +} - // shared memory stride - Value *stride_0 = builder_->getInt32(tmp->get_shapes()[ord[0]]); - // indices - Value *idx_cc = axes_.at(a_axes_->get(op, ord[1])).values[in_cc]; - // offset - Value *off = builder_->CreateMul(stride_0, idx_cc); - if(rank > 2){ - Value *stride_1 = builder_->CreateMul(stride_0, - builder_->getInt32(tmp->get_shapes()[ord[1]])); - Value *idx_zz = axes_.at(a_axes_->get(op, ord[2])).values[in_zz]; - off = builder_->CreateAdd(off, builder_->CreateMul(stride_1, idx_zz)); - } - current.push_back(builder_->CreateGEP(base, off)); - } - ptrs.push_back(current); +void generator::visit_masked_load_async_inst(ir::masked_load_async_inst* x){ + unsigned vector = 1; + ir::value *ptrs = x->get_pointer_operand(); + ir::value *msks = x->get_mask_operand(); + analysis::shared_layout* out_layout = layouts_->get(x)->to_shared(); + analysis::scanline_layout* in_layout = layouts_->get(ptrs)->to_scanline(); + auto out_order = out_layout->get_order(); + auto in_order = in_layout->get_order(); + // tiles + if(out_order == in_order) + vector = in_layout->nts(in_order[0]); + // + int dtsize = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; + int num_per_phase = std::max(128 / (in_layout->mts(in_order[0])*vector*dtsize), 1); + Value *max_phase = i32(8 / num_per_phase); + // + auto shapes = x->get_type()->get_tile_shapes(); + // + int per_thread_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); + int n_shared = std::max(8 / in_layout->mts(in_order[1]), 1); + std::vector shared; + for(size_t i = 0; i < n_shared; i++){ + indices_t idx = idxs_.at(ptrs).at(i*per_thread_ld); + // phase + Value* phase = udiv(idx[in_order[1]], i32(num_per_phase)); + phase = urem(phase, max_phase); + // off + Value* off_0 = idx[in_order[0]]; + off_0 = udiv(off_0, i32(vector)); + off_0 = xor_(off_0, phase); + off_0 = mul(off_0 , i32(vector)); + Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); + Value* off = add(off_0, off_1); + // + shared.push_back(gep(shmems_[x], {off})); } - // Re-coalesce loops - for(int in_z = 0; in_z < in_pt[ord[2]]; in_z++) - for(int in_c = 0; in_c < in_pt[ord[1]]; in_c++){ - // write to shared - tgt_->add_barrier(mod_, *builder_); - for(int in_zz = 0; in_zz < wmma_pt[ord[2]]; in_zz++) - for(int in_cc = 0; in_cc < wmma_pt[ord[1]]; in_cc++){ - std::vector starts(rank), len(rank); - starts[ord[0]] = 0; - starts[ord[1]] = in_c*wmma_pt[ord[1]] + in_cc; - len[ord[0]] = wmma_pt[ord[0]]*in_pt[ord[0]]; - len[ord[1]] = 1; - if(rank > 2){ - starts[ord[2]] = in_z*wmma_pt[ord[2]] + in_zz; - len[ord[2]] = 1; - } - in_dt->for_each([&](indices_t idx){ - Value *write_ptr = builder_->CreateGEP(ptrs[in_zz][in_cc], idx[ord[0]]); - builder_->CreateStore(in_dt->get_value(idx), write_ptr); - }, starts, len); - } - tgt_->add_barrier(mod_, *builder_); - // load from shared - for(int out_zz = 0; out_zz < out_pt[ord[2]] / in_pt[ord[2]]; out_zz++) - for(int out_cc = 0; out_cc < out_pt[ord[1]] / in_pt[ord[1]]; out_cc++){ - std::vector starts(rank), len(rank); - starts[ord[0]] = 0; - starts[ord[1]] = in_c*(out_pt[ord[1]] / in_pt[ord[1]]) + out_cc; - len[ord[0]] = out_pt[ord[0]]; - len[ord[1]] = 1; - if(rank > 2){ - starts[ord[2]] = in_z*(out_pt[ord[2]] / in_pt[ord[2]]) + out_zz; - len[ord[2]] = 1; - } - out_dt->for_each([&](indices_t idx){ - indices_t read_idx(rank); - read_idx[ord[0]] = idx[ord[0]]; - read_idx[ord[1]] = axes_.at(a_axes_->get(rc, ord[1])).values[out_cc]; - if(rank > 2) - read_idx[ord[2]] = axes_.at(a_axes_->get(rc, ord[2])).values[out_zz]; - out_dt->set_value(idx, tmp->get_value(read_idx)); - }, starts, len); - } + // + for(size_t i = 0; i < idxs_.at(ptrs).size(); i += vector){ + auto idx = idxs_[ptrs][i]; + // input ptr info + GetElementPtrInst *in_gep = dyn_cast(vals_[ptrs][idx]); + Value *in_base = in_gep->getPointerOperand(); + size_t in_off = dyn_cast(in_gep->idx_begin())->getValue().getSExtValue()*2*vector; + Value* out_base = shared[(i / per_thread_ld) % n_shared]; + int out_off_0 = (i / per_thread_ld) / n_shared * n_shared * in_layout->mts(in_order[1]); + int out_off_1 = i % per_thread_ld; + int out_off = (out_off_0*shapes[in_order[0]] + out_off_1)*2; + // asm + FunctionType *ty = FunctionType::get(void_ty, {out_base->getType(), in_base->getType()}, false); + std::string mod = (vector*2 == 16) ? ".cg" : ".ca"; + std::string asm_str = "@$0 cp.async" + mod + ".shared.global [$1 + " + std::to_string(out_off) + "], [$2 + " + std::to_string(in_off) + "], " + std::to_string(vector*2) + ";"; + InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,r,l", true); + call(iasm, {vals_[msks][idx], out_base, in_base}); } - tgt_->add_barrier(mod_, *builder_); } void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { - unsigned vector_size = 1; + unsigned in_vec = 1; ir::value *arg = cts->get_operand(0); analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared(); analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline(); @@ -1322,113 +1480,154 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { auto in_order = in_layout->get_order(); // tiles if(out_order == in_order) - vector_size = in_layout->nts(in_order[0]); + in_vec = in_layout->nts(in_order[0]); + int out_vec = swizzle_->get_vec(out_layout); + int min_vec = std::min(out_vec, in_vec); + int s = std::max(out_vec / in_vec, 1); + // + int per_phase = swizzle_->get_per_phase(out_layout); + int max_phase = swizzle_->get_max_phase(out_layout); + // + int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]); + int n_shared_1 = std::max(per_phase*max_phase / in_layout->mts(in_order[1]), 1); + int n_shared_0 = std::max(in_vec / out_vec, 1); - std::map packets; - for_each(arg, [&](indices_t idx){ - distributed_tile* in = (distributed_tile*)tmap_.at(arg); - unsigned linear = in->get_linear_index(idx); - unsigned id = linear / vector_size; - Value *in_value = in->get_value(idx); - if(linear % vector_size == 0) - packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size)); - packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size); - }); + BasicBlock* CurrBB = builder_->GetInsertBlock(); + BasicBlock* FirstBB = &CurrBB->getParent()->getEntryBlock(); + auto shapes = cts->get_type()->get_tile_shapes(); - for_each(arg, [&](indices_t idx){ - distributed_tile* in = (distributed_tile*)tmap_.at(arg); - shared_tile* result = (shared_tile*)tmap_.at(cts); - unsigned linear = in->get_linear_index(idx); - unsigned id = linear / vector_size; - if(linear % vector_size == 0) - result->set_value(idx, packets[id]); - }); + // default implementation + Value *current = nullptr; + std::map, Value*> ptrs; + for(int i = 0; i < idxs_.at(arg).size(); i++){ + auto idx = idxs_[arg][i]; + Value *in_value = vals_[arg][idx]; + if(i % min_vec == 0) + current = UndefValue::get(vec_ty(in_value->getType(), min_vec)); + current = insert_elt(current, in_value, i % min_vec); + if(i % min_vec == min_vec - 1){ + unsigned id = i / min_vec; + // input ptr info + int id_0 = id % (in_ld/min_vec); + int id_1 = id / (in_ld/min_vec); + int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]); + int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]); + int off = (off_1*shapes[in_order[0]] + off_0); + std::pair key = {id_1 % n_shared_1, id_0 % n_shared_0}; + if(ptrs.find(key) == ptrs.end()){ + builder_->SetInsertPoint(FirstBB->getTerminator()); + indices_t idx = idxs_.at(arg).at(key.first*in_ld); + Value* phase = udiv(idx[in_order[1]], i32(per_phase)); + phase = urem(phase, i32(max_phase)); + Value* off_1 = mul(idx[in_order[1]], i32(shapes[in_order[0]])); + Value* off_0 = add(idx[in_order[0]], i32(key.second*out_vec)); + off_0 = udiv(off_0, i32(min_vec)); + off_0 = add(mul(xor_(udiv(off_0, i32(s)), phase),i32(s)), urem(off_0, i32(s))); + off_0 = mul(off_0 , i32(min_vec)); + Value* off = add(off_0, off_1); + builder_->SetInsertPoint(CurrBB); + ptrs[key] = gep(shmems_.at(cts), {off}); + } + Value* ptr = gep(ptrs[key], {i32(off)}); + ptr = bit_cast(ptr, current->getType()->getPointerTo(3)); + // asm + store(current, ptr); + } + }; } -void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst* cfs) { - for_each(cfs, [&](indices_t idx){ - set_value(cfs, idx, get_value(cfs->get_operand(0), idx)); - }); + +void generator::visit_copy_from_shared_inst(ir::copy_from_shared_inst*) { + throw std::runtime_error("TODO"); +} + +Instruction* generator::add_barrier() { + Module *module = builder_->GetInsertBlock()->getModule(); + return tgt_->add_barrier(module, *builder_); } void generator::visit_barrier_inst(ir::barrier_inst*) { - Module *module = builder_->GetInsertBlock()->getModule(); - tgt_->add_barrier(module, *builder_); + add_barrier(); +} + +void generator::visit_async_wait_inst(ir::async_wait_inst*) { + std::string asm_str = "cp.async.wait_all;"; + InlineAsm *iasm = InlineAsm::get(FunctionType::get(void_ty, {}), asm_str, "", true); + call(iasm); + add_barrier(); } void generator::visit_make_range_dyn(ir::make_range_dyn* x) { - for_each(x, [&](indices_t idx){ + for(indices_t idx: idxs_.at(x)){ assert(idx.size() == 1); - if(idx[0] == builder_->getInt32(0)) - set_value(x, idx, idx[0]); + if(idx[0] == i32(0)) + vals_[x][idx] = idx[0]; else{ BinaryOperator *bin_add = dyn_cast(idx[0]); assert(bin_add); - Value *res = bin_add->getOperand(0); - set_value(x, idx, res); + vals_[x][idx] = bin_add->getOperand(0); } - }); + } } void generator::visit_make_range_sta(ir::make_range_sta* x) { - for_each(x, [&](indices_t idx){ + for(indices_t idx: idxs_.at(x)){ assert(idx.size() == 1); - if(idx[0] == builder_->getInt32(0)){ - set_value(x, idx, idx[0]); + if(idx[0] == i32(0)){ + vals_[x][idx] = idx[0]; } else{ BinaryOperator *bin_add = dyn_cast(idx[0]); assert(bin_add); - Value *res = bin_add->getOperand(1); - assert(isa(res)); - set_value(x, idx, res); + Value *cst = bin_add->getOperand(1); + assert(isa(cst)); + vals_[x][idx] = cst; } - }); + }; } void generator::visit_make_range(ir::make_range* x) { - for_each(x, [&](indices_t idx){ - assert(idx.size() == 1); - set_value(x, idx, idx[0]); - }); + for(indices_t idx: idxs_.at(x)){ + vals_[x][idx] = idx[0]; + } } void generator::visit_undef_value(ir::undef_value *ud) { - vmap_[ud] = llvm::UndefValue::get(llvm_type(ud->get_type(), *ctx_)); + vals_[ud][{}] = llvm::UndefValue::get(cvt(ud->get_type())); } void generator::visit_constant_int(ir::constant_int *cst){ - Type *ty = llvm_type(cst->get_type()->get_scalar_ty(), *ctx_); - vmap_[cst] = ConstantInt::get(ty, cst->get_value()); + Type *ty = cvt(cst->get_type()->get_scalar_ty()); + vals_[cst][{}] = ConstantInt::get(ty, cst->get_value()); } void generator::visit_constant_fp(ir::constant_fp *cst){ - Type *ty = llvm_type(cst->get_type()->get_scalar_ty(), *ctx_); - vmap_[cst] = ConstantFP::get(ty, cst->get_value()); + Type *ty = cvt(cst->get_type()->get_scalar_ty()); + vals_[cst][{}] = ConstantFP::get(ty, cst->get_value()); } void generator::visit_alloc_const(ir::alloc_const *alloc) { unsigned size = ((ir::constant_int*)alloc->get_operand(0))->get_value(); - Type *element_ty = llvm_type(alloc->get_type()->get_pointer_element_ty(), *ctx_); + Type *element_ty = cvt(alloc->get_type()->get_pointer_element_ty()); Type *array_ty = llvm::ArrayType::get(element_ty, size); Value *array = new llvm::GlobalVariable(*mod_, array_ty, false, llvm::GlobalVariable::ExternalLinkage, nullptr, alloc->get_name(), nullptr, llvm::GlobalVariable::NotThreadLocal, 4); - vmap_[alloc] = builder_->CreateBitCast(array, element_ty->getPointerTo(4)); + vals_[alloc][{}] = bit_cast(array, element_ty->getPointerTo(4)); } void generator::visit_function(ir::function* fn) { LLVMContext &ctx = builder_->getContext(); - FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), *ctx_); + FunctionType *fn_ty = (FunctionType*)cvt(fn->get_fn_type()); if(!tgt_->is_gpu()){ Type *fn_ret_ty = fn_ty->getReturnType(); std::vector fn_args_ty; for(unsigned i = 0; i < fn_ty->getNumParams(); i++) fn_args_ty.push_back(fn_ty->getParamType(i)); - fn_args_ty.push_back(builder_->getInt32Ty()); - fn_args_ty.push_back(builder_->getInt32Ty()); - fn_args_ty.push_back(builder_->getInt32Ty()); + fn_args_ty.push_back(i32_ty); + fn_args_ty.push_back(i32_ty); + fn_args_ty.push_back(i32_ty); fn_ty = FunctionType::get(fn_ret_ty, fn_args_ty, false); } Function *ret = Function::Create(fn_ty, Function::ExternalLinkage, fn->get_name(), mod_); @@ -1437,9 +1636,9 @@ void generator::visit_function(ir::function* fn) { unsigned id = attr_pair.first; for(ir::attribute attr: attr_pair.second) if(attr.is_llvm_attr()){ - llvm::Attribute llattr = llvm_attr(ctx, attr); + llvm::Attribute llattr = cvt(attr); if(llattr.getKindAsEnum() != llvm::Attribute::None) - ret->addAttribute(id, llvm_attr(ctx, attr)); + ret->addAttribute(id, cvt(attr)); } } // set metadata @@ -1448,19 +1647,19 @@ void generator::visit_function(ir::function* fn) { Metadata *md_args[] = { ValueAsMetadata::get(ret), MDString::get(ctx, "maxntidx"), - ValueAsMetadata::get(builder_->getInt32(num_warps_*32)) + ValueAsMetadata::get(i32(num_warps_*32)) }; mod_->getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args)); } // set arguments for(unsigned i = 0; i < fn->args().size(); i++) - vmap_[fn->args()[i]] = &*(ret->arg_begin() + i); + vals_[fn->args()[i]][{}] = &*(ret->arg_begin() + i); // create blocks for(ir::basic_block *block: fn->blocks()) { BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret); - vmap_[block] = dst_block; + bbs_[block] = dst_block; } - builder_->SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]); + builder_->SetInsertPoint(bbs_[fn->blocks()[0]]); // initialize layouts for(auto x: layouts_->get_all()){ visit_layout(x.second); @@ -1474,79 +1673,287 @@ void generator::visit_function(ir::function* fn) { -void generator::visit_layout_hmma_884(analysis::mma884_layout* layout) { - machine_layouts_[layout] = new machine_mma884_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout); +void generator::visit_layout_mma(analysis::mma_layout* layout) { + ir::value *a = nullptr; + ir::value *b = nullptr; + for(ir::value* v: layout->get_values()) + if(ir::dot_inst* dot = dynamic_cast(v)){ + a = dot->get_operand(0); + b = dot->get_operand(1); + } + analysis::data_layout* layout_a = layouts_->get(a); + analysis::data_layout* layout_b = layouts_->get(b); + + const auto& shape = layout->get_shape(); + Value *_1 = i32(1); + Value *_2 = i32(2); + Value *_3 = i32(3); + Value *_4 = i32(4); + Value *_8 = i32(8); + Value *_16 = i32(16); + Value *_32 = i32(32); + int cc = tgt_->as_nvidia()->sm(); + std::vector idx_m; + std::vector idx_n; + std::vector idx_z; + // + Value* thread = tgt_->get_local_id(mod_, *builder_, 0); + Value *lane = urem(thread, _32); + Value *warp = udiv(thread, _32); + /* lane offset */ + if(cc < 80){ + auto ord_a = layout_a->get_order(); + auto ord_b = layout_b->get_order(); + bool is_a_row = ord_a[0] != 0; + bool is_b_row = ord_b[0] != 0; + /* warp offset */ + Value *warp_0 = urem(warp, i32(layout->wpt(0))); + Value *warp_12 = udiv(warp, i32(layout->wpt(0))); + Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); + Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); + Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); + // Quad offset + Value *off_quad_m = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(0))); + Value *off_quad_n = mul(udiv(and_(lane, _16), _4), i32(layout->fpw(1))); + // Pair offset + Value *off_pair_m = udiv(urem(lane, _16), _4); + off_pair_m = urem(off_pair_m, i32(layout->fpw(0))); + off_pair_m = mul(off_pair_m, i32(4)); + Value *off_pair_n = udiv(urem(lane, _16), _4); + off_pair_n = udiv(off_pair_n, i32(layout->fpw(0))); + off_pair_n = urem(off_pair_n, i32(layout->fpw(1))); + off_pair_n = mul(off_pair_n, i32(4)); + // scale + off_pair_m = mul(off_pair_m, i32(layout->rep(0)/2)); + off_quad_m = mul(off_quad_m, i32(layout->rep(0)/2)); + off_pair_n = mul(off_pair_n, i32(layout->rep(1)/2)); + off_quad_n = mul(off_quad_n, i32(layout->rep(1)/2)); + // Quad pair offset + Value *off_lane_m = add(off_pair_m, off_quad_m); + Value *off_lane_n = add(off_pair_n, off_quad_n); + // a offset + offset_a_m_[layout] = add(off_warp_m, off_lane_m); + offset_a_k_[layout] = and_(lane, _3); + // b offsets + offset_b_n_[layout] = add(off_warp_n, off_lane_n); + offset_b_k_[layout] = and_(lane, _3); + // i indices + Value *offset_c_m = add(and_(lane, _1), offset_a_m_[layout]); + for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)) + for(unsigned mm = 0; mm < layout->rep(0); mm++) + idx_m.push_back(add(offset_c_m, i32(m + mm*2))); + // j indices + Value *offset_c_n = add(and_(lane, _2), add(off_warp_n, off_pair_n)); + for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)) + for(unsigned nn = 0; nn < layout->rep(1); nn++){ + idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1)))); + idx_n.push_back(add(offset_c_n, i32(n + nn/2*4 + (nn%2)*2*layout->fpw(1)*layout->rep(1) + 1))); + } + if(is_a_row){ + offset_a_m_[layout] = add(offset_a_m_[layout], urem(thread, i32(4))); + offset_a_k_[layout] = i32(0); + } + if(!is_b_row){ + offset_b_n_[layout] = add(offset_b_n_[layout], urem(thread, i32(4))); + offset_b_k_[layout] = i32(0); + } + /* axes */ + axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0}; + axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1}; + } + else{ + /* warp offset */ + Value *warp_0 = urem(warp, i32(layout->wpt(0))); + Value *warp_12 = udiv(warp, i32(layout->wpt(0))); + Value *warp_1 = urem(warp_12, i32(layout->wpt(1))); + Value *off_warp_m = mul(warp_0, i32(layout->spw(0))); + Value *off_warp_n = mul(warp_1, i32(layout->spw(1))); + Value *off_lane_m = urem(lane, _16); + Value *off_lane_n = urem(lane, _8); + /* offsets */ + // a offset + offset_a_m_[layout] = add(off_warp_m, off_lane_m); + offset_a_k_[layout] = i32(0); + // b offsets + offset_b_n_[layout] = add(off_warp_n, off_lane_n); + offset_b_k_[layout] = i32(0); + // c offset + Value *off_c_m = add(udiv(lane, _4), off_warp_m); + Value *off_c_n = add(mul(_2, urem(lane, _4)), off_warp_n); + for(unsigned m = 0; m < shape[0]; m+=layout->spt(0)){ + idx_m.push_back(add(off_c_m, i32(m))); + idx_m.push_back(add(off_c_m, i32(m + 8))); + } + for(unsigned n = 0; n < shape[1]; n+=layout->spt(1)){ + idx_n.push_back(add(off_c_n, i32(n))); + idx_n.push_back(add(off_c_n, i32(n + 1))); + } + /* axes */ + axes_[layout->get_axis(0)] = distributed_axis{1, idx_m, warp_0}; + axes_[layout->get_axis(1)] = distributed_axis{1, idx_n, warp_1}; + } } void generator::visit_layout_scanline(analysis::scanline_layout* layout) { - machine_layouts_[layout] = new machine_scanline_layout(mod_, &*builder_, tgt_, a_axes_, axes_, layout); + Value *warp_size = i32(32); + Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); + Value *u_thread_id = urem(u_thread_id_0, warp_size); + Value *u_warp_id = udiv(u_thread_id_0, warp_size); + + auto order = layout->get_order(); + const auto& shape = layout->get_shape(); + Value* full_thread_id = add(mul(u_warp_id, i32(32)), u_thread_id); + // Delinearize + size_t dim = shape.size(); + std::vector thread_id(dim); + for(unsigned k = 0; k < dim - 1; k++){ + Constant *dim_k = i32(layout->mts(order[k])); + Value *rem = urem(full_thread_id, dim_k); + full_thread_id = udiv(full_thread_id, dim_k); + thread_id[order[k]] = rem; + } + thread_id[order[dim - 1]] = full_thread_id; + // Create axes + for(unsigned k = 0; k < dim; k++) { + int nts = layout->nts(k); + int mts = layout->mts(k); + std::string str_k = std::to_string(k); + Value *contiguous_k = i32(nts); + Value *scaled_thread_id = mul(thread_id[k], contiguous_k); + unsigned per_block = nts * mts; + unsigned per_thread = nts * shape[k] / per_block; + std::vector idx_list(per_thread); + for(unsigned n = 0 ; n < per_thread; n++){ + unsigned offset = n / nts * per_block + n % nts; + idx_list[n] = add(scaled_thread_id, i32(offset), "idx_" + str_k + "_" + std::to_string(n)); + } + axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; + } } void generator::visit_layout_shared(analysis::shared_layout* layout) { - - machine_layouts_[layout] = new machine_shared_layout(mod_, &*builder_, tgt_, alloc_, sh_mem_ptr_, layout, vmap_, tmap_); + Type* ty = cvt(layout->get_type()); + PointerType *ptr_ty = ty->getPointerTo(shmem_->getType()->getPointerAddressSpace()); + // double-buffered + if(layout->get_double_buffer()) { + BasicBlock *current = builder_->GetInsertBlock(); + auto info = *layout->get_double_buffer(); + ir::phi_node *phi = info.phi; + BasicBlock *parent = bbs_.at(phi->get_parent()); + if(parent->empty()) + builder_->SetInsertPoint(parent); + else + builder_->SetInsertPoint(&*parent->getFirstNonPHI()); + // create pointers + shared_ptr_[layout] = phi(ptr_ty, 2); + shared_pre_ptr_[layout] = gep(shmem_, i32(alloc_->offset(layout))); + shared_pre_ptr_[layout] = bit_cast(shared_pre_ptr_[layout], shared_ptr_[layout]->getType()); + shared_off_[layout] = phi(i32_ty, 2); + shared_next_ptr_[layout] = gep(shared_ptr_[layout], shared_off_[layout], "next_ptr"); + builder_->SetInsertPoint(current); + } + else{ + size_t offset = alloc_->offset(layout); + shared_ptr_[layout] = gep(shmem_, i32(offset)); + shared_ptr_[layout] = bit_cast(shared_ptr_[layout], ptr_ty); + } } void generator::visit_basic_block(ir::basic_block * block) { - BasicBlock *parent = (BasicBlock*)vmap_[block]; + BasicBlock *parent = bbs_[block]; builder_->SetInsertPoint(parent); for(ir::instruction *i: block->get_inst_list()){ - // std::cout << typeid(*i).name() << std::endl; visit_value(i); } - vmap_[block] = builder_->GetInsertBlock(); + bbs_[block] = builder_->GetInsertBlock(); } void generator::visit_argument(ir::argument* arg) { } -void generator::for_each(ir::value *x, const std::function& fn) { - if(!x->get_type()->is_tile_ty()) - return fn({}); - else { -// if(tmap_.find(x) == tmap_.end()) -// tmap_[x] = machine_layouts_.at(layouts_->get(x))->create(x); - if(auto *dt = dynamic_cast(tmap_.at(x))) - dt->for_each(fn); +void generator::init_idx(ir::value *v) { + idxs_[v].clear(); + if(!v->get_type()->is_tile_ty()){ + idxs_[v].push_back({}); + return; } + if(layouts_->get(v)->to_shared()) + return; + const auto &shapes = v->get_type()->get_tile_shapes(); + size_t rank = shapes.size(); + std::vector axes(rank); + std::vector ord(rank); + // compute axes + for(size_t d = 0; d < shapes.size(); d++){ + if(shapes[d] > 1){ + unsigned x = a_axes_->get(v, d); + axes[d] = axes_.at(x); + } + else{ + axes[d].contiguous = 1; + axes[d].values = {i32(0)}; + } + } + // compute order + analysis::data_layout* layout = layouts_->get(v); + std::iota(ord.begin(), ord.end(), 0); + auto cmp = [&](int x, int y) { + unsigned axx = a_axes_->get(v, x); + unsigned axy = a_axes_->get(v, y); + size_t posx = layout->find_axis(axx); + size_t posy = layout->find_axis(axy); + if(posx < rank && posy < rank) + return layout->get_order(posx) < layout->get_order(posy); + return false; + }; + std::sort(ord.begin(), ord.end(), cmp); + ords_[v] = ord; + // indices + if(axes.size() == 1) + for(Value* x0: axes[ord[0]].values){ + idxs_[v].push_back({x0}); + } + if(axes.size() == 2) + for(Value* x1: axes[ord[1]].values) + for(Value* x0: axes[ord[0]].values){ + indices_t idx(2); + idx[ord[0]] = x0; + idx[ord[1]] = x1; + idxs_[v].push_back(idx); + } + if(axes.size() == 3) + for(Value* x2: axes[ord[2]].values) + for(Value* x1: axes[ord[1]].values) + for(Value* x0: axes[ord[0]].values){ + indices_t idx(3); + idx[ord[0]] = x0; + idx[ord[1]] = x1; + idx[ord[2]] = x2; + idxs_[v].push_back(idx); + } } -Value* generator::get_value(ir::value *x, const indices_t& idx) { - if(x->get_type()->is_tile_ty()) - return tmap_.at(x)->get_value(idx); - return vmap_.at(x); -} - -void generator::set_value(ir::value *x, const indices_t& idx, Value* v) { - if(x->get_type()->is_tile_ty()) - tmap_.at(x)->set_value(idx, v); - else - vmap_[x] = v; -} - - void generator::finalize_shared_layout(analysis::shared_layout *shared) { if(shared->get_double_buffer()) { auto info = *shared->get_double_buffer(); ir::phi_node *phi = info.phi; - PHINode *ptr = (PHINode*)((shared_tile*)tmap_.at(phi))->get_pointer(); - PHINode *offset = (PHINode*)((shared_tile*)tmap_.at(phi))->get_offset(); + PHINode *ptr = (PHINode*)shmems_[phi]; + PHINode *offset = (PHINode*)shoffs_[phi]; for(unsigned n = 0; n < phi->get_num_incoming(); n++){ ir::basic_block* inc_block = phi->get_incoming_block(n); ir::value* inc_val = phi->get_incoming_value(n); - BasicBlock *llvm_inc_block = (BasicBlock*)vmap_.at(inc_block); - shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); + BasicBlock *llvm_inc_block = bbs_.at(inc_block); if(inc_val == info.latch){ builder_->SetInsertPoint(llvm_inc_block->getTerminator()); - Value *next_offset = builder_->CreateNeg(offset); + Value *next_offset = neg(offset); offset->addIncoming(next_offset, llvm_inc_block); } else { unsigned num_bytes = shared->get_type()->get_primitive_size_in_bits() / 8; - offset->addIncoming(builder_->getInt32(shared->get_size() / (2*num_bytes)), llvm_inc_block); + offset->addIncoming(i32(shared->get_size() / (2*num_bytes)), llvm_inc_block); } - ptr->addIncoming(inc_shared->get_pointer(), llvm_inc_block); + ptr->addIncoming(shmems_[inc_val], llvm_inc_block); } } } @@ -1563,18 +1970,17 @@ void generator::finalize_function(ir::function *fn) { finalize_phi_node(phi); } -void generator::finalize_phi_node(ir::phi_node *phi) { - auto it = tmap_.find(phi); - if(it != tmap_.end() && dynamic_cast(it->second)) +void generator::finalize_phi_node(ir::phi_node *x) { + if(shmems_.find(x) != shmems_.end()) return; - for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::basic_block *inc_block = phi->get_incoming_block(n); - BasicBlock *llvm_inc_block = (BasicBlock*)vmap_.at(inc_block); - for_each(phi, [&](indices_t idx){ - PHINode *llvm_phi = (PHINode*)get_value(phi, idx); - Value *llvm_inc_val = get_value(phi->get_incoming_value(n), idx); - llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); - }); + for(unsigned n = 0; n < x->get_num_incoming(); n++){ + ir::basic_block *_block = x->get_incoming_block(n); + BasicBlock *block = bbs_.at(_block); + for(indices_t idx: idxs_.at(x)){ + PHINode *phi = (PHINode*)vals_[x][idx]; + Value *inc = vals_[x->get_incoming_value(n)][idx]; + phi->addIncoming(inc, block); + } } } @@ -1588,11 +1994,11 @@ void generator::visit(ir::module &src, llvm::Module &dst) { Type *int_8_ty = Type::getInt8Ty(*ctx_); Type *int_32_ty = Type::getInt32Ty(*ctx_); ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4); - Type *ptr_ty = PointerType::get(int_8_ty, 3); + Type *ptr_ty = ptr_ty(int_8_ty, 3); GlobalVariable *sh_mem_array = - new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage, + new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalWeakLinkage, nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); - sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty); + shmem_ = bit_cast(sh_mem_array, ptr_ty); } // visit functions for(ir::function *fn: src.get_function_list()) diff --git a/lib/codegen/selection/machine_layout.cc b/lib/codegen/selection/machine_layout.cc deleted file mode 100644 index 9ee45db1c..000000000 --- a/lib/codegen/selection/machine_layout.cc +++ /dev/null @@ -1,325 +0,0 @@ -#include -#include "triton/codegen/selection/machine_layout.h" -#include "triton/codegen/selection/machine_value.h" -#include "triton/codegen/selection/generator.h" -#include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/analysis/axes.h" -#include "triton/codegen/target.h" -#include "triton/ir/instructions.h" -#include "triton/ir/type.h" -#include "llvm/IR/IRBuilder.h" - -namespace triton{ -namespace codegen{ - -using namespace llvm; - -inline Type *llvm_type(ir::type *ty, LLVMContext &ctx) { - // function - if(auto* tt = dynamic_cast(ty)){ - Type *return_ty = llvm_type(tt->get_return_ty(), ctx); - std::vector param_tys; - std::transform(tt->params_begin(), tt->params_end(), std::back_inserter(param_tys), - [&ctx](ir::type* t){ return llvm_type(t, ctx);}); - return FunctionType::get(return_ty, param_tys, false); - } - // pointer - if(ty->is_pointer_ty()){ - Type *elt_ty = llvm_type(ty->get_pointer_element_ty(), ctx); - unsigned addr_space = ty->get_pointer_address_space(); - return PointerType::get(elt_ty, addr_space); - } - // integer - if(ty->is_integer_ty()){ - unsigned bitwidth = ty->get_integer_bitwidth(); - return IntegerType::get(ctx, bitwidth); - } - // primitive types - switch(ty->get_type_id()){ - case ir::type::VoidTyID: return Type::getVoidTy(ctx); - case ir::type::HalfTyID: return Type::getHalfTy(ctx); - case ir::type::FloatTyID: return Type::getFloatTy(ctx); - case ir::type::DoubleTyID: return Type::getDoubleTy(ctx); - case ir::type::X86_FP80TyID: return Type::getX86_FP80Ty(ctx); - case ir::type::PPC_FP128TyID: return Type::getPPC_FP128Ty(ctx); - case ir::type::LabelTyID: return Type::getLabelTy(ctx); - case ir::type::MetadataTyID: return Type::getMetadataTy(ctx); - case ir::type::TokenTyID: return Type::getTokenTy(ctx); - default: break; - } - // unknown type - throw std::runtime_error("unknown conversion from ir::type to Type"); -} - -// Grid construction -inline std::vector delinearize(Value *trailing, const std::vector& order, std::vector &shapes, IRBuilder<> &builder){ - size_t dim = shapes.size(); - std::vector result(dim); - for(unsigned k = 0; k < dim - 1; k++){ - Constant *dim_k = builder.getInt32(shapes[order[k]]); - Value *rem = builder.CreateURem(trailing, dim_k); - trailing = builder.CreateUDiv(trailing, dim_k); - result[order[k]] = rem; - } - result[order[dim - 1]] = trailing; - return result; -} - -inline int32_t ceil(int32_t num, int32_t div){ - return (num + div - 1)/div; -} - - - -machine_shared_layout::machine_shared_layout(Module *mod, Builder *builder, target *tgt, analysis::allocation* alloc, - Value *&sh_mem_ptr, analysis::shared_layout *layout, - std::map& vmap, - std::map& tmap) - : mod_(mod), builder_(builder), tgt_(tgt), alloc_(alloc), sh_mem_ptr_(sh_mem_ptr), layout_(layout), vmap_(vmap), tmap_(tmap) { - - Type* ty = llvm_type(layout_->get_type(), builder_->getContext()); - PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr_->getType()->getPointerAddressSpace()); - // double-buffered - if(layout_->get_double_buffer()) { - BasicBlock *current = builder_->GetInsertBlock(); - auto info = *layout_->get_double_buffer(); - ir::phi_node *phi = info.phi; - BasicBlock *parent = (BasicBlock*)vmap_.at((ir::value*)(phi->get_parent())); - if(parent->empty()) - builder_->SetInsertPoint(parent); - else - builder_->SetInsertPoint(&*parent->getFirstNonPHI()); - // create pointers - ptr_ = builder_->CreatePHI(ptr_ty, 2); - pre_ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layout_))); - pre_ptr_ = builder_->CreateBitCast(pre_ptr_, ptr_->getType()); - offset_ = builder_->CreatePHI(builder_->getInt32Ty(), 2); - next_ptr_ = builder_->CreateGEP(ptr_, offset_, "next_ptr"); - builder_->SetInsertPoint(current); - } - else{ - size_t offset = alloc_->offset(layout_); - ptr_ = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(offset)); - ptr_ = builder_->CreateBitCast(ptr_, ptr_ty); - } -} - - -tile* machine_shared_layout::create(ir::value *v) { - Type* ty = llvm_type(layout_->get_type(), builder_->getContext()); - auto double_buffer = layout_->get_double_buffer(); - // offset - Value *offset = nullptr; - if(double_buffer && v == double_buffer->phi) - offset = offset_; - // base pointer - Value *ptr = ptr_; - if(double_buffer && v == double_buffer->latch) - ptr = next_ptr_; - else if(double_buffer && v == double_buffer->first) - ptr = pre_ptr_; - // create tile - return new shared_tile(ty, layout_->get_shape(), layout_->get_order(), ptr, *builder_, offset); -} - -machine_distributed_layout::machine_distributed_layout(Module *mod, Builder *builder, target *tgt, - analysis::axes *a_axes, std::map& axes, - analysis::data_layout *layout) - : mod_(mod), builder_(builder), tgt_(tgt), a_axes_(a_axes), axes_(axes), layout_(layout) { - -} - -tile *machine_distributed_layout::create(ir::value *v) { - Type *ty = llvm_type(v->get_type()->get_scalar_ty(), builder_->getContext()); - const auto &shapes = v->get_type()->get_tile_shapes(); - size_t rank = shapes.size(); - std::vector axes(rank); - std::vector order(rank); - // compute axes - for(size_t d = 0; d < shapes.size(); d++){ - if(shapes[d] > 1){ - unsigned x = a_axes_->get(v, d); - axes[d] = axes_.at(x); - } - else{ - axes[d].contiguous = 1; - axes[d].values = {builder_->getInt32(0)}; - } - } - // compute order - std::iota(order.begin(), order.end(), 0); - auto cmp = [&](int x, int y) { - unsigned axx = a_axes_->get(v, x); - unsigned axy = a_axes_->get(v, y); - size_t posx = layout_->find_axis(axx); - size_t posy = layout_->find_axis(axy); - if(posx < rank && posy < rank) - return layout_->get_order(posx) < layout_->get_order(posy); - return false; - }; - std::sort(order.begin(), order.end(), cmp); - return new distributed_tile(ty, shapes, order, axes, *builder_); -} - -machine_mma884_layout::machine_mma884_layout(Module *mod, Builder *builder, - target *tgt, analysis::axes *a_axes, - std::map& axes, - analysis::mma884_layout* layout) - : machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) { - - Value *warp_size = builder_->getInt32(32); - Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); - Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); - Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); - - const auto& shape = layout->get_shape(); - if(shape.size() > 3) - throw std::runtime_error("unsupported"); - bool is_batched = shape.size() >= 3; - - Value *_1 = builder_->getInt32(1); - Value *_2 = builder_->getInt32(2); - Value *_3 = builder_->getInt32(3); - Value *_4 = builder_->getInt32(4); - Value *_16 = builder_->getInt32(16); - - // fragments per warp - unsigned fpw_0 = layout->fpw(0); - unsigned fpw_1 = layout->fpw(1); - unsigned fpw_2 = is_batched ? layout->fpw(2) : 1; - // warps per tile - unsigned wpt_0 = layout->wpt(0); - unsigned wpt_1 = layout->wpt(1); - unsigned wpt_2 = is_batched ? layout->wpt(2) : 1; - // mma warp tile size - unsigned hmma_wts_0 = fpw_0 * 8; - unsigned hmma_wts_1 = fpw_1 * 8; - unsigned hmma_wts_2 = is_batched ? fpw_2 : 1; - // mma block tile size - unsigned hmma_bts_0 = hmma_wts_0 * wpt_0; - unsigned hmma_bts_1 = hmma_wts_1 * wpt_1; - unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1; - // number of repetition - unsigned num_rep_0 = shape[0] / hmma_bts_0; - unsigned num_rep_1 = shape[1] / hmma_bts_1; - unsigned num_rep_2 = is_batched ? shape[2] / hmma_bts_2 : 1; - // size of each pack (interleaving) - pack_size_0_ = std::min(num_rep_0, 1); - pack_size_1_ = std::min(num_rep_1, 1); - // number of packs (interleaving) - num_packs_0_ = num_rep_0 / pack_size_0_; - num_packs_1_ = num_rep_1 / pack_size_1_; - - /* intra warp offset */ - // offset of quad in pair - Value *in_pair_off_a = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)), - builder_->getInt32(fpw_0 * pack_size_0_)); - Value *in_pair_off_b = builder_->CreateMul(builder_->CreateUDiv(builder_->CreateAnd(u_thread_id, _16), builder_->getInt32(4)), - builder_->getInt32(fpw_1 * pack_size_1_)); - - // Quad pair id - Value *pair_a_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4); - Value *pair_b_id = builder_->CreateUDiv(builder_->CreateURem(u_thread_id, _16), _4); - pair_a_id = builder_->CreateURem(pair_a_id, builder_->getInt32(fpw_0)); - pair_b_id = builder_->CreateUDiv(pair_b_id, builder_->getInt32(fpw_0)); - pair_b_id = builder_->CreateURem(pair_b_id, builder_->getInt32(fpw_1)); - // Quad pair offset - Value *pair_a_off = builder_->CreateMul(pair_a_id, builder_->getInt32(4 * pack_size_0_)); - Value *pair_b_off = builder_->CreateMul(pair_b_id, builder_->getInt32(4 * pack_size_1_)); - - /* inter warp offset */ - Value *warp_id_0 = builder_->CreateURem(u_warp_id, builder_->getInt32(wpt_0)); - Value *warp_id_12 = builder_->CreateUDiv(u_warp_id, builder_->getInt32(wpt_0)); - Value *warp_id_1 = builder_->CreateURem(warp_id_12, builder_->getInt32(wpt_1)); - Value *warp_id_2 = builder_->CreateUDiv(warp_id_12, builder_->getInt32(wpt_1)); - Value *warp_offset_i = builder_->CreateMul(warp_id_0, builder_->getInt32(hmma_wts_0 * pack_size_0_)); - Value *warp_offset_j = builder_->CreateMul(warp_id_1, builder_->getInt32(hmma_wts_1 * pack_size_1_)); - - /* offsets */ - // a offset - offset_a_i_ = builder_->CreateAdd(warp_offset_i, builder_->CreateAdd(pair_a_off, in_pair_off_a)); - offset_a_k_ = builder_->CreateAnd(u_thread_id, _3); - // b offsets - offset_b_j_ = builder_->CreateAdd(warp_offset_j, builder_->CreateAdd(pair_b_off, in_pair_off_b)); - offset_b_k_ = builder_->CreateAnd(u_thread_id, _3); - - // c offsets - Value *offset_c_i = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _1), offset_a_i_); - Value *offset_c_j = builder_->CreateAdd(builder_->CreateAnd(u_thread_id, _2), - builder_->CreateAdd(warp_offset_j, pair_b_off)); - - /* indices */ - // i indices - std::vector idx_i; - for(unsigned pack = 0; pack < num_packs_0_; pack++) - for(unsigned ii = 0; ii < pack_size_0_; ii++) - for(unsigned i = 0; i < 2; i++){ - idx_i.push_back(builder_->CreateAdd(offset_c_i, builder_->getInt32(pack*hmma_bts_0*pack_size_0_ + ii*4 + i*2))); - } - // j indices - std::vector idx_j; - for(unsigned pack = 0; pack < num_packs_1_; pack++) - for(unsigned jj = 0; jj < pack_size_1_; jj++) - for(unsigned j = 0; j < 2; j++){ - idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_))); - idx_j.push_back(builder_->CreateAdd(offset_c_j, builder_->getInt32(pack*hmma_bts_1*pack_size_1_ + jj*4 + j*4*fpw_1*pack_size_1_ + 1))); - } - // z indices - std::vector idx_z; - for(unsigned pack = 0; pack < num_rep_2; pack++) - idx_z.push_back(builder_->CreateAdd(warp_id_2, builder_->getInt32(pack*hmma_bts_2))); - - - /* axes */ - axes_[layout->get_axis(0)] = distributed_axis{1, idx_i, warp_id_0}; - axes_[layout->get_axis(1)] = distributed_axis{1, idx_j, warp_id_1}; - if(is_batched) - axes_[layout->get_axis(2)] = distributed_axis{1, idx_z, warp_id_2}; -} - - -machine_scanline_layout::machine_scanline_layout(Module *mod, Builder *builder, - target *tgt, - analysis::axes *a_axes, std::map &axes, - analysis::scanline_layout* layout) - : machine_distributed_layout(mod, builder, tgt, a_axes, axes, layout) { - - Value *warp_size = builder_->getInt32(32); - Value* u_thread_id_0 = tgt_->get_local_id(mod_, *builder_, 0); - Value *u_thread_id = builder_->CreateURem(u_thread_id_0, warp_size); - Value *u_warp_id = builder_->CreateUDiv(u_thread_id_0, warp_size); - - auto order = layout->get_order(); - const auto& shape = layout->get_shape(); - Value* full_thread_id = builder_->CreateAdd(builder_->CreateMul(u_warp_id, builder_->getInt32(32)), u_thread_id); - // Delinearize - size_t dim = shape.size(); - std::vector thread_id(dim); - for(unsigned k = 0; k < dim - 1; k++){ - Constant *dim_k = builder_->getInt32(layout->mts(order[k])); - Value *rem = builder_->CreateURem(full_thread_id, dim_k); - full_thread_id = builder_->CreateUDiv(full_thread_id, dim_k); - thread_id[order[k]] = rem; - } - thread_id[order[dim - 1]] = full_thread_id; - // Create axes - for(unsigned k = 0; k < dim; k++) { - int nts = layout->nts(k); - int mts = layout->mts(k); - std::string str_k = std::to_string(k); - Value *contiguous_k = builder_->getInt32(nts); - Value *scaled_thread_id = builder_->CreateMul(thread_id[k], contiguous_k); - unsigned per_block = nts * mts; - unsigned per_thread = nts * shape[k] / per_block; - std::vector idx_list(per_thread); - for(unsigned n = 0 ; n < per_thread; n++){ - unsigned offset = n / nts * per_block + n % nts; - idx_list[n] = builder_->CreateAdd(scaled_thread_id, builder_->getInt32(offset), "idx_" + str_k + "_" + std::to_string(n)); - } - axes_[layout->get_axis(k)] = distributed_axis{nts, idx_list, thread_id[k]}; - } -} - - -} -} diff --git a/lib/codegen/selection/machine_value.cc b/lib/codegen/selection/machine_value.cc deleted file mode 100644 index dbff237d1..000000000 --- a/lib/codegen/selection/machine_value.cc +++ /dev/null @@ -1,214 +0,0 @@ -#include -#include -#include "llvm/IR/IRBuilder.h" -#include "triton/codegen/selection/machine_value.h" - -namespace triton{ -namespace codegen{ - -using namespace llvm; - -/* Distributed Tile */ -void distributed_tile::init_indices() { - std::vector id(axes_.size(), 0); - // build - size_t k = 0; - while(true) { - indices_t current; - for(size_t d = 0; d < id.size(); d++) - current.push_back(axes_[d].values[id[d]]); - size_t sz = indices_.size(); - indices_[current] = sz; - values_[current] = nullptr; - ordered_indices_.push_back(current); - id[order_[0]]++; - while(id[order_[k]] == axes_[order_[k]].values.size()){ - if(k == id.size() - 1) - return; - id[order_[k++]] = 0; - id[order_[k]]++; - } - k = 0; - } -} - - -distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector& order, const axes_t &axes, llvm::IRBuilder<> &builder) - : tile(ty, shapes), axes_(axes), order_(order), builder_(builder) { - init_indices(); -} - -void distributed_tile::set_value(indices_t idx, Value *x) { - assert(x->getType() == ty_ && "cannot set a value of different type"); - Value *&result = values_[idx]; - assert(!result && "value cannot be set twice"); - result = x; -} - -Value* distributed_tile::get_value(indices_t idx) { - Value *result = values_.at(idx); - assert(result && "value has not been set"); - return result; -} - -unsigned distributed_tile::get_linear_index(indices_t idx) { - return indices_[idx]; -} - -indices_t distributed_tile::get_ordered_indices(unsigned id) { - return ordered_indices_.at(id); -} - - -void distributed_tile::for_each(std::function fn, int start, int end) { - if(end < 0) - end = ordered_indices_.size() + end + 1; - for(unsigned i = start; i < end; i++) - fn(ordered_indices_[i]); -} - -void distributed_tile::for_each(std::function fn, std::vector starts, std::vector sizes){ - int rank = sizes.size(); - int len = 1; - for(int s: sizes) - len *= s; - - for(int i = 0; i < len; i++){ - indices_t idx(rank); - int current = i; - for(int k = 0; k < rank; k++){ - idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]); - current = current / sizes[k]; - } - fn(idx); - } -} - - -/* Shared Tile */ -void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) { - BinaryOperator *bin_op = dyn_cast(arg); - Constant *_0 = ConstantInt::get(Type::getInt32Ty(arg->getContext()), 0); - if(dyn_cast(arg)){ - cst = arg; - non_cst = _0; - return; - } - if(!bin_op || bin_op->getOpcode() != llvm::BinaryOperator::Add){ - non_cst = arg; - cst = _0; - return; - } - Constant *cst_lhs = dyn_cast(bin_op->getOperand(0)); - Constant *cst_rhs = dyn_cast(bin_op->getOperand(1)); - if(cst_lhs && cst_rhs){ - cst = arg; - non_cst = _0; - } - else if(cst_lhs){ - cst = cst_lhs; - non_cst = bin_op->getOperand(1); - } - else if(cst_rhs){ - cst = cst_rhs; - non_cst = bin_op->getOperand(0); - } - else{ - non_cst = arg; - cst = _0; - } -} - -void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_idx, indices_t &cst_idx) { - non_cst_idx.clear(); - cst_idx.clear(); - for(Value *idx: arg_idx){ - Value *non_cst, *cst; - extract_constant(idx, non_cst, cst); - non_cst_idx.push_back(non_cst); - cst_idx.push_back(cst); - } -} - - -Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, - const std::vector& perm, const std::vector& order, - indices_t idx) { - // strides - std::vector strides(shapes.size(), builder.getInt32(0)); - strides[order[0]] = builder.getInt32(1); - for(size_t i = 1; i < idx.size(); i++) - strides[order[i]] = builder.CreateMul(strides[order[i-1]], builder.getInt32(shapes[order[i-1]])); - // result - Value *result = builder.getInt32(0); - for(size_t i = 0; i < idx.size(); i++) - result = builder.CreateAdd(result, builder.CreateMul(idx[perm[i]], strides[i])); - return result; -} - -shared_tile::shared_tile(Type *ty, const shapes_t &shapes, const std::vector& order, Value *ptr, llvm::IRBuilder<> &builder, Value *offset, const std::vector& perm): - tile(ty, shapes), order_(order), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1), perm_(perm){ - return_vector_ = false; - if(perm_.empty()){ - perm_.resize(shapes.size()); - std::iota(perm_.begin(), perm_.end(), 0); - } -} - -void shared_tile::set_value(indices_t idx, Value *value) { - Value *ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, idx)); - unsigned addr_space = ptr->getType()->getPointerAddressSpace(); - ptr = builder_.CreateBitCast(ptr, value->getType()->getPointerTo(addr_space)); - builder_.CreateStore(value, ptr); -} - -void shared_tile::set_vector_size(unsigned vector_size) { - vector_size_ = vector_size; -} - -void shared_tile::set_return_mode(bool return_vector){ - return_vector_ = return_vector; -} - - -Value* shared_tile::get_value(indices_t idx) { - indices_t non_cst_idx, cst_idx; - extract_constant(idx, non_cst_idx, cst_idx); - Value *&base_ptr = ptr_cache_[non_cst_idx]; - unsigned vector_size = vector_size_; - Type *ty = ty_; - if(ty->isHalfTy() && (vector_size % 2 == 0)){ - ty = IntegerType::get(ty->getContext(), 32); - vector_size = vector_size / 2; - } - if(base_ptr == nullptr){ -// BasicBlock* store = builder_.GetInsertBlock(); -// if(!non_cst_idx.empty()) -// if(isa(non_cst_idx.front())){ -// builder_.SetInsertPoint((Instruction*)non_cst_idx.front()); -// } - base_ptr = builder_.CreateGEP(ptr_, shared_offset(builder_, shapes_, perm_, order_, non_cst_idx)); - if(vector_size_ > 1){ - Type *vec_ty = VectorType::get(ty, vector_size); - Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace()); - base_ptr = builder_.CreateBitCast(base_ptr, vec_ptr_ty); - } -// builder_.SetInsertPoint(store); - } - Value *offset = shared_offset(builder_, shapes_, perm_, order_, cst_idx); - Value *div = offset; - if(vector_size_ > 1) - div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_)); - Value *ptr = builder_.CreateGEP(base_ptr, div); - Value *result = builder_.CreateLoad(ptr); - if(return_vector_ == false && vector_size_ > 1) { - Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_)); - result = builder_.CreateExtractElement(result, rem); - } - return result; -} - - - -} -} diff --git a/lib/codegen/target.cc b/lib/codegen/target.cc index 253dfc709..82ebbe649 100644 --- a/lib/codegen/target.cc +++ b/lib/codegen/target.cc @@ -14,6 +14,12 @@ namespace triton{ namespace codegen{ // base + + +nvidia_cu_target* target::as_nvidia() { + return dynamic_cast(this); +} + bool target::is_gpu() const { return is_gpu_; } @@ -25,7 +31,7 @@ void amd_cl_target::set_kernel(IRBuilder<>& builder, LLVMContext &ctx, Module *m Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) { Function *barrier = Intrinsic::getDeclaration(module, Intrinsic::amdgcn_s_barrier); - return builder.CreateCall(barrier, {}); + return builder.CreateIntrinsic(Intrinsic::amdgcn_s_barrier, {}, {}); } Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { @@ -45,8 +51,7 @@ Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigne Intrinsic::amdgcn_workgroup_id_y, Intrinsic::amdgcn_workgroup_id_z }; - Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]); - Value* group_id = builder.CreateCall(get_group_id, {}); + Value* group_id = builder.CreateIntrinsic(ids[ax], {}, {}); return group_id; } @@ -99,8 +104,7 @@ Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsi Intrinsic::nvvm_read_ptx_sreg_ctaid_y, Intrinsic::nvvm_read_ptx_sreg_ctaid_z }; - Value* get_cta_id = Intrinsic::getDeclaration(module, cta_ids[ax]); - Value* cta_id = builder.CreateCall(get_cta_id, {}); + Value* cta_id = builder.CreateIntrinsic(cta_ids[ax], {}, {}); return cta_id; } @@ -120,8 +124,7 @@ Value* nvidia_cu_target::get_num_blocks(Module *module, IRBuilder<>& builder, un Intrinsic::nvvm_read_ptx_sreg_nctaid_y, Intrinsic::nvvm_read_ptx_sreg_nctaid_z }; - Value* get_nctaid = Intrinsic::getDeclaration(module, ids[ax]); - return builder.CreateCall(get_nctaid, {}); + return builder.CreateIntrinsic(ids[ax], {}, {}); } // CPU diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 64af407e5..78316c0df 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -66,7 +66,7 @@ void coalesce::run(ir::module &mod) { for(size_t id = 0; id < num_groups; id++) { - if(!layout_->get(id)->to_mma884()) + if(!layout_->get(id)->to_mma()) continue; // extract memory stores const auto& values = layout_->values_of(id); diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc index 4b2aadb99..2641dad53 100644 --- a/lib/codegen/transform/cts.cc +++ b/lib/codegen/transform/cts.cc @@ -28,12 +28,14 @@ inline bool is_shmem_res(ir::value* v){ return true; if(i->get_id() == ir::INST_COPY_TO_SHARED) return true; + if(i->get_id() == ir::INST_MASKED_LOAD_ASYNC) + return true; return false; } // run pass on module -void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) { +void cts::add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) { auto *i = dynamic_cast(x); // not an instruction if(!i) { @@ -58,8 +60,9 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool // copy builder.set_insert_point_after(i); ir::value *copy; - if(to_shared) + if(to_shared){ copy = builder.create_copy_to_shared(x); + } else copy = builder.create_copy_from_shared(x); parent->replace_uses_of_with(x, copy); diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 450e98315..2972ed6ca 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -54,7 +54,7 @@ void membar::get_written_intervals(ir::instruction *i, interval_vec_t &res){ add_reference(i, res); } -void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) { +void membar::insert_barrier(ir::instruction *instr, std::pair type, ir::builder &builder) { if(auto *phi = dynamic_cast(instr)) { std::set incoming; for(unsigned n = 0; n < phi->get_num_incoming(); n++){ @@ -63,7 +63,10 @@ void membar::insert_barrier(ir::instruction *instr, ir::builder &builder) { if(incoming.insert(inc_val).second){ ir::basic_block *block = inc_val->get_parent(); builder.set_insert_point(block->get_inst_list().back()); - builder.create_barrier(); + if(type.first) + builder.create_async_wait(); + if(type.second) + builder.create_barrier(); } } } @@ -85,8 +88,9 @@ std::pair membar::transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, - std::set& insert_loc, - std::set& safe_war) { + std::map>& insert_loc, + std::set& safe_war, + std::vector& to_sync) { ir::basic_block::inst_list_t instructions = block->get_inst_list(); interval_vec_t new_written_to = written_to; interval_vec_t new_read_from = read_from; @@ -95,6 +99,8 @@ std::pair(i);}; + auto is_copy_to_shared = [&](ir::instruction *i){ return dynamic_cast(i);}; + bool copy_async_wait = std::any_of(to_sync.begin(), to_sync.end(), is_load_async); + bool barrier = std::any_of(to_sync.begin(), to_sync.end(), is_copy_to_shared); + insert_loc.insert({i, {copy_async_wait, barrier}}); new_written_to.clear(); new_read_from.clear(); + to_sync.clear(); } std::copy(written.begin(), written.end(), std::back_inserter(new_written_to)); std::copy(read.begin(), read.end(), std::back_inserter(new_read_from)); @@ -125,17 +136,17 @@ void membar::run(ir::module &mod) { if(!layout || !layout->get_double_buffer()) continue; for(ir::value *v: layout->get_values()) - if(v != layout->get_double_buffer()->phi) + if(v != layout->get_double_buffer()->phi){ safe_war.insert(v); + } } - - for(ir::function *fn: mod.get_function_list()){ std::vector rpo = ir::cfg::reverse_post_order(fn); std::map written_to; std::map read_from; - std::set insert_locs; + std::vector to_sync; + std::map> insert_locs; size_t n_inserted_im1 = 0; bool done = false; do{ @@ -150,7 +161,7 @@ void membar::run(ir::module &mod) { for(ir::basic_block* pred: block->get_predecessors()) pred_read_from.push_back(read_from[pred]); // apply transfer function - auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war); + auto result = transfer(block, join(pred_written_to), join(pred_read_from), insert_locs, safe_war, to_sync); written_to[block] = result.first; read_from[block] = result.second; } @@ -158,8 +169,9 @@ void membar::run(ir::module &mod) { done = (n_inserted_im1 == n_inserted_i); n_inserted_im1 = n_inserted_i; }while(!done); - for(ir::instruction* i: insert_locs) - insert_barrier(i, builder); + for(auto x: insert_locs){ + insert_barrier(x.first, x.second, builder); + } } } diff --git a/lib/codegen/transform/peephole.cc b/lib/codegen/transform/peephole.cc index d8062be6b..8caa1b0bc 100644 --- a/lib/codegen/transform/peephole.cc +++ b/lib/codegen/transform/peephole.cc @@ -97,6 +97,24 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){ //} +bool peephole::rewrite_load_to_shared(ir::instruction *value, ir::builder& builder){ + auto copy_to_shared = dynamic_cast(value); + if(!copy_to_shared) + return false; + ir::value *arg = copy_to_shared->get_operand(0); + ir::masked_load_inst* ld = dynamic_cast(arg); + if(!ld) + return false; + builder.set_insert_point(copy_to_shared); + ir::value *ptr = ld->get_pointer_operand(); + ir::value *msk = ld->get_mask_operand(); + ir::value *val = ld->get_false_value_operand(); + ir::value* new_load = builder.create_masked_load_async(ptr, msk, val); + copy_to_shared->replace_all_uses_with(new_load); + return true; + +} + bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){ auto x = dynamic_cast(value); if(!x) @@ -197,10 +215,12 @@ void peephole::run(ir::module &mod) { continue; bool was_modified = false; was_modified = was_modified || rewrite_mult(i, builder); -// was_modified = was_modified || rewrite_cts_cfs(i, builder); + // was_modified = was_modified || rewrite_cts_cfs(i, builder); was_modified = was_modified || rewrite_trans_phi(i, builder); was_modified = was_modified || rewrite_unit_red(i, builder); was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder); +// if(tgt_->as_nvidia()->sm() >= 80) +// was_modified = was_modified || rewrite_load_to_shared(i, builder); if(was_modified) seen.insert(i); } diff --git a/lib/codegen/transform/reorder.cc b/lib/codegen/transform/reorder.cc new file mode 100644 index 000000000..2949e427d --- /dev/null +++ b/lib/codegen/transform/reorder.cc @@ -0,0 +1,51 @@ +#include +#include +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/instructions.h" +#include "triton/codegen/transform/reorder.h" + +namespace triton { +namespace codegen{ +namespace transform{ + +void reorder::run(ir::module& mod){ + ir::builder &builder = mod.get_builder(); + std::vector> to_replace; + + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction* i: block->get_inst_list()){ + if(auto* ld = dynamic_cast(i)){ + ir::value* _ptr = ld->get_pointer_operand(); + ir::value* _msk = ld->get_mask_operand(); + ir::value* _val = ld->get_false_value_operand(); + auto ptr = std::find(block->begin(), block->end(), _ptr); + auto msk = std::find(block->begin(), block->end(), _msk); + auto val = std::find(block->begin(), block->end(), _val); + if(ptr == block->end() || msk == block->end() || val == block->end()) + continue; + auto it = std::find(block->begin(), block->end(), i); + int dist_ptr = std::distance(ptr, it); + int dist_msk = std::distance(msk, it); + int dist_val = std::distance(val, it); + if(dist_ptr < dist_msk && dist_ptr < dist_val) + builder.set_insert_point(++ptr); + if(dist_msk < dist_ptr && dist_msk < dist_val) + builder.set_insert_point(++msk); + if(dist_val < dist_ptr && dist_val < dist_msk) + builder.set_insert_point(++val); + ir::value* new_ld = builder.create_masked_load(_ptr, _msk, _val); + to_replace.push_back(std::make_pair(ld, new_ld)); + } + } + + for(auto& x: to_replace) + x.first->replace_all_uses_with(x.second); + +} + +} +} +} diff --git a/lib/driver/device.cc b/lib/driver/device.cc index 53ed3007d..ff65bb5fa 100755 --- a/lib/driver/device.cc +++ b/lib/driver/device.cc @@ -48,46 +48,6 @@ std::unique_ptr host_device::make_target() const { // CUDA // /* ------------------------ */ -// architecture -cu_device::Architecture cu_device::nv_arch(std::pair sm) const { - switch(sm.first) { - case 7: - switch(sm.second){ - case 0: return Architecture::SM_7_0; - } - - case 6: - switch(sm.second){ - case 0: return Architecture::SM_6_0; - case 1: return Architecture::SM_6_1; - } - - case 5: - switch(sm.second){ - case 0: return Architecture::SM_5_0; - case 2: return Architecture::SM_5_2; - default: return Architecture::UNKNOWN; - } - - case 3: - switch(sm.second){ - case 0: return Architecture::SM_3_0; - case 5: return Architecture::SM_3_5; - case 7: return Architecture::SM_3_7; - default: return Architecture::UNKNOWN; - } - - case 2: - switch(sm.second){ - case 0: return Architecture::SM_2_0; - case 1: return Architecture::SM_2_1; - default: return Architecture::UNKNOWN; - } - - default: return Architecture::UNKNOWN; - } -} - // information query template int cu_device::cuGetInfo() const{ @@ -108,11 +68,6 @@ nvmlDevice_t cu_device::nvml_device() const{ return map.at(key); } -// architecture -cu_device::Architecture cu_device::architecture() const{ - return nv_arch(compute_capability()); -} - // number of address bits size_t cu_device::address_bits() const{ return sizeof(size_t)*8; @@ -133,17 +88,17 @@ std::string cu_device::pci_bus_id() const{ } // force the device to be interpreted as a particular cc -void cu_device::interpret_as(std::pair cc){ - interpreted_as_ = std::make_shared>(cc); +void cu_device::interpret_as(int cc){ + interpreted_as_ = std::make_shared(cc); } // compute capability -std::pair cu_device::compute_capability() const { +int cu_device::compute_capability() const { if(interpreted_as_) return *interpreted_as_; - size_t _major = cuGetInfo(); - size_t _minor = cuGetInfo(); - return std::make_pair(_major, _minor); + size_t major = cuGetInfo(); + size_t minor = cuGetInfo(); + return major*10 + minor; } // maximum number of threads per block @@ -218,7 +173,7 @@ std::string cu_device::infos() const{ // target std::unique_ptr cu_device::make_target() const { - return std::unique_ptr(new codegen::nvidia_cu_target()); + return std::unique_ptr(new codegen::nvidia_cu_target(compute_capability())); } diff --git a/lib/driver/dispatch.cc b/lib/driver/dispatch.cc index 3b3af5596..df6f14ddb 100755 --- a/lib/driver/dispatch.cc +++ b/lib/driver/dispatch.cc @@ -93,6 +93,7 @@ namespace driver bool dispatch::cuinit(){ if(cuda_==nullptr){ + putenv((char*)"CUDA_CACHE_DISABLE=1"); std::string libcuda = tools::getenv("TRITON_LIBCUDA"); if(libcuda.empty()) cuda_ = dlopen("libcuda.so", RTLD_LAZY); diff --git a/lib/driver/module.cc b/lib/driver/module.cc index bdac2798e..90339c418 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -20,7 +20,9 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include +#include #include +#include #include "triton/driver/module.h" #include "triton/driver/context.h" #include "triton/driver/error.h" @@ -41,6 +43,19 @@ #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/Transforms/Utils/Cloning.h" +std::string exec(const char* cmd) { + std::array buffer; + std::string result; + std::unique_ptr pipe(popen(cmd, "r"), pclose); + if (!pipe) { + throw std::runtime_error("popen() failed!"); + } + while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { + result += buffer.data(); + } + return result; +} + namespace triton { namespace driver @@ -63,11 +78,11 @@ void module::init_llvm() { } module::module(CUmodule mod, bool has_ownership) - : polymorphic_resource(mod, has_ownership) { + : polymorphic_resource(mod, has_ownership), spilled_(0) { } module::module(host_module_t mod, bool has_ownership) - : polymorphic_resource(mod, has_ownership) { + : polymorphic_resource(mod, has_ownership), spilled_(0) { } @@ -86,10 +101,12 @@ void module::compile_llvm_module(std::unique_ptr module, const std file_type_t ft) { init_llvm(); // // debug -// llvm::legacy::PassManager pm; + llvm::legacy::PassManager pm; + std::string tmp; +// llvm::raw_string_ostream oss(llir_); // pm.add(llvm::createPrintModulePass(llvm::outs())); -// pm.add(llvm::createVerifierPass()); -// pm.run(*module); + pm.add(llvm::createVerifierPass()); + pm.run(*module); // create machine module->setTargetTriple(triple); std::string error; @@ -176,7 +193,7 @@ host_module::host_module(std::unique_ptr src): module(host_module_ // create execution engine for(llvm::Function& fn: src->functions()) - hst_->functions[fn.getName()] = &fn; + hst_->functions[fn.getName().str()] = &fn; // llvm::orc::JITTargetMachineBuilder JTMB = *llvm::orc::JITTargetMachineBuilder::detectHost(); // auto DL = JTMB.getDefaultDataLayoutForTarget(); @@ -225,7 +242,8 @@ static std::map vptx = { {10010, 64}, {10020, 65}, {11000, 70}, - {11010, 71} + {11010, 71}, + {11020, 72} }; std::string cu_module::compile_llvm_module(std::unique_ptr module, driver::device* device) { @@ -238,9 +256,7 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, assert(short_ptr); short_ptr->setValue(true); // compute capability - auto _cc = ((driver::cu_device*)device)->compute_capability(); - int cc = _cc.first*10 + _cc.second; - cc = std::min(cc, max_nvvm_cc); + int cc = ((driver::cu_device*)device)->compute_capability(); std::string sm = "sm_" + std::to_string(cc); // driver version int version; @@ -251,12 +267,11 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, throw std::runtime_error("Triton requires CUDA 10+"); // PTX version int ptx = vptx.at(version); - ptx = std::min(ptx, max_nvvm_ptx); int ptx_major = ptx / 10; int ptx_minor = ptx % 10; // create llvm::SmallVector buffer; - module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "+ptx" + std::to_string(ptx), Assembly); + module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", "sm_" + std::to_string(std::min(cc, max_nvvm_cc)), "", buffer, "+ptx" + std::to_string(std::min(ptx, max_nvvm_ptx)), Assembly); std::string result(buffer.begin(), buffer.end()); find_and_replace(result, ".version", "\n", ".version " + std::to_string(ptx_major) + "." + std::to_string(ptx_minor) + "\n"); find_and_replace(result, ".target", "\n", ".target " + sm + "\n"); @@ -266,21 +281,69 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, } -cu_module::cu_module(driver::device* device, std::unique_ptr ll_module): cu_module(compile_llvm_module(std::move(ll_module), device)) { } +cu_module::cu_module(driver::device* device, std::unique_ptr ll_module): cu_module(device, compile_llvm_module(std::move(ll_module), device)) { } -cu_module::cu_module(std::string const & source) : module(CUmodule(), true), source_(source){ +cu_module::cu_module(driver::device* device, std::string const & source) : module(CUmodule(), true), ptx_(source){ // JIT compile source-code - CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; - unsigned int errbufsize = 8096; - std::string errbuf(errbufsize, 0); - void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)errbuf.data()}; + try{ - dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval); - }catch(exception::cuda::invalid_ptx const &){ +// // compile ptx with ptxas +// char _fsrc[] = "/tmp/triton_k_XXXXXX"; +// char _flog[] = "/tmp/triton_l_XXXXXX"; +// int fdsrc = mkstemp(_fsrc); +// int fdlog = mkstemp(_flog); +// std::string fsrc = _fsrc; +// std::string flog = _flog; +// std::ofstream ofs(fsrc); +// ofs << source; +// ofs.close(); +// std::string cmd; +// int err; +// driver::cu_device* cu_device = (driver::cu_device*)device; +// cmd = "ptxas -v --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; +// err = system(cmd.c_str()); +// dispatch::cuModuleLoad(&*cu_, (fsrc + ".o").c_str()); +// std::ifstream file(flog); +// std::string log; +// if(file) +// while (!file.eof()) log.push_back(file.get()); +// unlink(_fsrc); +// unlink(_flog); + +// std::smatch match; +// std::regex expr ("\\b([0-9]+) bytes spill"); +// spilled_ = 0; +// while (std::regex_search (log,match,expr)){ +// spilled_ += std::stoi(match[1]); +// log = match.suffix(); +// } +// std::cout << log << std::endl; + + CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, CU_JIT_INFO_LOG_BUFFER, + CU_JIT_LOG_VERBOSE}; + unsigned int errbufsize = 8192; + unsigned int logbufsize = 8192; + char _err[errbufsize]; + char _log[logbufsize]; + void* optval[] = {(void*)(uintptr_t)errbufsize, (void*)_err, (void*)(uintptr_t)logbufsize, (void*)_log, (void*)1}; + dispatch::cuModuleLoadDataEx(&*cu_, ptx_.data(), 5, opt, optval); + std::string err(_err); + std::string log(_log); + +// std::cout << log << std::endl; + std::smatch match; + std::regex expr ("\\b([0-9]+) bytes spill"); + spilled_ = 0; + while (std::regex_search(log,match,expr)){ + spilled_ += std::stoi(match[1]); + log = match.suffix(); + } + } + catch(exception::cuda::invalid_ptx const &){ //#ifdef TRITON_LOG_PTX_ERROR std::cout << source << std::endl; std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl; - std::cerr << errbuf << std::endl; // exit(1); //#endif throw; diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index eceb694d2..2122b57bd 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -1,5 +1,6 @@ #include #include +#include #include "triton/ir/basic_block.h" #include "triton/ir/builder.h" #include "triton/ir/constant.h" @@ -253,6 +254,15 @@ DEFINE_FCMP_INSTR(ONE, cmp_pred_t::FCMP_ONE) value *builder::create_load(value *ptr, const std::string &name){ return insert(unmasked_load_inst::create(ptr, name)); +// type *ty = ptr->get_type()->get_pointer_element_ty(); +// value *mask = constant_int::get(get_int1_ty(), 1); +// value *undef = undef_value::get(ty); +// if(ptr->get_type()->is_tile_ty()){ +// auto shapes = ptr->get_type()->get_tile_shapes(); +// return insert(masked_load_inst::create(ptr, create_splat(mask, shapes), create_splat(undef, shapes), name)); +// } +// return insert(masked_load_inst::create(ptr, mask, undef, name)); + } value *builder::create_store(value *ptr, value *val, const std::string &name){ @@ -263,6 +273,7 @@ value *builder::create_masked_load(value *ptr, value *mask, value *false_value, return insert(masked_load_inst::create(ptr, mask, false_value, name)); } + value *builder::create_masked_store(value *ptr, value *val, value *mask, const std::string &name){ return insert(masked_store_inst::create(ptr, val, mask, name)); } @@ -348,13 +359,22 @@ value *builder::create_copy_to_shared(value *arg, const std::string &name) { return insert(copy_to_shared_inst::create(arg, name)); } + value *builder::create_copy_from_shared(value *arg, const std::string &name) { return insert(copy_from_shared_inst::create(arg, name)); } +value *builder::create_masked_load_async(value *ptr, value *mask, value *false_value, const std::string &name) { + return insert(masked_load_async_inst::create(ptr, mask, false_value, name)); +} + value *builder::create_barrier(const std::string &name) { return insert(barrier_inst::create(ctx_, name)); } +value *builder::create_async_wait() { + return insert(async_wait_inst::create(ctx_)); +} + } } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 684c0d7d9..5e15e83c7 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -463,6 +463,20 @@ masked_load_inst* masked_load_inst::create(value *ptr, value *mask, value *false return new masked_load_inst(ptr, mask, false_value, name, next); } +// masked load async +masked_load_async_inst::masked_load_async_inst(value *ptr, value *mask, value *false_value, + const std::string &name, instruction *next) + : load_inst(ptr, INST_MASKED_LOAD_ASYNC, 3, name, next) { + set_operand(0, ptr); + set_operand(1, mask); + set_operand(2, false_value); +} + +masked_load_async_inst* masked_load_async_inst::create(value *ptr, value *mask, value *false_value, + const std::string &name, instruction *next) { + return new masked_load_async_inst(ptr, mask, false_value, name, next); +} + // atomic add atomic_add_inst::atomic_add_inst(value *ptr, value *val, value *msk, const std::string &name, instruction *next) @@ -804,6 +818,14 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru return new barrier_inst(ctx, name, next); } +async_wait_inst::async_wait_inst(context &ctx, const std::string &name, + instruction *next) + : instruction(type::get_void_ty(ctx), INST_ASYNC_WAIT, 0, name, next) { } + +async_wait_inst* async_wait_inst::create(context &ctx, const std::string &name, instruction *next) { + return new async_wait_inst(ctx, name, next); +} + // nv_dynamic_program_idx make_range_dyn::make_range_dyn(type *ty, const std::string &name, instruction *next) diff --git a/lib/ir/print.cc b/lib/ir/print.cc index f88ba1f6f..47e1d9b59 100644 --- a/lib/ir/print.cc +++ b/lib/ir/print.cc @@ -65,7 +65,12 @@ void print(module &mod, std::ostream& os) { os << get_name(ops[i], cnt++); os << (i < num_ops - 1?", ":""); } - os << ";" << std::endl; + os << ";"; +// os << " ("; +// for(ir::user* usr: inst->get_users()) +// os << get_name(usr, cnt++) << ", " ; +// os << " )"; + os << std::endl; } } os << "}" << std::endl; diff --git a/lib/ir/value.cc b/lib/ir/value.cc index 81f803df6..b970e07d7 100644 --- a/lib/ir/value.cc +++ b/lib/ir/value.cc @@ -68,9 +68,10 @@ unsigned user::get_num_hidden() const { value::users_t::iterator user::replace_uses_of_with(value *before, value *after) { for(size_t i = 0; i < ops_.size(); i++) - if(ops_[i] == before) + if(ops_[i] == before){ ops_[i] = after; - after->add_use(this); + after->add_use(this); + } return before->erase_use(this); } diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 6de85ea25..a23d4125c 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -56,10 +56,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { return set_ret(bld_->create_dot(lhs, rhs, _0)); } case Token::MASKED_DEREF: { + // TODO: FIXME ir::type* ret_ty = GenIRType(binary->Type(), *ctx_); ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty()); + auto it = bld_->get_insert_block(); if(ret_ty->is_tile_ty()) false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes()); + bld_->set_insert_point(it); return set_ret(bld_->create_masked_load(rhs, lhs, false_value)); } case Token::ELLIPSIS: { @@ -274,9 +277,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) { if(ir::unmasked_load_inst* ld = dynamic_cast(true_val)) { if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty()) false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes()); - ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), - cond, - false_val); + ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(), cond, false_val); ld->replace_all_uses_with(new_ld); ld->erase_from_parent(); return set_ret(new_ld); @@ -468,10 +469,10 @@ void Generator::VisitForStmt(ForStmt *forStmt) { }); if(init_) VisitStmt(init_); -// VisitExpr(cond_); -// ir::value *cond = ret_; -// bld_->create_cond_br(cond, loop_bb, next_bb); - bld_->create_br(loop_bb); + VisitExpr(cond_); + ir::value *cond = ret_; + bld_->create_cond_br(cond, loop_bb, next_bb); +// bld_->create_br(loop_bb); bld_->set_insert_point(loop_bb); if(body_) VisitStmt(body_); diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index d73903c07..2b7062012 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -9,11 +9,13 @@ #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/liveness.h" #include "triton/codegen/analysis/align.h" +#include "triton/codegen/analysis/swizzle.h" #include "triton/codegen/transform/coalesce.h" #include "triton/codegen/transform/dce.h" #include "triton/codegen/transform/peephole.h" #include "triton/codegen/transform/membar.h" #include "triton/codegen/transform/reassociate.h" +#include "triton/codegen/transform/reorder.h" #include "triton/codegen/transform/cts.h" #include "triton/codegen/transform/disassociate.h" #include "triton/codegen/selection/generator.h" @@ -29,6 +31,7 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/print.h" +#include "triton/runtime/error.h" #include "triton/tools/bench.hpp" #include "triton/tools/sha1.hpp" #include "triton/tools/sys/getenv.hpp" @@ -67,7 +70,7 @@ void _loop_nest(std::vector const & ranges, /* OPTIONS */ /* --------------------- */ -std::string function::options_t::to_str() const{ +std::string options_t::to_str() const{ std::string ret = "nw-" + std::to_string(num_warps); for(const auto& x : defines){ ret += '-'; @@ -110,41 +113,41 @@ arg_type convert(ir::type *ty) { throw std::runtime_error("unknown type"); } -void function::caller::write(std::ofstream &ofs) { - // write name - ofs << name_ << std::endl; - // write signature - for(size_t i = 0; i < param_tys_.size(); i++) - ofs << param_tys_[i] << " "; - ofs << std::endl; - // write module - std::string source = ((driver::cu_module*)(&*parent_))->source(); - ofs << source; -} +//void function::caller::write(std::ofstream &ofs) { +// // write name +// ofs << name_ << std::endl; +// // write signature +// for(size_t i = 0; i < param_tys_.size(); i++) +// ofs << param_tys_[i] << " "; +// ofs << std::endl; +// // write module +// std::string source = ((driver::cu_module*)(&*parent_))->ptx(); +// ofs << source; +//} -void function::caller::read(std::ifstream &ifs) { - // read name - std::getline(ifs, name_); - // read signature - std::string line; - std::getline(ifs, line); - std::istringstream current(line); - int param; - param_tys_.clear(); - while(current >> param) - param_tys_.push_back((arg_type)param); - // read module - std::string src((std::istreambuf_iterator(ifs)), - std::istreambuf_iterator()); - parent_.reset(new driver::cu_module(src)); - bin_.reset(driver::kernel::create(&*parent_, name_.c_str())); +//void function::caller::read(driver::context* ctx, std::ifstream &ifs) { +// // read name +// std::getline(ifs, name_); +// // read signature +// std::string line; +// std::getline(ifs, line); +// std::istringstream current(line); +// int param; +// param_tys_.clear(); +// while(current >> param) +// param_tys_.push_back((arg_type)param); +// // read module +// std::string src((std::istreambuf_iterator(ifs)), +// std::istreambuf_iterator()); +// parent_.reset(new driver::cu_module(ctx, src)); +// bin_.reset(driver::kernel::create(&*parent_, name_.c_str())); -} +//} -function::caller::caller(std::ifstream &ifs, const options_t& opt) - : opt_(opt) { - read(ifs); -} +//function::caller::caller(driver::context* ctx, std::ifstream &ifs, const options_t& opt) +// : opt_(opt) { +// read(ctx, ifs); +//} function::caller::caller(ir::function *ir, std::shared_ptr parent, const options_t& opt) @@ -198,20 +201,23 @@ std::unique_ptr function::make_bin(ir::module &module, driver::d // generate llvm code llvm::LLVMContext ctx; std::unique_ptr llvm(new llvm::Module(module.get_name(), ctx)); + // optimizations + bool cts_use_async = target->as_nvidia()->sm() >= 80; // create passes codegen::analysis::align align; codegen::analysis::axes axes; + codegen::transform::cts cts(cts_use_async); codegen::transform::disassociate disassociate; codegen::analysis::layouts layouts(&axes, &align, opt.num_warps, target.get()); codegen::analysis::liveness liveness(&layouts); + codegen::analysis::swizzle swizzle(&layouts, target.get()); codegen::analysis::allocation allocation(&liveness); codegen::transform::membar barriers(&liveness, &layouts, &allocation); codegen::transform::dce dce; - codegen::transform::peephole peephole; + codegen::transform::peephole peephole(target.get()); codegen::transform::reassociate reassociate; codegen::transform::coalesce coalesce(&align, &layouts); - codegen::transform::cts cts; - codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps); + codegen::generator isel(&axes, &layouts, &align, &allocation, &swizzle, target.get(), opt.num_warps); // run passes dce.run(module); disassociate.run(module); @@ -233,17 +239,20 @@ std::unique_ptr function::make_bin(ir::module &module, driver::d } peephole.run(module); dce.run(module); -// ir::print(module, std::cout); align.run(module); axes.run(module); layouts.run(module); + swizzle.run(module); liveness.run(module); allocation.run(module); if(allocation.allocated_size() > device->max_shared_memory()) - throw std::runtime_error("using too much shared memory"); + throw exception::out_of_shared_memory(); barriers.run(module); +// ir::print(module, std::cout); isel.visit(module, *llvm); std::unique_ptr res(driver::module::create(device, std::move(llvm))); + if(res->spilled() > 256) + throw exception::out_of_registers(); return res; } @@ -265,11 +274,11 @@ void function::make(driver::device *device, options_t opt) { auto ir = make_ir(parser); // triton-ir -> binary std::unique_ptr bin; -// try{ + try{ bin = make_bin(*ir, device, opt); -// }catch(const std::runtime_error&){ -// return nullptr; -// } + }catch(const exception::base&){ + throw; + } // create callable ir::function *tmp = ir->get_function_list()[0]; callers_[opt].reset(new caller(tmp, std::move(bin), opt)); @@ -283,6 +292,7 @@ void function::precompile(driver::device* device, const options_space_t& space) for(const auto& x: space.defines) ranges.push_back(x.second.size()); // functor for source with given option + std::map err; auto do_make = [&](std::vector params) { // compilation options unsigned i = 0; @@ -291,20 +301,73 @@ void function::precompile(driver::device* device, const options_space_t& space) for(auto D: space.defines) opt.defines[D.first] = D.second[params[i++]]; // compile - make(device, opt); + try{ + make(device, opt); + }catch(const exception::base& e){ + err[opt] = e.what(); + } }; // multi-threaded compilation _loop_nest(ranges, do_make); - if(callers_.empty()) - throw std::runtime_error("could not compile kernel"); + if(callers_.empty()){ + std::ostringstream dbg; + dbg << "Auto-Tuner could not find any valid configuration:" << std::endl; + for(auto x: err){ + dbg << "[ "; + dbg << x.first.num_warps << ", "; + dbg << "{ "; + for(const auto& y: x.first.defines) + dbg << '"' << y.first << "\"= \"" << y.second << "\", "; + dbg << " } ] -> " << x.second << std::endl; + } + throw exception::no_valid_configuration(dbg.str()); + } } -std::string function::ptx(driver::device* device, const options_t& opt) { +std::string function::get_asm(asm_mode_t mode, driver::device* device, const options_t& opt) { make(device, opt); const auto& fn = callers_.at(opt); if(!fn) return ""; - return ((driver::cu_module*)fn->parent())->source(); + switch(mode){ + case ASM_LLIR:{ + return fn->parent()->llir(); + } + case ASM_NV_PTX: + case ASM_NV_SASS:{ + std::string ptx = ((driver::cu_module*)fn->parent())->ptx(); + // SASS + std::string input = std::tmpnam(nullptr); + std::string output = std::tmpnam(nullptr); + std::ofstream ofs(input); + ofs << ptx; + ofs.close(); + if(mode == ASM_NV_PTX) + return ptx; + std::string cmd; + int err; + // compile ptx + driver::cu_device* cu_device = (driver::cu_device*)device; + cmd = "ptxas --gpu-name=sm_" + std::to_string(cu_device->compute_capability()) + " " + input + " -o " + input + ".o"; + err = system(cmd.c_str()); + // disassemble + cmd = "cuobjdump --dump-sass " + input + ".o >> " + output; + err = system(cmd.c_str()); + std::regex comment(" *\\/\\* 0x[0-9a-f]+ \\*\\/"); + std::string to_delete = " /*"; + std::ifstream ifs(output); + std::string line; + std::string sass; + while(std::getline(ifs, line)) + if(!std::regex_match(line, comment)) + sass += line + "\n"; + return sass; + } + default: + return ""; + } + + } // returns program with best compilation options for given parameter diff --git a/python/examples/batchnorm.py b/python/examples/batchnorm.py deleted file mode 100644 index a69d127c4..000000000 --- a/python/examples/batchnorm.py +++ /dev/null @@ -1,56 +0,0 @@ -import triton -import numpy as np -from enum import Enum - -class MODE(Enum): - TF = 1 - TORCH = 2 - -try: - import tensorflow as tf - mode = MODE.TF -except ModuleNotFoundError: - pass - -try: - import torch - mode = MODE.TORCH -except ModuleNotFoundError: - pass - - -C, H, W, B = 32, 1, 1, 128 - -x = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32) -gamma = np.random.uniform(-1, 1, C).astype(np.float32) -beta = np.random.uniform(-1, 1, C).astype(np.float32) -dy = np.random.uniform(-1, 1, (C, H, W, B)).astype(np.float32) - -if mode == MODE.TORCH: - fw_x = torch.from_numpy(x).cuda() - fw_gamma = torch.from_numpy(gamma).cuda() - fw_beta = torch.from_numpy(beta).cuda() - fw_dy = torch.from_numpy(dy).cuda() - # register gradients - fw_x.requires_grad_(True) - fw_gamma.requires_grad_(True) - fw_beta.requires_grad_(True) - # execute - fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4) - fw_y.backward(fw_dy) - -if mode == MODE.TF: - fw_x = tf.placeholder(shape=x.shape, dtype=x.dtype) - fw_gamma = tf.placeholder(shape=gamma.shape, dtype=gamma.dtype) - fw_beta = tf.placeholder(shape=beta.shape, dtype=beta.dtype) - fw_dy = tf.placeholder(shape=dy.shape, dtype=dy.dtype) - # execute - fw_y = triton.ops.batchnorm(fw_x, fw_gamma, fw_beta, 1e-4) - fw_mean, fw_var = tf.nn.moments(fw_x, [1, 2, 3]) - fw_dx, fw_dgamma, fw_dbeta = tf.gradients(fw_y, [fw_x, fw_gamma, fw_beta], fw_dy) - # run - sess = tf.InteractiveSession() - feed_dict = {fw_x: x, fw_gamma: gamma, fw_beta: beta, fw_dy: dy} - sess.run(tf.global_variables_initializer()) - result = sess.run([fw_dx, fw_dgamma, fw_dbeta], feed_dict=feed_dict) - print(result) \ No newline at end of file diff --git a/python/examples/einsum.py b/python/examples/einsum.py deleted file mode 100644 index 1c6e078d1..000000000 --- a/python/examples/einsum.py +++ /dev/null @@ -1,213 +0,0 @@ -import triton -import torch -from torch.utils.cpp_extension import load -import numpy as np -#import utils -from time import time - -torch.manual_seed(0) - -#torch.backends.cudnn.benchmark = True - -configs = [] - -# Matrix multiplication -MNK = [ - (512, 512 ,512), - (2048, 2048, 2048), - #(8192, 8192, 8192), - - (64, 64, 64000), - (64, 64, 128000), - (256, 256, 64000), - (256, 256, 128000), - - (1536, 16, 1536), - (1536, 32, 1536), - (1536, 64, 1536), - # (1536, 128, 1536), - # (4096, 16, 4096), - # (4096, 32, 4096), - # (4096, 64, 4096), - # (4096, 128, 4096), - - # (127008, 768, 576) - ] -for M, N, K in MNK: - matmul = lambda a, b: torch.matmul(a, b) - configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict(), None, None, None)] -#for M, N, K in MNK: -# matmul = lambda a, b: torch.matmul(a.t(), b) -# configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict(), None, None, None)] -#for M, N, K in MNK: -# matmul = lambda a, b: torch.matmul(a, b.t()) -# configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict(), None, None, None)] - -# Relative attention -NTHSE = [ - (16, 512, 1, 64, 64), - # (16, 512, 1, 128, 128), - # (16, 512, 1, 256, 256), - # (16, 512, 1, 256, 512), - (16, 512, 8, 64, 64), - # (16, 512, 8, 128, 128), - # (16, 512, 8, 256, 256), - # (16, 512, 8, 256, 512), - - # (64, 1024, 1, 64, 64), - (64, 1024, 1, 128, 128), - # (64, 1024, 1, 256, 256), - # (64, 1024, 1, 256, 512), - # (64, 1024, 8, 64, 64), - (64, 1024, 8, 128, 128), - # (64, 1024, 8, 256, 256), - # (64, 1024, 8, 256, 512), - - # (128, 1024, 1, 64, 64), - # (128, 1024, 1, 128, 128), - # (128, 1024, 1, 256, 256), - (128, 1024, 1, 256, 512), - # (128, 1024, 8, 64, 64), - # (128, 1024, 8, 128, 128), - # (128, 1024, 8, 256, 256), - #(128, 1024, 8, 256, 512) - ] -#for N, T, H, S, E in NTHSE: -# configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict(), None, None, None)] -#for N, T, H, S, E in NTHSE: -# configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict(), None, None, None)] -#for N, T, H, S, E in NTHSE: -# configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict(), None, None, None)] - -# 1D Dense convolution -NCHKR = [ - #(1, 1152, 12602, 512, 3) - ] -for N, C, H, K, R in NCHKR: - torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1)) - configs += [([N, C, H], - [C, R, K], - [N, K, H - R + 1], - torch_fn, - 'nc(h+r),crk->nkh', - dict(), None, None, None)] - -# 2D Dense convolution -NCHWKRS = [ - #(8, 64, 128, 128, 768, 3, 3), - #(128, 3, 32, 32, 64, 3, 3), - #(1, 1024, 32, 112, 112, 1024, 3, 3), - #(8, 512, 32, 32, 1024, 3, 3) - ] -for N, C, G, H, W, K, R, S in NCHWKRS: - stride = 2 - torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2), stride=stride, groups=G) - P = (H - R + 1) // stride - Q = (W - S + 1) // stride - transform_a = lambda a: a.view(N, G, C // G, H, W) - transform_b = lambda b: b.view(C // G, R, S, G, K // G) - transform_c = lambda c: c.view(N, K, P, Q) - configs += [([N, C, H, W], - [C // G, R, S, K], - [N, G, K // G, P, Q], - torch_fn, - 'ngc(h*2+r)(w*2+s),crsgk->ngkhw', - dict(), transform_a, transform_b, transform_c)] - - -# 3D Dense Convolution -NCDHWKTRS = [ - #(8, 32, 27, 100, 100, 64, 3, 3, 3), - #(8, 64, 23, 48, 48, 256, 3, 3, 3), - #(8, 256, 19, 22, 22, 640, 3, 3, 3), - #(8, 640, 15, 36, 36, 384, 3, 3, 3) - ] -for N, C, D, H, W, K, T, R, S in NCDHWKTRS: - torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3)) - configs += [([N, C, D, H, W], - [C, T, R, S, K], - [N, K, D - T + 1, H - R + 1, W - R + 1], - torch_fn, - 'nc(d+t)(h+r)(w+s),ctrsk->nkdhw', - dict(), None, None, None)] - - -# Shift convolution -shift_cuda = torch.utils.cpp_extension.load( - 'shift_cuda', ['kernels/shift_cuda.cpp', - 'kernels/shift_cuda_kernel.cu'], - extra_cflags=['-O3']) -class shift(torch.autograd.Function): - @staticmethod - def forward(ctx, x, shift): - ctx.save_for_backward(shift) - return shift_cuda.forward(x, shift) - - @staticmethod - def backward(ctx, grad_output): - shift, = ctx.saved_tensors - grad_output = shift_cuda.backward(grad_output, shift) - - return grad_output, None - -NCHWKRS = [ - #(8, 64, 128, 128, 128, 3, 3), - #(8, 128, 64, 64, 256, 3, 3), - #(8, 256, 32, 32, 512, 3, 3), - #(8, 512, 32, 32, 1024, 3, 3) - ] -for N, C, H, W, K, R, S in NCHWKRS: - shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2 - shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2 - def shift_conv(a, b, **kwargs): - shift_h, shift_w = kwargs['sh'], kwargs['sw'] - shift_torch = np.column_stack((shift_w*-1, shift_h*-1)) - shift_torch = torch.from_numpy(shift_torch).cuda() - a = shift.apply(a, shift_torch) - b = b.permute(1, 0) - b = b.reshape(b.shape[0], b.shape[1], 1, 1) - return torch.nn.functional.conv2d(a, b) - configs += [([N, C, H, W], - [C, K], - [N, K, H, W], - shift_conv, - 'nc(h + sh[c])(w + sw[c]),ck->nkhw', - {'sh': shift_h, 'sw': shift_w}, - None, None, None)] - -# Benchmark -torch.set_num_threads(1) -for a_shape, b_shape, c_shape, torch_fn, expr, arrays, \ - transform_a, transform_b, transform_c in configs: - dtype = torch.cuda.FloatTensor - # initialize input tensors - a = torch.rand(*a_shape).type(dtype).cuda() - b = torch.rand(*b_shape).type(dtype).cuda() - # reference output - if torch_fn: - rc = torch_fn(a, b, **arrays) - else: - rc = torch.einsum(expr, a, b) - # triton output - ta = a if transform_a is None else transform_a(a) - tb = b if transform_b is None else transform_b(b) - tc = torch.empty(c_shape, device=a.device) - triton.ops.einsum(expr, ta, tb, tc, arrays = arrays, bench = True) - ctx = triton.ops._einsum.registry[tc] - tc = tc if transform_c is None else transform_c(tc) - # performance relative to equivalent matrix multiplication - B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K - cmp_eqbmm = True - if cmp_eqbmm: - a = torch.rand(B, M, K).type(dtype).cuda() - b = torch.rand(B, K, N).type(dtype).cuda() - c = torch.empty((B, M, N), device=a.device).cuda() - tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, c, bench = True) - ratio = triton.ops._einsum.registry[tmmc].forward_ms / ctx.forward_ms - cmp_str = f'({ratio:4.2f})' - else: - cmp_str = '' - # test and benchmark - bench = 2. * B * M * N * K / ctx.forward_ms * 1e-3 - diff = (tc - rc).abs().max() / rc.abs().max() - print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}') diff --git a/python/examples/kernels/shift_cuda.cpp b/python/examples/kernels/shift_cuda.cpp deleted file mode 100644 index b7a769feb..000000000 --- a/python/examples/kernels/shift_cuda.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include - -#include - -// CUDA forward declarations - -at::Tensor shift_cuda_forward( - const at::Tensor input, - const at::Tensor shift); - -at::Tensor shift_cuda_backward( - const at::Tensor grad_input, - const at::Tensor shift); - -// C++ interface - -// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -at::Tensor shift_forward( - const at::Tensor input, - const at::Tensor shift) { - CHECK_INPUT(input); - CHECK_INPUT(shift); - - return shift_cuda_forward(input, shift); -} - -at::Tensor shift_backward( - const at::Tensor grad_input, - const at::Tensor shift) { - CHECK_INPUT(grad_input); - CHECK_INPUT(shift); - return shift_cuda_backward(grad_input, shift); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &shift_forward, "Shift forward (CUDA)"); - m.def("backward", &shift_backward, "Shift backward (CUDA)"); -} diff --git a/python/examples/kernels/shift_cuda_kernel.cu b/python/examples/kernels/shift_cuda_kernel.cu deleted file mode 100644 index ca56b6b0f..000000000 --- a/python/examples/kernels/shift_cuda_kernel.cu +++ /dev/null @@ -1,111 +0,0 @@ -#include - -#include -#include - -#include - -namespace { -template -__global__ void shift_cuda_forward_kernel( - const scalar_t* __restrict__ input, - const int32_t* __restrict__ shift, - scalar_t* __restrict__ output, - const int32_t B, - const int32_t C, - const int32_t H, - const int32_t W) { - const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - const int32_t size = B*C*H*W; - - const int32_t CHW = C*H*W; - const int32_t HW = H*W; - const int32_t b = idx / CHW; - const int32_t c = (idx - b*CHW) / HW; - const int32_t h = (idx - b*CHW - c*HW) / W; - const int32_t w = idx - b*CHW - c*HW - h*W; - const int32_t target_w = w + shift[2*c]; - const int32_t target_h = h + shift[2*c + 1]; - const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w; - if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) { - output[target_idx] = input[idx]; - } -} - -template -__global__ void shift_cuda_backward_kernel( - const scalar_t* __restrict__ grad_input, - scalar_t* __restrict__ grad_output, - const int32_t* __restrict__ shift, - const int32_t B, - const int32_t C, - const int32_t W, - const int32_t H) { - const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; - const int32_t size = B*C*W*H; - const int32_t CWH = C*W*H; - const int32_t WH = W*H; - const int32_t b = idx / CWH; - const int32_t c = (idx - b*CWH) / WH; - const int32_t w = (idx - b*CWH - c*WH) / W; - const int32_t h = idx - b*CWH - c*WH - w*H; - const int32_t target_w = w - shift[2*c]; - const int32_t target_h = h - shift[2*c + 1]; - const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h; - if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) { - grad_output[target_idx] = grad_input[idx]; - } -} -} // namespace - -at::Tensor shift_cuda_forward( - const at::Tensor input, - const at::Tensor shift) { - const auto B = input.size(0); - const auto C = input.size(1); - const auto H = input.size(2); - const auto W = input.size(3); - const auto size = B*C*W*H; - const int threads = 1024; - const int blocks = (size + threads - 1) / threads; - auto output = at::zeros_like(input); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] { - shift_cuda_forward_kernel<<>>( - input.data(), - shift.data(), - output.data(), - B, - C, - H, - W); - })); - - return output; -} - -at::Tensor shift_cuda_backward( - const at::Tensor grad_input, - const at::Tensor shift) { - const auto B = grad_input.size(0); - const auto C = grad_input.size(1); - const auto H = grad_input.size(2); - const auto W = grad_input.size(3); - const auto size = B*C*W*H; - const int threads = 1024; - const int blocks = (size + threads - 1) / threads; - auto grad_output = at::zeros_like(grad_input); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] { - shift_cuda_backward_kernel<<>>( - grad_input.data(), - grad_output.data(), - shift.data(), - B, - C, - H, - W); - })); - - return grad_output; -} diff --git a/python/examples/test.py b/python/examples/test.py deleted file mode 100644 index c2ff7d473..000000000 --- a/python/examples/test.py +++ /dev/null @@ -1,109 +0,0 @@ -import triton -import numpy -import torch -import itertools - -torch.manual_seed(0) -numpy.random.seed(0) - -def to_sparse(expr, data, layout, shape, block): - # shape of result - sparse = None - shape_ret = [] - for i, d in enumerate(expr): - if d.isupper() and sparse is None: - sparse = i - shape_ret.append(int(layout.sum())) - if d.isupper(): - shape_ret.append(block[d]) - else: - shape_ret.append(shape[i]) - # iterator - steps = [block[d] if d.isupper() else 1 for d in expr] - it = [range(0, shape[i], steps[i]) for i in range(len(expr))] - # create result - ret = torch.empty(*shape_ret, dtype=data.dtype, device=data.device) - blockid = 0 - nzblockid = 0 - for curr in itertools.product(*it): - if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]): - blockid = 0 - nzblockid = 0 - data_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))] - ret_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))] - ret_slice.insert(sparse, nzblockid) - if int(layout.view(-1)[blockid]): - ret[ret_slice] = data[data_slice] - nzblockid += 1 - blockid += 1 - return ret - -def to_dense(expr, data, layout, shape, block): - sparse = None - for i, d in enumerate(expr): - if d.isupper() and sparse is None: - sparse = i - - ret = torch.zeros(*shape, dtype=data.dtype, device=data.device) - steps = [block[d] if d.isupper() else 1 for d in expr] - it = [range(0, shape[i], steps[i]) for i in range(len(expr))] - blockid = 0 - nzblockid = 0 - for curr in itertools.product(*it): - if all([curr[i] == it[i][0] for i in range(len(curr)) if expr[i].isupper()]): - blockid = 0 - nzblockid = 0 - ret_slice = [slice(curr[i], curr[i] + steps[i], 1) for i in range(len(curr))] - data_slice = [slice(0, block[expr[i]], 1) if expr[i].isupper() else slice(curr[i], curr[i] + 1) for i in range(len(curr))] - data_slice.insert(sparse, nzblockid) - if int(layout.view(-1)[blockid]): - ret[ret_slice] = data[data_slice] - nzblockid += 1 - blockid += 1 - return ret - -def test_expr(expr, shape, blocks): - # decompose expr - expr_a, expr_bc = expr.split(",") - expr_b, expr_c = expr_bc.split("->") - # check with argument is sparse - sparse_a = any(x.isupper() for x in expr_a) - sparse_b = any(x.isupper() for x in expr_b) - sparse_c = any(x.isupper() for x in expr_c) - # allocate data - shape_a = [shape[d.lower()] for d in expr_a] - shape_b = [shape[d.lower()] for d in expr_b] - shape_c = [shape[d.lower()] for d in expr_c] - ref_a = torch.rand(*shape_a, device='cuda') - ref_b = torch.rand(*shape_b, device='cuda') - ref_c = torch.zeros(*shape_c, device='cuda') - # layouts - layout_a = [shape[d.lower()]//blocks[d] for d in expr_a if d.isupper()] - layout_b = [shape[d.lower()]//blocks[d] for d in expr_b if d.isupper()] - layout_c = [shape[d.lower()]//blocks[d] for d in expr_c if d.isupper()] - layout_a = torch.randint(0, 2, layout_a, device='cuda') - layout_b = torch.randint(0, 2, layout_b, device='cuda') - layout_c = torch.randint(0, 2, layout_c, device='cuda') - # triton computation - triton_a = to_sparse(expr_a, ref_a, layout_a, shape_a, blocks) if sparse_a else ref_a - triton_b = to_sparse(expr_b, ref_b, layout_b, shape_b, blocks) if sparse_b else ref_b - layouts = {expr_a: layout_a, expr_b: layout_b, expr_c: layout_c} - triton_c = triton.ops.einsum(expr, triton_a, triton_b, layouts, blocks) - torch.cuda.synchronize() - # reference computation - ref_a = to_dense(expr_a, triton_a, layout_a, shape_a, blocks) if sparse_a else ref_a - ref_b = to_dense(expr_b, triton_b, layout_b, shape_b, blocks) if sparse_b else ref_b - ref_c = torch.einsum(expr.lower(), ref_a, ref_b) - if sparse_c: - ref_c = to_sparse(expr_c, ref_c, layout_c, shape_c, blocks) - torch.cuda.synchronize() - print((ref_c - triton_c).abs().max()) - - - - -# shape characteristics -test_expr('bHMK,bhkn->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'K': 32}) -test_expr('bhmk,bHKN->bhmn', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'K': 32, 'N': 32}) -test_expr('bhmk,bhkn->bHMN', {'b': 2, 'h': 2, 'm': 256, 'k': 256, 'n': 256}, {'H': 1, 'M': 32, 'N': 32}) - diff --git a/python/examples/tutorials/vec_add.py b/python/examples/tutorials/add.py similarity index 100% rename from python/examples/tutorials/vec_add.py rename to python/examples/tutorials/add.py diff --git a/python/examples/tutorials/conv2d.py b/python/examples/tutorials/conv2d.py index 8e8f38491..c30009cdf 100644 --- a/python/examples/tutorials/conv2d.py +++ b/python/examples/tutorials/conv2d.py @@ -171,7 +171,7 @@ class _conv(torch.autograd.Function): _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines)) delta, kernel = _conv.kernel[dtype] # allocate output - c = triton.empty([Z, CO, P, Q], dtype=dtype) + c = torch.empty([Z, CO, P, Q], dtype=dtype) # enqueue grid = lambda opt: [triton.cdiv(Z*P*Q, opt.d('TM')), triton.cdiv(CO, opt.d('TN'))] diff --git a/python/examples/tutorials/mat_copy.py b/python/examples/tutorials/copy.py similarity index 100% rename from python/examples/tutorials/mat_copy.py rename to python/examples/tutorials/copy.py diff --git a/python/examples/tutorials/mat_mul.py b/python/examples/tutorials/matmul.py similarity index 77% rename from python/examples/tutorials/mat_mul.py rename to python/examples/tutorials/matmul.py index 4acbebb11..9b6904260 100644 --- a/python/examples/tutorials/mat_mul.py +++ b/python/examples/tutorials/matmul.py @@ -3,6 +3,9 @@ import triton class _dot(torch.autograd.Function): src = """ +#define STM 4 +#define STN 4 + __global__ void dot(TYPE * A __noalias __readonly __aligned(16), TYPE * B __noalias __readonly __aligned(16), TYPE * C __noalias __aligned(16), @@ -14,20 +17,26 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), int ldb __multipleof(8), int ldc __multipleof(8)) { // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int ridz = get_program_id(2); - int gridx = M / TM; - int gridy = N / TN; - int rid = ridx + ridy * gridx; - ridx = rid / gridy; - ridy = rid % gridy; - int rm[TM] = ridx * TM + 0 ... TM; - int rn[TN] = ridy * TN + 0 ... TN; + int pid = get_program_id(0); + int pidz = get_program_id(2); + int gridm = M / TM; + int gridn = N / TN; + int stgridm = (gridm + STM - 1) / STM; + int stgridn = (gridn + STN - 1) / STN; + int stid = pid / (STM * STN); + int laneid = pid % (STM * STN); + int stm = stid / stgridn; + int stn = stid % stgridn; + int lanem = laneid / STN; + int lanen = laneid % STN; + int pidm = stm*STM + lanem; + int pidn = stn*STN + lanen; + int rm[TM] = pidm * TM + 0 ... TM; + int rn[TN] = pidn * TN + 0 ... TN; // reduction splitting K = K / TZ; - int rk[TK] = ridz * K + 0 ... TK; + int rk[TK] = pidz * K + 0 ... TK; // pointers to operands int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; @@ -44,11 +53,11 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), // reduction loop float acc[TM, TN] = 0; for(int k = K; k > 0; k -= TK){ - acc += a @ b; bool checka[TM, TK] = k > TK; bool checkb[TK, TN] = k > TK; pa += TK * STRIDE_AK; pb += TK * STRIDE_BK; + acc += a @ b; a = *?(checka)pa; b = *?(checkb)pb; } @@ -56,8 +65,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), TYPE c[TM, TN] = acc; // epilogue - int rxm[TM] = ridx * TM + 0 ... TM; - int rxn[TN] = ridy * TN + 0 ... TN; + int rxm[TM] = pidm * TM + 0 ... TM; + int rxn[TN] = pidn * TN + 0 ... TN; int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :]; TYPE* pc[TM, TN] = C + offc; bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); @@ -66,7 +75,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), *?(checkc) pc = c; #else // accumulate partial result using spin-locks - int *plock = locks + rid; + int *plock = locks + pid; int *pcount = plock + get_num_programs(0) * get_num_programs(1); for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); int count = *pcount; @@ -100,7 +109,7 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), 'STRIDE_BN': '1', 'STRIDE_BK': 'ldb', 'TM' : [128], 'TN' : [128], - 'TK' : [16], + 'TK' : [32], 'TZ' : [1] } _dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines) @@ -109,9 +118,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), M, K = a.shape K, N = b.shape c = torch.empty([M,N], dtype=dtype, device=a.device) + print(kernel.asm('sass', c.device)) + print(kernel.asm('ptx', c.device)) # enqueue - grid = lambda opt: [triton.cdiv(M, opt.d('TM')), - triton.cdiv(N, opt.d('TN'))] + grid = lambda opt: [triton.cdiv(M, opt.d('TM'))*triton.cdiv(N, opt.d('TN'))] time = kernel(a, b, c, 1., M, N, K, a.stride(0), b.stride(0), c.stride(0), grid=grid) return c @@ -130,6 +140,4 @@ b = torch.rand((K, N)).cuda().half() zc = torch.matmul(a,b) zc_ = dot(a,b) - - print(torch.allclose(zc, zc_)) diff --git a/python/examples/tutorials/mat_transpose.py b/python/examples/tutorials/trans.py similarity index 100% rename from python/examples/tutorials/mat_transpose.py rename to python/examples/tutorials/trans.py diff --git a/python/setup.py b/python/setup.py index d12e45dba..898b7f6c5 100644 --- a/python/setup.py +++ b/python/setup.py @@ -111,7 +111,7 @@ setup( author_email='ptillet@g.harvard.edu', description='A language and compiler for custom Deep Learning operations', long_description='', - packages=['triton', 'triton/_C', 'triton/ops'], + packages=['triton', 'triton/_C'], install_requires=['numpy', 'torch', 'sympy'], package_data={'': data}, ext_modules=[CMakeExtension('triton', 'triton/_C/')], diff --git a/python/src/bindings.cc b/python/src/bindings.cc index d67ddc533..f6f34836d 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -38,7 +38,7 @@ void delete_grid(const map_key_t& key) { void register_fn(const map_key_t& key, const std::string& src, - const rt::function::options_space_t& opt) { + const rt::options_space_t& opt) { if(id_fn_map.find(key) == id_fn_map.end()) id_fn_map[key].reset(new rt::function(src, opt, "")); } @@ -47,9 +47,9 @@ void delete_fn(const map_key_t& key) { id_fn_map.erase(key); } -std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt) { - triton::driver::cu_device device(torch_get_cuda_device(key.second), false); - return id_fn_map[key]->ptx(&device, opt); +std::string get_fn_asm(const map_key_t& key, rt::asm_mode_t mode, const rt::options_t& opt) { + triton::driver::cu_device device(key.second, false); + return id_fn_map[key]->get_asm(mode, &device, opt); } void cleanup() { @@ -63,7 +63,7 @@ size_t make_op_id() { /* Function signature */ void make_module(const std::string& src, ir::module* ir, - const runtime::function::options_space_t& opt) { + const runtime::options_space_t& opt) { std::string copy = triton::runtime::function::preheader() + src; // pre-process TokenSequence tokens; @@ -80,7 +80,7 @@ void make_module(const std::string& src, ir::module* ir, } std::vector get_fn_signature(const std::string& src, - const runtime::function::options_space_t& opt) { + const runtime::options_space_t& opt) { // triton-ir code-gen ir::context ctx; auto ir = std::shared_ptr(new ir::module("", ctx)); @@ -95,8 +95,8 @@ std::vector get_fn_signature(const std::string& src, return ret; } -typedef triton::runtime::function::options_t options_t; -typedef triton::runtime::function::options_space_t options_space_t; +typedef triton::runtime::options_t options_t; +typedef triton::runtime::options_space_t options_space_t; PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; @@ -112,6 +112,10 @@ PYBIND11_MODULE(libtriton, m) { .value("float", rt::FLOAT_T) .value("double", rt::DOUBLE_T) .value("buffer", rt::BUFFER_T); + + pybind11::enum_(m, "asm_mode") + .value("ptx", rt::ASM_NV_PTX) + .value("sass", rt::ASM_NV_SASS); pybind11::class_(m, "options") .def(pybind11::init<>()) @@ -126,7 +130,7 @@ PYBIND11_MODULE(libtriton, m) { // hooks into triton constructs since frameworks may not use pybind11 m.def("get_fn_signature", &get_fn_signature); - m.def("get_fn_ptx", &get_fn_ptx); + m.def("get_fn_asm", &get_fn_asm); m.def("register_grid", ®ister_grid); m.def("delete_grid", &delete_grid); m.def("register_fn", ®ister_fn); diff --git a/python/src/launch.cc b/python/src/launch.cc index ad0cac7e9..999d9c595 100644 --- a/python/src/launch.cc +++ b/python/src/launch.cc @@ -59,16 +59,12 @@ void synchronize(int64_t dev_id) { } } -torch::Tensor raw_like(torch::Tensor x){ +torch::Tensor cuda_empty_like(torch::Tensor x){ if(x.nbytes() == 0) return torch::empty_like(x); - C10_CUDA_CHECK(cudaSetDevice(x.device().index())); - auto shape = x.sizes(); - CUdeviceptr data; - triton::driver::dispatch::cuMemAlloc(&data, x.nbytes()); - auto deleter = [data](void* ptr) { triton::driver::dispatch::cuMemFree_v2(data); }; - auto ret = torch::from_blob((void*)data, shape, deleter, x.options()); - ret.copy_(x); + void* data; + cudaMalloc(&data, x.nbytes()); + auto ret = torch::from_blob((void*)data, x.sizes(), x.strides(), [data](void* ptr) { cudaFree(data); }, x.options()); return ret; } @@ -94,6 +90,6 @@ void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args, static auto registry = torch::RegisterOperators() .op("triton::launch_kernel", &launch_kernel) - .op("triton::raw_like", &raw_like) + .op("triton::cuda_empty_like", &cuda_empty_like) .op("triton::cdiv_sum", &cdiv_sum) .op("triton::synchronize", &synchronize); diff --git a/python/triton/__init__.py b/python/triton/__init__.py index b1e9e8bc6..a3c74c5f3 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -1,7 +1,4 @@ from .kernel import * -import triton.ops -#import triton.nn - # clean-up libtriton resources import atexit diff --git a/python/triton/kernel.py b/python/triton/kernel.py index c002c803e..83ea60a58 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -68,8 +68,17 @@ class kernel: size = sum([sizes[x] for x in arg_types]) self.tys = ''.join([codes[x] for x in arg_types]) - def ptx(self, device, **kwargs): + def asm(self, mode, device, **kwargs): dev_id = device.index + # assembly mode + supported = { + 'ptx': libtriton.asm_mode.ptx, + 'sass': libtriton.asm_mode.sass, + } + if mode not in supported: + raise('ASM mode must be in ', supported.keys()) + mode = supported[mode] + # disambiguates #defines libtriton.register_fn((self.op_id, dev_id), self.src, self.opt) def _single_value_or_err(x, key): if isinstance(x, list) and len(x) == 1: @@ -86,15 +95,18 @@ class kernel: opt = libtriton.options() opt.num_warps = _single_value_or_err(self.opt.num_warps, 'num_warps') opt.defines = defines - return libtriton.get_fn_ptx((self.op_id, dev_id), opt) + # run + return libtriton.get_fn_asm((self.op_id, dev_id), mode, opt) def __call__(self, *args, **kwargs): if 'TRITON_DEBUG_MODE' in os.environ: _args = args - args = [x for x in args] + args = [x.clone() if isinstance(x, torch.Tensor) else x for x in _args] for i in range(len(args)): if isinstance(args[i], torch.Tensor): - args[i] = torch.ops.triton.raw_like(args[i]) + args[i] = torch.ops.triton.cuda_empty_like(args[i]) + args[i].copy_(_args[i]) + torch.cuda.synchronize() for x in args: if isinstance(x, torch.Tensor): device = x.device.index @@ -116,6 +128,8 @@ class kernel: constants = list(kwargs['constants'].values()) if 'constants' in kwargs else [] torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants) if 'TRITON_DEBUG_MODE' in os.environ: + torch.cuda.synchronize() for i in range(len(args)): if isinstance(args[i], torch.Tensor): - _args[i].copy_(args[i]) \ No newline at end of file + _args[i].copy_(args[i].clone()) + args = _args \ No newline at end of file diff --git a/python/triton/nn/__init__.py b/python/triton/nn/__init__.py deleted file mode 100644 index 84c1a0a78..000000000 --- a/python/triton/nn/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .conv import replace_conv2d -from .attention import replace_mah \ No newline at end of file diff --git a/python/triton/nn/attention.py b/python/triton/nn/attention.py deleted file mode 100644 index 84d198b2e..000000000 --- a/python/triton/nn/attention.py +++ /dev/null @@ -1,312 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import triton - -def bmm(x, w, mask = None): - b, m, k = x.size() - b, k, n = w.size() - out = torch.empty([b, m, n], device=x.device) - triton.ops.einsum('bmk,bkn->bmn', x, w, out, mask=mask, bench=False) - return out - -def multi_head_attention_forward(query, # type: Tensor - key, # type: Tensor - value, # type: Tensor - embed_dim_to_check, # type: int - num_heads, # type: int - in_proj_weight, # type: Tensor - in_proj_bias, # type: Tensor - bias_k, # type: Optional[Tensor] - bias_v, # type: Optional[Tensor] - add_zero_attn, # type: bool - dropout_p, # type: float - out_proj_weight, # type: Tensor - out_proj_bias, # type: Tensor - training=True, # type: bool - key_padding_mask=None, # type: Optional[Tensor] - need_weights=True, # type: bool - attn_mask=None, # type: Optional[Tensor] - use_separate_proj_weight=False, # type: bool - q_proj_weight=None, # type: Optional[Tensor] - k_proj_weight=None, # type: Optional[Tensor] - v_proj_weight=None, # type: Optional[Tensor] - static_k=None, # type: Optional[Tensor] - static_v=None, # type: Optional[Tensor] - acc_bitmask=None - ): - # type: (...) -> Tuple[Tensor, Optional[Tensor]] - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - See "Attention Is All You Need" for more details. - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - bias_k, bias_v: bias of the key and value sequences to be added at dim=0. - add_zero_attn: add a new batch of zeros to the key and - value sequences at dim=1. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: mask that prevents attention to certain positions. This is an additive mask - (i.e. the values will be added to the attention layer). - use_separate_proj_weight: the function accept the proj. weights for query, key, - and value in differnt forms. If false, in_proj_weight will be used, which is - a combination of q_proj_weight, k_proj_weight, v_proj_weight. - q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. - static_k, static_v: static key and value used for attention operators. - - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. - - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, - N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, - N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size() == value.size() - - head_dim = embed_dim // num_heads - assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - - if not use_separate_proj_weight: - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = F.linear(query, _w, _b) - - if key is None: - assert value is None - k = None - v = None - else: - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = F.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = F.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = F.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = F.linear(value, _w, _b) - else: - q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) - len1, len2 = q_proj_weight_non_opt.size() - assert len1 == embed_dim and len2 == query.size(-1) - - k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) - len1, len2 = k_proj_weight_non_opt.size() - assert len1 == embed_dim and len2 == key.size(-1) - - v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) - len1, len2 = v_proj_weight_non_opt.size() - assert len1 == embed_dim and len2 == value.size(-1) - - if in_proj_bias is not None: - q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) - k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) - v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) - else: - q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) - k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) - v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) - q = q * scaling - - if bias_k is not None and bias_v is not None: - if static_k is None and static_v is None: - k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) - v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) - if attn_mask is not None: - attn_mask = torch.cat([attn_mask, - torch.zeros((attn_mask.size(0), 1), - dtype=attn_mask.dtype, - device=attn_mask.device)], dim=1) - if key_padding_mask is not None: - key_padding_mask = torch.cat( - [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), - dtype=key_padding_mask.dtype, - device=key_padding_mask.device)], dim=1) - else: - assert static_k is None, "bias cannot be added to static key." - assert static_v is None, "bias cannot be added to static value." - else: - assert bias_k is None - assert bias_v is None - - q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) - if k is not None: - k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - if v is not None: - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - if static_k is not None: - assert static_k.size(0) == bsz * num_heads - assert static_k.size(2) == head_dim - k = static_k - - if static_v is not None: - assert static_v.size(0) == bsz * num_heads - assert static_v.size(2) == head_dim - v = static_v - - src_len = k.size(1) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz - assert key_padding_mask.size(1) == src_len - - if add_zero_attn: - src_len += 1 - k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) - v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) - if attn_mask is not None: - attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), - dtype=attn_mask.dtype, - device=attn_mask.device)], dim=1) - if key_padding_mask is not None: - key_padding_mask = torch.cat( - [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), - dtype=key_padding_mask.dtype, - device=key_padding_mask.device)], dim=1) - - - attn_output_weights = bmm(q, k.transpose(1, 2), mask=acc_bitmask) - assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] - - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(0) - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float('-inf'), - ) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) - - attn_output_weights = F.softmax( - attn_output_weights, dim=-1) - attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) - - attn_output = bmm(attn_output_weights, v, mask=acc_bitmask) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - -class MultiheadAttention(nn.modules.activation.MultiheadAttention): - - def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, acc_bitmask=None): - super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim) - self.acc_bitmask = acc_bitmask - - - def forward(self, query, key, value, key_padding_mask=None, - need_weights=True, attn_mask=None): - if not self._qkv_same_embed_dim: - return multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, - acc_bitmask=self.acc_bitmask) - else: - return multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, - acc_bitmask=self.acc_bitmask) - - -def replace_mah(model, mask = None): - for child_name, child in model.named_children(): - if isinstance(child, nn.modules.activation.MultiheadAttention): - add_bias_kv = child.bias_k is not None - device = child.in_proj_weight.device - mah = MultiheadAttention(child.embed_dim, child.num_heads, - dropout=child.dropout, add_bias_kv=add_bias_kv, - add_zero_attn=child.add_zero_attn, kdim=child.kdim, - vdim=child.vdim, acc_bitmask=mask).to(device) - for yparam, xparam in zip(mah.parameters(), child.parameters()): - yparam.data.copy_(xparam.data) - setattr(model, child_name, mah) - else: - replace_mah(child, mask) \ No newline at end of file diff --git a/python/triton/nn/conv.py b/python/triton/nn/conv.py deleted file mode 100644 index a6966f6ae..000000000 --- a/python/triton/nn/conv.py +++ /dev/null @@ -1,166 +0,0 @@ -import triton -import torch.nn as nn -import torch -import torch.nn.functional as F - -class _conv2d(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, weight, bias, - stride, padding, dilation, groups, - acc_bitmask): - assert dilation == (1, 1) - assert groups == 1 - assert bias == None - pad_h, pad_w = padding - stride_h, stride_w = stride - n, c, h, w = x.size() - k, c, r, s = weight.size() - # allocate output - p = (h + 2*padding[0] - r)//stride[0] + 1 - q = (w + 2*padding[1] - s)//stride[1] + 1 - output = torch.empty((n, k, p, q), dtype=x.dtype, device=x.device) - # padding - if pad_h or pad_w: - x = triton.ops._einsum.pad(x, [pad_w, pad_w, pad_h, pad_h]) - # convolution - triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw', - x, weight, mask=acc_bitmask, - output=output, - values = {'pad_h': pad_h, - 'stride_h': stride_h, - 'pad_w': pad_w, - 'stride_w': stride_w}) - # prepare backprop - ctx.save_for_backward(x, weight) - ctx.stride = stride - ctx.padding = padding - ctx.acc_bitmask = acc_bitmask - # return - return output - - @staticmethod - def backward(ctx, dy): - # retrieve contextual information - x, weight = ctx.saved_tensors - stride = ctx.stride - padding = ctx.padding - acc_bitmask = ctx.acc_bitmask - # gradient of the input - dx = None - if ctx.needs_input_grad[0]: - # dy must be padded - n, k, p, q = dy.size() - n, c, h, w = x.size() - k, c, r, s = weight.size() - dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4]) - # have to be careful here - # the gradient of strided conv is a conv over a sparse image - # which can be decomposed as a set of smaller convs - dx = torch.empty_like(x) - for offh in range(stride[0]): - for offw in range(stride[1]): - poffh = (offh + padding[0]) % stride[0] - poffw = (offw + padding[1]) % stride[1] - pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0]) - pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1]) - if poffh >= r or poffw >= s: - dx[:, :, offh::stride[0], offw::stride[1]] = 0 - else: - triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw', - dypad[:, :, :, :], - weight[:, :, poffh::stride[0], poffw::stride[1]], - output = dx[:, :, offh::stride[0], offw::stride[1]], - mask = acc_bitmask, - values = {'pad_h': pad_h, - 'pad_w': pad_w}) - - # gradient for the weight - dw = None - if ctx.needs_input_grad[1]: - dw = torch.empty_like(weight) - triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs', - x, dy, output = dw, mask = acc_bitmask) - #print('dw: ', dw.view(-1)[0]) - return dx, dw, None, None, None, None, None, None -conv2d = _conv2d.apply - -class Conv2d(nn.Conv2d): - - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, - bias=True, padding_mode='zeros', - acc_bitmask = None): - super(Conv2d, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, - groups, bias, padding_mode) - self.acc_bitmask = acc_bitmask - - def forward(self, input): - #if self.kernel_size[0] == 3 and self.stride[0] != 1: - #print(self.padding, self.stride, input.size(), self.weight.size()) - # return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - return conv2d(input, self.weight, self.bias, self.stride, - self.padding, self.dilation, self.groups, - self.acc_bitmask) - - -def replace_conv2d(model, acc_bitmask = None): - for child_name, child in model.named_children(): - if isinstance(child, nn.Conv2d): - conv2d = Conv2d(child.in_channels, child.out_channels, child.kernel_size, - child.stride, child.padding, child.dilation, child.groups, - child.bias, child.padding_mode, - acc_bitmask=acc_bitmask) - for yparam, xparam in zip(conv2d.parameters(), child.parameters()): - yparam.data.copy_(xparam.data) - setattr(model, child_name, conv2d) - else: - replace_conv2d(child, acc_bitmask) - -# initialize input -#N, C, H, W, K, RS = 16, 32, 24, 24, 64, 3 -#torch.Size([128, 64, 30, 30]) torch.Size([128, 64, 3, 3]) -#torch.Size([128, 128, 15, 15]) torch.Size([256, 128, 3, 3]) -#torch.Size([128, 256, 8, 8]) torch.Size([512, 256, 3, 3]) - -if __name__ == '__main__': - N, C, H, W, K, RS = 128, 64, 30, 30, 128, 3 - #N, C, H, W, K, RS = 128, 128, 15, 15, 256, 3 - #N, C, H, W, K, RS = 128, 256, 8, 8, 512, 3 - pad, stride = 0, 1 - torch.manual_seed(0) - x = torch.randn((N, C, H, W)).cuda() - x.requires_grad_(True) - #x.data[:] = 1 - # initialize layers - torch.manual_seed(0) - rconv2d = nn.Conv2d(C, K, RS, stride, pad, bias=False).cuda() - torch.manual_seed(0) - tconv2d = Conv2d(C, K, RS, stride, pad, bias=False).cuda() - #rconv2d.weight.data[:] = 1 - #tconv2d.weight.data[:] = 1 - ry = rconv2d(x) - ty = tconv2d(x) - # reference - dy = torch.randn(ry.size()).cuda() - #dy.data[:] = 1 - ry.backward(dy) - rdx = x.grad.clone() - rdw = rconv2d.weight.grad.clone() - x.grad.zero_() - # triton - ty.backward(dy) - tdx = x.grad.clone() - tdw = tconv2d.weight.grad.clone() - x.grad.zero_() - # print error - diff = lambda x, y: (x - y).abs().max() - print(diff(ry, ty)) - print(diff(rdx, tdx)) - print(diff(rdw, tdw)) - #print((rdx - tdx).abs()) - - #print((rdx[0,0,:,:] - tdx[0,0,:,:])) - #print(rdx[0,0,:,:]) - #print(tdx[0,0,:,:]) \ No newline at end of file diff --git a/python/triton/nn/linear.py b/python/triton/nn/linear.py deleted file mode 100644 index a62fee600..000000000 --- a/python/triton/nn/linear.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch -import triton - - -def linear(x, w, bias = None): - print(x.size(), w.size()) - m, k = x.size() - k, n = w.size() - out = torch.empty([m, n], device=x.device) - triton.ops.einsum('mk,nk->mn', x, w, bias) - if bias is not None: - out += bias - return out \ No newline at end of file diff --git a/python/triton/ops/__init__.py b/python/triton/ops/__init__.py deleted file mode 100644 index ea638f76c..000000000 --- a/python/triton/ops/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .einsum import _einsum, einsum -from .batchnorm import _batchnorm, batchnorm \ No newline at end of file diff --git a/python/triton/ops/batchnorm.py b/python/triton/ops/batchnorm.py deleted file mode 100644 index cc4c23ee2..000000000 --- a/python/triton/ops/batchnorm.py +++ /dev/null @@ -1,136 +0,0 @@ -import triton -import torch -import math - -class _batchnorm(torch.autograd.Function): - - fwd_src = """ -void fwdbatchnorm(float *Y, float *M, float *V, - float *X, float *G, float *B, - int N, float eps) { - // pointers - int c = get_program_id(1); - int rm[TM] = 0 ... TM; - float *px[TM] = X + rm + c*N; - float* py[TM] = Y + rm + c*N; - - // compute mean - float accm[TM] = 0; - for(int i = 0; i < N; i = i + TM) - accm = accm + *(px + i); - float mean = (float)accm[+] / N; - *(M + c) = mean; - - // compute variance - float accv[TM] = 0; - for(int i = 0; i < N; i = i + TM){ - float x[TM] = *(px + i); - x = x - mean; - accv = accv + x*x; - } - float var = (float)accv[+] / N; - *(V + c) = var; - - // Normalize batch - float gamma = *(G + c); - float beta = *(B + c); - float rstdg = 1 / sqrtf(var + eps) * gamma; - for(int i = 0; i < N; i = i + TM){ - float x[TM] = *(px + i); - float y[TM] = (x - mean)*rstdg + beta; - *(py + i) = y; - } -} -""" - - bwd_src = """ -void bwdbatchnorm(float *DX, float *DG, float *DB, - float *DY, float *X, float *G, - float *M, float *V, - int N, float epsilon) { - - // pointers - int c = get_program_id(1); - int rx[TM] = 0 ... TM; - int offset = c*N; - float* px[TM] = X + rx + offset; - float* pdy[TM] = DY + rx + offset; - float* pdx[TM] = DX + rx + offset; - - // fetch statistics - float gamma = *(G + c); - float mean = *(M + c); - float var = *(V + c); - float rstd = 1 / sqrtf(var + epsilon); - - // compute dgamma and dbeta - float acc_dg[TM] = 0; - float acc_db[TM] = 0; - for(int i = 0; i < N; i = i + TM){ - float x[TM] = *(px + i); - float dy[TM] = *(pdy + i); - acc_dg += dy*(x - mean)*rstd; - acc_db += dy; - } - float dg = acc_dg[+]; - float db = acc_db[+]; - *(DG + c) = dg; - *(DB + c) = db; - - // compute dx - for(int i = 0; i < N; i = i + TM){ - float x[TM] = *(px + i); - float dy[TM] = *(pdy + i); - float xhat[TM] = (x - mean) * rstd; - float xtmp[TM] = (xhat * dg + db) / N; - float dx[TM] = (dy - xtmp) * rstd * gamma; - *(pdx + i) = dx; - } -} -""" - - fwd_kernel = None - bwd_kernel = None - - @staticmethod - def forward(ctx, x, gamma, beta, eps): - # lazy compilation of kernel - if _batchnorm.fwd_kernel is None: - _batchnorm.fwd_kernel = triton.kernel(fwd_src, defines = {'TM': 128}) - # shapes - shape = triton.shape(x) - dtype = x.dtype - # allocate outputs - C, H, W, B = shape[0], shape[1], shape[2], shape[3] - y = triton.empty(shape, dtype=dtype) - mean = triton.empty([C], dtype=dtype) - var = triton.empty([C], dtype=dtype) - # execute kernels - _batchnorm.fwd_kernel(y, mean, var, x, gamma, beta, H*W*B, eps, - grid = lambda opt: [1, C]) - # save - ctx.save_for_backward(x, gamma, beta, mean, var) - ctx.eps = eps - return y - - @staticmethod - def backward(ctx, dy): - # lazy compilation of kernel - if _batchnorm.bwd_kernel is None: - _batchnorm.bwd_kernel = triton.kernel(bwd_src, defines = {'TN': 128}) - # retrieve info - x, gamma, beta, mean, var = ctx.saved_tensors - eps = ctx.eps - # allocate result - dx = triton.empty(triton.shape(x), dtype=x.dtype) - dgamma = triton.empty(triton.shape(gamma), dtype=gamma.dtype) - dbeta = triton.empty(triton.shape(beta), dtype=beta.dtype) - # execute - C, H, W, B = triton.shape(x) - _batchnorm.bwd_kernel(dx, dgamma, dbeta, dy, - x, gamma, mean, var, - H*W*B, eps, - grid = lambda opt: [1, C]) - return dx, dgamma, dbeta, None - -batchnorm = _batchnorm.apply \ No newline at end of file diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py deleted file mode 100644 index f1fb719cf..000000000 --- a/python/triton/ops/einsum.py +++ /dev/null @@ -1,794 +0,0 @@ -from math import ceil, log2 -from enum import IntEnum -from functools import reduce -from operator import mul -from collections import OrderedDict -from collections import namedtuple -import re -import string -import triton -import torch -# numpy -- ideally removed in a future release -import numpy as np -# sympy -- ideally removed in a future release -import sympy as sp -from sympy.parsing.sympy_parser import parse_expr -from sympy.printing.ccode import C89CodePrinter - - -class _einsum(torch.autograd.Function): - - - ############################# - ## Triton-C code generation - ############################# - def print_cc(expr, axes_0, axes_1, axes_2, prefix): - if expr in axes_0: - return f'{prefix}r{expr}[:, newaxis, newaxis]' - if expr in axes_1: - return f'{prefix}r{expr}[newaxis, :, newaxis]' - if expr in axes_2: - return f'{prefix}r{expr}[newaxis, newaxis, :]' - return expr - - def unpack_cc(tile, axes, prefix, remat): - ret = '' - axes = list(map(str, axes)) - for i, d in enumerate(reversed(axes)): - if i == len(axes) - 1: - break - currs = ''.join(axes[: len(axes) - i]) - nexts = ''.join(axes[: len(axes) - (i + 1)]) - ty = '' if remat else 'int ' - sz = '' if remat or tile is None else f'[{tile}]' - ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n' - ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n' - return ret - - def strides_cc(name, expr): - ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1'] - ret = dict(zip(expr, ret)) - return ret - - def make_kernel(name, dtype, - expr_a, expr_b, expr_c, - sparse_a, sparse_b, sparse_c, - axes_m, axes_n, axes_k, axes_b, - multipleof_a, multipleof_b, multipleof_c, - stride_a_last, stride_b_last, stride_c_last, - lut_mode_a, lut_mode_b, - delta_a, delta_b, - blocks): - - use_lut_a = True - use_lut_b = True - - outer_sparse_a = [x for x in expr_a if x in sparse_a and x not in axes_k] - outer_dense_a = [x for x in expr_a if x not in sparse_a and x not in axes_k] - outer_sparse_b = [x for x in expr_b if x in sparse_b and x not in axes_k] - outer_dense_b = [x for x in expr_b if x not in sparse_b and x not in axes_k] - outer_dense_c = [x for x in expr_c if x not in sparse_c and x not in axes_k] - - - src = "" - - if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: - src += f""" -char __constant__* AD = calloc({4*len(delta_a)});""" - if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: - src += f""" -char __constant__* BD = calloc({4*len(delta_b)});""" - - - src += f""" -__global__ void {name}( - TYPE * A __noalias __readonly __aligned(16) - , TYPE * B __noalias __readonly __aligned(16) - , TYPE * C - , int * locks - , float alpha - , int matmul_m, int matmul_n, int matmul_k __multipleof(16) - , int div_m - """ - for dim in [axes_m, axes_n, axes_k, axes_b]: - for d in dim: - src += f", int dim_{d}" - src += "\n " - for dim, name, mult, sparse in zip([expr_a, expr_b, expr_c], - ['a', 'b', 'c'], - [multipleof_a, multipleof_b, multipleof_c], - [sparse_a, sparse_b, sparse_c]): - for d in range(len(dim) - 1): - if sparse and dim[d] == sparse[0]: - src += f', int stride_{name}_block __multipleof({mult})' - src += f", int stride_{name}_{d} __multipleof({mult})" - src += "\n " - if lut_mode_a == _einsum.LUT_MODE.SCALAR: - src += f", int stride_a_inner __multipleof({multipleof_a})" - src += f", int rem_delta_a __multipleof({multipleof_a})" - elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM: - src += ", int* AD __noalias __readonly __aligned(16)" - src += "\n " - if lut_mode_b == _einsum.LUT_MODE.SCALAR: - src += f", int stride_b_inner __multipleof({multipleof_b})" - src += f", int rem_delta_b __multipleof({multipleof_b})" - elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM: - src += ", int* BD" - src += "\n " - if sparse_c: - src += ", int* CD" - if sparse_a or sparse_b: - src += ", int width" - src += """) { - - - // program identifiers - int pid_0 = get_program_id(0); - int pid_1 = get_program_id(1); - -""" - if sparse_a: - src += f""" - int off_n = pid_0 / width; - int off_header = pid_0 % width; - int* header = AD + off_header * {2 + len(outer_sparse_a)}; - int* pdelta = AD + *(header + 0); - matmul_k = *(header + 1);""" - for i, d in enumerate(outer_sparse_a): - src += f""" - int off_{d} = *(header + {2 + i});""" - src += f""" - int inca = *(pdelta + 0); - int incb = *(pdelta + 1); - int off_{''.join(map(str, outer_dense_a))} = pid_1; -""" - _einsum.unpack_cc(None, outer_dense_a, "off_", False) - elif sparse_b: - src += f""" - int off_m = pid_0 / width; - int off_header = pid_0 % width; - int* header = BD + off_header * {2 + len(outer_sparse_b)}; - int* pdelta = BD + *(header + 0); - matmul_k = *(header + 1);""" - for i, d in enumerate(outer_sparse_b): - src += f""" - int off_{d} = *(header + {2 + i});""" - src += f""" - int incb = *(pdelta + 0); - int inca = *(pdelta + 1); - int off_{''.join(map(str, outer_dense_b))} = pid_1; -""" - _einsum.unpack_cc(None, outer_dense_b, "off_", False) - elif sparse_c: - src += f""" - // load LUT header - int *header = CD + pid_0 * {len(sparse_c)};""" - for i, d in enumerate(sparse_c): - src += f""" - int off_{d} = *(header + {i});""" - src += f""" - int off_{''.join(map(str, outer_dense_c))} = pid_1;""" - else: - src += """ - // re-order outer program ids - int grid_m = (matmul_m + TM - 1) / TM; - int grid_n = (matmul_n + TN - 1) / TN; - int off_mn = pid_0 / div_m; - int off_n = off_mn % grid_n; - int off_m = (off_mn / grid_n)*div_m + (pid_0 % div_m); - int off_b = get_program_id(1);""" - - src += """ -#if TZ == 1 - int off_k = 0; -#else - // get reduction sub-group program id - int pid_z = get_program_id(2); - int grid_z = get_num_programs(2); - int div_z = matmul_k / TZ; - int rem_z = matmul_k % TZ; - int off_k = pid_z * div_z; - matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z); -#endif - int rem_k = matmul_k % TK; - - // create ranges -""" - - sparse = sparse_a + sparse_b + sparse_c - for axes, tile, off, prefixes in zip([axes_m, axes_n, axes_b, axes_k], - ['TM', 'TN', 'TB', 'TK'], - ['off_m*TM', 'off_n*TN', 'off_b*TB', 'off_k'], - [['a', 'c'], ['b', 'c'], ['a', 'b', 'c'], ['a', 'b']]): - if not axes: - continue - currs = ''.join(map(str,axes)) - has_sparse_component = set(axes) & set(sparse) - if has_sparse_component: - src += f" int r{currs}[{tile}] = 0 ... {tile};\n" - src += _einsum.unpack_cc(tile, axes, f'r', False) - else: - src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n" - src += _einsum.unpack_cc(tile, axes, f'r', False) - for pfx in prefixes: - for d in axes: - is_dense_dim = d not in sparse - is_dense_storage = (pfx == 'a' and not sparse_a) or\ - (pfx == 'b' and not sparse_b) or\ - (pfx == 'c' and not sparse_c) - if not is_dense_dim and is_dense_storage: - src += f" int {pfx}r{d}[{tile}] = off_{d} * BLOCK{d.upper()} + r{d};\n" - elif is_dense_dim and has_sparse_component: - src += f" int {pfx}r{d}[{tile}] = off_{d};\n" - else: - src += f" int {pfx}r{d}[{tile}] = r{d};\n" - - src += f""" - // initialize pointers to A - int offa[TM, TK, TB] = {'inca' if sparse_a or sparse_b else '0'} """ - for i, sym in enumerate(expr_a): - ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b, 'a') - stride = f'stride_a_{i}' if i < len(expr_a) - 1 else f'{stride_a_last}' - src += f" + ({ccode}) * {stride}\n " - src += ';' - - src += """ - TYPE *pa[TM, TK, TB] = A + offa;""" - - - if not sparse_a and not sparse_b and use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR: - spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else '' - cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else '' - src += f""" - int offadelta[TK] = off_k + 0 ... TK; - int {spec} *padelta[TK] = {cast}AD + offadelta; - int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];""" - - src += f""" - - // initialize pointers to B - int offb[TK, TN, TB] = {'incb' if sparse_a or sparse_b else '0'}""" - for i, sym in enumerate(expr_b): - ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b, 'b') - stride = f'stride_b_{i}' if i < len(expr_b) - 1 else f'{stride_b_last}' - src += f" + ({ccode}) * {stride}\n " - src += ';' - - src += """ - TYPE *pb[TK, TN, TB] = B + offb;""" - - - if not sparse_a and not sparse_b and use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR: - spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else '' - cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else '' - src += f""" - // initialize pointers to B look-up table - int offbdelta[TK] = off_k + 0 ... TK; - int *pbdelta[TK] = BD + offbdelta;""" - - - rk = 'r{}'.format(''.join(map(str,axes_k))) - src += f""" - - // prefetch - int prefetch_k = select(rem_k > 0, rem_k, TK); - bool checkam[TM] = ar""" + ''.join(map(str,axes_m)) + f""" < matmul_m; - bool checkbn[TN] = br""" + ''.join(map(str,axes_n)) + f""" < matmul_n; - bool checkk[TK] = r{''.join(map(str, axes_k))} < prefetch_k; - bool checka[TM, TK, TB] = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; - bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis]; - TYPE a[TM, TK, TB] = checka ? *pa : 0; - TYPE b[TK, TN, TB] = checkb ? *pb : 0;""" - - if sparse_a: - src += f""" - // update pointers to look-up tables - pdelta += 2; - int incda = *(pdelta + 0); - int incdb = *(pdelta + 1); - pa += incda; - pb += incdb;""" - if sparse_b: - src += f""" - // update pointers to look-up tables - pdelta += 2; - int incdb = *(pdelta + 0); - int incda = *(pdelta + 1); - pa += incda; - pb += incdb;""" - - if not sparse_a and not sparse_b and lut_mode_a == _einsum.LUT_MODE.SCALAR: - src += """ - pa += rem_delta_a;""" - elif not sparse_a and not sparse_b: - src += """ - pa += incda; - padelta += TK; - incda = (*padelta)[newaxis, :, newaxis];""" - - if not sparse_a and not sparse_b and lut_mode_b == _einsum.LUT_MODE.SCALAR: - src += """ - pb += rem_delta_b;""" - elif not sparse_a and not sparse_b: - src += """ - pb += (*pbdelta)[:, newaxis, newaxis]; - pbdelta += TK;""" - - src += f""" - - // accumulate - float acc[TM, TN, TB] = 0; - for(int k = matmul_k; k > 0; k -= TK) {{ - acc += a @ b; - - // load inputs - checkk = k > TK; - checka = checkam[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; - checkb = checkk[:, newaxis, newaxis] && checkbn[newaxis, :, newaxis]; - a = *?(checka)pa; - b = *?(checkb)pb; - - // update pointers""" - if sparse_a: - src += """ - pdelta += 2; - incda = *(pdelta + 0); - incdb = *(pdelta + 1); - pa += incda; - pb += incdb; - """ - elif sparse_b: - src += """ - pdelta += 2; - incdb = *(pdelta + 0); - incda = *(pdelta + 1); - pa += incda; - pb += incdb; - """ - else: - if lut_mode_a == _einsum.LUT_MODE.SCALAR: - src += """ - pa += stride_a_inner;""" - else: - src += """ - pa += incda; - padelta += TK; - incda = (*padelta)[newaxis, :, newaxis];""" - if lut_mode_b == _einsum.LUT_MODE.SCALAR: - src += """ - pb += stride_b_inner;""" - else: - src += """ - pb += (*pbdelta)[:, newaxis, newaxis]; - pbdelta += TK;""" - - src += f""" - }} - TYPE c[TM, TN, TB] = acc; - - // initialize pointers to C - int offc[TM, TN, TB] = {'pid_0*TN*TN' if sparse_c else 0}""" - for i, sym in enumerate(expr_c): - stride = f'stride_c_{i}' if i < len(expr_c) - 1 else f'{stride_c_last}' - ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b, 'c') - src += f"\n + ({ccode}) * {stride}" - src += ';' - - src += """ - TYPE *pc[TM, TN, TB] = C + offc; - - // bounds-checking - bool checkcm[TM] = cr""" + ''.join(map(str,axes_m)) + """ < matmul_m; - bool checkcn[TN] = cr""" + ''.join(map(str,axes_n)) + """ < matmul_n; - bool checkc[TM, TN, TB] = checkcm[:, newaxis, newaxis] && - checkcn[newaxis, :, newaxis]; - - // write back -#if TZ == 1 - *?(checkc)pc = c; -#else - int *plock = locks + pid_mn + pid_b * get_num_programs(0); - int *pcount = plock + 1024*1024; - // spin - for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc)pc = c; - else - *?(checkc)pc = c + *?(checkc)pc; - atomic_xchg(pcount, (count + 1) % (grid_z)); - atomic_xchg(plock, 0); -#endif -} -""" - # compilation options - TM, TN, TB, TZ = [32], [32], 1, [1] - TK = 16 if dtype==torch.float16 else 8 - defines = {'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype} - for d, B in blocks.items(): - defines[f'BLOCK{d}'] = B - # create kernel - ret = triton.kernel(src, defines=defines) - # set constant - if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: - ret.set_constant('AD', delta_a) - if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: - ret.set_constant('BD', delta_b) - return ret - - ############################ - ## Look-up Table - ############################ - - class LUT_MODE(IntEnum): - SCALAR = 1 - CONSTANT = 2 - DRAM = 3 - - def lut_mode(delta): - if delta.size == 0 or np.min(delta) == np.max(delta): - return _einsum.LUT_MODE.SCALAR - #if delta.size < 4096: - # return _einsum.LUT_MODE.CONSTANT - return _einsum.LUT_MODE.DRAM - - def symbolic_delta(symbols, axes): - rank = len(symbols) - strides = [sp.symbols(f'stride{d}') for d in range(rank)] - nexts = {s: sp.symbols(f'next{s}') for s in axes} - delta = 0 - for i in range(rank): - delta += strides[i] * (symbols[i].subs(nexts) - symbols[i]) - return delta - - def unpack_offset(k, axes, dims): - ret = dict() - for d in reversed(axes): - ret[d] = k % dims[d] - k = k // dims[d] - return ret - - - def make_dsd_delta(axes, step, stride, dims, symbols, sparse, layout, blocks): - # depth of reductions - depth = layout.sum(*[i for i, d in enumerate(sparse) if d in axes]) - # outer dimension indices - outer = torch.nonzero(depth, as_tuple=False) - outer = [outer[:,i] for i in range(outer.shape[1])] - # find offset of outer dimensions - depth = depth.view(-1) - offsets = torch.zeros_like(depth) - offsets[1:] = torch.cumsum(depth[:-1], 0) - # compute delta for b - # TODO: support multiple sparse red indices - col = next((i for i, d in enumerate(sparse) if d in axes), None) - block = blocks[sparse[-1].upper()] - div = block // step - delta_b = torch.nonzero(layout.transpose(-1, col), as_tuple=False)[:, -1].reshape(-1).contiguous() - delta_b *= block - delta_b = [delta_b + step*i for i in range(div)] - delta_b = torch.stack(delta_b, dim=1) - delta_b = delta_b.view(-1) - # compute delta for a - bstride = 1 - for d in sparse[::-1]: - if d in axes: - break - bstride *= blocks[d.upper()] - order = [d for d in sparse if d not in axes] +\ - [d for d in sparse if d in axes] - idx = [sparse.index(d) for d in order] - layout[layout > 0] = 1 + torch.arange(layout.sum(), device=layout.device) - layout = layout.permute(*idx) - delta_a = layout[layout > 0] - 1 - delta_a *= np.prod(list(blocks.values())) - saved = delta_a[offsets] - delta_a[1:] = delta_a[1:] - delta_a[:-1] - delta_a = delta_a.view(-1, 1).repeat(1, div) - delta_a[:, 1:] = step*bstride - delta_a[:, 0] -= (div - 1)*step*bstride - delta_a[offsets, 0] = saved - delta_a = delta_a.view(-1) - delta = torch.stack((delta_a, delta_b), dim=1).view(-1).contiguous() - # form look-up table - depth *= blocks[symbols[-1].upper()] - offsets *= div - header = torch.stack((offsets, depth, *outer), dim=1).view(-1).contiguous() - nouter = 2 + len(outer) - header[::nouter] = header[::nouter]*2 + header.shape[0] - lut = torch.cat((header, delta)).int().int().cpu().numpy() - return lut, nouter, _einsum.LUT_MODE.DRAM - - def make_delta(axes, step, stride, dims, symbols, sparse, layout, lut = None, nouter = None): - # symbolic pointer increments - symbols = [sp.symbols(x) for x in symbols] - delta = _einsum.symbolic_delta(symbols, axes) - args = [f'stride{d}' for d in range(len(stride))] - args += [f'{sk}' for sk in axes] - args += [f'next{sk}' for sk in axes] - fn = sp.lambdify(args, delta, 'numpy') - if lut is None: - # inner axes values - inner = [dims[d] for d in axes] - inner = np.prod(inner) - rem = inner % step - rem = rem if rem > 0 else step - # k = [0, 1, ..., step, rem, rem + 1, ... rem + inner] - # nextk = [rem, 1 + rem, ..., step + rem, rem + step, rem + 1 + step, ..., inner + step] - k = np.concatenate((np.arange(step), np.arange(rem, inner))).astype(np.int32) - nextk = np.concatenate((k[:step] + rem, k[step:] + step)) - else: - idx = (lut[:lut[0]:nouter] - lut[0])//2 - k = lut[lut[0]+1::2] - k = np.insert(k, idx, 0) - nextk = k[1:] - k = k[:-1] - # offsets - off = _einsum.unpack_offset(k, axes, dims) - nextoff = _einsum.unpack_offset(nextk, axes, dims) - # evaluate deltas - args = [s for s in stride] - args += [off[sk] for sk in axes] - args += [nextoff[sk] for sk in axes] - delta = fn(*args) - delta = np.maximum(delta, 0) - if lut is not None: - idx = idx[1:] + np.arange(idx.shape[0] - 1) - delta = np.delete(delta, idx) - lut[lut[0]+1::2] = delta - return None, None - return delta, _einsum.lut_mode(delta[step:-step]) - - @staticmethod - def make_sdd_lut(layout_c, sparse_c, blocks): - nnz = torch.nonzero(layout_c, as_tuple=False) - lut = nnz.reshape(-1).int().cuda() - return lut - - ############################ - ## Compilation - ############################ - - class instance: - - locks = None - kernel_cache = dict() - - def __init__(self, einsum, dtype, stride_a, stride_b, shape_a, shape_b, layouts, blocks): - # parse symbols - expr_a, expr_bc = einsum.split(",") - expr_b, expr_c = expr_bc.split("->") - sym_a = expr_a.lower() - sym_b = expr_b.lower() - sym_c = expr_c.lower() - sparse_a = [x.lower() for x in expr_a if x.isupper()] - sparse_b = [x.lower() for x in expr_b if x.isupper()] - sparse_c = [x.lower() for x in expr_c if x.isupper()] - layout_a = layouts.get(expr_a) - layout_b = layouts.get(expr_b) - layout_c = layouts.get(expr_c) - # parse axes - axes_b = [d for d in sym_a if d in sym_b and d in sym_c] - axes_m = [d for d in sym_a if d not in sym_b and d in sym_c] - axes_k = [d for d in sym_a if d in sym_b and d not in sym_c] - axes_n = [d for d in sym_b if d not in sym_a and d in sym_c] - axes = axes_b + axes_m + axes_n + axes_k - # check block sizes - for d in sparse_a + sparse_b + sparse_c: - if d.upper() not in blocks: - raise ValueError(f'unspecified block size for dimension: {d.upper()}') - # check layout is present - if sparse_a and layout_a is None: - raise ValueError('A is sparse but not layout provided') - if sparse_b and layout_b is None: - raise ValueError('B is sparse but not layout provided') - if sparse_c and layout_c is None: - raise ValueError('C is sparse but not layout provided') - # check dimensions - dims_a = dict([(x, y) for x,y in zip(sym_a, shape_a) if x not in sparse_a]) - dims_b = dict([(x, y) for x,y in zip(sym_b, shape_b) if x not in sparse_b]) - dims_La = None if layout_a is None else dict(zip([x for x in expr_a if x.isupper()], layout_a.shape)) - dims_Lb = None if layout_b is None else dict(zip([x for x in expr_b if x.isupper()], layout_b.shape)) - # TODO: could be cleaner - read_shape = lambda d, dimsT, dimsL, sparse: dimsL[d.upper()] * blocks[d.upper()] if d in sparse else dimsT[d] - for d in axes_b + axes_m + axes_n + axes_k: - dim_a = read_shape(d, dims_a, dims_La, sparse_a) if d in sym_a else None - dim_b = read_shape(d, dims_b, dims_Lb, sparse_b) if d in sym_b else None - if d in axes_b and dim_a and dim_b and dim_a != dim_b: - raise ValueError(f'incomparible batch dimension {d} (A: {dim_a}, B: {dim_b})') - if d in axes_k and dim_a and dim_b and dim_a != dim_b: - raise ValueError(f'incompatible inner dimension {d} (A: {dim_a}, B: {dim_b})') - dims = dict() - dims.update(dims_a) - dims.update(dims_b) - for i, d in enumerate(sparse_a): - dims[d] = layout_a.shape[i] * blocks[d.upper()] - for i, d in enumerate(sparse_b): - dims[d] = layout_b.shape[i] * blocks[d.upper()] - # allocate output - shape_c = [dims[d] if d.islower() else blocks[d] for d in expr_c] - if sparse_c: - shape_c.insert(expr_c.index(sparse_c[0].upper()), int(layout_c.sum())) - stride_c = [None] * len(shape_c) - stride_c[-1] = 1 - for i in reversed(range(len(shape_c) - 1)): - stride_c[i] = stride_c[i+1] * shape_c[i+1] - # look-up tables - TK = 16 if dtype == torch.float16 else 8 - if sparse_a and not sparse_b: - delta_a, nouter, lut_mode_a = _einsum.make_dsd_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, blocks) - delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, delta_a, nouter) - if sparse_b and not sparse_a: - delta_b, nouter, lut_mode_b = _einsum.make_dsd_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b, blocks) - delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a, delta_b, nouter) - if not sparse_a and not sparse_b: - delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, sparse_a, layout_a) - delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, sparse_b, layout_b) - if sparse_c: - delta_c = _einsum.make_sdd_lut(layout_c, sparse_c, blocks) - # hash for recompilation - stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0]) - stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0]) - stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0]) - stride_a_last = stride_a[-1] - stride_b_last = stride_b[-1] - stride_c_last = stride_c[-1] - name = f'{dtype}_{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\ - f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}'\ - f'_{stride_a_last}_{stride_b_last}_{stride_c_last}' - # recompile if necessary - cache = _einsum.instance.kernel_cache - if name not in cache: - cachesize = len(cache) - cache[name] = _einsum.make_kernel(f'__einsum{cachesize}', - dtype, - sym_a, sym_b, sym_c, - sparse_a, sparse_b, sparse_c, - axes_m, axes_n, axes_k, axes_b, - stride_a_multiple, stride_b_multiple, stride_c_multiple, - stride_a_last, stride_b_last, stride_c_last, - lut_mode_a, lut_mode_b, - delta_a, delta_b, - blocks) - self.kernel = cache[name] - # Initialize locks - if _einsum.instance.locks is None: - _einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda() - # Kernel arguments - dim_m = [dims[d] for d in axes_m] - dim_n = [dims[d] for d in axes_n] - dim_k = [dims[d] for d in axes_k] - dim_b = [dims[d] for d in axes_b] - M = reduce(mul, dim_m, 1) - N = reduce(mul, dim_n, 1) - K = reduce(mul, dim_k, 1) - B = reduce(mul, [dims[d] for d in axes_b if d.upper() not in einsum], 1) - stride_a = list(stride_a[:-1]) - stride_b = list(stride_b[:-1]) - stride_c = list(stride_c[:-1]) - alpha = 1. - div_m = 1 - self.args = [None, None, None] - self.args += [_einsum.instance.locks] - self.args += [alpha, M, N, K, div_m] - self.args += dim_m - self.args += dim_n - self.args += dim_k - self.args += dim_b - self.args += stride_a - self.args += stride_b - self.args += stride_c - # LUT for A - if lut_mode_a == _einsum.LUT_MODE.SCALAR: - self.args += [delta_a[TK], delta_a[0]] - elif sparse_a or lut_mode_a == _einsum.LUT_MODE.DRAM: - self.args += [torch.from_numpy(delta_a).cuda()] - # LUT for B - if lut_mode_b == _einsum.LUT_MODE.SCALAR: - self.args += [delta_b[TK], delta_b[0]] - elif sparse_b or lut_mode_b == _einsum.LUT_MODE.DRAM: - self.args += [torch.from_numpy(delta_b).cuda()] - # LUT for C - if sparse_c: - self.args += [delta_c] - if sparse_a or sparse_b: - width = delta_a[0] // nouter if sparse_a else delta_b[0] // nouter - self.args += [width] - # Grid - if sparse_a: - self.grid = lambda opt: [width*triton.cdiv(N, opt.d('TN')), B, opt.d('TZ')] - elif sparse_b: - self.grid = lambda opt: [width*triton.cdiv(M, opt.d('TM')), B, opt.d('TZ')] - elif sparse_c: - width = int(layout_c.sum()) - self.grid = lambda opt: [width, B, opt.d('TZ')] - else: - self.grid = lambda opt: [triton.cdiv(M, opt.d('TM')) * - triton.cdiv(N, opt.d('TN')), - triton.cdiv(B, opt.d('TB')), - opt.d('TZ')] - # position of dynamic arguments - self.pos_a = 0 - self.pos_b = 1 - self.pos_c = 2 - # save information on the operation - self.expr_a = expr_a - self.expr_b = expr_b - self.expr_c = expr_c - self.matmul_B = B - self.matmul_M = M - self.matmul_N = N - self.matmul_K = K - # output shape - self.shape_c = shape_c - - - def run(self, a, b): - c = torch.empty(*self.shape_c, dtype=a.dtype, device=a.device) - self.args[self.pos_a] = a - self.args[self.pos_b] = b - self.args[self.pos_c] = c - self.kernel(*self.args, grid=self.grid) - return c - - - - - ############################ - ## Forward - ############################ - - instance_cache = dict() - registry = dict() - @staticmethod - def forward(ctx, expr, a, b, layouts, blocks): - # compile einsum instance - cache = _einsum.instance_cache - key = (expr, a.dtype, - a.stride(), b.stride(), - a.shape , b.shape) - if key not in cache: - cache[key] = _einsum.instance(expr, a.dtype, - a.stride(), b.stride(), - a.shape , b.shape , - layouts, blocks) - instance = cache[key] - # run and mark as dirty c modified in-place - c = instance.run(a, b) - # save information in context - ctx.expr_a = instance.expr_a - ctx.expr_b = instance.expr_b - ctx.expr_c = instance.expr_c - ctx.matmul_B = instance.matmul_B - ctx.matmul_M = instance.matmul_M - ctx.matmul_N = instance.matmul_N - ctx.matmul_K = instance.matmul_K - ctx.save_for_backward(a, b) - return c - - - ############################ - ## Backward - ############################ - - @staticmethod - def backward(ctx, dy): - a, b = ctx.saved_tensors - expr_a = ctx.expr_a - expr_b = ctx.expr_b - expr_c = ctx.expr_c - # gradient of first argument - da = None - if ctx.needs_input_grad[1]: - da = torch.empty_like(a) - einsum(f'{expr_c},{expr_b}->{expr_a}', dy, b, da) - # gradient of second argument - db = None - if ctx.needs_input_grad[2]: - db = torch.empty_like(b) - einsum(f'{expr_a},{expr_c}->{expr_b}', a, dy, db) - return None, da, db, None, None, None, None, None, None, None - - -def einsum(expr, a, b, layouts = None, blocks = dict()): - return _einsum.apply(expr, a, b, layouts, blocks) \ No newline at end of file diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 9996310ab..483fde4df 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -21,7 +21,7 @@ int main() { // config_t{ord, x[0], x[1], 1280, 1280, 1280}, // config_t{ord, x[0], x[1], 1536, 1536, 1536}, // config_t{ord, x[0], x[1], 2048, 2048, 2048}, - config_t{ord, x[0], x[1], 4096, 4096, 4096}, + config_t{ord, x[0], x[1], 8192, 8192, 8192}, // config_t{ord, x[0], x[1], 256, 16, 256}, // config_t{ord, x[0], x[1], 512, 16, 512}, diff --git a/tests/common/conv.h b/tests/common/conv.h index 4e194e154..3101f396f 100644 --- a/tests/common/conv.h +++ b/tests/common/conv.h @@ -102,7 +102,7 @@ void triton_conv(drv::context* context, drv::stream* stream, stream->write(&*ddelta, true, 0, hdelta); // macros - rt::function::options_space_t opt; + rt::options_space_t opt; opt.defines.push_back({"TYPE", {ty}}); opt.defines.push_back({"TM", {"128"}}); opt.defines.push_back({"TN", {"128"}}); @@ -125,7 +125,7 @@ void triton_conv(drv::context* context, drv::stream* stream, W*H*CI, W*H, W, 1, CO*S*R , CO*S, CO, 1, Q*P*CO, Q*P, Q, 1}; - auto grid = [Z,P,Q,CO](const rt::function::options_t& x) { + auto grid = [Z,P,Q,CO](const rt::options_t& x) { return rt::grid_t{ceil(Z*P*Q, x.D("TM")), ceil(CO , x.D("TN")), (size_t)x.D("TZ")}; diff --git a/tests/common/copy.h b/tests/common/copy.h index 09c30a952..60dcd8233 100644 --- a/tests/common/copy.h +++ b/tests/common/copy.h @@ -107,7 +107,7 @@ void triton_copy_nd(drv::context* context, drv::stream* stream, const std::vecto auto dx = std::unique_ptr(drv::buffer::create(context, size*dtsize)); auto dy = std::unique_ptr(drv::buffer::create(context, size*dtsize)); // create options - rt::function::options_space_t opt; + rt::options_space_t opt; // macros diff --git a/tests/common/dot.h b/tests/common/dot.h index 28424764e..2fe46d9cc 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -101,46 +101,38 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, // ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size()); // macros - rt::function::options_space_t opt; + rt::options_space_t opts; // A access patterns - opt.defines.push_back({"USEA", {AT? "a" : "a" }}); - opt.defines.push_back({"BROADCAST_AK", {AT? "newaxis, :" : "newaxis, :" }}); - opt.defines.push_back({"BROADCAST_AM", {AT? ":, newaxis" : ":, newaxis" }}); - opt.defines.push_back({"SHAPE_A", {AT? "TM, TK" : "TM, TK" }}); - opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); - opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); + opts.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); + opts.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); // B access patterns - opt.defines.push_back({"USEB", {BT? "b" : "b" }}); - opt.defines.push_back({"BROADCAST_BK", {BT? ":, newaxis" : ":, newaxis" }}); - opt.defines.push_back({"BROADCAST_BN", {BT? "newaxis, :" : "newaxis, :" }}); - opt.defines.push_back({"SHAPE_B", {BT? "TK, TN" : "TK, TN" }}); - opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }}); - opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }}); + opts.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }}); + opts.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }}); // data-type - opt.defines.push_back({"TYPE", {ty}}); + opts.defines.push_back({"TYPE", {ty}}); // tile sizes if(mode == TEST) { - opt.defines.push_back({"TM", {std::to_string(TM)}}); - opt.defines.push_back({"TN", {std::to_string(TN)}}); - opt.defines.push_back({"TK", {std::to_string(TK)}}); - opt.defines.push_back({"TZ", {"1"}}); - opt.num_warps = {nwarp}; + opts.defines.push_back({"TM", {std::to_string(TM)}}); + opts.defines.push_back({"TN", {std::to_string(TN)}}); + opts.defines.push_back({"TK", {std::to_string(TK)}}); + opts.defines.push_back({"TZ", {"1"}}); + opts.num_warps = {nwarp}; } if(mode == BENCH) { - opt.defines.push_back({"TM", {"64", "128"}}); - opt.defines.push_back({"TN", {"64", "128"}}); - opt.defines.push_back({"TK", {"16"}}); - opt.defines.push_back({"TZ", {"1"}}); - opt.num_warps = {4}; + opts.defines.push_back({"TM", {"128"}}); + opts.defines.push_back({"TN", {"128"}}); + opts.defines.push_back({"TK", {"32"}}); + opts.defines.push_back({"TZ", {"1"}}); + opts.num_warps = {4}; } // kernels - rt::function function(src::dot, opt); + rt::function function(src::dot, opts); dot_arg_t args = {da->addr_as_uintptr_t(), db->addr_as_uintptr_t(), dc->addr_as_uintptr_t(), 1, M, N, K, lda, ldb, ldc, dlocks->addr_as_uintptr_t()}; - auto grid = [M, N](const rt::function::options_t& x) { - return rt::grid_t{ceil(M, x.D("TM")), + auto grid = [M, N](const rt::options_t& x) { + return rt::grid_t{ceil(M, x.D("TM"))* ceil(N, x.D("TN")), (size_t)x.D("TZ")}; }; @@ -151,19 +143,25 @@ void triton_dot(drv::context* context, drv::stream* stream, bool AT, bool BT, double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); bench.push_back(tflops(triton_ns)); -// // cublas -// if(cublas::cublasinit()){ -// T alpha(static_cast(1)); -// T beta(static_cast(0)); -// cublasGemmAlgo_t fastest; + // cublas + if(cublas::cublasinit()){ + T alpha(static_cast(1)); + T beta(static_cast(0)); + cublasGemmAlgo_t fastest; // cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); -// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, -// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, -// ldc, nullptr, fastest); }, stream); -// bench.push_back(tflops(cublas_ms)); -// } + double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, !AT, !BT, M, N, K, + &alpha, &*da, lda, &*db, ldb, &beta, &*dc, + ldc); }, stream); + bench.push_back(tflops(cublas_ms)); + } } +// rt::options_t opt; +// for(auto &x: opts.defines) +// opt.defines[x.first] = x.second[0]; +// opt.num_warps = 1; +// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl; + // test triton if(mode == TEST){ srand(0); diff --git a/tests/common/reduce.h b/tests/common/reduce.h index 504676ec8..34369e6e7 100644 --- a/tests/common/reduce.h +++ b/tests/common/reduce.h @@ -95,25 +95,25 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec y_strides.push_back(y_strides[d] + " * " + y_shapename[y_order[d]]); // options - rt::function::options_space_t opt; - opt.defines.push_back({"TYPE", {ty}}); + rt::options_space_t opts; + opts.defines.push_back({"TYPE", {ty}}); for(int d = 0; d < rank_x; d++) - opt.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}}); + opts.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}}); for(int d = 0; d < rank_y; d++) - opt.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}}); + opts.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}}); if(TS.empty()) TS = tile_nd(rank_x); for(int d = 0; d < rank_x; d++) - opt.defines.push_back({"TS" + std::to_string(d), TS[d]}); + opts.defines.push_back({"TS" + std::to_string(d), TS[d]}); std::vector axy; for(int d = 0; d < rank_x; d++) if(d != axis) axy.push_back(d); for(int d = 0; d < rank_y; d++) - opt.defines.push_back({"TY" + std::to_string(d), {std::to_string(shape_x[axy[d]])}}); + opts.defines.push_back({"TY" + std::to_string(d), {std::to_string(shape_x[axy[d]])}}); for(int d = 0; d < rank_y; d++) - opt.defines.push_back({"RY" + std::to_string(d), {"rs" + std::to_string(axy[d])}}); + opts.defines.push_back({"RY" + std::to_string(d), {"rs" + std::to_string(axy[d])}}); std::string RED = ""; for(int n = 0; n < rank_x; n++){ @@ -121,30 +121,40 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec RED += ", "; RED += (n==axis) ? to_str(op) : ":"; } - opt.defines.push_back({"RED", {RED}}); - opt.num_warps = {1}; + opts.defines.push_back({"RED", {RED}}); + opts.num_warps = {2}; // kernel - rt::function function(src::reduce_nd[rank_x - 1], opt); + rt::function function(src::reduce_nd[rank_x - 1], opts); // input buffers auto dx = std::unique_ptr(drv::buffer::create(context, size_x*dtsize)); auto dy = std::unique_ptr(drv::buffer::create(context, size_y*dtsize)); // grid - reduce_arg_t args = {*dx->cu(), *dy->cu(), shape_x[0]}; - if(shape_x.size() > 1) args.S1 = shape_x[1]; - if(shape_x.size() > 2) args.S2 = shape_x[2]; + std::stringstream oss; + rt::add_arg(oss, *dx->cu()); + rt::add_arg(oss, *dy->cu()); + rt::add_arg(oss, (uint32_t)shape_x[0]); + if(shape_x.size() > 1) rt::add_arg(oss, (uint32_t)shape_x[1]); + if(shape_x.size() > 2) rt::add_arg(oss, (uint32_t)shape_x[2]); std::vector ts = {"TS0", "TS1", "TS2"}; auto grid = grid_nd(shape_x, ts); // metrics if(mode == BENCH){ auto gbps = [&](double ns) { return 2 * size_x * dtsize / (ns * 1e-9) * 1e-9; }; - double triton_ns = triton::tools::bench([&]() { function((void**)&args, sizeof(args), grid, stream, device);}, stream); + double triton_ns = triton::tools::bench([&]() { function((void**)oss.str().data(), oss.str().size(), grid, stream, device);}, stream); bench.push_back(gbps(triton_ns)); } +// rt::options_t opt; +// for(auto &x: opts.defines) +// opt.defines[x.first] = x.second[0]; +// opt.num_warps = 1; +// std::cout << function.get_asm(rt::ASM_NV_PTX, device, opt) << std::endl; + + // test triton if(mode == TEST){ std::vector hy(size_y); @@ -153,7 +163,7 @@ void triton_reduce_nd(drv::context* context, drv::stream* stream, const std::vec init_zeros(hy); init_rand(hx); stream->write(&*dx, true, 0, hx); - function((void**)&args, sizeof(args), grid, stream, device); + function((void**)oss.str().data(), oss.str().size(), grid, stream, device); stream->synchronize(); stream->read(&*dy, true, 0, hy); cc_reduce_nd(ry, hx, op, axis, shape_x); diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index 54e58e13e..5973fe762 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -2,6 +2,9 @@ namespace src { const char *dot = R"( +#define STM 8 +#define STN 8 + __global__ void dot(TYPE * A __noalias __readonly __aligned(16), TYPE * B __noalias __readonly __aligned(16), TYPE * C __noalias __aligned(16), @@ -14,54 +17,59 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), int ldc __multipleof(8), int* locks) { // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int ridz = get_program_id(2); - int gridx = M / TM; - int gridy = N / TN; - int rid = ridx + ridy * gridx; - ridx = rid / gridy; - ridy = rid % gridy; - int rm[TM] = ridx * TM + 0 ... TM; - int rn[TN] = ridy * TN + 0 ... TN; + int pid = get_program_id(0); + int pidz = get_program_id(2); + int gridm = (M + TM - 1) / TM; + int gridn = (N + TN - 1) / TN; + int width = STM*gridn; + int stm = pid / width; + int RSTM = min(gridm - stm*STM, STM); + int stn = (pid % width) / (RSTM*STN); + int RSTN = min(gridn - stn*STN, STN); + int laneid = pid % (RSTM * RSTN); + int lanem = laneid / RSTN; + int lanen = laneid % RSTN; + int pidm = stm*STM + lanem; + int pidn = stn*STN + lanen; + int rm[TM] = pidm * TM + 0 ... TM; + int rn[TN] = pidn * TN + 0 ... TN; // reduction splitting K = K / TZ; - int rk[TK] = ridz * K + 0 ... TK; - + int rk[TK] = pidz * K + 0 ... TK; // pointers to operands - int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; - int offb[SHAPE_B] = rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; - TYPE* pa[SHAPE_A] = A + offa; - TYPE* pb[SHAPE_B] = B + offb; - + int offa[TM, TK] = rk[newaxis, :] * STRIDE_AK + rm[:, newaxis] * STRIDE_AM; + int offb[TK, TN] = rk[:, newaxis] * STRIDE_BK + rn[newaxis, :] * STRIDE_BN; + TYPE* pa[TM, TK] = A + offa; + TYPE* pb[TK, TN] = B + offb; // prefetches operands - bool checka[SHAPE_A] = rk[BROADCAST_AK] < K; - bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K; - TYPE a[SHAPE_A] = checka ? *pa : 0; - TYPE b[SHAPE_B] = checkb ? *pb : 0; - + bool checka[TM, TK] = rk[newaxis, :] < K; + bool checkb[TK, TN] = rk[:, newaxis] < K; + TYPE a[TM, TK] = checka ? *pa : 0; + TYPE b[TK, TN] = checkb ? *pb : 0; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; // reduction loop float acc[TM, TN] = 0; for(int k = K; k > 0; k -= TK){ - acc += USEA @ USEB; - bool checka[SHAPE_A] = k > TK; - bool checkb[SHAPE_B] = k > TK; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; + bool checka[TM, TK] = k > TK; + bool checkb[TK, TN] = k > TK; + acc += a @ b; a = *?(checka)pa; b = *?(checkb)pb; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; } acc = acc * alpha; TYPE c[TM, TN] = acc; // epilogue - int rxm[TM] = get_program_id(0) * TM + 0 ... TM; - int rxn[TN] = get_program_id(1) * TN + 0 ... TN; - int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :]; + int rcm[TM] = pidm * TM + 0 ... TM; + int rcn[TN] = pidn * TN + 0 ... TN; + int offc[TM, TN] = rcm[:, newaxis] * ldc + rcn[newaxis, :]; TYPE* pc[TM, TN] = C + offc; - bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); - + bool checkc[TM, TN] = rcm[:, newaxis] < M && + rcn[newaxis, :] < N; #if (TZ==1) *?(checkc) pc = c; #else diff --git a/tests/common/util.h b/tests/common/util.h index 56f98acde..80774a3b5 100644 --- a/tests/common/util.h +++ b/tests/common/util.h @@ -19,13 +19,13 @@ inline size_t ceil(size_t x, size_t y) { } inline rt::function::grid_fn_ty grid1d(size_t N) { - return [N](const rt::function::options_t& x) { + return [N](const rt::options_t& x) { return rt::grid_t{ceil(N, x.D("TN"))}; }; } inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { - return [M, N](const rt::function::options_t& x) { + return [M, N](const rt::options_t& x) { return rt::grid_t{ceil(M, x.D("TM")), ceil(N, x.D("TN"))}; }; @@ -33,7 +33,7 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { inline rt::function::grid_fn_ty grid_nd(const std::vector &shape, const std::vector& ts) { - return [&shape, &ts](const rt::function::options_t& x) { + return [&shape, &ts](const rt::options_t& x) { rt::grid_t ret; for(size_t d = 0; d < shape.size(); d++) ret.push_back(ceil(shape[d], x.D(ts[d]))); @@ -71,6 +71,12 @@ void init_zeros(std::vector& x) { x[i] = 0; } +template +void init_ones(std::vector& x) { + for(size_t i = 0; i < x.size(); i++) + x[i] = 1; +} + /* ------------------------ * Loop Nests * ------------------------ */ @@ -163,7 +169,7 @@ for(size_t i = 0; i < hc.size(); i++) std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; return false; } -return true; + return true; } } diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index 896906da1..b7931e248 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -8,22 +8,58 @@ int main() { auto context = triton::driver::backend::contexts::get_default(); triton::driver::stream* stream = triton::driver::stream::create(context->backend()); // shapes to test - typedef std::tuple config_t; + typedef std::tuple config_t; std::vector configs; - for(int TM: std::vector{32, 64, 128}) - for(int TN: std::vector{32, 64, 128}) - for(int TK: std::vector{16}) - for(int nwarps: std::vector{4}) - for(bool AT: std::array{false, true}) - for(bool BT: std::array{false, true}){ - configs.push_back(config_t{FLOAT, AT, BT, TM, TN, TK, TM, TN, TK, nwarps}); - } + for(dtype_t dtype: std::vector{FLOAT, HALF}) + for(bool AT: std::vector{false, true}) + for(bool BT: std::vector{false, true}){ + // 1 warp + configs.push_back({dtype, AT, BT, 16, 16, 16, 1}); + configs.push_back({dtype, AT, BT, 32, 16, 16, 1}); + configs.push_back({dtype, AT, BT, 16, 32, 16, 1}); + configs.push_back({dtype, AT, BT, 16, 16, 32, 1}); + configs.push_back({dtype, AT, BT, 32, 16, 32, 1}); + configs.push_back({dtype, AT, BT, 16, 32, 32, 1}); + if(dtype == HALF){ + configs.push_back({dtype, AT, BT, 16, 64, 16, 1}); + configs.push_back({dtype, AT, BT, 16, 16, 64, 1}); + configs.push_back({dtype, AT, BT, 64, 16, 64, 1}); + configs.push_back({dtype, AT, BT, 16, 64, 64, 1}); + } + // 2 warps + configs.push_back({dtype, AT, BT, 64, 32, 64, 2}); + configs.push_back({dtype, AT, BT, 32, 64, 64, 2}); + configs.push_back({dtype, AT, BT, 64, 32, 16, 2}); + configs.push_back({dtype, AT, BT, 32, 64, 16, 2}); + configs.push_back({dtype, AT, BT, 128, 32, 32, 2}); + configs.push_back({dtype, AT, BT, 32, 128, 32, 2}); + // 4 warps + configs.push_back({dtype, AT, BT, 128, 64, 16, 4}); + configs.push_back({dtype, AT, BT, 64, 128, 16, 4}); + configs.push_back({dtype, AT, BT, 128, 32, 32, 4}); + configs.push_back({dtype, AT, BT, 32, 128, 32, 4}); + if(dtype == HALF){ + configs.push_back({dtype, AT, BT, 128, 32, 64, 4}); + configs.push_back({dtype, AT, BT, 32, 128, 64, 4}); + } + // 8 warps + configs.push_back({dtype, AT, BT, 128, 256, 16, 8}); + configs.push_back({dtype, AT, BT, 256, 128, 16, 8}); + if(dtype == HALF){ + configs.push_back({dtype, AT, BT, 256, 128, 32, 8}); + configs.push_back({dtype, AT, BT, 256, 128, 32, 8}); + } + + }; // test dtype_t dtype; bool AT, BT; int M, N, K, TM, TN, TK, nwarp; for(const auto& c: configs){ - std::tie(dtype, AT, BT, M, N, K, TM, TN, TK, nwarp) = c; + std::tie(dtype, AT, BT, TM, TN, TK, nwarp) = c; + M = TM; + N = TN; + K = TK; std::cout << "Testing " << c << " ... " << std::flush; if(test_dot(context, stream, dtype, AT, BT, M, N, K, {0, 1}, {0, 1}, TM, TN, TK, (size_t)nwarp)) std::cout << " Pass! " << std::endl; diff --git a/tests/unit/reduce.cc b/tests/unit/reduce.cc index 24f760c1f..757582712 100644 --- a/tests/unit/reduce.cc +++ b/tests/unit/reduce.cc @@ -20,12 +20,15 @@ int main() { // shapes to benchmark typedef std::tuple, int, reduce_op_t> config_t; std::vector configs = { - config_t{{8, 8, 4}, 2, ADD}, - config_t{{32}, 0, ADD}, + config_t{{64}, 0, ADD}, + config_t{{128}, 0, MIN}, config_t{{32, 32}, 0, MAX}, config_t{{32, 32}, 1, ADD}, config_t{{32, 64}, 0, ADD}, - config_t{{64, 32}, 1, ADD} + config_t{{64, 32}, 1, ADD}, + config_t{{8, 8, 4}, 2, ADD}, + config_t{{8, 8, 4}, 0, ADD}, + config_t{{8, 8, 4}, 1, ADD} }; // does the work int axis;