[triton/ast]: cleaned the ast module
This commit is contained in:
@@ -1,691 +1,12 @@
|
|||||||
#ifndef TDL_INCLUDE_AST_H
|
#ifndef TRITON_INCLUDE_AST_AST_H
|
||||||
#define TDL_INCLUDE_AST_H
|
#define TRITON_INCLUDE_AST_AST_H
|
||||||
|
|
||||||
|
#include "ops.h"
|
||||||
#include "parser.hpp"
|
#include "parser.hpp"
|
||||||
#include <cassert>
|
#include "declaration.h"
|
||||||
#include <vector>
|
#include "error.h"
|
||||||
#include <string>
|
#include "expression.h"
|
||||||
#include <iostream>
|
#include "node.h"
|
||||||
|
#include "ops.h"
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
|
|
||||||
|
|
||||||
namespace ir{
|
|
||||||
class function;
|
|
||||||
class value;
|
|
||||||
class type;
|
|
||||||
class builder;
|
|
||||||
class module;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace ast{
|
|
||||||
|
|
||||||
// Enumerations
|
|
||||||
enum ASSIGN_OP_T{
|
|
||||||
ASSIGN,
|
|
||||||
INPLACE_MUL, INPLACE_DIV, INPLACE_MOD,
|
|
||||||
INPLACE_ADD, INPLACE_SUB,
|
|
||||||
INPLACE_LSHIFT, INPLACE_RSHIFT,
|
|
||||||
INPLACE_AND, INPLACE_XOR,
|
|
||||||
INPLACE_OR
|
|
||||||
};
|
|
||||||
|
|
||||||
enum BIN_OP_T{
|
|
||||||
MUL, DIV, MOD,
|
|
||||||
ADD, SUB,
|
|
||||||
LEFT_SHIFT, RIGHT_SHIFT,
|
|
||||||
LT, GT,
|
|
||||||
LE, GE,
|
|
||||||
EQ, NE,
|
|
||||||
AND, XOR, OR,
|
|
||||||
LAND, LOR
|
|
||||||
};
|
|
||||||
|
|
||||||
enum UNARY_OP_T{
|
|
||||||
INC, DEC,
|
|
||||||
PLUS, MINUS,
|
|
||||||
ADDR, DEREF,
|
|
||||||
COMPL, NOT
|
|
||||||
};
|
|
||||||
|
|
||||||
enum TYPE_T{
|
|
||||||
VOID_T,
|
|
||||||
UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T,
|
|
||||||
INT1_T, INT8_T, INT16_T, INT32_T, INT64_T,
|
|
||||||
FLOAT32_T, FLOAT64_T
|
|
||||||
};
|
|
||||||
|
|
||||||
enum STORAGE_SPEC_T{
|
|
||||||
CONST_T,
|
|
||||||
TUNABLE_T,
|
|
||||||
KERNEL_T,
|
|
||||||
RESTRICT_T,
|
|
||||||
READONLY_T,
|
|
||||||
CONSTANT_SPACE_T,
|
|
||||||
WRITEONLY_T
|
|
||||||
};
|
|
||||||
|
|
||||||
class pointer;
|
|
||||||
class identifier;
|
|
||||||
class constant;
|
|
||||||
|
|
||||||
// AST
|
|
||||||
class node {
|
|
||||||
protected:
|
|
||||||
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
|
|
||||||
static void implicit_broadcast(ir::module *mod, ir::type *dst_ty, ir::value *&src);
|
|
||||||
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
|
|
||||||
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
|
||||||
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
|
|
||||||
public:
|
|
||||||
virtual ir::value* codegen(ir::module *) const { return nullptr; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template<class T>
|
|
||||||
class list: public node {
|
|
||||||
public:
|
|
||||||
list(const T& x): values_(1, x) {}
|
|
||||||
|
|
||||||
node* append(const T& x){
|
|
||||||
values_.push_back(x);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module * mod) const{
|
|
||||||
for(T x: values_){
|
|
||||||
x->codegen(mod);
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<T> &values() const
|
|
||||||
{ return values_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::vector<T> values_;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum slice_enum_t{
|
|
||||||
ALL,
|
|
||||||
NEWAXIS
|
|
||||||
};
|
|
||||||
|
|
||||||
class slice: public node{
|
|
||||||
public:
|
|
||||||
slice(slice_enum_t type)
|
|
||||||
: type_(type){}
|
|
||||||
|
|
||||||
slice_enum_t type() const{
|
|
||||||
return type_;
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
const slice_enum_t type_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class named_expression;
|
|
||||||
|
|
||||||
class expression: public node{
|
|
||||||
public:
|
|
||||||
virtual ir::value* codegen(ir::module *) const = 0;
|
|
||||||
named_expression *lvalue() const { return lvalue_; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
named_expression *lvalue_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class postfix_expression: public expression{
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
class builtin_expression: public node{
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
class typed_declaration_specifier;
|
|
||||||
class alloc_const: public builtin_expression{
|
|
||||||
public:
|
|
||||||
alloc_const(node *spec, node *size): spec_((typed_declaration_specifier*)spec), size_((constant*)size) { }
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const typed_declaration_specifier* spec_;
|
|
||||||
const constant* size_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class get_global_range: public builtin_expression{
|
|
||||||
public:
|
|
||||||
get_global_range(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { }
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const constant* size_;
|
|
||||||
const constant* axis_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class get_range_id: public builtin_expression{
|
|
||||||
public:
|
|
||||||
get_range_id(node *axis): axis_((constant*)axis) { }
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const constant* axis_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class atomic_cas: public builtin_expression{
|
|
||||||
public:
|
|
||||||
atomic_cas(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { }
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const node *ptr_;
|
|
||||||
const node *cmp_;
|
|
||||||
const node *val_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class matmul_expression: public builtin_expression{
|
|
||||||
public:
|
|
||||||
matmul_expression(node* A, node *B, node *C):
|
|
||||||
A_((expression*)A), B_((expression*)B), C_((expression*)C) { }
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const expression *A_;
|
|
||||||
const expression *B_;
|
|
||||||
const expression *C_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class max_expression: public builtin_expression{
|
|
||||||
public:
|
|
||||||
max_expression(node* x, node* y)
|
|
||||||
: x_((expression*)x), y_((expression*)y){ }
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const expression *x_;
|
|
||||||
const expression *y_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class min_expression: public builtin_expression{
|
|
||||||
public:
|
|
||||||
min_expression(node* x, node* y)
|
|
||||||
: x_((expression*)x), y_((expression*)y){ }
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const expression *x_;
|
|
||||||
const expression *y_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class select_expression: public builtin_expression{
|
|
||||||
public:
|
|
||||||
select_expression(node* pred, node* if_value, node* else_value)
|
|
||||||
: pred_((expression*)pred), if_value_((expression*)if_value), else_value_((expression*)else_value) { }
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const expression *pred_;
|
|
||||||
const expression *if_value_;
|
|
||||||
const expression *else_value_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class trans_expression: public builtin_expression{
|
|
||||||
public:
|
|
||||||
trans_expression(node *arg): arg_(arg) {}
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
node* arg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class indexing_expression: public postfix_expression{
|
|
||||||
public:
|
|
||||||
indexing_expression(node *id, node *slices)
|
|
||||||
: id_((const identifier*)id), slices_((const list<slice*>*)slices) {}
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const identifier* id_;
|
|
||||||
const list<slice*>* slices_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class named_expression: public expression {
|
|
||||||
public:
|
|
||||||
named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; }
|
|
||||||
const identifier *id() const { return id_; }
|
|
||||||
ir::value* codegen(ir::module * mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const identifier *id_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class binary_operator: public expression{
|
|
||||||
private:
|
|
||||||
ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
binary_operator(BIN_OP_T op, node *lhs, node *rhs)
|
|
||||||
: op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) {
|
|
||||||
}
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const BIN_OP_T op_;
|
|
||||||
const expression *lhs_;
|
|
||||||
const expression *rhs_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class constant: public expression{
|
|
||||||
public:
|
|
||||||
constant(int value): value_(value) { }
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
int value() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const int value_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class constant_range: public expression {
|
|
||||||
public:
|
|
||||||
constant_range(node *first, node *last)
|
|
||||||
: first_((constant*)first), last_((constant*)last) { }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
constant *first_;
|
|
||||||
constant *last_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class string_literal: public expression{
|
|
||||||
public:
|
|
||||||
string_literal(char *&value): value_(value) { }
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
std::string value_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class unary_operator: public expression{
|
|
||||||
private:
|
|
||||||
ir::value *llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
unary_operator(UNARY_OP_T op, node *arg)
|
|
||||||
: op_(op),
|
|
||||||
arg_((expression*)arg) {
|
|
||||||
if(op == DEREF)
|
|
||||||
this->lvalue_ = arg_->lvalue();
|
|
||||||
}
|
|
||||||
|
|
||||||
UNARY_OP_T get_op() const { return op_; }
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const UNARY_OP_T op_;
|
|
||||||
const expression *arg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class type_name;
|
|
||||||
class cast_operator: public expression{
|
|
||||||
private:
|
|
||||||
ir::value *llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
cast_operator(node *T, node *arg):
|
|
||||||
T_((type_name*)T),
|
|
||||||
arg_((expression*)arg) { }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const type_name *T_;
|
|
||||||
const expression *arg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class conditional_expression: public expression{
|
|
||||||
private:
|
|
||||||
ir::value *llvm_op(ir::builder &builder,
|
|
||||||
ir::value *cond, ir::value *true_value, ir::value *false_value,
|
|
||||||
const std::string &name) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
conditional_expression(node *cond, node *true_value, node *false_value)
|
|
||||||
: cond_((expression*)cond),
|
|
||||||
true_value_((expression*)true_value),
|
|
||||||
false_value_((expression*)false_value) { }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const expression *cond_;
|
|
||||||
const expression *true_value_;
|
|
||||||
const expression *false_value_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class assignment_expression: public expression{
|
|
||||||
public:
|
|
||||||
assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue)
|
|
||||||
: lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
const expression *lvalue() const { return lvalue_; }
|
|
||||||
const expression *rvalue() const { return rvalue_; }
|
|
||||||
|
|
||||||
public:
|
|
||||||
ASSIGN_OP_T op_;
|
|
||||||
const expression *lvalue_;
|
|
||||||
const expression *rvalue_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class initializer;
|
|
||||||
class declaration_specifier;
|
|
||||||
|
|
||||||
class block_item: public node{
|
|
||||||
};
|
|
||||||
|
|
||||||
class declaration: public block_item{
|
|
||||||
public:
|
|
||||||
declaration(node *spec, node *init)
|
|
||||||
: spec_((declaration_specifier*)spec), init_((list<initializer*>*)init) { }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module * mod) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const declaration_specifier *spec_;
|
|
||||||
const list<initializer*> *init_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class statement: public block_item{
|
|
||||||
};
|
|
||||||
|
|
||||||
class expression_statement: public statement{
|
|
||||||
public:
|
|
||||||
expression_statement(node *expr, node *mask = nullptr)
|
|
||||||
: expr_((expression*)expr), pred_((expression*)mask){ }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module * mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
expression *expr_;
|
|
||||||
expression *pred_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class compound_statement: public statement{
|
|
||||||
typedef list<declaration*>* declarations_t;
|
|
||||||
typedef list<statement*>* statements_t;
|
|
||||||
|
|
||||||
public:
|
|
||||||
compound_statement(node* items)
|
|
||||||
: items_((list<block_item*>*)items){}
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module * mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
list<block_item*>* items_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class selection_statement: public statement{
|
|
||||||
public:
|
|
||||||
selection_statement(node *cond, node *if_value, node *else_value = nullptr)
|
|
||||||
: cond_(cond), then_value_(if_value), else_value_(else_value) { }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const node *cond_;
|
|
||||||
const node *then_value_;
|
|
||||||
const node *else_value_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class iteration_statement: public statement{
|
|
||||||
public:
|
|
||||||
iteration_statement(node *init, node *stop, node *exec, node *statements)
|
|
||||||
: init_(init), stop_(stop), exec_(exec), statements_(statements)
|
|
||||||
{ }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const node *init_;
|
|
||||||
const node *stop_;
|
|
||||||
const node *exec_;
|
|
||||||
const node *statements_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class while_statement: public statement{
|
|
||||||
public:
|
|
||||||
while_statement(node *cond, node *statements)
|
|
||||||
: cond_(cond), statements_(statements)
|
|
||||||
{ }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const node *cond_;
|
|
||||||
const node *statements_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Jump
|
|
||||||
|
|
||||||
class jump_statement: public statement{
|
|
||||||
public:
|
|
||||||
using statement::statement;
|
|
||||||
};
|
|
||||||
|
|
||||||
class continue_statement: public jump_statement{
|
|
||||||
public:
|
|
||||||
ir::value* codegen(ir::module *mod) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
class no_op: public statement { };
|
|
||||||
|
|
||||||
// Types
|
|
||||||
class declaration_specifier: public node{
|
|
||||||
public:
|
|
||||||
virtual ir::type* type(ir::module *mod) const = 0;
|
|
||||||
virtual std::vector<STORAGE_SPEC_T> storage() const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class typed_declaration_specifier: public declaration_specifier {
|
|
||||||
public:
|
|
||||||
typed_declaration_specifier(TYPE_T ty): ty_(ty){ }
|
|
||||||
ir::type* type(ir::module *mod) const;
|
|
||||||
std::vector<STORAGE_SPEC_T> storage() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const TYPE_T ty_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class storage_declaration_specifier: public declaration_specifier {
|
|
||||||
public:
|
|
||||||
storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec)
|
|
||||||
: storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {}
|
|
||||||
ir::type* type(ir::module *mod) const;
|
|
||||||
std::vector<STORAGE_SPEC_T> storage() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
const STORAGE_SPEC_T storage_spec_;
|
|
||||||
const declaration_specifier* decl_spec_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class declarator;
|
|
||||||
class parameter: public node {
|
|
||||||
public:
|
|
||||||
parameter(node *spec, node *decl)
|
|
||||||
: spec_((declaration_specifier*)spec),
|
|
||||||
decl_((declarator*)decl) { }
|
|
||||||
|
|
||||||
ir::type* type(ir::module *mod) const;
|
|
||||||
std::vector<STORAGE_SPEC_T> storage() const;
|
|
||||||
const identifier* id() const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const declaration_specifier *spec_;
|
|
||||||
const declarator *decl_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/* Declarators */
|
|
||||||
class declarator: public node{
|
|
||||||
protected:
|
|
||||||
typedef std::vector<STORAGE_SPEC_T> storage_spec_vec_t;
|
|
||||||
typedef const storage_spec_vec_t& storage_spec_vec_const_ref_t;
|
|
||||||
|
|
||||||
public:
|
|
||||||
virtual ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const = 0;
|
|
||||||
|
|
||||||
public:
|
|
||||||
declarator(node *lhs)
|
|
||||||
: lhs_((declarator*)lhs), ptr_(nullptr){ }
|
|
||||||
|
|
||||||
ir::type* type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
|
||||||
|
|
||||||
const identifier* id() const {
|
|
||||||
return (const identifier*)lhs_;
|
|
||||||
}
|
|
||||||
|
|
||||||
declarator *set_ptr(node *ptr){
|
|
||||||
ptr_ = (pointer*)ptr;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_addr_space(unsigned addr_space){
|
|
||||||
addr_space_ = addr_space;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
declarator *lhs_;
|
|
||||||
pointer *ptr_;
|
|
||||||
unsigned addr_space_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class identifier: public declarator {
|
|
||||||
ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
identifier(char *&name): declarator(this), name_(name) { }
|
|
||||||
const std::string &name() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::string name_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class pointer: public declarator{
|
|
||||||
private:
|
|
||||||
ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
pointer(node *id): declarator(id) { }
|
|
||||||
};
|
|
||||||
|
|
||||||
class tile: public declarator{
|
|
||||||
private:
|
|
||||||
ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
tile(node *id, node *shapes)
|
|
||||||
: declarator(id), shapes_((list<expression*>*)(shapes)) { }
|
|
||||||
|
|
||||||
public:
|
|
||||||
const list<expression*>* shapes_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class function: public declarator{
|
|
||||||
private:
|
|
||||||
ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
function(node *id, node *args)
|
|
||||||
: declarator(id), args_((list<parameter*>*)args) { }
|
|
||||||
|
|
||||||
void bind_parameters(ir::module *mod, ir::function *fn) const;
|
|
||||||
unsigned get_num_args() const { return args_->values().size(); }
|
|
||||||
parameter* get_arg(unsigned i) const { return args_->values().at(i); }
|
|
||||||
|
|
||||||
public:
|
|
||||||
const list<parameter*>* args_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class initializer : public declarator{
|
|
||||||
private:
|
|
||||||
ir::type* type_impl(ir::module * mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
initializer(node *decl, node *init)
|
|
||||||
: declarator((node*)((declarator*)decl)->id()),
|
|
||||||
decl_((declarator*)decl), expr_((expression*)init){ }
|
|
||||||
|
|
||||||
void set_specifier(const declaration_specifier *spec);
|
|
||||||
ir::value* codegen(ir::module *) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const declaration_specifier *spec_;
|
|
||||||
declarator *decl_;
|
|
||||||
const expression *expr_;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class type_name: public node{
|
|
||||||
public:
|
|
||||||
type_name(node *spec, node * decl)
|
|
||||||
: spec_((declaration_specifier*)spec), decl_((declarator*)decl) { }
|
|
||||||
|
|
||||||
ir::type *type(ir::module *mod) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const declaration_specifier *spec_;
|
|
||||||
const declarator *decl_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/* Function definition */
|
|
||||||
class function_definition: public node{
|
|
||||||
public:
|
|
||||||
function_definition(node *spec, node *header, node *body)
|
|
||||||
: spec_((declaration_specifier*)spec), header_((function *)header), body_((compound_statement*)body) { }
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module * mod) const;
|
|
||||||
|
|
||||||
public:
|
|
||||||
const declaration_specifier *spec_;
|
|
||||||
const function *header_;
|
|
||||||
const compound_statement *body_;
|
|
||||||
};
|
|
||||||
|
|
||||||
/* Translation Unit */
|
|
||||||
class translation_unit: public node{
|
|
||||||
public:
|
|
||||||
translation_unit(node *item)
|
|
||||||
: decls_(item) { }
|
|
||||||
|
|
||||||
translation_unit *add(node *item) {
|
|
||||||
decls_.append(item);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value* codegen(ir::module * mod) const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
list<node*> decls_;
|
|
||||||
};
|
|
||||||
|
|
||||||
void update_location(const char *t);
|
|
||||||
void print_error(const char *error);
|
|
||||||
char return_impl(char t, const char * yytext);
|
|
||||||
yytokentype return_impl(yytokentype t, const char * yytext);
|
|
||||||
void return_void(const char * yytext);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
222
include/triton/ast/declaration.h
Normal file
222
include/triton/ast/declaration.h
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_AST_DECLARATION_H
|
||||||
|
#define TRITON_INCLUDE_AST_DECLARATION_H
|
||||||
|
|
||||||
|
#include "node.h"
|
||||||
|
#include "parser.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
|
||||||
|
namespace ir{
|
||||||
|
class function;
|
||||||
|
class value;
|
||||||
|
class type;
|
||||||
|
class builder;
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
class expression;
|
||||||
|
class pointer;
|
||||||
|
class identifier;
|
||||||
|
class constant;
|
||||||
|
class compound_statement;
|
||||||
|
class initializer;
|
||||||
|
class declaration_specifier;
|
||||||
|
|
||||||
|
|
||||||
|
class declaration: public block_item{
|
||||||
|
public:
|
||||||
|
declaration(node *spec, node *init)
|
||||||
|
: spec_((declaration_specifier*)spec), init_((list<initializer*>*)init) { }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const declaration_specifier *spec_;
|
||||||
|
const list<initializer*> *init_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Types
|
||||||
|
class declaration_specifier: public node{
|
||||||
|
public:
|
||||||
|
virtual ir::type* type(ir::module *mod) const = 0;
|
||||||
|
virtual std::vector<STORAGE_SPEC_T> storage() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class typed_declaration_specifier: public declaration_specifier {
|
||||||
|
public:
|
||||||
|
typed_declaration_specifier(TYPE_T ty): ty_(ty){ }
|
||||||
|
ir::type* type(ir::module *mod) const;
|
||||||
|
std::vector<STORAGE_SPEC_T> storage() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const TYPE_T ty_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class storage_declaration_specifier: public declaration_specifier {
|
||||||
|
public:
|
||||||
|
storage_declaration_specifier(STORAGE_SPEC_T storage_spec, node *decl_spec)
|
||||||
|
: storage_spec_(storage_spec), decl_spec_((declaration_specifier*)decl_spec) {}
|
||||||
|
ir::type* type(ir::module *mod) const;
|
||||||
|
std::vector<STORAGE_SPEC_T> storage() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const STORAGE_SPEC_T storage_spec_;
|
||||||
|
const declaration_specifier* decl_spec_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class declarator;
|
||||||
|
class parameter: public node {
|
||||||
|
public:
|
||||||
|
parameter(node *spec, node *decl)
|
||||||
|
: spec_((declaration_specifier*)spec),
|
||||||
|
decl_((declarator*)decl) { }
|
||||||
|
|
||||||
|
ir::type* type(ir::module *mod) const;
|
||||||
|
std::vector<STORAGE_SPEC_T> storage() const;
|
||||||
|
const identifier* id() const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const declaration_specifier *spec_;
|
||||||
|
const declarator *decl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Declarators */
|
||||||
|
class declarator: public node{
|
||||||
|
protected:
|
||||||
|
typedef std::vector<STORAGE_SPEC_T> storage_spec_vec_t;
|
||||||
|
typedef const storage_spec_vec_t& storage_spec_vec_const_ref_t;
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const = 0;
|
||||||
|
|
||||||
|
public:
|
||||||
|
declarator(node *lhs)
|
||||||
|
: lhs_((declarator*)lhs), ptr_(nullptr){ }
|
||||||
|
|
||||||
|
ir::type* type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
||||||
|
|
||||||
|
const identifier* id() const {
|
||||||
|
return (const identifier*)lhs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
declarator *set_ptr(node *ptr){
|
||||||
|
ptr_ = (pointer*)ptr;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_addr_space(unsigned addr_space){
|
||||||
|
addr_space_ = addr_space;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
declarator *lhs_;
|
||||||
|
pointer *ptr_;
|
||||||
|
unsigned addr_space_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class identifier: public declarator {
|
||||||
|
ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
identifier(char *&name): declarator(this), name_(name) { }
|
||||||
|
const std::string &name() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class pointer: public declarator{
|
||||||
|
private:
|
||||||
|
ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
pointer(node *id): declarator(id) { }
|
||||||
|
};
|
||||||
|
|
||||||
|
class tile: public declarator{
|
||||||
|
private:
|
||||||
|
ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
tile(node *id, node *shapes)
|
||||||
|
: declarator(id), shapes_((list<expression*>*)(shapes)) { }
|
||||||
|
|
||||||
|
public:
|
||||||
|
const list<expression*>* shapes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class function: public declarator{
|
||||||
|
private:
|
||||||
|
ir::type* type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
function(node *id, node *args)
|
||||||
|
: declarator(id), args_((list<parameter*>*)args) { }
|
||||||
|
|
||||||
|
void bind_parameters(ir::module *mod, ir::function *fn) const;
|
||||||
|
unsigned get_num_args() const { return args_->values().size(); }
|
||||||
|
parameter* get_arg(unsigned i) const { return args_->values().at(i); }
|
||||||
|
|
||||||
|
public:
|
||||||
|
const list<parameter*>* args_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class initializer : public declarator{
|
||||||
|
private:
|
||||||
|
ir::type* type_impl(ir::module * mod, ir::type *type, storage_spec_vec_const_ref_t storage) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
initializer(node *decl, node *init)
|
||||||
|
: declarator((node*)((declarator*)decl)->id()),
|
||||||
|
decl_((declarator*)decl), expr_((expression*)init){ }
|
||||||
|
|
||||||
|
void set_specifier(const declaration_specifier *spec);
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const declaration_specifier *spec_;
|
||||||
|
declarator *decl_;
|
||||||
|
const expression *expr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class type_name: public node{
|
||||||
|
public:
|
||||||
|
type_name(node *spec, node * decl)
|
||||||
|
: spec_((declaration_specifier*)spec), decl_((declarator*)decl) { }
|
||||||
|
|
||||||
|
ir::type *type(ir::module *mod) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const declaration_specifier *spec_;
|
||||||
|
const declarator *decl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Function definition */
|
||||||
|
class function_definition: public node{
|
||||||
|
public:
|
||||||
|
function_definition(node *spec, node *header, node *body)
|
||||||
|
: spec_((declaration_specifier*)spec), header_((function *)header), body_((compound_statement*)body) { }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const declaration_specifier *spec_;
|
||||||
|
const function *header_;
|
||||||
|
const compound_statement *body_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
62
include/triton/ast/error.h
Normal file
62
include/triton/ast/error.h
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_AST_ERROR_H
|
||||||
|
#define TRITON_INCLUDE_AST_ERROR_H
|
||||||
|
|
||||||
|
#include "ops.h"
|
||||||
|
#include "parser.hpp"
|
||||||
|
#include "node.h"
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
|
||||||
|
namespace ir{
|
||||||
|
class function;
|
||||||
|
class value;
|
||||||
|
class type;
|
||||||
|
class builder;
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
class expression;
|
||||||
|
class pointer;
|
||||||
|
class identifier;
|
||||||
|
class constant;
|
||||||
|
class compound_statement;
|
||||||
|
class initializer;
|
||||||
|
class declaration_specifier;
|
||||||
|
class function;
|
||||||
|
|
||||||
|
/* Translation Unit */
|
||||||
|
class translation_unit: public node{
|
||||||
|
public:
|
||||||
|
translation_unit(node *item)
|
||||||
|
: decls_(item) { }
|
||||||
|
|
||||||
|
translation_unit *add(node *item) {
|
||||||
|
decls_.append(item);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
list<node*> decls_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void update_location(const char *t);
|
||||||
|
void print_error(const char *error);
|
||||||
|
char return_impl(char t, const char * yytext);
|
||||||
|
yytokentype return_impl(yytokentype t, const char * yytext);
|
||||||
|
void return_void(const char * yytext);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
311
include/triton/ast/expression.h
Normal file
311
include/triton/ast/expression.h
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
#ifndef TDL_INCLUDE_AST_EXPRESSION_H
|
||||||
|
#define TDL_INCLUDE_AST_EXPRESSION_H
|
||||||
|
|
||||||
|
#include "parser.hpp"
|
||||||
|
#include "ast.h"
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
|
||||||
|
namespace ir{
|
||||||
|
class function;
|
||||||
|
class value;
|
||||||
|
class type;
|
||||||
|
class builder;
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
|
||||||
|
enum slice_enum_t{
|
||||||
|
ALL,
|
||||||
|
NEWAXIS
|
||||||
|
};
|
||||||
|
|
||||||
|
class slice: public node{
|
||||||
|
public:
|
||||||
|
slice(slice_enum_t type)
|
||||||
|
: type_(type){}
|
||||||
|
|
||||||
|
slice_enum_t type() const{
|
||||||
|
return type_;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
const slice_enum_t type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class named_expression;
|
||||||
|
|
||||||
|
class expression: public node{
|
||||||
|
public:
|
||||||
|
virtual ir::value* codegen(ir::module *) const = 0;
|
||||||
|
named_expression *lvalue() const { return lvalue_; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
named_expression *lvalue_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class postfix_expression: public expression{
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
class builtin_expression: public node{
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
class typed_declaration_specifier;
|
||||||
|
class alloc_const_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
alloc_const_expression(node *spec, node *size): spec_((typed_declaration_specifier*)spec), size_((constant*)size) { }
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const typed_declaration_specifier* spec_;
|
||||||
|
const constant* size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class get_global_range_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
get_global_range_expression(node *size, node *axis): size_((constant*)size), axis_((constant*)axis) { }
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const constant* size_;
|
||||||
|
const constant* axis_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class get_range_id_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
get_range_id_expression(node *axis): axis_((constant*)axis) { }
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const constant* axis_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class atomic_cas_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
atomic_cas_expression(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { }
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const node *ptr_;
|
||||||
|
const node *cmp_;
|
||||||
|
const node *val_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class matmul_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
matmul_expression(node* A, node *B, node *C):
|
||||||
|
A_((expression*)A), B_((expression*)B), C_((expression*)C) { }
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const expression *A_;
|
||||||
|
const expression *B_;
|
||||||
|
const expression *C_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class max_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
max_expression(node* x, node* y)
|
||||||
|
: x_((expression*)x), y_((expression*)y){ }
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const expression *x_;
|
||||||
|
const expression *y_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class min_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
min_expression(node* x, node* y)
|
||||||
|
: x_((expression*)x), y_((expression*)y){ }
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const expression *x_;
|
||||||
|
const expression *y_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class select_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
select_expression(node* pred, node* if_value, node* else_value)
|
||||||
|
: pred_((expression*)pred), if_value_((expression*)if_value), else_value_((expression*)else_value) { }
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const expression *pred_;
|
||||||
|
const expression *if_value_;
|
||||||
|
const expression *else_value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class trans_expression: public builtin_expression{
|
||||||
|
public:
|
||||||
|
trans_expression(node *arg): arg_(arg) {}
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
node* arg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class indexing_expression: public postfix_expression{
|
||||||
|
public:
|
||||||
|
indexing_expression(node *id, node *slices)
|
||||||
|
: id_((const identifier*)id), slices_((const list<slice*>*)slices) {}
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const identifier* id_;
|
||||||
|
const list<slice*>* slices_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class named_expression: public expression {
|
||||||
|
public:
|
||||||
|
named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; }
|
||||||
|
const identifier *id() const { return id_; }
|
||||||
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const identifier *id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class binary_expression: public expression{
|
||||||
|
private:
|
||||||
|
ir::value* llvm_op(ir::module *mod, ir::builder &bld, ir::value *lhs, ir::value *rhs, const std::string &name) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
binary_expression(BIN_OP_T op, node *lhs, node *rhs)
|
||||||
|
: op_(op), lhs_((expression*)lhs), rhs_((expression*)rhs) {
|
||||||
|
}
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const BIN_OP_T op_;
|
||||||
|
const expression *lhs_;
|
||||||
|
const expression *rhs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class constant: public expression{
|
||||||
|
public:
|
||||||
|
constant(int value): value_(value) { }
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
int value() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class constant_range: public expression {
|
||||||
|
public:
|
||||||
|
constant_range(node *first, node *last)
|
||||||
|
: first_((constant*)first), last_((constant*)last) { }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
constant *first_;
|
||||||
|
constant *last_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class string_literal: public expression{
|
||||||
|
public:
|
||||||
|
string_literal(char *&value): value_(value) { }
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::string value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class unary_expression: public expression{
|
||||||
|
private:
|
||||||
|
ir::value *llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
unary_expression(UNARY_OP_T op, node *arg)
|
||||||
|
: op_(op),
|
||||||
|
arg_((expression*)arg) {
|
||||||
|
if(op == DEREF)
|
||||||
|
this->lvalue_ = arg_->lvalue();
|
||||||
|
}
|
||||||
|
|
||||||
|
UNARY_OP_T get_op() const { return op_; }
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const UNARY_OP_T op_;
|
||||||
|
const expression *arg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class type_name;
|
||||||
|
class cast_expression: public expression{
|
||||||
|
private:
|
||||||
|
ir::value *llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
cast_expression(node *T, node *arg):
|
||||||
|
T_((type_name*)T),
|
||||||
|
arg_((expression*)arg) { }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const type_name *T_;
|
||||||
|
const expression *arg_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class conditional_expression: public expression{
|
||||||
|
private:
|
||||||
|
ir::value *llvm_op(ir::builder &builder,
|
||||||
|
ir::value *cond, ir::value *true_value, ir::value *false_value,
|
||||||
|
const std::string &name) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
conditional_expression(node *cond, node *true_value, node *false_value)
|
||||||
|
: cond_((expression*)cond),
|
||||||
|
true_value_((expression*)true_value),
|
||||||
|
false_value_((expression*)false_value) { }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const expression *cond_;
|
||||||
|
const expression *true_value_;
|
||||||
|
const expression *false_value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class assignment_expression: public expression{
|
||||||
|
public:
|
||||||
|
assignment_expression(node *lvalue, ASSIGN_OP_T op, node *rvalue)
|
||||||
|
: lvalue_((named_expression*)lvalue), op_(op), rvalue_((expression*)rvalue) { }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
const expression *lvalue() const { return lvalue_; }
|
||||||
|
const expression *rvalue() const { return rvalue_; }
|
||||||
|
|
||||||
|
public:
|
||||||
|
ASSIGN_OP_T op_;
|
||||||
|
const expression *lvalue_;
|
||||||
|
const expression *rvalue_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
37
include/triton/ast/module.h
Normal file
37
include/triton/ast/module.h
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_AST_MODULE_H
|
||||||
|
#define TRITON_INCLUDE_AST_MODULE_H
|
||||||
|
|
||||||
|
#include "ops.h"
|
||||||
|
#include "parser.hpp"
|
||||||
|
#include "node.h"
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
/* Translation Unit */
|
||||||
|
class translation_unit: public node{
|
||||||
|
public:
|
||||||
|
translation_unit(node *item)
|
||||||
|
: decls_(item) { }
|
||||||
|
|
||||||
|
translation_unit *add(node *item) {
|
||||||
|
decls_.append(item);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
list<node*> decls_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
77
include/triton/ast/node.h
Normal file
77
include/triton/ast/node.h
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_AST_NODE_H
|
||||||
|
#define TRITON_INCLUDE_AST_NODE_H
|
||||||
|
|
||||||
|
#include "ops.h"
|
||||||
|
#include "parser.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
|
||||||
|
namespace ir{
|
||||||
|
class function;
|
||||||
|
class value;
|
||||||
|
class type;
|
||||||
|
class builder;
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
class expression;
|
||||||
|
class pointer;
|
||||||
|
class identifier;
|
||||||
|
class constant;
|
||||||
|
class compound_statement;
|
||||||
|
class initializer;
|
||||||
|
class declaration_specifier;
|
||||||
|
class function;
|
||||||
|
|
||||||
|
// Node
|
||||||
|
class node {
|
||||||
|
protected:
|
||||||
|
static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty);
|
||||||
|
static void implicit_broadcast(ir::module *mod, ir::type *dst_ty, ir::value *&src);
|
||||||
|
static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs);
|
||||||
|
static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||||
|
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed);
|
||||||
|
public:
|
||||||
|
virtual ir::value* codegen(ir::module *) const { return nullptr; }
|
||||||
|
};
|
||||||
|
|
||||||
|
class block_item: public node{
|
||||||
|
};
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
class list: public node {
|
||||||
|
public:
|
||||||
|
list(const T& x): values_(1, x) {}
|
||||||
|
|
||||||
|
node* append(const T& x){
|
||||||
|
values_.push_back(x);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module * mod) const{
|
||||||
|
for(T x: values_){
|
||||||
|
x->codegen(mod);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<T> &values() const
|
||||||
|
{ return values_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<T> values_;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
60
include/triton/ast/ops.h
Normal file
60
include/triton/ast/ops.h
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_AST_OPS_H
|
||||||
|
#define TRITON_INCLUDE_AST_OPS_H
|
||||||
|
|
||||||
|
#include "parser.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
enum ASSIGN_OP_T{
|
||||||
|
ASSIGN,
|
||||||
|
INPLACE_MUL, INPLACE_DIV, INPLACE_MOD,
|
||||||
|
INPLACE_ADD, INPLACE_SUB,
|
||||||
|
INPLACE_LSHIFT, INPLACE_RSHIFT,
|
||||||
|
INPLACE_AND, INPLACE_XOR,
|
||||||
|
INPLACE_OR
|
||||||
|
};
|
||||||
|
|
||||||
|
enum BIN_OP_T{
|
||||||
|
MUL, DIV, MOD,
|
||||||
|
ADD, SUB,
|
||||||
|
LEFT_SHIFT, RIGHT_SHIFT,
|
||||||
|
LT, GT,
|
||||||
|
LE, GE,
|
||||||
|
EQ, NE,
|
||||||
|
AND, XOR, OR,
|
||||||
|
LAND, LOR
|
||||||
|
};
|
||||||
|
|
||||||
|
enum UNARY_OP_T{
|
||||||
|
INC, DEC,
|
||||||
|
PLUS, MINUS,
|
||||||
|
ADDR, DEREF,
|
||||||
|
COMPL, NOT
|
||||||
|
};
|
||||||
|
|
||||||
|
enum TYPE_T{
|
||||||
|
VOID_T,
|
||||||
|
UINT1_T, UINT8_T, UINT16_T, UINT32_T, UINT64_T,
|
||||||
|
INT1_T, INT8_T, INT16_T, INT32_T, INT64_T,
|
||||||
|
FLOAT32_T, FLOAT64_T
|
||||||
|
};
|
||||||
|
|
||||||
|
enum STORAGE_SPEC_T{
|
||||||
|
CONST_T,
|
||||||
|
TUNABLE_T,
|
||||||
|
KERNEL_T,
|
||||||
|
RESTRICT_T,
|
||||||
|
READONLY_T,
|
||||||
|
CONSTANT_SPACE_T,
|
||||||
|
WRITEONLY_T
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
@@ -9,6 +9,9 @@ class node;
|
|||||||
using namespace triton::ast;
|
using namespace triton::ast;
|
||||||
#define YYSTYPE node*
|
#define YYSTYPE node*
|
||||||
#include "../include/triton/ast/ast.h"
|
#include "../include/triton/ast/ast.h"
|
||||||
|
#include "../include/triton/ast/expression.h"
|
||||||
|
#include "../include/triton/ast/statement.h"
|
||||||
|
#include "../include/triton/ast/declaration.h"
|
||||||
|
|
||||||
extern char* yytext;
|
extern char* yytext;
|
||||||
void yyerror(const char *s);
|
void yyerror(const char *s);
|
||||||
@@ -94,15 +97,6 @@ abstract_declarator
|
|||||||
direct_abstract_declarator
|
direct_abstract_declarator
|
||||||
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
: '[' primary_expression_list ']' { $$ = new tile(nullptr, $1); }
|
||||||
|
|
||||||
constant:
|
|
||||||
CONSTANT { $$ = new constant(atoi(yytext)); }
|
|
||||||
;
|
|
||||||
|
|
||||||
constant_list:
|
|
||||||
constant { $$ = new list<constant*>((constant*)$1); }
|
|
||||||
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $3); }
|
|
||||||
;
|
|
||||||
|
|
||||||
type_name
|
type_name
|
||||||
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
: declaration_specifiers { $$ = new type_name($1, nullptr); }
|
||||||
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
|
| declaration_specifiers abstract_declarator { $$ = new type_name($1, $2); }
|
||||||
@@ -112,27 +106,39 @@ type_name
|
|||||||
/* Expressions */
|
/* Expressions */
|
||||||
/* -------------------------- */
|
/* -------------------------- */
|
||||||
|
|
||||||
|
/* Constants */
|
||||||
|
constant
|
||||||
|
: CONSTANT { $$ = new constant(atoi(yytext)); }
|
||||||
|
;
|
||||||
|
|
||||||
|
constant_list
|
||||||
|
: constant { $$ = new list<constant*>((constant*)$1); }
|
||||||
|
| constant_list ',' constant { $$ = append_ptr_list<constant>($1, $3); }
|
||||||
|
;
|
||||||
|
|
||||||
identifier
|
identifier
|
||||||
: IDENTIFIER { $$ = new identifier(yytext); }
|
: IDENTIFIER { $$ = new identifier(yytext); }
|
||||||
;
|
;
|
||||||
|
|
||||||
builtin
|
/* Built-in */
|
||||||
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); }
|
builtin_expression
|
||||||
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id($3); }
|
: GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range_expression($3, $6); }
|
||||||
|
| GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
||||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||||
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
| ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const_expression(new typed_declaration_specifier(get_type_spec($2)), $4); }
|
||||||
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
| TRANS '(' expression ')' { $$ = new trans_expression($3); }
|
||||||
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
|
| MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); }
|
||||||
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
| MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); }
|
||||||
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
|
| SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); }
|
||||||
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas($3, $5, $7); }
|
| ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas_expression($3, $5, $7); }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
/* Primary */
|
||||||
primary_expression
|
primary_expression
|
||||||
: identifier { $$ = new named_expression($1); }
|
: identifier { $$ = new named_expression($1); }
|
||||||
| constant { $$ = $1; }
|
| constant { $$ = $1; }
|
||||||
| primary_expression ELLIPSIS primary_expression { $$ = new constant_range($1, $3); }
|
| primary_expression ELLIPSIS primary_expression { $$ = new constant_range($1, $3); }
|
||||||
| builtin { $$ = $1; }
|
| builtin_expression { $$ = $1; }
|
||||||
| STRING_LITERAL { $$ = new string_literal(yytext); }
|
| STRING_LITERAL { $$ = new string_literal(yytext); }
|
||||||
| '(' expression ')' { $$ = $2; }
|
| '(' expression ')' { $$ = $2; }
|
||||||
;
|
;
|
||||||
@@ -142,6 +148,7 @@ primary_expression_list
|
|||||||
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
|
| primary_expression_list ',' primary_expression { $$ = append_ptr_list<expression>($1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
/* Postfix */
|
||||||
slice
|
slice
|
||||||
: ':' { $$ = new slice(triton::ast::ALL); }
|
: ':' { $$ = new slice(triton::ast::ALL); }
|
||||||
| NEWAXIS { $$ = new slice(triton::ast::NEWAXIS); }
|
| NEWAXIS { $$ = new slice(triton::ast::NEWAXIS); }
|
||||||
@@ -155,13 +162,7 @@ postfix_expression
|
|||||||
| identifier '[' slice_list ']' { $$ = new indexing_expression($1, $3);}
|
| identifier '[' slice_list ']' { $$ = new indexing_expression($1, $3);}
|
||||||
;
|
;
|
||||||
|
|
||||||
unary_expression
|
/* Unary */
|
||||||
: postfix_expression { $$ = $1; }
|
|
||||||
| INC_OP unary_expression { $$ = new unary_operator(INC, $2); }
|
|
||||||
| DEC_OP unary_expression { $$ = new unary_operator(DEC, $2); }
|
|
||||||
| unary_operator cast_expression { $$ = new unary_operator(get_unary_op($1), $2); }
|
|
||||||
;
|
|
||||||
|
|
||||||
unary_operator
|
unary_operator
|
||||||
: '&' { $$ = new token(ADDR); }
|
: '&' { $$ = new token(ADDR); }
|
||||||
| '*' { $$ = new token(DEREF); }
|
| '*' { $$ = new token(DEREF); }
|
||||||
@@ -171,70 +172,77 @@ unary_operator
|
|||||||
| '!' { $$ = new token(NOT); }
|
| '!' { $$ = new token(NOT); }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
unary_expression
|
||||||
|
: postfix_expression { $$ = $1; }
|
||||||
|
| INC_OP unary_expression { $$ = new unary_expression(INC, $2); }
|
||||||
|
| DEC_OP unary_expression { $$ = new unary_expression(DEC, $2); }
|
||||||
|
| unary_operator cast_expression { $$ = new unary_expression(get_unary_op($1), $2); }
|
||||||
|
;
|
||||||
|
|
||||||
cast_expression
|
cast_expression
|
||||||
: unary_expression { $$ = $1; }
|
: unary_expression { $$ = $1; }
|
||||||
| '(' type_name ')' cast_expression { $$ = new cast_operator($2, $4); }
|
| '(' type_name ')' cast_expression { $$ = new cast_expression($2, $4); }
|
||||||
;
|
;
|
||||||
|
|
||||||
multiplicative_expression
|
multiplicative_expression
|
||||||
: cast_expression { $$ = $1; }
|
: cast_expression { $$ = $1; }
|
||||||
| multiplicative_expression '*' cast_expression { $$ = new binary_operator(MUL, $1, $3); }
|
| multiplicative_expression '*' cast_expression { $$ = new binary_expression(MUL, $1, $3); }
|
||||||
| multiplicative_expression '/' cast_expression { $$ = new binary_operator(DIV, $1, $3); }
|
| multiplicative_expression '/' cast_expression { $$ = new binary_expression(DIV, $1, $3); }
|
||||||
| multiplicative_expression '%' cast_expression { $$ = new binary_operator(MOD, $1, $3); }
|
| multiplicative_expression '%' cast_expression { $$ = new binary_expression(MOD, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
additive_expression
|
additive_expression
|
||||||
: multiplicative_expression { $$ = $1; }
|
: multiplicative_expression { $$ = $1; }
|
||||||
| additive_expression '+' multiplicative_expression { $$ = new binary_operator(ADD, $1, $3); }
|
| additive_expression '+' multiplicative_expression { $$ = new binary_expression(ADD, $1, $3); }
|
||||||
| additive_expression '-' multiplicative_expression { $$ = new binary_operator(SUB, $1, $3); }
|
| additive_expression '-' multiplicative_expression { $$ = new binary_expression(SUB, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
shift_expression
|
shift_expression
|
||||||
: additive_expression { $$ = $1; }
|
: additive_expression { $$ = $1; }
|
||||||
| shift_expression LEFT_OP additive_expression { $$ = new binary_operator(LEFT_SHIFT, $1, $3); }
|
| shift_expression LEFT_OP additive_expression { $$ = new binary_expression(LEFT_SHIFT, $1, $3); }
|
||||||
| shift_expression RIGHT_OP additive_expression { $$ = new binary_operator(RIGHT_SHIFT, $1, $3); }
|
| shift_expression RIGHT_OP additive_expression { $$ = new binary_expression(RIGHT_SHIFT, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
/* Comparison */
|
/* Comparison */
|
||||||
relational_expression
|
relational_expression
|
||||||
: shift_expression { $$ = $1; }
|
: shift_expression { $$ = $1; }
|
||||||
| relational_expression '<' shift_expression { $$ = new binary_operator(LT, $1, $3); }
|
| relational_expression '<' shift_expression { $$ = new binary_expression(LT, $1, $3); }
|
||||||
| relational_expression '>' shift_expression { $$ = new binary_operator(GT, $1, $3); }
|
| relational_expression '>' shift_expression { $$ = new binary_expression(GT, $1, $3); }
|
||||||
| relational_expression LE_OP shift_expression { $$ = new binary_operator(LE, $1, $3); }
|
| relational_expression LE_OP shift_expression { $$ = new binary_expression(LE, $1, $3); }
|
||||||
| relational_expression GE_OP shift_expression { $$ = new binary_operator(GE, $1, $3); }
|
| relational_expression GE_OP shift_expression { $$ = new binary_expression(GE, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
equality_expression
|
equality_expression
|
||||||
: relational_expression { $$ = $1; }
|
: relational_expression { $$ = $1; }
|
||||||
| equality_expression EQ_OP relational_expression { $$ = new binary_operator(EQ, $1, $3); }
|
| equality_expression EQ_OP relational_expression { $$ = new binary_expression(EQ, $1, $3); }
|
||||||
| equality_expression NE_OP relational_expression { $$ = new binary_operator(NE, $1, $3); }
|
| equality_expression NE_OP relational_expression { $$ = new binary_expression(NE, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
/* Binary */
|
/* Binary */
|
||||||
and_expression
|
and_expression
|
||||||
: equality_expression { $$ = $1; }
|
: equality_expression { $$ = $1; }
|
||||||
| and_expression '&' equality_expression { $$ = new binary_operator(AND, $1, $3); }
|
| and_expression '&' equality_expression { $$ = new binary_expression(AND, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
exclusive_or_expression
|
exclusive_or_expression
|
||||||
: and_expression { $$ = $1; }
|
: and_expression { $$ = $1; }
|
||||||
| exclusive_or_expression '^' and_expression { $$ = new binary_operator(XOR, $1, $3); }
|
| exclusive_or_expression '^' and_expression { $$ = new binary_expression(XOR, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
inclusive_or_expression
|
inclusive_or_expression
|
||||||
: exclusive_or_expression { $$ = $1; }
|
: exclusive_or_expression { $$ = $1; }
|
||||||
| inclusive_or_expression '|' exclusive_or_expression { $$ = new binary_operator(OR, $1, $3); }
|
| inclusive_or_expression '|' exclusive_or_expression { $$ = new binary_expression(OR, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
/* Logical */
|
/* Logical */
|
||||||
logical_and_expression
|
logical_and_expression
|
||||||
: inclusive_or_expression { $$ = $1; }
|
: inclusive_or_expression { $$ = $1; }
|
||||||
| logical_and_expression AND_OP inclusive_or_expression { $$ = new binary_operator(LAND, $1, $3); }
|
| logical_and_expression AND_OP inclusive_or_expression { $$ = new binary_expression(LAND, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
logical_or_expression
|
logical_or_expression
|
||||||
: logical_and_expression { $$ = $1; }
|
: logical_and_expression { $$ = $1; }
|
||||||
| logical_or_expression OR_OP logical_and_expression { $$ = new binary_operator(LOR, $1, $3); }
|
| logical_or_expression OR_OP logical_and_expression { $$ = new binary_expression(LOR, $1, $3); }
|
||||||
;
|
;
|
||||||
|
|
||||||
/* Conditional */
|
/* Conditional */
|
||||||
@@ -364,7 +372,6 @@ declarator
|
|||||||
| direct_declarator { $$ = $1; }
|
| direct_declarator { $$ = $1; }
|
||||||
;
|
;
|
||||||
|
|
||||||
|
|
||||||
init_declarator
|
init_declarator
|
||||||
: declarator { $$ = new initializer($1, nullptr); }
|
: declarator { $$ = new initializer($1, nullptr); }
|
||||||
| declarator '=' initialization_expression { $$ = new initializer($1, $3); }
|
| declarator '=' initialization_expression { $$ = new initializer($1, $3); }
|
||||||
|
121
include/triton/ast/statement.h
Normal file
121
include/triton/ast/statement.h
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
#ifndef TRITON_INCLUDE_AST_STATEMENT_H
|
||||||
|
#define TRITON_INCLUDE_AST_STATEMENT_H
|
||||||
|
|
||||||
|
#include "parser.hpp"
|
||||||
|
#include "triton/ast/ast.h"
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
|
||||||
|
namespace ir{
|
||||||
|
class function;
|
||||||
|
class value;
|
||||||
|
class type;
|
||||||
|
class builder;
|
||||||
|
class module;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
class declaration;
|
||||||
|
|
||||||
|
class statement: public block_item{
|
||||||
|
};
|
||||||
|
|
||||||
|
// Expression
|
||||||
|
class expression_statement: public statement{
|
||||||
|
public:
|
||||||
|
expression_statement(node *expr, node *mask = nullptr)
|
||||||
|
: expr_((expression*)expr), pred_((expression*)mask){ }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
expression *expr_;
|
||||||
|
expression *pred_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compound
|
||||||
|
class compound_statement: public statement{
|
||||||
|
typedef list<declaration*>* declarations_t;
|
||||||
|
typedef list<statement*>* statements_t;
|
||||||
|
|
||||||
|
public:
|
||||||
|
compound_statement(node* items)
|
||||||
|
: items_((list<block_item*>*)items){}
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module * mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
list<block_item*>* items_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Selection
|
||||||
|
class selection_statement: public statement{
|
||||||
|
public:
|
||||||
|
selection_statement(node *cond, node *if_value, node *else_value = nullptr)
|
||||||
|
: cond_(cond), then_value_(if_value), else_value_(else_value) { }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
const node *cond_;
|
||||||
|
const node *then_value_;
|
||||||
|
const node *else_value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Iteration
|
||||||
|
class iteration_statement: public statement{
|
||||||
|
public:
|
||||||
|
iteration_statement(node *init, node *stop, node *exec, node *statements)
|
||||||
|
: init_(init), stop_(stop), exec_(exec), statements_(statements)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const node *init_;
|
||||||
|
const node *stop_;
|
||||||
|
const node *exec_;
|
||||||
|
const node *statements_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// While
|
||||||
|
class while_statement: public statement{
|
||||||
|
public:
|
||||||
|
while_statement(node *cond, node *statements)
|
||||||
|
: cond_(cond), statements_(statements)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const node *cond_;
|
||||||
|
const node *statements_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Jump
|
||||||
|
class jump_statement: public statement{
|
||||||
|
public:
|
||||||
|
using statement::statement;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Continue
|
||||||
|
class continue_statement: public jump_statement{
|
||||||
|
public:
|
||||||
|
ir::value* codegen(ir::module *mod) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// No op
|
||||||
|
class no_op: public statement { };
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
199
lib/ast/declaration.cpp
Normal file
199
lib/ast/declaration.cpp
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
#include "triton/ast/statement.h"
|
||||||
|
#include "triton/ast/declaration.h"
|
||||||
|
#include "triton/ir/function.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/basic_block.h"
|
||||||
|
#include "triton/ir/builder.h"
|
||||||
|
#include "triton/ir/type.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
/* Declaration specifier */
|
||||||
|
ir::type* typed_declaration_specifier::type(ir::module *mod) const {
|
||||||
|
ir::context &ctx = mod->get_context();
|
||||||
|
switch (ty_) {
|
||||||
|
case VOID_T: return ir::type::get_void_ty(ctx);
|
||||||
|
case INT1_T: return ir::type::get_int1_ty(ctx);
|
||||||
|
case INT8_T: return ir::type::get_int8_ty(ctx);
|
||||||
|
case INT16_T: return ir::type::get_int16_ty(ctx);
|
||||||
|
case INT32_T: return ir::type::get_int32_ty(ctx);
|
||||||
|
case INT64_T: return ir::type::get_int64_ty(ctx);
|
||||||
|
case FLOAT32_T: return ir::type::get_float_ty(ctx);
|
||||||
|
case FLOAT64_T: return ir::type::get_double_ty(ctx);
|
||||||
|
default: throw std::runtime_error("unreachable");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<STORAGE_SPEC_T> typed_declaration_specifier::storage() const {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ir::type* storage_declaration_specifier::type(ir::module *mod) const {
|
||||||
|
return decl_spec_->type(mod);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<STORAGE_SPEC_T> storage_declaration_specifier::storage() const {
|
||||||
|
auto result = decl_spec_->storage();
|
||||||
|
result.push_back(storage_spec_);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Parameter */
|
||||||
|
ir::type* parameter::type(ir::module *mod) const {
|
||||||
|
return decl_->type(mod, spec_->type(mod), {});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<STORAGE_SPEC_T> parameter::storage() const {
|
||||||
|
return spec_->storage();
|
||||||
|
}
|
||||||
|
|
||||||
|
const identifier *parameter::id() const {
|
||||||
|
return decl_->id();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Declarators */
|
||||||
|
ir::type* declarator::type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
||||||
|
if(ptr_)
|
||||||
|
return type_impl(mod, ptr_->type(mod, type, storage), storage);
|
||||||
|
return type_impl(mod, type, storage);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier
|
||||||
|
ir::type* identifier::type_impl(ir::module *, ir::type *type, storage_spec_vec_const_ref_t) const{
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string &identifier::name() const{
|
||||||
|
return name_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tile
|
||||||
|
ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t) const{
|
||||||
|
ir::type::tile_shapes_t shapes;
|
||||||
|
for(expression *expr: shapes_->values()){
|
||||||
|
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
||||||
|
assert(shape);
|
||||||
|
shapes.push_back(shape);
|
||||||
|
}
|
||||||
|
return ir::tile_type::get(type, shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Pointer
|
||||||
|
ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
||||||
|
bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end();
|
||||||
|
return ir::pointer_type::get(type, is_ptr_to_const?4:1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function
|
||||||
|
void function::bind_parameters(ir::module *mod, ir::function *fn) const{
|
||||||
|
std::vector<ir::argument*> args = fn->args();
|
||||||
|
assert(args.size() == args_->values().size());
|
||||||
|
for(size_t i = 0; i < args.size(); i++){
|
||||||
|
parameter *param_i = args_->values().at(i);
|
||||||
|
const identifier *id_i = param_i->id();
|
||||||
|
if(id_i){
|
||||||
|
args[i]->set_name(id_i->name());
|
||||||
|
mod->set_value(id_i->name(), nullptr, args[i]);
|
||||||
|
mod->get_scope().types[id_i->name()] = args[i]->get_type();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::type* function::type_impl(ir::module* mod, ir::type *type, storage_spec_vec_const_ref_t) const{
|
||||||
|
std::vector<ir::type*> types;
|
||||||
|
for(parameter* param: args_->values())
|
||||||
|
types.push_back(param->type(mod));
|
||||||
|
return ir::function_type::get(type, types);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Declaration */
|
||||||
|
ir::value* declaration::codegen(ir::module* mod) const{
|
||||||
|
for(initializer *init: init_->values())
|
||||||
|
init->set_specifier(spec_);
|
||||||
|
init_->codegen(mod);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Initializer */
|
||||||
|
ir::type* initializer::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
||||||
|
return decl_->type(mod, type, storage);
|
||||||
|
}
|
||||||
|
|
||||||
|
void initializer::set_specifier(const declaration_specifier *spec) {
|
||||||
|
spec_ = spec;
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* initializer::codegen(ir::module * mod) const{
|
||||||
|
std::vector<STORAGE_SPEC_T> storage = spec_->storage();
|
||||||
|
ir::type *ty = decl_->type(mod, spec_->type(mod), storage);
|
||||||
|
std::string name = decl_->id()->name();
|
||||||
|
ir::value *value = ir::undef_value::get(ty);
|
||||||
|
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
||||||
|
auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
|
||||||
|
if(csts == nullptr)
|
||||||
|
throw std::runtime_error("must specify constant list for metaparameters");
|
||||||
|
std::vector<unsigned> values;
|
||||||
|
for(constant* cst: csts->values())
|
||||||
|
values.push_back(cst->value());
|
||||||
|
value = ir::metaparameter::create(mod->get_context(), ty, values);
|
||||||
|
mod->register_global(name, value);
|
||||||
|
}
|
||||||
|
else if(expr_){
|
||||||
|
value = expr_->codegen(mod);
|
||||||
|
value = explicit_cast(mod->get_builder(), value, ty);
|
||||||
|
implicit_broadcast(mod, ty, value);
|
||||||
|
}
|
||||||
|
value->set_name(name);
|
||||||
|
mod->set_value(name, value);
|
||||||
|
mod->get_scope().types[name] = ty;
|
||||||
|
if(auto *x = dynamic_cast<ir::alloc_const*>(value))
|
||||||
|
mod->add_alloc(x);
|
||||||
|
if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end())
|
||||||
|
mod->set_const(name);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Type name */
|
||||||
|
ir::type *type_name::type(ir::module *mod) const{
|
||||||
|
return decl_->type(mod, spec_->type(mod), {});
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Function definition */
|
||||||
|
ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){
|
||||||
|
switch(spec){
|
||||||
|
case RESTRICT_T: return ir::noalias;
|
||||||
|
case READONLY_T: return ir::readonly;
|
||||||
|
case WRITEONLY_T: return ir::writeonly;
|
||||||
|
default: throw std::runtime_error("cannot convert storage specifier to IR function attribute");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* function_definition::codegen(ir::module *mod) const{
|
||||||
|
ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->storage());
|
||||||
|
const std::string &name = header_->id()->name();
|
||||||
|
ir::function *fn = mod->get_or_insert_function(name, prototype);
|
||||||
|
for(unsigned i = 0; i < header_->get_num_args(); i++){
|
||||||
|
parameter *param = header_->get_arg(i);
|
||||||
|
std::vector<STORAGE_SPEC_T> storage = param->storage();
|
||||||
|
for(STORAGE_SPEC_T spec: storage)
|
||||||
|
fn->add_attr(1 + i, get_ir_attr(spec));
|
||||||
|
}
|
||||||
|
header_->bind_parameters(mod, fn);
|
||||||
|
ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn);
|
||||||
|
mod->seal_block(entry);
|
||||||
|
mod->get_builder().set_insert_point(entry);
|
||||||
|
body_->codegen(mod);
|
||||||
|
mod->get_builder().create_ret_void();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
49
lib/ast/error.cpp
Normal file
49
lib/ast/error.cpp
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
#include "triton/ast/error.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
static int current_line = 0;
|
||||||
|
static int current_column = 0;
|
||||||
|
|
||||||
|
// begin token
|
||||||
|
void update_location(const char *text) {
|
||||||
|
for (int i = 0; text[i] != '\0'; i++){
|
||||||
|
if (text[i] == '\n'){
|
||||||
|
current_column = 0;
|
||||||
|
current_line++;
|
||||||
|
}
|
||||||
|
else if (text[i] == '\t')
|
||||||
|
current_column += 8 - (current_column % 8);
|
||||||
|
else
|
||||||
|
current_column++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_error(const char *cerror) {
|
||||||
|
std::string error(cerror);
|
||||||
|
auto it = error.find("syntax error,");
|
||||||
|
error.replace(it, 13, "");
|
||||||
|
std::cerr << "error at line " << current_line << " (column " << current_column << "): " << error << std::endl;
|
||||||
|
throw std::runtime_error("compilation failed");
|
||||||
|
}
|
||||||
|
|
||||||
|
char return_impl(char t, const char * yytext) {
|
||||||
|
update_location(yytext);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
yytokentype return_impl(yytokentype t, const char * yytext){
|
||||||
|
update_location(yytext);
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
|
||||||
|
void return_void(const char * yytext){
|
||||||
|
update_location(yytext);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
329
lib/ast/expression.cpp
Normal file
329
lib/ast/expression.cpp
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
#include "triton/ast/expression.h"
|
||||||
|
#include "triton/ast/declaration.h"
|
||||||
|
#include "triton/ir/constant.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/builder.h"
|
||||||
|
#include "triton/ir/type.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
|
||||||
|
/* Binary operator */
|
||||||
|
ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
|
||||||
|
{
|
||||||
|
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
||||||
|
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
|
||||||
|
implicit_broadcast(mod, lhs, rhs);
|
||||||
|
if(op_==MUL && is_float)
|
||||||
|
return builder.create_fmul(lhs, rhs, name);
|
||||||
|
if(op_==MUL && is_int)
|
||||||
|
return builder.create_mul(lhs, rhs, name);
|
||||||
|
if(op_==DIV && is_float)
|
||||||
|
return builder.create_fdiv(lhs, rhs, name);
|
||||||
|
if(op_==DIV && is_int && is_signed)
|
||||||
|
return builder.create_sdiv(lhs, rhs, name);
|
||||||
|
if(op_==DIV && is_int && !is_signed)
|
||||||
|
return builder.create_udiv(lhs, rhs, name);
|
||||||
|
if(op_==MOD && is_float)
|
||||||
|
return builder.create_frem(lhs, rhs, name);
|
||||||
|
if(op_==MOD && is_int && is_signed)
|
||||||
|
return builder.create_srem(lhs, rhs, name);
|
||||||
|
if(op_==MOD && is_int && !is_signed)
|
||||||
|
return builder.create_urem(lhs, rhs, name);
|
||||||
|
if(op_==ADD && is_float)
|
||||||
|
return builder.create_fadd(lhs, rhs, name);
|
||||||
|
if(op_==ADD && is_int)
|
||||||
|
return builder.create_add(lhs, rhs);
|
||||||
|
if(op_==ADD && is_ptr)
|
||||||
|
return builder.create_gep(lhs, {rhs});
|
||||||
|
if(op_==SUB && is_float)
|
||||||
|
return builder.create_fsub(lhs, rhs, name);
|
||||||
|
if(op_==SUB && is_int)
|
||||||
|
return builder.create_sub(lhs, rhs, name);
|
||||||
|
if(op_==SUB && is_ptr)
|
||||||
|
return builder.create_gep(lhs, {builder.create_neg(rhs)});
|
||||||
|
if(op_==LEFT_SHIFT)
|
||||||
|
return builder.create_shl(lhs, rhs, name);
|
||||||
|
if(op_==RIGHT_SHIFT)
|
||||||
|
return builder.create_ashr(lhs, rhs, name);
|
||||||
|
if(op_ == LT && is_float)
|
||||||
|
return builder.create_fcmpOLT(lhs, rhs, name);
|
||||||
|
if(op_ == LT && is_int && is_signed)
|
||||||
|
return builder.create_icmpSLT(lhs, rhs, name);
|
||||||
|
if(op_ == LT && is_int && !is_signed)
|
||||||
|
return builder.create_icmpULT(lhs, rhs, name);
|
||||||
|
if(op_ == GT && is_float)
|
||||||
|
return builder.create_fcmpOGT(lhs, rhs, name);
|
||||||
|
if(op_ == GT && is_int && is_signed)
|
||||||
|
return builder.create_icmpSGT(lhs, rhs, name);
|
||||||
|
if(op_ == GT && is_int && !is_signed)
|
||||||
|
return builder.create_icmpUGT(lhs, rhs, name);
|
||||||
|
if(op_ == LE && is_float)
|
||||||
|
return builder.create_fcmpOLE(lhs, rhs, name);
|
||||||
|
if(op_ == LE && is_int && is_signed)
|
||||||
|
return builder.create_icmpSLE(lhs, rhs, name);
|
||||||
|
if(op_ == LE && is_int && !is_signed)
|
||||||
|
return builder.create_icmpULE(lhs, rhs, name);
|
||||||
|
if(op_ == GE && is_float)
|
||||||
|
return builder.create_fcmpOGE(lhs, rhs, name);
|
||||||
|
if(op_ == GE && is_int && is_signed)
|
||||||
|
return builder.create_icmpSGE(lhs, rhs, name);
|
||||||
|
if(op_ == GE && is_int && !is_signed)
|
||||||
|
return builder.create_icmpUGE(lhs, rhs, name);
|
||||||
|
if(op_ == EQ && is_float)
|
||||||
|
return builder.create_fcmpOEQ(lhs, rhs, name);
|
||||||
|
if(op_ == EQ && is_int)
|
||||||
|
return builder.create_icmpEQ(lhs, rhs, name);
|
||||||
|
if(op_ == NE && is_float)
|
||||||
|
return builder.create_fcmpONE(lhs, rhs, name);
|
||||||
|
if(op_ == NE && is_int)
|
||||||
|
return builder.create_icmpNE(lhs, rhs, name);
|
||||||
|
if(op_ == AND)
|
||||||
|
return builder.create_and(lhs, rhs, name);
|
||||||
|
if(op_ == XOR)
|
||||||
|
return builder.create_xor(lhs, rhs, name);
|
||||||
|
if(op_ == OR)
|
||||||
|
return builder.create_or(lhs, rhs, name);
|
||||||
|
if(op_ == LAND)
|
||||||
|
return builder.create_and(lhs, rhs, name);
|
||||||
|
if(op_ == LOR)
|
||||||
|
return builder.create_or(lhs, rhs, name);
|
||||||
|
throw std::runtime_error("unreachable");
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* binary_expression::codegen(ir::module *mod) const{
|
||||||
|
ir::value *lhs = lhs_->codegen(mod);
|
||||||
|
ir::value *rhs = rhs_->codegen(mod);
|
||||||
|
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Builtin expression */
|
||||||
|
|
||||||
|
// alloc constant
|
||||||
|
ir::value* alloc_const_expression::codegen(ir::module *mod) const {
|
||||||
|
ir::type *ty = spec_->type(mod);
|
||||||
|
ir::constant_int *size = (ir::constant_int*)size_->codegen(mod);
|
||||||
|
ir::alloc_const *res = new ir::alloc_const(ty, size);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get_global_range
|
||||||
|
ir::value* get_global_range_expression::codegen(ir::module *mod) const {
|
||||||
|
ir::builder &builder = mod->get_builder();
|
||||||
|
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
|
||||||
|
}
|
||||||
|
|
||||||
|
// get_range_id
|
||||||
|
ir::value* get_range_id_expression::codegen(ir::module *mod) const {
|
||||||
|
return mod->get_builder().create_get_range_id(axis_->value());
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomic cas
|
||||||
|
ir::value* atomic_cas_expression::codegen(ir::module *mod) const {
|
||||||
|
ir::value *ptr = ptr_->codegen(mod);
|
||||||
|
ir::value *cmp = cmp_->codegen(mod);
|
||||||
|
ir::value *val = val_->codegen(mod);
|
||||||
|
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// matmul
|
||||||
|
ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||||
|
ir::value *A = A_->codegen(mod);
|
||||||
|
ir::value *B = B_->codegen(mod);
|
||||||
|
ir::value *C = C_->codegen(mod);
|
||||||
|
// unsigned M = A->get_type()->get_tile_shapes()[0];
|
||||||
|
// unsigned N = B->get_type()->get_tile_shapes()[1];
|
||||||
|
// ir::type *scalar_ty = A->get_type()->get_scalar_ty();
|
||||||
|
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
|
||||||
|
// ir::value *tmp = ir::undef_value::get(tile_ty);
|
||||||
|
// implicit_broadcast(mod, tmp, C);
|
||||||
|
return mod->get_builder().create_dot(A, B, C);
|
||||||
|
}
|
||||||
|
|
||||||
|
// min
|
||||||
|
ir::value* min_expression::codegen(ir::module *mod) const {
|
||||||
|
ir::value* cmp = binary_expression(LT, (node*)x_, (node*)y_).codegen(mod);
|
||||||
|
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
|
||||||
|
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
|
||||||
|
return mod->get_builder().create_select(cmp, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
// max
|
||||||
|
ir::value* max_expression::codegen(ir::module *mod) const {
|
||||||
|
ir::value* cmp = binary_expression(GT, (node*)x_, (node*)y_).codegen(mod);
|
||||||
|
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
|
||||||
|
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
|
||||||
|
return mod->get_builder().create_select(cmp, x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
// select
|
||||||
|
ir::value* select_expression::codegen(ir::module *mod) const {
|
||||||
|
ir::value* pred = pred_->codegen(mod);
|
||||||
|
ir::value* if_value = if_value_->codegen(mod);
|
||||||
|
ir::value* else_value = else_value_->codegen(mod);
|
||||||
|
return mod->get_builder().create_select(pred, if_value, else_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trans
|
||||||
|
ir::value* trans_expression::codegen(ir::module *mod) const {
|
||||||
|
return mod->get_builder().create_trans(arg_->codegen(mod));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Postfix expression */
|
||||||
|
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||||
|
ir::value *in = mod->get_value(id_->name());
|
||||||
|
const std::vector<slice*> &slices = slices_->values();
|
||||||
|
auto in_shapes = in->get_type()->get_tile_shapes();
|
||||||
|
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||||
|
ir::type::tile_shapes_t out_shapes(slices.size());
|
||||||
|
// create shapes
|
||||||
|
size_t current = 0;
|
||||||
|
for(size_t i = 0; i < out_shapes.size(); i++)
|
||||||
|
out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++];
|
||||||
|
return mod->get_builder().create_reshape(in, out_shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Unary operator */
|
||||||
|
ir::value *unary_expression::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{
|
||||||
|
ir::type *atype = arg->get_type();
|
||||||
|
bool is_float = atype->is_floating_point_ty();
|
||||||
|
bool is_int = atype->is_integer_ty();
|
||||||
|
if(op_ == INC)
|
||||||
|
return builder.create_add(arg, builder.get_int32(1), name);
|
||||||
|
if(op_ == DEC)
|
||||||
|
return builder.create_sub(arg, builder.get_int32(1), name);
|
||||||
|
if(op_ == PLUS)
|
||||||
|
return arg;
|
||||||
|
if(op_ == MINUS && is_float)
|
||||||
|
return builder.create_fneg(arg, name);
|
||||||
|
if(op_ == MINUS && is_int)
|
||||||
|
return builder.create_neg(arg, name);
|
||||||
|
if(op_ == ADDR)
|
||||||
|
throw std::runtime_error("not supported");
|
||||||
|
if(op_ == DEREF)
|
||||||
|
return builder.create_load(arg, name);
|
||||||
|
if(op_ == COMPL)
|
||||||
|
throw std::runtime_error("not supported");
|
||||||
|
if(op_ == NOT)
|
||||||
|
return builder.create_not(arg, name);
|
||||||
|
throw std::runtime_error("unreachable");
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* unary_expression::codegen(ir::module *mod) const{
|
||||||
|
ir::value *arg = arg_->codegen(mod);
|
||||||
|
ir::value *result = llvm_op(mod->get_builder(), arg, "");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Cast operator */
|
||||||
|
ir::value *cast_expression::llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const{
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
ir::value* cast_expression::codegen(ir::module *mod) const{
|
||||||
|
ir::value *arg = arg_->codegen(mod);
|
||||||
|
ir::type *T = T_->type(mod);
|
||||||
|
return llvm_op(mod->get_builder(), T, arg, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Conditional expression */
|
||||||
|
ir::value *conditional_expression::codegen(ir::module *mod) const{
|
||||||
|
ir::builder &builder = mod->get_builder();
|
||||||
|
ir::value *pred = cond_->codegen(mod);
|
||||||
|
ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
|
||||||
|
ir::value *true_mask = mask->get_result(0);
|
||||||
|
ir::value *false_mask = mask->get_result(1);
|
||||||
|
ir::value *true_value = true_value_->codegen(mod);
|
||||||
|
ir::value *false_value = false_value_->codegen(mod);
|
||||||
|
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
|
||||||
|
itn->set_mask_pred(true_mask);
|
||||||
|
if(auto *itn = dynamic_cast<ir::instruction*>(false_value))
|
||||||
|
itn->set_mask_pred(false_mask);
|
||||||
|
bool is_float, is_ptr, is_int, is_signed;
|
||||||
|
ir::value *uncasted_true_value = true_value;
|
||||||
|
ir::value *uncasted_false_value = false_value;
|
||||||
|
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
||||||
|
implicit_broadcast(mod, true_value, false_value);
|
||||||
|
{
|
||||||
|
ir::value *current = true_value;
|
||||||
|
while(current != uncasted_true_value) {
|
||||||
|
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
|
||||||
|
itn->set_mask_pred(true_mask);
|
||||||
|
current = itn->get_operand(0);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ir::value *current = false_value;
|
||||||
|
while(current != uncasted_false_value) {
|
||||||
|
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
|
||||||
|
itn->set_mask_pred(false_mask);
|
||||||
|
current = itn->get_operand(0);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Assignment expression */
|
||||||
|
ir::value *assignment_expression::codegen(ir::module *mod) const{
|
||||||
|
ir::value *rvalue = rvalue_->codegen(mod);
|
||||||
|
if(auto *x = dynamic_cast<const named_expression*>(lvalue_)){
|
||||||
|
ir::type *ty = mod->get_scope().types.at(x->id()->name());
|
||||||
|
rvalue = explicit_cast(mod->get_builder(), rvalue, ty);
|
||||||
|
implicit_broadcast(mod, ty, rvalue);
|
||||||
|
mod->set_value(x->id()->name(), rvalue);
|
||||||
|
}
|
||||||
|
else if(auto* x = dynamic_cast<const unary_expression*>(lvalue_)){
|
||||||
|
assert(x->get_op()==DEREF);
|
||||||
|
assert(x->lvalue());
|
||||||
|
ir::value *ptr = x->lvalue()->codegen(mod);
|
||||||
|
rvalue = mod->get_builder().create_store(ptr, rvalue);
|
||||||
|
}
|
||||||
|
return rvalue;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* String literal */
|
||||||
|
ir::value* string_literal::codegen(ir::module *) const{
|
||||||
|
throw std::runtime_error("not supported");
|
||||||
|
// return ir::constant_data_array::get_string(mod->get_context(), value_);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Constant */
|
||||||
|
ir::value* constant::codegen(ir::module *mod) const{
|
||||||
|
return mod->get_builder().get_int32(value_);
|
||||||
|
}
|
||||||
|
|
||||||
|
int constant::value() const{
|
||||||
|
return value_;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Constant range */
|
||||||
|
ir::value* constant_range::codegen(ir::module *mod) const{
|
||||||
|
return ir::constant_range::get((ir::constant_int*)first_->codegen(mod),
|
||||||
|
(ir::constant_int*)last_->codegen(mod));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Named */
|
||||||
|
ir::value* named_expression::codegen(ir::module *mod) const{
|
||||||
|
const std::string &name = id()->name();
|
||||||
|
const auto& declarations = mod->get_scope().types;
|
||||||
|
if(declarations.find(name) == declarations.end())
|
||||||
|
throw std::runtime_error("variable " + name + " not declared");
|
||||||
|
return mod->get_value(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@@ -1,855 +0,0 @@
|
|||||||
#include <functional>
|
|
||||||
#include <algorithm>
|
|
||||||
#include "triton/ast/ast.h"
|
|
||||||
#include "triton/ir/constant.h"
|
|
||||||
#include "triton/ir/function.h"
|
|
||||||
#include "triton/ir/module.h"
|
|
||||||
#include "triton/ir/basic_block.h"
|
|
||||||
#include "triton/ir/builder.h"
|
|
||||||
#include "triton/ir/type.h"
|
|
||||||
#include <iostream>
|
|
||||||
#include <stdarg.h>
|
|
||||||
|
|
||||||
|
|
||||||
namespace triton{
|
|
||||||
|
|
||||||
namespace ast{
|
|
||||||
|
|
||||||
static int current_line = 0;
|
|
||||||
static int current_column = 0;
|
|
||||||
|
|
||||||
/* node */
|
|
||||||
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
|
||||||
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
|
|
||||||
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
|
|
||||||
bool src_signed = false;
|
|
||||||
bool dst_signed = false;
|
|
||||||
if(src_scalar_ty == dst_scalar_ty)
|
|
||||||
return src;
|
|
||||||
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
|
|
||||||
return builder.create_si_to_fp(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
|
|
||||||
return builder.create_ui_to_fp(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
|
|
||||||
return builder.create_fp_to_si(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
|
|
||||||
return builder.create_fp_to_ui(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
|
||||||
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
|
|
||||||
return builder.create_fp_ext(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
|
||||||
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
|
|
||||||
return builder.create_fp_trunc(src, dst_ty);
|
|
||||||
|
|
||||||
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
|
|
||||||
src_scalar_ty->get_integer_bitwidth())
|
|
||||||
return builder.create_int_cast(src, dst_ty, dst_signed);
|
|
||||||
|
|
||||||
else
|
|
||||||
throw std::runtime_error("unreachable");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
|
||||||
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
|
|
||||||
// Input types
|
|
||||||
ir::type *left_ty = lhs->get_type()->get_scalar_ty();
|
|
||||||
ir::type *right_ty = rhs->get_type()->get_scalar_ty();
|
|
||||||
// One operand is pointer
|
|
||||||
if(left_ty->is_pointer_ty() || right_ty->is_pointer_ty()){
|
|
||||||
if(left_ty->is_pointer_ty() && right_ty->is_pointer_ty())
|
|
||||||
throw std::runtime_error("invalid operands");
|
|
||||||
if(right_ty->is_pointer_ty())
|
|
||||||
std::swap(lhs, rhs);
|
|
||||||
is_ptr = true;
|
|
||||||
}
|
|
||||||
// One operand is double
|
|
||||||
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
|
|
||||||
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
|
|
||||||
to_convert = explicit_cast(builder, to_convert, builder.get_double_ty());
|
|
||||||
is_float = true;
|
|
||||||
}
|
|
||||||
// One operand is float
|
|
||||||
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
|
|
||||||
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
|
|
||||||
to_convert = explicit_cast(builder, to_convert, builder.get_float_ty());
|
|
||||||
is_float = true;
|
|
||||||
}
|
|
||||||
// Both operands are integers
|
|
||||||
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
|
||||||
is_int = true;
|
|
||||||
is_signed = true; // always signed for now
|
|
||||||
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
|
|
||||||
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
|
|
||||||
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
|
|
||||||
to_convert = explicit_cast(builder, to_convert, dst_ty);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Not reachable
|
|
||||||
else
|
|
||||||
throw std::runtime_error("unreachable");
|
|
||||||
}
|
|
||||||
|
|
||||||
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) {
|
|
||||||
ir::type *lhs_ty = lhs->get_type();
|
|
||||||
ir::type *rhs_ty = rhs->get_type();
|
|
||||||
ir::type *res_ty = nullptr;
|
|
||||||
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
|
||||||
return;
|
|
||||||
else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
|
||||||
res_ty = lhs_ty;
|
|
||||||
else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty())
|
|
||||||
res_ty = rhs_ty;
|
|
||||||
else{
|
|
||||||
auto lhs_shapes = lhs_ty->get_tile_shapes();
|
|
||||||
auto rhs_shapes = rhs_ty->get_tile_shapes();
|
|
||||||
size_t lhs_size = lhs_shapes.size();
|
|
||||||
size_t rhs_size = rhs_shapes.size();
|
|
||||||
size_t res_size = std::max(lhs_size, rhs_size);
|
|
||||||
ir::type::tile_shapes_t res_shapes(res_size);
|
|
||||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
|
||||||
for(int i = 0; i < res_size; i++){
|
|
||||||
if(i >= res_size - lhs_size && i >= res_size - rhs_size)
|
|
||||||
res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i];
|
|
||||||
else if(i >= res_size - lhs_size)
|
|
||||||
res_shapes[i] = lhs_shapes[i];
|
|
||||||
else if(i >= res_size - rhs_size)
|
|
||||||
res_shapes[i] = rhs_shapes[i];
|
|
||||||
}
|
|
||||||
res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes);
|
|
||||||
}
|
|
||||||
implicit_broadcast(mod, res_ty, rhs);
|
|
||||||
implicit_broadcast(mod, res_ty, lhs);
|
|
||||||
}
|
|
||||||
|
|
||||||
void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){
|
|
||||||
ir::builder &builder = mod->get_builder();
|
|
||||||
ir::type *src_ty = src->get_type();
|
|
||||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
|
||||||
// Both are scalar
|
|
||||||
if(!ty->is_tile_ty() && !src_ty->is_tile_ty())
|
|
||||||
return;
|
|
||||||
// Broadcast scalar
|
|
||||||
if(ty->is_tile_ty() && !src_ty->is_tile_ty()){
|
|
||||||
src = builder.create_splat(src, ty->get_tile_shapes());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Downcast tile
|
|
||||||
if(!ty->is_tile_ty() && src_ty->is_tile_ty()){
|
|
||||||
for(ir::constant *shape: src_ty->get_tile_shapes())
|
|
||||||
if(shape != one)
|
|
||||||
throw std::runtime_error("cannot downcast");
|
|
||||||
src = builder.create_downcast(src);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Both are arrays
|
|
||||||
auto dst_shapes = ty->get_tile_shapes();
|
|
||||||
auto src_shapes = src_ty->get_tile_shapes();
|
|
||||||
int dst_dim = dst_shapes.size();
|
|
||||||
int src_dim = src_shapes.size();
|
|
||||||
// Pad
|
|
||||||
int off = dst_dim - src_dim;
|
|
||||||
for(size_t i = 0; i < off; i++)
|
|
||||||
src_shapes.insert(src_shapes.begin(), one);
|
|
||||||
if(off > 0)
|
|
||||||
src = builder.create_reshape(src, src_shapes);
|
|
||||||
// Broadcast
|
|
||||||
for(int i = dst_dim - 1; i>= 0; i--)
|
|
||||||
if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one)
|
|
||||||
throw std::runtime_error("cannot broadcast");
|
|
||||||
if(dst_shapes != src_shapes)
|
|
||||||
src = builder.create_broadcast(src, dst_shapes);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Helper */
|
|
||||||
inline bool is_terminator(ir::value* x) {
|
|
||||||
return x && dynamic_cast<ir::terminator_inst*>(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Translation unit */
|
|
||||||
ir::value* translation_unit::codegen(ir::module *mod) const{
|
|
||||||
mod->add_new_scope();
|
|
||||||
decls_.codegen(mod);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Declaration specifier */
|
|
||||||
ir::type* typed_declaration_specifier::type(ir::module *mod) const {
|
|
||||||
ir::context &ctx = mod->get_context();
|
|
||||||
switch (ty_) {
|
|
||||||
case VOID_T: return ir::type::get_void_ty(ctx);
|
|
||||||
case INT1_T: return ir::type::get_int1_ty(ctx);
|
|
||||||
case INT8_T: return ir::type::get_int8_ty(ctx);
|
|
||||||
case INT16_T: return ir::type::get_int16_ty(ctx);
|
|
||||||
case INT32_T: return ir::type::get_int32_ty(ctx);
|
|
||||||
case INT64_T: return ir::type::get_int64_ty(ctx);
|
|
||||||
case FLOAT32_T: return ir::type::get_float_ty(ctx);
|
|
||||||
case FLOAT64_T: return ir::type::get_double_ty(ctx);
|
|
||||||
default: throw std::runtime_error("unreachable");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<STORAGE_SPEC_T> typed_declaration_specifier::storage() const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
ir::type* storage_declaration_specifier::type(ir::module *mod) const {
|
|
||||||
return decl_spec_->type(mod);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<STORAGE_SPEC_T> storage_declaration_specifier::storage() const {
|
|
||||||
auto result = decl_spec_->storage();
|
|
||||||
result.push_back(storage_spec_);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/* Parameter */
|
|
||||||
ir::type* parameter::type(ir::module *mod) const {
|
|
||||||
return decl_->type(mod, spec_->type(mod), {});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<STORAGE_SPEC_T> parameter::storage() const {
|
|
||||||
return spec_->storage();
|
|
||||||
}
|
|
||||||
|
|
||||||
const identifier *parameter::id() const {
|
|
||||||
return decl_->id();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Declarators */
|
|
||||||
ir::type* declarator::type(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
|
||||||
if(ptr_)
|
|
||||||
return type_impl(mod, ptr_->type(mod, type, storage), storage);
|
|
||||||
return type_impl(mod, type, storage);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Identifier
|
|
||||||
ir::type* identifier::type_impl(ir::module *, ir::type *type, storage_spec_vec_const_ref_t) const{
|
|
||||||
return type;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string &identifier::name() const{
|
|
||||||
return name_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tile
|
|
||||||
ir::type* tile::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t) const{
|
|
||||||
ir::type::tile_shapes_t shapes;
|
|
||||||
for(expression *expr: shapes_->values()){
|
|
||||||
ir::constant_int *shape = dynamic_cast<ir::constant_int*>(expr->codegen(mod));
|
|
||||||
assert(shape);
|
|
||||||
shapes.push_back(shape);
|
|
||||||
}
|
|
||||||
return ir::tile_type::get(type, shapes);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// Pointer
|
|
||||||
ir::type* pointer::type_impl(ir::module*, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
|
||||||
bool is_ptr_to_const = std::find(storage.begin(), storage.end(), CONSTANT_SPACE_T) != storage.end();
|
|
||||||
return ir::pointer_type::get(type, is_ptr_to_const?4:1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function
|
|
||||||
void function::bind_parameters(ir::module *mod, ir::function *fn) const{
|
|
||||||
std::vector<ir::argument*> args = fn->args();
|
|
||||||
assert(args.size() == args_->values().size());
|
|
||||||
for(size_t i = 0; i < args.size(); i++){
|
|
||||||
parameter *param_i = args_->values().at(i);
|
|
||||||
const identifier *id_i = param_i->id();
|
|
||||||
if(id_i){
|
|
||||||
args[i]->set_name(id_i->name());
|
|
||||||
mod->set_value(id_i->name(), nullptr, args[i]);
|
|
||||||
mod->get_scope().types[id_i->name()] = args[i]->get_type();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::type* function::type_impl(ir::module* mod, ir::type *type, storage_spec_vec_const_ref_t) const{
|
|
||||||
std::vector<ir::type*> types;
|
|
||||||
for(parameter* param: args_->values())
|
|
||||||
types.push_back(param->type(mod));
|
|
||||||
return ir::function_type::get(type, types);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Function definition */
|
|
||||||
ir::attribute_t get_ir_attr(STORAGE_SPEC_T spec){
|
|
||||||
switch(spec){
|
|
||||||
case RESTRICT_T: return ir::noalias;
|
|
||||||
case READONLY_T: return ir::readonly;
|
|
||||||
case WRITEONLY_T: return ir::writeonly;
|
|
||||||
default: throw std::runtime_error("cannot convert storage specifier to IR function attribute");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value* function_definition::codegen(ir::module *mod) const{
|
|
||||||
ir::function_type *prototype = (ir::function_type*)header_->type(mod, spec_->type(mod), spec_->storage());
|
|
||||||
const std::string &name = header_->id()->name();
|
|
||||||
ir::function *fn = mod->get_or_insert_function(name, prototype);
|
|
||||||
for(unsigned i = 0; i < header_->get_num_args(); i++){
|
|
||||||
parameter *param = header_->get_arg(i);
|
|
||||||
std::vector<STORAGE_SPEC_T> storage = param->storage();
|
|
||||||
for(STORAGE_SPEC_T spec: storage)
|
|
||||||
fn->add_attr(1 + i, get_ir_attr(spec));
|
|
||||||
}
|
|
||||||
header_->bind_parameters(mod, fn);
|
|
||||||
ir::basic_block *entry = ir::basic_block::create(mod->get_context(), "entry", fn);
|
|
||||||
mod->seal_block(entry);
|
|
||||||
mod->get_builder().set_insert_point(entry);
|
|
||||||
body_->codegen(mod);
|
|
||||||
mod->get_builder().create_ret_void();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Statements */
|
|
||||||
ir::value* compound_statement::codegen(ir::module* mod) const{
|
|
||||||
mod->add_new_scope();
|
|
||||||
if(items_)
|
|
||||||
items_->codegen(mod);
|
|
||||||
mod->pop_scope();
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* expression statement */
|
|
||||||
ir::value* expression_statement::codegen(ir::module *mod) const{
|
|
||||||
ir::builder &builder = mod->get_builder();
|
|
||||||
ir::basic_block *block = builder.get_insert_block();
|
|
||||||
if(pred_) {
|
|
||||||
// check that it is an assignment
|
|
||||||
assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_);
|
|
||||||
assert(assignment);
|
|
||||||
// generate mask
|
|
||||||
ir::value *pred = pred_->codegen(mod);
|
|
||||||
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
|
|
||||||
// generate expression
|
|
||||||
unsigned szbegin = block->get_inst_list().size();
|
|
||||||
ir::value *expr = expr_->codegen(mod);
|
|
||||||
ir::basic_block::iterator begin = block->begin();
|
|
||||||
std::advance(begin, szbegin);
|
|
||||||
// set mask
|
|
||||||
ir::type *ty = expr->get_type();
|
|
||||||
for(auto it = begin; it != builder.get_insert_point(); it++)
|
|
||||||
(*it)->set_mask_pred(mask->get_result(0));
|
|
||||||
// if(auto *itn = dynamic_cast<ir::instruction*>(expr))
|
|
||||||
// itn->set_mask_pred(mask->get_result(0));
|
|
||||||
if(ty->is_void_ty())
|
|
||||||
return expr;
|
|
||||||
// merge with psi
|
|
||||||
ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr,
|
|
||||||
mask->get_result(1), ir::undef_value::get(ty));
|
|
||||||
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
|
|
||||||
mod->set_value(name, psi);
|
|
||||||
return psi;
|
|
||||||
}
|
|
||||||
return expr_->codegen(mod);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* For statement */
|
|
||||||
ir::value* iteration_statement::codegen(ir::module *mod) const{
|
|
||||||
ir::builder &builder = mod->get_builder();
|
|
||||||
ir::context &ctx = mod->get_context();
|
|
||||||
ir::basic_block *current_bb = builder.get_insert_block();
|
|
||||||
ir::function *fn = current_bb->get_parent();
|
|
||||||
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
|
|
||||||
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
|
||||||
mod->set_continue_fn([&](){
|
|
||||||
if(exec_)
|
|
||||||
exec_->codegen(mod);
|
|
||||||
ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx));
|
|
||||||
return builder.create_cond_br(cond, loop_bb, next_bb);
|
|
||||||
});
|
|
||||||
init_->codegen(mod);
|
|
||||||
ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx));
|
|
||||||
builder.create_cond_br(cond, loop_bb, next_bb);
|
|
||||||
// builder.create_br(loop_bb);
|
|
||||||
builder.set_insert_point(loop_bb);
|
|
||||||
if(!is_terminator(statements_->codegen(mod)))
|
|
||||||
mod->get_continue_fn()();
|
|
||||||
ir::basic_block *stop_bb = builder.get_insert_block();
|
|
||||||
mod->seal_block(stop_bb);
|
|
||||||
mod->seal_block(loop_bb);
|
|
||||||
mod->seal_block(builder.get_insert_block());
|
|
||||||
mod->seal_block(next_bb);
|
|
||||||
builder.set_insert_point(next_bb);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* While statement */
|
|
||||||
ir::value* while_statement::codegen(ir::module* mod) const{
|
|
||||||
ir::builder &builder = mod->get_builder();
|
|
||||||
ir::context &ctx = mod->get_context();
|
|
||||||
ir::basic_block *current_bb = builder.get_insert_block();
|
|
||||||
ir::function *fn = current_bb->get_parent();
|
|
||||||
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
|
|
||||||
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
|
||||||
mod->set_continue_fn([&](){
|
|
||||||
ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx));
|
|
||||||
return builder.create_cond_br(cond, loop_bb, next_bb);
|
|
||||||
});
|
|
||||||
ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx));
|
|
||||||
builder.create_cond_br(cond, loop_bb, next_bb);
|
|
||||||
builder.set_insert_point(loop_bb);
|
|
||||||
if(!is_terminator(statements_->codegen(mod)))
|
|
||||||
mod->get_continue_fn()();
|
|
||||||
ir::basic_block *stop_bb = builder.get_insert_block();
|
|
||||||
mod->seal_block(stop_bb);
|
|
||||||
mod->seal_block(loop_bb);
|
|
||||||
mod->seal_block(builder.get_insert_block());
|
|
||||||
mod->seal_block(next_bb);
|
|
||||||
builder.set_insert_point(next_bb);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Selection statement */
|
|
||||||
ir::value* selection_statement::codegen(ir::module* mod) const{
|
|
||||||
ir::builder &builder = mod->get_builder();
|
|
||||||
ir::context &ctx = mod->get_context();
|
|
||||||
ir::function *fn = builder.get_insert_block()->get_parent();
|
|
||||||
ir::value *cond = cond_->codegen(mod);
|
|
||||||
ir::basic_block *then_bb = ir::basic_block::create(ctx, "then", fn);
|
|
||||||
ir::basic_block *else_bb = else_value_?ir::basic_block::create(ctx, "else", fn):nullptr;
|
|
||||||
ir::basic_block *endif_bb = ir::basic_block::create(ctx, "endif", fn);
|
|
||||||
mod->seal_block(then_bb);
|
|
||||||
if(else_value_)
|
|
||||||
mod->seal_block(else_bb);
|
|
||||||
|
|
||||||
// Branch
|
|
||||||
if(else_value_)
|
|
||||||
builder.create_cond_br(cond, then_bb, else_bb);
|
|
||||||
else
|
|
||||||
builder.create_cond_br(cond, then_bb, endif_bb);
|
|
||||||
// Then
|
|
||||||
builder.set_insert_point(then_bb);
|
|
||||||
if(!is_terminator(then_value_->codegen(mod)))
|
|
||||||
builder.create_br(endif_bb);
|
|
||||||
// Else
|
|
||||||
if(else_value_){
|
|
||||||
builder.set_insert_point(else_bb);
|
|
||||||
if(!is_terminator(else_value_->codegen(mod)))
|
|
||||||
builder.create_br(endif_bb);
|
|
||||||
}
|
|
||||||
// Endif
|
|
||||||
mod->seal_block(endif_bb);
|
|
||||||
builder.set_insert_point(endif_bb);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Continue statement */
|
|
||||||
ir::value* continue_statement::codegen(ir::module *mod) const{
|
|
||||||
return mod->get_continue_fn()();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Declaration */
|
|
||||||
ir::value* declaration::codegen(ir::module* mod) const{
|
|
||||||
for(initializer *init: init_->values())
|
|
||||||
init->set_specifier(spec_);
|
|
||||||
init_->codegen(mod);
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Initializer */
|
|
||||||
ir::type* initializer::type_impl(ir::module *mod, ir::type *type, storage_spec_vec_const_ref_t storage) const{
|
|
||||||
return decl_->type(mod, type, storage);
|
|
||||||
}
|
|
||||||
|
|
||||||
void initializer::set_specifier(const declaration_specifier *spec) {
|
|
||||||
spec_ = spec;
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value* initializer::codegen(ir::module * mod) const{
|
|
||||||
std::vector<STORAGE_SPEC_T> storage = spec_->storage();
|
|
||||||
ir::type *ty = decl_->type(mod, spec_->type(mod), storage);
|
|
||||||
std::string name = decl_->id()->name();
|
|
||||||
ir::value *value = ir::undef_value::get(ty);
|
|
||||||
if(std::find(storage.begin(), storage.end(), TUNABLE_T) != storage.end()){
|
|
||||||
auto csts = dynamic_cast<list<constant*>*>((node*)expr_);
|
|
||||||
if(csts == nullptr)
|
|
||||||
throw std::runtime_error("must specify constant list for metaparameters");
|
|
||||||
std::vector<unsigned> values;
|
|
||||||
for(constant* cst: csts->values())
|
|
||||||
values.push_back(cst->value());
|
|
||||||
value = ir::metaparameter::create(mod->get_context(), ty, values);
|
|
||||||
mod->register_global(name, value);
|
|
||||||
}
|
|
||||||
else if(expr_){
|
|
||||||
value = expr_->codegen(mod);
|
|
||||||
value = explicit_cast(mod->get_builder(), value, ty);
|
|
||||||
implicit_broadcast(mod, ty, value);
|
|
||||||
}
|
|
||||||
value->set_name(name);
|
|
||||||
mod->set_value(name, value);
|
|
||||||
mod->get_scope().types[name] = ty;
|
|
||||||
if(auto *x = dynamic_cast<ir::alloc_const*>(value))
|
|
||||||
mod->add_alloc(x);
|
|
||||||
if(std::find(storage.begin(), storage.end(), CONST_T) != storage.end())
|
|
||||||
mod->set_const(name);
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*------------------*/
|
|
||||||
/* Expression */
|
|
||||||
/*------------------*/
|
|
||||||
/* Binary operator */
|
|
||||||
ir::value *binary_operator::llvm_op(ir::module *mod, ir::builder &builder, ir::value *lhs, ir::value *rhs, const std::string &name) const
|
|
||||||
{
|
|
||||||
bool is_float = false, is_ptr = false, is_int = false, is_signed = false;
|
|
||||||
implicit_cast(builder, lhs, rhs, is_float, is_ptr, is_int, is_signed);
|
|
||||||
implicit_broadcast(mod, lhs, rhs);
|
|
||||||
if(op_==MUL && is_float)
|
|
||||||
return builder.create_fmul(lhs, rhs, name);
|
|
||||||
if(op_==MUL && is_int)
|
|
||||||
return builder.create_mul(lhs, rhs, name);
|
|
||||||
if(op_==DIV && is_float)
|
|
||||||
return builder.create_fdiv(lhs, rhs, name);
|
|
||||||
if(op_==DIV && is_int && is_signed)
|
|
||||||
return builder.create_sdiv(lhs, rhs, name);
|
|
||||||
if(op_==DIV && is_int && !is_signed)
|
|
||||||
return builder.create_udiv(lhs, rhs, name);
|
|
||||||
if(op_==MOD && is_float)
|
|
||||||
return builder.create_frem(lhs, rhs, name);
|
|
||||||
if(op_==MOD && is_int && is_signed)
|
|
||||||
return builder.create_srem(lhs, rhs, name);
|
|
||||||
if(op_==MOD && is_int && !is_signed)
|
|
||||||
return builder.create_urem(lhs, rhs, name);
|
|
||||||
if(op_==ADD && is_float)
|
|
||||||
return builder.create_fadd(lhs, rhs, name);
|
|
||||||
if(op_==ADD && is_int)
|
|
||||||
return builder.create_add(lhs, rhs);
|
|
||||||
if(op_==ADD && is_ptr)
|
|
||||||
return builder.create_gep(lhs, {rhs});
|
|
||||||
if(op_==SUB && is_float)
|
|
||||||
return builder.create_fsub(lhs, rhs, name);
|
|
||||||
if(op_==SUB && is_int)
|
|
||||||
return builder.create_sub(lhs, rhs, name);
|
|
||||||
if(op_==SUB && is_ptr)
|
|
||||||
return builder.create_gep(lhs, {builder.create_neg(rhs)});
|
|
||||||
if(op_==LEFT_SHIFT)
|
|
||||||
return builder.create_shl(lhs, rhs, name);
|
|
||||||
if(op_==RIGHT_SHIFT)
|
|
||||||
return builder.create_ashr(lhs, rhs, name);
|
|
||||||
if(op_ == LT && is_float)
|
|
||||||
return builder.create_fcmpOLT(lhs, rhs, name);
|
|
||||||
if(op_ == LT && is_int && is_signed)
|
|
||||||
return builder.create_icmpSLT(lhs, rhs, name);
|
|
||||||
if(op_ == LT && is_int && !is_signed)
|
|
||||||
return builder.create_icmpULT(lhs, rhs, name);
|
|
||||||
if(op_ == GT && is_float)
|
|
||||||
return builder.create_fcmpOGT(lhs, rhs, name);
|
|
||||||
if(op_ == GT && is_int && is_signed)
|
|
||||||
return builder.create_icmpSGT(lhs, rhs, name);
|
|
||||||
if(op_ == GT && is_int && !is_signed)
|
|
||||||
return builder.create_icmpUGT(lhs, rhs, name);
|
|
||||||
if(op_ == LE && is_float)
|
|
||||||
return builder.create_fcmpOLE(lhs, rhs, name);
|
|
||||||
if(op_ == LE && is_int && is_signed)
|
|
||||||
return builder.create_icmpSLE(lhs, rhs, name);
|
|
||||||
if(op_ == LE && is_int && !is_signed)
|
|
||||||
return builder.create_icmpULE(lhs, rhs, name);
|
|
||||||
if(op_ == GE && is_float)
|
|
||||||
return builder.create_fcmpOGE(lhs, rhs, name);
|
|
||||||
if(op_ == GE && is_int && is_signed)
|
|
||||||
return builder.create_icmpSGE(lhs, rhs, name);
|
|
||||||
if(op_ == GE && is_int && !is_signed)
|
|
||||||
return builder.create_icmpUGE(lhs, rhs, name);
|
|
||||||
if(op_ == EQ && is_float)
|
|
||||||
return builder.create_fcmpOEQ(lhs, rhs, name);
|
|
||||||
if(op_ == EQ && is_int)
|
|
||||||
return builder.create_icmpEQ(lhs, rhs, name);
|
|
||||||
if(op_ == NE && is_float)
|
|
||||||
return builder.create_fcmpONE(lhs, rhs, name);
|
|
||||||
if(op_ == NE && is_int)
|
|
||||||
return builder.create_icmpNE(lhs, rhs, name);
|
|
||||||
if(op_ == AND)
|
|
||||||
return builder.create_and(lhs, rhs, name);
|
|
||||||
if(op_ == XOR)
|
|
||||||
return builder.create_xor(lhs, rhs, name);
|
|
||||||
if(op_ == OR)
|
|
||||||
return builder.create_or(lhs, rhs, name);
|
|
||||||
if(op_ == LAND)
|
|
||||||
return builder.create_and(lhs, rhs, name);
|
|
||||||
if(op_ == LOR)
|
|
||||||
return builder.create_or(lhs, rhs, name);
|
|
||||||
throw std::runtime_error("unreachable");
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value* binary_operator::codegen(ir::module *mod) const{
|
|
||||||
ir::value *lhs = lhs_->codegen(mod);
|
|
||||||
ir::value *rhs = rhs_->codegen(mod);
|
|
||||||
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Builtin expression */
|
|
||||||
|
|
||||||
// alloc constant
|
|
||||||
ir::value* alloc_const::codegen(ir::module *mod) const {
|
|
||||||
ir::type *ty = spec_->type(mod);
|
|
||||||
ir::constant_int *size = (ir::constant_int*)size_->codegen(mod);
|
|
||||||
ir::alloc_const *res = new ir::alloc_const(ty, size);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
// get_global_range
|
|
||||||
ir::value* get_global_range::codegen(ir::module *mod) const {
|
|
||||||
ir::builder &builder = mod->get_builder();
|
|
||||||
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
|
|
||||||
}
|
|
||||||
|
|
||||||
// get_range_id
|
|
||||||
ir::value* get_range_id::codegen(ir::module *mod) const {
|
|
||||||
return mod->get_builder().create_get_range_id(axis_->value());
|
|
||||||
}
|
|
||||||
|
|
||||||
// atomic cas
|
|
||||||
ir::value* atomic_cas::codegen(ir::module *mod) const {
|
|
||||||
ir::value *ptr = ptr_->codegen(mod);
|
|
||||||
ir::value *cmp = cmp_->codegen(mod);
|
|
||||||
ir::value *val = val_->codegen(mod);
|
|
||||||
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
// matmul
|
|
||||||
ir::value* matmul_expression::codegen(ir::module *mod) const {
|
|
||||||
ir::value *A = A_->codegen(mod);
|
|
||||||
ir::value *B = B_->codegen(mod);
|
|
||||||
ir::value *C = C_->codegen(mod);
|
|
||||||
// unsigned M = A->get_type()->get_tile_shapes()[0];
|
|
||||||
// unsigned N = B->get_type()->get_tile_shapes()[1];
|
|
||||||
// ir::type *scalar_ty = A->get_type()->get_scalar_ty();
|
|
||||||
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
|
|
||||||
// ir::value *tmp = ir::undef_value::get(tile_ty);
|
|
||||||
// implicit_broadcast(mod, tmp, C);
|
|
||||||
return mod->get_builder().create_dot(A, B, C);
|
|
||||||
}
|
|
||||||
|
|
||||||
// min
|
|
||||||
ir::value* min_expression::codegen(ir::module *mod) const {
|
|
||||||
ir::value* cmp = binary_operator(LT, (node*)x_, (node*)y_).codegen(mod);
|
|
||||||
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
|
|
||||||
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
|
|
||||||
return mod->get_builder().create_select(cmp, x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
// max
|
|
||||||
ir::value* max_expression::codegen(ir::module *mod) const {
|
|
||||||
ir::value* cmp = binary_operator(GT, (node*)x_, (node*)y_).codegen(mod);
|
|
||||||
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
|
|
||||||
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
|
|
||||||
return mod->get_builder().create_select(cmp, x, y);
|
|
||||||
}
|
|
||||||
|
|
||||||
// select
|
|
||||||
ir::value* select_expression::codegen(ir::module *mod) const {
|
|
||||||
ir::value* pred = pred_->codegen(mod);
|
|
||||||
ir::value* if_value = if_value_->codegen(mod);
|
|
||||||
ir::value* else_value = else_value_->codegen(mod);
|
|
||||||
return mod->get_builder().create_select(pred, if_value, else_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Trans
|
|
||||||
ir::value* trans_expression::codegen(ir::module *mod) const {
|
|
||||||
return mod->get_builder().create_trans(arg_->codegen(mod));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Postfix expression */
|
|
||||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
|
||||||
ir::value *in = mod->get_value(id_->name());
|
|
||||||
const std::vector<slice*> &slices = slices_->values();
|
|
||||||
auto in_shapes = in->get_type()->get_tile_shapes();
|
|
||||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
|
||||||
ir::type::tile_shapes_t out_shapes(slices.size());
|
|
||||||
// create shapes
|
|
||||||
size_t current = 0;
|
|
||||||
for(size_t i = 0; i < out_shapes.size(); i++)
|
|
||||||
out_shapes[i] = (slices[i]->type()==NEWAXIS)?one:in_shapes[current++];
|
|
||||||
return mod->get_builder().create_reshape(in, out_shapes);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/* Unary operator */
|
|
||||||
ir::value *unary_operator::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{
|
|
||||||
ir::type *atype = arg->get_type();
|
|
||||||
bool is_float = atype->is_floating_point_ty();
|
|
||||||
bool is_int = atype->is_integer_ty();
|
|
||||||
if(op_ == INC)
|
|
||||||
return builder.create_add(arg, builder.get_int32(1), name);
|
|
||||||
if(op_ == DEC)
|
|
||||||
return builder.create_sub(arg, builder.get_int32(1), name);
|
|
||||||
if(op_ == PLUS)
|
|
||||||
return arg;
|
|
||||||
if(op_ == MINUS && is_float)
|
|
||||||
return builder.create_fneg(arg, name);
|
|
||||||
if(op_ == MINUS && is_int)
|
|
||||||
return builder.create_neg(arg, name);
|
|
||||||
if(op_ == ADDR)
|
|
||||||
throw std::runtime_error("not supported");
|
|
||||||
if(op_ == DEREF)
|
|
||||||
return builder.create_load(arg, name);
|
|
||||||
if(op_ == COMPL)
|
|
||||||
throw std::runtime_error("not supported");
|
|
||||||
if(op_ == NOT)
|
|
||||||
return builder.create_not(arg, name);
|
|
||||||
throw std::runtime_error("unreachable");
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value* unary_operator::codegen(ir::module *mod) const{
|
|
||||||
ir::value *arg = arg_->codegen(mod);
|
|
||||||
ir::value *result = llvm_op(mod->get_builder(), arg, "");
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Cast operator */
|
|
||||||
ir::value *cast_operator::llvm_op(ir::builder &builder, ir::type *T, ir::value *arg, const std::string &name) const{
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
ir::value* cast_operator::codegen(ir::module *mod) const{
|
|
||||||
ir::value *arg = arg_->codegen(mod);
|
|
||||||
ir::type *T = T_->type(mod);
|
|
||||||
return llvm_op(mod->get_builder(), T, arg, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Conditional expression */
|
|
||||||
ir::value *conditional_expression::codegen(ir::module *mod) const{
|
|
||||||
ir::builder &builder = mod->get_builder();
|
|
||||||
ir::value *pred = cond_->codegen(mod);
|
|
||||||
ir::instruction *mask = (ir::instruction*)builder.create_mask(pred);
|
|
||||||
ir::value *true_mask = mask->get_result(0);
|
|
||||||
ir::value *false_mask = mask->get_result(1);
|
|
||||||
ir::value *true_value = true_value_->codegen(mod);
|
|
||||||
ir::value *false_value = false_value_->codegen(mod);
|
|
||||||
if(auto *itn = dynamic_cast<ir::instruction*>(true_value))
|
|
||||||
itn->set_mask_pred(true_mask);
|
|
||||||
if(auto *itn = dynamic_cast<ir::instruction*>(false_value))
|
|
||||||
itn->set_mask_pred(false_mask);
|
|
||||||
bool is_float, is_ptr, is_int, is_signed;
|
|
||||||
ir::value *uncasted_true_value = true_value;
|
|
||||||
ir::value *uncasted_false_value = false_value;
|
|
||||||
implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed);
|
|
||||||
implicit_broadcast(mod, true_value, false_value);
|
|
||||||
{
|
|
||||||
ir::value *current = true_value;
|
|
||||||
while(current != uncasted_true_value) {
|
|
||||||
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
|
|
||||||
itn->set_mask_pred(true_mask);
|
|
||||||
current = itn->get_operand(0);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ir::value *current = false_value;
|
|
||||||
while(current != uncasted_false_value) {
|
|
||||||
if(auto *itn = dynamic_cast<ir::instruction*>(current)){
|
|
||||||
itn->set_mask_pred(false_mask);
|
|
||||||
current = itn->get_operand(0);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Assignment expression */
|
|
||||||
ir::value *assignment_expression::codegen(ir::module *mod) const{
|
|
||||||
ir::value *rvalue = rvalue_->codegen(mod);
|
|
||||||
if(auto *x = dynamic_cast<const named_expression*>(lvalue_)){
|
|
||||||
ir::type *ty = mod->get_scope().types.at(x->id()->name());
|
|
||||||
rvalue = explicit_cast(mod->get_builder(), rvalue, ty);
|
|
||||||
implicit_broadcast(mod, ty, rvalue);
|
|
||||||
mod->set_value(x->id()->name(), rvalue);
|
|
||||||
}
|
|
||||||
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
|
|
||||||
assert(x->get_op()==DEREF);
|
|
||||||
assert(x->lvalue());
|
|
||||||
ir::value *ptr = x->lvalue()->codegen(mod);
|
|
||||||
rvalue = mod->get_builder().create_store(ptr, rvalue);
|
|
||||||
}
|
|
||||||
return rvalue;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Type name */
|
|
||||||
ir::type *type_name::type(ir::module *mod) const{
|
|
||||||
return decl_->type(mod, spec_->type(mod), {});
|
|
||||||
}
|
|
||||||
|
|
||||||
/* String literal */
|
|
||||||
ir::value* string_literal::codegen(ir::module *) const{
|
|
||||||
throw std::runtime_error("not supported");
|
|
||||||
// return ir::constant_data_array::get_string(mod->get_context(), value_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Constant */
|
|
||||||
ir::value* constant::codegen(ir::module *mod) const{
|
|
||||||
return mod->get_builder().get_int32(value_);
|
|
||||||
}
|
|
||||||
|
|
||||||
int constant::value() const{
|
|
||||||
return value_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Constant range */
|
|
||||||
ir::value* constant_range::codegen(ir::module *mod) const{
|
|
||||||
return ir::constant_range::get((ir::constant_int*)first_->codegen(mod),
|
|
||||||
(ir::constant_int*)last_->codegen(mod));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Named */
|
|
||||||
ir::value* named_expression::codegen(ir::module *mod) const{
|
|
||||||
const std::string &name = id()->name();
|
|
||||||
const auto& declarations = mod->get_scope().types;
|
|
||||||
if(declarations.find(name) == declarations.end())
|
|
||||||
throw std::runtime_error("variable " + name + " not declared");
|
|
||||||
return mod->get_value(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// begin token
|
|
||||||
void update_location(const char *text) {
|
|
||||||
for (int i = 0; text[i] != '\0'; i++){
|
|
||||||
if (text[i] == '\n'){
|
|
||||||
current_column = 0;
|
|
||||||
current_line++;
|
|
||||||
}
|
|
||||||
else if (text[i] == '\t')
|
|
||||||
current_column += 8 - (current_column % 8);
|
|
||||||
else
|
|
||||||
current_column++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_error(const char *cerror) {
|
|
||||||
std::string error(cerror);
|
|
||||||
auto it = error.find("syntax error,");
|
|
||||||
error.replace(it, 13, "");
|
|
||||||
std::cerr << "error at line " << current_line << " (column " << current_column << "): " << error << std::endl;
|
|
||||||
throw std::runtime_error("compilation failed");
|
|
||||||
}
|
|
||||||
|
|
||||||
char return_impl(char t, const char * yytext) {
|
|
||||||
update_location(yytext);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
yytokentype return_impl(yytokentype t, const char * yytext){
|
|
||||||
update_location(yytext);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
void return_void(const char * yytext){
|
|
||||||
update_location(yytext);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
18
lib/ast/module.cpp
Normal file
18
lib/ast/module.cpp
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
#include "triton/ast/module.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
/* Translation unit */
|
||||||
|
ir::value* translation_unit::codegen(ir::module *mod) const{
|
||||||
|
mod->add_new_scope();
|
||||||
|
decls_.codegen(mod);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
160
lib/ast/node.cpp
Normal file
160
lib/ast/node.cpp
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
#include "triton/ast/node.h"
|
||||||
|
#include "triton/ir/builder.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/constant.h"
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
/* node */
|
||||||
|
ir::value *node::explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty){
|
||||||
|
ir::type *src_scalar_ty = src->get_type()->get_scalar_ty();
|
||||||
|
ir::type *dst_scalar_ty = dst_ty->get_scalar_ty();
|
||||||
|
bool src_signed = false;
|
||||||
|
bool dst_signed = false;
|
||||||
|
if(src_scalar_ty == dst_scalar_ty)
|
||||||
|
return src;
|
||||||
|
else if(src_scalar_ty->is_integer_ty() && src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||||
|
return builder.create_si_to_fp(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_integer_ty() && !src_signed && dst_scalar_ty->is_floating_point_ty())
|
||||||
|
return builder.create_ui_to_fp(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && dst_signed)
|
||||||
|
return builder.create_fp_to_si(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_integer_ty() && !dst_signed)
|
||||||
|
return builder.create_fp_to_ui(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
||||||
|
src_scalar_ty->get_fp_mantissa_width() < dst_scalar_ty->get_fp_mantissa_width())
|
||||||
|
return builder.create_fp_ext(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_floating_point_ty() && dst_scalar_ty->is_floating_point_ty() &&
|
||||||
|
src_scalar_ty->get_fp_mantissa_width() > dst_scalar_ty->get_fp_mantissa_width())
|
||||||
|
return builder.create_fp_trunc(src, dst_ty);
|
||||||
|
|
||||||
|
else if(src_scalar_ty->is_integer_ty() && dst_scalar_ty->is_integer_ty() &&
|
||||||
|
src_scalar_ty->get_integer_bitwidth())
|
||||||
|
return builder.create_int_cast(src, dst_ty, dst_signed);
|
||||||
|
|
||||||
|
else
|
||||||
|
throw std::runtime_error("unreachable");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||||
|
bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed){
|
||||||
|
// Input types
|
||||||
|
ir::type *left_ty = lhs->get_type()->get_scalar_ty();
|
||||||
|
ir::type *right_ty = rhs->get_type()->get_scalar_ty();
|
||||||
|
// One operand is pointer
|
||||||
|
if(left_ty->is_pointer_ty() || right_ty->is_pointer_ty()){
|
||||||
|
if(left_ty->is_pointer_ty() && right_ty->is_pointer_ty())
|
||||||
|
throw std::runtime_error("invalid operands");
|
||||||
|
if(right_ty->is_pointer_ty())
|
||||||
|
std::swap(lhs, rhs);
|
||||||
|
is_ptr = true;
|
||||||
|
}
|
||||||
|
// One operand is double
|
||||||
|
else if(left_ty->is_double_ty() || right_ty->is_double_ty()){
|
||||||
|
ir::value *&to_convert = left_ty->is_double_ty()?rhs:lhs;
|
||||||
|
to_convert = explicit_cast(builder, to_convert, builder.get_double_ty());
|
||||||
|
is_float = true;
|
||||||
|
}
|
||||||
|
// One operand is float
|
||||||
|
else if(left_ty->is_float_ty() || right_ty->is_float_ty()){
|
||||||
|
ir::value *&to_convert = left_ty->is_float_ty()?rhs:lhs;
|
||||||
|
to_convert = explicit_cast(builder, to_convert, builder.get_float_ty());
|
||||||
|
is_float = true;
|
||||||
|
}
|
||||||
|
// Both operands are integers
|
||||||
|
else if(left_ty->is_integer_ty() && right_ty->is_integer_ty()){
|
||||||
|
is_int = true;
|
||||||
|
is_signed = true; // always signed for now
|
||||||
|
if(left_ty->get_integer_bitwidth() != right_ty->get_integer_bitwidth()){
|
||||||
|
ir::value *&to_convert = (left_ty->get_integer_bitwidth() > right_ty->get_integer_bitwidth())?rhs:lhs;
|
||||||
|
ir::type *dst_ty = (to_convert==lhs)?right_ty:left_ty;
|
||||||
|
to_convert = explicit_cast(builder, to_convert, dst_ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Not reachable
|
||||||
|
else
|
||||||
|
throw std::runtime_error("unreachable");
|
||||||
|
}
|
||||||
|
|
||||||
|
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) {
|
||||||
|
ir::type *lhs_ty = lhs->get_type();
|
||||||
|
ir::type *rhs_ty = rhs->get_type();
|
||||||
|
ir::type *res_ty = nullptr;
|
||||||
|
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||||
|
return;
|
||||||
|
else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||||
|
res_ty = lhs_ty;
|
||||||
|
else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty())
|
||||||
|
res_ty = rhs_ty;
|
||||||
|
else{
|
||||||
|
auto lhs_shapes = lhs_ty->get_tile_shapes();
|
||||||
|
auto rhs_shapes = rhs_ty->get_tile_shapes();
|
||||||
|
size_t lhs_size = lhs_shapes.size();
|
||||||
|
size_t rhs_size = rhs_shapes.size();
|
||||||
|
size_t res_size = std::max(lhs_size, rhs_size);
|
||||||
|
ir::type::tile_shapes_t res_shapes(res_size);
|
||||||
|
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||||
|
for(int i = 0; i < res_size; i++){
|
||||||
|
if(i >= res_size - lhs_size && i >= res_size - rhs_size)
|
||||||
|
res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i];
|
||||||
|
else if(i >= res_size - lhs_size)
|
||||||
|
res_shapes[i] = lhs_shapes[i];
|
||||||
|
else if(i >= res_size - rhs_size)
|
||||||
|
res_shapes[i] = rhs_shapes[i];
|
||||||
|
}
|
||||||
|
res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes);
|
||||||
|
}
|
||||||
|
implicit_broadcast(mod, res_ty, rhs);
|
||||||
|
implicit_broadcast(mod, res_ty, lhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){
|
||||||
|
ir::builder &builder = mod->get_builder();
|
||||||
|
ir::type *src_ty = src->get_type();
|
||||||
|
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||||
|
// Both are scalar
|
||||||
|
if(!ty->is_tile_ty() && !src_ty->is_tile_ty())
|
||||||
|
return;
|
||||||
|
// Broadcast scalar
|
||||||
|
if(ty->is_tile_ty() && !src_ty->is_tile_ty()){
|
||||||
|
src = builder.create_splat(src, ty->get_tile_shapes());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Downcast tile
|
||||||
|
if(!ty->is_tile_ty() && src_ty->is_tile_ty()){
|
||||||
|
for(ir::constant *shape: src_ty->get_tile_shapes())
|
||||||
|
if(shape != one)
|
||||||
|
throw std::runtime_error("cannot downcast");
|
||||||
|
src = builder.create_downcast(src);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Both are arrays
|
||||||
|
auto dst_shapes = ty->get_tile_shapes();
|
||||||
|
auto src_shapes = src_ty->get_tile_shapes();
|
||||||
|
int dst_dim = dst_shapes.size();
|
||||||
|
int src_dim = src_shapes.size();
|
||||||
|
// Pad
|
||||||
|
int off = dst_dim - src_dim;
|
||||||
|
for(size_t i = 0; i < off; i++)
|
||||||
|
src_shapes.insert(src_shapes.begin(), one);
|
||||||
|
if(off > 0)
|
||||||
|
src = builder.create_reshape(src, src_shapes);
|
||||||
|
// Broadcast
|
||||||
|
for(int i = dst_dim - 1; i>= 0; i--)
|
||||||
|
if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one)
|
||||||
|
throw std::runtime_error("cannot broadcast");
|
||||||
|
if(dst_shapes != src_shapes)
|
||||||
|
src = builder.create_broadcast(src, dst_shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
160
lib/ast/statement.cpp
Normal file
160
lib/ast/statement.cpp
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
#include "triton/ast/expression.h"
|
||||||
|
#include "triton/ast/statement.h"
|
||||||
|
#include "triton/ast/declaration.h"
|
||||||
|
#include "triton/ir/constant.h"
|
||||||
|
#include "triton/ir/module.h"
|
||||||
|
#include "triton/ir/basic_block.h"
|
||||||
|
#include "triton/ir/builder.h"
|
||||||
|
#include "triton/ir/type.h"
|
||||||
|
|
||||||
|
namespace triton{
|
||||||
|
|
||||||
|
namespace ast{
|
||||||
|
|
||||||
|
/* Helpers */
|
||||||
|
inline bool is_terminator(ir::value* x) {
|
||||||
|
return x && dynamic_cast<ir::terminator_inst*>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Statements */
|
||||||
|
ir::value* compound_statement::codegen(ir::module* mod) const{
|
||||||
|
mod->add_new_scope();
|
||||||
|
if(items_)
|
||||||
|
items_->codegen(mod);
|
||||||
|
mod->pop_scope();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Expression statement */
|
||||||
|
ir::value* expression_statement::codegen(ir::module *mod) const{
|
||||||
|
ir::builder &builder = mod->get_builder();
|
||||||
|
ir::basic_block *block = builder.get_insert_block();
|
||||||
|
if(pred_) {
|
||||||
|
// check that it is an assignment
|
||||||
|
assignment_expression *assignment = dynamic_cast<assignment_expression*>(expr_);
|
||||||
|
assert(assignment);
|
||||||
|
// generate mask
|
||||||
|
ir::value *pred = pred_->codegen(mod);
|
||||||
|
ir::mask_inst *mask = (ir::mask_inst*)builder.create_mask(pred);
|
||||||
|
// generate expression
|
||||||
|
unsigned szbegin = block->get_inst_list().size();
|
||||||
|
ir::value *expr = expr_->codegen(mod);
|
||||||
|
ir::basic_block::iterator begin = block->begin();
|
||||||
|
std::advance(begin, szbegin);
|
||||||
|
// set mask
|
||||||
|
ir::type *ty = expr->get_type();
|
||||||
|
for(auto it = begin; it != builder.get_insert_point(); it++)
|
||||||
|
(*it)->set_mask_pred(mask->get_result(0));
|
||||||
|
// if(auto *itn = dynamic_cast<ir::instruction*>(expr))
|
||||||
|
// itn->set_mask_pred(mask->get_result(0));
|
||||||
|
if(ty->is_void_ty())
|
||||||
|
return expr;
|
||||||
|
// merge with psi
|
||||||
|
ir::psi_inst *psi = (ir::psi_inst*)builder.create_merge(mask->get_result(0), expr,
|
||||||
|
mask->get_result(1), ir::undef_value::get(ty));
|
||||||
|
std::string name = ((named_expression*)assignment->lvalue())->id()->name();
|
||||||
|
mod->set_value(name, psi);
|
||||||
|
return psi;
|
||||||
|
}
|
||||||
|
return expr_->codegen(mod);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* For statement */
|
||||||
|
ir::value* iteration_statement::codegen(ir::module *mod) const{
|
||||||
|
ir::builder &builder = mod->get_builder();
|
||||||
|
ir::context &ctx = mod->get_context();
|
||||||
|
ir::basic_block *current_bb = builder.get_insert_block();
|
||||||
|
ir::function *fn = current_bb->get_parent();
|
||||||
|
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
|
||||||
|
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
||||||
|
mod->set_continue_fn([&](){
|
||||||
|
if(exec_)
|
||||||
|
exec_->codegen(mod);
|
||||||
|
ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx));
|
||||||
|
return builder.create_cond_br(cond, loop_bb, next_bb);
|
||||||
|
});
|
||||||
|
init_->codegen(mod);
|
||||||
|
ir::value *cond = explicit_cast(builder, stop_->codegen(mod), ir::type::get_int1_ty(ctx));
|
||||||
|
builder.create_cond_br(cond, loop_bb, next_bb);
|
||||||
|
// builder.create_br(loop_bb);
|
||||||
|
builder.set_insert_point(loop_bb);
|
||||||
|
if(!is_terminator(statements_->codegen(mod)))
|
||||||
|
mod->get_continue_fn()();
|
||||||
|
ir::basic_block *stop_bb = builder.get_insert_block();
|
||||||
|
mod->seal_block(stop_bb);
|
||||||
|
mod->seal_block(loop_bb);
|
||||||
|
mod->seal_block(builder.get_insert_block());
|
||||||
|
mod->seal_block(next_bb);
|
||||||
|
builder.set_insert_point(next_bb);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* While statement */
|
||||||
|
ir::value* while_statement::codegen(ir::module* mod) const{
|
||||||
|
ir::builder &builder = mod->get_builder();
|
||||||
|
ir::context &ctx = mod->get_context();
|
||||||
|
ir::basic_block *current_bb = builder.get_insert_block();
|
||||||
|
ir::function *fn = current_bb->get_parent();
|
||||||
|
ir::basic_block *loop_bb = ir::basic_block::create(ctx, "loop", fn);
|
||||||
|
ir::basic_block *next_bb = ir::basic_block::create(ctx, "postloop", fn);
|
||||||
|
mod->set_continue_fn([&](){
|
||||||
|
ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx));
|
||||||
|
return builder.create_cond_br(cond, loop_bb, next_bb);
|
||||||
|
});
|
||||||
|
ir::value *cond = explicit_cast(builder, cond_->codegen(mod), ir::type::get_int1_ty(ctx));
|
||||||
|
builder.create_cond_br(cond, loop_bb, next_bb);
|
||||||
|
builder.set_insert_point(loop_bb);
|
||||||
|
if(!is_terminator(statements_->codegen(mod)))
|
||||||
|
mod->get_continue_fn()();
|
||||||
|
ir::basic_block *stop_bb = builder.get_insert_block();
|
||||||
|
mod->seal_block(stop_bb);
|
||||||
|
mod->seal_block(loop_bb);
|
||||||
|
mod->seal_block(builder.get_insert_block());
|
||||||
|
mod->seal_block(next_bb);
|
||||||
|
builder.set_insert_point(next_bb);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Selection statement */
|
||||||
|
ir::value* selection_statement::codegen(ir::module* mod) const{
|
||||||
|
ir::builder &builder = mod->get_builder();
|
||||||
|
ir::context &ctx = mod->get_context();
|
||||||
|
ir::function *fn = builder.get_insert_block()->get_parent();
|
||||||
|
ir::value *cond = cond_->codegen(mod);
|
||||||
|
ir::basic_block *then_bb = ir::basic_block::create(ctx, "then", fn);
|
||||||
|
ir::basic_block *else_bb = else_value_?ir::basic_block::create(ctx, "else", fn):nullptr;
|
||||||
|
ir::basic_block *endif_bb = ir::basic_block::create(ctx, "endif", fn);
|
||||||
|
mod->seal_block(then_bb);
|
||||||
|
if(else_value_)
|
||||||
|
mod->seal_block(else_bb);
|
||||||
|
|
||||||
|
// Branch
|
||||||
|
if(else_value_)
|
||||||
|
builder.create_cond_br(cond, then_bb, else_bb);
|
||||||
|
else
|
||||||
|
builder.create_cond_br(cond, then_bb, endif_bb);
|
||||||
|
// Then
|
||||||
|
builder.set_insert_point(then_bb);
|
||||||
|
if(!is_terminator(then_value_->codegen(mod)))
|
||||||
|
builder.create_br(endif_bb);
|
||||||
|
// Else
|
||||||
|
if(else_value_){
|
||||||
|
builder.set_insert_point(else_bb);
|
||||||
|
if(!is_terminator(else_value_->codegen(mod)))
|
||||||
|
builder.create_br(endif_bb);
|
||||||
|
}
|
||||||
|
// Endif
|
||||||
|
mod->seal_block(endif_bb);
|
||||||
|
builder.set_insert_point(endif_bb);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Continue statement */
|
||||||
|
ir::value* continue_statement::codegen(ir::module *mod) const{
|
||||||
|
return mod->get_continue_fn()();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Reference in New Issue
Block a user