From 5ce1b726dcbf288154253e59dc326dff5bde7b60 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 24 Oct 2021 02:30:46 -0700 Subject: [PATCH] [CODEGEN] Various bugfixes that make it possible to fuse RNG in a matmul epilogue (#356) --- include/triton/codegen/selection/generator.h | 2 + include/triton/ir/builder.h | 3 ++ include/triton/ir/dispatch.h | 2 + include/triton/ir/enums.h | 2 + include/triton/ir/instructions.h | 56 ++++++++++---------- include/triton/ir/visitor.h | 7 +-- lib/codegen/analysis/axes.cc | 3 +- lib/codegen/analysis/layout.cc | 1 + lib/codegen/selection/generator.cc | 30 +++++++++++ lib/codegen/transform/disassociate.cc | 2 + lib/driver/llvm.cc | 2 +- lib/ir/builder.cc | 7 +++ lib/ir/dispatch.cc | 9 ++++ lib/ir/instructions.cc | 32 ++++++++++- python/src/triton.cc | 2 + python/triton/language/core.py | 17 +++++- python/triton/language/random.py | 32 +++-------- 17 files changed, 149 insertions(+), 60 deletions(-) diff --git a/include/triton/codegen/selection/generator.h b/include/triton/codegen/selection/generator.h index a02e98d56..ad7d01a55 100644 --- a/include/triton/codegen/selection/generator.h +++ b/include/triton/codegen/selection/generator.h @@ -147,12 +147,14 @@ public: void visit_store_inst(ir::store_inst*); void visit_unmasked_store_inst(ir::unmasked_store_inst*); void visit_masked_store_inst(ir::masked_store_inst*); + void visit_cat_inst(ir::cat_inst*); void visit_reshape_inst(ir::reshape_inst*); void visit_splat_inst(ir::splat_inst*); void visit_broadcast_inst(ir::broadcast_inst*); void visit_downcast_inst(ir::downcast_inst*); void visit_exp_inst(ir::exp_inst*); void visit_cos_inst(ir::cos_inst*); + void visit_umulhi_inst(ir::umulhi_inst* x); void visit_sin_inst(ir::sin_inst*); void visit_log_inst(ir::log_inst*); void visit_get_program_id_inst(ir::get_program_id_inst*); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 40ced2bd1..a80bc471f 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -137,6 +137,7 @@ public: // Block instruction value *create_splat(value *arg, const type::block_shapes_t &shapes); value *create_reshape(value *arg, const type::block_shapes_t &shapes); + value *create_cat(value *lhs, value *rhs); value *create_broadcast(value *arg, const type::block_shapes_t &shapes); // Built-in instruction value *create_get_program_id(unsigned axis); @@ -153,6 +154,8 @@ public: value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis); value *create_select(value *pred, value *if_value, value *else_value); // Intrinsics + // These have no place in the IR, and hopefully they can be removed at some point + value *create_umulhi(value* lhs, value* rhs); value *create_copy_to_shared(value *arg); value *create_masked_load_async(value *arg, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache); value *create_copy_from_shared(value *arg); diff --git a/include/triton/ir/dispatch.h b/include/triton/ir/dispatch.h index b53c89d36..0c8295948 100644 --- a/include/triton/ir/dispatch.h +++ b/include/triton/ir/dispatch.h @@ -61,6 +61,7 @@ struct dispatch{ // casting ops static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder); + static ir::value *cat(ir::value *lhs, ir::value *rhs, ir::builder *builder); static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder); static std::tuple broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder); static ir::value *bitcast(ir::value *input, ir::type *type, ir::builder *builder); @@ -90,6 +91,7 @@ struct dispatch{ static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder); // math + static ir::value *umulhi(ir::value *x, ir::value *y, ir::builder *builder); static ir::value *exp(ir::value *x, ir::builder *builder); static ir::value *log(ir::value *x, ir::builder *builder); static ir::value *cos(ir::value *x, ir::builder *builder); diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h index 5be63d4d2..8cb7835f0 100644 --- a/include/triton/ir/enums.h +++ b/include/triton/ir/enums.h @@ -132,6 +132,7 @@ enum value_id_t: unsigned { // retile INST_RESHAPE, INST_SPLAT, + INST_CAT, INST_BROADCAST, INST_DOWNCAST, // builtin @@ -142,6 +143,7 @@ enum value_id_t: unsigned { INST_ATOMIC_EXCH, INST_ATOMIC_RMW, // math + INST_UMULHI, INST_EXP, INST_COS, INST_SIN, diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 9b1ffbb79..fdb2fd411 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -520,6 +520,21 @@ public: // retile_inst classes //===----------------------------------------------------------------------===// +// cat + +class cat_inst: public instruction { +private: + std::string repr_impl() const { return "cat"; } + cat_inst(value *x, value *y, const std::string &name, instruction *next); + +public: + static instruction* create(value *lhs, value *rhs, + const std::string &name = "", + instruction *next = nullptr); + _TRITON_DEFINE_CLONE(cat_inst) + _TRITON_DEFINE_ACCEPT(cat_inst) +}; + // retile class retile_inst: public unary_inst { @@ -654,6 +669,17 @@ public: static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr); }; +class umulhi_inst: public builtin_inst { +private: + umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr); + std::string repr_impl() const { return "umulhi"; } + _TRITON_DEFINE_CLONE(umulhi_inst) + _TRITON_DEFINE_ACCEPT(umulhi_inst) + +public: + static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr); +}; + class exp_inst: public builtin_inst { private: exp_inst(value *val, const std::string &name = "", instruction *next = nullptr); @@ -803,6 +829,7 @@ public: // intrinsics classes //===----------------------------------------------------------------------===// + class copy_to_shared_inst: public unary_inst{ private: using unary_inst::unary_inst; @@ -884,35 +911,6 @@ public: instruction *next=nullptr); }; -//// On NVIDIA, implementation is such that -//// constant_range = nv_dynamic_program_idx + nv_static_program_idx -//// so as to enable re-association on nv_static_program_idx which is constant -//class make_range_dyn: public instruction { -//private: -// make_range_dyn(type *ty, const std::string &name, instruction *next); -// std::string repr_impl() const { return "nv_dynamic_program_idx"; } -// _TRITON_DEFINE_CLONE(make_range_dyn) -// _TRITON_DEFINE_ACCEPT(make_range_dyn) - -//public: -// static make_range_dyn* create(type *ty, const std::string &name = "", instruction *next = nullptr); -//}; - -//class make_range_sta: public constant { -//private: -// make_range_sta(make_range *range); - -//public: -// static make_range_sta *get(make_range* range); -// make_range* get_range() const; -// std::string repr() const { return "nv_static_program_idx"; } -// _TRITON_DEFINE_ACCEPT(make_range_sta) - -//private: -// make_range *range_; -//}; - - /* constant range */ class make_range: public instruction{ make_range(type *ty, constant_int* first, constant_int* last); diff --git a/include/triton/ir/visitor.h b/include/triton/ir/visitor.h index 8073a6b66..4979b0b52 100644 --- a/include/triton/ir/visitor.h +++ b/include/triton/ir/visitor.h @@ -45,9 +45,11 @@ class masked_store_inst; class retile_inst; class reshape_inst; class splat_inst; +class cat_inst; class broadcast_inst; class downcast_inst; +class umulhi_inst; class exp_inst; class cos_inst; class sin_inst; @@ -122,6 +124,7 @@ public: virtual void visit_unmasked_store_inst(unmasked_store_inst*) = 0; virtual void visit_masked_store_inst(masked_store_inst*) = 0; + virtual void visit_umulhi_inst(umulhi_inst*) = 0; virtual void visit_exp_inst(exp_inst*) = 0; virtual void visit_cos_inst(cos_inst*) = 0; virtual void visit_sin_inst(sin_inst*) = 0; @@ -129,6 +132,7 @@ public: virtual void visit_reshape_inst(reshape_inst*) = 0; virtual void visit_splat_inst(splat_inst*) = 0; + virtual void visit_cat_inst(cat_inst*) = 0; virtual void visit_broadcast_inst(broadcast_inst*) = 0; virtual void visit_downcast_inst(downcast_inst*) = 0; @@ -150,13 +154,10 @@ public: virtual void visit_masked_load_async_inst(masked_load_async_inst*)= 0; virtual void visit_barrier_inst(barrier_inst*) = 0; virtual void visit_async_wait_inst(async_wait_inst*) = 0; -// virtual void visit_make_range_dyn(make_range_dyn*) = 0; virtual void visit_make_range(make_range*) = 0; virtual void visit_prefetch_s_inst(prefetch_s_inst*) = 0; - virtual void visit_function(function*) = 0; -// virtual void visit_make_range_sta(make_range_sta*) = 0; virtual void visit_undef_value(undef_value*) = 0; virtual void visit_constant_int(constant_int*) = 0; virtual void visit_constant_fp(constant_fp*) = 0; diff --git a/lib/codegen/analysis/axes.cc b/lib/codegen/analysis/axes.cc index 37b95eaa3..f079d2580 100644 --- a/lib/codegen/analysis/axes.cc +++ b/lib/codegen/analysis/axes.cc @@ -116,7 +116,8 @@ void axes::update_graph(ir::instruction *i) { switch (i->get_id()) { case ir::INST_REDUCE: return update_graph_reduce(i); case ir::INST_RESHAPE: return update_graph_reshape(i); - case ir::INST_SPLAT: return update_graph_no_edge(i);; + case ir::INST_SPLAT: return update_graph_no_edge(i); + case ir::INST_CAT: return update_graph_elementwise(i, true); case ir::INST_TRANS: return update_graph_trans(i); case ir::INST_BROADCAST: return update_graph_broadcast(i); case ir::INST_DOT: return update_graph_dot(i); diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 843265937..6ea0dd219 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -499,6 +499,7 @@ void layouts::run(ir::module &mod) { make_graph(i); }); + // connected components graph_.connected_components(&values_, &groups_); diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 6e31e5c7e..7316e047a 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -774,6 +774,22 @@ void generator::visit_masked_store_inst(ir::masked_store_inst* x) { visit_store_inst(x); } +/** + * \brief Code Generation for `cat` + */ +void generator::visit_cat_inst(ir::cat_inst* x) { + auto idxs = idxs_.at(x); + ir::value* lhs = x->get_operand(0); + ir::value* rhs = x->get_operand(1); + int i = 0; + for(size_t j = 0; j < idxs_.at(lhs).size(); j ++) + vals_[x][idxs_[x][i++]] = vals_[lhs][idxs_[lhs][j]]; + for(size_t j = 0; j < idxs_.at(rhs).size(); j ++){ + vals_[x][idxs_[x][i++]] = vals_[rhs][idxs_[rhs][j]]; + } +} + + /** * \brief Code Generation for `reshape` @@ -861,6 +877,20 @@ void generator::visit_cos_inst(ir::cos_inst* x){ } } +/** + * \brief Code Generation for `umulhi` + */ +void generator::visit_umulhi_inst(ir::umulhi_inst* x){ + std::vector tys = {i32_ty, i32_ty}; + FunctionType *fn_ty = FunctionType::get(i32_ty, tys, false); + InlineAsm *umulhi = InlineAsm::get(fn_ty, "mul.hi.u32 $0, $1, $2;", "=r,r,r", false); + for(auto idx: idxs_.at(x)){ + Value* lhs = vals_[x->get_operand(0)][idx]; + Value* rhs = vals_[x->get_operand(1)][idx]; + vals_[x][idx] = call(umulhi, std::vector{lhs, rhs}); + } + } + /** * \brief Code Generation for `sin` */ diff --git a/lib/codegen/transform/disassociate.cc b/lib/codegen/transform/disassociate.cc index 0d9e1b8ef..2709125f8 100644 --- a/lib/codegen/transform/disassociate.cc +++ b/lib/codegen/transform/disassociate.cc @@ -11,6 +11,8 @@ namespace transform{ ir::instruction* rematerialize(ir::builder& bld, ir::instruction *root, std::set& seen) { + if (dynamic_cast(root)) + return root; if(!seen.insert(root).second) return root; if(!root->get_type()->is_block_ty()) diff --git a/lib/driver/llvm.cc b/lib/driver/llvm.cc index f3c76ce77..3c11fbf35 100644 --- a/lib/driver/llvm.cc +++ b/lib/driver/llvm.cc @@ -178,7 +178,7 @@ std::string ptx_to_cubin(const std::string& ptx, int cc) { ofs.close(); std::string cmd; int err; - cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o 2> " + flog; + cmd = ptxas + " -v --gpu-name=sm_" + std::to_string(cc) + " " + fsrc + " -o " + fsrc + ".o"; err = system(cmd.c_str()); CUmodule ret; std::ifstream _cubin(_fbin, std::ios::binary ); diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index f4eadcb84..cc1d354ee 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -297,6 +297,10 @@ value *builder::create_reshape(value *arg, const type::block_shapes_t &shapes) { return insert(reshape_inst::create(arg, shapes)); } +value *builder::create_cat(value *lhs, value *rhs) { + return insert(cat_inst::create(lhs, rhs)); +} + value *builder::create_splat(value *arg, const type::block_shapes_t &shapes) { return insert(splat_inst::create(arg, shapes)); } @@ -369,6 +373,9 @@ value *builder::create_select(value *pred, value *if_value, value *else_value){ // intrinsic instructions //===----------------------------------------------------------------------===// +value *builder::create_umulhi(value *lhs, value *rhs) { + return insert(umulhi_inst::create(lhs, rhs)); +} value *builder::create_copy_to_shared(value *arg) { return insert(copy_to_shared_inst::create(arg)); diff --git a/lib/ir/dispatch.cc b/lib/ir/dispatch.cc index cc0626fbd..811e5c819 100644 --- a/lib/ir/dispatch.cc +++ b/lib/ir/dispatch.cc @@ -368,6 +368,10 @@ ir::value *dispatch::reshape(ir::value *input, shape_t dst_shape, ir::builder *b return builder->create_reshape(input, dst_shape); } +ir::value *dispatch::cat(ir::value *lhs, ir::value *rhs, ir::builder *builder) { + return builder->create_cat(lhs, rhs); +} + ir::value *dispatch::broadcast(ir::value *input, shape_t shape, ir::builder *builder) { if (!input->get_type()->is_block_ty()) return builder->create_splat(input, shape); @@ -715,6 +719,11 @@ ir::value *dispatch::sum(ir::value *input, unsigned int axis, ir::builder *build // Math //===----------------------------------------------------------------------===// +ir::value *dispatch::umulhi(ir::value *x, ir::value* y, ir::builder *builder) { + binary_op_type_checking(x, y, builder); + return builder->insert(umulhi_inst::create(x, y)); +} + ir::value *dispatch::exp(ir::value *x, ir::builder *builder) { return builder->create_exp(x); } diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index b3e52e94d..32e7674c6 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -522,11 +522,28 @@ masked_store_inst* masked_store_inst::create(value *ptr, value *val, value *mask // retile_inst classes //===----------------------------------------------------------------------===// +// cat + +cat_inst::cat_inst(value *x, value *y, const std::string &name, instruction *next) + : instruction(block_type::get(x->get_type()->get_scalar_ty(), + {x->get_type()->get_block_shapes()[0] + + y->get_type()->get_block_shapes()[0] }), INST_CAT, 2, name, next) { + set_operand(0, x); + set_operand(1, y); +} + +instruction* cat_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) { + return new cat_inst(lhs, rhs, name, next); +} + +// retile + retile_inst::retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next) : unary_inst(block_type::get(arg->get_type()->get_scalar_ty(), shapes), id, arg, name, next) { } + // reshape instruction* reshape_inst::create(value *arg, const type::block_shapes_t &shapes, @@ -761,6 +778,19 @@ instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const s } +// umulhi + +umulhi_inst::umulhi_inst(value *lhs, value *rhs, const std::string &name, instruction *next) + : builtin_inst(lhs->get_type(), INST_UMULHI, 2, name, next) { + set_operand(0, lhs); + set_operand(1, rhs); +} + +instruction* umulhi_inst::create(value *lhs, value *rhs, const std::string &name, instruction *next) { + return new umulhi_inst(lhs, rhs, name, next); +} + + // exp exp_inst::exp_inst(value *val, const std::string &name, instruction *next) @@ -877,7 +907,7 @@ make_range::make_range(type *ty, constant_int *first, constant_int *last) make_range *make_range::create(constant_int *first, constant_int *last) { assert(first->get_type()->is_integer_ty()); assert(first->get_type() == last->get_type()); - assert(((constant_int*)first)->get_value() == 0); +// assert(((constant_int*)first)->get_value() == 0); type *ty = block_type::get(first->get_type(), {(unsigned)last->get_value() - (unsigned)first->get_value()}); return new make_range(ty, first, last); } diff --git a/python/src/triton.cc b/python/src/triton.cc index 5f40f48f4..9298f9db4 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -313,6 +313,7 @@ void init_triton_frontend(py::module &&m) { m.def("arange", &ir::dispatch::arange, ret::reference); m.def("zeros", &ir::dispatch::zeros, ret::reference); // type manipuatation + m.def("cat", &ir::dispatch::cat, ret::reference); m.def("reshape", &ir::dispatch::reshape, ret::reference); typedef std::tuple (*broadcast_ty)(ir::value *, ir::value *, ir::builder *); typedef ir::value *(*broadcast_to_ty)(ir::value *, ir::type::block_shapes_t, ir::builder *); @@ -340,6 +341,7 @@ void init_triton_frontend(py::module &&m) { m.def("max", &ir::dispatch::max, ret::reference); m.def("sum", &ir::dispatch::sum, ret::reference); // math + m.def("umulhi", &ir::dispatch::umulhi, ret::reference); m.def("exp", &ir::dispatch::exp, ret::reference); m.def("log", &ir::dispatch::log, ret::reference); m.def("cos", &ir::dispatch::cos, ret::reference); diff --git a/python/triton/language/core.py b/python/triton/language/core.py index f201a9591..e584aecb1 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -346,6 +346,18 @@ def broadcast_to(input, shape, _builder=None): """ return frontend.broadcast_to(input, shape, _builder) +@builtin +def cat(input, other, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input block. + :type input: + :param other: The second input block. + :type other: + """ + return frontend.cat(input, other, _builder) + @builtin def reshape(input, shape, _builder=None): @@ -524,6 +536,10 @@ def where(condition, x, y, _builder=None): # Math # ----------------------- +@builtin +def umulhi(x, y, _builder=None): + return frontend.umulhi(x, y, _builder) + def _add_math_1arg_docstr(name): def _decorator(func): @@ -543,7 +559,6 @@ def _add_math_1arg_docstr(name): def exp(x, _builder=None): return frontend.exp(x, _builder) - @builtin @_add_math_1arg_docstr("natural logarithm") def log(x, _builder=None): diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 1a1ecbc37..3a3d7f9e1 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -31,42 +31,26 @@ def PHILOX_ROUND_B(): # 0xCD9E8D57 return -845247145 - @triton.jit def hacky_to_uint64(x): return ((x >> 1).to(tl.int64) << 1) + (x & 1).to(tl.int64) -@triton.jit -def multiply_low_high(a, b): - return ( - a * b, - ((hacky_to_uint64(a) * hacky_to_uint64(b)) >> 32).to(tl.int32) - ) - - @triton.jit def single_round(c0, c1, c2, c3, k0, k1): A = PHILOX_ROUND_A() B = PHILOX_ROUND_B() - lo0, hi0 = multiply_low_high(A, c0) - lo1, hi1 = multiply_low_high(B, c2) - - return ( - hi1 ^ c1 ^ k0, - lo1, - hi0 ^ c3 ^ k1, - lo0, - ) + _c0, _c2 = c0, c2 + c0 = tl.umulhi(B, _c2) ^ c1 ^ k0 + c2 = tl.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + return c0, c1, c2, c3 @triton.jit def raise_key(k0, k1): - return ( - k0 + PHILOX_KEY_A(), - k1 + PHILOX_KEY_B(), - ) - + return (k0 + PHILOX_KEY_A(), k1 + PHILOX_KEY_B()) @triton.jit def philox_f(c0, c1, c2, c3, k0, k1): @@ -125,7 +109,7 @@ def randint4x(seed, offset): :param seed: The seed for generating random numbers. :param offsets: The offsets to generate random numbers for. """ - z = 0 + z = offset*0 #FIXME: just 0 doesn't work. Likelye some error with broadcasting seed = hacky_to_uint64(seed) # uint will solve this seed_hi = ((seed >> 32) & 0xffffffff).to(tl.int32) seed_lo = (seed & 0xffffffff).to(tl.int32)