diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 0b6c859b1..5af20edbe 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -136,7 +136,7 @@ public: value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); value *create_sqrt(value *A, const std::string &name = ""); - value *create_reduce(value *A, unsigned axis, const std::string &name = ""); + value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = ""); value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); // Intrinsics value *create_copy_to_shared(value *arg, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 5a07f79b2..d961790ab 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -611,19 +611,28 @@ public: }; class reduce_inst: public builtin_inst { -private: - static type* get_res_type(value *arg, unsigned axis); +public: + enum op_t{ + ADD, SUB, MAX, MIN, + FADD, FSUB, FMAX, FMIN + }; private: - reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next); + static type* get_res_type(value *arg, unsigned axis); + static std::string to_str(op_t op); + +private: + reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next); std::string repr_impl() const { return "red<" + std::to_string(axis_) + ">"; } public: - static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr); + static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr); unsigned get_axis() const { return axis_; } + op_t get_op() const { return op_; } private: unsigned axis_; + op_t op_; }; class select_inst: public builtin_inst { diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h index 602113f93..f11d08fc8 100644 --- a/include/triton/lang/token.h +++ b/include/triton/lang/token.h @@ -131,6 +131,8 @@ public: // TILE ARITHMETICS BEGIN NEWAXIS, + MAX, + MIN, // TILE ARITHMETICS END ALIGNAS, // _Alignas diff --git a/lib/codegen/analysis/grid.cc b/lib/codegen/analysis/grid.cc index 4ce4116e3..da8516daa 100644 --- a/lib/codegen/analysis/grid.cc +++ b/lib/codegen/analysis/grid.cc @@ -60,15 +60,6 @@ void grids::init_c_graph(ir::instruction *v) { else if(dynamic_cast(v)) return; else if(dynamic_cast(v)) { -// unsigned axis = reduce->get_axis(); -// ir::value *arg = reduce->get_operand(0); -// auto in_shapes = arg->get_type()->get_tile_shapes(); -// unsigned current = 0; -// for(unsigned i = 0; i < in_shapes.size(); i++){ -// if(i == axis) -// continue; -// add_constraint({reduce, current++}, {arg, i}); -// } return; } else @@ -305,7 +296,6 @@ void grids::run(ir::module &mod) { for(size_t d = 0; d < shapes.size(); d++){ std::string str_d = std::to_string(d); effective_num_threads *= params_.at(i).at("mts.d" + str_d)->get_value(); - std::cout << shapes[d] << " " << params_.at(i).at("mts.d" + str_d)->get_value() << " " << params_.at(i).at("nts.d" + str_d)->get_value() << std::endl; } if(num_threads != effective_num_threads) throw std::runtime_error("cannot create a kernel with this amount of warps"); diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index 243eb2bb2..0b1568354 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -925,30 +925,47 @@ void selection::lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { Module *module = fn->getParent(); std::map partial; - ir::value *op = x->get_operand(0); - distributed_tile* op_tile = (distributed_tile*)tmap_.at(op); + ir::value *arg = x->get_operand(0); + distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg); + ir::reduce_inst::op_t op = x->get_op(); + auto accumulate = [&](Value* x, Value *y) -> Value* { + switch(op) { + case ir::reduce_inst::ADD: return builder.CreateAdd(x, y); + case ir::reduce_inst::SUB: return builder.CreateSub(x, y); + case ir::reduce_inst::MAX: return builder.CreateMaximum(x, y); + case ir::reduce_inst::MIN: return builder.CreateMinimum(x, y); + case ir::reduce_inst::FADD: return builder.CreateFAdd(x, y); + case ir::reduce_inst::FSUB: return builder.CreateFSub(x, y); + case ir::reduce_inst::FMAX: return builder.CreateSelect(builder.CreateFCmpOGT(x, y), x, y); + case ir::reduce_inst::FMIN: return builder.CreateSelect(builder.CreateFCmpOLT(x, y), x, y); + default: break; + } + assert(false); + return nullptr; + }; + unsigned axis = x->get_axis(); // reduce within thread - op_tile->for_each([&](indices_t idx) { + arg_tile->for_each([&](indices_t idx) { indices_t pidx = idx; pidx[axis] = builder.getInt32(0); - Value *current = op_tile->get_value(idx); + Value *current = arg_tile->get_value(idx); // current partial result is not initialized -- create if(partial.find(pidx) == partial.end()) partial[pidx] = current; // current partial result is initialized -- accumulate else - partial[pidx] = builder.CreateFAdd(partial[pidx], current); + partial[pidx] = accumulate(partial[pidx], current); }); // depth - unsigned shape_ax = op->get_type()->get_tile_shapes()[axis]; - unsigned per_thread = op_tile->axis(axis).values.size(); + unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis]; + unsigned per_thread = arg_tile->axis(axis).values.size(); unsigned depth = shape_ax / per_thread; // shapes - auto shared_shapes = op_tile->get_shapes(); + auto shared_shapes = arg_tile->get_shapes(); shared_shapes[axis] = depth; // reduce within blocks @@ -957,7 +974,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); for(auto& x: partial) { // current element being computed - Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id; + Value *lane = axes_.at(params_->get_param_group(arg, axis)).thread_id; Value *&result = x.second; indices_t write_idx = x.first; write_idx[axis] = lane; @@ -981,7 +998,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, tgt_->add_barrier(module, builder); Value *next = builder.CreateLoad(read_ptr); // accumulate - result = builder.CreateFAdd(result, next); + result = accumulate(result, next); // write back builder.CreateStore(result, write_ptr); } diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 458365a60..8f42e263c 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -323,8 +323,8 @@ value *builder::create_sqrt(value *A, const std::string &name) { return insert(sqrt_inst::create(A, name)); } -value *builder::create_reduce(value *A, unsigned axis, const std::string &name) { - return insert(reduce_inst::create(A, axis, name)); +value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name) { + return insert(reduce_inst::create(A, op, axis, name)); } value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){ diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index e7e5de1f2..9df26dc1a 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -615,6 +615,23 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction //===----------------------------------------------------------------------===// // reduce instructions //===----------------------------------------------------------------------===// + +std::string reduce_inst::to_str(op_t op) { + switch (op) { + case ADD: return "+"; + case SUB: return "-"; + case MAX: return "imax"; + case MIN: return "imin"; + case FADD: return "+"; + case FSUB: return "-"; + case FMAX: return "fmax"; + case FMIN: return "fmin"; + default: break; + } + assert(false); + return ""; +} + type* reduce_inst::get_res_type(value *arg, unsigned axis) { ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes(); shapes.erase(shapes.begin() + axis); @@ -625,14 +642,15 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) { return tile_type::get(scalar_ty, shapes); } -reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next) +reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) : builtin_inst(get_res_type(arg, axis), 1, 1, name, next), + op_(op), axis_(axis){ set_operand(0, arg); } -instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) { - return new reduce_inst(arg, axis, name, next); +instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) { + return new reduce_inst(arg, op, axis, name, next); } diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 56acb1c03..8384dd710 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -154,12 +154,24 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { error_not_implemented(); } +ir::reduce_inst::op_t reduce_op(int tag, bool is_float) { + using ir::reduce_inst; + switch(tag){ + case Token::ADD: return is_float ? reduce_inst::FADD : reduce_inst::ADD; + case Token::SUB: return is_float ? reduce_inst::FSUB : reduce_inst::SUB; + case Token::MAX: return is_float ? reduce_inst::FMAX : reduce_inst::MAX; + case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN; + default: break; + } + should_not_happen(); + return reduce_inst::op_t(); +} void Generator::VisitUnaryOp(UnaryOp* unary) { - // recursion Visit(unary->operand_); - ir::value* op = ret_; - + ir::value* arg = ret_; + ir::type *arg_ty = arg->get_type(); + ir::type *arg_scal_ty = arg_ty->get_scalar_ty(); // return switch (unary->op_) { case Token::PREFIX_INC: return error_not_implemented(); @@ -167,17 +179,19 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { case Token::POSTFIX_INC: return error_not_implemented(); case Token::POSTFIX_DEC: return error_not_implemented(); case Token::ADDR: return error_not_implemented(); - case Token::DEREF: return set_ret(bld_->create_load(op)); + case Token::DEREF: return set_ret(bld_->create_load(arg)); case Token::PLUS: return error_not_implemented(); case Token::MINUS: return error_not_implemented(); - case '~': return set_ret(bld_->create_neg(op)); - case '!': return set_ret(bld_->create_not(op)); - case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_))); - case '^': return set_ret(bld_->create_trans(op)); + case '~': return set_ret(bld_->create_neg(arg)); + case '!': return set_ret(bld_->create_not(arg)); + case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_))); + case '^': return set_ret(bld_->create_trans(arg)); case Token::REDUCE: { int ax, tag; UnaryOp::decodeRed(unary->info_, ax, tag); - return set_ret(bld_->create_reduce(op, ax)); + bool is_float = arg_scal_ty->is_floating_point_ty(); + ir::reduce_inst::op_t op = reduce_op(tag, is_float); + return set_ret(bld_->create_reduce(arg, op, ax)); } default: error_not_implemented(); } diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index f69337ced..a30258c3d 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -466,7 +466,9 @@ Expr* Parser::ParseSubScripting(Expr* lhs) { break; case Token::ADD: - case Token::SUB:{ + case Token::SUB: + case Token::MAX: + case Token::MIN:{ int info = UnaryOp::encodeRed(i, tok->tag_); redInfo.push_back({i, info}); shape.push_back(lhsShape[i++]); diff --git a/lib/lang/token.cc b/lib/lang/token.cc index b9f3c8467..8b61aa098 100644 --- a/lib/lang/token.cc +++ b/lib/lang/token.cc @@ -54,6 +54,8 @@ const std::unordered_map Token::kwTypeMap_ { { "_Noreturn", Token::NORETURN }, { "_Static_assert", Token::STATIC_ASSERT }, { "_Thread_local", Token::THREAD }, + { "max", Token::MAX }, + { "min", Token::MIN }, }; const std::unordered_map Token::tagLexemeMap_ { diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 79ee61a51..ea84eac00 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -157,6 +157,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr for(auto it: opt_space_.defines) cpp.AddMacro(it.first, &opt.defines.at(it.first)); cpp.Process(tokens); +// tokens.Print(stdout); // parse Parser parser(tokens); parser.Parse(); @@ -200,7 +201,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::transform::peephole peephole; codegen::transform::reassociate reassociate(&alignment_info, &grids); codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get()); - ir::print(module, std::cout); +// ir::print(module, std::cout); // run passes peephole.run(module); dce.run(module); diff --git a/tests/common/src/reduce.h b/tests/common/src/reduce.h index 02cc3fbe7..3a77e960e 100644 --- a/tests/common/src/reduce.h +++ b/tests/common/src/reduce.h @@ -19,7 +19,7 @@ void reduce2d(TYPE * X __noalias __readonly __aligned(16), int rm[TM] = ridm * TM + 0 ... TM; int rn[TN] = ridn * TN + 0 ... TN; TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; - TYPE* py[TY] = Y + rm; + TYPE* py[TY] = Y + RY; *py = (*px)[RED]; } )"; diff --git a/tests/common/util.h b/tests/common/util.h index 800e2c5ae..6de7f340f 100644 --- a/tests/common/util.h +++ b/tests/common/util.h @@ -43,6 +43,34 @@ void init_zeros(std::vector& x) { x[i] = 0; } +enum reduce_op_t { + ADD, + MAX, + MIN +}; + +std::string to_str(reduce_op_t op) { + switch (op) { + case ADD: return "+"; + case MAX: return "max"; + case MIN: return "min"; + default: break; + } + assert(false); + return ""; +} + +template +std::function get_accumulator(reduce_op_t op) { + switch (op) { + case ADD: return [](T x, T y) { return x + y; }; + case MAX: return [](T x, T y) { return std::max(x, y); }; + case MIN: return [](T x, T y) { return std::min(x, y); }; + default: break; + } + assert(false); + return std::function(); +} namespace aux{ @@ -70,6 +98,23 @@ auto operator<<(std::basic_ostream& os, std::tuple const& t) return os << ")"; } +template +std::basic_ostream& operator<<(std::basic_ostream& os, const std::vector& vec) { + os << "{"; + for(size_t i = 0; i < vec.size(); i++){ + if(i > 0) + os << ", "; + os << vec[i]; + } + os << "}"; + return os; +} + +template +std::basic_ostream& operator<<(std::basic_ostream& os, reduce_op_t op) { + return os << to_str(op); +} + namespace testing { diff --git a/tests/unit/reduce.cc b/tests/unit/reduce.cc index 2317b76d2..5951f3e50 100644 --- a/tests/unit/reduce.cc +++ b/tests/unit/reduce.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include "triton/driver/backend.h" #include "triton/driver/stream.h" #include "triton/tools/bench.hpp" @@ -40,58 +41,66 @@ int offset(const std::vector& idx, const std::vector& shapes) { } template -void reduce_nd(std::vector &y, const std::vector &x, size_t axis, const std::vector& shapes) { +void reduce_nd(std::vector &y, const std::vector &x, reduce_op_t op, size_t axis, const std::vector& shapes) { assert(axis <= shapes.size() - 1); // remove shape at index axis to get outer dimensions std::vector outer = shapes; outer.erase(outer.begin() + axis); // retrieve shape at index axis to get inner dimension int inner = shapes[axis]; + // accumualtion function + auto acc = get_accumulator(op); // iterate over outer dimensions _loop_nest(outer, [&](const std::vector& y_idx) { - T acc = 0; + T ret = 0; auto x_idx = y_idx; x_idx.insert(x_idx.begin() + axis, 0); // accumulate over inner dimensions for(int z = 0; z < inner; z++){ x_idx[axis] = z; - acc = acc + x[offset(x_idx, shapes)]; + ret = acc(ret, x[offset(x_idx, shapes)]); } - y[offset(y_idx, outer)] = acc; + y[offset(y_idx, outer)] = ret; }); } -bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){ +bool do_test(drv::stream* stream, std::vector shape, int axis, reduce_op_t op, int nwarp){ typedef float NumericT; std::string ty = "float"; size_t dt_nbytes = sizeof(NumericT); drv::context* context = stream->context(); - std::vector hy(M); - std::vector ry(M); - std::vector hx(M*N); + size_t axy = (axis == 0) ? 1 : 0; + std::string RY = (axis == 0) ? "rn" : "rm"; + std::vector hy(shape[axy]); + std::vector ry(shape[axy]); + std::vector hx(shape[0]*shape[1]); srand(0); init_zeros(hy); init_rand(hx); - for(int i = 0; i < M; i++) - for(int j = 0; j < N; j++) - hx[i + j*M] = i+j; auto dy = std::shared_ptr(drv::buffer::create(context, hy.size()*dt_nbytes)); auto dx = std::shared_ptr(drv::buffer::create(context, hx.size()*dt_nbytes)); stream->write(&*dy, true, 0, hy); stream->write(&*dx, true, 0, hx); rt::function::options_space_t opt; opt.defines.push_back({"TYPE", {ty}}); - opt.defines.push_back({"TM", {std::to_string(M)}}); - opt.defines.push_back({"TN", {std::to_string(N)}}); - opt.defines.push_back({"TY", {std::to_string(M)}}); - opt.defines.push_back({"RED", {"+, :"}}); + opt.defines.push_back({"TM", {std::to_string(shape[0])}}); + opt.defines.push_back({"TN", {std::to_string(shape[1])}}); + opt.defines.push_back({"TY", {std::to_string(shape[axy])}}); + opt.defines.push_back({"RY", {RY}}); + std::string RED = ""; + for(int n = 0; n < 2; n++){ + if(n > 0) + RED += ", "; + RED += (n==axis) ? to_str(op) : ":"; + } + opt.defines.push_back({"RED", {RED}}); opt.num_warps = {nwarp}; rt::function function(src::reduce2d, opt); - function({&*dx, &*dy, M, N, M}, grid2d(M, N), stream); + function({&*dx, &*dy, shape[0], shape[1], shape[0]}, grid2d(shape[0], shape[1]), stream); stream->synchronize(); stream->read(&*dy, true, 0, hy); - reduce_nd(ry, hx, 0, {M, N}); + reduce_nd(ry, hx, op, axis, shape); return testing::diff(hy, ry); } @@ -100,17 +109,21 @@ int main() { auto context = triton::driver::backend::contexts::get_default(); triton::driver::stream* stream = triton::driver::stream::create(context); // shapes to benchmark - typedef std::tuple config_t; + typedef std::tuple, int, reduce_op_t> config_t; std::vector configs = { - config_t{32, 32, "+"} + config_t{{32, 32}, 0, MAX}, + config_t{{32, 32}, 1, ADD}, + config_t{{32, 64}, 0, ADD}, + config_t{{64, 32}, 1, ADD} }; // does the work - int M, N; - std::string op; + int axis; + std::vector shape; + reduce_op_t op; for(const auto& c: configs){ - std::tie(M, N, op) = c; + std::tie(shape, axis, op) = c; std::cout << "Testing " << c << " ... " << std::flush; - if(do_test(stream, M, N, op, 1)) + if(do_test(stream, shape, axis, op, 1)) std::cout << " Pass! " << std::endl; else std::cout << " Fail! " << std::endl;