[lang] more progress on parser

This commit is contained in:
Philippe Tillet
2019-08-19 20:56:39 -07:00
parent 0970fe12dd
commit bc11e31419
30 changed files with 10862 additions and 158 deletions

View File

@@ -62,15 +62,15 @@ endif()
# Triton # Triton
file(GLOB_RECURSE LIBTRITON_SRC lib/*.cpp) file(GLOB_RECURSE LIBTRITON_SRC lib/*.cpp lib/*.cc)
add_library(triton SHARED ${LIBTRITON_SRC} ${EIGHTCC_SRC} ${PYTHON_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS}) add_library(triton SHARED ${LIBTRITON_SRC} ${EIGHTCC_SRC} ${PYTHON_SRC} ${BISON_Parser_OUTPUTS} ${FLEX_Lexer_OUTPUTS})
target_link_libraries(triton LLVM) target_link_libraries(triton LLVM)
# Warning level # Warning level
if(MSVC) #if(MSVC)
target_compile_options(triton PRIVATE /W4) # target_compile_options(triton PRIVATE /W4)
else() #else()
target_compile_options(triton PRIVATE -Wno-unused-parameter -Wall -Wextra -pedantic) # target_compile_options(triton PRIVATE -Wno-unused-parameter -Wall -Wextra -pedantic)
endif() #endif()

View File

@@ -78,19 +78,23 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb) + ")";
std::string res = std::string res =
R"( R"(
const tunable int TM = {128}; #define TM 128
const tunable int TN = {128}; #define TN 128
const tunable int TK = {32}; #define TK 32
void matmul(restrict read_only align(16) )" + a_ty + R"( *A, extern int get_program_id(int);
restrict read_only align(16) )" + b_ty + R"( *B,
restrict read_only align(16) )" + c_ty + R"( *C, void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
restrict )" + b_ty + R"( * B __attribute__((readonly, aligned(16))),
restrict )" + c_ty + R"( * C __attribute__((aligned(16))),
int M, int N, int K, int M, int N, int K,
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) { int lda __attribute__((multiple_of(8))),
int ldb __attribute__((multiple_of(8))),
int ldc) {
int ridx = get_program_id(0); int ridx = get_program_id(0);
int ridy = get_program_id(1); int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + (0 ... TM); int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
int ryb[TN] = ridy * TN + (0 ... TN); int ryb[TN] = ridy * TN + 0 ... TN;
int rka[TK] = 0 ... TK; int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK; int rkb[TK] = 0 ... TK;
float xc[)" + XCS + R"(] = 0; float xc[)" + XCS + R"(] = 0;
@@ -112,7 +116,7 @@ void matmul(restrict read_only align(16) )" + a_ty + R"( *A,
bool checkc0[TM] = rxc < M; bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N; bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = c; *pc = c;
} }
)"; )";
return res; return res;

View File

@@ -0,0 +1,743 @@
#ifndef _WGTCC_AST_H_
#define _WGTCC_AST_H_
#include "error.h"
#include "token.h"
#include "type.h"
#include <cassert>
#include <list>
#include <memory>
#include <string>
class Visitor;
template<typename T> class Evaluator;
class AddrEvaluator;
class Generator;
class Scope;
class Parser;
class ASTNode;
class Token;
class TokenSequence;
// Expressions
class Expr;
class BinaryOp;
class UnaryOp;
class ConditionalOp;
class FuncCall;
class TempVar;
class Constant;
class Identifier;
class Object;
struct Initializer;
class Declaration;
class Enumerator;
// Statements
class Stmt;
class IfStmt;
class JumpStmt;
class LabelStmt;
class EmptyStmt;
class CompoundStmt;
class FuncDef;
class TranslationUnit;
/*
* AST Node
*/
class ASTNode {
public:
virtual ~ASTNode() {}
virtual void Accept(Visitor* v) = 0;
protected:
ASTNode() {}
MemPool* pool_ {nullptr};
};
using ExtDecl = ASTNode;
/*
* Statements
*/
class Stmt : public ASTNode {
public:
virtual ~Stmt() {}
protected:
Stmt() {}
};
class EmptyStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static EmptyStmt* New();
virtual ~EmptyStmt() {}
virtual void Accept(Visitor* v);
protected:
EmptyStmt() {}
};
class LabelStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static LabelStmt* New();
~LabelStmt() {}
virtual void Accept(Visitor* v);
std::string Repr() const { return ".L" + std::to_string(tag_); }
protected:
LabelStmt(): tag_(GenTag()) {}
private:
static int GenTag() {
static int tag = 0;
return ++tag;
}
int tag_; // 使用整型的tag值而不直接用字符串
};
class IfStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static IfStmt* New(Expr* cond, Stmt* then, Stmt* els=nullptr);
virtual ~IfStmt() {}
virtual void Accept(Visitor* v);
protected:
IfStmt(Expr* cond, Stmt* then, Stmt* els = nullptr)
: cond_(cond), then_(then), else_(els) {}
private:
Expr* cond_;
Stmt* then_;
Stmt* else_;
};
class JumpStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static JumpStmt* New(LabelStmt* label);
virtual ~JumpStmt() {}
virtual void Accept(Visitor* v);
void SetLabel(LabelStmt* label) { label_ = label; }
protected:
JumpStmt(LabelStmt* label): label_(label) {}
private:
LabelStmt* label_;
};
class ReturnStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ReturnStmt* New(Expr* expr);
virtual ~ReturnStmt() {}
virtual void Accept(Visitor* v);
protected:
ReturnStmt(::Expr* expr): expr_(expr) {}
private:
::Expr* expr_;
};
using StmtList = std::list<Stmt*>;
class CompoundStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static CompoundStmt* New(StmtList& stmts, ::Scope* scope=nullptr);
virtual ~CompoundStmt() {}
virtual void Accept(Visitor* v);
StmtList& Stmts() { return stmts_; }
::Scope* Scope() { return scope_; }
protected:
CompoundStmt(const StmtList& stmts, ::Scope* scope=nullptr)
: stmts_(stmts), scope_(scope) {}
private:
StmtList stmts_;
::Scope* scope_;
};
struct Initializer {
Initializer(Type* type,
int offset,
Expr* expr,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0)
: type_(type),
offset_(offset),
bitFieldBegin_(bitFieldBegin),
bitFieldWidth_(bitFieldWidth),
expr_(expr) {}
bool operator<(const Initializer& rhs) const;
// It could be the object it self or, it will be the member
// that was initialized
Type* type_;
int offset_;
unsigned char bitFieldBegin_;
unsigned char bitFieldWidth_;
Expr* expr_;
};
using InitList = std::set<Initializer>;
class Declaration: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Declaration* New(Object* obj);
virtual ~Declaration() {}
virtual void Accept(Visitor* v);
InitList& Inits() { return inits_; }
Object* Obj() { return obj_; }
void AddInit(Initializer init);
protected:
Declaration(Object* obj): obj_(obj) {}
Object* obj_;
InitList inits_;
};
/*
* Expr
* BinaryOp
* UnaryOp
* ConditionalOp
* FuncCall
* Constant
* Identifier
* Object
* TempVar
*/
class Expr : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValGenerator;
public:
virtual ~Expr() {}
::Type* Type() { return type_.GetPtr(); }
virtual bool IsLVal() = 0;
virtual void TypeChecking() = 0;
void EnsureCompatible(const QualType lhs, const QualType rhs) const;
void EnsureCompatibleOrVoidPointer(const QualType lhs,
const QualType rhs) const;
const Token* Tok() const { return tok_; }
void SetTok(const Token* tok) { tok_ = tok; }
static Expr* MayCast(Expr* expr);
static Expr* MayCast(Expr* expr, QualType desType);
virtual bool IsNullPointerConstant() const { return false; }
bool IsConstQualified() const { return type_.IsConstQualified(); }
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
bool IsVolatileQualified() const { return type_.IsVolatileQualified(); }
protected:
// You can construct a expression without specifying a type,
// then the type should be evaluated in TypeChecking()
Expr(const Token* tok, QualType type): tok_(tok), type_(type) {}
const Token* tok_;
QualType type_;
};
/*
* '+', '-', '*', '/', '%', '<', '>', '<<', '>>', '|', '&', '^'
* '=',(复合赋值运算符被拆分为两个运算)
* '==', '!=', '<=', '>=',
* '&&', '||'
* '['(下标运算符), '.'(成员运算符)
* ','(逗号运算符),
*/
class BinaryOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValGenerator;
friend class Declaration;
public:
static BinaryOp* New(const Token* tok, Expr* lhs, Expr* rhs);
static BinaryOp* New(const Token* tok, int op, Expr* lhs, Expr* rhs);
virtual ~BinaryOp() {}
virtual void Accept(Visitor* v);
// Member ref operator is a lvalue
virtual bool IsLVal() {
switch (op_) {
case '.': return !Type()->ToArray() && lhs_->IsLVal();
case ']': return !Type()->ToArray();
default: return false;
}
}
ArithmType* Convert();
void Broadcast();
virtual void TypeChecking();
void SubScriptingOpTypeChecking();
void MemberRefOpTypeChecking();
void MultiOpTypeChecking();
void AdditiveOpTypeChecking();
void ShiftOpTypeChecking();
void RangeOpTypeChecking();
void RelationalOpTypeChecking();
void EqualityOpTypeChecking();
void BitwiseOpTypeChecking();
void LogicalOpTypeChecking();
void AssignOpTypeChecking();
void CommaOpTypeChecking();
protected:
BinaryOp(const Token* tok, int op, Expr* lhs, Expr* rhs)
: Expr(tok, nullptr), op_(op) {
lhs_ = lhs, rhs_ = rhs;
if (op != '.') {
lhs_ = MayCast(lhs);
rhs_ = MayCast(rhs);
}
}
int op_;
Expr* lhs_;
Expr* rhs_;
};
/*
* Unary Operator:
* '++' (prefix/postfix)
* '--' (prefix/postfix)
* '&' (ADDR)
* '*' (DEREF)
* '+' (PLUS)
* '-' (MINUS)
* '~'
* '!'
* CAST // like (int)3
*/
class UnaryOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValGenerator;
public:
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr);
virtual ~UnaryOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal();
ArithmType* Convert();
void TypeChecking();
void IncDecOpTypeChecking();
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void UnaryArithmOpTypeChecking();
void CastOpTypeChecking();
protected:
UnaryOp(int op, Expr* operand, QualType type=nullptr)
: Expr(operand->Tok(), type), op_(op) {
operand_ = operand;
if (op_ != Token::CAST && op_ != Token::ADDR) {
operand_ = MayCast(operand);
}
}
int op_;
Expr* operand_;
};
// cond ? true false
class ConditionalOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ConditionalOp* New(const Token* tok,
Expr* cond, Expr* exprTrue, Expr* exprFalse);
virtual ~ConditionalOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
ArithmType* Convert();
virtual void TypeChecking();
protected:
ConditionalOp(Expr* cond, Expr* exprTrue, Expr* exprFalse)
: Expr(cond->Tok(), nullptr), cond_(MayCast(cond)),
exprTrue_(MayCast(exprTrue)), exprFalse_(MayCast(exprFalse)) {}
private:
Expr* cond_;
Expr* exprTrue_;
Expr* exprFalse_;
};
class FuncCall : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
using ArgList = std::vector<Expr*>;
public:
static FuncCall* New(Expr* designator, const ArgList& args);
~FuncCall() {}
virtual void Accept(Visitor* v);
// A function call is ofcourse not lvalue
virtual bool IsLVal() { return false; }
ArgList* Args() { return &args_; }
Expr* Designator() { return designator_; }
const std::string& Name() const { return tok_->str_; }
::FuncType* FuncType() { return designator_->Type()->ToFunc(); }
virtual void TypeChecking();
protected:
FuncCall(Expr* designator, const ArgList& args)
: Expr(designator->Tok(), nullptr),
designator_(designator), args_(args) {}
Expr* designator_;
ArgList args_;
};
class Constant: public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Constant* New(const Token* tok, int tag, long val);
static Constant* New(const Token* tok, int tag, double val);
static Constant* New(const Token* tok, int tag, const std::string* val);
~Constant() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
virtual void TypeChecking() {}
long IVal() const { return ival_; }
double FVal() const { return fval_; }
const std::string* SVal() const { return sval_; }
std::string SValRepr() const;
std::string Repr() const { return std::string(".LC") + std::to_string(id_); }
protected:
Constant(const Token* tok, QualType type, long val)
: Expr(tok, type), ival_(val) {}
Constant(const Token* tok, QualType type, double val)
: Expr(tok, type), fval_(val) {}
Constant(const Token* tok, QualType type, const std::string* val)
: Expr(tok, type), sval_(val) {}
union {
long ival_;
double fval_;
struct {
long id_;
const std::string* sval_;
};
};
};
class TempVar : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static TempVar* New(QualType type);
virtual ~TempVar() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return true; }
virtual void TypeChecking() {}
protected:
TempVar(QualType type): Expr(nullptr, type), tag_(GenTag()) {}
private:
static int GenTag() {
static int tag = 0;
return ++tag;
}
int tag_;
};
enum Linkage {
L_NONE,
L_EXTERNAL,
L_INTERNAL,
};
class Identifier: public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValGenerator;
public:
static Identifier* New(const Token* tok, QualType type, Linkage linkage);
virtual ~Identifier() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
virtual Object* ToObject() { return nullptr; }
virtual Enumerator* ToEnumerator() { return nullptr; }
// An identifer can be:
// object, sturct/union/enum tag, typedef name, function, label.
Identifier* ToTypeName() {
// A typename has no linkage
// And a function has external or internal linkage
if (ToObject() || ToEnumerator() || linkage_ != L_NONE)
return nullptr;
return this;
}
virtual const std::string Name() const { return tok_->str_; }
enum Linkage Linkage() const { return linkage_; }
void SetLinkage(enum Linkage linkage) { linkage_ = linkage; }
virtual void TypeChecking() {}
protected:
Identifier(const Token* tok, QualType type, enum Linkage linkage)
: Expr(tok, type), linkage_(linkage) {}
// An identifier has property linkage
enum Linkage linkage_;
};
class Enumerator: public Identifier {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Enumerator* New(const Token* tok, int val);
virtual ~Enumerator() {}
virtual void Accept(Visitor* v);
virtual Enumerator* ToEnumerator() { return this; }
int Val() const { return cons_->IVal(); }
protected:
Enumerator(const Token* tok, int val)
: Identifier(tok, ArithmType::New(T_INT), L_NONE),
cons_(Constant::New(tok, T_INT, (long)val)) {}
Constant* cons_;
};
class Object : public Identifier {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValGenerator;
public:
static Object* New(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0);
static Object* NewAnony(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0);
~Object() {}
virtual void Accept(Visitor* v);
virtual Object* ToObject() { return this; }
virtual bool IsLVal() {
// TODO(wgtdkp): not all object is lval?
return true;
}
bool IsStatic() const {
return (Storage() & S_STATIC) || (Linkage() != L_NONE);
}
int Storage() const { return storage_; }
void SetStorage(int storage) { storage_ = storage; }
int Align() const { return align_; }
void SetAlign(int align) {
assert(align > 0);
// Allowing reduce alignment to implement __attribute__((packed))
//if (align < align_)
// Error(this, "alignment specifier cannot reduce alignment");
align_ = align;
}
int Offset() const { return offset_; }
void SetOffset(int offset) { offset_ = offset; }
Declaration* Decl() { return decl_; }
void SetDecl(Declaration* decl) { decl_ = decl; }
unsigned char BitFieldBegin() const { return bitFieldBegin_; }
unsigned char BitFieldEnd() const { return bitFieldBegin_ + bitFieldWidth_; }
unsigned char BitFieldWidth() const { return bitFieldWidth_; }
static unsigned long BitFieldMask(Object* bitField) {
return BitFieldMask(bitField->bitFieldBegin_, bitField->bitFieldWidth_);
}
static unsigned long BitFieldMask(unsigned char begin, unsigned char width) {
auto end = begin + width;
return ((0xFFFFFFFFFFFFFFFFUL << (64 - end)) >> (64 - width)) << begin;
}
bool HasInit() const { return decl_ && decl_->Inits().size(); }
bool Anonymous() const { return anonymous_; }
virtual const std::string Name() const { return Identifier::Name(); }
std::string Repr() const {
assert(IsStatic() || anonymous_);
if (anonymous_)
return "anonymous." + std::to_string(id_);
if (linkage_ == L_NONE)
return Name() + "." + std::to_string(id_);
return Name();
}
protected:
Object(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0)
: Identifier(tok, type, linkage),
storage_(storage),
offset_(0),
align_(type->Align()),
decl_(nullptr),
bitFieldBegin_(bitFieldBegin),
bitFieldWidth_(bitFieldWidth),
anonymous_(false) {}
private:
int storage_;
int offset_;
int align_;
Declaration* decl_;
unsigned char bitFieldBegin_;
// 0 means it's not a bitfield
unsigned char bitFieldWidth_;
bool anonymous_;
long id_ {0};
};
/*
* Declaration
*/
class FuncDef : public ExtDecl {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
using ParamList = std::vector<Object*>;
public:
static FuncDef* New(Identifier* ident, LabelStmt* retLabel);
virtual ~FuncDef() {}
virtual void Accept(Visitor* v);
::FuncType* FuncType() { return ident_->Type()->ToFunc(); }
CompoundStmt* Body() { return body_; }
void SetBody(CompoundStmt* body) { body_ = body; }
std::string Name() const { return ident_->Name(); }
enum Linkage Linkage() { return ident_->Linkage(); }
protected:
FuncDef(Identifier* ident, LabelStmt* retLabel)
: ident_(ident), retLabel_(retLabel) {}
private:
Identifier* ident_;
LabelStmt* retLabel_;
CompoundStmt* body_;
};
using ExtDeclList = std::list<ExtDecl*>;
class TranslationUnit : public ASTNode {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static TranslationUnit* New() { return new TranslationUnit();}
virtual ~TranslationUnit() {}
virtual void Accept(Visitor* v);
void Add(ExtDecl* extDecl) { extDecls_.push_back(extDecl); }
ExtDeclList& ExtDecls() { return extDecls_; }
const ExtDeclList& ExtDecls() const { return extDecls_; }
private:
TranslationUnit() {}
ExtDeclList extDecls_;
};
#endif

View File

@@ -0,0 +1,274 @@
#ifndef _WGTCC_CODE_GEN_H_
#define _WGTCC_CODE_GEN_H_
#include "ast.h"
#include "visitor.h"
class Parser;
struct Addr;
struct ROData;
template<> class Evaluator<Addr>;
struct StaticInitializer;
using TypeList = std::vector<Type*>;
using LocationList = std::vector<std::string>;
using RODataList = std::vector<ROData>;
using StaticInitList = std::vector<StaticInitializer>;
enum class ParamClass {
INTEGER,
SSE,
SSEUP,
X87,
X87_UP,
COMPLEX_X87,
NO_CLASS,
MEMORY
};
struct ParamLocations {
LocationList locs_;
size_t regCnt_;
size_t xregCnt_;
};
struct ROData {
ROData(long ival, int align): ival_(ival), align_(align) {
label_ = ".LC" + std::to_string(GenTag());
}
explicit ROData(const std::string& sval): sval_(sval), align_(1) {
label_ = ".LC" + std::to_string(GenTag());
}
~ROData() {}
std::string sval_;
long ival_;
int align_;
std::string label_;
private:
static long GenTag() {
static long tag = 0;
return tag++;
}
};
struct ObjectAddr {
explicit ObjectAddr(int offset)
: ObjectAddr("", "%rbp", offset) {}
ObjectAddr(const std::string& label, const std::string& base, int offset)
: label_(label), base_(base), offset_(offset) {}
std::string Repr() const;
std::string label_;
std::string base_;
int offset_;
unsigned char bitFieldBegin_ {0};
unsigned char bitFieldWidth_ {0};
};
struct StaticInitializer {
int offset_;
int width_;
long val_;
std::string label_;
};
class Generator: public Visitor {
friend class Evaluator<Addr>;
public:
Generator() {}
virtual void Visit(ASTNode* node) { node->Accept(this); }
void VisitExpr(Expr* expr) { expr->Accept(this); }
void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
// Expression
virtual void VisitBinaryOp(BinaryOp* binaryOp);
virtual void VisitUnaryOp(UnaryOp* unaryOp);
virtual void VisitConditionalOp(ConditionalOp* condOp);
virtual void VisitFuncCall(FuncCall* funcCall);
virtual void VisitObject(Object* obj);
virtual void VisitEnumerator(Enumerator* enumer);
virtual void VisitIdentifier(Identifier* ident);
virtual void VisitConstant(Constant* cons);
virtual void VisitTempVar(TempVar* tempVar);
// Statement
virtual void VisitDeclaration(Declaration* init);
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt);
virtual void VisitIfStmt(IfStmt* ifStmt);
virtual void VisitJumpStmt(JumpStmt* jumpStmt);
virtual void VisitReturnStmt(ReturnStmt* returnStmt);
virtual void VisitLabelStmt(LabelStmt* labelStmt);
virtual void VisitCompoundStmt(CompoundStmt* compoundStmt);
virtual void VisitFuncDef(FuncDef* funcDef);
virtual void VisitTranslationUnit(TranslationUnit* unit);
static void SetInOut(Parser* parser, FILE* outFile) {
parser_ = parser;
outFile_ = outFile;
}
void Gen();
protected:
// Binary
void GenCommaOp(BinaryOp* comma);
void GenMemberRefOp(BinaryOp* binaryOp);
void GenAndOp(BinaryOp* binaryOp);
void GenOrOp(BinaryOp* binaryOp);
void GenAddOp(BinaryOp* binaryOp);
void GenSubOp(BinaryOp* binaryOp);
void GenAssignOp(BinaryOp* assign);
void GenCastOp(UnaryOp* cast);
void GenDerefOp(UnaryOp* deref);
void GenMinusOp(UnaryOp* minus);
void GenPointerArithm(BinaryOp* binary);
void GenDivOp(bool flt, bool sign, int width, int op);
void GenMulOp(int width, bool flt, bool sign);
void GenCompOp(int width, bool flt, const char* set);
void GenCompZero(Type* type);
// Unary
void GenIncDec(Expr* operand, bool postfix, const std::string& inst);
StaticInitializer GetStaticInit(InitList::iterator& iter,
InitList::iterator end, int offset);
void GenStaticDecl(Declaration* decl);
void GenSaveArea();
void GenBuiltin(FuncCall* funcCall);
void AllocObjects(Scope* scope,
const FuncDef::ParamList& params=FuncDef::ParamList());
void CopyStruct(ObjectAddr desAddr, int width);
std::string ConsLabel(Constant* cons);
ParamLocations GetParamLocations(const TypeList& types, bool retStruct);
void GetParamRegOffsets(int& gpOffset, int& fpOffset,
int& overflow, FuncType* funcType);
void Emit(const std::string& str) {
fprintf(outFile_, "\t%s\n", str.c_str());
}
void Emit(const std::string& inst,
const std::string& src,
const std::string& des) {
Emit(inst + "\t" + src + ", " + des);
}
void Emit(const std::string& inst,
int imm,
const std::string& reg) {
Emit(inst + "\t$" + std::to_string(imm) + ", " + reg);
}
void Emit(const std::string& inst,
const std::string& des) {
Emit(inst + "\t" + des);
}
void Emit(const std::string& inst,
const LabelStmt* label) {
Emit(inst + "\t" + label->Repr());
}
void Emit(const std::string& inst,
const ObjectAddr& src,
const ObjectAddr& des) {
Emit(inst, src.Repr(), des.Repr());
}
void Emit(const std::string& inst,
const std::string& src,
const ObjectAddr& des) {
Emit(inst, src, des.Repr());
}
void Emit(const std::string& inst,
const ObjectAddr& src,
const std::string& des) {
Emit(inst, src.Repr(), des);
}
void EmitLabel(const std::string& label);
void EmitZero(ObjectAddr addr, int width);
void EmitLoad(const std::string& addr, Type* type);
void EmitLoad(const std::string& addr, int width, bool flt);
void EmitStore(const ObjectAddr& addr, Type* type);
void EmitStore(const std::string& addr, Type* type);
void EmitStore(const std::string& addr, int width, bool flt);
void EmitLoadBitField(const std::string& addr, Object* bitField);
void EmitStoreBitField(const ObjectAddr& addr, Type* type);
void EmitLoc(Expr* expr);
int Push(Type* type);
int Push(const std::string& reg);
int Pop(const std::string& reg);
void Spill(bool flt);
void Restore(bool flt);
void Save(bool flt);
void Exchange(bool flt);
protected:
static const std::string* last_file;
static Parser* parser_;
static FILE* outFile_;
static RODataList rodatas_;
static int offset_;
// The address that store the register %rdi,
// when the return value is a struct/union
static int retAddrOffset_;
static FuncDef* curFunc_;
static std::vector<Declaration*> staticDecls_;
};
class LValGenerator: public Generator {
public:
LValGenerator() {}
// Expression
virtual void VisitBinaryOp(BinaryOp* binaryOp);
virtual void VisitUnaryOp(UnaryOp* unaryOp);
virtual void VisitObject(Object* obj);
virtual void VisitIdentifier(Identifier* ident);
virtual void VisitConditionalOp(ConditionalOp* condOp) { assert(false); }
virtual void VisitFuncCall(FuncCall* funcCall) { assert(false); }
virtual void VisitEnumerator(Enumerator* enumer) { assert(false); }
virtual void VisitConstant(Constant* cons) { assert(false); }
virtual void VisitTempVar(TempVar* tempVar);
ObjectAddr GenExpr(Expr* expr) {
expr->Accept(this);
return addr_;
}
private:
ObjectAddr addr_ {"", "", 0};
};
#endif

