more debugging

This commit is contained in:
Philippe Tillet
2019-08-21 21:53:41 -07:00
parent a23225ad37
commit a6ec807223
10 changed files with 213 additions and 77 deletions

View File

@@ -138,6 +138,25 @@ private:
Stmt* else_;
};
class ForStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ForStmt* New(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr);
virtual ~ForStmt() {}
virtual void Accept(Visitor* v);
protected:
ForStmt(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr)
: body_(body), init_(init), cond_(cond), step_(step) {}
private:
Stmt* body_;
Stmt* init_;
Expr* cond_;
Expr* step_;
};
class JumpStmt : public Stmt {
template<typename T> friend class Evaluator;

View File

@@ -31,7 +31,7 @@ using StaticInitList = std::vector<StaticInitializer>;
// Error
inline void should_not_happen() { assert(false); }
inline void error_not_implemented() { assert(false); }
inline void error_not_implemented() { throw std::runtime_error("not implemented"); }
class Generator: public Visitor {
friend class Evaluator<Addr>;
@@ -48,32 +48,33 @@ protected:
public:
Generator(Parser* parser) : parser_(parser) {}
virtual void Visit(ASTNode* node) { node->Accept(this); }
void Visit(ASTNode* node) { node->Accept(this); }
void VisitExpr(Expr* expr) { expr->Accept(this); }
void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
// Expression
virtual void VisitBinaryOp(BinaryOp* binaryOp);
virtual void VisitUnaryOp(UnaryOp* unaryOp);
virtual void VisitConditionalOp(ConditionalOp* condOp);
virtual void VisitFuncCall(FuncCall* funcCall);
virtual void VisitObject(Object* obj);
virtual void VisitEnumerator(Enumerator* enumer);
virtual void VisitIdentifier(Identifier* ident);
virtual void VisitConstant(Constant* cons);
virtual void VisitTempVar(TempVar* tempVar);
void VisitBinaryOp(BinaryOp* binaryOp);
void VisitUnaryOp(UnaryOp* unaryOp);
void VisitConditionalOp(ConditionalOp* condOp);
void VisitFuncCall(FuncCall* funcCall);
void VisitObject(Object* obj);
void VisitEnumerator(Enumerator* enumer);
void VisitIdentifier(Identifier* ident);
void VisitConstant(Constant* cons);
void VisitTempVar(TempVar* tempVar);
// Statement
virtual void VisitDeclaration(Declaration* init);
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt);
virtual void VisitIfStmt(IfStmt* ifStmt);
virtual void VisitJumpStmt(JumpStmt* jumpStmt);
virtual void VisitReturnStmt(ReturnStmt* returnStmt);
virtual void VisitLabelStmt(LabelStmt* labelStmt);
virtual void VisitCompoundStmt(CompoundStmt* compoundStmt);
void VisitDeclaration(Declaration* init);
void VisitEmptyStmt(EmptyStmt* emptyStmt);
void VisitIfStmt(IfStmt* ifStmt);
void VisitForStmt(ForStmt* ifStmt);
void VisitJumpStmt(JumpStmt* jumpStmt);
void VisitReturnStmt(ReturnStmt* returnStmt);
void VisitLabelStmt(LabelStmt* labelStmt);
void VisitCompoundStmt(CompoundStmt* compoundStmt);
virtual void VisitFuncDef(FuncDef* funcDef);
virtual void VisitTranslationUnit(TranslationUnit* unit);
void VisitFuncDef(FuncDef* funcDef);
void VisitTranslationUnit(TranslationUnit* unit);
void Gen(ir::module *mod);
@@ -127,6 +128,7 @@ public:
void VisitDeclaration(Declaration*) { should_not_happen(); }
void VisitEmptyStmt(EmptyStmt*) { should_not_happen(); }
void VisitIfStmt(IfStmt*) { should_not_happen(); }
void VisitForStmt(ForStmt*) { should_not_happen(); }
void VisitJumpStmt(JumpStmt*) { should_not_happen(); }
void VisitReturnStmt(ReturnStmt*) { should_not_happen(); }
void VisitLabelStmt(LabelStmt*) { should_not_happen(); }

View File

@@ -45,6 +45,7 @@ public:
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
@@ -100,6 +101,7 @@ public:
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}

View File

@@ -146,7 +146,7 @@ public:
CompoundStmt* ParseSwitchStmt();
CompoundStmt* ParseWhileStmt();
CompoundStmt* ParseDoStmt();
CompoundStmt* ParseForStmt();
ForStmt *ParseForStmt();
JumpStmt* ParseGotoStmt();
JumpStmt* ParseContinueStmt();
JumpStmt* ParseBreakStmt();

View File

