more debugging
This commit is contained in:
@@ -138,6 +138,25 @@ private:
|
|||||||
Stmt* else_;
|
Stmt* else_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ForStmt: public Stmt {
|
||||||
|
template<typename T> friend class Evaluator;
|
||||||
|
friend class AddrEvaluator;
|
||||||
|
friend class Generator;
|
||||||
|
public:
|
||||||
|
static ForStmt* New(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr);
|
||||||
|
virtual ~ForStmt() {}
|
||||||
|
virtual void Accept(Visitor* v);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
ForStmt(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr)
|
||||||
|
: body_(body), init_(init), cond_(cond), step_(step) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Stmt* body_;
|
||||||
|
Stmt* init_;
|
||||||
|
Expr* cond_;
|
||||||
|
Expr* step_;
|
||||||
|
};
|
||||||
|
|
||||||
class JumpStmt : public Stmt {
|
class JumpStmt : public Stmt {
|
||||||
template<typename T> friend class Evaluator;
|
template<typename T> friend class Evaluator;
|
||||||
|
@@ -31,7 +31,7 @@ using StaticInitList = std::vector<StaticInitializer>;
|
|||||||
|
|
||||||
// Error
|
// Error
|
||||||
inline void should_not_happen() { assert(false); }
|
inline void should_not_happen() { assert(false); }
|
||||||
inline void error_not_implemented() { assert(false); }
|
inline void error_not_implemented() { throw std::runtime_error("not implemented"); }
|
||||||
|
|
||||||
class Generator: public Visitor {
|
class Generator: public Visitor {
|
||||||
friend class Evaluator<Addr>;
|
friend class Evaluator<Addr>;
|
||||||
@@ -48,32 +48,33 @@ protected:
|
|||||||
public:
|
public:
|
||||||
Generator(Parser* parser) : parser_(parser) {}
|
Generator(Parser* parser) : parser_(parser) {}
|
||||||
|
|
||||||
virtual void Visit(ASTNode* node) { node->Accept(this); }
|
void Visit(ASTNode* node) { node->Accept(this); }
|
||||||
void VisitExpr(Expr* expr) { expr->Accept(this); }
|
void VisitExpr(Expr* expr) { expr->Accept(this); }
|
||||||
void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
|
void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
|
||||||
|
|
||||||
// Expression
|
// Expression
|
||||||
virtual void VisitBinaryOp(BinaryOp* binaryOp);
|
void VisitBinaryOp(BinaryOp* binaryOp);
|
||||||
virtual void VisitUnaryOp(UnaryOp* unaryOp);
|
void VisitUnaryOp(UnaryOp* unaryOp);
|
||||||
virtual void VisitConditionalOp(ConditionalOp* condOp);
|
void VisitConditionalOp(ConditionalOp* condOp);
|
||||||
virtual void VisitFuncCall(FuncCall* funcCall);
|
void VisitFuncCall(FuncCall* funcCall);
|
||||||
virtual void VisitObject(Object* obj);
|
void VisitObject(Object* obj);
|
||||||
virtual void VisitEnumerator(Enumerator* enumer);
|
void VisitEnumerator(Enumerator* enumer);
|
||||||
virtual void VisitIdentifier(Identifier* ident);
|
void VisitIdentifier(Identifier* ident);
|
||||||
virtual void VisitConstant(Constant* cons);
|
void VisitConstant(Constant* cons);
|
||||||
virtual void VisitTempVar(TempVar* tempVar);
|
void VisitTempVar(TempVar* tempVar);
|
||||||
|
|
||||||
// Statement
|
// Statement
|
||||||
virtual void VisitDeclaration(Declaration* init);
|
void VisitDeclaration(Declaration* init);
|
||||||
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt);
|
void VisitEmptyStmt(EmptyStmt* emptyStmt);
|
||||||
virtual void VisitIfStmt(IfStmt* ifStmt);
|
void VisitIfStmt(IfStmt* ifStmt);
|
||||||
virtual void VisitJumpStmt(JumpStmt* jumpStmt);
|
void VisitForStmt(ForStmt* ifStmt);
|
||||||
virtual void VisitReturnStmt(ReturnStmt* returnStmt);
|
void VisitJumpStmt(JumpStmt* jumpStmt);
|
||||||
virtual void VisitLabelStmt(LabelStmt* labelStmt);
|
void VisitReturnStmt(ReturnStmt* returnStmt);
|
||||||
virtual void VisitCompoundStmt(CompoundStmt* compoundStmt);
|
void VisitLabelStmt(LabelStmt* labelStmt);
|
||||||
|
void VisitCompoundStmt(CompoundStmt* compoundStmt);
|
||||||
|
|
||||||
virtual void VisitFuncDef(FuncDef* funcDef);
|
void VisitFuncDef(FuncDef* funcDef);
|
||||||
virtual void VisitTranslationUnit(TranslationUnit* unit);
|
void VisitTranslationUnit(TranslationUnit* unit);
|
||||||
|
|
||||||
void Gen(ir::module *mod);
|
void Gen(ir::module *mod);
|
||||||
|
|
||||||
@@ -127,6 +128,7 @@ public:
|
|||||||
void VisitDeclaration(Declaration*) { should_not_happen(); }
|
void VisitDeclaration(Declaration*) { should_not_happen(); }
|
||||||
void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); }
|
void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); }
|
||||||
void VisitIfStmt(IfStmt*) { should_not_happen(); }
|
void VisitIfStmt(IfStmt*) { should_not_happen(); }
|
||||||
|
void VisitForStmt(ForStmt*) { should_not_happen(); }
|
||||||
void VisitJumpStmt(JumpStmt*) { should_not_happen(); }
|
void VisitJumpStmt(JumpStmt*) { should_not_happen(); }
|
||||||
void VisitReturnStmt(ReturnStmt*) { should_not_happen(); }
|
void VisitReturnStmt(ReturnStmt*) { should_not_happen(); }
|
||||||
void VisitLabelStmt(LabelStmt*) { should_not_happen(); }
|
void VisitLabelStmt(LabelStmt*) { should_not_happen(); }
|
||||||
|
@@ -45,6 +45,7 @@ public:
|
|||||||
// We may should assert here
|
// We may should assert here
|
||||||
virtual void VisitDeclaration(Declaration* init) {}
|
virtual void VisitDeclaration(Declaration* init) {}
|
||||||
virtual void VisitIfStmt(IfStmt* ifStmt) {}
|
virtual void VisitIfStmt(IfStmt* ifStmt) {}
|
||||||
|
virtual void VisitForStmt(ForStmt* forStmt) {}
|
||||||
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
|
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
|
||||||
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
|
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
|
||||||
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
|
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
|
||||||
@@ -100,6 +101,7 @@ public:
|
|||||||
// We may should assert here
|
// We may should assert here
|
||||||
virtual void VisitDeclaration(Declaration* init) {}
|
virtual void VisitDeclaration(Declaration* init) {}
|
||||||
virtual void VisitIfStmt(IfStmt* ifStmt) {}
|
virtual void VisitIfStmt(IfStmt* ifStmt) {}
|
||||||
|
virtual void VisitForStmt(ForStmt* forStmt) {}
|
||||||
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
|
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
|
||||||
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
|
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
|
||||||
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
|
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
|
||||||
|
@@ -146,7 +146,7 @@ public:
|
|||||||
CompoundStmt* ParseSwitchStmt();
|
CompoundStmt* ParseSwitchStmt();
|
||||||
CompoundStmt* ParseWhileStmt();
|
CompoundStmt* ParseWhileStmt();
|
||||||
CompoundStmt* ParseDoStmt();
|
CompoundStmt* ParseDoStmt();
|
||||||
CompoundStmt* ParseForStmt();
|
ForStmt *ParseForStmt();
|
||||||
JumpStmt* ParseGotoStmt();
|
JumpStmt* ParseGotoStmt();
|
||||||
JumpStmt* ParseContinueStmt();
|
JumpStmt* ParseContinueStmt();
|
||||||
JumpStmt* ParseBreakStmt();
|
JumpStmt* ParseBreakStmt();
|
||||||
|
@@ -14,6 +14,7 @@ class TempVar;
|
|||||||
|
|
||||||
class Declaration;
|
class Declaration;
|
||||||
class IfStmt;
|
class IfStmt;
|
||||||
|
class ForStmt;
|
||||||
class JumpStmt;
|
class JumpStmt;
|
||||||
class ReturnStmt;
|
class ReturnStmt;
|
||||||
class LabelStmt;
|
class LabelStmt;
|
||||||
@@ -38,6 +39,7 @@ public:
|
|||||||
|
|
||||||
virtual void VisitDeclaration(Declaration* init) = 0;
|
virtual void VisitDeclaration(Declaration* init) = 0;
|
||||||
virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
|
virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
|
||||||
|
virtual void VisitForStmt(ForStmt* ifStmt) = 0;
|
||||||
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
|
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
|
||||||
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
|
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
|
||||||
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;
|
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;
|
||||||
|
@@ -18,6 +18,7 @@ static MemPoolImp<TempVar> tempVarPool;
|
|||||||
static MemPoolImp<UnaryOp> unaryOpPool;
|
static MemPoolImp<UnaryOp> unaryOpPool;
|
||||||
static MemPoolImp<EmptyStmt> emptyStmtPool;
|
static MemPoolImp<EmptyStmt> emptyStmtPool;
|
||||||
static MemPoolImp<IfStmt> ifStmtPool;
|
static MemPoolImp<IfStmt> ifStmtPool;
|
||||||
|
static MemPoolImp<ForStmt> forStmtPool;
|
||||||
static MemPoolImp<JumpStmt> jumpStmtPool;
|
static MemPoolImp<JumpStmt> jumpStmtPool;
|
||||||
static MemPoolImp<ReturnStmt> returnStmtPool;
|
static MemPoolImp<ReturnStmt> returnStmtPool;
|
||||||
static MemPoolImp<LabelStmt> labelStmtPool;
|
static MemPoolImp<LabelStmt> labelStmtPool;
|
||||||
@@ -48,6 +49,10 @@ void IfStmt::Accept(Visitor* v) {
|
|||||||
v->VisitIfStmt(this);
|
v->VisitIfStmt(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ForStmt::Accept(Visitor* v) {
|
||||||
|
v->VisitForStmt(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void JumpStmt::Accept(Visitor* v) {
|
void JumpStmt::Accept(Visitor* v) {
|
||||||
v->VisitJumpStmt(this);
|
v->VisitJumpStmt(this);
|
||||||
@@ -396,6 +401,7 @@ void BinaryOp::AdditiveOpTypeChecking() {
|
|||||||
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
|
||||||
auto lhsPtrType = lhsScalType->ToPointer();
|
auto lhsPtrType = lhsScalType->ToPointer();
|
||||||
auto rhsPtrType = rhsScalType->ToPointer();
|
auto rhsPtrType = rhsScalType->ToPointer();
|
||||||
|
std::cout << "adding" << std::endl;
|
||||||
if (lhsPtrType) {
|
if (lhsPtrType) {
|
||||||
if (op_ == '-') {
|
if (op_ == '-') {
|
||||||
if (rhsPtrType) {
|
if (rhsPtrType) {
|
||||||
@@ -430,6 +436,7 @@ void BinaryOp::AdditiveOpTypeChecking() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void BinaryOp::RangeOpTypeChecking() {
|
void BinaryOp::RangeOpTypeChecking() {
|
||||||
|
std::cout << "range" << std::endl;
|
||||||
auto lhsType = lhs_->Type()->ToArithm();
|
auto lhsType = lhs_->Type()->ToArithm();
|
||||||
auto rhsType = rhs_->Type()->ToArithm();
|
auto rhsType = rhs_->Type()->ToArithm();
|
||||||
if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
|
if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
|
||||||
@@ -546,6 +553,7 @@ void BinaryOp::AssignOpTypeChecking() {
|
|||||||
// The other constraints are lefted to cast operator
|
// The other constraints are lefted to cast operator
|
||||||
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
|
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
|
||||||
type_ = lhs_->Type();
|
type_ = lhs_->Type();
|
||||||
|
Broadcast();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -969,6 +977,11 @@ CompoundStmt* CompoundStmt::New(std::list<Stmt*>& stmts, ::Scope* scope) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ForStmt* ForStmt::New(Stmt* body, Stmt* init, Expr* cond, Expr* step) {
|
||||||
|
auto ret = new (forStmtPool.Alloc()) ForStmt(body, init, cond, step);
|
||||||
|
ret->pool_ = &forStmtPool;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
JumpStmt* JumpStmt::New(LabelStmt* label) {
|
JumpStmt* JumpStmt::New(LabelStmt* label) {
|
||||||
auto ret = new (jumpStmtPool.Alloc()) JumpStmt(label);
|
auto ret = new (jumpStmtPool.Alloc()) JumpStmt(label);
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||
#include "triton/lang/wgtcc/parser.h"
|
#include "triton/lang/wgtcc/parser.h"
|
||||||
#include "triton/lang/wgtcc/token.h"
|
#include "triton/lang/wgtcc/token.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/function.h"
|
||||||
|
|
||||||
// Helpers
|
// Helpers
|
||||||
void Generator::set_ret(ir::value* value) {
|
void Generator::set_ret(ir::value* value) {
|
||||||
@@ -25,10 +26,12 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
|
|
||||||
Visit(binary->lhs_);
|
Visit(binary->lhs_);
|
||||||
ir::value* lhs = ret_;
|
ir::value* lhs = ret_;
|
||||||
|
|
||||||
// op info
|
// op info
|
||||||
auto type = binary->lhs_->Type();
|
auto type = binary->lhs_->Type();
|
||||||
auto flt = type->IsFloat();
|
auto flt = type->IsFloat();
|
||||||
auto sign = !type->IsUnsigned();
|
auto sign = !type->IsUnsigned();
|
||||||
|
|
||||||
// return
|
// return
|
||||||
switch(binary->op_){
|
switch(binary->op_){
|
||||||
case Token::LOGICAL_AND: return set_ret(bld_->create_and(lhs, rhs));
|
case Token::LOGICAL_AND: return set_ret(bld_->create_and(lhs, rhs));
|
||||||
@@ -40,6 +43,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
|
|||||||
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
|
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
|
||||||
case '.': return error_not_implemented();
|
case '.': return error_not_implemented();
|
||||||
case ',': return error_not_implemented();
|
case ',': return error_not_implemented();
|
||||||
|
case Token::ELLIPSIS: {
|
||||||
|
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
|
||||||
|
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
|
||||||
|
if(!clhs || !crhs)
|
||||||
|
should_not_happen();
|
||||||
|
return set_ret(ir::constant_range::get(clhs, crhs));
|
||||||
|
}
|
||||||
case '+':
|
case '+':
|
||||||
if(binary->lhs_->Type()->ToPointer())
|
if(binary->lhs_->Type()->ToPointer())
|
||||||
return set_ret(bld_->create_gep(lhs, {rhs}));
|
return set_ret(bld_->create_gep(lhs, {rhs}));
|
||||||
@@ -210,6 +220,14 @@ void Generator::VisitDeclaration(Declaration* decl) {
|
|||||||
if(inits.size() > 1)
|
if(inits.size() > 1)
|
||||||
assert(false);
|
assert(false);
|
||||||
val = inits[0];
|
val = inits[0];
|
||||||
|
std::cout << obj->Name() << " " << val->get_type()->get_type_id() << " " << ty->get_type_id() << std::endl;
|
||||||
|
if(val->get_type()->is_tile_ty() && ty->is_tile_ty()) {
|
||||||
|
for(auto s: val->get_type()->get_tile_shapes())
|
||||||
|
std::cout << s->get_value() << std::endl;
|
||||||
|
std::cout << "---" << std::endl;
|
||||||
|
for(auto s: ty->get_tile_shapes())
|
||||||
|
std::cout << s->get_value() << std::endl;
|
||||||
|
}
|
||||||
assert(val->get_type() == ty);
|
assert(val->get_type() == ty);
|
||||||
// update scope symbols table
|
// update scope symbols table
|
||||||
const std::string &name = obj->Name();
|
const std::string &name = obj->Name();
|
||||||
@@ -258,6 +276,38 @@ void Generator::VisitIfStmt(IfStmt* ifStmt) {
|
|||||||
bld_->set_insert_point(endif_bb);
|
bld_->set_insert_point(endif_bb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Generator::VisitForStmt(ForStmt *forStmt) {
|
||||||
|
Stmt *init_ = forStmt->init_;
|
||||||
|
Expr *cond_ = forStmt->cond_;
|
||||||
|
Expr *step_ = forStmt->step_;
|
||||||
|
Stmt *body_ = forStmt->body_;
|
||||||
|
ir::basic_block *current_bb = bld_->get_insert_block();
|
||||||
|
ir::function *fn = current_bb->get_parent();
|
||||||
|
ir::basic_block *loop_bb = ir::basic_block::create(*ctx_, "loop", fn);
|
||||||
|
ir::basic_block *next_bb = ir::basic_block::create(*ctx_, "postloop", fn);
|
||||||
|
mod_->set_continue_fn([&](){
|
||||||
|
if(step_)
|
||||||
|
VisitExpr(step_);
|
||||||
|
VisitExpr(cond_);
|
||||||
|
ir::value *cond = ret_;
|
||||||
|
return bld_->create_cond_br(cond, loop_bb, next_bb);
|
||||||
|
});
|
||||||
|
VisitStmt(init_);
|
||||||
|
VisitExpr(cond_);
|
||||||
|
ir::value *cond = ret_;
|
||||||
|
bld_->create_cond_br(cond, loop_bb, next_bb);
|
||||||
|
bld_->set_insert_point(loop_bb);
|
||||||
|
VisitStmt(body_);
|
||||||
|
if(!is_terminator(ret_))
|
||||||
|
mod_->get_continue_fn()();
|
||||||
|
ir::basic_block *stop_bb = bld_->get_insert_block();
|
||||||
|
mod_->seal_block(stop_bb);
|
||||||
|
mod_->seal_block(loop_bb);
|
||||||
|
mod_->seal_block(bld_->get_insert_block());
|
||||||
|
mod_->seal_block(next_bb);
|
||||||
|
bld_->set_insert_point(next_bb);
|
||||||
|
}
|
||||||
|
|
||||||
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
|
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
|
||||||
return error_not_implemented();
|
return error_not_implemented();
|
||||||
}
|
}
|
||||||
@@ -277,7 +327,7 @@ void Generator::VisitLabelStmt(LabelStmt* labelStmt) {
|
|||||||
|
|
||||||
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
|
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
|
||||||
if (compoundStmt->scope_){
|
if (compoundStmt->scope_){
|
||||||
AllocObjects(compoundStmt->scope_);
|
// AllocObjects(compoundStmt->scope_);
|
||||||
pushScope();
|
pushScope();
|
||||||
}
|
}
|
||||||
for (auto stmt: compoundStmt->stmts_)
|
for (auto stmt: compoundStmt->stmts_)
|
||||||
@@ -287,32 +337,99 @@ void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitFuncDef(FuncDef* funcDef) {
|
void Generator::VisitFuncDef(FuncDef* funcDef) {
|
||||||
return error_not_implemented();
|
Stmt *body = funcDef->body_;
|
||||||
|
const std::string& name = funcDef->Name();
|
||||||
|
FuncType* type = funcDef->FuncType();
|
||||||
|
auto prototype = dynamic_cast<ir::function_type*>(GenIRType(type, *ctx_));
|
||||||
|
if(!prototype)
|
||||||
|
should_not_happen();
|
||||||
|
ir::function *fn = mod_->get_or_insert_function(name, prototype);
|
||||||
|
std::vector<ir::argument*> args = fn->args();
|
||||||
|
size_t i = 0;
|
||||||
|
for(Object* obj: type->Params()){
|
||||||
|
std::string name = obj->Name();
|
||||||
|
args[i]->set_name(name);
|
||||||
|
mod_->set_value(name, nullptr, args[i]);
|
||||||
|
mod_->get_scope().types[name] = args[i]->get_type();
|
||||||
|
}
|
||||||
|
ir::basic_block *entry = ir::basic_block::create(mod_->get_context(), "entry", fn);
|
||||||
|
mod_->seal_block(entry);
|
||||||
|
mod_->get_builder().set_insert_point(entry);
|
||||||
|
VisitStmt(body);
|
||||||
|
if(!dynamic_cast<ir::return_inst*>(ret_))
|
||||||
|
mod_->get_builder().create_ret_void();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::VisitTranslationUnit(TranslationUnit* unit) {
|
void Generator::VisitTranslationUnit(TranslationUnit* unit) {
|
||||||
|
pushScope();
|
||||||
for (auto extDecl: unit->ExtDecls())
|
for (auto extDecl: unit->ExtDecls())
|
||||||
Visit(extDecl);
|
Visit(extDecl);
|
||||||
|
popScope();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Generator::Gen(ir::module *mod) {
|
void Generator::Gen(ir::module *mod) {
|
||||||
pushScope();
|
|
||||||
mod_ = mod;
|
mod_ = mod;
|
||||||
ctx_ = &mod_->get_context();
|
ctx_ = &mod_->get_context();
|
||||||
bld_ = &mod_->get_builder();
|
bld_ = &mod_->get_builder();
|
||||||
std::unique_ptr<LValAssigner> assign(new LValAssigner(this));
|
assign_ = new LValAssigner(this);
|
||||||
assign_ = assign.get();
|
|
||||||
VisitTranslationUnit(parser_->Unit());
|
VisitTranslationUnit(parser_->Unit());
|
||||||
|
delete assign_;
|
||||||
assign_ = nullptr;
|
assign_ = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Triton-IR Values
|
// Triton-IR Values
|
||||||
|
|
||||||
ir::value* Generator::GenCastOp(ir::value* op, ir::type* type) {
|
ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) {
|
||||||
//TODO
|
if(dst_ty->is_tile_ty()) {
|
||||||
assert(false);
|
auto dst_shapes = dst_ty->get_tile_shapes();
|
||||||
return nullptr;
|
if(!src->get_type()->is_tile_ty())
|
||||||
|
return bld_->create_splat(src, dst_shapes);
|
||||||
|
auto src_shapes = src->get_type()->get_tile_shapes();
|
||||||
|
if(src_shapes.size() != dst_shapes.size())
|
||||||
|
return bld_->create_reshape(src, dst_shapes);
|
||||||
|
else
|
||||||
|
return bld_->create_broadcast(src, dst_shapes);
|
||||||
|
}
|
||||||
|
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
|
||||||
|
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
|
||||||
|
bool src_signed = false;
|
||||||
|
bool dst_signed = false;
|
||||||
|
|
||||||
|
if(src->get_type()->is_tile_ty())
|
||||||
|
dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type());
|
||||||
|
|
||||||
|
if(src_scalar_ty == dst_scalar_ty)
|
||||||
|
return src;
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||||
|
return bld_->create_si_to_fp(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||||
|
return bld_->create_ui_to_fp(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
|
||||||
|
return bld_->create_fp_to_si(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
|
||||||
|
return bld_->create_fp_to_ui(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
||||||
|
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
|
||||||
|
return bld_->create_fp_ext(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
||||||
|
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
|
||||||
|
return bld_->create_fp_trunc(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
|
||||||
|
src_scalar_ty->get_integer_bitwidth())
|
||||||
|
return bld_->create_int_cast(src, dst_ty, dst_signed);
|
||||||
|
|
||||||
|
else{
|
||||||
|
should_not_happen();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Triton-IR Types
|
// Triton-IR Types
|
||||||
|
@@ -450,8 +450,9 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
|||||||
// create ret shape
|
// create ret shape
|
||||||
TileType::ShapeInt shape;
|
TileType::ShapeInt shape;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
|
const Token* tok;
|
||||||
do {
|
do {
|
||||||
auto tok = ts_.Next();
|
tok = ts_.Next();
|
||||||
if(tok->tag_ == ':')
|
if(tok->tag_ == ':')
|
||||||
shape.push_back(lhsShape[i++]);
|
shape.push_back(lhsShape[i++]);
|
||||||
else if(tok->tag_ == Token::NEWAXIS)
|
else if(tok->tag_ == Token::NEWAXIS)
|
||||||
@@ -460,6 +461,8 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
|||||||
Error(tok, "only ':' and newaxis are supported in subscripts");
|
Error(tok, "only ':' and newaxis are supported in subscripts");
|
||||||
}while(ts_.Try(','));
|
}while(ts_.Try(','));
|
||||||
ts_.Expect(']');
|
ts_.Expect(']');
|
||||||
|
// if(lhsShape.size() > i)
|
||||||
|
// Error(tok, "broadcasting not using all operand axes");
|
||||||
// create ret tile
|
// create ret tile
|
||||||
TileType *retType = TileType::New(shape, lhsQual);
|
TileType *retType = TileType::New(shape, lhsQual);
|
||||||
return UnaryOp::New(Token::CAST, lhs, retType);
|
return UnaryOp::New(Token::CAST, lhs, retType);
|
||||||
@@ -2298,61 +2301,33 @@ IfStmt* Parser::ParseIfStmt() {
|
|||||||
continueDest_ = continueDestBackup; \
|
continueDest_ = continueDestBackup; \
|
||||||
}
|
}
|
||||||
|
|
||||||
CompoundStmt* Parser::ParseForStmt() {
|
ForStmt* Parser::ParseForStmt() {
|
||||||
EnterBlock();
|
EnterBlock();
|
||||||
ts_.Expect('(');
|
ts_.Expect('(');
|
||||||
|
// init
|
||||||
std::list<Stmt*> stmts;
|
Stmt* init = nullptr;
|
||||||
|
|
||||||
if (IsType(ts_.Peek())) {
|
if (IsType(ts_.Peek())) {
|
||||||
stmts.push_back(ParseDecl());
|
init = ParseDecl();
|
||||||
} else if (!ts_.Try(';')) {
|
} else if (!ts_.Try(';')) {
|
||||||
stmts.push_back(ParseExpr());
|
init = ParseExpr();
|
||||||
ts_.Expect(';');
|
ts_.Expect(';');
|
||||||
}
|
}
|
||||||
|
// cond
|
||||||
Expr* condExpr = nullptr;
|
Expr* cond = nullptr;
|
||||||
if (!ts_.Try(';')) {
|
if (!ts_.Try(';')) {
|
||||||
condExpr = ParseExpr();
|
cond = ParseExpr();
|
||||||
ts_.Expect(';');
|
ts_.Expect(';');
|
||||||
}
|
}
|
||||||
|
// step
|
||||||
Expr* stepExpr = nullptr;
|
Expr* step = nullptr;
|
||||||
if (!ts_.Try(')')) {
|
if (!ts_.Try(')')) {
|
||||||
stepExpr = ParseExpr();
|
step = ParseExpr();
|
||||||
ts_.Expect(')');
|
ts_.Expect(')');
|
||||||
}
|
}
|
||||||
|
// body
|
||||||
auto condLabel = LabelStmt::New();
|
Stmt* body = ParseStmt();
|
||||||
auto stepLabel = LabelStmt::New();
|
|
||||||
auto endLabel = LabelStmt::New();
|
|
||||||
stmts.push_back(condLabel);
|
|
||||||
if (condExpr) {
|
|
||||||
auto gotoEndStmt = JumpStmt::New(endLabel);
|
|
||||||
auto ifStmt = IfStmt::New(condExpr, EmptyStmt::New(), gotoEndStmt);
|
|
||||||
stmts.push_back(ifStmt);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 我们需要给break和continue语句提供相应的标号,不然不知往哪里跳
|
|
||||||
Stmt* bodyStmt;
|
|
||||||
ENTER_LOOP_BODY(endLabel, stepLabel);
|
|
||||||
bodyStmt = ParseStmt();
|
|
||||||
// 因为for的嵌套结构,在这里需要回复break和continue的目标标号
|
|
||||||
EXIT_LOOP_BODY()
|
|
||||||
|
|
||||||
stmts.push_back(bodyStmt);
|
|
||||||
stmts.push_back(stepLabel);
|
|
||||||
if (stepExpr)
|
|
||||||
stmts.push_back(stepExpr);
|
|
||||||
else
|
|
||||||
stmts.push_back(EmptyStmt::New());
|
|
||||||
stmts.push_back(JumpStmt::New(condLabel));
|
|
||||||
stmts.push_back(endLabel);
|
|
||||||
|
|
||||||
auto scope = curScope_;
|
|
||||||
ExitBlock();
|
ExitBlock();
|
||||||
|
return ForStmt::New(body, init, cond, step);
|
||||||
return CompoundStmt::New(stmts, scope);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -317,11 +317,13 @@ bool ArrayType::Compatible(const Type& other) const {
|
|||||||
|
|
||||||
bool TileType::Compatible(const Type& other) const {
|
bool TileType::Compatible(const Type& other) const {
|
||||||
// For two tile type to be compatible,
|
// For two tile type to be compatible,
|
||||||
// the element types must be compatible, and have same shape
|
// the element types must be compatible
|
||||||
// if both specified
|
// and they must have compatible shapes
|
||||||
auto otherTile = other.ToTile();
|
auto otherTile = other.ToTile();
|
||||||
if(!otherTile) return false;
|
if(!otherTile)
|
||||||
if (!derived_->Compatible(*otherTile->derived_)) return false;
|
return false;
|
||||||
|
if (!derived_->Compatible(*otherTile->derived_))
|
||||||
|
return false;
|
||||||
// The shapes should be equal if both specified
|
// The shapes should be equal if both specified
|
||||||
if(complete_ && otherTile->complete_)
|
if(complete_ && otherTile->complete_)
|
||||||
return shape_ == otherTile->shape_;
|
return shape_ == otherTile->shape_;
|
||||||
|
@@ -8,6 +8,7 @@
|
|||||||
#include "triton/lang/lang.h"
|
#include "triton/lang/lang.h"
|
||||||
#include "triton/lang/wgtcc/cpp.h"
|
#include "triton/lang/wgtcc/cpp.h"
|
||||||
#include "triton/lang/wgtcc/parser.h"
|
#include "triton/lang/wgtcc/parser.h"
|
||||||
|
#include "triton/lang/wgtcc/code_gen.h"
|
||||||
#include "triton/driver/device.h"
|
#include "triton/driver/device.h"
|
||||||
#include "triton/driver/stream.h"
|
#include "triton/driver/stream.h"
|
||||||
#include "triton/driver/kernel.h"
|
#include "triton/driver/kernel.h"
|
||||||
@@ -133,6 +134,9 @@ triton::lang::translation_unit *function::make_ast(const char *csrc) {
|
|||||||
cpp.Process(ts);
|
cpp.Process(ts);
|
||||||
Parser parser(ts);
|
Parser parser(ts);
|
||||||
parser.Parse();
|
parser.Parse();
|
||||||
|
Generator gen(&parser);
|
||||||
|
ir::module out("", ctx_);
|
||||||
|
gen.Gen(&out);
|
||||||
exit(EXIT_FAILURE);
|
exit(EXIT_FAILURE);
|
||||||
|
|
||||||
// if (only_preprocess) {
|
// if (only_preprocess) {
|
||||||
|
Reference in New Issue
Block a user