basic parsing doesn't throw error

This commit is contained in:
Philippe Tillet
2019-08-20 16:22:43 -07:00
parent bc11e31419
commit 61f25f90eb
9 changed files with 287 additions and 94 deletions

View File

@@ -55,8 +55,8 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = ""; std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb"; std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT ? "trans(a)" : "a"; std::string usea = AT ? "^a" : "a";
std::string useb = BT ? "trans(b)" : "b"; std::string useb = BT ? "^b" : "b";
if(AT){ if(AT){
std::swap(AS0, AS1); std::swap(AS0, AS1);
std::swap(XAS0, XAS1); std::swap(XAS0, XAS1);
@@ -82,6 +82,11 @@ R"(
#define TN 128 #define TN 128
#define TK 32 #define TK 32
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
extern int get_program_id(int); extern int get_program_id(int);
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))), void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
@@ -94,28 +99,28 @@ void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
int ridx = get_program_id(0); int ridx = get_program_id(0);
int ridy = get_program_id(1); int ridy = get_program_id(1);
int rxa[{TM, TN}] = ridx * TM + 0 ... TM; int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
int ryb[TN] = ridy * TN + 0 ... TN; int ryb[{TN}] = ridy * TN + 0 ... TN;
int rka[TK] = 0 ... TK; int rka[{TK}] = 0 ... TK;
int rkb[TK] = 0 ... TK; int rkb[{TK}] = 0 ... TK;
float xc[)" + XCS + R"(] = 0; float xc[{)" + XCS + R"(}] = 0;
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; )" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; )" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[)" + AS + R"(] = *pa; )" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
)" + b_ty + R"( b[)" + BS + R"(] = *pb; )" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
for(int k = K; k > 0; k = k - TK){ for(int k = K; k > 0; k = k - TK){
xc = dot()" + usea + ", " + useb + R"(, xc); xc = )" + usea + " @ " + useb + R"( + xc;
pa = pa + TK)" + lda0 + R"(; pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(; pb = pb + TK)" + ldb0 + R"(;
a = *pa; a = *pa;
b = *pb; b = *pb;
} }
int rxc[TM] = ridx * TM + (0 ... TM); int rxc[{TM}] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN); int ryc[{TN}] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; )" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[TM, TN] = xc; )" + c_ty + R"( c[{TM, TN}] = xc;
bool checkc0[TM] = rxc < M; bool checkc0[{TM}] = rxc < M;
bool checkc1[TN] = ryc < N; bool checkc1[{TN}] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
*pc = c; *pc = c;
} }
)"; )";

View File

@@ -278,6 +278,9 @@ public:
static Expr* MayCast(Expr* expr); static Expr* MayCast(Expr* expr);
static Expr* MayCast(Expr* expr, QualType desType); static Expr* MayCast(Expr* expr, QualType desType);
static ::Type* TryExtractScalarType(Expr* loc, Expr *operand);
static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty);
virtual bool IsNullPointerConstant() const { return false; } virtual bool IsNullPointerConstant() const { return false; }
bool IsConstQualified() const { return type_.IsConstQualified(); } bool IsConstQualified() const { return type_.IsConstQualified(); }
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); } bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
@@ -332,6 +335,7 @@ public:
void AdditiveOpTypeChecking(); void AdditiveOpTypeChecking();
void ShiftOpTypeChecking(); void ShiftOpTypeChecking();
void RangeOpTypeChecking(); void RangeOpTypeChecking();
void MatmulOpTypeChecking();
void RelationalOpTypeChecking(); void RelationalOpTypeChecking();
void EqualityOpTypeChecking(); void EqualityOpTypeChecking();
void BitwiseOpTypeChecking(); void BitwiseOpTypeChecking();
@@ -378,11 +382,12 @@ public:
virtual ~UnaryOp() {} virtual ~UnaryOp() {}
virtual void Accept(Visitor* v); virtual void Accept(Visitor* v);
virtual bool IsLVal(); virtual bool IsLVal();
ArithmType* Convert(); ::Type *Convert();
void TypeChecking(); void TypeChecking();
void IncDecOpTypeChecking(); void IncDecOpTypeChecking();
void AddrOpTypeChecking(); void AddrOpTypeChecking();
void DerefOpTypeChecking(); void DerefOpTypeChecking();
void TransOpTypeChecking();
void UnaryArithmOpTypeChecking(); void UnaryArithmOpTypeChecking();
void CastOpTypeChecking(); void CastOpTypeChecking();