View File

@@ -0,0 +1,162 @@
#ifndef _WGTCC_CPP_H_
#define _WGTCC_CPP_H_
#include "scanner.h"
#include <cstdio>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <string>
class Macro;
struct CondDirective;
using MacroMap = std::map<std::string, Macro>;
using ParamList = std::list<std::string>;
using ParamMap = std::map<std::string, TokenSequence>;
using PPCondStack = std::stack<CondDirective>;
using PathList = std::list<std::string>;
class Macro {
public:
Macro(const TokenSequence& repSeq, bool preDef=false)
: funcLike_(false), variadic_(false),
preDef_(preDef), repSeq_(repSeq) {}
Macro(bool variadic, ParamList& params,
TokenSequence& repSeq, bool preDef=false)
: funcLike_(true), variadic_(variadic), preDef_(preDef),
params_(params), repSeq_(repSeq) {}
~Macro() {}
bool FuncLike() { return funcLike_; }
bool ObjLike() { return !FuncLike(); }
bool Variadic() { return variadic_; }
bool PreDef() { return preDef_; }
ParamList& Params() { return params_; }
TokenSequence RepSeq(const std::string* filename, unsigned line);
private:
bool funcLike_;
bool variadic_;
bool preDef_;
ParamList params_;
TokenSequence repSeq_;
};
struct CondDirective {
int tag_;
bool enabled_;
bool cond_;
};
class Preprocessor {
public:
Preprocessor(const std::string* str, bool isSrc = true)
: curLine_(1), lineLine_(0), curCond_(true), fName_(nullptr), fSrc_(nullptr) {
if(isSrc)
fSrc_ = str;
else
fName_ = str;
// Add predefined
Init();
}
~Preprocessor() {}
void Finalize(TokenSequence os);
void Process(TokenSequence& os);
void Expand(TokenSequence& os, TokenSequence is, bool inCond=false);
void Subst(TokenSequence& os, TokenSequence is,
bool leadingWS, const HideSet& hs, ParamMap& params);
void Glue(TokenSequence& os, TokenSequence is);
void Glue(TokenSequence& os, const Token* tok);
const Token* Stringize(TokenSequence is);
void Stringize(std::string& str, TokenSequence is);
const Token* ParseActualParam(TokenSequence& is, Macro* macro, ParamMap& paramMap);
int GetDirective(TokenSequence& is);
const Token* EvalDefOp(TokenSequence& is);
void ReplaceIdent(TokenSequence& is);
void ParseDirective(TokenSequence& os, TokenSequence& is, int directive);
void ParseIf(TokenSequence ls);
void ParseIfdef(TokenSequence ls);
void ParseIfndef(TokenSequence ls);
void ParseElif(TokenSequence ls);
void ParseElse(TokenSequence ls);
void ParseEndif(TokenSequence ls);
void ParseInclude(TokenSequence& is, TokenSequence ls);
void ParseDef(TokenSequence ls);
void ParseUndef(TokenSequence ls);
void ParseLine(TokenSequence ls);
void ParseError(TokenSequence ls);
void ParsePragma(TokenSequence ls);
void IncludeSrc(TokenSequence& is, const std::string* text, const std::string* filename);
void IncludeFile(TokenSequence& is, const std::string* filename);
bool ParseIdentList(ParamList& params, TokenSequence& is);
Macro* FindMacro(const std::string& name) {
auto res = macroMap_.find(name);
if (res == macroMap_.end())
return nullptr;
return &res->second;
}
void AddMacro(const std::string& name,
std::string* text, bool preDef=false);
void AddMacro(const std::string& name, const Macro& macro) {
auto res = macroMap_.find(name);
if (res != macroMap_.end()) {
// TODO(wgtdkp): give warning
macroMap_.erase(res);
}
macroMap_.insert(std::make_pair(name, macro));
}
void RemoveMacro(const std::string& name) {
auto res = macroMap_.find(name);
if (res == macroMap_.end())
return;
if(res->second.PreDef()) // Cannot undef predefined macro
return;
macroMap_.erase(res);
}
std::string* SearchFile(const std::string& name,
const bool libHeader,
bool next,
const std::string& curPath);
void AddSearchPath(std::string path);
void HandleTheFileMacro(TokenSequence& os, const Token* macro);
void HandleTheLineMacro(TokenSequence& os, const Token* macro);
void UpdateFirstTokenLine(TokenSequence ts);
bool NeedExpand() const {
if (ppCondStack_.empty())
return true;
auto top = ppCondStack_.top();
return top.enabled_ && top.cond_;
}
private:
void Init();
PPCondStack ppCondStack_;
unsigned curLine_;
unsigned lineLine_;
bool curCond_;
MacroMap macroMap_;
PathList searchPaths_;
const std::string* fName_;
const std::string* fSrc_;
};
#endif

View File

@@ -0,0 +1,20 @@
#ifndef _WGTCC_ENCODING_H_
#define _WGTCC_ENCODING_H_
#include <string>
enum class Encoding {
NONE,
CHAR16,
CHAR32,
UTF8,
WCHAR
};
void ConvertToUTF16(std::string& str);
void ConvertToUTF32(std::string& str);
void AppendUCN(std::string& str, int c);
#endif

View File

@@ -0,0 +1,15 @@
#ifndef _WGTCC_ERROR_H_
#define _WGTCC_ERROR_H_
struct SourceLocation;
class Token;
class Expr;
[[noreturn]] void Error(const char* format, ...);
[[noreturn]] void Error(const SourceLocation& loc, const char* format, ...);
[[noreturn]] void Error(const Token* tok, const char* format, ...);
[[noreturn]] void Error(const Expr* expr, const char* format, ...);
#endif

View File

@@ -0,0 +1,120 @@
#ifndef _WGTCC_EVALUATOR_H_
#define _WGTCC_EVALUATOR_H_
#include "ast.h"
#include "error.h"
#include "visitor.h"
class Expr;
template<typename T>
class Evaluator: public Visitor {
public:
Evaluator() {}
virtual ~Evaluator() {}
virtual void VisitBinaryOp(BinaryOp* binary);
virtual void VisitUnaryOp(UnaryOp* unary);
virtual void VisitConditionalOp(ConditionalOp* cond);
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
val_ = static_cast<T>(enumer->Val());
}
virtual void VisitIdentifier(Identifier* ident) {
Error(ident, "expect constant expression");
}
virtual void VisitObject(Object* obj) {
Error(obj, "expect constant expression");
}
virtual void VisitConstant(Constant* cons) {
if (cons->Type()->IsFloat()) {
val_ = static_cast<T>(cons->FVal());
} else if (cons->Type()->IsInteger()) {
val_ = static_cast<T>(cons->IVal());
} else {
assert(false);
}
}
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
virtual void VisitFuncDef(FuncDef* funcDef) {}
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
T Eval(Expr* expr) {
expr->Accept(this);
return val_;
}
private:
T val_;
};
struct Addr {
std::string label_;
int offset_;
};
template<>
class Evaluator<Addr>: public Visitor {
public:
Evaluator<Addr>() {}
virtual ~Evaluator<Addr>() {}
virtual void VisitBinaryOp(BinaryOp* binary);
virtual void VisitUnaryOp(UnaryOp* unary);
virtual void VisitConditionalOp(ConditionalOp* cond);
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
addr_.offset_ = enumer->Val();
}
virtual void VisitIdentifier(Identifier* ident) {
addr_.label_ = ident->Name();
addr_.offset_ = 0;
}
virtual void VisitObject(Object* obj) {
if (!obj->IsStatic()) {
Error(obj, "expect static object");
}
addr_.label_ = obj->Repr();
addr_.offset_ = 0;
}
virtual void VisitConstant(Constant* cons);
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
virtual void VisitFuncDef(FuncDef* funcDef) {}
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
Addr Eval(Expr* expr) {
expr->Accept(this);
return addr_;
}
private:
Addr addr_;
};
#endif

View File

@@ -0,0 +1,101 @@
#ifndef _WGTCC_MEM_POOL_H_
#define _WGTCC_MEM_POOL_H_
#include <cstddef>
#include <vector>
class MemPool {
public:
MemPool(): allocated_(0) {}
virtual ~MemPool() {}
MemPool(const MemPool& other) = delete;
MemPool& operator=(const MemPool& other) = delete;
virtual void* Alloc() = 0;
virtual void Free(void* addr) = 0;
virtual void Clear() = 0;
protected:
size_t allocated_;
};
template <class T>
class MemPoolImp: public MemPool {
public:
MemPoolImp() : root_(nullptr) {}
virtual ~MemPoolImp() {}
MemPoolImp(const MemPool& other) = delete;
MemPoolImp& operator=(MemPool& other) = delete;
virtual void* Alloc();
virtual void Free(void* addr);
virtual void Clear();
private:
enum {
COUNT = (4 * 1024) / sizeof(T)
};
union Chunk {
Chunk* next_;
char mem_[sizeof(T)];
};
struct Block {
Block() {
for (size_t i = 0; i < COUNT - 1; ++i)
chunks_[i].next_ = &chunks_[i+1];
chunks_[COUNT-1].next_ = nullptr;
}
Chunk chunks_[COUNT];
};
std::vector<Block*> blocks_;
Chunk* root_;
};
template <class T>
void* MemPoolImp<T>::Alloc() {
if (nullptr == root_) { // 空间不够,需要分配空间
auto block = new Block();
root_ = block->chunks_;
// 如果blocks实现为std::list, 那么push_back实际的overhead更大
// 这也表明,即使我们不需要随机访问功能(那么std::vector的拷贝是一种overhead)
// 仍然倾向于使用std::vector
// 当然std::vector的指数级capacity增长会造成内存浪费。
blocks_.push_back(block);
}
auto ret = root_;
root_ = root_->next_;
++allocated_;
return ret;
}
template <class T>
void MemPoolImp<T>::Free(void* addr) {
if (nullptr == addr)
return;
auto chunk = static_cast<Chunk*>(addr);
chunk->next_ = root_;
root_ = chunk;
--allocated_;
}
template <class T>
void MemPoolImp<T>::Clear() {
for (auto block: blocks_)
delete block;
blocks_.resize(0);
root_ = nullptr;
allocated_ = 0;
}
#endif

View File

@@ -0,0 +1,244 @@
#ifndef _PARSER_H_
#define _PARSER_H_
#include "ast.h"
#include "encoding.h"
#include "error.h"
#include "mem_pool.h"
#include "scope.h"
#include "token.h"
#include <cassert>
#include <memory>
#include <stack>
class Preprocessor;
using TokenTypePair = std::pair<const Token*, QualType>;
class Parser {
using LiteralList = std::vector<Constant*>;
using StaticObjectList = std::vector<Object*>;
using CaseLabelList = std::vector<std::pair<Constant*, LabelStmt*>>;
using LabelJumpList = std::list<std::pair<const Token*, JumpStmt*>>;
using LabelMap = std::map<std::string, LabelStmt*>;
friend class Generator;
public:
explicit Parser(const TokenSequence& ts)
: unit_(TranslationUnit::New()),
ts_(ts),
externalSymbols_(new Scope(nullptr, S_BLOCK)),
errTok_(nullptr),
curScope_(new Scope(nullptr, S_FILE)),
curFunc_(nullptr),
breakDest_(nullptr),
continueDest_(nullptr),
caseLabels_(nullptr),
defaultLabel_(nullptr) {
ts_.SetParser(this);
}
~Parser() {}
Constant* ParseConstant(const Token* tok);
Constant* ParseFloat(const Token* tok);
Constant* ParseInteger(const Token* tok);
Constant* ParseCharacter(const Token* tok);
Encoding ParseLiteral(std::string& str, const Token* tok);
Constant* ConcatLiterals(const Token* tok);
Expr* ParseGeneric();
void Parse();
void ParseTranslationUnit();
FuncDef* ParseFuncDef(Identifier* ident);
// Expressions
Expr* ParseExpr();
Expr* ParsePrimaryExpr();
QualType TryCompoundLiteral();
Object* ParseCompoundLiteral(QualType type);
Expr* ParsePostfixExpr();
Expr* ParsePostfixExprTail(Expr* primExpr);
Expr* ParseSubScripting(Expr* pointer);
BinaryOp* ParseMemberRef(const Token* tok, int op, Expr* lhs);
UnaryOp* ParsePostfixIncDec(const Token* tok, Expr* operand);
FuncCall* ParseFuncCall(Expr* caller);
Expr* ParseUnaryExpr();
Constant* ParseSizeof();
Constant* ParseAlignof();
UnaryOp* ParsePrefixIncDec(const Token* tok);
UnaryOp* ParseUnaryOp(const Token* tok, int op);
QualType ParseTypeName();
Expr* ParseCastExpr();
Expr* ParseRangeExpr();
Expr* ParseMultiplicativeExpr();
Expr* ParseAdditiveExpr();
Expr* ParseShiftExpr();
Expr* ParseRelationalExpr();
Expr* ParseEqualityExpr();
Expr* ParseBitiwiseAndExpr();
Expr* ParseBitwiseXorExpr();
Expr* ParseBitwiseOrExpr();
Expr* ParseLogicalAndExpr();
Expr* ParseLogicalOrExpr();
Expr* ParseConditionalExpr();
Expr* ParseCommaExpr();
Expr* ParseAssignExpr();
// Declarations
CompoundStmt* ParseDecl();
void ParseStaticAssert();
QualType ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec);
QualType ParseSpecQual();
int ParseAlignas();
Type* ParseStructUnionSpec(bool isStruct);
StructType* ParseStructUnionDecl(StructType* type);
void ParseBitField(StructType* structType, const Token* tok, QualType type);
Type* ParseEnumSpec();
Type* ParseEnumerator(ArithmType* type);
int ParseQual();
QualType ParsePointer(QualType typePointedTo);
TokenTypePair ParseDeclarator(QualType type);
QualType ParseArrayFuncDeclarator(const Token* ident, QualType base);
int ParseArrayLength();
TileType::ShapeInt ParseTileShape();
bool ParseParamList(FuncType::ParamList& params);
Object* ParseParamDecl();
QualType ParseAbstractDeclarator(QualType type);
Identifier* ParseDirectDeclarator(QualType type,
int storageSpec,
int funcSpec,
int align);
// Initializer
void ParseInitializer(Declaration* decl,
QualType type,
int offset,
bool designated=false,
bool forceBrace=false,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0);
void ParseArrayInitializer(Declaration* decl,
ArrayType* type,
int offset,
bool designated);
StructType::Iterator ParseStructDesignator(StructType* type,
const std::string& name);
void ParseStructInitializer(Declaration* decl,
StructType* type,
int offset,
bool designated);
bool ParseLiteralInitializer(Declaration* init,
ArrayType* type,
int offset);
Declaration* ParseInitDeclarator(Identifier* ident);
Declaration* ParseInitDeclaratorSub(Object* obj);
// Statements
Stmt* ParseStmt();
CompoundStmt* ParseCompoundStmt(FuncType* funcType=nullptr);
IfStmt* ParseIfStmt();
CompoundStmt* ParseSwitchStmt();
CompoundStmt* ParseWhileStmt();
CompoundStmt* ParseDoStmt();
CompoundStmt* ParseForStmt();
JumpStmt* ParseGotoStmt();
JumpStmt* ParseContinueStmt();
JumpStmt* ParseBreakStmt();
ReturnStmt* ParseReturnStmt();
CompoundStmt* ParseLabelStmt(const Token* label);
CompoundStmt* ParseCaseStmt();
CompoundStmt* ParseDefaultStmt();
Identifier* ProcessDeclarator(const Token* tok,
QualType type,
int storageSpec,
int funcSpec,
int align);
// GNU extensions
void TryAttributeSpecList();
void ParseAttributeSpec();
void ParseAttribute();
bool IsTypeName(const Token* tok) const{
if (tok->IsTypeSpecQual())
return true;
if (tok->IsIdentifier()) {
auto ident = curScope_->Find(tok);
if (ident && ident->ToTypeName())
return true;
}
return false;
}
bool IsType(const Token* tok) const{
if (tok->IsDecl())
return true;
if (tok->IsIdentifier()) {
auto ident = curScope_->Find(tok);
return (ident && ident->ToTypeName());
}
return false;
}
void EnsureInteger(Expr* expr) {
if (!expr->Type()->IsInteger()) {
Error(expr, "expect integer expression");
}
}
void EnterBlock(FuncType* funcType=nullptr);
void ExitBlock() { curScope_ = curScope_->Parent(); }
void EnterProto() { curScope_ = new Scope(curScope_, S_PROTO); }
void ExitProto() { curScope_ = curScope_->Parent(); }
FuncDef* EnterFunc(Identifier* ident);
void ExitFunc();
LabelStmt* FindLabel(const std::string& label) {
auto ret = curLabels_.find(label);
if (curLabels_.end() == ret)
return nullptr;
return ret->second;
}
void AddLabel(const std::string& label, LabelStmt* labelStmt) {
assert(nullptr == FindLabel(label));
curLabels_[label] = labelStmt;
}
TranslationUnit* Unit() { return unit_; }
FuncDef* CurFunc() { return curFunc_; }
const TokenSequence& ts() const { return ts_; }
private:
static bool IsBuiltin(FuncType* type);
static bool IsBuiltin(const std::string& name);
static Identifier* GetBuiltin(const Token* tok);
static void DefineBuiltins();
static FuncType* vaStartType_;
static FuncType* vaArgType_;
// The root of the AST
TranslationUnit* unit_;
TokenSequence ts_;
// It is not the real scope,
// It contains all external symbols(resolved and not resolved)
Scope* externalSymbols_;
const Token* errTok_;
Scope* curScope_;
FuncDef* curFunc_;
LabelMap curLabels_;
LabelJumpList unresolvedJumps_;
LabelStmt* breakDest_;
LabelStmt* continueDest_;
CaseLabelList* caseLabels_;
LabelStmt* defaultLabel_;
};
#endif

View File

