diff --git a/include/triton/lang/ast.h b/include/triton/lang/ast.h index 8bf96a96b..43cfc485f 100644 --- a/include/triton/lang/ast.h +++ b/include/triton/lang/ast.h @@ -418,22 +418,25 @@ class UnaryOp : public Expr { friend class LValAssigner; public: - static UnaryOp* New(int op, Expr* operand, QualType type=nullptr); + static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0); virtual ~UnaryOp() {} virtual void Accept(Visitor* v); virtual bool IsLVal(); ::Type *Convert(); + static int encodeRed(int ax, int tag); + static void decodeRed(int info, int& ax, int& tag); void TypeChecking(); void IncDecOpTypeChecking(); void AddrOpTypeChecking(); void DerefOpTypeChecking(); + void ReduceOpTypeChecking(); void TransOpTypeChecking(); void UnaryArithmOpTypeChecking(); void CastOpTypeChecking(); protected: - UnaryOp(int op, Expr* operand, QualType type=nullptr) - : Expr(operand->Tok(), type), op_(op) { + UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0) + : Expr(operand->Tok(), type), op_(op), info_(info) { operand_ = operand; if (op_ != Token::CAST && op_ != Token::ADDR) { operand_ = MayCast(operand); @@ -441,6 +444,7 @@ protected: } int op_; + int info_; Expr* operand_; }; diff --git a/include/triton/lang/token.h b/include/triton/lang/token.h index 5724c50e3..602113f93 100644 --- a/include/triton/lang/token.h +++ b/include/triton/lang/token.h @@ -180,9 +180,7 @@ public: PLUS, MINUS, CAST, - REDUCE_ADD, - REDUCE_MAX, - REDUCE_MIN, + REDUCE, // For preprocessor PP_IF, diff --git a/lib/codegen/selection.cc b/lib/codegen/selection.cc index ff246f4f5..02611444f 100644 --- a/lib/codegen/selection.cc +++ b/lib/codegen/selection.cc @@ -962,7 +962,9 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, tgt_->add_barrier(module, builder); builder.CreateStore(result, write_ptr); // build result - unsigned depth = params_->get_param(op, "wpt.d" + std::to_string(axis))->get_value(); + unsigned shape_ax = op->get_type()->get_tile_shapes()[axis]; + unsigned per_thread = op_tile->axis(axis).values.size(); + unsigned depth = shape_ax / per_thread; for(unsigned i = depth/2; i > 0; i >>= 1){ // current indices indices_t current(write_idx.size(), builder.getInt32(0)); diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index e1f008c36..b0a50adc3 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -448,6 +448,8 @@ void BinaryOp::RangeOpTypeChecking() { } void BinaryOp::MaskedDerefOpTypeChecking() { +// auto lhsTileType = lhs_->Type()->ToTile(); +// auto rhsTileType = rhs_->Type()->ToTile(); ::Type* lhsScalType = TryExtractScalarType(this, lhs_); ::Type* rhsScalType = TryExtractScalarType(this, rhs_); auto lhsType = lhsScalType->ToArithm(); @@ -572,8 +574,8 @@ void BinaryOp::AssignOpTypeChecking() { * Unary Operators */ -UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) { - auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type); +UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) { + auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info); ret->pool_ = &unaryOpPool; ret->TypeChecking(); @@ -581,6 +583,18 @@ UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) { } +int UnaryOp::encodeRed(int ax, int tag) { + int result = 0; + result |= ax; + result |= tag << 16; + return result; +} + +void UnaryOp::decodeRed(int info, int& ax, int& tag) { + ax = info & 0x0000FFFF; + tag = (info & 0xFFFF0000) >> 16; +} + bool UnaryOp::IsLVal() { // Only deref('*') could be lvalue; return op_ == Token::DEREF; @@ -626,6 +640,9 @@ void UnaryOp::TypeChecking() { case '^': return TransOpTypeChecking(); + case Token::REDUCE: + return ReduceOpTypeChecking(); + default: assert(false); } @@ -663,6 +680,16 @@ void UnaryOp::DerefOpTypeChecking() { type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr()); } +void UnaryOp::ReduceOpTypeChecking() { + int ax, tag; + decodeRed(info_, ax, tag); + auto tileType = operand_->Type()->ToTile(); + if(!tileType) + Error(this, "array expected for reduction operation"); + auto shape = tileType->Shape(); + shape.erase(shape.begin() + ax); + type_ = TileType::New(shape, tileType->Derived()); +} void UnaryOp::TransOpTypeChecking() { auto tileType = operand_->Type()->ToTile(); diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index 228bd69dd..56acb1c03 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -174,6 +174,11 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { case '!': return set_ret(bld_->create_not(op)); case Token::CAST: return set_ret(GenCastOp(op, GenIRType(unary->Type(), *ctx_))); case '^': return set_ret(bld_->create_trans(op)); + case Token::REDUCE: { + int ax, tag; + UnaryOp::decodeRed(unary->info_, ax, tag); + return set_ret(bld_->create_reduce(op, ax)); + } default: error_not_implemented(); } return error_not_implemented(); @@ -412,16 +417,41 @@ void Generator::Gen(ir::module *mod) { ir::value* Generator::GenBroadcastOp(ir::value* src, ir::type* dst_ty) { + if(src->get_type() == dst_ty) + return src; if(dst_ty->is_tile_ty()) { ir::type *src_ty = src->get_type(); auto dst_shapes = dst_ty->get_tile_shapes(); if(!src_ty->is_tile_ty()) return bld_->create_splat(src, dst_shapes); auto src_shapes = src_ty->get_tile_shapes(); - if(src_shapes.size() != dst_shapes.size()) - return bld_->create_reshape(src, dst_shapes); - else + if(src_shapes.size() != dst_shapes.size()){ + unsigned src_numel = 1; + for(unsigned s: src_shapes) + src_numel *= s; + unsigned dst_numel = 1; + for(unsigned s: dst_shapes) + dst_numel *= s; + if(src_numel == dst_numel) + return bld_->create_reshape(src, dst_shapes); + else { + auto padded_shapes = src_shapes; + while(padded_shapes.size() != dst_shapes.size()) + padded_shapes.insert(padded_shapes.begin(), 1); + // check that broadcast is legal + for(size_t d = 0; d < padded_shapes.size(); d++){ + if(dst_shapes[d] != padded_shapes[d] && + padded_shapes[d] != 1) + should_not_happen(); + } + // pad and broadcast + ir::value *padded = bld_->create_reshape(src, padded_shapes); + return bld_->create_broadcast(padded, dst_shapes); + } + } + else{ return bld_->create_broadcast(src, dst_shapes); + } } return src; } diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index 6c669208f..f69337ced 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -453,7 +453,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) { TileType::ShapeInt shape; size_t i = 0; const Token* tok; - std::vector> redList; + std::vector> redInfo; do { tok = ts_.Next(); switch(tok->tag_) { @@ -465,10 +465,13 @@ Expr* Parser::ParseSubScripting(Expr* lhs) { shape.push_back(1); break; -// case Token::ADD: -// case Token::SUB: -// redList.push_back({i, tok->tag_}); -// break; + case Token::ADD: + case Token::SUB:{ + int info = UnaryOp::encodeRed(i, tok->tag_); + redInfo.push_back({i, info}); + shape.push_back(lhsShape[i++]); + break; + } default: Error(tok, "Unexpected subscript symbol encountered at dimension %d", i); @@ -479,8 +482,21 @@ Expr* Parser::ParseSubScripting(Expr* lhs) { if(lhsShape.size() > i) Error(tok, "broadcasting not using all operand axes"); // create ret tile - TileType *retType = TileType::New(shape, lhsQual); - return UnaryOp::New(Token::CAST, lhs, retType); + Expr* res = lhs; + for(auto r: redInfo){ + shape.erase(shape.begin() + r.first); + Type *retType; + if(shape.empty()) + retType = lhsQual.GetPtr(); + else + retType = TileType::New(shape, lhsQual); + res = UnaryOp::New(Token::REDUCE, res, retType, r.second); + } + if(!shape.empty()){ + TileType *retType = TileType::New(shape, lhsQual); + res = UnaryOp::New(Token::CAST, res, retType); + } + return res; } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 114626dce..ae21128f6 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -204,6 +204,7 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::transform::peephole peephole; codegen::transform::reassociate reassociate(&alignment_info, &grids); codegen::selection selection(&shmem_allocation, &grids, &shmem_info, &alignment_info, target.get()); +// ir::print(module, std::cout); // run passes peephole.run(module); dce.run(module); diff --git a/tests/common/src/reduce.h b/tests/common/src/reduce.h index a9788f340..1f2be7461 100644 --- a/tests/common/src/reduce.h +++ b/tests/common/src/reduce.h @@ -19,7 +19,7 @@ void reduce2d(TYPE * X __noalias __readonly __aligned(16), int rm[TM] = ridm * TM + 0 ... TM; int rn[TN] = ridn * TN + 0 ... TN; TYPE* px[TM, TN] = X + rm[:, newaxis] + rn[newaxis, :] * ldx; - TYPE* py[TM, TN] = Y + rm[:, newaxis]; + TYPE* py[TM] = Y + rm; *py = (*px)[:, +]; } )"; diff --git a/tests/common/util.h b/tests/common/util.h index e5cfef7b8..874b33e84 100644 --- a/tests/common/util.h +++ b/tests/common/util.h @@ -37,6 +37,12 @@ void init_rand(std::vector& x) { x[i] = static_cast((double)rand()/RAND_MAX); } +template +void init_zeros(std::vector& x) { + for(size_t i = 0; i < x.size(); i++) + x[i] = 0; +} + namespace aux{ diff --git a/tests/unit/reduce.cc b/tests/unit/reduce.cc index 59b574c4d..c513d9cb5 100644 --- a/tests/unit/reduce.cc +++ b/tests/unit/reduce.cc @@ -15,15 +15,26 @@ namespace drv = triton::driver; namespace rt = triton::runtime; +template +void cpu_ref(std::vector &y, const std::vector &x, int M, int N) { + for(int m = 0; m < M; m++){ + T acc = 0; + for(int n = 0; n < N; n++) + acc = acc + x[m + n*M]; + y[m] = acc; + } +} + bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){ typedef float NumericT; std::string ty = "float"; size_t dt_nbytes = sizeof(NumericT); drv::context* context = stream->context(); std::vector hy(M); + std::vector ry(M); std::vector hx(M*N); srand(0); - init_rand(hy); + init_zeros(hy); init_rand(hx); auto dy = std::shared_ptr(drv::buffer::create(context, hy.size()*dt_nbytes)); auto dx = std::shared_ptr(drv::buffer::create(context, hx.size()*dt_nbytes)); @@ -35,8 +46,11 @@ bool do_test(drv::stream* stream, int M, int N, std::string op, int nwarp){ opt.defines.push_back({"TN", {std::to_string(N)}}); opt.num_warps = {nwarp}; rt::function function(src::reduce2d, opt); - function({&*dy, &*dx, M, N, M}, grid2d(M, N), stream); + function({&*dx, &*dy, M, N, M}, grid2d(M, N), stream); stream->synchronize(); + stream->read(&*dy, true, 0, hy); + cpu_ref(ry, hx, M, N); + return testing::diff(hy, ry); } int main() {