more debugging

This commit is contained in:
Philippe Tillet
2019-08-21 21:53:41 -07:00
parent a23225ad37
commit a6ec807223
10 changed files with 213 additions and 77 deletions

View File

@@ -138,6 +138,25 @@ private:
Stmt* else_; Stmt* else_;
}; };
class ForStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ForStmt* New(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr);
virtual ~ForStmt() {}
virtual void Accept(Visitor* v);
protected:
ForStmt(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr)
: body_(body), init_(init), cond_(cond), step_(step) {}
private:
Stmt* body_;
Stmt* init_;
Expr* cond_;
Expr* step_;
};
class JumpStmt : public Stmt { class JumpStmt : public Stmt {
template<typename T> friend class Evaluator; template<typename T> friend class Evaluator;

View File

@@ -31,7 +31,7 @@ using StaticInitList = std::vector<StaticInitializer>;
// Error // Error
inline void should_not_happen() { assert(false); } inline void should_not_happen() { assert(false); }
inline void error_not_implemented() { assert(false); } inline void error_not_implemented() { throw std::runtime_error("not implemented"); }
class Generator: public Visitor { class Generator: public Visitor {
friend class Evaluator<Addr>; friend class Evaluator<Addr>;
@@ -48,32 +48,33 @@ protected:
public: public:
Generator(Parser* parser) : parser_(parser) {} Generator(Parser* parser) : parser_(parser) {}
virtual void Visit(ASTNode* node) { node->Accept(this); } void Visit(ASTNode* node) { node->Accept(this); }
void VisitExpr(Expr* expr) { expr->Accept(this); } void VisitExpr(Expr* expr) { expr->Accept(this); }
void VisitStmt(Stmt* stmt) { stmt->Accept(this); } void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
// Expression // Expression
virtual void VisitBinaryOp(BinaryOp* binaryOp); void VisitBinaryOp(BinaryOp* binaryOp);
virtual void VisitUnaryOp(UnaryOp* unaryOp); void VisitUnaryOp(UnaryOp* unaryOp);
virtual void VisitConditionalOp(ConditionalOp* condOp); void VisitConditionalOp(ConditionalOp* condOp);
virtual void VisitFuncCall(FuncCall* funcCall); void VisitFuncCall(FuncCall* funcCall);
virtual void VisitObject(Object* obj); void VisitObject(Object* obj);
virtual void VisitEnumerator(Enumerator* enumer); void VisitEnumerator(Enumerator* enumer);
virtual void VisitIdentifier(Identifier* ident); void VisitIdentifier(Identifier* ident);
virtual void VisitConstant(Constant* cons); void VisitConstant(Constant* cons);
virtual void VisitTempVar(TempVar* tempVar); void VisitTempVar(TempVar* tempVar);
// Statement // Statement
virtual void VisitDeclaration(Declaration* init); void VisitDeclaration(Declaration* init);
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt); void VisitEmptyStmt(EmptyStmt* emptyStmt);
virtual void VisitIfStmt(IfStmt* ifStmt); void VisitIfStmt(IfStmt* ifStmt);
virtual void VisitJumpStmt(JumpStmt* jumpStmt); void VisitForStmt(ForStmt* ifStmt);
virtual void VisitReturnStmt(ReturnStmt* returnStmt); void VisitJumpStmt(JumpStmt* jumpStmt);
virtual void VisitLabelStmt(LabelStmt* labelStmt); void VisitReturnStmt(ReturnStmt* returnStmt);
virtual void VisitCompoundStmt(CompoundStmt* compoundStmt); void VisitLabelStmt(LabelStmt* labelStmt);
void VisitCompoundStmt(CompoundStmt* compoundStmt);
virtual void VisitFuncDef(FuncDef* funcDef); void VisitFuncDef(FuncDef* funcDef);
virtual void VisitTranslationUnit(TranslationUnit* unit); void VisitTranslationUnit(TranslationUnit* unit);
void Gen(ir::module *mod); void Gen(ir::module *mod);
@@ -127,6 +128,7 @@ public:
void VisitDeclaration(Declaration*) { should_not_happen(); } void VisitDeclaration(Declaration*) { should_not_happen(); }
void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); } void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); }
void VisitIfStmt(IfStmt*) { should_not_happen(); } void VisitIfStmt(IfStmt*) { should_not_happen(); }
void VisitForStmt(ForStmt*) { should_not_happen(); }
void VisitJumpStmt(JumpStmt*) { should_not_happen(); } void VisitJumpStmt(JumpStmt*) { should_not_happen(); }
void VisitReturnStmt(ReturnStmt*) { should_not_happen(); } void VisitReturnStmt(ReturnStmt*) { should_not_happen(); }
void VisitLabelStmt(LabelStmt*) { should_not_happen(); } void VisitLabelStmt(LabelStmt*) { should_not_happen(); }