@@ -0,0 +1,84 @@
#ifndef _WGTCC_SCANNER_H_
#define _WGTCC_SCANNER_H_
#include "error.h"
#include "encoding.h"
#include "token.h"
#include <string>
#include <cassert>
class Scanner {
public:
explicit Scanner(const Token* tok)
: Scanner(&tok->str_, tok->loc_) {}
Scanner(const std::string* text, const SourceLocation& loc)
: Scanner(text, loc.filename_, loc.line_, loc.column_) {}
explicit Scanner(const std::string* text,
const std::string* filename=nullptr,
unsigned line=1, unsigned column=1)
: text_(text), tok_(Token::END) {
// TODO(wgtdkp): initialization
p_ = &(*text_)[0];
loc_ = {filename, p_, line, 1};
}
virtual ~Scanner() {}
Scanner(const Scanner& other) = delete;
Scanner& operator=(const Scanner& other) = delete;
// Scan plain text and generate tokens in ts.
// The param 'ts' need not be empty, if so, the tokens
// are inserted at the *header* of 'ts'.
// The param 'ws' tells if there is leading white space
// before this token, it is only SkipComment() that will
// set this param.
Token* Scan(bool ws=false);
void Tokenize(TokenSequence& ts);
static std::string ScanHeadName(const Token* lhs, const Token* rhs);
Encoding ScanCharacter(int& val);
Encoding ScanLiteral(std::string& val);
std::string ScanIdentifier();
private:
Token* SkipIdentifier();
Token* SkipNumber();
Token* SkipLiteral();
Token* SkipCharacter();
Token* MakeToken(int tag);
Token* MakeNewLine();
Encoding ScanEncoding(int c);
int ScanEscaped();
int ScanHexEscaped();
int ScanOctEscaped(int c);
int ScanUCN(int len);
void SkipWhiteSpace();
void SkipComment();
bool IsUCN(int c) { return c == '\\' && (Test('u') || Test('U')); }
bool IsOctal(int c) { return '0' <= c && c <= '7'; }
int XDigit(int c);
bool Empty() const { return *p_ == 0; }
int Peek();
bool Test(int c) { return Peek() == c; };
int Next();
void PutBack();
bool Try(int c) {
if (Peek() == c) {
Next();
return true;
}
return false;
};
void Mark() { tok_.loc_ = loc_; };
const std::string* text_;
SourceLocation loc_;
Token tok_;
const char* p_;
};
std::string* ReadFile(const std::string& filename);
#endif

View File

@@ -0,0 +1,70 @@
#ifndef _WGTCC_SCOPE_H_
#define _WGTCC_SCOPE_H_
#include <iostream>
#include <map>
#include <string>
#include <vector>
class Identifier;
class Token;
enum ScopeType {
S_FILE,
S_PROTO,
S_BLOCK,
S_FUNC,
};
class Scope {
friend class StructType;
using TagList = std::vector<Identifier*>;
using IdentMap = std::map<std::string, Identifier*>;
public:
explicit Scope(Scope* parent, enum ScopeType type)
: parent_(parent), type_(type) {}
~Scope() {}
Scope* Parent() { return parent_; }
void SetParent(Scope* parent) { parent_ = parent; }
enum ScopeType Type() const { return type_; }
Identifier* Find(const Token* tok);
Identifier* FindInCurScope(const Token* tok);
Identifier* FindTag(const Token* tok);
Identifier* FindTagInCurScope(const Token* tok);
TagList AllTagsInCurScope() const;
void Insert(Identifier* ident);
void Insert(const std::string& name, Identifier* ident);
void InsertTag(Identifier* ident);
void Print();
bool operator==(const Scope& other) const { return type_ == other.type_; }
IdentMap::iterator begin() { return identMap_.begin(); }
IdentMap::iterator end() { return identMap_.end(); }
size_t size() const { return identMap_.size(); }
private:
Identifier* Find(const std::string& name);
Identifier* FindInCurScope(const std::string& name);
Identifier* FindTag(const std::string& name);
Identifier* FindTagInCurScope(const std::string& name);
std::string TagName(const std::string& name) {
return name + "@:tag";
}
static bool IsTagName(const std::string& name) {
return name.size() > 5 && name[name.size() - 5] == '@';
}
const Scope& operator=(const Scope& other);
Scope(const Scope& scope);
Scope* parent_;
enum ScopeType type_;
IdentMap identMap_;
};
#endif

View File

@@ -0,0 +1,418 @@
#ifndef _WGTCC_TOKEN_H_
#define _WGTCC_TOKEN_H_
#include "error.h"
#include <cassert>
#include <cstring>
#include <iostream>
#include <list>
#include <set>
#include <string>
#include <unordered_map>
class Generator;
class Parser;
class Scanner;
class Token;
class TokenSequence;
using HideSet = std::set<std::string>;
using TokenList = std::list<const Token*>;
struct SourceLocation {
const std::string* filename_;
const char* lineBegin_;
unsigned line_;
unsigned column_;
const char* Begin() const {
return lineBegin_ + column_ - 1;
}
};
class Token {
friend class Scanner;
public:
enum {
// Punctuators
LPAR = '(',
RPAR = ')',
LSQB = '[',
RSQB = ']',
COLON = ':',
COMMA = ',',
SEMI = ';',
ADD = '+',
SUB = '-',
MUL = '*',
DIV = '/',
OR = '|',
AND = '&',
XOR = '^',
LESS = '<',
GREATER = '>',
EQUAL = '=',
DOT = '.',
MOD = '%',
LBRACE = '{',
RBRACE = '}',
TILDE = '~',
NOT = '!',
COND = '?',
SHARP = '#',
AT = '@',
NEW_LINE = '\n',
DSHARP = 128, // '##'
PTR,
INC,
DEC,
LEFT,
RIGHT,
LE,
GE,
EQ,
NE,
LOGICAL_AND,
LOGICAL_OR,
MUL_ASSIGN,
DIV_ASSIGN,
MOD_ASSIGN,
ADD_ASSIGN,
SUB_ASSIGN,
LEFT_ASSIGN,
RIGHT_ASSIGN,
AND_ASSIGN,
XOR_ASSIGN,
OR_ASSIGN,
ELLIPSIS,
// Punctuators end
// KEYWORD BEGIN
// TYPE QUALIFIER BEGIN
CONST,
RESTRICT,
VOLATILE,
ATOMIC,
// TYPE QUALIFIER END
// TYPE SPECIFIER BEGIN
VOID,
CHAR,
SHORT,
INT,
LONG,
HALF,
FLOAT,
DOUBLE,
SIGNED,
UNSIGNED,
BOOL, // _Bool
COMPLEX, // _Complex
STRUCT,
UNION,
ENUM,
// TYPE SPECIFIER END
ATTRIBUTE, // GNU extension __attribute__
// FUNCTION SPECIFIER BEGIN
INLINE,
NORETURN, // _Noreturn
// FUNCTION SPECIFIER END
ALIGNAS, // _Alignas
// For syntactic convenience
STATIC_ASSERT, // _Static_assert
// STORAGE CLASS SPECIFIER BEGIN
TYPEDEF,
EXTERN,
STATIC,
THREAD, // _Thread_local
AUTO,
REGISTER,
// STORAGE CLASS SPECIFIER END
BREAK,
CASE,
CONTINUE,
DEFAULT,
DO,
ELSE,
FOR,
GOTO,
IF,
RETURN,
SIZEOF,
SWITCH,
WHILE,
ALIGNOF, // _Alignof
GENERIC, // _Generic
IMAGINARY, // _Imaginary
// KEYWORD END
IDENTIFIER,
CONSTANT,
I_CONSTANT,
C_CONSTANT,
F_CONSTANT,
LITERAL,
// For the parser, a identifier is a typedef name or user defined type
POSTFIX_INC,
POSTFIX_DEC,
PREFIX_INC,
PREFIX_DEC,
ADDR, // '&'
DEREF, // '*'
PLUS,
MINUS,
CAST,
// For preprocessor
PP_IF,
PP_IFDEF,
PP_IFNDEF,
PP_ELIF,
PP_ELSE,
PP_ENDIF,
PP_INCLUDE,
PP_DEFINE,
PP_UNDEF,
PP_LINE,
PP_ERROR,
PP_PRAGMA,
PP_NONE,
PP_EMPTY,
IGNORE,
INVALID,
END,
NOTOK = -1,
};
static Token* New(int tag);
static Token* New(const Token& other);
static Token* New(int tag,
const SourceLocation& loc,
const std::string& str,
bool ws=false);
Token& operator=(const Token& other) {
tag_ = other.tag_;
ws_ = other.ws_;
loc_ = other.loc_;
str_ = other.str_;
hs_ = other.hs_ ? new HideSet(*other.hs_): nullptr;
return *this;
}
virtual ~Token() {}
// Token::NOTOK represents not a kw.
static int KeyWordTag(const std::string& key) {
auto kwIter = kwTypeMap_.find(key);
if (kwTypeMap_.end() == kwIter)
return Token::NOTOK; // Not a key word type
return kwIter->second;
}
static bool IsKeyWord(const std::string& name);
static bool IsKeyWord(int tag) { return CONST <= tag && tag < IDENTIFIER; }
bool IsKeyWord() const { return IsKeyWord(tag_); }
bool IsPunctuator() const { return 0 <= tag_ && tag_ <= ELLIPSIS; }
bool IsLiteral() const { return tag_ == LITERAL; }
bool IsConstant() const { return CONSTANT <= tag_ && tag_ <= F_CONSTANT; }
bool IsIdentifier() const { return IDENTIFIER == tag_; }
bool IsEOF() const { return tag_ == Token::END; }
bool IsTypeSpecQual() const { return CONST <= tag_ && tag_ <= ENUM; }
bool IsDecl() const { return CONST <= tag_ && tag_ <= REGISTER; }
static const char* Lexeme(int tag) {
auto iter = tagLexemeMap_.find(tag);
if (iter == tagLexemeMap_.end())
return nullptr;
return iter->second;
}
int tag_;
// 'ws_' standards for weither there is preceding white space
// This is to simplify the '#' operator(stringize) in macro expansion
bool ws_ { false };
SourceLocation loc_;
std::string str_;
HideSet* hs_ { nullptr };
private:
explicit Token(int tag): tag_(tag) {}
Token(int tag, const SourceLocation& loc,
const std::string& str, bool ws=false)
: tag_(tag), ws_(ws), loc_(loc), str_(str) {}
Token(const Token& other) {
*this = other;
}
static const std::unordered_map<std::string, int> kwTypeMap_;
static const std::unordered_map<int, const char*> tagLexemeMap_;
};
class TokenSequence {
friend class Preprocessor;
public:
TokenSequence(): tokList_(new TokenList()),
begin_(tokList_->begin()), end_(tokList_->end()) {}
explicit TokenSequence(Token* tok) {
TokenSequence();
InsertBack(tok);
}
explicit TokenSequence(TokenList* tokList)
: tokList_(tokList),
begin_(tokList->begin()),
end_(tokList->end()) {}
TokenSequence(TokenList* tokList,
TokenList::iterator begin,
TokenList::iterator end)
: tokList_(tokList), begin_(begin), end_(end) {}
~TokenSequence() {}
TokenSequence(const TokenSequence& other) { *this = other; }
const TokenSequence& operator=(const TokenSequence& other) {
tokList_ = other.tokList_;
begin_ = other.begin_;
end_ = other.end_;
return *this;
}
void Copy(const TokenSequence& other) {
tokList_ = new TokenList(other.begin_, other.end_);
begin_ = tokList_->begin();
end_ = tokList_->end();
for (auto iter = begin_; iter != end_; ++iter)
*iter = Token::New(**iter);
}
void UpdateHeadLocation(const SourceLocation& loc) {
assert(!Empty());
auto tok = const_cast<Token*>(Peek());
tok->loc_ = loc;
}
void FinalizeSubst(bool leadingWS, const HideSet& hs) {
auto ts = *this;
while (!ts.Empty()) {
auto tok = const_cast<Token*>(ts.Next());
if (!tok->hs_)
tok->hs_ = new HideSet(hs);
else
tok->hs_->insert(hs.begin(), hs.end());
}
// Even if the token sequence is empty
const_cast<Token*>(Peek())->ws_ = leadingWS;
}
const Token* Expect(int expect);
bool Try(int tag) {
if (Peek()->tag_ == tag) {
Next();
return true;
}
return false;
}
bool Test(int tag) { return Peek()->tag_ == tag; }
const Token* Next() {
auto ret = Peek();
if (!ret->IsEOF()) {
++begin_;
Peek(); // May skip newline token, but why ?
} else {
++exceed_end;
}
return ret;
}
void PutBack() {
assert(begin_ != tokList_->begin());
if (exceed_end > 0) {
--exceed_end;
} else {
--begin_;
if ((*begin_)->tag_ == Token::NEW_LINE)
PutBack();
}
}
const Token* Peek() const;
const Token* Peek2() {
if (Empty())
return Peek(); // Return the Token::END
Next();
auto ret = Peek();
PutBack();
return ret;
}
const Token* Back() const {
auto back = end_;
return *--back;
}
void PopBack() {
assert(!Empty());
assert(end_ == tokList_->end());
auto size_eq1 = tokList_->back() == *begin_;
tokList_->pop_back();
end_ = tokList_->end();
if (size_eq1)
begin_ = end_;
}
TokenList::iterator Mark() { return begin_; }
void ResetTo(TokenList::iterator mark) { begin_ = mark; }
bool Empty() const { return Peek()->tag_ == Token::END; }
void InsertBack(TokenSequence& ts) {
auto pos = tokList_->insert(end_, ts.begin_, ts.end_);
if (begin_ == end_) {
begin_ = pos;
}
}
void InsertBack(const Token* tok) {
auto pos = tokList_->insert(end_, tok);
if (begin_ == end_) {
begin_ = pos;
}
}
// If there is preceding newline
void InsertFront(TokenSequence& ts) {
auto pos = GetInsertFrontPos();
begin_ = tokList_->insert(pos, ts.begin_, ts.end_);
}
void InsertFront(const Token* tok) {
auto pos = GetInsertFrontPos();
begin_ = tokList_->insert(pos, tok);
}
bool IsBeginOfLine() const;
TokenSequence GetLine();
void SetParser(Parser* parser) { parser_ = parser; }
void Print(FILE* fp=stdout) const;
void Print(std::string *str) const;
private:
// Find a insert position with no preceding newline
TokenList::iterator GetInsertFrontPos() {
auto pos = begin_;
if (pos == tokList_->begin())
return pos;
--pos;
while (pos != tokList_->begin() && (*pos)->tag_ == Token::NEW_LINE)
--pos;
return ++pos;
}
TokenList* tokList_;
mutable TokenList::iterator begin_;
TokenList::iterator end_;
Parser* parser_ {nullptr};
int exceed_end {0};
};
#endif

View File

@@ -0,0 +1,450 @@
#ifndef _WGTCC_TYPE_H_
#define _WGTCC_TYPE_H_
#include "mem_pool.h"
#include "scope.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <list>
class Scope;
class Token;
class Expr;
class Type;
class QualType;
class VoidType;
class Identifier;
class Object;
class Constant;
class ArithmType;
class DerivedType;
class ArrayType;
class TileType;
class FuncType;
class PointerType;
class StructType;
class EnumType;
enum {
// Storage class specifiers
S_TYPEDEF = 0x01,
S_EXTERN = 0x02,
S_STATIC = 0x04,
S_THREAD = 0x08,
S_AUTO = 0x10,
S_REGISTER = 0x20,
// Type specifier
T_SIGNED = 0x40,
T_UNSIGNED = 0x80,
T_CHAR = 0x100,
T_SHORT = 0x200,
T_INT = 0x400,
T_LONG = 0x800,
T_VOID = 0x1000,
T_HALF = 0x2000,
T_FLOAT = 0x4000,
T_DOUBLE = 0x8000,
T_BOOL = 0x10000,
T_COMPLEX = 0x20000,
// T_ATOMIC = 0x40000,
T_STRUCT_UNION = 0x80000,
T_ENUM = 0x100000,
T_TYPEDEF_NAME = 0x200000,
T_LLONG = 0x4000000,
// Function specifier
F_INLINE = 0x8000000,
F_NORETURN = 0x10000000,
};
struct Qualifier {
enum {
CONST = 0x01,
RESTRICT = 0x02,
VOLATILE = 0x04,
MASK = CONST | RESTRICT | VOLATILE
};
};
class QualType {
public:
QualType(Type* ptr, int quals=0x00)
: ptr_(reinterpret_cast<intptr_t>(ptr)) {
assert((quals & ~Qualifier::MASK) == 0);
ptr_ |= quals;
}
operator bool() const { return !IsNull(); }
bool IsNull() const { return GetPtr() == nullptr; }
const Type* GetPtr() const {
return reinterpret_cast<const Type*>(ptr_ & ~Qualifier::MASK);
}
Type* GetPtr() {
return reinterpret_cast<Type*>(ptr_ & ~Qualifier::MASK);
}
Type& operator*() { return *GetPtr(); }
const Type& operator*() const { return *GetPtr(); }
Type* operator->() { return GetPtr(); }
const Type* operator->() const { return GetPtr(); }
// Indicate whether the specified types are identical(exclude qualifiers).
friend bool operator==(QualType lhs, QualType rhs) {
return lhs.operator->() == rhs.operator->();
}
friend bool operator!=(QualType lhs, QualType rhs) {
return !(lhs == rhs);
}
int Qual() const { return ptr_ & 0x07; }
bool IsConstQualified() const { return ptr_ & Qualifier::CONST; }
bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; }
bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; }
private:
intptr_t ptr_;
};
class Type {
public:
static const int intWidth_ = 4;
static const int machineWidth_ = 8;
bool operator!=(const Type& other) const = delete;
bool operator==(const Type& other) const = delete;
virtual bool Compatible(const Type& other) const {
return complete_ == other.complete_;
}
virtual ~Type() {}
// For Debugging
virtual std::string Str() const = 0;
virtual int Width() const = 0;
virtual int Align() const { return Width(); }
static int MakeAlign(int offset, int align) {
if ((offset % align) == 0)
return offset;
if (offset >= 0)
return offset + align - (offset % align);
else
return offset - align - (offset % align);
}
static QualType MayCast(QualType type, bool inProtoScope=false);
bool Complete() const { return complete_; }
void SetComplete(bool complete) const { complete_ = complete; }
bool IsReal() const { return IsInteger() || IsFloat(); };
virtual bool IsScalar() const { return false; }
virtual bool IsFloat() const { return false; }
virtual bool IsInteger() const { return false; }
virtual bool IsBool() const { return false; }
virtual bool IsVoidPointer() const { return false; }
virtual bool IsUnsigned() const { return false; }
virtual VoidType* ToVoid() { return nullptr; }
virtual const VoidType* ToVoid() const { return nullptr; }
virtual ArithmType* ToArithm() { return nullptr; }
virtual const ArithmType* ToArithm() const { return nullptr; }
virtual ArrayType* ToArray() { return nullptr; }
virtual const ArrayType* ToArray() const { return nullptr; }
virtual TileType* ToTile() { return nullptr; }
virtual const TileType* ToTile() const { return nullptr; }
virtual FuncType* ToFunc() { return nullptr; }
virtual const FuncType* ToFunc() const { return nullptr; }
virtual PointerType* ToPointer() { return nullptr; }
virtual const PointerType* ToPointer() const { return nullptr; }
virtual DerivedType* ToDerived() { return nullptr; }
virtual const DerivedType* ToDerived() const { return nullptr; }
virtual StructType* ToStruct() { return nullptr; }
virtual const StructType* ToStruct() const { return nullptr; }
protected:
Type(MemPool* pool, bool complete)
: complete_(complete), pool_(pool) {}
mutable bool complete_;
MemPool* pool_;
};
class VoidType : public Type {
public:
static VoidType* New();
virtual ~VoidType() {}
virtual VoidType* ToVoid() { return this; }
virtual const VoidType* ToVoid() const { return this; }
virtual bool Compatible(const Type& other) const { return other.ToVoid(); }
virtual int Width() const {
// Non-standard GNU extension
return 1;
}
virtual std::string Str() const { return "void:1"; }
protected:
explicit VoidType(MemPool* pool): Type(pool, false) {}
};
class ArithmType : public Type {
public:
static ArithmType* New(int typeSpec);
virtual ~ArithmType() {}
virtual ArithmType* ToArithm() { return this; }
virtual const ArithmType* ToArithm() const { return this; }
virtual bool Compatible(const Type& other) const {
// C11 6.2.7 [1]: Two types have compatible type if their types are the same
// But I would to loose this constraints: integer and pointer are compatible
// if (IsInteger() && other.ToPointer())
// return other.Compatible(*this);
return this == &other;
}
virtual int Width() const;
virtual std::string Str() const;
virtual bool IsScalar() const { return true; }
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
virtual bool IsFloat() const {
return (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
}
virtual bool IsBool() const { return tag_ & T_BOOL; }
bool IsComplex() const { return tag_ & T_COMPLEX; }
int Tag() const { return tag_; }
int Rank() const;
static ArithmType* IntegerPromote(ArithmType* type) {
assert(type->IsInteger());
if (type->Rank() < ArithmType::New(T_INT)->Rank())
return ArithmType::New(T_INT);
return type;
}
static ArithmType* MaxType(ArithmType* lhsType,
ArithmType* rhsType);
protected:
explicit ArithmType(MemPool* pool, int spec)
: Type(pool, true), tag_(Spec2Tag(spec)) {}
private:
static int Spec2Tag(int spec);
int tag_;
};
class DerivedType : public Type {
public:
QualType Derived() const { return derived_; }
void SetDerived(QualType derived) { derived_ = derived; }
virtual DerivedType* ToDerived() { return this; }
virtual const DerivedType* ToDerived() const { return this; }
protected:
DerivedType(MemPool* pool, QualType derived)
: Type(pool, true), derived_(derived) {}
QualType derived_;
};
class PointerType : public DerivedType {
public:
static PointerType* New(QualType derived);
virtual ~PointerType() {}
virtual PointerType* ToPointer() { return this; }
virtual const PointerType* ToPointer() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 8; }
virtual bool IsScalar() const { return true; }
virtual bool IsVoidPointer() const { return derived_->ToVoid(); }
virtual std::string Str() const {
return derived_->Str() + "*:" + std::to_string(Width());
}
protected:
PointerType(MemPool* pool, QualType derived): DerivedType(pool, derived) {}
};
class ArrayType : public DerivedType {
public:
static ArrayType* New(int len, QualType eleType);
static ArrayType* New(Expr* expr, QualType eleType);
virtual ~ArrayType() { /*delete derived_;*/ }
virtual ArrayType* ToArray() { return this; }
virtual const ArrayType* ToArray() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const {
return Complete() ? (derived_->Width() * len_): 0;
}
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[]:" + std::to_string(Width());
}
int GetElementOffset(int idx) const { return derived_->Width() * idx; }
int Len() const { return len_; }
void SetLen(int len) { len_ = len; }
bool Variadic() const { return lenExpr_ != nullptr; }
protected:
ArrayType(MemPool* pool, Expr* lenExpr, QualType derived)
: DerivedType(pool, derived),
lenExpr_(lenExpr), len_(0) {
SetComplete(false);
}
ArrayType(MemPool* pool, int len, QualType derived)
: DerivedType(pool, derived),
lenExpr_(nullptr), len_(len) {
SetComplete(len_ >= 0);
}
const Expr* lenExpr_;
int len_;
};
class TileType : public DerivedType {
public:
using ShapeExpr = std::vector<Expr*>;
using ShapeInt = std::vector<int>;
public:
static TileType* New(const ShapeExpr& expr, QualType eleType);
static TileType* New(const ShapeInt& shape, QualType eleType);
virtual ~TileType() { }
virtual TileType* toTile() { return this; }
virtual const TileType* toTile() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 0; }
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[{}]:" + std::to_string(Width());
}
ShapeInt Shape() { return shape_; }
protected:
TileType(MemPool* pool, const ShapeExpr& expr, QualType derived)
: DerivedType(pool, derived),
shapeExpr_(expr) {
bool isComplete = true;
for(Expr* s: shapeExpr_)
isComplete = isComplete && !s;
SetComplete(isComplete);
}
TileType(MemPool* pool, const ShapeInt& shape, QualType derived)
: DerivedType(pool, derived),
shape_(shape) {
bool isComplete = true;
for(int s: shape_)
isComplete = isComplete && (s>=0);
SetComplete(isComplete);
}
protected:
ShapeExpr shapeExpr_;
ShapeInt shape_;
};
class FuncType : public DerivedType {
public:
using ParamList = std::vector<Object*>;
public:
static FuncType* New(QualType derived,
int funcSpec,
bool variadic,
const ParamList& params);
virtual ~FuncType() {}
virtual FuncType* ToFunc() { return this; }
virtual const FuncType* ToFunc() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 1; }
virtual std::string Str() const;
const ParamList& Params() const { return params_; }
void SetParams(const ParamList& params) { params_ = params; }
bool Variadic() const { return variadic_; }
bool IsInline() const { return inlineNoReturn_ & F_INLINE; }
bool IsNoReturn() const { return inlineNoReturn_ & F_NORETURN; }
protected:
FuncType(MemPool* pool, QualType derived, int inlineReturn,
bool variadic, const ParamList& params)
: DerivedType(pool, derived), inlineNoReturn_(inlineReturn),
variadic_(variadic), params_(params) {
SetComplete(false);
}
private:
int inlineNoReturn_;
bool variadic_;
ParamList params_;
};
class StructType : public Type {
public:
using MemberList = std::list<Object*>;
using Iterator = std::list<Object*>::iterator;
public:
static StructType* New(bool isStruct,
bool hasTag,
Scope* parent);
virtual ~StructType() {}
virtual StructType* ToStruct() { return this; }
virtual const StructType* ToStruct() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return width_; }
virtual int Align() const { return align_; }
virtual std::string Str() const;
// struct/union
void AddMember(Object* member);
void AddBitField(Object* member, int offset);
bool IsStruct() const { return isStruct_; }
Object* GetMember(const std::string& member);
Scope* MemberMap() { return memberMap_; }
MemberList& Members() { return members_; }
int Offset() const { return offset_; }
bool HasTag() const { return hasTag_; }
void MergeAnony(Object* anony);
void Finalize();
protected:
// Default is incomplete
StructType(MemPool* pool, bool isStruct, bool hasTag, Scope* parent);
StructType(const StructType& other);
private:
void CalcWidth();
bool isStruct_;
bool hasTag_;
Scope* memberMap_;
MemberList members_;
int offset_;
int width_;
int align_;
int bitFieldAlign_;
};
#endif

