Deprecation of Triton-C and Replacement by decorated Python functions (#86)

This PR implements a major overhaul of the frontend for Triton, and replaces Triton-C by a pure Python API in which kernels are defined as @triton.jit decorated functions. The documentation and tutorials have also been updated to accommodate these changes.

See documentations for more information on the new API
This commit is contained in:
Philippe Tillet
2021-04-20 22:29:40 -04:00
committed by Philippe Tillet
parent 1fdb465b71
commit 39f4730305
91 changed files with 4500 additions and 13008 deletions

View File

@@ -1,30 +1,31 @@
#ifndef _TRITON_CODEGEN_PASS_H_
#define _TRITON_CODEGEN_PASS_H_
#include <list>
#include <memory>
namespace triton{
namespace ir{
class module;
}
namespace driver{
class device;
class module;
class kernel;
}
}
namespace triton{
namespace codegen{
class pass {
public:
virtual void run(ir::module& m);
};
// TODO:
// There should be a proper pass manager there!
void add_passes_to_emit_bin(ir::module &ir, driver::device* dev, int num_warps,
driver::module*& mod, driver::kernel*& ker, size_t& shared_mem);
class pass_manager {
public:
void add(pass* p);
void run(ir::module& m);
private:
std::list<pass*> passes;
};
}
}
#endif

View File

@@ -119,7 +119,7 @@ public:
void visit_exp_inst(ir::exp_inst*);
void visit_log_inst(ir::log_inst*);
void visit_get_program_id_inst(ir::get_program_id_inst*);
void visit_get_num_program_inst(ir::get_num_program_inst*);
void visit_get_num_programs_inst(ir::get_num_programs_inst*);
void visit_atomic_cas_inst(ir::atomic_cas_inst*);
void visit_atomic_exch_inst(ir::atomic_exch_inst*);
void visit_atomic_add_inst(ir::atomic_add_inst*);

View File

@@ -59,7 +59,7 @@ public:
// CUDA
class cu_module: public module {
std::string compile_llvm_module(std::unique_ptr<llvm::Module> module, driver::device* device);
std::string compile_llvm_module(llvm::Module* module, driver::device* device);
void init_from_ptx(const std::string& ptx);
public:

View File

@@ -1,4 +1,4 @@
#pragma once
#pragma once
#ifndef _TRITON_IR_BUILDER_H_
#define _TRITON_IR_BUILDER_H_
@@ -27,6 +27,8 @@ class builder{
public:
// Constructor
builder(context &ctx);
// Getters
const context& get_context() { return ctx_; }
// Setters
void set_insert_point(iterator instr);
void set_insert_point(instruction* i);
@@ -38,6 +40,9 @@ public:
value *get_int1(bool val);
value *get_int32(int32_t val);
value *get_int64(int64_t val);
value *get_float16(float val);
value *get_float32(float val);
value *get_range(int32_t lo, int32_t hi);
// Types
type *get_void_ty();
type *get_int1_ty();
@@ -50,11 +55,10 @@ public:
type *get_double_ty();
// Insert
template<typename InstTy>
InstTy* insert(InstTy *inst, const std::string &name = ""){
InstTy* insert(InstTy *inst){
assert(block_);
block_->get_inst_list().insert(insert_point_, inst);
inst->set_parent(block_);
inst->set_name(name);
// for(ir::value* op: inst->ops())
// op->add_use(inst);
return inst;
@@ -64,91 +68,87 @@ public:
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
// Cast instructions
value *create_cast(cast_op_t op, value *v, type *dst_ty, const std::string &name = "");
value* create_ptr_to_int(value *src, type *dst_ty, const std::string &name = "");
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
value* create_ui_to_fp(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_to_si(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_to_ui(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_ext(value *src, type *dst_ty, const std::string &name = "");
value* create_fp_trunc(value *src, type *dst_ty, const std::string &name = "");
value* create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name = "");
value *create_downcast(value *arg, const std::string &name = "");
value *create_cast(cast_op_t op, value *v, type *dst_ty);
value* create_ptr_to_int(value *src, type *dst_ty);
value* create_si_to_fp(value *src, type *dst_ty);
value* create_ui_to_fp(value *src, type *dst_ty);
value* create_fp_to_si(value *src, type *dst_ty);
value* create_fp_to_ui(value *src, type *dst_ty);
value* create_fp_ext(value *src, type *dst_ty);
value* create_fp_trunc(value *src, type *dst_ty);
value* create_int_cast(value *src, type *dst_ty, bool is_signed);
value *create_downcast(value *arg);
// Phi instruction
phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = "");
phi_node* create_phi(type *ty, unsigned num_reserved);
// Binary instructions
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs, const std::string &name = "");
value *create_fdiv(value *lhs, value *rhs, const std::string &name = "");
value *create_frem(value *lhs, value *rhs, const std::string &name = "");
value *create_fadd(value *lhs, value *rhs, const std::string &name = "");
value *create_fsub(value *lhs, value *rhs, const std::string &name = "");
value *create_mul(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs, const std::string &name = "");
value *create_udiv(value *lhs, value *rhs, const std::string &name = "");
value *create_srem(value *lhs, value *rhs, const std::string &name = "");
value *create_urem(value *lhs, value *rhs, const std::string &name = "");
value *create_add(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_lshr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_ashr(value *lhs, value *rhs, const std::string &name = "", bool has_nuw = false, bool has_nsw = false);
value *create_insert_nuwnswb_binop(binary_op_t op, value *lhs, value *rhs, bool has_nuw, bool has_nsw);
value *create_fmul(value *lhs, value *rhs);
value *create_fdiv(value *lhs, value *rhs);
value *create_frem(value *lhs, value *rhs);
value *create_fadd(value *lhs, value *rhs);
value *create_fsub(value *lhs, value *rhs);
value *create_mul(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sdiv(value *lhs, value *rhs);
value *create_udiv(value *lhs, value *rhs);
value *create_srem(value *lhs, value *rhs);
value *create_urem(value *lhs, value *rhs);
value *create_add(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_sub(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_shl(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_lshr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
value *create_ashr(value *lhs, value *rhs, bool has_nuw = false, bool has_nsw = false);
// GEP
value *create_gep(value *ptr, const std::vector<value*>& idx_list, const std::string &name = "");
value *create_gep(value *ptr, const std::vector<value*>& idx_list);
// Comparison (int)
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSLE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSLT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSGE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpSGT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpULE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpULT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpUGE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpUGT(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpEQ(value *lhs, value *rhs, const std::string &name = "");
value *create_icmpNE(value *lhs, value *rhs, const std::string &name = "");
value *create_icmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_icmpSLE(value *lhs, value *rhs);
value *create_icmpSLT(value *lhs, value *rhs);
value *create_icmpSGE(value *lhs, value *rhs);
value *create_icmpSGT(value *lhs, value *rhs);
value *create_icmpULE(value *lhs, value *rhs);
value *create_icmpULT(value *lhs, value *rhs);
value *create_icmpUGE(value *lhs, value *rhs);
value *create_icmpUGT(value *lhs, value *rhs);
value *create_icmpEQ(value *lhs, value *rhs);
value *create_icmpNE(value *lhs, value *rhs);
// Comparison (float)
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOLT(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOGT(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOLE(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOGE(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpOEQ(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmpONE(value *lhs, value *rhs, const std::string &name = "");
value *create_fcmp(cmp_pred_t pred, value *lhs, value *rhs);
value *create_fcmpOLT(value *lhs, value *rhs);
value *create_fcmpOGT(value *lhs, value *rhs);
value *create_fcmpOLE(value *lhs, value *rhs);
value *create_fcmpOGE(value *lhs, value *rhs);
value *create_fcmpOEQ(value *lhs, value *rhs);
value *create_fcmpONE(value *lhs, value *rhs);
// Logical
value *create_and(value *lhs, value *rhs, const std::string &name = "");
value *create_xor(value *lhs, value *rhs, const std::string &name = "");
value *create_or(value *lhs, value *rhs, const std::string &name = "");
// Unary
// value *create_fneg(value *arg, const std::string &name = "");
// value *create_neg(value *arg, const std::string &name = "");
// value *create_not(value *arg, const std::string &name = "");
value *create_and(value *lhs, value *rhs);
value *create_xor(value *lhs, value *rhs);
value *create_or(value *lhs, value *rhs);
// Input/Output
value *create_load(value *arg, const std::string &name = "");
value *create_store(value *ptr, value *val, const std::string &name = "");
value *create_masked_load(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_masked_store(value *ptr, value *val, value *mask, const std::string &name = "");
// Tile instruction
value *create_splat(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_load(value *arg);
value *create_store(value *ptr, value *val);
value *create_masked_load(value *arg, value *mask, value *false_value);
value *create_masked_store(value *ptr, value *val, value *mask);
// Block instruction
value *create_splat(value *arg, const type::block_shapes_t &shapes);
value *create_reshape(value *arg, const type::block_shapes_t &shapes);
value *create_broadcast(value *arg, const type::block_shapes_t &shapes);
// Built-in instruction
value *create_get_program_id(unsigned axis, const std::string &name = "");
value *create_get_num_program(unsigned axis, const std::string &name = "");
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
value *create_atomic_add(value *ptr, value *val, value *msk, const std::string &name = "");
value *create_exp(value* arg, const std::string &name = "");
value *create_log(value* arg, const std::string &name = "");
value *create_dot(value *A, value *B, value *C, const std::string &name = "");
value *create_trans(value *A, const std::vector<int> &perm = {}, const std::string &name = "");
value *create_sqrt(value *A, const std::string &name = "");
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis, const std::string &name = "");
value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = "");
value *create_get_program_id(unsigned axis);
value *create_get_num_programs(unsigned axis);
value *create_atomic_cas(value *ptr, value *cmp, value *val);
value *create_atomic_exch(value *ptr, value *val);
value *create_atomic_add(value *ptr, value *val, value *msk);
value *create_exp(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);
value *create_select(value *pred, value *if_value, value *else_value);
// Intrinsics
value *create_copy_to_shared(value *arg, const std::string &name = "");
value *create_masked_load_async(value *arg, value *mask, value *false_value, const std::string &name = "");
value *create_copy_from_shared(value *arg, const std::string &name = "");
value *create_copy_to_shared(value *arg);
value *create_masked_load_async(value *arg, value *mask, value *false_value);
value *create_copy_from_shared(value *arg);
value *create_barrier(const std::string &name = "");
value *create_async_wait(int N);
@@ -158,6 +158,7 @@ private:
iterator insert_point_;
};
}
}

View File

@@ -9,6 +9,7 @@
namespace triton{
namespace ir{
class builder;
class type;
class context_impl;
@@ -16,8 +17,11 @@ class context_impl;
class context {
public:
context();
context(const context&) = delete;
context& operator=(const context&) = delete;
public:
ir::builder* builder = nullptr;
std::shared_ptr<context_impl> p_impl;
};

View File

@@ -28,13 +28,16 @@ public:
integer_type int1_ty, int8_ty, int16_ty, int32_ty, int64_ty, int128_ty;
// Pointer types
std::map<std::pair<type*, unsigned>, pointer_type*> ptr_tys;
std::map<std::pair<type*, type::tile_shapes_t>, tile_type*> tile_tys;
// Block types
std::map<std::pair<type*, type::block_shapes_t>, block_type*> block_tys;
// Int constants
std::map<std::pair<type*, uint64_t>, constant_int*> int_constants_;
// Float constants
std::map<std::pair<type*, double>, constant_fp*> fp_constants_;
// undef values
std::map<type*, undef_value*> uv_constants_;
};
}

View File

@@ -0,0 +1,97 @@
#pragma once
#ifndef _TRITON_IR_DISPATCH_H_
#define _TRITON_IR_DISPATCH_H_
#include "triton/ir/builder.h"
namespace triton{
namespace ir{
/*----------------------------------------------
higher level functions that follow the likely
semantics of most expected frontends
----------------------------------------------*/
struct semantic_error: public std::runtime_error {
semantic_error(const std::string& msg):
std::runtime_error(msg) { }
};
struct dispatch{
typedef ir::type::block_shapes_t shape_t;
// programming model
static ir::value *program_id(int axis, ir::builder *builder);
static ir::value *num_programs(int axis, ir::builder *builder);
// binary operators
static ir::value *add(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *sub(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mul(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *truediv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *floordiv(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *mod(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *and_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *or_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *xor_(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *lshr(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *shl(ir::value *input, ir::value *other, ir::builder *builder);
// unary operators
static ir::value *plus(ir::value *input, ir::builder *builder);
static ir::value *minus(ir::value *input, ir::builder *builder);
static ir::value *invert(ir::value *input, ir::builder *builder);
// comparison operators
static ir::value *greater_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *greater_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_than(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *less_equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *equal(ir::value *input, ir::value *other, ir::builder *builder);
static ir::value *not_equal(ir::value *input, ir::value *other, ir::builder *builder);
// block creation
static ir::value* arange(int start, int end, ir::builder *builder);
static ir::value* zeros(shape_t shape, ir::type *dtype, ir::builder *builder);
// casting ops
static ir::value *reshape(ir::value *input, shape_t shape, ir::builder *builder);
static ir::value *broadcast(ir::value *input, shape_t shape, ir::builder *builder);
static std::tuple<ir::value*, ir::value*> broadcast(ir::value *lhs, ir::value* rhs, ir::builder *builder);
static ir::value *cast(ir::value *input, ir::type *type, ir::builder *builder);
// memory operators
static ir::value *load(ir::value* ptr, ir::value* mask, ir::value* other, ir::builder *builder);
static ir::value *store(ir::value* ptr, ir::value *value, ir::value *mask, ir::builder *builder);
static ir::value *atomic_cas(ir::value* ptr, ir::value *cmp, ir::value *val, ir::builder *builder);
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);
// reduction
static ir::value *min(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *max(ir::value *input, unsigned int axis, ir::builder *builder);
static ir::value *sum(ir::value *input, unsigned int axis, ir::builder *builder);
// math
static ir::value *exp(ir::value *x, ir::builder *builder);
static ir::value *log(ir::value *x, ir::builder *builder);
static ir::value *sqrt(ir::value *x, ir::builder *builder);
// internal (debug/optimization)
static ir::value *multiple_of(ir::value *x, int value, ir::builder *builder);
static ir::value *debug_barrier(ir::builder *builder);
};
}
}
#endif

View File

@@ -35,7 +35,7 @@ private:
/* Attribute */
enum attribute_kind_t {
readonly,
readonly = 0,
writeonly,
noalias,
aligned,
@@ -71,7 +71,7 @@ public:
case writeonly: return ".writeonly";
case noalias: return ".noalias";
case aligned: return ".aligned(" + std::to_string(value_) + ")";
case multiple_of: return ".readonly";
case multiple_of: return ".multipleof(" + std::to_string(value_) + ")";
case retune: return ".retunr";
default: break;
}
@@ -102,7 +102,7 @@ private:
public:
// accessors
const args_t &args() { return args_; }
const args_t &args() const { return args_; }
function_type* get_fn_type() { return fn_ty_; }
// factory methods

View File

@@ -514,7 +514,7 @@ public:
class retile_inst: public unary_inst {
protected:
retile_inst(value *arg, value_id_t id, const type::tile_shapes_t &shapes, const std::string &name, instruction *next);
retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
};
// reshape
@@ -525,7 +525,7 @@ private:
std::string repr_impl() const { return "reshape"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(reshape_inst)
_TRITON_DEFINE_ACCEPT(reshape_inst)
@@ -539,7 +539,7 @@ private:
std::string repr_impl() const { return "splat"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(splat_inst)
_TRITON_DEFINE_ACCEPT(splat_inst)
@@ -553,7 +553,7 @@ private:
std::string repr_impl() const { return "broadcast"; }
public:
static instruction* create(value *arg, const type::tile_shapes_t &shape_suffix,
static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(broadcast_inst)
_TRITON_DEFINE_ACCEPT(broadcast_inst)
@@ -597,16 +597,16 @@ private:
unsigned axis_;
};
class get_num_program_inst: public builtin_inst {
class get_num_programs_inst: public builtin_inst {
private:
get_num_program_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_num_program(" + std::to_string(axis_) + ")"; }
get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
unsigned get_axis() const { return axis_; }
_TRITON_DEFINE_CLONE(get_num_program_inst)
_TRITON_DEFINE_ACCEPT(get_num_program_inst)
_TRITON_DEFINE_CLONE(get_num_programs_inst)
_TRITON_DEFINE_ACCEPT(get_num_programs_inst)
private:
unsigned axis_;

View File

@@ -36,6 +36,11 @@ class alloc_const;
/* Module */
struct scope {
public:
const std::map<std::string, ir::value*>& get_values() { return values; }
void set_type(const std::string& name, ir::type* ty) { types[name] = ty; }
ir::type* get_type(const std::string& name) { return types.at(name); }
private:
std::map<std::string, ir::type*> types;
std::map<std::string, ir::value*> values;
};
@@ -61,8 +66,7 @@ private:
void push_function(function *fn) { functions_.push_back(fn); }
public:
module(const std::string &name);
context& get_context();
module(const std::string &name, builder& builder);
builder& get_builder();
// Setters
void set_value(const std::string& name, basic_block* block, value *x);
@@ -95,8 +99,7 @@ public:
private:
std::string name_;
context context_;
builder builder_;
builder& builder_;
std::map<val_key_t, value*> values_;
std::map<val_key_t, type*> types_;
std::set<std::string> const_;

View File

@@ -18,7 +18,7 @@ class constant_int;
/* Type */
class type {
public:
typedef std::vector<unsigned> tile_shapes_t;
typedef std::vector<unsigned> block_shapes_t;
protected:
typedef std::vector<type*> contained_tys_vec_t;
@@ -43,7 +43,7 @@ public:
FunctionTyID, ///< 11: Functions
PointerTyID, ///< 12: Pointers
StructTyID, ///< 13: Struct
TileTyID, ///< 14: Tile
BlockTyID, ///< 14: Tile
};
public:
@@ -62,7 +62,7 @@ public:
unsigned get_tile_bitwidth() const;
unsigned get_primitive_size_in_bits() const;
type *get_scalar_ty() const;
const tile_shapes_t& get_tile_shapes() const;
block_shapes_t get_block_shapes() const;
const size_t get_tile_rank() const;
const size_t get_tile_ranks1() const;
unsigned get_tile_num_elements() const;
@@ -83,7 +83,7 @@ public:
get_integer_bitwidth() == bitwidth;}
bool is_bool_ty() const { return is_integer_ty(1); }
bool is_pointer_ty() const { return id_ == PointerTyID; }
bool is_tile_ty() const { return id_ == TileTyID; }
bool is_block_ty() const { return id_ == BlockTyID; }
// Composite predicates
bool is_int_or_tileint_ty();
@@ -110,7 +110,7 @@ public:
// repr
std::string tile_repr() const {
std::string res = get_tile_element_ty()->repr();
auto shapes = get_tile_shapes();
auto shapes = get_block_shapes();
res += "<";
for(size_t i = 0; i < shapes.size(); i++){
if(i > 0)
@@ -137,7 +137,7 @@ public:
case FunctionTyID: return "fn";
case PointerTyID: return get_pointer_element_ty()->repr() + "*";
case StructTyID: return "struct";
case TileTyID: return tile_repr();
case BlockTyID: return tile_repr();
default: break;
}
assert(false);
@@ -180,23 +180,23 @@ public:
type* get_type_at_index(value *idx) const;
};
class tile_type: public composite_type {
class block_type: public composite_type {
private:
tile_type(type *ty, const tile_shapes_t &shapes);
block_type(type *ty, const block_shapes_t &shapes);
static bool is_valid_elt_ty(type *ty);
public:
// accessors
const tile_shapes_t& get_shapes() const { return shapes_; }
const block_shapes_t& get_shapes() const { return shapes_; }
unsigned get_num_elements() const;
unsigned get_bitwidth() const;
// factory methods
static tile_type* get(type *ty, const tile_shapes_t &shapes);
static tile_type* get_same_shapes(type *ty, type *ref);
static block_type* get(type *ty, const block_shapes_t &shapes);
static block_type* get_same_shapes(type *ty, type *ref);
private:
tile_shapes_t shapes_;
block_shapes_t shapes_;
};
class pointer_type: public type {

View File

@@ -52,7 +52,7 @@ class exp_inst;
class log_inst;
class get_program_id_inst;
class get_num_program_inst;
class get_num_programs_inst;
class atomic_cas_inst;
class atomic_exch_inst;
class atomic_add_inst;
@@ -128,7 +128,7 @@ public:
virtual void visit_downcast_inst(downcast_inst*) = 0;
virtual void visit_get_program_id_inst(get_program_id_inst*) = 0;
virtual void visit_get_num_program_inst(get_num_program_inst*) = 0;
virtual void visit_get_num_programs_inst(get_num_programs_inst*) = 0;
virtual void visit_atomic_cas_inst(atomic_cas_inst*) = 0;
virtual void visit_atomic_exch_inst(atomic_exch_inst*) = 0;
virtual void visit_atomic_add_inst(atomic_add_inst*) = 0;

View File

@@ -1,823 +0,0 @@
#pragma once
#ifndef _WGTCC_AST_H_
#define _WGTCC_AST_H_
#include "error.h"
#include "token.h"
#include "type.h"
#include <cassert>
#include <list>
#include <memory>
#include <string>
class Visitor;
template<typename T> class Evaluator;
class AddrEvaluator;
class Generator;
class Scope;
class Parser;
class ASTNode;
class Token;
class TokenSequence;
// Expressions
class Expr;
class BinaryOp;
class UnaryOp;
class ConditionalOp;
class FuncCall;
class TempVar;
class Constant;
class Identifier;
class Object;
struct Initializer;
class Declaration;
class Enumerator;
// Statements
class Stmt;
class IfStmt;
class ForStmt;
class JumpStmt;
class LabelStmt;
class EmptyStmt;
class CompoundStmt;
class FuncDef;
class TranslationUnit;
/*
* AST Node
*/
class ASTNode {
public:
struct Attr{
enum KindT{
MULTIPLEOF,
ALIGNED,
NOALIAS,
READONLY,
WRITEONLY,
RETUNE,
};
KindT kind;
std::vector<Expr*> vals;
};
using AttrList = std::vector<Attr>;
public:
virtual ~ASTNode() {}
virtual void Accept(Visitor* v) = 0;
protected:
ASTNode() {}
MemPool* pool_ {nullptr};
};
using ExtDecl = ASTNode;
/*
* Statements
*/
class Stmt : public ASTNode {
public:
virtual ~Stmt() {}
protected:
Stmt() {}
};
class EmptyStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static EmptyStmt* New();
virtual ~EmptyStmt() {}
virtual void Accept(Visitor* v);
protected:
EmptyStmt() {}
};
class LabelStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static LabelStmt* New();
~LabelStmt() {}
virtual void Accept(Visitor* v);
std::string Repr() const { return ".L" + std::to_string(tag_); }
protected:
LabelStmt(): tag_(GenTag()) {}
private:
static int GenTag() {
static int tag = 0;
return ++tag;
}
int tag_; // 使用整型的tag值而不直接用字符串
};
class IfStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static IfStmt* New(Expr* cond, Stmt* then, Stmt* els=nullptr);
virtual ~IfStmt() {}
virtual void Accept(Visitor* v);
protected:
IfStmt(Expr* cond, Stmt* then, Stmt* els = nullptr)
: cond_(cond), then_(then), else_(els) {}
private:
Expr* cond_;
Stmt* then_;
Stmt* else_;
};
class ForStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ForStmt* New(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr);
virtual ~ForStmt() {}
virtual void Accept(Visitor* v);
protected:
ForStmt(Stmt* body, Stmt* init = nullptr, Expr* cond = nullptr, Expr* step = nullptr)
: body_(body), init_(init), cond_(cond), step_(step) {}
private:
Stmt* body_;
Stmt* init_;
Expr* cond_;
Expr* step_;
};
class JumpStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static JumpStmt* New(LabelStmt* label);
virtual ~JumpStmt() {}
virtual void Accept(Visitor* v);
void SetLabel(LabelStmt* label) { label_ = label; }
protected:
JumpStmt(LabelStmt* label): label_(label) {}
private:
LabelStmt* label_;
};
class ReturnStmt: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ReturnStmt* New(Expr* expr);
virtual ~ReturnStmt() {}
virtual void Accept(Visitor* v);
protected:
ReturnStmt(::Expr* expr): expr_(expr) {}
private:
::Expr* expr_;
};
using StmtList = std::list<Stmt*>;
class CompoundStmt : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static CompoundStmt* New(StmtList& stmts, ::Scope* scope=nullptr);
virtual ~CompoundStmt() {}
virtual void Accept(Visitor* v);
StmtList& Stmts() { return stmts_; }
::Scope* Scope() { return scope_; }
protected:
CompoundStmt(const StmtList& stmts, ::Scope* scope=nullptr)
: stmts_(stmts), scope_(scope) {}
private:
StmtList stmts_;
::Scope* scope_;
};
struct Initializer {
Initializer(Type* type,
int offset,
Expr* expr,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0)
: type_(type),
offset_(offset),
bitFieldBegin_(bitFieldBegin),
bitFieldWidth_(bitFieldWidth),
expr_(expr) {}
bool operator<(const Initializer& rhs) const;
// It could be the object it self or, it will be the member
// that was initialized
Type* type_;
int offset_;
unsigned char bitFieldBegin_;
unsigned char bitFieldWidth_;
Expr* expr_;
};
using InitList = std::set<Initializer>;
class Declaration: public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Declaration* New(Object* obj);
virtual ~Declaration() {}
virtual void Accept(Visitor* v);
InitList& Inits() { return inits_; }
Object* Obj() { return obj_; }
void AddInit(Initializer init);
protected:
Declaration(Object* obj): obj_(obj) {}
Object* obj_;
InitList inits_;
};
/*
* Expr
* BinaryOp
* UnaryOp
* ConditionalOp
* FuncCall
* Constant
* Identifier
* Object
* TempVar
*/
class Expr : public Stmt {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
virtual ~Expr() {}
::Type* Type() { return type_.GetPtr(); }
virtual bool IsLVal() = 0;
virtual void TypeChecking() = 0;
void EnsureCompatible(const QualType lhs, const QualType rhs) const;
void EnsureCompatibleOrVoidPointer(const QualType lhs,
const QualType rhs) const;
const Token* Tok() const { return tok_; }
void SetTok(const Token* tok) { tok_ = tok; }
static Expr* MayCast(Expr* expr);
static Expr* MayCast(Expr* expr, QualType desType);
static ::Type* TryExtractScalarType(Expr* loc, Expr *operand);
static ::Type* ScalarOrLikeTile(Expr* operand, ::Type* ty);
virtual bool IsNullPointerConstant() const { return false; }
bool IsConstQualified() const { return type_.IsConstQualified(); }
bool IsRestrictQualified() const { return type_.IsRestrictQualified(); }
bool IsVolatileQualified() const { return type_.IsVolatileQualified(); }
protected:
// You can construct a expression without specifying a type,
// then the type should be evaluated in TypeChecking()
Expr(const Token* tok, QualType type): tok_(tok), type_(type) {}
const Token* tok_;
QualType type_;
};
/*
* '+', '-', '*', '/', '%', '<', '>', '<<', '>>', '|', '&', '^'
* '=',(复合赋值运算符被拆分为两个运算)
* '==', '!=', '<=', '>=',
* '&&', '||'
* '['(下标运算符), '.'(成员运算符)
* ','(逗号运算符),
*/
class BinaryOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
friend class Declaration;
public:
static BinaryOp* New(const Token* tok, Expr* lhs, Expr* rhs);
static BinaryOp* New(const Token* tok, int op, Expr* lhs, Expr* rhs);
virtual ~BinaryOp() {}
virtual void Accept(Visitor* v);
// Member ref operator is a lvalue
virtual bool IsLVal() {
switch (op_) {
case '.': return !Type()->ToArray() && lhs_->IsLVal();
case ']': return !Type()->ToArray();
case Token::MASKED_DEREF: return true;
default: return false;
}
}
ArithmType* Convert();
static void Broadcast(Expr* loc, Expr*& lhs, Expr*& rhs, QualType &type);
virtual void TypeChecking();
void SubScriptingOpTypeChecking();
void MemberRefOpTypeChecking();
void MultiOpTypeChecking();
void AdditiveOpTypeChecking();
void ShiftOpTypeChecking();
void RangeOpTypeChecking();
void MatmulOpTypeChecking();
void MaskedDerefOpTypeChecking();
void RelationalOpTypeChecking();
void EqualityOpTypeChecking();
void BitwiseOpTypeChecking();
void LogicalOpTypeChecking();
void AssignOpTypeChecking();
void CommaOpTypeChecking();
protected:
BinaryOp(const Token* tok, int op, Expr* lhs, Expr* rhs)
: Expr(tok, nullptr), op_(op) {
lhs_ = lhs, rhs_ = rhs;
if (op != '.') {
lhs_ = MayCast(lhs);
rhs_ = MayCast(rhs);
}
}
int op_;
Expr* lhs_;
Expr* rhs_;
};
/*
* Unary Operator:
* '++' (prefix/postfix)
* '--' (prefix/postfix)
* '&' (ADDR)
* '*' (DEREF)
* '+' (PLUS)
* '-' (MINUS)
* '~'
* '!'
* CAST // like (int)3
*/
class UnaryOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static UnaryOp* New(int op, Expr* operand, QualType type=nullptr, int info=0);
virtual ~UnaryOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal();
::Type *Convert();
static int encodeRed(int ax, int tag);
static void decodeRed(int info, int& ax, int& tag);
void TypeChecking();
void IncDecOpTypeChecking();
void AddrOpTypeChecking();
void DerefOpTypeChecking();
void ReduceOpTypeChecking();
void UnaryArithmOpTypeChecking();
void BitcastOpTypeChecking();
void CastOpTypeChecking();
void IntrinsicOpTypeChecking();
protected:
UnaryOp(int op, Expr* operand, QualType type=nullptr, int info=0)
: Expr(operand->Tok(), type), op_(op), info_(info) {
operand_ = operand;
if (op_ != Token::CAST && op_ != Token::ADDR) {
operand_ = MayCast(operand);
}
}
int op_;
int info_;
Expr* operand_;
};
class TransOp: public Expr {
friend class Generator;
public:
using PermInt = std::vector<int>;
public:
static TransOp* New(const PermInt& perm, Expr* operand);
const PermInt& getPerm() const { return perm_; }
void Accept(Visitor* v);
bool IsLVal() { return false; }
void TypeChecking();
protected:
TransOp(const PermInt& perm, Expr* operand)
: Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {}
private:
Expr* operand_;
PermInt perm_;
};
// cond ? true false
class ConditionalOp : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static ConditionalOp* New(const Token* tok,
Expr* cond, Expr* exprTrue, Expr* exprFalse);
virtual ~ConditionalOp() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
ArithmType* Convert();
virtual void TypeChecking();
protected:
ConditionalOp(Expr* cond, Expr* exprTrue, Expr* exprFalse)
: Expr(cond->Tok(), nullptr), cond_(MayCast(cond)),
exprTrue_(MayCast(exprTrue)), exprFalse_(MayCast(exprFalse)) {}
private:
Expr* cond_;
Expr* exprTrue_;
Expr* exprFalse_;
};
class FuncCall : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
using ArgList = std::vector<Expr*>;
public:
static FuncCall* New(Expr* designator, const ArgList& args);
~FuncCall() {}
virtual void Accept(Visitor* v);
// A function call is ofcourse not lvalue
virtual bool IsLVal() { return false; }
ArgList* Args() { return &args_; }
Expr* Designator() { return designator_; }
const std::string& Name() const { return tok_->str_; }
::FuncType* FuncType() { return designator_->Type()->ToFunc(); }
virtual void TypeChecking();
protected:
FuncCall(Expr* designator, const ArgList& args)
: Expr(designator->Tok(), nullptr),
designator_(designator), args_(args) {}
Expr* designator_;
ArgList args_;
};
class Constant: public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Constant* New(const Token* tok, int tag, long val);
static Constant* New(const Token* tok, int tag, double val);
static Constant* New(const Token* tok, int tag, const std::string* val);
~Constant() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
virtual void TypeChecking() {}
long IVal() const { return ival_; }
double FVal() const { return fval_; }
const std::string* SVal() const { return sval_; }
std::string SValRepr() const;
std::string Repr() const { return std::string(".LC") + std::to_string(id_); }
protected:
Constant(const Token* tok, QualType type, long val)
: Expr(tok, type), ival_(val) {}
Constant(const Token* tok, QualType type, double val)
: Expr(tok, type), fval_(val) {}
Constant(const Token* tok, QualType type, const std::string* val)
: Expr(tok, type), sval_(val) {}
union {
long ival_;
double fval_;
struct {
long id_;
const std::string* sval_;
};
};
};
class TempVar : public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static TempVar* New(QualType type);
virtual ~TempVar() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return true; }
virtual void TypeChecking() {}
protected:
TempVar(QualType type): Expr(nullptr, type), tag_(GenTag()) {}
private:
static int GenTag() {
static int tag = 0;
return ++tag;
}
int tag_;
};
enum Linkage {
L_NONE,
L_EXTERNAL,
L_INTERNAL,
};
class Identifier: public Expr {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static Identifier* New(const Token* tok, QualType type, Linkage linkage, const AttrList& attrList={});
virtual ~Identifier() {}
virtual void Accept(Visitor* v);
virtual bool IsLVal() { return false; }
virtual Object* ToObject() { return nullptr; }
virtual Enumerator* ToEnumerator() { return nullptr; }
// An identifer can be:
// object, sturct/union/enum tag, typedef name, function, label.
Identifier* ToTypeName() {
// A typename has no linkage
// And a function has external or internal linkage
if (ToObject() || ToEnumerator() || linkage_ != L_NONE)
return nullptr;
return this;
}
virtual const std::string Name() const { return tok_->str_; }
enum Linkage Linkage() const { return linkage_; }
void SetLinkage(enum Linkage linkage) { linkage_ = linkage; }
virtual void TypeChecking() {}
protected:
Identifier(const Token* tok, QualType type, enum Linkage linkage, const AttrList& attrList={})
: Expr(tok, type), linkage_(linkage), attrList_(attrList) {}
// An identifier has property linkage
enum Linkage linkage_;
AttrList attrList_;
};
class Enumerator: public Identifier {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static Enumerator* New(const Token* tok, int val);
virtual ~Enumerator() {}
virtual void Accept(Visitor* v);
virtual Enumerator* ToEnumerator() { return this; }
int Val() const { return cons_->IVal(); }
protected:
Enumerator(const Token* tok, int val)
: Identifier(tok, ArithmType::New(T_INT), L_NONE),
cons_(Constant::New(tok, T_INT, (long)val)) {}
Constant* cons_;
};
class Object : public Identifier {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
friend class LValAssigner;
public:
static Object* New(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={});
static Object* NewAnony(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={});
~Object() {}
virtual void Accept(Visitor* v);
virtual Object* ToObject() { return this; }
virtual bool IsLVal() {
// TODO(wgtdkp): not all object is lval?
return true;
}
bool IsStatic() const {
return (Storage() & S_STATIC) || (Linkage() != L_NONE);
}
int Storage() const { return storage_; }
void SetStorage(int storage) { storage_ = storage; }
int Align() const { return align_; }
void SetAlign(int align) {
assert(align > 0);
// Allowing reduce alignment to implement __attribute__((packed))
//if (align < align_)
// Error(this, "alignment specifier cannot reduce alignment");
align_ = align;
}
int Offset() const { return offset_; }
void SetOffset(int offset) { offset_ = offset; }
Declaration* Decl() { return decl_; }
void SetDecl(Declaration* decl) { decl_ = decl; }
const AttrList& GetAttrList() const { return attrList_; }
unsigned char BitFieldBegin() const { return bitFieldBegin_; }
unsigned char BitFieldEnd() const { return bitFieldBegin_ + bitFieldWidth_; }
unsigned char BitFieldWidth() const { return bitFieldWidth_; }
static unsigned long BitFieldMask(Object* bitField) {
return BitFieldMask(bitField->bitFieldBegin_, bitField->bitFieldWidth_);
}
static unsigned long BitFieldMask(unsigned char begin, unsigned char width) {
auto end = begin + width;
return ((0xFFFFFFFFFFFFFFFFUL << (64 - end)) >> (64 - width)) << begin;
}
bool HasInit() const { return decl_ && decl_->Inits().size(); }
bool Anonymous() const { return anonymous_; }
virtual const std::string Name() const { return Identifier::Name(); }
std::string Repr() const {
assert(IsStatic() || anonymous_);
if (anonymous_)
return "anonymous." + std::to_string(id_);
if (linkage_ == L_NONE)
return Name() + "." + std::to_string(id_);
return Name();
}
protected:
Object(const Token* tok,
QualType type,
int storage=0,
enum Linkage linkage=L_NONE,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0,
const AttrList& attrList={})
: Identifier(tok, type, linkage),
storage_(storage),
offset_(0),
align_(type->Align()),
decl_(nullptr),
bitFieldBegin_(bitFieldBegin),
bitFieldWidth_(bitFieldWidth),
anonymous_(false),
attrList_(attrList){}
private:
int storage_;
int offset_;
int align_;
Declaration* decl_;
unsigned char bitFieldBegin_;
// 0 means it's not a bitfield
unsigned char bitFieldWidth_;
bool anonymous_;
long id_ {0};
AttrList attrList_;
};
/*
* Declaration
*/
class FuncDef : public ExtDecl {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
using ParamList = std::vector<Object*>;
public:
static FuncDef* New(Identifier* ident, LabelStmt* retLabel);
virtual ~FuncDef() {}
virtual void Accept(Visitor* v);
::FuncType* FuncType() { return ident_->Type()->ToFunc(); }
CompoundStmt* Body() { return body_; }
void SetBody(CompoundStmt* body) { body_ = body; }
std::string Name() const { return ident_->Name(); }
enum Linkage Linkage() { return ident_->Linkage(); }
protected:
FuncDef(Identifier* ident, LabelStmt* retLabel)
: ident_(ident), retLabel_(retLabel) {}
private:
Identifier* ident_;
LabelStmt* retLabel_;
CompoundStmt* body_;
};
using ExtDeclList = std::list<ExtDecl*>;
class TranslationUnit : public ASTNode {
template<typename T> friend class Evaluator;
friend class AddrEvaluator;
friend class Generator;
public:
static TranslationUnit* New() { return new TranslationUnit();}
virtual ~TranslationUnit() {}
virtual void Accept(Visitor* v);
void Add(ExtDecl* extDecl) { extDecls_.push_back(extDecl); }
ExtDeclList& ExtDecls() { return extDecls_; }
const ExtDeclList& ExtDecls() const { return extDecls_; }
private:
TranslationUnit() {}
ExtDeclList extDecls_;
};
#endif

View File

@@ -1,167 +0,0 @@
#pragma once
#ifndef _WGTCC_CODE_GEN_H_
#define _WGTCC_CODE_GEN_H_
#include "ast.h"
#include "visitor.h"
#include <stack>
namespace triton{
namespace ir{
class value;
class module;
class type;
class context;
class builder;
class attribute;
}
}
using namespace triton;
class Parser;
struct Addr;
template<> class Evaluator<Addr>;
struct StaticInitializer;
class LValAssigner;
using TypeList = std::vector<Type*>;
using LocationList = std::vector<std::string>;
using StaticInitList = std::vector<StaticInitializer>;
// Error
inline void should_not_happen(const std::string& suffix) { throw std::runtime_error("internal compiler error: " + suffix); }
inline void error_not_implemented(const std::string& msg) { throw std::runtime_error(msg); }
class Generator: public Visitor {
friend class Evaluator<Addr>;
friend class LValAssigner;
protected:
struct scope {
std::map<std::string, ir::type*> types;
std::map<std::string, ir::value*> values;
};
void set_ret(ir::value* value);
ir::value *GenUnaryMinus(ir::value* arg);
ir::value *GenUnaryInc(UnaryOp* arg, bool is_postfix, bool is_inc);
public:
Generator(Parser* parser) : parser_(parser) {}
void Visit(ASTNode* node) { node->Accept(this); }
void VisitExpr(Expr* expr) { expr->Accept(this); }
void VisitStmt(Stmt* stmt) { stmt->Accept(this); }
// Expression
void VisitBinaryOp(BinaryOp* binaryOp);
void VisitUnaryOp(UnaryOp* unaryOp);
void VisitTransOp(TransOp* transOp);
void VisitConditionalOp(ConditionalOp* condOp);
void VisitFuncCall(FuncCall* funcCall);
void VisitObject(Object* obj);
void VisitEnumerator(Enumerator* enumer);
void VisitIdentifier(Identifier* ident);
void VisitConstant(Constant* cons);
void VisitTempVar(TempVar* tempVar);
// Statement
void VisitDeclaration(Declaration* init);
void VisitEmptyStmt(EmptyStmt* emptyStmt);
void VisitIfStmt(IfStmt* ifStmt);
void VisitForStmt(ForStmt* ifStmt);
void VisitJumpStmt(JumpStmt* jumpStmt);
void VisitReturnStmt(ReturnStmt* returnStmt);
void VisitLabelStmt(LabelStmt* labelStmt);
void VisitCompoundStmt(CompoundStmt* compoundStmt);
void VisitFuncDef(FuncDef* funcDef);
void VisitTranslationUnit(TranslationUnit* unit);
void Gen(ir::module *mod);
protected:
// Triton-IR attributes
ir::attribute GenIRAttr(ASTNode::Attr attr);
// Triton-IR metadata
void SetIRMetadata(ASTNode::Attr attr, ir::value *rhs);
// Triton-IR values
ir::value* GenAssignOp(Expr* lvalue, ir::value* rhs);
ir::value* GenBroadcastOp(ir::value* src, ir::type* dst_ty);
ir::value* GenNumcastOp(ir::value*src, ir::type* dst_ty);
ir::value* GenSemCastOp(ir::value* op, ir::type* type);
ir::value* GenBitCastOp(ir::value* src, ir::type* dst_ty);
// Triton-IR types
static ir::type* GenIRType(::Type* type, ir::context &ctx);
static ir::type* GenIRArithmType(ArithmType* type, ir::context& ctx);
static ir::type* GenIRArrayType(ArrayType* type, ir::context& ctx);
static ir::type* GenIRTileType(TileType* type, ir::context& ctx);
static ir::type* GenIRFuncType(FuncType* type, ir::context& ctx);
static ir::type* GenIRPointerType(PointerType* type, ir::context& ctx);
static ir::type* GenIRStructType(StructType* type, ir::context& ctx);
void AllocObjects(Scope* scope, const FuncDef::ParamList& params=FuncDef::ParamList());
// SSA
void pushScope();
void popScope();
private:
Parser* parser_;
ir::value* ret_;
ir::builder* bld_;
ir::context* ctx_;
ir::module* mod_;
private:
// std::stack<scope> scopes_;
LValAssigner* assign_;
};
class LValAssigner: public Visitor {
public:
LValAssigner(Generator* gen): gen_(gen) {}
// Expression
void VisitBinaryOp(BinaryOp* binaryOp);
void VisitUnaryOp(UnaryOp* unaryOp);
void VisitObject(Object* obj);
void VisitIdentifier(Identifier* ident);
void VisitConditionalOp(ConditionalOp*) { should_not_happen("conditional cannot be lvalue"); }
void VisitFuncCall(FuncCall*) { should_not_happen("funccall cannot be lvalue"); }
void VisitTransOp(TransOp*) { should_not_happen("transop cannot be lvalue"); }
void VisitEnumerator(Enumerator*) { should_not_happen("enumerator cannot be lvalue"); }
void VisitConstant(Constant*) { should_not_happen("constant cannot be lvalue"); }
void VisitTempVar(TempVar*) { should_not_happen("tempvar cannot be lvalue"); }
void VisitDeclaration(Declaration*) { should_not_happen("declaration cannot be lvalue"); }
void VisitEmptyStmt(EmptyStmt*) { should_not_happen("empty statement cannot be lvalue"); }
void VisitIfStmt(IfStmt*) { should_not_happen("if statement cannot be lvalue"); }
void VisitForStmt(ForStmt*) { should_not_happen("for statement cannot be lvalue"); }
void VisitJumpStmt(JumpStmt*) { should_not_happen("jump statement cannot be lvalue"); }
void VisitReturnStmt(ReturnStmt*) { should_not_happen("return statement cannot be lvalue"); }
void VisitLabelStmt(LabelStmt*) { should_not_happen("label statement cannot be lvalue"); }
void VisitCompoundStmt(CompoundStmt*) { should_not_happen("compound statement cannot be lvalue"); }
void VisitFuncDef(FuncDef*) { should_not_happen("function definition cannot be lvalue"); }
void VisitTranslationUnit(TranslationUnit*) { should_not_happen("translation unit cannot be lvalue"); }
ir::value* GenExpr(Expr* expr, ir::value* rhs) {
rhs_ = rhs;
expr->Accept(this);
return ret_;
}
private:
ir::value* ret_;
ir::value* rhs_;
Generator* gen_;
};
#endif

View File

@@ -1,164 +0,0 @@
#pragma once
#ifndef _WGTCC_CPP_H_
#define _WGTCC_CPP_H_
#include "scanner.h"
#include <cstdio>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <string>
class Macro;
struct CondDirective;
using MacroMap = std::map<std::string, Macro>;
using ParamList = std::list<std::string>;
using ParamMap = std::map<std::string, TokenSequence>;
using PPCondStack = std::stack<CondDirective>;
using PathList = std::list<std::string>;
class Macro {
public:
Macro(const TokenSequence& repSeq, bool preDef=false)
: funcLike_(false), variadic_(false),
preDef_(preDef), repSeq_(repSeq) {}
Macro(bool variadic, ParamList& params,
TokenSequence& repSeq, bool preDef=false)
: funcLike_(true), variadic_(variadic), preDef_(preDef),
params_(params), repSeq_(repSeq) {}
~Macro() {}
bool FuncLike() { return funcLike_; }
bool ObjLike() { return !FuncLike(); }
bool Variadic() { return variadic_; }
bool PreDef() { return preDef_; }
ParamList& Params() { return params_; }
TokenSequence RepSeq(const std::string* filename, unsigned line);
private:
bool funcLike_;
bool variadic_;
bool preDef_;
ParamList params_;
TokenSequence repSeq_;
};
struct CondDirective {
int tag_;
bool enabled_;
bool cond_;
};
class Preprocessor {
public:
Preprocessor(const std::string* str, bool isSrc = true)
: curLine_(1), lineLine_(0), curCond_(true), fName_(nullptr), fSrc_(nullptr) {
if(isSrc)
fSrc_ = str;
else
fName_ = str;
// Add predefined
Init();
}
~Preprocessor() {}
void Finalize(TokenSequence os);
void Process(TokenSequence& os);
void Expand(TokenSequence& os, TokenSequence is, bool inCond=false);
void Subst(TokenSequence& os, TokenSequence is,
bool leadingWS, const HideSet& hs, ParamMap& params);
void Glue(TokenSequence& os, TokenSequence is);
void Glue(TokenSequence& os, const Token* tok);
const Token* Stringize(TokenSequence is);
void Stringize(std::string& str, TokenSequence is);
const Token* ParseActualParam(TokenSequence& is, Macro* macro, ParamMap& paramMap);
int GetDirective(TokenSequence& is);
const Token* EvalDefOp(TokenSequence& is);
void ReplaceIdent(TokenSequence& is);
void ParseDirective(TokenSequence& os, TokenSequence& is, int directive);
void ParseIf(TokenSequence ls);
void ParseIfdef(TokenSequence ls);
void ParseIfndef(TokenSequence ls);
void ParseElif(TokenSequence ls);
void ParseElse(TokenSequence ls);
void ParseEndif(TokenSequence ls);
void ParseInclude(TokenSequence& is, TokenSequence ls);
void ParseDef(TokenSequence ls);
void ParseUndef(TokenSequence ls);
void ParseLine(TokenSequence ls);
void ParseError(TokenSequence ls);
void ParsePragma(TokenSequence ls);
void IncludeSrc(TokenSequence& is, const std::string* text, const std::string* filename);
void IncludeFile(TokenSequence& is, const std::string* filename);
bool ParseIdentList(ParamList& params, TokenSequence& is);
Macro* FindMacro(const std::string& name) {
auto res = macroMap_.find(name);
if (res == macroMap_.end())
return nullptr;
return &res->second;
}
void AddMacro(const std::string& name,
std::string* text, bool preDef=false);
void AddMacro(const std::string& name, const Macro& macro) {
auto res = macroMap_.find(name);
if (res != macroMap_.end()) {
// TODO(wgtdkp): give warning
macroMap_.erase(res);
}
macroMap_.insert(std::make_pair(name, macro));
}
void RemoveMacro(const std::string& name) {
auto res = macroMap_.find(name);
if (res == macroMap_.end())
return;
if(res->second.PreDef()) // Cannot undef predefined macro
return;
macroMap_.erase(res);
}
std::string* SearchFile(const std::string& name,
const bool libHeader,
bool next,
const std::string& curPath);
void AddSearchPath(std::string path);
void HandleTheFileMacro(TokenSequence& os, const Token* macro);
void HandleTheLineMacro(TokenSequence& os, const Token* macro);
void UpdateFirstTokenLine(TokenSequence ts);
bool NeedExpand() const {
if (ppCondStack_.empty())
return true;
auto top = ppCondStack_.top();
return top.enabled_ && top.cond_;
}
private:
void Init();
PPCondStack ppCondStack_;
unsigned curLine_;
unsigned lineLine_;
bool curCond_;
MacroMap macroMap_;
PathList searchPaths_;
const std::string* fName_;
const std::string* fSrc_;
};
#endif

View File

@@ -1,22 +0,0 @@
#pragma once
#ifndef _WGTCC_ENCODING_H_
#define _WGTCC_ENCODING_H_
#include <string>
enum class Encoding {
NONE,
CHAR16,
CHAR32,
UTF8,
WCHAR
};
void ConvertToUTF16(std::string& str);
void ConvertToUTF32(std::string& str);
void AppendUCN(std::string& str, int c);
#endif

View File

@@ -1,17 +0,0 @@
#pragma once
#ifndef _WGTCC_ERROR_H_
#define _WGTCC_ERROR_H_
struct SourceLocation;
class Token;
class Expr;
[[noreturn]] void Error(const char* format, ...);
[[noreturn]] void Error(const SourceLocation& loc, const char* format, ...);
[[noreturn]] void Error(const Token* tok, const char* format, ...);
[[noreturn]] void Error(const Expr* expr, const char* format, ...);
#endif

View File

@@ -1,130 +0,0 @@
#pragma once
#ifndef _WGTCC_EVALUATOR_H_
#define _WGTCC_EVALUATOR_H_
#include "ast.h"
#include "error.h"
#include "visitor.h"
class Expr;
template<typename T>
class Evaluator: public Visitor {
public:
Evaluator() {}
virtual ~Evaluator() {}
virtual void VisitBinaryOp(BinaryOp* binary);
virtual void VisitUnaryOp(UnaryOp* unary);
virtual void VisitConditionalOp(ConditionalOp* cond);
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
val_ = static_cast<T>(enumer->Val());
}
virtual void VisitIdentifier(Identifier* ident) {
Error(ident, "expect constant expression");
}
virtual void VisitTransOp(TransOp* trans) {
Error(trans, "expect constant expression");
}
virtual void VisitObject(Object* obj) {
Error(obj, "expect constant expression");
}
virtual void VisitConstant(Constant* cons) {
if (cons->Type()->IsFloat()) {
val_ = static_cast<T>(cons->FVal());
} else if (cons->Type()->IsInteger()) {
val_ = static_cast<T>(cons->IVal());
} else {
assert(false);
}
}
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
virtual void VisitFuncDef(FuncDef* funcDef) {}
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
T Eval(Expr* expr) {
expr->Accept(this);
return val_;
}
private:
T val_;
};
struct Addr {
std::string label_;
int offset_;
};
template<>
class Evaluator<Addr>: public Visitor {
public:
Evaluator<Addr>() {}
virtual ~Evaluator<Addr>() {}
virtual void VisitBinaryOp(BinaryOp* binary);
virtual void VisitUnaryOp(UnaryOp* unary);
virtual void VisitConditionalOp(ConditionalOp* cond);
virtual void VisitFuncCall(FuncCall* funcCall) {
Error(funcCall, "expect constant expression");
}
virtual void VisitTransOp(TransOp* trans) {
Error(trans, "expect constant expression");
}
virtual void VisitEnumerator(Enumerator* enumer) {
addr_.offset_ = enumer->Val();
}
virtual void VisitIdentifier(Identifier* ident) {
addr_.label_ = ident->Name();
addr_.offset_ = 0;
}
virtual void VisitObject(Object* obj) {
if (!obj->IsStatic()) {
Error(obj, "expect static object");
}
addr_.label_ = obj->Repr();
addr_.offset_ = 0;
}
virtual void VisitConstant(Constant* cons);
virtual void VisitTempVar(TempVar* tempVar) { assert(false); }
// We may should assert here
virtual void VisitDeclaration(Declaration* init) {}
virtual void VisitIfStmt(IfStmt* ifStmt) {}
virtual void VisitForStmt(ForStmt* forStmt) {}
virtual void VisitJumpStmt(JumpStmt* jumpStmt) {}
virtual void VisitReturnStmt(ReturnStmt* returnStmt) {}
virtual void VisitLabelStmt(LabelStmt* labelStmt) {}
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) {}
virtual void VisitCompoundStmt(CompoundStmt* compStmt) {}
virtual void VisitFuncDef(FuncDef* funcDef) {}
virtual void VisitTranslationUnit(TranslationUnit* unit) {}
Addr Eval(Expr* expr) {
expr->Accept(this);
return addr_;
}
private:
Addr addr_;
};
#endif

View File

@@ -1,103 +0,0 @@
#pragma once
#ifndef _WGTCC_MEM_POOL_H_
#define _WGTCC_MEM_POOL_H_
#include <cstddef>
#include <vector>
class MemPool {
public:
MemPool(): allocated_(0) {}
virtual ~MemPool() {}
MemPool(const MemPool& other) = delete;
MemPool& operator=(const MemPool& other) = delete;
virtual void* Alloc() = 0;
virtual void Free(void* addr) = 0;
virtual void Clear() = 0;
protected:
size_t allocated_;
};
template <class T>
class MemPoolImp: public MemPool {
public:
MemPoolImp() : root_(nullptr) {}
virtual ~MemPoolImp() {}
MemPoolImp(const MemPool& other) = delete;
MemPoolImp& operator=(MemPool& other) = delete;
virtual void* Alloc();
virtual void Free(void* addr);
virtual void Clear();
private:
enum {
COUNT = (4 * 1024) / sizeof(T)
};
union Chunk {
Chunk* next_;
char mem_[sizeof(T)];
};
struct Block {
Block() {
for (size_t i = 0; i < COUNT - 1; ++i)
chunks_[i].next_ = &chunks_[i+1];
chunks_[COUNT-1].next_ = nullptr;
}
Chunk chunks_[COUNT];
};
std::vector<Block*> blocks_;
Chunk* root_;
};
template <class T>
void* MemPoolImp<T>::Alloc() {
if (nullptr == root_) { // 空间不够,需要分配空间
auto block = new Block();
root_ = block->chunks_;
// 如果blocks实现为std::list, 那么push_back实际的overhead更大
// 这也表明,即使我们不需要随机访问功能(那么std::vector的拷贝是一种overhead)
// 仍然倾向于使用std::vector
// 当然std::vector的指数级capacity增长会造成内存浪费。
blocks_.push_back(block);
}
auto ret = root_;
root_ = root_->next_;
++allocated_;
return ret;
}
template <class T>
void MemPoolImp<T>::Free(void* addr) {
if (nullptr == addr)
return;
auto chunk = static_cast<Chunk*>(addr);
chunk->next_ = root_;
root_ = chunk;
--allocated_;
}
template <class T>
void MemPoolImp<T>::Clear() {
for (auto block: blocks_)
delete block;
blocks_.resize(0);
root_ = nullptr;
allocated_ = 0;
}
#endif

View File

@@ -1,260 +0,0 @@
#pragma once
#ifndef _PARSER_H_
#define _PARSER_H_
#include "ast.h"
#include "encoding.h"
#include "error.h"
#include "mem_pool.h"
#include "scope.h"
#include "token.h"
#include <cassert>
#include <memory>
#include <stack>
class Preprocessor;
struct DeclInfo {
DeclInfo(const Token* _tok,
QualType _type,
ASTNode::AttrList _attrs = {})
: tok(_tok), type(_type), attrs(_attrs) {}
const Token* tok;
QualType type;
ASTNode::AttrList attrs;
};
class Parser {
using LiteralList = std::vector<Constant*>;
using StaticObjectList = std::vector<Object*>;
using CaseLabelList = std::vector<std::pair<Constant*, LabelStmt*>>;
using LabelJumpList = std::list<std::pair<const Token*, JumpStmt*>>;
using LabelMap = std::map<std::string, LabelStmt*>;
friend class Generator;
public:
explicit Parser(TokenSequence& ts)
: unit_(TranslationUnit::New()),
ts_(ts),
externalSymbols_(new Scope(nullptr, S_BLOCK)),
errTok_(nullptr),
curScope_(new Scope(nullptr, S_FILE)),
curFunc_(nullptr),
breakDest_(nullptr),
continueDest_(nullptr),
caseLabels_(nullptr),
defaultLabel_(nullptr) {
ts_.SetParser(this);
}
~Parser() {}
Constant* ParseConstant(const Token* tok);
Constant* ParseFloat(const Token* tok);
Constant* ParseInteger(const Token* tok);
Constant* ParseCharacter(const Token* tok);
Encoding ParseLiteral(std::string& str, const Token* tok);
Constant* ConcatLiterals(const Token* tok);
Expr* ParseGeneric();
void Parse();
void ParseTranslationUnit();
FuncDef* ParseFuncDef(Identifier* ident);
// Expressions
Expr* ParseExpr();
Expr* ParsePrimaryExpr();
QualType TryCompoundLiteral();
Object* ParseCompoundLiteral(QualType type);
Expr* ParsePostfixExpr();
Expr* ParsePostfixExprTail(Expr* primExpr);
Expr* ParseSubScripting(Expr* pointer);
BinaryOp* ParseMemberRef(const Token* tok, int op, Expr* lhs);
UnaryOp* ParsePostfixIncDec(const Token* tok, Expr* operand);
FuncCall* ParseFuncCall(Expr* caller);
Expr* ParseUnaryExpr();
Constant* ParseSizeof();
Constant* ParseAlignof();
UnaryOp* ParsePrefixIncDec(const Token* tok);
UnaryOp* ParseUnaryIntrinsicOp(int op);
UnaryOp* ParseUnaryOp(const Token* tok, int op);
Expr* ParseDerefOp(const Token* tok);
QualType ParseTypeName();
Expr* ParseCastExpr();
Expr* ParseRangeExpr();
Expr* ParseMatmulExpr();
Expr* ParseMultiplicativeExpr();
Expr* ParseAdditiveExpr();
Expr* ParseShiftExpr();
Expr* ParseRelationalExpr();
Expr* ParseEqualityExpr();
Expr* ParseBitiwiseAndExpr();
Expr* ParseBitwiseXorExpr();
Expr* ParseBitwiseOrExpr();
Expr* ParseLogicalAndExpr();
Expr* ParseLogicalOrExpr();
Expr* ParseConditionalExpr();
Expr* ParseCommaExpr();
Expr* ParseAssignExpr();
// Declarations
CompoundStmt* ParseDecl();
void ParseStaticAssert();
QualType ParseDeclSpec(int* storageSpec, int* funcSpec, int* alignSpec);
QualType ParseSpecQual();
int ParseAlignas();
Type* ParseStructUnionSpec(bool isStruct);
StructType* ParseStructUnionDecl(StructType* type);
void ParseBitField(StructType* structType, const Token* tok, QualType type);
Type* ParseEnumSpec();
Type* ParseEnumerator(ArithmType* type);
int ParseQual();
QualType ParsePointer(QualType typePointedTo);
DeclInfo ParseDeclarator(QualType type);
QualType ParseArrayFuncDeclarator(const Token* ident, QualType base);
int ParseArrayLength();
TileType::ShapeInt ParseTileShape();
bool ParseParamList(FuncType::ParamList& params);
Object* ParseParamDecl();
QualType ParseAbstractDeclarator(QualType type);
Identifier* ParseDirectDeclarator(QualType type,
int storageSpec,
int funcSpec,
int align);
// Initializer
void ParseInitializer(Declaration* decl,
QualType type,
int offset,
bool designated=false,
bool forceBrace=false,
unsigned char bitFieldBegin=0,
unsigned char bitFieldWidth=0);
void ParseArrayInitializer(Declaration* decl,
ArrayType* type,
int offset,
bool designated);
StructType::Iterator ParseStructDesignator(StructType* type,
const std::string& name);
void ParseStructInitializer(Declaration* decl,
StructType* type,
int offset,
bool designated);
bool ParseLiteralInitializer(Declaration* init,
ArrayType* type,
int offset);
Declaration* ParseInitDeclarator(Identifier* ident);
Declaration* ParseInitDeclaratorSub(Object* obj);
// Statements
Stmt* ParseStmt();
CompoundStmt* ParseCompoundStmt(FuncType* funcType=nullptr);
IfStmt* ParseIfStmt();
CompoundStmt* ParseSwitchStmt();
CompoundStmt* ParseWhileStmt();
CompoundStmt* ParseDoStmt();
ForStmt *ParseForStmt();
JumpStmt* ParseGotoStmt();
JumpStmt* ParseContinueStmt();
JumpStmt* ParseBreakStmt();
ReturnStmt* ParseReturnStmt();
CompoundStmt* ParseLabelStmt(const Token* label);
CompoundStmt* ParseCaseStmt();
CompoundStmt* ParseDefaultStmt();
Identifier* ProcessDeclarator(const Token* tok,
QualType type, const ASTNode::AttrList &attrs,
int storageSpec,
int funcSpec,
int align);
// GNU extensions
ASTNode::AttrList TryAttributeSpecList();
void ParseAttributeSpec(ASTNode::AttrList &attrList);
ASTNode::Attr ParseAttribute();
bool IsTypeName(const Token* tok) const{
if (tok->IsTypeSpecQual())
return true;
if (tok->IsIdentifier()) {
auto ident = curScope_->Find(tok);
if (ident && ident->ToTypeName())
return true;
}
return false;
}
bool IsType(const Token* tok) const{
if (tok->IsDecl())
return true;
if (tok->IsIdentifier()) {
auto ident = curScope_->Find(tok);
return (ident && ident->ToTypeName());
}
return false;
}
void EnsureInteger(Expr* expr) {
if (!expr->Type()->IsInteger()) {
Error(expr, "expect integer expression");
}
}
void EnterBlock(FuncType* funcType=nullptr);
void ExitBlock() { curScope_ = curScope_->Parent(); }
void EnterProto() { curScope_ = new Scope(curScope_, S_PROTO); }
void ExitProto() { curScope_ = curScope_->Parent(); }
FuncDef* EnterFunc(Identifier* ident);
void ExitFunc();
LabelStmt* FindLabel(const std::string& label) {
auto ret = curLabels_.find(label);
if (curLabels_.end() == ret)
return nullptr;
return ret->second;
}
void AddLabel(const std::string& label, LabelStmt* labelStmt) {
assert(nullptr == FindLabel(label));
curLabels_[label] = labelStmt;
}
TranslationUnit* Unit() { return unit_; }
FuncDef* CurFunc() { return curFunc_; }
const TokenSequence& ts() const { return ts_; }
protected:
static bool IsBuiltin(FuncType* type);
static bool IsBuiltin(const std::string& name);
static Identifier* GetBuiltin(const Token* tok);
static void DefineBuiltins();
static FuncType* vaStartType_;
static FuncType* vaArgType_;
// The root of the AST
TranslationUnit* unit_;
TokenSequence& ts_;
// It is not the real scope,
// It contains all external symbols(resolved and not resolved)
Scope* externalSymbols_;
const Token* errTok_;
Scope* curScope_;
FuncDef* curFunc_;
LabelMap curLabels_;
LabelJumpList unresolvedJumps_;
LabelStmt* breakDest_;
LabelStmt* continueDest_;
CaseLabelList* caseLabels_;
LabelStmt* defaultLabel_;
};
#endif

View File

@@ -1,86 +0,0 @@
#pragma once
#ifndef _WGTCC_SCANNER_H_
#define _WGTCC_SCANNER_H_
#include "error.h"
#include "encoding.h"
#include "token.h"
#include <string>
#include <cassert>
class Scanner {
public:
explicit Scanner(const Token* tok)
: Scanner(&tok->str_, tok->loc_) {}
Scanner(const std::string* text, const SourceLocation& loc)
: Scanner(text, loc.filename_, loc.line_, loc.column_) {}
explicit Scanner(const std::string* text,
const std::string* filename=nullptr,
unsigned line=1, unsigned column=1)
: text_(text), tok_(Token::END) {
// TODO(wgtdkp): initialization
p_ = &(*text_)[0];
loc_ = {filename, p_, line, 1};
}
virtual ~Scanner() {}
Scanner(const Scanner& other) = delete;
Scanner& operator=(const Scanner& other) = delete;
// Scan plain text and generate tokens in ts.
// The param 'ts' need not be empty, if so, the tokens
// are inserted at the *header* of 'ts'.
// The param 'ws' tells if there is leading white space
// before this token, it is only SkipComment() that will
// set this param.
Token* Scan(bool ws=false);
void Tokenize(TokenSequence& ts);
static std::string ScanHeadName(const Token* lhs, const Token* rhs);
Encoding ScanCharacter(int& val);
Encoding ScanLiteral(std::string& val);
std::string ScanIdentifier();
private:
Token* SkipIdentifier();
Token* SkipNumber();
Token* SkipLiteral();
Token* SkipCharacter();
Token* MakeToken(int tag);
Token* MakeNewLine();
Encoding ScanEncoding(int c);
int ScanEscaped();
int ScanHexEscaped();
int ScanOctEscaped(int c);
int ScanUCN(int len);
void SkipWhiteSpace();
void SkipComment();
bool IsUCN(int c) { return c == '\\' && (Test('u') || Test('U')); }
bool IsOctal(int c) { return '0' <= c && c <= '7'; }
int XDigit(int c);
bool Empty() const { return *p_ == 0; }
int Peek();
bool Test(int c) { return Peek() == c; };
int Next();
void PutBack();
bool Try(int c) {
if (Peek() == c) {
Next();
return true;
}
return false;
};
void Mark() { tok_.loc_ = loc_; };
const std::string* text_;
SourceLocation loc_;
Token tok_;
const char* p_;
};
std::string* ReadFile(const std::string& filename);
#endif

View File

@@ -1,72 +0,0 @@
#pragma once
#ifndef _WGTCC_SCOPE_H_
#define _WGTCC_SCOPE_H_
#include <iostream>
#include <map>
#include <string>
#include <vector>
class Identifier;
class Token;
enum ScopeType {
S_FILE,
S_PROTO,
S_BLOCK,
S_FUNC,
};
class Scope {
friend class StructType;
using TagList = std::vector<Identifier*>;
using IdentMap = std::map<std::string, Identifier*>;
public:
explicit Scope(Scope* parent, enum ScopeType type)
: parent_(parent), type_(type) {}
~Scope() {}
Scope* Parent() { return parent_; }
void SetParent(Scope* parent) { parent_ = parent; }
enum ScopeType Type() const { return type_; }
Identifier* Find(const Token* tok);
Identifier* FindInCurScope(const Token* tok);
Identifier* FindTag(const Token* tok);
Identifier* FindTagInCurScope(const Token* tok);
TagList AllTagsInCurScope() const;
void Insert(Identifier* ident);
void Insert(const std::string& name, Identifier* ident);
void InsertTag(Identifier* ident);
void Print();
bool operator==(const Scope& other) const { return type_ == other.type_; }
IdentMap::iterator begin() { return identMap_.begin(); }
IdentMap::iterator end() { return identMap_.end(); }
size_t size() const { return identMap_.size(); }
private:
Identifier* Find(const std::string& name);
Identifier* FindInCurScope(const std::string& name);
Identifier* FindTag(const std::string& name);
Identifier* FindTagInCurScope(const std::string& name);
std::string TagName(const std::string& name) {
return name + "@:tag";
}
static bool IsTagName(const std::string& name) {
return name.size() > 5 && name[name.size() - 5] == '@';
}
const Scope& operator=(const Scope& other);
Scope(const Scope& scope);
Scope* parent_;
enum ScopeType type_;
IdentMap identMap_;
};
#endif

View File

@@ -1,434 +0,0 @@
#pragma once
#ifndef _WGTCC_TOKEN_H_
#define _WGTCC_TOKEN_H_
#include "error.h"
#include <cassert>
#include <cstring>
#include <iostream>
#include <list>
#include <set>
#include <string>
#include <unordered_map>
class Generator;
class Parser;
class Scanner;
class Token;
class TokenSequence;
using HideSet = std::set<std::string>;
using TokenList = std::list<const Token*>;
struct SourceLocation {
const std::string* filename_;
const char* lineBegin_;
unsigned line_;
unsigned column_;
const char* Begin() const {
return lineBegin_ + column_ - 1;
}
};
class Token {
friend class Scanner;
public:
enum {
// Punctuators
LPAR = '(',
RPAR = ')',
LSQB = '[',
RSQB = ']',
COLON = ':',
COMMA = ',',
SEMI = ';',
ADD = '+',
SUB = '-',
MUL = '*',
DIV = '/',
OR = '|',
AND = '&',
XOR = '^',
LESS = '<',
GREATER = '>',
EQUAL = '=',
DOT = '.',
MOD = '%',
LBRACE = '{',
RBRACE = '}',
TILDE = '~',
NOT = '!',
COND = '?',
SHARP = '#',
MATMUL = '@',
NEW_LINE = '\n',
DSHARP = 128, // '##'
PTR,
INC,
DEC,
LEFT,
RIGHT,
LE,
GE,
EQ,
NE,
LOGICAL_AND,
LOGICAL_OR,
MUL_ASSIGN,
DIV_ASSIGN,
MOD_ASSIGN,
ADD_ASSIGN,
SUB_ASSIGN,
LEFT_ASSIGN,
RIGHT_ASSIGN,
AND_ASSIGN,
XOR_ASSIGN,
OR_ASSIGN,
ELLIPSIS,
MASKED_DEREF,
// Punctuators end
// KEYWORD BEGIN
// TYPE QUALIFIER BEGIN
CONST,
RESTRICT,
VOLATILE,
ATOMIC,
// TYPE QUALIFIER END
// TYPE SPECIFIER BEGIN
VOID,
CHAR,
SHORT,
INT,
LONG,
HALF,
FLOAT,
DOUBLE,
SIGNED,
UNSIGNED,
BOOL, // _Bool
COMPLEX, // _Complex
STRUCT,
UNION,
ENUM,
// TYPE SPECIFIER END
ATTRIBUTE, // GNU extension __attribute__
// FUNCTION SPECIFIER BEGIN
INLINE,
NORETURN, // _Noreturn
// FUNCTION SPECIFIER END
// TILE ARITHMETICS BEGIN
NEWAXIS,
MAX,
MIN,
// TILE ARITHMETICS END
ALIGNAS, // _Alignas
// For syntactic convenience
STATIC_ASSERT, // _Static_assert
// STORAGE CLASS SPECIFIER BEGIN
TYPEDEF,
EXTERN,
STATIC,
THREAD, // _Thread_local
AUTO,
GLOBAL,
CMEM, // constant memory
// STORAGE CLASS SPECIFIER END
BREAK,
CASE,
CONTINUE,
DEFAULT,
DO,
ELSE,
FOR,
GOTO,
IF,
RETURN,
SIZEOF,
SWITCH,
WHILE,
ALIGNOF, // _Alignof
GENERIC, // _Generic
IMAGINARY, // _Imaginary
// function keywords
BITCAST,
EXP,
LOG,
SQRTF,
// KEYWORD END
IDENTIFIER,
CONSTANT,
I_CONSTANT,
C_CONSTANT,
F_CONSTANT,
LITERAL,
// For the parser, a identifier is a typedef name or user defined type
POSTFIX_INC,
POSTFIX_DEC,
PREFIX_INC,
PREFIX_DEC,
ADDR, // '&'
DEREF, // '*'
PLUS,
MINUS,
CAST,
REDUCE,
// For preprocessor
PP_IF,
PP_IFDEF,
PP_IFNDEF,
PP_ELIF,
PP_ELSE,
PP_ENDIF,
PP_INCLUDE,
PP_DEFINE,
PP_UNDEF,
PP_LINE,
PP_ERROR,
PP_PRAGMA,
PP_NONE,
PP_EMPTY,
IGNORE,
INVALID,
END,
NOTOK = -1,
};
static Token* New(int tag);
static Token* New(const Token& other);
static Token* New(int tag,
const SourceLocation& loc,
const std::string& str,
bool ws=false);
Token& operator=(const Token& other) {
tag_ = other.tag_;
ws_ = other.ws_;
loc_ = other.loc_;
str_ = other.str_;
hs_ = other.hs_ ? new HideSet(*other.hs_): nullptr;
return *this;
}
virtual ~Token() {}
// Token::NOTOK represents not a kw.
static int KeyWordTag(const std::string& key) {
auto kwIter = kwTypeMap_.find(key);
if (kwTypeMap_.end() == kwIter)
return Token::NOTOK; // Not a key word type
return kwIter->second;
}
static bool IsKeyWord(const std::string& name);
static bool IsKeyWord(int tag) { return CONST <= tag && tag < IDENTIFIER; }
bool IsKeyWord() const { return IsKeyWord(tag_); }
bool IsPunctuator() const { return 0 <= tag_ && tag_ <= ELLIPSIS; }
bool IsLiteral() const { return tag_ == LITERAL; }
bool IsConstant() const { return CONSTANT <= tag_ && tag_ <= F_CONSTANT; }
bool IsIdentifier() const { return IDENTIFIER == tag_; }
bool IsEOF() const { return tag_ == Token::END; }
bool IsTypeSpecQual() const { return CONST <= tag_ && tag_ <= ENUM; }
bool IsDecl() const { return CONST <= tag_ && tag_ <= GLOBAL; }
static const char* Lexeme(int tag) {
auto iter = tagLexemeMap_.find(tag);
if (iter == tagLexemeMap_.end())
return nullptr;
return iter->second;
}
int tag_;
// 'ws_' standards for weither there is preceding white space
// This is to simplify the '#' operator(stringize) in macro expansion
bool ws_ { false };
SourceLocation loc_;
std::string str_;
HideSet* hs_ { nullptr };
private:
explicit Token(int tag): tag_(tag) {}
Token(int tag, const SourceLocation& loc,
const std::string& str, bool ws=false)
: tag_(tag), ws_(ws), loc_(loc), str_(str) {}
Token(const Token& other) {
*this = other;
}
static const std::unordered_map<std::string, int> kwTypeMap_;
static const std::unordered_map<int, const char*> tagLexemeMap_;
};
class TokenSequence {
friend class Preprocessor;
public:
TokenSequence(): tokList_(new TokenList()),
begin_(tokList_->begin()), end_(tokList_->end()) {}
explicit TokenSequence(Token* tok) {
TokenSequence();
InsertBack(tok);
}
explicit TokenSequence(TokenList* tokList)
: tokList_(tokList),
begin_(tokList->begin()),
end_(tokList->end()) {}
TokenSequence(TokenList* tokList,
TokenList::iterator begin,
TokenList::iterator end)
: tokList_(tokList), begin_(begin), end_(end) {}
~TokenSequence() {}
TokenSequence(const TokenSequence& other) { *this = other; }
const TokenSequence& operator=(const TokenSequence& other) {
tokList_ = other.tokList_;
begin_ = other.begin_;
end_ = other.end_;
return *this;
}
void Copy(const TokenSequence& other) {
tokList_ = new TokenList(other.begin_, other.end_);
begin_ = tokList_->begin();
end_ = tokList_->end();
for (auto iter = begin_; iter != end_; ++iter)
*iter = Token::New(**iter);
}
void UpdateHeadLocation(const SourceLocation& loc) {
assert(!Empty());
auto tok = const_cast<Token*>(Peek());
tok->loc_ = loc;
}
void FinalizeSubst(bool leadingWS, const HideSet& hs) {
auto ts = *this;
while (!ts.Empty()) {
auto tok = const_cast<Token*>(ts.Next());
if (!tok->hs_)
tok->hs_ = new HideSet(hs);
else
tok->hs_->insert(hs.begin(), hs.end());
}
// Even if the token sequence is empty
const_cast<Token*>(Peek())->ws_ = leadingWS;
}
const Token* Expect(int expect);
bool Try(int tag) {
if (Peek()->tag_ == tag) {
Next();
return true;
}
return false;
}
bool Test(int tag) { return Peek()->tag_ == tag; }
const Token* Next() {
auto ret = Peek();
if (!ret->IsEOF()) {
++begin_;
Peek(); // May skip newline token, but why ?
} else {
++exceed_end;
}
return ret;
}
void PutBack() {
assert(begin_ != tokList_->begin());
if (exceed_end > 0) {
--exceed_end;
} else {
--begin_;
if ((*begin_)->tag_ == Token::NEW_LINE)
PutBack();
}
}
const Token* Peek() const;
const Token* Peek2() {
if (Empty())
return Peek(); // Return the Token::END
Next();
auto ret = Peek();
PutBack();
return ret;
}
const Token* Back() const {
auto back = end_;
return *--back;
}
void PopBack() {
assert(!Empty());
assert(end_ == tokList_->end());
auto size_eq1 = tokList_->back() == *begin_;
tokList_->pop_back();
end_ = tokList_->end();
if (size_eq1)
begin_ = end_;
}
TokenList::iterator Mark() { return begin_; }
void ResetTo(TokenList::iterator mark) { begin_ = mark; }
bool Empty() const { return Peek()->tag_ == Token::END; }
void InsertBack(TokenSequence& ts) {
auto pos = tokList_->insert(end_, ts.begin_, ts.end_);
if (begin_ == end_) {
begin_ = pos;
}
}
void InsertBack(const Token* tok) {
auto pos = tokList_->insert(end_, tok);
if (begin_ == end_) {
begin_ = pos;
}
}
// If there is preceding newline
void InsertFront(TokenSequence& ts) {
auto pos = GetInsertFrontPos();
begin_ = tokList_->insert(pos, ts.begin_, ts.end_);
}
void InsertFront(const Token* tok) {
auto pos = GetInsertFrontPos();
begin_ = tokList_->insert(pos, tok);
}
bool IsBeginOfLine() const;
TokenSequence GetLine();
void SetParser(Parser* parser) { parser_ = parser; }
void Print(FILE* fp=stdout) const;
void Print(std::string *str) const;
private:
// Find a insert position with no preceding newline
TokenList::iterator GetInsertFrontPos() {
auto pos = begin_;
if (pos == tokList_->begin())
return pos;
--pos;
while (pos != tokList_->begin() && (*pos)->tag_ == Token::NEW_LINE)
--pos;
return ++pos;
}
TokenList* tokList_;
mutable TokenList::iterator begin_;
TokenList::iterator end_;
Parser* parser_ {nullptr};
int exceed_end {0};
};
#endif

View File

@@ -1,453 +0,0 @@
#pragma once
#ifndef _WGTCC_TYPE_H_
#define _WGTCC_TYPE_H_
#include "mem_pool.h"
#include "scope.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <list>
class Scope;
class Token;
class Expr;
class Type;
class QualType;
class VoidType;
class Identifier;
class Object;
class Constant;
class ArithmType;
class DerivedType;
class ArrayType;
class TileType;
class FuncType;
class PointerType;
class StructType;
class EnumType;
enum {
// Storage class specifiers
S_TYPEDEF = 0x01,
S_EXTERN = 0x02,
S_STATIC = 0x04,
S_THREAD = 0x08,
S_CONSTANT = 0x10,
S_GLOBAL = 0x20,
// Type specifier
T_SIGNED = 0x40,
T_UNSIGNED = 0x80,
T_CHAR = 0x100,
T_SHORT = 0x200,
T_INT = 0x400,
T_LONG = 0x800,
T_VOID = 0x1000,
T_HALF = 0x2000,
T_FLOAT = 0x4000,
T_DOUBLE = 0x8000,
T_BOOL = 0x10000,
T_COMPLEX = 0x20000,
// T_ATOMIC = 0x40000,
T_STRUCT_UNION = 0x80000,
T_ENUM = 0x100000,
T_TYPEDEF_NAME = 0x200000,
T_LLONG = 0x4000000,
// Function specifier
F_INLINE = 0x8000000,
F_NORETURN = 0x10000000,
};
struct Qualifier {
enum {
CONST = 0x01,
RESTRICT = 0x02,
VOLATILE = 0x04,
CMEM = 0x08,
MASK = CONST | RESTRICT | VOLATILE | CMEM
};
};
class QualType {
public:
QualType(Type* ptr, int quals=0x00)
: ptr_(reinterpret_cast<intptr_t>(ptr)) {
assert((quals & ~Qualifier::MASK) == 0);
ptr_ |= quals;
}
operator bool() const { return !IsNull(); }
bool IsNull() const { return GetPtr() == nullptr; }
const Type* GetPtr() const {
return reinterpret_cast<const Type*>(ptr_ & ~Qualifier::MASK);
}
Type* GetPtr() {
return reinterpret_cast<Type*>(ptr_ & ~Qualifier::MASK);
}
Type& operator*() { return *GetPtr(); }
const Type& operator*() const { return *GetPtr(); }
Type* operator->() { return GetPtr(); }
const Type* operator->() const { return GetPtr(); }
// Indicate whether the specified types are identical(exclude qualifiers).
friend bool operator==(QualType lhs, QualType rhs) {
return lhs.operator->() == rhs.operator->();
}
friend bool operator!=(QualType lhs, QualType rhs) {
return !(lhs == rhs);
}
int Qual() const { return ptr_ & 0x07; }
bool IsConstQualified() const { return ptr_ & Qualifier::CONST; }
bool IsRestrictQualified() const { return ptr_ & Qualifier::RESTRICT; }
bool IsVolatileQualified() const { return ptr_ & Qualifier::VOLATILE; }
bool IsConstantQualified() const { return ptr_ & Qualifier::CMEM; }
private:
intptr_t ptr_;
};
class Type {
public:
static const int intWidth_ = 4;
static const int machineWidth_ = 8;
bool operator!=(const Type& other) const = delete;
bool operator==(const Type& other) const = delete;
virtual bool Compatible(const Type& other) const {
return complete_ == other.complete_;
}
virtual ~Type() {}
// For Debugging
virtual std::string Str() const = 0;
virtual int Width() const = 0;
virtual int Align() const { return Width(); }
static int MakeAlign(int offset, int align) {
if ((offset % align) == 0)
return offset;
if (offset >= 0)
return offset + align - (offset % align);
else
return offset - align - (offset % align);
}
static QualType MayCast(QualType type, bool inProtoScope=false);
bool Complete() const { return complete_; }
void SetComplete(bool complete) const { complete_ = complete; }
bool IsReal() const { return IsInteger() || IsFloat(); };
virtual bool IsScalar() const { return false; }
virtual bool IsFloat() const { return false; }
virtual bool IsInteger() const { return false; }
virtual bool IsBool() const { return false; }
virtual bool IsVoidPointer() const { return false; }
virtual bool IsUnsigned() const { return false; }
virtual bool IsTile() const { return ToTile() != nullptr; }
const Type* ScalarType() const;
Type* ScalarType();
virtual VoidType* ToVoid() { return nullptr; }
virtual const VoidType* ToVoid() const { return nullptr; }
virtual ArithmType* ToArithm() { return nullptr; }
virtual const ArithmType* ToArithm() const { return nullptr; }
virtual ArrayType* ToArray() { return nullptr; }
virtual const ArrayType* ToArray() const { return nullptr; }
virtual TileType* ToTile() { return nullptr; }
virtual const TileType* ToTile() const { return nullptr; }
virtual FuncType* ToFunc() { return nullptr; }
virtual const FuncType* ToFunc() const { return nullptr; }
virtual PointerType* ToPointer() { return nullptr; }
virtual const PointerType* ToPointer() const { return nullptr; }
virtual DerivedType* ToDerived() { return nullptr; }
virtual const DerivedType* ToDerived() const { return nullptr; }
virtual StructType* ToStruct() { return nullptr; }
virtual const StructType* ToStruct() const { return nullptr; }
protected:
Type(MemPool* pool, bool complete)
: complete_(complete), pool_(pool) {}
mutable bool complete_;
MemPool* pool_;
};
class VoidType : public Type {
public:
static VoidType* New();
virtual ~VoidType() {}
virtual VoidType* ToVoid() { return this; }
virtual const VoidType* ToVoid() const { return this; }
virtual bool Compatible(const Type& other) const { return other.ToVoid(); }
virtual int Width() const {
// Non-standard GNU extension
return 1;
}
virtual std::string Str() const { return "void:1"; }
protected:
explicit VoidType(MemPool* pool): Type(pool, false) {}
};
class ArithmType : public Type {
public:
static ArithmType* New(int typeSpec);
virtual ~ArithmType() {}
virtual ArithmType* ToArithm() { return this; }
virtual const ArithmType* ToArithm() const { return this; }
virtual bool Compatible(const Type& other) const {
// C11 6.2.7 [1]: Two types have compatible type if their types are the same
// But I would to loose this constraints: integer and pointer are compatible
// if (IsInteger() && other.ToPointer())
// return other.Compatible(*this);
return this == &other;
}
virtual int Width() const;
virtual std::string Str() const;
virtual bool IsScalar() const { return true; }
virtual bool IsInteger() const { return !IsFloat() && !IsComplex(); }
virtual bool IsUnsigned() const { return tag_ & T_UNSIGNED; }
virtual bool IsFloat() const {
return (tag_ & T_HALF) || (tag_ & T_FLOAT) || (tag_ & T_DOUBLE);
}
virtual bool IsBool() const { return tag_ & T_BOOL; }
bool IsComplex() const { return tag_ & T_COMPLEX; }
int Tag() const { return tag_; }
int Rank() const;
static ArithmType* IntegerPromote(ArithmType* type) {
assert(type->IsInteger());
if (type->Rank() < ArithmType::New(T_INT)->Rank())
return ArithmType::New(T_INT);
return type;
}
static ArithmType* MaxType(ArithmType* lhsType,
ArithmType* rhsType);
protected:
explicit ArithmType(MemPool* pool, int spec)
: Type(pool, true), tag_(Spec2Tag(spec)) {}
private:
static int Spec2Tag(int spec);
int tag_;
};
class DerivedType : public Type {
public:
QualType Derived() const { return derived_; }
void SetDerived(QualType derived) { derived_ = derived; }
virtual DerivedType* ToDerived() { return this; }
virtual const DerivedType* ToDerived() const { return this; }
protected:
DerivedType(MemPool* pool, QualType derived)
: Type(pool, true), derived_(derived) {}
QualType derived_;
};
class PointerType : public DerivedType {
public:
static PointerType* New(QualType derived);
virtual ~PointerType() {}
virtual PointerType* ToPointer() { return this; }
virtual const PointerType* ToPointer() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 8; }
virtual bool IsScalar() const { return true; }
virtual bool IsVoidPointer() const { return derived_->ToVoid(); }
virtual std::string Str() const {
return derived_->Str() + "*:" + std::to_string(Width());
}
protected:
PointerType(MemPool* pool, QualType derived): DerivedType(pool, derived) {}
};
class ArrayType : public DerivedType {
public:
static ArrayType* New(int len, QualType eleType);
static ArrayType* New(Expr* expr, QualType eleType);
virtual ~ArrayType() { /*delete derived_;*/ }
virtual ArrayType* ToArray() { return this; }
virtual const ArrayType* ToArray() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const {
return Complete() ? (derived_->Width() * len_): 0;
}
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[]:" + std::to_string(Width());
}
int GetElementOffset(int idx) const { return derived_->Width() * idx; }
int Len() const { return len_; }
void SetLen(int len) { len_ = len; }
bool Variadic() const { return lenExpr_ != nullptr; }
protected:
ArrayType(MemPool* pool, Expr* lenExpr, QualType derived)
: DerivedType(pool, derived),
lenExpr_(lenExpr), len_(0) {
SetComplete(false);
}
ArrayType(MemPool* pool, int len, QualType derived)
: DerivedType(pool, derived),
lenExpr_(nullptr), len_(len) {
SetComplete(len_ >= 0);
}
const Expr* lenExpr_;
int len_;
};
class TileType : public DerivedType {
public:
using ShapeExpr = std::vector<Expr*>;
using ShapeInt = std::vector<int>;
public:
static TileType* New(const ShapeInt& shape, QualType eleType);
virtual ~TileType() { }
virtual TileType* ToTile() { return this; }
virtual const TileType* ToTile() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return Complete() ? derived_->Width()*NumEle() : 0; }
virtual int Align() const { return derived_->Align(); }
virtual std::string Str() const {
return derived_->Str() + "[{}]:" + std::to_string(Width());
}
ShapeInt Shape() { return shape_; }
int NumEle() const {
int ret = 1;
for(int s: shape_)
ret *= s;
return ret;
}
bool CheckPow2NumEl() const {
int n = NumEle();
return n && !(n & (n - 1));
}
protected:
TileType(MemPool* pool, const ShapeInt& shape, QualType derived);
protected:
ShapeExpr shapeExpr_;
ShapeInt shape_;
};
class FuncType : public DerivedType {
public:
using ParamList = std::vector<Object*>;
public:
static FuncType* New(QualType derived,
int funcSpec,
bool variadic,
const ParamList& params);
virtual ~FuncType() {}
virtual FuncType* ToFunc() { return this; }
virtual const FuncType* ToFunc() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return 1; }
virtual std::string Str() const;
const ParamList& Params() const { return params_; }
void SetParams(const ParamList& params) { params_ = params; }
bool Variadic() const { return variadic_; }
bool IsInline() const { return inlineNoReturn_ & F_INLINE; }
bool IsNoReturn() const { return inlineNoReturn_ & F_NORETURN; }
protected:
FuncType(MemPool* pool, QualType derived, int inlineReturn,
bool variadic, const ParamList& params)
: DerivedType(pool, derived), inlineNoReturn_(inlineReturn),
variadic_(variadic), params_(params) {
SetComplete(false);
}
private:
int inlineNoReturn_;
bool variadic_;
ParamList params_;
};
class StructType : public Type {
public:
using MemberList = std::list<Object*>;
using Iterator = std::list<Object*>::iterator;
public:
static StructType* New(bool isStruct,
bool hasTag,
Scope* parent);
virtual ~StructType() {}
virtual StructType* ToStruct() { return this; }
virtual const StructType* ToStruct() const { return this; }
virtual bool Compatible(const Type& other) const;
virtual int Width() const { return width_; }
virtual int Align() const { return align_; }
virtual std::string Str() const;
// struct/union
void AddMember(Object* member);
void AddBitField(Object* member, int offset);
bool IsStruct() const { return isStruct_; }
Object* GetMember(const std::string& member);
Scope* MemberMap() { return memberMap_; }
MemberList& Members() { return members_; }
int Offset() const { return offset_; }
bool HasTag() const { return hasTag_; }
void MergeAnony(Object* anony);
void Finalize();
protected:
// Default is incomplete
StructType(MemPool* pool, bool isStruct, bool hasTag, Scope* parent);
StructType(const StructType& other);
private:
void CalcWidth();
bool isStruct_;
bool hasTag_;
Scope* memberMap_;
MemberList members_;
int offset_;
int width_;
int align_;
int bitFieldAlign_;
};
#endif

View File

@@ -1,56 +0,0 @@
#pragma once
#ifndef _WGTCC_VISITOR_H_
#define _WGTCC_VISITOR_H_
class BinaryOp;
class UnaryOp;
class TransOp;
class ConditionalOp;
class FuncCall;
class Identifier;
class Object;
class Enumerator;
class Constant;
class TempVar;
class Declaration;
class IfStmt;
class ForStmt;
class JumpStmt;
class ReturnStmt;
class LabelStmt;
class EmptyStmt;
class CompoundStmt;
class FuncDef;
class TranslationUnit;
class Visitor {
public:
virtual ~Visitor() {}
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
virtual void VisitTransOp(TransOp* trans) = 0;
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
virtual void VisitEnumerator(Enumerator* enumer) = 0;
virtual void VisitIdentifier(Identifier* ident) = 0;
virtual void VisitObject(Object* obj) = 0;
virtual void VisitConstant(Constant* cons) = 0;
virtual void VisitTempVar(TempVar* tempVar) = 0;
virtual void VisitDeclaration(Declaration* init) = 0;
virtual void VisitIfStmt(IfStmt* ifStmt) = 0;
virtual void VisitForStmt(ForStmt* ifStmt) = 0;
virtual void VisitJumpStmt(JumpStmt* jumpStmt) = 0;
virtual void VisitReturnStmt(ReturnStmt* returnStmt) = 0;
virtual void VisitLabelStmt(LabelStmt* labelStmt) = 0;
virtual void VisitEmptyStmt(EmptyStmt* emptyStmt) = 0;
virtual void VisitCompoundStmt(CompoundStmt* compStmt) = 0;
virtual void VisitFuncDef(FuncDef* funcDef) = 0;
virtual void VisitTranslationUnit(TranslationUnit* unit) = 0;
};
#endif

View File

@@ -1,27 +0,0 @@
#pragma once
#ifndef _TRITON_RUNTIME_ARG_H_
#define _TRITON_RUNTIME_ARG_H_
#include <string>
#include <stdexcept>
#include <sstream>
namespace triton{
namespace ir{
class type;
}
namespace driver{
class buffer;
}
namespace runtime {
}
}
#endif

View File

@@ -1,34 +0,0 @@
#pragma once
#ifndef _TRITON_RUNTIME_ERROR_H_
#define _TRITON_RUNTIME_ERROR_H_
#include <exception>
#include <string>
namespace triton {
namespace runtime{
namespace exception {
class base: public std::exception {};
#define TRITON_CREATE_RUNTIME_EXCEPTION(name, msg) class name: public base { public: const char * what() const throw(){ return "Triton: Error - Runtime: " msg; } };
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_shared_memory, "out of shared memory")
TRITON_CREATE_RUNTIME_EXCEPTION(out_of_registers, "out of registers")
class no_valid_configuration: public exception::base {
public:
no_valid_configuration(const std::string& err): err_(err) { }
const char * what() const throw(){ return err_.c_str(); }
private:
std::string err_;
};
#undef TRITON_CREATE_RUNTIME_EXCEPTION
}
}
}
#endif

View File

@@ -1,159 +0,0 @@
#pragma once
#ifndef _TRITON_RUNTIME_FUNCTION_H_
#define _TRITON_RUNTIME_FUNCTION_H_
#include <map>
#include <unordered_map>
#include <vector>
#include <string>
#include <sstream>
#include <memory>
#include <functional>
// codegen
#include "triton/ir/function.h"
#include "triton/ir/context.h"
#include "triton/runtime/arg.h"
#include "triton/runtime/error.h"
// driver forward declaration
namespace triton {
namespace driver{
class module;
class stream;
class kernel;
class context;
class device;
}
}
// ir forward declaration
namespace triton{
namespace ir {
class module;
class function;
class context;
}
}
namespace triton{
namespace runtime{
/* ------------------------- */
/* Compilation options */
/* ------------------------- */
struct options_t {
template<class T>
T D(const std::string& name) const {
return std::stoi(defines.at(name));
}
std::unordered_map<std::string, std::string> defines;
int num_warps;
};
/* ------------------------- */
/* Runtime arguments */
/* ------------------------- */
enum arg_type {
INT1_T,
INT8_T,
INT16_T,
INT32_T,
INT64_T,
HALF_T,
FLOAT_T,
DOUBLE_T,
BUFFER_T
};
inline size_t size_of(arg_type ty){
switch(ty){
case INT1_T : return 1;
case INT8_T : return 1;
case INT16_T : return 2;
case INT32_T : return 4;
case INT64_T : return 8;
case HALF_T : return 2;
case FLOAT_T : return 4;
case DOUBLE_T: return 8;
case BUFFER_T: return 8;
default: throw std::runtime_error("unknown type");
}
}
template<class T>
void add_arg(std::stringstream& ss, T arg) {
ss.write((char*)&arg, sizeof(T));
}
/* ------------------------- */
/* ------------------------- */
class kernel{
public:
typedef std::vector<size_t> grid_t;
public:
static std::shared_ptr<ir::module> src_to_ir(const std::string& src, const options_t& opt);
static std::tuple<std::shared_ptr<driver::module>,
std::shared_ptr<driver::kernel>,
size_t> ir_to_bin(ir::module& ir, driver::device *dev, const options_t &opt);
public:
kernel(const std::string& src, const options_t& opt, driver::device *device, const std::map<int, triton::ir::attribute> &attrs = {});
void operator()(const std::string& args, driver::stream *stream, const grid_t& grid) const;
std::string get_asm(const std::string &mode);
public:
const options_t opt;
private:
driver::device* dev_;
// handles
std::shared_ptr<ir::module> ir_;
std::shared_ptr<driver::module> mod_;
std::shared_ptr<driver::kernel> ker_;
// shared mem
size_t shared_mem_;
};
struct config {
std::map<std::string, std::string> defines;
int num_warps;
};
class function {
public:
typedef std::function<kernel::grid_t(const options_t&)> grid_fn_ty;
typedef std::pair<options_t, std::shared_ptr<kernel>> kernel_pair_t;
typedef std::map<std::vector<uint64_t>, kernel*> cache_t;
typedef std::vector<config> autotune_confs_t;
public:
function(const std::string& src, const options_t& opt, driver::device *device,
const std::vector<config>& tune_confs = {}, const std::vector<std::string> &tune_key = {});
kernel* autotune(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
void operator()(const std::string& args, const grid_fn_ty& grid, driver::stream *stream);
const std::vector<arg_type> get_signature() { return sig_; }
private:
std::map<std::vector<uint64_t>, std::vector<std::shared_ptr<kernel>>> kernels_;
std::map<std::vector<uint64_t>, kernel*> cache_;
std::vector<arg_type> sig_;
std::vector<int> align_idxs_;
std::vector<int> int_idxs_;
std::vector<int> key_idxs_;
std::vector<int> arg_size_;
std::vector<int> arg_off_;
std::vector<options_t> opts_;
std::string src_;
driver::device* device_;
};
}
}
#endif