basic parsing doesn't throw error
This commit is contained in:
@@ -55,8 +55,8 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
|
|||||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||||
std::string lda0 = "*lda", lda1 = "";
|
std::string lda0 = "*lda", lda1 = "";
|
||||||
std::string ldb0 = "", ldb1 = "*ldb";
|
std::string ldb0 = "", ldb1 = "*ldb";
|
||||||
std::string usea = AT ? "trans(a)" : "a";
|
std::string usea = AT ? "^a" : "a";
|
||||||
std::string useb = BT ? "trans(b)" : "b";
|
std::string useb = BT ? "^b" : "b";
|
||||||
if(AT){
|
if(AT){
|
||||||
std::swap(AS0, AS1);
|
std::swap(AS0, AS1);
|
||||||
std::swap(XAS0, XAS1);
|
std::swap(XAS0, XAS1);
|
||||||
@@ -82,6 +82,11 @@ R"(
|
|||||||
#define TN 128
|
#define TN 128
|
||||||
#define TK 32
|
#define TK 32
|
||||||
|
|
||||||
|
#define bool _Bool
|
||||||
|
#define true 1
|
||||||
|
#define false 0
|
||||||
|
#define __bool_true_false_are_defined 1
|
||||||
|
|
||||||
extern int get_program_id(int);
|
extern int get_program_id(int);
|
||||||
|
|
||||||
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||||
@@ -94,28 +99,28 @@ void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
|||||||
int ridx = get_program_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_program_id(1);
|
int ridy = get_program_id(1);
|
||||||
int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
|
int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
|
||||||
int ryb[TN] = ridy * TN + 0 ... TN;
|
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[{TK}] = 0 ... TK;
|
||||||
int rkb[TK] = 0 ... TK;
|
int rkb[{TK}] = 0 ... TK;
|
||||||
float xc[)" + XCS + R"(] = 0;
|
float xc[{)" + XCS + R"(}] = 0;
|
||||||
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||||
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||||
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
|
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
||||||
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
|
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
||||||
for(int k = K; k > 0; k = k - TK){
|
for(int k = K; k > 0; k = k - TK){
|
||||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||||
pa = pa + TK)" + lda0 + R"(;
|
pa = pa + TK)" + lda0 + R"(;
|
||||||
pb = pb + TK)" + ldb0 + R"(;
|
pb = pb + TK)" + ldb0 + R"(;
|
||||||
a = *pa;
|
a = *pa;
|
||||||
b = *pb;
|
b = *pb;
|
||||||
}
|
}
|
||||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
||||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
||||||
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||||
)" + c_ty + R"( c[TM, TN] = xc;
|
)" + c_ty + R"( c[{TM, TN}] = xc;
|
||||||
bool checkc0[TM] = rxc < M;
|
bool checkc0[{TM}] = rxc < M;
|
||||||
bool checkc1[TN] = ryc < N;
|
bool checkc1[{TN}] = ryc < N;
|
||||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
*pc = c;
|
*pc = c;
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
@@ -278,6 +278,9 @@ public:
|
|||||||
|
|
||||||
static Expr* MayCast(Expr* expr);
|
static Expr* MayCast(Expr* expr);
|
||||||
static Expr* MayCast(Expr* expr, QualType desType);
|
static Expr* MayCast(Expr* expr, QualType desType);
|
||||||
|
static ::Type* TryExtractScalarType(Expr* loc, Expr *operand);
|
||||||
|
static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty);
|
||||||
|
|
||||||
virtual bool IsNullPointerConstant() const { return false; }
|
virtual bool IsNullPointerConstant() const { return false; }
|
||||||
bool IsConstQualified() const { return type_.IsConstQualified(); }
|
bool IsConstQualified() const { return type_.IsConstQualified(); }
|
||||||
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
|
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
|
||||||
@@ -332,6 +335,7 @@ public:
|
|||||||
void AdditiveOpTypeChecking();
|
void AdditiveOpTypeChecking();
|
||||||
void ShiftOpTypeChecking();
|
void ShiftOpTypeChecking();
|
||||||
void RangeOpTypeChecking();
|
void RangeOpTypeChecking();
|
||||||
|
void MatmulOpTypeChecking();
|
||||||
void RelationalOpTypeChecking();
|
void RelationalOpTypeChecking();
|
||||||
void EqualityOpTypeChecking();
|
void EqualityOpTypeChecking();
|
||||||
void BitwiseOpTypeChecking();
|
void BitwiseOpTypeChecking();
|
||||||
@@ -378,11 +382,12 @@ public:
|
|||||||
virtual ~UnaryOp() {}
|
virtual ~UnaryOp() {}
|
||||||
virtual void Accept(Visitor* v);
|
virtual void Accept(Visitor* v);
|
||||||
virtual bool IsLVal();
|
virtual bool IsLVal();
|
||||||
ArithmType* Convert();
|
::Type *Convert();
|
||||||
void TypeChecking();
|
void TypeChecking();
|
||||||
void IncDecOpTypeChecking();
|
void IncDecOpTypeChecking();
|
||||||
void AddrOpTypeChecking();
|
void AddrOpTypeChecking();
|
||||||
void DerefOpTypeChecking();
|
void DerefOpTypeChecking();
|
||||||
|
void TransOpTypeChecking();
|
||||||
void UnaryArithmOpTypeChecking();
|
void UnaryArithmOpTypeChecking();
|
||||||
void CastOpTypeChecking();
|
void CastOpTypeChecking();
|
||||||
|
|
||||||
|
@@ -75,6 +75,7 @@ public:
|
|||||||
QualType ParseTypeName();
|
QualType ParseTypeName();
|
||||||
Expr* ParseCastExpr();
|
Expr* ParseCastExpr();
|
||||||
Expr* ParseRangeExpr();
|
Expr* ParseRangeExpr();
|
||||||
|
Expr* ParseMatmulExpr();
|
||||||
Expr* ParseMultiplicativeExpr();
|
Expr* ParseMultiplicativeExpr();
|
||||||
Expr* ParseAdditiveExpr();
|
Expr* ParseAdditiveExpr();
|
||||||
Expr* ParseShiftExpr();
|
Expr* ParseShiftExpr();
|
||||||
|
@@ -64,7 +64,7 @@ public:
|
|||||||
NOT = '!',
|
NOT = '!',
|
||||||
COND = '?',
|
COND = '?',
|
||||||
SHARP = '#',
|
SHARP = '#',
|
||||||
AT = '@',
|
MATMUL = '@',
|
||||||
NEW_LINE = '\n',
|
NEW_LINE = '\n',
|
||||||
|
|
||||||
DSHARP = 128, // '##'
|
DSHARP = 128, // '##'
|
||||||
@@ -126,6 +126,10 @@ public:
|
|||||||
NORETURN, // _Noreturn
|
NORETURN, // _Noreturn
|
||||||
// FUNCTION SPECIFIER END
|
// FUNCTION SPECIFIER END
|
||||||
|
|
||||||
|
// TILE ARITHMETICS BEGIN
|
||||||
|
NEWAXIS,
|
||||||
|
// TILE ARITHMETICS END
|
||||||
|
|
||||||
ALIGNAS, // _Alignas
|
ALIGNAS, // _Alignas
|
||||||
// For syntactic convenience
|
// For syntactic convenience
|
||||||
STATIC_ASSERT, // _Static_assert
|
STATIC_ASSERT, // _Static_assert
|
||||||
|
@@ -153,6 +153,10 @@ public:
|
|||||||
virtual bool IsBool() const { return false; }
|
virtual bool IsBool() const { return false; }
|
||||||
virtual bool IsVoidPointer() const { return false; }
|
virtual bool IsVoidPointer() const { return false; }
|
||||||
virtual bool IsUnsigned() const { return false; }
|
virtual bool IsUnsigned() const { return false; }
|
||||||
|
virtual bool IsTile() const { return ToTile() != nullptr; }
|
||||||
|
|
||||||
|
const Type* ScalarType() const;
|
||||||
|
Type* ScalarType();
|
||||||
|
|
||||||
virtual VoidType* ToVoid() { return nullptr; }
|
virtual VoidType* ToVoid() { return nullptr; }
|
||||||
virtual const VoidType* ToVoid() const { return nullptr; }
|
virtual const VoidType* ToVoid() const { return nullptr; }
|
||||||
@@ -327,16 +331,22 @@ public:
|
|||||||
static TileType* New(const ShapeInt& shape, QualType eleType);
|
static TileType* New(const ShapeInt& shape, QualType eleType);
|
||||||
virtual ~TileType() { }
|
virtual ~TileType() { }
|
||||||
|
|
||||||
virtual TileType* toTile() { return this; }
|
virtual TileType* ToTile() { return this; }
|
||||||
virtual const TileType* toTile() const { return this; }
|
virtual const TileType* ToTile() const { return this; }
|
||||||
virtual bool Compatible(const Type& other) const;
|
virtual bool Compatible(const Type& other) const;
|
||||||
virtual int Width() const { return 0; }
|
virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; }
|
||||||
virtual int Align() const { return derived_->Align(); }
|
virtual int Align() const { return derived_->Align(); }
|
||||||
virtual std::string Str() const {
|
virtual std::string Str() const {
|
||||||
return derived_->Str() + "[{}]:" + std::to_string(Width());
|
return derived_->Str() + "[{}]:" + std::to_string(Width());
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapeInt Shape() { return shape_; }
|
ShapeInt Shape() { return shape_; }
|
||||||
|
int NumEle() const {
|
||||||
|
int ret = 1;
|
||||||
|
for(int s: shape_)
|
||||||
|
ret *= s;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
TileType(MemPool* pool, const ShapeExpr& expr, QualType derived)
|
TileType(MemPool* pool, const ShapeExpr& expr, QualType derived)
|
||||||
|
@@ -144,6 +144,26 @@ Expr* Expr::MayCast(Expr* expr, QualType desType) {
|
|||||||
return expr;
|
return expr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract the operand's scalar type if possible
|
||||||
|
// and emit an error otherwise
|
||||||
|
::Type* Expr::TryExtractScalarType(Expr* loc, Expr *operand) {
|
||||||
|
auto scalType = operand->Type()->ScalarType();
|
||||||
|
if(!scalType)
|
||||||
|
Error(loc, "expect tile or scalar operand");
|
||||||
|
return scalType;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If operand is a tile, return a tile of the same shape and
|
||||||
|
// provided element type
|
||||||
|
// If operand is a scalar, return provided element type
|
||||||
|
// directly
|
||||||
|
::Type* Expr::ScalarOrLikeTile(Expr* operand, ::Type* ty) {
|
||||||
|
assert(ty->IsScalar());
|
||||||
|
::Type *retTy = ty;
|
||||||
|
if(TileType *T = operand->Type()->ToTile())
|
||||||
|
retTy = TileType::New(T->Shape(), retTy);
|
||||||
|
return retTy;
|
||||||
|
}
|
||||||
|
|
||||||
BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) {
|
BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) {
|
||||||
return New(tok, tok->tag_, lhs, rhs);
|
return New(tok, tok->tag_, lhs, rhs);
|
||||||
@@ -166,6 +186,7 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
|
|||||||
case Token::LOGICAL_AND:
|
case Token::LOGICAL_AND:
|
||||||
case Token::LOGICAL_OR:
|
case Token::LOGICAL_OR:
|
||||||
case Token::ELLIPSIS:
|
case Token::ELLIPSIS:
|
||||||
|
case Token::MATMUL:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
assert(0);
|
assert(0);
|
||||||
@@ -180,18 +201,18 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
|
|||||||
|
|
||||||
|
|
||||||
ArithmType* BinaryOp::Convert() {
|
ArithmType* BinaryOp::Convert() {
|
||||||
// Both lhs and rhs are ensured to be have arithmetic type
|
// Both lhs and rhs are ensured to be have arithmetic scalar type
|
||||||
auto lhsType = lhs_->Type()->ToArithm();
|
auto lhsType = lhs_->Type()->ScalarType()->ToArithm();
|
||||||
auto rhsType = rhs_->Type()->ToArithm();
|
auto rhsType = rhs_->Type()->ScalarType()->ToArithm();
|
||||||
assert(lhsType && rhsType);
|
assert(lhsType && rhsType);
|
||||||
auto type = ArithmType::MaxType(lhsType, rhsType);
|
auto maxType = ArithmType::MaxType(lhsType, rhsType);
|
||||||
if (lhsType != type) { // Pointer comparation is enough!
|
if (lhsType != maxType) { // Pointer comparation is enough!
|
||||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, type);
|
lhs_ = UnaryOp::New(Token::CAST, lhs_, ScalarOrLikeTile(lhs_, maxType));
|
||||||
}
|
}
|
||||||
if (rhsType != type) {
|
if (rhsType != maxType) {
|
||||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, type);
|
rhs_ = UnaryOp::New(Token::CAST, rhs_, ScalarOrLikeTile(rhs_, maxType));
|
||||||
}
|
}
|
||||||
return type;
|
return maxType;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BinaryOp::Broadcast() {
|
void BinaryOp::Broadcast() {
|
||||||
@@ -225,6 +246,8 @@ void BinaryOp::Broadcast() {
|
|||||||
retShape[i] = rhsShape[i];
|
retShape[i] = rhsShape[i];
|
||||||
else if(rhsShape[i] == 1)
|
else if(rhsShape[i] == 1)
|
||||||
retShape[i] = lhsShape[i];
|
retShape[i] = lhsShape[i];
|
||||||
|
else if(lhsShape[i] == rhsShape[i])
|
||||||
|
retShape[i] = lhsShape[i];
|
||||||
else
|
else
|
||||||
Error(this, "cannot broadcast dimension %d "
|
Error(this, "cannot broadcast dimension %d "
|
||||||
"for operands of shape %d and %d",
|
"for operands of shape %d and %d",
|
||||||
@@ -232,8 +255,10 @@ void BinaryOp::Broadcast() {
|
|||||||
}
|
}
|
||||||
auto eleType = lhsType->Derived();
|
auto eleType = lhsType->Derived();
|
||||||
type_ = TileType::New(retShape, eleType);
|
type_ = TileType::New(retShape, eleType);
|
||||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
|
if(retShape != lhsShape)
|
||||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
|
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
|
||||||
|
if(retShape != rhsShape)
|
||||||
|
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -303,6 +328,9 @@ void BinaryOp::TypeChecking() {
|
|||||||
case Token::ELLIPSIS:
|
case Token::ELLIPSIS:
|
||||||
return RangeOpTypeChecking();
|
return RangeOpTypeChecking();
|
||||||
|
|
||||||
|
case Token::MATMUL:
|
||||||
|
return MatmulOpTypeChecking();
|
||||||
|
|
||||||
default:
|
default:
|
||||||
assert(0);
|
assert(0);
|
||||||
}
|
}
|
||||||
@@ -315,12 +343,15 @@ void BinaryOp::CommaOpTypeChecking() {
|
|||||||
|
|
||||||
|
|
||||||
void BinaryOp::SubScriptingOpTypeChecking() {
|
void BinaryOp::SubScriptingOpTypeChecking() {
|
||||||
auto lhsType = lhs_->Type()->ToPointer();
|
assert(false);
|
||||||
|
auto lhsType = lhs_->Type()->ToTile();
|
||||||
|
|
||||||
if (!lhsType) {
|
if (!lhsType) {
|
||||||
Error(this, "an pointer expected");
|
Error(this, "operator [] can only be used on tiles");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!rhs_->Type()->IsInteger()) {
|
if (!rhs_->Type()->IsInteger()) {
|
||||||
Error(this, "the operand of [] should be intger");
|
Error(this, "the operand of [] should be integer");
|
||||||
}
|
}
|
||||||
|
|
||||||
// The type of [] operator is the derived type
|
// The type of [] operator is the derived type
|
||||||
@@ -334,14 +365,20 @@ void BinaryOp::MemberRefOpTypeChecking() {
|
|||||||
|
|
||||||
|
|
||||||
void BinaryOp::MultiOpTypeChecking() {
|
void BinaryOp::MultiOpTypeChecking() {
|
||||||
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
|
::Type* lhsScalType = lhs_->Type()->ScalarType();
|
||||||
|
::Type* rhsScalType = rhs_->Type()->ScalarType();
|
||||||
|
if(!lhsScalType || !rhsScalType) {
|
||||||
|
Error(this, "operands should have type or scalar type");
|
||||||
|
}
|
||||||
|
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
||||||
Error(this, "operands should have arithmetic type");
|
Error(this, "operands should have arithmetic type");
|
||||||
}
|
}
|
||||||
if ('%' == op_ &&
|
if ('%' == op_ &&
|
||||||
!(lhs_->Type()->IsInteger() && rhs_->Type()->IsInteger())) {
|
!(lhsScalType->IsInteger() && rhsScalType->IsInteger())) {
|
||||||
Error(this, "operands of '%%' should be integers");
|
Error(this, "operands of '%%' should be integers");
|
||||||
}
|
}
|
||||||
type_ = Convert();
|
type_ = Convert();
|
||||||
|
Broadcast();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -351,40 +388,47 @@ void BinaryOp::MultiOpTypeChecking() {
|
|||||||
* 2. pointer can be used:
|
* 2. pointer can be used:
|
||||||
* 1. lhs of MINUS operator, and rhs must be integer or pointer;
|
* 1. lhs of MINUS operator, and rhs must be integer or pointer;
|
||||||
* 2. lhs/rhs of ADD operator, and the other operand must be integer;
|
* 2. lhs/rhs of ADD operator, and the other operand must be integer;
|
||||||
|
* 3. tiles can be used:
|
||||||
|
* 1. the scalar type of lhs/rhs satisfy the above requirements
|
||||||
|
* 2. lhs/rhs that have identical shape
|
||||||
|
* 3. lhs/rhs that can be broadcast as per numpy-like semantics
|
||||||
*/
|
*/
|
||||||
void BinaryOp::AdditiveOpTypeChecking() {
|
void BinaryOp::AdditiveOpTypeChecking() {
|
||||||
auto lhsType = lhs_->Type()->ToPointer();
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
auto rhsType = rhs_->Type()->ToPointer();
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
if (lhsType) {
|
auto lhsPtrType = lhsScalType->ToPointer();
|
||||||
|
auto rhsPtrType = rhsScalType->ToPointer();
|
||||||
|
if (lhsPtrType) {
|
||||||
if (op_ == '-') {
|
if (op_ == '-') {
|
||||||
if (rhsType) {
|
if (rhsPtrType) {
|
||||||
if (!lhsType->Compatible(*rhsType))
|
if (!lhsPtrType->Compatible(*rhsPtrType))
|
||||||
Error(this, "invalid operands to binary -");
|
Error(this, "invalid operands to binary -");
|
||||||
type_ = ArithmType::New(T_LONG); // ptrdiff_t
|
type_ = ArithmType::New(T_LONG); // ptrdiff_t
|
||||||
} else if (!rhs_->Type()->IsInteger()) {
|
} else if (!rhsScalType->IsInteger()) {
|
||||||
Error(this, "invalid operands to binary -");
|
Error(this, "invalid operands to binary -");
|
||||||
} else {
|
} else {
|
||||||
type_ = lhsType;
|
type_ = lhsPtrType;
|
||||||
}
|
}
|
||||||
} else if (!rhs_->Type()->IsInteger()) {
|
} else if (!rhsScalType->IsInteger()) {
|
||||||
Error(this, "invalid operands to binary +");
|
Error(this, "invalid operands to binary +");
|
||||||
} else {
|
} else {
|
||||||
type_ = lhsType;
|
type_ = lhsPtrType;
|
||||||
}
|
}
|
||||||
} else if (rhsType) {
|
} else if (rhsPtrType) {
|
||||||
if (op_ == '+' && !lhs_->Type()->IsInteger()) {
|
if (op_ == '+' && !lhsScalType->IsInteger()) {
|
||||||
Error(this, "invalid operands to binary '+'");
|
Error(this, "invalid operands to binary '+'");
|
||||||
} else if (op_ == '-' && !lhsType) {
|
} else if (op_ == '-' && !lhsPtrType) {
|
||||||
Error(this, "invalid operands to binary '-'");
|
Error(this, "invalid operands to binary '-'");
|
||||||
}
|
}
|
||||||
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhs_->Type();
|
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhsScalType;
|
||||||
std::swap(lhs_, rhs_); // To simplify code gen
|
std::swap(lhs_, rhs_); // To simplify code gen
|
||||||
} else {
|
} else {
|
||||||
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
|
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
||||||
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
||||||
}
|
}
|
||||||
type_ = Convert();
|
type_ = Convert();
|
||||||
}
|
}
|
||||||
|
Broadcast();
|
||||||
}
|
}
|
||||||
|
|
||||||
void BinaryOp::RangeOpTypeChecking() {
|
void BinaryOp::RangeOpTypeChecking() {
|
||||||
@@ -396,59 +440,95 @@ void BinaryOp::RangeOpTypeChecking() {
|
|||||||
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
|
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
|
||||||
long begin = Evaluator<long>().Eval(lhs_);
|
long begin = Evaluator<long>().Eval(lhs_);
|
||||||
long end = Evaluator<long>().Eval(rhs_);
|
long end = Evaluator<long>().Eval(rhs_);
|
||||||
int len = end - begin;
|
int len = static_cast<int>(end - begin);
|
||||||
if(len < 0)
|
if(len < 0)
|
||||||
Error(this, "range cannot be negative");
|
Error(this, "range cannot be negative");
|
||||||
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
|
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BinaryOp::MatmulOpTypeChecking() {
|
||||||
|
auto lhsType = lhs_->Type()->ToTile();
|
||||||
|
auto rhsType = rhs_->Type()->ToTile();
|
||||||
|
if(!lhsType || !rhsType)
|
||||||
|
Error(this, "expect tile operands for matrix multiplication");
|
||||||
|
auto lhsShape = lhsType->Shape();
|
||||||
|
auto rhsShape = rhsType->Shape();
|
||||||
|
size_t lhsRank = lhsShape.size();
|
||||||
|
size_t rhsRank = rhsShape.size();
|
||||||
|
if(lhsRank != 2 || rhsRank != 2)
|
||||||
|
Error(this, "matrix multiplication operands must have rank 2");
|
||||||
|
if(lhsShape[1] != rhsShape[0])
|
||||||
|
Error(this, "matrix multiplication operands have incompatible inner dimension"
|
||||||
|
" %d and %d", lhsShape[1], rhsShape[0]);
|
||||||
|
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
|
||||||
|
QualType retType = lhsType->Derived();
|
||||||
|
if(retType != rhsType->Derived())
|
||||||
|
Error(this, "matrix multiplication operands have incompatible data types");
|
||||||
|
type_ = TileType::New(retShape, lhsType->Derived());
|
||||||
|
}
|
||||||
|
|
||||||
void BinaryOp::ShiftOpTypeChecking() {
|
void BinaryOp::ShiftOpTypeChecking() {
|
||||||
auto lhsType = lhs_->Type()->ToArithm();
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
auto rhsType = rhs_->Type()->ToArithm();
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
|
auto lhsType = lhsScalType->ToArithm();
|
||||||
|
auto rhsType = rhsScalType->ToArithm();
|
||||||
if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
|
if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
|
||||||
Error(this, "expect integers for shift operator");
|
Error(this, "expect integers for shift operator");
|
||||||
lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType));
|
lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType)));
|
||||||
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
|
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType)));
|
||||||
type_ = lhs_->Type();
|
type_ = lhs_->Type();
|
||||||
|
Broadcast();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void BinaryOp::RelationalOpTypeChecking() {
|
void BinaryOp::RelationalOpTypeChecking() {
|
||||||
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) {
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
EnsureCompatible(lhs_->Type(), rhs_->Type());
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
|
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
|
||||||
|
EnsureCompatible(lhsScalType, rhsScalType);
|
||||||
} else {
|
} else {
|
||||||
if (!lhs_->Type()->IsReal() || !rhs_->Type()->IsReal()) {
|
if (!lhsScalType->IsReal() || !rhsScalType->IsReal()) {
|
||||||
Error(this, "expect real type of operands");
|
Error(this, "expect real type of operands");
|
||||||
}
|
}
|
||||||
Convert();
|
Convert();
|
||||||
}
|
}
|
||||||
type_ = ArithmType::New(T_INT);
|
type_ = ArithmType::New(T_INT);
|
||||||
|
Broadcast();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void BinaryOp::EqualityOpTypeChecking() {
|
void BinaryOp::EqualityOpTypeChecking() {
|
||||||
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) {
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type());
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
|
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
|
||||||
|
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
|
||||||
} else {
|
} else {
|
||||||
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm())
|
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm())
|
||||||
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
||||||
Convert();
|
Convert();
|
||||||
}
|
}
|
||||||
type_ = ArithmType::New(T_INT);
|
type_ = ArithmType::New(T_INT);
|
||||||
|
Broadcast();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void BinaryOp::BitwiseOpTypeChecking() {
|
void BinaryOp::BitwiseOpTypeChecking() {
|
||||||
if (!lhs_->Type()->IsInteger() || !rhs_->Type()->IsInteger())
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
|
if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger())
|
||||||
Error(this, "operands of '&' should be integer");
|
Error(this, "operands of '&' should be integer");
|
||||||
type_ = Convert();
|
type_ = Convert();
|
||||||
|
Broadcast();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void BinaryOp::LogicalOpTypeChecking() {
|
void BinaryOp::LogicalOpTypeChecking() {
|
||||||
if (!lhs_->Type()->IsScalar() || !rhs_->Type()->IsScalar())
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
|
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
|
||||||
Error(this, "the operand should be arithmetic type or pointer");
|
Error(this, "the operand should be arithmetic type or pointer");
|
||||||
type_ = ArithmType::New(T_INT);
|
type_ = ArithmType::New(T_INT);
|
||||||
|
Broadcast();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -459,12 +539,14 @@ void BinaryOp::AssignOpTypeChecking() {
|
|||||||
Error(lhs_, "lvalue expression expected");
|
Error(lhs_, "lvalue expression expected");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||||
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type());
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
|
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
||||||
|
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The other constraints are lefted to cast operator
|
// The other constraints are lefted to cast operator
|
||||||
rhs_ = Expr::MayCast(rhs_, lhs_->Type());
|
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
|
||||||
type_ = lhs_->Type();
|
type_ = lhs_->Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -488,13 +570,16 @@ bool UnaryOp::IsLVal() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ArithmType* UnaryOp::Convert() {
|
::Type* UnaryOp::Convert() {
|
||||||
auto arithmType = operand_->Type()->ToArithm();
|
auto scalType = operand_->Type()->ScalarType();
|
||||||
|
assert(scalType);
|
||||||
|
auto arithmType = scalType->ToArithm();
|
||||||
assert(arithmType);
|
assert(arithmType);
|
||||||
if (arithmType->IsInteger())
|
if (arithmType->IsInteger())
|
||||||
arithmType = ArithmType::IntegerPromote(arithmType);
|
arithmType = ArithmType::IntegerPromote(arithmType);
|
||||||
operand_ = Expr::MayCast(operand_, arithmType);
|
::Type* retType = ScalarOrLikeTile(operand_, arithmType);
|
||||||
return arithmType;
|
operand_ = Expr::MayCast(operand_, retType);
|
||||||
|
return retType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -521,20 +606,22 @@ void UnaryOp::TypeChecking() {
|
|||||||
case Token::CAST:
|
case Token::CAST:
|
||||||
return CastOpTypeChecking();
|
return CastOpTypeChecking();
|
||||||
|
|
||||||
|
case '^':
|
||||||
|
return TransOpTypeChecking();
|
||||||
|
|
||||||
default:
|
default:
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void UnaryOp::IncDecOpTypeChecking() {
|
void UnaryOp::IncDecOpTypeChecking() {
|
||||||
if (operand_->IsConstQualified()) {
|
if (operand_->IsConstQualified()) {
|
||||||
Error(this, "increment/decrement of const qualified expression");
|
Error(this, "increment/decrement of const qualified expression");
|
||||||
} else if (!operand_->IsLVal()) {
|
} else if (!operand_->IsLVal()) {
|
||||||
Error(this, "lvalue expression expected");
|
Error(this, "lvalue expression expected");
|
||||||
}
|
}
|
||||||
|
auto scalType = TryExtractScalarType(this, operand_);
|
||||||
if (!operand_->Type()->IsReal() && !operand_->Type()->ToPointer()) {
|
if (!scalType->IsReal() && !scalType->ToPointer()) {
|
||||||
Error(this, "expect operand of real type or pointer");
|
Error(this, "expect operand of real type or pointer");
|
||||||
}
|
}
|
||||||
type_ = operand_->Type();
|
type_ = operand_->Type();
|
||||||
@@ -545,43 +632,78 @@ void UnaryOp::AddrOpTypeChecking() {
|
|||||||
auto funcType = operand_->Type()->ToFunc();
|
auto funcType = operand_->Type()->ToFunc();
|
||||||
if (funcType == nullptr && !operand_->IsLVal())
|
if (funcType == nullptr && !operand_->IsLVal())
|
||||||
Error(this, "expression must be an lvalue or function designator");
|
Error(this, "expression must be an lvalue or function designator");
|
||||||
|
if(operand_->Type()->IsTile())
|
||||||
|
Error(this, "cannot take the address of a tile");
|
||||||
type_ = PointerType::New(operand_->Type());
|
type_ = PointerType::New(operand_->Type());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void UnaryOp::DerefOpTypeChecking() {
|
void UnaryOp::DerefOpTypeChecking() {
|
||||||
auto pointerType = operand_->Type()->ToPointer();
|
auto scalType = TryExtractScalarType(this, operand_);
|
||||||
|
auto pointerType = scalType->ToPointer();
|
||||||
if (!pointerType)
|
if (!pointerType)
|
||||||
Error(this, "pointer expected for deref operator '*'");
|
Error(this, "pointer expected for deref operator '*'");
|
||||||
type_ = pointerType->Derived();
|
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void UnaryOp::TransOpTypeChecking() {
|
||||||
|
auto tileType = operand_->Type()->ToTile();
|
||||||
|
if(!tileType)
|
||||||
|
Error(this, "tile expected for transposition operator '^'");
|
||||||
|
auto shape = tileType->Shape();
|
||||||
|
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
|
||||||
|
type_ = TileType::New(shape, tileType->Derived());
|
||||||
|
}
|
||||||
|
|
||||||
void UnaryOp::UnaryArithmOpTypeChecking() {
|
void UnaryOp::UnaryArithmOpTypeChecking() {
|
||||||
|
auto scalType = TryExtractScalarType(this, operand_);
|
||||||
if (Token::PLUS == op_ || Token::MINUS == op_) {
|
if (Token::PLUS == op_ || Token::MINUS == op_) {
|
||||||
if (!operand_->Type()->ToArithm())
|
if (!scalType->ToArithm())
|
||||||
Error(this, "Arithmetic type expected");
|
Error(this, "Arithmetic type expected");
|
||||||
Convert();
|
Convert();
|
||||||
type_ = operand_->Type();
|
type_ = operand_->Type();
|
||||||
} else if ('~' == op_) {
|
} else if ('~' == op_) {
|
||||||
if (!operand_->Type()->IsInteger())
|
if (!scalType->IsInteger())
|
||||||
Error(this, "integer expected for operator '~'");
|
Error(this, "integer expected for operator '~'");
|
||||||
Convert();
|
Convert();
|
||||||
type_ = operand_->Type();
|
type_ = operand_->Type();
|
||||||
} else if (!operand_->Type()->IsScalar()) {
|
} else if (!scalType->IsScalar()) {
|
||||||
Error(this, "arithmetic type or pointer expected for operator '!'");
|
Error(this, "arithmetic type or pointer expected for operator '!'");
|
||||||
} else {
|
} else {
|
||||||
type_ = ArithmType::New(T_INT);
|
type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_INT));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void UnaryOp::CastOpTypeChecking() {
|
void UnaryOp::CastOpTypeChecking() {
|
||||||
auto operandType = Type::MayCast(operand_->Type());
|
auto operandType = Type::MayCast(operand_->Type());
|
||||||
|
|
||||||
// The type_ has been initiated to dest type
|
// The type_ has been initiated to dest type
|
||||||
if (type_->ToVoid()) {
|
if (type_->ToVoid()) {
|
||||||
// The expression becomes a void expression
|
// The expression becomes a void expression
|
||||||
|
} else if(type_->IsTile() || operandType->IsTile()) {
|
||||||
|
/* Broadcasting rules:
|
||||||
|
* 1. Tiles with 1 element can be converted to scalar
|
||||||
|
* 2. Scalar can be converted to tiles of any shapes
|
||||||
|
* 3. Tiles can be converted to another tile only if the
|
||||||
|
* mismatching dimensions are unitary
|
||||||
|
*/
|
||||||
|
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
|
||||||
|
Error(this, "tile with more than one element cannot be casted to scalar");
|
||||||
|
if(type_->IsTile() && operandType->IsTile()){
|
||||||
|
auto shape = type_->ToTile()->Shape();
|
||||||
|
auto operandShape = operandType->ToTile()->Shape();
|
||||||
|
if(operandShape.size() > shape.size())
|
||||||
|
Error(this, "cast cannot reduce operand rank");
|
||||||
|
while(operandShape.size() < shape.size())
|
||||||
|
operandShape.insert(operandShape.begin(), 1);
|
||||||
|
for(size_t i = 0; i < shape.size(); i++) {
|
||||||
|
if(shape[i] != 1 && operandShape[i] != 1 && shape[i] != operandShape[i])
|
||||||
|
Error(this, "cannot broadcast dimension %d "
|
||||||
|
"for operands of shape %d and %d",
|
||||||
|
i, shape[i], operandShape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (!type_->IsScalar() || !operandType->IsScalar()) {
|
} else if (!type_->IsScalar() || !operandType->IsScalar()) {
|
||||||
if (!type_->Compatible(*operandType))
|
if (!type_->Compatible(*operandType))
|
||||||
Error(this, "the cast type should be arithemetic type or pointer");
|
Error(this, "the cast type should be arithemetic type or pointer");
|
||||||
|
@@ -442,11 +442,27 @@ Expr* Parser::ParsePostfixExprTail(Expr* lhs) {
|
|||||||
|
|
||||||
|
|
||||||
Expr* Parser::ParseSubScripting(Expr* lhs) {
|
Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||||
auto rhs = ParseExpr();
|
auto lhsTile = lhs->Type()->ToTile();
|
||||||
auto tok = ts_.Peek();
|
if(lhsTile == nullptr)
|
||||||
|
Error(lhs, "tile expected");
|
||||||
|
TileType::ShapeInt lhsShape = lhsTile->Shape();
|
||||||
|
QualType lhsQual = lhsTile->Derived();
|
||||||
|
// create ret shape
|
||||||
|
TileType::ShapeInt shape;
|
||||||
|
size_t i = 0;
|
||||||
|
do {
|
||||||
|
auto tok = ts_.Next();
|
||||||
|
if(tok->tag_ == ':')
|
||||||
|
shape.push_back(lhsShape[i++]);
|
||||||
|
else if(tok->tag_ == Token::NEWAXIS)
|
||||||
|
shape.push_back(1);
|
||||||
|
else
|
||||||
|
Error(tok, "only ':' and newaxis are supported in subscripts");
|
||||||
|
}while(ts_.Try(','));
|
||||||
ts_.Expect(']');
|
ts_.Expect(']');
|
||||||
auto operand = BinaryOp::New(tok, '+', lhs, rhs);
|
// create ret tile
|
||||||
return UnaryOp::New(Token::DEREF, operand);
|
TileType *retType = TileType::New(shape, lhsQual);
|
||||||
|
return UnaryOp::New(Token::CAST, lhs, retType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -501,6 +517,7 @@ Expr* Parser::ParseUnaryExpr() {
|
|||||||
case '-': return ParseUnaryOp(tok, Token::MINUS);
|
case '-': return ParseUnaryOp(tok, Token::MINUS);
|
||||||
case '~': return ParseUnaryOp(tok, '~');
|
case '~': return ParseUnaryOp(tok, '~');
|
||||||
case '!': return ParseUnaryOp(tok, '!');
|
case '!': return ParseUnaryOp(tok, '!');
|
||||||
|
case '^': return ParseUnaryOp(tok, Token::XOR);
|
||||||
default:
|
default:
|
||||||
ts_.PutBack();
|
ts_.PutBack();
|
||||||
return ParsePostfixExpr();
|
return ParsePostfixExpr();
|
||||||
@@ -584,7 +601,7 @@ Expr* Parser::ParseCastExpr() {
|
|||||||
Expr* Parser::ParseRangeExpr() {
|
Expr* Parser::ParseRangeExpr() {
|
||||||
auto lhs = ParseCastExpr();
|
auto lhs = ParseCastExpr();
|
||||||
auto tok = ts_.Next();
|
auto tok = ts_.Next();
|
||||||
while (tok->tag_ == Token::ELLIPSIS){
|
while (tok->tag_ == Token::ELLIPSIS) {
|
||||||
auto rhs = ParseCastExpr();
|
auto rhs = ParseCastExpr();
|
||||||
lhs = BinaryOp::New(tok, lhs, rhs);
|
lhs = BinaryOp::New(tok, lhs, rhs);
|
||||||
tok = ts_.Next();
|
tok = ts_.Next();
|
||||||
@@ -593,16 +610,26 @@ Expr* Parser::ParseRangeExpr() {
|
|||||||
return lhs;
|
return lhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr* Parser::ParseMultiplicativeExpr() {
|
Expr* Parser::ParseMatmulExpr() {
|
||||||
auto lhs = ParseRangeExpr();
|
auto lhs = ParseRangeExpr();
|
||||||
auto tok = ts_.Next();
|
auto tok = ts_.Next();
|
||||||
while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') {
|
while (tok->tag_ == Token::MATMUL) {
|
||||||
auto rhs = ParseRangeExpr();
|
auto rhs = ParseRangeExpr();
|
||||||
lhs = BinaryOp::New(tok, lhs, rhs);
|
lhs = BinaryOp::New(tok, lhs, rhs);
|
||||||
|
|
||||||
tok = ts_.Next();
|
tok = ts_.Next();
|
||||||
}
|
}
|
||||||
|
ts_.PutBack();
|
||||||
|
return lhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
Expr* Parser::ParseMultiplicativeExpr() {
|
||||||
|
auto lhs = ParseMatmulExpr();
|
||||||
|
auto tok = ts_.Next();
|
||||||
|
while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') {
|
||||||
|
auto rhs = ParseMatmulExpr();
|
||||||
|
lhs = BinaryOp::New(tok, lhs, rhs);
|
||||||
|
tok = ts_.Next();
|
||||||
|
}
|
||||||
ts_.PutBack();
|
ts_.PutBack();
|
||||||
return lhs;
|
return lhs;
|
||||||
}
|
}
|
||||||
|
@@ -27,6 +27,7 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
|
|||||||
{ "inline", Token::INLINE },
|
{ "inline", Token::INLINE },
|
||||||
{ "int", Token::INT },
|
{ "int", Token::INT },
|
||||||
{ "long", Token::LONG },
|
{ "long", Token::LONG },
|
||||||
|
{ "newaxis", Token::NEWAXIS },
|
||||||
{ "signed", Token::SIGNED },
|
{ "signed", Token::SIGNED },
|
||||||
{ "unsigned", Token::UNSIGNED },
|
{ "unsigned", Token::UNSIGNED },
|
||||||
{ "register", Token::REGISTER },
|
{ "register", Token::REGISTER },
|
||||||
@@ -126,6 +127,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
|
|||||||
{ Token::INLINE, "inline" },
|
{ Token::INLINE, "inline" },
|
||||||
{ Token::INT, "int" },
|
{ Token::INT, "int" },
|
||||||
{ Token::LONG, "long" },
|
{ Token::LONG, "long" },
|
||||||
|
{ Token::NEWAXIS, "newaxis" },
|
||||||
{ Token::SIGNED, "signed" },
|
{ Token::SIGNED, "signed" },
|
||||||
{ Token::UNSIGNED, "unsigned" },
|
{ Token::UNSIGNED, "unsigned" },
|
||||||
{ Token::REGISTER, "register" },
|
{ Token::REGISTER, "register" },
|
||||||
|
@@ -32,6 +32,18 @@ QualType Type::MayCast(QualType type, bool inProtoScope) {
|
|||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const Type* Type::ScalarType() const {
|
||||||
|
if(IsScalar())
|
||||||
|
return this;
|
||||||
|
if(const TileType* p = ToTile())
|
||||||
|
return p->Derived().GetPtr();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type* Type::ScalarType() {
|
||||||
|
auto cthis = const_cast<const Type*>(this);
|
||||||
|
return const_cast<Type*>(cthis->ScalarType());
|
||||||
|
}
|
||||||
|
|
||||||
VoidType* VoidType::New() {
|
VoidType* VoidType::New() {
|
||||||
static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool);
|
static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool);
|
||||||
@@ -143,12 +155,16 @@ int ArithmType::Width() const {
|
|||||||
return intWidth_ << 1;
|
return intWidth_ << 1;
|
||||||
case T_LLONG: case T_UNSIGNED | T_LLONG:
|
case T_LLONG: case T_UNSIGNED | T_LLONG:
|
||||||
return intWidth_ << 1;
|
return intWidth_ << 1;
|
||||||
|
case T_HALF:
|
||||||
|
return intWidth_ >> 1;
|
||||||
case T_FLOAT:
|
case T_FLOAT:
|
||||||
return intWidth_;
|
return intWidth_;
|
||||||
case T_DOUBLE:
|
case T_DOUBLE:
|
||||||
return intWidth_ << 1;
|
return intWidth_ << 1;
|
||||||
case T_LONG | T_DOUBLE:
|
case T_LONG | T_DOUBLE:
|
||||||
return intWidth_ << 1;
|
return intWidth_ << 1;
|
||||||
|
case T_HALF | T_COMPLEX:
|
||||||
|
return intWidth_;
|
||||||
case T_FLOAT | T_COMPLEX:
|
case T_FLOAT | T_COMPLEX:
|
||||||
return intWidth_ << 1;
|
return intWidth_ << 1;
|
||||||
case T_DOUBLE | T_COMPLEX:
|
case T_DOUBLE | T_COMPLEX:
|
||||||
@@ -171,9 +187,10 @@ int ArithmType::Rank() const {
|
|||||||
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3;
|
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3;
|
||||||
case T_LONG: case T_UNSIGNED | T_LONG: return 4;
|
case T_LONG: case T_UNSIGNED | T_LONG: return 4;
|
||||||
case T_LLONG: case T_UNSIGNED | T_LLONG: return 5;
|
case T_LLONG: case T_UNSIGNED | T_LLONG: return 5;
|
||||||
case T_FLOAT: return 6;
|
case T_HALF: return 6;
|
||||||
case T_DOUBLE: return 7;
|
case T_FLOAT: return 7;
|
||||||
case T_LONG | T_DOUBLE: return 8;
|
case T_DOUBLE: return 8;
|
||||||
|
case T_LONG | T_DOUBLE: return 9;
|
||||||
default:
|
default:
|
||||||
assert(tag_ & T_COMPLEX);
|
assert(tag_ & T_COMPLEX);
|
||||||
Error("complex not supported yet");
|
Error("complex not supported yet");
|
||||||
|
Reference in New Issue
Block a user