View File

@@ -0,0 +1,50 @@
#ifndef _WGTCC_VISITOR_H_
#define _WGTCC_VISITOR_H_
class BinaryOp;
class UnaryOp;
class ConditionalOp;
class FuncCall;
class Identifier;
class Object;
class Enumerator;
class Constant;
class TempVar;
class Declaration;
class IfStmt;
class JumpStmt;
class ReturnStmt;
class LabelStmt;
class EmptyStmt;
class CompoundStmt;
class FuncDef;
class TranslationUnit;
class Visitor {
public:
virtual ~Visitor() {}
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
virtual void VisitEnumerator(Enumerator* enumer) = 0;
virtual void VisitIdentifier(Identifier* ident) = 0;
virtual void VisitObject(Object* obj) = 0;
virtual void VisitConstant(Constant* cons) = 0;
virtual void VisitTempVar(TempVar* tempVar) = 0;
virtual void VisitDeclaration(Declaration* init) = 0;
virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) = 0;
virtual void VisitCompoundStmt(CompoundStmt* compStmt) = 0;
virtual void VisitFuncDef(FuncDef* funcDef) = 0;
virtual void VisitTranslationUnit(TranslationUnit* unit) = 0;
};
#endif

885
lib/lang/wgtcc/ast.cc Normal file
View File

@@ -0,0 +1,885 @@
#include "triton/lang/wgtcc/ast.h"
#include "triton/lang/wgtcc/code_gen.h"
#include "triton/lang/wgtcc/error.h"
#include "triton/lang/wgtcc/evaluator.h"
#include "triton/lang/wgtcc/mem_pool.h"
#include "triton/lang/wgtcc/parser.h"
#include "triton/lang/wgtcc/token.h"
static MemPoolImp<BinaryOp> binaryOpPool;
static MemPoolImp<ConditionalOp> conditionalOpPool;
static MemPoolImp<FuncCall> funcCallPool;
static MemPoolImp<Declaration> initializationPool;
static MemPoolImp<Object> objectPool;
static MemPoolImp<Identifier> identifierPool;
static MemPoolImp<Enumerator> enumeratorPool;
static MemPoolImp<Constant> constantPool;
static MemPoolImp<TempVar> tempVarPool;
static MemPoolImp<UnaryOp> unaryOpPool;
static MemPoolImp<EmptyStmt> emptyStmtPool;
static MemPoolImp<IfStmt> ifStmtPool;
static MemPoolImp<JumpStmt> jumpStmtPool;
static MemPoolImp<ReturnStmt> returnStmtPool;
static MemPoolImp<LabelStmt> labelStmtPool;
static MemPoolImp<CompoundStmt> compoundStmtPool;
static MemPoolImp<FuncDef> funcDefPool;
/*
* Accept
*/
void Declaration::Accept(Visitor* v) {
v->VisitDeclaration(this);
}
void EmptyStmt::Accept(Visitor* v) {
// Nothing to do
}
void LabelStmt::Accept(Visitor* v) {
v->VisitLabelStmt(this);
}
void IfStmt::Accept(Visitor* v) {
v->VisitIfStmt(this);
}
void JumpStmt::Accept(Visitor* v) {
v->VisitJumpStmt(this);
}
void ReturnStmt::Accept(Visitor* v) {
v->VisitReturnStmt(this);
}
void CompoundStmt::Accept(Visitor* v) {
v->VisitCompoundStmt(this);
}
void BinaryOp::Accept(Visitor* v) {
v->VisitBinaryOp(this);
}
void UnaryOp::Accept(Visitor* v) {
v->VisitUnaryOp(this);
}
void ConditionalOp::Accept(Visitor* v) {
v->VisitConditionalOp(this);
}
void FuncCall::Accept(Visitor* v) {
v->VisitFuncCall(this);
}
void Identifier::Accept(Visitor* v) {
v->VisitIdentifier(this);
}
void Object::Accept(Visitor* v) {
v->VisitObject(this);
}
void Constant::Accept(Visitor* v) {
v->VisitConstant(this);
}
void Enumerator::Accept(Visitor* v)
{
v->VisitEnumerator(this);
}
void TempVar::Accept(Visitor* v) {
v->VisitTempVar(this);
}
void FuncDef::Accept(Visitor* v) {
v->VisitFuncDef(this);
}
void TranslationUnit::Accept(Visitor* v) {
v->VisitTranslationUnit(this);
}
// Casting array to pointer, function to pointer to function
Expr* Expr::MayCast(Expr* expr) {
auto type = Type::MayCast(expr->Type());
// If the types are equal, no need cast
if (type != expr->Type()) { // Pointer comparison is enough
return UnaryOp::New(Token::CAST, expr, type);
}
return expr;
}
Expr* Expr::MayCast(Expr* expr, QualType desType) {
expr = MayCast(expr);
auto srcType = expr->Type();
if (desType->ToPointer() && srcType->ToPointer())
if (desType->IsVoidPointer() || srcType->IsVoidPointer())
return expr;
if (!desType->Compatible(*expr->Type()))
expr = UnaryOp::New(Token::CAST, expr, desType);
return expr;
}
BinaryOp* BinaryOp::New(const Token* tok, Expr* lhs, Expr* rhs) {
return New(tok, tok->tag_, lhs, rhs);
}
BinaryOp* BinaryOp::New(const Token* tok, int op, Expr* lhs, Expr* rhs) {
switch (op) {
case ',': case '.': case '=':
case '*': case '/': case '%':
case '+': case '-': case '&':
case '^': case '|': case '<':
case '>':
case Token::LEFT:
case Token::RIGHT:
case Token::LE:
case Token::GE:
case Token::EQ:
case Token::NE:
case Token::LOGICAL_AND:
case Token::LOGICAL_OR:
case Token::ELLIPSIS:
break;
default:
assert(0);
}
auto ret = new (binaryOpPool.Alloc()) BinaryOp(tok, op, lhs, rhs);
ret->pool_ = &binaryOpPool;
ret->TypeChecking();
return ret;
}
ArithmType* BinaryOp::Convert() {
// Both lhs and rhs are ensured to be have arithmetic type
auto lhsType = lhs_->Type()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm();
assert(lhsType && rhsType);
auto type = ArithmType::MaxType(lhsType, rhsType);
if (lhsType != type) { // Pointer comparation is enough!
lhs_ = UnaryOp::New(Token::CAST, lhs_, type);
}
if (rhsType != type) {
rhs_ = UnaryOp::New(Token::CAST, rhs_, type);
}
return type;
}
void BinaryOp::Broadcast() {
auto lhsType = lhs_->Type()->ToTile();
auto rhsType = rhs_->Type()->ToTile();
if(!lhsType && !rhsType)
return ;
else if(lhsType && !rhsType){
type_ = lhsType;
rhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
}
else if(!lhsType && rhsType){
type_ = rhsType;
lhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
}
else {
auto lhsShape = lhsType->Shape();
auto rhsShape = rhsType->Shape();
auto lhsRank = lhsShape.size();
auto rhsRank = rhsShape.size();
auto retRank = std::max(lhsRank, rhsRank);
// pad to the left until shapes have the same rank
while(lhsShape.size() < retRank)
lhsShape.insert(lhsShape.begin(), 1);
while(rhsShape.size() < retRank)
rhsShape.insert(rhsShape.begin(), 1);
// broadcast if possible
TileType::ShapeInt retShape(retRank);
for(size_t i = 0; i < retRank; i++) {
if(lhsShape[i] == 1)
retShape[i] = rhsShape[i];
else if(rhsShape[i] == 1)
retShape[i] = lhsShape[i];
else
Error(this, "cannot broadcast dimension %d "
"for operands of shape %d and %d",
i, lhsShape[i], rhsShape[i]);
}
auto eleType = lhsType->Derived();
type_ = TileType::New(retShape, eleType);
lhs_ = UnaryOp::New(Token::CAST, lhs_, type_);
rhs_ = UnaryOp::New(Token::CAST, rhs_, type_);
}
}
/*
* Type checking
*/
void Expr::EnsureCompatibleOrVoidPointer(const QualType lhs,
const QualType rhs) const {
if (lhs->ToPointer() && rhs->ToPointer() &&
(lhs->IsVoidPointer() || rhs->IsVoidPointer())) {
return;
}
EnsureCompatible(lhs, rhs);
}
void Expr::EnsureCompatible(const QualType lhs, const QualType rhs) const {
if (!lhs->Compatible(*rhs))
Error(this, "incompatible types");
}
void BinaryOp::TypeChecking() {
switch (op_) {
case '.':
return MemberRefOpTypeChecking();
case '*':
case '/':
case '%':
return MultiOpTypeChecking();
case '+':
case '-':
return AdditiveOpTypeChecking();
case Token::LEFT:
case Token::RIGHT:
return ShiftOpTypeChecking();
case '<':
case '>':
case Token::LE:
case Token::GE:
return RelationalOpTypeChecking();
case Token::EQ:
case Token::NE:
return EqualityOpTypeChecking();
case '&':
case '^':
case '|':
return BitwiseOpTypeChecking();
case Token::LOGICAL_AND:
case Token::LOGICAL_OR:
return LogicalOpTypeChecking();
case '=':
return AssignOpTypeChecking();
case ',':
return CommaOpTypeChecking();
case Token::ELLIPSIS:
return RangeOpTypeChecking();
default:
assert(0);
}
}
void BinaryOp::CommaOpTypeChecking() {
type_ = rhs_->Type();
}
void BinaryOp::SubScriptingOpTypeChecking() {
auto lhsType = lhs_->Type()->ToPointer();
if (!lhsType) {
Error(this, "an pointer expected");
}
if (!rhs_->Type()->IsInteger()) {
Error(this, "the operand of [] should be intger");
}
// The type of [] operator is the derived type
type_ = lhsType->Derived();
}
void BinaryOp::MemberRefOpTypeChecking() {
type_ = rhs_->Type();
}
void BinaryOp::MultiOpTypeChecking() {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
Error(this, "operands should have arithmetic type");
}
if ('%' == op_ &&
!(lhs_->Type()->IsInteger() && rhs_->Type()->IsInteger())) {
Error(this, "operands of '%%' should be integers");
}
type_ = Convert();
}
/*
* Additive operator is only allowed between:
* 1. arithmetic types (bool, interger, floating)
* 2. pointer can be used:
* 1. lhs of MINUS operator, and rhs must be integer or pointer;
* 2. lhs/rhs of ADD operator, and the other operand must be integer;
*/
void BinaryOp::AdditiveOpTypeChecking() {
auto lhsType = lhs_->Type()->ToPointer();
auto rhsType = rhs_->Type()->ToPointer();
if (lhsType) {
if (op_ == '-') {
if (rhsType) {
if (!lhsType->Compatible(*rhsType))
Error(this, "invalid operands to binary -");
type_ = ArithmType::New(T_LONG); // ptrdiff_t
} else if (!rhs_->Type()->IsInteger()) {
Error(this, "invalid operands to binary -");
} else {
type_ = lhsType;
}
} else if (!rhs_->Type()->IsInteger()) {
Error(this, "invalid operands to binary +");
} else {
type_ = lhsType;
}
} else if (rhsType) {
if (op_ == '+' && !lhs_->Type()->IsInteger()) {
Error(this, "invalid operands to binary '+'");
} else if (op_ == '-' && !lhsType) {
Error(this, "invalid operands to binary '-'");
}
type_ = op_ == '-' ? ArithmType::New(T_LONG): rhs_->Type();
std::swap(lhs_, rhs_); // To simplify code gen
} else {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
}
type_ = Convert();
}
}
void BinaryOp::RangeOpTypeChecking() {
auto lhsType = lhs_->Type()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm();
if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
Error(this, "expect integers for range operator");
lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType));
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
long begin = Evaluator<long>().Eval(lhs_);
long end = Evaluator<long>().Eval(rhs_);
int len = end - begin;
if(len < 0)
Error(this, "range cannot be negative");
type_ = TileType::New(TileType::ShapeInt{len}, lhs_->Type());
}
void BinaryOp::ShiftOpTypeChecking() {
auto lhsType = lhs_->Type()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm();
if (!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
Error(this, "expect integers for shift operator");
lhs_ = Expr::MayCast(lhs_, ArithmType::IntegerPromote(lhsType));
rhs_ = Expr::MayCast(rhs_, ArithmType::IntegerPromote(rhsType));
type_ = lhs_->Type();
}
void BinaryOp::RelationalOpTypeChecking() {
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) {
EnsureCompatible(lhs_->Type(), rhs_->Type());
} else {
if (!lhs_->Type()->IsReal() || !rhs_->Type()->IsReal()) {
Error(this, "expect real type of operands");
}
Convert();
}
type_ = ArithmType::New(T_INT);
}
void BinaryOp::EqualityOpTypeChecking() {
if (lhs_->Type()->ToPointer() || rhs_->Type()->ToPointer()) {
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type());
} else {
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm())
Error(this, "invalid operands to binary %s", tok_->str_.c_str());
Convert();
}
type_ = ArithmType::New(T_INT);
}
void BinaryOp::BitwiseOpTypeChecking() {
if (!lhs_->Type()->IsInteger() || !rhs_->Type()->IsInteger())
Error(this, "operands of '&' should be integer");
type_ = Convert();
}
void BinaryOp::LogicalOpTypeChecking() {
if (!lhs_->Type()->IsScalar() || !rhs_->Type()->IsScalar())
Error(this, "the operand should be arithmetic type or pointer");
type_ = ArithmType::New(T_INT);
}
void BinaryOp::AssignOpTypeChecking() {
if (lhs_->IsConstQualified()) {
Error(lhs_, "left operand of '=' is const qualified");
} else if (!lhs_->IsLVal()) {
Error(lhs_, "lvalue expression expected");
}
if (!lhs_->Type()->ToArithm() || !rhs_->Type()->ToArithm()) {
EnsureCompatibleOrVoidPointer(lhs_->Type(), rhs_->Type());
}
// The other constraints are lefted to cast operator
rhs_ = Expr::MayCast(rhs_, lhs_->Type());
type_ = lhs_->Type();
}
/*
* Unary Operators
*/
UnaryOp* UnaryOp::New(int op, Expr* operand, QualType type) {
auto ret = new (unaryOpPool.Alloc()) UnaryOp(op, operand, type);
ret->pool_ = &unaryOpPool;
ret->TypeChecking();
return ret;
}
bool UnaryOp::IsLVal() {
// Only deref('*') could be lvalue;
return op_ == Token::DEREF;
}
ArithmType* UnaryOp::Convert() {
auto arithmType = operand_->Type()->ToArithm();
assert(arithmType);
if (arithmType->IsInteger())
arithmType = ArithmType::IntegerPromote(arithmType);
operand_ = Expr::MayCast(operand_, arithmType);
return arithmType;
}
void UnaryOp::TypeChecking() {
switch (op_) {
case Token::POSTFIX_INC:
case Token::POSTFIX_DEC:
case Token::PREFIX_INC:
case Token::PREFIX_DEC:
return IncDecOpTypeChecking();
case Token::ADDR:
return AddrOpTypeChecking();
case Token::DEREF:
return DerefOpTypeChecking();
case Token::PLUS:
case Token::MINUS:
case '~':
case '!':
return UnaryArithmOpTypeChecking();
case Token::CAST:
return CastOpTypeChecking();
default:
assert(false);
}
}
void UnaryOp::IncDecOpTypeChecking() {
if (operand_->IsConstQualified()) {
Error(this, "increment/decrement of const qualified expression");
} else if (!operand_->IsLVal()) {
Error(this, "lvalue expression expected");
}
if (!operand_->Type()->IsReal() && !operand_->Type()->ToPointer()) {
Error(this, "expect operand of real type or pointer");
}
type_ = operand_->Type();
}
void UnaryOp::AddrOpTypeChecking() {
auto funcType = operand_->Type()->ToFunc();
if (funcType == nullptr && !operand_->IsLVal())
Error(this, "expression must be an lvalue or function designator");
type_ = PointerType::New(operand_->Type());
}
void UnaryOp::DerefOpTypeChecking() {
auto pointerType = operand_->Type()->ToPointer();
if (!pointerType)
Error(this, "pointer expected for deref operator '*'");
type_ = pointerType->Derived();
}
void UnaryOp::UnaryArithmOpTypeChecking() {
if (Token::PLUS == op_ || Token::MINUS == op_) {
if (!operand_->Type()->ToArithm())
Error(this, "Arithmetic type expected");
Convert();
type_ = operand_->Type();
} else if ('~' == op_) {
if (!operand_->Type()->IsInteger())
Error(this, "integer expected for operator '~'");
Convert();
type_ = operand_->Type();
} else if (!operand_->Type()->IsScalar()) {
Error(this, "arithmetic type or pointer expected for operator '!'");
} else {
type_ = ArithmType::New(T_INT);
}
}
void UnaryOp::CastOpTypeChecking() {
auto operandType = Type::MayCast(operand_->Type());
// The type_ has been initiated to dest type
if (type_->ToVoid()) {
// The expression becomes a void expression
} else if (!type_->IsScalar() || !operandType->IsScalar()) {
if (!type_->Compatible(*operandType))
Error(this, "the cast type should be arithemetic type or pointer");
} else if (type_->IsFloat() && operandType->ToPointer()) {
Error(this, "cannot cast a pointer to floating");
} else if (type_->ToPointer() && operandType->IsFloat()) {
Error(this, "cannot cast a floating to pointer");
}
}
/*
* Conditional Operator
*/
ConditionalOp* ConditionalOp::New(const Token* tok,
Expr* cond,
Expr* exprTrue,
Expr* exprFalse) {
auto ret = new (conditionalOpPool.Alloc())
ConditionalOp(cond, exprTrue, exprFalse);
ret->pool_ = &conditionalOpPool;
ret->TypeChecking();
return ret;
}
ArithmType* ConditionalOp::Convert() {
auto lhsType = exprTrue_->Type()->ToArithm();
auto rhsType = exprFalse_->Type()->ToArithm();
assert(lhsType && rhsType);
auto type = ArithmType::MaxType(lhsType, rhsType);
if (lhsType != type) { // Pointer comparation is enough!
exprTrue_ = UnaryOp::New(Token::CAST, exprTrue_, type);
}
if (rhsType != type) {
exprFalse_ = UnaryOp::New(Token::CAST, exprFalse_, type);
}
return type;
}
void ConditionalOp::TypeChecking() {
if (!cond_->Type()->IsScalar()) {
Error(cond_->Tok(), "scalar is required");
}
auto lhsType = exprTrue_->Type();
auto rhsType = exprFalse_->Type();
if (lhsType->ToArithm() && rhsType->ToArithm()) {
type_ = Convert();
} else {
EnsureCompatibleOrVoidPointer(lhsType, rhsType);
type_ = lhsType;
}
}
/*
* Function Call
*/
FuncCall* FuncCall::New(Expr* designator, const ArgList& args) {
auto ret = new (funcCallPool.Alloc()) FuncCall(designator, args);
ret->pool_ = &funcCallPool;
ret->TypeChecking();
return ret;
}
void FuncCall::TypeChecking() {
auto pointerType = designator_->Type()->ToPointer();
if (pointerType) {
if (!pointerType->Derived()->ToFunc())
Error(designator_, "called object is not a function or function pointer");
// Convert function pointer to function type
designator_ = UnaryOp::New(Token::DEREF, designator_);
}
auto funcType = designator_->Type()->ToFunc();
if (!funcType) {
Error(designator_, "called object is not a function or function pointer");
} else if (!funcType->Derived()->ToVoid() &&
!funcType->Derived()->Complete()) {
Error(designator_, "invalid use of incomplete return type");
}
auto arg = args_.begin();
for (auto param: funcType->Params()) {
if (arg == args_.end())
Error(this, "too few arguments for function call");
*arg = Expr::MayCast(*arg, param->Type());
++arg;
}
if (arg != args_.end() && !funcType->Variadic())
Error(this, "too many arguments for function call");
// C11 6.5.2.2 [6]: promote float to double if it has no prototype
while (arg != args_.end()) {
if ((*arg)->Type()->IsFloat() && (*arg)->Type()->Width() == 4) {
auto type = ArithmType::New(T_DOUBLE);
*arg = UnaryOp::New(Token::CAST, *arg, type);
}
++arg;
}
type_ = funcType->Derived();
}
/*
* Identifier
*/
Identifier* Identifier::New(const Token* tok,
QualType type,
enum Linkage linkage) {
auto ret = new (identifierPool.Alloc()) Identifier(tok, type, linkage);
ret->pool_ = &identifierPool;
return ret;
}
Enumerator* Enumerator::New(const Token* tok, int val) {
auto ret = new (enumeratorPool.Alloc()) Enumerator(tok, val);
ret->pool_ = &enumeratorPool;
return ret;
}
Declaration* Declaration::New(Object* obj) {
auto ret = new (initializationPool.Alloc()) Declaration(obj);
ret->pool_ = &initializationPool;
return ret;
}
void Declaration::AddInit(Initializer init) {
init.expr_ = Expr::MayCast(init.expr_, init.type_);
auto res = inits_.insert(init);
if (!res.second) {
inits_.erase(res.first);
inits_.insert(init);
}
}
/*
* Object
*/
Object* Object::New(const Token* tok,
QualType type,
int storage,
enum Linkage linkage,
unsigned char bitFieldBegin,
unsigned char bitFieldWidth) {
auto ret = new (objectPool.Alloc())
Object(tok, type, storage, linkage, bitFieldBegin, bitFieldWidth);
ret->pool_ = &objectPool;
static long id = 0;
if (ret->IsStatic() || ret->Anonymous())
ret->id_ = ++id;
return ret;
}
Object* Object::NewAnony(const Token* tok,
QualType type,
int storage,
enum Linkage linkage,
unsigned char bitFieldBegin,
unsigned char bitFieldWidth) {
auto ret = new (objectPool.Alloc())
Object(tok, type, storage, linkage, bitFieldBegin, bitFieldWidth);
ret->pool_ = &objectPool;
ret->anonymous_ = true;
static long id = 0;
if (ret->IsStatic() || ret->anonymous_)
ret->id_ = ++id;
return ret;
}
/*
* Constant
*/
Constant* Constant::New(const Token* tok, int tag, long val) {
auto type = ArithmType::New(tag);
auto ret = new (constantPool.Alloc()) Constant(tok, type, val);
ret->pool_ = &constantPool;
return ret;
}
Constant* Constant::New(const Token* tok, int tag, double val) {
auto type = ArithmType::New(tag);
auto ret = new (constantPool.Alloc()) Constant(tok, type, val);
ret->pool_ = &constantPool;
return ret;
}
Constant* Constant::New(const Token* tok, int tag, const std::string* val) {
auto derived = ArithmType::New(tag);
auto type = ArrayType::New(val->size() / derived->Width(), derived);
auto ret = new (constantPool.Alloc()) Constant(tok, type, val);
ret->pool_ = &constantPool;
static long id = 0;
ret->id_ = ++id;
return ret;
}
std::string Constant::SValRepr() const {
std::vector<char> buf(4 * sval_->size() + 1);
for (size_t i = 0; i < sval_->size(); ++i) {
int c = (*sval_)[i];
sprintf(&buf[i * 4], "\\x%1x%1x", (c >> 4) & 0xf, c & 0xf);
}
return std::string(buf.begin(), buf.end() - 1);
}
/*
* TempVar
*/
TempVar* TempVar::New(QualType type) {
auto ret = new (tempVarPool.Alloc()) TempVar(type);
ret->pool_ = &tempVarPool;
return ret;
}
/*
* Statement
*/
EmptyStmt* EmptyStmt::New() {
auto ret = new (emptyStmtPool.Alloc()) EmptyStmt();
ret->pool_ = &emptyStmtPool;
return ret;
}
// The else stmt could be null
IfStmt* IfStmt::New(Expr* cond, Stmt* then, Stmt* els) {
auto ret = new (ifStmtPool.Alloc()) IfStmt(cond, then, els);
ret->pool_ = &ifStmtPool;
return ret;
}
CompoundStmt* CompoundStmt::New(std::list<Stmt*>& stmts, ::Scope* scope) {
auto ret = new (compoundStmtPool.Alloc()) CompoundStmt(stmts, scope);
ret->pool_ = &compoundStmtPool;
return ret;
}
JumpStmt* JumpStmt::New(LabelStmt* label) {
auto ret = new (jumpStmtPool.Alloc()) JumpStmt(label);
ret->pool_ = &jumpStmtPool;
return ret;
}
ReturnStmt* ReturnStmt::New(Expr* expr) {
auto ret = new (returnStmtPool.Alloc()) ReturnStmt(expr);
ret->pool_ = &returnStmtPool;
return ret;
}
LabelStmt* LabelStmt::New() {
auto ret = new (labelStmtPool.Alloc()) LabelStmt();
ret->pool_ = &labelStmtPool;
return ret;
}
FuncDef* FuncDef::New(Identifier* ident, LabelStmt* retLabel) {
auto ret = new (funcDefPool.Alloc()) FuncDef(ident, retLabel);
ret->pool_ = &funcDefPool;
return ret;
}
bool Initializer::operator<(const Initializer& rhs) const {
if (offset_ < rhs.offset_)
return true;
return (offset_ == rhs.offset_ && bitFieldBegin_ < rhs.bitFieldBegin_);
}