View File

@@ -45,6 +45,7 @@ public:
// We may should assert here // We may should assert here
virtual void VisitDeclaration(Declaration* init) {} virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {} virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {} virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {} virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {} virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
@@ -100,6 +101,7 @@ public:
// We may should assert here // We may should assert here
virtual void VisitDeclaration(Declaration* init) {} virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {} virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {} virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {} virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {} virtual void VisitLabelStmt(LabelStmt* labelStmt) {}

View File

@@ -146,7 +146,7 @@ public:
CompoundStmt* ParseSwitchStmt(); CompoundStmt* ParseSwitchStmt();
CompoundStmt* ParseWhileStmt(); CompoundStmt* ParseWhileStmt();
CompoundStmt* ParseDoStmt(); CompoundStmt* ParseDoStmt();
CompoundStmt* ParseForStmt(); ForStmt *ParseForStmt();
JumpStmt* ParseGotoStmt(); JumpStmt* ParseGotoStmt();
JumpStmt* ParseContinueStmt(); JumpStmt* ParseContinueStmt();
JumpStmt* ParseBreakStmt(); JumpStmt* ParseBreakStmt();

View File

@@ -14,6 +14,7 @@ class TempVar;
class Declaration; class Declaration;
class IfStmt; class IfStmt;
class ForStmt;
class JumpStmt; class JumpStmt;
class ReturnStmt; class ReturnStmt;
class LabelStmt; class LabelStmt;
@@ -38,6 +39,7 @@ public:
virtual void VisitDeclaration(Declaration* init) = 0; virtual void VisitDeclaration(Declaration* init) = 0;
virtual void VisitIfStmt(IfStmt* ifStmt) = 0; virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
virtual void VisitForStmt(ForStmt* ifStmt) = 0;
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0; virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0; virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0; virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;

View File

