diff --git a/include/triton/codegen/instructions.h b/include/triton/codegen/instructions.h deleted file mode 100644 index c42abee4a..000000000 --- a/include/triton/codegen/instructions.h +++ /dev/null @@ -1,83 +0,0 @@ -#ifndef _TRITON_CODEGEN_INSTRUCTIONS_H_ -#define _TRITON_CODEGEN_INSTRUCTIONS_H_ - -#include "triton/ir/enums.h" -#include -#include - -namespace triton{ - -namespace ir{ -class instruction; -} - -namespace codegen{ - -enum storage_info_t { - NONE, - ANY, - SHARED, - DISTRIBUTED, - REPLICATED -}; - -typedef std::pair> inst_storage_info_t; -static const std::map storage_info = { - // scalars - { ir::INST_GET_PROGRAM_ID, {REPLICATED, {}}}, - { ir::INST_GET_NUM_PROGRAMS, {REPLICATED, {}}}, - // scalar/array - { ir::INST_PHI, {ANY, {ANY, ANY}}}, - { ir::INST_BINOP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, - { ir::INST_GETELEMENTPTR, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, - { ir::INST_SELECT, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}}, - { ir::INST_SQRT, {DISTRIBUTED, {DISTRIBUTED}}}, - // cmp - { ir::INST_ICMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, - { ir::INST_FCMP, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, - // cast - { ir::INST_CAST_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_ZEXT, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_SEXT, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_FP_TRUNC, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_FP_EXT, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_UI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_SI_TO_FP, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_FP_TO_UI, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_FP_TO_SI, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_PTR_TO_INT, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_INT_TO_PTR, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_BIT_CAST, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_CAST_ADDR_SPACE_CAST, {DISTRIBUTED, {DISTRIBUTED}}}, - // io - { ir::INST_UNMASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_MASKED_LOAD, {DISTRIBUTED, {DISTRIBUTED, DISTRIBUTED}}}, - { ir::INST_UNMASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED}}}, - { ir::INST_MASKED_STORE, {NONE , {DISTRIBUTED, DISTRIBUTED, DISTRIBUTED}}}, - // retile - { ir::INST_RESHAPE, {DISTRIBUTED, {DISTRIBUTED}}}, - { ir::INST_SPLAT, {DISTRIBUTED, {REPLICATED}}}, - { ir::INST_BROADCAST, {DISTRIBUTED, {REPLICATED}}}, - { ir::INST_DOWNCAST, {DISTRIBUTED, {REPLICATED}}}, - // array arithmetic - { ir::INST_TRANS, {SHARED, {SHARED}}}, - { ir::INST_REDUCE, {SHARED, {DISTRIBUTED}}}, - { ir::INST_DOT, {DISTRIBUTED, {SHARED, SHARED, DISTRIBUTED}}}, - // terminator - { ir::INST_RETURN, {NONE, {}}}, - { ir::INST_UNCOND_BRANCH, {NONE, {}}}, - { ir::INST_COND_BRANCH, {NONE, {REPLICATED}}}, - // intrinsics - { ir::INST_COPY_TO_SHARED, {SHARED, {DISTRIBUTED}}}, - { ir::INST_COPY_FROM_SHARED, {DISTRIBUTED, {SHARED}}}, - { ir::INST_BARRIER, {NONE, {}}}, - { ir::INST_MAKE_RANGE_DYN, {DISTRIBUTED, {}}}, - { ir::INST_MAKE_RANGE_STA, {DISTRIBUTED, {}}}, - { ir::INST_MAKE_RANGE, {DISTRIBUTED, {}}} -}; - - -} -} - -#endif diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index 3e6c0bacb..1f18bc6e1 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -122,6 +122,7 @@ public: void visit_reduce_inst(ir::reduce_inst*); void visit_select_inst(ir::select_inst*); + void visit_recoalesce_inst(ir::recoalesce_inst*); void visit_copy_to_shared_inst(ir::copy_to_shared_inst*); void visit_copy_from_shared_inst(ir::copy_from_shared_inst*); void visit_barrier_inst(ir::barrier_inst*); diff --git a/include/triton/codegen/selection/machine_value.h b/include/triton/codegen/selection/machine_value.h index 508881fd3..917151971 100644 --- a/include/triton/codegen/selection/machine_value.h +++ b/include/triton/codegen/selection/machine_value.h @@ -128,22 +128,22 @@ private: Type *make_vector_ty(Type *ty, size_t vector_size); public: - distributed_tile(Type *ty, const shapes_t& shapes, const std::vector& order, const axes_t &axes, Builder &builder, bool vectorize); + distributed_tile(Type *ty, const shapes_t& shapes, const std::vector& order, const axes_t &axes, Builder &builder); void set_value(indices_t idx, Value *v); Value* get_value(indices_t idx); const std::vector& get_order() { return order_; } unsigned get_linear_index(indices_t idx); indices_t get_ordered_indices(unsigned id); - void for_each(std::function fn); - const distributed_axis &axis(unsigned dim) { return axes_.at(dim); } + void for_each(std::function fn, int start = 0, int end = -1); + void for_each(std::function fn, std::vector start, std::vector size); + const distributed_axis &axis(unsigned dim) { return axes_.at(dim); } private: axes_t axes_; std::vector order_; indices_map_t indices_; values_map_t values_; ordered_indices_vec_t ordered_indices_; - size_t vector_size_; Builder &builder_; }; diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index 4f277d99b..d62142c7d 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -45,6 +45,8 @@ public: llvm::SmallVectorImpl &buffer, const std::string &features, file_type_t file_type); + virtual std::unique_ptr symbol(const char * name) const = 0; + protected: driver::context* ctx_; @@ -54,13 +56,14 @@ protected: class host_module: public module{ public: host_module(driver::context* context, std::unique_ptr module); + std::unique_ptr symbol(const char * name) const; }; // OpenCL class ocl_module: public module{ - public: ocl_module(driver::context* context, std::unique_ptr module); + std::unique_ptr symbol(const char * name) const; }; // CUDA @@ -70,7 +73,7 @@ class cu_module: public module { public: cu_module(driver::context* context, std::unique_ptr module); cu_module(driver::context* context, const std::string& source); - cu_buffer* symbol(const char * name) const; + std::unique_ptr symbol(const char * name) const; private: std::string source_; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 62690b11e..d5782e3a1 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -37,6 +37,7 @@ public: // Constants value *get_int32(unsigned val); // Types + type *get_void_ty(); type *get_int1_ty(); type *get_int8_ty(); type *get_int16_ty(); @@ -115,10 +116,10 @@ public: value *create_and(value *lhs, value *rhs, const std::string &name = ""); value *create_xor(value *lhs, value *rhs, const std::string &name = ""); value *create_or(value *lhs, value *rhs, const std::string &name = ""); - // Side effects - value *create_fneg(value *arg, const std::string &name = ""); - value *create_neg(value *arg, const std::string &name = ""); - value *create_not(value *arg, const std::string &name = ""); + // Unary +// value *create_fneg(value *arg, const std::string &name = ""); +// value *create_neg(value *arg, const std::string &name = ""); +// value *create_not(value *arg, const std::string &name = ""); // Input/Output value *create_load(value *arg, const std::string &name = ""); value *create_store(value *ptr, value *val, const std::string &name = ""); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 94c74c085..491d37edf 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -134,6 +134,7 @@ enum value_id_t: unsigned { // intrinsics INST_COPY_TO_SHARED, INST_COPY_FROM_SHARED, + INST_RECOALESCE, INST_BARRIER, INST_MAKE_RANGE_DYN, INST_MAKE_RANGE_STA, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 7c6b0465d..41eb98eb3 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -148,9 +148,9 @@ public: // Factory methods static binary_operator *create(binary_op_t op, value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr); - static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr); - static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr); - static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr); +// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr); +// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr); +// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr); _TRITON_DEFINE_CLONE(binary_operator) _TRITON_DEFINE_ACCEPT(binary_operator) @@ -732,6 +732,17 @@ public: _TRITON_DEFINE_ACCEPT(copy_from_shared_inst) }; +class recoalesce_inst: public unary_inst{ +private: + using unary_inst::unary_inst; + std::string repr_impl() const { return "recoalesce_inst"; } + +public: + static recoalesce_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr); + _TRITON_DEFINE_CLONE(recoalesce_inst) + _TRITON_DEFINE_ACCEPT(recoalesce_inst) +}; + class barrier_inst: public instruction{ private: barrier_inst(context &ctx, const std::string &name, instruction *next); diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 62e63e6c4..b5941b88f 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -59,6 +59,7 @@ class sqrt_inst; class reduce_inst; class select_inst; +class recoalesce_inst; class copy_to_shared_inst; class copy_from_shared_inst; class barrier_inst; @@ -129,6 +130,7 @@ public: virtual void visit_reduce_inst(reduce_inst*) = 0; virtual void visit_select_inst(select_inst*) = 0; + virtual void visit_recoalesce_inst(recoalesce_inst*) = 0; virtual void visit_copy_to_shared_inst(copy_to_shared_inst*) = 0; virtual void visit_copy_from_shared_inst(copy_from_shared_inst*) = 0; virtual void visit_barrier_inst(barrier_inst*) = 0; diff --git a/include/triton/lang/code_gen.h b/include/triton/lang/code_gen.h index 96a02ce9a..a29cf268b 100644 --- a/include/triton/lang/code_gen.h +++ b/include/triton/lang/code_gen.h @@ -47,6 +47,7 @@ protected: }; void set_ret(ir::value* value); + ir::value *GenUnaryMinus(ir::value* arg); public: Generator(Parser* parser) : parser_(parser) {} diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h index f11d08fc8..1b2868849 100644 --- a/include/triton/lang/token.h +++ b/include/triton/lang/token.h @@ -145,6 +145,7 @@ public: THREAD, // _Thread_local AUTO, GLOBAL, + CMEM, // constant memory // STORAGE CLASS SPECIFIER END BREAK, diff --git a/include/triton/lang/type.h b/include/triton/lang/type.h index 0985ba5e1..59ea8eb3f 100644 --- a/include/triton/lang/type.h +++ b/include/triton/lang/type.h @@ -39,7 +39,7 @@ enum { S_EXTERN = 0x02, S_STATIC = 0x04, S_THREAD = 0x08, - S_AUTO = 0x10, + S_CONSTANT = 0x10, S_GLOBAL = 0x20, // Type specifier @@ -73,7 +73,8 @@ struct Qualifier { CONST = 0x01, RESTRICT = 0x02, VOLATILE = 0x04, - MASK = CONST | RESTRICT | VOLATILE + CMEM = 0x08, + MASK = CONST | RESTRICT | VOLATILE | CMEM }; }; @@ -111,6 +112,7 @@ public: bool IsConstQualified() const { return ptr_ & Qualifier::CONST; } bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; } bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; } + bool IsConstantQualified() const { return ptr_ & Qualifier::CMEM; } private: intptr_t ptr_; diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index a6ab851a9..26253b7ee 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -8,11 +8,9 @@ #include #include #include -#include // codegen #include "triton/ir/context.h" #include "triton/codegen/target.h" -#include "triton/lang/parser.h" #include "triton/runtime/arg.h" namespace llvm { @@ -20,6 +18,8 @@ namespace llvm { class LLVMContext; } +class Parser; + namespace triton { namespace driver{ @@ -106,14 +106,14 @@ public: function(const std::string& src, const options_space_t& opt = options_space_t()); void operator()(const std::vector& args, const grid_t& grid, driver::stream* stream); void operator()(const std::vector& args, const grid_fn_ty& grid, driver::stream *stream); - std::string make_tensorflow_src(const std::vector &outputs, const std::string ¯o); + void set_cst(const std::string& name, void* data, size_t n_bytes); private: ir::context ctx_; std::string src_; options_space_t opt_space_; std::map cache_; - std::mutex src_mutex_; + std::map> cst_; }; } diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 48a4ab972..430418b27 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -30,7 +30,7 @@ private: high_resolution_clock::time_point _start; }; -inline double bench(std::function const & op, driver::stream * stream) +inline double bench(std::function const & op, driver::stream * stream, bool normalize = false) { // const driver::device * device = stream->context()->device(); timer tmr; @@ -38,9 +38,10 @@ inline double bench(std::function const & op, driver::stream * stream) double total_time = 0; op(); stream->synchronize(); - while(total_time*1e-9 < 1e-2){ + while(total_time*1e-9 < 1e-1){ float norm = 1; // normalize clock if possible to reduce noise in auto-tuning + if(normalize) if(auto cu_device = dynamic_cast(stream->context()->device())) norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); tmr.start(); diff --git a/lib/codegen/analysis/align.cc b/lib/codegen/analysis/align.cc index ec692d6f6..8e9ba699c 100644 --- a/lib/codegen/analysis/align.cc +++ b/lib/codegen/analysis/align.cc @@ -5,24 +5,41 @@ #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" +#include namespace triton { namespace codegen{ namespace analysis{ -inline int gcd(int a, int b) { - if (a == 0) - return b; - if (b == 0) - return a; - if (a == b) - return a; - if (a > b) - return gcd(a - b, b); - return gcd(a, b - a); +// Function for extended Euclidean Algorithm +int gcd_impl(int a, int b, int *x, int *y) +{ + // Base Case + if (a == 0) + { + *x = 0; + *y = 1; + return b; + } + + int x1, y1; // To store results of recursive call + int gcd = gcd_impl(b%a, a, &x1, &y1); + + // Update x and y using results of + // recursive call + *x = y1 - (b/a) * x1; + *y = x1; + + return gcd; } +int gcd(int a, int b) { + int x, y; + return gcd_impl(a, b, &x, &y); +} + + inline int lcm(int a, int b) { return (a * b) / gcd(a, b); } @@ -156,7 +173,7 @@ std::vector align::populate_is_constant(ir::value *v) { if(is_constant_.find(v) != is_constant_.end()) return is_constant_.at(v); if(auto *x = dynamic_cast(v)) - return add_to_cache(v, {cst_info{true, (unsigned)x->get_value()}}, is_constant_); + return add_to_cache(v, {cst_info{true, std::min(x->get_value(), 128)}}, is_constant_); if(dynamic_cast(v)) return add_to_cache(v, {cst_info{true, 0}}, is_constant_); if(auto *x = dynamic_cast(v)) @@ -448,7 +465,7 @@ std::vector align::populate_starting_multiple(ir::value *v){ if(auto *x = dynamic_cast(v)) return populate_starting_multiple_binop(x); if(auto *x = dynamic_cast(v)) - return add_to_cache(x, {(unsigned)x->get_value()}, starting_multiple_); + return add_to_cache(x, {std::min(x->get_value(), 128)}, starting_multiple_); if(auto *x = dynamic_cast(v)) return add_to_cache(x, {(unsigned)x->get_first()->get_value()}, starting_multiple_); if(auto *x = dynamic_cast(v)) @@ -484,6 +501,7 @@ void align::populate(ir::value *v) { populate_is_constant(v); populate_starting_multiple(v); populate_max_contiguous(v); + } void align::run(ir::module &mod) { diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 0e67877b9..a01ef9aa1 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -113,6 +113,7 @@ void axes::update_graph(ir::instruction *i) { case ir::INST_DOT: return update_graph_dot(i); case ir::INST_COPY_TO_SHARED: return update_graph_no_edge(i);; case ir::INST_COPY_FROM_SHARED: return update_graph_no_edge(i); + case ir::INST_RECOALESCE: return update_graph_no_edge(i); default: return update_graph_elementwise(i); } return; diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 70ca9e3b2..6d7c2dc9c 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -1,9 +1,9 @@ #include #include +#include #include "triton/codegen/analysis/axes.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/analysis/layout.h" -#include "triton/codegen/instructions.h" #include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/ir/utils.h" @@ -148,8 +148,11 @@ layout_t::layout_t(layout_type_t _type, extract_io_use(v, ptr); order.resize(axes.size()); std::iota(order.begin(), order.end(), 0); - for(ir::value *v: ptr){ - auto max_contiguous = align->contiguous(v); + auto largest = std::max_element(ptr.begin(), ptr.end(), [&](ir::value *x, ir::value *y){ + return x->get_type()->get_tile_rank() < y->get_type()->get_tile_rank(); + }); + if(*largest){ + auto max_contiguous = align->contiguous(*largest); std::sort(order.begin(), order.end(), [&](unsigned a, unsigned b) { return max_contiguous[a] > max_contiguous[b]; }); @@ -166,9 +169,8 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps, const std::vector& _shapes, const std::vector &values, ir::type *_ty, size_t _id, analysis::align* align): layout_t(HMMA_884, _axes, _shapes, values, _ty, _id, align) { - - unsigned shape_0 = shapes[order[0]]; - unsigned shape_1 = shapes[order[1]]; + unsigned shape_0 = shapes[0]; + unsigned shape_1 = shapes[1]; /* fragments per warp */ // try to make things as square as possible to maximize data re-use fpw = {1, 1, 1}; @@ -196,6 +198,7 @@ layout_hmma_884_t::layout_hmma_884_t(size_t num_warps, unsigned effective_num_warps = 1; for(size_t d = 0; d < shapes.size(); d++) effective_num_warps *= wpt[d]; + if(num_warps != effective_num_warps) throw std::runtime_error("cannot create a kernel with this amount of warps"); } @@ -213,20 +216,38 @@ layout_scanline_t::layout_scanline_t(size_t num_warps, unsigned num_threads = num_warps * 32; nts.resize(shapes.size()); mts.resize(shapes.size()); + bool is_dot = std::any_of(values.begin(), values.end(), + [&](ir::value* v) { return dynamic_cast(v); }); + + ir::value *ptr = nullptr; + for(ir::value *v: values) + for(ir::user *usr: v->get_users()) + if(auto *st = dynamic_cast(usr)) + ptr = st->get_pointer_operand(); + unsigned i = order[0]; - nts[i] = clamp(size / num_threads, 1, 4); + int contiguous = 4; + if(ptr) + contiguous = std::min(align->contiguous(ptr)[i], 4); + + nts[i] = clamp(size / num_threads, 1, std::min(contiguous, shapes[i])); mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]); - num_threads = num_threads / mts[i]; + size /= shapes[i]; + num_threads /= mts[i]; + if(is_dot) + nts[order[1]] = clamp(size / num_threads, 1, std::min(4, shapes[order[1]])); for(size_t d = 1; d < shapes.size(); d++){ i = order[d]; - nts[i] = 1; - mts[i] = clamp(num_threads, 1, shapes[i]); + if(d > 1 || !is_dot) + nts[i] = 1; + mts[i] = clamp(num_threads, 1, shapes[i] / nts[i]); num_threads = num_threads / mts[i]; } /* sanity check */ unsigned effective_num_threads = 1; for(size_t d = 0; d < shapes.size(); d++) effective_num_threads *= mts[d]; + if(num_warps * 32 != effective_num_threads) throw std::runtime_error("cannot create a kernel with this amount of warps"); } @@ -259,8 +280,8 @@ void extract_double_bufferable(ir::value *v, std::shared_ptr(value_0); ir::instruction *i_1 = dynamic_cast(value_1); if(!i_0 || !i_1 || - storage_info.at(i_0->get_id()).first != codegen::SHARED || - storage_info.at(i_1->get_id()).first != codegen::SHARED) + !dynamic_cast(i_0) || + !dynamic_cast(i_1) ) return; if(is_latch_1) res.reset(new double_buffer_info_t{value_0, value_1, phi}); @@ -284,10 +305,9 @@ layout_shared_t::layout_shared_t(const layout_t *arg, extract_double_bufferable(v, double_buffer); // order - if(arg->type == SCANLINE) - order = arg->order; - else - order = arg->order; + std::vector arg_order = arg ? arg->order : std::vector{0}; + order = arg_order; + ir::value* dot_a = nullptr; ir::value* dot_b = nullptr; ir::value* hmma_dot_a = nullptr; @@ -304,24 +324,27 @@ layout_shared_t::layout_shared_t(const layout_t *arg, col.push_back(s); row.push_back(s); } + + bool is_nonhmma_dot_a = dot_a && !hmma_dot_a; bool is_nonhmma_dot_b = dot_b && !hmma_dot_b; if(is_nonhmma_dot_a) order = is_trans(dot_a) ? row : col; - if(is_nonhmma_dot_b) + else if(is_nonhmma_dot_b) order = is_trans(dot_b) ? col : row; - +// else +// order = row; // padding pad = 0; if(hmma_dot_a){ bool row = is_trans(hmma_dot_a) ^ order[0] != 0; - pad = 24 - shapes[row ? order[0] : order[1]] % 32; + pad = 24 - shapes[row ? 0 : 1] % 32; } else if(hmma_dot_b){ bool row = is_trans(hmma_dot_b) ^ order[0] != 0; - pad = 24 - shapes[row ? order[1] : order[0]] % 32; + pad = 24 - shapes[row ? 1 : 0] % 32; } - else if(order != arg->order) { + else if(order != arg_order) { pad = 4; } shapes[order[0]] += pad; @@ -395,6 +418,29 @@ void layout::run(ir::module &mod) { layouts_[id] = new layout_shared_t(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), id, align_); tmp_[red] = id; } + if(auto *recoalasce = dynamic_cast(i)){ + ir::value *val = recoalasce->get_operand(0); + const layout_t* in_layout = get(val); + const layout_t* out_layout = get(i); + if(in_layout->type != HMMA_884) + return; + id++; + ir::type::tile_shapes_t in_shape = val->get_type()->get_tile_shapes(); + ir::type::tile_shapes_t shape(in_shape.size()); + size_t ld = out_layout->order[0]; + shape[ld] = in_shape[ld]; + for(size_t k = 0; k < in_shape.size(); k++) + if(k != ld) + shape[k] = 4*in_layout->fpw[k]*in_layout->wpt[k]; + // create layout + layouts_[id] = new layout_shared_t(out_layout, axes_->get(val), shape, {recoalasce}, val->get_type()->get_scalar_ty(), id, align_); + tmp_[recoalasce] = id; + } + if(auto *atom = dynamic_cast(i)){ + id++; + layouts_[id] = new layout_shared_t(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), id, align_); + tmp_[atom] = id; + } }); } diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 6585af9be..4d4fe0b11 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -7,7 +7,6 @@ #include "triton/codegen/analysis/allocation.h" #include "triton/codegen/analysis/align.h" #include "triton/codegen/transform/coalesce.h" -#include "triton/codegen/instructions.h" #include "triton/ir/context.h" #include "triton/ir/module.h" #include "triton/ir/function.h" @@ -351,10 +350,9 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) { unsigned id = linear / vector_size; if(linear % vector_size == 0) { Value *ptr = pointers->get_value(idx); - - ptr = builder_->CreateBitCast(ptr, PointerType::get(VectorType::get(result->get_ty(), vector_size), ptr->getType()->getPointerAddressSpace())); + Value *mask = masks->get_value(idx); BasicBlock *current_bb = builder_->GetInsertBlock(); Function *parent = builder_->GetInsertBlock()->getParent(); @@ -386,9 +384,9 @@ void generator::visit_masked_load_inst(ir::masked_load_inst* x) { // if(cst) // offset = " + " + std::to_string(cst->getValue().getSExtValue()*2*vector_size); // Type *fp16x2_ty = VectorType::get(builder_->getHalfTy(), 2); -// Type *fp16x2_pack4_ty = StructType::get(ctx, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); +// Type *fp16x2_pack4_ty = StructType::get(*ctx_, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty}); // FunctionType *ty = FunctionType::get(fp16x2_pack4_ty, {mask->getType(), ptr->getType()}, false); -// std::string asm_str = "@$0 ld.global.nc.b32 {$1, $2, $3, $4}, [$5" + offset + "];"; +// std::string asm_str = "@$0 ld.global.nc.v4.b32 {$1, $2, $3, $4}, [$5" + offset + "];"; // if(false_values) // asm_str += "\n\t@!$0 mov.v4.b32 {$1, $2, $3, $4}, {0, 0, 0, 0};"; // InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,=r,=r,=r,=r,l", true); @@ -420,31 +418,83 @@ void generator::visit_unmasked_store_inst(ir::unmasked_store_inst* st) { void generator::visit_masked_store_inst(ir::masked_store_inst* st) { distributed_tile* ptrs = (distributed_tile*)tmap_.at(st->get_pointer_operand()); - distributed_tile* scalars = (distributed_tile*)tmap_.at(st->get_value_operand()); - ir::value *mask = st->get_mask_operand(); - distributed_tile* preds = (distributed_tile*)tmap_.at(mask); - ptrs->for_each([&](indices_t idx){ - Value *scalar = scalars->get_value(idx); - Value *ptr = ptrs->get_value(idx); - Value *pred = preds->get_value(idx); - Function *parent = builder_->GetInsertBlock()->getParent(); - BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent); - BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent); - builder_->CreateCondBr(pred, mask_then_bb, mask_done_bb); - builder_->SetInsertPoint(mask_then_bb); - builder_->CreateStore(scalar, ptr); - builder_->CreateBr(mask_done_bb); - builder_->SetInsertPoint(mask_done_bb); -// std::string offset = ""; -// if(GetElementPtrInst *gep = dyn_cast(ptr)) -// if(gep->getNumIndices() == 1) -// if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ -// offset = " + " + std::to_string(cst->getValue().getSExtValue()*4); -// } -// FunctionType *ty = FunctionType::get(Type::getVoidTy(ctx), {pred->getType(), ptr->getType(), scalar->getType()}, false); -// std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;"; -// InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true); -// builder.CreateCall(iasm, {pred, ptr, scalar}); + distributed_tile* masks = (distributed_tile*)tmap_.at(st->get_mask_operand()); + // vector size + int vector_size = 1; + int ld = ptrs->get_order()[0]; + unsigned alignment = alignment_->get(st->get_pointer_operand(), ld); + vector_size = std::min(ptrs->axis(ld).contiguous, alignment); + // create packets + std::map packets; + ir::value *arg = st->get_value_operand(); + for_each(arg, [&](indices_t idx){ + distributed_tile* in = (distributed_tile*)tmap_.at(arg); + unsigned linear = in->get_linear_index(idx); + unsigned id = linear / vector_size; + Value *in_value = in->get_value(idx); + if(linear % vector_size == 0) + packets[id] = UndefValue::get(VectorType::get(in_value->getType(), vector_size)); + packets[id] = builder_->CreateInsertElement(packets.at(id), in_value, linear % vector_size); + }); + // write-back packets + for_each(arg, [&](indices_t idx){ + distributed_tile* in = (distributed_tile*)tmap_.at(arg); + unsigned linear = in->get_linear_index(idx); + unsigned id = linear / vector_size; + if(linear % vector_size == 0){ + // fetch tile elements + Value *elt = packets[id]; + Value *ptr = ptrs->get_value(idx); + Value *pred = masks->get_value(idx); + // type information + Type *ty = elt->getType(); + unsigned nbits = ty->getScalarSizeInBits(); + unsigned nbytes = nbits / 8; + // extract pointer offset + std::string offset = ""; + if(GetElementPtrInst *gep = dyn_cast(ptr)) + if(gep->getNumIndices() == 1) + if(ConstantInt *cst = dyn_cast(gep->idx_begin())){ + offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbytes); + ptr = gep->getPointerOperand(); + } + ptr = builder_->CreateBitCast(ptr, ty->getPointerTo(1)); + // asm argument type + std::vector arg_ty = {pred->getType(), ptr->getType()}; + for(int v = 0; v < vector_size; v++) + arg_ty.push_back(ty->getScalarType()); + // asm function type + FunctionType *fn_ty = FunctionType::get(builder_->getVoidTy(), arg_ty, false); + // asm string + std::string asm_str; + asm_str += "@$0 st.global"; + if(vector_size > 1) + asm_str += ".v" + std::to_string(vector_size); + asm_str += ".b" + std::to_string(nbits) + " [$1" + offset + "],"; + if(vector_size > 1) + asm_str += "{"; + for(int v = 0; v < vector_size; v++){ + if(v > 0) + asm_str += ", "; + asm_str += "$" + std::to_string(2 + v); + } + if(vector_size > 1) + asm_str += "}"; + asm_str += ";"; + // asm constraint + std::string constraint = "b,l"; + for(int v = 0; v < vector_size; v++){ + constraint += ","; + constraint += (nbits == 32 ? "r" : "h"); + } + // create inline asm + InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true); + // call asm + std::vector args = {pred, ptr}; + for(int v = 0; v < vector_size; v++) + args.push_back(builder_->CreateExtractElement(elt, builder_->getInt32(v))); + builder_->CreateCall(iasm, args); + } }); } @@ -504,23 +554,27 @@ void generator::visit_atomic_cas_inst(ir::atomic_cas_inst* cas) { Value *pred = builder_->CreateICmpEQ(tid, builder_->getInt32(0)); BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); - Value *ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(cas)))); - ptr = builder_->CreateBitCast(ptr, PointerType::get(builder_->getInt32Ty(), ptr->getType()->getPointerAddressSpace())); - tgt_->add_memfence(module, *builder_); tgt_->add_barrier(module, *builder_); + tgt_->add_memfence(module, *builder_); builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); Value *cas_ptr = vmap_.at(cas->get_operand(0)); Value *cas_cmp = vmap_.at(cas->get_operand(1)); Value *cas_val = vmap_.at(cas->get_operand(2)); - Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); + Value *old = builder_->CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, + AtomicOrdering::Monotonic, + AtomicOrdering::Monotonic); old = builder_->CreateExtractValue(old, {0}); - builder_->CreateStore(old, ptr); + Value *atom_ptr; + atom_ptr = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(cas))))); + atom_ptr = builder_->CreateBitCast(atom_ptr, PointerType::get(old->getType(), 3)); + + builder_->CreateStore(old, atom_ptr); builder_->CreateBr(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); tgt_->add_barrier(module, *builder_); - vmap_[cas] = builder_->CreateLoad(ptr); + vmap_[cas] = builder_->CreateLoad(atom_ptr); } void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { @@ -533,14 +587,14 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) { BasicBlock *tid_0_bb = BasicBlock::Create(*ctx_, "tid_0", current->getParent()); BasicBlock *tid_0_done_bb = BasicBlock::Create(*ctx_, "tid_0_done", current->getParent()); tgt_->add_memfence(module, *builder_); - tgt_->add_barrier(module, *builder_); builder_->CreateCondBr(pred, tid_0_bb, tid_0_done_bb); builder_->SetInsertPoint(tid_0_bb); - vmap_[xchg] = builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, AtomicOrdering::Monotonic, SyncScope::System); + builder_->CreateAtomicRMW(AtomicRMWInst::Xchg, rmw_ptr, rmw_val, + AtomicOrdering::Monotonic, + SyncScope::System); builder_->CreateBr(tid_0_done_bb); builder_->SetInsertPoint(tid_0_done_bb); tgt_->add_memfence(module, *builder_); - tgt_->add_barrier(module, *builder_); } void generator::visit_atomic_add_inst(ir::atomic_add_inst*) { @@ -861,6 +915,115 @@ void generator::visit_select_inst(ir::select_inst* select) { } +void generator::visit_recoalesce_inst(ir::recoalesce_inst* rc) { + ir::value *op = rc->get_operand(0); + ir::tile_type::tile_shapes_t shape = rc->get_type()->get_tile_shapes(); + size_t rank = shape.size(); + // temporary layout + shared_tile *tmp = (shared_tile*)machine_layouts_.at(layouts_->get(layouts_->tmp(rc))) + ->create(rc); + // pointer to temporary shared memory + Type *ty = llvm_type(rc->get_type()->get_scalar_ty(), *ctx_); + // layouts + const analysis::layout_t* in_layout = layouts_->get(op); + const analysis::layout_t* out_layout = layouts_->get(rc); + // machine tiles + distributed_tile *in_dt = (distributed_tile*)(tmap_.at(op)); + distributed_tile *out_dt = (distributed_tile*)(tmap_.at(rc)); + // WMMA configuration + long wmma_pt[3] = { 2, 4, 1}; + long wmma[3] = { 8*in_layout->wpt[0]*in_layout->fpw[0], + 8*in_layout->wpt[1]*in_layout->fpw[1], + 1}; + // Work per thread for input layout + long in_pt[3] = { shape[0] / wmma[0], + shape[1] / wmma[1], + 1 }; + // Work per thread for output layout + long out_pt[3] = { shape[0] / out_layout->mts[0], + shape[1] / out_layout->mts[1], + 1 }; + if(rank > 2){ + wmma[2] = in_layout->wpt[2]*in_layout->fpw[2]; + in_pt[2] = shape[2] / wmma[2]; + out_pt[2] = shape[2] / out_layout->mts[2]; + } + // Orders + auto ord = out_layout->order; + if(ord.size() < 3) + ord.push_back(2); + // pointer lanes + std::vector> ptrs; + for(int in_zz = 0; in_zz < wmma_pt[ord[2]]; in_zz++) { + std::vector current; + for(int in_cc = 0; in_cc < wmma_pt[ord[1]]; in_cc++) { + Value *base; + base = builder_->CreateGEP(sh_mem_ptr_, builder_->getInt32(alloc_->offset(layouts_->get(layouts_->tmp(rc))))); + base = builder_->CreateBitCast(base, PointerType::get(ty, 3)); + + // shared memory stride + Value *stride_0 = builder_->getInt32(tmp->get_shapes()[ord[0]]); + // indices + Value *idx_cc = axes_.at(a_axes_->get(op, ord[1])).values[in_cc]; + // offset + Value *off = builder_->CreateMul(stride_0, idx_cc); + if(rank > 2){ + Value *stride_1 = builder_->CreateMul(stride_0, + builder_->getInt32(tmp->get_shapes()[ord[1]])); + Value *idx_zz = axes_.at(a_axes_->get(op, ord[2])).values[in_zz]; + off = builder_->CreateAdd(off, builder_->CreateMul(stride_1, idx_zz)); + } + current.push_back(builder_->CreateGEP(base, off)); + } + ptrs.push_back(current); + } + // Re-coalesce loops + for(int in_z = 0; in_z < in_pt[ord[2]]; in_z++) + for(int in_c = 0; in_c < in_pt[ord[1]]; in_c++){ + // write to shared + tgt_->add_barrier(mod_, *builder_); + for(int in_zz = 0; in_zz < wmma_pt[ord[2]]; in_zz++) + for(int in_cc = 0; in_cc < wmma_pt[ord[1]]; in_cc++){ + std::vector starts(rank), len(rank); + starts[ord[0]] = 0; + starts[ord[1]] = in_c*wmma_pt[ord[1]] + in_cc; + len[ord[0]] = wmma_pt[ord[0]]*in_pt[ord[0]]; + len[ord[1]] = 1; + if(rank > 2){ + starts[ord[2]] = in_z*wmma_pt[ord[2]] + in_zz; + len[ord[2]] = 1; + } + in_dt->for_each([&](indices_t idx){ + Value *write_ptr = builder_->CreateGEP(ptrs[in_zz][in_cc], idx[ord[0]]); + builder_->CreateStore(in_dt->get_value(idx), write_ptr); + }, starts, len); + } + tgt_->add_barrier(mod_, *builder_); + // load from shared + for(int out_zz = 0; out_zz < out_pt[ord[2]] / in_pt[ord[2]]; out_zz++) + for(int out_cc = 0; out_cc < out_pt[ord[1]] / in_pt[ord[1]]; out_cc++){ + std::vector starts(rank), len(rank); + starts[ord[0]] = 0; + starts[ord[1]] = in_c*(out_pt[ord[1]] / in_pt[ord[1]]) + out_cc; + len[ord[0]] = out_pt[ord[0]]; + len[ord[1]] = 1; + if(rank > 2){ + starts[ord[2]] = in_z*(out_pt[ord[2]] / in_pt[ord[2]]) + out_zz; + len[ord[2]] = 1; + } + out_dt->for_each([&](indices_t idx){ + indices_t read_idx(rank); + read_idx[ord[0]] = idx[ord[0]]; + read_idx[ord[1]] = axes_.at(a_axes_->get(rc, ord[1])).values[out_cc]; + if(rank > 2) + read_idx[ord[2]] = axes_.at(a_axes_->get(rc, ord[2])).values[out_zz]; + out_dt->set_value(idx, tmp->get_value(read_idx)); + }, starts, len); + } + } + tgt_->add_barrier(mod_, *builder_); +} + void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) { unsigned vector_size = 1; auto x_order = layouts_->get(cts)->order; @@ -1126,16 +1289,14 @@ void generator::visit(ir::module &src, llvm::Module &dst) { if(tgt_->is_gpu()) if(unsigned alloc_size = alloc_->allocated_size()){ Type *int_8_ty = Type::getInt8Ty(*ctx_); - ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size); + Type *int_32_ty = Type::getInt32Ty(*ctx_); + ArrayType *array_ty = ArrayType::get(int_32_ty, alloc_size/4); Type *ptr_ty = PointerType::get(int_8_ty, 3); GlobalVariable *sh_mem_array = new GlobalVariable(*mod_, array_ty, false, GlobalVariable::ExternalLinkage, nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); sh_mem_ptr_ = builder_->CreateBitCast(sh_mem_array, ptr_ty); } - // allocate constant memory - for(ir::alloc_const *x: src.allocs()) - visit_alloc_const(x); // visit functions for(ir::function *fn: src.get_function_list()) visit_function(fn); diff --git a/lib/codegen/selection/machine_layout.cc b/lib/codegen/selection/machine_layout.cc index 1c026bfc8..2d02e7b1f 100644 --- a/lib/codegen/selection/machine_layout.cc +++ b/lib/codegen/selection/machine_layout.cc @@ -143,7 +143,7 @@ tile *machine_layout_distributed_t::create(ir::value *v) { axes[d].values = {builder_->getInt32(0)}; } } - return new distributed_tile(ty, shapes, layout_->order, axes, *builder_, false); + return new distributed_tile(ty, shapes, layout_->order, axes, *builder_); } machine_layout_hmma_884_t::machine_layout_hmma_884_t(Module *mod, Builder *builder, diff --git a/lib/codegen/selection/machine_value.cc b/lib/codegen/selection/machine_value.cc index a7cd73a8e..72aace4b2 100644 --- a/lib/codegen/selection/machine_value.cc +++ b/lib/codegen/selection/machine_value.cc @@ -45,9 +45,8 @@ llvm::Type *distributed_tile::make_vector_ty(llvm::Type *ty, size_t vector_size) return VectorType::get(ty, vector_size); } -distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector& order, const axes_t &axes, llvm::IRBuilder<> &builder, bool vectorize) - : tile(make_vector_ty(ty, vectorize?axes[0].contiguous:1), shapes), axes_(axes), order_(order), builder_(builder) { - vector_size_ = vectorize?ty_->getVectorNumElements():1; +distributed_tile::distributed_tile(Type *ty, const shapes_t &shapes, const std::vector& order, const axes_t &axes, llvm::IRBuilder<> &builder) + : tile(ty, shapes), axes_(axes), order_(order), builder_(builder) { init_indices(); } @@ -73,13 +72,31 @@ indices_t distributed_tile::get_ordered_indices(unsigned id) { } -void distributed_tile::for_each(std::function fn) { - for(unsigned i = 0; i < ordered_indices_.size(); i++){ - if(i % vector_size_ == 0) - fn(ordered_indices_[i]); +void distributed_tile::for_each(std::function fn, int start, int end) { + if(end < 0) + end = ordered_indices_.size() + end + 1; + for(unsigned i = start; i < end; i++) + fn(ordered_indices_[i]); +} + +void distributed_tile::for_each(std::function fn, std::vector starts, std::vector sizes){ + int rank = sizes.size(); + int len = 1; + for(int s: sizes) + len *= s; + + for(int i = 0; i < len; i++){ + indices_t idx(rank); + int current = i; + for(int k = 0; k < rank; k++){ + idx[k] = axes_[k].values.at(starts[k] + current % sizes[k]); + current = current / sizes[k]; + } + fn(idx); } } + /* Shared Tile */ void shared_tile::extract_constant(Value *arg, Value *&non_cst, Value *&cst) { BinaryOperator *bin_op = dyn_cast(arg); @@ -126,7 +143,9 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_ } -Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, const std::vector& perm, const std::vector& order, indices_t idx) { +Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, + const std::vector& perm, const std::vector& order, + indices_t idx) { // strides std::vector strides(order.size()); strides[order[0]] = builder.getInt32(1); diff --git a/lib/codegen/transform/coalesce.cc b/lib/codegen/transform/coalesce.cc index 764e2138a..78c03396f 100644 --- a/lib/codegen/transform/coalesce.cc +++ b/lib/codegen/transform/coalesce.cc @@ -1,6 +1,8 @@ #include +#include #include "triton/ir/utils.h" #include "triton/ir/instructions.h" +#include "triton/ir/function.h" #include "triton/ir/module.h" #include "triton/codegen/transform/coalesce.h" #include "triton/codegen/analysis/align.h" @@ -60,8 +62,43 @@ ir::value* coalesce::rematerialize(ir::value *x, ir::builder &builder, } void coalesce::run(ir::module &mod) { - // find values to rematerialize size_t num_groups = layout_->num_layouts(); + + for(size_t id = 0; id < num_groups; id++) { + if(layout_->get(id)->type != analysis::HMMA_884) + continue; + // extract memory stores + const auto& values = layout_->values_of(id); + ir::value* dot = nullptr; + for(ir::value *v: values) + if(auto x = dynamic_cast(v)) + dot = x; + + ir::builder& builder = mod.get_builder(); + std::vector worklist = {dot}; + std::set seen; + while(!worklist.empty()) { + ir::value *current = worklist.back(); + seen.insert(current); + worklist.pop_back(); + // stop if trunc + if(auto x = dynamic_cast(current)){ + builder.set_insert_point_after(x); + ir::recoalesce_inst* rc = ir::recoalesce_inst::create(x); + builder.insert(rc); + x->replace_all_uses_with(rc); + rc->replace_uses_of_with(rc, x); + break; + } + // recurse + for(ir::user *u: current->get_users()) + if(seen.find(u) == seen.end()) + worklist.push_back(u); + } + } + + + // find values to rematerialize std::vector remat; for(size_t id = 0; id < num_groups; id++) { const auto& values = layout_->values_of(id); @@ -71,8 +108,10 @@ void coalesce::run(ir::module &mod) { extract_io_use(v, io); // extract leading axes std::map> axes; - for(ir::io_inst *i: io) - extract_ld(i, axes); + for(ir::io_inst *i: io){ + if(i->get_pointer_operand()->get_type()->get_tile_ranks1() == layout_->get(id)->axes.size()) + extract_ld(i, axes); + } // update list of values to rematerialize if(axes.empty()) continue; diff --git a/lib/codegen/transform/cts.cc b/lib/codegen/transform/cts.cc index 47a1e13a8..f98b685e1 100644 --- a/lib/codegen/transform/cts.cc +++ b/lib/codegen/transform/cts.cc @@ -1,21 +1,37 @@ #include "triton/codegen/transform/cts.h" -#include "triton/codegen/instructions.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" #include "triton/ir/instructions.h" +#include namespace triton { namespace codegen{ namespace transform{ -inline bool is_shared(ir::value *v) { - auto *i = dynamic_cast(v); + +inline bool is_shmem_op(ir::instruction* i, int op) { + if(i->get_id() == ir::INST_DOT) + return op==0 || op==1; + if(i->get_id() == ir::INST_COPY_FROM_SHARED) + return op==0; + return false; +} + +inline bool is_shmem_res(ir::value* v){ + ir::instruction* i = dynamic_cast(v); if(!i) return false; - return storage_info.at(i->get_id()).first == codegen::SHARED; + if(i->get_id() == ir::INST_TRANS) + return true; + if(i->get_id() == ir::INST_REDUCE) + return true; + if(i->get_id() == ir::INST_COPY_TO_SHARED) + return true; + return false; } + // run pass on module void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool to_shared) { auto *i = dynamic_cast(x); @@ -36,9 +52,8 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool add_copy(phi, phi->get_incoming_value(i), builder, to_shared); return; } - ir::value_id_t id = i->get_id(); // already in shared memory - if(to_shared && storage_info.at(id).first == SHARED) + if(to_shared && is_shmem_res(i)) return; // copy builder.set_insert_point_after(i); @@ -53,18 +68,19 @@ void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder, bool void cts::run(ir::module &mod) { // Add shared copies ir::builder &builder = mod.get_builder(); - for(ir::function *fn: mod.get_function_list()){ - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()){ - auto storage = storage_info.at(i->get_id()); + for(ir::function* fn: mod.get_function_list()){ + for(ir::basic_block* block: fn->blocks()) + for(ir::instruction* i: block->get_inst_list()){ + size_t num_op = i->get_num_operands(); // copy to shared operands - for(size_t k = 0; k < storage.second.size(); k++) - if(storage.second[k] == SHARED) + for(size_t k = 0; k < num_op; k++) + if(is_shmem_op(i, k)) add_copy(i, i->get_operand(k), builder, true); // copy from shared operands - for(size_t k = 0; k < storage.second.size(); k++) - if(storage.second[k] == DISTRIBUTED && - is_shared(i->get_operand(k))){ + for(size_t k = 0; k < num_op; k++) + if(!dynamic_cast(i) && + !is_shmem_op(i,k) && + is_shmem_res(i->get_operand(k))){ add_copy(i, i->get_operand(k), builder, false); } } diff --git a/lib/codegen/transform/membar.cc b/lib/codegen/transform/membar.cc index 44316d504..8cb48f7df 100644 --- a/lib/codegen/transform/membar.cc +++ b/lib/codegen/transform/membar.cc @@ -3,7 +3,6 @@ #include #include "triton/codegen/analysis/layout.h" #include "triton/codegen/analysis/allocation.h" -#include "triton/codegen/instructions.h" #include "triton/codegen/transform/membar.h" #include "triton/ir/module.h" #include "triton/ir/function.h" diff --git a/lib/driver/module.cc b/lib/driver/module.cc index ddeb20bfc..28940f563 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -180,6 +180,11 @@ host_module::host_module(driver::context * context, std::unique_ptrengine = builder.create(); } +std::unique_ptr host_module::symbol(const char *name) const { + throw std::runtime_error("not implemented"); +} + + /* ------------------------ */ // OpenCL // /* ------------------------ */ @@ -211,10 +216,21 @@ ocl_module::ocl_module(driver::context * context, std::unique_ptr // } } +std::unique_ptr ocl_module::symbol(const char *name) const { + throw std::runtime_error("not implemented"); +} /* ------------------------ */ // CUDA // /* ------------------------ */ +static bool find_and_replace(std::string& str, const std::string& begin, const std::string& end, const std::string& target){ + size_t start_replace = str.find(begin); + size_t end_replace = str.find(end, start_replace); + if(start_replace == std::string::npos) + return false; + str.replace(start_replace, end_replace + 1 - start_replace, target); + return true; +} std::string cu_module::compile_llvm_module(std::unique_ptr module, driver::device* device) { // options @@ -231,19 +247,17 @@ std::string cu_module::compile_llvm_module(std::unique_ptr module, llvm::SmallVector buffer; module::compile_llvm_module(std::move(module), "nvptx64-nvidia-cuda", sm, "", buffer, "ptx63", Assembly); std::string result(buffer.begin(), buffer.end()); - size_t start_replace = result.find(".version"); - size_t end_replace = result.find('\n', start_replace); - assert(start_replace != std::string::npos); - result.replace(start_replace, end_replace - start_replace, ".version 6.4"); + find_and_replace(result, ".version", "\n", ".version 6.4\n"); + while(find_and_replace(result, "\t// begin inline asm", "\n", "")); + while(find_and_replace(result, "\t// end inline asm", "\n", "")); return result; } cu_module::cu_module(driver::context * context, std::unique_ptr ll_module): cu_module(context, compile_llvm_module(std::move(ll_module), context->device())) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ -// exit(EXIT_FAILURE); -// std::cout << source << std::endl; cu_context::context_switcher ctx(*context); +// std::cout << source << std::endl; // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; unsigned int errbufsize = 8096; @@ -260,11 +274,12 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo } } -cu_buffer* cu_module::symbol(const char *name) const{ +std::unique_ptr cu_module::symbol(const char *name) const{ CUdeviceptr handle; size_t size; dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name); - return new cu_buffer(ctx_, size, handle, false); + std::unique_ptr res(new cu_buffer(ctx_, size, handle, false)); + return std::move(res); } diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index a2ea9d30b..b1e417e5f 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -48,6 +48,9 @@ value *builder::get_int32(unsigned val) { return constant_int::get(type::get_int32_ty(ctx_), val); } +type *builder::get_void_ty() +{ return type::get_void_ty(ctx_); } + type *builder::get_int1_ty() { return type::get_int1_ty(ctx_); } @@ -132,19 +135,12 @@ phi_node* builder::create_phi(type *ty, unsigned num_reserved, const std::string return insert(binary_operator::create(OPCODE, lhs, rhs), name);\ } -#define DEFINE_UNARY_FLOAT(SUFFIX)\ - value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\ - return insert(binary_operator::create_ ## SUFFIX(arg), name);\ - } - // Binary DEFINE_BINARY_FLOAT(fmul, binary_op_t::FMul) DEFINE_BINARY_FLOAT(fdiv, binary_op_t::FDiv) DEFINE_BINARY_FLOAT(frem, binary_op_t::FRem) DEFINE_BINARY_FLOAT(fadd, binary_op_t::FAdd) DEFINE_BINARY_FLOAT(fsub, binary_op_t::FSub) -// Unary -DEFINE_UNARY_FLOAT(fneg) //===----------------------------------------------------------------------===// @@ -171,10 +167,7 @@ value* builder::create_insert_nuwnswb_binop(binary_op_t op, value *lhs, return create_insert_nuwnswb_binop(OPCODE, lhs, rhs, name, false, false);\ } -#define DEFINE_UNARY_INT(SUFFIX)\ - value *builder::create_ ## SUFFIX(value *arg, const std::string &name){\ - return insert(binary_operator::create_ ## SUFFIX(arg), name);\ - } + // Binary DEFINE_NOWRAP_BINARY(mul, binary_op_t::Mul) @@ -190,9 +183,6 @@ DEFINE_BINARY_INT(urem, binary_op_t::URem) DEFINE_BINARY_INT(and, binary_op_t::And) DEFINE_BINARY_INT(or, binary_op_t::Or) DEFINE_BINARY_INT(xor, binary_op_t::Xor) -// Unary -DEFINE_UNARY_INT(neg) -DEFINE_UNARY_INT(not) //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index 0be815a51..930c4a116 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -138,23 +138,23 @@ binary_operator *binary_operator::create(binary_op_t op, value *lhs, value *rhs, return new binary_operator(op, lhs, rhs, lhs->get_type(), name, next); } -binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){ - assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty()); - value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()); - return binary_operator::create(binary_op_t::FSub, zero, arg, name, next); -} +//binary_operator *binary_operator::create_fneg(value *arg, const std::string &name, instruction *next){ +// assert(arg->get_type()->get_scalar_ty()->is_floating_point_ty()); +// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()); +// return binary_operator::create(binary_op_t::FSub, zero, arg, name, next); +//} -binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){ - assert(arg->get_type()->get_scalar_ty()->is_integer_ty()); - value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()); - return binary_operator::create(binary_op_t::Sub, zero, arg, name, next); -} +//binary_operator *binary_operator::create_neg(value *arg, const std::string &name, instruction *next){ +// assert(arg->get_type()->get_scalar_ty()->is_integer_ty()); +// value *zero = constant_fp::get_zero_value_for_negation(arg->get_type()->get_scalar_ty()); +// return binary_operator::create(binary_op_t::Sub, zero, arg, name, next); +//} -binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){ - assert(arg->get_type()->is_integer_ty()); - constant *mask = constant::get_all_ones_value(arg->get_type()); - return binary_operator::create(binary_op_t::Xor, arg, mask, name, next); -} +//binary_operator *binary_operator::create_not(value *arg, const std::string &name, instruction *next){ +// assert(arg->get_type()->is_integer_ty()); +// constant *mask = constant::get_all_ones_value(arg->get_type()); +// return binary_operator::create(binary_op_t::Xor, arg, mask, name, next); +//} //===----------------------------------------------------------------------===// // cmp_inst classes @@ -762,6 +762,12 @@ copy_from_shared_inst* copy_from_shared_inst::create(value *arg, const std::stri return new copy_from_shared_inst(arg->get_type(), INST_COPY_FROM_SHARED, arg, name, next); } +// recoalesce +recoalesce_inst* recoalesce_inst::create(value *arg, const std::string &name, instruction *next) { + return new recoalesce_inst(arg->get_type(), INST_RECOALESCE, arg, name, next); +} + + // barrier barrier_inst::barrier_inst(context &ctx, const std::string &name, diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index d13f68856..8bbf39081 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -57,7 +57,10 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { } case Token::MASKED_DEREF: { ir::type* ret_ty = GenIRType(binary->Type(), *ctx_); - return set_ret(bld_->create_masked_load(rhs, lhs, ir::undef_value::get(ret_ty))); + ir::value* false_value = ir::undef_value::get(ret_ty->get_scalar_ty()); + if(ret_ty->is_tile_ty()) + false_value = bld_->create_splat(false_value, ret_ty->get_tile_shapes()); + return set_ret(bld_->create_masked_load(rhs, lhs, false_value)); } case Token::ELLIPSIS: { auto clhs = dynamic_cast(lhs); @@ -76,7 +79,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { return set_ret(bld_->create_add(lhs, rhs)); case '-': if(binary->lhs_->Type()->ToPointer()) - return set_ret(bld_->create_gep(lhs, {bld_->create_neg(rhs)})); + return set_ret(bld_->create_gep(lhs, {GenUnaryMinus(rhs)})); else if(flt) return set_ret(bld_->create_fsub(lhs, rhs)); else @@ -147,7 +150,7 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { if(flt) return set_ret(bld_->create_fcmpONE(lhs, rhs)); else - return set_ret(bld_->create_icmpEQ(lhs, rhs)); + return set_ret(bld_->create_icmpNE(lhs, rhs)); default: error_not_implemented(); } @@ -166,6 +169,16 @@ ir::reduce_inst::op_t reduce_op(int tag, bool is_float) { should_not_happen(); return reduce_inst::op_t(); } + +ir::value* Generator::GenUnaryMinus(ir::value* arg) { + ir::type *ty = arg->get_type(); + ir::type *sca_ty = ty->get_scalar_ty(); + ir::value *_0 = ir::constant_fp::get_zero_value_for_negation(sca_ty); + if(ty->is_tile_ty()) + _0 = bld_->create_splat(_0, ty->get_tile_shapes()); + return bld_->create_sub(_0, arg); +} + void Generator::VisitUnaryOp(UnaryOp* unary) { // recursion Visit(unary->operand_); @@ -174,17 +187,17 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { ir::type *arg_scal_ty = arg_ty->get_scalar_ty(); // return switch (unary->op_) { - case Token::PREFIX_INC: return error_not_implemented(); - case Token::PREFIX_DEC: return error_not_implemented(); + case Token::PREFIX_INC: return error_not_implemented(); + case Token::PREFIX_DEC: return error_not_implemented(); case Token::POSTFIX_INC: return error_not_implemented(); case Token::POSTFIX_DEC: return error_not_implemented(); - case Token::ADDR: return error_not_implemented(); - case Token::DEREF: return set_ret(bld_->create_load(arg)); - case Token::PLUS: return error_not_implemented(); - case Token::MINUS: return error_not_implemented(); - case '~': return set_ret(bld_->create_neg(arg)); - case '!': return set_ret(bld_->create_not(arg)); - case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_))); + case Token::ADDR: return error_not_implemented(); + case Token::DEREF: return set_ret(bld_->create_load(arg)); + case Token::PLUS: return error_not_implemented(); + case Token::MINUS: return set_ret(GenUnaryMinus(arg)); + case '~': return error_not_implemented(); + case '!': return error_not_implemented(); + case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_))); case Token::REDUCE: { int ax, tag; UnaryOp::decodeRed(unary->info_, ax, tag); @@ -232,11 +245,54 @@ void Generator::VisitFuncCall(FuncCall* funcCall) { else return should_not_happen(); } + if(name == "get_num_programs"){ + VisitExpr(funcCall->Args()->at(0)); + ir::value* ret = ret_; + if(auto axis = dynamic_cast(ret)) + return set_ret(bld_->create_get_num_program(axis->get_value())); + else + return should_not_happen(); + } + if(name == "atomic_cas"){ + VisitExpr(funcCall->Args()->at(0)); + ir::value* ptr = ret_; + VisitExpr(funcCall->Args()->at(1)); + ir::value* cmp = ret_; + VisitExpr(funcCall->Args()->at(2)); + ir::value* val = ret_; + return set_ret(bld_->create_atomic_cas(ptr, cmp, val)); + } + if(name == "atomic_xchg"){ + VisitExpr(funcCall->Args()->at(0)); + ir::value* ptr = ret_; + VisitExpr(funcCall->Args()->at(1)); + ir::value* val = ret_; + return set_ret(bld_->create_atomic_exch(ptr, val)); + } if(name == "sqrtf"){ VisitExpr(funcCall->Args()->at(0)); ir::value* ret = ret_; return set_ret(bld_->create_sqrt(ret)); } + if(name == "calloc"){ + VisitExpr(funcCall->Args()->at(0)); + ir::value* ret = ret_; + ir::constant_int *size = dynamic_cast(ret); + assert(size); + ir::alloc_const* alloc = new ir::alloc_const(bld_->get_int8_ty(), size); + mod_->add_alloc(alloc); + return set_ret(alloc); + } + //TODO: integrate this into conditionalop + if(name == "select"){ + VisitExpr(funcCall->Args()->at(0)); + ir::value* cond = ret_; + VisitExpr(funcCall->Args()->at(1)); + ir::value* true_val = ret_; + VisitExpr(funcCall->Args()->at(2)); + ir::value* false_val = ret_; + return set_ret(bld_->create_select(cond, true_val, false_val)); + } return error_not_implemented(); } @@ -350,12 +406,15 @@ void Generator::VisitForStmt(ForStmt *forStmt) { ir::value *cond = ret_; return bld_->create_cond_br(cond, loop_bb, next_bb); }); - VisitStmt(init_); - VisitExpr(cond_); - ir::value *cond = ret_; - bld_->create_cond_br(cond, loop_bb, next_bb); + if(init_) + VisitStmt(init_); +// VisitExpr(cond_); +// ir::value *cond = ret_; +// bld_->create_cond_br(cond, loop_bb, next_bb); + bld_->create_br(loop_bb); bld_->set_insert_point(loop_bb); - VisitStmt(body_); + if(body_) + VisitStmt(body_); if(!is_terminator(ret_)) mod_->get_continue_fn()(); ir::basic_block *stop_bb = bld_->get_insert_block(); @@ -512,6 +571,8 @@ ir::value* Generator::GenNumcastOp(ir::value*src, ir::type* dst_ty) { else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() && src_scalar_ty->get_integer_bitwidth()) return bld_->create_int_cast(src, dst_ty, dst_signed); + else if(src_scalar_ty->is_pointer_ty() && dst_scalar_ty->is_pointer_ty()) + return bld_->create_cast(ir::BitCast, src, dst_ty); else{ should_not_happen(); return nullptr; @@ -611,6 +672,8 @@ ir::type* Generator::GenIRFuncType(FuncType* type, ir::context& ctx) { ir::type* Generator::GenIRPointerType(PointerType* type, ir::context& ctx) { ir::type* ele_ty = GenIRType(type->Derived().GetPtr(), ctx); unsigned addr_space = 1; + if(type->Derived().IsConstantQualified()) + addr_space = 4; return ir::pointer_type::get(ele_ty, addr_space); } diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index 960c983cf..ae37d9567 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -1083,14 +1083,12 @@ QualType Parser::ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec) *storageSpec |= S_THREAD; break; - case Token::AUTO: - EnsureAndSetStorageSpec(tok, storageSpec, S_AUTO); - break; // Type qualifier case Token::CONST: qualSpec |= Qualifier::CONST; break; case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break; case Token::VOLATILE: qualSpec |= Qualifier::VOLATILE; break; + case Token::CMEM: qualSpec |= Qualifier::CMEM; break; // Type specifier case Token::SIGNED: @@ -1551,6 +1549,7 @@ int Parser::ParseQual() { case Token::CONST: qualSpec |= Qualifier::CONST; break; case Token::RESTRICT: qualSpec |= Qualifier::RESTRICT; break; case Token::VOLATILE: qualSpec |= Qualifier::VOLATILE; break; + case Token::CMEM: qualSpec |= Qualifier::CMEM; break; case Token::ATOMIC: Error(tok, "do not support 'atomic'"); break; default: ts_.PutBack(); return qualSpec; } @@ -1769,6 +1768,7 @@ QualType Parser::ParseArrayFuncDeclarator(const Token* ident, QualType base) { if (!base->Complete()) { Error(ident, "'%s' has incomplete element type", ident->str_.c_str()); } + // return a pointer for tiles in constant memory: return TileType::New(shape, base); } else if (ts_.Try('(')) { // Function declaration diff --git a/lib/lang/token.cc b/lib/lang/token.cc index 8b61aa098..c4a95c0c4 100644 --- a/lib/lang/token.cc +++ b/lib/lang/token.cc @@ -7,6 +7,7 @@ static MemPoolImp tokenPool; const std::unordered_map Token::kwTypeMap_ { + { "__constant__", Token::CMEM }, { "__global__", Token::GLOBAL }, { "auto", Token::AUTO }, { "break", Token::BREAK }, diff --git a/lib/lang/type.cc b/lib/lang/type.cc index a1564ad97..13d09cf89 100644 --- a/lib/lang/type.cc +++ b/lib/lang/type.cc @@ -294,7 +294,8 @@ std::string ArithmType::Str() const { bool PointerType::Compatible(const Type& other) const { // C11 6.7.6.1 [2]: pointer compatibility auto otherPointer = other.ToPointer(); - return otherPointer && derived_->Compatible(*otherPointer->derived_); + return otherPointer && + derived_->Compatible(*otherPointer->derived_); // FIXME(wgtdkp): cannot loose compatible constraints //return other.IsInteger() || diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 7c6005a56..fe1d77b66 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -184,10 +184,20 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr // kernel uses too much resources if(!bin) return; + // copy constants + std::unique_ptr buffer; + for(ir::alloc_const* alloc: ir->allocs()){ + std::string name = alloc->get_name(); + auto it = cst_.find(name); + if(it == cst_.end()) + throw std::runtime_error("constant not set before execution"); + buffer = bin->symbol(name.c_str()); + stream->write(&*buffer, true, 0, it->second); + } // benchmark ir::function *tmp = ir->get_function_list()[0]; caller call(tmp, std::move(bin), opt); - double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream); + double ts = tools::bench([&]() { call(stream, grid_fn(opt), args); }, stream, true); // save best if(ts < best_ts) { best_ts = ts; @@ -222,20 +232,14 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps); // run passes dce.run(module); -// ir::print(module, std::cout); - disassociate.run(module); - -// ir::print(module, std::cout); - dce.run(module); -// ir::print(module, std::cout); - peephole.run(module); dce.run(module); align.run(module); cts.run(module); axes.run(module); +// ir::print(module, std::cout); layouts.run(module); coalesce.run(module); dce.run(module); @@ -246,17 +250,19 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c dce.run(module); align.run(module); axes.run(module); +// ir::print(module, std::cout); layouts.run(module); liveness.run(module); allocation.run(module); if(allocation.allocated_size() > context->device()->max_shared_memory()) return std::unique_ptr(); barriers.run(module); -// std::cout << "isel" << std::endl; +// ir::print(module, std::cout); isel.visit(module, *llvm); // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); // done +// exit(EXIT_FAILURE); return res; } @@ -273,8 +279,13 @@ R"( #define __aligned(A) __attribute__((aligned(A))) #define __multipleof(A) __attribute__((multipleof(A))) +extern int atomic_cas(int*, int, int); +extern int atomic_xchg(int*, int); extern int get_program_id(int); +extern int get_num_programs(int); extern float sqrtf(float); +extern int select(bool, int, int); +extern char __constant__ * calloc(int); )"; } @@ -316,5 +327,9 @@ void function::operator()(const std::vector& args, const grid_t& grid, driv return this->operator()(args, [&grid](const options_t&){ return grid; }, stream); } +void function::set_cst(const std::string& name, void* data, size_t n_bytes) { + cst_[name] = std::vector((char*)data, (char*)data + n_bytes); +} + } } diff --git a/python/examples/blocksparse.py b/python/examples/blocksparse.py deleted file mode 100644 index 7d15fc4f4..000000000 --- a/python/examples/blocksparse.py +++ /dev/null @@ -1,157 +0,0 @@ -import tensorflow as tf -import triton -import numpy as np - -src = ''' - #if AT == 1 - #define USE_A ^a - #define STRIDE_AK lda - #define STRIDE_AM 1 - #define BROADCAST_AK :, newaxis - #define BROADCAST_AM newaxis, : - #define SHAPE_A TK, TM - #else - #define USE_A a - #define STRIDE_AK 1 - #define STRIDE_AM lda - #define BROADCAST_AK newaxis, : - #define BROADCAST_AM :, newaxis - #define SHAPE_A TM, TK - #endif - - #if BT == 1 - #define USE_B ^b - #define STRIDE_BK 1 - #define STRIDE_BM ldb - #define BROADCAST_BN newaxis, : - #define BROADCAST_BK :, newaxis - #define SHAPE_B TN, TK - #else - #define USE_B b - #define STRIDE_BK ldb - #define STRIDE_BM 1 - #define BROADCAST_BN :, newaxis - #define BROADCAST_BK newaxis, : - #define SHAPE_B TK, TN - #endif - - void dot (TYPE* A __readonly __noalias __align(16), - TYPE* B __readonly __noalias __align(16), - TYPE* C __writeonly __noalias __align(16), - int lda, int ldb, int ldc, - int N, int* lut, - int* locks, int nlocks) { - int ridx = get_program_id(0); - float c[TM, TN] = 0; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - // load LUT header - int *header = lut + get_program_id(1) * 4; - int offset = *(header + 0); - int K = *(header + 1); - int column = *(header + 2); - int lockid = *(header + 3); - int *plut = lut + offset * 2; - int offx = ridx; - int offy = 0; - // compute x, y offsets - int rxa[TM] = offx * TM + (0 ... TM); - int ryb[TN] = offy * TN + (0 ... TN); - // bounds checking - bool checka[SHAPE_A] = (rxa < N)[:, newaxis]; - bool checkb[SHAPE_B] = 1; - // base offset - int offa[SHAPE_A] = rxa[BROADCAST_AM] * STRIDE_AM + rka[BROADCAST_AK] * STRIDE_AK; - int offb[SHAPE_B] = ryb[BROADCAST_BN] * STRIDE_BN + rkb[BROADCAST_BK] * STRIDE_BK; - for(int k = K; k > 0; k -= 1) { - // fetch block indices - int ak = *(plut + 0); - int bk = *(plut + 1); - lut += 2; - // compute pointers to blocks - TYPE* pa[SHAPE_A] = A + offa + ak * TK * lda; - TYPE* pb[SHAPE_B] = B + offb + bk * TK * TN; - // load blocks - TYPE a[SHAPE_A] = checka ? *pa : 0; - TYPE b[SHAPE_B] = *pb; - // multiply blocks - c += USE_A @ USE_B; - } - int rxc[TM] = ridx * TM + (0 ... TM); - int ryc[TN] = column * TN + (0 ... TN); - TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc; - bool checkc[TM, TN] = (rxc < N)[:, newaxis]; - if(lockid == 0) { - *?(checkc) pc = c; - } - else { - int *plock = locks + ridx*nlocks + lockid - 1; - int *pcount = plock + get_num_program(0)*nlocks; - while(atomic_cas(plock, 0, 1)); - int count = *pcount; - if(count == 0) - *?(checkc) pc = c; - else - *?(checkc) pc = c + *pc; - atomic_exch(pcount, 1); - atomic_exch(plock, 0); - } - } -''' - -# std::string dot::triton_c_src_dw() const { -# bool AT = (op_ == WGRAD); -# bool BT = (op_ == FPROP); -# std::string usea = AT ? "trans(a)" : "a"; -# std::string useb = BT ? "trans(b)" : "b"; -# std::string sizea = AT ? "TK, TM" : "TM, TK"; -# std::string sizeb = BT ? "TN, TK" : "TK, TN"; -# std::string bca0 = AT ? "newaxis, :" : ":, newaxis"; -# std::string bca1 = AT ? ":, newaxis" : "newaxis, :"; -# std::string bcb0 = BT ? ":, newaxis" : "newaxis, :"; -# std::string bcb1 = BT ? "newaxis, :" : ":, newaxis"; -# std::string lda0 = AT ? "*lda" : ""; -# std::string lda1 = AT ? "" : "*lda"; -# std::string ldb0 = BT ? "" : "*ldb"; -# std::string ldb1 = BT ? "*ldb" : "" ; -# std::string result = -# R"( -# const tunable int TM = {)" + std::to_string(BS_) + R"(}; -# const tunable int TN = {)" + std::to_string(BS_) + R"(}; -# const tunable int TK = {32}; -# void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, -# restrict read_only align(16) )" + ab_ty_ + R"( *B, -# )" + c_ty_ + R"(* C, -# int lda, int ldb, int ldc, -# int N, int* lut, -# int* locks, int nlocks) { -# int ridx = get_range_id(0); -# float acc[TM, TN] = 0; -# int rka[TK] = 0 ... TK; -# int rkb[TK] = 0 ... TK; -# int *header = lut + ridx * 2; -# int offx = *(header + 0); -# int offy = *(header + 1); -# int rxa[TM] = offx*TM + (0 ... TM); -# int ryb[TN] = offy*TN + (0 ... TN); -# bool checka[TK, TM] = (rka < N)[:, newaxis]; -# bool checkb[TK, TN] = (rkb < N)[:, newaxis]; -# int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(; -# int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; -# )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa; -# )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb; -# )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0; -# )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0; -# for(int k = N; k > 0; k = k - TK) { -# acc = dot()" + usea + ", " + useb + R"(, acc); -# pa = pa + TK)" + lda1 + R"(; -# pb = pb + TK)" + ldb1 + R"(; -# a = checka ? *pa : 0; -# b = checkb ? *pb : 0; -# } -# int rxc[TM] = (0 ... TM); -# int ryc[TN] = (0 ... TN); -# )" + c_ty_ + R"( c[TM, TN] = acc; -# )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN; -# *pc = c; -# })"; \ No newline at end of file diff --git a/python/examples/conv.py b/python/examples/conv.py deleted file mode 100644 index 43f0f5d91..000000000 --- a/python/examples/conv.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -import triton - -N, C, K = 32, 8, 32 -H, W = 16, 16 -R, S = 3, 3 -torch.manual_seed(0) -a = torch.randn(N, C, H, W).cuda() -b = torch.ones(C, R, S, K).cuda() - -rc = torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2)) -tc = triton.ops.conv(a, b) -print((rc - tc).abs().max()) -#print((rc[:30,:30,:,:] - tc[:30, :30, :, :]).abs().max()) -#print(tc[31, 31,:,:]) \ No newline at end of file diff --git a/python/examples/dot.py b/python/examples/dot.py deleted file mode 100644 index dfc2587f2..000000000 --- a/python/examples/dot.py +++ /dev/null @@ -1,71 +0,0 @@ -import numpy as np -import triton - -def run_tf(): - M, N, K = 2048, 2048, 2048 - a = tf.placeholder(tf.float32, shape=[M, K]) - b = tf.placeholder(tf.float32, shape=[N, K]) - triton_c = triton.ops.dot(a, b, False, True, 1) - triton_d = triton.ops.dot(triton_c, b, True, False, 1) - triton_y = tf.math.reduce_mean(triton_d) - fw_c = tf.matmul(a, b, False, True) - fw_d = tf.matmul(fw_c, b, True, False) - fw_y = tf.math.reduce_mean(fw_d) - # Gradient - triton_da, triton_db = tf.gradients(triton_y, [a, b]) - fw_da, fw_db = tf.gradients(fw_y, [a, b]) - # Reference - feed_dict = {a: np.random.rand(M, K).astype(np.float32), - b: np.random.rand(K, N).astype(np.float32)} - sess = tf.InteractiveSession() - sess.run(tf.global_variables_initializer()) - result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict) - triton_da, fw_da = result[0][0], result[1][0] - triton_db, fw_db = result[2][0], result[3][0] - # Benchmark - nanosec = triton.bench_registry[triton_d] - print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) - print('Diff DA:', (triton_da - fw_da).max()) - print('Diff DB:', (triton_db - fw_db).max()) - - -def run_torch(): - torch.manual_seed(0) - M, N, K = 2048, 2048, 2048 - a = torch.randn(M, K).cuda() - b = torch.randn(K, N).cuda() - a.requires_grad_(True) - b.requires_grad_(True) - torch_c = torch.matmul(a, torch.t(b)) - torch_d = torch.matmul(torch.t(torch_c), b) - torch_y = torch.mean(torch_d) - triton_c = triton.ops.dot(a, b, False, True, 1) - triton_d = triton.ops.dot(triton_c, b, True, False, 1) - triton_y = torch.mean(triton_d) - # torch gradient - torch_y.backward() - torch_da = a.grad.clone() - torch_db = b.grad.clone() - # triton gradient - a.grad.zero_() - b.grad.zero_() - triton_y.backward() - triton_da = a.grad.clone() - triton_db = b.grad.clone() - - #nanosec = triton.bench_registry[triton_d] - #print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) - print('Diff DA:', (torch_da - triton_da).max()) - print('Diff DB:', (torch_db - triton_db).max()) - -try: - import tensorflow as tf - run_tf() -except ModuleNotFoundError: - pass - -try: - import torch - run_torch() -except ModuleNotFoundError: - pass diff --git a/python/examples/einsum.py b/python/examples/einsum.py index 8c3327e5a..2cbf2ca10 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -1,92 +1,194 @@ -#!/usr/bin/env python - -import numpy as np -from enum import Enum import triton +import torch +from torch.utils.cpp_extension import load +import numpy as np +#import utils +from time import time -class MODE(Enum): - TF = 1 - TORCH = 2 +#torch.backends.cudnn.benchmark = True -try: - import tensorflow as tf - mode = MODE.TF -except ModuleNotFoundError: - pass +configs = [] -try: - import torch - mode = MODE.TORCH -except ModuleNotFoundError: - pass +# Matrix multiplication +MNK = [ + (512, 512 ,512), + (2048, 2048, 2048), + (8192, 8192, 8192), + + (64, 64, 64000), + (64, 64, 128000), + (256, 256, 64000), + (256, 256, 128000), -cases = [] -# Matmul -cases += [[[4, 1024, 1024], [1024, 1024], [4, 1024, 1024], "btc,ck->btk"]] -# Attention -# cases += [[[4, 256, 8, 2, 64], [8, 2, 512, 64], [4, 256, 8, 2, 512], "bchak,hank->bchan"]] + (1536, 16, 1536), + (1536, 32, 1536), + (1536, 64, 1536), + (1536, 128, 1536), + (4096, 16, 4096), + (4096, 32, 4096), + (4096, 64, 4096), + (4096, 128, 4096), + + #(127008, 768, 576) + ] +for M, N, K in MNK: + matmul = lambda a, b: torch.matmul(a, b) + configs += [([M, K], [K, N], [M, N], matmul, 'mk,kn->mn', dict())] +for M, N, K in MNK: + matmul = lambda a, b: torch.matmul(a.t(), b) + configs += [([M, K], [M, N], [K, N], None, 'mk,mn->kn', dict())] +for M, N, K in MNK: + matmul = lambda a, b: torch.matmul(a, b.t()) + configs += [([M, N], [K, N], [M, K], None, 'mn,kn->mk', dict())] -if mode == MODE.TF: - sess = tf.InteractiveSession() +# Relative attention +NTHSE = [ + #(16, 512, 1, 64, 64), + # (16, 512, 1, 128, 128), + # (16, 512, 1, 256, 256), + # (16, 512, 1, 256, 512), + #(16, 512, 8, 64, 64), + # (16, 512, 8, 128, 128), + # (16, 512, 8, 256, 256), + # (16, 512, 8, 256, 512), -for a_shape, b_shape, c_shape, einsum in cases: + # (64, 1024, 1, 64, 64), + #(64, 1024, 1, 128, 128), + # (64, 1024, 1, 256, 256), + # (64, 1024, 1, 256, 512), + # (64, 1024, 8, 64, 64), + #(64, 1024, 8, 128, 128), + # (64, 1024, 8, 256, 256), + # (64, 1024, 8, 256, 512), - A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32) - B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32) - E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32) + # (128, 1024, 1, 64, 64), + # (128, 1024, 1, 128, 128), + # (128, 1024, 1, 256, 256), + #(128, 1024, 1, 256, 512), + # (128, 1024, 8, 64, 64), + # (128, 1024, 8, 128, 128), + # (128, 1024, 8, 256, 256), + #(128, 1024, 8, 256, 512) + ] +for N, T, H, S, E in NTHSE: + configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())] +for N, T, H, S, E in NTHSE: + configs += [([N, H, T, E], [N, T, H, S], [H, E, S], None, 'nhte,nths->hes', dict())] +for N, T, H, S, E in NTHSE: + configs += [([N, H, T, E], [H, E, S], [N, T, H, S], None, 'nhte,hes->nths', dict())] - # Execute (tensorflow) - if mode == MODE.TF: - a = tf.placeholder(tf.float32, a_shape, name="a") - b = tf.placeholder(tf.float32, b_shape, name="b") - e = tf.placeholder(tf.float32, c_shape, name="e") - c = triton.ops.einsum(einsum, a, b, 1) - da, db = tf.gradients(c, [a, b], e) - feed_dict = { a: A.astype(np.float32), - b: B.astype(np.float32), - e: E } - sess.run(tf.global_variables_initializer()) - result = sess.run([c, da, db], feed_dict = feed_dict) - # Execute (torch) - if mode == MODE.TORCH: - a = torch.from_numpy(A).cuda() - b = torch.from_numpy(B).cuda() - e = torch.from_numpy(E).cuda() - a.requires_grad_(True) - b.requires_grad_(True) - c = triton.ops.einsum(einsum, a, b, 1) - torch.autograd.backward(c, e) - da = a.grad - db = b.grad - result = [c.cpu().detach().numpy(), da.cpu().detach().numpy(), db.cpu().detach().numpy()] - - # benchmark - nanosec = triton.bench_registry[c] - ctx = triton.ctx_registry[c] - b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4))) - ops = 2.*b*m*n*k - print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3) - #print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3) - #print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3) +# 1D Dense convolution +NCHKR = [ + # (1, 1152, 12602, 512, 3) + ] +for N, C, H, K, R in NCHKR: + torch_fn = lambda a, b: torch.nn.functional.conv1d(a, b.permute(2, 0, 1)) + configs += [([N, C, H], + [C, R, K], + [N, K, H - R + 1], + torch_fn, + 'nc(h+r),crk->nkh', + dict())] - # test - ctx = triton.ctx_registry[c] - t_a = ctx.trans_a - t_b = ctx.trans_b - e_a = ctx.einsum_a - e_b = ctx.einsum_b - e_c = ctx.einsum_c - C = np.einsum(einsum, A, B) - if not t_a and not t_b: # NN - DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B) - DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E) - elif not t_a and t_b: # NT - DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B) - DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A) - elif t_a and not t_b: # TN - DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E) - DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E) - c, da, db = result[0], result[1], result[2] - print('C diff:', np.abs((C - c)).max()) - print('DA diff:', np.abs((DA - da)).max()) - print('DB diff:', np.abs((DB - db)).max()) \ No newline at end of file +# 2D Dense convolution +NCHWKRS = [ + #(8, 64, 128, 128, 768, 3, 3), + #(8, 128, 64, 64, 256, 3, 3), + #(8, 256, 32, 32, 512, 3, 3), + #(8, 512, 32, 32, 1024, 3, 3) + ] +for N, C, H, W, K, R, S in NCHWKRS: + torch_fn = lambda a, b: torch.nn.functional.conv2d(a, b.permute(3, 0, 1, 2)) + configs += [([N, C, H, W], + [C, R, S, K], + [N, K, H - R + 1, W - R + 1], + torch_fn, + 'nc(h+r)(w+s),crsk->nkhw', + dict())] + +# 3D Dense Convolution +NCDHWKTRS = [ + #(8, 32, 27, 100, 100, 64, 3, 3, 3), + #(8, 64, 23, 48, 48, 256, 3, 3, 3), + #(8, 256, 19, 22, 22, 640, 3, 3, 3), + #(8, 640, 15, 36, 36, 384, 3, 3, 3) + ] +for N, C, D, H, W, K, T, R, S in NCDHWKTRS: + torch_fn = lambda a, b: torch.nn.functional.conv3d(a, b.permute(4, 0, 1, 2, 3)) + configs += [([N, C, D, H, W], + [C, T, R, S, K], + [N, K, D - T + 1, H - R + 1, W - R + 1], + torch_fn, + 'nc(d+t)(h+r)(w+s),ctrsk->nkdhw', + dict())] + + +# Shift convolution +shift_cuda = torch.utils.cpp_extension.load( + 'shift_cuda', ['kernels/shift_cuda.cpp', + 'kernels/shift_cuda_kernel.cu'], + extra_cflags=['-O3']) +class shift(torch.autograd.Function): + @staticmethod + def forward(ctx, x, shift): + ctx.save_for_backward(shift) + return shift_cuda.forward(x, shift) + + @staticmethod + def backward(ctx, grad_output): + shift, = ctx.saved_tensors + grad_output = shift_cuda.backward(grad_output, shift) + + return grad_output, None + +NCHWKRS = [ + #(8, 64, 128, 128, 128, 3, 3), + #(8, 128, 64, 64, 256, 3, 3), + #(8, 256, 32, 32, 512, 3, 3), + #(8, 512, 32, 32, 1024, 3, 3) + ] +for N, C, H, W, K, R, S in NCHWKRS: + shift_h = np.random.randint(R, size=C, dtype=np.int32) - R//2 + shift_w = np.random.randint(S, size=C, dtype=np.int32) - S//2 + def shift_conv(a, b, **kwargs): + shift_h, shift_w = kwargs['sh'], kwargs['sw'] + shift_torch = np.column_stack((shift_w*-1, shift_h*-1)) + shift_torch = torch.from_numpy(shift_torch).cuda() + a = shift.apply(a, shift_torch) + b = b.permute(1, 0) + b = b.reshape(b.shape[0], b.shape[1], 1, 1) + return torch.nn.functional.conv2d(a, b) + configs += [([N, C, H, W], + [C, K], + [N, K, H, W], + shift_conv, + 'nc(h + sh[c])(w + sw[c]),ck->nkhw', + {'sh': shift_h, 'sw': shift_w})] + +# Benchmark +torch.set_num_threads(1) +for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: + dtype = torch.cuda.HalfTensor + # initialize input tensors + a = torch.rand(*a_shape).type(dtype).cuda() + b = torch.rand(*b_shape).type(dtype).cuda() + # triton output + #ta = triton.ops._einsum.pad(a, [4,4,4,4]) + tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True) + # reference output + if torch_fn: + rc = torch_fn(a, b, **arrays) + else: + rc = torch.einsum(expr, a, b) + # performance relative to equivalent matrix multiplication + ctx = triton.ctx_registry[tc] + B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K + # a = torch.rand(B, M, K).type(dtype).cuda() + # b = torch.rand(B, K, N).type(dtype).cuda() + # tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) + # ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc] + ratio = 0 + # test and benchmark + bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3 + diff = (tc - rc).abs().max() / rc.abs().max() + print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} ({ratio:4.2f}); {diff:4.2f}') diff --git a/python/examples/kernels/shift_cuda.cpp b/python/examples/kernels/shift_cuda.cpp new file mode 100644 index 000000000..b7a769feb --- /dev/null +++ b/python/examples/kernels/shift_cuda.cpp @@ -0,0 +1,42 @@ +#include + +#include + +// CUDA forward declarations + +at::Tensor shift_cuda_forward( + const at::Tensor input, + const at::Tensor shift); + +at::Tensor shift_cuda_backward( + const at::Tensor grad_input, + const at::Tensor shift); + +// C++ interface + +// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +at::Tensor shift_forward( + const at::Tensor input, + const at::Tensor shift) { + CHECK_INPUT(input); + CHECK_INPUT(shift); + + return shift_cuda_forward(input, shift); +} + +at::Tensor shift_backward( + const at::Tensor grad_input, + const at::Tensor shift) { + CHECK_INPUT(grad_input); + CHECK_INPUT(shift); + return shift_cuda_backward(grad_input, shift); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &shift_forward, "Shift forward (CUDA)"); + m.def("backward", &shift_backward, "Shift backward (CUDA)"); +} diff --git a/python/examples/kernels/shift_cuda_kernel.cu b/python/examples/kernels/shift_cuda_kernel.cu new file mode 100644 index 000000000..ca56b6b0f --- /dev/null +++ b/python/examples/kernels/shift_cuda_kernel.cu @@ -0,0 +1,111 @@ +#include + +#include +#include + +#include + +namespace { +template +__global__ void shift_cuda_forward_kernel( + const scalar_t* __restrict__ input, + const int32_t* __restrict__ shift, + scalar_t* __restrict__ output, + const int32_t B, + const int32_t C, + const int32_t H, + const int32_t W) { + const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t size = B*C*H*W; + + const int32_t CHW = C*H*W; + const int32_t HW = H*W; + const int32_t b = idx / CHW; + const int32_t c = (idx - b*CHW) / HW; + const int32_t h = (idx - b*CHW - c*HW) / W; + const int32_t w = idx - b*CHW - c*HW - h*W; + const int32_t target_w = w + shift[2*c]; + const int32_t target_h = h + shift[2*c + 1]; + const int32_t target_idx = b*CHW + c*HW + target_h*W + target_w; + if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) { + output[target_idx] = input[idx]; + } +} + +template +__global__ void shift_cuda_backward_kernel( + const scalar_t* __restrict__ grad_input, + scalar_t* __restrict__ grad_output, + const int32_t* __restrict__ shift, + const int32_t B, + const int32_t C, + const int32_t W, + const int32_t H) { + const int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t size = B*C*W*H; + const int32_t CWH = C*W*H; + const int32_t WH = W*H; + const int32_t b = idx / CWH; + const int32_t c = (idx - b*CWH) / WH; + const int32_t w = (idx - b*CWH - c*WH) / W; + const int32_t h = idx - b*CWH - c*WH - w*H; + const int32_t target_w = w - shift[2*c]; + const int32_t target_h = h - shift[2*c + 1]; + const int32_t target_idx = b*CWH + c*WH + target_w*W + target_h; + if (idx < size && target_w >= 0 && target_w < W && target_h >= 0 && target_h < H) { + grad_output[target_idx] = grad_input[idx]; + } +} +} // namespace + +at::Tensor shift_cuda_forward( + const at::Tensor input, + const at::Tensor shift) { + const auto B = input.size(0); + const auto C = input.size(1); + const auto H = input.size(2); + const auto W = input.size(3); + const auto size = B*C*W*H; + const int threads = 1024; + const int blocks = (size + threads - 1) / threads; + auto output = at::zeros_like(input); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "shift_forward_cuda", ([&] { + shift_cuda_forward_kernel<<>>( + input.data(), + shift.data(), + output.data(), + B, + C, + H, + W); + })); + + return output; +} + +at::Tensor shift_cuda_backward( + const at::Tensor grad_input, + const at::Tensor shift) { + const auto B = grad_input.size(0); + const auto C = grad_input.size(1); + const auto H = grad_input.size(2); + const auto W = grad_input.size(3); + const auto size = B*C*W*H; + const int threads = 1024; + const int blocks = (size + threads - 1) / threads; + auto grad_output = at::zeros_like(grad_input); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_input.type(), "shift_backward_cuda", ([&] { + shift_cuda_backward_kernel<<>>( + grad_input.data(), + grad_output.data(), + shift.data(), + B, + C, + H, + W); + })); + + return grad_output; +} diff --git a/python/setup.py b/python/setup.py index 4c8d38259..060a1c450 100644 --- a/python/setup.py +++ b/python/setup.py @@ -77,7 +77,7 @@ class CMakeBuild(build_ext): pass cfg = 'Debug' if self.debug else 'Release' - #cfg = 'Release' + cfg = 'Release' build_args = ['--config', cfg] if platform.system() == "Windows": diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 7fb4a29f0..8b3ee2971 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -48,6 +49,11 @@ void delete_fn(size_t id) { id_fn_map.erase(id); } +void register_cst(size_t id, const std::string& name, pybind11::buffer& data) { + pybind11::buffer_info info = data.request(); + id_fn_map[id]->set_cst(name, info.ptr, info.size*info.itemsize); +} + void cleanup() { id_grid_map.clear(); id_fn_map.clear(); @@ -508,7 +514,8 @@ void gen_torch_make_handles(std::ostream &os, os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl; else{ os << " CHECK_INPUT(" << name << ");" << std::endl; - os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage().data(), false);" << std::endl; + os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), " + " (CUdeviceptr)((char*)" + name + ".storage().data() + " + name + ".storage_offset() * " + name + ".itemsize()), false);" << std::endl; } } } @@ -526,8 +533,8 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector 0)\n "; os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n "; } @@ -562,18 +569,15 @@ std::tuple 0: - bench_registry[ret] = libtriton.retrieve_scalar(op_id) + bench_registry[ret] = libtriton.retrieve_scalar(bench_id) else: assert False \ No newline at end of file diff --git a/python/triton/ops/conv.py b/python/triton/ops/conv.py index 8a2678f2a..8bd0acbd3 100644 --- a/python/triton/ops/conv.py +++ b/python/triton/ops/conv.py @@ -38,25 +38,19 @@ void convnd(A_TYPE *A, int rah[TM] = rabh % CH; rah = rah * UPAW - off_uah; raw = raw * UPAH - off_uaw; - int racr[TK] = rk / BW; - int ras[TK] = rk % BW; - int rac[TK] = racr / BH; - int rar[TK] = racr % BH; - rar = UPAR * rar; - ras = UPAS * ras; int ram[TM] = rab*lda_n + rah*lda_h + raw*lda_w; - int rak[TK] = rac*lda_c + rar*lda_h + ras*lda_w; + int rak[TK] = *(ADELTA + rk); A_TYPE* pa[TM, TK] = A + ram[:, newaxis] + rak[newaxis, :]; // pointers for B int rbk[TK] = rk; int rbn[TN] = ryb; - B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_s; + B_TYPE* pb[TK, TN] = B + rbn[newaxis, :] * ldb_k + rbk[:, newaxis] * ldb_c; // pointers for A look-up table int rklut[TK] = rk % LUT_SIZE; int* padiff[TK] = ADIFF + rklut; - int* padelta[TK] = ADELTA + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w; + int* padelta[TK] = ADELTA + TK + rklut + off_uw * LUT_SIZE + off_uh * LUT_SIZE * upsample_w; int adiff[TK] = *padiff; int adelta[TK] = *padelta; @@ -66,7 +60,7 @@ void convnd(A_TYPE *A, for(int k = K; k > 0; k = k - TK){ c += a @ b; pa += adelta[newaxis, :]; - pb += TK * ldb_s; + pb += TK * ldb_c; // increment A look-up table padelta = padelta + adiff; adelta = *padelta; @@ -99,29 +93,54 @@ void convnd(A_TYPE *A, kernel = triton.kernel(src, ['C']) @staticmethod - def _unpack(idx, D, H, W): - cdh = idx // W - w = idx % W - cd = cdh // H - h = cdh % H - c = cd // D - d = cd % D - return c, d, h, w + def _unpack(idx, order, shape_b): + _123 = idx // shape_b[order[0]] + _0 = idx % shape_b[order[0]] + _23 = _123 // shape_b[order[1]] + _1 = _123 % shape_b[order[1]] + _3 = _23 // shape_b[order[2]] + _2 = _23 % shape_b[order[2]] + return _0, _1, _2, _3 @staticmethod - def _delta_a(upsample_d, upsample_h, upsample_w, depth, TK, - T, R, S, stride_a): + def _roundup(x, div): + return (x + div - 1) // div * div + + @staticmethod + def _delta_a(upsample_d, upsample_h, upsample_w, + bc, bd, bh, bw, + ac, ad, ah, aw, + stride_a, shape_b, + TK): + # Parse the axes so that the reduction is done + # from the innermost dimension outward + order = sorted([bc, bd, bh, bw], reverse = True) + c, d, h, w = [order.index(x) for x in [bc, bd, bh, bw]] + # Size of the lookup table is the product of the 3 innermost dimensions + K = _conv._roundup(TK, shape_b[order[0]] * shape_b[order[1]] * shape_b[order[2]]) + # Allocate temporary arrays ud = np.arange(upsample_d, dtype=np.int32)[:, np.newaxis, np.newaxis, np.newaxis] uh = np.arange(upsample_h, dtype=np.int32)[np.newaxis, :, np.newaxis, np.newaxis] uw = np.arange(upsample_w, dtype=np.int32)[np.newaxis, np.newaxis, :, np.newaxis] - ctrs = np.arange(depth, dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :] - c, t, r, s = _conv._unpack(ctrs, T, R, S) - nextc, nextt, nextr, nexts = _conv._unpack(ctrs + TK, T, R, S) - cdiff = nextc - c - tdiff = nextt - t - rdiff = nextr - r - sdiff = nexts - s - return cdiff*stride_a[1] + tdiff*stride_a[2] + rdiff*stride_a[3] + sdiff*stride_a[4] + k = np.arange(K , dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :] + # Find reduction indices at the current and next reduction indices + currentk = _conv._unpack(k , order, shape_b) + nextk = _conv._unpack(k + TK, order, shape_b) + # Compute memory stride + result = 0 + result += (nextk[c] - currentk[c]) * stride_a[ac] + result += (nextk[d] - currentk[d]) * stride_a[ad] + result += (nextk[h] - currentk[h]) * stride_a[ah] + result += (nextk[w] - currentk[w]) * stride_a[aw] + # Initial k + ki = np.arange(TK , dtype=np.int32)[np.newaxis, np.newaxis, np.newaxis, :] + currentk = _conv._unpack(ki, order, shape_b) + resulti = 0 + resulti += currentk[c] * stride_a[ac] + resulti += currentk[d] * stride_a[ad] + resulti += currentk[h] * stride_a[ah] + resulti += currentk[w] * stride_a[aw] + return np.concatenate((resulti, result), axis=-1) @staticmethod def _extract_strides(shape): @@ -134,38 +153,56 @@ void convnd(A_TYPE *A, @staticmethod def _call(a, b, - upsample_d, upsample_h, upsample_w, pad_d, pad_h, pad_w, - stride_d, stride_h, stride_w, - mode): + stride_d, stride_h, stride_w, + upsample_d, upsample_h, upsample_w, + a_layout, b_layout, c_layout): # input shapes shape_a = list(triton.shape(a)) shape_b = list(triton.shape(b)) - # add depth - shape_a.insert(2, 1) - shape_b.insert(1, 1) - NB, NC, AD, AH, AW = shape_a - NC, BD, BH, BW, NF = shape_b + dim = len(shape_a) - 2 + # indices + an, ac, ad, ah, aw = [a_layout.find(x) for x in 'ncdhw'] + bk, bc, bd, bh, bw = [b_layout.find(x) for x in 'kctrs'] + cn, ck, cd, ch, cw = [c_layout.find(x) for x in 'nkdhw'] + # extract shapes + if dim == 2: + shape_a.insert(ad, 1) + if dim == 2: + shape_b.insert(bd, 1) # output shape - CD = (AD*upsample_d - BD + 1 + 2*pad_d + stride_d - 1) // stride_d - CH = (AH*upsample_h - BH + 1 + 2*pad_h + stride_h - 1) // stride_h - CW = (AW*upsample_w - BW + 1 + 2*pad_w + stride_w - 1) // stride_w - shape_c = [NB, NF, CD, CH, CW] + shape_c = [0] * 5 + shape_c[cn] = shape_a[an] + shape_c[ck] = shape_b[bk] + shape_c[cd] = (shape_a[ad]*upsample_d - shape_b[bd] + 1 + 2*pad_d + stride_d - 1) // stride_d + shape_c[ch] = (shape_a[ah]*upsample_h - shape_b[bh] + 1 + 2*pad_h + stride_h - 1) // stride_h + shape_c[cw] = (shape_a[aw]*upsample_w - shape_b[bw] + 1 + 2*pad_w + stride_w - 1) // stride_w # strides stride_a = _conv._extract_strides(shape_a) stride_b = _conv._extract_strides(shape_b) stride_c = _conv._extract_strides(shape_c) - # look-up tables + # tiling parameters + TM = [32] + TN = [32] TK = 8 - FS = BD * BH * BW - depth = (TK + FS - 1)//FS * FS + # pointer deltas for a delta_a = _conv._delta_a(upsample_d, upsample_h, upsample_w, - depth, TK, BD, BH, BW, stride_a) + bc, bd, bh, bw, + ac, ad, ah, aw, + stride_a, shape_b, + TK) delta_a = triton.fw.torch.from_numpy(delta_a).cuda() - inc_a = np.arange(depth, dtype=np.int32) - inc_a = ((inc_a + TK) % depth) - inc_a + # delta increments for a + inc_a = np.arange(delta_a.shape[-1] - TK, dtype=np.int32) + inc_a = ((inc_a + TK) % inc_a.size) - inc_a inc_a = triton.fw.torch.from_numpy(inc_a).cuda() - + # allocate output + if dim == 2: + shape_c.pop(cd) + c = triton.empty(shape_c, dtype=a.dtype) + if dim == 2: + shape_c.insert(cd, 1) + # execute kernel trans_b = False is_wgrad = False is_blut = False @@ -174,31 +211,99 @@ void convnd(A_TYPE *A, 'UPAS': 'stride_w' if is_wgrad else '1', 'UPAH': '' if is_wgrad else 'stride_h', 'UPAW': '' if is_wgrad else 'stride_w', - 'LUT_SIZE': depth, - 'TM': [32], - 'TN': [32], - 'TK': TK, - 'A_TYPE': 'float', - 'B_TYPE': 'float' + 'LUT_SIZE': delta_a.shape[-1], + 'TM': TM, 'TN': TN, 'TK': TK, + 'A_TYPE': 'float', 'B_TYPE': 'float' } - - shape_c.pop(2) - c = triton.empty(shape_c, dtype=a.dtype) - grid = lambda opt: [triton.cdiv(NB*CD*CH*CW, opt.d('TM')), triton.cdiv(NF, opt.d('TN'))] - print(stride_c) - print(stride_b) - _conv.kernel(a, b, c, NB*CD*CH*CW, NF, NC*BD*BH*BW, AH, AW, BH, BW, CH, CW, NC, - stride_a[0], stride_a[1], stride_a[2], stride_a[3], stride_a[4], - stride_b[0], stride_b[1], stride_b[2], stride_b[3], stride_b[4], - stride_c[0], stride_c[1], stride_c[2], stride_c[3], stride_c[4], - pad_h, pad_w, stride_h, stride_w, upsample_h, upsample_w, + MATMUL_M = shape_c[cn] * shape_c[cd] * shape_c[ch] * shape_c[cw] + MATMUL_N = shape_c[ck] + MATMUL_K = shape_b[bc] * shape_b[bd] * shape_b[bh] * shape_b[bw] + _conv.kernel(a, b, c, + # matrix multiplication shapes + MATMUL_M, MATMUL_N, MATMUL_K, + # shapes for a + shape_a[ah], shape_a[aw], + # shapes for b + shape_b[bh], shape_b[bw], + # chapes for c + shape_c[ch], shape_c[cw], shape_c[cn], + # strides for a + stride_a[an], stride_a[ac], stride_a[ad + 0], stride_a[ad + 1], stride_a[ad + 2], + # strides for b + stride_b[bc], stride_b[bd + 0], stride_b[bd + 1], stride_b[bd + 2], stride_b[bk], + # strides for c + stride_c[cn], stride_c[ck], stride_c[cd], stride_c[cd + 1], stride_c[cd + 2], + # padding + pad_h, pad_w, + # striding + stride_h, stride_w, + # upsampling + upsample_h, upsample_w, 0, 0, 0, 0, 0, 0, + # look-up table delta_a, inc_a, - grid, **macros) + lambda opt: [triton.cdiv(MATMUL_M, opt.d('TM')), triton.cdiv(MATMUL_N, opt.d('TN'))], + **macros) return c @staticmethod - def forward(ctx, input, weight): - return _conv._call(input, weight, 1, 1, 1, 0, 0, 0, 1, 1, 1, '') + def forward(ctx, x, w, + pad_d = 0, pad_h = 0, pad_w = 0, + stride_d = 1, stride_h = 1, stride_w = 1, + upsample_d = 1, upsample_h = 1, upsample_w = 1, + layout_a = 'ncdhw', layout_b = 'ktrsc', layout_c = 'nkdhw'): + # save for backward + ctx.save_for_backward(x, w) + ctx.pad_d = pad_d + ctx.pad_h = pad_h + ctx.pad_w = pad_w + ctx.stride_d = stride_d + ctx.stride_h = stride_h + ctx.stride_w = stride_w + ctx.upsample_d = upsample_d + ctx.upsample_h = upsample_h + ctx.upsample_w = upsample_w + ctx.layout_a = layout_a + ctx.layout_b = layout_b + ctx.layout_c = layout_c + # return + return _conv._call(x, w, + pad_d, pad_h, pad_w, + stride_d, stride_h, stride_w, + upsample_d, upsample_h, upsample_w, + layout_a, layout_b, layout_c) + + @staticmethod + def backward(ctx, dy): + x, w = ctx.saved_tensors + pad_d = ctx.pad_d + pad_h = ctx.pad_h + pad_w = ctx.pad_w + stride_d = ctx.stride_d + stride_h = ctx.stride_h + stride_w = ctx.stride_w + upsample_d = ctx.upsample_d + upsample_h = ctx.upsample_h + upsample_w = ctx.upsample_w + layout_a = ctx.layout_a + layout_b = ctx.layout_b + layout_c = ctx.layout_c + + # TODO: Deal with this + dx_pad_d = 1 + dx_pad_h = 1 + dx_pad_w = 1 + dx = _conv.call(dy, w, + dw_pad_d, dw_pad_h, dw_pad_w, + upsample_w, upsample_h, upsample_w, + stride_d, stride_h, stride_w, + 'ncdhw', 'cktrs', 'nkdhw') + + + + ret = [None] * 14 + ret[0] = None + ret[1] = dw + return None, conv = _conv.apply \ No newline at end of file diff --git a/python/triton/ops/dot.py b/python/triton/ops/dot.py index ae568c642..89b28d20e 100644 --- a/python/triton/ops/dot.py +++ b/python/triton/ops/dot.py @@ -3,37 +3,50 @@ import triton class _dot(triton.function): src = """ -void dot(TYPE * A, TYPE * B, TYPE * C, +void dot(TYPE * A __noalias __readonly __aligned(16), + TYPE * B __noalias __readonly __aligned(16), + TYPE * C, + float alpha, int M, int N, int K, int lda __multipleof(8), int ldb __multipleof(8), int ldc) { - // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int rm[TM] = ridx * TM + 0 ... TM; - int rn[TN] = ridy * TN + 0 ... TN; - int rk[TK] = 0 ... TK; - float c[TM, TN] = 0; - // pointers to operands - TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; - TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; - // prefetches operands - TYPE a[SHAPE_A] = *pa; - TYPE b[SHAPE_B] = *pb; - // reduction loop - for(int k = K; k > 0; k-= TK){ - c += USE_A @ USE_B; - pa = pa + TK * STRIDE_AK; - pb = pb + TK * STRIDE_BK; - bool checka[SHAPE_A] = k > TK; - bool checkb[SHAPE_B] = k > TK; - a = checka ? *pa : 0; - b = checkb ? *pb : 0; - } - // epilogue - TYPE* pc[TM, TN] = C + rm[:, newaxis] * ldc + rn[newaxis, :]; - *pc = c; + // prologue + int ridx = get_program_id(0); + int ridy = get_program_id(1); + int rm[TM] = ridx * TM + 0 ... TM; + int rn[TN] = ridy * TN + 0 ... TN; + int rk[TK] = 0 ... TK; + + // pointers to operands + TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; + TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; + + // prefetches operands + bool checka[SHAPE_A] = rk[BROADCAST_AK] < K; + bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K; + TYPE a[SHAPE_A] = checka ? *pa : 0; + TYPE b[SHAPE_B] = checkb ? *pb : 0; + + // reduction loop + float c[TM, TN] = 0; + for(int k = K; k > 0; k -= TK){ + c += USE_A @ USE_B; + bool checka[SHAPE_A] = k > TK; + bool checkb[SHAPE_B] = k > TK; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; + a = *?(checka)pa; + b = *?(checkb)pb; + } + //c = c * alpha; + + // epilogue + int rxm[TM] = get_program_id(0) * TM + 0 ... TM; + int rxn[TN] = get_program_id(1) * TN + 0 ... TN; + TYPE* pc[TM, TN] = C + rxm[:, newaxis] * ldc + rxn[newaxis, :]; + bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); + *?(checkc)pc = (TYPE[TM, TN])c; } """ kernel = triton.kernel(src, ['C']) @@ -75,10 +88,10 @@ void dot(TYPE * A, TYPE * B, TYPE * C, 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} - _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, + _dot.kernel(a, b, c, 1., M, N, Ka, lda, ldb, ldc, grid, bench=bench, AT = transpose_a, BT = transpose_b, TYPE = dtype, - TM = [64, 128], TN = [64, 128], TK = [8], **macros) + TM = [64], TN = [128], TK = [8], **macros) return c @staticmethod diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 4c3409885..ff29432e5 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -1,234 +1,651 @@ -# Special thanks to Scott Gray from OpenAI for writing the einsum parsing function - - +import numpy as np +import torch +from math import ceil, log2 +from enum import IntEnum import triton -import math +from functools import reduce +from operator import mul +from sympy.parsing.sympy_parser import parse_expr +import sympy as sp +from collections import OrderedDict +from collections import namedtuple +import re +from sympy.printing.ccode import C89CodePrinter + class _einsum(triton.function): - src = """ -void einsumk(TYPE * A, TYPE * B, TYPE * C, - int dim_M, int dim_N, int dim_K, - int std_A0 __multipleof(8), - int std_B0 __multipleof(8), - int std_C0 __multipleof(8), - int std_A1 __multipleof(8), - int std_B1 __multipleof(8), - int std_C1 __multipleof(8)) { - // program id - int pgm = get_program_id(0); - int pgn = get_program_id(1); - int pgb = get_program_id(2); - // range - int rm[TM] = pgm * TM + 0 ... TM; - int rn[TN] = pgn * TN + 0 ... TN; - int rb[TB] = pgb * TB + 0 ... TB; - int rk[TK] = 0 ... TK; - // accumulator - float c[TM, TN, TB] = 0; - // pointers to a - TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK - + rm[BROADCAST_AM] * STRIDE_AM - + rb[newaxis, newaxis, :] * std_A0; - // pointers to b - TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK - + rn[BROADCAST_BN] * STRIDE_BN - + rb[newaxis, newaxis, :] * std_B0; - // prefetch - TYPE a[SHAPE_A] = *pa; - TYPE b[SHAPE_B] = *pb; - // accumulation - for(int k = dim_K; k > 0; k -= TK) { - c += USE_A @ USE_B; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; - bool checka[SHAPE_A] = k > TK; - bool checkb[SHAPE_B] = k > TK; - a = checka ? *pa : 0; - b = checkb ? *pb : 0; - } - // write-back - TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1 - + rn[newaxis, :, newaxis] * 1 - + rb[newaxis, newaxis, :] * std_C0; - bool checkm[TM] = rm < dim_M; - bool checkn[TN] = rn < dim_N; + + ############################# + ## Triton-C code generation + ############################# + def print_cc(expr, axes_0, axes_1, axes_2): + + class TritonCodePrinter(C89CodePrinter): + + def __init__(self, axes_0, axes_1, axes_2): + super(TritonCodePrinter, self).__init__() + self.axes_0 = axes_0 + self.axes_1 = axes_1 + self.axes_2 = axes_2 + + def _print_Symbol(self, expr): + name = super(C89CodePrinter, self)._print_Symbol(expr) + if expr in self.axes_0: + return f'r{name}[:, newaxis, newaxis]' + if expr in self.axes_1: + return f'r{name}[newaxis, :, newaxis]' + if expr in self.axes_2: + return f'r{name}[newaxis, newaxis, :]' + return name + + def _print_Indexed(self, expr): + assert len(expr.indices) == 1 + return "*(%s + %s)" % (self._print(expr.base.label), + self._print(expr.indices[0])) + + return TritonCodePrinter(axes_0, axes_1, axes_2).doprint(expr) + + + def unpack_cc(tile, axes, prefix, remat): + ret = '' + axes = list(map(str, axes)) + for i, d in enumerate(reversed(axes)): + if i == len(axes) - 1: + break + currs = ''.join(axes[: len(axes) - i]) + nexts = ''.join(axes[: len(axes) - (i + 1)]) + ty = '' if remat else 'int ' + sz = '' if remat else f'[{tile}]' + ret += f' {ty}{prefix}{nexts}{sz} = r{currs} / dim_{d};\n' + ret += f' {ty}{prefix}{d}{sz} = r{currs} % dim_{d};\n' + return ret + + def strides_cc(name, expr): + ret = [f'stride_{name}_{d}' for d in expr[:-1]] + ['1'] + ret = dict(zip(expr, ret)) + return ret + + def make_kernel(name, + expr_a, expr_b, expr_c, + axes_m, axes_n, axes_k, axes_b, + multipleof_a, multipleof_b, multipleof_c, + lut_mode_a, lut_mode_b, + delta_a, delta_b, + subscripted): + + use_lut_a = True + use_lut_b = True + + src = "" + + if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: + src += f""" +char __constant__* AD = calloc({4*len(delta_a)});""" + if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: + src += f""" +char __constant__* BD = calloc({4*len(delta_b)});""" + + + src += f""" +__global__ void {name}( + TYPE * A __noalias __readonly __aligned(16) + , TYPE * B __noalias __readonly __aligned(16) + , TYPE * C + , int * locks + , float alpha + , int matmul_m, int matmul_n, int matmul_k __multipleof(16) + , int div_m + """ + for dim in [axes_m, axes_n, axes_k, axes_b]: + for d in dim: + src += f", int dim_{d}" + src += "\n " + for dim, name, mult in zip([expr_a, expr_b, expr_c], + ['a', 'b', 'c'], + [multipleof_a, multipleof_b, multipleof_c]): + for d in range(len(dim) - 1): + attr = f'__multipleof({mult})' + src += f", int stride_{name}_{d} {attr}" + src += "\n " + if lut_mode_a == _einsum.LUT_MODE.SCALAR: + src += f", int stride_a_inner __multipleof({multipleof_a})" + elif lut_mode_a == _einsum.LUT_MODE.DRAM: + src += ", int* AD __noalias __readonly __aligned(16)" + src += "\n " + if lut_mode_b == _einsum.LUT_MODE.SCALAR: + src += f", int stride_b_inner __multipleof({multipleof_b})" + elif lut_mode_b == _einsum.LUT_MODE.DRAM: + src += ", int* BD" + for ptr in subscripted: + src += f", int* {ptr}" + src += """) { + + // re-order outer program ids + int grid_m = (matmul_m + TM - 1) / TM; + int grid_n = (matmul_n + TN - 1) / TN; + int pid_mn = get_program_id(0) / div_m; + int pid_n = pid_mn % grid_n; + int pid_m = (pid_mn / grid_n)*div_m + (get_program_id(0) % div_m); + + // get batch program id + int pid_b = get_program_id(1); + +#if TZ == 1 + int off_k = 0; +#else + // get reduction sub-group program id + int pid_z = get_program_id(2); + int grid_z = get_num_programs(2); + int div_z = matmul_k / TZ; + int rem_z = matmul_k % TZ; + int off_k = pid_z * div_z; + matmul_k = select(pid_z < rem_z, div_z, div_z + rem_z); +#endif + + // create ranges +""" + rk = 'r{}'.format(''.join(map(str,axes_k))) + for axes, tile, off in zip([axes_m, axes_n, axes_b, axes_k], + ['TM', 'TN', 'TB', 'TK'], + ['pid_m*TM', 'pid_n*TN', 'pid_b*TB', 'off_k']): + currs = ''.join(map(str,axes)) + if axes: + src += f" int r{currs}[{tile}] = {off} + 0 ... {tile};\n" + src += _einsum.unpack_cc(tile, axes, 'r', False) + + src += """ + // initialize pointers to A + int offa[TM, TK, TB] = """ + for i, sym in enumerate(expr_a): + ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b) + stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1' + if i > 0: + src += ' + ' + src += f"({ccode}) * {stride}\n " + src += ';' + + src += """ + TYPE *pa[TM, TK, TB] = A + offa;""" + + if use_lut_a and not lut_mode_a == _einsum.LUT_MODE.SCALAR: + spec = '__constant__' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else '' + cast = '(int __constant__*)' if lut_mode_a == _einsum.LUT_MODE.CONSTANT else '' + src += f""" + // initialize pointers to A look-up table + int offadelta[TK] = off_k + 0 ... TK; + int {spec} *padelta[TK] = {cast}AD + offadelta; + int incda[TM, TK, TB] = (*padelta)[newaxis, :, newaxis];""" + + src += """ + + // initialize pointers to B + int offb[TK, TN, TB] = """ + for i, sym in enumerate(expr_b): + ccode = _einsum.print_cc(sym, axes_k, axes_n, axes_b) + stride = f'stride_b_{i}' if i < len(expr_b) - 1 else '1' + if i > 0: + src += ' + ' + src += f"({ccode}) * {stride}\n " + src += ';' + + src += """ + TYPE *pb[TK, TN, TB] = B + offb;""" + + + if use_lut_b and not lut_mode_b == _einsum.LUT_MODE.SCALAR: + spec = '__constant__' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else '' + cast = '(int __constant__*)' if lut_mode_b == _einsum.LUT_MODE.CONSTANT else '' + src += f""" + // initialize pointers to B look-up table + int offbdelta[TK] = off_k + 0 ... TK; + int *pbdelta[TK] = BD + offbdelta;""" + + src += f""" + + // prefetch + bool checkm[TM] = r""" + ''.join(map(str,axes_m)) + f""" < matmul_m; + bool checkn[TN] = r""" + ''.join(map(str,axes_n)) + f""" < matmul_n; + bool checkk[TK] = {rk} < matmul_k + off_k; + bool checka[TM, TK, TB] = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; + bool checkb[TK, TN, TB] = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; + TYPE a[TM, TK, TB] = checka ? *pa : 0; + TYPE b[TK, TN, TB] = checkb ? *pb : 0; + // accumulate + float acc[TM, TN, TB] = 0; + for(int k = matmul_k; k > 0; k -= TK) {{ + acc += a @ b;""" + + if not use_lut_a or not use_lut_b: + src += f""" + {rk} += TK; +""" + src += _einsum.unpack_cc(tile, axes_k, 'r', True) + + + if use_lut_a: + if lut_mode_a == _einsum.LUT_MODE.SCALAR: + src += """ + pa += stride_a_inner;""" + else: + src += """ + pa += incda; + padelta += TK; + incda = (*padelta)[newaxis, :, newaxis];""" + else: + src += """ + offa = """ + for i, sym in enumerate(expr_a): + ccode = _einsum.print_cc(sym, axes_m, axes_k, axes_b) + stride = f'stride_a_{i}' if i < len(expr_a) - 1 else '1' + if i > 0: + src += ' + ' + src += f"({ccode}) * {stride}\n " + src += """; + TYPE *pa[TM, TK, TB] = A + offa;""" + + + + if lut_mode_b == _einsum.LUT_MODE.SCALAR: + src += """ + pb += stride_b_inner;""" + else: + src += """ + pb += (*pbdelta)[:, newaxis, newaxis]; + pbdelta += TK;""" + + src += f""" + checkk = k > TK; + checka = checkm[:, newaxis, newaxis] && checkk[newaxis, :, newaxis]; + checkb = checkk[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; + a = *?(checka)pa; + b = *?(checkb)pb; + }} + TYPE c[TM, TN, TB] = acc; + + // re-materialize ranges +""" + for axes, tile, off in zip([axes_m, axes_n, axes_b], + ['TM', 'TN', 'TB'], + ['pid_m*TM', 'pid_n*TN', 'pid_b*TB']): + currs = ''.join(map(str,axes)) + if axes: + src += f" r{currs} = {off} + 0 ... {tile};\n" + src += _einsum.unpack_cc(tile, axes, 'r', True) + + src += """ + // initialize pointers to C + int offc[TM, TN, TB] = """ + for i, sym in enumerate(expr_c): + stride = f'stride_c_{i}' if i < len(expr_c) - 1 else '1' + ccode = _einsum.print_cc(sym, axes_m, axes_n, axes_b) + if i > 0: + src += ' + ' + src += f"({ccode}) * {stride}\n " + src += ';' + + src += """ + TYPE *pc[TM, TN, TB] = C + offc; + + // bounds-checking + checkm = r""" + ''.join(map(str,axes_m)) + """ < matmul_m; + checkn = r""" + ''.join(map(str,axes_n)) + """ < matmul_n; bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] && - checkn[newaxis, :, newaxis]; - *?(checkc)pc = (TYPE[TM, TN, TB])c; + checkn[newaxis, :, newaxis]; + + // write back +#if TZ == 1 + *?(checkc)pc = c; +#else + int *plock = locks + pid_mn + pid_b * get_num_programs(0); + int *pcount = plock + 1024*1024; + // spin + for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); + int count = *pcount; + if(count == 0) + *?(checkc)pc = c; + else + *?(checkc)pc = c + *?(checkc)pc; + atomic_xchg(pcount, (count + 1) % (grid_z)); + atomic_xchg(plock, 0); +#endif } """ - kernel = triton.kernel(src, ['C']) + #print(src) + ret = triton.kernel(src, ['C']) + if use_lut_a and lut_mode_a == _einsum.LUT_MODE.CONSTANT: + ret.set_constant('AD', delta_a) + if use_lut_b and lut_mode_b == _einsum.LUT_MODE.CONSTANT: + ret.set_constant('BD', delta_b) + return ret + + ############################ + ## Look-up Table + ############################ + + class LUT_MODE(IntEnum): + SCALAR = 1 + CONSTANT = 2 + DRAM = 3 + + def lut_mode(delta): + if delta.size == 0 or np.min(delta) == np.max(delta): + return _einsum.LUT_MODE.SCALAR + #if delta.size < 4096: + # return _einsum.LUT_MODE.CONSTANT + return _einsum.LUT_MODE.DRAM + + def symbolic_delta(symbols, axes): + rank = len(symbols) + strides = [sp.symbols(f'stride{d}') for d in range(rank)] + nexts = {s: sp.symbols(f'next{s}') for s in axes} + delta = 0 + for i in range(rank): + delta += strides[i] * (symbols[i].subs(nexts) - symbols[i]) + return delta + + def unpack_offset(k, axes, dims): + ret = dict() + for d in reversed(axes): + ret[d] = k % dims[d] + k = k // dims[d] + return ret + + def make_delta(axes, step, stride, dims, symbols, arrays): + # symbolic pointer increments + delta = _einsum.symbolic_delta(symbols, axes) + args = [f'stride{d}' for d in range(len(stride))] + args += [f'{sk}' for sk in axes] + args += [f'next{sk}' for sk in axes] + args += [f'{sk}' for sk, _ in arrays] + fn = sp.lambdify(args, delta, 'numpy') + # inner axes values + inner = [dims[d] for d in axes] + k = np.arange(np.prod(inner), dtype=np.int32) + off = _einsum.unpack_offset(k, axes, dims) + nextoff = _einsum.unpack_offset(k + step, axes, dims) + # evaluate deltas + args = [s for s in stride] + args += [off[sk] for sk in axes] + args += [nextoff[sk] for sk in axes] + args += [x for _, x in arrays] + delta = fn(*args) + return delta, _einsum.lut_mode(delta[:-step]) + + ############################ + ## Einsum parsing + ############################ + + def uniq(seq): + seen = set() + seen_add = seen.add + return [x for x in seq if not (x in seen or seen_add(x))] + + def parse_axes(expr_a, expr_b, expr_c, subscripted): + is_index = lambda x: type(x) == sp.indexed.Indexed or str(x) in subscripted + sym_a = [x for s in expr_a for x in s.free_symbols if not is_index(x)] + sym_b = [x for s in expr_b for x in s.free_symbols if not is_index(x)] + sym_c = [x for s in expr_c for x in s.free_symbols] + batch = [d for d in sym_a if d in sym_b and d in sym_c] + outer = [d for d in sym_a if d not in sym_b and d in sym_c] + inner = [d for d in sym_a if d in sym_b and d not in sym_c] + illegal = [d for d in sym_a if d not in sym_b and d not in sym_c] + if illegal: + raise ValueError(f"einsum labels {illegal} ({expr_a}) "\ + f"not present in {expr_b} or {expr_c}") + return _einsum.uniq(batch), _einsum.uniq(outer), _einsum.uniq(inner) + + + def replace_subscript(expr, arrays): + # replace array indexing by Indexed() + indexed = re.findall('([_a-zA-Z][_a-zA-Z0-9]*)\[([_a-z]*)\]', expr) + for x in indexed: + arrays.append(x[0]) + expr = expr.replace(f'{x[0]}[{x[1]}]', f'Indexed({x[0]},{x[1]})') + return expr + + + def parse_expr(expr, arrays): + # extract symbols + sym = [] + i = 0 + while i < len(expr): + d = expr[i] + if d == '(': + size = expr[i:].find(')') + d = expr[i : i + size + 1] + d = _einsum.replace_subscript(d, arrays) + sym.append(parse_expr(d)) + i += size + 1 + else: + sym.append(parse_expr(d)) + i += 1 + return sym - @staticmethod - def _append_dim(dim_data, dim_type, idx, label, dim, stride): - if dim_type in dim_data: - data = dim_data[dim_type] - if idx != data["idx"] + 1: - raise ValueError("aggregate inner, outer and batch dims must be adjacent to each other.") - data["dim"] *= dim - data["lab"] = label + data["lab"] - else: - dim_data[dim_type] = dict(idx=idx, lab=label, dim=dim, std=stride) - return dim_type + ############################ + ## Preprocessing + ############################ @staticmethod - def _parse_abc(labels_a, labels_b, labels_c, shape_a, is_a=False): + def pad(tensor, pad): + pad = pad + [0] * (2*len(tensor.shape) - len(pad)) + begin = [ x if x > 0 else None for x in pad[-1::-2]] + end = [-x if x > 0 else None for x in pad[-2::-2]] + slices = [slice(b, e) for b, e in zip(begin, end)] + tensor = torch.nn.functional.pad(tensor, pad, 'constant', 0) + tensor = tensor[slices] + return tensor - if len(labels_a) != len(shape_a): - raise ValueError(f"einsum notation dims do not match shape: {labels_a} {shape_a}") - trans = False - stride = 1 - std1 = None - data = dict() - for idx, (lab, dim) in enumerate(reversed(list(zip(labels_a, shape_a)))): - #print(idx, lab, dim) - if dim is None: - raise ValueError("einsum doens't currently work on shapes with placeholder dims.") - if idx == 0 and dim % 8 != 0: - raise ValueError("contiguous dim must be multiple of 8") + ############################ + ## Compilation + ############################ - if lab in labels_c: - # batch dim - if lab in labels_b: - _einsum._append_dim(data, "B", idx, lab, dim, stride) - if idx == 0: - raise ValueError(f"batch dim can not be contiguous dim: {lab} {labels_a} {shape_a}") - # outer dim - else: - std1 = _einsum._append_dim(data, "O", idx, lab, dim, stride) - if idx == 0: - trans = is_a - # inner dim - elif lab in labels_b: - std1 = _einsum._append_dim(data, "I", idx, lab, dim, stride) - if idx == 0: - trans = not is_a - else: - raise ValueError(f"einsum def for output: {lab} ({labels_a}), not present in either other def") + class instance: - stride *= dim + locks = None + kernel_cache = dict() - if "B" not in data: - data["B"] = dict(dim=1, std=1) + def __init__(self, einsum, dtype, stride_a, stride_b, stride_c, shape_a, shape_b, shape_c, arrays): + # parse symbols + expr_a, expr_bc = einsum.split(",") + expr_b, expr_c = expr_bc.split("->") + subscripted = [] + sym_a = _einsum.parse_expr(expr_a, subscripted) + sym_b = _einsum.parse_expr(expr_b, subscripted) + sym_c = _einsum.parse_expr(expr_c, subscripted) + # parse axes + axes_b, axes_m, axes_k = _einsum.parse_axes(sym_a, sym_b, sym_c, subscripted) + _, axes_n, _ = _einsum.parse_axes(sym_b, sym_a, sym_c, subscripted) + axes = axes_b + axes_m + axes_n + axes_k + # check dimensions + dims_a = dict(zip(sym_a, shape_a)) + dims_b = dict(zip(sym_b, shape_b)) + dims_c = dict(zip(sym_c, shape_c)) + for axes in [axes_b, axes_k]: + for d in axes: + dim_a = dims_a[d] if d in sym_a else None + dim_b = dims_b[d] if d in sym_b else None + if dim_a and dim_b and dim_a != dim_b: + raise ValueError(f'incompatible dimension {d}' + f' (a: {dim_a}; b: {dim_b})') + dims = dict() + dims.update(dims_a) + dims.update(dims_b) + dims.update(dims_c) + # look-up tables + TK = 16 if dtype == triton.fw.torch.float16 else 8 + arrays = [(x, arrays[x]) for x in subscripted] + delta_a, lut_mode_a = _einsum.make_delta(axes_k, TK, stride_a, dims, sym_a, arrays) + delta_b, lut_mode_b = _einsum.make_delta(axes_k, TK, stride_b, dims, sym_b, arrays) + # hash for recompilation + stride_a_multiple = max([x for x in [1, 2, 4, 8] if shape_a[-1] % x == 0]) + stride_b_multiple = max([x for x in [1, 2, 4, 8] if shape_b[-1] % x == 0]) + stride_c_multiple = max([x for x in [1, 2, 4, 8] if shape_c[-1] % x == 0]) + name = f'{expr_a}_{expr_b}_{expr_c}_{lut_mode_a}_{lut_mode_b}'\ + f'_{stride_a_multiple}_{stride_b_multiple}_{stride_c_multiple}' + # recompile if necessary + cache = _einsum.instance.kernel_cache + if name not in cache: + cachesize = len(cache) + cache[name] = _einsum.make_kernel(f'__einsum{cachesize}', + sym_a, sym_b, sym_c, + axes_m, axes_n, axes_k, axes_b, + stride_a_multiple, stride_b_multiple, stride_c_multiple, + lut_mode_a, lut_mode_b, + delta_a, delta_b, + subscripted) + self.kernel = cache[name] + # Initialize locks + if _einsum.instance.locks is None: + _einsum.instance.locks = torch.zeros(2*1024*1024, dtype=torch.int32).cuda() + # Kernel arguments + dim_m = [dims[d] for d in axes_m] + dim_n = [dims[d] for d in axes_n] + dim_k = [dims[d] for d in axes_k] + dim_b = [dims[d] for d in axes_b] + M = reduce(mul, dim_m, 1) + N = reduce(mul, dim_n, 1) + K = reduce(mul, dim_k, 1) + B = reduce(mul, dim_b, 1) + stride_a = list(stride_a[:-1]) + stride_b = list(stride_b[:-1]) + stride_c = list(stride_c[:-1]) + arrays = [torch.from_numpy(x).cuda() for _, x in arrays] + alpha = 1. + div_m = 1 + self.args = [None, None, None, + _einsum.instance.locks, + alpha, M, N, K, div_m] +\ + dim_m + dim_n + dim_k + dim_b +\ + stride_a + stride_b + stride_c + if lut_mode_a != _einsum.LUT_MODE.CONSTANT: + delta_a = delta_a[0] if lut_mode_a == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_a).cuda() + self.args += [delta_a] + if lut_mode_b != _einsum.LUT_MODE.CONSTANT: + delta_b = delta_b[0] if lut_mode_b == _einsum.LUT_MODE.SCALAR else torch.from_numpy(delta_b).cuda() + self.args += [delta_b] + self.args += arrays + self.args += [lambda opt: [triton.cdiv(M, opt.d('TM')) * + triton.cdiv(N, opt.d('TN')), + triton.cdiv(B, opt.d('TB')), + opt.d('TZ')]] + # position of dynamic arguments + self.pos_a = 0 + self.pos_b = 1 + self.pos_c = 2 + # pre-processor macros + TM = [x for x in [16, 32, 64, 128] if x <= M] + TN = [x for x in [16, 32, 64, 128] if x <= N] + TB = [x for x in [1, 2, 4] if x <= B] + MAX_GZ = K // 2048 + MIN_GM = M // max(TM) + MIN_GN = N // max(TN) + MIN_GB = B // max(TB) + TZ = [x for x in [1, 2, 4, 8, 16, 32] \ + if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] + TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] + #TB, TZ = [1], [1] + #TM, TN, TB, TZ = [128], [128], [1], [1] + self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype } + self.dtype = dtype + self.flops = 2 * B * M * N * K + self.sym_a = sym_a + self.sym_b = sym_b + self.sym_c = sym_c + # save equivalent mat-mul dimensions + self.matmul_B = B + self.matmul_M = M + self.matmul_N = N + self.matmul_K = K + + def run(self, a, b, c, bench): + self.args[self.pos_a] = a + self.args[self.pos_b] = b + self.args[self.pos_c] = c + self.kernel(*self.args, bench=bench, **self.macros) - # batch, outer, inner, std0, std1, trans - return data["B"]["dim"], data["O"]["dim"], data["I"]["dim"], data["B"]["std"], data[std1]["std"], trans + + + + ############################ + ## Forward + ############################ + + instance_cache = dict() @staticmethod - def _parse_einsum(labels_a, labels_b, labels_c, shape_a, shape_b): - - dims_a = dict(zip(labels_a, shape_a)) - dims_b = dict(zip(labels_b, shape_b)) - shape_c = list() - for lab in labels_c: - if lab in dims_a: - shape_c.append(dims_a[lab]) - elif lab in dims_b: - shape_c.append(dims_b[lab]) - else: - raise ValueError(f"einsum def for output: {lab} ({labels_c}), not present in either input def ({labels_a}, {labels_b})") - - BA, M, KA, std_a0, std_a1, ta = _einsum._parse_abc(labels_a, labels_b, labels_c, shape_a, True) - BB, N, KB, std_b0, std_b1, tb = _einsum._parse_abc(labels_b, labels_a, labels_c, shape_b, False) - BC, _, _, std_c0, std_c1, _ = _einsum._parse_abc(labels_c, labels_b, labels_a, shape_c) - - if not (BA == BB == BC): - raise ValueError("mismatched batch dims") - if KA != KB: - raise ValueError("mismatched reduction dims") - - return shape_c, (BA, M, N, KA), (std_a0, std_b0, std_c0), (std_a1, std_b1, std_c1), ta, tb - - @staticmethod - def call(a, b, trans_a, trans_b, shape_c, bmnk, - std0, std1, einsum_a, einsum_b, einsum_c, - bench): + def forward(ctx, einsum, a, b, shape_c, **kwargs): + bench = kwargs['bench'] if 'bench' in kwargs else False + arrays = kwargs['arrays'] if 'arrays' in kwargs else dict() + # allocate output dtype = a.dtype - c = triton.empty(shape_c, dtype) - grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')), - triton.cdiv(bmnk[2], opt.d('TN')), - triton.cdiv(bmnk[0], opt.d('TB'))] - macros = {# handle A transposition - 'USE_A' : 'a[^1, ^0, ^2]' if trans_a else 'a', - 'STRIDE_AK' : 'std_A1' if trans_a else '1', - 'STRIDE_AM' : '1' if trans_a else 'std_A1', - 'BROADCAST_AK': ':, newaxis, newaxis' if trans_a else 'newaxis, :, newaxis', - 'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis', - 'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB', - # handle B transposition - 'USE_B' : 'b' if not trans_b else 'b[^1, ^0, ^2]', - 'STRIDE_BK' : 'std_B1' if not trans_b else '1', - 'STRIDE_BN' : '1' if not trans_b else 'std_B1', - 'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis', - 'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis', - 'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'} - TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))] - TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))] - TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))] - TK = [bmnk[2]] if bmnk[2] < 16 else [8, 16] - _einsum.kernel(a, b, c, - bmnk[1], bmnk[2], bmnk[3], - std0[0], std0[1], std0[2], - std1[0], std1[1], std1[2], - grid, bench=bench, - **macros, - TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB) + c = triton.empty(shape_c, dtype=dtype) + key = (einsum, dtype, + a.stride(), b.stride(), c.stride(), + a.shape, b.shape, c.shape) + # compile einsum instance + cache = _einsum.instance_cache + #if key not in cache: + cache[key] = _einsum.instance(einsum, dtype, + a.stride(), b.stride(), c.stride(), + a.shape, b.shape, c.shape, arrays) + instance = cache[key] + instance.run(a, b, c, bench) + # save information in context + ctx.flops = instance.flops + ctx.sym_a = instance.sym_a + ctx.sym_b = instance.sym_b + ctx.sym_c = instance.sym_c + ctx.matmul_B = instance.matmul_B + ctx.matmul_M = instance.matmul_M + ctx.matmul_N = instance.matmul_N + ctx.matmul_K = instance.matmul_K + ctx.bench = bench + ctx.save_for_backward(a, b) return c + ############################ + ## Backward + ############################ @staticmethod - def forward(ctx, subscripts, a, b, bench = 0): - ctx.save_for_backward(a, b) - # parse - if type(subscripts) is str: - einsum_a, einsum_bc = subscripts.split(",") - einsum_b, einsum_c = einsum_bc.split("->") - else: - einsum_a, einsum_b, einsum_c = subscripts - shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum( - einsum_a, einsum_b, einsum_c, - triton.shape(a), triton.shape(b)) - # save for backward - ctx.trans_a = ta - ctx.trans_b = tb - ctx.einsum_a = einsum_a - ctx.einsum_b = einsum_b - ctx.einsum_c = einsum_c - ctx.bench = bench - ctx.bmnk = bmnk - # run - return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench) - + def sym_invert(sym_c, sym_x, prefix, renamed, inverse): + for i, expr in enumerate(sym_x): + if expr.is_symbol: + continue + sc = [x for x in expr.free_symbols if x in sym_c][0] + sx = sp.symbols(f'{prefix}{i}') + renamed[expr] = sx + inverse[sc] = sp.solve(sp.Eq(expr, sx), sc)[0] @staticmethod - def backward(ctx, dc): + def sym_to_expr(sym): + res = [f'({x})' for x in sym] + res = ''.join(res) + return res + + @staticmethod + def backward(ctx, dy): a, b = ctx.saved_tensors - trans_a = ctx.trans_a - trans_b = ctx.trans_b - einsum_a = ctx.einsum_a - einsum_b = ctx.einsum_b - einsum_c = ctx.einsum_c - bench = ctx.bench + sym_a = ctx.sym_a + sym_b = ctx.sym_b + sym_c = ctx.sym_c + inverse = dict() + renamed = dict() + _einsum.sym_invert(sym_c, sym_a, 'a', renamed, inverse) + _einsum.sym_invert(sym_c, sym_b, 'b', renamed, inverse) + sym_a = [renamed[x] if x in renamed else x for x in sym_a] + sym_b = [renamed[x] if x in renamed else x for x in sym_b] + sym_c = [inverse[x] if x in inverse else x for x in sym_c] + expr_a = _einsum.sym_to_expr(sym_a) + expr_b = _einsum.sym_to_expr(sym_b) + expr_c = _einsum.sym_to_expr(sym_c) + expr = f'{expr_c},{expr_b}->{expr_a}' + da = einsum(expr, dy, b, a.shape, False) + return None, da, None, None, None - if not trans_a and not trans_b: # NN - da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench) - db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench) - elif not trans_a and trans_b: # NT - da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench) - db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench) - - elif trans_a and not trans_b: # TN - da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench) - db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench) - - elif trans_a and trans_b: # TT (not used) - da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench) - db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench) - - return None, da, db, None einsum = _einsum.apply \ No newline at end of file diff --git a/python/triton/utils.py b/python/triton/utils.py index 0b012af3f..117f69136 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -24,7 +24,7 @@ def empty(shape, dtype): return tf_empty_proxy(shape, dtype) #return fw.tf_extra_ops.alloc_empty(args, T = dtype) elif fw.has_torch(): - return fw.torch.empty(shape).cuda() + return fw.torch.empty(shape, dtype=dtype).cuda() def shape(A) : if fw.has_tensorflow(): @@ -47,16 +47,23 @@ class id_dict: return libtriton.retrieve_scalar(self.id) def __init__(self): - self.data = weakref.WeakKeyDictionary() + self.data = dict() def __delitem__(self, key): del self.data[key] - def __getitem__(self, key): + @staticmethod + def _get_key(key): if fw.has_tensorflow(): if isinstance(key, fw.tensorflow.Tensor): - key = key.op - ret = self.data[key] + key = id(key.op) + if fw.has_torch(): + if isinstance(key, fw.torch.Tensor): + key = id(key) + return key + + def __getitem__(self, key): + ret = self.data[id_dict._get_key(key)] if isinstance(ret, id_dict.lazy_entry): return ret.get() return ret @@ -65,7 +72,4 @@ class id_dict: return len(self.data) def __setitem__(self, key, value): - if fw.has_tensorflow(): - if isinstance(key, fw.tensorflow.Tensor): - key = key.op - self.data[key] = value \ No newline at end of file + self.data[id_dict._get_key(key)] = value \ No newline at end of file diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 876ce0962..79718a232 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -10,10 +10,12 @@ int main() { typedef std::tuple, bool, bool, int, int, int> config_t; std::vector configs; for(auto ord: std::vector>{{1, 0}}) - for(auto x: std::vector>{{false, false}, {false, true}, - {true, false}}){ + for(auto x: std::vector>{{false, false}}){ std::vector tmp = { - config_t{ord, x[0], x[1], 2048, 2048, 2048}, +// config_t{ord, x[0], x[1], 512, 512, 512}, +// config_t{ord, x[0], x[1], 1024, 1024, 1024}, + config_t{ord, x[0], x[1], 127008, 768, 576}, +// config_t{ord, x[0], x[1], 8192, 8192, 8192} // config_t{ord, x[0], x[1], 16, 2048, 2048}, // config_t{ord, x[0], x[1], 32, 2048, 2048}, // config_t{ord, x[0], x[1], 64, 2048, 2048}, @@ -33,7 +35,7 @@ int main() { int32_t M, N, K; for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; - std::cout << "// " << c << std::flush; + std::cout << "// " << c ; for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; diff --git a/tests/common/dot.h b/tests/common/dot.h index a157d7994..427e7ca04 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -20,7 +20,7 @@ static void cc_dot(std::vector &c, const std::vector &a, const std::vector float acc = 0; for(size_t k = 0; k < K; k++) acc = acc + (!AT ? a[k*M + m] : a[m*K + k]) * (!BT ? b[n*K + k] : b[k*N + n]); - c[m + n*M] = static_cast(acc); + c[m*N + n] = static_cast(acc); } } @@ -72,9 +72,9 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, std::string ty = to_string::value; size_t dt_nbytes = sizeof(T); drv::context* context = stream->context(); - int32_t lda = AT ? K : M; - int32_t ldb = BT ? N : K; - int32_t ldc = M; + int32_t lda = (AT ^ a_order[0]==1) ? K : M; + int32_t ldb = (BT ^ b_order[0]==1) ? N : K; + int32_t ldc = N; std::vector sa = { "1", "lda" }; std::vector sb = { "1", "ldb" }; @@ -86,17 +86,17 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, // macros rt::function::options_space_t opt; // A access patterns - opt.defines.push_back({"USEA", {AT? "a[^1, ^0]" : "a" }}); - opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }}); - opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }}); - opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }}); + opt.defines.push_back({"USEA", {AT? "a" : "a" }}); + opt.defines.push_back({"BROADCAST_AK", {AT? "newaxis, :" : "newaxis, :" }}); + opt.defines.push_back({"BROADCAST_AM", {AT? ":, newaxis" : ":, newaxis" }}); + opt.defines.push_back({"SHAPE_A", {AT? "TM, TK" : "TM, TK" }}); opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); // B access patterns - opt.defines.push_back({"USEB", {BT? "b[^1, ^0]" : "b" }}); - opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }}); - opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }}); - opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }}); + opt.defines.push_back({"USEB", {BT? "b" : "b" }}); + opt.defines.push_back({"BROADCAST_BK", {BT? ":, newaxis" : ":, newaxis" }}); + opt.defines.push_back({"BROADCAST_BN", {BT? "newaxis, :" : "newaxis, :" }}); + opt.defines.push_back({"SHAPE_B", {BT? "TK, TN" : "TK, TN" }}); opt.defines.push_back({"STRIDE_BK", {BT? sb[b_order[1]] : sb[b_order[0]] }}); opt.defines.push_back({"STRIDE_BN", {BT? sb[b_order[0]] : sb[b_order[1]] }}); // data-type @@ -109,15 +109,15 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, opt.num_warps = {nwarp}; } if(mode == BENCH) { - opt.defines.push_back({"TM", {"128"}}); - opt.defines.push_back({"TN", {"128"}}); - opt.defines.push_back({"TK", {"16"}}); - opt.num_warps = {4}; + opt.defines.push_back({"TM", {"32", "64", "128"}}); + opt.defines.push_back({"TN", {"32", "64", "128"}}); + opt.defines.push_back({"TK", {to_string::value == "half" ? "16" : "8"}}); + opt.num_warps = {2, 4, 8}; } // kernels rt::function function(src::dot, opt); - std::vector args = {&*da, &*db, &*dc, M, N, K, lda, ldb, ldc}; + std::vector args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc}; auto grid = grid2d(M, N); // metrics @@ -126,17 +126,17 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream); bench.push_back(tflops(triton_ns)); - // // cublas - // if(cublas::cublasinit()){ - // NumericT alpha(static_cast(1)); - // NumericT beta(static_cast(0)); - // cublasGemmAlgo_t fastest; - // cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); - // double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K, - // &alpha, &*da, lda, &*db, ldb, &beta, &*dc, - // ldc, nullptr, fastest); }, stream); - // result.push_back(tflops(cublas_ms)); - // } +// // cublas +// if(cublas::cublasinit()){ +// T alpha(static_cast(1)); +// T beta(static_cast(0)); +// cublasGemmAlgo_t fastest; +// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); +// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, +// &alpha, &*da, lda, &*db, ldb, &beta, &*dc, +// ldc, nullptr, fastest); }, stream); +// bench.push_back(tflops(cublas_ms)); +// } } // test triton @@ -147,9 +147,9 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, std::vector ha(M*K); std::vector hb(K*N); for(size_t i = 0; i < ha.size(); i++) - ha[i] = 1; + ha[i] = (float)rand()/RAND_MAX; for(size_t i = 0; i < hb.size(); i++) - hb[i] = 1; + hb[i] = (float)rand()/RAND_MAX; // copy buffer stream->write(&*da, true, 0, ha); stream->write(&*db, true, 0, hb); diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index 7c368e593..4dcab1efc 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -2,37 +2,58 @@ namespace src { const char *dot = R"( -void dot(TYPE * A, TYPE * B, TYPE * C, - int M, int N, int K, - int lda __multipleof(8), - int ldb __multipleof(8), - int ldc) { - // prologue - int ridx = get_program_id(0); - int ridy = get_program_id(1); - int rm[TM] = ridx * TM + 0 ... TM; - int rn[TN] = ridy * TN + 0 ... TN; - int rk[TK] = 0 ... TK; - float c[TM, TN] = 0; - // pointers to operands - TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; - TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; - // prefetches operands - TYPE a[SHAPE_A] = *pa; - TYPE b[SHAPE_B] = *pb; - // reduction loop - for(int k = K; k > 0; k-= TK){ - c += USEA @ USEB; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; - bool checka[SHAPE_A] = k > TK; - bool checkb[SHAPE_B] = k > TK; - a = checka ? *pa : 0; - b = checkb ? *pb : 0; - } - // epilogue - TYPE* pc[TM, TN] = C + rm[:, newaxis] + rn[newaxis, :] * ldc; - *pc = c; +__global__ void dot(TYPE * A __noalias __readonly __aligned(16), + TYPE * B __noalias __readonly __aligned(16), + TYPE * C __noalias __aligned(16), + float alpha, + int M, int N, int K, + int lda __multipleof(8), + int ldb __multipleof(8), + int ldc __multipleof(8)) { + // prologue + int ridx = get_program_id(0); + int ridy = get_program_id(1); + int gridx = M / TM; + int gridy = N / TN; + int rid = ridx + ridy * gridx; + ridx = rid / gridy; + ridy = rid % gridy; + int rm[TM] = ridx * TM + 0 ... TM; + int rn[TN] = ridy * TN + 0 ... TN; + int rk[TK] = 0 ... TK; + + // pointers to operands + int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; + int offb[SHAPE_B] = rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; + TYPE* pa[SHAPE_A] = A + offa; + TYPE* pb[SHAPE_B] = B + offb; + + // prefetches operands + bool checka[SHAPE_A] = rk[BROADCAST_AK] < K; + bool checkb[SHAPE_B] = rk[BROADCAST_BK] < K; + TYPE a[SHAPE_A] = checka ? *pa : 0; + TYPE b[SHAPE_B] = checkb ? *pb : 0; + + // reduction loop + float c[TM, TN] = 0; + for(int k = K; k > 0; k -= TK){ + c += USEA @ USEB; + bool checka[SHAPE_A] = k > TK; + bool checkb[SHAPE_B] = k > TK; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; + a = *?(checka)pa; + b = *?(checkb)pb; + } + //c = c * alpha; + + // epilogue + int rxm[TM] = get_program_id(0) * TM + 0 ... TM; + int rxn[TN] = get_program_id(1) * TN + 0 ... TN; + int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :]; + TYPE* pc[TM, TN] = C + offc; + bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); + *?(checkc)pc = (TYPE[TM, TN])c; } )"; diff --git a/tests/common/util.h b/tests/common/util.h index 0a7788195..89489f889 100644 --- a/tests/common/util.h +++ b/tests/common/util.h @@ -159,7 +159,7 @@ bool diff(const std::vector& hc, const std::vector& rc) { if(hc.size() != rc.size()) return false; for(size_t i = 0; i < hc.size(); i++) - if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; return false; } diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index dec01dc21..283951377 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -10,8 +10,8 @@ int main() { // shapes to test typedef std::tuple config_t; std::vector configs; - for(int TM: std::vector{32, 64}) - for(int TN: std::vector{32, 64}) + for(int TM: std::vector{32, 64, 128}) + for(int TN: std::vector{32, 64, 128}) for(int TK: std::vector{16}) for(int nwarps: std::vector{4}) for(bool AT: std::array{false, true})