1561
lib/lang/wgtcc/code_gen.cc Normal file

File diff suppressed because it is too large Load Diff

886
lib/lang/wgtcc/cpp.cc Normal file
View File

@@ -0,0 +1,886 @@
#include "triton/lang/wgtcc/cpp.h"
#include "triton/lang/wgtcc/evaluator.h"
#include "triton/lang/wgtcc/parser.h"
#include <ctime>
#include <fcntl.h>
#include <unistd.h>
#include <unordered_map>
extern std::string filename_in;
extern std::string filename_out;
using DirectiveMap = std::unordered_map<std::string, int>;
static const DirectiveMap directiveMap {
{"if", Token::PP_IF},
{"ifdef", Token::PP_IFDEF},
{"ifndef", Token::PP_IFNDEF},
{"elif", Token::PP_ELIF},
{"else", Token::PP_ELSE},
{"endif", Token::PP_ENDIF},
{"include", Token::PP_INCLUDE},
// Non-standard GNU extension
{"include_next", Token::PP_INCLUDE},
{"define", Token::PP_DEFINE},
{"undef", Token::PP_UNDEF},
{"line", Token::PP_LINE},
{"error", Token::PP_ERROR},
{"pragma", Token::PP_PRAGMA}
};
/*
* params:
* is: input token sequence
* os: output token sequence
*/
void Preprocessor::Expand(TokenSequence& os, TokenSequence is, bool inCond) {
Macro* macro = nullptr;
int direcitve;
while (!is.Empty()) {
UpdateFirstTokenLine(is);
auto tok = is.Peek();
const auto& name = tok->str_;
if ((direcitve = GetDirective(is)) != Token::INVALID) {
ParseDirective(os, is, direcitve);
} else if (!inCond && !NeedExpand()) {
// Discards the token
is.Next();
} else if (inCond && name == "defined") {
is.Next();
os.InsertBack(EvalDefOp(is));
} else if (tok->hs_ && tok->hs_->find(name) != tok->hs_->end()) {
os.InsertBack(is.Next());
} else if ((macro = FindMacro(name))) {
is.Next();
if (name == "__FILE__") {
HandleTheFileMacro(os, tok);
} else if (name == "__LINE__") {
HandleTheLineMacro(os, tok);
} else if (macro->ObjLike()) {
// Make a copy, as subst will change repSeq
auto repSeq = macro->RepSeq(tok->loc_.filename_, tok->loc_.line_);
TokenList tokList;
TokenSequence repSeqSubsted(&tokList);
ParamMap paramMap;
// TODO(wgtdkp): hideset is not right
// Make a copy of hideset
// HS U {name}
auto hs = tok->hs_ ? *tok->hs_: HideSet();
hs.insert(name);
Subst(repSeqSubsted, repSeq, tok->ws_, hs, paramMap);
is.InsertFront(repSeqSubsted);
} else if (is.Try('(')) {
ParamMap paramMap;
auto rpar = ParseActualParam(is, macro, paramMap);
auto repSeq = macro->RepSeq(tok->loc_.filename_, tok->loc_.line_);
TokenList tokList;
TokenSequence repSeqSubsted(&tokList);
// (HS ^ HS') U {name}
// Use HS' U {name} directly
auto hs = rpar->hs_ ? *rpar->hs_: HideSet();
hs.insert(name);
Subst(repSeqSubsted, repSeq, tok->ws_, hs, paramMap);
is.InsertFront(repSeqSubsted);
} else {
os.InsertBack(tok);
}
} else {
os.InsertBack(is.Next());
}
}
}
static bool FindActualParam(TokenSequence& ap,
ParamMap& params,
const std::string& fp) {
auto res = params.find(fp);
if (res == params.end()) {
return false;
}
ap.Copy(res->second);
return true;
}
void Preprocessor::Subst(TokenSequence& os,
TokenSequence is,
bool leadingWS,
const HideSet& hs,
ParamMap& params) {
TokenSequence ap;
while (!is.Empty()) {
if (is.Test('#') && FindActualParam(ap, params, is.Peek2()->str_)) {
is.Next(); is.Next();
auto tok = Stringize(ap);
os.InsertBack(tok);
} else if (is.Test(Token::DSHARP) &&
FindActualParam(ap, params, is.Peek2()->str_)) {
is.Next(); is.Next();
if (!ap.Empty())
Glue(os, ap);
} else if (is.Test(Token::DSHARP)) {
is.Next();
auto tok = is.Next();
Glue(os, tok);
} else if (is.Peek2()->tag_ == Token::DSHARP &&
FindActualParam(ap, params, is.Peek()->str_)) {
is.Next();
if (ap.Empty()) {
is.Next();
if (FindActualParam(ap, params, is.Peek()->str_)) {
is.Next();
os.InsertBack(ap);
}
} else {
os.InsertBack(ap);
}
} else if (FindActualParam(ap, params, is.Peek()->str_)) {
auto tok = is.Next();
const_cast<Token*>(ap.Peek())->ws_ = tok->ws_;
Expand(os, ap);
} else {
os.InsertBack(is.Peek());
is.Next();
}
}
os.FinalizeSubst(leadingWS, hs);
}
void Preprocessor::Glue(TokenSequence& os, const Token* tok) {
TokenList tokList {tok};
TokenSequence is(&tokList);
Glue(os, is);
}
void Preprocessor::Glue(TokenSequence& os, TokenSequence is) {
auto lhs = os.Back();
auto rhs = is.Peek();
auto str = new std::string(lhs->str_ + rhs->str_);
TokenSequence ts;
Scanner scanner(str, lhs->loc_);
scanner.Tokenize(ts);
is.Next();
if (ts.Empty()) {
// TODO(wgtdkp):
// No new Token generated
// How to handle it???
} else {
os.PopBack();
auto newTok = const_cast<Token*>(ts.Next());
newTok->ws_ = lhs->ws_;
newTok->hs_ = lhs->hs_;
os.InsertBack(newTok);
}
if (!ts.Empty()) {
Error(lhs, "macro expansion failed: cannot concatenate");
}
os.InsertBack(is);
}
/*
* This is For the '#' operator in func-like macro
*/
const Token* Preprocessor::Stringize(TokenSequence is) {
std::string str = "\"";
while (!is.Empty()) {
auto tok = is.Next();
// Have preceding white space
// and is not the first token of the sequence
str.append(tok->ws_ && str.size() > 1, ' ');
if (tok->tag_ == Token::LITERAL || tok->tag_ == Token::C_CONSTANT) {
for (auto c: tok->str_) {
if (c == '"' || c == '\\')
str.push_back('\\');
str.push_back(c);
}
} else {
str += tok->str_;
}
}
str.push_back('\"');
auto ret = Token::New(*is.Peek());
ret->tag_ = Token::LITERAL;
ret->str_ = str;
return ret;
}
void Preprocessor::Finalize(TokenSequence os) {
while (!os.Empty()) {
auto tok = os.Next();
if (tok->tag_ == Token::INVALID) {
Error(tok, "stray token in program");
} else if (tok->tag_ == Token::IDENTIFIER) {
auto tag = Token::KeyWordTag(tok->str_);
if (Token::IsKeyWord(tag)) {
const_cast<Token*>(tok)->tag_ = tag;
} else {
const_cast<Token*>(tok)->str_ = Scanner(tok).ScanIdentifier();
}
}
if (fName_ && !tok->loc_.filename_) {
assert(false);
}
}
}
// TODO(wgtdkp): add predefined macros
void Preprocessor::Process(TokenSequence& os) {
TokenSequence is;
// Add source file
if(fName_)
IncludeFile(is, fName_);
else
IncludeSrc(is, fSrc_, nullptr);
// Expand
Expand(os, is);
Finalize(os);
}
const Token* Preprocessor::ParseActualParam(TokenSequence& is,
Macro* macro,
ParamMap& paramMap) {
const Token* ret;
if (macro->Params().size() == 0 && !macro->Variadic()) {
ret = is.Next();
if (ret->tag_ != ')')
Error(ret, "too many arguments");
return ret;
}
auto fp = macro->Params().begin();
TokenSequence ap;
int cnt = 1;
while (cnt > 0) {
if (is.Empty())
Error(is.Peek(), "premature end of input");
else if (is.Test('('))
++cnt;
else if (is.Test(')'))
--cnt;
if ((is.Test(',') && cnt == 1) || cnt == 0) {
if (fp == macro->Params().end()) {
if (!macro->Variadic())
Error(is.Peek(), "too many arguments");
if (cnt == 0)
paramMap.insert(std::make_pair("__VA_ARGS__", ap));
else
ap.InsertBack(is.Peek());
} else {
paramMap.insert(std::make_pair(*fp, ap));
ap = TokenSequence();
++fp;
}
} else {
ap.InsertBack(is.Peek());
}
ret = is.Next();
}
if (fp != macro->Params().end())
Error(is.Peek(), "too few params");
return ret;
}
const Token* Preprocessor::EvalDefOp(TokenSequence& is) {
auto hasPar = is.Try('(');
auto macro = is.Expect(Token::IDENTIFIER);
auto cons = Token::New(*macro);
if (hasPar) is.Expect(')');
cons->tag_ = Token::I_CONSTANT;
cons->str_ = FindMacro(macro->str_) ? "1": "0";
return cons;
}
void Preprocessor::ReplaceIdent(TokenSequence& is) {
TokenSequence os;
while (!is.Empty()) {
auto tok = is.Next();
if (tok->tag_ == Token::IDENTIFIER) {
auto cons = Token::New(*tok);
cons->tag_ = Token::I_CONSTANT;
cons->str_ = "0";
os.InsertBack(cons);
} else {
os.InsertBack(tok);
}
}
is = os;
}
int Preprocessor::GetDirective(TokenSequence& is) {
if (!is.Test('#') || !is.IsBeginOfLine())
return Token::INVALID;
is.Next();
if (is.IsBeginOfLine())
return Token::PP_EMPTY;
auto tag = is.Peek()->tag_;
if (tag == Token::IDENTIFIER || Token::IsKeyWord(tag)) {
auto str = is.Peek()->str_;
auto res = directiveMap.find(str);
if (res == directiveMap.end())
return Token::PP_NONE;
return res->second;
}
return Token::PP_NONE;
}
void Preprocessor::ParseDirective(TokenSequence& os,
TokenSequence& is,
int directive) {
if (directive == Token::PP_EMPTY)
return;
auto ls = is.GetLine();
switch(directive) {
case Token::PP_IF:
ParseIf(ls); break;
case Token::PP_IFDEF:
ParseIfdef(ls); break;
case Token::PP_IFNDEF:
ParseIfndef(ls); break;
case Token::PP_ELIF:
ParseElif(ls); break;
case Token::PP_ELSE:
ParseElse(ls); break;
case Token::PP_ENDIF:
ParseEndif(ls); break;
case Token::PP_INCLUDE:
if (NeedExpand())
ParseInclude(is, ls);
break;
case Token::PP_DEFINE:
if (NeedExpand())
ParseDef(ls);
break;
case Token::PP_UNDEF:
if (NeedExpand())
ParseUndef(ls);
break;
case Token::PP_LINE:
if (NeedExpand())
ParseLine(ls);
break;
case Token::PP_ERROR:
if (NeedExpand())
ParseError(ls);
break;
case Token::PP_PRAGMA:
if (NeedExpand())
ParsePragma(ls);
break;
case Token::PP_NONE:
break;
default:
assert(false);
}
}
void Preprocessor::ParsePragma(TokenSequence ls) {
// TODO(wgtdkp):
ls.Next();
}
void Preprocessor::ParseError(TokenSequence ls) {
ls.Next();
const auto& literal = Stringize(ls);
std::string msg;
Scanner(literal).ScanLiteral(msg);
Error(ls.Peek(), "%s", msg.c_str());
}
void Preprocessor::ParseLine(TokenSequence ls) {
auto directive = ls.Next(); // Skip directive 'line'
TokenSequence ts;
Expand(ts, ls);
auto tok = ts.Expect(Token::I_CONSTANT);
int line = 0;
size_t end = 0;
try {
line = stoi(tok->str_, &end, 10);
} catch (const std::out_of_range& oor) {
Error(tok, "line number out of range");
}
if (line == 0 || end != tok->str_.size()) {
Error(tok, "illegal line number");
}
curLine_ = line;
lineLine_ = directive->loc_.line_;
if (ts.Empty())
return;
tok = ts.Expect(Token::LITERAL);
// Enusure "s-char-sequence"
if (tok->str_.front() != '"' || tok->str_.back() != '"') {
Error(tok, "expect s-char-sequence");
}
}
void Preprocessor::ParseIf(TokenSequence ls) {
if (!NeedExpand()) {
ppCondStack_.push({Token::PP_IF, false, false});
return;
}
auto tok = ls.Next(); // Skip the directive
if (ls.Empty()) {
Error(tok, "expect expression in 'if' directive");
}
TokenSequence ts;
Expand(ts, ls, true);
ReplaceIdent(ts);
Parser parser(ts);
auto expr = parser.ParseExpr();
if (!parser.ts().Empty()) {
Error(parser.ts().Peek(), "unexpected extra expression");
}
bool cond;
if (expr->Type()->IsFloat()) {
cond = static_cast<bool>(Evaluator<double>().Eval(expr));
} else {
cond = static_cast<bool>(Evaluator<long>().Eval(expr));
}
ppCondStack_.push({Token::PP_IF, NeedExpand(), cond});
}
void Preprocessor::ParseIfdef(TokenSequence ls) {
if (!NeedExpand()) {
ppCondStack_.push({Token::PP_IFDEF, false, false});
return;
}
ls.Next();
auto ident = ls.Expect(Token::IDENTIFIER);
if (!ls.Empty()) {
Error(ls.Peek(), "expect new line");
}
auto cond = FindMacro(ident->str_) != nullptr;
ppCondStack_.push({Token::PP_IFDEF, NeedExpand(), cond});
}
void Preprocessor::ParseIfndef(TokenSequence ls) {
ParseIfdef(ls);
auto top = ppCondStack_.top();
ppCondStack_.pop();
top.tag_ = Token::PP_IFNDEF;
top.cond_ = !top.cond_;
ppCondStack_.push(top);
}
void Preprocessor::ParseElif(TokenSequence ls) {
auto directive = ls.Next(); // Skip the directive
if (ppCondStack_.empty())
Error(directive, "unexpected 'elif' directive");
auto top = ppCondStack_.top();
if (top.tag_ == Token::PP_ELSE)
Error(directive, "unexpected 'elif' directive");
while (!ppCondStack_.empty()) {
top = ppCondStack_.top();
if (top.tag_ == Token::PP_IF ||
top.tag_ == Token::PP_IFDEF ||
top.tag_ == Token::PP_IFNDEF ||
top.cond_) {
break;
}
ppCondStack_.pop();
}
if (ppCondStack_.empty())
Error(directive, "unexpected 'elif' directive");
auto enabled = top.enabled_;
if (!enabled) {
ppCondStack_.push({Token::PP_ELIF, false, false});
return;
}
if (ls.Empty()) {
Error(ls.Peek(), "expect expression in 'elif' directive");
}
TokenSequence ts;
Expand(ts, ls, true);
ReplaceIdent(ts);
Parser parser(ts);
auto expr = parser.ParseExpr();
if (!parser.ts().Empty()) {
Error(parser.ts().Peek(), "unexpected extra expression");
}
bool cond;
if (expr->Type()->IsFloat()) {
std::cout << Evaluator<double>().Eval(expr) << std::endl;
cond = static_cast<bool>(Evaluator<double>().Eval(expr));
} else {
cond = static_cast<bool>(Evaluator<long>().Eval(expr));
}
cond = cond && !top.cond_;
ppCondStack_.push({Token::PP_ELIF, true, cond});
}
void Preprocessor::ParseElse(TokenSequence ls) {
auto directive = ls.Next();
if (!ls.Empty())
Error(ls.Peek(), "expect new line");
if (ppCondStack_.empty())
Error(directive, "unexpected 'else' directive");
auto top = ppCondStack_.top();
if (top.tag_ == Token::PP_ELSE)
Error(directive, "unexpected 'else' directive");
while (!ppCondStack_.empty()) {
top = ppCondStack_.top();
if (top.tag_ == Token::PP_IF ||
top.tag_ == Token::PP_IFDEF ||
top.tag_ == Token::PP_IFNDEF ||
top.cond_) {
break;
}
ppCondStack_.pop();
}
if (ppCondStack_.empty())
Error(directive, "unexpected 'else' directive");
auto cond = !top.cond_;
auto enabled = top.enabled_;
ppCondStack_.push({Token::PP_ELSE, enabled, cond});
}
void Preprocessor::ParseEndif(TokenSequence ls) {
auto directive = ls.Next();
if (!ls.Empty())
Error(ls.Peek(), "expect new line");
while ( !ppCondStack_.empty()) {
auto top = ppCondStack_.top();
ppCondStack_.pop();
if (top.tag_ == Token::PP_IF
|| top.tag_ == Token::PP_IFDEF
|| top.tag_ == Token::PP_IFNDEF) {
return;
}
}
if (ppCondStack_.empty())
Error(directive, "unexpected 'endif' directive");
}
// Have Read the '#'
void Preprocessor::ParseInclude(TokenSequence& is, TokenSequence ls) {
bool next = ls.Next()->str_ == "include_next"; // Skip 'include'
if (!ls.Test(Token::LITERAL) && !ls.Test('<')) {
TokenSequence ts;
Expand(ts, ls, true);
ls = ts;
}
auto tok = ls.Next();
if (tok->tag_ == Token::LITERAL) {
if (!ls.Empty()) {
Error(ls.Peek(), "expect new line");
}
std::string filename;
Scanner(tok).ScanLiteral(filename);
auto fullPath = SearchFile(filename, false, next, *tok->loc_.filename_);
if (fullPath == nullptr)
Error(tok, "%s: No such file or directory", filename.c_str());
IncludeFile(is, fullPath);
} else if (tok->tag_ == '<') {
auto lhs = tok;
auto rhs = tok;
int cnt = 1;
while (!(rhs = ls.Next())->IsEOF()) {
if (rhs->tag_ == '<')
++cnt;
else if (rhs->tag_ == '>')
--cnt;
if (cnt == 0)
break;
}
if (cnt != 0)
Error(rhs, "expect '>'");
if (!ls.Empty())
Error(ls.Peek(), "expect new line");
const auto& filename = Scanner::ScanHeadName(lhs, rhs);
auto fullPath = SearchFile(filename, true, next, *tok->loc_.filename_);
if (fullPath == nullptr) {
Error(tok, "%s: No such file or directory", filename.c_str());
}
IncludeFile(is, fullPath);
} else {
Error(tok, "expect filename(string or in '<>')");
}
}
void Preprocessor::ParseUndef(TokenSequence ls) {
ls.Next(); // Skip directive
auto ident = ls.Expect(Token::IDENTIFIER);
if (!ls.Empty())
Error(ls.Peek(), "expect new line");
RemoveMacro(ident->str_);
}
void Preprocessor::ParseDef(TokenSequence ls) {
ls.Next();
auto ident = ls.Expect(Token::IDENTIFIER);
if (ident->str_ == "defined") {
Error(ident, "'defined' cannot be used as a macro name");
}
auto tok = ls.Peek();
if (tok->tag_ == '(' && !tok->ws_) {
// There is no white space between ident and '('
// Hence, we are defining function-like macro
// Parse Identifier list
ls.Next(); // Skip '('
ParamList params;
auto variadic = ParseIdentList(params, ls);
const auto& macro = Macro(variadic, params, ls);
AddMacro(ident->str_, macro);
} else {
AddMacro(ident->str_, Macro(ls));
}
}
bool Preprocessor::ParseIdentList(ParamList& params, TokenSequence& is) {
const Token* tok = is.Peek();
while (!is.Empty()) {
tok = is.Next();
if (tok->tag_ == ')') {
return false;
} else if (tok->tag_ == Token::ELLIPSIS) {
is.Expect(')');
return true;
} else if (tok->tag_ != Token::IDENTIFIER) {
Error(tok, "expect identifier");
}
for (const auto& param: params) {
if (param == tok->str_)
Error(tok, "duplicated param");
}
params.push_back(tok->str_);
if (!is.Try(',')) {
is.Expect(')');
return false;
}
}
Error(tok, "unexpected end of line");
}
void Preprocessor::IncludeSrc(TokenSequence& is,
const std::string* text,
const std::string* filename) {
TokenSequence ts {is.tokList_, is.begin_, is.begin_};
Scanner scanner(text, filename);
scanner.Tokenize(ts);
// We done including header file
is.begin_ = ts.begin_;
}
void Preprocessor::IncludeFile(TokenSequence& is,
const std::string* filename) {
IncludeSrc(is, ReadFile(*filename), filename);
}
static std::string GetDir(const std::string& path) {
auto pos = path.rfind('/');
if (pos == std::string::npos)
return "./";
return path.substr(0, pos + 1);
}
std::string* Preprocessor::SearchFile(const std::string& name,
const bool libHeader,
bool next,
const std::string& curPath) {
if (libHeader && !next) {
searchPaths_.push_back(GetDir(curPath));
} else {
searchPaths_.push_front(GetDir(curPath));
}
auto iter = searchPaths_.begin();
for (; iter != searchPaths_.end(); ++iter) {
auto dd = open(iter->c_str(), O_RDONLY);
if (dd == -1) // TODO(wgtdkp): or ensure it before preprocessing
continue;
auto fd = openat(dd, name.c_str(), O_RDONLY);
close(dd);
if (fd != -1) {
// Intentional, so that recursive include
// will result in running out of file descriptor
//close(fd);
auto path = *iter + name;
if (next) {
if (path != curPath)
continue;
else
next = false;
} else {
if (path == curPath)
continue;
if (libHeader && !next)
searchPaths_.pop_back();
else
searchPaths_.pop_front();
return new std::string(path);
}
} else if (errno == EMFILE) {
Error("may recursive include");
}
}
return nullptr;
}
void Preprocessor::AddMacro(const std::string& name,
std::string* text,
bool preDef) {
TokenSequence ts;
Scanner scanner(text);
scanner.Tokenize(ts);
Macro macro(ts, preDef);
AddMacro(name, macro);
}
static std::string* Date() {
time_t t = time(NULL);
struct tm* tm = localtime(&t);
char buf[14];
strftime(buf, sizeof buf, "\"%a %M %Y\"", tm);
return new std::string(buf);
}
void Preprocessor::Init() {
// Preinclude search paths
AddSearchPath("/usr/local/include/");
AddSearchPath("/usr/include/x86_64-linux-gnu/");
AddSearchPath("/usr/include/linux/");
AddSearchPath("/usr/include/");
AddSearchPath("/usr/local/wgtcc/include/");
// The __FILE__ and __LINE__ macro is empty
// They are handled seperately
AddMacro("__FILE__", Macro(TokenSequence(), true));
AddMacro("__LINE__", Macro(TokenSequence(), true));
AddMacro("__DATE__", Date(), true);
AddMacro("__STDC__", new std::string("1"), true);
AddMacro("__STDC__HOSTED__", new std::string("0"), true);
AddMacro("__STDC_VERSION__", new std::string("201103L"), true);
}
void Preprocessor::HandleTheFileMacro(TokenSequence& os, const Token* macro) {
auto file = Token::New(*macro);
file->tag_ = Token::LITERAL;
file->str_ = "\"" + *macro->loc_.filename_ + "\"";
os.InsertBack(file);
}
void Preprocessor::HandleTheLineMacro(TokenSequence& os, const Token* macro) {
auto line = Token::New(*macro);
line->tag_ = Token::I_CONSTANT;
line->str_ = std::to_string(macro->loc_.line_);
os.InsertBack(line);
}
void Preprocessor::UpdateFirstTokenLine(TokenSequence ts) {
auto loc = ts.Peek()->loc_;
loc.line_ = curLine_ + loc.line_ - lineLine_ - 1;
ts.UpdateHeadLocation(loc);
}
TokenSequence Macro::RepSeq(const std::string* filename, unsigned line) {
// Update line
TokenList tl;
TokenSequence ret(&tl);
ret.Copy(repSeq_);
auto ts = ret;
while (!ts.Empty()) {
auto loc = ts.Peek()->loc_;
loc.filename_ = filename;
loc.line_ = line;
ts.UpdateHeadLocation(loc);
ts.Next();
}
return ret;
}
void Preprocessor::AddSearchPath(std::string path) {
if (path.back() != '/')
path += "/";
if (path[0] != '/')
path = "./" + path;
searchPaths_.push_front(path);
}