View File

@@ -75,6 +75,7 @@ public:
QualType ParseTypeName(); QualType ParseTypeName();
Expr* ParseCastExpr(); Expr* ParseCastExpr();
Expr* ParseRangeExpr(); Expr* ParseRangeExpr();
Expr* ParseMatmulExpr();
Expr* ParseMultiplicativeExpr(); Expr* ParseMultiplicativeExpr();
Expr* ParseAdditiveExpr(); Expr* ParseAdditiveExpr();
Expr* ParseShiftExpr(); Expr* ParseShiftExpr();

View File

@@ -64,7 +64,7 @@ public:
NOT = '!', NOT = '!',
COND = '?', COND = '?',
SHARP = '#', SHARP = '#',
AT = '@', MATMUL = '@',
NEW_LINE = '\n', NEW_LINE = '\n',
DSHARP = 128, // '##' DSHARP = 128, // '##'
@@ -126,6 +126,10 @@ public:
NORETURN, // _Noreturn NORETURN, // _Noreturn
// FUNCTION SPECIFIER END // FUNCTION SPECIFIER END
// TILE ARITHMETICS BEGIN
NEWAXIS,
// TILE ARITHMETICS END
ALIGNAS, // _Alignas ALIGNAS, // _Alignas
// For syntactic convenience // For syntactic convenience
STATIC_ASSERT, // _Static_assert STATIC_ASSERT, // _Static_assert

View File

@@ -153,6 +153,10 @@ public:
virtual bool IsBool() const { return false; } virtual bool IsBool() const { return false; }
virtual bool IsVoidPointer() const { return false; } virtual bool IsVoidPointer() const { return false; }
virtual bool IsUnsigned() const { return false; } virtual bool IsUnsigned() const { return false; }
virtual bool IsTile() const { return ToTile() != nullptr; }
const Type* ScalarType() const;
Type* ScalarType();
virtual VoidType* ToVoid() { return nullptr; } virtual VoidType* ToVoid() { return nullptr; }
virtual const VoidType* ToVoid() const { return nullptr; } virtual const VoidType* ToVoid() const { return nullptr; }
@@ -327,16 +331,22 @@ public:
static TileType* New(const ShapeInt& shape, QualType eleType); static TileType* New(const ShapeInt& shape, QualType eleType);
virtual ~TileType() { } virtual ~TileType() { }
virtual TileType* toTile() { return this; } virtual TileType* ToTile() { return this; }
virtual const TileType* toTile() const { return this; } virtual const TileType* ToTile() const { return this; }
virtual bool Compatible(const Type& other) const; virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 0; } virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; }
virtual int Align() const { return derived_->Align(); } virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const { virtual std::string Str() const {
return derived_->Str() + "[{}]:" + std::to_string(Width()); return derived_->Str() + "[{}]:" + std::to_string(Width());
} }
ShapeInt Shape() { return shape_; } ShapeInt Shape() { return shape_; }
int NumEle() const {
int ret = 1;
for(int s: shape_)
ret *= s;
return ret;
}
protected: protected:
TileType(MemPool* pool, const ShapeExpr& expr, QualType derived) TileType(MemPool* pool, const ShapeExpr& expr, QualType derived)

View File