@@ -14,6 +14,7 @@ class TempVar;
class Declaration;
class IfStmt;
class ForStmt;
class JumpStmt;
class ReturnStmt;
class LabelStmt;
@@ -38,6 +39,7 @@ public:
virtual void VisitDeclaration(Declaration* init) = 0;
virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
virtual void VisitForStmt(ForStmt* ifStmt) = 0;
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;

View File

@@ -18,6 +18,7 @@ static MemPoolImp<TempVar> tempVarPool;
static MemPoolImp<UnaryOp> unaryOpPool;
static MemPoolImp<EmptyStmt> emptyStmtPool;
static MemPoolImp<IfStmt> ifStmtPool;
static MemPoolImp<ForStmt> forStmtPool;
static MemPoolImp<JumpStmt> jumpStmtPool;
static MemPoolImp<ReturnStmt> returnStmtPool;
static MemPoolImp<LabelStmt> labelStmtPool;
@@ -48,6 +49,10 @@ void IfStmt::Accept(Visitor* v) {
v->VisitIfStmt(this);
}
void ForStmt::Accept(Visitor* v) {
v->VisitForStmt(this);
}
void JumpStmt::Accept(Visitor* v) {
v->VisitJumpStmt(this);
@@ -396,6 +401,7 @@ void BinaryOp::AdditiveOpTypeChecking() {
::Type* rhsScalType = TryExtractScalarType(this, rhs_);
auto lhsPtrType = lhsScalType->ToPointer();
auto rhsPtrType = rhsScalType->ToPointer();
std::cout << "adding" << std::endl;
if (lhsPtrType) {
if (op_ == '-') {
if (rhsPtrType) {
@@ -430,6 +436,7 @@ void BinaryOp::AdditiveOpTypeChecking() {
}
void BinaryOp::RangeOpTypeChecking() {
std::cout << "range" << std::endl;
auto lhsType = lhs_->Type()->ToArithm();
auto rhsType = rhs_->Type()->ToArithm();
if(!lhsType || !lhsType->IsInteger() || !rhsType || !rhsType->IsInteger())
@@ -546,6 +553,7 @@ void BinaryOp::AssignOpTypeChecking() {
// The other constraints are lefted to cast operator
rhs_ = Expr::MayCast(rhs_, ScalarOrLikeTile(rhs_, lhsScalType));
type_ = lhs_->Type();
Broadcast();
}
@@ -969,6 +977,11 @@ CompoundStmt* CompoundStmt::New(std::list<Stmt*>& stmts, ::Scope* scope) {
return ret;
}
ForStmt* ForStmt::New(Stmt* body, Stmt* init, Expr* cond, Expr* step) {
auto ret = new (forStmtPool.Alloc()) ForStmt(body, init, cond, step);
ret->pool_ = &forStmtPool;
return ret;
}
JumpStmt* JumpStmt::New(LabelStmt* label) {
auto ret = new (jumpStmtPool.Alloc()) JumpStmt(label);

View File

@@ -3,6 +3,7 @@
#include "triton/lang/wgtcc/parser.h"
#include "triton/lang/wgtcc/token.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
// Helpers
void Generator::set_ret(ir::value* value) {
@@ -25,10 +26,12 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
Visit(binary->lhs_);
ir::value* lhs = ret_;
// op info
auto type = binary->lhs_->Type();
auto flt = type->IsFloat();
auto sign = !type->IsUnsigned();
// return
switch(binary->op_){
case Token::LOGICAL_AND: return set_ret(bld_->create_and(lhs, rhs));
@@ -40,6 +43,13 @@ void Generator::VisitBinaryOp(BinaryOp* binary) {
case Token::RIGHT: return set_ret(bld_->create_lshr(lhs, rhs));
case '.': return error_not_implemented();
case ',': return error_not_implemented();
case Token::ELLIPSIS: {
auto clhs = dynamic_cast<ir::constant_int*>(lhs);
auto crhs = dynamic_cast<ir::constant_int*>(rhs);
if(!clhs || !crhs)
should_not_happen();
return set_ret(ir::constant_range::get(clhs, crhs));
}
case '+':
if(binary->lhs_->Type()->ToPointer())
return set_ret(bld_->create_gep(lhs, {rhs}));
@@ -210,6 +220,14 @@ void Generator::VisitDeclaration(Declaration* decl) {
if(inits.size() > 1)
assert(false);
val = inits[0];
std::cout << obj->Name() << " " << val->get_type()->get_type_id() << " " << ty->get_type_id() << std::endl;
if(val->get_type()->is_tile_ty() && ty->is_tile_ty()) {
for(auto s: val->get_type()->get_tile_shapes())
std::cout << s->get_value() << std::endl;
std::cout << "---" << std::endl;
for(auto s: ty->get_tile_shapes())
std::cout << s->get_value() << std::endl;
}
assert(val->get_type() == ty);
// update scope symbols table
const std::string &name = obj->Name();
@@ -258,6 +276,38 @@ void Generator::VisitIfStmt(IfStmt* ifStmt) {
bld_->set_insert_point(endif_bb);
}
void Generator::VisitForStmt(ForStmt *forStmt) {
Stmt *init_ = forStmt->init_;
Expr *cond_ = forStmt->cond_;
Expr *step_ = forStmt->step_;
Stmt *body_ = forStmt->body_;
ir::basic_block *current_bb = bld_->get_insert_block();
ir::function *fn = current_bb->get_parent();
ir::basic_block *loop_bb = ir::basic_block::create(*ctx_, "loop", fn);
ir::basic_block *next_bb = ir::basic_block::create(*ctx_, "postloop", fn);
mod_->set_continue_fn([&](){
if(step_)
VisitExpr(step_);
VisitExpr(cond_);
ir::value *cond = ret_;
return bld_->create_cond_br(cond, loop_bb, next_bb);
});
VisitStmt(init_);
VisitExpr(cond_);
ir::value *cond = ret_;
bld_->create_cond_br(cond, loop_bb, next_bb);
bld_->set_insert_point(loop_bb);
VisitStmt(body_);
if(!is_terminator(ret_))
mod_->get_continue_fn()();
ir::basic_block *stop_bb = bld_->get_insert_block();
mod_->seal_block(stop_bb);
mod_->seal_block(loop_bb);
mod_->seal_block(bld_->get_insert_block());
mod_->seal_block(next_bb);
bld_->set_insert_point(next_bb);
}
void Generator::VisitJumpStmt(JumpStmt* jumpStmt) {
return error_not_implemented();
}
@@ -277,7 +327,7 @@ void Generator::VisitLabelStmt(LabelStmt* labelStmt) {
void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
if (compoundStmt->scope_){
AllocObjects(compoundStmt->scope_);
// AllocObjects(compoundStmt->scope_);
pushScope();
}
for (auto stmt: compoundStmt->stmts_)
@@ -287,32 +337,99 @@ void Generator::VisitCompoundStmt(CompoundStmt* compoundStmt) {
}
void Generator::VisitFuncDef(FuncDef* funcDef) {
return error_not_implemented();
Stmt *body = funcDef->body_;
const std::string& name = funcDef->Name();
FuncType* type = funcDef->FuncType();
auto prototype = dynamic_cast<ir::function_type*>(GenIRType(type, *ctx_));
if(!prototype)
should_not_happen();
ir::function *fn = mod_->get_or_insert_function(name, prototype);
std::vector<ir::argument*> args = fn->args();
size_t i = 0;
for(Object* obj: type->Params()){
std::string name = obj->Name();
args[i]->set_name(name);
mod_->set_value(name, nullptr, args[i]);
mod_->get_scope().types[name] = args[i]->get_type();
}
ir::basic_block *entry = ir::basic_block::create(mod_->get_context(), "entry", fn);
mod_->seal_block(entry);
mod_->get_builder().set_insert_point(entry);
VisitStmt(body);
if(!dynamic_cast<ir::return_inst*>(ret_))
mod_->get_builder().create_ret_void();
}
void Generator::VisitTranslationUnit(TranslationUnit* unit) {
pushScope();
for (auto extDecl: unit->ExtDecls())
Visit(extDecl);
popScope();
}
void Generator::Gen(ir::module *mod) {
pushScope();
mod_ = mod;
ctx_ = &mod_->get_context();
bld_ = &mod_->get_builder();
std::unique_ptr<LValAssigner> assign(new LValAssigner(this));
assign_ = assign.get();
assign_ = new LValAssigner(this);
VisitTranslationUnit(parser_->Unit());
delete assign_;
assign_ = nullptr;
}
// Triton-IR Values
ir::value* Generator::GenCastOp(ir::value* op, ir::type* type) {
//TODO
assert(false);
ir::value* Generator::GenCastOp(ir::value* src, ir::type* dst_ty) {
if(dst_ty->is_tile_ty()) {
auto dst_shapes = dst_ty->get_tile_shapes();
if(!src->get_type()->is_tile_ty())
return bld_->create_splat(src, dst_shapes);
auto src_shapes = src->get_type()->get_tile_shapes();
if(src_shapes.size() != dst_shapes.size())
return bld_->create_reshape(src, dst_shapes);
else
return bld_->create_broadcast(src, dst_shapes);
}
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
bool src_signed = false;
bool dst_signed = false;
if(src->get_type()->is_tile_ty())
dst_ty = ir::tile_type::get_same_shapes(dst_scalar_ty, src->get_type());
if(src_scalar_ty == dst_scalar_ty)
return src;
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_si_to_fp(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
return bld_->create_ui_to_fp(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
return bld_->create_fp_to_si(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
return bld_->create_fp_to_ui(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
return bld_->create_fp_ext(src, dst_ty);
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
return bld_->create_fp_trunc(src, dst_ty);
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
src_scalar_ty->get_integer_bitwidth())
return bld_->create_int_cast(src, dst_ty, dst_signed);
else{
should_not_happen();
return nullptr;
}
}
// Triton-IR Types

View File

@@ -450,8 +450,9 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
// create ret shape
TileType::ShapeInt shape;
size_t i = 0;
const Token* tok;
do {
auto tok = ts_.Next();
tok = ts_.Next();
if(tok->tag_ == ':')
shape.push_back(lhsShape[i++]);
else if(tok->tag_ == Token::NEWAXIS)
@@ -460,6 +461,8 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
Error(tok, "only ':' and newaxis are supported in subscripts");
}while(ts_.Try(','));
ts_.Expect(']');
// if(lhsShape.size() > i)
// Error(tok, "broadcasting not using all operand axes");
// create ret tile
TileType *retType = TileType::New(shape, lhsQual);
return UnaryOp::New(Token::CAST, lhs, retType);
@@ -2298,61 +2301,33 @@ IfStmt* Parser::ParseIfStmt() {
continueDest_ = continueDestBackup; \
}
CompoundStmt* Parser::ParseForStmt() {
ForStmt* Parser::ParseForStmt() {
EnterBlock();
ts_.Expect('(');
std::list<Stmt*> stmts;
// init
Stmt* init = nullptr;
if (IsType(ts_.Peek())) {
stmts.push_back(ParseDecl());
init = ParseDecl();
} else if (!ts_.Try(';')) {
stmts.push_back(ParseExpr());
init = ParseExpr();
ts_.Expect(';');
}
Expr* condExpr = nullptr;
// cond
Expr* cond = nullptr;
if (!ts_.Try(';')) {
condExpr = ParseExpr();
cond = ParseExpr();
ts_.Expect(';');
}
Expr* stepExpr = nullptr;
// step
Expr* step = nullptr;
if (!ts_.Try(')')) {
stepExpr = ParseExpr();
step = ParseExpr();
ts_.Expect(')');
}
auto condLabel = LabelStmt::New();
auto stepLabel = LabelStmt::New();
auto endLabel = LabelStmt::New();
stmts.push_back(condLabel);
if (condExpr) {
auto gotoEndStmt = JumpStmt::New(endLabel);
auto ifStmt = IfStmt::New(condExpr, EmptyStmt::New(), gotoEndStmt);
stmts.push_back(ifStmt);
}
// 我们需要给break和continue语句提供相应的标号不然不知往哪里跳
Stmt* bodyStmt;
ENTER_LOOP_BODY(endLabel, stepLabel);
bodyStmt = ParseStmt();
// 因为for的嵌套结构在这里需要回复break和continue的目标标号
EXIT_LOOP_BODY()
stmts.push_back(bodyStmt);
stmts.push_back(stepLabel);
if (stepExpr)
stmts.push_back(stepExpr);
else
stmts.push_back(EmptyStmt::New());
stmts.push_back(JumpStmt::New(condLabel));
stmts.push_back(endLabel);
auto scope = curScope_;
// body
Stmt* body = ParseStmt();
ExitBlock();
return CompoundStmt::New(stmts, scope);
return ForStmt::New(body, init, cond, step);
}

View File

@@ -317,11 +317,13 @@ bool ArrayType::Compatible(const Type& other) const {
bool TileType::Compatible(const Type& other) const {
// For two tile type to be compatible,
// the element types must be compatible, and have same shape
// if both specified
// the element types must be compatible
// and they must have compatible shapes
auto otherTile = other.ToTile();
if(!otherTile) return false;
if (!derived_->Compatible(*otherTile->derived_)) return false;
if(!otherTile)
return false;
if (!derived_->Compatible(*otherTile->derived_))
return false;
// The shapes should be equal if both specified
if(complete_ && otherTile->complete_)
return shape_ == otherTile->shape_;

View File

@@ -8,6 +8,7 @@
#include "triton/lang/lang.h"
#include "triton/lang/wgtcc/cpp.h"
#include "triton/lang/wgtcc/parser.h"
#include "triton/lang/wgtcc/code_gen.h"
#include "triton/driver/device.h"
#include "triton/driver/stream.h"
#include "triton/driver/kernel.h"
@@ -133,6 +134,9 @@ triton::lang::translation_unit *function::make_ast(const char *csrc) {
cpp.Process(ts);
Parser parser(ts);
parser.Parse();
Generator gen(&parser);
ir::module out("", ctx_);
gen.Gen(&out);
exit(EXIT_FAILURE);
// if (only_preprocess) {