View File

@@ -0,0 +1,42 @@
#include "triton/lang/wgtcc/encoding.h"
#include <climits>
#include <codecvt>
#include <locale>
#include <iostream>
static void Append16LE(std::string& str, char16_t c) {
str.push_back(c & UCHAR_MAX);
str.push_back((c >> 8) & UCHAR_MAX);
}
static void Append32LE(std::string& str, char32_t c) {
Append16LE(str, c & USHRT_MAX);
Append16LE(str, (c >> 16) & USHRT_MAX);
}
void ConvertToUTF16(std::string& str) {
std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t> utf8_ucs2_cvt;
auto str16 = utf8_ucs2_cvt.from_bytes(str);
str.resize(0);
for (auto c16: str16)
Append16LE(str, c16);
}
void ConvertToUTF32(std::string& str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> utf8_ucs4_cvt;
auto str32 = utf8_ucs4_cvt.from_bytes(str);
str.resize(0);
for (auto c32: str32)
Append32LE(str, c32);
}
void AppendUCN(std::string& str, int c) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> utf8_ucs4_cvt;
str += utf8_ucs4_cvt.to_bytes(static_cast<char32_t>(c));
}

95
lib/lang/wgtcc/error.cc Normal file
View File

@@ -0,0 +1,95 @@
#include "triton/lang/wgtcc/error.h"
#include "triton/lang/wgtcc/ast.h"
#include "triton/lang/wgtcc/token.h"
#include <cstdarg>
#include <cstdio>
#include <cstring>
#include <string>
#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_GREEN "\x1b[32m"
#define ANSI_COLOR_YELLOW "\x1b[33m"
#define ANSI_COLOR_BLUE "\x1b[34m"
#define ANSI_COLOR_MAGENTA "\x1b[35m"
#define ANSI_COLOR_CYAN "\x1b[36m"
#define ANSI_COLOR_RESET "\x1b[0m"
extern std::string program;
void Error(const char* format, ...) {
fprintf(stderr,
"%s: " ANSI_COLOR_RED "error: " ANSI_COLOR_RESET,
program.c_str());
va_list args;
va_start(args, format);
vfprintf(stderr, format, args);
va_end(args);
fprintf(stderr, "\n");
exit(-1);
}
[[noreturn]]
static void VError(const SourceLocation& loc,
const char* format,
va_list args) {
const char* filename = nullptr;
if(loc.filename_)
filename = loc.filename_->c_str();
fprintf(stderr,
"%s:%d:%d: " ANSI_COLOR_RED "error: " ANSI_COLOR_RESET,
filename,
loc.line_,
loc.column_);
vfprintf(stderr, format, args);
fprintf(stderr, "\n ");
bool sawNoSpace = false;
int nspaces = 0;
for (auto p = loc.lineBegin_; *p != '\n' && *p != 0; p++) {
if (!sawNoSpace && (*p == ' ' || *p == '\t')) {
++nspaces;
} else {
sawNoSpace = true;
fputc(*p, stderr);
}
}
fprintf(stderr, "\n ");
for (unsigned i = 1; i + nspaces < loc.column_; ++i)
fputc(' ', stderr);
fprintf(stderr, ANSI_COLOR_GREEN "^\n");
exit(-1);
}
void Error(const SourceLocation& loc, const char* format, ...) {
va_list args;
va_start(args, format);
VError(loc, format, args);
va_end(args);
}
void Error(const Token* tok, const char* format, ...) {
va_list args;
va_start(args, format);
VError(tok->loc_, format, args);
va_end(args);
}
void Error(const Expr* expr, const char* format, ...) {
va_list args;
va_start(args, format);
VError(expr->Tok()->loc_, format, args);
va_end(args);
}

210
lib/lang/wgtcc/evaluator.cc Normal file
View File

@@ -0,0 +1,210 @@
#include "triton/lang/wgtcc/evaluator.h"
#include "triton/lang/wgtcc/ast.h"
#include "triton/lang/wgtcc/code_gen.h"
#include "triton/lang/wgtcc/token.h"
template<typename T>
void Evaluator<T>::VisitBinaryOp(BinaryOp* binary) {
#define L Evaluator<T>().Eval(binary->lhs_)
#define R Evaluator<T>().Eval(binary->rhs_)
#define LL Evaluator<long>().Eval(binary->lhs_)
#define LR Evaluator<long>().Eval(binary->rhs_)
if (binary->Type()->ToPointer()) {
auto val = Evaluator<Addr>().Eval(binary);
if (val.label_.size()) {
Error(binary, "expect constant integer expression");
}
val_ = static_cast<T>(val.offset_);
return;
}
switch (binary->op_) {
case '+': val_ = L + R; break;
case '-': val_ = L - R; break;
case '*': val_ = L * R; break;
case '/': {
auto l = L, r = R;
if (r == 0)
Error(binary, "division by zero");
val_ = l / r;
} break;
case '%': {
auto l = LL, r = LR;
if (r == 0)
Error(binary, "division by zero");
val_ = l % r;
} break;
// Bitwise operators that do not accept float
case '|': val_ = LL | LR; break;
case '&': val_ = LL & LR; break;
case '^': val_ = LL ^ LR; break;
case Token::LEFT: val_ = LL << LR; break;
case Token::RIGHT: val_ = LL >> LR; break;
case '<': val_ = L < R; break;
case '>': val_ = L > R; break;
case Token::LOGICAL_AND: val_ = L && R; break;
case Token::LOGICAL_OR: val_ = L || R; break;
case Token::EQ: val_ = L == R; break;
case Token::NE: val_ = L != R; break;
case Token::LE: val_ = L <= R; break;
case Token::GE: val_ = L >= R; break;
case '=': case ',': val_ = R; break;
case '.': {
auto addr = Evaluator<Addr>().Eval(binary);
if (addr.label_.size())
Error(binary, "expect constant expression");
val_ = addr.offset_;
}
default: assert(false);
}
#undef L
#undef R
#undef LL
#undef LR
}
template<typename T>
void Evaluator<T>::VisitUnaryOp(UnaryOp* unary) {
#define VAL Evaluator<T>().Eval(unary->operand_)
#define LVAL Evaluator<long>().Eval(unary->operand_)
switch (unary->op_) {
case Token::PLUS: val_ = VAL; break;
case Token::MINUS: val_ = -VAL; break;
case '~': val_ = ~LVAL; break;
case '!': val_ = !VAL; break;
case Token::CAST:
if (unary->Type()->IsInteger())
val_ = static_cast<long>(VAL);
else
val_ = VAL;
break;
case Token::ADDR: {
auto addr = Evaluator<Addr>().Eval(unary->operand_);
if (addr.label_.size())
Error(unary, "expect constant expression");
val_ = addr.offset_;
} break;
default: Error(unary, "expect constant expression");
}
#undef LVAL
#undef VAL
}
template<typename T>
void Evaluator<T>::VisitConditionalOp(ConditionalOp* condOp) {
bool cond;
auto condType = condOp->cond_->Type();
if (condType->IsInteger()) {
auto val = Evaluator<long>().Eval(condOp->cond_);
cond = val != 0;
} else if (condType->IsFloat()) {
auto val = Evaluator<double>().Eval(condOp->cond_);
cond = val != 0.0;
} else if (condType->ToPointer()) {
auto val = Evaluator<Addr>().Eval(condOp->cond_);
cond = val.label_.size() || val.offset_;
} else {
assert(false);
}
if (cond) {
val_ = Evaluator<T>().Eval(condOp->exprTrue_);
} else {
val_ = Evaluator<T>().Eval(condOp->exprFalse_);
}
}
void Evaluator<Addr>::VisitBinaryOp(BinaryOp* binary) {
#define LR Evaluator<long>().Eval(binary->rhs_)
#define R Evaluator<Addr>().Eval(binary->rhs_)
auto l = Evaluator<Addr>().Eval(binary->lhs_);
int width = 1;
auto pointerType = binary->Type()->ToPointer();
if (pointerType)
width = pointerType->Derived()->Width();
switch (binary->op_) {
case '+':
assert(pointerType);
addr_.label_ = l.label_;
addr_.offset_ = l.offset_ + LR * width;
break;
case '-':
assert(pointerType);
addr_.label_ = l.label_;
addr_.offset_ = l.offset_ + LR * width;
break;
case '.': {
addr_.label_ = l.label_;
auto type = binary->lhs_->Type()->ToStruct();
auto offset = type->GetMember(binary->rhs_->tok_->str_)->Offset();
addr_.offset_ = l.offset_ + offset;
break;
}
default: assert(false);
}
#undef LR
#undef R
}
void Evaluator<Addr>::VisitUnaryOp(UnaryOp* unary) {
auto addr = Evaluator<Addr>().Eval(unary->operand_);
switch (unary->op_) {
case Token::CAST:
case Token::ADDR:
case Token::DEREF:
addr_ = addr; break;
default: assert(false);
}
}
void Evaluator<Addr>::VisitConditionalOp(ConditionalOp* condOp) {
bool cond;
auto condType = condOp->cond_->Type();
if (condType->IsInteger()) {
auto val = Evaluator<long>().Eval(condOp->cond_);
cond = val != 0;
} else if (condType->IsFloat()) {
auto val = Evaluator<double>().Eval(condOp->cond_);
cond = val != 0.0;
} else if (condType->ToPointer()) {
auto val = Evaluator<Addr>().Eval(condOp->cond_);
cond = val.label_.size() || val.offset_;
} else {
assert(false);
}
if (cond) {
addr_ = Evaluator<Addr>().Eval(condOp->exprTrue_);
} else {
addr_ = Evaluator<Addr>().Eval(condOp->exprFalse_);
}
}
void Evaluator<Addr>::VisitConstant(Constant* cons) {
if (cons->Type()->IsInteger()) {
addr_ = {"", static_cast<int>(cons->IVal())};
} else if (cons->Type()->ToArray()) {
Generator().ConsLabel(cons); // Add the literal to rodatas_.
addr_.label_ = Generator::rodatas_.back().label_;
addr_.offset_ = 0;
} else {
assert(false);
}
}

253
lib/lang/wgtcc/main.cc Normal file
View File

@@ -0,0 +1,253 @@
#include "triton/lang/wgtcc/code_gen.h"
#include "triton/lang/wgtcc/cpp.h"
#include "triton/lang/wgtcc/error.h"
#include "triton/lang/wgtcc/parser.h"
#include "triton/lang/wgtcc/scanner.h"
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <list>
#include <string>
#include <vector>
#include <fcntl.h>
#include <unistd.h>
#include <sys/wait.h>
std::string program;
std::string filename_in;
std::string filename_out;
bool debug = false;
static bool only_preprocess = false;
static bool only_compile = false;
static bool specified_out_name = false;
static std::list<std::string> filenames_in;
static std::list<std::string> gcc_filenames_in;
static std::list<std::string> gcc_args;
static std::list<std::string> defines;
static std::list<std::string> include_paths;
static void Usage() {
printf("Usage: wgtcc [options] file...\n"
"Options: \n"
" -h Display this information\n"
" -D Define object like macro\n"
" -I Add search path\n"
" -E Preprocess only; do not compile, assemble or link\n"
" -S Compile only; do not assemble or link\n"
" -o specify output file\n");
exit(0);
}
static std::string GetExtension(const std::string& filename) {
return filename.substr(filename.size() >= 2 ? filename.size() - 2 : 0);
}
static void ValidateFileName(const std::string& filename) {
auto ext = GetExtension(filename);
if (ext != ".c" && ext != ".s" && ext != ".o" && ext != ".a")
Error("bad file name format:'%s'", filename.c_str());
}
static void DefineMacro(Preprocessor& cpp, const std::string& def) {
auto pos = def.find('=');
std::string macro;
std::string* replace;
if (pos == std::string::npos) {
macro = def;
replace = new std::string();
} else {
macro = def.substr(0, pos);
replace = new std::string(def.substr(pos + 1));
}
cpp.AddMacro(macro, replace);
}
static std::string GetName(const std::string& path) {
auto pos = path.rfind('/');
if (pos == std::string::npos)
return path;
return path.substr(pos + 1);
}
static int RunWgtcc() {
if (GetExtension(filename_in) != ".c")
return -3;
Preprocessor cpp(&filename_in);
for (auto& def: defines)
DefineMacro(cpp, def);
for (auto& path: include_paths)
cpp.AddSearchPath(path);
FILE* fp = stdout;
if (specified_out_name) {
fp = fopen(filename_out.c_str(), "w");
}
TokenSequence ts;
cpp.Process(ts);
if (only_preprocess) {
ts.Print(fp);
return 0;
}
if (!only_compile || !specified_out_name) {
filename_out = GetName(filename_in);
filename_out.back() = 's';
}
fp = fopen(filename_out.c_str(), "w");
Parser parser(ts);
parser.Parse();
Generator::SetInOut(&parser, fp);
Generator().Gen();
fclose(fp);
return 0;
}
static int RunGcc() {
// Froce C11
bool spec_std = false;
for (auto& arg: gcc_args) {
if (arg.substr(0, 4) == "-std") {
arg = "-std=c11";
spec_std = true;
}
}
if (!spec_std) {
gcc_args.push_front("-std=c11");
}
std::string systemArg = "gcc";
for (const auto& arg: gcc_args) {
systemArg += " " + arg;
}
auto ret = system(systemArg.c_str());
return ret;
}
static void ParseInclude(int argc, char* argv[], int& i) {
if (argv[i][2]) {
include_paths.push_front(&argv[i][2]);
return;
}
if (i == argc - 1) {
Error("missing argument to '%s'", argv[i]);
}
include_paths.push_front(argv[++i]);
gcc_args.push_back(argv[i]);
}
static void ParseDefine(int argc, char* argv[], int& i) {
if (argv[i][2]) {
defines.push_back(&argv[i][2]);
return;
}
if (i == argc - 1)
Error("missing argument to '%s'", argv[i]);
defines.push_back(argv[++i]);
gcc_args.push_back(argv[i]);
}
static void ParseOut(int argc, char* argv[], int& i) {
if (i == argc - 1)
Error("missing argument to '%s'", argv[i]);
filename_out = argv[++i];
gcc_args.push_back(argv[i]);
}
/* Use:
* wgtcc: compile
* gcc: assemble and link
* Allowing multi file may not be a good idea...
*/
int main(int argc, char* argv[]) {
if (argc < 2)
Usage();
program = std::string(argv[0]);
for (auto i = 1; i < argc; ++i) {
if (argv[i][0] != '-') {
filename_in = std::string(argv[i]);
ValidateFileName(filename_in);
filenames_in.push_back(filename_in);
continue;
}
gcc_args.push_back(argv[i]);
switch (argv[i][1]) {
case 'h': Usage(); break;
case 'E': only_preprocess = true; break;
case 'S': only_compile = true; break;
case 'I': ParseInclude(argc, argv, i); break;
case 'D': ParseDefine(argc, argv, i); break;
case 'o':
specified_out_name = true;
ParseOut(argc, argv, i); break;
case 'g': gcc_args.pop_back(); debug = true; break;
default:;
}
}
#ifdef DEBUG
RunWgtcc();
#else
for (const auto& filename: filenames_in) {
filename_in = filename;
pid_t pid = fork();
if (pid < 0) {
Error("fork error");
} else if (pid == 0) {
// Do work in child process
return RunWgtcc();
}
}
for (size_t i = 0; i < filenames_in.size(); ++i) {
int stat;
wait(&stat);
// Child process terminate normaly if :
// 1. terminate with `exit()`, that is, WIFEXITED(stat) if true.
// 2. the status code is 0, that is, WEXITSTATUS(stat) == 0
if (!WIFEXITED(stat) || WEXITSTATUS(stat))
return 0;
}
#endif
if (only_preprocess || only_compile) {
if (specified_out_name && filenames_in.size() > 1)
Error("cannot specifier output filename with multiple input file");
return 0;
}
std::list<std::string> filenames_out;
for (auto& filename: filenames_in) {
if (GetExtension(filename) == ".c") {
gcc_args.push_back(GetName(filename));
gcc_args.back().back() = 's';
} else {
gcc_args.clear();
for (int i = 1; i < argc; ++i)
gcc_args.push_back(argv[i]);
break;
}
}
auto ret = RunGcc();
remove(filename_out.c_str());
return ret;
}