@@ -18,6 +18,7 @@ static MemPoolImp<TempVar> tempVarPool;
static MemPoolImp<UnaryOp> unaryOpPool; static MemPoolImp<UnaryOp> unaryOpPool;
static MemPoolImp<EmptyStmt> emptyStmtPool; static MemPoolImp<EmptyStmt> emptyStmtPool;
static MemPoolImp<IfStmt> ifStmtPool; static MemPoolImp<IfStmt> ifStmtPool;
static MemPoolImp<ForStmt> forStmtPool;
static MemPoolImp<JumpStmt> jumpStmtPool; static MemPoolImp<JumpStmt> jumpStmtPool;
static MemPoolImp<ReturnStmt> returnStmtPool; static MemPoolImp<ReturnStmt> returnStmtPool;
static MemPoolImp<LabelStmt> labelStmtPool; static MemPoolImp<LabelStmt> labelStmtPool;
@@ -48,6 +49,10 @@ void IfStmt::Accept(Visitor* v) {
v->VisitIfStmt(this); v->VisitIfStmt(this);
} }
void ForStmt::Accept(Visitor* v) {
v->VisitForStmt(this);
}
void JumpStmt::Accept(Visitor* v) { void JumpStmt::Accept(Visitor* v) {
v->VisitJumpStmt(this); v->VisitJumpStmt(this);
@@ -396,6 +401,7 @@ void BinaryOp::AdditiveOpTypeChecking() {
::Type* rhsScalType = TryExtractScalarType(this, rhs_); ::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsPtrType = lhsScalType->ToPointer(); auto lhsPtrType = lhsScalType->ToPointer();
auto rhsPtrType = rhsScalType->ToPointer(); auto rhsPtrType = rhsScalType->ToPointer();
std::cout << "adding" << std::endl;
if (lhsPtrType) { if (lhsPtrType) {
if (op_ == '-') { if (op_ == '-') {
if (rhsPtrType) { if (rhsPtrType) {
@@ -430,6 +436,7 @@ void BinaryOp::AdditiveOpTypeChecking() {
} }
void BinaryOp::RangeOpTypeChecking() { void BinaryOp::RangeOpTypeChecking() {
std::cout << "range" << std::endl;
auto lhsType = lhs_->Type()->ToArithm(); auto lhsType = lhs_->Type()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm(); auto rhsType = rhs_->Type()->ToArithm();
if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger()) if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
@@ -546,6 +553,7 @@ void BinaryOp::AssignOpTypeChecking() {
// The other constraints are lefted to cast operator // The other constraints are lefted to cast operator
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType)); rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
type_ = lhs_->Type(); type_ = lhs_->Type();
Broadcast();
} }
@@ -969,6 +977,11 @@ CompoundStmt* CompoundStmt::New(std::list<Stmt*>& stmts, ::Scope* scope) {
return ret; return ret;
} }
ForStmt* ForStmt::New(Stmt* body, Stmt* init, Expr* cond, Expr* step) {
auto ret = new (forStmtPool.Alloc()) ForStmt(body, init, cond, step);
ret->pool_ = &forStmtPool;
return ret;
}
JumpStmt* JumpStmt::New(LabelStmt* label) { JumpStmt* JumpStmt::New(LabelStmt* label) {
auto ret = new (jumpStmtPool.Alloc()) JumpStmt(label); auto ret = new (jumpStmtPool.Alloc()) JumpStmt(label);

View File

@@ -3,6 +3,7 @@
#include "triton/lang/wgtcc/parser.h" #include "triton/lang/wgtcc/parser.h"
#include "triton/lang/wgtcc/token.h" #include "triton/lang/wgtcc/token.h"
#include "triton/ir/module.h" #include "triton/ir/module.h"
#include "triton/ir/function.h"
// Helpers // Helpers
void Generator::set_ret(ir::value* value) { void Generator::set_ret(ir::value* value) {
@@ -25,10 +26,12 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
Visit(binary->lhs_); Visit(binary->lhs_);
ir::value* lhs = ret_; ir::value* lhs = ret_;
// op info // op info
auto type = binary->lhs_->Type(); auto type = binary->lhs_->Type();
auto flt = type->IsFloat(); auto flt = type->IsFloat();
auto sign = !type->IsUnsigned(); auto sign = !type->IsUnsigned();
// return // return
switch(binary->op_){ switch(binary->op_){
case Token::LOGICAL_AND: return set_ret(bld_->create_and(lhs, rhs)); case Token::LOGICAL_AND: return set_ret(bld_->create_and(lhs, rhs));
@@ -40,6 +43,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs)); case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
case '.': return error_not_implemented(); case '.': return error_not_implemented();
case ',': return error_not_implemented(); case ',': return error_not_implemented();
case Token::ELLIPSIS: {
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
if(!clhs || !crhs)
should_not_happen();
return set_ret(ir::constant_range::get(clhs, crhs));
}
case '+': case '+':
if(binary->lhs_->Type()->ToPointer()) if(binary->lhs_->Type()->ToPointer())
return set_ret(bld_->create_gep(lhs, {rhs})); return set_ret(bld_->create_gep(lhs, {rhs}));
@@ -210,6 +220,14 @@ void Generator::VisitDeclaration(Declaration* decl) {
if(inits.size() > 1) if(inits.size() > 1)
assert(false); assert(false);
val = inits[0]; val = inits[0];
std::cout << obj->Name() << " " << val->get_type()->get_type_id() << " " << ty->get_type_id() << std::endl;
if(val->get_type()->is_tile_ty() && ty->is_tile_ty()) {
for(auto s: val->get_type()->get_tile_shapes())
std::cout << s->get_value() << std::endl;
std::cout << "---" << std::endl;
for(auto s: ty->get_tile_shapes())
std::cout << s->get_value() << std::endl;
}
assert(val->get_type() == ty); assert(val->get_type() == ty);
// update scope symbols table // update scope symbols table
const std::string &name = obj->Name(); const std::string &name = obj->Name();
@@ -258,6 +276,38 @@ void Generator::VisitIfStmt(IfStmt* ifStmt) {
bld_->set_insert_point(endif_bb); bld_->set_insert_point(endif_bb);
} }
void Generator::VisitForStmt(ForStmt *forStmt) {
Stmt *init_ = forStmt->init_;
Expr *cond_ = forStmt->cond_;
Expr *step_ = forStmt->step_;
Stmt *body_ = forStmt->body_;
ir::basic_block *current_bb = bld_->get_insert_block();
ir::function *fn = current_bb->get_parent();
ir::basic_block *loop_bb = ir::basic_block::create(*ctx_, "loop", fn);
ir::basic_block *next_bb = ir::basic_block::create(*ctx_, "postloop", fn);
mod_->set_continue_fn([&](){
if(step_)
VisitExpr(step_);
VisitExpr(cond_);
ir::value *cond = ret_;
return bld_->create_cond_br(cond, loop_bb, next_bb);
});
VisitStmt(init_);
VisitExpr(cond_);
ir::value *cond = ret_;
bld_->create_cond_br(cond, loop_bb, next_bb);
bld_->set_insert_point(loop_bb);
VisitStmt(body_);
if(!is_terminator(ret_))
mod_->get_continue_fn()();
ir::basic_block *stop_bb = bld_->get_insert_block();
mod_->seal_block(stop_bb);
mod_->seal_block(loop_bb);
mod_->seal_block(bld_->get_insert_block());
mod_->seal_block(next_bb);
bld_->set_insert_point(next_bb);
}
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) { void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
return error_not_implemented(); return error_not_implemented();
} }
@@ -277,7 +327,7 @@ void Generator::VisitLabelStmt(LabelStmt* labelStmt) {
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) { void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
if (compoundStmt->scope_){ if (compoundStmt->scope_){
AllocObjects(compoundStmt->scope_); // AllocObjects(compoundStmt->scope_);
pushScope(); pushScope();
} }
for (auto stmt: compoundStmt->stmts_) for (auto stmt: compoundStmt->stmts_)
@@ -287,32 +337,99 @@ void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
} }
void Generator::VisitFuncDef(FuncDef* funcDef) { void Generator::VisitFuncDef(FuncDef* funcDef) {
return error_not_implemented(); Stmt *body = funcDef->body_;
const std::string& name = funcDef->Name();
FuncType* type = funcDef->FuncType();
auto prototype = dynamic_cast<ir::function_type*>(GenIRType(type, *ctx_));
if(!prototype)
should_not_happen();
ir::function *fn = mod_->get_or_insert_function(name, prototype);
std::vector<ir::argument*> args = fn->args();
size_t i = 0;
for(Object* obj: type->Params()){
std::string name = obj->Name();
args[i]->set_name(name);
mod_->set_value(name, nullptr, args[i]);
mod_->get_scope().types[name] = args[i]->get_type();
}
ir::basic_block *entry = ir::basic_block::create(mod_->get_context(), "entry", fn);
mod_->seal_block(entry);
mod_->get_builder().set_insert_point(entry);
VisitStmt(body);
if(!dynamic_cast<ir::return_inst*>(ret_))
mod_->get_builder().create_ret_void();
} }
void Generator::VisitTranslationUnit(TranslationUnit* unit) { void Generator::VisitTranslationUnit(TranslationUnit* unit) {
pushScope();
for (auto extDecl: unit->ExtDecls()) for (auto extDecl: unit->ExtDecls())
Visit(extDecl); Visit(extDecl);
popScope();
} }
void Generator::Gen(ir::module *mod) { void Generator::Gen(ir::module *mod) {
pushScope();
mod_ = mod; mod_ = mod;
ctx_ = &mod_->get_context(); ctx_ = &mod_->get_context();
bld_ = &mod_->get_builder(); bld_ = &mod_->get_builder();
std::unique_ptr<LValAssigner> assign(new LValAssigner(this)); assign_ = new LValAssigner(this);
assign_ = assign.get();
VisitTranslationUnit(parser_->Unit()); VisitTranslationUnit(parser_->Unit());
delete assign_;
assign_ = nullptr; assign_ = nullptr;
} }
// Triton-IR Values // Triton-IR Values
ir::value* Generator::GenCastOp(ir::value* op, ir::type* type) { ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) {
//TODO if(dst_ty->is_tile_ty()) {
assert(false); auto dst_shapes = dst_ty->get_tile_shapes();
if(!src->get_type()->is_tile_ty())
return bld_->create_splat(src, dst_shapes);
auto src_shapes = src->get_type()->get_tile_shapes();
if(src_shapes.size() != dst_shapes.size())
return bld_->create_reshape(src, dst_shapes);
else
return bld_->create_broadcast(src, dst_shapes);
}
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
bool src_signed = false;
bool dst_signed = false;
if(src->get_type()->is_tile_ty())
dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type());
if(src_scalar_ty == dst_scalar_ty)
return src;
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_si_to_fp(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_ui_to_fp(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
return bld_->create_fp_to_si(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
return bld_->create_fp_to_ui(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
return bld_->create_fp_ext(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
return bld_->create_fp_trunc(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
src_scalar_ty->get_integer_bitwidth())
return bld_->create_int_cast(src, dst_ty, dst_signed);
else{
should_not_happen();
return nullptr; return nullptr;
}
} }
// Triton-IR Types // Triton-IR Types

View File

@@ -450,8 +450,9 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
// create ret shape // create ret shape
TileType::ShapeInt shape; TileType::ShapeInt shape;
size_t i = 0; size_t i = 0;
const Token* tok;
do { do {
auto tok = ts_.Next(); tok = ts_.Next();
if(tok->tag_ == ':') if(tok->tag_ == ':')
shape.push_back(lhsShape[i++]); shape.push_back(lhsShape[i++]);
else if(tok->tag_ == Token::NEWAXIS) else if(tok->tag_ == Token::NEWAXIS)
@@ -460,6 +461,8 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
Error(tok, "only ':' and newaxis are supported in subscripts"); Error(tok, "only ':' and newaxis are supported in subscripts");
}while(ts_.Try(',')); }while(ts_.Try(','));
ts_.Expect(']'); ts_.Expect(']');
// if(lhsShape.size() > i)
// Error(tok, "broadcasting not using all operand axes");
// create ret tile // create ret tile
TileType *retType = TileType::New(shape, lhsQual); TileType *retType = TileType::New(shape, lhsQual);
return UnaryOp::New(Token::CAST, lhs, retType); return UnaryOp::New(Token::CAST, lhs, retType);
@@ -2298,61 +2301,33 @@ IfStmt* Parser::ParseIfStmt() {
continueDest_ = continueDestBackup; \ continueDest_ = continueDestBackup; \
} }
CompoundStmt* Parser::ParseForStmt() { ForStmt* Parser::ParseForStmt() {
EnterBlock(); EnterBlock();
ts_.Expect('('); ts_.Expect('(');
// init
std::list<Stmt*> stmts; Stmt* init = nullptr;
if (IsType(ts_.Peek())) { if (IsType(ts_.Peek())) {
stmts.push_back(ParseDecl()); init = ParseDecl();
} else if (!ts_.Try(';')) { } else if (!ts_.Try(';')) {
stmts.push_back(ParseExpr()); init = ParseExpr();
ts_.Expect(';'); ts_.Expect(';');
} }
// cond
Expr* condExpr = nullptr; Expr* cond = nullptr;
if (!ts_.Try(';')) { if (!ts_.Try(';')) {
condExpr = ParseExpr(); cond = ParseExpr();
ts_.Expect(';'); ts_.Expect(';');
} }
// step
Expr* stepExpr = nullptr; Expr* step = nullptr;
if (!ts_.Try(')')) { if (!ts_.Try(')')) {
stepExpr = ParseExpr(); step = ParseExpr();
ts_.Expect(')'); ts_.Expect(')');
} }
// body
auto condLabel = LabelStmt::New(); Stmt* body = ParseStmt();
auto stepLabel = LabelStmt::New();
auto endLabel = LabelStmt::New();
stmts.push_back(condLabel);
if (condExpr) {
auto gotoEndStmt = JumpStmt::New(endLabel);
auto ifStmt = IfStmt::New(condExpr, EmptyStmt::New(), gotoEndStmt);
stmts.push_back(ifStmt);
}
// 我们需要给break和continue语句提供相应的标号不然不知往哪里跳
Stmt* bodyStmt;
ENTER_LOOP_BODY(endLabel, stepLabel);
bodyStmt = ParseStmt();
// 因为for的嵌套结构在这里需要回复break和continue的目标标号
EXIT_LOOP_BODY()
stmts.push_back(bodyStmt);
stmts.push_back(stepLabel);
if (stepExpr)
stmts.push_back(stepExpr);
else
stmts.push_back(EmptyStmt::New());
stmts.push_back(JumpStmt::New(condLabel));
stmts.push_back(endLabel);
auto scope = curScope_;
ExitBlock(); ExitBlock();
return ForStmt::New(body, init, cond, step);
return CompoundStmt::New(stmts, scope);
} }

View File

@@ -317,11 +317,13 @@ bool ArrayType::Compatible(const Type& other) const {
bool TileType::Compatible(const Type& other) const { bool TileType::Compatible(const Type& other) const {
// For two tile type to be compatible, // For two tile type to be compatible,
// the element types must be compatible, and have same shape // the element types must be compatible
// if both specified // and they must have compatible shapes
auto otherTile = other.ToTile(); auto otherTile = other.ToTile();
if(!otherTile) return false; if(!otherTile)
if (!derived_->Compatible(*otherTile->derived_)) return false; return false;
if (!derived_->Compatible(*otherTile->derived_))
return false;
// The shapes should be equal if both specified // The shapes should be equal if both specified
if(complete_ && otherTile->complete_) if(complete_ && otherTile->complete_)
return shape_ == otherTile->shape_; return shape_ == otherTile->shape_;

View File

@@ -8,6 +8,7 @@
#include "triton/lang/lang.h" #include "triton/lang/lang.h"
#include "triton/lang/wgtcc/cpp.h" #include "triton/lang/wgtcc/cpp.h"
#include "triton/lang/wgtcc/parser.h" #include "triton/lang/wgtcc/parser.h"
#include "triton/lang/wgtcc/code_gen.h"
#include "triton/driver/device.h" #include "triton/driver/device.h"
#include "triton/driver/stream.h" #include "triton/driver/stream.h"
#include "triton/driver/kernel.h" #include "triton/driver/kernel.h"
@@ -133,6 +134,9 @@ triton::lang::translation_unit *function::make_ast(const char *csrc) {
cpp.Process(ts); cpp.Process(ts);
Parser parser(ts); Parser parser(ts);
parser.Parse(); parser.Parse();
Generator gen(&parser);
ir::module out("", ctx_);
gen.Gen(&out);
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
// if (only_preprocess) { // if (only_preprocess) {