basic parsing doesn't throw error
This commit is contained in:
@@ -55,8 +55,8 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
std::string usea = AT ? "^a" : "a";
|
||||
std::string useb = BT ? "^b" : "b";
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
@@ -82,6 +82,11 @@ R"(
|
||||
#define TN 128
|
||||
#define TK 32
|
||||
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
|
||||
extern int get_program_id(int);
|
||||
|
||||
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||
@@ -94,28 +99,28 @@ void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
|
||||
int ryb[TN] = ridy * TN + 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
|
||||
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
||||
int rka[{TK}] = 0 ... TK;
|
||||
int rkb[{TK}] = 0 ... TK;
|
||||
float xc[{)" + XCS + R"(}] = 0;
|
||||
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
||||
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
||||
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[{TM, TN}] = xc;
|
||||
bool checkc0[{TM}] = rxc < M;
|
||||
bool checkc1[{TN}] = ryc < N;
|
||||
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
@@ -278,6 +278,9 @@ public:
|
||||
|
||||
static Expr* MayCast(Expr* expr);
|
||||
static Expr* MayCast(Expr* expr, QualType desType);
|
||||
static ::Type* TryExtractScalarType(Expr* loc, Expr *operand);
|
||||
static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty);
|
||||
|
||||
virtual bool IsNullPointerConstant() const { return false; }
|
||||
bool IsConstQualified() const { return type_.IsConstQualified(); }
|
||||
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
|
||||
@@ -332,6 +335,7 @@ public:
|
||||
void AdditiveOpTypeChecking();
|
||||
void ShiftOpTypeChecking();
|
||||
void RangeOpTypeChecking();
|
||||
void MatmulOpTypeChecking();
|
||||
void RelationalOpTypeChecking();
|
||||
void EqualityOpTypeChecking();
|
||||
void BitwiseOpTypeChecking();
|
||||
@@ -378,11 +382,12 @@ public:
|
||||
virtual ~UnaryOp() {}
|
||||
virtual void Accept(Visitor* v);
|
||||
virtual bool IsLVal();
|
||||
ArithmType* Convert();
|
||||
::Type *Convert();
|
||||
void TypeChecking();
|
||||
void IncDecOpTypeChecking();
|
||||
void AddrOpTypeChecking();
|
||||
void DerefOpTypeChecking();
|
||||
void TransOpTypeChecking();
|
||||
void UnaryArithmOpTypeChecking();
|
||||
void CastOpTypeChecking();
|
||||
|
||||
|
@@ -75,6 +75,7 @@ public:
|
||||
QualType ParseTypeName();
|
||||
Expr* ParseCastExpr();
|
||||
Expr* ParseRangeExpr();
|
||||
Expr* ParseMatmulExpr();
|
||||
Expr* ParseMultiplicativeExpr();
|
||||
Expr* ParseAdditiveExpr();
|
||||
Expr* ParseShiftExpr();
|
||||
|
@@ -64,7 +64,7 @@ public:
|
||||
NOT = '!',
|
||||
COND = '?',
|
||||
SHARP = '#',
|
||||
AT = '@',
|
||||
MATMUL = '@',
|
||||
NEW_LINE = '\n',
|
||||
|
||||
DSHARP = 128, // '##'
|
||||
@@ -126,6 +126,10 @@ public:
|
||||
NORETURN, // _Noreturn
|
||||
// FUNCTION SPECIFIER END
|
||||
|
||||
// TILE ARITHMETICS BEGIN
|
||||
NEWAXIS,
|
||||
// TILE ARITHMETICS END
|
||||
|
||||
ALIGNAS, // _Alignas
|
||||
// For syntactic convenience
|
||||
STATIC_ASSERT, // _Static_assert
|
||||
|
@@ -153,6 +153,10 @@ public:
|
||||
virtual bool IsBool() const { return false; }
|
||||
virtual bool IsVoidPointer() const { return false; }
|
||||
virtual bool IsUnsigned() const { return false; }
|
||||
virtual bool IsTile() const { return ToTile() != nullptr; }
|
||||
|
||||
const Type* ScalarType() const;
|
||||
Type* ScalarType();
|
||||
|
||||
virtual VoidType* ToVoid() { return nullptr; }
|
||||
virtual const VoidType* ToVoid() const { return nullptr; }
|
||||
@@ -327,16 +331,22 @@ public:
|
||||
static TileType* New(const ShapeInt& shape, QualType eleType);
|
||||
virtual ~TileType() { }
|
||||
|
||||
virtual TileType* toTile() { return this; }
|
||||
virtual const TileType* toTile() const { return this; }
|
||||
virtual TileType* ToTile() { return this; }
|
||||
virtual const TileType* ToTile() const { return this; }
|
||||
virtual bool Compatible(const Type& other) const;
|
||||
virtual int Width() const { return 0; }
|
||||
virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; }
|
||||
virtual int Align() const { return derived_->Align(); }
|
||||
virtual std::string Str() const {
|
||||
return derived_->Str() + "[{}]:" + std::to_string(Width());
|
||||
}
|
||||
|
||||
ShapeInt Shape() { return shape_; }
|
||||
int NumEle() const {
|
||||
int ret = 1;
|
||||
for(int s: shape_)
|
||||
ret *= s;
|
||||
return ret;
|
||||
}
|
||||
|
||||
protected:
|
||||
TileType(MemPool* pool, const ShapeExpr& expr, QualType derived)
|
||||
|
@@ -144,6 +144,26 @@ Expr* Expr::MayCast(Expr* expr, QualType desType) {
|
||||
return expr;
|
||||
}
|
||||
|
||||
// Extract the operand's scalar type if possible
|
||||
// and emit an error otherwise
|
||||
::Type* Expr::TryExtractScalarType(Expr* loc, Expr *operand) {
|
||||
auto scalType = operand->Type()->ScalarType();
|
||||
if(!scalType)
|
||||
Error(loc, "expect tile or scalar operand");
|
||||
return scalType;
|
||||
}
|
||||
|
||||
// If operand is a tile, return a tile of the same shape and
|
||||
// provided element type
|
||||
// If operand is a scalar, return provided element type
|
||||
// directly
|
||||
::Type* Expr::ScalarOrLikeTile(Expr* operand, ::Type* ty) {
|
||||
assert(ty->IsScalar());
|
||||
::Type *retTy = ty;
|
||||
if(TileType *T = operand->Type()->ToTile())
|
||||
retTy = TileType::New(T->Shape(), retTy);
|
||||
return retTy;
|
||||
}
|
||||
|
||||
BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) {
|
||||
return New(tok, tok->tag_, lhs, rhs);
|
||||
@@ -166,6 +186,7 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
|
||||
case Token::LOGICAL_AND:
|
||||
case Token::LOGICAL_OR:
|
||||
case Token::ELLIPSIS:
|
||||
case Token::MATMUL:
|
||||
break;
|
||||
default:
|
||||
assert(0);
|
||||
@@ -180,18 +201,18 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
|
||||
|
||||
|
||||
ArithmType* BinaryOp::Convert() {
|
||||
// Both lhs and rhs are ensured to be have arithmetic type
|
||||
auto lhsType = lhs_->Type()->ToArithm();
|
||||
auto rhsType = rhs_->Type()->ToArithm();
|
||||
// Both lhs and rhs are ensured to be have arithmetic scalar type
|
||||
auto lhsType = lhs_->Type()->ScalarType()->ToArithm();
|
||||
auto rhsType = rhs_->Type()->ScalarType()->ToArithm();
|
||||
assert(lhsType && rhsType);
|
||||
auto type = ArithmType::MaxType(lhsType, rhsType);
|
||||
if (lhsType != type) { // Pointer comparation is enough!
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, type);
|
||||
auto maxType = ArithmType::MaxType(lhsType, rhsType);
|
||||
if (lhsType != maxType) { // Pointer comparation is enough!
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, ScalarOrLikeTile(lhs_, maxType));
|
||||
}
|
||||
if (rhsType != type) {
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, type);
|
||||
if (rhsType != maxType) {
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, ScalarOrLikeTile(rhs_, maxType));
|
||||
}
|
||||
return type;
|
||||
return maxType;
|
||||
}
|
||||
|
||||
void BinaryOp::Broadcast() {
|
||||
@@ -225,6 +246,8 @@ void BinaryOp::Broadcast() {
|
||||
retShape[i] = rhsShape[i];
|
||||
else if(rhsShape[i] == 1)
|
||||
retShape[i] = lhsShape[i];
|
||||
else if(lhsShape[i] == rhsShape[i])
|
||||
retShape[i] = lhsShape[i];
|
||||
else
|
||||
Error(this, "cannot broadcast dimension %d "
|
||||
"for operands of shape %d and %d",
|
||||
@@ -232,8 +255,10 @@ void BinaryOp::Broadcast() {
|
||||
}
|
||||
auto eleType = lhsType->Derived();
|
||||
type_ = TileType::New(retShape, eleType);
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
|
||||
if(retShape != lhsShape)
|
||||
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
|
||||
if(retShape != rhsShape)
|
||||
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,6 +328,9 @@ void BinaryOp::TypeChecking() {
|
||||
case Token::ELLIPSIS:
|
||||
return RangeOpTypeChecking();
|
||||
|
||||
case Token::MATMUL:
|
||||
return MatmulOpTypeChecking();
|
||||
|
||||
default:
|
||||
assert(0);
|
||||
}
|
||||
@@ -315,12 +343,15 @@ void BinaryOp::CommaOpTypeChecking() {
|
||||
|
||||
|
||||
void BinaryOp::SubScriptingOpTypeChecking() {
|
||||
auto lhsType = lhs_->Type()->ToPointer();
|
||||
assert(false);
|
||||
auto lhsType = lhs_->Type()->ToTile();
|
||||
|
||||
if (!lhsType) {
|
||||
Error(this, "an pointer expected");
|
||||
Error(this, "operator [] can only be used on tiles");
|
||||
}
|
||||
|
||||
if (!rhs_->Type()->IsInteger()) {
|
||||
Error(this, "the operand of [] should be intger");
|
||||
Error(this, "the operand of [] should be integer");
|
||||
}
|
||||
|
||||
// The type of [] operator is the derived type
|
||||
@@ -334,14 +365,20 @@ void BinaryOp::MemberRefOpTypeChecking() {
|
||||
|
||||
|
||||
void BinaryOp::MultiOpTypeChecking() {
|
||||
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
|
||||
::Type* lhsScalType = lhs_->Type()->ScalarType();
|
||||
::Type* rhsScalType = rhs_->Type()->ScalarType();
|
||||
if(!lhsScalType || !rhsScalType) {
|
||||
Error(this, "operands should have type or scalar type");
|
||||
}
|
||||
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
||||
Error(this, "operands should have arithmetic type");
|
||||
}
|
||||
if ('%' == op_ &&
|
||||
!(lhs_->Type()->IsInteger() && rhs_->Type()->IsInteger())) {
|
||||
!(lhsScalType->IsInteger() && rhsScalType->IsInteger())) {
|
||||
Error(this, "operands of '%%' should be integers");
|
||||
}
|
||||
type_ = Convert();
|
||||
Broadcast();
|
||||
}
|
||||
|
||||
|
||||
@@ -351,40 +388,47 @@ void BinaryOp::MultiOpTypeChecking() {
|
||||
* 2. pointer can be used:
|
||||
* 1. lhs of MINUS operator, and rhs must be integer or pointer;
|
||||
* 2. lhs/rhs of ADD operator, and the other operand must be integer;
|
||||
* 3. tiles can be used:
|
||||
* 1. the scalar type of lhs/rhs satisfy the above requirements
|
||||
* 2. lhs/rhs that have identical shape
|
||||
* 3. lhs/rhs that can be broadcast as per numpy-like semantics
|
||||
*/
|
||||
void BinaryOp::AdditiveOpTypeChecking() {
|
||||
auto lhsType = lhs_->Type()->ToPointer();
|
||||
auto rhsType = rhs_->Type()->ToPointer();
|
||||
if (lhsType) {
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
auto lhsPtrType = lhsScalType->ToPointer();
|
||||
auto rhsPtrType = rhsScalType->ToPointer();
|
||||
if (lhsPtrType) {
|
||||
if (op_ == '-') {
|
||||
if (rhsType) {
|
||||
if (!lhsType->Compatible(*rhsType))
|
||||
if (rhsPtrType) {
|
||||
if (!lhsPtrType->Compatible(*rhsPtrType))
|
||||
Error(this, "invalid operands to binary -");
|
||||
type_ = ArithmType::New(T_LONG); // ptrdiff_t
|
||||
} else if (!rhs_->Type()->IsInteger()) {
|
||||
} else if (!rhsScalType->IsInteger()) {
|
||||
Error(this, "invalid operands to binary -");
|
||||
} else {
|
||||
type_ = lhsType;
|
||||
type_ = lhsPtrType;
|
||||
}
|
||||
} else if (!rhs_->Type()->IsInteger()) {
|
||||
} else if (!rhsScalType->IsInteger()) {
|
||||
Error(this, "invalid operands to binary +");
|
||||
} else {
|
||||
type_ = lhsType;
|
||||
type_ = lhsPtrType;
|
||||
}
|
||||
} else if (rhsType) {
|
||||
if (op_ == '+' && !lhs_->Type()->IsInteger()) {
|
||||
} else if (rhsPtrType) {
|
||||
if (op_ == '+' && !lhsScalType->IsInteger()) {
|
||||
Error(this, "invalid operands to binary '+'");
|
||||
} else if (op_ == '-' && !lhsType) {
|
||||
} else if (op_ == '-' && !lhsPtrType) {
|
||||
Error(this, "invalid operands to binary '-'");
|
||||
}
|
||||
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhs_->Type();
|
||||
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhsScalType;
|
||||
std::swap(lhs_, rhs_); // To simplify code gen
|
||||
} else {
|
||||
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
|
||||
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
||||
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
||||
}
|
||||
type_ = Convert();
|
||||
}
|
||||
Broadcast();
|
||||
}
|
||||
|
||||
void BinaryOp::RangeOpTypeChecking() {
|
||||
@@ -396,59 +440,95 @@ void BinaryOp::RangeOpTypeChecking() {
|
||||
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
|
||||
long begin = Evaluator<long>().Eval(lhs_);
|
||||
long end = Evaluator<long>().Eval(rhs_);
|
||||
int len = end - begin;
|
||||
int len = static_cast<int>(end - begin);
|
||||
if(len < 0)
|
||||
Error(this, "range cannot be negative");
|
||||
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
|
||||
}
|
||||
|
||||
void BinaryOp::MatmulOpTypeChecking() {
|
||||
auto lhsType = lhs_->Type()->ToTile();
|
||||
auto rhsType = rhs_->Type()->ToTile();
|
||||
if(!lhsType || !rhsType)
|
||||
Error(this, "expect tile operands for matrix multiplication");
|
||||
auto lhsShape = lhsType->Shape();
|
||||
auto rhsShape = rhsType->Shape();
|
||||
size_t lhsRank = lhsShape.size();
|
||||
size_t rhsRank = rhsShape.size();
|
||||
if(lhsRank != 2 || rhsRank != 2)
|
||||
Error(this, "matrix multiplication operands must have rank 2");
|
||||
if(lhsShape[1] != rhsShape[0])
|
||||
Error(this, "matrix multiplication operands have incompatible inner dimension"
|
||||
" %d and %d", lhsShape[1], rhsShape[0]);
|
||||
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
|
||||
QualType retType = lhsType->Derived();
|
||||
if(retType != rhsType->Derived())
|
||||
Error(this, "matrix multiplication operands have incompatible data types");
|
||||
type_ = TileType::New(retShape, lhsType->Derived());
|
||||
}
|
||||
|
||||
void BinaryOp::ShiftOpTypeChecking() {
|
||||
auto lhsType = lhs_->Type()->ToArithm();
|
||||
auto rhsType = rhs_->Type()->ToArithm();
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
auto lhsType = lhsScalType->ToArithm();
|
||||
auto rhsType = rhsScalType->ToArithm();
|
||||
if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
|
||||
Error(this, "expect integers for shift operator");
|
||||
lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType));
|
||||
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
|
||||
lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType)));
|
||||
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType)));
|
||||
type_ = lhs_->Type();
|
||||
Broadcast();
|
||||
}
|
||||
|
||||
|
||||
void BinaryOp::RelationalOpTypeChecking() {
|
||||
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) {
|
||||
EnsureCompatible(lhs_->Type(), rhs_->Type());
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
|
||||
EnsureCompatible(lhsScalType, rhsScalType);
|
||||
} else {
|
||||
if (!lhs_->Type()->IsReal() || !rhs_->Type()->IsReal()) {
|
||||
if (!lhsScalType->IsReal() || !rhsScalType->IsReal()) {
|
||||
Error(this, "expect real type of operands");
|
||||
}
|
||||
Convert();
|
||||
}
|
||||
type_ = ArithmType::New(T_INT);
|
||||
Broadcast();
|
||||
}
|
||||
|
||||
|
||||
void BinaryOp::EqualityOpTypeChecking() {
|
||||
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) {
|
||||
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type());
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
|
||||
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
|
||||
} else {
|
||||
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm())
|
||||
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm())
|
||||
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
||||
Convert();
|
||||
}
|
||||
type_ = ArithmType::New(T_INT);
|
||||
Broadcast();
|
||||
}
|
||||
|
||||
|
||||
void BinaryOp::BitwiseOpTypeChecking() {
|
||||
if (!lhs_->Type()->IsInteger() || !rhs_->Type()->IsInteger())
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger())
|
||||
Error(this, "operands of '&' should be integer");
|
||||
type_ = Convert();
|
||||
Broadcast();
|
||||
}
|
||||
|
||||
|
||||
void BinaryOp::LogicalOpTypeChecking() {
|
||||
if (!lhs_->Type()->IsScalar() || !rhs_->Type()->IsScalar())
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
|
||||
Error(this, "the operand should be arithmetic type or pointer");
|
||||
type_ = ArithmType::New(T_INT);
|
||||
Broadcast();
|
||||
}
|
||||
|
||||
|
||||
@@ -459,12 +539,14 @@ void BinaryOp::AssignOpTypeChecking() {
|
||||
Error(lhs_, "lvalue expression expected");
|
||||
}
|
||||
|
||||
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
|
||||
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type());
|
||||
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
||||
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
|
||||
}
|
||||
|
||||
// The other constraints are lefted to cast operator
|
||||
rhs_ = Expr::MayCast(rhs_, lhs_->Type());
|
||||
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
|
||||
type_ = lhs_->Type();
|
||||
}
|
||||
|
||||
@@ -488,13 +570,16 @@ bool UnaryOp::IsLVal() {
|
||||
}
|
||||
|
||||
|
||||
ArithmType* UnaryOp::Convert() {
|
||||
auto arithmType = operand_->Type()->ToArithm();
|
||||
::Type* UnaryOp::Convert() {
|
||||
auto scalType = operand_->Type()->ScalarType();
|
||||
assert(scalType);
|
||||
auto arithmType = scalType->ToArithm();
|
||||
assert(arithmType);
|
||||
if (arithmType->IsInteger())
|
||||
arithmType = ArithmType::IntegerPromote(arithmType);
|
||||
operand_ = Expr::MayCast(operand_, arithmType);
|
||||
return arithmType;
|
||||
::Type* retType = ScalarOrLikeTile(operand_, arithmType);
|
||||
operand_ = Expr::MayCast(operand_, retType);
|
||||
return retType;
|
||||
}
|
||||
|
||||
|
||||
@@ -521,20 +606,22 @@ void UnaryOp::TypeChecking() {
|
||||
case Token::CAST:
|
||||
return CastOpTypeChecking();
|
||||
|
||||
case '^':
|
||||
return TransOpTypeChecking();
|
||||
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void UnaryOp::IncDecOpTypeChecking() {
|
||||
if (operand_->IsConstQualified()) {
|
||||
Error(this, "increment/decrement of const qualified expression");
|
||||
} else if (!operand_->IsLVal()) {
|
||||
Error(this, "lvalue expression expected");
|
||||
}
|
||||
|
||||
if (!operand_->Type()->IsReal() && !operand_->Type()->ToPointer()) {
|
||||
auto scalType = TryExtractScalarType(this, operand_);
|
||||
if (!scalType->IsReal() && !scalType->ToPointer()) {
|
||||
Error(this, "expect operand of real type or pointer");
|
||||
}
|
||||
type_ = operand_->Type();
|
||||
@@ -545,43 +632,78 @@ void UnaryOp::AddrOpTypeChecking() {
|
||||
auto funcType = operand_->Type()->ToFunc();
|
||||
if (funcType == nullptr && !operand_->IsLVal())
|
||||
Error(this, "expression must be an lvalue or function designator");
|
||||
if(operand_->Type()->IsTile())
|
||||
Error(this, "cannot take the address of a tile");
|
||||
type_ = PointerType::New(operand_->Type());
|
||||
}
|
||||
|
||||
|
||||
void UnaryOp::DerefOpTypeChecking() {
|
||||
auto pointerType = operand_->Type()->ToPointer();
|
||||
auto scalType = TryExtractScalarType(this, operand_);
|
||||
auto pointerType = scalType->ToPointer();
|
||||
if (!pointerType)
|
||||
Error(this, "pointer expected for deref operator '*'");
|
||||
type_ = pointerType->Derived();
|
||||
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
|
||||
}
|
||||
|
||||
|
||||
void UnaryOp::TransOpTypeChecking() {
|
||||
auto tileType = operand_->Type()->ToTile();
|
||||
if(!tileType)
|
||||
Error(this, "tile expected for transposition operator '^'");
|
||||
auto shape = tileType->Shape();
|
||||
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
|
||||
type_ = TileType::New(shape, tileType->Derived());
|
||||
}
|
||||
|
||||
void UnaryOp::UnaryArithmOpTypeChecking() {
|
||||
auto scalType = TryExtractScalarType(this, operand_);
|
||||
if (Token::PLUS == op_ || Token::MINUS == op_) {
|
||||
if (!operand_->Type()->ToArithm())
|
||||
if (!scalType->ToArithm())
|
||||
Error(this, "Arithmetic type expected");
|
||||
Convert();
|
||||
type_ = operand_->Type();
|
||||
} else if ('~' == op_) {
|
||||
if (!operand_->Type()->IsInteger())
|
||||
if (!scalType->IsInteger())
|
||||
Error(this, "integer expected for operator '~'");
|
||||
Convert();
|
||||
type_ = operand_->Type();
|
||||
} else if (!operand_->Type()->IsScalar()) {
|
||||
} else if (!scalType->IsScalar()) {
|
||||
Error(this, "arithmetic type or pointer expected for operator '!'");
|
||||
} else {
|
||||
type_ = ArithmType::New(T_INT);
|
||||
type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_INT));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void UnaryOp::CastOpTypeChecking() {
|
||||
auto operandType = Type::MayCast(operand_->Type());
|
||||
|
||||
// The type_ has been initiated to dest type
|
||||
if (type_->ToVoid()) {
|
||||
// The expression becomes a void expression
|
||||
} else if(type_->IsTile() || operandType->IsTile()) {
|
||||
/* Broadcasting rules:
|
||||
* 1. Tiles with 1 element can be converted to scalar
|
||||
* 2. Scalar can be converted to tiles of any shapes
|
||||
* 3. Tiles can be converted to another tile only if the
|
||||
* mismatching dimensions are unitary
|
||||
*/
|
||||
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
|
||||
Error(this, "tile with more than one element cannot be casted to scalar");
|
||||
if(type_->IsTile() && operandType->IsTile()){
|
||||
auto shape = type_->ToTile()->Shape();
|
||||
auto operandShape = operandType->ToTile()->Shape();
|
||||
if(operandShape.size() > shape.size())
|
||||
Error(this, "cast cannot reduce operand rank");
|
||||
while(operandShape.size() < shape.size())
|
||||
operandShape.insert(operandShape.begin(), 1);
|
||||
for(size_t i = 0; i < shape.size(); i++) {
|
||||
if(shape[i] != 1 && operandShape[i] != 1 && shape[i] != operandShape[i])
|
||||
Error(this, "cannot broadcast dimension %d "
|
||||
"for operands of shape %d and %d",
|
||||
i, shape[i], operandShape[i]);
|
||||
}
|
||||
}
|
||||
} else if (!type_->IsScalar() || !operandType->IsScalar()) {
|
||||
if (!type_->Compatible(*operandType))
|
||||
Error(this, "the cast type should be arithemetic type or pointer");
|
||||
|
@@ -442,11 +442,27 @@ Expr* Parser::ParsePostfixExprTail(Expr* lhs) {
|
||||
|
||||
|
||||
Expr* Parser::ParseSubScripting(Expr* lhs) {
|
||||
auto rhs = ParseExpr();
|
||||
auto tok = ts_.Peek();
|
||||
auto lhsTile = lhs->Type()->ToTile();
|
||||
if(lhsTile == nullptr)
|
||||
Error(lhs, "tile expected");
|
||||
TileType::ShapeInt lhsShape = lhsTile->Shape();
|
||||
QualType lhsQual = lhsTile->Derived();
|
||||
// create ret shape
|
||||
TileType::ShapeInt shape;
|
||||
size_t i = 0;
|
||||
do {
|
||||
auto tok = ts_.Next();
|
||||
if(tok->tag_ == ':')
|
||||
shape.push_back(lhsShape[i++]);
|
||||
else if(tok->tag_ == Token::NEWAXIS)
|
||||
shape.push_back(1);
|
||||
else
|
||||
Error(tok, "only ':' and newaxis are supported in subscripts");
|
||||
}while(ts_.Try(','));
|
||||
ts_.Expect(']');
|
||||
auto operand = BinaryOp::New(tok, '+', lhs, rhs);
|
||||
return UnaryOp::New(Token::DEREF, operand);
|
||||
// create ret tile
|
||||
TileType *retType = TileType::New(shape, lhsQual);
|
||||
return UnaryOp::New(Token::CAST, lhs, retType);
|
||||
}
|
||||
|
||||
|
||||
@@ -501,6 +517,7 @@ Expr* Parser::ParseUnaryExpr() {
|
||||
case '-': return ParseUnaryOp(tok, Token::MINUS);
|
||||
case '~': return ParseUnaryOp(tok, '~');
|
||||
case '!': return ParseUnaryOp(tok, '!');
|
||||
case '^': return ParseUnaryOp(tok, Token::XOR);
|
||||
default:
|
||||
ts_.PutBack();
|
||||
return ParsePostfixExpr();
|
||||
@@ -584,7 +601,7 @@ Expr* Parser::ParseCastExpr() {
|
||||
Expr* Parser::ParseRangeExpr() {
|
||||
auto lhs = ParseCastExpr();
|
||||
auto tok = ts_.Next();
|
||||
while (tok->tag_ == Token::ELLIPSIS){
|
||||
while (tok->tag_ == Token::ELLIPSIS) {
|
||||
auto rhs = ParseCastExpr();
|
||||
lhs = BinaryOp::New(tok, lhs, rhs);
|
||||
tok = ts_.Next();
|
||||
@@ -593,16 +610,26 @@ Expr* Parser::ParseRangeExpr() {
|
||||
return lhs;
|
||||
}
|
||||
|
||||
Expr* Parser::ParseMultiplicativeExpr() {
|
||||
Expr* Parser::ParseMatmulExpr() {
|
||||
auto lhs = ParseRangeExpr();
|
||||
auto tok = ts_.Next();
|
||||
while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') {
|
||||
while (tok->tag_ == Token::MATMUL) {
|
||||
auto rhs = ParseRangeExpr();
|
||||
lhs = BinaryOp::New(tok, lhs, rhs);
|
||||
|
||||
tok = ts_.Next();
|
||||
}
|
||||
ts_.PutBack();
|
||||
return lhs;
|
||||
}
|
||||
|
||||
Expr* Parser::ParseMultiplicativeExpr() {
|
||||
auto lhs = ParseMatmulExpr();
|
||||
auto tok = ts_.Next();
|
||||
while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') {
|
||||
auto rhs = ParseMatmulExpr();
|
||||
lhs = BinaryOp::New(tok, lhs, rhs);
|
||||
tok = ts_.Next();
|
||||
}
|
||||
ts_.PutBack();
|
||||
return lhs;
|
||||
}
|
||||
|
@@ -27,6 +27,7 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
|
||||
{ "inline", Token::INLINE },
|
||||
{ "int", Token::INT },
|
||||
{ "long", Token::LONG },
|
||||
{ "newaxis", Token::NEWAXIS },
|
||||
{ "signed", Token::SIGNED },
|
||||
{ "unsigned", Token::UNSIGNED },
|
||||
{ "register", Token::REGISTER },
|
||||
@@ -126,6 +127,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
|
||||
{ Token::INLINE, "inline" },
|
||||
{ Token::INT, "int" },
|
||||
{ Token::LONG, "long" },
|
||||
{ Token::NEWAXIS, "newaxis" },
|
||||
{ Token::SIGNED, "signed" },
|
||||
{ Token::UNSIGNED, "unsigned" },
|
||||
{ Token::REGISTER, "register" },
|
||||
|
@@ -32,6 +32,18 @@ QualType Type::MayCast(QualType type, bool inProtoScope) {
|
||||
return type;
|
||||
}
|
||||
|
||||
const Type* Type::ScalarType() const {
|
||||
if(IsScalar())
|
||||
return this;
|
||||
if(const TileType* p = ToTile())
|
||||
return p->Derived().GetPtr();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type* Type::ScalarType() {
|
||||
auto cthis = const_cast<const Type*>(this);
|
||||
return const_cast<Type*>(cthis->ScalarType());
|
||||
}
|
||||
|
||||
VoidType* VoidType::New() {
|
||||
static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool);
|
||||
@@ -143,12 +155,16 @@ int ArithmType::Width() const {
|
||||
return intWidth_ << 1;
|
||||
case T_LLONG: case T_UNSIGNED | T_LLONG:
|
||||
return intWidth_ << 1;
|
||||
case T_HALF:
|
||||
return intWidth_ >> 1;
|
||||
case T_FLOAT:
|
||||
return intWidth_;
|
||||
case T_DOUBLE:
|
||||
return intWidth_ << 1;
|
||||
case T_LONG | T_DOUBLE:
|
||||
return intWidth_ << 1;
|
||||
case T_HALF | T_COMPLEX:
|
||||
return intWidth_;
|
||||
case T_FLOAT | T_COMPLEX:
|
||||
return intWidth_ << 1;
|
||||
case T_DOUBLE | T_COMPLEX:
|
||||
@@ -171,9 +187,10 @@ int ArithmType::Rank() const {
|
||||
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3;
|
||||
case T_LONG: case T_UNSIGNED | T_LONG: return 4;
|
||||
case T_LLONG: case T_UNSIGNED | T_LLONG: return 5;
|
||||
case T_FLOAT: return 6;
|
||||
case T_DOUBLE: return 7;
|
||||
case T_LONG | T_DOUBLE: return 8;
|
||||
case T_HALF: return 6;
|
||||
case T_FLOAT: return 7;
|
||||
case T_DOUBLE: return 8;
|
||||
case T_LONG | T_DOUBLE: return 9;
|
||||
default:
|
||||
assert(tag_ & T_COMPLEX);
|
||||
Error("complex not supported yet");
|
||||
|
Reference in New Issue
Block a user