2688
lib/lang/wgtcc/parser.cc Normal file

File diff suppressed because it is too large Load Diff

452
lib/lang/wgtcc/scanner.cc Normal file
View File

@@ -0,0 +1,452 @@
#include "triton/lang/wgtcc/scanner.h"
#include <cctype>
#include <climits>
void Scanner::Tokenize(TokenSequence& ts) {
while (true) {
auto tok = Scan();
if (tok->tag_ == Token::END) {
if (ts.Empty() || (ts.Back()->tag_ != Token::NEW_LINE)) {
auto t = Token::New(*tok);
t->tag_ = Token::NEW_LINE;
t->str_ = "\n";
ts.InsertBack(t);
}
break;
} else {
if (!ts.Empty() && ts.Back()->tag_ == Token::NEW_LINE)
tok->ws_ = true;
ts.InsertBack(tok);
}
}
}
std::string Scanner::ScanHeadName(const Token* lhs, const Token* rhs) {
std::string str;
const char* begin = lhs->loc_.Begin() + 1;
const char* end = rhs->loc_.Begin();
for (; begin != end; ++begin) {
if (*begin == '\n' && str.back() == '\\')
str.pop_back();
else
str.push_back(*begin);
}
return str;
}
Token* Scanner::Scan(bool ws) {
tok_.ws_ = ws;
SkipWhiteSpace();
Mark();
if (Test('\n')) {
auto ret = MakeNewLine();
Next();
return ret;
}
auto c = Next();
switch (c) {
case '#': return MakeToken(Try('#') ? Token::DSHARP: c);
case ':': return MakeToken(Try('>') ? ']': c);
case '(': case ')': case '[': case ']':
case '?': case ',': case '{': case '}':
case '~': case ';': case '@':
return MakeToken(c);
case '-':
if (Try('>')) return MakeToken(Token::PTR);
if (Try('-')) return MakeToken(Token::DEC);
if (Try('=')) return MakeToken(Token::SUB_ASSIGN);
return MakeToken(c);
case '+':
if (Try('+')) return MakeToken(Token::INC);
if (Try('=')) return MakeToken(Token::ADD_ASSIGN);
return MakeToken(c);
case '<':
if (Try('<')) return MakeToken(Try('=') ? Token::LEFT_ASSIGN: Token::LEFT);
if (Try('=')) return MakeToken(Token::LE);
if (Try(':')) return MakeToken('[');
if (Try('%')) return MakeToken('{');
return MakeToken(c);
case '%':
if (Try('=')) return MakeToken(Token::MOD_ASSIGN);
if (Try('>')) return MakeToken('}');
if (Try(':')) {
if (Try('%')) {
if (Try(':')) return MakeToken(Token::DSHARP);
PutBack();
}
return MakeToken('#');
}
return MakeToken(c);
case '>':
if (Try('>')) return MakeToken(Try('=') ? Token::RIGHT_ASSIGN: Token::RIGHT);
if (Try('=')) return MakeToken(Token::GE);
return MakeToken(c);
case '=': return MakeToken(Try('=') ? Token::EQ: c);
case '!': return MakeToken(Try('=') ? Token::NE: c);
case '&':
if (Try('&')) return MakeToken(Token::LOGICAL_AND);
if (Try('=')) return MakeToken(Token::AND_ASSIGN);
return MakeToken(c);
case '|':
if (Try('|')) return MakeToken(Token::LOGICAL_OR);
if (Try('=')) return MakeToken(Token::OR_ASSIGN);
return MakeToken(c);
case '*': return MakeToken(Try('=') ? Token::MUL_ASSIGN: c);
case '/':
if (Test('/') || Test('*')) {
SkipComment();
return Scan(true);
}
return MakeToken(Try('=') ? Token::DIV_ASSIGN: c);
case '^': return MakeToken(Try('=') ? Token::XOR_ASSIGN: c);
case '.':
if (isdigit(Peek())) return SkipNumber();
if (Try('.')) {
if (Try('.')) return MakeToken(Token::ELLIPSIS);
PutBack();
return MakeToken('.');
}
return MakeToken(c);
case '0' ... '9': return SkipNumber();
case 'u': case 'U': case 'L': {
/*auto enc = */ScanEncoding(c);
if (Try('\'')) return SkipCharacter();
if (Try('\"')) return SkipLiteral();
return SkipIdentifier();
}
case '\'': return SkipCharacter();
case '\"': return SkipLiteral();
case 'a' ... 't': case 'v' ... 'z': case 'A' ... 'K':
case 'M' ... 'T': case 'V' ... 'Z': case '_': case '$':
case 0x80 ... 0xfd:
return SkipIdentifier();
case '\\':
// Universal character name is allowed in identifier
if (Test('u') || Test('U'))
return SkipIdentifier();
return MakeToken(Token::INVALID);
case '\0': return MakeToken(Token::END);
default: return MakeToken(Token::INVALID);
}
}
void Scanner::SkipWhiteSpace() {
while (isspace(Peek()) && Peek() != '\n') {
tok_.ws_ = true;
Next();
}
}
void Scanner::SkipComment() {
if (Try('/')) {
// Line comment terminated an newline or eof
while (!Empty()) {
if (Peek() == '\n')
return;
Next();
}
return;
} else if (Try('*')) {
while (!Empty()) {
auto c = Next();
if (c == '*' && Peek() == '/') {
Next();
return;
}
}
Error(loc_, "unterminated block comment");
}
assert(false);
}
std::string Scanner::ScanIdentifier() {
std::string val;
while (!Empty()) {
auto c = Next();
if (IsUCN(c)) {
c = ScanEscaped(); // Call ScanUCN()
AppendUCN(val, c);
} else {
val.push_back(c);
}
}
return val;
}
Token* Scanner::SkipIdentifier() {
PutBack();
auto c = Next();
while (isalnum(c)
|| (0x80 <= c && c <= 0xfd)
|| c == '_'
|| c == '$'
|| IsUCN(c)) {
if (IsUCN(c))
c = ScanEscaped(); // Just read it
c = Next();
}
PutBack();
return MakeToken(Token::IDENTIFIER);
}
// Scan PP-Number
Token* Scanner::SkipNumber() {
PutBack();
bool sawHexPrefix = false;
int tag = Token::I_CONSTANT;
auto c = Next();
while (c == '.' || isdigit(c) || isalpha(c) || c == '_' || IsUCN(c)) {
if (c == 'e' || c =='E' || c == 'p' || c == 'P') {
if (!Try('-')) Try('+');
if (!((c == 'e' || c == 'E') && sawHexPrefix))
tag = Token::F_CONSTANT;
} else if (IsUCN(c)) {
ScanEscaped();
} else if (c == '.') {
tag = Token::F_CONSTANT;
} else if (c == 'x' || c == 'X') {
sawHexPrefix = true;
}
c = Next();
}
PutBack();
return MakeToken(tag);
}
Encoding Scanner::ScanLiteral(std::string& val) {
auto enc = Test('\"') ? Encoding::NONE: ScanEncoding(Next());
Next();
val.resize(0);
while (!Test('\"')) {
auto c = Next();
bool isucn = IsUCN(c);
if (c == '\\')
c = ScanEscaped();
if (isucn)
AppendUCN(val, c);
else
val.push_back(c);
}
return enc;
}
Token* Scanner::SkipLiteral() {
auto c = Next();
while (c != '\"' && c != '\n' && c != '\0') {
if (c == '\\') Next();
c = Next();
}
if (c != '\"')
Error(loc_, "unterminated string literal");
return MakeToken(Token::LITERAL);
}
Encoding Scanner::ScanCharacter(int& val) {
auto enc = Test('\'') ? Encoding::NONE: ScanEncoding(Next());
Next();
val = 0;
while (!Test('\'')) {
auto c = Next();
if (c == '\\')
c = ScanEscaped();
if (enc == Encoding::NONE)
val = (val << 8) + c;
else
val = c;
}
return enc;
}
Token* Scanner::SkipCharacter() {
auto c = Next();
while (c != '\'' && c != '\n' && c != '\0') {
if (c == '\\') Next();
c = Next();
}
if (c != '\'')
Error(loc_, "unterminated character constant");
return MakeToken(Token::C_CONSTANT);
}
int Scanner::ScanEscaped() {
auto c = Next();
switch (c) {
case '\\': case '\'': case '\"': case '\?':
return c;
case 'a': return '\a';
case 'b': return '\b';
case 'f': return '\f';
case 'n': return '\n';
case 'r': return '\r';
case 't': return '\t';
case 'v': return '\v';
// Non-standard GCC extention
case 'e': return '\033';
case 'x': return ScanHexEscaped();
case '0' ... '7': return ScanOctEscaped(c);
case 'u': return ScanUCN(4);
case 'U': return ScanUCN(8);
default: Error(loc_, "unrecognized escape character '%c'", c);
}
return c; // Make compiler happy
}
int Scanner::ScanHexEscaped() {
int val = 0, c = Peek();
if (!isxdigit(c))
Error(loc_, "expect xdigit, but got '%c'", c);
while (isxdigit(c)) {
val = (val << 4) + XDigit(c);
Next();
c = Peek();
}
return val;
}
int Scanner::ScanOctEscaped(int c) {
int val = XDigit(c);
c = Peek();
if (!IsOctal(c))
return val;
val = (val << 3) + XDigit(c);
Next();
c = Peek();
if (!IsOctal(c))
return val;
val = (val << 3) + XDigit(c);
Next();
return val;
}
int Scanner::ScanUCN(int len) {
assert(len == 4 || len == 8);
int val = 0;
for (auto i = 0; i < len; ++i) {
auto c = Next();
if (!isxdigit(c))
Error(loc_, "expect xdigit, but got '%c'", c);
val = (val << 4) + XDigit(c);
}
return val;
}
int Scanner::XDigit(int c) {
switch (c) {
case '0' ... '9': return c - '0';
case 'a' ... 'z': return c - 'a' + 10;
case 'A' ... 'Z': return c - 'A' + 10;
default: assert(false); return c;
}
}
Encoding Scanner::ScanEncoding(int c) {
switch (c) {
case 'u': return Try('8') ? Encoding::UTF8: Encoding::CHAR16;
case 'U': return Encoding::CHAR32;
case 'L': return Encoding::WCHAR;
default: assert(false); return Encoding::NONE;
}
}
std::string* ReadFile(const std::string& filename) {
FILE* f = fopen(filename.c_str(), "r");
if (!f) Error("%s: No such file or directory", filename.c_str());
auto text = new std::string;
int c;
while (EOF != (c = fgetc(f)))
text->push_back(c);
fclose(f);
return text;
}
int Scanner::Next() {
int c = Peek();
++p_;
if (c == '\n') {
++loc_.line_;
loc_.column_ = 1;
loc_.lineBegin_ = p_;
} else {
++loc_.column_;
}
return c;
}
int Scanner::Peek() {
int c = (uint8_t)(*p_);
if (c == '\\' && p_[1] == '\n') {
p_ += 2;
++loc_.line_;
loc_.column_ = 1;
loc_.lineBegin_ = p_;
return Peek();
}
return c;
}
// There couldn't be more than one PutBack() that
// cross two line, so just leave lineBegin, because
// we never care about the pos of newline token
void Scanner::PutBack() {
int c = *--p_;
if (c == '\n' && p_[-1] == '\\') {
--loc_.line_;
--p_;
return PutBack();
} else if (c == '\n') {
--loc_.line_;
} else {
--loc_.column_;
}
}
Token* Scanner::MakeToken(int tag) {
tok_.tag_ = tag;
auto& str = tok_.str_;
str.resize(0);
const char* p = tok_.loc_.lineBegin_ + tok_.loc_.column_ - 1;
for (; p < p_; ++p) {
if (p[0] == '\n' && p[-1] == '\\')
str.pop_back();
else
str.push_back(p[0]);
}
return Token::New(tok_);
}
/*
* New line is special, it is generated before reading the character '\n'
*/
Token* Scanner::MakeNewLine() {
tok_.tag_ = '\n';
tok_.str_ = std::string(p_, p_ + 1);
return Token::New(tok_);
}

111
lib/lang/wgtcc/scope.cc Normal file
View File

@@ -0,0 +1,111 @@
#include "triton/lang/wgtcc/scope.h"
#include "triton/lang/wgtcc/ast.h"
#include <cassert>
#include <iostream>
Identifier* Scope::Find(const Token* tok) {
auto ret = Find(tok->str_);
if (ret) ret->SetTok(tok);
return ret;
}
Identifier* Scope::FindInCurScope(const Token* tok) {
auto ret = FindInCurScope(tok->str_);
if (ret) ret->SetTok(tok);
return ret;
}
Identifier* Scope::FindTag(const Token* tok) {
auto ret = FindTag(tok->str_);
if (ret) ret->SetTok(tok);
return ret;
}
Identifier* Scope::FindTagInCurScope(const Token* tok) {
auto ret = FindTagInCurScope(tok->str_);
if (ret) ret->SetTok(tok);
return ret;
}
void Scope::Insert(Identifier* ident) {
Insert(ident->Name(), ident);
}
void Scope::InsertTag(Identifier* ident) {
Insert(TagName(ident->Name()), ident);
}
Identifier* Scope::Find(const std::string& name) {
auto ident = identMap_.find(name);
if (ident != identMap_.end())
return ident->second;
if (type_ == S_FILE || parent_ == nullptr)
return nullptr;
return parent_->Find(name);
}
Identifier* Scope::FindInCurScope(const std::string& name) {
auto ident = identMap_.find(name);
if (ident == identMap_.end())
return nullptr;
return ident->second;
}
void Scope::Insert(const std::string& name, Identifier* ident) {
assert(FindInCurScope(name) == nullptr);
identMap_[name] = ident;
}
Identifier* Scope::FindTag(const std::string& name) {
auto tag = Find(TagName(name));
if (tag) assert(tag->ToTypeName());
return tag;
}
Identifier* Scope::FindTagInCurScope(const std::string& name) {
auto tag = FindInCurScope(TagName(name));
assert(tag == nullptr || tag->ToTypeName());
return tag;
}
Scope::TagList Scope::AllTagsInCurScope() const {
TagList tags;
for (auto& kv: identMap_) {
if (IsTagName(kv.first))
tags.push_back(kv.second);
}
return tags;
}
void Scope::Print() {
std::cout << "scope: " << this << std::endl;
auto iter = identMap_.begin();
for (; iter != identMap_.end(); ++iter) {
auto name = iter->first;
auto ident = iter->second;
if (ident->ToTypeName()) {
std::cout << name << "\t[type:\t"
<< ident->Type()->Str() << "]" << std::endl;
} else {
std::cout << name << "\t[object:\t"
<< ident->Type()->Str() << "]" << std::endl;
}
}
std::cout << std::endl;
}

259
lib/lang/wgtcc/token.cc Normal file
View File

@@ -0,0 +1,259 @@
#include "triton/lang/wgtcc/token.h"
#include "triton/lang/wgtcc/mem_pool.h"
#include "triton/lang/wgtcc/parser.h"
static MemPoolImp<Token> tokenPool;
const std::unordered_map<std::string, int> Token::kwTypeMap_ {
{ "auto", Token::AUTO },
{ "break", Token::BREAK },
{ "case", Token::CASE },
{ "char", Token::CHAR },
{ "const", Token::CONST },
{ "continue", Token::CONTINUE },
{ "default", Token::DEFAULT },
{ "do", Token::DO },
{ "double", Token::DOUBLE },
{ "else", Token::ELSE },
{ "enum", Token::ENUM },
{ "extern", Token::EXTERN },
{ "float", Token::FLOAT },
{ "for", Token::FOR },
{ "goto", Token::GOTO },
{ "half", Token::HALF },
{ "if", Token::IF },
{ "inline", Token::INLINE },
{ "int", Token::INT },
{ "long", Token::LONG },
{ "signed", Token::SIGNED },
{ "unsigned", Token::UNSIGNED },
{ "register", Token::REGISTER },
{ "restrict", Token::RESTRICT },
{ "return", Token::RETURN },
{ "short", Token::SHORT },
{ "sizeof", Token::SIZEOF },
{ "static", Token::STATIC },
{ "struct", Token::STRUCT },
{ "switch", Token::SWITCH },
{ "typedef", Token::TYPEDEF },
{ "union", Token::UNION },
{ "void", Token::VOID },
{ "volatile", Token::VOLATILE },
{ "while", Token::WHILE },
{ "_Alignas", Token::ALIGNAS },
{ "_Alignof", Token::ALIGNOF },
{ "_Atomic", Token::ATOMIC },
{ "__attribute__", Token::ATTRIBUTE },
{ "_Bool", Token::BOOL },
{ "_Complex", Token::COMPLEX },
{ "_Generic", Token::GENERIC },
{ "_Imaginary", Token::IMAGINARY },
{ "_Noreturn", Token::NORETURN },
{ "_Static_assert", Token::STATIC_ASSERT },
{ "_Thread_local", Token::THREAD },
};
const std::unordered_map<int, const char*> Token::tagLexemeMap_ {
{ '(', "(" },
{ ')', ")" },
{ '[', "[" },
{ ']', "]" },
{ ':', ":" },
{ ',', "," },
{ ';', ";" },
{ '+', "+" },
{ '-', "-" },
{ '*', "*" },
{ '/', "/" },
{ '|', "|" },
{ '&', "&" },
{ '<', "<" },
{ '>', ">" },
{ '=', "=" },
{ '.', "." },
{ '%', "%" },
{ '{', "{" },
{ '}', "}" },
{ '^', "^" },
{ '~', "~" },
{ '!', "!" },
{ '?', "?" },
{ '#', "#" },
{ '@', "@" },
{ Token::DSHARP, "##" },
{ Token::PTR, "->" },
{ Token::INC, "++" },
{ Token::DEC, "--" },
{ Token::LEFT, "<<" },
{ Token::RIGHT, ">>" },
{ Token::LE, "<=" },
{ Token::GE, ">=" },
{ Token::EQ, "==" },
{ Token::NE, "!=" },
{ Token::LOGICAL_AND, "&&" },
{ Token::LOGICAL_OR, "||" },
{ Token::MUL_ASSIGN, "*=" },
{ Token::DIV_ASSIGN, "/=" },
{ Token::MOD_ASSIGN, "%=" },
{ Token::ADD_ASSIGN, "+=" },
{ Token::SUB_ASSIGN, "-=" },
{ Token::LEFT_ASSIGN, "<<=" },
{ Token::RIGHT_ASSIGN, ">>=" },
{ Token::AND_ASSIGN, "&=" },
{ Token::XOR_ASSIGN, "^=" },
{ Token::OR_ASSIGN, "|=" },
{ Token::ELLIPSIS, "..." },
{ Token::AUTO, "auto" },
{ Token::BREAK, "break" },
{ Token::CASE, "case" },
{ Token::CHAR, "char" },
{ Token::CONST, "const" },
{ Token::CONTINUE, "continue" },
{ Token::DEFAULT, "default" },
{ Token::DO, "do" },
{ Token::DOUBLE, "double" },
{ Token::ELSE, "else" },
{ Token::ENUM, "enum" },
{ Token::EXTERN, "extern" },
{ Token::FLOAT, "float" },
{ Token::FOR, "for" },
{ Token::GOTO, "goto" },
{ Token::IF, "if" },
{ Token::INLINE, "inline" },
{ Token::INT, "int" },
{ Token::LONG, "long" },
{ Token::SIGNED, "signed" },
{ Token::UNSIGNED, "unsigned" },
{ Token::REGISTER, "register" },
{ Token::RESTRICT, "restrict" },
{ Token::RETURN, "return" },
{ Token::SHORT, "short" },
{ Token::SIZEOF, "sizeof" },
{ Token::STATIC, "static" },
{ Token::STRUCT, "struct" },
{ Token::SWITCH, "switch" },
{ Token::TYPEDEF, "typedef" },
{ Token::UNION, "union" },
{ Token::VOID, "void" },
{ Token::VOLATILE, "volatile" },
{ Token::WHILE, "while" },
{ Token::ALIGNAS, "_Alignas" },
{ Token::ALIGNOF, "_Alignof" },
{ Token::ATOMIC, "_Atomic" },
{ Token::ATTRIBUTE, "__attribute__" },
{ Token::BOOL, "_Bool" },
{ Token::COMPLEX, "_Complex" },
{ Token::GENERIC, "_Generic" },
{ Token::IMAGINARY, "_Imaginary" },
{ Token::NORETURN, "_Noreturn" },
{ Token::STATIC_ASSERT, "_Static_assert" },
{ Token::THREAD, "_Thread_local" },
{ Token::END, "(eof)" },
{ Token::IDENTIFIER, "(identifier)" },
{ Token::CONSTANT, "(constant)" },
{ Token::LITERAL, "(string literal)" },
};
Token* Token::New(int tag) {
return new (tokenPool.Alloc()) Token(tag);
}
Token* Token::New(const Token& other) {
return new (tokenPool.Alloc()) Token(other);
}
Token* Token::New(int tag,
const SourceLocation& loc,
const std::string& str,
bool ws) {
return new (tokenPool.Alloc()) Token(tag, loc, str, ws);
}
TokenSequence TokenSequence::GetLine() {
auto begin = begin_;
while (begin_ != end_ && (*begin_)->tag_ != Token::NEW_LINE)
++begin_;
auto end = begin_;
return {tokList_, begin, end};
}
/*
* If this seq starts from the begin of a line.
* Called only after we have saw '#' in the token sequence.
*/
bool TokenSequence::IsBeginOfLine() const {
if (begin_ == tokList_->begin())
return true;
auto pre = begin_;
--pre;
// We do not insert a newline at the end of a source file.
// Thus if two token have different filename, the second is
// the begin of a line.
return ((*pre)->tag_ == Token::NEW_LINE ||
(*pre)->loc_.filename_ != (*begin_)->loc_.filename_);
}
const Token* TokenSequence::Peek() const {
static auto eof = Token::New(Token::END);
if (begin_ != end_ && (*begin_)->tag_ == Token::NEW_LINE) {
++begin_;
return Peek();
} else if (begin_ == end_) {
if (end_ != tokList_->begin())
*eof = *Back();
eof->tag_ = Token::END;
return eof;
} else if (parser_ && (*begin_)->tag_ == Token::IDENTIFIER &&
(*begin_)->str_ == "__func__") {
auto filename = Token::New(*(*begin_));
filename->tag_ = Token::LITERAL;
filename->str_ = "\"" + parser_->CurFunc()->Name() + "\"";
*begin_ = filename;
}
return *begin_;
}
const Token* TokenSequence::Expect(int expect) {
auto tok = Peek();
if (!Try(expect)) {
Error(tok, "'%s' expected, but got '%s'",
Token::Lexeme(expect), tok->str_.c_str());
}
return tok;
}
void TokenSequence::Print(FILE* fp) const {
unsigned lastLine = 0;
auto ts = *this;
while (!ts.Empty()) {
auto tok = ts.Next();
if (lastLine != tok->loc_.line_) {
fputs("\n", fp);
for (unsigned i = 0; i < tok->loc_.column_; ++i)
fputc(' ', fp);
} else if (tok->ws_) {
fputc(' ', fp);
}
fputs(tok->str_.c_str(), fp);
fflush(fp);
lastLine = tok->loc_.line_;
}
fputs("\n", fp);
}
//void TokenSequence::Print(std::string *str) const {
//}

