basic parsing doesn't throw error

This commit is contained in:
Philippe Tillet
2019-08-20 16:22:43 -07:00
parent bc11e31419
commit 61f25f90eb
9 changed files with 287 additions and 94 deletions

View File

@@ -55,8 +55,8 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT ? "trans(a)" : "a";
std::string useb = BT ? "trans(b)" : "b";
std::string usea = AT ? "^a" : "a";
std::string useb = BT ? "^b" : "b";
if(AT){
std::swap(AS0, AS1);
std::swap(XAS0, XAS1);
@@ -82,6 +82,11 @@ R"(
#define TN 128
#define TK 32
#define bool _Bool
#define true 1
#define false 0
#define __bool_true_false_are_defined 1
extern int get_program_id(int);
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
@@ -94,28 +99,28 @@ void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
int ryb[TN] = ridy * TN + 0 ... TN;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
float xc[)" + XCS + R"(] = 0;
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
int ryb[{TN}] = ridy * TN + 0 ... TN;
int rka[{TK}] = 0 ... TK;
int rkb[{TK}] = 0 ... TK;
float xc[{)" + XCS + R"(}] = 0;
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
for(int k = K; k > 0; k = k - TK){
xc = dot()" + usea + ", " + useb + R"(, xc);
xc = )" + usea + " @ " + useb + R"( + xc;
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[TM, TN] = xc;
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
int rxc[{TM}] = ridx * TM + (0 ... TM);
int ryc[{TN}] = ridy * TN + (0 ... TN);
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty + R"( c[{TM, TN}] = xc;
bool checkc0[{TM}] = rxc < M;
bool checkc1[{TN}] = ryc < N;
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
*pc = c;
}
)";

View File

@@ -278,6 +278,9 @@ public:
static Expr* MayCast(Expr* expr);
static Expr* MayCast(Expr* expr, QualType desType);
static ::Type* TryExtractScalarType(Expr* loc, Expr *operand);
static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty);
virtual bool IsNullPointerConstant() const { return false; }
bool IsConstQualified() const { return type_.IsConstQualified(); }
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
@@ -332,6 +335,7 @@ public:
void AdditiveOpTypeChecking();
void ShiftOpTypeChecking();
void RangeOpTypeChecking();
void MatmulOpTypeChecking();
void RelationalOpTypeChecking();
void EqualityOpTypeChecking();
void BitwiseOpTypeChecking();
@@ -378,11 +382,12 @@ public:
virtual ~UnaryOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal();
ArithmType* Convert();
::Type *Convert();
void TypeChecking();
void IncDecOpTypeChecking();
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void TransOpTypeChecking();
void UnaryArithmOpTypeChecking();
void CastOpTypeChecking();

View File

@@ -75,6 +75,7 @@ public:
QualType ParseTypeName();
Expr* ParseCastExpr();
Expr* ParseRangeExpr();
Expr* ParseMatmulExpr();
Expr* ParseMultiplicativeExpr();
Expr* ParseAdditiveExpr();
Expr* ParseShiftExpr();

View File

@@ -64,7 +64,7 @@ public:
NOT = '!',
COND = '?',
SHARP = '#',
AT = '@',
MATMUL = '@',
NEW_LINE = '\n',
DSHARP = 128, // '##'
@@ -126,6 +126,10 @@ public:
NORETURN, // _Noreturn
// FUNCTION SPECIFIER END
// TILE ARITHMETICS BEGIN
NEWAXIS,
// TILE ARITHMETICS END
ALIGNAS, // _Alignas
// For syntactic convenience
STATIC_ASSERT, // _Static_assert

View File

@@ -153,6 +153,10 @@ public:
virtual bool IsBool() const { return false; }
virtual bool IsVoidPointer() const { return false; }
virtual bool IsUnsigned() const { return false; }
virtual bool IsTile() const { return ToTile() != nullptr; }
const Type* ScalarType() const;
Type* ScalarType();
virtual VoidType* ToVoid() { return nullptr; }
virtual const VoidType* ToVoid() const { return nullptr; }
@@ -327,16 +331,22 @@ public:
static TileType* New(const ShapeInt& shape, QualType eleType);
virtual ~TileType() { }
virtual TileType* toTile() { return this; }
virtual const TileType* toTile() const { return this; }
virtual TileType* ToTile() { return this; }
virtual const TileType* ToTile() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 0; }
virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; }
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[{}]:" + std::to_string(Width());
}
ShapeInt Shape() { return shape_; }
int NumEle() const {
int ret = 1;
for(int s: shape_)
ret *= s;
return ret;
}
protected:
TileType(MemPool* pool, const ShapeExpr& expr, QualType derived)

