diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 55c25b575..83a38be4e 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -55,8 +55,8 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string lda0 = "*lda", lda1 = ""; std::string ldb0 = "", ldb1 = "*ldb"; - std::string usea = AT ? "trans(a)" : "a"; - std::string useb = BT ? "trans(b)" : "b"; + std::string usea = AT ? "^a" : "a"; + std::string useb = BT ? "^b" : "b"; if(AT){ std::swap(AS0, AS1); std::swap(XAS0, XAS1); @@ -82,6 +82,11 @@ R"( #define TN 128 #define TK 32 +#define bool _Bool +#define true 1 +#define false 0 +#define __bool_true_false_are_defined 1 + extern int get_program_id(int); void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))), @@ -94,28 +99,28 @@ void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))), int ridx = get_program_id(0); int ridy = get_program_id(1); int rxa[{TM, TN}] = ridx * TM + 0 ... TM; - int ryb[TN] = ridy * TN + 0 ... TN; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; - float xc[)" + XCS + R"(] = 0; - )" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; - )" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; - )" + a_ty + R"( a[)" + AS + R"(] = *pa; - )" + b_ty + R"( b[)" + BS + R"(] = *pb; + int ryb[{TN}] = ridy * TN + 0 ... TN; + int rka[{TK}] = 0 ... TK; + int rkb[{TK}] = 0 ... TK; + float xc[{)" + XCS + R"(}] = 0; + )" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; + )" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; + )" + a_ty + R"( a[{)" + AS + R"(}] = *pa; + )" + b_ty + R"( b[{)" + BS + R"(}] = *pb; for(int k = K; k > 0; k = k - TK){ - xc = dot()" + usea + ", " + useb + R"(, xc); + xc = )" + usea + " @ " + useb + R"( + xc; pa = pa + TK)" + lda0 + R"(; pb = pb + TK)" + ldb0 + R"(; a = *pa; b = *pb; } - int rxc[TM] = ridx * TM + (0 ... TM); - int ryc[TN] = ridy * TN + (0 ... TN); - )" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - )" + c_ty + R"( c[TM, TN] = xc; - bool checkc0[TM] = rxc < M; - bool checkc1[TN] = ryc < N; - bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + int rxc[{TM}] = ridx * TM + (0 ... TM); + int ryc[{TN}] = ridy * TN + (0 ... TN); + )" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + )" + c_ty + R"( c[{TM, TN}] = xc; + bool checkc0[{TM}] = rxc < M; + bool checkc1[{TN}] = ryc < N; + bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :]; *pc = c; } )"; diff --git a/include/triton/lang/wgtcc/ast.h b/include/triton/lang/wgtcc/ast.h index 3cb3257f7..1181fb63a 100644 --- a/include/triton/lang/wgtcc/ast.h +++ b/include/triton/lang/wgtcc/ast.h @@ -278,6 +278,9 @@ public: static Expr* MayCast(Expr* expr); static Expr* MayCast(Expr* expr, QualType desType); + static ::Type* TryExtractScalarType(Expr* loc, Expr *operand); + static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty); + virtual bool IsNullPointerConstant() const { return false; } bool IsConstQualified() const { return type_.IsConstQualified(); } bool IsRestrictQualified() const { return type_.IsRestrictQualified(); } @@ -332,6 +335,7 @@ public: void AdditiveOpTypeChecking(); void ShiftOpTypeChecking(); void RangeOpTypeChecking(); + void MatmulOpTypeChecking(); void RelationalOpTypeChecking(); void EqualityOpTypeChecking(); void BitwiseOpTypeChecking(); @@ -378,11 +382,12 @@ public: virtual ~UnaryOp() {} virtual void Accept(Visitor* v); virtual bool IsLVal(); - ArithmType* Convert(); + ::Type *Convert(); void TypeChecking(); void IncDecOpTypeChecking(); void AddrOpTypeChecking(); void DerefOpTypeChecking(); + void TransOpTypeChecking(); void UnaryArithmOpTypeChecking(); void CastOpTypeChecking(); diff --git a/include/triton/lang/wgtcc/parser.h b/include/triton/lang/wgtcc/parser.h index c1de92491..8c21af727 100644 --- a/include/triton/lang/wgtcc/parser.h +++ b/include/triton/lang/wgtcc/parser.h @@ -75,6 +75,7 @@ public: QualType ParseTypeName(); Expr* ParseCastExpr(); Expr* ParseRangeExpr(); + Expr* ParseMatmulExpr(); Expr* ParseMultiplicativeExpr(); Expr* ParseAdditiveExpr(); Expr* ParseShiftExpr(); diff --git a/include/triton/lang/wgtcc/token.h b/include/triton/lang/wgtcc/token.h index 391507f80..e982ec803 100644 --- a/include/triton/lang/wgtcc/token.h +++ b/include/triton/lang/wgtcc/token.h @@ -64,7 +64,7 @@ public: NOT = '!', COND = '?', SHARP = '#', - AT = '@', + MATMUL = '@', NEW_LINE = '\n', DSHARP = 128, // '##' @@ -126,6 +126,10 @@ public: NORETURN, // _Noreturn // FUNCTION SPECIFIER END + // TILE ARITHMETICS BEGIN + NEWAXIS, + // TILE ARITHMETICS END + ALIGNAS, // _Alignas // For syntactic convenience STATIC_ASSERT, // _Static_assert diff --git a/include/triton/lang/wgtcc/type.h b/include/triton/lang/wgtcc/type.h index 20c2fa898..b43b74339 100644 --- a/include/triton/lang/wgtcc/type.h +++ b/include/triton/lang/wgtcc/type.h @@ -153,6 +153,10 @@ public: virtual bool IsBool() const { return false; } virtual bool IsVoidPointer() const { return false; } virtual bool IsUnsigned() const { return false; } + virtual bool IsTile() const { return ToTile() != nullptr; } + + const Type* ScalarType() const; + Type* ScalarType(); virtual VoidType* ToVoid() { return nullptr; } virtual const VoidType* ToVoid() const { return nullptr; } @@ -327,16 +331,22 @@ public: static TileType* New(const ShapeInt& shape, QualType eleType); virtual ~TileType() { } - virtual TileType* toTile() { return this; } - virtual const TileType* toTile() const { return this; } + virtual TileType* ToTile() { return this; } + virtual const TileType* ToTile() const { return this; } virtual bool Compatible(const Type& other) const; - virtual int Width() const { return 0; } + virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; } virtual int Align() const { return derived_->Align(); } virtual std::string Str() const { return derived_->Str() + "[{}]:" + std::to_string(Width()); } ShapeInt Shape() { return shape_; } + int NumEle() const { + int ret = 1; + for(int s: shape_) + ret *= s; + return ret; + } protected: TileType(MemPool* pool, const ShapeExpr& expr, QualType derived) diff --git a/lib/lang/wgtcc/ast.cc b/lib/lang/wgtcc/ast.cc index eb673584f..5646fbb4c 100644 --- a/lib/lang/wgtcc/ast.cc +++ b/lib/lang/wgtcc/ast.cc @@ -144,6 +144,26 @@ Expr* Expr::MayCast(Expr* expr, QualType desType) { return expr; } +// Extract the operand's scalar type if possible +// and emit an error otherwise +::Type* Expr::TryExtractScalarType(Expr* loc, Expr *operand) { + auto scalType = operand->Type()->ScalarType(); + if(!scalType) + Error(loc, "expect tile or scalar operand"); + return scalType; +} + +// If operand is a tile, return a tile of the same shape and +// provided element type +// If operand is a scalar, return provided element type +// directly +::Type* Expr::ScalarOrLikeTile(Expr* operand, ::Type* ty) { + assert(ty->IsScalar()); + ::Type *retTy = ty; + if(TileType *T = operand->Type()->ToTile()) + retTy = TileType::New(T->Shape(), retTy); + return retTy; +} BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) { return New(tok, tok->tag_, lhs, rhs); @@ -166,6 +186,7 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) { case Token::LOGICAL_AND: case Token::LOGICAL_OR: case Token::ELLIPSIS: + case Token::MATMUL: break; default: assert(0); @@ -180,18 +201,18 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) { ArithmType* BinaryOp::Convert() { - // Both lhs and rhs are ensured to be have arithmetic type - auto lhsType = lhs_->Type()->ToArithm(); - auto rhsType = rhs_->Type()->ToArithm(); + // Both lhs and rhs are ensured to be have arithmetic scalar type + auto lhsType = lhs_->Type()->ScalarType()->ToArithm(); + auto rhsType = rhs_->Type()->ScalarType()->ToArithm(); assert(lhsType && rhsType); - auto type = ArithmType::MaxType(lhsType, rhsType); - if (lhsType != type) { // Pointer comparation is enough! - lhs_ = UnaryOp::New(Token::CAST, lhs_, type); + auto maxType = ArithmType::MaxType(lhsType, rhsType); + if (lhsType != maxType) { // Pointer comparation is enough! + lhs_ = UnaryOp::New(Token::CAST, lhs_, ScalarOrLikeTile(lhs_, maxType)); } - if (rhsType != type) { - rhs_ = UnaryOp::New(Token::CAST, rhs_, type); + if (rhsType != maxType) { + rhs_ = UnaryOp::New(Token::CAST, rhs_, ScalarOrLikeTile(rhs_, maxType)); } - return type; + return maxType; } void BinaryOp::Broadcast() { @@ -225,6 +246,8 @@ void BinaryOp::Broadcast() { retShape[i] = rhsShape[i]; else if(rhsShape[i] == 1) retShape[i] = lhsShape[i]; + else if(lhsShape[i] == rhsShape[i]) + retShape[i] = lhsShape[i]; else Error(this, "cannot broadcast dimension %d " "for operands of shape %d and %d", @@ -232,8 +255,10 @@ void BinaryOp::Broadcast() { } auto eleType = lhsType->Derived(); type_ = TileType::New(retShape, eleType); - lhs_ = UnaryOp::New(Token::CAST, lhs_, type_); - rhs_ = UnaryOp::New(Token::CAST, rhs_, type_); + if(retShape != lhsShape) + lhs_ = UnaryOp::New(Token::CAST, lhs_, type_); + if(retShape != rhsShape) + rhs_ = UnaryOp::New(Token::CAST, rhs_, type_); } } @@ -303,6 +328,9 @@ void BinaryOp::TypeChecking() { case Token::ELLIPSIS: return RangeOpTypeChecking(); + case Token::MATMUL: + return MatmulOpTypeChecking(); + default: assert(0); } @@ -315,12 +343,15 @@ void BinaryOp::CommaOpTypeChecking() { void BinaryOp::SubScriptingOpTypeChecking() { - auto lhsType = lhs_->Type()->ToPointer(); + assert(false); + auto lhsType = lhs_->Type()->ToTile(); + if (!lhsType) { - Error(this, "an pointer expected"); + Error(this, "operator [] can only be used on tiles"); } + if (!rhs_->Type()->IsInteger()) { - Error(this, "the operand of [] should be intger"); + Error(this, "the operand of [] should be integer"); } // The type of [] operator is the derived type @@ -334,14 +365,20 @@ void BinaryOp::MemberRefOpTypeChecking() { void BinaryOp::MultiOpTypeChecking() { - if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) { + ::Type* lhsScalType = lhs_->Type()->ScalarType(); + ::Type* rhsScalType = rhs_->Type()->ScalarType(); + if(!lhsScalType || !rhsScalType) { + Error(this, "operands should have type or scalar type"); + } + if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) { Error(this, "operands should have arithmetic type"); } if ('%' == op_ && - !(lhs_->Type()->IsInteger() && rhs_->Type()->IsInteger())) { + !(lhsScalType->IsInteger() && rhsScalType->IsInteger())) { Error(this, "operands of '%%' should be integers"); } type_ = Convert(); + Broadcast(); } @@ -351,40 +388,47 @@ void BinaryOp::MultiOpTypeChecking() { * 2. pointer can be used: * 1. lhs of MINUS operator, and rhs must be integer or pointer; * 2. lhs/rhs of ADD operator, and the other operand must be integer; + * 3. tiles can be used: + * 1. the scalar type of lhs/rhs satisfy the above requirements + * 2. lhs/rhs that have identical shape + * 3. lhs/rhs that can be broadcast as per numpy-like semantics */ void BinaryOp::AdditiveOpTypeChecking() { - auto lhsType = lhs_->Type()->ToPointer(); - auto rhsType = rhs_->Type()->ToPointer(); - if (lhsType) { + ::Type* lhsScalType = TryExtractScalarType(this, lhs_); + ::Type* rhsScalType = TryExtractScalarType(this, rhs_); + auto lhsPtrType = lhsScalType->ToPointer(); + auto rhsPtrType = rhsScalType->ToPointer(); + if (lhsPtrType) { if (op_ == '-') { - if (rhsType) { - if (!lhsType->Compatible(*rhsType)) + if (rhsPtrType) { + if (!lhsPtrType->Compatible(*rhsPtrType)) Error(this, "invalid operands to binary -"); type_ = ArithmType::New(T_LONG); // ptrdiff_t - } else if (!rhs_->Type()->IsInteger()) { + } else if (!rhsScalType->IsInteger()) { Error(this, "invalid operands to binary -"); } else { - type_ = lhsType; + type_ = lhsPtrType; } - } else if (!rhs_->Type()->IsInteger()) { + } else if (!rhsScalType->IsInteger()) { Error(this, "invalid operands to binary +"); } else { - type_ = lhsType; + type_ = lhsPtrType; } - } else if (rhsType) { - if (op_ == '+' && !lhs_->Type()->IsInteger()) { + } else if (rhsPtrType) { + if (op_ == '+' && !lhsScalType->IsInteger()) { Error(this, "invalid operands to binary '+'"); - } else if (op_ == '-' && !lhsType) { + } else if (op_ == '-' && !lhsPtrType) { Error(this, "invalid operands to binary '-'"); } - type_ = op_ == '-' ? ArithmType::New(T_LONG): rhs_->Type(); + type_ = op_ == '-' ? ArithmType::New(T_LONG): rhsScalType; std::swap(lhs_, rhs_); // To simplify code gen } else { - if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) { + if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) { Error(this, "invalid operands to binary %s", tok_->str_.c_str()); } type_ = Convert(); } + Broadcast(); } void BinaryOp::RangeOpTypeChecking() { @@ -396,59 +440,95 @@ void BinaryOp::RangeOpTypeChecking() { rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType)); long begin = Evaluator().Eval(lhs_); long end = Evaluator().Eval(rhs_); - int len = end - begin; + int len = static_cast(end - begin); if(len < 0) Error(this, "range cannot be negative"); type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type()); } +void BinaryOp::MatmulOpTypeChecking() { + auto lhsType = lhs_->Type()->ToTile(); + auto rhsType = rhs_->Type()->ToTile(); + if(!lhsType || !rhsType) + Error(this, "expect tile operands for matrix multiplication"); + auto lhsShape = lhsType->Shape(); + auto rhsShape = rhsType->Shape(); + size_t lhsRank = lhsShape.size(); + size_t rhsRank = rhsShape.size(); + if(lhsRank != 2 || rhsRank != 2) + Error(this, "matrix multiplication operands must have rank 2"); + if(lhsShape[1] != rhsShape[0]) + Error(this, "matrix multiplication operands have incompatible inner dimension" + " %d and %d", lhsShape[1], rhsShape[0]); + TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]}; + QualType retType = lhsType->Derived(); + if(retType != rhsType->Derived()) + Error(this, "matrix multiplication operands have incompatible data types"); + type_ = TileType::New(retShape, lhsType->Derived()); +} + void BinaryOp::ShiftOpTypeChecking() { - auto lhsType = lhs_->Type()->ToArithm(); - auto rhsType = rhs_->Type()->ToArithm(); + ::Type* lhsScalType = TryExtractScalarType(this, lhs_); + ::Type* rhsScalType = TryExtractScalarType(this, rhs_); + auto lhsType = lhsScalType->ToArithm(); + auto rhsType = rhsScalType->ToArithm(); if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger()) Error(this, "expect integers for shift operator"); - lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType)); - rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType)); + lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType))); + rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType))); type_ = lhs_->Type(); + Broadcast(); } void BinaryOp::RelationalOpTypeChecking() { - if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) { - EnsureCompatible(lhs_->Type(), rhs_->Type()); + ::Type* lhsScalType = TryExtractScalarType(this, lhs_); + ::Type* rhsScalType = TryExtractScalarType(this, rhs_); + if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) { + EnsureCompatible(lhsScalType, rhsScalType); } else { - if (!lhs_->Type()->IsReal() || !rhs_->Type()->IsReal()) { + if (!lhsScalType->IsReal() || !rhsScalType->IsReal()) { Error(this, "expect real type of operands"); } Convert(); } type_ = ArithmType::New(T_INT); + Broadcast(); } void BinaryOp::EqualityOpTypeChecking() { - if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) { - EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type()); + ::Type* lhsScalType = TryExtractScalarType(this, lhs_); + ::Type* rhsScalType = TryExtractScalarType(this, rhs_); + if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) { + EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType); } else { - if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) + if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) Error(this, "invalid operands to binary %s", tok_->str_.c_str()); Convert(); } type_ = ArithmType::New(T_INT); + Broadcast(); } void BinaryOp::BitwiseOpTypeChecking() { - if (!lhs_->Type()->IsInteger() || !rhs_->Type()->IsInteger()) + ::Type* lhsScalType = TryExtractScalarType(this, lhs_); + ::Type* rhsScalType = TryExtractScalarType(this, rhs_); + if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger()) Error(this, "operands of '&' should be integer"); type_ = Convert(); + Broadcast(); } void BinaryOp::LogicalOpTypeChecking() { - if (!lhs_->Type()->IsScalar() || !rhs_->Type()->IsScalar()) + ::Type* lhsScalType = TryExtractScalarType(this, lhs_); + ::Type* rhsScalType = TryExtractScalarType(this, rhs_); + if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar()) Error(this, "the operand should be arithmetic type or pointer"); type_ = ArithmType::New(T_INT); + Broadcast(); } @@ -459,12 +539,14 @@ void BinaryOp::AssignOpTypeChecking() { Error(lhs_, "lvalue expression expected"); } - if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) { - EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type()); + ::Type* lhsScalType = TryExtractScalarType(this, lhs_); + ::Type* rhsScalType = TryExtractScalarType(this, rhs_); + if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) { + EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType); } // The other constraints are lefted to cast operator - rhs_ = Expr::MayCast(rhs_, lhs_->Type()); + rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType)); type_ = lhs_->Type(); } @@ -488,13 +570,16 @@ bool UnaryOp::IsLVal() { } -ArithmType* UnaryOp::Convert() { - auto arithmType = operand_->Type()->ToArithm(); +::Type* UnaryOp::Convert() { + auto scalType = operand_->Type()->ScalarType(); + assert(scalType); + auto arithmType = scalType->ToArithm(); assert(arithmType); if (arithmType->IsInteger()) arithmType = ArithmType::IntegerPromote(arithmType); - operand_ = Expr::MayCast(operand_, arithmType); - return arithmType; + ::Type* retType = ScalarOrLikeTile(operand_, arithmType); + operand_ = Expr::MayCast(operand_, retType); + return retType; } @@ -521,20 +606,22 @@ void UnaryOp::TypeChecking() { case Token::CAST: return CastOpTypeChecking(); + case '^': + return TransOpTypeChecking(); + default: assert(false); } } - void UnaryOp::IncDecOpTypeChecking() { if (operand_->IsConstQualified()) { Error(this, "increment/decrement of const qualified expression"); } else if (!operand_->IsLVal()) { Error(this, "lvalue expression expected"); } - - if (!operand_->Type()->IsReal() && !operand_->Type()->ToPointer()) { + auto scalType = TryExtractScalarType(this, operand_); + if (!scalType->IsReal() && !scalType->ToPointer()) { Error(this, "expect operand of real type or pointer"); } type_ = operand_->Type(); @@ -545,43 +632,78 @@ void UnaryOp::AddrOpTypeChecking() { auto funcType = operand_->Type()->ToFunc(); if (funcType == nullptr && !operand_->IsLVal()) Error(this, "expression must be an lvalue or function designator"); + if(operand_->Type()->IsTile()) + Error(this, "cannot take the address of a tile"); type_ = PointerType::New(operand_->Type()); } void UnaryOp::DerefOpTypeChecking() { - auto pointerType = operand_->Type()->ToPointer(); + auto scalType = TryExtractScalarType(this, operand_); + auto pointerType = scalType->ToPointer(); if (!pointerType) Error(this, "pointer expected for deref operator '*'"); - type_ = pointerType->Derived(); + type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr()); } +void UnaryOp::TransOpTypeChecking() { + auto tileType = operand_->Type()->ToTile(); + if(!tileType) + Error(this, "tile expected for transposition operator '^'"); + auto shape = tileType->Shape(); + std::rotate(shape.begin(), shape.begin() + 1, shape.end()); + type_ = TileType::New(shape, tileType->Derived()); +} + void UnaryOp::UnaryArithmOpTypeChecking() { + auto scalType = TryExtractScalarType(this, operand_); if (Token::PLUS == op_ || Token::MINUS == op_) { - if (!operand_->Type()->ToArithm()) + if (!scalType->ToArithm()) Error(this, "Arithmetic type expected"); Convert(); type_ = operand_->Type(); } else if ('~' == op_) { - if (!operand_->Type()->IsInteger()) + if (!scalType->IsInteger()) Error(this, "integer expected for operator '~'"); Convert(); type_ = operand_->Type(); - } else if (!operand_->Type()->IsScalar()) { + } else if (!scalType->IsScalar()) { Error(this, "arithmetic type or pointer expected for operator '!'"); } else { - type_ = ArithmType::New(T_INT); + type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_INT)); } } void UnaryOp::CastOpTypeChecking() { auto operandType = Type::MayCast(operand_->Type()); - // The type_ has been initiated to dest type if (type_->ToVoid()) { // The expression becomes a void expression + } else if(type_->IsTile() || operandType->IsTile()) { + /* Broadcasting rules: + * 1. Tiles with 1 element can be converted to scalar + * 2. Scalar can be converted to tiles of any shapes + * 3. Tiles can be converted to another tile only if the + * mismatching dimensions are unitary + */ + if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1) + Error(this, "tile with more than one element cannot be casted to scalar"); + if(type_->IsTile() && operandType->IsTile()){ + auto shape = type_->ToTile()->Shape(); + auto operandShape = operandType->ToTile()->Shape(); + if(operandShape.size() > shape.size()) + Error(this, "cast cannot reduce operand rank"); + while(operandShape.size() < shape.size()) + operandShape.insert(operandShape.begin(), 1); + for(size_t i = 0; i < shape.size(); i++) { + if(shape[i] != 1 && operandShape[i] != 1 && shape[i] != operandShape[i]) + Error(this, "cannot broadcast dimension %d " + "for operands of shape %d and %d", + i, shape[i], operandShape[i]); + } + } } else if (!type_->IsScalar() || !operandType->IsScalar()) { if (!type_->Compatible(*operandType)) Error(this, "the cast type should be arithemetic type or pointer"); diff --git a/lib/lang/wgtcc/parser.cc b/lib/lang/wgtcc/parser.cc index 8ec16ee51..cf1e582fc 100644 --- a/lib/lang/wgtcc/parser.cc +++ b/lib/lang/wgtcc/parser.cc @@ -442,11 +442,27 @@ Expr* Parser::ParsePostfixExprTail(Expr* lhs) { Expr* Parser::ParseSubScripting(Expr* lhs) { - auto rhs = ParseExpr(); - auto tok = ts_.Peek(); + auto lhsTile = lhs->Type()->ToTile(); + if(lhsTile == nullptr) + Error(lhs, "tile expected"); + TileType::ShapeInt lhsShape = lhsTile->Shape(); + QualType lhsQual = lhsTile->Derived(); + // create ret shape + TileType::ShapeInt shape; + size_t i = 0; + do { + auto tok = ts_.Next(); + if(tok->tag_ == ':') + shape.push_back(lhsShape[i++]); + else if(tok->tag_ == Token::NEWAXIS) + shape.push_back(1); + else + Error(tok, "only ':' and newaxis are supported in subscripts"); + }while(ts_.Try(',')); ts_.Expect(']'); - auto operand = BinaryOp::New(tok, '+', lhs, rhs); - return UnaryOp::New(Token::DEREF, operand); + // create ret tile + TileType *retType = TileType::New(shape, lhsQual); + return UnaryOp::New(Token::CAST, lhs, retType); } @@ -501,6 +517,7 @@ Expr* Parser::ParseUnaryExpr() { case '-': return ParseUnaryOp(tok, Token::MINUS); case '~': return ParseUnaryOp(tok, '~'); case '!': return ParseUnaryOp(tok, '!'); + case '^': return ParseUnaryOp(tok, Token::XOR); default: ts_.PutBack(); return ParsePostfixExpr(); @@ -584,7 +601,7 @@ Expr* Parser::ParseCastExpr() { Expr* Parser::ParseRangeExpr() { auto lhs = ParseCastExpr(); auto tok = ts_.Next(); - while (tok->tag_ == Token::ELLIPSIS){ + while (tok->tag_ == Token::ELLIPSIS) { auto rhs = ParseCastExpr(); lhs = BinaryOp::New(tok, lhs, rhs); tok = ts_.Next(); @@ -593,16 +610,26 @@ Expr* Parser::ParseRangeExpr() { return lhs; } -Expr* Parser::ParseMultiplicativeExpr() { +Expr* Parser::ParseMatmulExpr() { auto lhs = ParseRangeExpr(); auto tok = ts_.Next(); - while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') { + while (tok->tag_ == Token::MATMUL) { auto rhs = ParseRangeExpr(); lhs = BinaryOp::New(tok, lhs, rhs); - tok = ts_.Next(); } + ts_.PutBack(); + return lhs; +} +Expr* Parser::ParseMultiplicativeExpr() { + auto lhs = ParseMatmulExpr(); + auto tok = ts_.Next(); + while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') { + auto rhs = ParseMatmulExpr(); + lhs = BinaryOp::New(tok, lhs, rhs); + tok = ts_.Next(); + } ts_.PutBack(); return lhs; } diff --git a/lib/lang/wgtcc/token.cc b/lib/lang/wgtcc/token.cc index 62c9b41f6..ba588588e 100644 --- a/lib/lang/wgtcc/token.cc +++ b/lib/lang/wgtcc/token.cc @@ -27,6 +27,7 @@ const std::unordered_map Token::kwTypeMap_ { { "inline", Token::INLINE }, { "int", Token::INT }, { "long", Token::LONG }, + { "newaxis", Token::NEWAXIS }, { "signed", Token::SIGNED }, { "unsigned", Token::UNSIGNED }, { "register", Token::REGISTER }, @@ -126,6 +127,7 @@ const std::unordered_map Token::tagLexemeMap_ { { Token::INLINE, "inline" }, { Token::INT, "int" }, { Token::LONG, "long" }, + { Token::NEWAXIS, "newaxis" }, { Token::SIGNED, "signed" }, { Token::UNSIGNED, "unsigned" }, { Token::REGISTER, "register" }, diff --git a/lib/lang/wgtcc/type.cc b/lib/lang/wgtcc/type.cc index 369e8ed05..94f17b985 100644 --- a/lib/lang/wgtcc/type.cc +++ b/lib/lang/wgtcc/type.cc @@ -32,6 +32,18 @@ QualType Type::MayCast(QualType type, bool inProtoScope) { return type; } +const Type* Type::ScalarType() const { + if(IsScalar()) + return this; + if(const TileType* p = ToTile()) + return p->Derived().GetPtr(); + return nullptr; +} + +Type* Type::ScalarType() { + auto cthis = const_cast(this); + return const_cast(cthis->ScalarType()); +} VoidType* VoidType::New() { static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool); @@ -143,12 +155,16 @@ int ArithmType::Width() const { return intWidth_ << 1; case T_LLONG: case T_UNSIGNED | T_LLONG: return intWidth_ << 1; + case T_HALF: + return intWidth_ >> 1; case T_FLOAT: return intWidth_; case T_DOUBLE: return intWidth_ << 1; case T_LONG | T_DOUBLE: return intWidth_ << 1; + case T_HALF | T_COMPLEX: + return intWidth_; case T_FLOAT | T_COMPLEX: return intWidth_ << 1; case T_DOUBLE | T_COMPLEX: @@ -171,9 +187,10 @@ int ArithmType::Rank() const { case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3; case T_LONG: case T_UNSIGNED | T_LONG: return 4; case T_LLONG: case T_UNSIGNED | T_LLONG: return 5; - case T_FLOAT: return 6; - case T_DOUBLE: return 7; - case T_LONG | T_DOUBLE: return 8; + case T_HALF: return 6; + case T_FLOAT: return 7; + case T_DOUBLE: return 8; + case T_LONG | T_DOUBLE: return 9; default: assert(tag_ & T_COMPLEX); Error("complex not supported yet");