1096 lines
30 KiB
C++
1096 lines
30 KiB
C++
#include "triton/lang/ast.h"
|
|
#include "triton/lang/error.h"
|
|
#include "triton/lang/evaluator.h"
|
|
#include "triton/lang/mem_pool.h"
|
|
#include "triton/lang/parser.h"
|
|
#include "triton/lang/token.h"
|
|
|
|
|
|
static MemPoolImp<BinaryOp> binaryOpPool;
|
|
static MemPoolImp<TransOp> transOpPool;
|
|
static MemPoolImp<ConditionalOp> conditionalOpPool;
|
|
static MemPoolImp<FuncCall> funcCallPool;
|
|
static MemPoolImp<Declaration> initializationPool;
|
|
static MemPoolImp<Object> objectPool;
|
|
static MemPoolImp<Identifier> identifierPool;
|
|
static MemPoolImp<Enumerator> enumeratorPool;
|
|
static MemPoolImp<Constant> constantPool;
|
|
static MemPoolImp<TempVar> tempVarPool;
|
|
static MemPoolImp<UnaryOp> unaryOpPool;
|
|
static MemPoolImp<EmptyStmt> emptyStmtPool;
|
|
static MemPoolImp<IfStmt> ifStmtPool;
|
|
static MemPoolImp<ForStmt> forStmtPool;
|
|
static MemPoolImp<JumpStmt> jumpStmtPool;
|
|
static MemPoolImp<ReturnStmt> returnStmtPool;
|
|
static MemPoolImp<LabelStmt> labelStmtPool;
|
|
static MemPoolImp<CompoundStmt> compoundStmtPool;
|
|
static MemPoolImp<FuncDef> funcDefPool;
|
|
|
|
|
|
/*
|
|
* Accept
|
|
*/
|
|
|
|
void Declaration::Accept(Visitor* v) {
|
|
v->VisitDeclaration(this);
|
|
}
|
|
|
|
|
|
void EmptyStmt::Accept(Visitor* v) {
|
|
// Nothing to do
|
|
}
|
|
|
|
|
|
void LabelStmt::Accept(Visitor* v) {
|
|
v->VisitLabelStmt(this);
|
|
}
|
|
|
|
|
|
void IfStmt::Accept(Visitor* v) {
|
|
v->VisitIfStmt(this);
|
|
}
|
|
|
|
void ForStmt::Accept(Visitor* v) {
|
|
v->VisitForStmt(this);
|
|
}
|
|
|
|
|
|
void JumpStmt::Accept(Visitor* v) {
|
|
v->VisitJumpStmt(this);
|
|
}
|
|
|
|
|
|
void ReturnStmt::Accept(Visitor* v) {
|
|
v->VisitReturnStmt(this);
|
|
}
|
|
|
|
|
|
void CompoundStmt::Accept(Visitor* v) {
|
|
v->VisitCompoundStmt(this);
|
|
}
|
|
|
|
|
|
void BinaryOp::Accept(Visitor* v) {
|
|
v->VisitBinaryOp(this);
|
|
}
|
|
|
|
|
|
void UnaryOp::Accept(Visitor* v) {
|
|
v->VisitUnaryOp(this);
|
|
}
|
|
|
|
void TransOp::Accept(Visitor* v) {
|
|
v->VisitTransOp(this);
|
|
}
|
|
|
|
void ConditionalOp::Accept(Visitor* v) {
|
|
v->VisitConditionalOp(this);
|
|
}
|
|
|
|
|
|
void FuncCall::Accept(Visitor* v) {
|
|
v->VisitFuncCall(this);
|
|
}
|
|
|
|
|
|
void Identifier::Accept(Visitor* v) {
|
|
v->VisitIdentifier(this);
|
|
}
|
|
|
|
|
|
void Object::Accept(Visitor* v) {
|
|
v->VisitObject(this);
|
|
}
|
|
|
|
|
|
void Constant::Accept(Visitor* v) {
|
|
v->VisitConstant(this);
|
|
}
|
|
|
|
|
|
void Enumerator::Accept(Visitor* v)
|
|
{
|
|
v->VisitEnumerator(this);
|
|
}
|
|
|
|
|
|
void TempVar::Accept(Visitor* v) {
|
|
v->VisitTempVar(this);
|
|
}
|
|
|
|
|
|
void FuncDef::Accept(Visitor* v) {
|
|
v->VisitFuncDef(this);
|
|
}
|
|
|
|
|
|
void TranslationUnit::Accept(Visitor* v) {
|
|
v->VisitTranslationUnit(this);
|
|
}
|
|
|
|
|
|
// Casting array to pointer, function to pointer to function
|
|
Expr* Expr::MayCast(Expr* expr) {
|
|
auto type = Type::MayCast(expr->Type());
|
|
// If the types are equal, no need cast
|
|
if (type != expr->Type()) { // Pointer comparison is enough
|
|
return UnaryOp::New(Token::CAST, expr, type);
|
|
}
|
|
return expr;
|
|
}
|
|
|
|
|
|
Expr* Expr::MayCast(Expr* expr, QualType desType) {
|
|
expr = MayCast(expr);
|
|
auto srcType = expr->Type();
|
|
if (desType->ToPointer() && srcType->ToPointer())
|
|
if (desType->IsVoidPointer() || srcType->IsVoidPointer())
|
|
return expr;
|
|
if (!desType->Compatible(*expr->Type()))
|
|
expr = UnaryOp::New(Token::CAST, expr, desType);
|
|
return expr;
|
|
}
|
|
|
|
// Extract the operand's scalar type if possible
|
|
// and emit an error otherwise
|
|
::Type* Expr::TryExtractScalarType(Expr* loc, Expr *operand) {
|
|
auto scalType = operand->Type()->ScalarType();
|
|
if(!scalType)
|
|
Error(loc, "expect tile or scalar operand");
|
|
return scalType;
|
|
}
|
|
|
|
// If operand is a tile, return a tile of the same shape and
|
|
// provided element type
|
|
// If operand is a scalar, return provided element type
|
|
// directly
|
|
::Type* Expr::ScalarOrLikeTile(Expr* operand, ::Type* ty) {
|
|
assert(ty->IsScalar());
|
|
::Type *retTy = ty;
|
|
if(TileType *T = operand->Type()->ToTile())
|
|
retTy = TileType::New(T->Shape(), retTy);
|
|
return retTy;
|
|
}
|
|
|
|
BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) {
|
|
return New(tok, tok->tag_, lhs, rhs);
|
|
}
|
|
|
|
|
|
BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
|
|
switch (op) {
|
|
case ',': case '.': case '=':
|
|
case '*': case '/': case '%':
|
|
case '+': case '-': case '&':
|
|
case '^': case '|': case '<':
|
|
case '>':
|
|
case Token::LEFT:
|
|
case Token::RIGHT:
|
|
case Token::LE:
|
|
case Token::GE:
|
|
case Token::EQ:
|
|
case Token::NE:
|
|
case Token::LOGICAL_AND:
|
|
case Token::LOGICAL_OR:
|
|
case Token::ELLIPSIS:
|
|
case Token::MATMUL:
|
|
case Token::MASKED_DEREF:
|
|
break;
|
|
default:
|
|
assert(0);
|
|
}
|
|
|
|
auto ret = new (binaryOpPool.Alloc()) BinaryOp(tok, op, lhs, rhs);
|
|
ret->pool_ = &binaryOpPool;
|
|
|
|
ret->TypeChecking();
|
|
return ret;
|
|
}
|
|
|
|
|
|
ArithmType* BinaryOp::Convert() {
|
|
// Both lhs and rhs are ensured to be have arithmetic scalar type
|
|
auto lhsType = lhs_->Type()->ScalarType()->ToArithm();
|
|
auto rhsType = rhs_->Type()->ScalarType()->ToArithm();
|
|
assert(lhsType && rhsType);
|
|
auto maxType = ArithmType::MaxType(lhsType, rhsType);
|
|
if (lhsType != maxType) { // Pointer comparation is enough!
|
|
lhs_ = UnaryOp::New(Token::CAST, lhs_, ScalarOrLikeTile(lhs_, maxType));
|
|
}
|
|
if (rhsType != maxType) {
|
|
rhs_ = UnaryOp::New(Token::CAST, rhs_, ScalarOrLikeTile(rhs_, maxType));
|
|
}
|
|
return maxType;
|
|
}
|
|
|
|
void BinaryOp::Broadcast(Expr* loc, Expr *&lhs, Expr *&rhs, QualType& type) {
|
|
auto lhsType = lhs->Type()->ToTile();
|
|
auto rhsType = rhs->Type()->ToTile();
|
|
auto eleType = type->ScalarType();
|
|
assert(eleType);
|
|
if(!lhsType && !rhsType)
|
|
return ;
|
|
else if(lhsType && !rhsType){
|
|
type = TileType::New(lhsType->Shape(), eleType);
|
|
::Type* rtype = TileType::New(lhsType->Shape(), rhs->Type()->ScalarType());
|
|
rhs = UnaryOp::New(Token::CAST, rhs, rtype);
|
|
}
|
|
else if(!lhsType && rhsType){
|
|
type = TileType::New(rhsType->Shape(), eleType);
|
|
::Type* ltype = TileType::New(rhsType->Shape(), lhs->Type()->ScalarType());
|
|
lhs = UnaryOp::New(Token::CAST, lhs, ltype);
|
|
|
|
}
|
|
else {
|
|
auto lhsShape = lhsType->Shape();
|
|
auto rhsShape = rhsType->Shape();
|
|
auto lhsRank = lhsShape.size();
|
|
auto rhsRank = rhsShape.size();
|
|
auto retRank = std::max(lhsRank, rhsRank);
|
|
// pad to the left until shapes have the same rank
|
|
while(lhsShape.size() < retRank)
|
|
lhsShape.insert(lhsShape.begin(), 1);
|
|
while(rhsShape.size() < retRank)
|
|
rhsShape.insert(rhsShape.begin(), 1);
|
|
// broadcast if possible
|
|
TileType::ShapeInt retShape(retRank);
|
|
for(size_t i = 0; i < retRank; i++) {
|
|
if(lhsShape[i] == 1)
|
|
retShape[i] = rhsShape[i];
|
|
else if(rhsShape[i] == 1)
|
|
retShape[i] = lhsShape[i];
|
|
else if(lhsShape[i] == rhsShape[i])
|
|
retShape[i] = lhsShape[i];
|
|
else
|
|
Error(loc, "cannot broadcast dimension %d "
|
|
"for operands of shape %d and %d",
|
|
i, lhsShape[i], rhsShape[i]);
|
|
}
|
|
::Type* ltype = TileType::New(retShape, lhsType->ScalarType());
|
|
::Type* rtype = TileType::New(retShape, rhsType->ScalarType());
|
|
type = TileType::New(retShape, eleType);
|
|
if(retShape != lhsShape)
|
|
lhs = UnaryOp::New(Token::CAST, lhs, ltype);
|
|
if(retShape != rhsShape)
|
|
rhs = UnaryOp::New(Token::CAST, rhs, rtype);
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Type checking
|
|
*/
|
|
|
|
void Expr::EnsureCompatibleOrVoidPointer(const QualType lhs,
|
|
const QualType rhs) const {
|
|
if (lhs->ToPointer() && rhs->ToPointer() &&
|
|
(lhs->IsVoidPointer() || rhs->IsVoidPointer())) {
|
|
return;
|
|
}
|
|
EnsureCompatible(lhs, rhs);
|
|
}
|
|
|
|
|
|
void Expr::EnsureCompatible(const QualType lhs, const QualType rhs) const {
|
|
if (!lhs->Compatible(*rhs))
|
|
Error(this, "incompatible types");
|
|
}
|
|
|
|
|
|
void BinaryOp::TypeChecking() {
|
|
switch (op_) {
|
|
case '.':
|
|
return MemberRefOpTypeChecking();
|
|
|
|
case '*':
|
|
case '/':
|
|
case '%':
|
|
return MultiOpTypeChecking();
|
|
|
|
case '+':
|
|
case '-':
|
|
return AdditiveOpTypeChecking();
|
|
|
|
case Token::LEFT:
|
|
case Token::RIGHT:
|
|
return ShiftOpTypeChecking();
|
|
|
|
case '<':
|
|
case '>':
|
|
case Token::LE:
|
|
case Token::GE:
|
|
return RelationalOpTypeChecking();
|
|
|
|
case Token::EQ:
|
|
case Token::NE:
|
|
return EqualityOpTypeChecking();
|
|
|
|
case '&':
|
|
case '^':
|
|
case '|':
|
|
return BitwiseOpTypeChecking();
|
|
|
|
case Token::LOGICAL_AND:
|
|
case Token::LOGICAL_OR:
|
|
return LogicalOpTypeChecking();
|
|
|
|
case '=':
|
|
return AssignOpTypeChecking();
|
|
|
|
case ',':
|
|
return CommaOpTypeChecking();
|
|
|
|
case Token::ELLIPSIS:
|
|
return RangeOpTypeChecking();
|
|
|
|
case Token::MATMUL:
|
|
return MatmulOpTypeChecking();
|
|
|
|
case Token::MASKED_DEREF:
|
|
return MaskedDerefOpTypeChecking();
|
|
|
|
default:
|
|
assert(0);
|
|
}
|
|
}
|
|
|
|
|
|
void BinaryOp::CommaOpTypeChecking() {
|
|
type_ = rhs_->Type();
|
|
}
|
|
|
|
|
|
void BinaryOp::SubScriptingOpTypeChecking() {
|
|
assert(false);
|
|
}
|
|
|
|
|
|
void BinaryOp::MemberRefOpTypeChecking() {
|
|
type_ = rhs_->Type();
|
|
}
|
|
|
|
|
|
void BinaryOp::MultiOpTypeChecking() {
|
|
::Type* lhsScalType = lhs_->Type()->ScalarType();
|
|
::Type* rhsScalType = rhs_->Type()->ScalarType();
|
|
if(!lhsScalType || !rhsScalType) {
|
|
Error(this, "operands should have type or scalar type");
|
|
}
|
|
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
|
Error(this, "operands should have arithmetic type");
|
|
}
|
|
if ('%' == op_ &&
|
|
!(lhsScalType->IsInteger() && rhsScalType->IsInteger())) {
|
|
Error(this, "operands of '%%' should be integers");
|
|
}
|
|
type_ = Convert();
|
|
Broadcast(this, lhs_, rhs_, type_);
|
|
}
|
|
|
|
|
|
/*
|
|
* Additive operator is only allowed between:
|
|
* 1. arithmetic types (bool, interger, floating)
|
|
* 2. pointer can be used:
|
|
* 1. lhs of MINUS operator, and rhs must be integer or pointer;
|
|
* 2. lhs/rhs of ADD operator, and the other operand must be integer;
|
|
* 3. tiles can be used:
|
|
* 1. the scalar type of lhs/rhs satisfy the above requirements
|
|
* 2. lhs/rhs that have identical shape
|
|
* 3. lhs/rhs that can be broadcast as per numpy-like semantics
|
|
*/
|
|
void BinaryOp::AdditiveOpTypeChecking() {
|
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
|
auto lhsPtrType = lhsScalType->ToPointer();
|
|
auto rhsPtrType = rhsScalType->ToPointer();
|
|
if (lhsPtrType) {
|
|
if (op_ == '-') {
|
|
if (rhsPtrType) {
|
|
if (!lhsPtrType->Compatible(*rhsPtrType))
|
|
Error(this, "invalid operands to binary -");
|
|
type_ = ArithmType::New(T_LONG); // ptrdiff_t
|
|
} else if (!rhsScalType->IsInteger()) {
|
|
Error(this, "invalid operands to binary -");
|
|
} else {
|
|
type_ = lhsPtrType;
|
|
}
|
|
} else if (!rhsScalType->IsInteger()) {
|
|
Error(this, "invalid operands to binary +");
|
|
} else {
|
|
type_ = lhsPtrType;
|
|
}
|
|
} else if (rhsPtrType) {
|
|
if (op_ == '+' && !lhsScalType->IsInteger()) {
|
|
Error(this, "invalid operands to binary '+'");
|
|
} else if (op_ == '-' && !lhsPtrType) {
|
|
Error(this, "invalid operands to binary '-'");
|
|
}
|
|
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhsScalType;
|
|
std::swap(lhs_, rhs_); // To simplify code gen
|
|
} else {
|
|
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
|
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
|
}
|
|
type_ = Convert();
|
|
}
|
|
Broadcast(this, lhs_, rhs_, type_);
|
|
}
|
|
|
|
void BinaryOp::RangeOpTypeChecking() {
|
|
auto lhsType = lhs_->Type()->ToArithm();
|
|
auto rhsType = rhs_->Type()->ToArithm();
|
|
if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
|
|
Error(this, "expect integers for range operator");
|
|
lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType));
|
|
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
|
|
long begin = Evaluator<long>().Eval(lhs_);
|
|
long end = Evaluator<long>().Eval(rhs_);
|
|
int len = static_cast<int>(end - begin);
|
|
if(len < 0)
|
|
Error(this, "range cannot be negative");
|
|
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
|
|
}
|
|
|
|
void BinaryOp::MaskedDerefOpTypeChecking() {
|
|
// auto lhsTileType = lhs_->Type()->ToTile();
|
|
// auto rhsTileType = rhs_->Type()->ToTile();
|
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
|
auto lhsType = lhsScalType->ToArithm();
|
|
auto rhsType = rhsScalType->ToPointer();
|
|
if (!rhsType)
|
|
Error(this, "pointer expected for deref pointer in operator '*?'");
|
|
if (!lhsType || (lhsType && !lhsType->IsBool()))
|
|
Error(this, "bool expected for deref mask in operator '*?'");
|
|
type_ = ScalarOrLikeTile(rhs_, rhsType->Derived().GetPtr());
|
|
Broadcast(this, lhs_, rhs_, type_);
|
|
}
|
|
|
|
void BinaryOp::MatmulOpTypeChecking() {
|
|
auto lhsType = lhs_->Type()->ToTile();
|
|
auto rhsType = rhs_->Type()->ToTile();
|
|
if(!lhsType || !rhsType)
|
|
Error(this, "expect tile operands for matrix multiplication");
|
|
auto lhsShape = lhsType->Shape();
|
|
auto rhsShape = rhsType->Shape();
|
|
size_t lhsRank = lhsShape.size();
|
|
size_t rhsRank = rhsShape.size();
|
|
if(lhsRank != rhsRank)
|
|
Error(this, "matrix multiplication operands have incompatible rank"
|
|
"%d and %d", lhsRank, rhsRank);
|
|
for(int d = 2; d < lhsRank; d++)
|
|
if(lhsShape[d] != rhsShape[d])
|
|
Error(this, "matrix multiplication operands have incompatible batch dimension"
|
|
"%d and %d for axis %d", lhsShape[d], rhsShape[d], d);
|
|
if(lhsShape[1] != rhsShape[0])
|
|
Error(this, "matrix multiplication operands have incompatible inner dimension"
|
|
" %d and %d", lhsShape[1], rhsShape[0]);
|
|
// ret shape
|
|
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
|
|
for(int d = 2; d < lhsRank; d++)
|
|
retShape.push_back(lhsShape[d]);
|
|
QualType retType = lhsType->Derived();
|
|
if(retType != rhsType->Derived())
|
|
Error(this, "matrix multiplication operands have incompatible data types");
|
|
ArithmType* ScalType = lhsType->ScalarType()->ToArithm();
|
|
if(ScalType->Tag() & T_HALF)
|
|
ScalType = ArithmType::New(T_FLOAT);
|
|
type_ = TileType::New(retShape, ScalType);
|
|
}
|
|
|
|
void BinaryOp::ShiftOpTypeChecking() {
|
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
|
auto lhsType = lhsScalType->ToArithm();
|
|
auto rhsType = rhsScalType->ToArithm();
|
|
if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
|
|
Error(this, "expect integers for shift operator");
|
|
lhs_ = Expr::MayCast(lhs_, ScalarOrLikeTile(lhs_, ArithmType::IntegerPromote(lhsType)));
|
|
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, ArithmType::IntegerPromote(rhsType)));
|
|
type_ = lhs_->Type();
|
|
Broadcast(this, lhs_, rhs_, type_);
|
|
}
|
|
|
|
|
|
void BinaryOp::RelationalOpTypeChecking() {
|
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
|
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
|
|
EnsureCompatible(lhsScalType, rhsScalType);
|
|
} else {
|
|
if (!lhsScalType->IsReal() || !rhsScalType->IsReal()) {
|
|
Error(this, "expect real type of operands");
|
|
}
|
|
Convert();
|
|
}
|
|
type_ = ArithmType::New(T_INT);
|
|
Broadcast(this, lhs_, rhs_, type_);
|
|
}
|
|
|
|
|
|
void BinaryOp::EqualityOpTypeChecking() {
|
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
|
if (lhsScalType->ToPointer() || rhsScalType->ToPointer()) {
|
|
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
|
|
} else {
|
|
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm())
|
|
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
|
|
Convert();
|
|
}
|
|
type_ = ArithmType::New(T_INT);
|
|
Broadcast(this, lhs_, rhs_, type_);
|
|
}
|
|
|
|
|
|
void BinaryOp::BitwiseOpTypeChecking() {
|
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
|
if (!lhsScalType->IsInteger() || !rhsScalType->IsInteger())
|
|
Error(this, "operands of '&' should be integer");
|
|
type_ = Convert();
|
|
Broadcast(this, lhs_, rhs_, type_);
|
|
}
|
|
|
|
|
|
void BinaryOp::LogicalOpTypeChecking() {
|
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
|
if (!lhsScalType->IsScalar() || !rhsScalType->IsScalar())
|
|
Error(this, "the operand should be arithmetic type or pointer");
|
|
type_ = ArithmType::New(T_INT);
|
|
Broadcast(this, lhs_, rhs_, type_);
|
|
}
|
|
|
|
|
|
void BinaryOp::AssignOpTypeChecking() {
|
|
if (lhs_->IsConstQualified()) {
|
|
Error(lhs_, "left operand of '=' is const qualified");
|
|
} else if (!lhs_->IsLVal()) {
|
|
Error(lhs_, "lvalue expression expected");
|
|
}
|
|
|
|
::Type* lhsScalType = TryExtractScalarType(this, lhs_);
|
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
|
if (!lhsScalType->ToArithm() || !rhsScalType->ToArithm()) {
|
|
EnsureCompatibleOrVoidPointer(lhsScalType, rhsScalType);
|
|
}
|
|
|
|
// The other constraints are lefted to cast operator
|
|
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
|
|
type_ = lhs_->Type();
|
|
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
|
|
}
|
|
|
|
/*
|
|
* Unary Operators
|
|
*/
|
|
|
|
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type, int info) {
|
|
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type, info);
|
|
ret->pool_ = &unaryOpPool;
|
|
|
|
ret->TypeChecking();
|
|
return ret;
|
|
}
|
|
|
|
|
|
int UnaryOp::encodeRed(int ax, int tag) {
|
|
int result = 0;
|
|
result |= ax;
|
|
result |= tag << 16;
|
|
return result;
|
|
}
|
|
|
|
void UnaryOp::decodeRed(int info, int& ax, int& tag) {
|
|
ax = info & 0x0000FFFF;
|
|
tag = (info & 0xFFFF0000) >> 16;
|
|
}
|
|
|
|
bool UnaryOp::IsLVal() {
|
|
// Only deref('*') could be lvalue;
|
|
return op_ == Token::DEREF;
|
|
}
|
|
|
|
|
|
::Type* UnaryOp::Convert() {
|
|
auto scalType = operand_->Type()->ScalarType();
|
|
assert(scalType);
|
|
auto arithmType = scalType->ToArithm();
|
|
assert(arithmType);
|
|
if (arithmType->IsInteger())
|
|
arithmType = ArithmType::IntegerPromote(arithmType);
|
|
::Type* retType = ScalarOrLikeTile(operand_, arithmType);
|
|
operand_ = Expr::MayCast(operand_, retType);
|
|
return retType;
|
|
}
|
|
|
|
|
|
void UnaryOp::TypeChecking() {
|
|
switch (op_) {
|
|
case Token::POSTFIX_INC:
|
|
case Token::POSTFIX_DEC:
|
|
case Token::PREFIX_INC:
|
|
case Token::PREFIX_DEC:
|
|
return IncDecOpTypeChecking();
|
|
|
|
case Token::ADDR:
|
|
return AddrOpTypeChecking();
|
|
|
|
case Token::DEREF:
|
|
return DerefOpTypeChecking();
|
|
|
|
case Token::PLUS:
|
|
case Token::MINUS:
|
|
case '~':
|
|
case '!':
|
|
return UnaryArithmOpTypeChecking();
|
|
|
|
case Token::BITCAST:
|
|
return BitcastOpTypeChecking();
|
|
|
|
case Token::CAST:
|
|
return CastOpTypeChecking();
|
|
|
|
case Token::REDUCE:
|
|
return ReduceOpTypeChecking();
|
|
|
|
default:
|
|
assert(false);
|
|
}
|
|
}
|
|
|
|
void UnaryOp::IncDecOpTypeChecking() {
|
|
if (operand_->IsConstQualified()) {
|
|
Error(this, "increment/decrement of const qualified expression");
|
|
} else if (!operand_->IsLVal()) {
|
|
Error(this, "lvalue expression expected");
|
|
}
|
|
auto scalType = TryExtractScalarType(this, operand_);
|
|
if (!scalType->IsReal() && !scalType->ToPointer()) {
|
|
Error(this, "expect operand of real type or pointer");
|
|
}
|
|
type_ = operand_->Type();
|
|
}
|
|
|
|
|
|
void UnaryOp::AddrOpTypeChecking() {
|
|
auto funcType = operand_->Type()->ToFunc();
|
|
if (funcType == nullptr && !operand_->IsLVal())
|
|
Error(this, "expression must be an lvalue or function designator");
|
|
if(operand_->Type()->IsTile())
|
|
Error(this, "cannot take the address of a tile");
|
|
type_ = PointerType::New(operand_->Type());
|
|
}
|
|
|
|
|
|
void UnaryOp::DerefOpTypeChecking() {
|
|
auto scalType = TryExtractScalarType(this, operand_);
|
|
auto pointerType = scalType->ToPointer();
|
|
if (!pointerType)
|
|
Error(this, "pointer expected for deref operator '*'");
|
|
type_ = ScalarOrLikeTile(operand_, pointerType->Derived().GetPtr());
|
|
}
|
|
|
|
void UnaryOp::ReduceOpTypeChecking() {
|
|
int ax, tag;
|
|
decodeRed(info_, ax, tag);
|
|
auto tileType = operand_->Type()->ToTile();
|
|
if(!tileType)
|
|
Error(this, "array expected for reduction operation");
|
|
auto shape = tileType->Shape();
|
|
shape.erase(shape.begin() + ax);
|
|
if(shape.empty())
|
|
type_ = tileType->Derived();
|
|
else
|
|
type_ = TileType::New(shape, tileType->Derived());
|
|
}
|
|
|
|
void UnaryOp::UnaryArithmOpTypeChecking() {
|
|
auto scalType = TryExtractScalarType(this, operand_);
|
|
if (Token::PLUS == op_ || Token::MINUS == op_) {
|
|
if (!scalType->ToArithm())
|
|
Error(this, "Arithmetic type expected");
|
|
Convert();
|
|
type_ = operand_->Type();
|
|
} else if ('~' == op_) {
|
|
if (!scalType->IsInteger())
|
|
Error(this, "integer expected for operator '~'");
|
|
Convert();
|
|
type_ = operand_->Type();
|
|
} else if (!scalType->IsScalar()) {
|
|
Error(this, "arithmetic type or pointer expected for operator '!'");
|
|
} else {
|
|
type_ = ScalarOrLikeTile(operand_, ArithmType::New(T_INT));
|
|
}
|
|
}
|
|
|
|
void UnaryOp::BitcastOpTypeChecking() {
|
|
auto operandType = Type::MayCast(operand_->Type());
|
|
if(type_->Width() != operandType->Width())
|
|
Error(this, "cannot bitcast to type of different width");
|
|
}
|
|
|
|
void UnaryOp::CastOpTypeChecking() {
|
|
auto operandType = Type::MayCast(operand_->Type());
|
|
// The type_ has been initiated to dest type
|
|
if (type_->ToVoid()) {
|
|
// The expression becomes a void expression
|
|
} else if(type_->IsTile() || operandType->IsTile()) {
|
|
/* Broadcasting rules:
|
|
* 1. Tiles with 1 element can be converted to scalar
|
|
* 2. Scalar can be converted to tiles of any shapes
|
|
* 3. Tiles can be converted to another tile only if the
|
|
* mismatching dimensions are unitary
|
|
*/
|
|
if(type_->IsScalar() && operandType->ToTile()->NumEle() != 1)
|
|
Error(this, "tile with more than one element cannot be casted to scalar");
|
|
if(type_->IsTile() && operandType->IsTile()){
|
|
auto shape = type_->ToTile()->Shape();
|
|
auto operandShape = operandType->ToTile()->Shape();
|
|
if(operandShape.size() > shape.size())
|
|
Error(this, "cast cannot reduce operand rank");
|
|
while(operandShape.size() < shape.size())
|
|
operandShape.insert(operandShape.begin(), 1);
|
|
for(size_t i = 0; i < shape.size(); i++) {
|
|
if(shape[i] != 1 && operandShape[i] != 1 && shape[i] != operandShape[i])
|
|
Error(this, "cannot broadcast dimension %d "
|
|
"for operands of shape %d and %d",
|
|
i, shape[i], operandShape[i]);
|
|
}
|
|
}
|
|
} else if (!type_->IsScalar() || !operandType->IsScalar()) {
|
|
if (!type_->Compatible(*operandType))
|
|
Error(this, "the cast type should be arithemetic type or pointer");
|
|
} else if (type_->IsFloat() && operandType->ToPointer()) {
|
|
Error(this, "cannot cast a pointer to floating");
|
|
} else if (type_->ToPointer() && operandType->IsFloat()) {
|
|
Error(this, "cannot cast a floating to pointer");
|
|
}
|
|
}
|
|
|
|
/*
|
|
* Transposition Operator
|
|
*/
|
|
void TransOp::TypeChecking() {
|
|
auto tileType = operand_->Type()->ToTile();
|
|
if(!tileType)
|
|
Error(this, "tile expected for transposition operator '^'");
|
|
auto opShape = tileType->Shape();
|
|
if(perm_.size() != opShape.size())
|
|
Error(this, "invalid permutations");
|
|
// permutate input shape
|
|
TileType::ShapeInt resShape(opShape.size());
|
|
for(int d = 0; d < opShape.size(); d++)
|
|
resShape[d] = opShape[perm_[d]];
|
|
type_ = TileType::New(resShape, tileType->Derived());
|
|
}
|
|
|
|
TransOp* TransOp::New(const PermInt& perm, Expr* operand) {
|
|
auto ret = new (transOpPool.Alloc()) TransOp(perm, operand);
|
|
ret->pool_ = &transOpPool;
|
|
ret->TypeChecking();
|
|
return ret;
|
|
}
|
|
|
|
/*
|
|
* Conditional Operator
|
|
*/
|
|
|
|
ConditionalOp* ConditionalOp::New(const Token* tok,
|
|
Expr* cond,
|
|
Expr* exprTrue,
|
|
Expr* exprFalse) {
|
|
auto ret = new (conditionalOpPool.Alloc())
|
|
ConditionalOp(cond, exprTrue, exprFalse);
|
|
ret->pool_ = &conditionalOpPool;
|
|
|
|
ret->TypeChecking();
|
|
return ret;
|
|
}
|
|
|
|
|
|
ArithmType* ConditionalOp::Convert() {
|
|
auto lhsType = exprTrue_->Type()->ScalarType()->ToArithm();
|
|
auto rhsType = exprFalse_->Type()->ScalarType()->ToArithm();
|
|
assert(lhsType && rhsType);
|
|
auto type = ArithmType::MaxType(lhsType, rhsType);
|
|
if (lhsType != type) { // Pointer comparation is enough!
|
|
exprTrue_ = UnaryOp::New(Token::CAST, exprTrue_, type);
|
|
}
|
|
if (rhsType != type) {
|
|
exprFalse_ = UnaryOp::New(Token::CAST, exprFalse_, type);
|
|
}
|
|
|
|
return type;
|
|
}
|
|
|
|
|
|
void ConditionalOp::TypeChecking() {
|
|
auto condScalarType = TryExtractScalarType(this, cond_);
|
|
|
|
if (!condScalarType) {
|
|
Error(cond_->Tok(), "condition must be tile or scalar");
|
|
}
|
|
|
|
auto lhsType = TryExtractScalarType(this, exprTrue_);
|
|
auto rhsType = TryExtractScalarType(this, exprFalse_);
|
|
if (lhsType->ToArithm() && rhsType->ToArithm()) {
|
|
type_ = Convert();
|
|
} else {
|
|
EnsureCompatibleOrVoidPointer(lhsType, rhsType);
|
|
type_ = lhsType;
|
|
}
|
|
BinaryOp::Broadcast(this, exprFalse_, exprTrue_, type_);
|
|
}
|
|
|
|
|
|
/*
|
|
* Function Call
|
|
*/
|
|
|
|
FuncCall* FuncCall::New(Expr* designator, const ArgList& args) {
|
|
auto ret = new (funcCallPool.Alloc()) FuncCall(designator, args);
|
|
ret->pool_ = &funcCallPool;
|
|
|
|
ret->TypeChecking();
|
|
return ret;
|
|
}
|
|
|
|
|
|
void FuncCall::TypeChecking() {
|
|
auto pointerType = designator_->Type()->ToPointer();
|
|
if (pointerType) {
|
|
if (!pointerType->Derived()->ToFunc())
|
|
Error(designator_, "called object is not a function or function pointer");
|
|
// Convert function pointer to function type
|
|
designator_ = UnaryOp::New(Token::DEREF, designator_);
|
|
}
|
|
auto funcType = designator_->Type()->ToFunc();
|
|
if (!funcType) {
|
|
Error(designator_, "called object is not a function or function pointer");
|
|
} else if (!funcType->Derived()->ToVoid() &&
|
|
!funcType->Derived()->Complete()) {
|
|
Error(designator_, "invalid use of incomplete return type");
|
|
}
|
|
|
|
auto arg = args_.begin();
|
|
for (auto param: funcType->Params()) {
|
|
if (arg == args_.end())
|
|
Error(this, "too few arguments for function call");
|
|
*arg = Expr::MayCast(*arg, param->Type());
|
|
++arg;
|
|
}
|
|
if (arg != args_.end() && !funcType->Variadic())
|
|
Error(this, "too many arguments for function call");
|
|
|
|
// C11 6.5.2.2 [6]: promote float to double if it has no prototype
|
|
while (arg != args_.end()) {
|
|
if ((*arg)->Type()->IsFloat() && (*arg)->Type()->Width() == 4) {
|
|
auto type = ArithmType::New(T_DOUBLE);
|
|
*arg = UnaryOp::New(Token::CAST, *arg, type);
|
|
}
|
|
++arg;
|
|
}
|
|
|
|
type_ = funcType->Derived();
|
|
}
|
|
|
|
|
|
/*
|
|
* Identifier
|
|
*/
|
|
|
|
Identifier* Identifier::New(const Token* tok,
|
|
QualType type,
|
|
enum Linkage linkage,
|
|
const AttrList &attrList) {
|
|
auto ret = new (identifierPool.Alloc()) Identifier(tok, type, linkage, attrList);
|
|
ret->pool_ = &identifierPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
Enumerator* Enumerator::New(const Token* tok, int val) {
|
|
auto ret = new (enumeratorPool.Alloc()) Enumerator(tok, val);
|
|
ret->pool_ = &enumeratorPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
Declaration* Declaration::New(Object* obj) {
|
|
auto ret = new (initializationPool.Alloc()) Declaration(obj);
|
|
ret->pool_ = &initializationPool;
|
|
return ret;
|
|
}
|
|
|
|
void Declaration::AddInit(Initializer init) {
|
|
init.expr_ = Expr::MayCast(init.expr_, init.type_);
|
|
auto res = inits_.insert(init);
|
|
if (!res.second) {
|
|
inits_.erase(res.first);
|
|
inits_.insert(init);
|
|
}
|
|
}
|
|
|
|
|
|
/*
|
|
* Object
|
|
*/
|
|
|
|
Object* Object::New(const Token* tok,
|
|
QualType type,
|
|
int storage,
|
|
enum Linkage linkage,
|
|
unsigned char bitFieldBegin,
|
|
unsigned char bitFieldWidth,
|
|
const AttrList& attrList) {
|
|
auto ret = new (objectPool.Alloc())
|
|
Object(tok, type, storage, linkage, bitFieldBegin, bitFieldWidth, attrList);
|
|
ret->pool_ = &objectPool;
|
|
|
|
static long id = 0;
|
|
if (ret->IsStatic() || ret->Anonymous())
|
|
ret->id_ = ++id;
|
|
return ret;
|
|
}
|
|
|
|
|
|
Object* Object::NewAnony(const Token* tok,
|
|
QualType type,
|
|
int storage,
|
|
enum Linkage linkage,
|
|
unsigned char bitFieldBegin,
|
|
unsigned char bitFieldWidth,
|
|
const AttrList& attrList) {
|
|
auto ret = new (objectPool.Alloc())
|
|
Object(tok, type, storage, linkage, bitFieldBegin, bitFieldWidth, attrList);
|
|
ret->pool_ = &objectPool;
|
|
ret->anonymous_ = true;
|
|
|
|
static long id = 0;
|
|
if (ret->IsStatic() || ret->anonymous_)
|
|
ret->id_ = ++id;
|
|
return ret;
|
|
}
|
|
|
|
|
|
/*
|
|
* Constant
|
|
*/
|
|
|
|
Constant* Constant::New(const Token* tok, int tag, long val) {
|
|
auto type = ArithmType::New(tag);
|
|
auto ret = new (constantPool.Alloc()) Constant(tok, type, val);
|
|
ret->pool_ = &constantPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
Constant* Constant::New(const Token* tok, int tag, double val) {
|
|
auto type = ArithmType::New(tag);
|
|
auto ret = new (constantPool.Alloc()) Constant(tok, type, val);
|
|
ret->pool_ = &constantPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
Constant* Constant::New(const Token* tok, int tag, const std::string* val) {
|
|
auto derived = ArithmType::New(tag);
|
|
auto type = ArrayType::New(val->size() / derived->Width(), derived);
|
|
|
|
auto ret = new (constantPool.Alloc()) Constant(tok, type, val);
|
|
ret->pool_ = &constantPool;
|
|
|
|
static long id = 0;
|
|
ret->id_ = ++id;
|
|
return ret;
|
|
}
|
|
|
|
|
|
std::string Constant::SValRepr() const {
|
|
std::vector<char> buf(4 * sval_->size() + 1);
|
|
for (size_t i = 0; i < sval_->size(); ++i) {
|
|
int c = (*sval_)[i];
|
|
sprintf(&buf[i * 4], "\\x%1x%1x", (c >> 4) & 0xf, c & 0xf);
|
|
}
|
|
return std::string(buf.begin(), buf.end() - 1);
|
|
}
|
|
|
|
|
|
/*
|
|
* TempVar
|
|
*/
|
|
|
|
TempVar* TempVar::New(QualType type) {
|
|
auto ret = new (tempVarPool.Alloc()) TempVar(type);
|
|
ret->pool_ = &tempVarPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
/*
|
|
* Statement
|
|
*/
|
|
|
|
EmptyStmt* EmptyStmt::New() {
|
|
auto ret = new (emptyStmtPool.Alloc()) EmptyStmt();
|
|
ret->pool_ = &emptyStmtPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
// The else stmt could be null
|
|
IfStmt* IfStmt::New(Expr* cond, Stmt* then, Stmt* els) {
|
|
auto ret = new (ifStmtPool.Alloc()) IfStmt(cond, then, els);
|
|
ret->pool_ = &ifStmtPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
CompoundStmt* CompoundStmt::New(std::list<Stmt*>& stmts, ::Scope* scope) {
|
|
auto ret = new (compoundStmtPool.Alloc()) CompoundStmt(stmts, scope);
|
|
ret->pool_ = &compoundStmtPool;
|
|
return ret;
|
|
}
|
|
|
|
ForStmt* ForStmt::New(Stmt* body, Stmt* init, Expr* cond, Expr* step) {
|
|
auto ret = new (forStmtPool.Alloc()) ForStmt(body, init, cond, step);
|
|
ret->pool_ = &forStmtPool;
|
|
return ret;
|
|
}
|
|
|
|
JumpStmt* JumpStmt::New(LabelStmt* label) {
|
|
auto ret = new (jumpStmtPool.Alloc()) JumpStmt(label);
|
|
ret->pool_ = &jumpStmtPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
ReturnStmt* ReturnStmt::New(Expr* expr) {
|
|
auto ret = new (returnStmtPool.Alloc()) ReturnStmt(expr);
|
|
ret->pool_ = &returnStmtPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
LabelStmt* LabelStmt::New() {
|
|
auto ret = new (labelStmtPool.Alloc()) LabelStmt();
|
|
ret->pool_ = &labelStmtPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
FuncDef* FuncDef::New(Identifier* ident, LabelStmt* retLabel) {
|
|
auto ret = new (funcDefPool.Alloc()) FuncDef(ident, retLabel);
|
|
ret->pool_ = &funcDefPool;
|
|
return ret;
|
|
}
|
|
|
|
|
|
bool Initializer::operator<(const Initializer& rhs) const {
|
|
if (offset_ < rhs.offset_)
|
|
return true;
|
|
return (offset_ == rhs.offset_ && bitFieldBegin_ < rhs.bitFieldBegin_);
|
|
}
|