View File

@@ -144,6 +144,26 @@ Expr* Expr::MayCast(Expr* expr, QualType desType) {
return expr;
}
// Extract the operand's scalar type if possible
// and emit an error otherwise
::Type* Expr::TryExtractScalarType(Expr* loc, Expr *operand) {
auto scalType = operand->Type()->ScalarType();
if(!scalType)
Error(loc, "expect tile or scalar operand");
return scalType;
}
// If operand is a tile, return a tile of the same shape and
// provided element type
// If operand is a scalar, return provided element type
// directly
::Type* Expr::ScalarOrLikeTile(Expr* operand, ::Type* ty) {
assert(ty->IsScalar());
::Type *retTy = ty;
if(TileType *T = operand->Type()->ToTile())
retTy = TileType::New(T->Shape(), retTy);
return retTy;
}
BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) {
return New(tok, tok->tag_, lhs, rhs);
@@ -166,6 +186,7 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
case Token::LOGICAL_AND:
case Token::LOGICAL_OR:
case Token::ELLIPSIS:
case Token::MATMUL:
break;
default:
assert(0);
@@ -180,18 +201,18 @@ BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
ArithmType* BinaryOp::Convert() {
// Both lhs and rhs are ensured to be have arithmetic type
auto lhsType = lhs_->Type()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm();
// Both lhs and rhs are ensured to be have arithmetic scalar type
auto lhsType = lhs_->Type()->ScalarType()->ToArithm();
auto rhsType = rhs_->Type()->ScalarType()->ToArithm();
assert(lhsType && rhsType);
auto type = ArithmType::MaxType(lhsType, rhsType);
if (lhsType != type) { // Pointer comparation is enough!
lhs_ = UnaryOp::New(Token::CAST, lhs_, type);
auto maxType = ArithmType::MaxType(lhsType, rhsType);
if (lhsType != maxType) { // Pointer comparation is enough!
lhs_ = UnaryOp::New(Token::CAST, lhs_, ScalarOrLikeTile(lhs_, maxType));
}
if (rhsType != type) {
rhs_ = UnaryOp::New(Token::CAST, rhs_, type);
if (rhsType != maxType) {
rhs_ = UnaryOp::New(Token::CAST, rhs_, ScalarOrLikeTile(rhs_, maxType));
}
return type;
return maxType;
}
void BinaryOp::Broadcast() {
@@ -225,6 +246,8 @@ void BinaryOp::Broadcast() {
retShape[i] = rhsShape[i];
else if(rhsShape[i] == 1)
retShape[i] = lhsShape[i];
else if(lhsShape[i] == rhsShape[i])
retShape[i] = lhsShape[i];
else
Error(this, "cannot broadcast dimension %d "
"for operands of shape %d and %d",
@@ -232,8 +255,10 @@ void BinaryOp::Broadcast() {
}
auto eleType = lhsType->Derived();
type_ = TileType::New(retShape, eleType);
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
if(retShape != lhsShape)
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
if(retShape != rhsShape)
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
}
}
@@ -303,6 +328,9 @@ void BinaryOp::TypeChecking() {
case Token::ELLIPSIS:
return RangeOpTypeChecking();
case Token::MATMUL:
return MatmulOpTypeChecking();
default:
assert(0);
}
@@ -315,12 +343,15 @@ void BinaryOp::CommaOpTypeChecking() {
void BinaryOp::SubScriptingOpTypeChecking() {
auto lhsType = lhs_->Type()->ToPointer();
assert(false);
auto lhsType = lhs_->Type()->ToTile();
if (!lhsType) {
Error(this, "an pointer expected");
Error(this, "operator [] can only be used on tiles");
}
if (!rhs_->Type()->IsInteger()) {
Error(this, "the operand of [] should be intger");
Error(this, "the operand of [] should be integer");
}
// The type of [] operator is the derived type
@@ -334,14 +365,20 @@ void BinaryOp::MemberRefOpTypeChecking() {
void BinaryOp::MultiOpTypeChecking() {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
::Type* lhsScalType = lhs_->Type()->ScalarType();
::Type* rhsScalType = rhs_->Type()->ScalarType();
if(!lhsScalType || !rhsScalType) {
Error(this, "operands should have type or scalar type");
}
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
Error(this, "operands should have arithmetic type");
}
if ('%' == op_ &&
!(lhs_->Type()->IsInteger() && rhs_->Type()->IsInteger())) {
!(lhsScalType->IsInteger() && rhsScalType->IsInteger())) {
Error(this, "operands of '%%' should be integers");
}
type_ = Convert();
Broadcast();
}
@@ -351,40 +388,47 @@ void BinaryOp::MultiOpTypeChecking() {
* 2. pointer can be used:
* 1. lhs of MINUS operator, and rhs must be integer or pointer;
* 2. lhs/rhs of ADD operator, and the other operand must be integer;
* 3. tiles can be used:
* 1. the scalar type of lhs/rhs satisfy the above requirements
* 2. lhs/rhs that have identical shape
* 3. lhs/rhs that can be broadcast as per numpy-like semantics
*/
void BinaryOp::AdditiveOpTypeChecking() {
auto lhsType = lhs_->Type()->ToPointer();
auto rhsType = rhs_->Type()->ToPointer();
if (lhsType) {
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsPtrType = lhsScalType->ToPointer();
auto rhsPtrType = rhsScalType->ToPointer();
if (lhsPtrType) {
if (op_ == '-') {
if (rhsType) {
if (!lhsType->Compatible(*rhsType))
if (rhsPtrType) {
if (!lhsPtrType->Compatible(*rhsPtrType))
Error(this, "invalid operands to binary -");
type_ = ArithmType::New(T_LONG); // ptrdiff_t
} else if (!rhs_->Type()->IsInteger()) {
} else if (!rhsScalType->IsInteger()) {
Error(this, "invalid operands to binary -");
} else {
type_ = lhsType;
type_ = lhsPtrType;
}
} else if (!rhs_->Type()->IsInteger()) {
} else if (!rhsScalType->IsInteger()) {
Error(this, "invalid operands to binary +");
} else {
type_ = lhsType;
type_ = lhsPtrType;
}
} else if (rhsType) {
if (op_ == '+' && !lhs_->Type()->IsInteger()) {
} else if (rhsPtrType) {
if (op_ == '+' && !lhsScalType->IsInteger()) {
Error(this, "invalid operands to binary '+'");
} else if (op_ == '-' && !lhsType) {
} else if (op_ == '-' && !lhsPtrType) {
Error(this, "invalid operands to binary '-'");
}
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhs_->Type();
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhsScalType;
std::swap(lhs_, rhs_); // To simplify code gen
} else {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
}
type_ = Convert();
}
Broadcast();
}
void BinaryOp::RangeOpTypeChecking() {
@@ -396,59 +440,95 @@ void BinaryOp::RangeOpTypeChecking() {
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
long begin = Evaluator<long>().Eval(lhs_);
long end = Evaluator<long>().Eval(rhs_);
int len = end - begin;
int len = static_cast<int>(end - begin);
if(len < 0)
Error(this, "range cannot be negative");
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
}
void BinaryOp::MatmulOpTypeChecking() {
auto lhsType = lhs_->Type()->ToTile();
auto rhsType = rhs_->Type()->ToTile();
if(!lhsType || !rhsType)
Error(this, "expect tile operands for matrix multiplication");
auto lhsShape = lhsType->Shape();
auto rhsShape = rhsType->Shape();
size_t lhsRank = lhsShape.size();
size_t rhsRank = rhsShape.size();
if(lhsRank != 2 || rhsRank != 2)
Error(this, "matrix multiplication operands must have rank 2");
if(lhsShape[1] != rhsShape[0])
Error(this, "matrix multiplication operands have incompatible inner dimension"
" %d and %d", lhsShape[1], rhsShape[0]);
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
QualType retType = lhsType->Derived();
if(retType != rhsType->Derived())
Error(this, "matrix multiplication operands have incompatible data types");
type_ = TileType::New(retShape, lhsType->Derived());
}
void BinaryOp::ShiftOpTypeChecking() {
auto lhsType = lhs_->Type()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm();
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsType = lhsScalType->ToArithm();
auto rhsType = rhsScalType->ToArithm();
if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
Error(this, "expect integers for shift operator");
lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType));
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType)));
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType)));
type_ = lhs_->Type();
Broadcast();
}
void BinaryOp::RelationalOpTypeChecking() {
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) {
EnsureCompatible(lhs_->Type(), rhs_->Type());
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
EnsureCompatible(lhsScalType, rhsScalType);
} else {
if (!lhs_->Type()->IsReal() || !rhs_->Type()->IsReal()) {
if (!lhsScalType->IsReal() || !rhsScalType->IsReal()) {
Error(this, "expect real type of operands");
}
Convert();
}
type_ = ArithmType::New(T_INT);
Broadcast();
}
void BinaryOp::EqualityOpTypeChecking() {
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) {
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type());
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
} else {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm())
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm())
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
Convert();
}
type_ = ArithmType::New(T_INT);
Broadcast();
}
void BinaryOp::BitwiseOpTypeChecking() {
if (!lhs_->Type()->IsInteger() || !rhs_->Type()->IsInteger())
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger())
Error(this, "operands of '&' should be integer");
type_ = Convert();
Broadcast();
}
void BinaryOp::LogicalOpTypeChecking() {
if (!lhs_->Type()->IsScalar() || !rhs_->Type()->IsScalar())
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
Error(this, "the operand should be arithmetic type or pointer");
type_ = ArithmType::New(T_INT);
Broadcast();
}
@@ -459,12 +539,14 @@ void BinaryOp::AssignOpTypeChecking() {
Error(lhs_, "lvalue expression expected");
}
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type());
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
}
// The other constraints are lefted to cast operator
rhs_ = Expr::MayCast(rhs_, lhs_->Type());
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
type_ = lhs_->Type();
}
@@ -488,13 +570,16 @@ bool UnaryOp::IsLVal() {
}
ArithmType* UnaryOp::Convert() {
auto arithmType = operand_->Type()->ToArithm();
::Type* UnaryOp::Convert() {
auto scalType = operand_->Type()->ScalarType();
assert(scalType);
auto arithmType = scalType->ToArithm();
assert(arithmType);
if (arithmType->IsInteger())
arithmType = ArithmType::IntegerPromote(arithmType);
operand_ = Expr::MayCast(operand_, arithmType);
return arithmType;
::Type* retType = ScalarOrLikeTile(operand_, arithmType);
operand_ = Expr::MayCast(operand_, retType);
return retType;
}
@@ -521,20 +606,22 @@ void UnaryOp::TypeChecking() {
case Token::CAST:
return CastOpTypeChecking();
case '^':
return TransOpTypeChecking();
default:
assert(false);
}
}
void UnaryOp::IncDecOpTypeChecking() {
if (operand_->IsConstQualified()) {
Error(this, "increment/decrement of const qualified expression");
} else if (!operand_->IsLVal()) {
Error(this, "lvalue expression expected");
}
if (!operand_->Type()->IsReal() && !operand_->Type()->ToPointer()) {
auto scalType = TryExtractScalarType(this, operand_);
if (!scalType->IsReal() && !scalType->ToPointer()) {
Error(this, "expect operand of real type or pointer");
}
type_ = operand_->Type();
@@ -545,43 +632,78 @@ void UnaryOp::AddrOpTypeChecking() {
auto funcType = operand_->Type()->ToFunc();
if (funcType == nullptr && !operand_->IsLVal())
Error(this, "expression must be an lvalue or function designator");
if(operand_->Type()->IsTile())
Error(this, "cannot take the address of a tile");
type_ = PointerType::New(operand_->Type());
}
void UnaryOp::DerefOpTypeChecking() {
auto pointerType = operand_->Type()->ToPointer();
auto scalType = TryExtractScalarType(this, operand_);
auto pointerType = scalType->ToPointer();
if (!pointerType)
Error(this, "pointer expected for deref operator '*'");
type_ = pointerType->Derived();
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
}
void UnaryOp::TransOpTypeChecking() {
auto tileType = operand_->Type()->ToTile();
if(!tileType)
Error(this, "tile expected for transposition operator '^'");
auto shape = tileType->Shape();
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
type_ = TileType::New(shape, tileType->Derived());
}
void UnaryOp::UnaryArithmOpTypeChecking() {
auto scalType = TryExtractScalarType(this, operand_);
if (Token::PLUS == op_ || Token::MINUS == op_) {
if (!operand_->Type()->ToArithm())
if (!scalType->ToArithm())
Error(this, "Arithmetic type expected");
Convert();
type_ = operand_->Type();
} else if ('~' == op_) {
if (!operand_->Type()->IsInteger())
if (!scalType->IsInteger())
Error(this, "integer expected for operator '~'");
Convert();
type_ = operand_->Type();
} else if (!operand_->Type()->IsScalar()) {
} else if (!scalType->IsScalar()) {
Error(this, "arithmetic type or pointer expected for operator '!'");
} else {
type_ = ArithmType::New(T_INT);
type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_INT));
}
}
void UnaryOp::CastOpTypeChecking() {
auto operandType = Type::MayCast(operand_->Type());
// The type_ has been initiated to dest type
if (type_->ToVoid()) {
// The expression becomes a void expression
} else if(type_->IsTile() || operandType->IsTile()) {
/* Broadcasting rules:
* 1. Tiles with 1 element can be converted to scalar
* 2. Scalar can be converted to tiles of any shapes
* 3. Tiles can be converted to another tile only if the
* mismatching dimensions are unitary
*/
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
Error(this, "tile with more than one element cannot be casted to scalar");
if(type_->IsTile() && operandType->IsTile()){
auto shape = type_->ToTile()->Shape();
auto operandShape = operandType->ToTile()->Shape();
if(operandShape.size() > shape.size())
Error(this, "cast cannot reduce operand rank");
while(operandShape.size() < shape.size())
operandShape.insert(operandShape.begin(), 1);
for(size_t i = 0; i < shape.size(); i++) {
if(shape[i] != 1 && operandShape[i] != 1 && shape[i] != operandShape[i])
Error(this, "cannot broadcast dimension %d "
"for operands of shape %d and %d",
i, shape[i], operandShape[i]);
}
}
} else if (!type_->IsScalar() || !operandType->IsScalar()) {
if (!type_->Compatible(*operandType))
Error(this, "the cast type should be arithemetic type or pointer");

View File

@@ -442,11 +442,27 @@ Expr* Parser::ParsePostfixExprTail(Expr* lhs) {
Expr* Parser::ParseSubScripting(Expr* lhs) {
auto rhs = ParseExpr();
auto tok = ts_.Peek();
auto lhsTile = lhs->Type()->ToTile();
if(lhsTile == nullptr)
Error(lhs, "tile expected");
TileType::ShapeInt lhsShape = lhsTile->Shape();
QualType lhsQual = lhsTile->Derived();
// create ret shape
TileType::ShapeInt shape;
size_t i = 0;
do {
auto tok = ts_.Next();
if(tok->tag_ == ':')
shape.push_back(lhsShape[i++]);
else if(tok->tag_ == Token::NEWAXIS)
shape.push_back(1);
else
Error(tok, "only ':' and newaxis are supported in subscripts");
}while(ts_.Try(','));
ts_.Expect(']');
auto operand = BinaryOp::New(tok, '+', lhs, rhs);
return UnaryOp::New(Token::DEREF, operand);
// create ret tile
TileType *retType = TileType::New(shape, lhsQual);
return UnaryOp::New(Token::CAST, lhs, retType);
}
@@ -501,6 +517,7 @@ Expr* Parser::ParseUnaryExpr() {
case '-': return ParseUnaryOp(tok, Token::MINUS);
case '~': return ParseUnaryOp(tok, '~');
case '!': return ParseUnaryOp(tok, '!');
case '^': return ParseUnaryOp(tok, Token::XOR);
default:
ts_.PutBack();
return ParsePostfixExpr();
@@ -584,7 +601,7 @@ Expr* Parser::ParseCastExpr() {
Expr* Parser::ParseRangeExpr() {
auto lhs = ParseCastExpr();
auto tok = ts_.Next();
while (tok->tag_ == Token::ELLIPSIS){
while (tok->tag_ == Token::ELLIPSIS) {
auto rhs = ParseCastExpr();
lhs = BinaryOp::New(tok, lhs, rhs);
tok = ts_.Next();
@@ -593,16 +610,26 @@ Expr* Parser::ParseRangeExpr() {
return lhs;
}
Expr* Parser::ParseMultiplicativeExpr() {
Expr* Parser::ParseMatmulExpr() {
auto lhs = ParseRangeExpr();
auto tok = ts_.Next();
while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') {
while (tok->tag_ == Token::MATMUL) {
auto rhs = ParseRangeExpr();
lhs = BinaryOp::New(tok, lhs, rhs);
tok = ts_.Next();
}
ts_.PutBack();
return lhs;
}
Expr* Parser::ParseMultiplicativeExpr() {
auto lhs = ParseMatmulExpr();
auto tok = ts_.Next();
while (tok->tag_ == '*' || tok->tag_ == '/' || tok->tag_ == '%') {
auto rhs = ParseMatmulExpr();
lhs = BinaryOp::New(tok, lhs, rhs);
tok = ts_.Next();
}
ts_.PutBack();
return lhs;
}

View File

@@ -27,6 +27,7 @@ const std::unordered_map<std::string, int> Token::kwTypeMap_ {
{ "inline", Token::INLINE },
{ "int", Token::INT },
{ "long", Token::LONG },
{ "newaxis", Token::NEWAXIS },
{ "signed", Token::SIGNED },
{ "unsigned", Token::UNSIGNED },
{ "register", Token::REGISTER },
@@ -126,6 +127,7 @@ const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
{ Token::INLINE, "inline" },
{ Token::INT, "int" },
{ Token::LONG, "long" },
{ Token::NEWAXIS, "newaxis" },
{ Token::SIGNED, "signed" },
{ Token::UNSIGNED, "unsigned" },
{ Token::REGISTER, "register" },

View File

@@ -32,6 +32,18 @@ QualType Type::MayCast(QualType type, bool inProtoScope) {
return type;
}
const Type* Type::ScalarType() const {
if(IsScalar())
return this;
if(const TileType* p = ToTile())
return p->Derived().GetPtr();
return nullptr;
}
Type* Type::ScalarType() {
auto cthis = const_cast<const Type*>(this);
return const_cast<Type*>(cthis->ScalarType());
}
VoidType* VoidType::New() {
static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool);
@@ -143,12 +155,16 @@ int ArithmType::Width() const {
return intWidth_ << 1;
case T_LLONG: case T_UNSIGNED | T_LLONG:
return intWidth_ << 1;
case T_HALF:
return intWidth_ >> 1;
case T_FLOAT:
return intWidth_;
case T_DOUBLE:
return intWidth_ << 1;
case T_LONG | T_DOUBLE:
return intWidth_ << 1;
case T_HALF | T_COMPLEX:
return intWidth_;
case T_FLOAT | T_COMPLEX:
return intWidth_ << 1;
case T_DOUBLE | T_COMPLEX:
@@ -171,9 +187,10 @@ int ArithmType::Rank() const {
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3;
case T_LONG: case T_UNSIGNED | T_LONG: return 4;
case T_LLONG: case T_UNSIGNED | T_LLONG: return 5;
case T_FLOAT: return 6;
case T_DOUBLE: return 7;
case T_LONG | T_DOUBLE: return 8;
case T_HALF: return 6;
case T_FLOAT: return 7;
case T_DOUBLE: return 8;
case T_LONG | T_DOUBLE: return 9;
default:
assert(tag_ & T_COMPLEX);
Error("complex not supported yet");