484
lib/lang/wgtcc/type.cc Normal file
View File

@@ -0,0 +1,484 @@
#include "triton/lang/wgtcc/type.h"
#include "triton/lang/wgtcc/ast.h"
#include "triton/lang/wgtcc/scope.h"
#include "triton/lang/wgtcc/token.h"
#include <cassert>
#include <algorithm>
#include <iostream>
static MemPoolImp<VoidType> voidTypePool;
static MemPoolImp<ArrayType> arrayTypePool;
static MemPoolImp<TileType> tileTypePool;
static MemPoolImp<FuncType> funcTypePool;
static MemPoolImp<PointerType> pointerTypePool;
static MemPoolImp<StructType> structUnionTypePool;
static MemPoolImp<ArithmType> arithmTypePool;
QualType Type::MayCast(QualType type, bool inProtoScope) {
auto funcType = type->ToFunc();
auto arrayType = type->ToArray();
if (funcType) {
return PointerType::New(funcType);
} else if (arrayType) {
auto ret = PointerType::New(arrayType->Derived());
// C11 6.7.6.3 [7]: qualifiers are specified in '[]'
// As we do not support qualifiers in '[]', the qualifier whould be none
return QualType(ret, inProtoScope? 0: Qualifier::CONST);
}
return type;
}
VoidType* VoidType::New() {
static auto ret = new (voidTypePool.Alloc()) VoidType(&voidTypePool);
return ret;
}
ArithmType* ArithmType::New(int typeSpec) {
#define NEW_TYPE(tag) \
new (arithmTypePool.Alloc()) ArithmType(&arithmTypePool, tag);
static auto boolType = NEW_TYPE(T_BOOL);
static auto charType = NEW_TYPE(T_CHAR);
static auto ucharType = NEW_TYPE(T_UNSIGNED | T_CHAR);
static auto shortType = NEW_TYPE(T_SHORT);
static auto ushortType = NEW_TYPE(T_UNSIGNED | T_SHORT);
static auto intType = NEW_TYPE(T_INT);
static auto uintType = NEW_TYPE(T_UNSIGNED | T_INT);
static auto longType = NEW_TYPE(T_LONG);
static auto ulongType = NEW_TYPE(T_UNSIGNED | T_LONG);
static auto llongType = NEW_TYPE(T_LLONG)
static auto ullongType = NEW_TYPE(T_UNSIGNED | T_LLONG);
static auto halfType = NEW_TYPE(T_HALF);
static auto floatType = NEW_TYPE(T_FLOAT);
static auto doubleType = NEW_TYPE(T_DOUBLE);
static auto ldoubleType = NEW_TYPE(T_LONG | T_DOUBLE);
auto tag = ArithmType::Spec2Tag(typeSpec);
switch (tag) {
case T_BOOL: return boolType;
case T_CHAR: return charType;
case T_UNSIGNED | T_CHAR: return ucharType;
case T_SHORT: return shortType;
case T_UNSIGNED | T_SHORT:return ushortType;
case T_INT: return intType;
case T_UNSIGNED:
case T_UNSIGNED | T_INT: return uintType;
case T_LONG: return longType;
case T_UNSIGNED | T_LONG: return ulongType;
case T_LLONG: return llongType;
case T_UNSIGNED | T_LLONG:return ullongType;
case T_HALF: return halfType;
case T_FLOAT: return floatType;
case T_DOUBLE: return doubleType;
case T_LONG | T_DOUBLE: return ldoubleType;
default:
assert(tag & T_COMPLEX);
Error("complex not supported yet");
}
return nullptr; // Make compiler happy
#undef NEW_TYPE
}
ArrayType* ArrayType::New(int len, QualType eleType) {
return new (arrayTypePool.Alloc())
ArrayType(&arrayTypePool, len, eleType);
}
ArrayType* ArrayType::New(Expr* expr, QualType eleType) {
return new (arrayTypePool.Alloc())
ArrayType(&arrayTypePool, expr, eleType);
}
TileType* TileType::New(const ShapeExpr &expr, QualType eleType) {
return new (tileTypePool.Alloc())
TileType(&tileTypePool, expr, eleType);
}
TileType* TileType::New(const ShapeInt &shape, QualType eleType) {
return new (tileTypePool.Alloc())
TileType(&tileTypePool, shape, eleType);
}
FuncType* FuncType::New(QualType derived,
int funcSpec,
bool variadic,
const ParamList& params) {
return new (funcTypePool.Alloc())
FuncType(&funcTypePool, derived, funcSpec, variadic, params);
}
PointerType* PointerType::New(QualType derived) {
return new (pointerTypePool.Alloc())
PointerType(&pointerTypePool, derived);
}
StructType* StructType::New(bool isStruct,
bool hasTag,
Scope* parent) {
return new (structUnionTypePool.Alloc())
StructType(&structUnionTypePool, isStruct, hasTag, parent);
}
int ArithmType::Width() const {
switch (tag_) {
case T_BOOL: case T_CHAR: case T_UNSIGNED | T_CHAR:
return 1;
case T_SHORT: case T_UNSIGNED | T_SHORT:
return intWidth_ >> 1;
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT:
return intWidth_;
case T_LONG: case T_UNSIGNED | T_LONG:
return intWidth_ << 1;
case T_LLONG: case T_UNSIGNED | T_LLONG:
return intWidth_ << 1;
case T_FLOAT:
return intWidth_;
case T_DOUBLE:
return intWidth_ << 1;
case T_LONG | T_DOUBLE:
return intWidth_ << 1;
case T_FLOAT | T_COMPLEX:
return intWidth_ << 1;
case T_DOUBLE | T_COMPLEX:
return intWidth_ << 2;
case T_LONG | T_DOUBLE | T_COMPLEX:
return intWidth_ << 2;
default:
assert(false);
}
return intWidth_; // Make compiler happy
}
int ArithmType::Rank() const {
switch (tag_) {
case T_BOOL: return 0;
case T_CHAR: case T_UNSIGNED | T_CHAR: return 1;
case T_SHORT: case T_UNSIGNED | T_SHORT: return 2;
case T_INT: case T_UNSIGNED: case T_UNSIGNED | T_INT: return 3;
case T_LONG: case T_UNSIGNED | T_LONG: return 4;
case T_LLONG: case T_UNSIGNED | T_LLONG: return 5;
case T_FLOAT: return 6;
case T_DOUBLE: return 7;
case T_LONG | T_DOUBLE: return 8;
default:
assert(tag_ & T_COMPLEX);
Error("complex not supported yet");
}
return 0;
}
ArithmType* ArithmType::MaxType(ArithmType* lhs,
ArithmType* rhs) {
if (lhs->IsInteger())
lhs = ArithmType::IntegerPromote(lhs);
if (rhs->IsInteger())
rhs = ArithmType::IntegerPromote(rhs);
auto ret = lhs->Rank() > rhs->Rank() ? lhs: rhs;
if (lhs->Width() == rhs->Width() && (lhs->IsUnsigned() || rhs->IsUnsigned()))
return ArithmType::New(T_UNSIGNED | ret->Tag());
return ret;
}
/*
* Converting from type specifier to type tag
*/
int ArithmType::Spec2Tag(int spec) {
if (spec == T_SIGNED) {
return T_INT;
}
spec &= ~T_SIGNED;
if ((spec & T_SHORT) || (spec & T_LONG)
|| (spec & T_LLONG)) {
spec &= ~T_INT;
}
return spec;
}
std::string ArithmType::Str() const {
std::string width = ":" + std::to_string(Width());
switch (tag_) {
case T_BOOL:
return "bool" + width;
case T_CHAR:
return "char" + width;
case T_UNSIGNED | T_CHAR:
return "unsigned char" + width;
case T_SHORT:
return "short" + width;
case T_UNSIGNED | T_SHORT:
return "unsigned short" + width;
case T_INT:
return "int" + width;
case T_UNSIGNED:
return "unsigned int" + width;
case T_LONG:
return "long" + width;
case T_UNSIGNED | T_LONG:
return "unsigned long" + width;
case T_LLONG:
return "long long" + width;
case T_UNSIGNED | T_LLONG:
return "unsigned long long" + width;
case T_FLOAT:
return "float" + width;
case T_DOUBLE:
return "double" + width;
case T_LONG | T_DOUBLE:
return "long double" + width;
case T_FLOAT | T_COMPLEX:
return "float complex" + width;
case T_DOUBLE | T_COMPLEX:
return "double complex" + width;
case T_LONG | T_DOUBLE | T_COMPLEX:
return "long double complex" + width;
default:
assert(false);
}
return "error"; // Make compiler happy
}
bool PointerType::Compatible(const Type& other) const {
// C11 6.7.6.1 [2]: pointer compatibility
auto otherPointer = other.ToPointer();
return otherPointer && derived_->Compatible(*otherPointer->derived_);
// FIXME(wgtdkp): cannot loose compatible constraints
//return other.IsInteger() ||
// (otherPointer && derived_->Compatible(*otherPointer->derived_));
}
bool ArrayType::Compatible(const Type& other) const {
// C11 6.7.6.2 [6]: For two array type to be compatible,
// the element types must be compatible, and have same length
// if both specified.
auto otherArray = other.ToArray();
if (!otherArray) return false;
if (!derived_->Compatible(*otherArray->derived_)) return false;
// The lengths should equal if both specified
if (complete_ && otherArray->complete_)
return len_ == otherArray->len_;
return true;
}
bool TileType::Compatible(const Type& other) const {
// For two tile type to be compatible,
// the element types must be compatible, and have same shape
// if both specified
auto otherTile = other.ToTile();
if(!otherTile) return false;
if (!derived_->Compatible(*otherTile->derived_)) return false;
// The shapes should be equal if both specified
if(complete_ && otherTile->complete_)
return shape_ == otherTile->shape_;
return true;
}
bool FuncType::Compatible(const Type& other) const {
auto otherFunc = other.ToFunc();
// The other type is not an function type
if (!otherFunc) return false;
// TODO(wgtdkp): do we need to check the type of return value when deciding
// compatibility of two function types ??
if (!derived_->Compatible(*otherFunc->derived_))
return false;
if (params_.size() != otherFunc->params_.size())
return false;
auto thisIter = params_.begin();
auto otherIter = otherFunc->params_.begin();
while (thisIter != params_.end()) {
if (!(*thisIter)->Type()->Compatible(*(*otherIter)->Type()))
return false;
++thisIter;
++otherIter;
}
return true;
}
std::string FuncType::Str() const {
auto str = derived_->Str() + "(";
auto iter = params_.begin();
for (; iter != params_.end(); ++iter) {
str += (*iter)->Type()->Str() + ", ";
}
if (variadic_)
str += "...";
else if (params_.size())
str.resize(str.size() - 2);
return str + ")";
}
StructType::StructType(MemPool* pool,
bool isStruct,
bool hasTag,
Scope* parent)
: Type(pool, false),
isStruct_(isStruct),
hasTag_(hasTag),
memberMap_(new Scope(parent, S_BLOCK)),
offset_(0),
width_(0),
// If a struct type has no member, it gets alignment of 1
align_(1),
bitFieldAlign_(1) {}
Object* StructType::GetMember(const std::string& member) {
auto ident = memberMap_->FindInCurScope(member);
if (ident == nullptr)
return nullptr;
return ident->ToObject();
}
void StructType::CalcWidth() {
width_ = 0;
auto iter = memberMap_->identMap_.begin();
for (; iter != memberMap_->identMap_.end(); ++iter) {
width_ += iter->second->Type()->Width();
}
}
bool StructType::Compatible(const Type& other) const {
return this == &other; // Pointer comparison
}
// TODO(wgtdkp): more detailed representation
std::string StructType::Str() const {
std::string str = isStruct_ ? "struct": "union";
return str + ":" + std::to_string(width_);
}
// Remove useless unnamed bitfield members as they are just for parsing
void StructType::Finalize() {
for (auto iter = members_.begin(); iter != members_.end();) {
if ((*iter)->BitFieldWidth() && (*iter)->Anonymous()) {
members_.erase(iter++);
} else {
++iter;
}
}
}
void StructType::AddMember(Object* member) {
auto offset = MakeAlign(offset_, member->Align());
member->SetOffset(offset);
members_.push_back(member);
memberMap_->Insert(member->Name(), member);
align_ = std::max(align_, member->Align());
bitFieldAlign_ = std::max(bitFieldAlign_, align_);
if (isStruct_) {
offset_ = offset + member->Type()->Width();
width_ = MakeAlign(offset_, align_);
} else {
assert(offset_ == 0);
width_ = std::max(width_, member->Type()->Width());
width_ = MakeAlign(width_, align_);
}
}
void StructType::AddBitField(Object* bitField, int offset) {
bitField->SetOffset(offset);
members_.push_back(bitField);
if (!bitField->Anonymous())
memberMap_->Insert(bitField->Name(), bitField);
auto bytes = MakeAlign(bitField->BitFieldEnd(), 8) / 8;
bitFieldAlign_ = std::max(bitFieldAlign_, bitField->Align());
// Does not aligned, default is 1
if (isStruct_) {
offset_ = offset + bytes;
width_ = MakeAlign(offset_, std::max(bitFieldAlign_, bitField->Align()));
} else {
assert(offset_ == 0);
width_ = std::max(width_, bitField->Type()->Width());
}
}
// Move members of Anonymous struct/union to external struct/union
void StructType::MergeAnony(Object* anony) {
auto anonyType = anony->Type()->ToStruct();
auto offset = MakeAlign(offset_, anony->Align());
// Members in map are never anonymous
for (auto& kv: *anonyType->memberMap_) {
auto& name = kv.first;
auto member = kv.second->ToObject();
if (member == nullptr) {
continue;
}
// Every member of anonymous struct/union
// are offseted by external struct/union
member->SetOffset(offset + member->Offset());
if (GetMember(name)) {
Error(member, "duplicated member '%s'", name.c_str());
}
// Simplify anony struct's member searching
memberMap_->Insert(name, member);
}
anony->SetOffset(offset);
members_.push_back(anony);
align_ = std::max(align_, anony->Align());
if (isStruct_) {
offset_ = offset + anonyType->Width();
width_ = MakeAlign(offset_, align_);
} else {
assert(offset_ == 0);
width_ = std::max(width_, anonyType->Width());
}
}

View File

@@ -6,6 +6,8 @@
#include "triton/codegen/selection/selection.h" #include "triton/codegen/selection/selection.h"
#include "triton/runtime/function.h" #include "triton/runtime/function.h"
#include "triton/lang/lang.h" #include "triton/lang/lang.h"
#include "triton/lang/wgtcc/cpp.h"
#include "triton/lang/wgtcc/parser.h"
#include "triton/driver/device.h" #include "triton/driver/device.h"
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
@@ -115,8 +117,30 @@ void function::caller::operator ()(driver::stream *stream, const std::array<size
// module // module
triton::lang::translation_unit *function::make_ast(const char *src) { triton::lang::translation_unit *function::make_ast(const char *csrc) {
YY_BUFFER_STATE buffer = yy_scan_string(src); std::string src(csrc);
Preprocessor cpp(&src, true);
// for (auto& def: defines)
// DefineMacro(cpp, def);
// for (auto& path: include_paths)
// cpp.AddSearchPath(path);
FILE* fp = stdout;
// if (specified_out_name) {
// fp = fopen(filename_out.c_str(), "w");
// }
TokenSequence ts;
cpp.Process(ts);
Parser parser(ts);
parser.Parse();
exit(EXIT_FAILURE);
// if (only_preprocess) {
// ts.Print(fp);
// return 0;
// }
YY_BUFFER_STATE buffer = yy_scan_string(csrc);
yyparse(); yyparse();
yy_delete_buffer(buffer); yy_delete_buffer(buffer);
triton::lang::translation_unit *program = ast_root; triton::lang::translation_unit *program = ast_root;

View File

@@ -133,7 +133,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector<ir::argument*>
} }
void gen_register_kernel_builder(std::ostream &os, const std::string &name, void gen_register_kernel_builder(std::ostream &os, const std::string &name,
const std::string &classname, const std::string &opname,
const std::vector<ir::argument*>& args){ const std::vector<ir::argument*>& args){
os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)"; os << "REGISTER_KERNEL_BUILDER(Name(\"" + name + "\").Device(DEVICE_GPU)";
for(size_t i = 0; i < args.size(); i++){ for(size_t i = 0; i < args.size(); i++){
@@ -144,7 +144,7 @@ void gen_register_kernel_builder(std::ostream &os, const std::string &name,
if(!arg->get_type()->is_pointer_ty()) if(!arg->get_type()->is_pointer_ty())
os << ".HostMemory(\"" + name + "\")"; os << ".HostMemory(\"" + name + "\")";
} }
os << ", " + classname << ");\n"; os << ", " + opname << ");\n";
} }
void gen_register_op(std::ostream &os, const std::string &name, void gen_register_op(std::ostream &os, const std::string &name,
@@ -181,10 +181,9 @@ std::string make_tensorflow_src(const std::string src,
ir::function* fn = ir->get_function_list().front(); ir::function* fn = ir->get_function_list().front();
std::string name = fn->get_name(); std::string name = fn->get_name();
name[0] = static_cast<char>(std::toupper(name[0])); name[0] = static_cast<char>(std::toupper(name[0]));
std::string classname = name + "Op"; std::string opname = name + "Op";
std::ostringstream oss; std::ostringstream oss;
oss << R"( oss << R"(
#include "triton/driver/buffer.h" #include "triton/driver/buffer.h"
#include "triton/driver/backend.h" #include "triton/driver/backend.h"
@@ -207,9 +206,9 @@ namespace drv = triton::driver;
std::string src = R"TTKERNSRC( )" + src + ")TTKERNSRC\";" + R"( std::string src = R"TTKERNSRC( )" + src + ")TTKERNSRC\";" + R"(
class )" << classname << R"(: public OpKernel { class )" << opname << R"(: public OpKernel {
public: public:
explicit )" << classname << R"((OpKernelConstruction* context) explicit )" << opname << R"((OpKernelConstruction* context)
: OpKernel(context), fn_(src) { } : OpKernel(context), fn_(src) { }
void Compute(OpKernelContext* context){ void Compute(OpKernelContext* context){
@@ -246,7 +245,7 @@ private:
// register kernel builder // register kernel builder
)"; )";
gen_register_kernel_builder(oss, name, classname, fn->args()); gen_register_kernel_builder(oss, name, opname, fn->args());
oss << R"( oss << R"(
// register op // register op
)"; )";