@@ -144,6 +144,26 @@ Expr* Expr::MayCast(Expr* expr, QualType desType) {
return expr; return expr;
} }
// Extract the operand's scalar type if possible
// and emit an error otherwise
::Type* Expr::TryExtractScalarType(Expr* loc, Expr *operand) {
auto scalType = operand->Type()->ScalarType();
if(!scalType)
Error(loc, "expect tile or scalar operand");
return scalType;
}
// If operand is a tile, return a tile of the same shape and
// provided element type
// If operand is a scalar, return provided element type
// directly
::Type* Expr::ScalarOrLikeTile(Expr* operand, ::Type* ty) {
assert(ty->IsScalar());
::Type *retTy = ty;
if(TileType *T = operand->Type()->ToTile())
retTy = TileType::New(T->Shape(), retTy);
return retTy;
}
BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) { BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) {
return New(tok, tok->tag_, lhs, rhs); return New(tok, tok->tag_, lhs, rhs);
@@ -166,6 +186,7 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
case Token::LOGICAL_AND: case Token::LOGICAL_AND:
case Token::LOGICAL_OR: case Token::LOGICAL_OR:
case Token::ELLIPSIS: case Token::ELLIPSIS:
case Token::MATMUL:
break; break;
default: default:
assert(0); assert(0);
@@ -180,18 +201,18 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
ArithmType* BinaryOp::Convert() { ArithmType* BinaryOp::Convert() {
// Both lhs and rhs are ensured to be have arithmetic type // Both lhs and rhs are ensured to be have arithmetic scalar type
auto lhsType = lhs_->Type()->ToArithm(); auto lhsType = lhs_->Type()->ScalarType()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm(); auto rhsType = rhs_->Type()->ScalarType()->ToArithm();
assert(lhsType && rhsType); assert(lhsType && rhsType);
auto type = ArithmType::MaxType(lhsType, rhsType); auto maxType = ArithmType::MaxType(lhsType, rhsType);
if (lhsType != type) { // Pointer comparation is enough! if (lhsType != maxType) { // Pointer comparation is enough!
lhs_ = UnaryOp::New(Token::CAST, lhs_, type); lhs_ = UnaryOp::New(Token::CAST, lhs_, ScalarOrLikeTile(lhs_, maxType));
} }
if (rhsType != type) { if (rhsType != maxType) {
rhs_ = UnaryOp::New(Token::CAST, rhs_, type); rhs_ = UnaryOp::New(Token::CAST, rhs_, ScalarOrLikeTile(rhs_, maxType));
} }
return type; return maxType;
} }
void BinaryOp::Broadcast() { void BinaryOp::Broadcast() {
@@ -225,6 +246,8 @@ void BinaryOp::Broadcast() {
retShape[i] = rhsShape[i]; retShape[i] = rhsShape[i];
else if(rhsShape[i] == 1) else if(rhsShape[i] == 1)
retShape[i] = lhsShape[i]; retShape[i] = lhsShape[i];
else if(lhsShape[i] == rhsShape[i])
retShape[i] = lhsShape[i];
else else
Error(this, "cannot broadcast dimension %d " Error(this, "cannot broadcast dimension %d "
"for operands of shape %d and %d", "for operands of shape %d and %d",
@@ -232,8 +255,10 @@ void BinaryOp::Broadcast() {
} }
auto eleType = lhsType->Derived(); auto eleType = lhsType->Derived();
type_ = TileType::New(retShape, eleType); type_ = TileType::New(retShape, eleType);
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_); if(retShape != lhsShape)
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_); lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
if(retShape != rhsShape)
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
} }
} }
@@ -303,6 +328,9 @@ void BinaryOp::TypeChecking() {
case Token::ELLIPSIS: case Token::ELLIPSIS:
return RangeOpTypeChecking(); return RangeOpTypeChecking();
case Token::MATMUL:
return MatmulOpTypeChecking();
default: default:
assert(0); assert(0);
} }
@@ -315,12 +343,15 @@ void BinaryOp::CommaOpTypeChecking() {
void BinaryOp::SubScriptingOpTypeChecking() { void BinaryOp::SubScriptingOpTypeChecking() {
auto lhsType = lhs_->Type()->ToPointer(); assert(false);
auto lhsType = lhs_->Type()->ToTile();
if (!lhsType) { if (!lhsType) {
Error(this, "an pointer expected"); Error(this, "operator [] can only be used on tiles");
} }
if (!rhs_->Type()->IsInteger()) { if (!rhs_->Type()->IsInteger()) {
Error(this, "the operand of [] should be intger"); Error(this, "the operand of [] should be integer");
} }
// The type of [] operator is the derived type // The type of [] operator is the derived type
@@ -334,14 +365,20 @@ void BinaryOp::MemberRefOpTypeChecking() {
void BinaryOp::MultiOpTypeChecking() { void BinaryOp::MultiOpTypeChecking() {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) { ::Type* lhsScalType = lhs_->Type()->ScalarType();
::Type* rhsScalType = rhs_->Type()->ScalarType();
if(!lhsScalType || !rhsScalType) {
Error(this, "operands should have type or scalar type");
}
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
Error(this, "operands should have arithmetic type"); Error(this, "operands should have arithmetic type");
} }
if ('%' == op_ && if ('%' == op_ &&
!(lhs_->Type()->IsInteger() && rhs_->Type()->IsInteger())) { !(lhsScalType->IsInteger() && rhsScalType->IsInteger())) {
Error(this, "operands of '%%' should be integers"); Error(this, "operands of '%%' should be integers");
} }
type_ = Convert(); type_ = Convert();
Broadcast();
} }
@@ -351,40 +388,47 @@ void BinaryOp::MultiOpTypeChecking() {
* 2. pointer can be used: * 2. pointer can be used:
* 1. lhs of MINUS operator, and rhs must be integer or pointer; * 1. lhs of MINUS operator, and rhs must be integer or pointer;
* 2. lhs/rhs of ADD operator, and the other operand must be integer; * 2. lhs/rhs of ADD operator, and the other operand must be integer;
* 3. tiles can be used:
* 1. the scalar type of lhs/rhs satisfy the above requirements
* 2. lhs/rhs that have identical shape
* 3. lhs/rhs that can be broadcast as per numpy-like semantics
*/ */
void BinaryOp::AdditiveOpTypeChecking() { void BinaryOp::AdditiveOpTypeChecking() {
auto lhsType = lhs_->Type()->ToPointer(); ::Type* lhsScalType = TryExtractScalarType(this, lhs_);
auto rhsType = rhs_->Type()->ToPointer(); ::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (lhsType) { auto lhsPtrType = lhsScalType->ToPointer();
auto rhsPtrType = rhsScalType->ToPointer();
if (lhsPtrType) {
if (op_ == '-') { if (op_ == '-') {
if (rhsType) { if (rhsPtrType) {
if (!lhsType->Compatible(*rhsType)) if (!lhsPtrType->Compatible(*rhsPtrType))
Error(this, "invalid operands to binary -"); Error(this, "invalid operands to binary -");
type_ = ArithmType::New(T_LONG); // ptrdiff_t type_ = ArithmType::New(T_LONG); // ptrdiff_t
} else if (!rhs_->Type()->IsInteger()) { } else if (!rhsScalType->IsInteger()) {
Error(this, "invalid operands to binary -"); Error(this, "invalid operands to binary -");
} else { } else {
type_ = lhsType; type_ = lhsPtrType;
} }
} else if (!rhs_->Type()->IsInteger()) { } else if (!rhsScalType->IsInteger()) {
Error(this, "invalid operands to binary +"); Error(this, "invalid operands to binary +");
} else { } else {
type_ = lhsType; type_ = lhsPtrType;
} }
} else if (rhsType) { } else if (rhsPtrType) {
if (op_ == '+' && !lhs_->Type()->IsInteger()) { if (op_ == '+' && !lhsScalType->IsInteger()) {
Error(this, "invalid operands to binary '+'"); Error(this, "invalid operands to binary '+'");
} else if (op_ == '-' && !lhsType) { } else if (op_ == '-' && !lhsPtrType) {
Error(this, "invalid operands to binary '-'"); Error(this, "invalid operands to binary '-'");
} }
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhs_->Type(); type_ = op_ == '-' ? ArithmType::New(T_LONG): rhsScalType;
std::swap(lhs_, rhs_); // To simplify code gen std::swap(lhs_, rhs_); // To simplify code gen
} else { } else {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) { if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
Error(this, "invalid operands to binary %s", tok_->str_.c_str()); Error(this, "invalid operands to binary %s", tok_->str_.c_str());
} }
type_ = Convert(); type_ = Convert();
} }
Broadcast();
} }
void BinaryOp::RangeOpTypeChecking() { void BinaryOp::RangeOpTypeChecking() {
@@ -396,59 +440,95 @@ void BinaryOp::RangeOpTypeChecking() {
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType)); rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
long begin = Evaluator<long>().Eval(lhs_); long begin = Evaluator<long>().Eval(lhs_);
long end = Evaluator<long>().Eval(rhs_); long end = Evaluator<long>().Eval(rhs_);
int len = end - begin; int len = static_cast<int>(end - begin);
if(len < 0) if(len < 0)
Error(this, "range cannot be negative"); Error(this, "range cannot be negative");
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type()); type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
} }
void BinaryOp::MatmulOpTypeChecking() {
auto lhsType = lhs_->Type()->ToTile();
auto rhsType = rhs_->Type()->ToTile();
if(!lhsType || !rhsType)
Error(this, "expect tile operands for matrix multiplication");
auto lhsShape = lhsType->Shape();
auto rhsShape = rhsType->Shape();
size_t lhsRank = lhsShape.size();
size_t rhsRank = rhsShape.size();
if(lhsRank != 2 || rhsRank != 2)
Error(this, "matrix multiplication operands must have rank 2");
if(lhsShape[1] != rhsShape[0])
Error(this, "matrix multiplication operands have incompatible inner dimension"
" %d and %d", lhsShape[1], rhsShape[0]);
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
QualType retType = lhsType->Derived();
if(retType != rhsType->Derived())
Error(this, "matrix multiplication operands have incompatible data types");
type_ = TileType::New(retShape, lhsType->Derived());
}
void BinaryOp::ShiftOpTypeChecking() { void BinaryOp::ShiftOpTypeChecking() {
auto lhsType = lhs_->Type()->ToArithm(); ::Type* lhsScalType = TryExtractScalarType(this, lhs_);
auto rhsType = rhs_->Type()->ToArithm(); ::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsType = lhsScalType->ToArithm();
auto rhsType = rhsScalType->ToArithm();
if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger()) if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
Error(this, "expect integers for shift operator"); Error(this, "expect integers for shift operator");
lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType)); lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType)));
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType)); rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType)));
type_ = lhs_->Type(); type_ = lhs_->Type();
Broadcast();
} }
void BinaryOp::RelationalOpTypeChecking() { void BinaryOp::RelationalOpTypeChecking() {
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) { ::Type* lhsScalType = TryExtractScalarType(this, lhs_);
EnsureCompatible(lhs_->Type(), rhs_->Type()); ::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
EnsureCompatible(lhsScalType, rhsScalType);
} else { } else {
if (!lhs_->Type()->IsReal() || !rhs_->Type()->IsReal()) { if (!lhsScalType->IsReal() || !rhsScalType->IsReal()) {
Error(this, "expect real type of operands"); Error(this, "expect real type of operands");
} }
Convert(); Convert();
} }
type_ = ArithmType::New(T_INT); type_ = ArithmType::New(T_INT);
Broadcast();
} }
void BinaryOp::EqualityOpTypeChecking() { void BinaryOp::EqualityOpTypeChecking() {
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) { ::Type* lhsScalType = TryExtractScalarType(this, lhs_);
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type()); ::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
} else { } else {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm())
Error(this, "invalid operands to binary %s", tok_->str_.c_str()); Error(this, "invalid operands to binary %s", tok_->str_.c_str());
Convert(); Convert();
} }
type_ = ArithmType::New(T_INT); type_ = ArithmType::New(T_INT);
Broadcast();
} }
void BinaryOp::BitwiseOpTypeChecking() { void BinaryOp::BitwiseOpTypeChecking() {
if (!lhs_->Type()->IsInteger() || !rhs_->Type()->IsInteger()) ::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger())
Error(this, "operands of '&' should be integer"); Error(this, "operands of '&' should be integer");
type_ = Convert(); type_ = Convert();
Broadcast();
} }
void BinaryOp::LogicalOpTypeChecking() { void BinaryOp::LogicalOpTypeChecking() {
if (!lhs_->Type()->IsScalar() || !rhs_->Type()->IsScalar()) ::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
Error(this, "the operand should be arithmetic type or pointer"); Error(this, "the operand should be arithmetic type or pointer");
type_ = ArithmType::New(T_INT); type_ = ArithmType::New(T_INT);
Broadcast();
} }
@@ -459,12 +539,14 @@ void BinaryOp::AssignOpTypeChecking() {
Error(lhs_, "lvalue expression expected"); Error(lhs_, "lvalue expression expected");
} }
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) { ::Type* lhsScalType = TryExtractScalarType(this, lhs_);
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type()); ::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
} }
// The other constraints are lefted to cast operator // The other constraints are lefted to cast operator
rhs_ = Expr::MayCast(rhs_, lhs_->Type()); rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
type_ = lhs_->Type(); type_ = lhs_->Type();
} }
@@ -488,13 +570,16 @@ bool UnaryOp::IsLVal() {
} }
ArithmType* UnaryOp::Convert() { ::Type* UnaryOp::Convert() {
auto arithmType = operand_->Type()->ToArithm(); auto scalType = operand_->Type()->ScalarType();
assert(scalType);
auto arithmType = scalType->ToArithm();
assert(arithmType); assert(arithmType);
if (arithmType->IsInteger()) if (arithmType->IsInteger())
arithmType = ArithmType::IntegerPromote(arithmType); arithmType = ArithmType::IntegerPromote(arithmType);
operand_ = Expr::MayCast(operand_, arithmType); ::Type* retType = ScalarOrLikeTile(operand_, arithmType);
return arithmType; operand_ = Expr::MayCast(operand_, retType);
return retType;
} }
@@ -521,20 +606,22 @@ void UnaryOp::TypeChecking() {
case Token::CAST: case Token::CAST:
return CastOpTypeChecking(); return CastOpTypeChecking();
case '^':
return TransOpTypeChecking();
default: default:
assert(false); assert(false);
} }
} }
void UnaryOp::IncDecOpTypeChecking() { void UnaryOp::IncDecOpTypeChecking() {
if (operand_->IsConstQualified()) { if (operand_->IsConstQualified()) {
Error(this, "increment/decrement of const qualified expression"); Error(this, "increment/decrement of const qualified expression");
} else if (!operand_->IsLVal()) { } else if (!operand_->IsLVal()) {
Error(this, "lvalue expression expected"); Error(this, "lvalue expression expected");
} }
auto scalType = TryExtractScalarType(this, operand_);
if (!operand_->Type()->IsReal() && !operand_->Type()->ToPointer()) { if (!scalType->IsReal() && !scalType->ToPointer()) {
Error(this, "expect operand of real type or pointer"); Error(this, "expect operand of real type or pointer");
} }
type_ = operand_->Type(); type_ = operand_->Type();
@@ -545,43 +632,78 @@ void UnaryOp::AddrOpTypeChecking() {
auto funcType = operand_->Type()->ToFunc(); auto funcType = operand_->Type()->ToFunc();
if (funcType == nullptr && !operand_->IsLVal()) if (funcType == nullptr && !operand_->IsLVal())
Error(this, "expression must be an lvalue or function designator"); Error(this, "expression must be an lvalue or function designator");
if(operand_->Type()->IsTile())
Error(this, "cannot take the address of a tile");
type_ = PointerType::New(operand_->Type()); type_ = PointerType::New(operand_->Type());
} }
void UnaryOp::DerefOpTypeChecking() { void UnaryOp::DerefOpTypeChecking() {
auto pointerType = operand_->Type()->ToPointer(); auto scalType = TryExtractScalarType(this, operand_);
auto pointerType = scalType->ToPointer();
if (!pointerType) if (!pointerType)
Error(this, "pointer expected for deref operator '*'"); Error(this, "pointer expected for deref operator '*'");
type_ = pointerType->Derived(); type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
} }
void UnaryOp::TransOpTypeChecking() {
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "tile expected for transposition operator '^'");
auto shape = tileType->Shape();
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::UnaryArithmOpTypeChecking() { void UnaryOp::UnaryArithmOpTypeChecking() {
auto scalType = TryExtractScalarType(this, operand_);
if (Token::PLUS == op_ || Token::MINUS == op_) { if (Token::PLUS == op_ || Token::MINUS == op_) {
if (!operand_->Type()->ToArithm()) if (!scalType->ToArithm())
Error(this, "Arithmetic type expected"); Error(this, "Arithmetic type expected");
Convert(); Convert();
type_ = operand_->Type(); type_ = operand_->Type();
} else if ('~' == op_) { } else if ('~' == op_) {
if (!operand_->Type()->IsInteger()) if (!scalType->IsInteger())
Error(this, "integer expected for operator '~'"); Error(this, "integer expected for operator '~'");
Convert(); Convert();
type_ = operand_->Type(); type_ = operand_->Type();
} else if (!operand_->Type()->IsScalar()) { } else if (!scalType->IsScalar()) {
Error(this, "arithmetic type or pointer expected for operator '!'"); Error(this, "arithmetic type or pointer expected for operator '!'");
} else { } else {
type_ = ArithmType::New(T_INT); type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_INT));
} }
} }
void UnaryOp::CastOpTypeChecking() { void UnaryOp::CastOpTypeChecking() {
auto operandType = Type::MayCast(operand_->Type()); auto operandType = Type::MayCast(operand_->Type());
// The type_ has been initiated to dest type // The type_ has been initiated to dest type
if (type_->ToVoid()) { if (type_->ToVoid()) {
// The expression becomes a void expression // The expression becomes a void expression
} else if(type_->IsTile() || operandType->IsTile()) {
/* Broadcasting rules:
* 1. Tiles with 1 element can be converted to scalar
* 2. Scalar can be converted to tiles of any shapes
* 3. Tiles can be converted to another tile only if the
* mismatching dimensions are unitary
*/
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
Error(this, "tile with more than one element cannot be casted to scalar");
if(type_->IsTile() && operandType->IsTile()){
auto shape = type_->ToTile()->Shape();
auto operandShape = operandType->ToTile()->Shape();
if(operandShape.size() > shape.size())
Error(this, "cast cannot reduce operand rank");
while(operandShape.size() < shape.size())
operandShape.insert(operandShape.begin(), 1);
for(size_t i = 0; i < shape.size(); i++) {
if(shape[i] != 1 && operandShape[i] != 1 && shape[i] != operandShape[i])
Error(this, "cannot broadcast dimension %d "
"for operands of shape %d and %d",
i, shape[i], operandShape[i]);
}
}
} else if (!type_->IsScalar() || !operandType->IsScalar()) { } else if (!type_->IsScalar() || !operandType->IsScalar()) {
if (!type_->Compatible(*operandType)) if (!type_->Compatible(*operandType))
Error(this, "the cast type should be arithemetic type or pointer"); Error(this, "the cast type should be arithemetic type or pointer");

View File

@@ -442,11 +442,27 @@ Expr* Parser::ParsePostfixExprTail(Expr* lhs) {
Expr* Parser::ParseSubScripting(Expr* lhs) { Expr* Parser::ParseSubScripting(Expr* lhs) {
auto rhs = ParseExpr(); auto lhsTile = lhs->Type()->ToTile();
auto tok = ts_.Peek(); if(lhsTile == nullptr)
Error(lhs, "tile expected");
TileType::ShapeInt lhsShape = lhsTile->Shape();
QualType lhsQual = lhsTile->Derived();
// create ret shape
TileType::ShapeInt shape;
size_t i = 0;
do {
auto tok = ts_.Next();
if(tok->tag_ == ':')
shape.push_back(lhsShape[i++]);
else if(tok->tag_ == Token::NEWAXIS)
shape.push_back(1);
else
Error(tok, "only ':' and newaxis are supported in subscripts");
}while(ts_.Try(','));
ts_.Expect(']'); ts_.Expect(']');
auto operand = BinaryOp::New(tok, '+', lhs, rhs); // create ret tile
return UnaryOp::New(Token::DEREF, operand); TileType *retType = TileType::New(shape, lhsQual);
return UnaryOp::New(Token::CAST, lhs, retType);
} }
@@ -501,6 +517,7 @@ Expr* Parser::ParseUnaryExpr() {
case '-': return ParseUnaryOp(tok, Token::MINUS); case '-': return ParseUnaryOp(tok, Token::MINUS);
case '~': return ParseUnaryOp(tok, '~'); case '~': return ParseUnaryOp(tok, '~');
case '!': return ParseUnaryOp(tok, '!'); case '!': return ParseUnaryOp(tok, '!');
case '^': return ParseUnaryOp(tok, Token::XOR);
default: default:
ts_.PutBack(); ts_.PutBack();
return ParsePostfixExpr(); return ParsePostfixExpr();
@@ -584,7 +601,7 @@ Expr* Parser::ParseCastExpr() {
Expr* Parser::ParseRangeExpr() { Expr* Parser::ParseRangeExpr() {
auto lhs = ParseCastExpr(); auto lhs = ParseCastExpr();
auto tok = ts_.Next(); auto tok = ts_.Next();
while (tok->tag_ == Token::ELLIPSIS){ while (tok->tag_ == Token::ELLIPSIS) {
auto rhs = ParseCastExpr(); auto rhs = ParseCastExpr();
lhs = BinaryOp::New(tok, lhs, rhs); lhs = BinaryOp::New(tok, lhs, rhs);
tok = ts_.Next(); tok = ts_.Next();
@@ -593,16 +610,26 @@ Expr* Parser::ParseRangeExpr() {
return lhs; return lhs;
} }
Expr* Parser::ParseMultiplicativeExpr() { Expr* Parser::ParseMatmulExpr() {
auto lhs = ParseRangeExpr(); auto lhs = ParseRangeExpr();
auto tok = ts_.Next(); auto tok = ts_.Next();
while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') { while (tok->tag_ == Token::MATMUL) {
auto rhs = ParseRangeExpr(); auto rhs = ParseRangeExpr();
lhs = BinaryOp::New(tok, lhs, rhs); lhs = BinaryOp::New(tok, lhs, rhs);
tok = ts_.Next(); tok = ts_.Next();
} }
ts_.PutBack();
return lhs;
}
Expr* Parser::ParseMultiplicativeExpr() {
auto lhs = ParseMatmulExpr();
auto tok = ts_.Next();
while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') {
auto rhs = ParseMatmulExpr();
lhs = BinaryOp::New(tok, lhs, rhs);
tok = ts_.Next();
}
ts_.PutBack(); ts_.PutBack();
return lhs; return lhs;
} }

View File

@@ -27,6 +27,7 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
{ "inline", Token::INLINE }, { "inline", Token::INLINE },
{ "int", Token::INT }, { "int", Token::INT },
{ "long", Token::LONG }, { "long", Token::LONG },
{ "newaxis", Token::NEWAXIS },
{ "signed", Token::SIGNED }, { "signed", Token::SIGNED },
{ "unsigned", Token::UNSIGNED }, { "unsigned", Token::UNSIGNED },
{ "register", Token::REGISTER }, { "register", Token::REGISTER },
@@ -126,6 +127,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
{ Token::INLINE, "inline" }, { Token::INLINE, "inline" },
{ Token::INT, "int" }, { Token::INT, "int" },
{ Token::LONG, "long" }, { Token::LONG, "long" },
{ Token::NEWAXIS, "newaxis" },
{ Token::SIGNED, "signed" }, { Token::SIGNED, "signed" },
{ Token::UNSIGNED, "unsigned" }, { Token::UNSIGNED, "unsigned" },
{ Token::REGISTER, "register" }, { Token::REGISTER, "register" },

View File

@@ -32,6 +32,18 @@ QualType Type::MayCast(QualType type, bool inProtoScope) {
return type; return type;
} }
const Type* Type::ScalarType() const {
if(IsScalar())
return this;
if(const TileType* p = ToTile())
return p->Derived().GetPtr();
return nullptr;
}
Type* Type::ScalarType() {
auto cthis = const_cast<const Type*>(this);
return const_cast<Type*>(cthis->ScalarType());
}
VoidType* VoidType::New() { VoidType* VoidType::New() {
static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool); static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool);
@@ -143,12 +155,16 @@ int ArithmType::Width() const {
return intWidth_ << 1; return intWidth_ << 1;
case T_LLONG: case T_UNSIGNED | T_LLONG: case T_LLONG: case T_UNSIGNED | T_LLONG:
return intWidth_ << 1; return intWidth_ << 1;
case T_HALF:
return intWidth_ >> 1;
case T_FLOAT: case T_FLOAT:
return intWidth_; return intWidth_;
case T_DOUBLE: case T_DOUBLE:
return intWidth_ << 1; return intWidth_ << 1;
case T_LONG | T_DOUBLE: case T_LONG | T_DOUBLE:
return intWidth_ << 1; return intWidth_ << 1;
case T_HALF | T_COMPLEX:
return intWidth_;
case T_FLOAT | T_COMPLEX: case T_FLOAT | T_COMPLEX:
return intWidth_ << 1; return intWidth_ << 1;
case T_DOUBLE | T_COMPLEX: case T_DOUBLE | T_COMPLEX:
@@ -171,9 +187,10 @@ int ArithmType::Rank() const {
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3; case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3;
case T_LONG: case T_UNSIGNED | T_LONG: return 4; case T_LONG: case T_UNSIGNED | T_LONG: return 4;
case T_LLONG: case T_UNSIGNED | T_LLONG: return 5; case T_LLONG: case T_UNSIGNED | T_LLONG: return 5;
case T_FLOAT: return 6; case T_HALF: return 6;
case T_DOUBLE: return 7; case T_FLOAT: return 7;
case T_LONG | T_DOUBLE: return 8; case T_DOUBLE: return 8;
case T_LONG | T_DOUBLE: return 9;
default: default:
assert(tag_ & T_COMPLEX); assert(tag_ & T_COMPLEX);
Error("complex not supported yet"); Error("complex not supported yet");