diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 0b6c859b1..5af20edbe 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -136,7 +136,7 @@ public: value *create_dot(value *A, value *B, value *C, const std::string &name = ""); value *create_trans(value *A, const std::vector &perm = {}, const std::string &name = ""); value *create_sqrt(value *A, const std::string &name = ""); - value *create_reduce(value *A, unsigned axis, const std::string &name = ""); + value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = ""); value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); // Intrinsics value *create_copy_to_shared(value *arg, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index a4fbc3710..d961790ab 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -611,19 +611,28 @@ public: }; class reduce_inst: public builtin_inst { +public: + enum op_t{ + ADD, SUB, MAX, MIN, + FADD, FSUB, FMAX, FMIN + }; + private: static type* get_res_type(value *arg, unsigned axis); + static std::string to_str(op_t op); private: - reduce_inst(value* arg, unsigned axis, const std::string& name, instruction* next); - std::string repr_impl() const { return "reduce"; } + reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next); + std::string repr_impl() const { return "red<" + std::to_string(axis_) + ">"; } public: - static instruction* create(value *arg, unsigned axis, const std::string &name = "", instruction *next = nullptr); + static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr); unsigned get_axis() const { return axis_; } + op_t get_op() const { return op_; } private: unsigned axis_; + op_t op_; }; class select_inst: public builtin_inst { diff --git a/include/triton/lang/ast.h b/include/triton/lang/ast.h index 8bf96a96b..43cfc485f 100644 --- a/include/triton/lang/ast.h +++ b/include/triton/lang/ast.h @@ -418,22 +418,25 @@ class UnaryOp : public Expr { friend class LValAssigner; public: - static UnaryOp* New(int op, Expr* operand, QualType type=nullptr); + static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0); virtual ~UnaryOp() {} virtual void Accept(Visitor* v); virtual bool IsLVal(); ::Type *Convert(); + static int encodeRed(int ax, int tag); + static void decodeRed(int info, int& ax, int& tag); void TypeChecking(); void IncDecOpTypeChecking(); void AddrOpTypeChecking(); void DerefOpTypeChecking(); + void ReduceOpTypeChecking(); void TransOpTypeChecking(); void UnaryArithmOpTypeChecking(); void CastOpTypeChecking(); protected: - UnaryOp(int op, Expr* operand, QualType type=nullptr) - : Expr(operand->Tok(), type), op_(op) { + UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0) + : Expr(operand->Tok(), type), op_(op), info_(info) { operand_ = operand; if (op_ != Token::CAST && op_ != Token::ADDR) { operand_ = MayCast(operand); @@ -441,6 +444,7 @@ protected: } int op_; + int info_; Expr* operand_; }; diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h index 1690ba246..f11d08fc8 100644 --- a/include/triton/lang/token.h +++ b/include/triton/lang/token.h @@ -131,6 +131,8 @@ public: // TILE ARITHMETICS BEGIN NEWAXIS, + MAX, + MIN, // TILE ARITHMETICS END ALIGNAS, // _Alignas @@ -180,6 +182,7 @@ public: PLUS, MINUS, CAST, + REDUCE, // For preprocessor PP_IF, diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 96ec35ef7..42ecd69f9 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -70,7 +70,7 @@ public: struct options_space_t { typedef std::pair> define_t; std::vector defines; - std::vector num_warps; + std::vector num_warps; }; struct options_t { diff --git a/lib/codegen/analysis/grid.cc b/lib/codegen/analysis/grid.cc index f90ab8822..da8516daa 100644 --- a/lib/codegen/analysis/grid.cc +++ b/lib/codegen/analysis/grid.cc @@ -59,16 +59,7 @@ void grids::init_c_graph(ir::instruction *v) { shapes = atom->get_operand(0)->get_type()->get_tile_shapes(); else if(dynamic_cast(v)) return; - else if(auto *reduce = dynamic_cast(v)) { - unsigned axis = reduce->get_axis(); - ir::value *arg = reduce->get_operand(0); - auto in_shapes = arg->get_type()->get_tile_shapes(); - unsigned current = 0; - for(unsigned i = 0; i < in_shapes.size(); i++){ - if(i == axis) - continue; - add_constraint({reduce, current++}, {arg, i}); - } + else if(dynamic_cast(v)) { return; } else @@ -244,7 +235,6 @@ void grids::run(ir::module &mod) { unsigned size = i->get_type()->get_tile_num_elements(); /* HMMA parameters*/ if(fragments_.at({i, 0}) == HMMA_FRAGMENT_C){ - /* fragments per warp */ // try to make things as square as possible to maximize data re-use std::vector fpw = {1, 1, 1}; @@ -285,7 +275,6 @@ void grids::run(ir::module &mod) { if(num_warps_ != effective_num_warps) throw std::runtime_error("cannot create a kernel with this amount of warps"); - } /* Scan-line */ diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index ff246f4f5..0b1568354 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -923,52 +923,74 @@ void selection::lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function } void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { - ir::instruction *ins = (ir::instruction*)x; Module *module = fn->getParent(); std::map partial; - ir::value *op = x->get_operand(0); - distributed_tile* op_tile = (distributed_tile*)tmap_.at(op); + ir::value *arg = x->get_operand(0); + distributed_tile* arg_tile = (distributed_tile*)tmap_.at(arg); + ir::reduce_inst::op_t op = x->get_op(); + auto accumulate = [&](Value* x, Value *y) -> Value* { + switch(op) { + case ir::reduce_inst::ADD: return builder.CreateAdd(x, y); + case ir::reduce_inst::SUB: return builder.CreateSub(x, y); + case ir::reduce_inst::MAX: return builder.CreateMaximum(x, y); + case ir::reduce_inst::MIN: return builder.CreateMinimum(x, y); + case ir::reduce_inst::FADD: return builder.CreateFAdd(x, y); + case ir::reduce_inst::FSUB: return builder.CreateFSub(x, y); + case ir::reduce_inst::FMAX: return builder.CreateSelect(builder.CreateFCmpOGT(x, y), x, y); + case ir::reduce_inst::FMIN: return builder.CreateSelect(builder.CreateFCmpOLT(x, y), x, y); + default: break; + } + assert(false); + return nullptr; + }; + unsigned axis = x->get_axis(); // reduce within thread - op_tile->for_each([&](indices_t idx) { + arg_tile->for_each([&](indices_t idx) { indices_t pidx = idx; - pidx.erase(pidx.begin() + axis); - Value *current = op_tile->get_value(idx); + pidx[axis] = builder.getInt32(0); + Value *current = arg_tile->get_value(idx); // current partial result is not initialized -- create if(partial.find(pidx) == partial.end()) partial[pidx] = current; // current partial result is initialized -- accumulate else - partial[pidx] = builder.CreateFAdd(partial[pidx], current); + partial[pidx] = accumulate(partial[pidx], current); }); + // depth + unsigned shape_ax = arg->get_type()->get_tile_shapes()[axis]; + unsigned per_thread = arg_tile->axis(axis).values.size(); + unsigned depth = shape_ax / per_thread; + + // shapes + auto shared_shapes = arg_tile->get_shapes(); + shared_shapes[axis] = depth; + // reduce within blocks unsigned addr_space = sh_mem_ptr_->getType()->getPointerAddressSpace(); Type *res_ty = builder.getFloatTy(); Value *base_ptr = builder.CreateBitCast(sh_mem_ptr_, PointerType::get(res_ty, addr_space)); for(auto& x: partial) { // current element being computed - Value *lane = axes_.at(params_->get_param_group(op, axis)).thread_id; + Value *lane = axes_.at(params_->get_param_group(arg, axis)).thread_id; Value *&result = x.second; indices_t write_idx = x.first; - write_idx.insert(write_idx.begin() + axis, lane); - + write_idx[axis] = lane; // shared memory write pointer - Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx); + Value *write_offset = shared_tile::shared_offset(builder, shared_shapes, write_idx); Value *write_ptr = builder.CreateGEP(base_ptr, write_offset); - // initialize shared memory tgt_->add_barrier(module, builder); builder.CreateStore(result, write_ptr); // build result - unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value(); for(unsigned i = depth/2; i > 0; i >>= 1){ // current indices indices_t current(write_idx.size(), builder.getInt32(0)); current[axis] = builder.getInt32(i); // shared memory offset - Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), current); + Value *read_offset = shared_tile::shared_offset(builder, shared_shapes, current); Value *is_active = builder.CreateICmpULT(lane, builder.getInt32(i)); read_offset = builder.CreateSelect(is_active, read_offset, builder.getInt32(0)); // shared memory read pointer @@ -976,25 +998,21 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, tgt_->add_barrier(module, builder); Value *next = builder.CreateLoad(read_ptr); // accumulate - result = builder.CreateFAdd(result, next); + result = accumulate(result, next); // write back builder.CreateStore(result, write_ptr); } - - // result is on the first lane of shared memory - indices_t final = write_idx; - final[axis] = builder.getInt32(0); - Value *read_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), final); - Value *read_ptr = builder.CreateGEP(base_ptr, read_offset); - tgt_->add_barrier(module, builder); - result = builder.CreateLoad(read_ptr); - if(tmap_.find(ins) == tmap_.end()) - vmap_[ins] = result; - else{ - distributed_tile *ti = (distributed_tile*)tmap_[ins]; - ti->set_value(x.first, result); - } } + tgt_->add_barrier(module, builder); + + distributed_tile* x_tile = (distributed_tile*)tmap_.at(x); + x_tile->for_each([&](indices_t idx) { + indices_t red_idx = idx; + red_idx.insert(red_idx.begin() + axis, builder.getInt32(0)); + Value *read_offset = shared_tile::shared_offset(builder, shared_shapes, red_idx); + Value *read_ptr = builder.CreateGEP(base_ptr, read_offset); + x_tile->set_value(idx, builder.CreateLoad(read_ptr)); + }); } void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { diff --git a/lib/ir/builder.cc b/lib/ir/builder.cc index 458365a60..8f42e263c 100644 --- a/lib/ir/builder.cc +++ b/lib/ir/builder.cc @@ -323,8 +323,8 @@ value *builder::create_sqrt(value *A, const std::string &name) { return insert(sqrt_inst::create(A, name)); } -value *builder::create_reduce(value *A, unsigned axis, const std::string &name) { - return insert(reduce_inst::create(A, axis, name)); +value *builder::create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name) { + return insert(reduce_inst::create(A, op, axis, name)); } value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){ diff --git a/lib/ir/instructions.cc b/lib/ir/instructions.cc index e7e5de1f2..9df26dc1a 100644 --- a/lib/ir/instructions.cc +++ b/lib/ir/instructions.cc @@ -615,6 +615,23 @@ instruction* sqrt_inst::create(value *arg, const std::string &name, instruction //===----------------------------------------------------------------------===// // reduce instructions //===----------------------------------------------------------------------===// + +std::string reduce_inst::to_str(op_t op) { + switch (op) { + case ADD: return "+"; + case SUB: return "-"; + case MAX: return "imax"; + case MIN: return "imin"; + case FADD: return "+"; + case FSUB: return "-"; + case FMAX: return "fmax"; + case FMIN: return "fmin"; + default: break; + } + assert(false); + return ""; +} + type* reduce_inst::get_res_type(value *arg, unsigned axis) { ir::tile_type::tile_shapes_t shapes = arg->get_type()->get_tile_shapes(); shapes.erase(shapes.begin() + axis); @@ -625,14 +642,15 @@ type* reduce_inst::get_res_type(value *arg, unsigned axis) { return tile_type::get(scalar_ty, shapes); } -reduce_inst::reduce_inst(value *arg, unsigned axis, const std::string &name, instruction *next) +reduce_inst::reduce_inst(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) : builtin_inst(get_res_type(arg, axis), 1, 1, name, next), + op_(op), axis_(axis){ set_operand(0, arg); } -instruction* reduce_inst::create(value *arg, unsigned axis, const std::string &name, instruction *next) { - return new reduce_inst(arg, axis, name, next); +instruction* reduce_inst::create(value *arg, op_t op, unsigned axis, const std::string &name, instruction *next) { + return new reduce_inst(arg, op, axis, name, next); } diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index e1f008c36..b0a50adc3 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() { } void BinaryOp::MaskedDerefOpTypeChecking() { +// auto lhsTileType = lhs_->Type()->ToTile(); +// auto rhsTileType = rhs_->Type()->ToTile(); ::Type* lhsScalType = TryExtractScalarType(this, lhs_); ::Type* rhsScalType = TryExtractScalarType(this, rhs_); auto lhsType = lhsScalType->ToArithm(); @@ -572,8 +574,8 @@ void BinaryOp::AssignOpTypeChecking() { * Unary Operators */ -UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) { - auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type); +UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) { + auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info); ret->pool_ = &unaryOpPool; ret->TypeChecking(); @@ -581,6 +583,18 @@ UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) { } +int UnaryOp::encodeRed(int ax, int tag) { + int result = 0; + result |= ax; + result |= tag << 16; + return result; +} + +void UnaryOp::decodeRed(int info, int& ax, int& tag) { + ax = info & 0x0000FFFF; + tag = (info & 0xFFFF0000) >> 16; +} + bool UnaryOp::IsLVal() { // Only deref('*') could be lvalue; return op_ == Token::DEREF; @@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() { case '^': return TransOpTypeChecking(); + case Token::REDUCE: + return ReduceOpTypeChecking(); + default: assert(false); } @@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() { type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr()); } +void UnaryOp::ReduceOpTypeChecking() { + int ax, tag; + decodeRed(info_, ax, tag); + auto tileType = operand_->Type()->ToTile(); + if(!tileType) + Error(this, "array expected for reduction operation"); + auto shape = tileType->Shape(); + shape.erase(shape.begin() + ax); + type_ = TileType::New(shape, tileType->Derived()); +} void UnaryOp::TransOpTypeChecking() { auto tileType = operand_->Type()->ToTile(); diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 228bd69dd..8384dd710 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -154,12 +154,24 @@ void Generator::VisitBinaryOp(BinaryOp* binary) { error_not_implemented(); } +ir::reduce_inst::op_t reduce_op(int tag, bool is_float) { + using ir::reduce_inst; + switch(tag){ + case Token::ADD: return is_float ? reduce_inst::FADD : reduce_inst::ADD; + case Token::SUB: return is_float ? reduce_inst::FSUB : reduce_inst::SUB; + case Token::MAX: return is_float ? reduce_inst::FMAX : reduce_inst::MAX; + case Token::MIN: return is_float ? reduce_inst::FMIN : reduce_inst::MIN; + default: break; + } + should_not_happen(); + return reduce_inst::op_t(); +} void Generator::VisitUnaryOp(UnaryOp* unary) { - // recursion Visit(unary->operand_); - ir::value* op = ret_; - + ir::value* arg = ret_; + ir::type *arg_ty = arg->get_type(); + ir::type *arg_scal_ty = arg_ty->get_scalar_ty(); // return switch (unary->op_) { case Token::PREFIX_INC: return error_not_implemented(); @@ -167,13 +179,20 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { case Token::POSTFIX_INC: return error_not_implemented(); case Token::POSTFIX_DEC: return error_not_implemented(); case Token::ADDR: return error_not_implemented(); - case Token::DEREF: return set_ret(bld_->create_load(op)); + case Token::DEREF: return set_ret(bld_->create_load(arg)); case Token::PLUS: return error_not_implemented(); case Token::MINUS: return error_not_implemented(); - case '~': return set_ret(bld_->create_neg(op)); - case '!': return set_ret(bld_->create_not(op)); - case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_))); - case '^': return set_ret(bld_->create_trans(op)); + case '~': return set_ret(bld_->create_neg(arg)); + case '!': return set_ret(bld_->create_not(arg)); + case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_))); + case '^': return set_ret(bld_->create_trans(arg)); + case Token::REDUCE: { + int ax, tag; + UnaryOp::decodeRed(unary->info_, ax, tag); + bool is_float = arg_scal_ty->is_floating_point_ty(); + ir::reduce_inst::op_t op = reduce_op(tag, is_float); + return set_ret(bld_->create_reduce(arg, op, ax)); + } default: error_not_implemented(); } return error_not_implemented(); @@ -412,16 +431,41 @@ void Generator::Gen(ir::module *mod) { ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) { + if(src->get_type() == dst_ty) + return src; if(dst_ty->is_tile_ty()) { ir::type *src_ty = src->get_type(); auto dst_shapes = dst_ty->get_tile_shapes(); if(!src_ty->is_tile_ty()) return bld_->create_splat(src, dst_shapes); auto src_shapes = src_ty->get_tile_shapes(); - if(src_shapes.size() != dst_shapes.size()) - return bld_->create_reshape(src, dst_shapes); - else + if(src_shapes.size() != dst_shapes.size()){ + unsigned src_numel = 1; + for(unsigned s: src_shapes) + src_numel *= s; + unsigned dst_numel = 1; + for(unsigned s: dst_shapes) + dst_numel *= s; + if(src_numel == dst_numel) + return bld_->create_reshape(src, dst_shapes); + else { + auto padded_shapes = src_shapes; + while(padded_shapes.size() != dst_shapes.size()) + padded_shapes.insert(padded_shapes.begin(), 1); + // check that broadcast is legal + for(size_t d = 0; d < padded_shapes.size(); d++){ + if(dst_shapes[d] != padded_shapes[d] && + padded_shapes[d] != 1) + should_not_happen(); + } + // pad and broadcast + ir::value *padded = bld_->create_reshape(src, padded_shapes); + return bld_->create_broadcast(padded, dst_shapes); + } + } + else{ return bld_->create_broadcast(src, dst_shapes); + } } return src; } diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index fed1422fc..a30258c3d 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -453,21 +453,52 @@ Expr* Parser::ParseSubScripting(Expr* lhs) { TileType::ShapeInt shape; size_t i = 0; const Token* tok; + std::vector> redInfo; do { tok = ts_.Next(); - if(tok->tag_ == ':') - shape.push_back(lhsShape[i++]); - else if(tok->tag_ == Token::NEWAXIS) - shape.push_back(1); - else - Error(tok, "only ':' and newaxis are supported in subscripts"); + switch(tok->tag_) { + case ':': + shape.push_back(lhsShape[i++]); + break; + + case Token::NEWAXIS: + shape.push_back(1); + break; + + case Token::ADD: + case Token::SUB: + case Token::MAX: + case Token::MIN:{ + int info = UnaryOp::encodeRed(i, tok->tag_); + redInfo.push_back({i, info}); + shape.push_back(lhsShape[i++]); + break; + } + + default: + Error(tok, "Unexpected subscript symbol encountered at dimension %d", i); + break; + } }while(ts_.Try(',')); ts_.Expect(']'); if(lhsShape.size() > i) Error(tok, "broadcasting not using all operand axes"); // create ret tile - TileType *retType = TileType::New(shape, lhsQual); - return UnaryOp::New(Token::CAST, lhs, retType); + Expr* res = lhs; + for(auto r: redInfo){ + shape.erase(shape.begin() + r.first); + Type *retType; + if(shape.empty()) + retType = lhsQual.GetPtr(); + else + retType = TileType::New(shape, lhsQual); + res = UnaryOp::New(Token::REDUCE, res, retType, r.second); + } + if(!shape.empty()){ + TileType *retType = TileType::New(shape, lhsQual); + res = UnaryOp::New(Token::CAST, res, retType); + } + return res; } diff --git a/lib/lang/token.cc b/lib/lang/token.cc index b9f3c8467..8b61aa098 100644 --- a/lib/lang/token.cc +++ b/lib/lang/token.cc @@ -54,6 +54,8 @@ const std::unordered_map Token::kwTypeMap_ { { "_Noreturn", Token::NORETURN }, { "_Static_assert", Token::STATIC_ASSERT }, { "_Thread_local", Token::THREAD }, + { "max", Token::MAX }, + { "min", Token::MIN }, }; const std::unordered_map Token::tagLexemeMap_ { diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 114626dce..ea84eac00 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -157,6 +157,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr for(auto it: opt_space_.defines) cpp.AddMacro(it.first, &opt.defines.at(it.first)); cpp.Process(tokens); +// tokens.Print(stdout); // parse Parser parser(tokens); parser.Parse(); @@ -164,11 +165,7 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr auto ir = make_ir(parser); // binary code-gen std::unique_ptr bin; - try{ - bin = make_bin(*ir, stream->context(), opt); - }catch(const std::runtime_error& e) { - return; - } + bin = make_bin(*ir, stream->context(), opt); // kernel uses too much resources if(!bin) return; @@ -204,6 +201,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::transform::peephole peephole; codegen::transform::reassociate reassociate(&alignment_info, &grids); codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get()); +// ir::print(module, std::cout); // run passes peephole.run(module); dce.run(module); diff --git a/tests/common/src/reduce.h b/tests/common/src/reduce.h new file mode 100644 index 000000000..3a77e960e --- /dev/null +++ b/tests/common/src/reduce.h @@ -0,0 +1,27 @@ +namespace src { + + const char *reduce1d = +R"( +void reduce1d(TYPE * X __noalias __readonly __aligned(16), + TYPE * Y __noalias __readonly __aligned(16), + int N) { +} +)"; + + + const char *reduce2d = +R"( +void reduce2d(TYPE * X __noalias __readonly __aligned(16), + TYPE * Y __noalias __writeonly __aligned(16), + int M, int N, int ldx) { + int ridm = get_program_id(0); + int ridn = get_program_id(1); + int rm[TM] = ridm * TM + 0 ... TM; + int rn[TN] = ridn * TN + 0 ... TN; + TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; + TYPE* py[TY] = Y + RY; + *py = (*px)[RED]; +} +)"; + +} diff --git a/tests/common/util.h b/tests/common/util.h index d8ffef090..0d06b47a8 100644 --- a/tests/common/util.h +++ b/tests/common/util.h @@ -9,6 +9,10 @@ namespace drv = triton::driver; namespace rt = triton::runtime; +/* ------------------------ + * Launch Grid + * ------------------------ */ + inline size_t ceil(size_t x, size_t y) { return (x + y - 1) / y; }; @@ -26,12 +30,116 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { }; } + +/* ------------------------ + * Tensor Initialization + * ------------------------ */ + +template +void init_rand(std::vector& x) { + for(size_t i = 0; i < x.size(); i++) + x[i] = static_cast((double)rand()/RAND_MAX); +} + +template +void init_zeros(std::vector& x) { + for(size_t i = 0; i < x.size(); i++) + x[i] = 0; +} + +/* ------------------------ + * Loop Nests + * ------------------------ */ + +void _loop_nest(std::vector const & ranges, + std::function const &)> const & f){ + int D = ranges.size(); + std::vector values(D, 0); + // Start with innermost loop + int i = D - 1; + while(true){ + // Execute function + f(values); + while(values[i]++ == ranges[i] - 1){ + if(i == 0) + return; + values[i--] = 0; + } + i = D - 1; + } +} + +/* ----------------------- + * TENSOR INDEXING + * ----------------------- */ + enum order_t { ROWMAJOR, COLMAJOR }; +int offset(const std::vector& idx, const std::vector& shapes) { + int result = idx[0]; + for(int i = 1; i < idx.size(); i++) + result += idx[i]*shapes[i-1]; + return result; +} + +/* ----------------------- + * REDUCTION HELPERS + * ----------------------- */ + +enum reduce_op_t { + ADD, + MAX, + MIN +}; + +std::string to_str(reduce_op_t op) { + switch (op) { + case ADD: return "+"; + case MAX: return "max"; + case MIN: return "min"; + default: break; + } + assert(false); + return ""; +} + +template +std::function get_accumulator(reduce_op_t op) { + switch (op) { + case ADD: return [](T x, T y) { return x + y; }; + case MAX: return [](T x, T y) { return std::max(x, y); }; + case MIN: return [](T x, T y) { return std::min(x, y); }; + default: break; + } + assert(false); + return std::function(); +} + + +/* ----------------------- + * TENSOR COMPARISON + * ----------------------- */ + +template +bool diff(const std::vector& hc, const std::vector& rc) { +if(hc.size() != rc.size()) + return false; +for(size_t i = 0; i < hc.size(); i++) + if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; + return false; + } +return true; +} + +/* ----------------------- + * PRETTY PRINTING + * ----------------------- */ + namespace aux{ template struct seq{}; @@ -57,22 +165,23 @@ auto operator<<(std::basic_ostream& os, std::tuple const& t) return os << ")"; } - -namespace testing { - - template - bool diff(const std::vector& hc, const std::vector& rc) { - if(hc.size() != rc.size()) - return false; - for(size_t i = 0; i < hc.size(); i++) - if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ - std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; - - return false; - } - return true; - } - +template +std::basic_ostream& operator<<(std::basic_ostream& os, const std::vector& vec) { + os << "{"; + for(size_t i = 0; i < vec.size(); i++){ + if(i > 0) + os << ", "; + os << vec[i]; + } + os << "}"; + return os; } +template +std::basic_ostream& operator<<(std::basic_ostream& os, reduce_op_t op) { + return os << to_str(op); +} + + + #endif diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 78fbc79d1..3efbdd71f 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -1,4 +1,4 @@ -foreach(PROG dot) +foreach(PROG dot reduce) set(TARGET unit_${PROG}) add_executable(${TARGET} ${PROG}.cc) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) diff --git a/tests/unit/dot.cc b/tests/unit/dot.cc index 69b8cf2d7..e1b0a8bb5 100644 --- a/tests/unit/dot.cc +++ b/tests/unit/dot.cc @@ -50,7 +50,7 @@ void cpu_ref(bool AT_, bool BT_, size_t M, size_t N, size_t K, } -bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, size_t nwarp){ +bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K, int32_t TM, int32_t TN, int32_t TK, int nwarp){ typedef float NumericT; std::string ty = "float"; size_t dt_nbytes = sizeof(NumericT); @@ -62,12 +62,9 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_ int32_t ldb = BT ? N : K; int32_t ldc = M; srand(0); - for(size_t i = 0; i < ha.size(); i++) - ha[i] = static_cast((float)rand()/RAND_MAX); - for(size_t i = 0; i < hb.size(); i++) - hb[i] = static_cast((float)rand()/RAND_MAX); - for(size_t i = 0; i < hc.size(); i++) - hc[i] = static_cast((double)0); + init_rand(ha); + init_rand(hb); + init_rand(hc); auto dc = std::shared_ptr(drv::buffer::create(context, hc.size()*dt_nbytes)); auto da = std::shared_ptr(drv::buffer::create(context, ha.size()*dt_nbytes)); auto db = std::shared_ptr(drv::buffer::create(context, hb.size()*dt_nbytes)); @@ -94,7 +91,7 @@ bool do_test(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_ stream->read(&*dc, true, 0, hc); std::vector rc(hc.size()); cpu_ref(AT, BT, M, N, K, rc, ha, hb); - return testing::diff(hc, rc); + return diff(hc, rc); } int main() { diff --git a/tests/unit/reduce.cc b/tests/unit/reduce.cc new file mode 100644 index 000000000..63b870fe5 --- /dev/null +++ b/tests/unit/reduce.cc @@ -0,0 +1,106 @@ +#include +#include +#include +#include +#include +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" +#include "triton/tools/bench.hpp" +#include "triton/external/half.hpp" +#include "triton/runtime/function.h" +#include "src/reduce.h" +#include "cuda/cublas.h" +#include "util.h" + +namespace drv = triton::driver; +namespace rt = triton::runtime; + +template +void reduce_nd(std::vector &y, const std::vector &x, reduce_op_t op, size_t axis, const std::vector& shapes) { + assert(axis <= shapes.size() - 1); + // remove shape at index axis to get outer dimensions + std::vector outer = shapes; + outer.erase(outer.begin() + axis); + // retrieve shape at index axis to get inner dimension + int inner = shapes[axis]; + // accumualtion function + auto acc = get_accumulator(op); + // iterate over outer dimensions + _loop_nest(outer, [&](const std::vector& y_idx) { + T ret = 0; + auto x_idx = y_idx; + x_idx.insert(x_idx.begin() + axis, 0); + // accumulate over inner dimensions + for(int z = 0; z < inner; z++){ + x_idx[axis] = z; + ret = acc(ret, x[offset(x_idx, shapes)]); + } + y[offset(y_idx, outer)] = ret; + }); +} + + +bool do_test(drv::stream* stream, std::vector shape, int axis, reduce_op_t op, int nwarp){ + typedef float NumericT; + std::string ty = "float"; + size_t dt_nbytes = sizeof(NumericT); + drv::context* context = stream->context(); + size_t axy = (axis == 0) ? 1 : 0; + std::string RY = (axis == 0) ? "rn" : "rm"; + std::vector hy(shape[axy]); + std::vector ry(shape[axy]); + std::vector hx(shape[0]*shape[1]); + srand(0); + init_zeros(hy); + init_rand(hx); + auto dy = std::shared_ptr(drv::buffer::create(context, hy.size()*dt_nbytes)); + auto dx = std::shared_ptr(drv::buffer::create(context, hx.size()*dt_nbytes)); + stream->write(&*dy, true, 0, hy); + stream->write(&*dx, true, 0, hx); + rt::function::options_space_t opt; + opt.defines.push_back({"TYPE", {ty}}); + opt.defines.push_back({"TM", {std::to_string(shape[0])}}); + opt.defines.push_back({"TN", {std::to_string(shape[1])}}); + opt.defines.push_back({"TY", {std::to_string(shape[axy])}}); + opt.defines.push_back({"RY", {RY}}); + std::string RED = ""; + for(int n = 0; n < 2; n++){ + if(n > 0) + RED += ", "; + RED += (n==axis) ? to_str(op) : ":"; + } + opt.defines.push_back({"RED", {RED}}); + opt.num_warps = {nwarp}; + rt::function function(src::reduce2d, opt); + function({&*dx, &*dy, shape[0], shape[1], shape[0]}, grid2d(shape[0], shape[1]), stream); + stream->synchronize(); + stream->read(&*dy, true, 0, hy); + reduce_nd(ry, hx, op, axis, shape); + return diff(hy, ry); +} + +int main() { + // initialize default compute device + auto context = triton::driver::backend::contexts::get_default(); + triton::driver::stream* stream = triton::driver::stream::create(context); + // shapes to benchmark + typedef std::tuple, int, reduce_op_t> config_t; + std::vector configs = { + config_t{{32, 32}, 0, MAX}, + config_t{{32, 32}, 1, ADD}, + config_t{{32, 64}, 0, ADD}, + config_t{{64, 32}, 1, ADD} + }; + // does the work + int axis; + std::vector shape; + reduce_op_t op; + for(const auto& c: configs){ + std::tie(shape, axis, op) = c; + std::cout << "Testing " << c << " ... " << std::flush; + if(do_test(stream, shape, axis, op, 1)) + std::cout << " Pass! " << std::endl; + else + std::cout << " Fail! " << std::endl; + } +}