diff --git a/include/triton/lang/ast.h b/include/triton/lang/ast.h index 43cfc485f..f523cb835 100644 --- a/include/triton/lang/ast.h +++ b/include/triton/lang/ast.h @@ -430,7 +430,6 @@ public: void AddrOpTypeChecking(); void DerefOpTypeChecking(); void ReduceOpTypeChecking(); - void TransOpTypeChecking(); void UnaryArithmOpTypeChecking(); void CastOpTypeChecking(); @@ -448,6 +447,28 @@ protected: Expr* operand_; }; +class TransOp: public Expr { + friend class Generator; + +public: + using PermInt = std::vector; + +public: + static TransOp* New(const PermInt& perm, Expr* operand); + const PermInt& getPerm() const { return perm_; } + void Accept(Visitor* v); + bool IsLVal() { return false; } + void TypeChecking(); + +protected: + TransOp(const PermInt& perm, Expr* operand) + : Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {} + +private: + Expr* operand_; + PermInt perm_; +}; + // cond ? true : false class ConditionalOp : public Expr { diff --git a/include/triton/lang/code_gen.h b/include/triton/lang/code_gen.h index 69a1a7514..96a02ce9a 100644 --- a/include/triton/lang/code_gen.h +++ b/include/triton/lang/code_gen.h @@ -58,6 +58,7 @@ public: // Expression void VisitBinaryOp(BinaryOp* binaryOp); void VisitUnaryOp(UnaryOp* unaryOp); + void VisitTransOp(TransOp* transOp); void VisitConditionalOp(ConditionalOp* condOp); void VisitFuncCall(FuncCall* funcCall); void VisitObject(Object* obj); @@ -130,6 +131,7 @@ public: void VisitConditionalOp(ConditionalOp*) { should_not_happen(); } void VisitFuncCall(FuncCall*) { should_not_happen(); } + void VisitTransOp(TransOp*) { should_not_happen(); } void VisitEnumerator(Enumerator*) { should_not_happen(); } void VisitConstant(Constant*) { should_not_happen(); } void VisitTempVar(TempVar*) { should_not_happen(); } diff --git a/include/triton/lang/evaluator.h b/include/triton/lang/evaluator.h index 589739b45..ac8404550 100644 --- a/include/triton/lang/evaluator.h +++ b/include/triton/lang/evaluator.h @@ -30,6 +30,9 @@ public: virtual void VisitIdentifier(Identifier* ident) { Error(ident, "expect constant expression"); } + virtual void VisitTransOp(TransOp* trans) { + Error(trans, "expect constant expression"); + } virtual void VisitObject(Object* obj) { Error(obj, "expect constant expression"); } @@ -83,6 +86,9 @@ public: virtual void VisitFuncCall(FuncCall* funcCall) { Error(funcCall, "expect constant expression"); } + virtual void VisitTransOp(TransOp* trans) { + Error(trans, "expect constant expression"); + } virtual void VisitEnumerator(Enumerator* enumer) { addr_.offset_ = enumer->Val(); } diff --git a/include/triton/lang/visitor.h b/include/triton/lang/visitor.h index 16398f57b..239071edf 100644 --- a/include/triton/lang/visitor.h +++ b/include/triton/lang/visitor.h @@ -6,6 +6,7 @@ class BinaryOp; class UnaryOp; +class TransOp; class ConditionalOp; class FuncCall; class Identifier; @@ -31,6 +32,7 @@ public: virtual ~Visitor() {} virtual void VisitBinaryOp(BinaryOp* binary) = 0; virtual void VisitUnaryOp(UnaryOp* unary) = 0; + virtual void VisitTransOp(TransOp* trans) = 0; virtual void VisitConditionalOp(ConditionalOp* cond) = 0; virtual void VisitFuncCall(FuncCall* funcCall) = 0; virtual void VisitEnumerator(Enumerator* enumer) = 0; diff --git a/lib/lang/ast.cc b/lib/lang/ast.cc index b6cd99633..1fd8b2dcb 100644 --- a/lib/lang/ast.cc +++ b/lib/lang/ast.cc @@ -7,6 +7,7 @@ static MemPoolImp binaryOpPool; +static MemPoolImp transOpPool; static MemPoolImp conditionalOpPool; static MemPoolImp funcCallPool; static MemPoolImp initializationPool; @@ -78,6 +79,9 @@ void UnaryOp::Accept(Visitor* v) { v->VisitUnaryOp(this); } +void TransOp::Accept(Visitor* v) { + v->VisitTransOp(this); +} void ConditionalOp::Accept(Visitor* v) { v->VisitConditionalOp(this); @@ -645,9 +649,6 @@ void UnaryOp::TypeChecking() { case Token::CAST: return CastOpTypeChecking(); - case '^': - return TransOpTypeChecking(); - case Token::REDUCE: return ReduceOpTypeChecking(); @@ -702,15 +703,6 @@ void UnaryOp::ReduceOpTypeChecking() { type_ = TileType::New(shape, tileType->Derived()); } -void UnaryOp::TransOpTypeChecking() { - auto tileType = operand_->Type()->ToTile(); - if(!tileType) - Error(this, "tile expected for transposition operator '^'"); - auto shape = tileType->Shape(); - std::rotate(shape.begin(), shape.begin() + 1, shape.end()); - type_ = TileType::New(shape, tileType->Derived()); -} - void UnaryOp::UnaryArithmOpTypeChecking() { auto scalType = TryExtractScalarType(this, operand_); if (Token::PLUS == op_ || Token::MINUS == op_) { @@ -769,6 +761,29 @@ void UnaryOp::CastOpTypeChecking() { } } +/* + * Transposition Operator + */ +void TransOp::TypeChecking() { + auto tileType = operand_->Type()->ToTile(); + if(!tileType) + Error(this, "tile expected for transposition operator '^'"); + auto opShape = tileType->Shape(); + if(perm_.size() != opShape.size()) + Error(this, "invalid permutations"); + // permutate input shape + TileType::ShapeInt resShape(opShape.size()); + for(int d = 0; d < opShape.size(); d++) + resShape[d] = opShape[perm_[d]]; + type_ = TileType::New(resShape, tileType->Derived()); +} + +TransOp* TransOp::New(const PermInt& perm, Expr* operand) { + auto ret = new (transOpPool.Alloc()) TransOp(perm, operand); + ret->pool_ = &transOpPool; + ret->TypeChecking(); + return ret; +} /* * Conditional Operator diff --git a/lib/lang/code_gen.cc b/lib/lang/code_gen.cc index aee604b4a..fdc754048 100644 --- a/lib/lang/code_gen.cc +++ b/lib/lang/code_gen.cc @@ -185,7 +185,6 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { case '~': return set_ret(bld_->create_neg(arg)); case '!': return set_ret(bld_->create_not(arg)); case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_))); - case '^': return set_ret(bld_->create_trans(arg)); case Token::REDUCE: { int ax, tag; UnaryOp::decodeRed(unary->info_, ax, tag); @@ -198,6 +197,12 @@ void Generator::VisitUnaryOp(UnaryOp* unary) { return error_not_implemented(); } +void Generator::VisitTransOp(TransOp *trans) { + Visit(trans->operand_); + ir::value* arg = ret_; + return set_ret(bld_->create_trans(arg, trans->getPerm())); +} + void Generator::VisitConditionalOp(ConditionalOp* condOp) { // auto &instructions = bld_->get_insert_block()->get_inst_list(); VisitExpr(condOp->cond_); diff --git a/lib/lang/parser.cc b/lib/lang/parser.cc index a30258c3d..960c983cf 100644 --- a/lib/lang/parser.cc +++ b/lib/lang/parser.cc @@ -451,6 +451,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) { QualType lhsQual = lhsTile->Derived(); // create ret shape TileType::ShapeInt shape; + TileType::ShapeInt axVec; size_t i = 0; const Token* tok; std::vector> redInfo; @@ -469,10 +470,22 @@ Expr* Parser::ParseSubScripting(Expr* lhs) { case Token::SUB: case Token::MAX: case Token::MIN:{ - int info = UnaryOp::encodeRed(i, tok->tag_); - redInfo.push_back({i, info}); - shape.push_back(lhsShape[i++]); - break; + int info = UnaryOp::encodeRed(i, tok->tag_); + redInfo.push_back({i, info}); + shape.push_back(lhsShape[i++]); + break; + } + + case '^':{ + Expr* expr = ParseConditionalExpr(); + EnsureInteger(expr); + int ax = Evaluator().Eval(expr); + axVec.push_back(ax); + if(ax < 0 || ax >= lhsShape.size()) + Error(tok, "unknown axis %d in transposition", ax); + shape.push_back(lhsShape[ax]); + i++; + break; } default: @@ -481,8 +494,19 @@ Expr* Parser::ParseSubScripting(Expr* lhs) { } }while(ts_.Try(',')); ts_.Expect(']'); + + // transposition mode + std::set axSet(axVec.begin(), axVec.end()); + if(!axSet.empty()){ + if(axSet.size()!=lhsShape.size()) + Error(tok, "transposition must address all axes of input array"); + return TransOp::New(axVec, lhs); + } + + // broadcasting mode if(lhsShape.size() > i) Error(tok, "broadcasting not using all operand axes"); + // create ret tile Expr* res = lhs; for(auto r: redInfo){ @@ -553,7 +577,15 @@ Expr* Parser::ParseUnaryExpr() { case '-': return ParseUnaryOp(tok, Token::MINUS); case '~': return ParseUnaryOp(tok, '~'); case '!': return ParseUnaryOp(tok, '!'); - case '^': return ParseUnaryOp(tok, Token::XOR); + case '^': { + auto operand = ParseCastExpr(); + TileType::ShapeInt shape = operand->Type()->ToTile()->Shape(); + TransOp::PermInt perm(shape.size()); + for(int d = 0; d < shape.size(); d++) + perm[d] = d; + std::rotate(perm.begin(), perm.begin() + 1, perm.end()); + return TransOp::New(perm, operand); + } default: ts_.PutBack(); return ParsePostfixExpr(); diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 571c8f1ba..43efa7db1 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -8,41 +8,36 @@ class _einsum(triton.function): int std_A0, int std_B0, int std_C0, int std_A1, int std_B1, int std_C1) { // program id - int pid0 = get_program_id(0); - int pid1 = get_program_id(1); - int pid2 = get_program_id(2); + int pgm = get_program_id(0); + int pgn = get_program_id(1); + int pgb = get_program_id(2); // range - int rma[TM] = pid0 * TM + 0 ... TM; - int rnb[TN] = pid1 * TN + 0 ... TN; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - int rba[TB] = pid2 * TB + 0 ... TB; - int rbb[TB] = pid2 * TB + 0 ... TB; + int rm[TM] = pgm * TM + 0 ... TM; + int rn[TN] = pgn * TN + 0 ... TN; + int rb[TB] = pgb * TB + 0 ... TB; + int rk[TK] = 0 ... TK; // accumulator TYPE c[TM, TN, TB] = 0; // pointers to a - TYPE *pa[TM, TK, TB] = A + rka[newaxis, :, newaxis] * 1 - + rma[:, newaxis, newaxis] * std_A1 - + rba[newaxis, newaxis, :] * std_A0; + TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + + rm[BROADCAST_AM] * STRIDE_AM + + rb[newaxis, newaxis, :] * std_A0; // pointers to b - TYPE *pb[TK, TN, TB] = B + rkb[:, newaxis, newaxis] * 1 - + rnb[newaxis, :, newaxis] * std_B1 - + rbb[newaxis, newaxis, :] * std_B0; + TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + + rn[BROADCAST_BN] * STRIDE_BN + + rb[newaxis, newaxis, :] * std_B0; // accumulation for(int k = dim_K; k > 0; k -= TK) { - TYPE a[TM, TK, TB] = *pa; - TYPE b[TK, TN, TB] = *pb; + TYPE a[SHAPE_A] = *pa; + TYPE b[SHAPE_B] = *pb; c += a @ b; pa += TK; pb += TK; } // write-back - int rmc[TM] = pid0 * TM + 0 ... TM; - int rnc[TN] = pid1 * TN + 0 ... TN; - int rbc[TB] = pid2 * TB + 0 ... TB; - TYPE *pc[TM, TN, TB] = C + rmc[:, newaxis, newaxis] * std_C1 - + rnc[newaxis, :, newaxis] * 1 - + rbc[newaxis, newaxis, :] * std_C0; + TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1 + + rn[newaxis, :, newaxis] * 1 + + rb[newaxis, newaxis, :] * std_C0; *pc = c; } """ @@ -138,12 +133,25 @@ class _einsum(triton.function): grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')), triton.cdiv(bmnk[2], opt.d('TN')), triton.cdiv(bmnk[0], opt.d('TB'))] - #print(std0, std1) + macros = {# handle A transposition + 'USE_A' : 'a[^1, ^0, ^2]' if trans_a else 'a', + 'STRIDE_AK' : 'std_A1' if trans_a else '1', + 'STRIDE_AM' : '1' if trans_a else 'std_A1', + 'BROADCAST_AK': ':, newaxis, newaxis' if trans_a else 'newaxis, :, newaxis', + 'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis', + 'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB', + # handle B transposition + 'USE_B' : 'b[^1, ^0, ^2]' if not trans_b else 'b', + 'STRIDE_BK' : 'std_B1' if not trans_b else '1', + 'STRIDE_BN' : '1' if not trans_b else 'std_B1', + 'BROADCAST_BK': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis', + 'BROADCAST_BN': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis', + 'SHAPE_B' : 'TN, TK, TB' if not trans_b else 'TK, TN, TB'} return _einsum.kernel(a, b, c, bmnk[1], bmnk[2], bmnk[3], std0[0], std0[1], std0[2], std1[0], std1[1], std1[2], - grid, + grid, **macros, TYPE='float', TM=32, TN=32, TK=8, TB=1) diff --git a/tests/common/dot.h b/tests/common/dot.h index 3ee1e9f68..23bb46c72 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -86,14 +86,14 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, // macros rt::function::options_space_t opt; // A access patterns - opt.defines.push_back({"USEA", {AT? "^a" : "a" }}); + opt.defines.push_back({"USEA", {AT? "a[^1, ^0]" : "a" }}); opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }}); opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }}); opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }}); opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }}); opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }}); // B access patterns - opt.defines.push_back({"USEB", {BT? "^b" : "b" }}); + opt.defines.push_back({"USEB", {BT? "b[^1, ^0]" : "b" }}); opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }}); opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }}); opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});