[Code generation] bugfixes in type logic

This commit is contained in:
Philippe Tillet
2018-12-18 23:02:28 -05:00
parent 97acf52dca
commit 176a437b21
4 changed files with 236 additions and 99 deletions

View File

@@ -1,6 +1,7 @@
#include <cstring>
#include <cstdio>
#include "ast.h"
#include "codegen.h"
typedef struct yy_buffer_state * YY_BUFFER_STATE;
extern int yyparse();
@@ -11,7 +12,7 @@ extern translation_unit *ast_root;
const char src[] =
"\
void test(int32 id){\
void test(fp32 *A, fp32 *B, fp32 *C){\
fp32 c[16, 16] = {0};\
int32 i = 0;\
i += 1;\
@@ -23,5 +24,8 @@ int main() {
yyparse();
yy_delete_buffer(buffer);
translation_unit *program = ast_root;
tdl::context context;
tdl::module module("matrix", &context);
program->codegen(&module);
return 0;
}

View File

@@ -8,7 +8,9 @@
namespace llvm{
class LLVMType;
class Function;
class Value;
class Type;
}
@@ -63,9 +65,17 @@ template<class T>
class list: public node {
public:
list(const T& x): values_{x} {}
node* append(const T& x) { values_.push_back(x); return this;}
void codegen(module* mod) { for(T x: values_){ x->codegen(mod); } }
const std::list<T> &values() const { return values_; }
node* append(const T& x){
values_.push_back(x);
return this;
}
void codegen(module* mod)
{ for(T x: values_){ x->codegen(mod); } }
const std::list<T> &values() const
{ return values_; }
private:
std::list<T> values_;
@@ -91,13 +101,6 @@ private:
const int value_;
};
class identifier: public node{
public:
identifier(char *&name): name_(name) { }
private:
std::string name_;
};
class string_literal: public node{
public:
@@ -202,106 +205,138 @@ private:
class no_op: public statement { };
// Types
class declarator: public node{
class declaration_specifier: public node{
public:
virtual llvm::LLVMType llvm_type(TYPE_T spec) const = 0;
declaration_specifier(TYPE_T spec)
: spec_(spec) { }
llvm::Type* type(module *mod) const;
private:
const TYPE_T spec_;
};
class pointer_declarator: public declarator{
class declarator;
class parameter: public node {
public:
pointer_declarator(unsigned order)
: order_(order) { }
parameter(node *spec, node *decl)
: spec_((declaration_specifier*)spec),
decl_((declarator*)decl) { }
pointer_declarator *inc(){
order_ += 1;
llvm::Type* type(module *mod) const;
public:
const declaration_specifier *spec_;
const declarator *decl_;
};
/* Declarators */
class pointer;
class identifier;
class declarator: public node{
virtual llvm::Type* type_impl(module*mod, llvm::Type *type) const = 0;
public:
declarator(node *lhs)
: lhs_((declarator*)lhs), ptr_(nullptr){ }
llvm::Type* type(module*mod, llvm::Type *type) const;
const identifier* id() const {
return (const identifier*)lhs_;
}
declarator *set_ptr(node *ptr){
ptr_ = (pointer*)ptr;
return this;
}
llvm::LLVMType llvm_type(TYPE_T spec) const;
private:
unsigned order_;
protected:
declarator *lhs_;
pointer *ptr_;
};
class tile_declarator: public declarator{
public:
tile_declarator(node *decl, node *shapes)
: decl_(decl), shapes_((list<constant*>*)(shapes)) { }
llvm::LLVMType llvm_type(TYPE_T spec) const;
class identifier: public declarator{
llvm::Type* type_impl(module*mod, llvm::Type *type) const;
public:
identifier(char *&name): declarator(nullptr), name_(name) { }
const std::string &name() const;
private:
std::string name_;
};
class pointer: public declarator{
private:
llvm::Type* type_impl(module *mod, llvm::Type *type) const;
public:
pointer(node *id): declarator(id) { }
};
class tile: public declarator{
private:
llvm::Type* type_impl(module *mod, llvm::Type *type) const;
public:
tile(node *id, node *shapes)
: declarator(id), shapes_((list<constant*>*)(shapes)) { }
public:
const node* decl_;
const list<constant*>* shapes_;
};
class function_declarator: public declarator{
public:
function_declarator(node *decl, node *args)
: decl_(decl), args_((list<node*>*)args) { }
llvm::LLVMType llvm_type(TYPE_T spec) const;
class function: public declarator{
private:
llvm::Type* type_impl(module *mod, llvm::Type *type) const;
public:
const node* decl_;
const list<node*>* args_;
function(node *id, node *args)
: declarator(id), args_((list<parameter*>*)args) { }
public:
const list<parameter*>* args_;
};
class compound_declarator: public declarator{
public:
compound_declarator(node *ptr, node *tile)
: ptr_(ptr), tile_(tile) { }
llvm::LLVMType llvm_type(TYPE_T spec) const;
class initializer : public declarator{
private:
llvm::Type* type_impl(module* mod, llvm::Type *type) const;
public:
const node *ptr_;
const node *tile_;
};
initializer(node *id, node *initializer)
: declarator(id), initializer_(initializer){ }
class init_declarator : public declarator{
public:
init_declarator(node *decl, node *initializer)
: decl_(decl), initializer_(initializer){ }
llvm::LLVMType llvm_type(TYPE_T spec) const;
public:
const node *decl_;
const node *initializer_;
};
class parameter: public node {
public:
parameter(TYPE_T spec, node *decl)
: spec_(spec), decl_(decl) { }
llvm::LLVMType* llvm_type() const;
public:
const TYPE_T spec_;
const node *decl_;
};
class type: public node{
public:
type(TYPE_T spec, node * decl)
: spec_(spec), decl_(decl) { }
: spec_(spec), decl_((declarator*)decl) { }
public:
const TYPE_T spec_;
const node *decl_;
const declarator *decl_;
};
/* Function definition */
class function_definition: public node{
public:
function_definition(TYPE_T spec, node *header, node *body)
: spec_(spec), header_((function_declarator *)header), body_((compound_statement*)body) { }
function_definition(node *spec, node *header, node *body)
: spec_((declaration_specifier*)spec), header_((function *)header), body_((compound_statement*)body) { }
void codegen(module* mod);
public:
const TYPE_T spec_;
const function_declarator *header_;
const declaration_specifier *spec_;
const function *header_;
const compound_statement *body_;
};

View File

@@ -74,17 +74,17 @@ type_specifier
;
pointer
: '*' { $$ = new pointer_declarator(1); }
| '*' pointer { $$ = ((pointer_declarator*)$1)->inc(); }
: '*' { $$ = new pointer(nullptr); }
| '*' pointer { $$ = new pointer($1); }
abstract_declarator
: pointer { $$ = $1; }
| pointer direct_abstract_declarator { $$ = ((declarator*)$2)->set_ptr($1); }
| direct_abstract_declarator { $$ = $1; }
| pointer direct_abstract_declarator { $$ = new compound_declarator($1, $2); }
;
direct_abstract_declarator
: '[' constant_list ']' { $$ = new tile_declarator(nullptr, $1); }
: '[' constant_list ']' { $$ = new tile(nullptr, $1); }
constant :
CONSTANT { $$ = new constant(atoi(yytext)); }
@@ -241,9 +241,9 @@ statement
compound_statement
: '{' '}' { $$ = new compound_statement(nullptr, nullptr); }
| '{' statement_list '}' { $$ = new compound_statement(nullptr, $1); }
| '{' declaration_list '}' { $$ = new compound_statement($1, nullptr); }
| '{' declaration_list statement_list '}' { $$ = new compound_statement($1, $2);}
| '{' statement_list '}' { $$ = new compound_statement(nullptr, $2); }
| '{' declaration_list '}' { $$ = new compound_statement($2, nullptr); }
| '{' declaration_list statement_list '}' { $$ = new compound_statement($2, $3);}
;
@@ -262,13 +262,13 @@ expression_statement
;
selection_statement
: IF '(' expression ')' statement { $$ = new selection_statement($1, $2); }
| IF '(' expression ')' statement ELSE statement { $$ = new selection_statement($1, $2, $3); }
: IF '(' expression ')' statement { $$ = new selection_statement($1, $3); }
| IF '(' expression ')' statement ELSE statement { $$ = new selection_statement($1, $3, $5); }
;
iteration_statement
: FOR '(' expression_statement expression_statement ')' statement { $$ = new iteration_statement($1, $2, NULL, $3); }
| FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($1, $2, $3, $3); }
: FOR '(' expression_statement expression_statement ')' statement { $$ = new iteration_statement($1, $3, NULL, $4); }
| FOR '(' expression_statement expression_statement expression ')' statement { $$ = new iteration_statement($1, $3, $4, $5); }
;
@@ -279,30 +279,30 @@ iteration_statement
direct_declarator
: identifier { $$ = $1; }
| identifier '[' constant_list ']' { $$ = new tile_declarator($1, $2); }
| identifier '(' parameter_list ')' { $$ = new function_declarator($1, $2); }
| identifier '(' ')' { $$ = new function_declarator($1, nullptr); }
;
| identifier '[' constant_list ']' { $$ = new tile($1, $3); }
| identifier '(' parameter_list ')' { $$ = new function($1, $3); }
| identifier '(' ')' { $$ = new function($1, nullptr); }
;
parameter_list
: parameter_declaration { $$ = new list<parameter*>((parameter*)$1); }
| parameter_list ',' parameter_declaration { $$ = append_ptr_list<parameter>($1, $2); }
| parameter_list ',' parameter_declaration { $$ = append_ptr_list<parameter>($1, $3); }
;
parameter_declaration
: declaration_specifiers declarator { $$ = new parameter(get_type_spec($1), $2); }
| declaration_specifiers abstract_declarator { $$ = new parameter(get_type_spec($1), $2); }
: declaration_specifiers declarator { $$ = new parameter($1, $2); }
| declaration_specifiers abstract_declarator { $$ = new parameter($1, $2); }
;
declaration_specifiers
: type_specifier { $$ = $1; }
: type_specifier { $$ = new declaration_specifier(get_type_spec($1)); }
;
init_declarator_list
: init_declarator { $$ = new list<init_declarator*>((init_declarator*)$1); }
| init_declarator_list ',' init_declarator { $$ = append_ptr_list<init_declarator>($1, $2); }
: init_declarator { $$ = new list<initializer*>((initializer*)$1); }
| init_declarator_list ',' init_declarator { $$ = append_ptr_list<initializer>($1, $3); }
;
declaration
@@ -311,18 +311,18 @@ declaration
;
declarator
: pointer direct_declarator { $$ = new compound_declarator($1, $2); }
: pointer direct_declarator { $$ = ((declarator*)$2)->set_ptr($1); }
| direct_declarator { $$ = $1; }
;
initializer
: assignment_expression { $$ = $1; }
| '{' constant '}' { $$ = $1; }
| '{' constant '}' { $$ = $2; }
;
init_declarator
: declarator { $$ = new init_declarator($1, nullptr); }
| declarator '=' initializer { $$ = new init_declarator($1, $2); }
: declarator { $$ = new initializer($1, nullptr); }
| declarator '=' initializer { $$ = new initializer($1, $3); }
;
/* -------------------------- */
@@ -330,7 +330,7 @@ init_declarator
/* -------------------------- */
translation_unit
: external_declaration { $$ = new translation_unit($1); }
: external_declaration { ast_root = new translation_unit($1); $$ = ast_root; }
| translation_unit external_declaration { $$ = ((translation_unit*)($1))->add($2); }
;
@@ -340,6 +340,6 @@ external_declaration
;
function_definition
: type_specifier declarator compound_statement { $$ = new function_definition(get_type_spec($1), $2, $3); }
: declaration_specifiers declarator compound_statement { $$ = new function_definition($1, $2, $3); }
;

View File

@@ -1,12 +1,110 @@
#include "ast.h"
#include "codegen.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
using namespace llvm;
namespace tdl{
/* Context */
context::context() { }
LLVMContext *context::handle() {
return &handle_;
}
/* Module */
module::module(const std::string &name, context *ctx)
: handle_(name.c_str(), *ctx->handle()), builder_(*ctx->handle()) {
}
llvm::Module* module::handle() {
return &handle_;
}
llvm::IRBuilder<>& module::builder() {
return builder_;
}
namespace ast{
void translation_unit::codegen(module *mod)
{ decls_->codegen(mod); }
/* Translation unit */
void translation_unit::codegen(module *mod){
decls_->codegen(mod);
}
/* Declaration specifier */
Type* declaration_specifier::type(module *mod) const {
LLVMContext &ctx = mod->handle()->getContext();
switch (spec_) {
case VOID_T: return Type::getVoidTy(ctx);
case INT8_T: return IntegerType::get(ctx, 8);
case INT16_T: return IntegerType::get(ctx, 16);
case INT32_T: return IntegerType::get(ctx, 32);
case INT64_T: return IntegerType::get(ctx, 64);
case FLOAT32_T: return Type::getFloatTy(ctx);
case FLOAT64_T: return Type::getDoubleTy(ctx);
default: assert(false && "unreachable"); throw;
}
}
/* Parameter */
Type* parameter::type(module *mod) const {
return decl_->type(mod, spec_->type(mod));
}
/* Declarators */
Type* declarator::type(module *mod, Type *type) const{
if(ptr_)
return type_impl(mod, ptr_->type(mod, type));
return type_impl(mod, type);
}
// Identifier
Type* identifier::type_impl(module *, Type *type) const{
return type;
}
const std::string &identifier::name() const{
return name_;
}
// Tile
Type* tile::type_impl(module*, Type *type) const{
return TileType::get(type, shapes_->values().size());
}
// Initializer
Type* initializer::type_impl(module *, Type *type) const{
return type;
}
// Pointer
Type* pointer::type_impl(module*, Type *type) const{
return PointerType::get(type, 1);
}
// Function
Type* function::type_impl(module*mod, Type *type) const{
SmallVector<Type*, 8> types;
for(parameter* param: args_->values()){
types.push_back(param->type(mod));
}
return FunctionType::get(type, types, false);
}
/* Function definition */
void function_definition::codegen(module *mod){
llvm::FunctionType *prototype = (llvm::FunctionType *)header_->type(mod, spec_->type(mod));
const std::string &name = header_->id()->name();
llvm::Function *fn = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, name, mod->handle());
llvm::BasicBlock::Create(mod->handle()->getContext(), "entry", fn);
mod->builder().SetInsertPoint();
}
}